1use async_trait::async_trait;
4use serde_json::json;
5
6use rig_compose::{Evidence, InvestigationContext, KernelError, Skill, SkillOutcome, ToolRegistry};
7
8#[derive(Default)]
12pub struct BaselineCompareSkill;
13
14#[async_trait]
15impl Skill for BaselineCompareSkill {
16 fn id(&self) -> &str {
17 "general.baseline_compare"
18 }
19 fn description(&self) -> &str {
20 "Suppresses confidence when observed behaviour is within the entity's known baseline."
21 }
22 fn applies(&self, ctx: &InvestigationContext) -> bool {
23 ctx.has_signal("baseline.available") && ctx.has_signal("baseline.within")
24 }
25 async fn execute(
26 &self,
27 ctx: &mut InvestigationContext,
28 _tools: &ToolRegistry,
29 ) -> Result<SkillOutcome, KernelError> {
30 ctx.evidence
31 .push(Evidence::new(self.id(), "baseline.suppress"));
32 Ok(SkillOutcome::default().with_delta(-0.2))
33 }
34}
35
36pub struct MemoryPivotSkill {
40 pub min_confidence: f32,
41 pub k: usize,
42}
43
44impl Default for MemoryPivotSkill {
45 fn default() -> Self {
46 Self {
47 min_confidence: 0.4,
48 k: 3,
49 }
50 }
51}
52
53#[async_trait]
54impl Skill for MemoryPivotSkill {
55 fn id(&self) -> &str {
56 "general.memory_pivot"
57 }
58 fn description(&self) -> &str {
59 "Retrieves similar episodes from memory once confidence is non-trivial."
60 }
61 fn applies(&self, ctx: &InvestigationContext) -> bool {
62 ctx.confidence >= self.min_confidence && !ctx.entity_id.is_empty()
63 }
64 async fn execute(
65 &self,
66 ctx: &mut InvestigationContext,
67 tools: &ToolRegistry,
68 ) -> Result<SkillOutcome, KernelError> {
69 let Ok(tool) = tools.get("memory.lookup") else {
70 return Ok(SkillOutcome::noop());
71 };
72 let v = tool
73 .invoke(json!({"query": ctx.entity_id, "k": self.k}))
74 .await?;
75 let top = v
76 .get("hits")
77 .and_then(|h| h.as_array())
78 .and_then(|a| a.first())
79 .cloned();
80 if let Some(hit) = top {
81 ctx.evidence
82 .push(Evidence::new(self.id(), "memory.hit").with_detail(hit));
83 }
84 Ok(SkillOutcome::noop())
85 }
86}
87
88#[cfg(test)]
89mod tests {
90 use std::sync::Arc;
91
92 use super::*;
93 use rig_compose::{LocalTool, Tool, ToolSchema};
94
95 #[tokio::test]
96 async fn baseline_compare_suppresses_when_within() {
97 let skill = BaselineCompareSkill;
98 let reg = ToolRegistry::new();
99 let mut ctx = InvestigationContext::new("a", "p")
100 .with_signal("baseline.available")
101 .with_signal("baseline.within");
102 ctx.confidence = 0.5;
103 let outcome = skill.execute(&mut ctx, ®).await.unwrap();
104 assert!(outcome.confidence_delta < 0.0);
105 }
106
107 #[tokio::test]
108 async fn memory_pivot_skipped_without_tool_authorisation() {
109 let skill = MemoryPivotSkill::default();
110 let reg = ToolRegistry::new();
111 let mut ctx = InvestigationContext::new("e", "p");
112 ctx.confidence = 0.6;
113 let outcome = skill.execute(&mut ctx, ®).await.unwrap();
114 assert_eq!(outcome.confidence_delta, 0.0);
115 assert!(ctx.evidence.is_empty());
116 }
117
118 #[tokio::test]
119 async fn memory_pivot_records_top_hit() {
120 let skill = MemoryPivotSkill::default();
121 let reg = ToolRegistry::new();
122 let schema = ToolSchema {
123 name: "memory.lookup".into(),
124 description: "stub".into(),
125 args_schema: json!({}),
126 result_schema: json!({}),
127 };
128 let stub: Arc<dyn Tool> = Arc::new(LocalTool::new(schema, |_v| async {
129 Ok(json!({"hits": [{"score": 0.9, "summary": "match", "episode_key": "k"}]}))
130 }));
131 reg.register(stub);
132 let mut ctx = InvestigationContext::new("e", "p");
133 ctx.confidence = 0.6;
134 skill.execute(&mut ctx, ®).await.unwrap();
135 assert_eq!(ctx.evidence.len(), 1);
136 assert_eq!(ctx.evidence[0].label, "memory.hit");
137 }
138}