1use std::sync::Arc;
10
11use async_trait::async_trait;
12use serde::{Deserialize, Serialize};
13use serde_json::{Value, json};
14use thiserror::Error;
15
16use rig_compose::{KernelError, Tool, ToolSchema};
17
18use crate::trace::ResourceTraceEnvelope;
19
20const TRACE_RESOURCE: &str = "memory";
21const TRACE_OPERATION: &str = "lookup";
22const TRACE_KIND: &str = "memory_lookup";
23
24pub const TRACE_REASON_NO_HITS: &str = "no_hits";
27pub const TRACE_REASON_BACKEND_ERROR: &str = "backend_error";
29
30#[derive(Debug, Error)]
32pub enum MemoryLookupError {
33 #[error("memory lookup backend error: {0}")]
35 Backend(String),
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct MemoryLookupHit {
41 pub score: f32,
43 pub summary: String,
45 #[serde(skip_serializing_if = "Option::is_none")]
47 pub key: Option<String>,
48 #[serde(default, skip_serializing_if = "Option::is_none")]
50 pub source_uri: Option<String>,
51 #[serde(default, skip_serializing_if = "Option::is_none")]
53 pub principal: Option<String>,
54 #[serde(default, skip_serializing_if = "Option::is_none")]
57 pub scope: Option<String>,
58 #[serde(default, skip_serializing_if = "Option::is_none")]
60 pub recorded_at_millis: Option<i64>,
61 #[serde(default, skip_serializing_if = "Value::is_null")]
63 pub metadata: Value,
64}
65
66impl MemoryLookupHit {
67 pub fn new(score: f32, summary: impl Into<String>) -> Self {
69 Self {
70 score,
71 summary: summary.into(),
72 key: None,
73 source_uri: None,
74 principal: None,
75 scope: None,
76 recorded_at_millis: None,
77 metadata: Value::Null,
78 }
79 }
80
81 pub fn with_key(mut self, key: impl Into<String>) -> Self {
83 self.key = Some(key.into());
84 self
85 }
86
87 pub fn with_source_uri(mut self, source_uri: impl Into<String>) -> Self {
89 self.source_uri = Some(source_uri.into());
90 self
91 }
92
93 pub fn with_principal(mut self, principal: impl Into<String>) -> Self {
95 self.principal = Some(principal.into());
96 self
97 }
98
99 pub fn with_scope(mut self, scope: impl Into<String>) -> Self {
101 self.scope = Some(scope.into());
102 self
103 }
104
105 pub fn with_recorded_at_millis(mut self, recorded_at_millis: i64) -> Self {
107 self.recorded_at_millis = Some(recorded_at_millis);
108 self
109 }
110
111 pub fn with_metadata(mut self, metadata: Value) -> Self {
113 self.metadata = metadata;
114 self
115 }
116}
117
118#[async_trait]
120pub trait MemoryLookupStore: Send + Sync {
121 async fn lookup(
123 &self,
124 query: &str,
125 k: usize,
126 ) -> Result<Vec<MemoryLookupHit>, MemoryLookupError>;
127}
128
129pub struct MemoryLookupTool {
131 store: Arc<dyn MemoryLookupStore>,
132}
133
134impl MemoryLookupTool {
135 pub const NAME: &'static str = "memory.lookup";
137
138 pub fn new(store: Arc<dyn MemoryLookupStore>) -> Self {
140 Self { store }
141 }
142
143 pub fn arc(store: Arc<dyn MemoryLookupStore>) -> Arc<dyn Tool> {
145 Arc::new(Self::new(store))
146 }
147}
148
149#[derive(Deserialize)]
150struct LookupArgs {
151 query: String,
152 #[serde(default = "default_k")]
153 k: usize,
154}
155
156fn default_k() -> usize {
157 3
158}
159
160#[async_trait]
161impl Tool for MemoryLookupTool {
162 fn schema(&self) -> ToolSchema {
163 ToolSchema {
164 name: Self::NAME.into(),
165 description: "Retrieve up to k similar memory episodes for a query.".into(),
166 args_schema: json!({
167 "type": "object",
168 "required": ["query"],
169 "properties": {
170 "query": {"type": "string"},
171 "k": {"type": "integer", "minimum": 1, "default": 3}
172 }
173 }),
174 result_schema: json!({
175 "type": "object",
176 "properties": {
177 "hits": {
178 "type": "array",
179 "items": {
180 "type": "object",
181 "properties": {
182 "score": {"type": "number"},
183 "summary": {"type": "string"},
184 "key": {"type": "string"},
185 "source_uri": {"type": "string"},
186 "principal": {"type": "string"},
187 "scope": {"type": "string"},
188 "recorded_at_millis": {"type": "integer"},
189 "metadata": {"type": "object"}
190 }
191 }
192 }
193 }
194 }),
195 }
196 }
197
198 fn name(&self) -> rig_compose::tool::ToolName {
199 Self::NAME.to_string()
200 }
201
202 async fn invoke(&self, args: Value) -> Result<Value, KernelError> {
203 let parsed: LookupArgs = serde_json::from_value(args)?;
204 if parsed.k == 0 {
205 return Err(KernelError::InvalidArgument(
206 "memory.lookup requires k >= 1".into(),
207 ));
208 }
209 let hits = self
210 .store
211 .lookup(&parsed.query, parsed.k)
212 .await
213 .map_err(|err| KernelError::ToolFailed(err.to_string()))?;
214 Ok(json!({ "hits": hits }))
215 }
216}
217
218#[must_use]
241pub fn memory_lookup_trace_envelope(
242 query: &str,
243 k: usize,
244 hits: &[MemoryLookupHit],
245 principal: Option<&str>,
246 scope: Option<&str>,
247) -> ResourceTraceEnvelope {
248 let mut input = json!({
249 "query": query,
250 "k": k,
251 });
252 if let Some(map) = input.as_object_mut() {
253 if let Some(principal) = principal {
254 map.insert("principal".into(), Value::String(principal.to_string()));
255 }
256 if let Some(scope) = scope {
257 map.insert("scope".into(), Value::String(scope.to_string()));
258 }
259 }
260
261 let mut output = json!({
262 "hit_count": hits.len(),
263 });
264 if let (Some(top), Some(map)) = (hits.first(), output.as_object_mut()) {
265 if let Some(score) = serde_json::Number::from_f64(top.score as f64) {
266 map.insert("top_score".into(), Value::Number(score));
267 }
268 if let Some(key) = &top.key {
269 map.insert("top_key".into(), Value::String(key.clone()));
270 }
271 }
272
273 let mut envelope = ResourceTraceEnvelope::new(TRACE_RESOURCE, TRACE_OPERATION, TRACE_KIND)
274 .with_input_summary(input)
275 .with_output_summary(output);
276
277 if hits.is_empty() {
278 envelope = envelope.with_reason(TRACE_REASON_NO_HITS);
279 }
280
281 if let Some(top) = hits.first() {
282 let mut metadata = serde_json::Map::new();
283 if let Some(source_uri) = &top.source_uri {
284 metadata.insert("source_uri".into(), Value::String(source_uri.clone()));
285 }
286 if let Some(recorded_at_millis) = top.recorded_at_millis {
287 metadata.insert(
288 "recorded_at_millis".into(),
289 Value::Number(serde_json::Number::from(recorded_at_millis)),
290 );
291 }
292 if let Some(top_principal) = &top.principal
293 && principal.is_none_or(|p| p != top_principal)
294 {
295 metadata.insert("top_principal".into(), Value::String(top_principal.clone()));
296 }
297 if let Some(top_scope) = &top.scope
298 && scope.is_none_or(|s| s != top_scope)
299 {
300 metadata.insert("top_scope".into(), Value::String(top_scope.clone()));
301 }
302 if !metadata.is_empty() {
303 envelope = envelope.with_metadata(Value::Object(metadata));
304 }
305 }
306
307 envelope
308}
309
310#[cfg(test)]
311mod tests {
312 use super::*;
313
314 struct StubMemory;
315
316 #[async_trait]
317 impl MemoryLookupStore for StubMemory {
318 async fn lookup(
319 &self,
320 query: &str,
321 k: usize,
322 ) -> Result<Vec<MemoryLookupHit>, MemoryLookupError> {
323 Ok(vec![
324 MemoryLookupHit::new(0.9, format!("matched {query}"))
325 .with_key("ep-1")
326 .with_metadata(json!({"rank": 1})),
327 ]
328 .into_iter()
329 .take(k)
330 .collect())
331 }
332 }
333
334 #[tokio::test]
335 async fn lookup_tool_returns_hits() {
336 let tool = MemoryLookupTool::new(Arc::new(StubMemory));
337 let out = tool
338 .invoke(json!({"query": "beacon", "k": 1}))
339 .await
340 .unwrap();
341 let score = out["hits"][0]["score"].as_f64().unwrap();
342 assert!((score - 0.9).abs() < 1e-6);
343 assert_eq!(out["hits"][0]["key"], "ep-1");
344 }
345
346 #[test]
347 fn lookup_hit_serializes_shared_metadata() {
348 let hit = MemoryLookupHit::new(0.75, "matched episode")
349 .with_key("ep-7")
350 .with_source_uri("memory://episode/7")
351 .with_principal("alice")
352 .with_scope("workspace")
353 .with_recorded_at_millis(1_700_000_000_000);
354
355 let json = serde_json::to_value(hit).unwrap();
356
357 assert_eq!(json["key"], "ep-7");
358 assert_eq!(json["source_uri"], "memory://episode/7");
359 assert_eq!(json["principal"], "alice");
360 assert_eq!(json["scope"], "workspace");
361 assert_eq!(json["recorded_at_millis"], 1_700_000_000_000_i64);
362 }
363
364 #[tokio::test]
365 async fn lookup_tool_rejects_zero_k() {
366 let tool = MemoryLookupTool::new(Arc::new(StubMemory));
367 let err = tool
368 .invoke(json!({"query": "beacon", "k": 0}))
369 .await
370 .unwrap_err();
371 assert!(matches!(err, KernelError::InvalidArgument(_)));
372 }
373
374 #[test]
375 fn trace_envelope_summarises_hits_and_metadata() {
376 let hits = vec![
377 MemoryLookupHit::new(0.91, "top hit")
378 .with_key("ep-1")
379 .with_source_uri("memory://ep/1")
380 .with_recorded_at_millis(1_700_000_000_000)
381 .with_principal("alice")
382 .with_scope("workspace"),
383 MemoryLookupHit::new(0.42, "runner up").with_key("ep-2"),
384 ];
385
386 let envelope =
387 memory_lookup_trace_envelope("beacon", 3, &hits, Some("alice"), Some("workspace"));
388
389 assert_eq!(envelope.version, ResourceTraceEnvelope::VERSION);
390 assert_eq!(envelope.resource, "memory");
391 assert_eq!(envelope.operation, "lookup");
392 assert_eq!(envelope.trace_kind, "memory_lookup");
393 assert_eq!(envelope.input_summary["query"], "beacon");
394 assert_eq!(envelope.input_summary["k"], 3);
395 assert_eq!(envelope.input_summary["principal"], "alice");
396 assert_eq!(envelope.input_summary["scope"], "workspace");
397 assert_eq!(envelope.output_summary["hit_count"], 2);
398 let top_score = envelope.output_summary["top_score"].as_f64().unwrap();
399 assert!((top_score - 0.91).abs() < 1e-6);
400 assert_eq!(envelope.output_summary["top_key"], "ep-1");
401 assert!(envelope.reason.is_none());
402 assert_eq!(envelope.metadata["source_uri"], "memory://ep/1");
403 assert_eq!(
404 envelope.metadata["recorded_at_millis"],
405 1_700_000_000_000_i64
406 );
407 assert!(envelope.metadata.get("top_principal").is_none());
410 assert!(envelope.metadata.get("top_scope").is_none());
411 }
412
413 #[test]
414 fn trace_envelope_emits_no_hits_reason_when_empty() {
415 let envelope = memory_lookup_trace_envelope("nothing", 5, &[], None, None);
416 assert_eq!(envelope.output_summary["hit_count"], 0);
417 assert!(envelope.output_summary.get("top_score").is_none());
418 assert_eq!(envelope.reason.as_deref(), Some(TRACE_REASON_NO_HITS));
419 assert!(envelope.metadata.is_null());
420 assert!(envelope.input_summary.get("principal").is_none());
421 }
422
423 #[test]
424 fn trace_envelope_records_mismatched_top_principal_scope() {
425 let hits = vec![
426 MemoryLookupHit::new(0.5, "cross-tenant")
427 .with_key("ep-9")
428 .with_principal("bob")
429 .with_scope("other"),
430 ];
431
432 let envelope =
433 memory_lookup_trace_envelope("q", 1, &hits, Some("alice"), Some("workspace"));
434
435 assert_eq!(envelope.metadata["top_principal"], "bob");
436 assert_eq!(envelope.metadata["top_scope"], "other");
437 }
438}