1use super::{Tool, ToolResult};
7use anyhow::Result;
8use async_trait::async_trait;
9use chrono::{DateTime, Utc};
10use serde::{Deserialize, Serialize};
11use serde_json::{Value, json};
12use std::collections::HashMap;
13use std::path::PathBuf;
14use tokio::fs;
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct MemoryEntry {
19 pub id: String,
21 pub content: String,
23 pub tags: Vec<String>,
25 pub created_at: DateTime<Utc>,
27 pub accessed_at: DateTime<Utc>,
29 pub access_count: u64,
31 pub scope: Option<String>,
33 pub source: Option<String>,
35 pub importance: u8,
37}
38
39impl MemoryEntry {
40 pub fn new(content: impl Into<String>, tags: Vec<String>) -> Self {
41 let now = Utc::now();
42 Self {
43 id: uuid::Uuid::new_v4().to_string(),
44 content: content.into(),
45 tags,
46 created_at: now,
47 accessed_at: now,
48 access_count: 0,
49 scope: None,
50 source: None,
51 importance: 3, }
53 }
54
55 pub fn with_scope(mut self, scope: impl Into<String>) -> Self {
56 self.scope = Some(scope.into());
57 self
58 }
59
60 pub fn with_source(mut self, source: impl Into<String>) -> Self {
61 self.source = Some(source.into());
62 self
63 }
64
65 pub fn with_importance(mut self, importance: u8) -> Self {
66 self.importance = importance.min(5);
67 self
68 }
69
70 pub fn touch(&mut self) {
71 self.accessed_at = Utc::now();
72 self.access_count += 1;
73 }
74}
75
76#[derive(Debug, Clone, Serialize, Deserialize, Default)]
78pub struct MemoryStore {
79 entries: HashMap<String, MemoryEntry>,
80}
81
82impl MemoryStore {
83 pub fn default_path() -> std::path::PathBuf {
85 directories::ProjectDirs::from("com", "codetether", "codetether")
86 .map(|p| p.data_dir().join("memory.json"))
87 .unwrap_or_else(|| PathBuf::from(".codetether/memory.json"))
88 }
89
90 pub async fn load() -> Result<Self> {
92 let path = Self::default_path();
93 if !path.exists() {
94 return Ok(Self::default());
95 }
96 let content = fs::read_to_string(&path).await?;
97 let store: MemoryStore = serde_json::from_str(&content)?;
98 Ok(store)
99 }
100
101 pub async fn save(&self) -> Result<()> {
103 let path = Self::default_path();
104 if let Some(parent) = path.parent() {
105 fs::create_dir_all(parent).await?;
106 }
107 let content = serde_json::to_string_pretty(self)?;
108 fs::write(&path, content).await?;
109 Ok(())
110 }
111
112 pub fn add(&mut self, entry: MemoryEntry) -> String {
114 let id = entry.id.clone();
115 self.entries.insert(id.clone(), entry);
116 id
117 }
118
119 pub fn get(&mut self, id: &str) -> Option<MemoryEntry> {
121 let entry = self.entries.get_mut(id)?;
122 entry.touch();
123 Some(entry.clone())
124 }
125
126 pub fn search(
128 &mut self,
129 query: Option<&str>,
130 tags: Option<&[String]>,
131 limit: usize,
132 ) -> Vec<MemoryEntry> {
133 let mut results: Vec<MemoryEntry> = self
134 .entries
135 .values_mut()
136 .filter(|entry| {
137 if let Some(search_tags) = tags {
139 if !search_tags.is_empty()
140 && !search_tags.iter().any(|t| entry.tags.contains(t))
141 {
142 return false;
143 }
144 }
145
146 if let Some(q) = query {
148 let q_lower = q.to_lowercase();
149 let matches_content = entry.content.to_lowercase().contains(&q_lower);
150 let matches_tags = entry
151 .tags
152 .iter()
153 .any(|t| t.to_lowercase().contains(&q_lower));
154 if !matches_content && !matches_tags {
155 return false;
156 }
157 }
158
159 true
160 })
161 .map(|e| {
162 e.touch();
163 e.clone()
164 })
165 .collect();
166
167 results.sort_by(|a, b| {
169 b.importance
170 .cmp(&a.importance)
171 .then_with(|| b.access_count.cmp(&a.access_count))
172 });
173
174 results.truncate(limit);
175 results
176 }
177
178 pub fn all_tags(&self) -> HashMap<String, u64> {
180 let mut tags: HashMap<String, u64> = HashMap::new();
181 for entry in self.entries.values() {
182 for tag in &entry.tags {
183 *tags.entry(tag.clone()).or_insert(0) += 1;
184 }
185 }
186 tags
187 }
188
189 pub fn delete(&mut self, id: &str) -> bool {
191 self.entries.remove(id).is_some()
192 }
193
194 pub fn stats(&self) -> MemoryStats {
196 let total = self.entries.len();
197 let total_accesses: u64 = self.entries.values().map(|e| e.access_count).sum();
198 let tags = self.all_tags();
199 MemoryStats {
200 total_entries: total,
201 total_accesses,
202 unique_tags: tags.len(),
203 tags,
204 }
205 }
206}
207
208#[derive(Debug, Clone, Serialize, Deserialize)]
209pub struct MemoryStats {
210 pub total_entries: usize,
211 pub total_accesses: u64,
212 pub unique_tags: usize,
213 pub tags: HashMap<String, u64>,
214}
215
216pub struct MemoryTool {
218 store: tokio::sync::Mutex<MemoryStore>,
219 initialized: std::sync::atomic::AtomicBool,
220}
221
222impl Default for MemoryTool {
223 fn default() -> Self {
224 Self::new()
225 }
226}
227
228impl MemoryTool {
229 pub fn new() -> Self {
230 Self {
231 store: tokio::sync::Mutex::new(MemoryStore::default()),
232 initialized: std::sync::atomic::AtomicBool::new(false),
233 }
234 }
235
236 pub async fn init(&self) -> Result<()> {
238 use std::sync::atomic::Ordering;
239
240 if self.initialized.load(Ordering::SeqCst) {
241 return Ok(());
242 }
243
244 let mut store = self.store.lock().await;
245 if let Ok(loaded) = MemoryStore::load().await {
246 *store = loaded;
247 }
248 self.initialized.store(true, Ordering::SeqCst);
249 Ok(())
250 }
251
252 pub async fn persist(&self) -> Result<()> {
254 let store = self.store.lock().await;
255 store.save().await
256 }
257}
258
259#[async_trait]
260impl Tool for MemoryTool {
261 fn id(&self) -> &str {
262 "memory"
263 }
264
265 fn name(&self) -> &str {
266 "Memory"
267 }
268
269 fn description(&self) -> &str {
270 "Store and retrieve persistent knowledge across sessions. Use 'save' to capture important insights, 'search' to find relevant memories, 'list' to see all entries, 'tags' to see available categories, or 'delete' to remove an entry."
271 }
272
273 fn parameters(&self) -> Value {
274 json!({
275 "type": "object",
276 "properties": {
277 "action": {
278 "type": "string",
279 "description": "Action to perform: 'save' (store new memory), 'search' (find memories), 'get' (retrieve specific memory), 'list' (show recent), 'tags' (show categories), 'delete' (remove), 'stats' (show statistics)",
280 "enum": ["save", "search", "get", "list", "tags", "delete", "stats"]
281 },
282 "content": {
283 "type": "string",
284 "description": "Memory content to save (required for 'save' action)"
285 },
286 "tags": {
287 "type": "array",
288 "items": {"type": "string"},
289 "description": "Tags for categorization (optional for 'save')"
290 },
291 "query": {
292 "type": "string",
293 "description": "Search query (for 'search' action)"
294 },
295 "scope": {
296 "type": "string",
297 "description": "Project/context scope (optional for 'save')"
298 },
299 "importance": {
300 "type": "integer",
301 "description": "Importance level 1-5 (optional for 'save', default 3)"
302 },
303 "id": {
304 "type": "string",
305 "description": "Memory ID (required for 'get' and 'delete')"
306 },
307 "limit": {
308 "type": "integer",
309 "description": "Maximum results to return (default 10, for 'search' and 'list')"
310 }
311 },
312 "required": ["action"]
313 })
314 }
315
316 async fn execute(&self, args: Value) -> Result<ToolResult> {
317 let needs_init = {
320 let store = self.store.lock().await;
321 store.entries.is_empty()
322 };
323
324 if needs_init {
325 self.init().await.ok();
326 }
327
328 let action = args["action"]
329 .as_str()
330 .ok_or_else(|| anyhow::anyhow!("action is required"))?;
331
332 match action {
333 "save" => self.execute_save(args).await,
334 "search" => self.execute_search(args).await,
335 "get" => self.execute_get(args).await,
336 "list" => self.execute_list(args).await,
337 "tags" => self.execute_tags(args).await,
338 "delete" => self.execute_delete(args).await,
339 "stats" => self.execute_stats(args).await,
340 _ => Ok(ToolResult::error(format!(
341 "Unknown action: {}. Use 'save', 'search', 'get', 'list', 'tags', 'delete', or 'stats'.",
342 action
343 ))),
344 }
345 }
346}
347
348impl MemoryTool {
349 async fn execute_save(&self, args: Value) -> Result<ToolResult> {
350 let content = args["content"]
351 .as_str()
352 .ok_or_else(|| anyhow::anyhow!("content is required for 'save' action"))?;
353
354 let tags: Vec<String> = args["tags"]
355 .as_array()
356 .map(|arr| {
357 arr.iter()
358 .filter_map(|v| v.as_str().map(String::from))
359 .collect()
360 })
361 .unwrap_or_default();
362
363 let scope = args["scope"].as_str().map(String::from);
364 let importance = args["importance"].as_u64().map(|v| v as u8).unwrap_or(3);
365
366 let mut entry = MemoryEntry::new(content, tags).with_importance(importance);
367
368 if let Some(s) = scope {
369 entry = entry.with_scope(s);
370 }
371
372 let id = {
373 let mut store = self.store.lock().await;
374 store.add(entry)
375 };
376
377 self.persist().await?;
379
380 Ok(ToolResult::success(format!(
381 "Memory saved with ID: {}\nImportance: {}/5",
382 id, importance
383 )))
384 }
385
386 async fn execute_search(&self, args: Value) -> Result<ToolResult> {
387 let query = args["query"].as_str();
388 let tags: Option<Vec<String>> = args["tags"].as_array().map(|arr| {
389 arr.iter()
390 .filter_map(|v| v.as_str().map(String::from))
391 .collect()
392 });
393 let limit = args["limit"].as_u64().map(|v| v as usize).unwrap_or(10);
394
395 let tags_ref = tags.as_ref().map(|v| v.as_slice());
396
397 let results = {
398 let mut store = self.store.lock().await;
399 store.search(query, tags_ref, limit)
400 };
401
402 if results.is_empty() {
403 return Ok(ToolResult::success(
404 "No memories found matching your criteria.".to_string(),
405 ));
406 }
407
408 let output = results
409 .iter()
410 .enumerate()
411 .map(|(i, m)| {
412 format!(
413 "{}. [{}] {} - {}\n Tags: {}\n Created: {}",
414 i + 1,
415 m.id.chars().take(8).collect::<String>(),
416 m.content.chars().take(80).collect::<String>()
417 + if m.content.len() > 80 { "..." } else { "" },
418 format!("accessed {} times", m.access_count),
419 m.tags.join(", "),
420 m.created_at.format("%Y-%m-%d %H:%M")
421 )
422 })
423 .collect::<Vec<_>>()
424 .join("\n\n");
425
426 Ok(ToolResult::success(format!(
427 "Found {} memories:\n\n{}",
428 results.len(),
429 output
430 )))
431 }
432
433 async fn execute_get(&self, args: Value) -> Result<ToolResult> {
434 let id = args["id"]
435 .as_str()
436 .ok_or_else(|| anyhow::anyhow!("id is required for 'get' action"))?;
437
438 let entry = {
439 let mut store = self.store.lock().await;
440 store.get(id).map(|e| e.clone())
441 };
442
443 match entry {
444 Some(e) => {
445 self.persist().await?;
447
448 Ok(ToolResult::success(format!(
449 "Memory ID: {}\nImportance: {}/5\nTags: {}\nCreated: {}\nAccessed: {} times\n\n{}",
450 e.id,
451 e.importance,
452 e.tags.join(", "),
453 e.created_at.format("%Y-%m-%d %H:%M:%S"),
454 e.access_count,
455 e.content
456 )))
457 }
458 None => Ok(ToolResult::error(format!("Memory not found: {}", id))),
459 }
460 }
461
462 async fn execute_list(&self, args: Value) -> Result<ToolResult> {
463 let limit = args["limit"].as_u64().map(|v| v as usize).unwrap_or(10);
464
465 let results = {
466 let mut store = self.store.lock().await;
467 store.search(None, None, limit)
468 };
469
470 if results.is_empty() {
471 return Ok(ToolResult::success(
472 "No memories stored yet. Use 'save' to add your first memory.".to_string(),
473 ));
474 }
475
476 let output = results
477 .iter()
478 .enumerate()
479 .map(|(i, m)| {
480 format!(
481 "{}. [{}] {} (importance: {}/5, accessed: {}x)",
482 i + 1,
483 m.id.chars().take(8).collect::<String>(),
484 m.content.chars().take(60).collect::<String>()
485 + if m.content.len() > 60 { "..." } else { "" },
486 m.importance,
487 m.access_count
488 )
489 })
490 .collect::<Vec<_>>()
491 .join("\n");
492
493 Ok(ToolResult::success(format!(
494 "Recent memories:\n\n{}",
495 output
496 )))
497 }
498
499 async fn execute_tags(&self, _args: Value) -> Result<ToolResult> {
500 let tags = {
501 let store = self.store.lock().await;
502 store.all_tags()
503 };
504
505 if tags.is_empty() {
506 return Ok(ToolResult::success(
507 "No tags yet. Add tags when saving memories.".to_string(),
508 ));
509 }
510
511 let mut sorted: Vec<_> = tags.iter().collect();
512 sorted.sort_by(|a, b| b.1.cmp(a.1));
513
514 let output = sorted
515 .iter()
516 .map(|(tag, count)| format!(" {} ({} memories)", tag, count))
517 .collect::<Vec<_>>()
518 .join("\n");
519
520 Ok(ToolResult::success(format!(
521 "Available tags:\n\n{}",
522 output
523 )))
524 }
525
526 async fn execute_delete(&self, args: Value) -> Result<ToolResult> {
527 let id = args["id"]
528 .as_str()
529 .ok_or_else(|| anyhow::anyhow!("id is required for 'delete' action"))?;
530
531 let deleted = {
532 let mut store = self.store.lock().await;
533 store.delete(id)
534 };
535
536 if deleted {
537 self.persist().await?;
538 Ok(ToolResult::success(format!("Memory deleted: {}", id)))
539 } else {
540 Ok(ToolResult::error(format!("Memory not found: {}", id)))
541 }
542 }
543
544 async fn execute_stats(&self, _args: Value) -> Result<ToolResult> {
545 let stats = {
546 let store = self.store.lock().await;
547 store.stats()
548 };
549
550 let tags_output = if stats.tags.is_empty() {
551 "None".to_string()
552 } else {
553 let mut sorted: Vec<_> = stats.tags.iter().collect();
554 sorted.sort_by(|a, b| b.1.cmp(a.1));
555 sorted
556 .iter()
557 .take(10)
558 .map(|(t, c)| format!(" {}: {}", t, c))
559 .collect::<Vec<_>>()
560 .join("\n")
561 };
562
563 Ok(ToolResult::success(format!(
564 "Memory Statistics:\n\n\
565 Total entries: {}\n\
566 Total accesses: {}\n\
567 Unique tags: {}\n\n\
568 Top tags:\n{}",
569 stats.total_entries, stats.total_accesses, stats.unique_tags, tags_output
570 )))
571 }
572}
573
574#[cfg(test)]
575mod tests {
576 use super::*;
577 use std::sync::atomic::Ordering;
578
579 #[tokio::test]
580 async fn test_memory_save_and_get() {
581 let tool = MemoryTool::new();
582 tool.initialized.store(true, Ordering::SeqCst);
584
585 let result = tool
587 .execute(json!({
588 "action": "save",
589 "content": "Test memory content",
590 "tags": ["test", "example"],
591 "importance": 4
592 }))
593 .await
594 .unwrap();
595
596 assert!(result.success);
597
598 let result = tool
600 .execute(json!({
601 "action": "list",
602 "limit": 5
603 }))
604 .await
605 .unwrap();
606
607 assert!(result.success);
608 assert!(result.output.contains("Test memory content"));
609
610 let result = tool
612 .execute(json!({
613 "action": "stats"
614 }))
615 .await
616 .unwrap();
617
618 assert!(result.success);
619 assert!(result.output.contains("Total entries: 1"));
620 }
621
622 #[tokio::test]
623 async fn test_memory_search() {
624 let tool = MemoryTool::new();
625 tool.initialized.store(true, Ordering::SeqCst);
627
628 tool.execute(json!({
630 "action": "save",
631 "content": "Rust programming insights",
632 "tags": ["rust", "programming"]
633 }))
634 .await
635 .unwrap();
636
637 tool.execute(json!({
638 "action": "save",
639 "content": "Python tips",
640 "tags": ["python", "programming"]
641 }))
642 .await
643 .unwrap();
644
645 let result = tool
647 .execute(json!({
648 "action": "search",
649 "tags": ["rust"]
650 }))
651 .await
652 .unwrap();
653
654 assert!(result.success);
655 assert!(result.output.contains("Rust"));
656 assert!(!result.output.contains("Python"));
657 }
658}