Skip to main content

nika_engine/runtime/
skill_injector.rs

1//! Skill Injector - Prepends skill content to agent system prompts
2//!
3//! The SkillInjector loads skill files (both local and pkg: URIs) and caches them
4//! for efficient reuse. When an agent task specifies skills, the injector prepends
5//! the skill content to the agent's system prompt.
6//!
7//! # Example
8//!
9//! ```yaml
10//! skills:
11//!   seo: ./skills/seo-writer.skill.md
12//!   brand: pkg:@supernovae/skills@1.0.0/brand.md
13//!
14//! tasks:
15//!   - id: generate
16//!     agent:
17//!       prompt: "Write content"
18//!       skills: [seo, brand]  # Skills injected into system prompt
19//! ```
20//!
21//! # Architecture
22//!
23//! ```text
24//! Workflow start:
25//!   1. Parse skills: block → HashMap<alias, path>
26//!   2. Resolve paths (local or pkg: URI)
27//!
28//! Agent task execution:
29//!   3. SkillInjector.inject(system_prompt, skill_names, skills_map, base_dir)
30//!   4. For each skill: load from cache or read file
31//!   5. Prepend skill content to system prompt
32//! ```
33
34use dashmap::DashMap;
35use std::path::Path;
36use std::sync::Arc;
37use tokio::fs;
38use tracing::{debug, warn};
39
40use crate::ast::skill_def::resolve_skill_path;
41use crate::error::NikaError;
42
43/// Thread-safe skill content cache
44///
45/// Uses DashMap for concurrent access without external locking.
46/// Keys are cache keys (resolved path string), values are skill content.
47pub struct SkillInjector {
48    /// Cache: resolved_path -> file content
49    cache: DashMap<String, Arc<str>>,
50}
51
52impl SkillInjector {
53    /// Create a new SkillInjector with empty cache
54    pub fn new() -> Self {
55        Self {
56            cache: DashMap::new(),
57        }
58    }
59
60    /// Load a skill file, using cache if available
61    ///
62    /// # Arguments
63    /// * `skill_path` - The skill path from YAML (local path or `pkg:` URI)
64    /// * `base_dir` - Base directory for resolving relative local paths
65    ///
66    /// # Returns
67    /// * `Ok(Arc<str>)` - Skill file content (from cache or freshly loaded)
68    /// * `Err(NikaError::SkillLoadError)` - If file cannot be read
69    ///
70    /// # Example
71    /// ```ignore
72    /// let injector = SkillInjector::new();
73    /// let content = injector.load_skill("./skills/seo.skill.md", Path::new("/project")).await?;
74    /// ```
75    pub async fn load_skill(
76        &self,
77        skill_path: &str,
78        base_dir: &Path,
79    ) -> Result<Arc<str>, NikaError> {
80        // Resolve the skill path (handles both local and pkg: URIs)
81        let resolved_path = resolve_skill_path(skill_path, base_dir)?;
82        let cache_key = resolved_path.to_string_lossy().to_string();
83
84        // Check cache first
85        if let Some(cached) = self.cache.get(&cache_key) {
86            debug!(skill_path = %skill_path, "Skill loaded from cache");
87            return Ok(Arc::clone(&cached));
88        }
89
90        // Load file content
91        let content =
92            fs::read_to_string(&resolved_path)
93                .await
94                .map_err(|e| NikaError::SkillLoadError {
95                    skill: skill_path.to_string(),
96                    reason: format!("Failed to read file '{}': {}", resolved_path.display(), e),
97                })?;
98
99        let content: Arc<str> = content.into();
100
101        // Cache for future use
102        self.cache.insert(cache_key, Arc::clone(&content));
103        debug!(skill_path = %skill_path, resolved = %resolved_path.display(), "Skill loaded and cached");
104
105        Ok(content)
106    }
107
108    /// Inject skills into a system prompt
109    ///
110    /// Loads each referenced skill and prepends it to the base system prompt.
111    /// Skills are separated by newlines and clearly marked with headers.
112    ///
113    /// # Arguments
114    /// * `base_prompt` - The agent's original system prompt (may be None)
115    /// * `skill_names` - List of skill aliases to inject
116    /// * `skills_map` - Workflow-level skills: HashMap<alias, path>
117    /// * `base_dir` - Base directory for resolving relative paths
118    ///
119    /// # Returns
120    /// * `Ok(String)` - Complete system prompt with skills prepended
121    /// * `Err(NikaError)` - If any skill fails to load
122    ///
123    /// # Example
124    /// ```ignore
125    /// let injector = SkillInjector::new();
126    /// let skills_map = [("seo".to_string(), "./skills/seo.md".to_string())].into();
127    /// let prompt = injector.inject(
128    ///     Some("Be helpful"),
129    ///     &["seo"],
130    ///     &skills_map,
131    ///     Path::new("/project"),
132    /// ).await?;
133    /// ```
134    pub async fn inject(
135        &self,
136        base_prompt: Option<&str>,
137        skill_names: &[&str],
138        skills_map: &std::collections::HashMap<String, String>,
139        base_dir: &Path,
140    ) -> Result<String, NikaError> {
141        if skill_names.is_empty() {
142            // No skills to inject - return base prompt or empty string
143            return Ok(base_prompt.unwrap_or_default().to_string());
144        }
145
146        let mut parts: Vec<String> = Vec::with_capacity(skill_names.len() + 1);
147
148        // Load each skill
149        for skill_name in skill_names {
150            // Get path from skills map
151            let skill_path =
152                skills_map
153                    .get(*skill_name)
154                    .ok_or_else(|| NikaError::SkillLoadError {
155                        skill: skill_name.to_string(),
156                        reason: format!(
157                            "Skill '{}' not found in workflow skills: block. Available: {:?}",
158                            skill_name,
159                            skills_map.keys().collect::<Vec<_>>()
160                        ),
161                    })?;
162
163            // Load skill content (uses cache)
164            match self.load_skill(skill_path, base_dir).await {
165                Ok(content) => {
166                    // Add skill with header for clarity
167                    // Trim trailing whitespace from content to prevent double newlines
168                    parts.push(format!(
169                        "# Skill: {}\n\n{}",
170                        skill_name,
171                        content.as_ref().trim_end()
172                    ));
173                }
174                Err(e) => {
175                    // Log warning but continue with other skills
176                    warn!(skill = %skill_name, error = %e, "Failed to load skill, skipping");
177                }
178            }
179        }
180
181        // Add base prompt at the end (if present)
182        if let Some(base) = base_prompt {
183            if !base.is_empty() {
184                parts.push(base.to_string());
185            }
186        }
187
188        Ok(parts.join("\n"))
189    }
190
191    /// Clear the skill cache (useful for testing or hot-reloading)
192    pub fn clear_cache(&self) {
193        self.cache.clear();
194        debug!("Skill cache cleared");
195    }
196
197    /// Get the number of cached skills
198    pub fn cache_size(&self) -> usize {
199        self.cache.len()
200    }
201
202    /// Check if a skill is cached
203    pub fn is_cached(&self, skill_path: &str, base_dir: &Path) -> bool {
204        if let Ok(resolved) = resolve_skill_path(skill_path, base_dir) {
205            let cache_key = resolved.to_string_lossy().to_string();
206            self.cache.contains_key(&cache_key)
207        } else {
208            false
209        }
210    }
211}
212
213impl Default for SkillInjector {
214    fn default() -> Self {
215        Self::new()
216    }
217}
218
219#[cfg(test)]
220mod tests {
221    use super::*;
222    use std::collections::HashMap;
223    use tempfile::TempDir;
224    use tokio::fs::write;
225
226    async fn setup_test_skills() -> (TempDir, HashMap<String, String>) {
227        let temp_dir = TempDir::new().unwrap();
228        let skills_dir = temp_dir.path().join("skills");
229        tokio::fs::create_dir_all(&skills_dir).await.unwrap();
230
231        // Create test skill files
232        let seo_path = skills_dir.join("seo.skill.md");
233        write(
234            &seo_path,
235            "# SEO Writer\n\nYou are an expert SEO content writer.\n",
236        )
237        .await
238        .unwrap();
239
240        let brand_path = skills_dir.join("brand.skill.md");
241        write(
242            &brand_path,
243            "# Brand Voice\n\nMaintain a friendly, professional tone.\n",
244        )
245        .await
246        .unwrap();
247
248        let mut skills_map = HashMap::new();
249        skills_map.insert("seo".to_string(), "./skills/seo.skill.md".to_string());
250        skills_map.insert("brand".to_string(), "./skills/brand.skill.md".to_string());
251
252        (temp_dir, skills_map)
253    }
254
255    #[tokio::test]
256    async fn test_load_skill_success() {
257        let (temp_dir, skills_map) = setup_test_skills().await;
258        let injector = SkillInjector::new();
259
260        let content = injector
261            .load_skill(skills_map.get("seo").unwrap(), temp_dir.path())
262            .await
263            .unwrap();
264
265        assert!(content.contains("SEO Writer"));
266        assert!(content.contains("expert SEO content writer"));
267    }
268
269    #[tokio::test]
270    async fn test_load_skill_caching() {
271        let (temp_dir, skills_map) = setup_test_skills().await;
272        let injector = SkillInjector::new();
273
274        // First load
275        let content1 = injector
276            .load_skill(skills_map.get("seo").unwrap(), temp_dir.path())
277            .await
278            .unwrap();
279
280        assert_eq!(injector.cache_size(), 1);
281
282        // Second load (should use cache)
283        let content2 = injector
284            .load_skill(skills_map.get("seo").unwrap(), temp_dir.path())
285            .await
286            .unwrap();
287
288        // Arc pointers should be the same (same cached instance)
289        assert!(Arc::ptr_eq(&content1, &content2));
290    }
291
292    #[tokio::test]
293    async fn test_load_skill_file_not_found() {
294        let temp_dir = TempDir::new().unwrap();
295        let injector = SkillInjector::new();
296
297        let result = injector
298            .load_skill("./nonexistent.skill.md", temp_dir.path())
299            .await;
300
301        assert!(result.is_err());
302        let err = result.unwrap_err();
303        assert!(matches!(err, NikaError::SkillLoadError { .. }));
304    }
305
306    #[tokio::test]
307    async fn test_inject_single_skill() {
308        let (temp_dir, skills_map) = setup_test_skills().await;
309        let injector = SkillInjector::new();
310
311        let result = injector
312            .inject(Some("Be helpful"), &["seo"], &skills_map, temp_dir.path())
313            .await
314            .unwrap();
315
316        assert!(result.contains("# Skill: seo"));
317        assert!(result.contains("SEO Writer"));
318        assert!(result.contains("Be helpful"));
319    }
320
321    #[tokio::test]
322    async fn test_inject_multiple_skills() {
323        let (temp_dir, skills_map) = setup_test_skills().await;
324        let injector = SkillInjector::new();
325
326        let result = injector
327            .inject(
328                Some("Base prompt"),
329                &["seo", "brand"],
330                &skills_map,
331                temp_dir.path(),
332            )
333            .await
334            .unwrap();
335
336        // Both skills should be present
337        assert!(result.contains("# Skill: seo"));
338        assert!(result.contains("# Skill: brand"));
339        assert!(result.contains("SEO Writer"));
340        assert!(result.contains("Brand Voice"));
341        // Base prompt at the end
342        assert!(result.contains("Base prompt"));
343    }
344
345    #[tokio::test]
346    async fn test_inject_no_skills() {
347        let temp_dir = TempDir::new().unwrap();
348        let skills_map = HashMap::new();
349        let injector = SkillInjector::new();
350
351        let result = injector
352            .inject(Some("Base prompt"), &[], &skills_map, temp_dir.path())
353            .await
354            .unwrap();
355
356        assert_eq!(result, "Base prompt");
357    }
358
359    #[tokio::test]
360    async fn test_inject_no_base_prompt() {
361        let (temp_dir, skills_map) = setup_test_skills().await;
362        let injector = SkillInjector::new();
363
364        let result = injector
365            .inject(None, &["seo"], &skills_map, temp_dir.path())
366            .await
367            .unwrap();
368
369        assert!(result.contains("# Skill: seo"));
370        assert!(result.contains("SEO Writer"));
371    }
372
373    #[tokio::test]
374    async fn test_inject_skill_not_in_map() {
375        let (temp_dir, skills_map) = setup_test_skills().await;
376        let injector = SkillInjector::new();
377
378        let result = injector
379            .inject(Some("Base"), &["nonexistent"], &skills_map, temp_dir.path())
380            .await;
381
382        assert!(result.is_err());
383        let err = result.unwrap_err();
384        if let NikaError::SkillLoadError { skill, reason } = err {
385            assert_eq!(skill, "nonexistent");
386            assert!(reason.contains("not found in workflow skills: block"));
387        } else {
388            panic!("Expected SkillLoadError");
389        }
390    }
391
392    #[tokio::test]
393    async fn test_clear_cache() {
394        let (temp_dir, skills_map) = setup_test_skills().await;
395        let injector = SkillInjector::new();
396
397        // Load a skill
398        injector
399            .load_skill(skills_map.get("seo").unwrap(), temp_dir.path())
400            .await
401            .unwrap();
402        assert_eq!(injector.cache_size(), 1);
403
404        // Clear cache
405        injector.clear_cache();
406        assert_eq!(injector.cache_size(), 0);
407    }
408
409    #[tokio::test]
410    async fn test_is_cached() {
411        let (temp_dir, skills_map) = setup_test_skills().await;
412        let injector = SkillInjector::new();
413
414        let skill_path = skills_map.get("seo").unwrap();
415
416        // Not cached initially
417        assert!(!injector.is_cached(skill_path, temp_dir.path()));
418
419        // Load skill
420        injector
421            .load_skill(skill_path, temp_dir.path())
422            .await
423            .unwrap();
424
425        // Now cached
426        assert!(injector.is_cached(skill_path, temp_dir.path()));
427    }
428
429    #[tokio::test]
430    async fn test_default_impl() {
431        let injector = SkillInjector::default();
432        assert_eq!(injector.cache_size(), 0);
433    }
434
435    #[tokio::test]
436    async fn test_inject_empty_base_prompt() {
437        let (temp_dir, skills_map) = setup_test_skills().await;
438        let injector = SkillInjector::new();
439
440        let result = injector
441            .inject(Some(""), &["seo"], &skills_map, temp_dir.path())
442            .await
443            .unwrap();
444
445        // Should have skill but no empty string at end
446        assert!(result.contains("# Skill: seo"));
447        assert!(!result.ends_with("\n\n")); // No double newline from empty base
448    }
449
450    #[tokio::test]
451    async fn test_concurrent_loads() {
452        let (temp_dir, skills_map) = setup_test_skills().await;
453        let injector = Arc::new(SkillInjector::new());
454
455        let skill_path = skills_map.get("seo").unwrap().clone();
456        let base_dir = temp_dir.path().to_path_buf();
457
458        // Spawn multiple concurrent loads
459        let mut handles = vec![];
460        for _ in 0..10 {
461            let inj = Arc::clone(&injector);
462            let path = skill_path.clone();
463            let dir = base_dir.clone();
464            handles.push(tokio::spawn(
465                async move { inj.load_skill(&path, &dir).await },
466            ));
467        }
468
469        // All should succeed
470        for handle in handles {
471            let result = handle.await.unwrap();
472            assert!(result.is_ok());
473        }
474
475        // Should only have one entry (deduplicated)
476        assert_eq!(injector.cache_size(), 1);
477    }
478}