Skip to main content

codetether_agent/tool/
memory.rs

1//! Memory tool: Persistent knowledge capture and retrieval
2//!
3//! Allows agents to store important insights, learnings, and decisions
4//! that persist across sessions for future reference.
5
6use super::{Tool, ToolResult};
7use anyhow::Result;
8use async_trait::async_trait;
9use chrono::{DateTime, Utc};
10use serde::{Deserialize, Serialize};
11use serde_json::{Value, json};
12use std::collections::HashMap;
13use std::path::PathBuf;
14use tokio::fs;
15
16/// A single memory entry
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct MemoryEntry {
19    /// Unique identifier
20    pub id: String,
21    /// The memory content
22    pub content: String,
23    /// Tags for categorization/search
24    pub tags: Vec<String>,
25    /// When this memory was created
26    pub created_at: DateTime<Utc>,
27    /// When this memory was last accessed
28    pub accessed_at: DateTime<Utc>,
29    /// How many times this memory has been accessed
30    pub access_count: u64,
31    /// Optional project/context scope
32    pub scope: Option<String>,
33    /// Source of the memory (session_id, tool, etc.)
34    pub source: Option<String>,
35    /// Importance level (1-5)
36    pub importance: u8,
37}
38
39impl MemoryEntry {
40    pub fn new(content: impl Into<String>, tags: Vec<String>) -> Self {
41        let now = Utc::now();
42        Self {
43            id: uuid::Uuid::new_v4().to_string(),
44            content: content.into(),
45            tags,
46            created_at: now,
47            accessed_at: now,
48            access_count: 0,
49            scope: None,
50            source: None,
51            importance: 3, // default medium importance
52        }
53    }
54
55    pub fn with_scope(mut self, scope: impl Into<String>) -> Self {
56        self.scope = Some(scope.into());
57        self
58    }
59
60    /// Set the source of this memory entry
61    #[allow(dead_code)]
62    pub fn with_source(mut self, source: impl Into<String>) -> Self {
63        self.source = Some(source.into());
64        self
65    }
66
67    pub fn with_importance(mut self, importance: u8) -> Self {
68        self.importance = importance.min(5);
69        self
70    }
71
72    pub fn touch(&mut self) {
73        self.accessed_at = Utc::now();
74        self.access_count += 1;
75    }
76}
77
78/// Memory store for persistence
79#[derive(Debug, Clone, Serialize, Deserialize, Default)]
80pub struct MemoryStore {
81    entries: HashMap<String, MemoryEntry>,
82}
83
84impl MemoryStore {
85    /// Get the default memory file path
86    pub fn default_path() -> std::path::PathBuf {
87        crate::config::Config::data_dir()
88            .map(|p| p.join("memory.json"))
89            .unwrap_or_else(|| PathBuf::from(".codetether-agent/memory.json"))
90    }
91
92    /// Load from disk
93    pub async fn load() -> Result<Self> {
94        let path = Self::default_path();
95        if !path.exists() {
96            return Ok(Self::default());
97        }
98        let content = fs::read_to_string(&path).await?;
99        let store: MemoryStore = serde_json::from_str(&content)?;
100        Ok(store)
101    }
102
103    /// Save to disk
104    pub async fn save(&self) -> Result<()> {
105        let path = Self::default_path();
106        if let Some(parent) = path.parent() {
107            fs::create_dir_all(parent).await?;
108        }
109        let content = serde_json::to_string_pretty(self)?;
110        fs::write(&path, content).await?;
111        Ok(())
112    }
113
114    /// Add a new memory
115    pub fn add(&mut self, entry: MemoryEntry) -> String {
116        let id = entry.id.clone();
117        self.entries.insert(id.clone(), entry);
118        id
119    }
120
121    /// Get a memory by ID
122    pub fn get(&mut self, id: &str) -> Option<MemoryEntry> {
123        let entry = self.entries.get_mut(id)?;
124        entry.touch();
125        Some(entry.clone())
126    }
127
128    /// Search memories by query or tags
129    pub fn search(
130        &mut self,
131        query: Option<&str>,
132        tags: Option<&[String]>,
133        limit: usize,
134    ) -> Vec<MemoryEntry> {
135        let mut results: Vec<MemoryEntry> = self
136            .entries
137            .values_mut()
138            .filter(|entry| {
139                // Filter by tags if provided
140                if let Some(search_tags) = tags
141                    && !search_tags.is_empty()
142                    && !search_tags.iter().any(|t| entry.tags.contains(t))
143                {
144                    return false;
145                }
146
147                // Filter by query if provided
148                if let Some(q) = query {
149                    let q_lower = q.to_lowercase();
150                    let matches_content = entry.content.to_lowercase().contains(&q_lower);
151                    let matches_tags = entry
152                        .tags
153                        .iter()
154                        .any(|t| t.to_lowercase().contains(&q_lower));
155                    if !matches_content && !matches_tags {
156                        return false;
157                    }
158                }
159
160                true
161            })
162            .map(|e| {
163                e.touch();
164                e.clone()
165            })
166            .collect();
167
168        // Sort by importance (descending) then access_count (descending)
169        results.sort_by(|a, b| {
170            b.importance
171                .cmp(&a.importance)
172                .then_with(|| b.access_count.cmp(&a.access_count))
173        });
174
175        results.truncate(limit);
176        results
177    }
178
179    /// Get all tags with counts
180    pub fn all_tags(&self) -> HashMap<String, u64> {
181        let mut tags: HashMap<String, u64> = HashMap::new();
182        for entry in self.entries.values() {
183            for tag in &entry.tags {
184                *tags.entry(tag.clone()).or_insert(0) += 1;
185            }
186        }
187        tags
188    }
189
190    /// Delete a memory
191    pub fn delete(&mut self, id: &str) -> bool {
192        self.entries.remove(id).is_some()
193    }
194
195    /// Get statistics
196    pub fn stats(&self) -> MemoryStats {
197        let total = self.entries.len();
198        let total_accesses: u64 = self.entries.values().map(|e| e.access_count).sum();
199        let tags = self.all_tags();
200        MemoryStats {
201            total_entries: total,
202            total_accesses,
203            unique_tags: tags.len(),
204            tags,
205        }
206    }
207}
208
209#[derive(Debug, Clone, Serialize, Deserialize)]
210pub struct MemoryStats {
211    pub total_entries: usize,
212    pub total_accesses: u64,
213    pub unique_tags: usize,
214    pub tags: HashMap<String, u64>,
215}
216
217/// Memory Tool - Store and retrieve persistent knowledge
218pub struct MemoryTool {
219    store: tokio::sync::Mutex<MemoryStore>,
220    initialized: std::sync::atomic::AtomicBool,
221}
222
223impl Default for MemoryTool {
224    fn default() -> Self {
225        Self::new()
226    }
227}
228
229impl MemoryTool {
230    pub fn new() -> Self {
231        Self {
232            store: tokio::sync::Mutex::new(MemoryStore::default()),
233            initialized: std::sync::atomic::AtomicBool::new(false),
234        }
235    }
236
237    /// Initialize store from disk (once)
238    pub async fn init(&self) -> Result<()> {
239        use std::sync::atomic::Ordering;
240
241        if self.initialized.load(Ordering::SeqCst) {
242            return Ok(());
243        }
244
245        let mut store = self.store.lock().await;
246        if let Ok(loaded) = MemoryStore::load().await {
247            *store = loaded;
248        }
249        self.initialized.store(true, Ordering::SeqCst);
250        Ok(())
251    }
252
253    /// Persist store to disk
254    pub async fn persist(&self) -> Result<()> {
255        let store = self.store.lock().await;
256        store.save().await
257    }
258}
259
260#[async_trait]
261impl Tool for MemoryTool {
262    fn id(&self) -> &str {
263        "memory"
264    }
265
266    fn name(&self) -> &str {
267        "Memory"
268    }
269
270    fn description(&self) -> &str {
271        "Store and retrieve persistent knowledge across sessions. Use 'save' to capture important insights, 'search' to find relevant memories, 'list' to see all entries, 'tags' to see available categories, or 'delete' to remove an entry."
272    }
273
274    fn parameters(&self) -> Value {
275        json!({
276            "type": "object",
277            "properties": {
278                "action": {
279                    "type": "string",
280                    "description": "Action to perform: 'save' (store new memory), 'search' (find memories), 'get' (retrieve specific memory), 'list' (show recent), 'tags' (show categories), 'delete' (remove), 'stats' (show statistics)",
281                    "enum": ["save", "search", "get", "list", "tags", "delete", "stats"]
282                },
283                "content": {
284                    "type": "string",
285                    "description": "Memory content to save (required for 'save' action)"
286                },
287                "tags": {
288                    "type": "array",
289                    "items": {"type": "string"},
290                    "description": "Tags for categorization (optional for 'save')"
291                },
292                "query": {
293                    "type": "string",
294                    "description": "Search query (for 'search' action)"
295                },
296                "scope": {
297                    "type": "string",
298                    "description": "Project/context scope (optional for 'save')"
299                },
300                "importance": {
301                    "type": "integer",
302                    "description": "Importance level 1-5 (optional for 'save', default 3)"
303                },
304                "id": {
305                    "type": "string",
306                    "description": "Memory ID (required for 'get' and 'delete')"
307                },
308                "limit": {
309                    "type": "integer",
310                    "description": "Maximum results to return (default 10, for 'search' and 'list')"
311                }
312            },
313            "required": ["action"]
314        })
315    }
316
317    async fn execute(&self, args: Value) -> Result<ToolResult> {
318        // Initialize store from disk if not already loaded
319        // Use a flag to avoid reloading on every call
320        let needs_init = {
321            let store = self.store.lock().await;
322            store.entries.is_empty()
323        };
324
325        if needs_init {
326            self.init().await.ok();
327        }
328
329        let action = args["action"]
330            .as_str()
331            .ok_or_else(|| anyhow::anyhow!("action is required"))?;
332
333        match action {
334            "save" => self.execute_save(args).await,
335            "search" => self.execute_search(args).await,
336            "get" => self.execute_get(args).await,
337            "list" => self.execute_list(args).await,
338            "tags" => self.execute_tags(args).await,
339            "delete" => self.execute_delete(args).await,
340            "stats" => self.execute_stats(args).await,
341            _ => Ok(ToolResult::error(format!(
342                "Unknown action: {}. Use 'save', 'search', 'get', 'list', 'tags', 'delete', or 'stats'.",
343                action
344            ))),
345        }
346    }
347}
348
349impl MemoryTool {
350    async fn execute_save(&self, args: Value) -> Result<ToolResult> {
351        let content = args["content"]
352            .as_str()
353            .ok_or_else(|| anyhow::anyhow!("content is required for 'save' action"))?;
354
355        let tags: Vec<String> = args["tags"]
356            .as_array()
357            .map(|arr| {
358                arr.iter()
359                    .filter_map(|v| v.as_str().map(String::from))
360                    .collect()
361            })
362            .unwrap_or_default();
363
364        let scope = args["scope"].as_str().map(String::from);
365        let importance = args["importance"].as_u64().map(|v| v as u8).unwrap_or(3);
366
367        let mut entry = MemoryEntry::new(content, tags).with_importance(importance);
368
369        if let Some(s) = scope {
370            entry = entry.with_scope(s);
371        }
372
373        let id = {
374            let mut store = self.store.lock().await;
375            store.add(entry)
376        };
377
378        // Persist to disk
379        self.persist().await?;
380
381        Ok(ToolResult::success(format!(
382            "Memory saved with ID: {}\nImportance: {}/5",
383            id, importance
384        )))
385    }
386
387    async fn execute_search(&self, args: Value) -> Result<ToolResult> {
388        let query = args["query"].as_str();
389        let tags: Option<Vec<String>> = args["tags"].as_array().map(|arr| {
390            arr.iter()
391                .filter_map(|v| v.as_str().map(String::from))
392                .collect()
393        });
394        let limit = args["limit"].as_u64().map(|v| v as usize).unwrap_or(10);
395
396        let tags_ref = tags.as_deref();
397
398        let results = {
399            let mut store = self.store.lock().await;
400            store.search(query, tags_ref, limit)
401        };
402
403        if results.is_empty() {
404            return Ok(ToolResult::success(
405                "No memories found matching your criteria.".to_string(),
406            ));
407        }
408
409        let output = results
410            .iter()
411            .enumerate()
412            .map(|(i, m)| {
413                format!(
414                    "{}. [{}] {} - {}\n   Tags: {}\n   Created: {}",
415                    i + 1,
416                    m.id.chars().take(8).collect::<String>(),
417                    m.content.chars().take(80).collect::<String>()
418                        + if m.content.len() > 80 { "..." } else { "" },
419                    format!("accessed {} times", m.access_count),
420                    m.tags.join(", "),
421                    m.created_at.format("%Y-%m-%d %H:%M")
422                )
423            })
424            .collect::<Vec<_>>()
425            .join("\n\n");
426
427        Ok(ToolResult::success(format!(
428            "Found {} memories:\n\n{}",
429            results.len(),
430            output
431        )))
432    }
433
434    async fn execute_get(&self, args: Value) -> Result<ToolResult> {
435        let id = args["id"]
436            .as_str()
437            .ok_or_else(|| anyhow::anyhow!("id is required for 'get' action"))?;
438
439        let entry = {
440            let mut store = self.store.lock().await;
441            store.get(id)
442        };
443
444        match entry {
445            Some(e) => {
446                // Persist the updated access count
447                self.persist().await?;
448
449                Ok(ToolResult::success(format!(
450                    "Memory ID: {}\nImportance: {}/5\nTags: {}\nCreated: {}\nAccessed: {} times\n\n{}",
451                    e.id,
452                    e.importance,
453                    e.tags.join(", "),
454                    e.created_at.format("%Y-%m-%d %H:%M:%S"),
455                    e.access_count,
456                    e.content
457                )))
458            }
459            None => Ok(ToolResult::error(format!("Memory not found: {}", id))),
460        }
461    }
462
463    async fn execute_list(&self, args: Value) -> Result<ToolResult> {
464        let limit = args["limit"].as_u64().map(|v| v as usize).unwrap_or(10);
465
466        let results = {
467            let mut store = self.store.lock().await;
468            store.search(None, None, limit)
469        };
470
471        if results.is_empty() {
472            return Ok(ToolResult::success(
473                "No memories stored yet. Use 'save' to add your first memory.".to_string(),
474            ));
475        }
476
477        let output = results
478            .iter()
479            .enumerate()
480            .map(|(i, m)| {
481                format!(
482                    "{}. [{}] {} (importance: {}/5, accessed: {}x)",
483                    i + 1,
484                    m.id.chars().take(8).collect::<String>(),
485                    m.content.chars().take(60).collect::<String>()
486                        + if m.content.len() > 60 { "..." } else { "" },
487                    m.importance,
488                    m.access_count
489                )
490            })
491            .collect::<Vec<_>>()
492            .join("\n");
493
494        Ok(ToolResult::success(format!(
495            "Recent memories:\n\n{}",
496            output
497        )))
498    }
499
500    async fn execute_tags(&self, _args: Value) -> Result<ToolResult> {
501        let tags = {
502            let store = self.store.lock().await;
503            store.all_tags()
504        };
505
506        if tags.is_empty() {
507            return Ok(ToolResult::success(
508                "No tags yet. Add tags when saving memories.".to_string(),
509            ));
510        }
511
512        let mut sorted: Vec<_> = tags.iter().collect();
513        sorted.sort_by(|a, b| b.1.cmp(a.1));
514
515        let output = sorted
516            .iter()
517            .map(|(tag, count)| format!("  {} ({} memories)", tag, count))
518            .collect::<Vec<_>>()
519            .join("\n");
520
521        Ok(ToolResult::success(format!(
522            "Available tags:\n\n{}",
523            output
524        )))
525    }
526
527    async fn execute_delete(&self, args: Value) -> Result<ToolResult> {
528        let id = args["id"]
529            .as_str()
530            .ok_or_else(|| anyhow::anyhow!("id is required for 'delete' action"))?;
531
532        let deleted = {
533            let mut store = self.store.lock().await;
534            store.delete(id)
535        };
536
537        if deleted {
538            self.persist().await?;
539            Ok(ToolResult::success(format!("Memory deleted: {}", id)))
540        } else {
541            Ok(ToolResult::error(format!("Memory not found: {}", id)))
542        }
543    }
544
545    async fn execute_stats(&self, _args: Value) -> Result<ToolResult> {
546        let stats = {
547            let store = self.store.lock().await;
548            store.stats()
549        };
550
551        let tags_output = if stats.tags.is_empty() {
552            "None".to_string()
553        } else {
554            let mut sorted: Vec<_> = stats.tags.iter().collect();
555            sorted.sort_by(|a, b| b.1.cmp(a.1));
556            sorted
557                .iter()
558                .take(10)
559                .map(|(t, c)| format!("  {}: {}", t, c))
560                .collect::<Vec<_>>()
561                .join("\n")
562        };
563
564        Ok(ToolResult::success(format!(
565            "Memory Statistics:\n\n\
566             Total entries: {}\n\
567             Total accesses: {}\n\
568             Unique tags: {}\n\n\
569             Top tags:\n{}",
570            stats.total_entries, stats.total_accesses, stats.unique_tags, tags_output
571        )))
572    }
573}
574
575#[cfg(test)]
576mod tests {
577    use super::*;
578    use std::sync::atomic::Ordering;
579
580    #[tokio::test]
581    async fn test_memory_save_and_get() {
582        let tool = MemoryTool::new();
583        // Mark as initialized to prevent loading from disk (isolated test)
584        tool.initialized.store(true, Ordering::SeqCst);
585
586        // Save a memory
587        let result = tool
588            .execute(json!({
589                "action": "save",
590                "content": "Test memory content",
591                "tags": ["test", "example"],
592                "importance": 4
593            }))
594            .await
595            .unwrap();
596
597        assert!(result.success);
598
599        // List memories
600        let result = tool
601            .execute(json!({
602                "action": "list",
603                "limit": 5
604            }))
605            .await
606            .unwrap();
607
608        assert!(result.success);
609        assert!(result.output.contains("Test memory content"));
610
611        // Get stats
612        let result = tool
613            .execute(json!({
614                "action": "stats"
615            }))
616            .await
617            .unwrap();
618
619        assert!(result.success);
620        assert!(result.output.contains("Total entries: 1"));
621    }
622
623    #[tokio::test]
624    async fn test_memory_search() {
625        let tool = MemoryTool::new();
626        // Mark as initialized to prevent loading from disk (isolated test)
627        tool.initialized.store(true, Ordering::SeqCst);
628
629        // Save with specific tags
630        tool.execute(json!({
631            "action": "save",
632            "content": "Rust programming insights",
633            "tags": ["rust", "programming"]
634        }))
635        .await
636        .unwrap();
637
638        tool.execute(json!({
639            "action": "save",
640            "content": "Python tips",
641            "tags": ["python", "programming"]
642        }))
643        .await
644        .unwrap();
645
646        // Search by tag
647        let result = tool
648            .execute(json!({
649                "action": "search",
650                "tags": ["rust"]
651            }))
652            .await
653            .unwrap();
654
655        assert!(result.success);
656        assert!(result.output.contains("Rust"));
657        assert!(!result.output.contains("Python"));
658    }
659}