Skip to main content

rig_resources/
skills.rs

1//! Prebuilt domain-neutral skills.
2
3use async_trait::async_trait;
4use serde_json::json;
5
6use rig_compose::{Evidence, InvestigationContext, KernelError, Skill, SkillOutcome, ToolRegistry};
7
8/// `general.baseline_compare` — suppresses confidence when behaviour falls
9/// inside the entity's known baseline. Conservative by design: if no
10/// `baseline.available` signal is present the skill is a no-op.
11#[derive(Default)]
12pub struct BaselineCompareSkill;
13
14#[async_trait]
15impl Skill for BaselineCompareSkill {
16    fn id(&self) -> &str {
17        "general.baseline_compare"
18    }
19    fn description(&self) -> &str {
20        "Suppresses confidence when observed behaviour is within the entity's known baseline."
21    }
22    fn applies(&self, ctx: &InvestigationContext) -> bool {
23        ctx.has_signal("baseline.available") && ctx.has_signal("baseline.within")
24    }
25    async fn execute(
26        &self,
27        ctx: &mut InvestigationContext,
28        _tools: &ToolRegistry,
29    ) -> Result<SkillOutcome, KernelError> {
30        ctx.evidence
31            .push(Evidence::new(self.id(), "baseline.suppress"));
32        Ok(SkillOutcome::default().with_delta(-0.2))
33    }
34}
35
36/// `general.memory_pivot` — calls `memory.lookup` once confidence has
37/// crossed `min_confidence`. Records the top hit as evidence; never
38/// adjusts confidence on its own (memory is context, not a verdict).
39pub struct MemoryPivotSkill {
40    pub min_confidence: f32,
41    pub k: usize,
42}
43
44impl Default for MemoryPivotSkill {
45    fn default() -> Self {
46        Self {
47            min_confidence: 0.4,
48            k: 3,
49        }
50    }
51}
52
53#[async_trait]
54impl Skill for MemoryPivotSkill {
55    fn id(&self) -> &str {
56        "general.memory_pivot"
57    }
58    fn description(&self) -> &str {
59        "Retrieves similar episodes from memory once confidence is non-trivial."
60    }
61    fn applies(&self, ctx: &InvestigationContext) -> bool {
62        ctx.confidence >= self.min_confidence && !ctx.entity_id.is_empty()
63    }
64    async fn execute(
65        &self,
66        ctx: &mut InvestigationContext,
67        tools: &ToolRegistry,
68    ) -> Result<SkillOutcome, KernelError> {
69        let Ok(tool) = tools.get("memory.lookup") else {
70            return Ok(SkillOutcome::noop());
71        };
72        let v = tool
73            .invoke(json!({"query": ctx.entity_id, "k": self.k}))
74            .await?;
75        let top = v
76            .get("hits")
77            .and_then(|h| h.as_array())
78            .and_then(|a| a.first())
79            .cloned();
80        if let Some(hit) = top {
81            ctx.evidence
82                .push(Evidence::new(self.id(), "memory.hit").with_detail(hit));
83        }
84        Ok(SkillOutcome::noop())
85    }
86}
87
88#[cfg(test)]
89mod tests {
90    use std::sync::Arc;
91
92    use super::*;
93    use rig_compose::{LocalTool, Tool, ToolSchema};
94
95    #[tokio::test]
96    async fn baseline_compare_suppresses_when_within() {
97        let skill = BaselineCompareSkill;
98        let reg = ToolRegistry::new();
99        let mut ctx = InvestigationContext::new("a", "p")
100            .with_signal("baseline.available")
101            .with_signal("baseline.within");
102        ctx.confidence = 0.5;
103        let outcome = skill.execute(&mut ctx, &reg).await.unwrap();
104        assert!(outcome.confidence_delta < 0.0);
105    }
106
107    #[tokio::test]
108    async fn memory_pivot_skipped_without_tool_authorisation() {
109        let skill = MemoryPivotSkill::default();
110        let reg = ToolRegistry::new();
111        let mut ctx = InvestigationContext::new("e", "p");
112        ctx.confidence = 0.6;
113        let outcome = skill.execute(&mut ctx, &reg).await.unwrap();
114        assert_eq!(outcome.confidence_delta, 0.0);
115        assert!(ctx.evidence.is_empty());
116    }
117
118    #[tokio::test]
119    async fn memory_pivot_records_top_hit() {
120        let skill = MemoryPivotSkill::default();
121        let reg = ToolRegistry::new();
122        let schema = ToolSchema {
123            name: "memory.lookup".into(),
124            description: "stub".into(),
125            args_schema: json!({}),
126            result_schema: json!({}),
127        };
128        let stub: Arc<dyn Tool> = Arc::new(LocalTool::new(schema, |_v| async {
129            Ok(json!({"hits": [{"score": 0.9, "summary": "match", "episode_key": "k"}]}))
130        }));
131        reg.register(stub);
132        let mut ctx = InvestigationContext::new("e", "p");
133        ctx.confidence = 0.6;
134        skill.execute(&mut ctx, &reg).await.unwrap();
135        assert_eq!(ctx.evidence.len(), 1);
136        assert_eq!(ctx.evidence[0].label, "memory.hit");
137    }
138}