Skip to main content

rig_resources/
memory.rs

1//! Memory lookup tool contract.
2//!
3//! [`MemoryPivotSkill`](crate::MemoryPivotSkill) calls a tool named
4//! `memory.lookup`. This module supplies the canonical tool and a small
5//! backend trait so stores such as `rig-memvid`, test fakes, or
6//! application-specific episode stores can expose the same lookup shape
7//! without depending on each other.
8
9use std::sync::Arc;
10
11use async_trait::async_trait;
12use serde::{Deserialize, Serialize};
13use serde_json::{Value, json};
14use thiserror::Error;
15
16use rig_compose::{KernelError, Tool, ToolSchema};
17
18/// Error returned by a [`MemoryLookupStore`].
19#[derive(Debug, Error)]
20pub enum MemoryLookupError {
21    /// The backing memory store failed.
22    #[error("memory lookup backend error: {0}")]
23    Backend(String),
24}
25
26/// One hit returned by a [`MemoryLookupStore`].
27#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct MemoryLookupHit {
29    /// Retrieval score in `[0, 1]`; higher is more similar.
30    pub score: f32,
31    /// Short text summary suitable for evidence display.
32    pub summary: String,
33    /// Optional stable store key, frame id, or episode id.
34    #[serde(skip_serializing_if = "Option::is_none")]
35    pub key: Option<String>,
36    /// Optional store-specific metadata.
37    #[serde(default, skip_serializing_if = "Value::is_null")]
38    pub metadata: Value,
39}
40
41impl MemoryLookupHit {
42    /// Create a hit with no key or metadata.
43    pub fn new(score: f32, summary: impl Into<String>) -> Self {
44        Self {
45            score,
46            summary: summary.into(),
47            key: None,
48            metadata: Value::Null,
49        }
50    }
51
52    /// Attach a stable storage key.
53    pub fn with_key(mut self, key: impl Into<String>) -> Self {
54        self.key = Some(key.into());
55        self
56    }
57
58    /// Attach store-specific metadata.
59    pub fn with_metadata(mut self, metadata: Value) -> Self {
60        self.metadata = metadata;
61        self
62    }
63}
64
65/// Backend contract for the canonical `memory.lookup` tool.
66#[async_trait]
67pub trait MemoryLookupStore: Send + Sync {
68    /// Return up to `k` hits most relevant to `query`.
69    async fn lookup(
70        &self,
71        query: &str,
72        k: usize,
73    ) -> Result<Vec<MemoryLookupHit>, MemoryLookupError>;
74}
75
76/// `memory.lookup` — reusable kernel tool for semantic or lexical memory pivots.
77pub struct MemoryLookupTool {
78    store: Arc<dyn MemoryLookupStore>,
79}
80
81impl MemoryLookupTool {
82    /// Stable tool name consumed by [`crate::MemoryPivotSkill`].
83    pub const NAME: &'static str = "memory.lookup";
84
85    /// Create a lookup tool backed by `store`.
86    pub fn new(store: Arc<dyn MemoryLookupStore>) -> Self {
87        Self { store }
88    }
89
90    /// Create the tool behind an [`Arc`] for registration in a `ToolRegistry`.
91    pub fn arc(store: Arc<dyn MemoryLookupStore>) -> Arc<dyn Tool> {
92        Arc::new(Self::new(store))
93    }
94}
95
96#[derive(Deserialize)]
97struct LookupArgs {
98    query: String,
99    #[serde(default = "default_k")]
100    k: usize,
101}
102
103fn default_k() -> usize {
104    3
105}
106
107#[async_trait]
108impl Tool for MemoryLookupTool {
109    fn schema(&self) -> ToolSchema {
110        ToolSchema {
111            name: Self::NAME.into(),
112            description: "Retrieve up to k similar memory episodes for a query.".into(),
113            args_schema: json!({
114                "type": "object",
115                "required": ["query"],
116                "properties": {
117                    "query": {"type": "string"},
118                    "k": {"type": "integer", "minimum": 1, "default": 3}
119                }
120            }),
121            result_schema: json!({
122                "type": "object",
123                "properties": {
124                    "hits": {
125                        "type": "array",
126                        "items": {
127                            "type": "object",
128                            "properties": {
129                                "score": {"type": "number"},
130                                "summary": {"type": "string"},
131                                "key": {"type": "string"},
132                                "metadata": {"type": "object"}
133                            }
134                        }
135                    }
136                }
137            }),
138        }
139    }
140
141    fn name(&self) -> rig_compose::tool::ToolName {
142        Self::NAME.to_string()
143    }
144
145    async fn invoke(&self, args: Value) -> Result<Value, KernelError> {
146        let parsed: LookupArgs = serde_json::from_value(args)?;
147        if parsed.k == 0 {
148            return Err(KernelError::InvalidArgument(
149                "memory.lookup requires k >= 1".into(),
150            ));
151        }
152        let hits = self
153            .store
154            .lookup(&parsed.query, parsed.k)
155            .await
156            .map_err(|err| KernelError::ToolFailed(err.to_string()))?;
157        Ok(json!({ "hits": hits }))
158    }
159}
160
161#[cfg(test)]
162mod tests {
163    use super::*;
164
165    struct StubMemory;
166
167    #[async_trait]
168    impl MemoryLookupStore for StubMemory {
169        async fn lookup(
170            &self,
171            query: &str,
172            k: usize,
173        ) -> Result<Vec<MemoryLookupHit>, MemoryLookupError> {
174            Ok(vec![
175                MemoryLookupHit::new(0.9, format!("matched {query}"))
176                    .with_key("ep-1")
177                    .with_metadata(json!({"rank": 1})),
178            ]
179            .into_iter()
180            .take(k)
181            .collect())
182        }
183    }
184
185    #[tokio::test]
186    async fn lookup_tool_returns_hits() {
187        let tool = MemoryLookupTool::new(Arc::new(StubMemory));
188        let out = tool
189            .invoke(json!({"query": "beacon", "k": 1}))
190            .await
191            .unwrap();
192        let score = out["hits"][0]["score"].as_f64().unwrap();
193        assert!((score - 0.9).abs() < 1e-6);
194        assert_eq!(out["hits"][0]["key"], "ep-1");
195    }
196
197    #[tokio::test]
198    async fn lookup_tool_rejects_zero_k() {
199        let tool = MemoryLookupTool::new(Arc::new(StubMemory));
200        let err = tool
201            .invoke(json!({"query": "beacon", "k": 0}))
202            .await
203            .unwrap_err();
204        assert!(matches!(err, KernelError::InvalidArgument(_)));
205    }
206}