1use std::sync::Arc;
9
10use async_trait::async_trait;
11use chrono::Utc;
12use oxi_sdk::{AgentTool, AgentToolResult, ToolContext};
13use serde_json::{json, Value};
14
15use crate::memory::{MemoryEntry, MemoryManager, MemoryType};
16
17pub struct MemoryWriteTool {
19 memory_manager: Arc<MemoryManager>,
20}
21
22impl MemoryWriteTool {
23 pub fn new(memory_manager: Arc<MemoryManager>) -> Self {
25 Self { memory_manager }
26 }
27
28 pub fn from_kernel(kernel: &crate::kernel_handle::KernelHandle) -> Self {
32 Self::new(kernel.agents.memory_manager().clone())
33 }
34}
35
36impl std::fmt::Debug for MemoryWriteTool {
37 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38 f.debug_struct("MemoryWriteTool").finish()
39 }
40}
41
42#[async_trait]
43impl AgentTool for MemoryWriteTool {
44 fn name(&self) -> &str {
45 "memory_write"
46 }
47
48 fn label(&self) -> &str {
49 "Memory Write"
50 }
51
52 fn description(&self) -> &str {
53 "Write a memory entry that persists across sessions. Use this to save important facts, episodes, or knowledge for future reference."
54 }
55
56 fn parameters_schema(&self) -> Value {
57 json!({
58 "type": "object",
59 "properties": {
60 "content": {
61 "type": "string",
62 "description": "The memory content to store"
63 },
64 "memory_type": {
65 "type": "string",
66 "enum": ["fact", "episode", "knowledge"],
67 "description": "The type of memory entry"
68 },
69 "tags": {
70 "type": "array",
71 "items": { "type": "string" },
72 "description": "Optional tags for categorization"
73 },
74 "importance": {
75 "type": "number",
76 "description": "Importance score 0.0-1.0 (default 0.5)"
77 }
78 },
79 "required": ["content", "memory_type"]
80 })
81 }
82
83 async fn execute(
84 &self,
85 _tool_call_id: &str,
86 params: Value,
87 _signal: Option<tokio::sync::oneshot::Receiver<()>>,
88 _ctx: &ToolContext,
89 ) -> Result<AgentToolResult, oxi_sdk::ToolError> {
90 let content = params["content"].as_str().unwrap_or("").to_string();
91 if content.is_empty() {
92 return Ok(AgentToolResult::error("content is required"));
93 }
94
95 let memory_type_str = params["memory_type"].as_str().unwrap_or("fact");
96 let memory_type = match memory_type_str {
97 "fact" => MemoryType::Fact,
98 "episode" => MemoryType::Episode,
99 "knowledge" => MemoryType::Knowledge,
100 _ => {
101 return Ok(AgentToolResult::error(format!(
102 "Invalid memory_type '{}'. Must be one of: fact, episode, knowledge",
103 memory_type_str
104 )))
105 }
106 };
107
108 let tags: Vec<String> = params["tags"]
109 .as_array()
110 .map(|arr| {
111 arr.iter()
112 .filter_map(|v| v.as_str().map(String::from))
113 .collect()
114 })
115 .unwrap_or_default();
116
117 let importance = params["importance"].as_f64().unwrap_or(0.5) as f32;
118
119 let now = Utc::now();
120 let entry = MemoryEntry {
121 id: uuid::Uuid::new_v4().to_string(),
122 memory_type,
123 content,
124 source: "agent".to_string(),
125 session_id: None,
126 tags,
127 importance: importance.clamp(0.0, 1.0),
128 created_at: now,
129 accessed_at: now,
130 access_count: 0,
131 };
132 let entry_id = entry.id.clone();
133
134 match self.memory_manager.remember(entry).await {
135 Ok(_) => Ok(AgentToolResult::success(format!(
136 "Memory entry saved (id: {}, type: {})",
137 entry_id, memory_type_str,
138 ))),
139 Err(e) => Ok(AgentToolResult::error(format!(
140 "Failed to write memory: {e}"
141 ))),
142 }
143 }
144}
145
146pub struct MemoryReadTool {
148 memory_manager: Arc<MemoryManager>,
149}
150
151impl MemoryReadTool {
152 pub fn new(memory_manager: Arc<MemoryManager>) -> Self {
154 Self { memory_manager }
155 }
156
157 pub fn from_kernel(kernel: &crate::kernel_handle::KernelHandle) -> Self {
161 Self::new(kernel.agents.memory_manager().clone())
162 }
163}
164
165impl std::fmt::Debug for MemoryReadTool {
166 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
167 f.debug_struct("MemoryReadTool").finish()
168 }
169}
170
171#[async_trait]
172impl AgentTool for MemoryReadTool {
173 fn name(&self) -> &str {
174 "memory_read"
175 }
176
177 fn label(&self) -> &str {
178 "Memory Read"
179 }
180
181 fn description(&self) -> &str {
182 "Read memory entries. Provide 'id' and 'memory_type' to read a specific entry, or just 'memory_type' to list entries of that type."
183 }
184
185 fn parameters_schema(&self) -> Value {
186 json!({
187 "type": "object",
188 "properties": {
189 "id": {
190 "type": "string",
191 "description": "Optional specific memory entry ID to read."
192 },
193 "memory_type": {
194 "type": "string",
195 "enum": ["fact", "episode", "knowledge"],
196 "description": "Type of memory to list (required if no id provided)"
197 },
198 "limit": {
199 "type": "integer",
200 "description": "Max entries to return when listing (default 10)"
201 }
202 }
203 })
204 }
205
206 async fn execute(
207 &self,
208 _tool_call_id: &str,
209 params: Value,
210 _signal: Option<tokio::sync::oneshot::Receiver<()>>,
211 _ctx: &ToolContext,
212 ) -> Result<AgentToolResult, oxi_sdk::ToolError> {
213 let limit = params["limit"].as_u64().unwrap_or(10) as usize;
214
215 if let Some(id) = params["id"].as_str() {
216 let memory_type_str = params["memory_type"].as_str().unwrap_or("fact");
218 let memory_type = parse_memory_type(memory_type_str);
219
220 match self.memory_manager.get(id, memory_type).await {
221 Ok(Some(entry)) => {
222 let output = format!(
223 "ID: {}\nType: {}\nSource: {}\nTags: {}\nImportance: {:.2}\nCreated: {}\n\n{}",
224 entry.id,
225 entry.memory_type.label(),
226 entry.source,
227 entry.tags.join(", "),
228 entry.importance,
229 entry.created_at,
230 entry.content,
231 );
232 Ok(AgentToolResult::success(&output))
233 }
234 Ok(None) => Ok(AgentToolResult::error(format!(
235 "Memory entry '{}' not found",
236 id
237 ))),
238 Err(e) => Ok(AgentToolResult::error(format!(
239 "Failed to read memory: {e}"
240 ))),
241 }
242 } else {
243 let memory_type_str = params["memory_type"].as_str().unwrap_or("fact");
245 let memory_type = parse_memory_type(memory_type_str);
246
247 match self.memory_manager.list(memory_type, limit).await {
248 Ok(entries) => {
249 if entries.is_empty() {
250 return Ok(AgentToolResult::success(format!(
251 "No {} memory entries found.",
252 memory_type_str,
253 )));
254 }
255 let mut output =
256 format!("Found {} {} entries:\n\n", entries.len(), memory_type_str,);
257 for entry in &entries {
258 let preview = truncate_str(&entry.content, 100);
259 output.push_str(&format!(
260 "- [{}] {} (id: {}…, tags: {})\n",
261 entry.memory_type.label(),
262 preview,
263 &entry.id[..8.min(entry.id.len())],
264 entry.tags.join(", "),
265 ));
266 }
267 Ok(AgentToolResult::success(&output))
268 }
269 Err(e) => Ok(AgentToolResult::error(format!(
270 "Failed to list memory: {e}"
271 ))),
272 }
273 }
274 }
275}
276
277pub struct MemorySearchTool {
279 memory_manager: Arc<MemoryManager>,
280}
281
282impl MemorySearchTool {
283 pub fn new(memory_manager: Arc<MemoryManager>) -> Self {
285 Self { memory_manager }
286 }
287
288 pub fn from_kernel(kernel: &crate::kernel_handle::KernelHandle) -> Self {
292 Self::new(kernel.agents.memory_manager().clone())
293 }
294}
295
296impl std::fmt::Debug for MemorySearchTool {
297 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
298 f.debug_struct("MemorySearchTool").finish()
299 }
300}
301
302#[async_trait]
303impl AgentTool for MemorySearchTool {
304 fn name(&self) -> &str {
305 "memory_search"
306 }
307
308 fn label(&self) -> &str {
309 "Memory Search"
310 }
311
312 fn description(&self) -> &str {
313 "Search memory entries by keyword query. Optionally filter by memory type."
314 }
315
316 fn parameters_schema(&self) -> Value {
317 json!({
318 "type": "object",
319 "properties": {
320 "query": {
321 "type": "string",
322 "description": "Text to search for in memory content"
323 },
324 "memory_type": {
325 "type": "string",
326 "enum": ["fact", "episode", "knowledge", "conversation", "session"],
327 "description": "Optional memory type to filter by"
328 },
329 "limit": {
330 "type": "integer",
331 "description": "Max results (default 10)"
332 }
333 },
334 "required": ["query"]
335 })
336 }
337
338 async fn execute(
339 &self,
340 _tool_call_id: &str,
341 params: Value,
342 _signal: Option<tokio::sync::oneshot::Receiver<()>>,
343 _ctx: &ToolContext,
344 ) -> Result<AgentToolResult, oxi_sdk::ToolError> {
345 let query = params["query"].as_str().unwrap_or("");
346 if query.is_empty() {
347 return Ok(AgentToolResult::error("query is required"));
348 }
349
350 let limit = params["limit"].as_u64().unwrap_or(10) as usize;
351
352 let memory_type = params["memory_type"].as_str().map(parse_memory_type);
353
354 match self.memory_manager.search(query, memory_type, limit).await {
355 Ok(entries) => {
356 if entries.is_empty() {
357 return Ok(AgentToolResult::success(
358 "No matching memory entries found.",
359 ));
360 }
361 let mut output = format!("Found {} matching entries:\n\n", entries.len());
362 for entry in &entries {
363 let preview = truncate_str(&entry.content, 100);
364 output.push_str(&format!(
365 "- [{}] {} (id: {}…, importance: {:.2}, tags: {})\n",
366 entry.memory_type.label(),
367 preview,
368 &entry.id[..8.min(entry.id.len())],
369 entry.importance,
370 entry.tags.join(", "),
371 ));
372 }
373 Ok(AgentToolResult::success(&output))
374 }
375 Err(e) => Ok(AgentToolResult::error(format!(
376 "Failed to search memory: {e}"
377 ))),
378 }
379 }
380}
381
382fn parse_memory_type(s: &str) -> MemoryType {
384 match s {
385 "conversation" => MemoryType::Conversation,
386 "session" => MemoryType::Session,
387 "fact" => MemoryType::Fact,
388 "episode" => MemoryType::Episode,
389 "knowledge" => MemoryType::Knowledge,
390 _ => MemoryType::Fact,
391 }
392}
393
394fn truncate_str(s: &str, max_chars: usize) -> &str {
396 if s.len() <= max_chars {
397 return s;
398 }
399 let mut boundary = max_chars;
401 while boundary > 0 && !s.is_char_boundary(boundary) {
402 boundary -= 1;
403 }
404 &s[..boundary]
405}
406
407#[cfg(test)]
408mod tests {
409 use super::*;
410
411 #[test]
412 fn test_truncate_str_ascii() {
413 assert_eq!(truncate_str("hello world", 5), "hello");
414 assert_eq!(truncate_str("hello", 10), "hello");
415 assert_eq!(truncate_str("", 5), "");
416 }
417
418 #[test]
419 fn test_truncate_str_utf8_korean() {
420 let korean = "안녕하세요"; assert_eq!(truncate_str(korean, 6), "안녕"); assert_eq!(truncate_str(korean, 7), "안녕"); assert_eq!(truncate_str(korean, 15), "안녕하세요");
425 }
426
427 #[test]
428 fn test_truncate_str_mixed() {
429 let mixed = "Hi 안녕"; assert_eq!(truncate_str(mixed, 3), "Hi ");
431 assert_eq!(truncate_str(mixed, 4), "Hi "); }
433
434 #[test]
435 fn test_parse_memory_type() {
436 assert!(matches!(parse_memory_type("fact"), MemoryType::Fact));
437 assert!(matches!(parse_memory_type("episode"), MemoryType::Episode));
438 assert!(matches!(
439 parse_memory_type("knowledge"),
440 MemoryType::Knowledge
441 ));
442 assert!(matches!(
443 parse_memory_type("conversation"),
444 MemoryType::Conversation
445 ));
446 assert!(matches!(parse_memory_type("session"), MemoryType::Session));
447 assert!(matches!(parse_memory_type("unknown"), MemoryType::Fact));
448 }
449
450 fn make_test_mm() -> std::sync::Arc<crate::memory::MemoryManager> {
451 let dir = std::env::temp_dir().join(format!("test-memory-{}", uuid::Uuid::new_v4()));
452 let state_store = std::sync::Arc::new(
453 crate::state_store::StateStore::new(dir).expect("test state store"),
454 );
455 std::sync::Arc::new(crate::memory::MemoryManager::new(state_store))
456 }
457
458 #[test]
459 fn test_memory_write_tool_schema() {
460 let mm = make_test_mm();
461 let tool = MemoryWriteTool::new(mm);
462 assert_eq!(tool.name(), "memory_write");
463 let schema = tool.parameters_schema();
464 assert!(schema["required"].is_array());
465 }
466
467 #[test]
468 fn test_memory_read_tool_schema() {
469 let mm = make_test_mm();
470 let tool = MemoryReadTool::new(mm);
471 assert_eq!(tool.name(), "memory_read");
472 }
473
474 #[test]
475 fn test_memory_search_tool_schema() {
476 let mm = make_test_mm();
477 let tool = MemorySearchTool::new(mm);
478 assert_eq!(tool.name(), "memory_search");
479 let schema = tool.parameters_schema();
480 assert!(schema["required"].is_array());
481 }
482}