Skip to main content

ai_agent/utils/hooks/
skill_improvement.rs

1// Source: ~/claudecode/openclaudecode/src/utils/hooks/skillImprovement.ts
2#![allow(dead_code)]
3
4use std::collections::HashMap;
5use std::sync::{Arc, Mutex};
6
7use crate::types::Message;
8use crate::utils::hooks::api_query_hook_helper::{
9    ApiQueryHookConfig, ReplHookContext, create_api_query_hook,
10};
11use crate::utils::hooks::post_sampling_hooks::register_post_sampling_hook;
12
13/// Number of user messages between each skill improvement analysis
14const TURN_BATCH_SIZE: usize = 5;
15
16/// A skill update suggestion
17#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
18pub struct SkillUpdate {
19    pub section: String,
20    pub change: String,
21    pub reason: String,
22}
23
24/// Skill improvement suggestion
25#[derive(Debug, Clone)]
26pub struct SkillImprovementSuggestion {
27    pub skill_name: String,
28    pub updates: Vec<SkillUpdate>,
29}
30
31/// State for skill improvement tracking
32struct SkillImprovementState {
33    last_analyzed_count: usize,
34    last_analyzed_index: usize,
35}
36
37lazy_static::lazy_static! {
38    static ref SKILL_IMPROVEMENT_STATE: Arc<Mutex<SkillImprovementState>> = Arc::new(Mutex::new(
39        SkillImprovementState {
40            last_analyzed_count: 0,
41            last_analyzed_index: 0,
42        }
43    ));
44}
45
46/// Find the project skill (simplified)
47fn find_project_skill() -> Option<ProjectSkillInfo> {
48    // In the TS version, this calls getInvokedSkillsForAgent
49    // and looks for skills starting with "projectSettings:"
50    None
51}
52
53/// Project skill information
54struct ProjectSkillInfo {
55    skill_name: String,
56    skill_path: String,
57    content: String,
58}
59
60/// Format recent messages for the skill improvement prompt
61fn format_recent_messages(messages: &[Message]) -> String {
62    messages
63        .iter()
64        .filter(|m| m.is_user() || m.is_assistant())
65        .map(|m| {
66            let role = if m.is_user() { "User" } else { "Assistant" };
67            let content = m.content.chars().take(500).collect::<String>();
68            format!("{}: {}", role, content)
69        })
70        .collect::<Vec<_>>()
71        .join("\n\n")
72}
73
74/// Count messages matching a predicate
75fn count_messages<F>(messages: &[Message], predicate: F) -> usize
76where
77    F: Fn(&Message) -> bool,
78{
79    messages.iter().filter(|m| predicate(m)).count()
80}
81
82/// Create the skill improvement hook
83fn create_skill_improvement_hook() -> Arc<
84    dyn Fn(ReplHookContext) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>>
85        + Send
86        + Sync,
87> {
88    let config: ApiQueryHookConfig<Vec<SkillUpdate>> = ApiQueryHookConfig {
89        name: "skill_improvement".to_string(),
90        should_run: Box::new(|context| {
91            let query_source = context.query_source.clone();
92            let messages = context.messages.clone();
93            Box::pin(async move {
94                // Only run for main REPL thread
95                if query_source
96                    .as_ref()
97                    .map(|s| s != "repl_main_thread")
98                    .unwrap_or(true)
99                {
100                    return false;
101                }
102
103                // Only run if there's a project skill
104                if find_project_skill().is_none() {
105                    return false;
106                }
107
108                // Only run every TURN_BATCH_SIZE user messages
109                let mut state = SKILL_IMPROVEMENT_STATE.lock().unwrap();
110                let user_count = count_messages(&messages, |m| m.is_user());
111                if user_count - state.last_analyzed_count < TURN_BATCH_SIZE {
112                    return false;
113                }
114
115                state.last_analyzed_count = user_count;
116                true
117            })
118        }),
119        build_messages: Box::new(|context| {
120            let project_skill = match find_project_skill() {
121                Some(s) => s,
122                None => return Vec::new(),
123            };
124
125            let mut state = SKILL_IMPROVEMENT_STATE.lock().unwrap();
126            // Only analyze messages since the last check
127            let new_messages = context.messages[state.last_analyzed_index..].to_vec();
128            state.last_analyzed_index = context.messages.len();
129
130            let formatted = format_recent_messages(&new_messages);
131
132            let prompt = format!(
133                r#"You are analyzing a conversation where a user is executing a skill (a repeatable process).
134Your job: identify if the user's recent messages contain preferences, requests, or corrections that should be permanently added to the skill definition for future runs.
135
136<skill_definition>
137{}
138</skill_definition>
139
140<recent_messages>
141{}
142</recent_messages>
143
144Look for:
145- Requests to add, change, or remove steps: "can you also ask me X", "please do Y too", "don't do Z"
146- Preferences about how steps should work: "ask me about energy levels", "note the time", "use a casual tone"
147- Corrections: "no, do X instead", "always use Y", "make sure to..."
148
149Ignore:
150- Routine conversation that doesn't generalize (one-time answers, chitchat)
151- Things the skill already does
152
153Output a JSON array inside <updates> tags. Each item: {{"section": "which step/section to modify or 'new step'", "change": "what to add/modify", "reason": "which user message prompted this"}}.
154Output <updates>[]</updates> if no updates are needed."#,
155                project_skill.content, formatted
156            );
157
158            vec![Message {
159                role: crate::types::api_types::MessageRole::User,
160                content: prompt,
161                attachments: None,
162                tool_call_id: None,
163                tool_calls: None,
164                is_error: None,
165                is_meta: None,
166                is_api_error_message: None,
167                error_details: None,
168                uuid: None,
169            }]
170        }),
171        system_prompt: None,
172        use_tools: Some(false),
173        parse_response: Box::new(|content, _context| {
174            // Extract content between <updates> tags
175            if let Some(updates_str) = extract_tag(content, "updates") {
176                match serde_json::from_str::<Vec<SkillUpdate>>(&updates_str) {
177                    Ok(updates) => updates,
178                    Err(_) => Vec::new(),
179                }
180            } else {
181                Vec::new()
182            }
183        }),
184        log_result: Box::new(|result, context| {
185            if let crate::utils::hooks::api_query_hook_helper::ApiQueryResult::Success {
186                result: updates,
187                uuid,
188                ..
189            } = result
190            {
191                if !updates.is_empty() {
192                    let project_skill = find_project_skill();
193                    let skill_name = project_skill
194                        .as_ref()
195                        .map(|s| s.skill_name.clone())
196                        .unwrap_or_else(|| "unknown".to_string());
197
198                    log_event(
199                        "tengu_skill_improvement_detected",
200                        &serde_json::json!({
201                            "updateCount": updates.len(),
202                            "uuid": uuid,
203                            "skill_name": skill_name,
204                        }),
205                    );
206
207                    // Update app state with suggestion
208                    // This would set context.tool_use_context.setAppState
209                    log::debug!(
210                        "Skill improvement detected for '{}': {} updates",
211                        skill_name,
212                        updates.len()
213                    );
214                }
215            }
216        }),
217        get_model: Box::new(|_context| get_small_fast_model()),
218    };
219
220    let boxed_hook = create_api_query_hook(config);
221    Arc::from(boxed_hook)
222}
223
224/// Initialize skill improvement hook
225pub fn init_skill_improvement() {
226    // Check feature flags (simplified - would use GrowthBook in production)
227    let skill_improvement_enabled = true; // feature('SKILL_IMPROVEMENT')
228    let copper_panda_enabled = false; // getFeatureValue_CACHED_MAY_BE_STALE('tengu_copper_panda', false)
229
230    if skill_improvement_enabled && copper_panda_enabled {
231        let hook = create_skill_improvement_hook();
232        register_post_sampling_hook(hook);
233    }
234}
235
236/// Apply skill improvements by calling a side-channel LLM to rewrite the skill file.
237/// Fire-and-forget - does not block the main conversation.
238pub async fn apply_skill_improvement(skill_name: &str, updates: &[SkillUpdate]) {
239    if skill_name.is_empty() {
240        return;
241    }
242
243    // Skills live at .claude/skills/<name>/SKILL.md relative to CWD
244    let cwd = std::env::current_dir().unwrap_or_default();
245    let file_path = cwd
246        .join(".claude")
247        .join("skills")
248        .join(skill_name)
249        .join("SKILL.md");
250
251    let current_content = match tokio::fs::read_to_string(&file_path).await {
252        Ok(content) => content,
253        Err(_) => {
254            log::error!("Failed to read skill file for improvement: {:?}", file_path);
255            return;
256        }
257    };
258
259    let update_list: String = updates
260        .iter()
261        .map(|u| format!("- {}: {}", u.section, u.change))
262        .collect::<Vec<_>>()
263        .join("\n");
264
265    let prompt = format!(
266        r#"You are editing a skill definition file. Apply the following improvements to the skill.
267
268<current_skill_file>
269{}
270</current_skill_file>
271
272<improvements>
273{}
274</improvements>
275
276Rules:
277- Integrate the improvements naturally into the existing structure
278- Preserve frontmatter (--- block) exactly as-is
279- Preserve the overall format and style
280- Do not remove existing content unless an improvement explicitly replaces it
281- Output the complete updated file inside <updated_file> tags"#,
282        current_content, update_list
283    );
284
285    // This would call the LLM to apply the improvements
286    // For now, just log
287    log::debug!(
288        "Would apply skill improvements for '{}': {}",
289        skill_name,
290        update_list
291    );
292}
293
294/// Extract content between XML-style tags
295fn extract_tag(content: &str, tag_name: &str) -> Option<String> {
296    let open_tag = format!("<{}>", tag_name);
297    let close_tag = format!("</{}>", tag_name);
298
299    if let Some(start) = content.find(&open_tag) {
300        let content_start = start + open_tag.len();
301        if let Some(end) = content[content_start..].find(&close_tag) {
302            return Some(content[content_start..content_start + end].to_string());
303        }
304    }
305    None
306}
307
308/// Get the small fast model
309fn get_small_fast_model() -> String {
310    "claude-3-haiku-20240307".to_string()
311}
312
313/// Log event for analytics (simplified)
314fn log_event(event_name: &str, metadata: &serde_json::Value) {
315    log::debug!("Analytics event: {} - {:?}", event_name, metadata);
316}
317
318/// Message extension methods
319impl Message {
320    fn is_user(&self) -> bool {
321        matches!(self.role, crate::types::api_types::MessageRole::User)
322    }
323
324    fn is_assistant(&self) -> bool {
325        matches!(self.role, crate::types::api_types::MessageRole::Assistant)
326    }
327}
328
329#[cfg(test)]
330mod tests {
331    use super::*;
332
333    #[test]
334    fn test_extract_tag() {
335        let content = "Some text <updates>[{\"section\": \"test\", \"change\": \"add\", \"reason\": \"because\"}]</updates> more text";
336        let result = extract_tag(content, "updates");
337        assert!(result.is_some());
338        let updates = result.unwrap();
339        assert!(updates.contains("section"));
340    }
341
342    #[test]
343    fn test_extract_tag_empty() {
344        let content = "<updates>[]</updates>";
345        let result = extract_tag(content, "updates");
346        assert_eq!(result, Some("[]".to_string()));
347    }
348
349    #[test]
350    fn test_extract_tag_not_found() {
351        let content = "No tags here";
352        let result = extract_tag(content, "updates");
353        assert!(result.is_none());
354    }
355
356    #[test]
357    fn test_format_recent_messages() {
358        let messages = vec![
359            Message {
360                content: "Hello".to_string(),
361                ..Default::default()
362            },
363            Message {
364                content: "Hi there".to_string(),
365                ..Default::default()
366            },
367        ];
368        let result = format_recent_messages(&messages);
369        // Would contain "User: Hello" and "User: Hi there"
370        assert!(result.contains("Hello"));
371    }
372
373    #[test]
374    fn test_count_messages() {
375        let messages = vec![
376            Message {
377                content: "msg1".to_string(),
378                ..Default::default()
379            },
380            Message {
381                content: "msg2".to_string(),
382                ..Default::default()
383            },
384            Message {
385                content: "msg3".to_string(),
386                ..Default::default()
387            },
388        ];
389        let count = count_messages(&messages, |_| true);
390        assert_eq!(count, 3);
391    }
392}