use dashmap::DashMap;
use std::path::Path;
use std::sync::Arc;
use tokio::fs;
use tracing::{debug, warn};
use crate::ast::skill_def::resolve_skill_path;
use crate::error::NikaError;
pub struct SkillInjector {
cache: DashMap<String, Arc<str>>,
}
impl SkillInjector {
pub fn new() -> Self {
Self {
cache: DashMap::new(),
}
}
pub async fn load_skill(
&self,
skill_path: &str,
base_dir: &Path,
) -> Result<Arc<str>, NikaError> {
let resolved_path = resolve_skill_path(skill_path, base_dir)?;
let cache_key = resolved_path.to_string_lossy().to_string();
if let Some(cached) = self.cache.get(&cache_key) {
debug!(skill_path = %skill_path, "Skill loaded from cache");
return Ok(Arc::clone(&cached));
}
let content =
fs::read_to_string(&resolved_path)
.await
.map_err(|e| NikaError::SkillLoadError {
skill: skill_path.to_string(),
reason: format!("Failed to read file '{}': {}", resolved_path.display(), e),
})?;
let content: Arc<str> = content.into();
self.cache.insert(cache_key, Arc::clone(&content));
debug!(skill_path = %skill_path, resolved = %resolved_path.display(), "Skill loaded and cached");
Ok(content)
}
pub async fn inject(
&self,
base_prompt: Option<&str>,
skill_names: &[&str],
skills_map: &std::collections::HashMap<String, String>,
base_dir: &Path,
) -> Result<String, NikaError> {
if skill_names.is_empty() {
return Ok(base_prompt.unwrap_or_default().to_string());
}
let mut parts: Vec<String> = Vec::with_capacity(skill_names.len() + 1);
for skill_name in skill_names {
let skill_path =
skills_map
.get(*skill_name)
.ok_or_else(|| NikaError::SkillLoadError {
skill: skill_name.to_string(),
reason: format!(
"Skill '{}' not found in workflow skills: block. Available: {:?}",
skill_name,
skills_map.keys().collect::<Vec<_>>()
),
})?;
match self.load_skill(skill_path, base_dir).await {
Ok(content) => {
parts.push(format!(
"# Skill: {}\n\n{}",
skill_name,
content.as_ref().trim_end()
));
}
Err(e) => {
warn!(skill = %skill_name, error = %e, "Failed to load skill, skipping");
}
}
}
if let Some(base) = base_prompt {
if !base.is_empty() {
parts.push(base.to_string());
}
}
Ok(parts.join("\n"))
}
pub fn clear_cache(&self) {
self.cache.clear();
debug!("Skill cache cleared");
}
pub fn cache_size(&self) -> usize {
self.cache.len()
}
pub fn is_cached(&self, skill_path: &str, base_dir: &Path) -> bool {
if let Ok(resolved) = resolve_skill_path(skill_path, base_dir) {
let cache_key = resolved.to_string_lossy().to_string();
self.cache.contains_key(&cache_key)
} else {
false
}
}
}
impl Default for SkillInjector {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
use tempfile::TempDir;
use tokio::fs::write;
async fn setup_test_skills() -> (TempDir, HashMap<String, String>) {
let temp_dir = TempDir::new().unwrap();
let skills_dir = temp_dir.path().join("skills");
tokio::fs::create_dir_all(&skills_dir).await.unwrap();
let seo_path = skills_dir.join("seo.skill.md");
write(
&seo_path,
"# SEO Writer\n\nYou are an expert SEO content writer.\n",
)
.await
.unwrap();
let brand_path = skills_dir.join("brand.skill.md");
write(
&brand_path,
"# Brand Voice\n\nMaintain a friendly, professional tone.\n",
)
.await
.unwrap();
let mut skills_map = HashMap::new();
skills_map.insert("seo".to_string(), "./skills/seo.skill.md".to_string());
skills_map.insert("brand".to_string(), "./skills/brand.skill.md".to_string());
(temp_dir, skills_map)
}
#[tokio::test]
async fn test_load_skill_success() {
let (temp_dir, skills_map) = setup_test_skills().await;
let injector = SkillInjector::new();
let content = injector
.load_skill(skills_map.get("seo").unwrap(), temp_dir.path())
.await
.unwrap();
assert!(content.contains("SEO Writer"));
assert!(content.contains("expert SEO content writer"));
}
#[tokio::test]
async fn test_load_skill_caching() {
let (temp_dir, skills_map) = setup_test_skills().await;
let injector = SkillInjector::new();
let content1 = injector
.load_skill(skills_map.get("seo").unwrap(), temp_dir.path())
.await
.unwrap();
assert_eq!(injector.cache_size(), 1);
let content2 = injector
.load_skill(skills_map.get("seo").unwrap(), temp_dir.path())
.await
.unwrap();
assert!(Arc::ptr_eq(&content1, &content2));
}
#[tokio::test]
async fn test_load_skill_file_not_found() {
let temp_dir = TempDir::new().unwrap();
let injector = SkillInjector::new();
let result = injector
.load_skill("./nonexistent.skill.md", temp_dir.path())
.await;
assert!(result.is_err());
let err = result.unwrap_err();
assert!(matches!(err, NikaError::SkillLoadError { .. }));
}
#[tokio::test]
async fn test_inject_single_skill() {
let (temp_dir, skills_map) = setup_test_skills().await;
let injector = SkillInjector::new();
let result = injector
.inject(Some("Be helpful"), &["seo"], &skills_map, temp_dir.path())
.await
.unwrap();
assert!(result.contains("# Skill: seo"));
assert!(result.contains("SEO Writer"));
assert!(result.contains("Be helpful"));
}
#[tokio::test]
async fn test_inject_multiple_skills() {
let (temp_dir, skills_map) = setup_test_skills().await;
let injector = SkillInjector::new();
let result = injector
.inject(
Some("Base prompt"),
&["seo", "brand"],
&skills_map,
temp_dir.path(),
)
.await
.unwrap();
assert!(result.contains("# Skill: seo"));
assert!(result.contains("# Skill: brand"));
assert!(result.contains("SEO Writer"));
assert!(result.contains("Brand Voice"));
assert!(result.contains("Base prompt"));
}
#[tokio::test]
async fn test_inject_no_skills() {
let temp_dir = TempDir::new().unwrap();
let skills_map = HashMap::new();
let injector = SkillInjector::new();
let result = injector
.inject(Some("Base prompt"), &[], &skills_map, temp_dir.path())
.await
.unwrap();
assert_eq!(result, "Base prompt");
}
#[tokio::test]
async fn test_inject_no_base_prompt() {
let (temp_dir, skills_map) = setup_test_skills().await;
let injector = SkillInjector::new();
let result = injector
.inject(None, &["seo"], &skills_map, temp_dir.path())
.await
.unwrap();
assert!(result.contains("# Skill: seo"));
assert!(result.contains("SEO Writer"));
}
#[tokio::test]
async fn test_inject_skill_not_in_map() {
let (temp_dir, skills_map) = setup_test_skills().await;
let injector = SkillInjector::new();
let result = injector
.inject(Some("Base"), &["nonexistent"], &skills_map, temp_dir.path())
.await;
assert!(result.is_err());
let err = result.unwrap_err();
if let NikaError::SkillLoadError { skill, reason } = err {
assert_eq!(skill, "nonexistent");
assert!(reason.contains("not found in workflow skills: block"));
} else {
panic!("Expected SkillLoadError");
}
}
#[tokio::test]
async fn test_clear_cache() {
let (temp_dir, skills_map) = setup_test_skills().await;
let injector = SkillInjector::new();
injector
.load_skill(skills_map.get("seo").unwrap(), temp_dir.path())
.await
.unwrap();
assert_eq!(injector.cache_size(), 1);
injector.clear_cache();
assert_eq!(injector.cache_size(), 0);
}
#[tokio::test]
async fn test_is_cached() {
let (temp_dir, skills_map) = setup_test_skills().await;
let injector = SkillInjector::new();
let skill_path = skills_map.get("seo").unwrap();
assert!(!injector.is_cached(skill_path, temp_dir.path()));
injector
.load_skill(skill_path, temp_dir.path())
.await
.unwrap();
assert!(injector.is_cached(skill_path, temp_dir.path()));
}
#[tokio::test]
async fn test_default_impl() {
let injector = SkillInjector::default();
assert_eq!(injector.cache_size(), 0);
}
#[tokio::test]
async fn test_inject_empty_base_prompt() {
let (temp_dir, skills_map) = setup_test_skills().await;
let injector = SkillInjector::new();
let result = injector
.inject(Some(""), &["seo"], &skills_map, temp_dir.path())
.await
.unwrap();
assert!(result.contains("# Skill: seo"));
assert!(!result.ends_with("\n\n")); }
#[tokio::test]
async fn test_concurrent_loads() {
let (temp_dir, skills_map) = setup_test_skills().await;
let injector = Arc::new(SkillInjector::new());
let skill_path = skills_map.get("seo").unwrap().clone();
let base_dir = temp_dir.path().to_path_buf();
let mut handles = vec![];
for _ in 0..10 {
let inj = Arc::clone(&injector);
let path = skill_path.clone();
let dir = base_dir.clone();
handles.push(tokio::spawn(
async move { inj.load_skill(&path, &dir).await },
));
}
for handle in handles {
let result = handle.await.unwrap();
assert!(result.is_ok());
}
assert_eq!(injector.cache_size(), 1);
}
}