Skip to main content

oxios_kernel/tools/
memory_tools.rs

1//! Memory tools for cross-session agent memory.
2//!
3//! Provides three tools:
4//! - `memory_write` — write a memory entry
5//! - `memory_read` — read/list memory entries
6//! - `memory_search` — search memory entries by content or tags
7
8use async_trait::async_trait;
9use std::sync::Arc;
10
11use chrono::Utc;
12use oxi_sdk::{AgentTool, AgentToolResult, ToolContext};
13use serde_json::{Value, json};
14
15use crate::memory::{MemoryEntry, MemoryManager, MemoryType};
16
17/// Tool for writing memory entries that persist across sessions.
18pub struct MemoryWriteTool {
19    memory_manager: Arc<MemoryManager>,
20}
21
22impl MemoryWriteTool {
23    /// Create a new MemoryWriteTool.
24    pub fn new(memory_manager: Arc<MemoryManager>) -> Self {
25        Self { memory_manager }
26    }
27
28    /// Create a `MemoryWriteTool` from a [`KernelHandle`].
29    ///
30    /// Extracts the memory manager from the kernel's agent facade.
31    pub fn from_kernel(kernel: &crate::kernel_handle::KernelHandle) -> Self {
32        Self::new(kernel.agents.memory_manager().clone())
33    }
34}
35
36impl std::fmt::Debug for MemoryWriteTool {
37    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38        f.debug_struct("MemoryWriteTool").finish()
39    }
40}
41
42#[async_trait]
43
44impl AgentTool for MemoryWriteTool {
45    fn name(&self) -> &str {
46        "memory_write"
47    }
48
49    fn label(&self) -> &str {
50        "Memory Write"
51    }
52
53    fn description(&self) -> &str {
54        "Store a recallable agent memory — facts about the user, behavioral patterns, session observations, preference corrections. Internal to the agent. Persisted across sessions via SQLite + HNSW vector index."
55    }
56
57    fn parameters_schema(&self) -> Value {
58        json!({
59            "type": "object",
60            "properties": {
61                "content": {
62                    "type": "string",
63                    "description": "The memory content to store"
64                },
65                "memory_type": {
66                    "type": "string",
67                    "enum": ["fact", "episode"],
68                    "description": "The type of memory entry"
69                },
70                "tags": {
71                    "type": "array",
72                    "items": { "type": "string" },
73                    "description": "Optional tags for categorization"
74                },
75                "importance": {
76                    "type": "number",
77                    "description": "Importance score 0.0-1.0 (default 0.5)"
78                }
79            },
80            "required": ["content", "memory_type"]
81        })
82    }
83
84    async fn execute(
85        &self,
86        _tool_call_id: &str,
87        params: Value,
88        _signal: Option<tokio::sync::oneshot::Receiver<()>>,
89        _ctx: &ToolContext,
90    ) -> Result<AgentToolResult, oxi_sdk::ToolError> {
91        let content = params["content"].as_str().unwrap_or("").to_string();
92        if content.is_empty() {
93            return Ok(AgentToolResult::error("content is required"));
94        }
95
96        let memory_type_str = params["memory_type"].as_str().unwrap_or("fact");
97        let memory_type = match memory_type_str {
98            "fact" => MemoryType::Fact,
99            "episode" => MemoryType::Episode,
100            "knowledge" => MemoryType::Knowledge,
101            _ => {
102                return Ok(AgentToolResult::error(format!(
103                    "Invalid memory_type '{memory_type_str}'. Must be one of: fact, episode, knowledge"
104                )));
105            }
106        };
107
108        let tags: Vec<String> = params["tags"]
109            .as_array()
110            .map(|arr| {
111                arr.iter()
112                    .filter_map(|v| v.as_str().map(String::from))
113                    .collect()
114            })
115            .unwrap_or_default();
116
117        let importance = params["importance"].as_f64().unwrap_or(0.5) as f32;
118
119        let now = Utc::now();
120        let entry = MemoryEntry {
121            id: uuid::Uuid::new_v4().to_string(),
122            memory_type,
123            tier: memory_type.initial_tier(),
124            content: content.clone(),
125            content_hash: crate::memory::content_hash(&content),
126            source: "agent".to_string(),
127            session_id: None,
128            tags: tags.clone(),
129            importance: importance.clamp(0.0, 1.0),
130            pinned: false,
131            protection: crate::memory::ProtectionLevel::None,
132            auto_classified: false,
133            session_appearances: 0,
134            user_corrected: false,
135            seen_in_sessions: vec![],
136            created_at: now,
137            accessed_at: now,
138            modified_at: now,
139            access_count: 0,
140            decay_score: 1.0,
141            compaction_level: 0,
142            compacted_from: vec![],
143            related_ids: vec![],
144            contradicts: None,
145        };
146        let entry_id = entry.id.clone();
147
148        match self.memory_manager.remember(entry).await {
149            Ok(_) => Ok(AgentToolResult::success(format!(
150                "Memory entry saved (id: {entry_id}, type: {memory_type_str})",
151            ))),
152            Err(e) => Ok(AgentToolResult::error(format!(
153                "Failed to write memory: {e}"
154            ))),
155        }
156    }
157}
158
159/// Tool for reading memory entries.
160pub struct MemoryReadTool {
161    memory_manager: Arc<MemoryManager>,
162}
163
164impl MemoryReadTool {
165    /// Create a new MemoryReadTool.
166    pub fn new(memory_manager: Arc<MemoryManager>) -> Self {
167        Self { memory_manager }
168    }
169
170    /// Create a `MemoryReadTool` from a [`KernelHandle`].
171    ///
172    /// Extracts the memory manager from the kernel's agent facade.
173    pub fn from_kernel(kernel: &crate::kernel_handle::KernelHandle) -> Self {
174        Self::new(kernel.agents.memory_manager().clone())
175    }
176}
177
178impl std::fmt::Debug for MemoryReadTool {
179    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
180        f.debug_struct("MemoryReadTool").finish()
181    }
182}
183
184#[async_trait]
185
186impl AgentTool for MemoryReadTool {
187    fn name(&self) -> &str {
188        "memory_read"
189    }
190
191    fn label(&self) -> &str {
192        "Memory Read"
193    }
194
195    fn description(&self) -> &str {
196        "Read memory entries. Provide 'id' and 'memory_type' to read a specific entry, or just 'memory_type' to list entries of that type."
197    }
198
199    fn parameters_schema(&self) -> Value {
200        json!({
201            "type": "object",
202            "properties": {
203                "id": {
204                    "type": "string",
205                    "description": "Optional specific memory entry ID to read."
206                },
207                "memory_type": {
208                    "type": "string",
209                    "enum": ["fact", "episode", "knowledge"],
210                    "description": "Type of memory to list (required if no id provided)"
211                },
212                "limit": {
213                    "type": "integer",
214                    "description": "Max entries to return when listing (default 10)"
215                }
216            }
217        })
218    }
219
220    async fn execute(
221        &self,
222        _tool_call_id: &str,
223        params: Value,
224        _signal: Option<tokio::sync::oneshot::Receiver<()>>,
225        _ctx: &ToolContext,
226    ) -> Result<AgentToolResult, oxi_sdk::ToolError> {
227        let limit = params["limit"].as_u64().unwrap_or(10) as usize;
228
229        if let Some(id) = params["id"].as_str() {
230            // Need memory_type to look up by ID
231            let memory_type_str = params["memory_type"].as_str().unwrap_or("fact");
232            let memory_type = parse_memory_type(memory_type_str);
233
234            match self.memory_manager.get(id, memory_type).await {
235                Ok(Some(entry)) => {
236                    let output = format!(
237                        "ID: {}\nType: {}\nSource: {}\nTags: {}\nImportance: {:.2}\nCreated: {}\n\n{}",
238                        entry.id,
239                        entry.memory_type.label(),
240                        entry.source,
241                        entry.tags.join(", "),
242                        entry.importance,
243                        entry.created_at,
244                        entry.content,
245                    );
246                    Ok(AgentToolResult::success(&output))
247                }
248                Ok(None) => Ok(AgentToolResult::error(format!(
249                    "Memory entry '{id}' not found"
250                ))),
251                Err(e) => Ok(AgentToolResult::error(format!(
252                    "Failed to read memory: {e}"
253                ))),
254            }
255        } else {
256            // List entries by type
257            let memory_type_str = params["memory_type"].as_str().unwrap_or("fact");
258            let memory_type = parse_memory_type(memory_type_str);
259
260            match self.memory_manager.list(memory_type, limit).await {
261                Ok(entries) => {
262                    if entries.is_empty() {
263                        return Ok(AgentToolResult::success(format!(
264                            "No {memory_type_str} memory entries found.",
265                        )));
266                    }
267                    let mut output =
268                        format!("Found {} {} entries:\n\n", entries.len(), memory_type_str,);
269                    for entry in &entries {
270                        let preview = truncate_str(&entry.content, 100);
271                        output.push_str(&format!(
272                            "- [{}] {} (id: {}…, tags: {})\n",
273                            entry.memory_type.label(),
274                            preview,
275                            &entry.id[..8.min(entry.id.len())],
276                            entry.tags.join(", "),
277                        ));
278                    }
279                    Ok(AgentToolResult::success(&output))
280                }
281                Err(e) => Ok(AgentToolResult::error(format!(
282                    "Failed to list memory: {e}"
283                ))),
284            }
285        }
286    }
287}
288
289/// Tool for searching memory entries by content or tags.
290pub struct MemorySearchTool {
291    memory_manager: Arc<MemoryManager>,
292}
293
294impl MemorySearchTool {
295    /// Create a new MemorySearchTool.
296    pub fn new(memory_manager: Arc<MemoryManager>) -> Self {
297        Self { memory_manager }
298    }
299
300    /// Create a `MemorySearchTool` from a [`KernelHandle`].
301    ///
302    /// Extracts the memory manager from the kernel's agent facade.
303    pub fn from_kernel(kernel: &crate::kernel_handle::KernelHandle) -> Self {
304        Self::new(kernel.agents.memory_manager().clone())
305    }
306}
307
308impl std::fmt::Debug for MemorySearchTool {
309    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
310        f.debug_struct("MemorySearchTool").finish()
311    }
312}
313
314#[async_trait]
315
316impl AgentTool for MemorySearchTool {
317    fn name(&self) -> &str {
318        "memory_search"
319    }
320
321    fn label(&self) -> &str {
322        "Memory Search"
323    }
324
325    fn description(&self) -> &str {
326        "Search memory entries by keyword query. Optionally filter by memory type."
327    }
328
329    fn parameters_schema(&self) -> Value {
330        json!({
331            "type": "object",
332            "properties": {
333                "query": {
334                    "type": "string",
335                    "description": "Text to search for in memory content"
336                },
337                "memory_type": {
338                    "type": "string",
339                    "enum": ["fact", "episode", "knowledge", "conversation", "session"],
340                    "description": "Optional memory type to filter by"
341                },
342                "limit": {
343                    "type": "integer",
344                    "description": "Max results (default 10)"
345                }
346            },
347            "required": ["query"]
348        })
349    }
350
351    async fn execute(
352        &self,
353        _tool_call_id: &str,
354        params: Value,
355        _signal: Option<tokio::sync::oneshot::Receiver<()>>,
356        _ctx: &ToolContext,
357    ) -> Result<AgentToolResult, oxi_sdk::ToolError> {
358        let query = params["query"].as_str().unwrap_or("");
359        if query.is_empty() {
360            return Ok(AgentToolResult::error("query is required"));
361        }
362
363        let limit = params["limit"].as_u64().unwrap_or(10) as usize;
364
365        let memory_type = params["memory_type"].as_str().map(parse_memory_type);
366
367        match self.memory_manager.search(query, memory_type, limit).await {
368            Ok(entries) => {
369                if entries.is_empty() {
370                    return Ok(AgentToolResult::success(
371                        "No matching memory entries found.",
372                    ));
373                }
374                let mut output = format!("Found {} matching entries:\n\n", entries.len());
375                for entry in &entries {
376                    let preview = truncate_str(&entry.content, 100);
377                    output.push_str(&format!(
378                        "- [{}] {} (id: {}…, importance: {:.2}, tags: {})\n",
379                        entry.memory_type.label(),
380                        preview,
381                        &entry.id[..8.min(entry.id.len())],
382                        entry.importance,
383                        entry.tags.join(", "),
384                    ));
385                }
386                Ok(AgentToolResult::success(&output))
387            }
388            Err(e) => Ok(AgentToolResult::error(format!(
389                "Failed to search memory: {e}"
390            ))),
391        }
392    }
393}
394
395/// Parse a memory type string, defaulting to Fact.
396fn parse_memory_type(s: &str) -> MemoryType {
397    match s {
398        "conversation" => MemoryType::Conversation,
399        "session" => MemoryType::Session,
400        "fact" => MemoryType::Fact,
401        "episode" => MemoryType::Episode,
402        "knowledge" => MemoryType::Knowledge,
403        "skill" => MemoryType::Skill,
404        "preference" => MemoryType::Preference,
405        "decision" => MemoryType::Decision,
406        "user_profile" | "profile" => MemoryType::UserProfile,
407        _ => MemoryType::Fact,
408    }
409}
410
411/// Truncate a string to at most `max_chars` characters, respecting UTF-8 boundaries.
412fn truncate_str(s: &str, max_chars: usize) -> &str {
413    if s.len() <= max_chars {
414        return s;
415    }
416    // Find the largest valid char boundary <= max_chars.
417    let mut boundary = max_chars;
418    while boundary > 0 && !s.is_char_boundary(boundary) {
419        boundary -= 1;
420    }
421    &s[..boundary]
422}
423
424#[cfg(test)]
425mod tests {
426    use super::*;
427
428    #[test]
429    fn test_truncate_str_ascii() {
430        assert_eq!(truncate_str("hello world", 5), "hello");
431        assert_eq!(truncate_str("hello", 10), "hello");
432        assert_eq!(truncate_str("", 5), "");
433    }
434
435    #[test]
436    fn test_truncate_str_utf8_korean() {
437        // Each Korean character is 3 bytes in UTF-8.
438        let korean = "안녕하세요"; // 15 bytes
439        assert_eq!(truncate_str(korean, 6), "안녕"); // 6 bytes = 2 chars
440        assert_eq!(truncate_str(korean, 7), "안녕"); // 7 bytes splits char → back to 6
441        assert_eq!(truncate_str(korean, 15), "안녕하세요");
442    }
443
444    #[test]
445    fn test_truncate_str_mixed() {
446        let mixed = "Hi 안녕"; // 2 + 1 + 6 = 9 bytes
447        assert_eq!(truncate_str(mixed, 3), "Hi ");
448        assert_eq!(truncate_str(mixed, 4), "Hi "); // 4 splits 안 → back to 3
449    }
450
451    #[test]
452    fn test_parse_memory_type() {
453        assert!(matches!(parse_memory_type("fact"), MemoryType::Fact));
454        assert!(matches!(parse_memory_type("episode"), MemoryType::Episode));
455        assert!(matches!(
456            parse_memory_type("knowledge"),
457            MemoryType::Knowledge
458        ));
459        assert!(matches!(
460            parse_memory_type("conversation"),
461            MemoryType::Conversation
462        ));
463        assert!(matches!(parse_memory_type("session"), MemoryType::Session));
464        assert!(matches!(parse_memory_type("unknown"), MemoryType::Fact));
465    }
466
467    fn make_test_mm() -> std::sync::Arc<crate::memory::MemoryManager> {
468        let dir = std::env::temp_dir().join(format!("test-memory-{}", uuid::Uuid::new_v4()));
469        let state_store = std::sync::Arc::new(
470            crate::state_store::StateStore::new(dir).expect("test state store"),
471        );
472        std::sync::Arc::new(crate::memory::MemoryManager::new(state_store))
473    }
474
475    #[test]
476    fn test_memory_write_tool_schema() {
477        let mm = make_test_mm();
478        let tool = MemoryWriteTool::new(mm);
479        assert_eq!(tool.name(), "memory_write");
480        let schema = tool.parameters_schema();
481        assert!(schema["required"].is_array());
482    }
483
484    #[test]
485    fn test_memory_read_tool_schema() {
486        let mm = make_test_mm();
487        let tool = MemoryReadTool::new(mm);
488        assert_eq!(tool.name(), "memory_read");
489    }
490
491    #[test]
492    fn test_memory_search_tool_schema() {
493        let mm = make_test_mm();
494        let tool = MemorySearchTool::new(mm);
495        assert_eq!(tool.name(), "memory_search");
496        let schema = tool.parameters_schema();
497        assert!(schema["required"].is_array());
498    }
499}