1mod git_scope;
7mod scope;
8
9use super::{Tool, ToolResult};
10use anyhow::Result;
11use async_trait::async_trait;
12use chrono::{DateTime, Utc};
13use serde::{Deserialize, Serialize};
14use serde_json::{Value, json};
15use std::collections::HashMap;
16use std::path::PathBuf;
17use tokio::fs;
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct MemoryEntry {
22 pub id: String,
24 pub content: String,
26 pub tags: Vec<String>,
28 pub created_at: DateTime<Utc>,
30 pub accessed_at: DateTime<Utc>,
32 pub access_count: u64,
34 pub scope: Option<String>,
36 pub source: Option<String>,
38 pub importance: u8,
40}
41
42impl MemoryEntry {
43 pub fn new(content: impl Into<String>, tags: Vec<String>) -> Self {
44 let now = Utc::now();
45 Self {
46 id: uuid::Uuid::new_v4().to_string(),
47 content: content.into(),
48 tags,
49 created_at: now,
50 accessed_at: now,
51 access_count: 0,
52 scope: None,
53 source: None,
54 importance: 3, }
56 }
57
58 pub fn with_scope(mut self, scope: impl Into<String>) -> Self {
59 self.scope = Some(scope.into());
60 self
61 }
62
63 #[allow(dead_code)]
65 pub fn with_source(mut self, source: impl Into<String>) -> Self {
66 self.source = Some(source.into());
67 self
68 }
69
70 pub fn with_importance(mut self, importance: u8) -> Self {
71 self.importance = importance.min(5);
72 self
73 }
74
75 pub fn touch(&mut self) {
76 self.accessed_at = Utc::now();
77 self.access_count += 1;
78 }
79}
80
81#[derive(Debug, Clone, Serialize, Deserialize, Default)]
83pub struct MemoryStore {
84 entries: HashMap<String, MemoryEntry>,
85}
86
87impl MemoryStore {
88 pub fn default_path() -> std::path::PathBuf {
90 crate::config::Config::data_dir()
91 .map(|p| p.join("memory.json"))
92 .unwrap_or_else(|| PathBuf::from(".codetether-agent/memory.json"))
93 }
94
95 pub async fn load() -> Result<Self> {
97 let path = Self::default_path();
98 if !path.exists() {
99 return Ok(Self::default());
100 }
101 let content = fs::read_to_string(&path).await?;
102 let store: MemoryStore = serde_json::from_str(&content)?;
103 Ok(store)
104 }
105
106 pub async fn save(&self) -> Result<()> {
108 let path = Self::default_path();
109 if let Some(parent) = path.parent() {
110 fs::create_dir_all(parent).await?;
111 }
112 let content = serde_json::to_string_pretty(self)?;
113 fs::write(&path, content).await?;
114 Ok(())
115 }
116
117 pub fn add(&mut self, entry: MemoryEntry) -> String {
119 let id = entry.id.clone();
120 self.entries.insert(id.clone(), entry);
121 id
122 }
123
124 pub fn get(&mut self, id: &str) -> Option<MemoryEntry> {
126 let id = self.resolve_id(id)?;
127 let entry = self.entries.get_mut(&id)?;
128 entry.touch();
129 Some(entry.clone())
130 }
131
132 fn resolve_id(&self, id: &str) -> Option<String> {
133 if self.entries.contains_key(id) {
134 return Some(id.to_string());
135 }
136 let mut matches = self.entries.keys().filter(|key| key.starts_with(id));
137 let found = matches.next()?.clone();
138 matches.next().is_none().then_some(found)
139 }
140
141 pub fn search(
143 &mut self,
144 query: Option<&str>,
145 tags: Option<&[String]>,
146 scope: Option<&str>,
147 limit: usize,
148 ) -> Vec<MemoryEntry> {
149 let mut results: Vec<MemoryEntry> = self
150 .entries
151 .values_mut()
152 .filter(|entry| {
153 if let Some(search_tags) = tags
155 && !search_tags.is_empty()
156 && !search_tags.iter().any(|t| entry.tags.contains(t))
157 {
158 return false;
159 }
160 if let Some(scope) = scope
161 && entry.scope.as_deref() != Some(scope)
162 {
163 return false;
164 }
165
166 if let Some(q) = query {
168 let q_lower = q.to_lowercase();
169 let matches_content = entry.content.to_lowercase().contains(&q_lower);
170 let matches_tags = entry
171 .tags
172 .iter()
173 .any(|t| t.to_lowercase().contains(&q_lower));
174 if !matches_content && !matches_tags {
175 return false;
176 }
177 }
178
179 true
180 })
181 .map(|e| {
182 e.touch();
183 e.clone()
184 })
185 .collect();
186
187 results.sort_by(|a, b| {
189 b.importance
190 .cmp(&a.importance)
191 .then_with(|| b.access_count.cmp(&a.access_count))
192 });
193
194 results.truncate(limit);
195 results
196 }
197
198 pub fn all_tags(&self) -> HashMap<String, u64> {
200 let mut tags: HashMap<String, u64> = HashMap::new();
201 for entry in self.entries.values() {
202 for tag in &entry.tags {
203 *tags.entry(tag.clone()).or_insert(0) += 1;
204 }
205 }
206 tags
207 }
208
209 pub fn delete(&mut self, id: &str) -> bool {
211 self.resolve_id(id)
212 .and_then(|id| self.entries.remove(&id))
213 .is_some()
214 }
215
216 pub fn stats(&self) -> MemoryStats {
218 let total = self.entries.len();
219 let total_accesses: u64 = self.entries.values().map(|e| e.access_count).sum();
220 let tags = self.all_tags();
221 MemoryStats {
222 total_entries: total,
223 total_accesses,
224 unique_tags: tags.len(),
225 tags,
226 }
227 }
228}
229
230#[derive(Debug, Clone, Serialize, Deserialize)]
231pub struct MemoryStats {
232 pub total_entries: usize,
233 pub total_accesses: u64,
234 pub unique_tags: usize,
235 pub tags: HashMap<String, u64>,
236}
237
238pub struct MemoryTool {
240 store: tokio::sync::Mutex<MemoryStore>,
241 initialized: std::sync::atomic::AtomicBool,
242}
243
244impl Default for MemoryTool {
245 fn default() -> Self {
246 Self::new()
247 }
248}
249
250impl MemoryTool {
251 pub fn new() -> Self {
252 Self {
253 store: tokio::sync::Mutex::new(MemoryStore::default()),
254 initialized: std::sync::atomic::AtomicBool::new(false),
255 }
256 }
257
258 pub async fn init(&self) -> Result<()> {
260 use std::sync::atomic::Ordering;
261
262 if self.initialized.load(Ordering::SeqCst) {
263 return Ok(());
264 }
265
266 let mut store = self.store.lock().await;
267 if let Ok(loaded) = MemoryStore::load().await {
268 *store = loaded;
269 }
270 self.initialized.store(true, Ordering::SeqCst);
271 Ok(())
272 }
273
274 pub async fn persist(&self) -> Result<()> {
276 let store = self.store.lock().await;
277 store.save().await
278 }
279}
280
281#[async_trait]
282impl Tool for MemoryTool {
283 fn id(&self) -> &str {
284 "memory"
285 }
286
287 fn name(&self) -> &str {
288 "Memory"
289 }
290
291 fn description(&self) -> &str {
292 "Store and retrieve persistent knowledge across sessions. Use 'save' to capture important insights, 'search' to find relevant memories, 'list' to see all entries, 'tags' to see available categories, or 'delete' to remove an entry."
293 }
294
295 fn parameters(&self) -> Value {
296 json!({
297 "type": "object",
298 "properties": {
299 "action": {
300 "type": "string",
301 "description": "Action to perform: 'save' (store new memory), 'search' (find memories), 'get' (retrieve specific memory), 'list' (show recent), 'tags' (show categories), 'delete' (remove), 'stats' (show statistics)",
302 "enum": ["save", "search", "get", "list", "tags", "delete", "stats"]
303 },
304 "content": {
305 "type": "string",
306 "description": "Memory content to save (required for 'save' action)"
307 },
308 "tags": {
309 "type": "array",
310 "items": {"type": "string"},
311 "description": "Tags for categorization (optional for 'save')"
312 },
313 "query": {
314 "type": "string",
315 "description": "Search query (for 'search' action)"
316 },
317 "scope": {
318 "type": "string",
319 "description": "Project/context scope. Defaults to a stable git scope, not the transient worktree path. Use 'all' for unscoped search/list."
320 },
321 "importance": {
322 "type": "integer",
323 "description": "Importance level 1-5 (optional for 'save', default 3)"
324 },
325 "id": {
326 "type": "string",
327 "description": "Memory ID (required for 'get' and 'delete')"
328 },
329 "limit": {
330 "type": "integer",
331 "description": "Maximum results to return (default 10, for 'search' and 'list')"
332 }
333 },
334 "required": ["action"]
335 })
336 }
337
338 async fn execute(&self, args: Value) -> Result<ToolResult> {
339 let needs_init = {
342 let store = self.store.lock().await;
343 store.entries.is_empty()
344 };
345
346 if needs_init {
347 self.init().await.ok();
348 }
349
350 let action = args["action"]
351 .as_str()
352 .ok_or_else(|| anyhow::anyhow!("action is required"))?;
353
354 match action {
355 "save" => self.execute_save(args).await,
356 "search" => self.execute_search(args).await,
357 "get" => self.execute_get(args).await,
358 "list" => self.execute_list(args).await,
359 "tags" => self.execute_tags(args).await,
360 "delete" => self.execute_delete(args).await,
361 "stats" => self.execute_stats(args).await,
362 _ => Ok(ToolResult::error(format!(
363 "Unknown action: {}. Use 'save', 'search', 'get', 'list', 'tags', 'delete', or 'stats'.",
364 action
365 ))),
366 }
367 }
368}
369
370impl MemoryTool {
371 async fn execute_save(&self, args: Value) -> Result<ToolResult> {
372 let content = args["content"]
373 .as_str()
374 .ok_or_else(|| anyhow::anyhow!("content is required for 'save' action"))?;
375
376 let tags: Vec<String> = args["tags"]
377 .as_array()
378 .map(|arr| {
379 arr.iter()
380 .filter_map(|v| v.as_str().map(String::from))
381 .collect()
382 })
383 .unwrap_or_default();
384
385 let scope = scope::save(&args);
386 let importance = args["importance"].as_u64().map(|v| v as u8).unwrap_or(3);
387
388 let mut entry = MemoryEntry::new(content, tags).with_importance(importance);
389
390 if let Some(s) = scope {
391 entry = entry.with_scope(s);
392 }
393
394 let id = {
395 let mut store = self.store.lock().await;
396 store.add(entry)
397 };
398
399 self.persist().await?;
401
402 Ok(ToolResult::success(format!(
403 "Memory saved with ID: {}\nImportance: {}/5",
404 id, importance
405 )))
406 }
407
408 async fn execute_search(&self, args: Value) -> Result<ToolResult> {
409 let query = args["query"].as_str();
410 let tags: Option<Vec<String>> = args["tags"].as_array().map(|arr| {
411 arr.iter()
412 .filter_map(|v| v.as_str().map(String::from))
413 .collect()
414 });
415 let limit = args["limit"].as_u64().map(|v| v as usize).unwrap_or(10);
416 let scope = scope::search(&args);
417
418 let tags_ref = tags.as_deref();
419
420 let results = {
421 let mut store = self.store.lock().await;
422 store.search(query, tags_ref, scope.as_deref(), limit)
423 };
424
425 if results.is_empty() {
426 return Ok(ToolResult::success(
427 "No memories found matching your criteria.".to_string(),
428 ));
429 }
430
431 let output = results
432 .iter()
433 .enumerate()
434 .map(|(i, m)| {
435 format!(
436 "{}. [{}] {} - {}\n Tags: {}\n Created: {}",
437 i + 1,
438 m.id,
439 m.content.chars().take(80).collect::<String>()
440 + if m.content.len() > 80 { "..." } else { "" },
441 format!("accessed {} times", m.access_count),
442 m.tags.join(", "),
443 m.created_at.format("%Y-%m-%d %H:%M")
444 )
445 })
446 .collect::<Vec<_>>()
447 .join("\n\n");
448
449 Ok(ToolResult::success(format!(
450 "Found {} memories:\n\n{}",
451 results.len(),
452 output
453 )))
454 }
455
456 async fn execute_get(&self, args: Value) -> Result<ToolResult> {
457 let id = args["id"]
458 .as_str()
459 .ok_or_else(|| anyhow::anyhow!("id is required for 'get' action"))?;
460
461 let entry = {
462 let mut store = self.store.lock().await;
463 store.get(id)
464 };
465
466 match entry {
467 Some(e) => {
468 self.persist().await?;
470
471 Ok(ToolResult::success(format!(
472 "Memory ID: {}\nImportance: {}/5\nTags: {}\nCreated: {}\nAccessed: {} times\n\n{}",
473 e.id,
474 e.importance,
475 e.tags.join(", "),
476 e.created_at.format("%Y-%m-%d %H:%M:%S"),
477 e.access_count,
478 e.content
479 )))
480 }
481 None => Ok(ToolResult::error(format!("Memory not found: {}", id))),
482 }
483 }
484
485 async fn execute_list(&self, args: Value) -> Result<ToolResult> {
486 let limit = args["limit"].as_u64().map(|v| v as usize).unwrap_or(10);
487 let scope = scope::search(&args);
488
489 let results = {
490 let mut store = self.store.lock().await;
491 store.search(None, None, scope.as_deref(), limit)
492 };
493
494 if results.is_empty() {
495 return Ok(ToolResult::success(
496 "No memories stored yet. Use 'save' to add your first memory.".to_string(),
497 ));
498 }
499
500 let output = results
501 .iter()
502 .enumerate()
503 .map(|(i, m)| {
504 format!(
505 "{}. [{}] {} (importance: {}/5, accessed: {}x)",
506 i + 1,
507 m.id,
508 m.content.chars().take(60).collect::<String>()
509 + if m.content.len() > 60 { "..." } else { "" },
510 m.importance,
511 m.access_count
512 )
513 })
514 .collect::<Vec<_>>()
515 .join("\n");
516
517 Ok(ToolResult::success(format!(
518 "Recent memories:\n\n{}",
519 output
520 )))
521 }
522
523 async fn execute_tags(&self, _args: Value) -> Result<ToolResult> {
524 let tags = {
525 let store = self.store.lock().await;
526 store.all_tags()
527 };
528
529 if tags.is_empty() {
530 return Ok(ToolResult::success(
531 "No tags yet. Add tags when saving memories.".to_string(),
532 ));
533 }
534
535 let mut sorted: Vec<_> = tags.iter().collect();
536 sorted.sort_by(|a, b| b.1.cmp(a.1));
537
538 let output = sorted
539 .iter()
540 .map(|(tag, count)| format!(" {} ({} memories)", tag, count))
541 .collect::<Vec<_>>()
542 .join("\n");
543
544 Ok(ToolResult::success(format!(
545 "Available tags:\n\n{}",
546 output
547 )))
548 }
549
550 async fn execute_delete(&self, args: Value) -> Result<ToolResult> {
551 let id = args["id"]
552 .as_str()
553 .ok_or_else(|| anyhow::anyhow!("id is required for 'delete' action"))?;
554
555 let deleted = {
556 let mut store = self.store.lock().await;
557 store.delete(id)
558 };
559
560 if deleted {
561 self.persist().await?;
562 Ok(ToolResult::success(format!("Memory deleted: {}", id)))
563 } else {
564 Ok(ToolResult::error(format!("Memory not found: {}", id)))
565 }
566 }
567
568 async fn execute_stats(&self, _args: Value) -> Result<ToolResult> {
569 let stats = {
570 let store = self.store.lock().await;
571 store.stats()
572 };
573
574 let tags_output = if stats.tags.is_empty() {
575 "None".to_string()
576 } else {
577 let mut sorted: Vec<_> = stats.tags.iter().collect();
578 sorted.sort_by(|a, b| b.1.cmp(a.1));
579 sorted
580 .iter()
581 .take(10)
582 .map(|(t, c)| format!(" {}: {}", t, c))
583 .collect::<Vec<_>>()
584 .join("\n")
585 };
586
587 Ok(ToolResult::success(format!(
588 "Memory Statistics:\n\n\
589 Total entries: {}\n\
590 Total accesses: {}\n\
591 Unique tags: {}\n\n\
592 Top tags:\n{}",
593 stats.total_entries, stats.total_accesses, stats.unique_tags, tags_output
594 )))
595 }
596}
597
598#[cfg(test)]
599mod tests;