Skip to main content

codetether_agent/tool/
memory.rs

1//! Memory tool: Persistent knowledge capture and retrieval
2//!
3//! Allows agents to store important insights, learnings, and decisions
4//! that persist across sessions for future reference.
5
6use super::{Tool, ToolResult};
7use anyhow::Result;
8use async_trait::async_trait;
9use chrono::{DateTime, Utc};
10use serde::{Deserialize, Serialize};
11use serde_json::{Value, json};
12use std::collections::HashMap;
13use std::path::PathBuf;
14use tokio::fs;
15
16/// A single memory entry
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct MemoryEntry {
19    /// Unique identifier
20    pub id: String,
21    /// The memory content
22    pub content: String,
23    /// Tags for categorization/search
24    pub tags: Vec<String>,
25    /// When this memory was created
26    pub created_at: DateTime<Utc>,
27    /// When this memory was last accessed
28    pub accessed_at: DateTime<Utc>,
29    /// How many times this memory has been accessed
30    pub access_count: u64,
31    /// Optional project/context scope
32    pub scope: Option<String>,
33    /// Source of the memory (session_id, tool, etc.)
34    pub source: Option<String>,
35    /// Importance level (1-5)
36    pub importance: u8,
37}
38
39impl MemoryEntry {
40    pub fn new(content: impl Into<String>, tags: Vec<String>) -> Self {
41        let now = Utc::now();
42        Self {
43            id: uuid::Uuid::new_v4().to_string(),
44            content: content.into(),
45            tags,
46            created_at: now,
47            accessed_at: now,
48            access_count: 0,
49            scope: None,
50            source: None,
51            importance: 3, // default medium importance
52        }
53    }
54
55    pub fn with_scope(mut self, scope: impl Into<String>) -> Self {
56        self.scope = Some(scope.into());
57        self
58    }
59
60    pub fn with_source(mut self, source: impl Into<String>) -> Self {
61        self.source = Some(source.into());
62        self
63    }
64
65    pub fn with_importance(mut self, importance: u8) -> Self {
66        self.importance = importance.min(5);
67        self
68    }
69
70    pub fn touch(&mut self) {
71        self.accessed_at = Utc::now();
72        self.access_count += 1;
73    }
74}
75
76/// Memory store for persistence
77#[derive(Debug, Clone, Serialize, Deserialize, Default)]
78pub struct MemoryStore {
79    entries: HashMap<String, MemoryEntry>,
80}
81
82impl MemoryStore {
83    /// Get the default memory file path
84    pub fn default_path() -> std::path::PathBuf {
85        directories::ProjectDirs::from("com", "codetether", "codetether")
86            .map(|p| p.data_dir().join("memory.json"))
87            .unwrap_or_else(|| PathBuf::from(".codetether/memory.json"))
88    }
89
90    /// Load from disk
91    pub async fn load() -> Result<Self> {
92        let path = Self::default_path();
93        if !path.exists() {
94            return Ok(Self::default());
95        }
96        let content = fs::read_to_string(&path).await?;
97        let store: MemoryStore = serde_json::from_str(&content)?;
98        Ok(store)
99    }
100
101    /// Save to disk
102    pub async fn save(&self) -> Result<()> {
103        let path = Self::default_path();
104        if let Some(parent) = path.parent() {
105            fs::create_dir_all(parent).await?;
106        }
107        let content = serde_json::to_string_pretty(self)?;
108        fs::write(&path, content).await?;
109        Ok(())
110    }
111
112    /// Add a new memory
113    pub fn add(&mut self, entry: MemoryEntry) -> String {
114        let id = entry.id.clone();
115        self.entries.insert(id.clone(), entry);
116        id
117    }
118
119    /// Get a memory by ID
120    pub fn get(&mut self, id: &str) -> Option<MemoryEntry> {
121        let entry = self.entries.get_mut(id)?;
122        entry.touch();
123        Some(entry.clone())
124    }
125
126    /// Search memories by query or tags
127    pub fn search(
128        &mut self,
129        query: Option<&str>,
130        tags: Option<&[String]>,
131        limit: usize,
132    ) -> Vec<MemoryEntry> {
133        let mut results: Vec<MemoryEntry> = self
134            .entries
135            .values_mut()
136            .filter(|entry| {
137                // Filter by tags if provided
138                if let Some(search_tags) = tags {
139                    if !search_tags.is_empty()
140                        && !search_tags.iter().any(|t| entry.tags.contains(t))
141                    {
142                        return false;
143                    }
144                }
145
146                // Filter by query if provided
147                if let Some(q) = query {
148                    let q_lower = q.to_lowercase();
149                    let matches_content = entry.content.to_lowercase().contains(&q_lower);
150                    let matches_tags = entry
151                        .tags
152                        .iter()
153                        .any(|t| t.to_lowercase().contains(&q_lower));
154                    if !matches_content && !matches_tags {
155                        return false;
156                    }
157                }
158
159                true
160            })
161            .map(|e| {
162                e.touch();
163                e.clone()
164            })
165            .collect();
166
167        // Sort by importance (descending) then access_count (descending)
168        results.sort_by(|a, b| {
169            b.importance
170                .cmp(&a.importance)
171                .then_with(|| b.access_count.cmp(&a.access_count))
172        });
173
174        results.truncate(limit);
175        results
176    }
177
178    /// Get all tags with counts
179    pub fn all_tags(&self) -> HashMap<String, u64> {
180        let mut tags: HashMap<String, u64> = HashMap::new();
181        for entry in self.entries.values() {
182            for tag in &entry.tags {
183                *tags.entry(tag.clone()).or_insert(0) += 1;
184            }
185        }
186        tags
187    }
188
189    /// Delete a memory
190    pub fn delete(&mut self, id: &str) -> bool {
191        self.entries.remove(id).is_some()
192    }
193
194    /// Get statistics
195    pub fn stats(&self) -> MemoryStats {
196        let total = self.entries.len();
197        let total_accesses: u64 = self.entries.values().map(|e| e.access_count).sum();
198        let tags = self.all_tags();
199        MemoryStats {
200            total_entries: total,
201            total_accesses,
202            unique_tags: tags.len(),
203            tags,
204        }
205    }
206}
207
208#[derive(Debug, Clone, Serialize, Deserialize)]
209pub struct MemoryStats {
210    pub total_entries: usize,
211    pub total_accesses: u64,
212    pub unique_tags: usize,
213    pub tags: HashMap<String, u64>,
214}
215
216/// Memory Tool - Store and retrieve persistent knowledge
217pub struct MemoryTool {
218    store: tokio::sync::Mutex<MemoryStore>,
219}
220
221impl Default for MemoryTool {
222    fn default() -> Self {
223        Self::new()
224    }
225}
226
227impl MemoryTool {
228    pub fn new() -> Self {
229        Self {
230            store: tokio::sync::Mutex::new(MemoryStore::default()),
231        }
232    }
233
234    /// Initialize store from disk
235    pub async fn init(&self) -> Result<()> {
236        let mut store = self.store.lock().await;
237        *store = MemoryStore::load().await?;
238        Ok(())
239    }
240
241    /// Persist store to disk
242    pub async fn persist(&self) -> Result<()> {
243        let store = self.store.lock().await;
244        store.save().await
245    }
246}
247
248#[async_trait]
249impl Tool for MemoryTool {
250    fn id(&self) -> &str {
251        "memory"
252    }
253
254    fn name(&self) -> &str {
255        "Memory"
256    }
257
258    fn description(&self) -> &str {
259        "Store and retrieve persistent knowledge across sessions. Use 'save' to capture important insights, 'search' to find relevant memories, 'list' to see all entries, 'tags' to see available categories, or 'delete' to remove an entry."
260    }
261
262    fn parameters(&self) -> Value {
263        json!({
264            "type": "object",
265            "properties": {
266                "action": {
267                    "type": "string",
268                    "description": "Action to perform: 'save' (store new memory), 'search' (find memories), 'get' (retrieve specific memory), 'list' (show recent), 'tags' (show categories), 'delete' (remove), 'stats' (show statistics)",
269                    "enum": ["save", "search", "get", "list", "tags", "delete", "stats"]
270                },
271                "content": {
272                    "type": "string",
273                    "description": "Memory content to save (required for 'save' action)"
274                },
275                "tags": {
276                    "type": "array",
277                    "items": {"type": "string"},
278                    "description": "Tags for categorization (optional for 'save')"
279                },
280                "query": {
281                    "type": "string",
282                    "description": "Search query (for 'search' action)"
283                },
284                "scope": {
285                    "type": "string",
286                    "description": "Project/context scope (optional for 'save')"
287                },
288                "importance": {
289                    "type": "integer",
290                    "description": "Importance level 1-5 (optional for 'save', default 3)"
291                },
292                "id": {
293                    "type": "string",
294                    "description": "Memory ID (required for 'get' and 'delete')"
295                },
296                "limit": {
297                    "type": "integer",
298                    "description": "Maximum results to return (default 10, for 'search' and 'list')"
299                }
300            },
301            "required": ["action"]
302        })
303    }
304
305    async fn execute(&self, args: Value) -> Result<ToolResult> {
306        // Ensure store is loaded
307        self.init().await.ok();
308
309        let action = args["action"]
310            .as_str()
311            .ok_or_else(|| anyhow::anyhow!("action is required"))?;
312
313        match action {
314            "save" => self.execute_save(args).await,
315            "search" => self.execute_search(args).await,
316            "get" => self.execute_get(args).await,
317            "list" => self.execute_list(args).await,
318            "tags" => self.execute_tags(args).await,
319            "delete" => self.execute_delete(args).await,
320            "stats" => self.execute_stats(args).await,
321            _ => Ok(ToolResult::error(format!(
322                "Unknown action: {}. Use 'save', 'search', 'get', 'list', 'tags', 'delete', or 'stats'.",
323                action
324            ))),
325        }
326    }
327}
328
329impl MemoryTool {
330    async fn execute_save(&self, args: Value) -> Result<ToolResult> {
331        let content = args["content"]
332            .as_str()
333            .ok_or_else(|| anyhow::anyhow!("content is required for 'save' action"))?;
334
335        let tags: Vec<String> = args["tags"]
336            .as_array()
337            .map(|arr| {
338                arr.iter()
339                    .filter_map(|v| v.as_str().map(String::from))
340                    .collect()
341            })
342            .unwrap_or_default();
343
344        let scope = args["scope"].as_str().map(String::from);
345        let importance = args["importance"].as_u64().map(|v| v as u8).unwrap_or(3);
346
347        let mut entry = MemoryEntry::new(content, tags).with_importance(importance);
348
349        if let Some(s) = scope {
350            entry = entry.with_scope(s);
351        }
352
353        let id = {
354            let mut store = self.store.lock().await;
355            store.add(entry)
356        };
357
358        // Persist to disk
359        self.persist().await?;
360
361        Ok(ToolResult::success(format!(
362            "Memory saved with ID: {}\nImportance: {}/5",
363            id, importance
364        )))
365    }
366
367    async fn execute_search(&self, args: Value) -> Result<ToolResult> {
368        let query = args["query"].as_str();
369        let tags: Option<Vec<String>> = args["tags"].as_array().map(|arr| {
370            arr.iter()
371                .filter_map(|v| v.as_str().map(String::from))
372                .collect()
373        });
374        let limit = args["limit"].as_u64().map(|v| v as usize).unwrap_or(10);
375
376        let tags_ref = tags.as_ref().map(|v| v.as_slice());
377
378        let results = {
379            let mut store = self.store.lock().await;
380            store.search(query, tags_ref, limit)
381        };
382
383        if results.is_empty() {
384            return Ok(ToolResult::success(
385                "No memories found matching your criteria.".to_string(),
386            ));
387        }
388
389        let output = results
390            .iter()
391            .enumerate()
392            .map(|(i, m)| {
393                format!(
394                    "{}. [{}] {} - {}\n   Tags: {}\n   Created: {}",
395                    i + 1,
396                    m.id.chars().take(8).collect::<String>(),
397                    m.content.chars().take(80).collect::<String>()
398                        + if m.content.len() > 80 { "..." } else { "" },
399                    format!("accessed {} times", m.access_count),
400                    m.tags.join(", "),
401                    m.created_at.format("%Y-%m-%d %H:%M")
402                )
403            })
404            .collect::<Vec<_>>()
405            .join("\n\n");
406
407        Ok(ToolResult::success(format!(
408            "Found {} memories:\n\n{}",
409            results.len(),
410            output
411        )))
412    }
413
414    async fn execute_get(&self, args: Value) -> Result<ToolResult> {
415        let id = args["id"]
416            .as_str()
417            .ok_or_else(|| anyhow::anyhow!("id is required for 'get' action"))?;
418
419        let entry = {
420            let mut store = self.store.lock().await;
421            store.get(id).map(|e| e.clone())
422        };
423
424        match entry {
425            Some(e) => {
426                // Persist the updated access count
427                self.persist().await?;
428
429                Ok(ToolResult::success(format!(
430                    "Memory ID: {}\nImportance: {}/5\nTags: {}\nCreated: {}\nAccessed: {} times\n\n{}",
431                    e.id,
432                    e.importance,
433                    e.tags.join(", "),
434                    e.created_at.format("%Y-%m-%d %H:%M:%S"),
435                    e.access_count,
436                    e.content
437                )))
438            }
439            None => Ok(ToolResult::error(format!("Memory not found: {}", id))),
440        }
441    }
442
443    async fn execute_list(&self, args: Value) -> Result<ToolResult> {
444        let limit = args["limit"].as_u64().map(|v| v as usize).unwrap_or(10);
445
446        let results = {
447            let mut store = self.store.lock().await;
448            store.search(None, None, limit)
449        };
450
451        if results.is_empty() {
452            return Ok(ToolResult::success(
453                "No memories stored yet. Use 'save' to add your first memory.".to_string(),
454            ));
455        }
456
457        let output = results
458            .iter()
459            .enumerate()
460            .map(|(i, m)| {
461                format!(
462                    "{}. [{}] {} (importance: {}/5, accessed: {}x)",
463                    i + 1,
464                    m.id.chars().take(8).collect::<String>(),
465                    m.content.chars().take(60).collect::<String>()
466                        + if m.content.len() > 60 { "..." } else { "" },
467                    m.importance,
468                    m.access_count
469                )
470            })
471            .collect::<Vec<_>>()
472            .join("\n");
473
474        Ok(ToolResult::success(format!(
475            "Recent memories:\n\n{}",
476            output
477        )))
478    }
479
480    async fn execute_tags(&self, _args: Value) -> Result<ToolResult> {
481        let tags = {
482            let store = self.store.lock().await;
483            store.all_tags()
484        };
485
486        if tags.is_empty() {
487            return Ok(ToolResult::success(
488                "No tags yet. Add tags when saving memories.".to_string(),
489            ));
490        }
491
492        let mut sorted: Vec<_> = tags.iter().collect();
493        sorted.sort_by(|a, b| b.1.cmp(a.1));
494
495        let output = sorted
496            .iter()
497            .map(|(tag, count)| format!("  {} ({} memories)", tag, count))
498            .collect::<Vec<_>>()
499            .join("\n");
500
501        Ok(ToolResult::success(format!(
502            "Available tags:\n\n{}",
503            output
504        )))
505    }
506
507    async fn execute_delete(&self, args: Value) -> Result<ToolResult> {
508        let id = args["id"]
509            .as_str()
510            .ok_or_else(|| anyhow::anyhow!("id is required for 'delete' action"))?;
511
512        let deleted = {
513            let mut store = self.store.lock().await;
514            store.delete(id)
515        };
516
517        if deleted {
518            self.persist().await?;
519            Ok(ToolResult::success(format!("Memory deleted: {}", id)))
520        } else {
521            Ok(ToolResult::error(format!("Memory not found: {}", id)))
522        }
523    }
524
525    async fn execute_stats(&self, _args: Value) -> Result<ToolResult> {
526        let stats = {
527            let store = self.store.lock().await;
528            store.stats()
529        };
530
531        let tags_output = if stats.tags.is_empty() {
532            "None".to_string()
533        } else {
534            let mut sorted: Vec<_> = stats.tags.iter().collect();
535            sorted.sort_by(|a, b| b.1.cmp(a.1));
536            sorted
537                .iter()
538                .take(10)
539                .map(|(t, c)| format!("  {}: {}", t, c))
540                .collect::<Vec<_>>()
541                .join("\n")
542        };
543
544        Ok(ToolResult::success(format!(
545            "Memory Statistics:\n\n\
546             Total entries: {}\n\
547             Total accesses: {}\n\
548             Unique tags: {}\n\n\
549             Top tags:\n{}",
550            stats.total_entries, stats.total_accesses, stats.unique_tags, tags_output
551        )))
552    }
553}
554
555#[cfg(test)]
556mod tests {
557    use super::*;
558
559    #[tokio::test]
560    async fn test_memory_save_and_get() {
561        let tool = MemoryTool::new();
562
563        // Save a memory
564        let result = tool
565            .execute(json!({
566                "action": "save",
567                "content": "Test memory content",
568                "tags": ["test", "example"],
569                "importance": 4
570            }))
571            .await
572            .unwrap();
573
574        assert!(result.success);
575
576        // List memories
577        let result = tool
578            .execute(json!({
579                "action": "list",
580                "limit": 5
581            }))
582            .await
583            .unwrap();
584
585        assert!(result.success);
586        assert!(result.output.contains("Test memory content"));
587
588        // Get stats
589        let result = tool
590            .execute(json!({
591                "action": "stats"
592            }))
593            .await
594            .unwrap();
595
596        assert!(result.success);
597        assert!(result.output.contains("Total entries: 1"));
598    }
599
600    #[tokio::test]
601    async fn test_memory_search() {
602        let tool = MemoryTool::new();
603
604        // Save with specific tags
605        tool.execute(json!({
606            "action": "save",
607            "content": "Rust programming insights",
608            "tags": ["rust", "programming"]
609        }))
610        .await
611        .unwrap();
612
613        tool.execute(json!({
614            "action": "save",
615            "content": "Python tips",
616            "tags": ["python", "programming"]
617        }))
618        .await
619        .unwrap();
620
621        // Search by tag
622        let result = tool
623            .execute(json!({
624                "action": "search",
625                "tags": ["rust"]
626            }))
627            .await
628            .unwrap();
629
630        assert!(result.success);
631        assert!(result.output.contains("Rust"));
632        assert!(!result.output.contains("Python"));
633    }
634}