Skip to main content

adk_skill/
injector.rs

1use crate::error::SkillResult;
2use crate::index::load_skill_index;
3use crate::model::{SelectionPolicy, SkillIndex, SkillMatch};
4use crate::select::select_skills;
5use adk_core::{Content, Part};
6use adk_plugin::{Plugin, PluginConfig, PluginManager};
7use std::path::Path;
8use std::sync::Arc;
9
10#[derive(Debug, Clone)]
11pub struct SkillInjectorConfig {
12    pub policy: SelectionPolicy,
13    pub max_injected_chars: usize,
14}
15
16impl Default for SkillInjectorConfig {
17    fn default() -> Self {
18        Self { policy: SelectionPolicy::default(), max_injected_chars: 2000 }
19    }
20}
21
22#[derive(Debug, Clone)]
23pub struct SkillInjector {
24    index: Arc<SkillIndex>,
25    config: SkillInjectorConfig,
26}
27
28impl SkillInjector {
29    pub fn from_root(root: impl AsRef<Path>, config: SkillInjectorConfig) -> SkillResult<Self> {
30        let index = load_skill_index(root)?;
31        Ok(Self { index: Arc::new(index), config })
32    }
33
34    pub fn from_index(index: SkillIndex, config: SkillInjectorConfig) -> Self {
35        Self { index: Arc::new(index), config }
36    }
37
38    pub fn index(&self) -> &SkillIndex {
39        self.index.as_ref()
40    }
41
42    pub fn policy(&self) -> &SelectionPolicy {
43        &self.config.policy
44    }
45
46    pub fn max_injected_chars(&self) -> usize {
47        self.config.max_injected_chars
48    }
49
50    pub fn build_plugin(&self, name: impl Into<String>) -> Plugin {
51        let plugin_name = name.into();
52        let index = self.index.clone();
53        let policy = self.config.policy.clone();
54        let max_injected_chars = self.config.max_injected_chars;
55
56        Plugin::new(PluginConfig {
57            name: plugin_name,
58            on_user_message: Some(Box::new(move |_ctx, mut content| {
59                let index = index.clone();
60                let policy = policy.clone();
61                Box::pin(async move {
62                    let injected = apply_skill_injection(
63                        &mut content,
64                        index.as_ref(),
65                        &policy,
66                        max_injected_chars,
67                    );
68                    Ok(if injected.is_some() { Some(content) } else { None })
69                })
70            })),
71            ..Default::default()
72        })
73    }
74
75    pub fn build_plugin_manager(&self, name: impl Into<String>) -> PluginManager {
76        PluginManager::new(vec![self.build_plugin(name)])
77    }
78}
79
80pub fn select_skill_prompt_block(
81    index: &SkillIndex,
82    query: &str,
83    policy: &SelectionPolicy,
84    max_injected_chars: usize,
85) -> Option<(SkillMatch, String)> {
86    let top = select_skills(index, query, policy).into_iter().next()?;
87    let matched = index.skills().iter().find(|s| s.id == top.skill.id)?;
88    let mut skill_body = matched.body.clone();
89    if skill_body.chars().count() > max_injected_chars {
90        skill_body = skill_body.chars().take(max_injected_chars).collect();
91    }
92
93    let prompt_block = format!("[skill:{}]\n{}\n[/skill]", matched.name, skill_body);
94    Some((top, prompt_block))
95}
96
97pub fn apply_skill_injection(
98    content: &mut Content,
99    index: &SkillIndex,
100    policy: &SelectionPolicy,
101    max_injected_chars: usize,
102) -> Option<SkillMatch> {
103    if content.role != "user" || index.is_empty() {
104        return None;
105    }
106
107    let original_text = extract_text(content);
108    if original_text.trim().is_empty() {
109        return None;
110    }
111
112    let (top, prompt_block) =
113        select_skill_prompt_block(index, &original_text, policy, max_injected_chars)?;
114    let injected_text = format!("{prompt_block}\n\n{original_text}");
115
116    if let Some(Part::Text { text }) =
117        content.parts.iter_mut().find(|part| matches!(part, Part::Text { .. }))
118    {
119        *text = injected_text;
120    } else {
121        content.parts.insert(0, Part::Text { text: injected_text });
122    }
123
124    Some(top)
125}
126
127fn extract_text(content: &Content) -> String {
128    content
129        .parts
130        .iter()
131        .filter_map(|p| match p {
132            Part::Text { text } => Some(text.as_str()),
133            _ => None,
134        })
135        .collect::<Vec<_>>()
136        .join("\n")
137}
138
139#[cfg(test)]
140mod tests {
141    use super::*;
142    use crate::index::load_skill_index;
143    use std::fs;
144
145    #[test]
146    fn injects_top_skill_into_user_message() {
147        let temp = tempfile::tempdir().unwrap();
148        let root = temp.path();
149        fs::create_dir_all(root.join(".skills")).unwrap();
150
151        fs::write(
152            root.join(".skills/search.md"),
153            "---\nname: search\ndescription: Search code\n---\nUse rg first.",
154        )
155        .unwrap();
156
157        let index = load_skill_index(root).unwrap();
158        let policy = SelectionPolicy { top_k: 1, min_score: 0.1, ..SelectionPolicy::default() };
159
160        let mut content = Content::new("user").with_text("Please search this repository quickly");
161        let matched = apply_skill_injection(&mut content, &index, &policy, 1000);
162
163        assert!(matched.is_some());
164        let injected = content.parts[0].text().unwrap();
165        assert!(injected.contains("[skill:search]"));
166        assert!(injected.contains("Use rg first."));
167    }
168}