Skip to main content

rig_resources/
memory.rs

1//! Memory lookup tool contract.
2//!
3//! [`MemoryPivotSkill`](crate::MemoryPivotSkill) calls a tool named
4//! `memory.lookup`. This module supplies the canonical tool and a small
5//! backend trait so stores such as `rig-memvid`, test fakes, or
6//! application-specific episode stores can expose the same lookup shape
7//! without depending on each other.
8
9use std::sync::Arc;
10
11use async_trait::async_trait;
12use serde::{Deserialize, Serialize};
13use serde_json::{Value, json};
14use thiserror::Error;
15
16use rig_compose::{KernelError, Tool, ToolSchema};
17
18use crate::trace::ResourceTraceEnvelope;
19
20const TRACE_RESOURCE: &str = "memory";
21const TRACE_OPERATION: &str = "lookup";
22const TRACE_KIND: &str = "memory_lookup";
23
24/// Reason code emitted on the [`ResourceTraceEnvelope`] when a lookup
25/// returned zero hits.
26pub const TRACE_REASON_NO_HITS: &str = "no_hits";
27/// Reason code emitted when the backing [`MemoryLookupStore`] failed.
28pub const TRACE_REASON_BACKEND_ERROR: &str = "backend_error";
29
30/// Error returned by a [`MemoryLookupStore`].
31#[derive(Debug, Error)]
32pub enum MemoryLookupError {
33    /// The backing memory store failed.
34    #[error("memory lookup backend error: {0}")]
35    Backend(String),
36}
37
38/// One hit returned by a [`MemoryLookupStore`].
39#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct MemoryLookupHit {
41    /// Retrieval score in `[0, 1]`; higher is more similar.
42    pub score: f32,
43    /// Short text summary suitable for evidence display.
44    pub summary: String,
45    /// Optional stable store key, frame id, or episode id.
46    #[serde(skip_serializing_if = "Option::is_none")]
47    pub key: Option<String>,
48    /// Optional URI or locator for the backing memory source.
49    #[serde(default, skip_serializing_if = "Option::is_none")]
50    pub source_uri: Option<String>,
51    /// Optional principal, actor, tenant, or subject associated with the hit.
52    #[serde(default, skip_serializing_if = "Option::is_none")]
53    pub principal: Option<String>,
54    /// Optional caller-defined lookup scope such as tenant, workspace, or
55    /// profile.
56    #[serde(default, skip_serializing_if = "Option::is_none")]
57    pub scope: Option<String>,
58    /// Optional milliseconds since the Unix epoch when the source was recorded.
59    #[serde(default, skip_serializing_if = "Option::is_none")]
60    pub recorded_at_millis: Option<i64>,
61    /// Optional store-specific metadata.
62    #[serde(default, skip_serializing_if = "Value::is_null")]
63    pub metadata: Value,
64}
65
66impl MemoryLookupHit {
67    /// Create a hit with no key or metadata.
68    pub fn new(score: f32, summary: impl Into<String>) -> Self {
69        Self {
70            score,
71            summary: summary.into(),
72            key: None,
73            source_uri: None,
74            principal: None,
75            scope: None,
76            recorded_at_millis: None,
77            metadata: Value::Null,
78        }
79    }
80
81    /// Attach a stable storage key.
82    pub fn with_key(mut self, key: impl Into<String>) -> Self {
83        self.key = Some(key.into());
84        self
85    }
86
87    /// Attach a source URI or locator.
88    pub fn with_source_uri(mut self, source_uri: impl Into<String>) -> Self {
89        self.source_uri = Some(source_uri.into());
90        self
91    }
92
93    /// Attach the principal, actor, tenant, or subject associated with the hit.
94    pub fn with_principal(mut self, principal: impl Into<String>) -> Self {
95        self.principal = Some(principal.into());
96        self
97    }
98
99    /// Attach the caller-defined lookup scope.
100    pub fn with_scope(mut self, scope: impl Into<String>) -> Self {
101        self.scope = Some(scope.into());
102        self
103    }
104
105    /// Attach the source record timestamp in milliseconds since the Unix epoch.
106    pub fn with_recorded_at_millis(mut self, recorded_at_millis: i64) -> Self {
107        self.recorded_at_millis = Some(recorded_at_millis);
108        self
109    }
110
111    /// Attach store-specific metadata.
112    pub fn with_metadata(mut self, metadata: Value) -> Self {
113        self.metadata = metadata;
114        self
115    }
116}
117
118/// Backend contract for the canonical `memory.lookup` tool.
119#[async_trait]
120pub trait MemoryLookupStore: Send + Sync {
121    /// Return up to `k` hits most relevant to `query`.
122    async fn lookup(
123        &self,
124        query: &str,
125        k: usize,
126    ) -> Result<Vec<MemoryLookupHit>, MemoryLookupError>;
127}
128
129/// `memory.lookup` — reusable kernel tool for semantic or lexical memory pivots.
130pub struct MemoryLookupTool {
131    store: Arc<dyn MemoryLookupStore>,
132}
133
134impl MemoryLookupTool {
135    /// Stable tool name consumed by [`crate::MemoryPivotSkill`].
136    pub const NAME: &'static str = "memory.lookup";
137
138    /// Create a lookup tool backed by `store`.
139    pub fn new(store: Arc<dyn MemoryLookupStore>) -> Self {
140        Self { store }
141    }
142
143    /// Create the tool behind an [`Arc`] for registration in a `ToolRegistry`.
144    pub fn arc(store: Arc<dyn MemoryLookupStore>) -> Arc<dyn Tool> {
145        Arc::new(Self::new(store))
146    }
147}
148
149#[derive(Deserialize)]
150struct LookupArgs {
151    query: String,
152    #[serde(default = "default_k")]
153    k: usize,
154}
155
156fn default_k() -> usize {
157    3
158}
159
160#[async_trait]
161impl Tool for MemoryLookupTool {
162    fn schema(&self) -> ToolSchema {
163        ToolSchema {
164            name: Self::NAME.into(),
165            description: "Retrieve up to k similar memory episodes for a query.".into(),
166            args_schema: json!({
167                "type": "object",
168                "required": ["query"],
169                "properties": {
170                    "query": {"type": "string"},
171                    "k": {"type": "integer", "minimum": 1, "default": 3}
172                }
173            }),
174            result_schema: json!({
175                "type": "object",
176                "properties": {
177                    "hits": {
178                        "type": "array",
179                        "items": {
180                            "type": "object",
181                            "properties": {
182                                "score": {"type": "number"},
183                                "summary": {"type": "string"},
184                                "key": {"type": "string"},
185                                "source_uri": {"type": "string"},
186                                "principal": {"type": "string"},
187                                "scope": {"type": "string"},
188                                "recorded_at_millis": {"type": "integer"},
189                                "metadata": {"type": "object"}
190                            }
191                        }
192                    }
193                }
194            }),
195        }
196    }
197
198    fn name(&self) -> rig_compose::tool::ToolName {
199        Self::NAME.to_string()
200    }
201
202    async fn invoke(&self, args: Value) -> Result<Value, KernelError> {
203        let parsed: LookupArgs = serde_json::from_value(args)?;
204        if parsed.k == 0 {
205            return Err(KernelError::InvalidArgument(
206                "memory.lookup requires k >= 1".into(),
207            ));
208        }
209        let hits = self
210            .store
211            .lookup(&parsed.query, parsed.k)
212            .await
213            .map_err(|err| KernelError::ToolFailed(err.to_string()))?;
214        Ok(json!({ "hits": hits }))
215    }
216}
217
218/// Build a [`ResourceTraceEnvelope`] describing a single `memory.lookup`
219/// invocation.
220///
221/// This complements [`crate::memory_hit_to_context_item`] (the prompt-side
222/// projection) by giving observability and audit consumers a trace-side
223/// record of the query, scope, hit count, and top match. The envelope is
224/// shaped to mirror [`crate::security_finding_trace_envelope`] so the same
225/// downstream pipelines can route both kinds without bespoke shapes.
226///
227/// `principal` and `scope` are optional caller-provided context (typically
228/// the calling agent's tenant or workspace, which the store may or may not
229/// have echoed back on each hit). When `hits` is empty the envelope carries
230/// the [`TRACE_REASON_NO_HITS`] reason code.
231///
232/// ```no_run
233/// use rig_resources::{MemoryLookupHit, memory_lookup_trace_envelope};
234///
235/// let hits = vec![MemoryLookupHit::new(0.82, "matched episode").with_key("ep-7")];
236/// let envelope = memory_lookup_trace_envelope("beacon", 3, &hits, Some("alice"), None);
237/// assert_eq!(envelope.resource, "memory");
238/// assert_eq!(envelope.output_summary["hit_count"], 1);
239/// ```
240#[must_use]
241pub fn memory_lookup_trace_envelope(
242    query: &str,
243    k: usize,
244    hits: &[MemoryLookupHit],
245    principal: Option<&str>,
246    scope: Option<&str>,
247) -> ResourceTraceEnvelope {
248    let mut input = json!({
249        "query": query,
250        "k": k,
251    });
252    if let Some(map) = input.as_object_mut() {
253        if let Some(principal) = principal {
254            map.insert("principal".into(), Value::String(principal.to_string()));
255        }
256        if let Some(scope) = scope {
257            map.insert("scope".into(), Value::String(scope.to_string()));
258        }
259    }
260
261    let mut output = json!({
262        "hit_count": hits.len(),
263    });
264    if let (Some(top), Some(map)) = (hits.first(), output.as_object_mut()) {
265        if let Some(score) = serde_json::Number::from_f64(top.score as f64) {
266            map.insert("top_score".into(), Value::Number(score));
267        }
268        if let Some(key) = &top.key {
269            map.insert("top_key".into(), Value::String(key.clone()));
270        }
271    }
272
273    let mut envelope = ResourceTraceEnvelope::new(TRACE_RESOURCE, TRACE_OPERATION, TRACE_KIND)
274        .with_input_summary(input)
275        .with_output_summary(output);
276
277    if hits.is_empty() {
278        envelope = envelope.with_reason(TRACE_REASON_NO_HITS);
279    }
280
281    if let Some(top) = hits.first() {
282        let mut metadata = serde_json::Map::new();
283        if let Some(source_uri) = &top.source_uri {
284            metadata.insert("source_uri".into(), Value::String(source_uri.clone()));
285        }
286        if let Some(recorded_at_millis) = top.recorded_at_millis {
287            metadata.insert(
288                "recorded_at_millis".into(),
289                Value::Number(serde_json::Number::from(recorded_at_millis)),
290            );
291        }
292        if let Some(top_principal) = &top.principal
293            && principal.is_none_or(|p| p != top_principal)
294        {
295            metadata.insert("top_principal".into(), Value::String(top_principal.clone()));
296        }
297        if let Some(top_scope) = &top.scope
298            && scope.is_none_or(|s| s != top_scope)
299        {
300            metadata.insert("top_scope".into(), Value::String(top_scope.clone()));
301        }
302        if !metadata.is_empty() {
303            envelope = envelope.with_metadata(Value::Object(metadata));
304        }
305    }
306
307    envelope
308}
309
310#[cfg(test)]
311mod tests {
312    use super::*;
313
314    struct StubMemory;
315
316    #[async_trait]
317    impl MemoryLookupStore for StubMemory {
318        async fn lookup(
319            &self,
320            query: &str,
321            k: usize,
322        ) -> Result<Vec<MemoryLookupHit>, MemoryLookupError> {
323            Ok(vec![
324                MemoryLookupHit::new(0.9, format!("matched {query}"))
325                    .with_key("ep-1")
326                    .with_metadata(json!({"rank": 1})),
327            ]
328            .into_iter()
329            .take(k)
330            .collect())
331        }
332    }
333
334    #[tokio::test]
335    async fn lookup_tool_returns_hits() {
336        let tool = MemoryLookupTool::new(Arc::new(StubMemory));
337        let out = tool
338            .invoke(json!({"query": "beacon", "k": 1}))
339            .await
340            .unwrap();
341        let score = out["hits"][0]["score"].as_f64().unwrap();
342        assert!((score - 0.9).abs() < 1e-6);
343        assert_eq!(out["hits"][0]["key"], "ep-1");
344    }
345
346    #[test]
347    fn lookup_hit_serializes_shared_metadata() {
348        let hit = MemoryLookupHit::new(0.75, "matched episode")
349            .with_key("ep-7")
350            .with_source_uri("memory://episode/7")
351            .with_principal("alice")
352            .with_scope("workspace")
353            .with_recorded_at_millis(1_700_000_000_000);
354
355        let json = serde_json::to_value(hit).unwrap();
356
357        assert_eq!(json["key"], "ep-7");
358        assert_eq!(json["source_uri"], "memory://episode/7");
359        assert_eq!(json["principal"], "alice");
360        assert_eq!(json["scope"], "workspace");
361        assert_eq!(json["recorded_at_millis"], 1_700_000_000_000_i64);
362    }
363
364    #[tokio::test]
365    async fn lookup_tool_rejects_zero_k() {
366        let tool = MemoryLookupTool::new(Arc::new(StubMemory));
367        let err = tool
368            .invoke(json!({"query": "beacon", "k": 0}))
369            .await
370            .unwrap_err();
371        assert!(matches!(err, KernelError::InvalidArgument(_)));
372    }
373
374    #[test]
375    fn trace_envelope_summarises_hits_and_metadata() {
376        let hits = vec![
377            MemoryLookupHit::new(0.91, "top hit")
378                .with_key("ep-1")
379                .with_source_uri("memory://ep/1")
380                .with_recorded_at_millis(1_700_000_000_000)
381                .with_principal("alice")
382                .with_scope("workspace"),
383            MemoryLookupHit::new(0.42, "runner up").with_key("ep-2"),
384        ];
385
386        let envelope =
387            memory_lookup_trace_envelope("beacon", 3, &hits, Some("alice"), Some("workspace"));
388
389        assert_eq!(envelope.version, ResourceTraceEnvelope::VERSION);
390        assert_eq!(envelope.resource, "memory");
391        assert_eq!(envelope.operation, "lookup");
392        assert_eq!(envelope.trace_kind, "memory_lookup");
393        assert_eq!(envelope.input_summary["query"], "beacon");
394        assert_eq!(envelope.input_summary["k"], 3);
395        assert_eq!(envelope.input_summary["principal"], "alice");
396        assert_eq!(envelope.input_summary["scope"], "workspace");
397        assert_eq!(envelope.output_summary["hit_count"], 2);
398        let top_score = envelope.output_summary["top_score"].as_f64().unwrap();
399        assert!((top_score - 0.91).abs() < 1e-6);
400        assert_eq!(envelope.output_summary["top_key"], "ep-1");
401        assert!(envelope.reason.is_none());
402        assert_eq!(envelope.metadata["source_uri"], "memory://ep/1");
403        assert_eq!(
404            envelope.metadata["recorded_at_millis"],
405            1_700_000_000_000_i64
406        );
407        // Caller-supplied principal/scope matches the top hit, so they are
408        // not echoed into metadata.
409        assert!(envelope.metadata.get("top_principal").is_none());
410        assert!(envelope.metadata.get("top_scope").is_none());
411    }
412
413    #[test]
414    fn trace_envelope_emits_no_hits_reason_when_empty() {
415        let envelope = memory_lookup_trace_envelope("nothing", 5, &[], None, None);
416        assert_eq!(envelope.output_summary["hit_count"], 0);
417        assert!(envelope.output_summary.get("top_score").is_none());
418        assert_eq!(envelope.reason.as_deref(), Some(TRACE_REASON_NO_HITS));
419        assert!(envelope.metadata.is_null());
420        assert!(envelope.input_summary.get("principal").is_none());
421    }
422
423    #[test]
424    fn trace_envelope_records_mismatched_top_principal_scope() {
425        let hits = vec![
426            MemoryLookupHit::new(0.5, "cross-tenant")
427                .with_key("ep-9")
428                .with_principal("bob")
429                .with_scope("other"),
430        ];
431
432        let envelope =
433            memory_lookup_trace_envelope("q", 1, &hits, Some("alice"), Some("workspace"));
434
435        assert_eq!(envelope.metadata["top_principal"], "bob");
436        assert_eq!(envelope.metadata["top_scope"], "other");
437    }
438}