Skip to main content

cersei_tools/
skill_tool.rs

1//! Skill tool: load and execute skill prompt templates.
2//! Supports:
3//! - `skill="list"` — list all available skills
4//! - `skill="<name>" args="<arguments>"` — load and expand a skill
5
6use super::*;
7use crate::skills::discovery;
8use serde::Deserialize;
9
10pub struct SkillTool {
11    /// Project root for skill discovery.
12    project_root: Option<std::path::PathBuf>,
13    /// Extra directories to search for skills.
14    extra_paths: Vec<std::path::PathBuf>,
15}
16
17impl SkillTool {
18    pub fn new() -> Self {
19        Self {
20            project_root: None,
21            extra_paths: Vec::new(),
22        }
23    }
24
25    pub fn with_project_root(mut self, root: impl Into<std::path::PathBuf>) -> Self {
26        self.project_root = Some(root.into());
27        self
28    }
29
30    pub fn with_extra_path(mut self, path: impl Into<std::path::PathBuf>) -> Self {
31        self.extra_paths.push(path.into());
32        self
33    }
34
35    pub fn with_extra_paths(mut self, paths: Vec<std::path::PathBuf>) -> Self {
36        self.extra_paths.extend(paths);
37        self
38    }
39}
40
41impl Default for SkillTool {
42    fn default() -> Self {
43        Self::new()
44    }
45}
46
47#[async_trait]
48impl Tool for SkillTool {
49    fn name(&self) -> &str {
50        "Skill"
51    }
52
53    fn description(&self) -> &str {
54        "Load and execute a skill (prompt template). Use skill='list' to see available skills."
55    }
56
57    fn permission_level(&self) -> PermissionLevel {
58        PermissionLevel::None
59    }
60    fn category(&self) -> ToolCategory {
61        ToolCategory::Custom
62    }
63
64    fn input_schema(&self) -> Value {
65        serde_json::json!({
66            "type": "object",
67            "properties": {
68                "skill": {
69                    "type": "string",
70                    "description": "Skill name, or 'list' to show available skills"
71                },
72                "args": {
73                    "type": "string",
74                    "description": "Arguments to pass to the skill (replaces $ARGUMENTS)"
75                }
76            },
77            "required": ["skill"]
78        })
79    }
80
81    async fn execute(&self, input: Value, ctx: &ToolContext) -> ToolResult {
82        #[derive(Deserialize)]
83        struct Input {
84            skill: String,
85            args: Option<String>,
86        }
87
88        let input: Input = match serde_json::from_value(input) {
89            Ok(i) => i,
90            Err(e) => return ToolResult::error(format!("Invalid input: {}", e)),
91        };
92
93        // List mode
94        if input.skill == "list" {
95            let project_root = self
96                .project_root
97                .as_deref()
98                .or_else(|| Some(ctx.working_dir.as_path()));
99            let skills = discovery::discover_all(project_root, &self.extra_paths);
100            return ToolResult::success(discovery::format_skill_list(&skills));
101        }
102
103        // Load mode
104        let project_root = self
105            .project_root
106            .as_deref()
107            .or_else(|| Some(ctx.working_dir.as_path()));
108        let loaded = discovery::load_skill(&input.skill, project_root, &self.extra_paths);
109
110        match loaded {
111            Some(skill) => {
112                let expanded = skill.expand(input.args.as_deref());
113
114                // Include metadata in the result
115                let mut meta = serde_json::json!({
116                    "skill_name": skill.meta.name,
117                    "format": format!("{:?}", skill.meta.format),
118                    "bundled": skill.meta.bundled,
119                });
120                if let Some(tools) = &skill.meta.allowed_tools {
121                    meta["allowed_tools"] = serde_json::json!(tools);
122                }
123
124                ToolResult::success(expanded).with_metadata(meta)
125            }
126            None => {
127                // Suggest similar skills
128                let all = discovery::discover_all(project_root, &self.extra_paths);
129                let suggestions: Vec<&str> = all
130                    .iter()
131                    .filter(|s| s.name.contains(&input.skill) || input.skill.contains(&s.name))
132                    .map(|s| s.name.as_str())
133                    .take(5)
134                    .collect();
135
136                let mut msg = format!("Skill '{}' not found.", input.skill);
137                if !suggestions.is_empty() {
138                    msg.push_str(&format!("\n\nDid you mean: {}?", suggestions.join(", ")));
139                }
140                msg.push_str("\n\nUse skill='list' to see all available skills.");
141                ToolResult::error(msg)
142            }
143        }
144    }
145}
146
147#[cfg(test)]
148mod tests {
149    use super::*;
150    use crate::permissions::AllowAll;
151    use std::sync::Arc;
152
153    fn test_ctx() -> ToolContext {
154        ToolContext {
155            working_dir: std::env::temp_dir(),
156            session_id: "skill-test".into(),
157            permissions: Arc::new(AllowAll),
158            cost_tracker: Arc::new(CostTracker::new()),
159            mcp_manager: None,
160            extensions: Extensions::default(),
161        }
162    }
163
164    #[tokio::test]
165    async fn test_skill_list() {
166        let tool = SkillTool::new();
167        let r = tool
168            .execute(serde_json::json!({"skill": "list"}), &test_ctx())
169            .await;
170        assert!(!r.is_error);
171        assert!(r.content.contains("Available skills:"));
172        assert!(r.content.contains("simplify"));
173        assert!(r.content.contains("[bundled]"));
174    }
175
176    #[tokio::test]
177    async fn test_skill_load_bundled() {
178        let tool = SkillTool::new();
179        let r = tool
180            .execute(
181                serde_json::json!({
182                    "skill": "debug",
183                    "args": "the login page crashes"
184                }),
185                &test_ctx(),
186            )
187            .await;
188        assert!(!r.is_error);
189        assert!(r.content.contains("the login page crashes"));
190        assert!(!r.content.contains("$ARGUMENTS"));
191        assert!(r.metadata.is_some());
192        assert_eq!(r.metadata.as_ref().unwrap()["bundled"], true);
193    }
194
195    #[tokio::test]
196    async fn test_skill_load_by_alias() {
197        let tool = SkillTool::new();
198        let r = tool
199            .execute(
200                serde_json::json!({"skill": "diagnose", "args": "memory leak"}),
201                &test_ctx(),
202            )
203            .await;
204        assert!(!r.is_error);
205        assert!(r.content.contains("memory leak"));
206    }
207
208    #[tokio::test]
209    async fn test_skill_not_found() {
210        let tool = SkillTool::new();
211        let r = tool
212            .execute(serde_json::json!({"skill": "nonexistent"}), &test_ctx())
213            .await;
214        assert!(r.is_error);
215        assert!(r.content.contains("not found"));
216    }
217
218    #[tokio::test]
219    async fn test_skill_load_from_disk() {
220        let tmp = tempfile::tempdir().unwrap();
221        let cmd_dir = tmp.path().join(".claude/commands");
222        std::fs::create_dir_all(&cmd_dir).unwrap();
223        std::fs::write(
224            cmd_dir.join("my-deploy.md"),
225            "---\ndescription: Deploy the app\n---\n\nDeploy $ARGUMENTS to production.",
226        )
227        .unwrap();
228
229        let tool = SkillTool::new().with_project_root(tmp.path());
230        let ctx = ToolContext {
231            working_dir: tmp.path().to_path_buf(),
232            ..test_ctx()
233        };
234
235        // List should include it
236        let r = tool
237            .execute(serde_json::json!({"skill": "list"}), &ctx)
238            .await;
239        assert!(r.content.contains("my-deploy"));
240
241        // Load and expand
242        let r = tool
243            .execute(
244                serde_json::json!({"skill": "my-deploy", "args": "v2.0"}),
245                &ctx,
246            )
247            .await;
248        assert!(!r.is_error);
249        assert!(r.content.contains("Deploy v2.0 to production"));
250        assert_eq!(r.metadata.as_ref().unwrap()["format"], "Commands");
251    }
252
253    #[tokio::test]
254    async fn test_skill_skills_format() {
255        let tmp = tempfile::tempdir().unwrap();
256        let skill_dir = tmp.path().join(".claude/skills/aws-deploy");
257        std::fs::create_dir_all(&skill_dir).unwrap();
258        std::fs::write(
259            skill_dir.join("SKILL.md"),
260            "---\nname: aws-deploy\ndescription: Deploy to AWS\n---\n\n# AWS Deploy\n\nUse CDK to deploy.",
261        )
262        .unwrap();
263
264        let tool = SkillTool::new().with_project_root(tmp.path());
265        let ctx = ToolContext {
266            working_dir: tmp.path().to_path_buf(),
267            ..test_ctx()
268        };
269
270        let r = tool
271            .execute(serde_json::json!({"skill": "aws-deploy"}), &ctx)
272            .await;
273        assert!(!r.is_error);
274        assert!(r.content.contains("CDK"));
275        assert_eq!(r.metadata.as_ref().unwrap()["format"], "Skills");
276    }
277
278    #[tokio::test]
279    async fn test_real_user_skills() {
280        // Test compatibility with actual ~/.claude/commands/ skills
281        let tool = SkillTool::new();
282        let r = tool
283            .execute(serde_json::json!({"skill": "list"}), &test_ctx())
284            .await;
285        // Should at least have bundled skills
286        assert!(r.content.contains("simplify"));
287
288        // Try loading "design" if it exists (from ~/.claude/commands/design.md)
289        let r = tool
290            .execute(serde_json::json!({"skill": "design"}), &test_ctx())
291            .await;
292        if !r.is_error {
293            println!("Loaded real user skill 'design': {} chars", r.content.len());
294            assert!(r.content.len() > 100); // design.md is substantial
295        }
296    }
297}