1use async_trait::async_trait;
4use serde_json::json;
5
6use rig_compose::{Evidence, InvestigationContext, KernelError, Skill, SkillOutcome, ToolRegistry};
7
8use crate::memory::{MemoryLookupHit, memory_lookup_trace_envelope};
9
10#[derive(Default)]
14pub struct BaselineCompareSkill;
15
16#[async_trait]
17impl Skill for BaselineCompareSkill {
18 fn id(&self) -> &str {
19 "general.baseline_compare"
20 }
21 fn description(&self) -> &str {
22 "Suppresses confidence when observed behaviour is within the entity's known baseline."
23 }
24 fn applies(&self, ctx: &InvestigationContext) -> bool {
25 ctx.has_signal("baseline.available") && ctx.has_signal("baseline.within")
26 }
27 async fn execute(
28 &self,
29 ctx: &mut InvestigationContext,
30 _tools: &ToolRegistry,
31 ) -> Result<SkillOutcome, KernelError> {
32 ctx.evidence
33 .push(Evidence::new(self.id(), "baseline.suppress"));
34 Ok(SkillOutcome::default().with_delta(-0.2))
35 }
36}
37
38pub struct MemoryPivotSkill {
42 pub min_confidence: f32,
43 pub k: usize,
44}
45
46impl Default for MemoryPivotSkill {
47 fn default() -> Self {
48 Self {
49 min_confidence: 0.4,
50 k: 3,
51 }
52 }
53}
54
55#[async_trait]
56impl Skill for MemoryPivotSkill {
57 fn id(&self) -> &str {
58 "general.memory_pivot"
59 }
60 fn description(&self) -> &str {
61 "Retrieves similar episodes from memory once confidence is non-trivial."
62 }
63 fn applies(&self, ctx: &InvestigationContext) -> bool {
64 ctx.confidence >= self.min_confidence && !ctx.entity_id.is_empty()
65 }
66 async fn execute(
67 &self,
68 ctx: &mut InvestigationContext,
69 tools: &ToolRegistry,
70 ) -> Result<SkillOutcome, KernelError> {
71 let Ok(tool) = tools.get("memory.lookup") else {
72 return Ok(SkillOutcome::noop());
73 };
74 let v = tool
75 .invoke(json!({"query": ctx.entity_id, "k": self.k}))
76 .await?;
77
78 let hits_array = v.get("hits").and_then(|h| h.as_array()).cloned();
83 let typed_hits: Vec<MemoryLookupHit> = hits_array
84 .as_ref()
85 .and_then(|arr| serde_json::from_value(json!(arr)).ok())
86 .unwrap_or_default();
87
88 if let Some(arr) = hits_array.as_ref()
89 && let Some(hit) = arr.first()
90 {
91 ctx.evidence
92 .push(Evidence::new(self.id(), "memory.hit").with_detail(hit.clone()));
93 }
94
95 if let Some(arr) = hits_array.as_ref()
96 && (typed_hits.len() == arr.len())
97 {
98 let envelope =
99 memory_lookup_trace_envelope(&ctx.entity_id, self.k, &typed_hits, None, None);
100 ctx.evidence
101 .push(Evidence::new(self.id(), "memory.trace").with_detail(envelope.to_value()));
102 }
103
104 Ok(SkillOutcome::noop())
105 }
106}
107
108#[cfg(test)]
109mod tests {
110 use std::sync::Arc;
111
112 use super::*;
113 use rig_compose::{LocalTool, Tool, ToolSchema};
114
115 #[tokio::test]
116 async fn baseline_compare_suppresses_when_within() {
117 let skill = BaselineCompareSkill;
118 let reg = ToolRegistry::new();
119 let mut ctx = InvestigationContext::new("a", "p")
120 .with_signal("baseline.available")
121 .with_signal("baseline.within");
122 ctx.confidence = 0.5;
123 let outcome = skill.execute(&mut ctx, ®).await.unwrap();
124 assert!(outcome.confidence_delta < 0.0);
125 }
126
127 #[tokio::test]
128 async fn memory_pivot_skipped_without_tool_authorisation() {
129 let skill = MemoryPivotSkill::default();
130 let reg = ToolRegistry::new();
131 let mut ctx = InvestigationContext::new("e", "p");
132 ctx.confidence = 0.6;
133 let outcome = skill.execute(&mut ctx, ®).await.unwrap();
134 assert_eq!(outcome.confidence_delta, 0.0);
135 assert!(ctx.evidence.is_empty());
136 }
137
138 #[tokio::test]
139 async fn memory_pivot_records_top_hit() {
140 let skill = MemoryPivotSkill::default();
141 let reg = ToolRegistry::new();
142 let schema = ToolSchema {
143 name: "memory.lookup".into(),
144 description: "stub".into(),
145 args_schema: json!({}),
146 result_schema: json!({}),
147 };
148 let stub: Arc<dyn Tool> = Arc::new(LocalTool::new(schema, |_v| async {
149 Ok(json!({"hits": [{"score": 0.9, "summary": "match", "key": "k"}]}))
150 }));
151 reg.register(stub);
152 let mut ctx = InvestigationContext::new("e", "p");
153 ctx.confidence = 0.6;
154 skill.execute(&mut ctx, ®).await.unwrap();
155 assert_eq!(ctx.evidence.len(), 2);
157 assert_eq!(ctx.evidence[0].label, "memory.hit");
158 assert_eq!(ctx.evidence[1].label, "memory.trace");
159 let trace = &ctx.evidence[1].detail;
160 assert_eq!(trace["resource"], "memory");
161 assert_eq!(trace["operation"], "lookup");
162 assert_eq!(trace["output_summary"]["hit_count"], 1);
163 assert_eq!(trace["output_summary"]["top_key"], "k");
164 }
165
166 #[tokio::test]
167 async fn memory_pivot_emits_no_hits_trace_when_empty() {
168 let skill = MemoryPivotSkill::default();
169 let reg = ToolRegistry::new();
170 let schema = ToolSchema {
171 name: "memory.lookup".into(),
172 description: "stub".into(),
173 args_schema: json!({}),
174 result_schema: json!({}),
175 };
176 let stub: Arc<dyn Tool> = Arc::new(LocalTool::new(schema, |_v| async {
177 Ok(json!({"hits": []}))
178 }));
179 reg.register(stub);
180 let mut ctx = InvestigationContext::new("nothing", "p");
181 ctx.confidence = 0.6;
182 skill.execute(&mut ctx, ®).await.unwrap();
183 assert_eq!(ctx.evidence.len(), 1);
185 assert_eq!(ctx.evidence[0].label, "memory.trace");
186 let trace = &ctx.evidence[0].detail;
187 assert_eq!(trace["output_summary"]["hit_count"], 0);
188 assert_eq!(trace["reason"], "no_hits");
189 }
190}