1use std::sync::Arc;
9
10use async_trait::async_trait;
11use chrono::Utc;
12use oxi_sdk::{AgentTool, AgentToolResult, ToolContext};
13use serde_json::{json, Value};
14
15use crate::memory::{MemoryEntry, MemoryManager, MemoryType};
16
17pub struct MemoryWriteTool {
19 memory_manager: Arc<MemoryManager>,
20}
21
22impl MemoryWriteTool {
23 pub fn new(memory_manager: Arc<MemoryManager>) -> Self {
25 Self { memory_manager }
26 }
27
28 pub fn from_kernel(kernel: &crate::kernel_handle::KernelHandle) -> Self {
32 Self::new(kernel.agents.memory_manager().clone())
33 }
34}
35
36impl std::fmt::Debug for MemoryWriteTool {
37 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38 f.debug_struct("MemoryWriteTool").finish()
39 }
40}
41
42#[async_trait]
43impl AgentTool for MemoryWriteTool {
44 fn name(&self) -> &str {
45 "memory_write"
46 }
47
48 fn label(&self) -> &str {
49 "Memory Write"
50 }
51
52 fn description(&self) -> &str {
53 "Write a memory entry that persists across sessions. Use this to save important facts, episodes, or knowledge for future reference."
54 }
55
56 fn parameters_schema(&self) -> Value {
57 json!({
58 "type": "object",
59 "properties": {
60 "content": {
61 "type": "string",
62 "description": "The memory content to store"
63 },
64 "memory_type": {
65 "type": "string",
66 "enum": ["fact", "episode", "knowledge"],
67 "description": "The type of memory entry"
68 },
69 "tags": {
70 "type": "array",
71 "items": { "type": "string" },
72 "description": "Optional tags for categorization"
73 },
74 "importance": {
75 "type": "number",
76 "description": "Importance score 0.0-1.0 (default 0.5)"
77 }
78 },
79 "required": ["content", "memory_type"]
80 })
81 }
82
83 async fn execute(
84 &self,
85 _tool_call_id: &str,
86 params: Value,
87 _signal: Option<tokio::sync::oneshot::Receiver<()>>,
88 _ctx: &ToolContext,
89 ) -> Result<AgentToolResult, oxi_sdk::ToolError> {
90 let content = params["content"].as_str().unwrap_or("").to_string();
91 if content.is_empty() {
92 return Ok(AgentToolResult::error("content is required"));
93 }
94
95 let memory_type_str = params["memory_type"].as_str().unwrap_or("fact");
96 let memory_type = match memory_type_str {
97 "fact" => MemoryType::Fact,
98 "episode" => MemoryType::Episode,
99 "knowledge" => MemoryType::Knowledge,
100 _ => {
101 return Ok(AgentToolResult::error(format!(
102 "Invalid memory_type '{memory_type_str}'. Must be one of: fact, episode, knowledge"
103 )))
104 }
105 };
106
107 let tags: Vec<String> = params["tags"]
108 .as_array()
109 .map(|arr| {
110 arr.iter()
111 .filter_map(|v| v.as_str().map(String::from))
112 .collect()
113 })
114 .unwrap_or_default();
115
116 let importance = params["importance"].as_f64().unwrap_or(0.5) as f32;
117
118 let now = Utc::now();
119 let entry = MemoryEntry {
120 id: uuid::Uuid::new_v4().to_string(),
121 memory_type,
122 tier: memory_type.initial_tier(),
123 content: content.clone(),
124 content_hash: crate::memory::content_hash(&content),
125 source: "agent".to_string(),
126 session_id: None,
127 tags: tags.clone(),
128 importance: importance.clamp(0.0, 1.0),
129 pinned: false,
130 protection: crate::memory::ProtectionLevel::None,
131 auto_classified: false,
132 session_appearances: 0,
133 user_corrected: false,
134 seen_in_sessions: vec![],
135 created_at: now,
136 accessed_at: now,
137 modified_at: now,
138 access_count: 0,
139 decay_score: 1.0,
140 compaction_level: 0,
141 compacted_from: vec![],
142 related_ids: vec![],
143 contradicts: None,
144 };
145 let entry_id = entry.id.clone();
146
147 match self.memory_manager.remember(entry).await {
148 Ok(_) => Ok(AgentToolResult::success(format!(
149 "Memory entry saved (id: {entry_id}, type: {memory_type_str})",
150 ))),
151 Err(e) => Ok(AgentToolResult::error(format!(
152 "Failed to write memory: {e}"
153 ))),
154 }
155 }
156}
157
158pub struct MemoryReadTool {
160 memory_manager: Arc<MemoryManager>,
161}
162
163impl MemoryReadTool {
164 pub fn new(memory_manager: Arc<MemoryManager>) -> Self {
166 Self { memory_manager }
167 }
168
169 pub fn from_kernel(kernel: &crate::kernel_handle::KernelHandle) -> Self {
173 Self::new(kernel.agents.memory_manager().clone())
174 }
175}
176
177impl std::fmt::Debug for MemoryReadTool {
178 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
179 f.debug_struct("MemoryReadTool").finish()
180 }
181}
182
183#[async_trait]
184impl AgentTool for MemoryReadTool {
185 fn name(&self) -> &str {
186 "memory_read"
187 }
188
189 fn label(&self) -> &str {
190 "Memory Read"
191 }
192
193 fn description(&self) -> &str {
194 "Read memory entries. Provide 'id' and 'memory_type' to read a specific entry, or just 'memory_type' to list entries of that type."
195 }
196
197 fn parameters_schema(&self) -> Value {
198 json!({
199 "type": "object",
200 "properties": {
201 "id": {
202 "type": "string",
203 "description": "Optional specific memory entry ID to read."
204 },
205 "memory_type": {
206 "type": "string",
207 "enum": ["fact", "episode", "knowledge"],
208 "description": "Type of memory to list (required if no id provided)"
209 },
210 "limit": {
211 "type": "integer",
212 "description": "Max entries to return when listing (default 10)"
213 }
214 }
215 })
216 }
217
218 async fn execute(
219 &self,
220 _tool_call_id: &str,
221 params: Value,
222 _signal: Option<tokio::sync::oneshot::Receiver<()>>,
223 _ctx: &ToolContext,
224 ) -> Result<AgentToolResult, oxi_sdk::ToolError> {
225 let limit = params["limit"].as_u64().unwrap_or(10) as usize;
226
227 if let Some(id) = params["id"].as_str() {
228 let memory_type_str = params["memory_type"].as_str().unwrap_or("fact");
230 let memory_type = parse_memory_type(memory_type_str);
231
232 match self.memory_manager.get(id, memory_type).await {
233 Ok(Some(entry)) => {
234 let output = format!(
235 "ID: {}\nType: {}\nSource: {}\nTags: {}\nImportance: {:.2}\nCreated: {}\n\n{}",
236 entry.id,
237 entry.memory_type.label(),
238 entry.source,
239 entry.tags.join(", "),
240 entry.importance,
241 entry.created_at,
242 entry.content,
243 );
244 Ok(AgentToolResult::success(&output))
245 }
246 Ok(None) => Ok(AgentToolResult::error(format!(
247 "Memory entry '{id}' not found"
248 ))),
249 Err(e) => Ok(AgentToolResult::error(format!(
250 "Failed to read memory: {e}"
251 ))),
252 }
253 } else {
254 let memory_type_str = params["memory_type"].as_str().unwrap_or("fact");
256 let memory_type = parse_memory_type(memory_type_str);
257
258 match self.memory_manager.list(memory_type, limit).await {
259 Ok(entries) => {
260 if entries.is_empty() {
261 return Ok(AgentToolResult::success(format!(
262 "No {memory_type_str} memory entries found.",
263 )));
264 }
265 let mut output =
266 format!("Found {} {} entries:\n\n", entries.len(), memory_type_str,);
267 for entry in &entries {
268 let preview = truncate_str(&entry.content, 100);
269 output.push_str(&format!(
270 "- [{}] {} (id: {}…, tags: {})\n",
271 entry.memory_type.label(),
272 preview,
273 &entry.id[..8.min(entry.id.len())],
274 entry.tags.join(", "),
275 ));
276 }
277 Ok(AgentToolResult::success(&output))
278 }
279 Err(e) => Ok(AgentToolResult::error(format!(
280 "Failed to list memory: {e}"
281 ))),
282 }
283 }
284 }
285}
286
287pub struct MemorySearchTool {
289 memory_manager: Arc<MemoryManager>,
290}
291
292impl MemorySearchTool {
293 pub fn new(memory_manager: Arc<MemoryManager>) -> Self {
295 Self { memory_manager }
296 }
297
298 pub fn from_kernel(kernel: &crate::kernel_handle::KernelHandle) -> Self {
302 Self::new(kernel.agents.memory_manager().clone())
303 }
304}
305
306impl std::fmt::Debug for MemorySearchTool {
307 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
308 f.debug_struct("MemorySearchTool").finish()
309 }
310}
311
312#[async_trait]
313impl AgentTool for MemorySearchTool {
314 fn name(&self) -> &str {
315 "memory_search"
316 }
317
318 fn label(&self) -> &str {
319 "Memory Search"
320 }
321
322 fn description(&self) -> &str {
323 "Search memory entries by keyword query. Optionally filter by memory type."
324 }
325
326 fn parameters_schema(&self) -> Value {
327 json!({
328 "type": "object",
329 "properties": {
330 "query": {
331 "type": "string",
332 "description": "Text to search for in memory content"
333 },
334 "memory_type": {
335 "type": "string",
336 "enum": ["fact", "episode", "knowledge", "conversation", "session"],
337 "description": "Optional memory type to filter by"
338 },
339 "limit": {
340 "type": "integer",
341 "description": "Max results (default 10)"
342 }
343 },
344 "required": ["query"]
345 })
346 }
347
348 async fn execute(
349 &self,
350 _tool_call_id: &str,
351 params: Value,
352 _signal: Option<tokio::sync::oneshot::Receiver<()>>,
353 _ctx: &ToolContext,
354 ) -> Result<AgentToolResult, oxi_sdk::ToolError> {
355 let query = params["query"].as_str().unwrap_or("");
356 if query.is_empty() {
357 return Ok(AgentToolResult::error("query is required"));
358 }
359
360 let limit = params["limit"].as_u64().unwrap_or(10) as usize;
361
362 let memory_type = params["memory_type"].as_str().map(parse_memory_type);
363
364 match self.memory_manager.search(query, memory_type, limit).await {
365 Ok(entries) => {
366 if entries.is_empty() {
367 return Ok(AgentToolResult::success(
368 "No matching memory entries found.",
369 ));
370 }
371 let mut output = format!("Found {} matching entries:\n\n", entries.len());
372 for entry in &entries {
373 let preview = truncate_str(&entry.content, 100);
374 output.push_str(&format!(
375 "- [{}] {} (id: {}…, importance: {:.2}, tags: {})\n",
376 entry.memory_type.label(),
377 preview,
378 &entry.id[..8.min(entry.id.len())],
379 entry.importance,
380 entry.tags.join(", "),
381 ));
382 }
383 Ok(AgentToolResult::success(&output))
384 }
385 Err(e) => Ok(AgentToolResult::error(format!(
386 "Failed to search memory: {e}"
387 ))),
388 }
389 }
390}
391
392fn parse_memory_type(s: &str) -> MemoryType {
394 match s {
395 "conversation" => MemoryType::Conversation,
396 "session" => MemoryType::Session,
397 "fact" => MemoryType::Fact,
398 "episode" => MemoryType::Episode,
399 "knowledge" => MemoryType::Knowledge,
400 "skill" => MemoryType::Skill,
401 "preference" => MemoryType::Preference,
402 "decision" => MemoryType::Decision,
403 "user_profile" | "profile" => MemoryType::UserProfile,
404 _ => MemoryType::Fact,
405 }
406}
407
408fn truncate_str(s: &str, max_chars: usize) -> &str {
410 if s.len() <= max_chars {
411 return s;
412 }
413 let mut boundary = max_chars;
415 while boundary > 0 && !s.is_char_boundary(boundary) {
416 boundary -= 1;
417 }
418 &s[..boundary]
419}
420
421#[cfg(test)]
422mod tests {
423 use super::*;
424
425 #[test]
426 fn test_truncate_str_ascii() {
427 assert_eq!(truncate_str("hello world", 5), "hello");
428 assert_eq!(truncate_str("hello", 10), "hello");
429 assert_eq!(truncate_str("", 5), "");
430 }
431
432 #[test]
433 fn test_truncate_str_utf8_korean() {
434 let korean = "안녕하세요"; assert_eq!(truncate_str(korean, 6), "안녕"); assert_eq!(truncate_str(korean, 7), "안녕"); assert_eq!(truncate_str(korean, 15), "안녕하세요");
439 }
440
441 #[test]
442 fn test_truncate_str_mixed() {
443 let mixed = "Hi 안녕"; assert_eq!(truncate_str(mixed, 3), "Hi ");
445 assert_eq!(truncate_str(mixed, 4), "Hi "); }
447
448 #[test]
449 fn test_parse_memory_type() {
450 assert!(matches!(parse_memory_type("fact"), MemoryType::Fact));
451 assert!(matches!(parse_memory_type("episode"), MemoryType::Episode));
452 assert!(matches!(
453 parse_memory_type("knowledge"),
454 MemoryType::Knowledge
455 ));
456 assert!(matches!(
457 parse_memory_type("conversation"),
458 MemoryType::Conversation
459 ));
460 assert!(matches!(parse_memory_type("session"), MemoryType::Session));
461 assert!(matches!(parse_memory_type("unknown"), MemoryType::Fact));
462 }
463
464 fn make_test_mm() -> std::sync::Arc<crate::memory::MemoryManager> {
465 let dir = std::env::temp_dir().join(format!("test-memory-{}", uuid::Uuid::new_v4()));
466 let state_store = std::sync::Arc::new(
467 crate::state_store::StateStore::new(dir).expect("test state store"),
468 );
469 std::sync::Arc::new(crate::memory::MemoryManager::new(state_store))
470 }
471
472 #[test]
473 fn test_memory_write_tool_schema() {
474 let mm = make_test_mm();
475 let tool = MemoryWriteTool::new(mm);
476 assert_eq!(tool.name(), "memory_write");
477 let schema = tool.parameters_schema();
478 assert!(schema["required"].is_array());
479 }
480
481 #[test]
482 fn test_memory_read_tool_schema() {
483 let mm = make_test_mm();
484 let tool = MemoryReadTool::new(mm);
485 assert_eq!(tool.name(), "memory_read");
486 }
487
488 #[test]
489 fn test_memory_search_tool_schema() {
490 let mm = make_test_mm();
491 let tool = MemorySearchTool::new(mm);
492 assert_eq!(tool.name(), "memory_search");
493 let schema = tool.parameters_schema();
494 assert!(schema["required"].is_array());
495 }
496}