1use std::sync::Arc;
10
11use async_trait::async_trait;
12use serde::{Deserialize, Serialize};
13use serde_json::{Value, json};
14use thiserror::Error;
15
16use rig_compose::{KernelError, Tool, ToolSchema};
17
18#[derive(Debug, Error)]
20pub enum MemoryLookupError {
21 #[error("memory lookup backend error: {0}")]
23 Backend(String),
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct MemoryLookupHit {
29 pub score: f32,
31 pub summary: String,
33 #[serde(skip_serializing_if = "Option::is_none")]
35 pub key: Option<String>,
36 #[serde(default, skip_serializing_if = "Value::is_null")]
38 pub metadata: Value,
39}
40
41impl MemoryLookupHit {
42 pub fn new(score: f32, summary: impl Into<String>) -> Self {
44 Self {
45 score,
46 summary: summary.into(),
47 key: None,
48 metadata: Value::Null,
49 }
50 }
51
52 pub fn with_key(mut self, key: impl Into<String>) -> Self {
54 self.key = Some(key.into());
55 self
56 }
57
58 pub fn with_metadata(mut self, metadata: Value) -> Self {
60 self.metadata = metadata;
61 self
62 }
63}
64
65#[async_trait]
67pub trait MemoryLookupStore: Send + Sync {
68 async fn lookup(
70 &self,
71 query: &str,
72 k: usize,
73 ) -> Result<Vec<MemoryLookupHit>, MemoryLookupError>;
74}
75
76pub struct MemoryLookupTool {
78 store: Arc<dyn MemoryLookupStore>,
79}
80
81impl MemoryLookupTool {
82 pub const NAME: &'static str = "memory.lookup";
84
85 pub fn new(store: Arc<dyn MemoryLookupStore>) -> Self {
87 Self { store }
88 }
89
90 pub fn arc(store: Arc<dyn MemoryLookupStore>) -> Arc<dyn Tool> {
92 Arc::new(Self::new(store))
93 }
94}
95
96#[derive(Deserialize)]
97struct LookupArgs {
98 query: String,
99 #[serde(default = "default_k")]
100 k: usize,
101}
102
103fn default_k() -> usize {
104 3
105}
106
107#[async_trait]
108impl Tool for MemoryLookupTool {
109 fn schema(&self) -> ToolSchema {
110 ToolSchema {
111 name: Self::NAME.into(),
112 description: "Retrieve up to k similar memory episodes for a query.".into(),
113 args_schema: json!({
114 "type": "object",
115 "required": ["query"],
116 "properties": {
117 "query": {"type": "string"},
118 "k": {"type": "integer", "minimum": 1, "default": 3}
119 }
120 }),
121 result_schema: json!({
122 "type": "object",
123 "properties": {
124 "hits": {
125 "type": "array",
126 "items": {
127 "type": "object",
128 "properties": {
129 "score": {"type": "number"},
130 "summary": {"type": "string"},
131 "key": {"type": "string"},
132 "metadata": {"type": "object"}
133 }
134 }
135 }
136 }
137 }),
138 }
139 }
140
141 fn name(&self) -> rig_compose::tool::ToolName {
142 Self::NAME.to_string()
143 }
144
145 async fn invoke(&self, args: Value) -> Result<Value, KernelError> {
146 let parsed: LookupArgs = serde_json::from_value(args)?;
147 if parsed.k == 0 {
148 return Err(KernelError::InvalidArgument(
149 "memory.lookup requires k >= 1".into(),
150 ));
151 }
152 let hits = self
153 .store
154 .lookup(&parsed.query, parsed.k)
155 .await
156 .map_err(|err| KernelError::ToolFailed(err.to_string()))?;
157 Ok(json!({ "hits": hits }))
158 }
159}
160
161#[cfg(test)]
162mod tests {
163 use super::*;
164
165 struct StubMemory;
166
167 #[async_trait]
168 impl MemoryLookupStore for StubMemory {
169 async fn lookup(
170 &self,
171 query: &str,
172 k: usize,
173 ) -> Result<Vec<MemoryLookupHit>, MemoryLookupError> {
174 Ok(vec![
175 MemoryLookupHit::new(0.9, format!("matched {query}"))
176 .with_key("ep-1")
177 .with_metadata(json!({"rank": 1})),
178 ]
179 .into_iter()
180 .take(k)
181 .collect())
182 }
183 }
184
185 #[tokio::test]
186 async fn lookup_tool_returns_hits() {
187 let tool = MemoryLookupTool::new(Arc::new(StubMemory));
188 let out = tool
189 .invoke(json!({"query": "beacon", "k": 1}))
190 .await
191 .unwrap();
192 let score = out["hits"][0]["score"].as_f64().unwrap();
193 assert!((score - 0.9).abs() < 1e-6);
194 assert_eq!(out["hits"][0]["key"], "ep-1");
195 }
196
197 #[tokio::test]
198 async fn lookup_tool_rejects_zero_k() {
199 let tool = MemoryLookupTool::new(Arc::new(StubMemory));
200 let err = tool
201 .invoke(json!({"query": "beacon", "k": 0}))
202 .await
203 .unwrap_err();
204 assert!(matches!(err, KernelError::InvalidArgument(_)));
205 }
206}