Skip to main content

oxios_kernel/tools/
memory_tools.rs

1//! Memory tools for cross-session agent memory.
2//!
3//! Provides three tools:
4//! - `memory_write` — write a memory entry
5//! - `memory_read` — read/list memory entries
6//! - `memory_search` — search memory entries by content or tags
7
8use std::sync::Arc;
9
10use async_trait::async_trait;
11use chrono::Utc;
12use oxi_sdk::{AgentTool, AgentToolResult, ToolContext};
13use serde_json::{json, Value};
14
15use crate::memory::{MemoryEntry, MemoryManager, MemoryType};
16
17/// Tool for writing memory entries that persist across sessions.
18pub struct MemoryWriteTool {
19    memory_manager: Arc<MemoryManager>,
20}
21
22impl MemoryWriteTool {
23    /// Create a new MemoryWriteTool.
24    pub fn new(memory_manager: Arc<MemoryManager>) -> Self {
25        Self { memory_manager }
26    }
27
28    /// Create a `MemoryWriteTool` from a [`KernelHandle`].
29    ///
30    /// Extracts the memory manager from the kernel's agent facade.
31    pub fn from_kernel(kernel: &crate::kernel_handle::KernelHandle) -> Self {
32        Self::new(kernel.agents.memory_manager().clone())
33    }
34}
35
36impl std::fmt::Debug for MemoryWriteTool {
37    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38        f.debug_struct("MemoryWriteTool").finish()
39    }
40}
41
42#[async_trait]
43impl AgentTool for MemoryWriteTool {
44    fn name(&self) -> &str {
45        "memory_write"
46    }
47
48    fn label(&self) -> &str {
49        "Memory Write"
50    }
51
52    fn description(&self) -> &str {
53        "Write a memory entry that persists across sessions. Use this to save important facts, episodes, or knowledge for future reference."
54    }
55
56    fn parameters_schema(&self) -> Value {
57        json!({
58            "type": "object",
59            "properties": {
60                "content": {
61                    "type": "string",
62                    "description": "The memory content to store"
63                },
64                "memory_type": {
65                    "type": "string",
66                    "enum": ["fact", "episode", "knowledge"],
67                    "description": "The type of memory entry"
68                },
69                "tags": {
70                    "type": "array",
71                    "items": { "type": "string" },
72                    "description": "Optional tags for categorization"
73                },
74                "importance": {
75                    "type": "number",
76                    "description": "Importance score 0.0-1.0 (default 0.5)"
77                }
78            },
79            "required": ["content", "memory_type"]
80        })
81    }
82
83    async fn execute(
84        &self,
85        _tool_call_id: &str,
86        params: Value,
87        _signal: Option<tokio::sync::oneshot::Receiver<()>>,
88        _ctx: &ToolContext,
89    ) -> Result<AgentToolResult, oxi_sdk::ToolError> {
90        let content = params["content"].as_str().unwrap_or("").to_string();
91        if content.is_empty() {
92            return Ok(AgentToolResult::error("content is required"));
93        }
94
95        let memory_type_str = params["memory_type"].as_str().unwrap_or("fact");
96        let memory_type = match memory_type_str {
97            "fact" => MemoryType::Fact,
98            "episode" => MemoryType::Episode,
99            "knowledge" => MemoryType::Knowledge,
100            _ => {
101                return Ok(AgentToolResult::error(format!(
102                    "Invalid memory_type '{}'. Must be one of: fact, episode, knowledge",
103                    memory_type_str
104                )))
105            }
106        };
107
108        let tags: Vec<String> = params["tags"]
109            .as_array()
110            .map(|arr| {
111                arr.iter()
112                    .filter_map(|v| v.as_str().map(String::from))
113                    .collect()
114            })
115            .unwrap_or_default();
116
117        let importance = params["importance"].as_f64().unwrap_or(0.5) as f32;
118
119        let now = Utc::now();
120        let entry = MemoryEntry {
121            id: uuid::Uuid::new_v4().to_string(),
122            memory_type,
123            content,
124            source: "agent".to_string(),
125            session_id: None,
126            tags,
127            importance: importance.clamp(0.0, 1.0),
128            created_at: now,
129            accessed_at: now,
130            access_count: 0,
131        };
132        let entry_id = entry.id.clone();
133
134        match self.memory_manager.remember(entry).await {
135            Ok(_) => Ok(AgentToolResult::success(format!(
136                "Memory entry saved (id: {}, type: {})",
137                entry_id, memory_type_str,
138            ))),
139            Err(e) => Ok(AgentToolResult::error(format!(
140                "Failed to write memory: {e}"
141            ))),
142        }
143    }
144}
145
146/// Tool for reading memory entries.
147pub struct MemoryReadTool {
148    memory_manager: Arc<MemoryManager>,
149}
150
151impl MemoryReadTool {
152    /// Create a new MemoryReadTool.
153    pub fn new(memory_manager: Arc<MemoryManager>) -> Self {
154        Self { memory_manager }
155    }
156
157    /// Create a `MemoryReadTool` from a [`KernelHandle`].
158    ///
159    /// Extracts the memory manager from the kernel's agent facade.
160    pub fn from_kernel(kernel: &crate::kernel_handle::KernelHandle) -> Self {
161        Self::new(kernel.agents.memory_manager().clone())
162    }
163}
164
165impl std::fmt::Debug for MemoryReadTool {
166    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
167        f.debug_struct("MemoryReadTool").finish()
168    }
169}
170
171#[async_trait]
172impl AgentTool for MemoryReadTool {
173    fn name(&self) -> &str {
174        "memory_read"
175    }
176
177    fn label(&self) -> &str {
178        "Memory Read"
179    }
180
181    fn description(&self) -> &str {
182        "Read memory entries. Provide 'id' and 'memory_type' to read a specific entry, or just 'memory_type' to list entries of that type."
183    }
184
185    fn parameters_schema(&self) -> Value {
186        json!({
187            "type": "object",
188            "properties": {
189                "id": {
190                    "type": "string",
191                    "description": "Optional specific memory entry ID to read."
192                },
193                "memory_type": {
194                    "type": "string",
195                    "enum": ["fact", "episode", "knowledge"],
196                    "description": "Type of memory to list (required if no id provided)"
197                },
198                "limit": {
199                    "type": "integer",
200                    "description": "Max entries to return when listing (default 10)"
201                }
202            }
203        })
204    }
205
206    async fn execute(
207        &self,
208        _tool_call_id: &str,
209        params: Value,
210        _signal: Option<tokio::sync::oneshot::Receiver<()>>,
211        _ctx: &ToolContext,
212    ) -> Result<AgentToolResult, oxi_sdk::ToolError> {
213        let limit = params["limit"].as_u64().unwrap_or(10) as usize;
214
215        if let Some(id) = params["id"].as_str() {
216            // Need memory_type to look up by ID
217            let memory_type_str = params["memory_type"].as_str().unwrap_or("fact");
218            let memory_type = parse_memory_type(memory_type_str);
219
220            match self.memory_manager.get(id, memory_type).await {
221                Ok(Some(entry)) => {
222                    let output = format!(
223                        "ID: {}\nType: {}\nSource: {}\nTags: {}\nImportance: {:.2}\nCreated: {}\n\n{}",
224                        entry.id,
225                        entry.memory_type.label(),
226                        entry.source,
227                        entry.tags.join(", "),
228                        entry.importance,
229                        entry.created_at,
230                        entry.content,
231                    );
232                    Ok(AgentToolResult::success(&output))
233                }
234                Ok(None) => Ok(AgentToolResult::error(format!(
235                    "Memory entry '{}' not found",
236                    id
237                ))),
238                Err(e) => Ok(AgentToolResult::error(format!(
239                    "Failed to read memory: {e}"
240                ))),
241            }
242        } else {
243            // List entries by type
244            let memory_type_str = params["memory_type"].as_str().unwrap_or("fact");
245            let memory_type = parse_memory_type(memory_type_str);
246
247            match self.memory_manager.list(memory_type, limit).await {
248                Ok(entries) => {
249                    if entries.is_empty() {
250                        return Ok(AgentToolResult::success(format!(
251                            "No {} memory entries found.",
252                            memory_type_str,
253                        )));
254                    }
255                    let mut output =
256                        format!("Found {} {} entries:\n\n", entries.len(), memory_type_str,);
257                    for entry in &entries {
258                        let preview = truncate_str(&entry.content, 100);
259                        output.push_str(&format!(
260                            "- [{}] {} (id: {}…, tags: {})\n",
261                            entry.memory_type.label(),
262                            preview,
263                            &entry.id[..8.min(entry.id.len())],
264                            entry.tags.join(", "),
265                        ));
266                    }
267                    Ok(AgentToolResult::success(&output))
268                }
269                Err(e) => Ok(AgentToolResult::error(format!(
270                    "Failed to list memory: {e}"
271                ))),
272            }
273        }
274    }
275}
276
277/// Tool for searching memory entries by content or tags.
278pub struct MemorySearchTool {
279    memory_manager: Arc<MemoryManager>,
280}
281
282impl MemorySearchTool {
283    /// Create a new MemorySearchTool.
284    pub fn new(memory_manager: Arc<MemoryManager>) -> Self {
285        Self { memory_manager }
286    }
287
288    /// Create a `MemorySearchTool` from a [`KernelHandle`].
289    ///
290    /// Extracts the memory manager from the kernel's agent facade.
291    pub fn from_kernel(kernel: &crate::kernel_handle::KernelHandle) -> Self {
292        Self::new(kernel.agents.memory_manager().clone())
293    }
294}
295
296impl std::fmt::Debug for MemorySearchTool {
297    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
298        f.debug_struct("MemorySearchTool").finish()
299    }
300}
301
302#[async_trait]
303impl AgentTool for MemorySearchTool {
304    fn name(&self) -> &str {
305        "memory_search"
306    }
307
308    fn label(&self) -> &str {
309        "Memory Search"
310    }
311
312    fn description(&self) -> &str {
313        "Search memory entries by keyword query. Optionally filter by memory type."
314    }
315
316    fn parameters_schema(&self) -> Value {
317        json!({
318            "type": "object",
319            "properties": {
320                "query": {
321                    "type": "string",
322                    "description": "Text to search for in memory content"
323                },
324                "memory_type": {
325                    "type": "string",
326                    "enum": ["fact", "episode", "knowledge", "conversation", "session"],
327                    "description": "Optional memory type to filter by"
328                },
329                "limit": {
330                    "type": "integer",
331                    "description": "Max results (default 10)"
332                }
333            },
334            "required": ["query"]
335        })
336    }
337
338    async fn execute(
339        &self,
340        _tool_call_id: &str,
341        params: Value,
342        _signal: Option<tokio::sync::oneshot::Receiver<()>>,
343        _ctx: &ToolContext,
344    ) -> Result<AgentToolResult, oxi_sdk::ToolError> {
345        let query = params["query"].as_str().unwrap_or("");
346        if query.is_empty() {
347            return Ok(AgentToolResult::error("query is required"));
348        }
349
350        let limit = params["limit"].as_u64().unwrap_or(10) as usize;
351
352        let memory_type = params["memory_type"].as_str().map(parse_memory_type);
353
354        match self.memory_manager.search(query, memory_type, limit).await {
355            Ok(entries) => {
356                if entries.is_empty() {
357                    return Ok(AgentToolResult::success(
358                        "No matching memory entries found.",
359                    ));
360                }
361                let mut output = format!("Found {} matching entries:\n\n", entries.len());
362                for entry in &entries {
363                    let preview = truncate_str(&entry.content, 100);
364                    output.push_str(&format!(
365                        "- [{}] {} (id: {}…, importance: {:.2}, tags: {})\n",
366                        entry.memory_type.label(),
367                        preview,
368                        &entry.id[..8.min(entry.id.len())],
369                        entry.importance,
370                        entry.tags.join(", "),
371                    ));
372                }
373                Ok(AgentToolResult::success(&output))
374            }
375            Err(e) => Ok(AgentToolResult::error(format!(
376                "Failed to search memory: {e}"
377            ))),
378        }
379    }
380}
381
382/// Parse a memory type string, defaulting to Fact.
383fn parse_memory_type(s: &str) -> MemoryType {
384    match s {
385        "conversation" => MemoryType::Conversation,
386        "session" => MemoryType::Session,
387        "fact" => MemoryType::Fact,
388        "episode" => MemoryType::Episode,
389        "knowledge" => MemoryType::Knowledge,
390        _ => MemoryType::Fact,
391    }
392}
393
394/// Truncate a string to at most `max_chars` characters, respecting UTF-8 boundaries.
395fn truncate_str(s: &str, max_chars: usize) -> &str {
396    if s.len() <= max_chars {
397        return s;
398    }
399    // Find the largest valid char boundary <= max_chars.
400    let mut boundary = max_chars;
401    while boundary > 0 && !s.is_char_boundary(boundary) {
402        boundary -= 1;
403    }
404    &s[..boundary]
405}
406
407#[cfg(test)]
408mod tests {
409    use super::*;
410
411    #[test]
412    fn test_truncate_str_ascii() {
413        assert_eq!(truncate_str("hello world", 5), "hello");
414        assert_eq!(truncate_str("hello", 10), "hello");
415        assert_eq!(truncate_str("", 5), "");
416    }
417
418    #[test]
419    fn test_truncate_str_utf8_korean() {
420        // Each Korean character is 3 bytes in UTF-8.
421        let korean = "안녕하세요"; // 15 bytes
422        assert_eq!(truncate_str(korean, 6), "안녕"); // 6 bytes = 2 chars
423        assert_eq!(truncate_str(korean, 7), "안녕"); // 7 bytes splits char → back to 6
424        assert_eq!(truncate_str(korean, 15), "안녕하세요");
425    }
426
427    #[test]
428    fn test_truncate_str_mixed() {
429        let mixed = "Hi 안녕"; // 2 + 1 + 6 = 9 bytes
430        assert_eq!(truncate_str(mixed, 3), "Hi ");
431        assert_eq!(truncate_str(mixed, 4), "Hi "); // 4 splits 안 → back to 3
432    }
433
434    #[test]
435    fn test_parse_memory_type() {
436        assert!(matches!(parse_memory_type("fact"), MemoryType::Fact));
437        assert!(matches!(parse_memory_type("episode"), MemoryType::Episode));
438        assert!(matches!(
439            parse_memory_type("knowledge"),
440            MemoryType::Knowledge
441        ));
442        assert!(matches!(
443            parse_memory_type("conversation"),
444            MemoryType::Conversation
445        ));
446        assert!(matches!(parse_memory_type("session"), MemoryType::Session));
447        assert!(matches!(parse_memory_type("unknown"), MemoryType::Fact));
448    }
449
450    fn make_test_mm() -> std::sync::Arc<crate::memory::MemoryManager> {
451        let dir = std::env::temp_dir().join(format!("test-memory-{}", uuid::Uuid::new_v4()));
452        let state_store = std::sync::Arc::new(
453            crate::state_store::StateStore::new(dir).expect("test state store"),
454        );
455        std::sync::Arc::new(crate::memory::MemoryManager::new(state_store))
456    }
457
458    #[test]
459    fn test_memory_write_tool_schema() {
460        let mm = make_test_mm();
461        let tool = MemoryWriteTool::new(mm);
462        assert_eq!(tool.name(), "memory_write");
463        let schema = tool.parameters_schema();
464        assert!(schema["required"].is_array());
465    }
466
467    #[test]
468    fn test_memory_read_tool_schema() {
469        let mm = make_test_mm();
470        let tool = MemoryReadTool::new(mm);
471        assert_eq!(tool.name(), "memory_read");
472    }
473
474    #[test]
475    fn test_memory_search_tool_schema() {
476        let mm = make_test_mm();
477        let tool = MemorySearchTool::new(mm);
478        assert_eq!(tool.name(), "memory_search");
479        let schema = tool.parameters_schema();
480        assert!(schema["required"].is_array());
481    }
482}