1use crate::error::SkillResult;
2use crate::index::load_skill_index;
3use crate::model::{SelectionPolicy, SkillIndex, SkillMatch};
4use crate::select::select_skills;
5use adk_core::{Content, Part};
6use adk_plugin::{Plugin, PluginConfig, PluginManager};
7use std::path::Path;
8use std::sync::Arc;
9
10#[derive(Debug, Clone)]
11pub struct SkillInjectorConfig {
12 pub policy: SelectionPolicy,
13 pub max_injected_chars: usize,
14}
15
16impl Default for SkillInjectorConfig {
17 fn default() -> Self {
18 Self { policy: SelectionPolicy::default(), max_injected_chars: 2000 }
19 }
20}
21
22#[derive(Debug, Clone)]
23pub struct SkillInjector {
24 index: Arc<SkillIndex>,
25 config: SkillInjectorConfig,
26}
27
28impl SkillInjector {
29 pub fn from_root(root: impl AsRef<Path>, config: SkillInjectorConfig) -> SkillResult<Self> {
30 let index = load_skill_index(root)?;
31 Ok(Self { index: Arc::new(index), config })
32 }
33
34 pub fn from_index(index: SkillIndex, config: SkillInjectorConfig) -> Self {
35 Self { index: Arc::new(index), config }
36 }
37
38 pub fn index(&self) -> &SkillIndex {
39 self.index.as_ref()
40 }
41
42 pub fn policy(&self) -> &SelectionPolicy {
43 &self.config.policy
44 }
45
46 pub fn max_injected_chars(&self) -> usize {
47 self.config.max_injected_chars
48 }
49
50 pub fn build_plugin(&self, name: impl Into<String>) -> Plugin {
51 let plugin_name = name.into();
52 let index = self.index.clone();
53 let policy = self.config.policy.clone();
54 let max_injected_chars = self.config.max_injected_chars;
55
56 Plugin::new(PluginConfig {
57 name: plugin_name,
58 on_user_message: Some(Box::new(move |_ctx, mut content| {
59 let index = index.clone();
60 let policy = policy.clone();
61 Box::pin(async move {
62 let injected = apply_skill_injection(
63 &mut content,
64 index.as_ref(),
65 &policy,
66 max_injected_chars,
67 );
68 Ok(if injected.is_some() { Some(content) } else { None })
69 })
70 })),
71 ..Default::default()
72 })
73 }
74
75 pub fn build_plugin_manager(&self, name: impl Into<String>) -> PluginManager {
76 PluginManager::new(vec![self.build_plugin(name)])
77 }
78}
79
80pub fn select_skill_prompt_block(
81 index: &SkillIndex,
82 query: &str,
83 policy: &SelectionPolicy,
84 max_injected_chars: usize,
85) -> Option<(SkillMatch, String)> {
86 let top = select_skills(index, query, policy).into_iter().next()?;
87 let matched = index.skills().iter().find(|s| s.id == top.skill.id)?;
88 let mut skill_body = matched.body.clone();
89 if skill_body.chars().count() > max_injected_chars {
90 skill_body = skill_body.chars().take(max_injected_chars).collect();
91 }
92
93 let prompt_block = format!("[skill:{}]\n{}\n[/skill]", matched.name, skill_body);
94 Some((top, prompt_block))
95}
96
97pub fn apply_skill_injection(
98 content: &mut Content,
99 index: &SkillIndex,
100 policy: &SelectionPolicy,
101 max_injected_chars: usize,
102) -> Option<SkillMatch> {
103 if content.role != "user" || index.is_empty() {
104 return None;
105 }
106
107 let original_text = extract_text(content);
108 if original_text.trim().is_empty() {
109 return None;
110 }
111
112 let (top, prompt_block) =
113 select_skill_prompt_block(index, &original_text, policy, max_injected_chars)?;
114 let injected_text = format!("{prompt_block}\n\n{original_text}");
115
116 if let Some(Part::Text { text }) =
117 content.parts.iter_mut().find(|part| matches!(part, Part::Text { .. }))
118 {
119 *text = injected_text;
120 } else {
121 content.parts.insert(0, Part::Text { text: injected_text });
122 }
123
124 Some(top)
125}
126
127fn extract_text(content: &Content) -> String {
128 content
129 .parts
130 .iter()
131 .filter_map(|p| match p {
132 Part::Text { text } => Some(text.as_str()),
133 _ => None,
134 })
135 .collect::<Vec<_>>()
136 .join("\n")
137}
138
139#[cfg(test)]
140mod tests {
141 use super::*;
142 use crate::index::load_skill_index;
143 use std::fs;
144
145 #[test]
146 fn injects_top_skill_into_user_message() {
147 let temp = tempfile::tempdir().unwrap();
148 let root = temp.path();
149 fs::create_dir_all(root.join(".skills")).unwrap();
150
151 fs::write(
152 root.join(".skills/search.md"),
153 "---\nname: search\ndescription: Search code\n---\nUse rg first.",
154 )
155 .unwrap();
156
157 let index = load_skill_index(root).unwrap();
158 let policy = SelectionPolicy { top_k: 1, min_score: 0.1, ..SelectionPolicy::default() };
159
160 let mut content = Content::new("user").with_text("Please search this repository quickly");
161 let matched = apply_skill_injection(&mut content, &index, &policy, 1000);
162
163 assert!(matched.is_some());
164 let injected = content.parts[0].text().unwrap();
165 assert!(injected.contains("[skill:search]"));
166 assert!(injected.contains("Use rg first."));
167 }
168}