1use std::sync::Arc;
10
11use async_trait::async_trait;
12use serde::{Deserialize, Serialize};
13use serde_json::{Value, json};
14use thiserror::Error;
15
16use rig_compose::{KernelError, Tool, ToolSchema};
17
18#[derive(Debug, Error)]
20pub enum MemoryLookupError {
21 #[error("memory lookup backend error: {0}")]
23 Backend(String),
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct MemoryLookupHit {
29 pub score: f32,
31 pub summary: String,
33 #[serde(skip_serializing_if = "Option::is_none")]
35 pub key: Option<String>,
36 #[serde(default, skip_serializing_if = "Option::is_none")]
38 pub source_uri: Option<String>,
39 #[serde(default, skip_serializing_if = "Option::is_none")]
41 pub principal: Option<String>,
42 #[serde(default, skip_serializing_if = "Option::is_none")]
45 pub scope: Option<String>,
46 #[serde(default, skip_serializing_if = "Option::is_none")]
48 pub recorded_at_millis: Option<i64>,
49 #[serde(default, skip_serializing_if = "Value::is_null")]
51 pub metadata: Value,
52}
53
54impl MemoryLookupHit {
55 pub fn new(score: f32, summary: impl Into<String>) -> Self {
57 Self {
58 score,
59 summary: summary.into(),
60 key: None,
61 source_uri: None,
62 principal: None,
63 scope: None,
64 recorded_at_millis: None,
65 metadata: Value::Null,
66 }
67 }
68
69 pub fn with_key(mut self, key: impl Into<String>) -> Self {
71 self.key = Some(key.into());
72 self
73 }
74
75 pub fn with_source_uri(mut self, source_uri: impl Into<String>) -> Self {
77 self.source_uri = Some(source_uri.into());
78 self
79 }
80
81 pub fn with_principal(mut self, principal: impl Into<String>) -> Self {
83 self.principal = Some(principal.into());
84 self
85 }
86
87 pub fn with_scope(mut self, scope: impl Into<String>) -> Self {
89 self.scope = Some(scope.into());
90 self
91 }
92
93 pub fn with_recorded_at_millis(mut self, recorded_at_millis: i64) -> Self {
95 self.recorded_at_millis = Some(recorded_at_millis);
96 self
97 }
98
99 pub fn with_metadata(mut self, metadata: Value) -> Self {
101 self.metadata = metadata;
102 self
103 }
104}
105
106#[async_trait]
108pub trait MemoryLookupStore: Send + Sync {
109 async fn lookup(
111 &self,
112 query: &str,
113 k: usize,
114 ) -> Result<Vec<MemoryLookupHit>, MemoryLookupError>;
115}
116
117pub struct MemoryLookupTool {
119 store: Arc<dyn MemoryLookupStore>,
120}
121
122impl MemoryLookupTool {
123 pub const NAME: &'static str = "memory.lookup";
125
126 pub fn new(store: Arc<dyn MemoryLookupStore>) -> Self {
128 Self { store }
129 }
130
131 pub fn arc(store: Arc<dyn MemoryLookupStore>) -> Arc<dyn Tool> {
133 Arc::new(Self::new(store))
134 }
135}
136
137#[derive(Deserialize)]
138struct LookupArgs {
139 query: String,
140 #[serde(default = "default_k")]
141 k: usize,
142}
143
144fn default_k() -> usize {
145 3
146}
147
148#[async_trait]
149impl Tool for MemoryLookupTool {
150 fn schema(&self) -> ToolSchema {
151 ToolSchema {
152 name: Self::NAME.into(),
153 description: "Retrieve up to k similar memory episodes for a query.".into(),
154 args_schema: json!({
155 "type": "object",
156 "required": ["query"],
157 "properties": {
158 "query": {"type": "string"},
159 "k": {"type": "integer", "minimum": 1, "default": 3}
160 }
161 }),
162 result_schema: json!({
163 "type": "object",
164 "properties": {
165 "hits": {
166 "type": "array",
167 "items": {
168 "type": "object",
169 "properties": {
170 "score": {"type": "number"},
171 "summary": {"type": "string"},
172 "key": {"type": "string"},
173 "source_uri": {"type": "string"},
174 "principal": {"type": "string"},
175 "scope": {"type": "string"},
176 "recorded_at_millis": {"type": "integer"},
177 "metadata": {"type": "object"}
178 }
179 }
180 }
181 }
182 }),
183 }
184 }
185
186 fn name(&self) -> rig_compose::tool::ToolName {
187 Self::NAME.to_string()
188 }
189
190 async fn invoke(&self, args: Value) -> Result<Value, KernelError> {
191 let parsed: LookupArgs = serde_json::from_value(args)?;
192 if parsed.k == 0 {
193 return Err(KernelError::InvalidArgument(
194 "memory.lookup requires k >= 1".into(),
195 ));
196 }
197 let hits = self
198 .store
199 .lookup(&parsed.query, parsed.k)
200 .await
201 .map_err(|err| KernelError::ToolFailed(err.to_string()))?;
202 Ok(json!({ "hits": hits }))
203 }
204}
205
206#[cfg(test)]
207mod tests {
208 use super::*;
209
210 struct StubMemory;
211
212 #[async_trait]
213 impl MemoryLookupStore for StubMemory {
214 async fn lookup(
215 &self,
216 query: &str,
217 k: usize,
218 ) -> Result<Vec<MemoryLookupHit>, MemoryLookupError> {
219 Ok(vec![
220 MemoryLookupHit::new(0.9, format!("matched {query}"))
221 .with_key("ep-1")
222 .with_metadata(json!({"rank": 1})),
223 ]
224 .into_iter()
225 .take(k)
226 .collect())
227 }
228 }
229
230 #[tokio::test]
231 async fn lookup_tool_returns_hits() {
232 let tool = MemoryLookupTool::new(Arc::new(StubMemory));
233 let out = tool
234 .invoke(json!({"query": "beacon", "k": 1}))
235 .await
236 .unwrap();
237 let score = out["hits"][0]["score"].as_f64().unwrap();
238 assert!((score - 0.9).abs() < 1e-6);
239 assert_eq!(out["hits"][0]["key"], "ep-1");
240 }
241
242 #[test]
243 fn lookup_hit_serializes_shared_metadata() {
244 let hit = MemoryLookupHit::new(0.75, "matched episode")
245 .with_key("ep-7")
246 .with_source_uri("memory://episode/7")
247 .with_principal("alice")
248 .with_scope("workspace")
249 .with_recorded_at_millis(1_700_000_000_000);
250
251 let json = serde_json::to_value(hit).unwrap();
252
253 assert_eq!(json["key"], "ep-7");
254 assert_eq!(json["source_uri"], "memory://episode/7");
255 assert_eq!(json["principal"], "alice");
256 assert_eq!(json["scope"], "workspace");
257 assert_eq!(json["recorded_at_millis"], 1_700_000_000_000_i64);
258 }
259
260 #[tokio::test]
261 async fn lookup_tool_rejects_zero_k() {
262 let tool = MemoryLookupTool::new(Arc::new(StubMemory));
263 let err = tool
264 .invoke(json!({"query": "beacon", "k": 0}))
265 .await
266 .unwrap_err();
267 assert!(matches!(err, KernelError::InvalidArgument(_)));
268 }
269}