Skip to main content

oxios_kernel/tools/
memory_tools.rs

1//! Memory tools for cross-session agent memory.
2//!
3//! Provides three tools:
4//! - `memory_write` — write a memory entry
5//! - `memory_read` — read/list memory entries
6//! - `memory_search` — search memory entries by content or tags
7
8use std::sync::Arc;
9
10use async_trait::async_trait;
11use chrono::Utc;
12use oxi_sdk::{AgentTool, AgentToolResult, ToolContext};
13use serde_json::{json, Value};
14
15use crate::memory::{MemoryEntry, MemoryManager, MemoryType};
16
17/// Tool for writing memory entries that persist across sessions.
18pub struct MemoryWriteTool {
19    memory_manager: Arc<MemoryManager>,
20}
21
22impl MemoryWriteTool {
23    /// Create a new MemoryWriteTool.
24    pub fn new(memory_manager: Arc<MemoryManager>) -> Self {
25        Self { memory_manager }
26    }
27
28    /// Create a `MemoryWriteTool` from a [`KernelHandle`].
29    ///
30    /// Extracts the memory manager from the kernel's agent facade.
31    pub fn from_kernel(kernel: &crate::kernel_handle::KernelHandle) -> Self {
32        Self::new(kernel.agents.memory_manager().clone())
33    }
34}
35
36impl std::fmt::Debug for MemoryWriteTool {
37    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38        f.debug_struct("MemoryWriteTool").finish()
39    }
40}
41
42#[async_trait]
43impl AgentTool for MemoryWriteTool {
44    fn name(&self) -> &str {
45        "memory_write"
46    }
47
48    fn label(&self) -> &str {
49        "Memory Write"
50    }
51
52    fn description(&self) -> &str {
53        "Write a memory entry that persists across sessions. Use this to save important facts, episodes, or knowledge for future reference."
54    }
55
56    fn parameters_schema(&self) -> Value {
57        json!({
58            "type": "object",
59            "properties": {
60                "content": {
61                    "type": "string",
62                    "description": "The memory content to store"
63                },
64                "memory_type": {
65                    "type": "string",
66                    "enum": ["fact", "episode", "knowledge"],
67                    "description": "The type of memory entry"
68                },
69                "tags": {
70                    "type": "array",
71                    "items": { "type": "string" },
72                    "description": "Optional tags for categorization"
73                },
74                "importance": {
75                    "type": "number",
76                    "description": "Importance score 0.0-1.0 (default 0.5)"
77                }
78            },
79            "required": ["content", "memory_type"]
80        })
81    }
82
83    async fn execute(
84        &self,
85        _tool_call_id: &str,
86        params: Value,
87        _signal: Option<tokio::sync::oneshot::Receiver<()>>,
88        _ctx: &ToolContext,
89    ) -> Result<AgentToolResult, oxi_sdk::ToolError> {
90        let content = params["content"].as_str().unwrap_or("").to_string();
91        if content.is_empty() {
92            return Ok(AgentToolResult::error("content is required"));
93        }
94
95        let memory_type_str = params["memory_type"].as_str().unwrap_or("fact");
96        let memory_type = match memory_type_str {
97            "fact" => MemoryType::Fact,
98            "episode" => MemoryType::Episode,
99            "knowledge" => MemoryType::Knowledge,
100            _ => {
101                return Ok(AgentToolResult::error(format!(
102                "Invalid memory_type '{memory_type_str}'. Must be one of: fact, episode, knowledge"
103            )))
104            }
105        };
106
107        let tags: Vec<String> = params["tags"]
108            .as_array()
109            .map(|arr| {
110                arr.iter()
111                    .filter_map(|v| v.as_str().map(String::from))
112                    .collect()
113            })
114            .unwrap_or_default();
115
116        let importance = params["importance"].as_f64().unwrap_or(0.5) as f32;
117
118        let now = Utc::now();
119        let entry = MemoryEntry {
120            id: uuid::Uuid::new_v4().to_string(),
121            memory_type,
122            tier: memory_type.initial_tier(),
123            content: content.clone(),
124            content_hash: crate::memory::content_hash(&content),
125            source: "agent".to_string(),
126            session_id: None,
127            tags: tags.clone(),
128            importance: importance.clamp(0.0, 1.0),
129            pinned: false,
130            protection: crate::memory::ProtectionLevel::None,
131            auto_classified: false,
132            session_appearances: 0,
133            user_corrected: false,
134            seen_in_sessions: vec![],
135            created_at: now,
136            accessed_at: now,
137            modified_at: now,
138            access_count: 0,
139            decay_score: 1.0,
140            compaction_level: 0,
141            compacted_from: vec![],
142            related_ids: vec![],
143            contradicts: None,
144        };
145        let entry_id = entry.id.clone();
146
147        match self.memory_manager.remember(entry).await {
148            Ok(_) => Ok(AgentToolResult::success(format!(
149                "Memory entry saved (id: {entry_id}, type: {memory_type_str})",
150            ))),
151            Err(e) => Ok(AgentToolResult::error(format!(
152                "Failed to write memory: {e}"
153            ))),
154        }
155    }
156}
157
158/// Tool for reading memory entries.
159pub struct MemoryReadTool {
160    memory_manager: Arc<MemoryManager>,
161}
162
163impl MemoryReadTool {
164    /// Create a new MemoryReadTool.
165    pub fn new(memory_manager: Arc<MemoryManager>) -> Self {
166        Self { memory_manager }
167    }
168
169    /// Create a `MemoryReadTool` from a [`KernelHandle`].
170    ///
171    /// Extracts the memory manager from the kernel's agent facade.
172    pub fn from_kernel(kernel: &crate::kernel_handle::KernelHandle) -> Self {
173        Self::new(kernel.agents.memory_manager().clone())
174    }
175}
176
177impl std::fmt::Debug for MemoryReadTool {
178    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
179        f.debug_struct("MemoryReadTool").finish()
180    }
181}
182
183#[async_trait]
184impl AgentTool for MemoryReadTool {
185    fn name(&self) -> &str {
186        "memory_read"
187    }
188
189    fn label(&self) -> &str {
190        "Memory Read"
191    }
192
193    fn description(&self) -> &str {
194        "Read memory entries. Provide 'id' and 'memory_type' to read a specific entry, or just 'memory_type' to list entries of that type."
195    }
196
197    fn parameters_schema(&self) -> Value {
198        json!({
199            "type": "object",
200            "properties": {
201                "id": {
202                    "type": "string",
203                    "description": "Optional specific memory entry ID to read."
204                },
205                "memory_type": {
206                    "type": "string",
207                    "enum": ["fact", "episode", "knowledge"],
208                    "description": "Type of memory to list (required if no id provided)"
209                },
210                "limit": {
211                    "type": "integer",
212                    "description": "Max entries to return when listing (default 10)"
213                }
214            }
215        })
216    }
217
218    async fn execute(
219        &self,
220        _tool_call_id: &str,
221        params: Value,
222        _signal: Option<tokio::sync::oneshot::Receiver<()>>,
223        _ctx: &ToolContext,
224    ) -> Result<AgentToolResult, oxi_sdk::ToolError> {
225        let limit = params["limit"].as_u64().unwrap_or(10) as usize;
226
227        if let Some(id) = params["id"].as_str() {
228            // Need memory_type to look up by ID
229            let memory_type_str = params["memory_type"].as_str().unwrap_or("fact");
230            let memory_type = parse_memory_type(memory_type_str);
231
232            match self.memory_manager.get(id, memory_type).await {
233                Ok(Some(entry)) => {
234                    let output = format!(
235                        "ID: {}\nType: {}\nSource: {}\nTags: {}\nImportance: {:.2}\nCreated: {}\n\n{}",
236                        entry.id,
237                        entry.memory_type.label(),
238                        entry.source,
239                        entry.tags.join(", "),
240                        entry.importance,
241                        entry.created_at,
242                        entry.content,
243                    );
244                    Ok(AgentToolResult::success(&output))
245                }
246                Ok(None) => Ok(AgentToolResult::error(format!(
247                    "Memory entry '{id}' not found"
248                ))),
249                Err(e) => Ok(AgentToolResult::error(format!(
250                    "Failed to read memory: {e}"
251                ))),
252            }
253        } else {
254            // List entries by type
255            let memory_type_str = params["memory_type"].as_str().unwrap_or("fact");
256            let memory_type = parse_memory_type(memory_type_str);
257
258            match self.memory_manager.list(memory_type, limit).await {
259                Ok(entries) => {
260                    if entries.is_empty() {
261                        return Ok(AgentToolResult::success(format!(
262                            "No {memory_type_str} memory entries found.",
263                        )));
264                    }
265                    let mut output =
266                        format!("Found {} {} entries:\n\n", entries.len(), memory_type_str,);
267                    for entry in &entries {
268                        let preview = truncate_str(&entry.content, 100);
269                        output.push_str(&format!(
270                            "- [{}] {} (id: {}…, tags: {})\n",
271                            entry.memory_type.label(),
272                            preview,
273                            &entry.id[..8.min(entry.id.len())],
274                            entry.tags.join(", "),
275                        ));
276                    }
277                    Ok(AgentToolResult::success(&output))
278                }
279                Err(e) => Ok(AgentToolResult::error(format!(
280                    "Failed to list memory: {e}"
281                ))),
282            }
283        }
284    }
285}
286
287/// Tool for searching memory entries by content or tags.
288pub struct MemorySearchTool {
289    memory_manager: Arc<MemoryManager>,
290}
291
292impl MemorySearchTool {
293    /// Create a new MemorySearchTool.
294    pub fn new(memory_manager: Arc<MemoryManager>) -> Self {
295        Self { memory_manager }
296    }
297
298    /// Create a `MemorySearchTool` from a [`KernelHandle`].
299    ///
300    /// Extracts the memory manager from the kernel's agent facade.
301    pub fn from_kernel(kernel: &crate::kernel_handle::KernelHandle) -> Self {
302        Self::new(kernel.agents.memory_manager().clone())
303    }
304}
305
306impl std::fmt::Debug for MemorySearchTool {
307    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
308        f.debug_struct("MemorySearchTool").finish()
309    }
310}
311
312#[async_trait]
313impl AgentTool for MemorySearchTool {
314    fn name(&self) -> &str {
315        "memory_search"
316    }
317
318    fn label(&self) -> &str {
319        "Memory Search"
320    }
321
322    fn description(&self) -> &str {
323        "Search memory entries by keyword query. Optionally filter by memory type."
324    }
325
326    fn parameters_schema(&self) -> Value {
327        json!({
328            "type": "object",
329            "properties": {
330                "query": {
331                    "type": "string",
332                    "description": "Text to search for in memory content"
333                },
334                "memory_type": {
335                    "type": "string",
336                    "enum": ["fact", "episode", "knowledge", "conversation", "session"],
337                    "description": "Optional memory type to filter by"
338                },
339                "limit": {
340                    "type": "integer",
341                    "description": "Max results (default 10)"
342                }
343            },
344            "required": ["query"]
345        })
346    }
347
348    async fn execute(
349        &self,
350        _tool_call_id: &str,
351        params: Value,
352        _signal: Option<tokio::sync::oneshot::Receiver<()>>,
353        _ctx: &ToolContext,
354    ) -> Result<AgentToolResult, oxi_sdk::ToolError> {
355        let query = params["query"].as_str().unwrap_or("");
356        if query.is_empty() {
357            return Ok(AgentToolResult::error("query is required"));
358        }
359
360        let limit = params["limit"].as_u64().unwrap_or(10) as usize;
361
362        let memory_type = params["memory_type"].as_str().map(parse_memory_type);
363
364        match self.memory_manager.search(query, memory_type, limit).await {
365            Ok(entries) => {
366                if entries.is_empty() {
367                    return Ok(AgentToolResult::success(
368                        "No matching memory entries found.",
369                    ));
370                }
371                let mut output = format!("Found {} matching entries:\n\n", entries.len());
372                for entry in &entries {
373                    let preview = truncate_str(&entry.content, 100);
374                    output.push_str(&format!(
375                        "- [{}] {} (id: {}…, importance: {:.2}, tags: {})\n",
376                        entry.memory_type.label(),
377                        preview,
378                        &entry.id[..8.min(entry.id.len())],
379                        entry.importance,
380                        entry.tags.join(", "),
381                    ));
382                }
383                Ok(AgentToolResult::success(&output))
384            }
385            Err(e) => Ok(AgentToolResult::error(format!(
386                "Failed to search memory: {e}"
387            ))),
388        }
389    }
390}
391
392/// Parse a memory type string, defaulting to Fact.
393fn parse_memory_type(s: &str) -> MemoryType {
394    match s {
395        "conversation" => MemoryType::Conversation,
396        "session" => MemoryType::Session,
397        "fact" => MemoryType::Fact,
398        "episode" => MemoryType::Episode,
399        "knowledge" => MemoryType::Knowledge,
400        "skill" => MemoryType::Skill,
401        "preference" => MemoryType::Preference,
402        "decision" => MemoryType::Decision,
403        "user_profile" | "profile" => MemoryType::UserProfile,
404        _ => MemoryType::Fact,
405    }
406}
407
408/// Truncate a string to at most `max_chars` characters, respecting UTF-8 boundaries.
409fn truncate_str(s: &str, max_chars: usize) -> &str {
410    if s.len() <= max_chars {
411        return s;
412    }
413    // Find the largest valid char boundary <= max_chars.
414    let mut boundary = max_chars;
415    while boundary > 0 && !s.is_char_boundary(boundary) {
416        boundary -= 1;
417    }
418    &s[..boundary]
419}
420
421#[cfg(test)]
422mod tests {
423    use super::*;
424
425    #[test]
426    fn test_truncate_str_ascii() {
427        assert_eq!(truncate_str("hello world", 5), "hello");
428        assert_eq!(truncate_str("hello", 10), "hello");
429        assert_eq!(truncate_str("", 5), "");
430    }
431
432    #[test]
433    fn test_truncate_str_utf8_korean() {
434        // Each Korean character is 3 bytes in UTF-8.
435        let korean = "안녕하세요"; // 15 bytes
436        assert_eq!(truncate_str(korean, 6), "안녕"); // 6 bytes = 2 chars
437        assert_eq!(truncate_str(korean, 7), "안녕"); // 7 bytes splits char → back to 6
438        assert_eq!(truncate_str(korean, 15), "안녕하세요");
439    }
440
441    #[test]
442    fn test_truncate_str_mixed() {
443        let mixed = "Hi 안녕"; // 2 + 1 + 6 = 9 bytes
444        assert_eq!(truncate_str(mixed, 3), "Hi ");
445        assert_eq!(truncate_str(mixed, 4), "Hi "); // 4 splits 안 → back to 3
446    }
447
448    #[test]
449    fn test_parse_memory_type() {
450        assert!(matches!(parse_memory_type("fact"), MemoryType::Fact));
451        assert!(matches!(parse_memory_type("episode"), MemoryType::Episode));
452        assert!(matches!(
453            parse_memory_type("knowledge"),
454            MemoryType::Knowledge
455        ));
456        assert!(matches!(
457            parse_memory_type("conversation"),
458            MemoryType::Conversation
459        ));
460        assert!(matches!(parse_memory_type("session"), MemoryType::Session));
461        assert!(matches!(parse_memory_type("unknown"), MemoryType::Fact));
462    }
463
464    fn make_test_mm() -> std::sync::Arc<crate::memory::MemoryManager> {
465        let dir = std::env::temp_dir().join(format!("test-memory-{}", uuid::Uuid::new_v4()));
466        let state_store = std::sync::Arc::new(
467            crate::state_store::StateStore::new(dir).expect("test state store"),
468        );
469        std::sync::Arc::new(crate::memory::MemoryManager::new(state_store))
470    }
471
472    #[test]
473    fn test_memory_write_tool_schema() {
474        let mm = make_test_mm();
475        let tool = MemoryWriteTool::new(mm);
476        assert_eq!(tool.name(), "memory_write");
477        let schema = tool.parameters_schema();
478        assert!(schema["required"].is_array());
479    }
480
481    #[test]
482    fn test_memory_read_tool_schema() {
483        let mm = make_test_mm();
484        let tool = MemoryReadTool::new(mm);
485        assert_eq!(tool.name(), "memory_read");
486    }
487
488    #[test]
489    fn test_memory_search_tool_schema() {
490        let mm = make_test_mm();
491        let tool = MemorySearchTool::new(mm);
492        assert_eq!(tool.name(), "memory_search");
493        let schema = tool.parameters_schema();
494        assert!(schema["required"].is_array());
495    }
496}