Skip to main content

rig_resources/
skills.rs

1//! Prebuilt domain-neutral skills.
2
3use async_trait::async_trait;
4use serde_json::json;
5
6use rig_compose::{Evidence, InvestigationContext, KernelError, Skill, SkillOutcome, ToolRegistry};
7
8use crate::memory::{MemoryLookupHit, memory_lookup_trace_envelope};
9
10/// `general.baseline_compare` — suppresses confidence when behaviour falls
11/// inside the entity's known baseline. Conservative by design: if no
12/// `baseline.available` signal is present the skill is a no-op.
13#[derive(Default)]
14pub struct BaselineCompareSkill;
15
16#[async_trait]
17impl Skill for BaselineCompareSkill {
18    fn id(&self) -> &str {
19        "general.baseline_compare"
20    }
21    fn description(&self) -> &str {
22        "Suppresses confidence when observed behaviour is within the entity's known baseline."
23    }
24    fn applies(&self, ctx: &InvestigationContext) -> bool {
25        ctx.has_signal("baseline.available") && ctx.has_signal("baseline.within")
26    }
27    async fn execute(
28        &self,
29        ctx: &mut InvestigationContext,
30        _tools: &ToolRegistry,
31    ) -> Result<SkillOutcome, KernelError> {
32        ctx.evidence
33            .push(Evidence::new(self.id(), "baseline.suppress"));
34        Ok(SkillOutcome::default().with_delta(-0.2))
35    }
36}
37
38/// `general.memory_pivot` — calls `memory.lookup` once confidence has
39/// crossed `min_confidence`. Records the top hit as evidence; never
40/// adjusts confidence on its own (memory is context, not a verdict).
41pub struct MemoryPivotSkill {
42    pub min_confidence: f32,
43    pub k: usize,
44}
45
46impl Default for MemoryPivotSkill {
47    fn default() -> Self {
48        Self {
49            min_confidence: 0.4,
50            k: 3,
51        }
52    }
53}
54
55#[async_trait]
56impl Skill for MemoryPivotSkill {
57    fn id(&self) -> &str {
58        "general.memory_pivot"
59    }
60    fn description(&self) -> &str {
61        "Retrieves similar episodes from memory once confidence is non-trivial."
62    }
63    fn applies(&self, ctx: &InvestigationContext) -> bool {
64        ctx.confidence >= self.min_confidence && !ctx.entity_id.is_empty()
65    }
66    async fn execute(
67        &self,
68        ctx: &mut InvestigationContext,
69        tools: &ToolRegistry,
70    ) -> Result<SkillOutcome, KernelError> {
71        let Ok(tool) = tools.get("memory.lookup") else {
72            return Ok(SkillOutcome::noop());
73        };
74        let v = tool
75            .invoke(json!({"query": ctx.entity_id, "k": self.k}))
76            .await?;
77
78        // Decode typed hits when the tool conforms to MemoryLookupTool's
79        // schema. Stores that emit a different shape get the legacy
80        // raw-JSON evidence path without the trace envelope; this keeps
81        // the skill backward-compatible with non-canonical memory tools.
82        let hits_array = v.get("hits").and_then(|h| h.as_array()).cloned();
83        let typed_hits: Vec<MemoryLookupHit> = hits_array
84            .as_ref()
85            .and_then(|arr| serde_json::from_value(json!(arr)).ok())
86            .unwrap_or_default();
87
88        if let Some(arr) = hits_array.as_ref()
89            && let Some(hit) = arr.first()
90        {
91            ctx.evidence
92                .push(Evidence::new(self.id(), "memory.hit").with_detail(hit.clone()));
93        }
94
95        if let Some(arr) = hits_array.as_ref()
96            && (typed_hits.len() == arr.len())
97        {
98            let envelope =
99                memory_lookup_trace_envelope(&ctx.entity_id, self.k, &typed_hits, None, None);
100            ctx.evidence
101                .push(Evidence::new(self.id(), "memory.trace").with_detail(envelope.to_value()));
102        }
103
104        Ok(SkillOutcome::noop())
105    }
106}
107
108#[cfg(test)]
109mod tests {
110    use std::sync::Arc;
111
112    use super::*;
113    use rig_compose::{LocalTool, Tool, ToolSchema};
114
115    #[tokio::test]
116    async fn baseline_compare_suppresses_when_within() {
117        let skill = BaselineCompareSkill;
118        let reg = ToolRegistry::new();
119        let mut ctx = InvestigationContext::new("a", "p")
120            .with_signal("baseline.available")
121            .with_signal("baseline.within");
122        ctx.confidence = 0.5;
123        let outcome = skill.execute(&mut ctx, &reg).await.unwrap();
124        assert!(outcome.confidence_delta < 0.0);
125    }
126
127    #[tokio::test]
128    async fn memory_pivot_skipped_without_tool_authorisation() {
129        let skill = MemoryPivotSkill::default();
130        let reg = ToolRegistry::new();
131        let mut ctx = InvestigationContext::new("e", "p");
132        ctx.confidence = 0.6;
133        let outcome = skill.execute(&mut ctx, &reg).await.unwrap();
134        assert_eq!(outcome.confidence_delta, 0.0);
135        assert!(ctx.evidence.is_empty());
136    }
137
138    #[tokio::test]
139    async fn memory_pivot_records_top_hit() {
140        let skill = MemoryPivotSkill::default();
141        let reg = ToolRegistry::new();
142        let schema = ToolSchema {
143            name: "memory.lookup".into(),
144            description: "stub".into(),
145            args_schema: json!({}),
146            result_schema: json!({}),
147        };
148        let stub: Arc<dyn Tool> = Arc::new(LocalTool::new(schema, |_v| async {
149            Ok(json!({"hits": [{"score": 0.9, "summary": "match", "key": "k"}]}))
150        }));
151        reg.register(stub);
152        let mut ctx = InvestigationContext::new("e", "p");
153        ctx.confidence = 0.6;
154        skill.execute(&mut ctx, &reg).await.unwrap();
155        // memory.hit (raw top JSON) + memory.trace (trace envelope)
156        assert_eq!(ctx.evidence.len(), 2);
157        assert_eq!(ctx.evidence[0].label, "memory.hit");
158        assert_eq!(ctx.evidence[1].label, "memory.trace");
159        let trace = &ctx.evidence[1].detail;
160        assert_eq!(trace["resource"], "memory");
161        assert_eq!(trace["operation"], "lookup");
162        assert_eq!(trace["output_summary"]["hit_count"], 1);
163        assert_eq!(trace["output_summary"]["top_key"], "k");
164    }
165
166    #[tokio::test]
167    async fn memory_pivot_emits_no_hits_trace_when_empty() {
168        let skill = MemoryPivotSkill::default();
169        let reg = ToolRegistry::new();
170        let schema = ToolSchema {
171            name: "memory.lookup".into(),
172            description: "stub".into(),
173            args_schema: json!({}),
174            result_schema: json!({}),
175        };
176        let stub: Arc<dyn Tool> = Arc::new(LocalTool::new(schema, |_v| async {
177            Ok(json!({"hits": []}))
178        }));
179        reg.register(stub);
180        let mut ctx = InvestigationContext::new("nothing", "p");
181        ctx.confidence = 0.6;
182        skill.execute(&mut ctx, &reg).await.unwrap();
183        // Only memory.trace — no memory.hit when the array is empty.
184        assert_eq!(ctx.evidence.len(), 1);
185        assert_eq!(ctx.evidence[0].label, "memory.trace");
186        let trace = &ctx.evidence[0].detail;
187        assert_eq!(trace["output_summary"]["hit_count"], 0);
188        assert_eq!(trace["reason"], "no_hits");
189    }
190}