Skip to main content

abu_skill/
loader.rs

1use std::{collections::HashMap, path::{Path, PathBuf}, sync::OnceLock};
2use crate::{Skill, SkillError, SkillFrontmatter, SkillResult};
3use regex::Regex;
4use walkdir::{WalkDir, DirEntry};
5use tracing::debug;
6
7pub struct SkillLoader {
8    pub dir: PathBuf,
9    pub skills: HashMap<String, Skill>,
10}
11
12impl SkillLoader {
13    pub fn load(skill_dir: impl Into<PathBuf>) -> SkillResult<Self> {
14        let skill_dir: PathBuf = skill_dir.into();
15        debug!("load skills from {}", skill_dir.display());
16        let skills = Self::load_skills(&skill_dir)?;
17        Ok(Self { dir: skill_dir, skills })
18    }
19
20    pub fn get_descriptions(&self) -> String {
21        if self.skills.is_empty() {
22            "(no skills available)".to_string()
23        } else {
24            let skill_descriptions = self.skills.iter()
25                .map(|(name, skill)| {
26                    format!("  - {}: {}", name, skill.frontmatter.description)
27                })
28                .collect::<Vec<_>>()
29                .join("\n");
30            
31            format!("Use load_skill to access full content of one skill.\nHere are all available skills for you:\n{}", skill_descriptions)
32        }
33    }
34
35    pub fn get_content(&self, name: &str) -> Option<&str> {
36        self.skills
37            .get(name)
38            .map(|skill| skill.body.as_str())
39    }
40
41    fn load_skills(skill_dir: &Path) -> SkillResult<HashMap<String, Skill>> {
42        WalkDir::new(skill_dir).min_depth(1).max_depth(1)
43            .into_iter()
44            .filter_map(Result::ok)
45            .filter_map(Self::check_skill_path)
46            .map(|(name, path)| {
47                debug!("load skill {} from {}", name, path.display());
48                Self::load_skill(&path).map(|skill| (name, skill))
49            })
50            .collect()
51    }
52
53    fn load_skill(path: &Path) -> SkillResult<Skill> {
54        let content = std::fs::read_to_string(path)?;
55        Self::parse_skill(&content)
56    }
57
58    fn parse_skill(content: &str) -> SkillResult<Skill> {
59        static RE: OnceLock<Regex> = OnceLock::new();
60        let re = RE.get_or_init(|| { Regex::new(r"(?s)^---\r?\n(.*?)\r?\n---\r?\n(.*)").expect("Invalid Regex") });
61
62        if let Some(caps) = re.captures(&content) {
63            let yaml_str = caps.get(1).map_or("", |m| m.as_str());
64            let body_str = caps.get(2).map_or("", |m| m.as_str());
65            let frontmatter: SkillFrontmatter = serde_yaml::from_str(yaml_str)?;
66            Ok(Skill {
67                frontmatter,
68                body: body_str.trim().to_string(), 
69            })
70        } else {
71            return Err(SkillError::InvalidFrontmatter {
72                message: "missing opening frontmatter delimiter (`---`)".to_string(),
73            });
74        }
75    }
76
77    fn check_skill_path(entry: DirEntry) -> Option<(String, PathBuf)> {
78        // type 1: .md file
79        if entry.file_type().is_file() && entry.path().extension().is_some_and(|ext| ext == "md") {
80            Some((entry.file_name().to_str().unwrap().to_string(), entry.path().to_owned()))
81        } else if entry.file_type().is_dir() {
82            let md_path = entry.path().join("SKILL.md");
83            // type 2: SKILL.md in sub dir
84            if md_path.exists() {
85                Some((entry.file_name().to_str().unwrap().to_string(), md_path))
86            } else {
87                None
88            }
89        } else {
90            None
91        }
92    } 
93}
94
95#[cfg(test)]
96mod tests {
97    use super::*;
98
99    #[test]
100    fn test_loader() {
101        let loader = SkillLoader::load("./skills").expect("load");
102        assert_eq!(loader.skills.len(), 2);
103    }
104
105    #[test]
106    fn test_parses_valid_skill() {
107        let content = r#"---
108name: repo_search
109description: Search the codebase quickly
110---
111Use ripgrep first.
112"#;
113        let skill = SkillLoader::parse_skill(content).unwrap();
114        assert_eq!(skill.frontmatter.name, "repo_search");
115        assert_eq!(skill.frontmatter.description, "Search the codebase quickly");
116        assert!(skill.body.contains("Use ripgrep first."));
117    }
118
119    #[test]
120    fn test_parses_skill_with_full_spec() {
121        let content = r#"---
122name: full_spec_agent
123description: An agent with everything
124compatibility: "Requires Python 3.10+"
125allowed-tools:
126  - tool1
127references:
128  - ref1
129metadata:
130  custom_key: custom_value
131  version: "1.2.3"
132  license: MIT
133---
134Body content.
135"#;
136        let skill = SkillLoader::parse_skill(content).unwrap();
137        assert_eq!(skill.frontmatter.name, "full_spec_agent");
138        assert_eq!(skill.frontmatter.compatibility, Some("Requires Python 3.10+".to_string()));
139        assert_eq!(skill.frontmatter.allowed_tools, vec!["tool1"]);
140        assert_eq!(
141            skill.frontmatter.metadata.get("custom_key").and_then(|v| v.as_str()),
142            Some("custom_value")
143        );
144    }
145}