use anyhow::{Context, Result, bail};
use async_trait::async_trait;
use std::path::{Path, PathBuf};
use super::{Skill, parser::parse_skill_file};
#[async_trait]
pub trait SkillLoader: Send + Sync {
async fn load(&self, name: &str) -> Result<Skill>;
async fn list(&self) -> Result<Vec<String>>;
async fn exists(&self, name: &str) -> bool {
self.load(name).await.is_ok()
}
}
pub struct FileSkillLoader {
base_path: PathBuf,
}
impl FileSkillLoader {
#[must_use]
pub fn new(base_path: impl Into<PathBuf>) -> Self {
Self {
base_path: base_path.into(),
}
}
#[must_use]
pub fn base_path(&self) -> &Path {
&self.base_path
}
fn skill_path(&self, name: &str) -> Result<PathBuf> {
if name.contains('/') || name.contains('\\') || name.contains("..") || name.contains('\0') {
bail!("Invalid skill name: must not contain path separators, '..', or null bytes");
}
Ok(self.base_path.join(format!("{name}.md")))
}
}
#[async_trait]
impl SkillLoader for FileSkillLoader {
async fn load(&self, name: &str) -> Result<Skill> {
let path = self.skill_path(name)?;
if !path.exists() {
bail!("Skill file not found: {}", path.display());
}
let content = tokio::fs::read_to_string(&path)
.await
.with_context(|| format!("Failed to read skill file: {}", path.display()))?;
let skill = parse_skill_file(&content)
.with_context(|| format!("Failed to parse skill file: {}", path.display()))?;
if skill.name != name {
log::warn!(
"Skill name '{}' in file doesn't match filename '{}'",
skill.name,
name
);
}
Ok(skill)
}
async fn list(&self) -> Result<Vec<String>> {
if !self.base_path.exists() {
return Ok(Vec::new());
}
let mut entries = tokio::fs::read_dir(&self.base_path)
.await
.with_context(|| {
format!(
"Failed to read skills directory: {}",
self.base_path.display()
)
})?;
let mut skills = Vec::new();
while let Some(entry) = entries.next_entry().await? {
let path = entry.path();
if path.extension().is_some_and(|ext| ext == "md")
&& let Some(name) = path.file_stem().and_then(|s| s.to_str())
{
skills.push(name.to_string());
}
}
skills.sort();
Ok(skills)
}
}
#[derive(Default)]
pub struct InMemorySkillLoader {
skills: std::sync::RwLock<std::collections::HashMap<String, Skill>>,
}
impl InMemorySkillLoader {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn add(&self, skill: Skill) -> Result<()> {
self.skills
.write()
.ok()
.context("lock poisoned")?
.insert(skill.name.clone(), skill);
Ok(())
}
pub fn remove(&self, name: &str) -> Result<Option<Skill>> {
let mut skills = self.skills.write().ok().context("lock poisoned")?;
Ok(skills.remove(name))
}
}
#[async_trait]
impl SkillLoader for InMemorySkillLoader {
async fn load(&self, name: &str) -> Result<Skill> {
let skills = self.skills.read().ok().context("lock poisoned")?;
skills
.get(name)
.cloned()
.with_context(|| format!("Skill not found: {name}"))
}
async fn list(&self) -> Result<Vec<String>> {
let mut names: Vec<_> = self
.skills
.read()
.ok()
.context("lock poisoned")?
.keys()
.cloned()
.collect();
names.sort();
Ok(names)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
use tempfile::TempDir;
#[tokio::test]
async fn test_file_loader_load() -> Result<()> {
let dir = TempDir::new()?;
let skill_path = dir.path().join("test-skill.md");
let mut file = std::fs::File::create(&skill_path)?;
writeln!(
file,
"---
name: test-skill
description: A test skill
---
You are a test assistant."
)?;
let loader = FileSkillLoader::new(dir.path());
let skill = loader.load("test-skill").await?;
assert_eq!(skill.name, "test-skill");
assert_eq!(skill.description, "A test skill");
assert!(skill.system_prompt.contains("test assistant"));
Ok(())
}
#[tokio::test]
async fn test_file_loader_load_not_found() {
let dir = TempDir::new().unwrap();
let loader = FileSkillLoader::new(dir.path());
let result = loader.load("nonexistent").await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("not found"));
}
#[tokio::test]
async fn test_file_loader_list() -> Result<()> {
let dir = TempDir::new()?;
for name in ["alpha", "beta", "gamma"] {
let path = dir.path().join(format!("{name}.md"));
let mut file = std::fs::File::create(&path)?;
writeln!(
file,
"---
name: {name}
---
Content"
)?;
}
let _ = std::fs::File::create(dir.path().join("readme.txt"))?;
let loader = FileSkillLoader::new(dir.path());
let skills = loader.list().await?;
assert_eq!(skills, vec!["alpha", "beta", "gamma"]);
Ok(())
}
#[tokio::test]
async fn test_file_loader_list_empty_dir() -> Result<()> {
let dir = TempDir::new()?;
let loader = FileSkillLoader::new(dir.path());
let skills = loader.list().await?;
assert!(skills.is_empty());
Ok(())
}
#[tokio::test]
async fn test_file_loader_list_nonexistent_dir() -> Result<()> {
let loader = FileSkillLoader::new("/nonexistent/path");
let skills = loader.list().await?;
assert!(skills.is_empty());
Ok(())
}
#[tokio::test]
async fn test_file_loader_exists() -> Result<()> {
let dir = TempDir::new()?;
let skill_path = dir.path().join("exists.md");
let mut file = std::fs::File::create(&skill_path)?;
writeln!(
file,
"---
name: exists
---
Content"
)?;
let loader = FileSkillLoader::new(dir.path());
assert!(loader.exists("exists").await);
assert!(!loader.exists("not-exists").await);
Ok(())
}
#[tokio::test]
async fn test_in_memory_loader() -> Result<()> {
let loader = InMemorySkillLoader::new();
loader.add(Skill::new("skill1", "Prompt 1").with_description("First skill"))?;
loader.add(Skill::new("skill2", "Prompt 2").with_description("Second skill"))?;
let first = loader.load("skill1").await?;
assert_eq!(first.name, "skill1");
assert_eq!(first.description, "First skill");
let skill_names = loader.list().await?;
assert_eq!(skill_names, vec!["skill1", "skill2"]);
loader.remove("skill1")?;
assert!(!loader.exists("skill1").await);
Ok(())
}
#[tokio::test]
async fn test_in_memory_loader_not_found() {
let loader = InMemorySkillLoader::new();
let result = loader.load("nonexistent").await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_file_loader_blocks_path_traversal() -> Result<()> {
let dir = TempDir::new()?;
let loader = FileSkillLoader::new(dir.path());
let traversal_names = [
"../etc/passwd",
"..\\windows\\system32",
"foo/../bar",
"foo/bar",
"foo\\bar",
"skill\0name",
];
for name in &traversal_names {
let result = loader.load(name).await;
assert!(result.is_err(), "Expected error for skill name: {name}");
assert!(
result
.unwrap_err()
.to_string()
.contains("Invalid skill name"),
"Expected 'Invalid skill name' error for: {name}"
);
}
Ok(())
}
}