Skip to main content

codetether_agent/tool/
memory.rs

1//! Memory tool: Persistent knowledge capture and retrieval
2//!
3//! Allows agents to store important insights, learnings, and decisions
4//! that persist across sessions for future reference.
5
6use super::{Tool, ToolResult};
7use anyhow::Result;
8use async_trait::async_trait;
9use chrono::{DateTime, Utc};
10use serde::{Deserialize, Serialize};
11use serde_json::{Value, json};
12use std::collections::HashMap;
13use std::path::PathBuf;
14use tokio::fs;
15
16/// A single memory entry
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct MemoryEntry {
19    /// Unique identifier
20    pub id: String,
21    /// The memory content
22    pub content: String,
23    /// Tags for categorization/search
24    pub tags: Vec<String>,
25    /// When this memory was created
26    pub created_at: DateTime<Utc>,
27    /// When this memory was last accessed
28    pub accessed_at: DateTime<Utc>,
29    /// How many times this memory has been accessed
30    pub access_count: u64,
31    /// Optional project/context scope
32    pub scope: Option<String>,
33    /// Source of the memory (session_id, tool, etc.)
34    pub source: Option<String>,
35    /// Importance level (1-5)
36    pub importance: u8,
37}
38
39impl MemoryEntry {
40    pub fn new(content: impl Into<String>, tags: Vec<String>) -> Self {
41        let now = Utc::now();
42        Self {
43            id: uuid::Uuid::new_v4().to_string(),
44            content: content.into(),
45            tags,
46            created_at: now,
47            accessed_at: now,
48            access_count: 0,
49            scope: None,
50            source: None,
51            importance: 3, // default medium importance
52        }
53    }
54
55    pub fn with_scope(mut self, scope: impl Into<String>) -> Self {
56        self.scope = Some(scope.into());
57        self
58    }
59
60    pub fn with_source(mut self, source: impl Into<String>) -> Self {
61        self.source = Some(source.into());
62        self
63    }
64
65    pub fn with_importance(mut self, importance: u8) -> Self {
66        self.importance = importance.min(5);
67        self
68    }
69
70    pub fn touch(&mut self) {
71        self.accessed_at = Utc::now();
72        self.access_count += 1;
73    }
74}
75
76/// Memory store for persistence
77#[derive(Debug, Clone, Serialize, Deserialize, Default)]
78pub struct MemoryStore {
79    entries: HashMap<String, MemoryEntry>,
80}
81
82impl MemoryStore {
83    /// Get the default memory file path
84    pub fn default_path() -> std::path::PathBuf {
85        directories::ProjectDirs::from("com", "codetether", "codetether")
86            .map(|p| p.data_dir().join("memory.json"))
87            .unwrap_or_else(|| PathBuf::from(".codetether/memory.json"))
88    }
89
90    /// Load from disk
91    pub async fn load() -> Result<Self> {
92        let path = Self::default_path();
93        if !path.exists() {
94            return Ok(Self::default());
95        }
96        let content = fs::read_to_string(&path).await?;
97        let store: MemoryStore = serde_json::from_str(&content)?;
98        Ok(store)
99    }
100
101    /// Save to disk
102    pub async fn save(&self) -> Result<()> {
103        let path = Self::default_path();
104        if let Some(parent) = path.parent() {
105            fs::create_dir_all(parent).await?;
106        }
107        let content = serde_json::to_string_pretty(self)?;
108        fs::write(&path, content).await?;
109        Ok(())
110    }
111
112    /// Add a new memory
113    pub fn add(&mut self, entry: MemoryEntry) -> String {
114        let id = entry.id.clone();
115        self.entries.insert(id.clone(), entry);
116        id
117    }
118
119    /// Get a memory by ID
120    pub fn get(&mut self, id: &str) -> Option<MemoryEntry> {
121        let entry = self.entries.get_mut(id)?;
122        entry.touch();
123        Some(entry.clone())
124    }
125
126    /// Search memories by query or tags
127    pub fn search(
128        &mut self,
129        query: Option<&str>,
130        tags: Option<&[String]>,
131        limit: usize,
132    ) -> Vec<MemoryEntry> {
133        let mut results: Vec<MemoryEntry> = self
134            .entries
135            .values_mut()
136            .filter(|entry| {
137                // Filter by tags if provided
138                if let Some(search_tags) = tags {
139                    if !search_tags.is_empty()
140                        && !search_tags.iter().any(|t| entry.tags.contains(t))
141                    {
142                        return false;
143                    }
144                }
145
146                // Filter by query if provided
147                if let Some(q) = query {
148                    let q_lower = q.to_lowercase();
149                    let matches_content = entry.content.to_lowercase().contains(&q_lower);
150                    let matches_tags = entry
151                        .tags
152                        .iter()
153                        .any(|t| t.to_lowercase().contains(&q_lower));
154                    if !matches_content && !matches_tags {
155                        return false;
156                    }
157                }
158
159                true
160            })
161            .map(|e| {
162                e.touch();
163                e.clone()
164            })
165            .collect();
166
167        // Sort by importance (descending) then access_count (descending)
168        results.sort_by(|a, b| {
169            b.importance
170                .cmp(&a.importance)
171                .then_with(|| b.access_count.cmp(&a.access_count))
172        });
173
174        results.truncate(limit);
175        results
176    }
177
178    /// Get all tags with counts
179    pub fn all_tags(&self) -> HashMap<String, u64> {
180        let mut tags: HashMap<String, u64> = HashMap::new();
181        for entry in self.entries.values() {
182            for tag in &entry.tags {
183                *tags.entry(tag.clone()).or_insert(0) += 1;
184            }
185        }
186        tags
187    }
188
189    /// Delete a memory
190    pub fn delete(&mut self, id: &str) -> bool {
191        self.entries.remove(id).is_some()
192    }
193
194    /// Get statistics
195    pub fn stats(&self) -> MemoryStats {
196        let total = self.entries.len();
197        let total_accesses: u64 = self.entries.values().map(|e| e.access_count).sum();
198        let tags = self.all_tags();
199        MemoryStats {
200            total_entries: total,
201            total_accesses,
202            unique_tags: tags.len(),
203            tags,
204        }
205    }
206}
207
208#[derive(Debug, Clone, Serialize, Deserialize)]
209pub struct MemoryStats {
210    pub total_entries: usize,
211    pub total_accesses: u64,
212    pub unique_tags: usize,
213    pub tags: HashMap<String, u64>,
214}
215
216/// Memory Tool - Store and retrieve persistent knowledge
217pub struct MemoryTool {
218    store: tokio::sync::Mutex<MemoryStore>,
219    initialized: std::sync::atomic::AtomicBool,
220}
221
222impl Default for MemoryTool {
223    fn default() -> Self {
224        Self::new()
225    }
226}
227
228impl MemoryTool {
229    pub fn new() -> Self {
230        Self {
231            store: tokio::sync::Mutex::new(MemoryStore::default()),
232            initialized: std::sync::atomic::AtomicBool::new(false),
233        }
234    }
235
236    /// Initialize store from disk (once)
237    pub async fn init(&self) -> Result<()> {
238        use std::sync::atomic::Ordering;
239
240        if self.initialized.load(Ordering::SeqCst) {
241            return Ok(());
242        }
243
244        let mut store = self.store.lock().await;
245        if let Ok(loaded) = MemoryStore::load().await {
246            *store = loaded;
247        }
248        self.initialized.store(true, Ordering::SeqCst);
249        Ok(())
250    }
251
252    /// Persist store to disk
253    pub async fn persist(&self) -> Result<()> {
254        let store = self.store.lock().await;
255        store.save().await
256    }
257}
258
259#[async_trait]
260impl Tool for MemoryTool {
261    fn id(&self) -> &str {
262        "memory"
263    }
264
265    fn name(&self) -> &str {
266        "Memory"
267    }
268
269    fn description(&self) -> &str {
270        "Store and retrieve persistent knowledge across sessions. Use 'save' to capture important insights, 'search' to find relevant memories, 'list' to see all entries, 'tags' to see available categories, or 'delete' to remove an entry."
271    }
272
273    fn parameters(&self) -> Value {
274        json!({
275            "type": "object",
276            "properties": {
277                "action": {
278                    "type": "string",
279                    "description": "Action to perform: 'save' (store new memory), 'search' (find memories), 'get' (retrieve specific memory), 'list' (show recent), 'tags' (show categories), 'delete' (remove), 'stats' (show statistics)",
280                    "enum": ["save", "search", "get", "list", "tags", "delete", "stats"]
281                },
282                "content": {
283                    "type": "string",
284                    "description": "Memory content to save (required for 'save' action)"
285                },
286                "tags": {
287                    "type": "array",
288                    "items": {"type": "string"},
289                    "description": "Tags for categorization (optional for 'save')"
290                },
291                "query": {
292                    "type": "string",
293                    "description": "Search query (for 'search' action)"
294                },
295                "scope": {
296                    "type": "string",
297                    "description": "Project/context scope (optional for 'save')"
298                },
299                "importance": {
300                    "type": "integer",
301                    "description": "Importance level 1-5 (optional for 'save', default 3)"
302                },
303                "id": {
304                    "type": "string",
305                    "description": "Memory ID (required for 'get' and 'delete')"
306                },
307                "limit": {
308                    "type": "integer",
309                    "description": "Maximum results to return (default 10, for 'search' and 'list')"
310                }
311            },
312            "required": ["action"]
313        })
314    }
315
316    async fn execute(&self, args: Value) -> Result<ToolResult> {
317        // Initialize store from disk if not already loaded
318        // Use a flag to avoid reloading on every call
319        let needs_init = {
320            let store = self.store.lock().await;
321            store.entries.is_empty()
322        };
323
324        if needs_init {
325            self.init().await.ok();
326        }
327
328        let action = args["action"]
329            .as_str()
330            .ok_or_else(|| anyhow::anyhow!("action is required"))?;
331
332        match action {
333            "save" => self.execute_save(args).await,
334            "search" => self.execute_search(args).await,
335            "get" => self.execute_get(args).await,
336            "list" => self.execute_list(args).await,
337            "tags" => self.execute_tags(args).await,
338            "delete" => self.execute_delete(args).await,
339            "stats" => self.execute_stats(args).await,
340            _ => Ok(ToolResult::error(format!(
341                "Unknown action: {}. Use 'save', 'search', 'get', 'list', 'tags', 'delete', or 'stats'.",
342                action
343            ))),
344        }
345    }
346}
347
348impl MemoryTool {
349    async fn execute_save(&self, args: Value) -> Result<ToolResult> {
350        let content = args["content"]
351            .as_str()
352            .ok_or_else(|| anyhow::anyhow!("content is required for 'save' action"))?;
353
354        let tags: Vec<String> = args["tags"]
355            .as_array()
356            .map(|arr| {
357                arr.iter()
358                    .filter_map(|v| v.as_str().map(String::from))
359                    .collect()
360            })
361            .unwrap_or_default();
362
363        let scope = args["scope"].as_str().map(String::from);
364        let importance = args["importance"].as_u64().map(|v| v as u8).unwrap_or(3);
365
366        let mut entry = MemoryEntry::new(content, tags).with_importance(importance);
367
368        if let Some(s) = scope {
369            entry = entry.with_scope(s);
370        }
371
372        let id = {
373            let mut store = self.store.lock().await;
374            store.add(entry)
375        };
376
377        // Persist to disk
378        self.persist().await?;
379
380        Ok(ToolResult::success(format!(
381            "Memory saved with ID: {}\nImportance: {}/5",
382            id, importance
383        )))
384    }
385
386    async fn execute_search(&self, args: Value) -> Result<ToolResult> {
387        let query = args["query"].as_str();
388        let tags: Option<Vec<String>> = args["tags"].as_array().map(|arr| {
389            arr.iter()
390                .filter_map(|v| v.as_str().map(String::from))
391                .collect()
392        });
393        let limit = args["limit"].as_u64().map(|v| v as usize).unwrap_or(10);
394
395        let tags_ref = tags.as_ref().map(|v| v.as_slice());
396
397        let results = {
398            let mut store = self.store.lock().await;
399            store.search(query, tags_ref, limit)
400        };
401
402        if results.is_empty() {
403            return Ok(ToolResult::success(
404                "No memories found matching your criteria.".to_string(),
405            ));
406        }
407
408        let output = results
409            .iter()
410            .enumerate()
411            .map(|(i, m)| {
412                format!(
413                    "{}. [{}] {} - {}\n   Tags: {}\n   Created: {}",
414                    i + 1,
415                    m.id.chars().take(8).collect::<String>(),
416                    m.content.chars().take(80).collect::<String>()
417                        + if m.content.len() > 80 { "..." } else { "" },
418                    format!("accessed {} times", m.access_count),
419                    m.tags.join(", "),
420                    m.created_at.format("%Y-%m-%d %H:%M")
421                )
422            })
423            .collect::<Vec<_>>()
424            .join("\n\n");
425
426        Ok(ToolResult::success(format!(
427            "Found {} memories:\n\n{}",
428            results.len(),
429            output
430        )))
431    }
432
433    async fn execute_get(&self, args: Value) -> Result<ToolResult> {
434        let id = args["id"]
435            .as_str()
436            .ok_or_else(|| anyhow::anyhow!("id is required for 'get' action"))?;
437
438        let entry = {
439            let mut store = self.store.lock().await;
440            store.get(id).map(|e| e.clone())
441        };
442
443        match entry {
444            Some(e) => {
445                // Persist the updated access count
446                self.persist().await?;
447
448                Ok(ToolResult::success(format!(
449                    "Memory ID: {}\nImportance: {}/5\nTags: {}\nCreated: {}\nAccessed: {} times\n\n{}",
450                    e.id,
451                    e.importance,
452                    e.tags.join(", "),
453                    e.created_at.format("%Y-%m-%d %H:%M:%S"),
454                    e.access_count,
455                    e.content
456                )))
457            }
458            None => Ok(ToolResult::error(format!("Memory not found: {}", id))),
459        }
460    }
461
462    async fn execute_list(&self, args: Value) -> Result<ToolResult> {
463        let limit = args["limit"].as_u64().map(|v| v as usize).unwrap_or(10);
464
465        let results = {
466            let mut store = self.store.lock().await;
467            store.search(None, None, limit)
468        };
469
470        if results.is_empty() {
471            return Ok(ToolResult::success(
472                "No memories stored yet. Use 'save' to add your first memory.".to_string(),
473            ));
474        }
475
476        let output = results
477            .iter()
478            .enumerate()
479            .map(|(i, m)| {
480                format!(
481                    "{}. [{}] {} (importance: {}/5, accessed: {}x)",
482                    i + 1,
483                    m.id.chars().take(8).collect::<String>(),
484                    m.content.chars().take(60).collect::<String>()
485                        + if m.content.len() > 60 { "..." } else { "" },
486                    m.importance,
487                    m.access_count
488                )
489            })
490            .collect::<Vec<_>>()
491            .join("\n");
492
493        Ok(ToolResult::success(format!(
494            "Recent memories:\n\n{}",
495            output
496        )))
497    }
498
499    async fn execute_tags(&self, _args: Value) -> Result<ToolResult> {
500        let tags = {
501            let store = self.store.lock().await;
502            store.all_tags()
503        };
504
505        if tags.is_empty() {
506            return Ok(ToolResult::success(
507                "No tags yet. Add tags when saving memories.".to_string(),
508            ));
509        }
510
511        let mut sorted: Vec<_> = tags.iter().collect();
512        sorted.sort_by(|a, b| b.1.cmp(a.1));
513
514        let output = sorted
515            .iter()
516            .map(|(tag, count)| format!("  {} ({} memories)", tag, count))
517            .collect::<Vec<_>>()
518            .join("\n");
519
520        Ok(ToolResult::success(format!(
521            "Available tags:\n\n{}",
522            output
523        )))
524    }
525
526    async fn execute_delete(&self, args: Value) -> Result<ToolResult> {
527        let id = args["id"]
528            .as_str()
529            .ok_or_else(|| anyhow::anyhow!("id is required for 'delete' action"))?;
530
531        let deleted = {
532            let mut store = self.store.lock().await;
533            store.delete(id)
534        };
535
536        if deleted {
537            self.persist().await?;
538            Ok(ToolResult::success(format!("Memory deleted: {}", id)))
539        } else {
540            Ok(ToolResult::error(format!("Memory not found: {}", id)))
541        }
542    }
543
544    async fn execute_stats(&self, _args: Value) -> Result<ToolResult> {
545        let stats = {
546            let store = self.store.lock().await;
547            store.stats()
548        };
549
550        let tags_output = if stats.tags.is_empty() {
551            "None".to_string()
552        } else {
553            let mut sorted: Vec<_> = stats.tags.iter().collect();
554            sorted.sort_by(|a, b| b.1.cmp(a.1));
555            sorted
556                .iter()
557                .take(10)
558                .map(|(t, c)| format!("  {}: {}", t, c))
559                .collect::<Vec<_>>()
560                .join("\n")
561        };
562
563        Ok(ToolResult::success(format!(
564            "Memory Statistics:\n\n\
565             Total entries: {}\n\
566             Total accesses: {}\n\
567             Unique tags: {}\n\n\
568             Top tags:\n{}",
569            stats.total_entries, stats.total_accesses, stats.unique_tags, tags_output
570        )))
571    }
572}
573
574#[cfg(test)]
575mod tests {
576    use super::*;
577    use std::sync::atomic::Ordering;
578
579    #[tokio::test]
580    async fn test_memory_save_and_get() {
581        let tool = MemoryTool::new();
582        // Mark as initialized to prevent loading from disk (isolated test)
583        tool.initialized.store(true, Ordering::SeqCst);
584
585        // Save a memory
586        let result = tool
587            .execute(json!({
588                "action": "save",
589                "content": "Test memory content",
590                "tags": ["test", "example"],
591                "importance": 4
592            }))
593            .await
594            .unwrap();
595
596        assert!(result.success);
597
598        // List memories
599        let result = tool
600            .execute(json!({
601                "action": "list",
602                "limit": 5
603            }))
604            .await
605            .unwrap();
606
607        assert!(result.success);
608        assert!(result.output.contains("Test memory content"));
609
610        // Get stats
611        let result = tool
612            .execute(json!({
613                "action": "stats"
614            }))
615            .await
616            .unwrap();
617
618        assert!(result.success);
619        assert!(result.output.contains("Total entries: 1"));
620    }
621
622    #[tokio::test]
623    async fn test_memory_search() {
624        let tool = MemoryTool::new();
625        // Mark as initialized to prevent loading from disk (isolated test)
626        tool.initialized.store(true, Ordering::SeqCst);
627
628        // Save with specific tags
629        tool.execute(json!({
630            "action": "save",
631            "content": "Rust programming insights",
632            "tags": ["rust", "programming"]
633        }))
634        .await
635        .unwrap();
636
637        tool.execute(json!({
638            "action": "save",
639            "content": "Python tips",
640            "tags": ["python", "programming"]
641        }))
642        .await
643        .unwrap();
644
645        // Search by tag
646        let result = tool
647            .execute(json!({
648                "action": "search",
649                "tags": ["rust"]
650            }))
651            .await
652            .unwrap();
653
654        assert!(result.success);
655        assert!(result.output.contains("Rust"));
656        assert!(!result.output.contains("Python"));
657    }
658}