1use async_trait::async_trait;
4use serde_json::json;
5
6use rig_compose::{Evidence, InvestigationContext, KernelError, Skill, SkillOutcome, ToolRegistry};
7
8use crate::memory::{MemoryLookupHit, memory_lookup_trace_envelope};
9
10#[derive(Default)]
14pub struct BaselineCompareSkill;
15
16#[async_trait]
17impl Skill for BaselineCompareSkill {
18 fn id(&self) -> &str {
19 "general.baseline_compare"
20 }
21 fn description(&self) -> &str {
22 "Suppresses confidence when observed behaviour is within the entity's known baseline."
23 }
24 fn applies(&self, ctx: &InvestigationContext) -> bool {
25 ctx.has_signal("baseline.available") && ctx.has_signal("baseline.within")
26 }
27 async fn execute(
28 &self,
29 ctx: &mut InvestigationContext,
30 _tools: &ToolRegistry,
31 ) -> Result<SkillOutcome, KernelError> {
32 ctx.evidence
33 .push(Evidence::new(self.id(), "baseline.suppress"));
34 Ok(SkillOutcome::default().with_delta(-0.2))
35 }
36}
37
38pub struct MemoryPivotSkill {
42 pub min_confidence: f32,
44 pub k: usize,
46}
47
48impl Default for MemoryPivotSkill {
49 fn default() -> Self {
50 Self {
51 min_confidence: 0.4,
52 k: 3,
53 }
54 }
55}
56
57#[async_trait]
58impl Skill for MemoryPivotSkill {
59 fn id(&self) -> &str {
60 "general.memory_pivot"
61 }
62 fn description(&self) -> &str {
63 "Retrieves similar episodes from memory once confidence is non-trivial."
64 }
65 fn applies(&self, ctx: &InvestigationContext) -> bool {
66 ctx.confidence >= self.min_confidence && !ctx.entity_id.is_empty()
67 }
68 async fn execute(
69 &self,
70 ctx: &mut InvestigationContext,
71 tools: &ToolRegistry,
72 ) -> Result<SkillOutcome, KernelError> {
73 let Ok(tool) = tools.get("memory.lookup") else {
74 return Ok(SkillOutcome::noop());
75 };
76 let v = tool
77 .invoke(json!({"query": ctx.entity_id, "k": self.k}))
78 .await?;
79
80 let hits_array = v.get("hits").and_then(|h| h.as_array()).cloned();
85 let typed_hits: Vec<MemoryLookupHit> = hits_array
86 .as_ref()
87 .and_then(|arr| serde_json::from_value(json!(arr)).ok())
88 .unwrap_or_default();
89
90 if let Some(arr) = hits_array.as_ref()
91 && let Some(hit) = arr.first()
92 {
93 ctx.evidence
94 .push(Evidence::new(self.id(), "memory.hit").with_detail(hit.clone()));
95 }
96
97 if let Some(arr) = hits_array.as_ref()
98 && (typed_hits.len() == arr.len())
99 {
100 let envelope =
101 memory_lookup_trace_envelope(&ctx.entity_id, self.k, &typed_hits, None, None);
102 ctx.evidence
103 .push(Evidence::new(self.id(), "memory.trace").with_detail(envelope.to_value()));
104 }
105
106 Ok(SkillOutcome::noop())
107 }
108}
109
110#[cfg(test)]
111mod tests {
112 use std::sync::Arc;
113
114 use super::*;
115 use rig_compose::{LocalTool, Tool, ToolSchema};
116
117 #[tokio::test]
118 async fn baseline_compare_suppresses_when_within() {
119 let skill = BaselineCompareSkill;
120 let reg = ToolRegistry::new();
121 let mut ctx = InvestigationContext::new("a", "p")
122 .with_signal("baseline.available")
123 .with_signal("baseline.within");
124 ctx.confidence = 0.5;
125 let outcome = skill.execute(&mut ctx, ®).await.unwrap();
126 assert!(outcome.confidence_delta < 0.0);
127 }
128
129 #[tokio::test]
130 async fn memory_pivot_skipped_without_tool_authorisation() {
131 let skill = MemoryPivotSkill::default();
132 let reg = ToolRegistry::new();
133 let mut ctx = InvestigationContext::new("e", "p");
134 ctx.confidence = 0.6;
135 let outcome = skill.execute(&mut ctx, ®).await.unwrap();
136 assert_eq!(outcome.confidence_delta, 0.0);
137 assert!(ctx.evidence.is_empty());
138 }
139
140 #[tokio::test]
141 async fn memory_pivot_records_top_hit() {
142 let skill = MemoryPivotSkill::default();
143 let reg = ToolRegistry::new();
144 let schema = ToolSchema {
145 name: "memory.lookup".into(),
146 description: "stub".into(),
147 args_schema: json!({}),
148 result_schema: json!({}),
149 };
150 let stub: Arc<dyn Tool> = Arc::new(LocalTool::new(schema, |_v| async {
151 Ok(json!({"hits": [{"score": 0.9, "summary": "match", "key": "k"}]}))
152 }));
153 reg.register(stub);
154 let mut ctx = InvestigationContext::new("e", "p");
155 ctx.confidence = 0.6;
156 skill.execute(&mut ctx, ®).await.unwrap();
157 assert_eq!(ctx.evidence.len(), 2);
159 assert_eq!(ctx.evidence[0].label, "memory.hit");
160 assert_eq!(ctx.evidence[1].label, "memory.trace");
161 let trace = &ctx.evidence[1].detail;
162 assert_eq!(trace["resource"], "memory");
163 assert_eq!(trace["operation"], "lookup");
164 assert_eq!(trace["output_summary"]["hit_count"], 1);
165 assert_eq!(trace["output_summary"]["top_key"], "k");
166 }
167
168 #[tokio::test]
169 async fn memory_pivot_emits_no_hits_trace_when_empty() {
170 let skill = MemoryPivotSkill::default();
171 let reg = ToolRegistry::new();
172 let schema = ToolSchema {
173 name: "memory.lookup".into(),
174 description: "stub".into(),
175 args_schema: json!({}),
176 result_schema: json!({}),
177 };
178 let stub: Arc<dyn Tool> = Arc::new(LocalTool::new(schema, |_v| async {
179 Ok(json!({"hits": []}))
180 }));
181 reg.register(stub);
182 let mut ctx = InvestigationContext::new("nothing", "p");
183 ctx.confidence = 0.6;
184 skill.execute(&mut ctx, ®).await.unwrap();
185 assert_eq!(ctx.evidence.len(), 1);
187 assert_eq!(ctx.evidence[0].label, "memory.trace");
188 let trace = &ctx.evidence[0].detail;
189 assert_eq!(trace["output_summary"]["hit_count"], 0);
190 assert_eq!(trace["reason"], "no_hits");
191 }
192}