Skip to main content

rig_resources/
memory.rs

1//! Memory lookup tool contract.
2//!
3//! [`MemoryPivotSkill`](crate::MemoryPivotSkill) calls a tool named
4//! `memory.lookup`. This module supplies the canonical tool and a small
5//! backend trait so stores such as `rig-memvid`, test fakes, or
6//! application-specific episode stores can expose the same lookup shape
7//! without depending on each other.
8
9use std::sync::Arc;
10
11use async_trait::async_trait;
12use serde::{Deserialize, Serialize};
13use serde_json::{Value, json};
14use thiserror::Error;
15
16use rig_compose::{KernelError, Tool, ToolSchema};
17
18/// Error returned by a [`MemoryLookupStore`].
19#[derive(Debug, Error)]
20pub enum MemoryLookupError {
21    /// The backing memory store failed.
22    #[error("memory lookup backend error: {0}")]
23    Backend(String),
24}
25
26/// One hit returned by a [`MemoryLookupStore`].
27#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct MemoryLookupHit {
29    /// Retrieval score in `[0, 1]`; higher is more similar.
30    pub score: f32,
31    /// Short text summary suitable for evidence display.
32    pub summary: String,
33    /// Optional stable store key, frame id, or episode id.
34    #[serde(skip_serializing_if = "Option::is_none")]
35    pub key: Option<String>,
36    /// Optional URI or locator for the backing memory source.
37    #[serde(default, skip_serializing_if = "Option::is_none")]
38    pub source_uri: Option<String>,
39    /// Optional principal, actor, tenant, or subject associated with the hit.
40    #[serde(default, skip_serializing_if = "Option::is_none")]
41    pub principal: Option<String>,
42    /// Optional caller-defined lookup scope such as tenant, workspace, or
43    /// profile.
44    #[serde(default, skip_serializing_if = "Option::is_none")]
45    pub scope: Option<String>,
46    /// Optional milliseconds since the Unix epoch when the source was recorded.
47    #[serde(default, skip_serializing_if = "Option::is_none")]
48    pub recorded_at_millis: Option<i64>,
49    /// Optional store-specific metadata.
50    #[serde(default, skip_serializing_if = "Value::is_null")]
51    pub metadata: Value,
52}
53
54impl MemoryLookupHit {
55    /// Create a hit with no key or metadata.
56    pub fn new(score: f32, summary: impl Into<String>) -> Self {
57        Self {
58            score,
59            summary: summary.into(),
60            key: None,
61            source_uri: None,
62            principal: None,
63            scope: None,
64            recorded_at_millis: None,
65            metadata: Value::Null,
66        }
67    }
68
69    /// Attach a stable storage key.
70    pub fn with_key(mut self, key: impl Into<String>) -> Self {
71        self.key = Some(key.into());
72        self
73    }
74
75    /// Attach a source URI or locator.
76    pub fn with_source_uri(mut self, source_uri: impl Into<String>) -> Self {
77        self.source_uri = Some(source_uri.into());
78        self
79    }
80
81    /// Attach the principal, actor, tenant, or subject associated with the hit.
82    pub fn with_principal(mut self, principal: impl Into<String>) -> Self {
83        self.principal = Some(principal.into());
84        self
85    }
86
87    /// Attach the caller-defined lookup scope.
88    pub fn with_scope(mut self, scope: impl Into<String>) -> Self {
89        self.scope = Some(scope.into());
90        self
91    }
92
93    /// Attach the source record timestamp in milliseconds since the Unix epoch.
94    pub fn with_recorded_at_millis(mut self, recorded_at_millis: i64) -> Self {
95        self.recorded_at_millis = Some(recorded_at_millis);
96        self
97    }
98
99    /// Attach store-specific metadata.
100    pub fn with_metadata(mut self, metadata: Value) -> Self {
101        self.metadata = metadata;
102        self
103    }
104}
105
106/// Backend contract for the canonical `memory.lookup` tool.
107#[async_trait]
108pub trait MemoryLookupStore: Send + Sync {
109    /// Return up to `k` hits most relevant to `query`.
110    async fn lookup(
111        &self,
112        query: &str,
113        k: usize,
114    ) -> Result<Vec<MemoryLookupHit>, MemoryLookupError>;
115}
116
117/// `memory.lookup` — reusable kernel tool for semantic or lexical memory pivots.
118pub struct MemoryLookupTool {
119    store: Arc<dyn MemoryLookupStore>,
120}
121
122impl MemoryLookupTool {
123    /// Stable tool name consumed by [`crate::MemoryPivotSkill`].
124    pub const NAME: &'static str = "memory.lookup";
125
126    /// Create a lookup tool backed by `store`.
127    pub fn new(store: Arc<dyn MemoryLookupStore>) -> Self {
128        Self { store }
129    }
130
131    /// Create the tool behind an [`Arc`] for registration in a `ToolRegistry`.
132    pub fn arc(store: Arc<dyn MemoryLookupStore>) -> Arc<dyn Tool> {
133        Arc::new(Self::new(store))
134    }
135}
136
137#[derive(Deserialize)]
138struct LookupArgs {
139    query: String,
140    #[serde(default = "default_k")]
141    k: usize,
142}
143
144fn default_k() -> usize {
145    3
146}
147
148#[async_trait]
149impl Tool for MemoryLookupTool {
150    fn schema(&self) -> ToolSchema {
151        ToolSchema {
152            name: Self::NAME.into(),
153            description: "Retrieve up to k similar memory episodes for a query.".into(),
154            args_schema: json!({
155                "type": "object",
156                "required": ["query"],
157                "properties": {
158                    "query": {"type": "string"},
159                    "k": {"type": "integer", "minimum": 1, "default": 3}
160                }
161            }),
162            result_schema: json!({
163                "type": "object",
164                "properties": {
165                    "hits": {
166                        "type": "array",
167                        "items": {
168                            "type": "object",
169                            "properties": {
170                                "score": {"type": "number"},
171                                "summary": {"type": "string"},
172                                "key": {"type": "string"},
173                                "source_uri": {"type": "string"},
174                                "principal": {"type": "string"},
175                                "scope": {"type": "string"},
176                                "recorded_at_millis": {"type": "integer"},
177                                "metadata": {"type": "object"}
178                            }
179                        }
180                    }
181                }
182            }),
183        }
184    }
185
186    fn name(&self) -> rig_compose::tool::ToolName {
187        Self::NAME.to_string()
188    }
189
190    async fn invoke(&self, args: Value) -> Result<Value, KernelError> {
191        let parsed: LookupArgs = serde_json::from_value(args)?;
192        if parsed.k == 0 {
193            return Err(KernelError::InvalidArgument(
194                "memory.lookup requires k >= 1".into(),
195            ));
196        }
197        let hits = self
198            .store
199            .lookup(&parsed.query, parsed.k)
200            .await
201            .map_err(|err| KernelError::ToolFailed(err.to_string()))?;
202        Ok(json!({ "hits": hits }))
203    }
204}
205
206#[cfg(test)]
207mod tests {
208    use super::*;
209
210    struct StubMemory;
211
212    #[async_trait]
213    impl MemoryLookupStore for StubMemory {
214        async fn lookup(
215            &self,
216            query: &str,
217            k: usize,
218        ) -> Result<Vec<MemoryLookupHit>, MemoryLookupError> {
219            Ok(vec![
220                MemoryLookupHit::new(0.9, format!("matched {query}"))
221                    .with_key("ep-1")
222                    .with_metadata(json!({"rank": 1})),
223            ]
224            .into_iter()
225            .take(k)
226            .collect())
227        }
228    }
229
230    #[tokio::test]
231    async fn lookup_tool_returns_hits() {
232        let tool = MemoryLookupTool::new(Arc::new(StubMemory));
233        let out = tool
234            .invoke(json!({"query": "beacon", "k": 1}))
235            .await
236            .unwrap();
237        let score = out["hits"][0]["score"].as_f64().unwrap();
238        assert!((score - 0.9).abs() < 1e-6);
239        assert_eq!(out["hits"][0]["key"], "ep-1");
240    }
241
242    #[test]
243    fn lookup_hit_serializes_shared_metadata() {
244        let hit = MemoryLookupHit::new(0.75, "matched episode")
245            .with_key("ep-7")
246            .with_source_uri("memory://episode/7")
247            .with_principal("alice")
248            .with_scope("workspace")
249            .with_recorded_at_millis(1_700_000_000_000);
250
251        let json = serde_json::to_value(hit).unwrap();
252
253        assert_eq!(json["key"], "ep-7");
254        assert_eq!(json["source_uri"], "memory://episode/7");
255        assert_eq!(json["principal"], "alice");
256        assert_eq!(json["scope"], "workspace");
257        assert_eq!(json["recorded_at_millis"], 1_700_000_000_000_i64);
258    }
259
260    #[tokio::test]
261    async fn lookup_tool_rejects_zero_k() {
262        let tool = MemoryLookupTool::new(Arc::new(StubMemory));
263        let err = tool
264            .invoke(json!({"query": "beacon", "k": 0}))
265            .await
266            .unwrap_err();
267        assert!(matches!(err, KernelError::InvalidArgument(_)));
268    }
269}