Skip to main content

codetether_agent/tool/
memory.rs

1//! Memory tool: Persistent knowledge capture and retrieval
2//!
3//! Allows agents to store important insights, learnings, and decisions
4//! that persist across sessions for future reference.
5
6mod git_scope;
7mod scope;
8
9use super::{Tool, ToolResult};
10use anyhow::Result;
11use async_trait::async_trait;
12use chrono::{DateTime, Utc};
13use serde::{Deserialize, Serialize};
14use serde_json::{Value, json};
15use std::collections::HashMap;
16use std::path::PathBuf;
17use tokio::fs;
18
19/// A single memory entry
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct MemoryEntry {
22    /// Unique identifier
23    pub id: String,
24    /// The memory content
25    pub content: String,
26    /// Tags for categorization/search
27    pub tags: Vec<String>,
28    /// When this memory was created
29    pub created_at: DateTime<Utc>,
30    /// When this memory was last accessed
31    pub accessed_at: DateTime<Utc>,
32    /// How many times this memory has been accessed
33    pub access_count: u64,
34    /// Optional project/context scope
35    pub scope: Option<String>,
36    /// Source of the memory (session_id, tool, etc.)
37    pub source: Option<String>,
38    /// Importance level (1-5)
39    pub importance: u8,
40}
41
42impl MemoryEntry {
43    pub fn new(content: impl Into<String>, tags: Vec<String>) -> Self {
44        let now = Utc::now();
45        Self {
46            id: uuid::Uuid::new_v4().to_string(),
47            content: content.into(),
48            tags,
49            created_at: now,
50            accessed_at: now,
51            access_count: 0,
52            scope: None,
53            source: None,
54            importance: 3, // default medium importance
55        }
56    }
57
58    pub fn with_scope(mut self, scope: impl Into<String>) -> Self {
59        self.scope = Some(scope.into());
60        self
61    }
62
63    /// Set the source of this memory entry
64    #[allow(dead_code)]
65    pub fn with_source(mut self, source: impl Into<String>) -> Self {
66        self.source = Some(source.into());
67        self
68    }
69
70    pub fn with_importance(mut self, importance: u8) -> Self {
71        self.importance = importance.min(5);
72        self
73    }
74
75    pub fn touch(&mut self) {
76        self.accessed_at = Utc::now();
77        self.access_count += 1;
78    }
79}
80
81/// Memory store for persistence
82#[derive(Debug, Clone, Serialize, Deserialize, Default)]
83pub struct MemoryStore {
84    entries: HashMap<String, MemoryEntry>,
85}
86
87impl MemoryStore {
88    /// Get the default memory file path
89    pub fn default_path() -> std::path::PathBuf {
90        crate::config::Config::data_dir()
91            .map(|p| p.join("memory.json"))
92            .unwrap_or_else(|| PathBuf::from(".codetether-agent/memory.json"))
93    }
94
95    /// Load from disk
96    pub async fn load() -> Result<Self> {
97        let path = Self::default_path();
98        if !path.exists() {
99            return Ok(Self::default());
100        }
101        let content = fs::read_to_string(&path).await?;
102        let store: MemoryStore = serde_json::from_str(&content)?;
103        Ok(store)
104    }
105
106    /// Save to disk
107    pub async fn save(&self) -> Result<()> {
108        let path = Self::default_path();
109        if let Some(parent) = path.parent() {
110            fs::create_dir_all(parent).await?;
111        }
112        let content = serde_json::to_string_pretty(self)?;
113        fs::write(&path, content).await?;
114        Ok(())
115    }
116
117    /// Add a new memory
118    pub fn add(&mut self, entry: MemoryEntry) -> String {
119        let id = entry.id.clone();
120        self.entries.insert(id.clone(), entry);
121        id
122    }
123
124    /// Get a memory by ID
125    pub fn get(&mut self, id: &str) -> Option<MemoryEntry> {
126        let id = self.resolve_id(id)?;
127        let entry = self.entries.get_mut(&id)?;
128        entry.touch();
129        Some(entry.clone())
130    }
131
132    fn resolve_id(&self, id: &str) -> Option<String> {
133        if self.entries.contains_key(id) {
134            return Some(id.to_string());
135        }
136        let mut matches = self.entries.keys().filter(|key| key.starts_with(id));
137        let found = matches.next()?.clone();
138        matches.next().is_none().then_some(found)
139    }
140
141    /// Search memories by query or tags
142    pub fn search(
143        &mut self,
144        query: Option<&str>,
145        tags: Option<&[String]>,
146        scope: Option<&str>,
147        limit: usize,
148    ) -> Vec<MemoryEntry> {
149        let mut results: Vec<MemoryEntry> = self
150            .entries
151            .values_mut()
152            .filter(|entry| {
153                // Filter by tags if provided
154                if let Some(search_tags) = tags
155                    && !search_tags.is_empty()
156                    && !search_tags.iter().any(|t| entry.tags.contains(t))
157                {
158                    return false;
159                }
160                if let Some(scope) = scope
161                    && entry.scope.as_deref() != Some(scope)
162                {
163                    return false;
164                }
165
166                // Filter by query if provided
167                if let Some(q) = query {
168                    let q_lower = q.to_lowercase();
169                    let matches_content = entry.content.to_lowercase().contains(&q_lower);
170                    let matches_tags = entry
171                        .tags
172                        .iter()
173                        .any(|t| t.to_lowercase().contains(&q_lower));
174                    if !matches_content && !matches_tags {
175                        return false;
176                    }
177                }
178
179                true
180            })
181            .map(|e| {
182                e.touch();
183                e.clone()
184            })
185            .collect();
186
187        // Sort by importance (descending) then access_count (descending)
188        results.sort_by(|a, b| {
189            b.importance
190                .cmp(&a.importance)
191                .then_with(|| b.access_count.cmp(&a.access_count))
192        });
193
194        results.truncate(limit);
195        results
196    }
197
198    /// Get all tags with counts
199    pub fn all_tags(&self) -> HashMap<String, u64> {
200        let mut tags: HashMap<String, u64> = HashMap::new();
201        for entry in self.entries.values() {
202            for tag in &entry.tags {
203                *tags.entry(tag.clone()).or_insert(0) += 1;
204            }
205        }
206        tags
207    }
208
209    /// Delete a memory
210    pub fn delete(&mut self, id: &str) -> bool {
211        self.resolve_id(id)
212            .and_then(|id| self.entries.remove(&id))
213            .is_some()
214    }
215
216    /// Get statistics
217    pub fn stats(&self) -> MemoryStats {
218        let total = self.entries.len();
219        let total_accesses: u64 = self.entries.values().map(|e| e.access_count).sum();
220        let tags = self.all_tags();
221        MemoryStats {
222            total_entries: total,
223            total_accesses,
224            unique_tags: tags.len(),
225            tags,
226        }
227    }
228}
229
230#[derive(Debug, Clone, Serialize, Deserialize)]
231pub struct MemoryStats {
232    pub total_entries: usize,
233    pub total_accesses: u64,
234    pub unique_tags: usize,
235    pub tags: HashMap<String, u64>,
236}
237
238/// Memory Tool - Store and retrieve persistent knowledge
239pub struct MemoryTool {
240    store: tokio::sync::Mutex<MemoryStore>,
241    initialized: std::sync::atomic::AtomicBool,
242}
243
244impl Default for MemoryTool {
245    fn default() -> Self {
246        Self::new()
247    }
248}
249
250impl MemoryTool {
251    pub fn new() -> Self {
252        Self {
253            store: tokio::sync::Mutex::new(MemoryStore::default()),
254            initialized: std::sync::atomic::AtomicBool::new(false),
255        }
256    }
257
258    /// Initialize store from disk (once)
259    pub async fn init(&self) -> Result<()> {
260        use std::sync::atomic::Ordering;
261
262        if self.initialized.load(Ordering::SeqCst) {
263            return Ok(());
264        }
265
266        let mut store = self.store.lock().await;
267        if let Ok(loaded) = MemoryStore::load().await {
268            *store = loaded;
269        }
270        self.initialized.store(true, Ordering::SeqCst);
271        Ok(())
272    }
273
274    /// Persist store to disk
275    pub async fn persist(&self) -> Result<()> {
276        let store = self.store.lock().await;
277        store.save().await
278    }
279}
280
281#[async_trait]
282impl Tool for MemoryTool {
283    fn id(&self) -> &str {
284        "memory"
285    }
286
287    fn name(&self) -> &str {
288        "Memory"
289    }
290
291    fn description(&self) -> &str {
292        "Store and retrieve persistent knowledge across sessions. Use 'save' to capture important insights, 'search' to find relevant memories, 'list' to see all entries, 'tags' to see available categories, or 'delete' to remove an entry."
293    }
294
295    fn parameters(&self) -> Value {
296        json!({
297            "type": "object",
298            "properties": {
299                "action": {
300                    "type": "string",
301                    "description": "Action to perform: 'save' (store new memory), 'search' (find memories), 'get' (retrieve specific memory), 'list' (show recent), 'tags' (show categories), 'delete' (remove), 'stats' (show statistics)",
302                    "enum": ["save", "search", "get", "list", "tags", "delete", "stats"]
303                },
304                "content": {
305                    "type": "string",
306                    "description": "Memory content to save (required for 'save' action)"
307                },
308                "tags": {
309                    "type": "array",
310                    "items": {"type": "string"},
311                    "description": "Tags for categorization (optional for 'save')"
312                },
313                "query": {
314                    "type": "string",
315                    "description": "Search query (for 'search' action)"
316                },
317                "scope": {
318                    "type": "string",
319                    "description": "Project/context scope. Defaults to a stable git scope, not the transient worktree path. Use 'all' for unscoped search/list."
320                },
321                "importance": {
322                    "type": "integer",
323                    "description": "Importance level 1-5 (optional for 'save', default 3)"
324                },
325                "id": {
326                    "type": "string",
327                    "description": "Memory ID (required for 'get' and 'delete')"
328                },
329                "limit": {
330                    "type": "integer",
331                    "description": "Maximum results to return (default 10, for 'search' and 'list')"
332                }
333            },
334            "required": ["action"]
335        })
336    }
337
338    async fn execute(&self, args: Value) -> Result<ToolResult> {
339        // Initialize store from disk if not already loaded
340        // Use a flag to avoid reloading on every call
341        let needs_init = {
342            let store = self.store.lock().await;
343            store.entries.is_empty()
344        };
345
346        if needs_init {
347            self.init().await.ok();
348        }
349
350        let action = args["action"]
351            .as_str()
352            .ok_or_else(|| anyhow::anyhow!("action is required"))?;
353
354        match action {
355            "save" => self.execute_save(args).await,
356            "search" => self.execute_search(args).await,
357            "get" => self.execute_get(args).await,
358            "list" => self.execute_list(args).await,
359            "tags" => self.execute_tags(args).await,
360            "delete" => self.execute_delete(args).await,
361            "stats" => self.execute_stats(args).await,
362            _ => Ok(ToolResult::error(format!(
363                "Unknown action: {}. Use 'save', 'search', 'get', 'list', 'tags', 'delete', or 'stats'.",
364                action
365            ))),
366        }
367    }
368}
369
370impl MemoryTool {
371    async fn execute_save(&self, args: Value) -> Result<ToolResult> {
372        let content = args["content"]
373            .as_str()
374            .ok_or_else(|| anyhow::anyhow!("content is required for 'save' action"))?;
375
376        let tags: Vec<String> = args["tags"]
377            .as_array()
378            .map(|arr| {
379                arr.iter()
380                    .filter_map(|v| v.as_str().map(String::from))
381                    .collect()
382            })
383            .unwrap_or_default();
384
385        let scope = scope::save(&args);
386        let importance = args["importance"].as_u64().map(|v| v as u8).unwrap_or(3);
387
388        let mut entry = MemoryEntry::new(content, tags).with_importance(importance);
389
390        if let Some(s) = scope {
391            entry = entry.with_scope(s);
392        }
393
394        let id = {
395            let mut store = self.store.lock().await;
396            store.add(entry)
397        };
398
399        // Persist to disk
400        self.persist().await?;
401
402        Ok(ToolResult::success(format!(
403            "Memory saved with ID: {}\nImportance: {}/5",
404            id, importance
405        )))
406    }
407
408    async fn execute_search(&self, args: Value) -> Result<ToolResult> {
409        let query = args["query"].as_str();
410        let tags: Option<Vec<String>> = args["tags"].as_array().map(|arr| {
411            arr.iter()
412                .filter_map(|v| v.as_str().map(String::from))
413                .collect()
414        });
415        let limit = args["limit"].as_u64().map(|v| v as usize).unwrap_or(10);
416        let scope = scope::search(&args);
417
418        let tags_ref = tags.as_deref();
419
420        let results = {
421            let mut store = self.store.lock().await;
422            store.search(query, tags_ref, scope.as_deref(), limit)
423        };
424
425        if results.is_empty() {
426            return Ok(ToolResult::success(
427                "No memories found matching your criteria.".to_string(),
428            ));
429        }
430
431        let output = results
432            .iter()
433            .enumerate()
434            .map(|(i, m)| {
435                format!(
436                    "{}. [{}] {} - {}\n   Tags: {}\n   Created: {}",
437                    i + 1,
438                    m.id,
439                    m.content.chars().take(80).collect::<String>()
440                        + if m.content.len() > 80 { "..." } else { "" },
441                    format!("accessed {} times", m.access_count),
442                    m.tags.join(", "),
443                    m.created_at.format("%Y-%m-%d %H:%M")
444                )
445            })
446            .collect::<Vec<_>>()
447            .join("\n\n");
448
449        Ok(ToolResult::success(format!(
450            "Found {} memories:\n\n{}",
451            results.len(),
452            output
453        )))
454    }
455
456    async fn execute_get(&self, args: Value) -> Result<ToolResult> {
457        let id = args["id"]
458            .as_str()
459            .ok_or_else(|| anyhow::anyhow!("id is required for 'get' action"))?;
460
461        let entry = {
462            let mut store = self.store.lock().await;
463            store.get(id)
464        };
465
466        match entry {
467            Some(e) => {
468                // Persist the updated access count
469                self.persist().await?;
470
471                Ok(ToolResult::success(format!(
472                    "Memory ID: {}\nImportance: {}/5\nTags: {}\nCreated: {}\nAccessed: {} times\n\n{}",
473                    e.id,
474                    e.importance,
475                    e.tags.join(", "),
476                    e.created_at.format("%Y-%m-%d %H:%M:%S"),
477                    e.access_count,
478                    e.content
479                )))
480            }
481            None => Ok(ToolResult::error(format!("Memory not found: {}", id))),
482        }
483    }
484
485    async fn execute_list(&self, args: Value) -> Result<ToolResult> {
486        let limit = args["limit"].as_u64().map(|v| v as usize).unwrap_or(10);
487        let scope = scope::search(&args);
488
489        let results = {
490            let mut store = self.store.lock().await;
491            store.search(None, None, scope.as_deref(), limit)
492        };
493
494        if results.is_empty() {
495            return Ok(ToolResult::success(
496                "No memories stored yet. Use 'save' to add your first memory.".to_string(),
497            ));
498        }
499
500        let output = results
501            .iter()
502            .enumerate()
503            .map(|(i, m)| {
504                format!(
505                    "{}. [{}] {} (importance: {}/5, accessed: {}x)",
506                    i + 1,
507                    m.id,
508                    m.content.chars().take(60).collect::<String>()
509                        + if m.content.len() > 60 { "..." } else { "" },
510                    m.importance,
511                    m.access_count
512                )
513            })
514            .collect::<Vec<_>>()
515            .join("\n");
516
517        Ok(ToolResult::success(format!(
518            "Recent memories:\n\n{}",
519            output
520        )))
521    }
522
523    async fn execute_tags(&self, _args: Value) -> Result<ToolResult> {
524        let tags = {
525            let store = self.store.lock().await;
526            store.all_tags()
527        };
528
529        if tags.is_empty() {
530            return Ok(ToolResult::success(
531                "No tags yet. Add tags when saving memories.".to_string(),
532            ));
533        }
534
535        let mut sorted: Vec<_> = tags.iter().collect();
536        sorted.sort_by(|a, b| b.1.cmp(a.1));
537
538        let output = sorted
539            .iter()
540            .map(|(tag, count)| format!("  {} ({} memories)", tag, count))
541            .collect::<Vec<_>>()
542            .join("\n");
543
544        Ok(ToolResult::success(format!(
545            "Available tags:\n\n{}",
546            output
547        )))
548    }
549
550    async fn execute_delete(&self, args: Value) -> Result<ToolResult> {
551        let id = args["id"]
552            .as_str()
553            .ok_or_else(|| anyhow::anyhow!("id is required for 'delete' action"))?;
554
555        let deleted = {
556            let mut store = self.store.lock().await;
557            store.delete(id)
558        };
559
560        if deleted {
561            self.persist().await?;
562            Ok(ToolResult::success(format!("Memory deleted: {}", id)))
563        } else {
564            Ok(ToolResult::error(format!("Memory not found: {}", id)))
565        }
566    }
567
568    async fn execute_stats(&self, _args: Value) -> Result<ToolResult> {
569        let stats = {
570            let store = self.store.lock().await;
571            store.stats()
572        };
573
574        let tags_output = if stats.tags.is_empty() {
575            "None".to_string()
576        } else {
577            let mut sorted: Vec<_> = stats.tags.iter().collect();
578            sorted.sort_by(|a, b| b.1.cmp(a.1));
579            sorted
580                .iter()
581                .take(10)
582                .map(|(t, c)| format!("  {}: {}", t, c))
583                .collect::<Vec<_>>()
584                .join("\n")
585        };
586
587        Ok(ToolResult::success(format!(
588            "Memory Statistics:\n\n\
589             Total entries: {}\n\
590             Total accesses: {}\n\
591             Unique tags: {}\n\n\
592             Top tags:\n{}",
593            stats.total_entries, stats.total_accesses, stats.unique_tags, tags_output
594        )))
595    }
596}
597
598#[cfg(test)]
599mod tests;