1use super::feedback::SkillScorer;
7use super::validator::SkillValidator;
8use super::Skill;
9use anyhow::Context;
10use std::collections::HashMap;
11use std::path::Path;
12use std::sync::{Arc, RwLock};
13
14pub struct SkillRegistry {
20 skills: Arc<RwLock<HashMap<String, Arc<Skill>>>>,
21 validator: Arc<RwLock<Option<Arc<dyn SkillValidator>>>>,
22 scorer: Arc<RwLock<Option<Arc<dyn SkillScorer>>>>,
23}
24
25impl SkillRegistry {
26 pub fn new() -> Self {
28 Self {
29 skills: Arc::new(RwLock::new(HashMap::new())),
30 validator: Arc::new(RwLock::new(None)),
31 scorer: Arc::new(RwLock::new(None)),
32 }
33 }
34
35 pub fn with_builtins() -> Self {
37 let registry = Self::new();
38 for skill in super::builtin::builtin_skills() {
39 registry.register_unchecked(skill);
41 }
42 registry
43 }
44
45 pub fn fork(&self) -> Self {
51 let skills = self.skills.read().unwrap().clone();
52 Self {
53 skills: Arc::new(RwLock::new(skills)),
54 validator: Arc::new(RwLock::new(None)),
55 scorer: Arc::new(RwLock::new(None)),
56 }
57 }
58
59 pub fn set_validator(&self, validator: Arc<dyn SkillValidator>) {
61 *self.validator.write().unwrap() = Some(validator);
62 }
63
64 pub fn set_scorer(&self, scorer: Arc<dyn SkillScorer>) {
66 *self.scorer.write().unwrap() = Some(scorer);
67 }
68
69 pub fn scorer(&self) -> Option<Arc<dyn SkillScorer>> {
71 self.scorer.read().unwrap().clone()
72 }
73
74 pub fn register(
79 &self,
80 skill: Arc<Skill>,
81 ) -> Result<(), super::validator::SkillValidationError> {
82 if let Some(ref validator) = *self.validator.read().unwrap() {
84 validator.validate(&skill)?;
85 }
86 self.register_unchecked(skill);
87 Ok(())
88 }
89
90 pub fn register_unchecked(&self, skill: Arc<Skill>) {
92 let mut skills = self.skills.write().unwrap();
93 skills.insert(skill.name.clone(), skill);
94 }
95
96 pub fn get(&self, name: &str) -> Option<Arc<Skill>> {
98 let skills = self.skills.read().unwrap();
99 skills.get(name).cloned()
100 }
101
102 pub fn list(&self) -> Vec<String> {
104 let skills = self.skills.read().unwrap();
105 skills.keys().cloned().collect()
106 }
107
108 pub fn all(&self) -> Vec<Arc<Skill>> {
110 let skills = self.skills.read().unwrap();
111 skills.values().cloned().collect()
112 }
113
114 pub fn load_from_dir(&self, dir: impl AsRef<Path>) -> anyhow::Result<usize> {
119 let dir = dir.as_ref();
120
121 if !dir.exists() {
122 return Ok(0);
123 }
124
125 if !dir.is_dir() {
126 anyhow::bail!("Path is not a directory: {}", dir.display());
127 }
128
129 let mut loaded = 0;
130
131 for entry in std::fs::read_dir(dir)
132 .with_context(|| format!("Failed to read directory: {}", dir.display()))?
133 {
134 let entry = entry?;
135 let path = entry.path();
136
137 if path.extension().and_then(|s| s.to_str()) != Some("md") {
139 continue;
140 }
141
142 match Skill::from_file(&path) {
144 Ok(skill) => {
145 let skill = Arc::new(skill);
146 match self.register(skill) {
147 Ok(()) => loaded += 1,
148 Err(e) => {
149 tracing::warn!("Skill validation failed for {}: {}", path.display(), e);
150 }
151 }
152 }
153 Err(e) => {
154 tracing::debug!("Skipped {}: {}", path.display(), e);
156 }
157 }
158 }
159
160 Ok(loaded)
161 }
162
163 pub fn load_from_file(&self, path: impl AsRef<Path>) -> anyhow::Result<Arc<Skill>> {
165 let skill = Skill::from_file(path)?;
166 let skill = Arc::new(skill);
167 self.register(skill.clone())
168 .map_err(|e| anyhow::anyhow!("Skill validation failed: {}", e))?;
169 Ok(skill)
170 }
171
172 pub fn remove(&self, name: &str) -> Option<Arc<Skill>> {
174 let mut skills = self.skills.write().unwrap();
175 skills.remove(name)
176 }
177
178 pub fn clear(&self) {
180 let mut skills = self.skills.write().unwrap();
181 skills.clear();
182 }
183
184 pub fn len(&self) -> usize {
186 let skills = self.skills.read().unwrap();
187 skills.len()
188 }
189
190 pub fn is_empty(&self) -> bool {
192 self.len() == 0
193 }
194
195 pub fn by_kind(&self, kind: super::SkillKind) -> Vec<Arc<Skill>> {
197 let skills = self.skills.read().unwrap();
198 skills
199 .values()
200 .filter(|s| s.kind == kind)
201 .cloned()
202 .collect()
203 }
204
205 pub fn by_tag(&self, tag: &str) -> Vec<Arc<Skill>> {
207 let skills = self.skills.read().unwrap();
208 skills
209 .values()
210 .filter(|s| s.tags.iter().any(|t| t == tag))
211 .cloned()
212 .collect()
213 }
214
215 pub fn personas(&self) -> Vec<Arc<Skill>> {
220 self.by_kind(super::SkillKind::Persona)
221 }
222
223 pub fn to_system_prompt(&self) -> String {
233 let skills = self.skills.read().unwrap();
234 let scorer = self.scorer.read().unwrap();
235
236 let instruction_skills: Vec<_> = skills
237 .values()
238 .filter(|s| s.kind == super::SkillKind::Instruction)
239 .filter(|s| match scorer.as_ref() {
240 Some(sc) => !sc.should_disable(&s.name),
241 None => true,
242 })
243 .collect();
244
245 if instruction_skills.is_empty() {
246 return String::new();
247 }
248
249 let mut prompt = String::from("# Available Skills\n\nThe following skills are available. Their full instructions will be provided when relevant.\n\n");
250 for skill in &instruction_skills {
251 prompt.push_str(&format!("- **{}**: {}\n", skill.name, skill.description));
252 }
253 prompt
254 }
255
256 pub fn match_skills(&self, user_input: &str) -> String {
261 let skills = self.skills.read().unwrap();
262 let scorer = self.scorer.read().unwrap();
263 let input_lower = user_input.to_lowercase();
264
265 let matched: Vec<_> = skills
266 .values()
267 .filter(|s| s.kind == super::SkillKind::Instruction)
268 .filter(|s| match scorer.as_ref() {
269 Some(sc) => !sc.should_disable(&s.name),
270 None => true,
271 })
272 .filter(|s| {
273 input_lower.contains(&s.name.to_lowercase())
275 || s.tags
276 .iter()
277 .any(|t| input_lower.contains(&t.to_lowercase()))
278 || input_lower.contains(
279 s.description
280 .to_lowercase()
281 .split_whitespace()
282 .next()
283 .unwrap_or(""),
284 )
285 })
286 .collect();
287
288 if matched.is_empty() {
289 return String::new();
290 }
291
292 let mut out = String::from("# Skill Instructions\n\n");
293 for skill in matched {
294 out.push_str(&skill.to_system_prompt());
295 out.push_str("\n\n---\n\n");
296 }
297 out
298 }
299}
300
301impl Default for SkillRegistry {
302 fn default() -> Self {
303 Self::new()
304 }
305}
306
307#[cfg(test)]
308mod tests {
309 use super::*;
310 use crate::skills::SkillKind;
311 use std::io::Write;
312 use tempfile::TempDir;
313
314 #[test]
315 fn test_new_registry() {
316 let registry = SkillRegistry::new();
317 assert_eq!(registry.len(), 0);
318 assert!(registry.is_empty());
319 }
320
321 #[test]
322 fn test_with_builtins() {
323 let registry = SkillRegistry::with_builtins();
324 assert_eq!(registry.len(), 7, "Expected 7 built-in skills");
325 assert!(!registry.is_empty());
326
327 assert!(registry.get("code-search").is_some());
329 assert!(registry.get("code-review").is_some());
330 assert!(registry.get("explain-code").is_some());
331 assert!(registry.get("find-bugs").is_some());
332
333 assert!(registry.get("builtin-tools").is_some());
335 assert!(registry.get("delegate-task").is_some());
336 assert!(registry.get("find-skills").is_some());
337 }
338
339 #[test]
340 fn test_register_and_get() {
341 let registry = SkillRegistry::new();
342
343 let skill = Arc::new(Skill {
344 name: "test-skill".to_string(),
345 description: "A test skill".to_string(),
346 allowed_tools: None,
347 disable_model_invocation: false,
348 kind: SkillKind::Instruction,
349 content: "Test content".to_string(),
350 tags: vec![],
351 version: None,
352 });
353
354 registry.register(skill.clone()).unwrap();
355
356 assert_eq!(registry.len(), 1);
357 let retrieved = registry.get("test-skill").unwrap();
358 assert_eq!(retrieved.name, "test-skill");
359 }
360
361 #[test]
362 fn test_list() {
363 let registry = SkillRegistry::with_builtins();
364 let names = registry.list();
365
366 assert_eq!(names.len(), 7, "Expected 7 built-in skills");
367 assert!(names.contains(&"code-search".to_string()));
368 assert!(names.contains(&"code-review".to_string()));
369 assert!(names.contains(&"builtin-tools".to_string()));
370 assert!(names.contains(&"delegate-task".to_string()));
371 assert!(names.contains(&"find-skills".to_string()));
372 }
373
374 #[test]
375 fn test_remove() {
376 let registry = SkillRegistry::with_builtins();
377 assert_eq!(registry.len(), 7);
378
379 let removed = registry.remove("code-search");
380 assert!(removed.is_some());
381 assert_eq!(registry.len(), 6);
382 assert!(registry.get("code-search").is_none());
383 }
384
385 #[test]
386 fn test_clear() {
387 let registry = SkillRegistry::with_builtins();
388 assert_eq!(registry.len(), 7);
389
390 registry.clear();
391 assert_eq!(registry.len(), 0);
392 assert!(registry.is_empty());
393 }
394
395 #[test]
396 fn test_by_kind() {
397 let registry = SkillRegistry::with_builtins();
398 let instruction_skills = registry.by_kind(SkillKind::Instruction);
399
400 assert_eq!(
401 instruction_skills.len(),
402 7,
403 "Expected 7 instruction skills (4 code assistance + 3 tool documentation)"
404 );
405
406 let tool_skills = registry.by_kind(SkillKind::Tool);
407 assert_eq!(tool_skills.len(), 0);
408 }
409
410 #[test]
411 fn test_by_tag() {
412 let registry = SkillRegistry::with_builtins();
413 let search_skills = registry.by_tag("search");
414
415 assert_eq!(search_skills.len(), 1);
416 assert_eq!(search_skills[0].name, "code-search");
417
418 let security_skills = registry.by_tag("security");
419 assert_eq!(security_skills.len(), 1);
420 assert_eq!(security_skills[0].name, "find-bugs");
421 }
422
423 #[test]
424 fn test_load_from_dir() -> anyhow::Result<()> {
425 let temp_dir = TempDir::new()?;
426
427 let skill_path = temp_dir.path().join("test-skill.md");
429 let mut file = std::fs::File::create(&skill_path)?;
430 writeln!(file, "---")?;
431 writeln!(file, "name: test-skill")?;
432 writeln!(file, "description: A test skill")?;
433 writeln!(file, "kind: instruction")?;
434 writeln!(file, "---")?;
435 writeln!(file, "# Test Skill")?;
436 writeln!(file, "This is a test skill.")?;
437 drop(file);
438
439 let readme_path = temp_dir.path().join("README.md");
441 std::fs::write(&readme_path, "# README\nNot a skill")?;
442
443 let txt_path = temp_dir.path().join("notes.txt");
445 std::fs::write(&txt_path, "Some notes")?;
446
447 let registry = SkillRegistry::new();
448 let loaded = registry.load_from_dir(temp_dir.path())?;
449
450 assert_eq!(loaded, 1);
451 assert_eq!(registry.len(), 1);
452 assert!(registry.get("test-skill").is_some());
453
454 Ok(())
455 }
456
457 #[test]
458 fn test_load_from_file() -> anyhow::Result<()> {
459 let temp_dir = TempDir::new()?;
460 let skill_path = temp_dir.path().join("my-skill.md");
461
462 let mut file = std::fs::File::create(&skill_path)?;
463 writeln!(file, "---")?;
464 writeln!(file, "name: my-skill")?;
465 writeln!(file, "description: My custom skill")?;
466 writeln!(file, "---")?;
467 writeln!(file, "# My Skill")?;
468 drop(file);
469
470 let registry = SkillRegistry::new();
471 let skill = registry.load_from_file(&skill_path)?;
472
473 assert_eq!(skill.name, "my-skill");
474 assert_eq!(registry.len(), 1);
475
476 Ok(())
477 }
478
479 #[test]
480 fn test_to_system_prompt() {
481 let registry = SkillRegistry::with_builtins();
482 let prompt = registry.to_system_prompt();
483
484 assert!(prompt.contains("# Available Skills"));
485 assert!(prompt.contains("code-search"));
486 assert!(prompt.contains("code-review"));
487 assert!(prompt.contains("explain-code"));
488 assert!(prompt.contains("find-bugs"));
489 }
490
491 #[test]
492 fn test_load_from_nonexistent_dir() {
493 let registry = SkillRegistry::new();
494 let result = registry.load_from_dir("/nonexistent/path");
495
496 assert!(result.is_ok());
497 assert_eq!(result.unwrap(), 0);
498 }
499
500 #[test]
503 fn test_register_with_validator_rejects_reserved() {
504 use crate::skills::validator::DefaultSkillValidator;
505
506 let registry = SkillRegistry::new();
507 registry.set_validator(Arc::new(DefaultSkillValidator::default()));
508
509 let skill = Arc::new(Skill {
510 name: "code-search".to_string(), description: "Override builtin".to_string(),
512 allowed_tools: None,
513 disable_model_invocation: false,
514 kind: SkillKind::Instruction,
515 content: "Malicious override".to_string(),
516 tags: vec![],
517 version: None,
518 });
519
520 let result = registry.register(skill);
521 assert!(result.is_err());
522 assert_eq!(registry.len(), 0);
523 }
524
525 #[test]
526 fn test_register_with_validator_accepts_valid() {
527 use crate::skills::validator::DefaultSkillValidator;
528
529 let registry = SkillRegistry::new();
530 registry.set_validator(Arc::new(DefaultSkillValidator::default()));
531
532 let skill = Arc::new(Skill {
533 name: "my-custom-skill".to_string(),
534 description: "A valid skill".to_string(),
535 allowed_tools: Some("read(*), grep(*)".to_string()),
536 disable_model_invocation: false,
537 kind: SkillKind::Instruction,
538 content: "Help with code review.".to_string(),
539 tags: vec![],
540 version: None,
541 });
542
543 assert!(registry.register(skill).is_ok());
544 assert_eq!(registry.len(), 1);
545 }
546
547 #[test]
548 fn test_register_without_validator_accepts_anything() {
549 let registry = SkillRegistry::new();
550 let skill = Arc::new(Skill {
553 name: "code-search".to_string(), description: "test".to_string(),
555 allowed_tools: None,
556 disable_model_invocation: false,
557 kind: SkillKind::Instruction,
558 content: "test".to_string(),
559 tags: vec![],
560 version: None,
561 });
562
563 assert!(registry.register(skill).is_ok());
564 }
565
566 #[test]
567 fn test_load_from_file_with_validator_rejects() {
568 use crate::skills::validator::DefaultSkillValidator;
569
570 let temp_dir = TempDir::new().unwrap();
571 let skill_path = temp_dir.path().join("code-search.md");
572
573 let mut file = std::fs::File::create(&skill_path).unwrap();
574 writeln!(file, "---").unwrap();
575 writeln!(file, "name: code-search").unwrap(); writeln!(file, "description: Override").unwrap();
577 writeln!(file, "---").unwrap();
578 writeln!(file, "# Override").unwrap();
579 drop(file);
580
581 let registry = SkillRegistry::new();
582 registry.set_validator(Arc::new(DefaultSkillValidator::default()));
583
584 let result = registry.load_from_file(&skill_path);
585 assert!(result.is_err());
586 assert_eq!(registry.len(), 0);
587 }
588
589 #[test]
592 fn test_to_system_prompt_skips_disabled_skills() {
593 use crate::skills::feedback::{DefaultSkillScorer, SkillFeedback, SkillOutcome};
594
595 let registry = SkillRegistry::new();
596 let scorer = Arc::new(DefaultSkillScorer::default());
597 registry.set_scorer(scorer.clone());
598
599 registry.register_unchecked(Arc::new(Skill {
601 name: "good-skill".to_string(),
602 description: "Good".to_string(),
603 allowed_tools: None,
604 disable_model_invocation: false,
605 kind: SkillKind::Instruction,
606 content: "Good instructions".to_string(),
607 tags: vec![],
608 version: None,
609 }));
610 registry.register_unchecked(Arc::new(Skill {
611 name: "bad-skill".to_string(),
612 description: "Bad".to_string(),
613 allowed_tools: None,
614 disable_model_invocation: false,
615 kind: SkillKind::Instruction,
616 content: "Bad instructions".to_string(),
617 tags: vec![],
618 version: None,
619 }));
620
621 for _ in 0..5 {
623 scorer.record(SkillFeedback {
624 skill_name: "bad-skill".to_string(),
625 outcome: SkillOutcome::Failure,
626 score_delta: -1.0,
627 reason: "Did not help".to_string(),
628 timestamp: 0,
629 });
630 }
631
632 let prompt = registry.to_system_prompt();
633 assert!(prompt.contains("good-skill"));
634 assert!(!prompt.contains("bad-skill"));
635 }
636
637 #[test]
638 fn test_fork_is_independent() {
639 let original = SkillRegistry::with_builtins();
640 let fork = original.fork();
641
642 assert_eq!(fork.len(), original.len());
644
645 fork.register_unchecked(Arc::new(Skill {
647 name: "session-only".to_string(),
648 description: "Only in fork".to_string(),
649 allowed_tools: None,
650 disable_model_invocation: false,
651 kind: SkillKind::Instruction,
652 content: "content".to_string(),
653 tags: vec![],
654 version: None,
655 }));
656
657 assert_eq!(fork.len(), original.len() + 1);
658 assert!(fork.get("session-only").is_some());
659 assert!(original.get("session-only").is_none());
660 }
661
662 #[test]
663 fn test_fork_inherits_builtins() {
664 let fork = SkillRegistry::with_builtins().fork();
665 assert!(fork.get("code-search").is_some());
666 assert!(fork.get("code-review").is_some());
667 assert!(fork.get("find-bugs").is_some());
668 }
669}