1use super::{Tool, ToolResult};
7use anyhow::Result;
8use async_trait::async_trait;
9use chrono::{DateTime, Utc};
10use serde::{Deserialize, Serialize};
11use serde_json::{Value, json};
12use std::collections::HashMap;
13use std::path::PathBuf;
14use tokio::fs;
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct MemoryEntry {
19 pub id: String,
21 pub content: String,
23 pub tags: Vec<String>,
25 pub created_at: DateTime<Utc>,
27 pub accessed_at: DateTime<Utc>,
29 pub access_count: u64,
31 pub scope: Option<String>,
33 pub source: Option<String>,
35 pub importance: u8,
37}
38
39impl MemoryEntry {
40 pub fn new(content: impl Into<String>, tags: Vec<String>) -> Self {
41 let now = Utc::now();
42 Self {
43 id: uuid::Uuid::new_v4().to_string(),
44 content: content.into(),
45 tags,
46 created_at: now,
47 accessed_at: now,
48 access_count: 0,
49 scope: None,
50 source: None,
51 importance: 3, }
53 }
54
55 pub fn with_scope(mut self, scope: impl Into<String>) -> Self {
56 self.scope = Some(scope.into());
57 self
58 }
59
60 pub fn with_source(mut self, source: impl Into<String>) -> Self {
61 self.source = Some(source.into());
62 self
63 }
64
65 pub fn with_importance(mut self, importance: u8) -> Self {
66 self.importance = importance.min(5);
67 self
68 }
69
70 pub fn touch(&mut self) {
71 self.accessed_at = Utc::now();
72 self.access_count += 1;
73 }
74}
75
76#[derive(Debug, Clone, Serialize, Deserialize, Default)]
78pub struct MemoryStore {
79 entries: HashMap<String, MemoryEntry>,
80}
81
82impl MemoryStore {
83 pub fn default_path() -> std::path::PathBuf {
85 directories::ProjectDirs::from("com", "codetether", "codetether")
86 .map(|p| p.data_dir().join("memory.json"))
87 .unwrap_or_else(|| PathBuf::from(".codetether/memory.json"))
88 }
89
90 pub async fn load() -> Result<Self> {
92 let path = Self::default_path();
93 if !path.exists() {
94 return Ok(Self::default());
95 }
96 let content = fs::read_to_string(&path).await?;
97 let store: MemoryStore = serde_json::from_str(&content)?;
98 Ok(store)
99 }
100
101 pub async fn save(&self) -> Result<()> {
103 let path = Self::default_path();
104 if let Some(parent) = path.parent() {
105 fs::create_dir_all(parent).await?;
106 }
107 let content = serde_json::to_string_pretty(self)?;
108 fs::write(&path, content).await?;
109 Ok(())
110 }
111
112 pub fn add(&mut self, entry: MemoryEntry) -> String {
114 let id = entry.id.clone();
115 self.entries.insert(id.clone(), entry);
116 id
117 }
118
119 pub fn get(&mut self, id: &str) -> Option<MemoryEntry> {
121 let entry = self.entries.get_mut(id)?;
122 entry.touch();
123 Some(entry.clone())
124 }
125
126 pub fn search(
128 &mut self,
129 query: Option<&str>,
130 tags: Option<&[String]>,
131 limit: usize,
132 ) -> Vec<MemoryEntry> {
133 let mut results: Vec<MemoryEntry> = self
134 .entries
135 .values_mut()
136 .filter(|entry| {
137 if let Some(search_tags) = tags {
139 if !search_tags.is_empty()
140 && !search_tags.iter().any(|t| entry.tags.contains(t))
141 {
142 return false;
143 }
144 }
145
146 if let Some(q) = query {
148 let q_lower = q.to_lowercase();
149 let matches_content = entry.content.to_lowercase().contains(&q_lower);
150 let matches_tags = entry
151 .tags
152 .iter()
153 .any(|t| t.to_lowercase().contains(&q_lower));
154 if !matches_content && !matches_tags {
155 return false;
156 }
157 }
158
159 true
160 })
161 .map(|e| {
162 e.touch();
163 e.clone()
164 })
165 .collect();
166
167 results.sort_by(|a, b| {
169 b.importance
170 .cmp(&a.importance)
171 .then_with(|| b.access_count.cmp(&a.access_count))
172 });
173
174 results.truncate(limit);
175 results
176 }
177
178 pub fn all_tags(&self) -> HashMap<String, u64> {
180 let mut tags: HashMap<String, u64> = HashMap::new();
181 for entry in self.entries.values() {
182 for tag in &entry.tags {
183 *tags.entry(tag.clone()).or_insert(0) += 1;
184 }
185 }
186 tags
187 }
188
189 pub fn delete(&mut self, id: &str) -> bool {
191 self.entries.remove(id).is_some()
192 }
193
194 pub fn stats(&self) -> MemoryStats {
196 let total = self.entries.len();
197 let total_accesses: u64 = self.entries.values().map(|e| e.access_count).sum();
198 let tags = self.all_tags();
199 MemoryStats {
200 total_entries: total,
201 total_accesses,
202 unique_tags: tags.len(),
203 tags,
204 }
205 }
206}
207
208#[derive(Debug, Clone, Serialize, Deserialize)]
209pub struct MemoryStats {
210 pub total_entries: usize,
211 pub total_accesses: u64,
212 pub unique_tags: usize,
213 pub tags: HashMap<String, u64>,
214}
215
216pub struct MemoryTool {
218 store: tokio::sync::Mutex<MemoryStore>,
219}
220
221impl Default for MemoryTool {
222 fn default() -> Self {
223 Self::new()
224 }
225}
226
227impl MemoryTool {
228 pub fn new() -> Self {
229 Self {
230 store: tokio::sync::Mutex::new(MemoryStore::default()),
231 }
232 }
233
234 pub async fn init(&self) -> Result<()> {
236 let mut store = self.store.lock().await;
237 *store = MemoryStore::load().await?;
238 Ok(())
239 }
240
241 pub async fn persist(&self) -> Result<()> {
243 let store = self.store.lock().await;
244 store.save().await
245 }
246}
247
248#[async_trait]
249impl Tool for MemoryTool {
250 fn id(&self) -> &str {
251 "memory"
252 }
253
254 fn name(&self) -> &str {
255 "Memory"
256 }
257
258 fn description(&self) -> &str {
259 "Store and retrieve persistent knowledge across sessions. Use 'save' to capture important insights, 'search' to find relevant memories, 'list' to see all entries, 'tags' to see available categories, or 'delete' to remove an entry."
260 }
261
262 fn parameters(&self) -> Value {
263 json!({
264 "type": "object",
265 "properties": {
266 "action": {
267 "type": "string",
268 "description": "Action to perform: 'save' (store new memory), 'search' (find memories), 'get' (retrieve specific memory), 'list' (show recent), 'tags' (show categories), 'delete' (remove), 'stats' (show statistics)",
269 "enum": ["save", "search", "get", "list", "tags", "delete", "stats"]
270 },
271 "content": {
272 "type": "string",
273 "description": "Memory content to save (required for 'save' action)"
274 },
275 "tags": {
276 "type": "array",
277 "items": {"type": "string"},
278 "description": "Tags for categorization (optional for 'save')"
279 },
280 "query": {
281 "type": "string",
282 "description": "Search query (for 'search' action)"
283 },
284 "scope": {
285 "type": "string",
286 "description": "Project/context scope (optional for 'save')"
287 },
288 "importance": {
289 "type": "integer",
290 "description": "Importance level 1-5 (optional for 'save', default 3)"
291 },
292 "id": {
293 "type": "string",
294 "description": "Memory ID (required for 'get' and 'delete')"
295 },
296 "limit": {
297 "type": "integer",
298 "description": "Maximum results to return (default 10, for 'search' and 'list')"
299 }
300 },
301 "required": ["action"]
302 })
303 }
304
305 async fn execute(&self, args: Value) -> Result<ToolResult> {
306 self.init().await.ok();
308
309 let action = args["action"]
310 .as_str()
311 .ok_or_else(|| anyhow::anyhow!("action is required"))?;
312
313 match action {
314 "save" => self.execute_save(args).await,
315 "search" => self.execute_search(args).await,
316 "get" => self.execute_get(args).await,
317 "list" => self.execute_list(args).await,
318 "tags" => self.execute_tags(args).await,
319 "delete" => self.execute_delete(args).await,
320 "stats" => self.execute_stats(args).await,
321 _ => Ok(ToolResult::error(format!(
322 "Unknown action: {}. Use 'save', 'search', 'get', 'list', 'tags', 'delete', or 'stats'.",
323 action
324 ))),
325 }
326 }
327}
328
329impl MemoryTool {
330 async fn execute_save(&self, args: Value) -> Result<ToolResult> {
331 let content = args["content"]
332 .as_str()
333 .ok_or_else(|| anyhow::anyhow!("content is required for 'save' action"))?;
334
335 let tags: Vec<String> = args["tags"]
336 .as_array()
337 .map(|arr| {
338 arr.iter()
339 .filter_map(|v| v.as_str().map(String::from))
340 .collect()
341 })
342 .unwrap_or_default();
343
344 let scope = args["scope"].as_str().map(String::from);
345 let importance = args["importance"].as_u64().map(|v| v as u8).unwrap_or(3);
346
347 let mut entry = MemoryEntry::new(content, tags).with_importance(importance);
348
349 if let Some(s) = scope {
350 entry = entry.with_scope(s);
351 }
352
353 let id = {
354 let mut store = self.store.lock().await;
355 store.add(entry)
356 };
357
358 self.persist().await?;
360
361 Ok(ToolResult::success(format!(
362 "Memory saved with ID: {}\nImportance: {}/5",
363 id, importance
364 )))
365 }
366
367 async fn execute_search(&self, args: Value) -> Result<ToolResult> {
368 let query = args["query"].as_str();
369 let tags: Option<Vec<String>> = args["tags"].as_array().map(|arr| {
370 arr.iter()
371 .filter_map(|v| v.as_str().map(String::from))
372 .collect()
373 });
374 let limit = args["limit"].as_u64().map(|v| v as usize).unwrap_or(10);
375
376 let tags_ref = tags.as_ref().map(|v| v.as_slice());
377
378 let results = {
379 let mut store = self.store.lock().await;
380 store.search(query, tags_ref, limit)
381 };
382
383 if results.is_empty() {
384 return Ok(ToolResult::success(
385 "No memories found matching your criteria.".to_string(),
386 ));
387 }
388
389 let output = results
390 .iter()
391 .enumerate()
392 .map(|(i, m)| {
393 format!(
394 "{}. [{}] {} - {}\n Tags: {}\n Created: {}",
395 i + 1,
396 m.id.chars().take(8).collect::<String>(),
397 m.content.chars().take(80).collect::<String>()
398 + if m.content.len() > 80 { "..." } else { "" },
399 format!("accessed {} times", m.access_count),
400 m.tags.join(", "),
401 m.created_at.format("%Y-%m-%d %H:%M")
402 )
403 })
404 .collect::<Vec<_>>()
405 .join("\n\n");
406
407 Ok(ToolResult::success(format!(
408 "Found {} memories:\n\n{}",
409 results.len(),
410 output
411 )))
412 }
413
414 async fn execute_get(&self, args: Value) -> Result<ToolResult> {
415 let id = args["id"]
416 .as_str()
417 .ok_or_else(|| anyhow::anyhow!("id is required for 'get' action"))?;
418
419 let entry = {
420 let mut store = self.store.lock().await;
421 store.get(id).map(|e| e.clone())
422 };
423
424 match entry {
425 Some(e) => {
426 self.persist().await?;
428
429 Ok(ToolResult::success(format!(
430 "Memory ID: {}\nImportance: {}/5\nTags: {}\nCreated: {}\nAccessed: {} times\n\n{}",
431 e.id,
432 e.importance,
433 e.tags.join(", "),
434 e.created_at.format("%Y-%m-%d %H:%M:%S"),
435 e.access_count,
436 e.content
437 )))
438 }
439 None => Ok(ToolResult::error(format!("Memory not found: {}", id))),
440 }
441 }
442
443 async fn execute_list(&self, args: Value) -> Result<ToolResult> {
444 let limit = args["limit"].as_u64().map(|v| v as usize).unwrap_or(10);
445
446 let results = {
447 let mut store = self.store.lock().await;
448 store.search(None, None, limit)
449 };
450
451 if results.is_empty() {
452 return Ok(ToolResult::success(
453 "No memories stored yet. Use 'save' to add your first memory.".to_string(),
454 ));
455 }
456
457 let output = results
458 .iter()
459 .enumerate()
460 .map(|(i, m)| {
461 format!(
462 "{}. [{}] {} (importance: {}/5, accessed: {}x)",
463 i + 1,
464 m.id.chars().take(8).collect::<String>(),
465 m.content.chars().take(60).collect::<String>()
466 + if m.content.len() > 60 { "..." } else { "" },
467 m.importance,
468 m.access_count
469 )
470 })
471 .collect::<Vec<_>>()
472 .join("\n");
473
474 Ok(ToolResult::success(format!(
475 "Recent memories:\n\n{}",
476 output
477 )))
478 }
479
480 async fn execute_tags(&self, _args: Value) -> Result<ToolResult> {
481 let tags = {
482 let store = self.store.lock().await;
483 store.all_tags()
484 };
485
486 if tags.is_empty() {
487 return Ok(ToolResult::success(
488 "No tags yet. Add tags when saving memories.".to_string(),
489 ));
490 }
491
492 let mut sorted: Vec<_> = tags.iter().collect();
493 sorted.sort_by(|a, b| b.1.cmp(a.1));
494
495 let output = sorted
496 .iter()
497 .map(|(tag, count)| format!(" {} ({} memories)", tag, count))
498 .collect::<Vec<_>>()
499 .join("\n");
500
501 Ok(ToolResult::success(format!(
502 "Available tags:\n\n{}",
503 output
504 )))
505 }
506
507 async fn execute_delete(&self, args: Value) -> Result<ToolResult> {
508 let id = args["id"]
509 .as_str()
510 .ok_or_else(|| anyhow::anyhow!("id is required for 'delete' action"))?;
511
512 let deleted = {
513 let mut store = self.store.lock().await;
514 store.delete(id)
515 };
516
517 if deleted {
518 self.persist().await?;
519 Ok(ToolResult::success(format!("Memory deleted: {}", id)))
520 } else {
521 Ok(ToolResult::error(format!("Memory not found: {}", id)))
522 }
523 }
524
525 async fn execute_stats(&self, _args: Value) -> Result<ToolResult> {
526 let stats = {
527 let store = self.store.lock().await;
528 store.stats()
529 };
530
531 let tags_output = if stats.tags.is_empty() {
532 "None".to_string()
533 } else {
534 let mut sorted: Vec<_> = stats.tags.iter().collect();
535 sorted.sort_by(|a, b| b.1.cmp(a.1));
536 sorted
537 .iter()
538 .take(10)
539 .map(|(t, c)| format!(" {}: {}", t, c))
540 .collect::<Vec<_>>()
541 .join("\n")
542 };
543
544 Ok(ToolResult::success(format!(
545 "Memory Statistics:\n\n\
546 Total entries: {}\n\
547 Total accesses: {}\n\
548 Unique tags: {}\n\n\
549 Top tags:\n{}",
550 stats.total_entries, stats.total_accesses, stats.unique_tags, tags_output
551 )))
552 }
553}
554
555#[cfg(test)]
556mod tests {
557 use super::*;
558
559 #[tokio::test]
560 async fn test_memory_save_and_get() {
561 let tool = MemoryTool::new();
562
563 let result = tool
565 .execute(json!({
566 "action": "save",
567 "content": "Test memory content",
568 "tags": ["test", "example"],
569 "importance": 4
570 }))
571 .await
572 .unwrap();
573
574 assert!(result.success);
575
576 let result = tool
578 .execute(json!({
579 "action": "list",
580 "limit": 5
581 }))
582 .await
583 .unwrap();
584
585 assert!(result.success);
586 assert!(result.output.contains("Test memory content"));
587
588 let result = tool
590 .execute(json!({
591 "action": "stats"
592 }))
593 .await
594 .unwrap();
595
596 assert!(result.success);
597 assert!(result.output.contains("Total entries: 1"));
598 }
599
600 #[tokio::test]
601 async fn test_memory_search() {
602 let tool = MemoryTool::new();
603
604 tool.execute(json!({
606 "action": "save",
607 "content": "Rust programming insights",
608 "tags": ["rust", "programming"]
609 }))
610 .await
611 .unwrap();
612
613 tool.execute(json!({
614 "action": "save",
615 "content": "Python tips",
616 "tags": ["python", "programming"]
617 }))
618 .await
619 .unwrap();
620
621 let result = tool
623 .execute(json!({
624 "action": "search",
625 "tags": ["rust"]
626 }))
627 .await
628 .unwrap();
629
630 assert!(result.success);
631 assert!(result.output.contains("Rust"));
632 assert!(!result.output.contains("Python"));
633 }
634}