1use crate::error::SkillResult;
2use crate::index::load_skill_index;
3use crate::model::{SelectionPolicy, SkillIndex, SkillMatch};
4use crate::select::select_skills;
5use adk_core::{Content, Part};
6use adk_plugin::{Plugin, PluginConfig, PluginManager};
7use std::path::Path;
8use std::sync::Arc;
9
10#[derive(Debug, Clone)]
11pub struct SkillInjectorConfig {
12 pub policy: SelectionPolicy,
13 pub max_injected_chars: usize,
14}
15
16impl Default for SkillInjectorConfig {
17 fn default() -> Self {
18 Self { policy: SelectionPolicy::default(), max_injected_chars: 2000 }
19 }
20}
21
22#[derive(Debug, Clone)]
23pub struct SkillInjector {
24 index: Arc<SkillIndex>,
25 config: SkillInjectorConfig,
26}
27
28impl SkillInjector {
29 pub fn from_root(root: impl AsRef<Path>, config: SkillInjectorConfig) -> SkillResult<Self> {
30 let index = load_skill_index(root)?;
31 Ok(Self { index: Arc::new(index), config })
32 }
33
34 pub fn from_index(index: SkillIndex, config: SkillInjectorConfig) -> Self {
35 Self { index: Arc::new(index), config }
36 }
37
38 pub fn index(&self) -> &SkillIndex {
39 self.index.as_ref()
40 }
41
42 pub fn policy(&self) -> &SelectionPolicy {
43 &self.config.policy
44 }
45
46 pub fn max_injected_chars(&self) -> usize {
47 self.config.max_injected_chars
48 }
49
50 pub fn build_plugin(&self, name: impl Into<String>) -> Plugin {
51 let plugin_name = name.into();
52 let index = self.index.clone();
53 let policy = self.config.policy.clone();
54 let max_injected_chars = self.config.max_injected_chars;
55
56 Plugin::new(PluginConfig {
57 name: plugin_name,
58 on_user_message: Some(Box::new(move |_ctx, mut content| {
59 let index = index.clone();
60 let policy = policy.clone();
61 Box::pin(async move {
62 let injected = apply_skill_injection(
63 &mut content,
64 index.as_ref(),
65 &policy,
66 max_injected_chars,
67 );
68 Ok(if injected.is_some() { Some(content) } else { None })
69 })
70 })),
71 ..Default::default()
72 })
73 }
74
75 pub fn build_plugin_manager(&self, name: impl Into<String>) -> PluginManager {
76 PluginManager::new(vec![self.build_plugin(name)])
77 }
78}
79
80pub fn select_skill_prompt_block(
86 index: &SkillIndex,
87 query: &str,
88 policy: &SelectionPolicy,
89 max_injected_chars: usize,
90) -> Option<(SkillMatch, String)> {
91 let top = select_skills(index, query, policy).into_iter().next()?;
92 let matched = index.find_by_id(&top.skill.id)?;
93 let prompt_block = matched.engineer_prompt_block(max_injected_chars);
94 Some((top, prompt_block))
95}
96
97pub fn apply_skill_injection(
104 content: &mut Content,
105 index: &SkillIndex,
106 policy: &SelectionPolicy,
107 max_injected_chars: usize,
108) -> Option<SkillMatch> {
109 if content.role != "user" || index.is_empty() {
110 return None;
111 }
112
113 let original_text = extract_text(content);
114 if original_text.trim().is_empty() {
115 return None;
116 }
117
118 let (top, prompt_block) =
119 select_skill_prompt_block(index, &original_text, policy, max_injected_chars)?;
120 let injected_text = format!("{prompt_block}\n\n{original_text}");
121
122 if let Some(Part::Text { text }) =
123 content.parts.iter_mut().find(|part| matches!(part, Part::Text { .. }))
124 {
125 *text = injected_text;
126 } else {
127 content.parts.insert(0, Part::Text { text: injected_text });
128 }
129
130 Some(top)
131}
132
133fn extract_text(content: &Content) -> String {
134 content
135 .parts
136 .iter()
137 .filter_map(|p| match p {
138 Part::Text { text } => Some(text.as_str()),
139 _ => None,
140 })
141 .collect::<Vec<_>>()
142 .join("\n")
143}
144
145#[cfg(test)]
146mod tests {
147 use super::*;
148 use crate::index::load_skill_index;
149 use std::fs;
150
151 #[test]
152 fn injects_top_skill_into_user_message() {
153 let temp = tempfile::tempdir().unwrap();
154 let root = temp.path();
155 fs::create_dir_all(root.join(".skills")).unwrap();
156
157 fs::write(
158 root.join(".skills/search.md"),
159 "---\nname: search\ndescription: Search code\n---\nUse rg first.",
160 )
161 .unwrap();
162
163 let index = load_skill_index(root).unwrap();
164 let policy = SelectionPolicy { top_k: 1, min_score: 0.1, ..SelectionPolicy::default() };
165
166 let mut content = Content::new("user").with_text("Please search this repository quickly");
167 let matched = apply_skill_injection(&mut content, &index, &policy, 1000);
168
169 assert!(matched.is_some());
170 let injected = content.parts[0].text().unwrap();
171 assert!(injected.contains("[skill:search]"));
172 assert!(injected.contains("Use rg first."));
173 }
174}