Skip to main content

adk_skill/
injector.rs

1use crate::error::SkillResult;
2use crate::index::{load_skill_index, load_skill_index_with_extras};
3use crate::model::{SelectionPolicy, SkillIndex, SkillMatch};
4use crate::select::select_skills;
5use adk_core::{Content, Part};
6use adk_plugin::{Plugin, PluginConfig, PluginManager};
7use std::path::{Path, PathBuf};
8use std::sync::Arc;
9
10#[derive(Debug, Clone)]
11pub struct SkillInjectorConfig {
12    pub policy: SelectionPolicy,
13    pub max_injected_chars: usize,
14    /// Optional global skills directory (e.g. `~/.config/adk/skills/`).
15    /// Skills here are included in the index but project-local skills
16    /// take precedence when names collide.
17    pub global_skills_dir: Option<PathBuf>,
18    /// Additional directories to scan for skills.
19    pub extra_paths: Vec<PathBuf>,
20}
21
22impl Default for SkillInjectorConfig {
23    fn default() -> Self {
24        Self {
25            policy: SelectionPolicy::default(),
26            max_injected_chars: 2000,
27            global_skills_dir: None,
28            extra_paths: Vec::new(),
29        }
30    }
31}
32
33#[derive(Debug, Clone)]
34pub struct SkillInjector {
35    index: Arc<SkillIndex>,
36    config: SkillInjectorConfig,
37}
38
39impl SkillInjector {
40    pub fn from_root(root: impl AsRef<Path>, config: SkillInjectorConfig) -> SkillResult<Self> {
41        let mut extra_dirs: Vec<PathBuf> = config.extra_paths.clone();
42        if let Some(ref global) = config.global_skills_dir {
43            extra_dirs.push(global.clone());
44        }
45        let index = if extra_dirs.is_empty() {
46            load_skill_index(root)?
47        } else {
48            load_skill_index_with_extras(root, &extra_dirs)?
49        };
50        Ok(Self { index: Arc::new(index), config })
51    }
52
53    pub fn from_index(index: SkillIndex, config: SkillInjectorConfig) -> Self {
54        Self { index: Arc::new(index), config }
55    }
56
57    pub fn index(&self) -> &SkillIndex {
58        self.index.as_ref()
59    }
60
61    pub fn policy(&self) -> &SelectionPolicy {
62        &self.config.policy
63    }
64
65    pub fn max_injected_chars(&self) -> usize {
66        self.config.max_injected_chars
67    }
68
69    pub fn build_plugin(&self, name: impl Into<String>) -> Plugin {
70        let plugin_name = name.into();
71        let index = self.index.clone();
72        let policy = self.config.policy.clone();
73        let max_injected_chars = self.config.max_injected_chars;
74
75        Plugin::new(PluginConfig {
76            name: plugin_name,
77            on_user_message: Some(Box::new(move |_ctx, mut content| {
78                let index = index.clone();
79                let policy = policy.clone();
80                Box::pin(async move {
81                    let injected = apply_skill_injection(
82                        &mut content,
83                        index.as_ref(),
84                        &policy,
85                        max_injected_chars,
86                    );
87                    Ok(if injected.is_some() { Some(content) } else { None })
88                })
89            })),
90            ..Default::default()
91        })
92    }
93
94    pub fn build_plugin_manager(&self, name: impl Into<String>) -> PluginManager {
95        PluginManager::new(vec![self.build_plugin(name)])
96    }
97}
98
99/// Selects the top-scoring skill for a query and returns its formatted prompt block.
100///
101/// Returns `None` if no skill meets the selection criteria. The prompt block is
102/// truncated to `max_injected_chars` and wrapped in a `[skill:<name>]` section
103/// suitable for prepending to a user message.
104pub fn select_skill_prompt_block(
105    index: &SkillIndex,
106    query: &str,
107    policy: &SelectionPolicy,
108    max_injected_chars: usize,
109) -> Option<(SkillMatch, String)> {
110    let top = select_skills(index, query, policy).into_iter().next()?;
111    let matched = index.find_by_id(&top.skill.id)?;
112    let prompt_block = matched.engineer_prompt_block(max_injected_chars);
113    Some((top, prompt_block))
114}
115
116/// Injects the best-matching skill prompt into a user [`Content`] message.
117///
118/// Extracts the text from `content`, selects the top skill from `index`, and
119/// prepends its prompt block to the first text part. Returns the matched skill
120/// on success, or `None` if the content is not a user message, the index is
121/// empty, or no skill meets the selection threshold.
122pub fn apply_skill_injection(
123    content: &mut Content,
124    index: &SkillIndex,
125    policy: &SelectionPolicy,
126    max_injected_chars: usize,
127) -> Option<SkillMatch> {
128    if content.role != "user" || index.is_empty() {
129        return None;
130    }
131
132    let original_text = extract_text(content);
133    if original_text.trim().is_empty() {
134        return None;
135    }
136
137    let (top, prompt_block) =
138        select_skill_prompt_block(index, &original_text, policy, max_injected_chars)?;
139    let injected_text = format!("{prompt_block}\n\n{original_text}");
140
141    if let Some(Part::Text { text }) =
142        content.parts.iter_mut().find(|part| matches!(part, Part::Text { .. }))
143    {
144        *text = injected_text;
145    } else {
146        content.parts.insert(0, Part::Text { text: injected_text });
147    }
148
149    Some(top)
150}
151
152fn extract_text(content: &Content) -> String {
153    content
154        .parts
155        .iter()
156        .filter_map(|p| match p {
157            Part::Text { text } => Some(text.as_str()),
158            _ => None,
159        })
160        .collect::<Vec<_>>()
161        .join("\n")
162}
163
164#[cfg(test)]
165mod tests {
166    use super::*;
167    use crate::index::load_skill_index;
168    use std::fs;
169
170    #[test]
171    fn injects_top_skill_into_user_message() {
172        let temp = tempfile::tempdir().unwrap();
173        let root = temp.path();
174        fs::create_dir_all(root.join(".skills")).unwrap();
175
176        fs::write(
177            root.join(".skills/search.md"),
178            "---\nname: search\ndescription: Search code\n---\nUse rg first.",
179        )
180        .unwrap();
181
182        let index = load_skill_index(root).unwrap();
183        let policy = SelectionPolicy { top_k: 1, min_score: 0.1, ..SelectionPolicy::default() };
184
185        let mut content = Content::new("user").with_text("Please search this repository quickly");
186        let matched = apply_skill_injection(&mut content, &index, &policy, 1000);
187
188        assert!(matched.is_some());
189        let injected = content.parts[0].text().unwrap();
190        assert!(injected.contains("[skill:search]"));
191        assert!(injected.contains("Use rg first."));
192    }
193}