#![allow(dead_code)]
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use crate::types::Message;
use crate::utils::hooks::api_query_hook_helper::{
ApiQueryHookConfig, ReplHookContext, create_api_query_hook,
};
use crate::utils::hooks::post_sampling_hooks::register_post_sampling_hook;
const TURN_BATCH_SIZE: usize = 5;
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct SkillUpdate {
pub section: String,
pub change: String,
pub reason: String,
}
#[derive(Debug, Clone)]
pub struct SkillImprovementSuggestion {
pub skill_name: String,
pub updates: Vec<SkillUpdate>,
}
struct SkillImprovementState {
last_analyzed_count: usize,
last_analyzed_index: usize,
}
lazy_static::lazy_static! {
static ref SKILL_IMPROVEMENT_STATE: Arc<Mutex<SkillImprovementState>> = Arc::new(Mutex::new(
SkillImprovementState {
last_analyzed_count: 0,
last_analyzed_index: 0,
}
));
}
fn find_project_skill() -> Option<ProjectSkillInfo> {
None
}
struct ProjectSkillInfo {
skill_name: String,
skill_path: String,
content: String,
}
fn format_recent_messages(messages: &[Message]) -> String {
messages
.iter()
.filter(|m| m.is_user() || m.is_assistant())
.map(|m| {
let role = if m.is_user() { "User" } else { "Assistant" };
let content = m.content.chars().take(500).collect::<String>();
format!("{}: {}", role, content)
})
.collect::<Vec<_>>()
.join("\n\n")
}
fn count_messages<F>(messages: &[Message], predicate: F) -> usize
where
F: Fn(&Message) -> bool,
{
messages.iter().filter(|m| predicate(m)).count()
}
fn create_skill_improvement_hook() -> Arc<
dyn Fn(ReplHookContext) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>>
+ Send
+ Sync,
> {
let config: ApiQueryHookConfig<Vec<SkillUpdate>> = ApiQueryHookConfig {
name: "skill_improvement".to_string(),
should_run: Box::new(|context| {
let query_source = context.query_source.clone();
let messages = context.messages.clone();
Box::pin(async move {
if query_source
.as_ref()
.map(|s| s != "repl_main_thread")
.unwrap_or(true)
{
return false;
}
if find_project_skill().is_none() {
return false;
}
let mut state = SKILL_IMPROVEMENT_STATE.lock().unwrap();
let user_count = count_messages(&messages, |m| m.is_user());
if user_count - state.last_analyzed_count < TURN_BATCH_SIZE {
return false;
}
state.last_analyzed_count = user_count;
true
})
}),
build_messages: Box::new(|context| {
let project_skill = match find_project_skill() {
Some(s) => s,
None => return Vec::new(),
};
let mut state = SKILL_IMPROVEMENT_STATE.lock().unwrap();
let new_messages = context.messages[state.last_analyzed_index..].to_vec();
state.last_analyzed_index = context.messages.len();
let formatted = format_recent_messages(&new_messages);
let prompt = format!(
r#"You are analyzing a conversation where a user is executing a skill (a repeatable process).
Your job: identify if the user's recent messages contain preferences, requests, or corrections that should be permanently added to the skill definition for future runs.
<skill_definition>
{}
</skill_definition>
<recent_messages>
{}
</recent_messages>
Look for:
- Requests to add, change, or remove steps: "can you also ask me X", "please do Y too", "don't do Z"
- Preferences about how steps should work: "ask me about energy levels", "note the time", "use a casual tone"
- Corrections: "no, do X instead", "always use Y", "make sure to..."
Ignore:
- Routine conversation that doesn't generalize (one-time answers, chitchat)
- Things the skill already does
Output a JSON array inside <updates> tags. Each item: {{"section": "which step/section to modify or 'new step'", "change": "what to add/modify", "reason": "which user message prompted this"}}.
Output <updates>[]</updates> if no updates are needed."#,
project_skill.content, formatted
);
vec![Message {
role: crate::types::api_types::MessageRole::User,
content: prompt,
attachments: None,
tool_call_id: None,
tool_calls: None,
is_error: None,
is_meta: None,
is_api_error_message: None,
error_details: None,
uuid: None,
}]
}),
system_prompt: None,
use_tools: Some(false),
parse_response: Box::new(|content, _context| {
if let Some(updates_str) = extract_tag(content, "updates") {
match serde_json::from_str::<Vec<SkillUpdate>>(&updates_str) {
Ok(updates) => updates,
Err(_) => Vec::new(),
}
} else {
Vec::new()
}
}),
log_result: Box::new(|result, context| {
if let crate::utils::hooks::api_query_hook_helper::ApiQueryResult::Success {
result: updates,
uuid,
..
} = result
{
if !updates.is_empty() {
let project_skill = find_project_skill();
let skill_name = project_skill
.as_ref()
.map(|s| s.skill_name.clone())
.unwrap_or_else(|| "unknown".to_string());
log_event(
"tengu_skill_improvement_detected",
&serde_json::json!({
"updateCount": updates.len(),
"uuid": uuid,
"skill_name": skill_name,
}),
);
log::debug!(
"Skill improvement detected for '{}': {} updates",
skill_name,
updates.len()
);
}
}
}),
get_model: Box::new(|_context| get_small_fast_model()),
};
let boxed_hook = create_api_query_hook(config);
Arc::from(boxed_hook)
}
pub fn init_skill_improvement() {
let skill_improvement_enabled = true; let copper_panda_enabled = false;
if skill_improvement_enabled && copper_panda_enabled {
let hook = create_skill_improvement_hook();
register_post_sampling_hook(hook);
}
}
pub async fn apply_skill_improvement(skill_name: &str, updates: &[SkillUpdate]) {
if skill_name.is_empty() {
return;
}
let cwd = std::env::current_dir().unwrap_or_default();
let file_path = cwd
.join(".claude")
.join("skills")
.join(skill_name)
.join("SKILL.md");
let current_content = match tokio::fs::read_to_string(&file_path).await {
Ok(content) => content,
Err(_) => {
log::error!("Failed to read skill file for improvement: {:?}", file_path);
return;
}
};
let update_list: String = updates
.iter()
.map(|u| format!("- {}: {}", u.section, u.change))
.collect::<Vec<_>>()
.join("\n");
let prompt = format!(
r#"You are editing a skill definition file. Apply the following improvements to the skill.
<current_skill_file>
{}
</current_skill_file>
<improvements>
{}
</improvements>
Rules:
- Integrate the improvements naturally into the existing structure
- Preserve frontmatter (--- block) exactly as-is
- Preserve the overall format and style
- Do not remove existing content unless an improvement explicitly replaces it
- Output the complete updated file inside <updated_file> tags"#,
current_content, update_list
);
log::debug!(
"Would apply skill improvements for '{}': {}",
skill_name,
update_list
);
}
fn extract_tag(content: &str, tag_name: &str) -> Option<String> {
let open_tag = format!("<{}>", tag_name);
let close_tag = format!("</{}>", tag_name);
if let Some(start) = content.find(&open_tag) {
let content_start = start + open_tag.len();
if let Some(end) = content[content_start..].find(&close_tag) {
return Some(content[content_start..content_start + end].to_string());
}
}
None
}
fn get_small_fast_model() -> String {
"claude-3-haiku-20240307".to_string()
}
fn log_event(event_name: &str, metadata: &serde_json::Value) {
log::debug!("Analytics event: {} - {:?}", event_name, metadata);
}
impl Message {
fn is_user(&self) -> bool {
matches!(self.role, crate::types::api_types::MessageRole::User)
}
fn is_assistant(&self) -> bool {
matches!(self.role, crate::types::api_types::MessageRole::Assistant)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_extract_tag() {
let content = "Some text <updates>[{\"section\": \"test\", \"change\": \"add\", \"reason\": \"because\"}]</updates> more text";
let result = extract_tag(content, "updates");
assert!(result.is_some());
let updates = result.unwrap();
assert!(updates.contains("section"));
}
#[test]
fn test_extract_tag_empty() {
let content = "<updates>[]</updates>";
let result = extract_tag(content, "updates");
assert_eq!(result, Some("[]".to_string()));
}
#[test]
fn test_extract_tag_not_found() {
let content = "No tags here";
let result = extract_tag(content, "updates");
assert!(result.is_none());
}
#[test]
fn test_format_recent_messages() {
let messages = vec![
Message {
content: "Hello".to_string(),
..Default::default()
},
Message {
content: "Hi there".to_string(),
..Default::default()
},
];
let result = format_recent_messages(&messages);
assert!(result.contains("Hello"));
}
#[test]
fn test_count_messages() {
let messages = vec![
Message {
content: "msg1".to_string(),
..Default::default()
},
Message {
content: "msg2".to_string(),
..Default::default()
},
Message {
content: "msg3".to_string(),
..Default::default()
},
];
let count = count_messages(&messages, |_| true);
assert_eq!(count, 3);
}
}