Skip to main content

oxi_agent/tools/
memory_retain.rs

1//! `memory_retain` tool — persist a memory item to the backend.
2
3use async_trait::async_trait;
4use serde_json::{Value, json};
5
6use super::{AgentTool, AgentToolResult, ToolContext, ToolError};
7
8/// Valid memory kinds accepted by [`MemoryRetainTool`].
9const VALID_KINDS: [&str; 4] = ["fact", "preference", "context", "summary"];
10
11/// Tool that persists a memory item (content, kind, importance) to the
12/// configured [`MemoryBackend`].
13///
14/// Requires `ctx.memory` to be set; otherwise returns an error. The memory
15/// is scoped to the current session (`ctx.session_id`), falling back to
16/// `"default"` when no session id is available.
17pub struct MemoryRetainTool;
18
19#[async_trait]
20impl AgentTool for MemoryRetainTool {
21    fn name(&self) -> &str {
22        "memory_retain"
23    }
24
25    fn label(&self) -> &str {
26        "Memory Retain"
27    }
28
29    fn description(&self) -> &str {
30        "Store a piece of information to long-term memory for later recall. \
31         Use for facts, preferences, context, or summaries worth remembering \
32         across sessions."
33    }
34
35    fn essential(&self) -> bool {
36        false
37    }
38
39    fn parameters_schema(&self) -> Value {
40        json!({
41            "type": "object",
42            "properties": {
43                "content": {
44                    "type": "string",
45                    "description": "The text to remember."
46                },
47                "kind": {
48                    "type": "string",
49                    "enum": ["fact", "preference", "context", "summary"],
50                    "default": "fact",
51                    "description": "Category of the memory."
52                },
53                "importance": {
54                    "type": "number",
55                    "minimum": 0.0,
56                    "maximum": 1.0,
57                    "default": 0.5,
58                    "description": "How important this memory is (0–1)."
59                }
60            },
61            "required": ["content"]
62        })
63    }
64
65    async fn execute(
66        &self,
67        _tool_call_id: &str,
68        params: Value,
69        _signal: Option<tokio::sync::oneshot::Receiver<()>>,
70        ctx: &ToolContext,
71    ) -> Result<AgentToolResult, ToolError> {
72        let backend = ctx.memory.as_ref().ok_or("Memory not configured")?;
73
74        let content = params
75            .get("content")
76            .and_then(|v| v.as_str())
77            .ok_or("Missing required parameter: content")?;
78
79        let kind = params
80            .get("kind")
81            .and_then(|v| v.as_str())
82            .unwrap_or("fact");
83        if !VALID_KINDS.contains(&kind) {
84            return Err(format!(
85                "Invalid kind '{}': expected one of {:?}",
86                kind, VALID_KINDS
87            ));
88        }
89
90        // `importance` is validated for forward-compatibility; the current
91        // `MemoryBackend::put` signature does not persist it.
92        if let Some(importance) = params.get("importance").and_then(|v| v.as_f64())
93            && !(0.0..=1.0).contains(&importance)
94        {
95            return Err(format!(
96                "importance must be between 0 and 1, got {}",
97                importance
98            ));
99        }
100
101        let subject = ctx.session_id.as_deref().unwrap_or("default");
102        backend.put(content, kind, subject).await?;
103
104        Ok(AgentToolResult::success(format!(
105            "Retained [{}] to memory.",
106            kind
107        )))
108    }
109}
110
111#[cfg(test)]
112mod tests {
113    use super::*;
114    use crate::tools::MemoryBackend;
115    use parking_lot::Mutex;
116    use std::future::Future;
117    use std::pin::Pin;
118    use std::sync::Arc;
119
120    /// Records every `put` call; the remaining trait methods are stubbed.
121    #[derive(Debug)]
122    struct MockMemory {
123        puts: Mutex<Vec<(String, String, String)>>,
124    }
125
126    impl MockMemory {
127        fn new() -> Self {
128            Self {
129                puts: Mutex::new(vec![]),
130            }
131        }
132    }
133
134    impl MemoryBackend for MockMemory {
135        fn put<'a>(
136            &'a self,
137            content: &'a str,
138            kind: &'a str,
139            subject: &'a str,
140        ) -> Pin<Box<dyn Future<Output = Result<String, ToolError>> + Send + 'a>> {
141            self.puts
142                .lock()
143                .push((content.into(), kind.into(), subject.into()));
144            Box::pin(async move { Ok("mem-1".to_string()) })
145        }
146
147        fn search<'a>(
148            &'a self,
149            _query: &'a str,
150            _k: usize,
151        ) -> Pin<
152            Box<dyn Future<Output = Result<Vec<crate::tools::MemoryItem>, ToolError>> + Send + 'a>,
153        > {
154            Box::pin(async move { Ok(vec![]) })
155        }
156
157        fn list<'a>(
158            &'a self,
159            _subject: &'a str,
160        ) -> Pin<
161            Box<dyn Future<Output = Result<Vec<crate::tools::MemoryItem>, ToolError>> + Send + 'a>,
162        > {
163            Box::pin(async move { Ok(vec![]) })
164        }
165
166        fn delete<'a>(
167            &'a self,
168            _id: &'a str,
169        ) -> Pin<Box<dyn Future<Output = Result<(), ToolError>> + Send + 'a>> {
170            Box::pin(async move { Ok(()) })
171        }
172    }
173
174    #[tokio::test]
175    async fn retain_calls_put_with_correct_args() {
176        let mock = Arc::new(MockMemory::new());
177        let ctx = ToolContext::default()
178            .with_session("sess-42")
179            .with_memory(mock.clone());
180        let result = MemoryRetainTool
181            .execute(
182                "c1",
183                json!({"content": "hello", "kind": "fact", "importance": 0.9}),
184                None,
185                &ctx,
186            )
187            .await
188            .unwrap();
189        assert!(result.success);
190        assert_eq!(result.output, "Retained [fact] to memory.");
191        let puts = mock.puts.lock();
192        assert_eq!(puts.len(), 1);
193        assert_eq!(puts[0].0, "hello");
194        assert_eq!(puts[0].1, "fact");
195        assert_eq!(puts[0].2, "sess-42");
196    }
197
198    #[tokio::test]
199    async fn retain_defaults_kind_to_fact() {
200        let mock = Arc::new(MockMemory::new());
201        let ctx = ToolContext::default().with_memory(mock.clone());
202        let result = MemoryRetainTool
203            .execute("c1", json!({"content": "x"}), None, &ctx)
204            .await
205            .unwrap();
206        assert_eq!(result.output, "Retained [fact] to memory.");
207        assert_eq!(mock.puts.lock()[0].1, "fact");
208    }
209
210    #[tokio::test]
211    async fn retain_uses_default_subject_without_session() {
212        let mock = Arc::new(MockMemory::new());
213        let ctx = ToolContext::default().with_memory(mock.clone());
214        MemoryRetainTool
215            .execute("c1", json!({"content": "x"}), None, &ctx)
216            .await
217            .unwrap();
218        assert_eq!(mock.puts.lock()[0].2, "default");
219    }
220
221    #[tokio::test]
222    async fn retain_errors_when_memory_not_configured() {
223        let ctx = ToolContext::default();
224        let err = MemoryRetainTool
225            .execute("c1", json!({"content": "x"}), None, &ctx)
226            .await
227            .unwrap_err();
228        assert_eq!(err, "Memory not configured");
229    }
230
231    #[tokio::test]
232    async fn retain_rejects_invalid_kind() {
233        let mock = Arc::new(MockMemory::new());
234        let ctx = ToolContext::default().with_memory(mock.clone());
235        let err = MemoryRetainTool
236            .execute("c1", json!({"content": "x", "kind": "bogus"}), None, &ctx)
237            .await
238            .unwrap_err();
239        assert!(err.contains("Invalid kind"));
240    }
241
242    #[tokio::test]
243    async fn retain_rejects_out_of_range_importance() {
244        let mock = Arc::new(MockMemory::new());
245        let ctx = ToolContext::default().with_memory(mock.clone());
246        let err = MemoryRetainTool
247            .execute("c1", json!({"content": "x", "importance": 1.5}), None, &ctx)
248            .await
249            .unwrap_err();
250        assert!(err.contains("importance"));
251    }
252
253    #[tokio::test]
254    async fn retain_rejects_missing_content() {
255        let mock = Arc::new(MockMemory::new());
256        let ctx = ToolContext::default().with_memory(mock.clone());
257        let err = MemoryRetainTool
258            .execute("c1", json!({"kind": "fact"}), None, &ctx)
259            .await
260            .unwrap_err();
261        assert!(err.contains("content"));
262    }
263}