1#![allow(dead_code)]
3
4use std::collections::HashMap;
5use std::sync::{Arc, Mutex};
6
7use crate::types::Message;
8use crate::utils::hooks::api_query_hook_helper::{
9 ApiQueryHookConfig, ReplHookContext, create_api_query_hook,
10};
11use crate::utils::hooks::post_sampling_hooks::register_post_sampling_hook;
12
13const TURN_BATCH_SIZE: usize = 5;
15
16#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
18pub struct SkillUpdate {
19 pub section: String,
20 pub change: String,
21 pub reason: String,
22}
23
24#[derive(Debug, Clone)]
26pub struct SkillImprovementSuggestion {
27 pub skill_name: String,
28 pub updates: Vec<SkillUpdate>,
29}
30
31struct SkillImprovementState {
33 last_analyzed_count: usize,
34 last_analyzed_index: usize,
35}
36
37lazy_static::lazy_static! {
38 static ref SKILL_IMPROVEMENT_STATE: Arc<Mutex<SkillImprovementState>> = Arc::new(Mutex::new(
39 SkillImprovementState {
40 last_analyzed_count: 0,
41 last_analyzed_index: 0,
42 }
43 ));
44}
45
46fn find_project_skill() -> Option<ProjectSkillInfo> {
48 None
51}
52
53struct ProjectSkillInfo {
55 skill_name: String,
56 skill_path: String,
57 content: String,
58}
59
60fn format_recent_messages(messages: &[Message]) -> String {
62 messages
63 .iter()
64 .filter(|m| m.is_user() || m.is_assistant())
65 .map(|m| {
66 let role = if m.is_user() { "User" } else { "Assistant" };
67 let content = m.content.chars().take(500).collect::<String>();
68 format!("{}: {}", role, content)
69 })
70 .collect::<Vec<_>>()
71 .join("\n\n")
72}
73
74fn count_messages<F>(messages: &[Message], predicate: F) -> usize
76where
77 F: Fn(&Message) -> bool,
78{
79 messages.iter().filter(|m| predicate(m)).count()
80}
81
82fn create_skill_improvement_hook() -> Arc<
84 dyn Fn(ReplHookContext) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>>
85 + Send
86 + Sync,
87> {
88 let config: ApiQueryHookConfig<Vec<SkillUpdate>> = ApiQueryHookConfig {
89 name: "skill_improvement".to_string(),
90 should_run: Box::new(|context| {
91 let query_source = context.query_source.clone();
92 let messages = context.messages.clone();
93 Box::pin(async move {
94 if query_source
96 .as_ref()
97 .map(|s| s != "repl_main_thread")
98 .unwrap_or(true)
99 {
100 return false;
101 }
102
103 if find_project_skill().is_none() {
105 return false;
106 }
107
108 let mut state = SKILL_IMPROVEMENT_STATE.lock().unwrap();
110 let user_count = count_messages(&messages, |m| m.is_user());
111 if user_count - state.last_analyzed_count < TURN_BATCH_SIZE {
112 return false;
113 }
114
115 state.last_analyzed_count = user_count;
116 true
117 })
118 }),
119 build_messages: Box::new(|context| {
120 let project_skill = match find_project_skill() {
121 Some(s) => s,
122 None => return Vec::new(),
123 };
124
125 let mut state = SKILL_IMPROVEMENT_STATE.lock().unwrap();
126 let new_messages = context.messages[state.last_analyzed_index..].to_vec();
128 state.last_analyzed_index = context.messages.len();
129
130 let formatted = format_recent_messages(&new_messages);
131
132 let prompt = format!(
133 r#"You are analyzing a conversation where a user is executing a skill (a repeatable process).
134Your job: identify if the user's recent messages contain preferences, requests, or corrections that should be permanently added to the skill definition for future runs.
135
136<skill_definition>
137{}
138</skill_definition>
139
140<recent_messages>
141{}
142</recent_messages>
143
144Look for:
145- Requests to add, change, or remove steps: "can you also ask me X", "please do Y too", "don't do Z"
146- Preferences about how steps should work: "ask me about energy levels", "note the time", "use a casual tone"
147- Corrections: "no, do X instead", "always use Y", "make sure to..."
148
149Ignore:
150- Routine conversation that doesn't generalize (one-time answers, chitchat)
151- Things the skill already does
152
153Output a JSON array inside <updates> tags. Each item: {{"section": "which step/section to modify or 'new step'", "change": "what to add/modify", "reason": "which user message prompted this"}}.
154Output <updates>[]</updates> if no updates are needed."#,
155 project_skill.content, formatted
156 );
157
158 vec![Message {
159 role: crate::types::api_types::MessageRole::User,
160 content: prompt,
161 attachments: None,
162 tool_call_id: None,
163 tool_calls: None,
164 is_error: None,
165 is_meta: None,
166 is_api_error_message: None,
167 error_details: None,
168 uuid: None,
169 }]
170 }),
171 system_prompt: None,
172 use_tools: Some(false),
173 parse_response: Box::new(|content, _context| {
174 if let Some(updates_str) = extract_tag(content, "updates") {
176 match serde_json::from_str::<Vec<SkillUpdate>>(&updates_str) {
177 Ok(updates) => updates,
178 Err(_) => Vec::new(),
179 }
180 } else {
181 Vec::new()
182 }
183 }),
184 log_result: Box::new(|result, context| {
185 if let crate::utils::hooks::api_query_hook_helper::ApiQueryResult::Success {
186 result: updates,
187 uuid,
188 ..
189 } = result
190 {
191 if !updates.is_empty() {
192 let project_skill = find_project_skill();
193 let skill_name = project_skill
194 .as_ref()
195 .map(|s| s.skill_name.clone())
196 .unwrap_or_else(|| "unknown".to_string());
197
198 log_event(
199 "tengu_skill_improvement_detected",
200 &serde_json::json!({
201 "updateCount": updates.len(),
202 "uuid": uuid,
203 "skill_name": skill_name,
204 }),
205 );
206
207 log::debug!(
210 "Skill improvement detected for '{}': {} updates",
211 skill_name,
212 updates.len()
213 );
214 }
215 }
216 }),
217 get_model: Box::new(|_context| get_small_fast_model()),
218 };
219
220 let boxed_hook = create_api_query_hook(config);
221 Arc::from(boxed_hook)
222}
223
224pub fn init_skill_improvement() {
226 let skill_improvement_enabled = true; let copper_panda_enabled = false; if skill_improvement_enabled && copper_panda_enabled {
231 let hook = create_skill_improvement_hook();
232 register_post_sampling_hook(hook);
233 }
234}
235
236pub async fn apply_skill_improvement(skill_name: &str, updates: &[SkillUpdate]) {
239 if skill_name.is_empty() {
240 return;
241 }
242
243 let cwd = std::env::current_dir().unwrap_or_default();
245 let file_path = cwd
246 .join(".claude")
247 .join("skills")
248 .join(skill_name)
249 .join("SKILL.md");
250
251 let current_content = match tokio::fs::read_to_string(&file_path).await {
252 Ok(content) => content,
253 Err(_) => {
254 log::error!("Failed to read skill file for improvement: {:?}", file_path);
255 return;
256 }
257 };
258
259 let update_list: String = updates
260 .iter()
261 .map(|u| format!("- {}: {}", u.section, u.change))
262 .collect::<Vec<_>>()
263 .join("\n");
264
265 let prompt = format!(
266 r#"You are editing a skill definition file. Apply the following improvements to the skill.
267
268<current_skill_file>
269{}
270</current_skill_file>
271
272<improvements>
273{}
274</improvements>
275
276Rules:
277- Integrate the improvements naturally into the existing structure
278- Preserve frontmatter (--- block) exactly as-is
279- Preserve the overall format and style
280- Do not remove existing content unless an improvement explicitly replaces it
281- Output the complete updated file inside <updated_file> tags"#,
282 current_content, update_list
283 );
284
285 log::debug!(
288 "Would apply skill improvements for '{}': {}",
289 skill_name,
290 update_list
291 );
292}
293
294fn extract_tag(content: &str, tag_name: &str) -> Option<String> {
296 let open_tag = format!("<{}>", tag_name);
297 let close_tag = format!("</{}>", tag_name);
298
299 if let Some(start) = content.find(&open_tag) {
300 let content_start = start + open_tag.len();
301 if let Some(end) = content[content_start..].find(&close_tag) {
302 return Some(content[content_start..content_start + end].to_string());
303 }
304 }
305 None
306}
307
308fn get_small_fast_model() -> String {
310 "claude-3-haiku-20240307".to_string()
311}
312
313fn log_event(event_name: &str, metadata: &serde_json::Value) {
315 log::debug!("Analytics event: {} - {:?}", event_name, metadata);
316}
317
318impl Message {
320 fn is_user(&self) -> bool {
321 matches!(self.role, crate::types::api_types::MessageRole::User)
322 }
323
324 fn is_assistant(&self) -> bool {
325 matches!(self.role, crate::types::api_types::MessageRole::Assistant)
326 }
327}
328
329#[cfg(test)]
330mod tests {
331 use super::*;
332
333 #[test]
334 fn test_extract_tag() {
335 let content = "Some text <updates>[{\"section\": \"test\", \"change\": \"add\", \"reason\": \"because\"}]</updates> more text";
336 let result = extract_tag(content, "updates");
337 assert!(result.is_some());
338 let updates = result.unwrap();
339 assert!(updates.contains("section"));
340 }
341
342 #[test]
343 fn test_extract_tag_empty() {
344 let content = "<updates>[]</updates>";
345 let result = extract_tag(content, "updates");
346 assert_eq!(result, Some("[]".to_string()));
347 }
348
349 #[test]
350 fn test_extract_tag_not_found() {
351 let content = "No tags here";
352 let result = extract_tag(content, "updates");
353 assert!(result.is_none());
354 }
355
356 #[test]
357 fn test_format_recent_messages() {
358 let messages = vec![
359 Message {
360 content: "Hello".to_string(),
361 ..Default::default()
362 },
363 Message {
364 content: "Hi there".to_string(),
365 ..Default::default()
366 },
367 ];
368 let result = format_recent_messages(&messages);
369 assert!(result.contains("Hello"));
371 }
372
373 #[test]
374 fn test_count_messages() {
375 let messages = vec![
376 Message {
377 content: "msg1".to_string(),
378 ..Default::default()
379 },
380 Message {
381 content: "msg2".to_string(),
382 ..Default::default()
383 },
384 Message {
385 content: "msg3".to_string(),
386 ..Default::default()
387 },
388 ];
389 let count = count_messages(&messages, |_| true);
390 assert_eq!(count, 3);
391 }
392}