1use super::{Tool, ToolResult};
7use anyhow::Result;
8use async_trait::async_trait;
9use chrono::{DateTime, Utc};
10use serde::{Deserialize, Serialize};
11use serde_json::{Value, json};
12use std::collections::HashMap;
13use std::path::PathBuf;
14use tokio::fs;
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct MemoryEntry {
19 pub id: String,
21 pub content: String,
23 pub tags: Vec<String>,
25 pub created_at: DateTime<Utc>,
27 pub accessed_at: DateTime<Utc>,
29 pub access_count: u64,
31 pub scope: Option<String>,
33 pub source: Option<String>,
35 pub importance: u8,
37}
38
39impl MemoryEntry {
40 pub fn new(content: impl Into<String>, tags: Vec<String>) -> Self {
41 let now = Utc::now();
42 Self {
43 id: uuid::Uuid::new_v4().to_string(),
44 content: content.into(),
45 tags,
46 created_at: now,
47 accessed_at: now,
48 access_count: 0,
49 scope: None,
50 source: None,
51 importance: 3, }
53 }
54
55 pub fn with_scope(mut self, scope: impl Into<String>) -> Self {
56 self.scope = Some(scope.into());
57 self
58 }
59
60 #[allow(dead_code)]
62 pub fn with_source(mut self, source: impl Into<String>) -> Self {
63 self.source = Some(source.into());
64 self
65 }
66
67 pub fn with_importance(mut self, importance: u8) -> Self {
68 self.importance = importance.min(5);
69 self
70 }
71
72 pub fn touch(&mut self) {
73 self.accessed_at = Utc::now();
74 self.access_count += 1;
75 }
76}
77
78#[derive(Debug, Clone, Serialize, Deserialize, Default)]
80pub struct MemoryStore {
81 entries: HashMap<String, MemoryEntry>,
82}
83
84impl MemoryStore {
85 pub fn default_path() -> std::path::PathBuf {
87 crate::config::Config::data_dir()
88 .map(|p| p.join("memory.json"))
89 .unwrap_or_else(|| PathBuf::from(".codetether-agent/memory.json"))
90 }
91
92 pub async fn load() -> Result<Self> {
94 let path = Self::default_path();
95 if !path.exists() {
96 return Ok(Self::default());
97 }
98 let content = fs::read_to_string(&path).await?;
99 let store: MemoryStore = serde_json::from_str(&content)?;
100 Ok(store)
101 }
102
103 pub async fn save(&self) -> Result<()> {
105 let path = Self::default_path();
106 if let Some(parent) = path.parent() {
107 fs::create_dir_all(parent).await?;
108 }
109 let content = serde_json::to_string_pretty(self)?;
110 fs::write(&path, content).await?;
111 Ok(())
112 }
113
114 pub fn add(&mut self, entry: MemoryEntry) -> String {
116 let id = entry.id.clone();
117 self.entries.insert(id.clone(), entry);
118 id
119 }
120
121 pub fn get(&mut self, id: &str) -> Option<MemoryEntry> {
123 let entry = self.entries.get_mut(id)?;
124 entry.touch();
125 Some(entry.clone())
126 }
127
128 pub fn search(
130 &mut self,
131 query: Option<&str>,
132 tags: Option<&[String]>,
133 limit: usize,
134 ) -> Vec<MemoryEntry> {
135 let mut results: Vec<MemoryEntry> = self
136 .entries
137 .values_mut()
138 .filter(|entry| {
139 if let Some(search_tags) = tags
141 && !search_tags.is_empty()
142 && !search_tags.iter().any(|t| entry.tags.contains(t))
143 {
144 return false;
145 }
146
147 if let Some(q) = query {
149 let q_lower = q.to_lowercase();
150 let matches_content = entry.content.to_lowercase().contains(&q_lower);
151 let matches_tags = entry
152 .tags
153 .iter()
154 .any(|t| t.to_lowercase().contains(&q_lower));
155 if !matches_content && !matches_tags {
156 return false;
157 }
158 }
159
160 true
161 })
162 .map(|e| {
163 e.touch();
164 e.clone()
165 })
166 .collect();
167
168 results.sort_by(|a, b| {
170 b.importance
171 .cmp(&a.importance)
172 .then_with(|| b.access_count.cmp(&a.access_count))
173 });
174
175 results.truncate(limit);
176 results
177 }
178
179 pub fn all_tags(&self) -> HashMap<String, u64> {
181 let mut tags: HashMap<String, u64> = HashMap::new();
182 for entry in self.entries.values() {
183 for tag in &entry.tags {
184 *tags.entry(tag.clone()).or_insert(0) += 1;
185 }
186 }
187 tags
188 }
189
190 pub fn delete(&mut self, id: &str) -> bool {
192 self.entries.remove(id).is_some()
193 }
194
195 pub fn stats(&self) -> MemoryStats {
197 let total = self.entries.len();
198 let total_accesses: u64 = self.entries.values().map(|e| e.access_count).sum();
199 let tags = self.all_tags();
200 MemoryStats {
201 total_entries: total,
202 total_accesses,
203 unique_tags: tags.len(),
204 tags,
205 }
206 }
207}
208
209#[derive(Debug, Clone, Serialize, Deserialize)]
210pub struct MemoryStats {
211 pub total_entries: usize,
212 pub total_accesses: u64,
213 pub unique_tags: usize,
214 pub tags: HashMap<String, u64>,
215}
216
217pub struct MemoryTool {
219 store: tokio::sync::Mutex<MemoryStore>,
220 initialized: std::sync::atomic::AtomicBool,
221}
222
223impl Default for MemoryTool {
224 fn default() -> Self {
225 Self::new()
226 }
227}
228
229impl MemoryTool {
230 pub fn new() -> Self {
231 Self {
232 store: tokio::sync::Mutex::new(MemoryStore::default()),
233 initialized: std::sync::atomic::AtomicBool::new(false),
234 }
235 }
236
237 pub async fn init(&self) -> Result<()> {
239 use std::sync::atomic::Ordering;
240
241 if self.initialized.load(Ordering::SeqCst) {
242 return Ok(());
243 }
244
245 let mut store = self.store.lock().await;
246 if let Ok(loaded) = MemoryStore::load().await {
247 *store = loaded;
248 }
249 self.initialized.store(true, Ordering::SeqCst);
250 Ok(())
251 }
252
253 pub async fn persist(&self) -> Result<()> {
255 let store = self.store.lock().await;
256 store.save().await
257 }
258}
259
260#[async_trait]
261impl Tool for MemoryTool {
262 fn id(&self) -> &str {
263 "memory"
264 }
265
266 fn name(&self) -> &str {
267 "Memory"
268 }
269
270 fn description(&self) -> &str {
271 "Store and retrieve persistent knowledge across sessions. Use 'save' to capture important insights, 'search' to find relevant memories, 'list' to see all entries, 'tags' to see available categories, or 'delete' to remove an entry."
272 }
273
274 fn parameters(&self) -> Value {
275 json!({
276 "type": "object",
277 "properties": {
278 "action": {
279 "type": "string",
280 "description": "Action to perform: 'save' (store new memory), 'search' (find memories), 'get' (retrieve specific memory), 'list' (show recent), 'tags' (show categories), 'delete' (remove), 'stats' (show statistics)",
281 "enum": ["save", "search", "get", "list", "tags", "delete", "stats"]
282 },
283 "content": {
284 "type": "string",
285 "description": "Memory content to save (required for 'save' action)"
286 },
287 "tags": {
288 "type": "array",
289 "items": {"type": "string"},
290 "description": "Tags for categorization (optional for 'save')"
291 },
292 "query": {
293 "type": "string",
294 "description": "Search query (for 'search' action)"
295 },
296 "scope": {
297 "type": "string",
298 "description": "Project/context scope (optional for 'save')"
299 },
300 "importance": {
301 "type": "integer",
302 "description": "Importance level 1-5 (optional for 'save', default 3)"
303 },
304 "id": {
305 "type": "string",
306 "description": "Memory ID (required for 'get' and 'delete')"
307 },
308 "limit": {
309 "type": "integer",
310 "description": "Maximum results to return (default 10, for 'search' and 'list')"
311 }
312 },
313 "required": ["action"]
314 })
315 }
316
317 async fn execute(&self, args: Value) -> Result<ToolResult> {
318 let needs_init = {
321 let store = self.store.lock().await;
322 store.entries.is_empty()
323 };
324
325 if needs_init {
326 self.init().await.ok();
327 }
328
329 let action = args["action"]
330 .as_str()
331 .ok_or_else(|| anyhow::anyhow!("action is required"))?;
332
333 match action {
334 "save" => self.execute_save(args).await,
335 "search" => self.execute_search(args).await,
336 "get" => self.execute_get(args).await,
337 "list" => self.execute_list(args).await,
338 "tags" => self.execute_tags(args).await,
339 "delete" => self.execute_delete(args).await,
340 "stats" => self.execute_stats(args).await,
341 _ => Ok(ToolResult::error(format!(
342 "Unknown action: {}. Use 'save', 'search', 'get', 'list', 'tags', 'delete', or 'stats'.",
343 action
344 ))),
345 }
346 }
347}
348
349impl MemoryTool {
350 async fn execute_save(&self, args: Value) -> Result<ToolResult> {
351 let content = args["content"]
352 .as_str()
353 .ok_or_else(|| anyhow::anyhow!("content is required for 'save' action"))?;
354
355 let tags: Vec<String> = args["tags"]
356 .as_array()
357 .map(|arr| {
358 arr.iter()
359 .filter_map(|v| v.as_str().map(String::from))
360 .collect()
361 })
362 .unwrap_or_default();
363
364 let scope = args["scope"].as_str().map(String::from);
365 let importance = args["importance"].as_u64().map(|v| v as u8).unwrap_or(3);
366
367 let mut entry = MemoryEntry::new(content, tags).with_importance(importance);
368
369 if let Some(s) = scope {
370 entry = entry.with_scope(s);
371 }
372
373 let id = {
374 let mut store = self.store.lock().await;
375 store.add(entry)
376 };
377
378 self.persist().await?;
380
381 Ok(ToolResult::success(format!(
382 "Memory saved with ID: {}\nImportance: {}/5",
383 id, importance
384 )))
385 }
386
387 async fn execute_search(&self, args: Value) -> Result<ToolResult> {
388 let query = args["query"].as_str();
389 let tags: Option<Vec<String>> = args["tags"].as_array().map(|arr| {
390 arr.iter()
391 .filter_map(|v| v.as_str().map(String::from))
392 .collect()
393 });
394 let limit = args["limit"].as_u64().map(|v| v as usize).unwrap_or(10);
395
396 let tags_ref = tags.as_deref();
397
398 let results = {
399 let mut store = self.store.lock().await;
400 store.search(query, tags_ref, limit)
401 };
402
403 if results.is_empty() {
404 return Ok(ToolResult::success(
405 "No memories found matching your criteria.".to_string(),
406 ));
407 }
408
409 let output = results
410 .iter()
411 .enumerate()
412 .map(|(i, m)| {
413 format!(
414 "{}. [{}] {} - {}\n Tags: {}\n Created: {}",
415 i + 1,
416 m.id.chars().take(8).collect::<String>(),
417 m.content.chars().take(80).collect::<String>()
418 + if m.content.len() > 80 { "..." } else { "" },
419 format!("accessed {} times", m.access_count),
420 m.tags.join(", "),
421 m.created_at.format("%Y-%m-%d %H:%M")
422 )
423 })
424 .collect::<Vec<_>>()
425 .join("\n\n");
426
427 Ok(ToolResult::success(format!(
428 "Found {} memories:\n\n{}",
429 results.len(),
430 output
431 )))
432 }
433
434 async fn execute_get(&self, args: Value) -> Result<ToolResult> {
435 let id = args["id"]
436 .as_str()
437 .ok_or_else(|| anyhow::anyhow!("id is required for 'get' action"))?;
438
439 let entry = {
440 let mut store = self.store.lock().await;
441 store.get(id)
442 };
443
444 match entry {
445 Some(e) => {
446 self.persist().await?;
448
449 Ok(ToolResult::success(format!(
450 "Memory ID: {}\nImportance: {}/5\nTags: {}\nCreated: {}\nAccessed: {} times\n\n{}",
451 e.id,
452 e.importance,
453 e.tags.join(", "),
454 e.created_at.format("%Y-%m-%d %H:%M:%S"),
455 e.access_count,
456 e.content
457 )))
458 }
459 None => Ok(ToolResult::error(format!("Memory not found: {}", id))),
460 }
461 }
462
463 async fn execute_list(&self, args: Value) -> Result<ToolResult> {
464 let limit = args["limit"].as_u64().map(|v| v as usize).unwrap_or(10);
465
466 let results = {
467 let mut store = self.store.lock().await;
468 store.search(None, None, limit)
469 };
470
471 if results.is_empty() {
472 return Ok(ToolResult::success(
473 "No memories stored yet. Use 'save' to add your first memory.".to_string(),
474 ));
475 }
476
477 let output = results
478 .iter()
479 .enumerate()
480 .map(|(i, m)| {
481 format!(
482 "{}. [{}] {} (importance: {}/5, accessed: {}x)",
483 i + 1,
484 m.id.chars().take(8).collect::<String>(),
485 m.content.chars().take(60).collect::<String>()
486 + if m.content.len() > 60 { "..." } else { "" },
487 m.importance,
488 m.access_count
489 )
490 })
491 .collect::<Vec<_>>()
492 .join("\n");
493
494 Ok(ToolResult::success(format!(
495 "Recent memories:\n\n{}",
496 output
497 )))
498 }
499
500 async fn execute_tags(&self, _args: Value) -> Result<ToolResult> {
501 let tags = {
502 let store = self.store.lock().await;
503 store.all_tags()
504 };
505
506 if tags.is_empty() {
507 return Ok(ToolResult::success(
508 "No tags yet. Add tags when saving memories.".to_string(),
509 ));
510 }
511
512 let mut sorted: Vec<_> = tags.iter().collect();
513 sorted.sort_by(|a, b| b.1.cmp(a.1));
514
515 let output = sorted
516 .iter()
517 .map(|(tag, count)| format!(" {} ({} memories)", tag, count))
518 .collect::<Vec<_>>()
519 .join("\n");
520
521 Ok(ToolResult::success(format!(
522 "Available tags:\n\n{}",
523 output
524 )))
525 }
526
527 async fn execute_delete(&self, args: Value) -> Result<ToolResult> {
528 let id = args["id"]
529 .as_str()
530 .ok_or_else(|| anyhow::anyhow!("id is required for 'delete' action"))?;
531
532 let deleted = {
533 let mut store = self.store.lock().await;
534 store.delete(id)
535 };
536
537 if deleted {
538 self.persist().await?;
539 Ok(ToolResult::success(format!("Memory deleted: {}", id)))
540 } else {
541 Ok(ToolResult::error(format!("Memory not found: {}", id)))
542 }
543 }
544
545 async fn execute_stats(&self, _args: Value) -> Result<ToolResult> {
546 let stats = {
547 let store = self.store.lock().await;
548 store.stats()
549 };
550
551 let tags_output = if stats.tags.is_empty() {
552 "None".to_string()
553 } else {
554 let mut sorted: Vec<_> = stats.tags.iter().collect();
555 sorted.sort_by(|a, b| b.1.cmp(a.1));
556 sorted
557 .iter()
558 .take(10)
559 .map(|(t, c)| format!(" {}: {}", t, c))
560 .collect::<Vec<_>>()
561 .join("\n")
562 };
563
564 Ok(ToolResult::success(format!(
565 "Memory Statistics:\n\n\
566 Total entries: {}\n\
567 Total accesses: {}\n\
568 Unique tags: {}\n\n\
569 Top tags:\n{}",
570 stats.total_entries, stats.total_accesses, stats.unique_tags, tags_output
571 )))
572 }
573}
574
575#[cfg(test)]
576mod tests {
577 use super::*;
578 use std::sync::atomic::Ordering;
579
580 #[tokio::test]
581 async fn test_memory_save_and_get() {
582 let tool = MemoryTool::new();
583 tool.initialized.store(true, Ordering::SeqCst);
585
586 let result = tool
588 .execute(json!({
589 "action": "save",
590 "content": "Test memory content",
591 "tags": ["test", "example"],
592 "importance": 4
593 }))
594 .await
595 .unwrap();
596
597 assert!(result.success);
598
599 let result = tool
601 .execute(json!({
602 "action": "list",
603 "limit": 5
604 }))
605 .await
606 .unwrap();
607
608 assert!(result.success);
609 assert!(result.output.contains("Test memory content"));
610
611 let result = tool
613 .execute(json!({
614 "action": "stats"
615 }))
616 .await
617 .unwrap();
618
619 assert!(result.success);
620 assert!(result.output.contains("Total entries: 1"));
621 }
622
623 #[tokio::test]
624 async fn test_memory_search() {
625 let tool = MemoryTool::new();
626 tool.initialized.store(true, Ordering::SeqCst);
628
629 tool.execute(json!({
631 "action": "save",
632 "content": "Rust programming insights",
633 "tags": ["rust", "programming"]
634 }))
635 .await
636 .unwrap();
637
638 tool.execute(json!({
639 "action": "save",
640 "content": "Python tips",
641 "tags": ["python", "programming"]
642 }))
643 .await
644 .unwrap();
645
646 let result = tool
648 .execute(json!({
649 "action": "search",
650 "tags": ["rust"]
651 }))
652 .await
653 .unwrap();
654
655 assert!(result.success);
656 assert!(result.output.contains("Rust"));
657 assert!(!result.output.contains("Python"));
658 }
659}