Skip to main content

aster/skills/
registry.rs

1//! Skill Registry
2//!
3//! Manages skill discovery, registration, and lookup.
4
5use super::loader::{load_skills_from_directory, load_skills_from_plugin_cache};
6use super::types::{InvokedSkill, SkillDefinition, SkillSource};
7use std::collections::HashMap;
8use std::path::{Path, PathBuf};
9use std::sync::{Arc, RwLock};
10use std::time::{SystemTime, UNIX_EPOCH};
11
12/// Skill registry for managing all available skills
13#[derive(Debug, Default)]
14pub struct SkillRegistry {
15    /// Registered skills by name
16    skills: HashMap<String, SkillDefinition>,
17    /// Invoked skills history
18    invoked: HashMap<String, InvokedSkill>,
19    /// Whether skills have been loaded
20    loaded: bool,
21}
22
23impl SkillRegistry {
24    /// Create a new empty registry
25    pub fn new() -> Self {
26        Self::default()
27    }
28
29    /// Check if skills have been loaded
30    pub fn is_loaded(&self) -> bool {
31        self.loaded
32    }
33
34    /// Get all registered skills
35    pub fn get_all(&self) -> Vec<&SkillDefinition> {
36        self.skills.values().collect()
37    }
38
39    /// Get skill count
40    pub fn len(&self) -> usize {
41        self.skills.len()
42    }
43
44    /// Check if registry is empty
45    pub fn is_empty(&self) -> bool {
46        self.skills.is_empty()
47    }
48
49    /// Register a skill
50    pub fn register(&mut self, skill: SkillDefinition) {
51        self.skills.insert(skill.skill_name.clone(), skill);
52    }
53
54    /// Unregister a skill by name
55    pub fn unregister(&mut self, skill_name: &str) -> Option<SkillDefinition> {
56        self.skills.remove(skill_name)
57    }
58
59    /// Find a skill by name (supports namespace lookup)
60    pub fn find(&self, skill_input: &str) -> Option<&SkillDefinition> {
61        // 1. Exact match
62        if let Some(skill) = self.skills.get(skill_input) {
63            return Some(skill);
64        }
65
66        // 2. If no namespace, try to find first matching short name
67        if !skill_input.contains(':') {
68            for skill in self.skills.values() {
69                if skill.short_name() == skill_input {
70                    return Some(skill);
71                }
72            }
73        }
74
75        None
76    }
77
78    /// Get skills by source
79    pub fn get_by_source(&self, source: SkillSource) -> Vec<&SkillDefinition> {
80        self.skills
81            .values()
82            .filter(|s| s.source == source)
83            .collect()
84    }
85
86    /// Get user-invocable skills
87    pub fn get_user_invocable(&self) -> Vec<&SkillDefinition> {
88        self.skills.values().filter(|s| s.user_invocable).collect()
89    }
90
91    /// Record an invoked skill
92    pub fn record_invoked(&mut self, skill_name: &str, skill_path: &Path, content: &str) {
93        let timestamp = SystemTime::now()
94            .duration_since(UNIX_EPOCH)
95            .map(|d| d.as_secs())
96            .unwrap_or(0);
97
98        self.invoked.insert(
99            skill_name.to_string(),
100            InvokedSkill {
101                skill_name: skill_name.to_string(),
102                skill_path: skill_path.to_path_buf(),
103                content: content.to_string(),
104                invoked_at: timestamp,
105            },
106        );
107    }
108
109    /// Get invoked skills
110    pub fn get_invoked(&self) -> &HashMap<String, InvokedSkill> {
111        &self.invoked
112    }
113
114    /// Clear invoked skills history
115    pub fn clear_invoked(&mut self) {
116        self.invoked.clear();
117    }
118
119    /// Clear all skills and reset loaded state
120    pub fn clear(&mut self) {
121        self.skills.clear();
122        self.invoked.clear();
123        self.loaded = false;
124    }
125
126    /// Get default skill directories
127    pub fn get_default_directories() -> Vec<(PathBuf, SkillSource)> {
128        let mut dirs = Vec::new();
129
130        // User-level directories
131        if let Some(home) = dirs::home_dir() {
132            dirs.push((home.join(".claude/skills"), SkillSource::User));
133        }
134
135        // Project-level directories
136        if let Ok(cwd) = std::env::current_dir() {
137            dirs.push((cwd.join(".claude/skills"), SkillSource::Project));
138        }
139
140        dirs
141    }
142
143    /// Initialize and load all skills
144    ///
145    /// Loading order (later overrides earlier):
146    /// 1. Plugin skills (lowest priority)
147    /// 2. User skills (~/.claude/skills/)
148    /// 3. Project skills (.claude/skills/) (highest priority)
149    pub fn initialize(&mut self) {
150        if self.loaded {
151            return;
152        }
153
154        self.skills.clear();
155
156        // 1. Load plugin skills (lowest priority)
157        for skill in load_skills_from_plugin_cache() {
158            self.skills.insert(skill.skill_name.clone(), skill);
159        }
160
161        // 2. Load from default directories
162        for (dir, source) in Self::get_default_directories() {
163            for skill in load_skills_from_directory(&dir, source) {
164                self.skills.insert(skill.skill_name.clone(), skill);
165            }
166        }
167
168        self.loaded = true;
169    }
170
171    /// Reload skills (clear and reinitialize)
172    pub fn reload(&mut self) {
173        self.loaded = false;
174        self.initialize();
175    }
176
177    /// Generate instructions for available skills
178    pub fn generate_instructions(&self) -> String {
179        if self.skills.is_empty() {
180            return String::new();
181        }
182
183        let mut instructions =
184            String::from("You have these skills at your disposal. Use them when relevant:\n\n");
185
186        let mut skill_list: Vec<_> = self.skills.values().collect();
187        skill_list.sort_by_key(|s| &s.skill_name);
188
189        for skill in skill_list {
190            instructions.push_str(&format!("- {}: {}\n", skill.skill_name, skill.description));
191        }
192
193        instructions
194    }
195}
196
197/// Thread-safe shared skill registry
198pub type SharedSkillRegistry = Arc<RwLock<SkillRegistry>>;
199
200/// Create a new shared skill registry
201pub fn new_shared_registry() -> SharedSkillRegistry {
202    Arc::new(RwLock::new(SkillRegistry::new()))
203}
204
205/// Global skill registry instance
206static GLOBAL_REGISTRY: std::sync::OnceLock<SharedSkillRegistry> = std::sync::OnceLock::new();
207
208/// Get the global skill registry
209pub fn global_registry() -> &'static SharedSkillRegistry {
210    GLOBAL_REGISTRY.get_or_init(|| {
211        let registry = new_shared_registry();
212        if let Ok(mut r) = registry.write() {
213            r.initialize();
214        }
215        registry
216    })
217}
218
219#[cfg(test)]
220mod tests {
221    use super::*;
222    use crate::skills::types::SkillExecutionMode;
223    #[allow(unused_imports)]
224    use std::fs;
225    #[allow(unused_imports)]
226    use tempfile::TempDir;
227
228    fn create_test_skill(name: &str, source: SkillSource) -> SkillDefinition {
229        SkillDefinition {
230            skill_name: format!("{}:{}", source, name),
231            display_name: name.to_string(),
232            description: format!("Test skill: {}", name),
233            has_user_specified_description: true,
234            markdown_content: "# Content".to_string(),
235            allowed_tools: None,
236            argument_hint: None,
237            when_to_use: None,
238            version: None,
239            model: None,
240            disable_model_invocation: false,
241            user_invocable: true,
242            source,
243            base_dir: PathBuf::from("/test"),
244            file_path: PathBuf::from("/test/SKILL.md"),
245            supporting_files: vec![],
246            execution_mode: SkillExecutionMode::default(),
247            provider: None,
248            workflow: None,
249        }
250    }
251
252    #[test]
253    fn test_registry_new() {
254        let registry = SkillRegistry::new();
255        assert!(!registry.is_loaded());
256        assert!(registry.is_empty());
257    }
258
259    #[test]
260    fn test_registry_register_and_find() {
261        let mut registry = SkillRegistry::new();
262        let skill = create_test_skill("my-skill", SkillSource::User);
263
264        registry.register(skill);
265
266        assert_eq!(registry.len(), 1);
267
268        // Exact match
269        let found = registry.find("user:my-skill");
270        assert!(found.is_some());
271        assert_eq!(found.unwrap().display_name, "my-skill");
272
273        // Short name match
274        let found = registry.find("my-skill");
275        assert!(found.is_some());
276    }
277
278    #[test]
279    fn test_registry_unregister() {
280        let mut registry = SkillRegistry::new();
281        let skill = create_test_skill("to-remove", SkillSource::User);
282
283        registry.register(skill);
284        assert_eq!(registry.len(), 1);
285
286        let removed = registry.unregister("user:to-remove");
287        assert!(removed.is_some());
288        assert_eq!(registry.len(), 0);
289    }
290
291    #[test]
292    fn test_registry_get_by_source() {
293        let mut registry = SkillRegistry::new();
294
295        registry.register(create_test_skill("user-skill", SkillSource::User));
296        registry.register(create_test_skill("project-skill", SkillSource::Project));
297        registry.register(create_test_skill("plugin-skill", SkillSource::Plugin));
298
299        let user_skills = registry.get_by_source(SkillSource::User);
300        assert_eq!(user_skills.len(), 1);
301
302        let project_skills = registry.get_by_source(SkillSource::Project);
303        assert_eq!(project_skills.len(), 1);
304    }
305
306    #[test]
307    fn test_registry_record_invoked() {
308        let mut registry = SkillRegistry::new();
309
310        registry.record_invoked(
311            "test-skill",
312            &PathBuf::from("/test/SKILL.md"),
313            "skill content",
314        );
315
316        let invoked = registry.get_invoked();
317        assert_eq!(invoked.len(), 1);
318        assert!(invoked.contains_key("test-skill"));
319    }
320
321    #[test]
322    fn test_registry_generate_instructions() {
323        let mut registry = SkillRegistry::new();
324
325        // Empty registry
326        let instructions = registry.generate_instructions();
327        assert!(instructions.is_empty());
328
329        // With skills
330        registry.register(create_test_skill("alpha", SkillSource::User));
331        registry.register(create_test_skill("beta", SkillSource::Project));
332
333        let instructions = registry.generate_instructions();
334        assert!(instructions.contains("alpha"));
335        assert!(instructions.contains("beta"));
336    }
337
338    #[test]
339    fn test_registry_clear() {
340        let mut registry = SkillRegistry::new();
341        registry.register(create_test_skill("skill", SkillSource::User));
342        registry.record_invoked("skill", &PathBuf::from("/test"), "content");
343
344        registry.clear();
345
346        assert!(registry.is_empty());
347        assert!(registry.get_invoked().is_empty());
348        assert!(!registry.is_loaded());
349    }
350
351    #[test]
352    fn test_shared_registry() {
353        let registry = new_shared_registry();
354
355        {
356            let mut r = registry.write().unwrap();
357            r.register(create_test_skill("shared-skill", SkillSource::User));
358        }
359
360        {
361            let r = registry.read().unwrap();
362            assert_eq!(r.len(), 1);
363        }
364    }
365}