Skip to main content

adk_skill/
injector.rs

1use crate::error::SkillResult;
2use crate::index::load_skill_index;
3use crate::model::{SelectionPolicy, SkillIndex, SkillMatch};
4use crate::select::select_skills;
5use adk_core::{Content, Part};
6use adk_plugin::{Plugin, PluginConfig, PluginManager};
7use std::path::Path;
8use std::sync::Arc;
9
10#[derive(Debug, Clone)]
11pub struct SkillInjectorConfig {
12    pub policy: SelectionPolicy,
13    pub max_injected_chars: usize,
14}
15
16impl Default for SkillInjectorConfig {
17    fn default() -> Self {
18        Self { policy: SelectionPolicy::default(), max_injected_chars: 2000 }
19    }
20}
21
22#[derive(Debug, Clone)]
23pub struct SkillInjector {
24    index: Arc<SkillIndex>,
25    config: SkillInjectorConfig,
26}
27
28impl SkillInjector {
29    pub fn from_root(root: impl AsRef<Path>, config: SkillInjectorConfig) -> SkillResult<Self> {
30        let index = load_skill_index(root)?;
31        Ok(Self { index: Arc::new(index), config })
32    }
33
34    pub fn from_index(index: SkillIndex, config: SkillInjectorConfig) -> Self {
35        Self { index: Arc::new(index), config }
36    }
37
38    pub fn index(&self) -> &SkillIndex {
39        self.index.as_ref()
40    }
41
42    pub fn policy(&self) -> &SelectionPolicy {
43        &self.config.policy
44    }
45
46    pub fn max_injected_chars(&self) -> usize {
47        self.config.max_injected_chars
48    }
49
50    pub fn build_plugin(&self, name: impl Into<String>) -> Plugin {
51        let plugin_name = name.into();
52        let index = self.index.clone();
53        let policy = self.config.policy.clone();
54        let max_injected_chars = self.config.max_injected_chars;
55
56        Plugin::new(PluginConfig {
57            name: plugin_name,
58            on_user_message: Some(Box::new(move |_ctx, mut content| {
59                let index = index.clone();
60                let policy = policy.clone();
61                Box::pin(async move {
62                    let injected = apply_skill_injection(
63                        &mut content,
64                        index.as_ref(),
65                        &policy,
66                        max_injected_chars,
67                    );
68                    Ok(if injected.is_some() { Some(content) } else { None })
69                })
70            })),
71            ..Default::default()
72        })
73    }
74
75    pub fn build_plugin_manager(&self, name: impl Into<String>) -> PluginManager {
76        PluginManager::new(vec![self.build_plugin(name)])
77    }
78}
79
80/// Selects the top-scoring skill for a query and returns its formatted prompt block.
81///
82/// Returns `None` if no skill meets the selection criteria. The prompt block is
83/// truncated to `max_injected_chars` and wrapped in a `[skill:<name>]` section
84/// suitable for prepending to a user message.
85pub fn select_skill_prompt_block(
86    index: &SkillIndex,
87    query: &str,
88    policy: &SelectionPolicy,
89    max_injected_chars: usize,
90) -> Option<(SkillMatch, String)> {
91    let top = select_skills(index, query, policy).into_iter().next()?;
92    let matched = index.find_by_id(&top.skill.id)?;
93    let prompt_block = matched.engineer_prompt_block(max_injected_chars);
94    Some((top, prompt_block))
95}
96
97/// Injects the best-matching skill prompt into a user [`Content`] message.
98///
99/// Extracts the text from `content`, selects the top skill from `index`, and
100/// prepends its prompt block to the first text part. Returns the matched skill
101/// on success, or `None` if the content is not a user message, the index is
102/// empty, or no skill meets the selection threshold.
103pub fn apply_skill_injection(
104    content: &mut Content,
105    index: &SkillIndex,
106    policy: &SelectionPolicy,
107    max_injected_chars: usize,
108) -> Option<SkillMatch> {
109    if content.role != "user" || index.is_empty() {
110        return None;
111    }
112
113    let original_text = extract_text(content);
114    if original_text.trim().is_empty() {
115        return None;
116    }
117
118    let (top, prompt_block) =
119        select_skill_prompt_block(index, &original_text, policy, max_injected_chars)?;
120    let injected_text = format!("{prompt_block}\n\n{original_text}");
121
122    if let Some(Part::Text { text }) =
123        content.parts.iter_mut().find(|part| matches!(part, Part::Text { .. }))
124    {
125        *text = injected_text;
126    } else {
127        content.parts.insert(0, Part::Text { text: injected_text });
128    }
129
130    Some(top)
131}
132
133fn extract_text(content: &Content) -> String {
134    content
135        .parts
136        .iter()
137        .filter_map(|p| match p {
138            Part::Text { text } => Some(text.as_str()),
139            _ => None,
140        })
141        .collect::<Vec<_>>()
142        .join("\n")
143}
144
145#[cfg(test)]
146mod tests {
147    use super::*;
148    use crate::index::load_skill_index;
149    use std::fs;
150
151    #[test]
152    fn injects_top_skill_into_user_message() {
153        let temp = tempfile::tempdir().unwrap();
154        let root = temp.path();
155        fs::create_dir_all(root.join(".skills")).unwrap();
156
157        fs::write(
158            root.join(".skills/search.md"),
159            "---\nname: search\ndescription: Search code\n---\nUse rg first.",
160        )
161        .unwrap();
162
163        let index = load_skill_index(root).unwrap();
164        let policy = SelectionPolicy { top_k: 1, min_score: 0.1, ..SelectionPolicy::default() };
165
166        let mut content = Content::new("user").with_text("Please search this repository quickly");
167        let matched = apply_skill_injection(&mut content, &index, &policy, 1000);
168
169        assert!(matched.is_some());
170        let injected = content.parts[0].text().unwrap();
171        assert!(injected.contains("[skill:search]"));
172        assert!(injected.contains("Use rg first."));
173    }
174}