Skip to main content

sediment/mcp/
tools.rs

1//! MCP Tool definitions for Sediment
2//!
3//! 4 tools: store, recall, list, forget
4
5use std::sync::Arc;
6
7use serde::Deserialize;
8use serde_json::{Value, json};
9
10use crate::access::AccessTracker;
11use crate::consolidation::{ConsolidationQueue, spawn_consolidation};
12use crate::db::score_with_decay;
13use crate::graph::GraphStore;
14use crate::item::{Item, ItemFilters};
15use crate::retry::{RetryConfig, with_retry};
16use crate::{Database, ListScope, StoreScope};
17
18use super::protocol::{CallToolResult, Tool};
19use super::server::ServerContext;
20
21/// Spawn a background task with panic logging. If the task panics, the panic
22/// is caught and logged as an error instead of silently disappearing.
23fn spawn_logged(name: &'static str, fut: impl std::future::Future<Output = ()> + Send + 'static) {
24    tokio::spawn(async move {
25        let result = tokio::task::spawn(fut).await;
26        if let Err(e) = result {
27            tracing::error!("Background task '{}' panicked: {:?}", name, e);
28        }
29    });
30}
31
32/// Get all available tools (4 total)
33pub fn get_tools() -> Vec<Tool> {
34    let store_schema = {
35        #[allow(unused_mut)]
36        let mut props = json!({
37            "content": {
38                "type": "string",
39                "description": "The content to store"
40            },
41            "scope": {
42                "type": "string",
43                "enum": ["project", "global"],
44                "default": "project",
45                "description": "Where to store: 'project' (current project) or 'global' (all projects)"
46            }
47        });
48
49        #[cfg(feature = "bench")]
50        {
51            props.as_object_mut().unwrap().insert(
52                "created_at".to_string(),
53                json!({
54                    "type": "number",
55                    "description": "Override creation timestamp (Unix seconds). Benchmark builds only."
56                }),
57            );
58        }
59
60        json!({
61            "type": "object",
62            "properties": props,
63            "required": ["content"]
64        })
65    };
66
67    vec![
68        Tool {
69            name: "store".to_string(),
70            description: "Store content for later retrieval. Use for preferences, facts, reference material, docs, or any information worth remembering. Long content is automatically chunked for better search.".to_string(),
71            input_schema: store_schema,
72        },
73        Tool {
74            name: "recall".to_string(),
75            description: "Search stored content by semantic similarity. Returns matching items with relevant excerpts for chunked content.".to_string(),
76            input_schema: json!({
77                "type": "object",
78                "properties": {
79                    "query": {
80                        "type": "string",
81                        "description": "What to search for (semantic search)"
82                    },
83                    "limit": {
84                        "type": "number",
85                        "default": 5,
86                        "description": "Maximum number of results"
87                    }
88                },
89                "required": ["query"]
90            }),
91        },
92        Tool {
93            name: "list".to_string(),
94            description: "List stored items.".to_string(),
95            input_schema: json!({
96                "type": "object",
97                "properties": {
98                    "limit": {
99                        "type": "number",
100                        "default": 10,
101                        "description": "Maximum number of results"
102                    },
103                    "scope": {
104                        "type": "string",
105                        "enum": ["project", "global", "all"],
106                        "default": "project",
107                        "description": "Which items to list: 'project', 'global', or 'all'"
108                    }
109                }
110            }),
111        },
112        Tool {
113            name: "forget".to_string(),
114            description: "Delete a stored item by its ID.".to_string(),
115            input_schema: json!({
116                "type": "object",
117                "properties": {
118                    "id": {
119                        "type": "string",
120                        "description": "The item ID to delete"
121                    }
122                },
123                "required": ["id"]
124            }),
125        },
126    ]
127}
128
129// ========== Parameter Structs ==========
130
131#[derive(Debug, Deserialize)]
132pub struct StoreParams {
133    pub content: String,
134    #[serde(default)]
135    pub scope: Option<String>,
136    /// Override creation timestamp (Unix seconds). Benchmark builds only.
137    #[cfg(feature = "bench")]
138    #[serde(default)]
139    pub created_at: Option<i64>,
140}
141
142#[derive(Debug, Deserialize)]
143pub struct RecallParams {
144    pub query: String,
145    #[serde(default)]
146    pub limit: Option<usize>,
147}
148
149#[derive(Debug, Deserialize)]
150pub struct ListParams {
151    #[serde(default)]
152    pub limit: Option<usize>,
153    #[serde(default)]
154    pub scope: Option<String>,
155}
156
157#[derive(Debug, Deserialize)]
158pub struct ForgetParams {
159    pub id: String,
160}
161
162// ========== Recall Configuration ==========
163
164/// Controls which graph and scoring features are enabled during recall.
165/// Used by benchmarks to measure the impact of individual features.
166pub struct RecallConfig {
167    pub enable_graph_backfill: bool,
168    pub enable_graph_expansion: bool,
169    pub enable_co_access: bool,
170    pub enable_decay_scoring: bool,
171    pub enable_background_tasks: bool,
172}
173
174impl Default for RecallConfig {
175    fn default() -> Self {
176        Self {
177            enable_graph_backfill: true,
178            enable_graph_expansion: true,
179            enable_co_access: true,
180            enable_decay_scoring: true,
181            enable_background_tasks: true,
182        }
183    }
184}
185
186/// Result of a recall pipeline execution (for benchmark consumption).
187pub struct RecallResult {
188    pub results: Vec<crate::item::SearchResult>,
189    pub graph_expanded: Vec<Value>,
190    pub suggested: Vec<Value>,
191    /// Raw (pre-decay/trust) similarity scores, keyed by item ID
192    pub raw_similarities: std::collections::HashMap<String, f32>,
193}
194
195// ========== Tool Execution ==========
196
197pub async fn execute_tool(ctx: &ServerContext, name: &str, args: Option<Value>) -> CallToolResult {
198    let config = RetryConfig::default();
199    let args_for_retry = args.clone();
200
201    let result = with_retry(&config, || {
202        let ctx_ref = ctx;
203        let name_ref = name;
204        let args_clone = args_for_retry.clone();
205
206        async move {
207            // Open fresh connection with shared embedder
208            let mut db = Database::open_with_embedder(
209                &ctx_ref.db_path,
210                ctx_ref.project_id.clone(),
211                ctx_ref.embedder.clone(),
212            )
213            .await
214            .map_err(|e| sanitize_err("Failed to open database", e))?;
215
216            // Open access tracker
217            let tracker = AccessTracker::open(&ctx_ref.access_db_path)
218                .map_err(|e| sanitize_err("Failed to open access tracker", e))?;
219
220            // Open graph store (shares access.db)
221            let graph = GraphStore::open(&ctx_ref.access_db_path)
222                .map_err(|e| sanitize_err("Failed to open graph store", e))?;
223
224            let result = match name_ref {
225                "store" => execute_store(&mut db, &tracker, &graph, ctx_ref, args_clone).await,
226                "recall" => execute_recall(&mut db, &tracker, &graph, ctx_ref, args_clone).await,
227                "list" => execute_list(&mut db, args_clone).await,
228                "forget" => execute_forget(&mut db, &graph, ctx_ref, args_clone).await,
229                _ => return Ok(CallToolResult::error(format!("Unknown tool: {}", name_ref))),
230            };
231
232            if result.is_error.unwrap_or(false)
233                && let Some(content) = result.content.first()
234                && is_retryable_error(&content.text)
235            {
236                return Err(content.text.clone());
237            }
238
239            Ok(result)
240        }
241    })
242    .await;
243
244    match result {
245        Ok(call_result) => call_result,
246        Err(e) => {
247            tracing::error!("Operation failed after retries: {}", e);
248            CallToolResult::error("Operation failed after retries")
249        }
250    }
251}
252
253fn is_retryable_error(error_msg: &str) -> bool {
254    let retryable_patterns = [
255        "connection",
256        "timeout",
257        "temporarily unavailable",
258        "resource busy",
259        "lock",
260        "I/O error",
261        "Failed to open",
262        "Failed to connect",
263    ];
264
265    let lower = error_msg.to_lowercase();
266    retryable_patterns
267        .iter()
268        .any(|p| lower.contains(&p.to_lowercase()))
269}
270
271// ========== Tool Implementations ==========
272
273async fn execute_store(
274    db: &mut Database,
275    _tracker: &AccessTracker,
276    graph: &GraphStore,
277    ctx: &ServerContext,
278    args: Option<Value>,
279) -> CallToolResult {
280    let params: StoreParams = match args {
281        Some(v) => match serde_json::from_value(v) {
282            Ok(p) => p,
283            Err(e) => {
284                tracing::debug!("Parameter validation failed: {}", e);
285                return CallToolResult::error("Invalid parameters");
286            }
287        },
288        None => return CallToolResult::error("Missing parameters"),
289    };
290
291    // Reject oversized content to prevent OOM during embedding/chunking.
292    // Intentionally byte-based (not char-based): memory allocation is proportional
293    // to byte length, so this is the correct metric for OOM prevention.
294    const MAX_CONTENT_BYTES: usize = 1_000_000;
295    if params.content.len() > MAX_CONTENT_BYTES {
296        return CallToolResult::error(format!(
297            "Content too large: {} bytes (max {} bytes)",
298            params.content.len(),
299            MAX_CONTENT_BYTES
300        ));
301    }
302
303    // Parse scope
304    let scope = params
305        .scope
306        .as_deref()
307        .map(|s| s.parse::<StoreScope>())
308        .transpose();
309
310    let scope = match scope {
311        Ok(s) => s.unwrap_or(StoreScope::Project),
312        Err(e) => return CallToolResult::error(e),
313    };
314
315    // Build item
316    let mut item = Item::new(&params.content);
317
318    // Override created_at if provided (benchmark builds only)
319    #[cfg(feature = "bench")]
320    if let Some(ts) = params.created_at {
321        if let Some(dt) = chrono::DateTime::from_timestamp(ts, 0) {
322            item = item.with_created_at(dt);
323        }
324    }
325
326    // Set project_id based on scope
327    if scope == StoreScope::Project
328        && let Some(project_id) = db.project_id()
329    {
330        item = item.with_project_id(project_id);
331    }
332
333    match db.store_item(item).await {
334        Ok(store_result) => {
335            let new_id = store_result.id.clone();
336
337            // Create graph node
338            let now = chrono::Utc::now().timestamp();
339            let project_id = db.project_id().map(|s| s.to_string());
340            if let Err(e) = graph.add_node(&new_id, project_id.as_deref(), now) {
341                tracing::warn!("graph add_node failed: {}", e);
342            }
343
344            // Enqueue consolidation candidates from conflicts
345            if !store_result.potential_conflicts.is_empty()
346                && let Ok(queue) = ConsolidationQueue::open(&ctx.access_db_path)
347            {
348                for conflict in &store_result.potential_conflicts {
349                    if let Err(e) = queue.enqueue(&new_id, &conflict.id, conflict.similarity as f64)
350                    {
351                        tracing::warn!("enqueue consolidation failed: {}", e);
352                    }
353                }
354            }
355
356            let mut result = json!({
357                "success": true,
358                "id": new_id,
359                "message": format!("Stored in {} scope", scope)
360            });
361
362            if !store_result.potential_conflicts.is_empty() {
363                let conflicts: Vec<Value> = store_result
364                    .potential_conflicts
365                    .iter()
366                    .map(|c| {
367                        json!({
368                            "id": c.id,
369                            "content": c.content,
370                            "similarity": format!("{:.2}", c.similarity)
371                        })
372                    })
373                    .collect();
374                result["potential_conflicts"] = json!(conflicts);
375            }
376
377            CallToolResult::success(
378                serde_json::to_string_pretty(&result)
379                    .unwrap_or_else(|e| format!("{{\"error\": \"serialization failed: {}\"}}", e)),
380            )
381        }
382        Err(e) => sanitized_error("Failed to store item", e),
383    }
384}
385
386/// Core recall pipeline, extracted for benchmarking.
387///
388/// Performs: vector search, optional decay scoring, optional graph backfill,
389/// optional 1-hop graph expansion, and optional co-access suggestions.
390pub async fn recall_pipeline(
391    db: &mut Database,
392    tracker: &AccessTracker,
393    graph: &GraphStore,
394    query: &str,
395    limit: usize,
396    filters: ItemFilters,
397    config: &RecallConfig,
398) -> std::result::Result<RecallResult, String> {
399    let mut results = db
400        .search_items(query, limit, filters)
401        .await
402        .map_err(|e| format!("Search failed: {}", e))?;
403
404    if results.is_empty() {
405        return Ok(RecallResult {
406            results: Vec::new(),
407            graph_expanded: Vec::new(),
408            suggested: Vec::new(),
409            raw_similarities: std::collections::HashMap::new(),
410        });
411    }
412
413    // Lazy graph backfill (uses project_id from SearchResult, no extra queries)
414    if config.enable_graph_backfill {
415        for result in &results {
416            if let Err(e) = graph.ensure_node_exists(
417                &result.id,
418                result.project_id.as_deref(),
419                result.created_at.timestamp(),
420            ) {
421                tracing::warn!("ensure_node_exists failed: {}", e);
422            }
423        }
424    }
425
426    // Decay scoring — preserve raw similarity for transparency
427    let mut raw_similarities: std::collections::HashMap<String, f32> =
428        std::collections::HashMap::new();
429    if config.enable_decay_scoring {
430        let item_ids: Vec<&str> = results.iter().map(|r| r.id.as_str()).collect();
431        let access_records = tracker.get_accesses(&item_ids).unwrap_or_default();
432        let validation_counts = tracker.get_validation_counts(&item_ids).unwrap_or_default();
433        let edge_counts = graph.get_edge_counts(&item_ids).unwrap_or_default();
434        let now = chrono::Utc::now().timestamp();
435
436        for result in &mut results {
437            raw_similarities.insert(result.id.clone(), result.similarity);
438
439            let created_at = result.created_at.timestamp();
440            let (access_count, last_accessed) = match access_records.get(&result.id) {
441                Some(rec) => (rec.access_count, Some(rec.last_accessed_at)),
442                None => (0, None),
443            };
444
445            let base_score = score_with_decay(
446                result.similarity,
447                now,
448                created_at,
449                access_count,
450                last_accessed,
451            );
452
453            let validation_count = validation_counts.get(&result.id).copied().unwrap_or(0);
454            let edge_count = edge_counts.get(&result.id).copied().unwrap_or(0);
455            let trust_bonus =
456                1.0 + 0.05 * (1.0 + validation_count as f64).ln() as f32 + 0.02 * edge_count as f32;
457
458            result.similarity = (base_score * trust_bonus).min(1.0);
459        }
460
461        results.sort_by(|a, b| {
462            b.similarity
463                .partial_cmp(&a.similarity)
464                .unwrap_or(std::cmp::Ordering::Equal)
465        });
466    }
467
468    // Record access
469    for result in &results {
470        let created_at = result.created_at.timestamp();
471        if let Err(e) = tracker.record_access(&result.id, created_at) {
472            tracing::warn!("record_access failed: {}", e);
473        }
474    }
475
476    // Graph expansion
477    let existing_ids: std::collections::HashSet<String> =
478        results.iter().map(|r| r.id.clone()).collect();
479
480    let mut graph_expanded = Vec::new();
481    if config.enable_graph_expansion {
482        let top_ids: Vec<&str> = results.iter().take(5).map(|r| r.id.as_str()).collect();
483        if let Ok(neighbors) = graph.get_neighbors(&top_ids, 0.5) {
484            // Collect neighbor IDs not already in results, then batch fetch
485            let neighbor_info: Vec<(String, String)> = neighbors
486                .into_iter()
487                .filter(|(id, _, _)| !existing_ids.contains(id))
488                .map(|(id, rel_type, _)| (id, rel_type))
489                .collect();
490
491            let neighbor_ids: Vec<&str> = neighbor_info.iter().map(|(id, _)| id.as_str()).collect();
492            if let Ok(items) = db.get_items_batch(&neighbor_ids).await {
493                let item_map: std::collections::HashMap<&str, &Item> =
494                    items.iter().map(|item| (item.id.as_str(), item)).collect();
495
496                for (neighbor_id, rel_type) in &neighbor_info {
497                    if let Some(item) = item_map.get(neighbor_id.as_str()) {
498                        let sr = crate::item::SearchResult::from_item(item, 0.05);
499                        let mut entry = json!({
500                            "id": sr.id,
501                            "similarity": "graph",
502                            "created": sr.created_at.to_rfc3339(),
503                            "graph_expanded": true,
504                            "rel_type": rel_type,
505                        });
506                        // Only include content for same-project or global items
507                        let same_project = match (db.project_id(), item.project_id.as_deref()) {
508                            (Some(current), Some(item_pid)) => current == item_pid,
509                            (_, None) => true,
510                            _ => false,
511                        };
512                        if same_project {
513                            entry["content"] = json!(sr.content);
514                        } else {
515                            entry["cross_project"] = json!(true);
516                        }
517                        graph_expanded.push(entry);
518                    }
519                }
520            }
521        }
522    }
523
524    // Co-access suggestions (batch fetch)
525    let mut suggested = Vec::new();
526    if config.enable_co_access {
527        let top3_ids: Vec<&str> = results.iter().take(3).map(|r| r.id.as_str()).collect();
528        if let Ok(co_accessed) = graph.get_co_accessed(&top3_ids, 3) {
529            let co_info: Vec<(String, i64)> = co_accessed
530                .into_iter()
531                .filter(|(id, _)| !existing_ids.contains(id))
532                .collect();
533
534            let co_ids: Vec<&str> = co_info.iter().map(|(id, _)| id.as_str()).collect();
535            if let Ok(items) = db.get_items_batch(&co_ids).await {
536                let item_map: std::collections::HashMap<&str, &Item> =
537                    items.iter().map(|item| (item.id.as_str(), item)).collect();
538
539                for (co_id, co_count) in &co_info {
540                    if let Some(item) = item_map.get(co_id.as_str()) {
541                        let same_project = match (db.project_id(), item.project_id.as_deref()) {
542                            (Some(current), Some(item_pid)) => current == item_pid,
543                            (_, None) => true,
544                            _ => false,
545                        };
546                        let mut entry = json!({
547                            "id": item.id,
548                            "reason": format!("frequently recalled with result (co-accessed {} times)", co_count),
549                        });
550                        if same_project {
551                            entry["content"] = json!(truncate(&item.content, 100));
552                        } else {
553                            entry["cross_project"] = json!(true);
554                        }
555                        suggested.push(entry);
556                    }
557                }
558            }
559        }
560    }
561
562    Ok(RecallResult {
563        results,
564        graph_expanded,
565        suggested,
566        raw_similarities,
567    })
568}
569
570async fn execute_recall(
571    db: &mut Database,
572    tracker: &AccessTracker,
573    graph: &GraphStore,
574    ctx: &ServerContext,
575    args: Option<Value>,
576) -> CallToolResult {
577    let params: RecallParams = match args {
578        Some(v) => match serde_json::from_value(v) {
579            Ok(p) => p,
580            Err(e) => {
581                tracing::debug!("Parameter validation failed: {}", e);
582                return CallToolResult::error("Invalid parameters");
583            }
584        },
585        None => return CallToolResult::error("Missing parameters"),
586    };
587
588    // Reject oversized queries to prevent OOM during tokenization.
589    // The model truncates to 512 tokens (~2KB of English text), so anything
590    // beyond 10KB is wasted processing.
591    const MAX_QUERY_BYTES: usize = 10_000;
592    if params.query.len() > MAX_QUERY_BYTES {
593        return CallToolResult::error(format!(
594            "Query too large: {} bytes (max {} bytes)",
595            params.query.len(),
596            MAX_QUERY_BYTES
597        ));
598    }
599
600    let limit = params.limit.unwrap_or(5).min(100);
601
602    let filters = ItemFilters::new();
603
604    let config = RecallConfig::default();
605
606    let recall_result =
607        match recall_pipeline(db, tracker, graph, &params.query, limit, filters, &config).await {
608            Ok(r) => r,
609            Err(e) => {
610                tracing::error!("Recall failed: {}", e);
611                return CallToolResult::error("Search failed");
612            }
613        };
614
615    if recall_result.results.is_empty() {
616        return CallToolResult::success("No items found matching your query.");
617    }
618
619    let results = &recall_result.results;
620
621    // Batch-fetch neighbors for all result IDs
622    let all_result_ids: Vec<&str> = results.iter().map(|r| r.id.as_str()).collect();
623    let neighbors_map = graph
624        .get_neighbors_mapped(&all_result_ids, 0.5)
625        .unwrap_or_default();
626
627    let formatted: Vec<Value> = results
628        .iter()
629        .map(|r| {
630            let mut obj = json!({
631                "id": r.id,
632                "content": r.content,
633                "similarity": format!("{:.2}", r.similarity),
634                "created": r.created_at.to_rfc3339(),
635            });
636
637            // Include raw (pre-decay) similarity when decay scoring was applied
638            if let Some(&raw_sim) = recall_result.raw_similarities.get(&r.id)
639                && (raw_sim - r.similarity).abs() > 0.001
640            {
641                obj["raw_similarity"] = json!(format!("{:.2}", raw_sim));
642            }
643
644            if let Some(ref excerpt) = r.relevant_excerpt {
645                obj["relevant_excerpt"] = json!(excerpt);
646            }
647
648            // Cross-project flag
649            if let Some(ref current_pid) = ctx.project_id
650                && let Some(ref item_pid) = r.project_id
651                && item_pid != current_pid
652            {
653                obj["cross_project"] = json!(true);
654            }
655
656            // Related IDs from graph (batch lookup)
657            if let Some(related) = neighbors_map.get(&r.id)
658                && !related.is_empty()
659            {
660                obj["related_ids"] = json!(related);
661            }
662
663            obj
664        })
665        .collect();
666
667    let mut result_json = json!({
668        "count": results.len(),
669        "results": formatted
670    });
671
672    if !recall_result.graph_expanded.is_empty() {
673        result_json["graph_expanded"] = json!(recall_result.graph_expanded);
674    }
675
676    if !recall_result.suggested.is_empty() {
677        result_json["suggested"] = json!(recall_result.suggested);
678    }
679
680    // Fire-and-forget: background consolidation (Phase 2b)
681    spawn_consolidation(
682        Arc::new(ctx.db_path.clone()),
683        Arc::new(ctx.access_db_path.clone()),
684        ctx.project_id.clone(),
685        ctx.embedder.clone(),
686        ctx.consolidation_semaphore.clone(),
687    );
688
689    // Fire-and-forget: co-access recording (Phase 3a)
690    let result_ids: Vec<String> = results.iter().map(|r| r.id.clone()).collect();
691    let access_db_path = ctx.access_db_path.clone();
692    spawn_logged("co_access", async move {
693        if let Ok(g) = GraphStore::open(&access_db_path) {
694            if let Err(e) = g.record_co_access(&result_ids) {
695                tracing::warn!("record_co_access failed: {}", e);
696            }
697        } else {
698            tracing::warn!("co_access: failed to open graph store");
699        }
700    });
701
702    // Periodic maintenance: every 10th recall
703    let run_count = ctx
704        .recall_count
705        .fetch_add(1, std::sync::atomic::Ordering::AcqRel);
706    if run_count % 10 == 9 {
707        // Clustering
708        let access_db_path = ctx.access_db_path.clone();
709        spawn_logged("clustering", async move {
710            if let Ok(g) = GraphStore::open(&access_db_path)
711                && let Ok(clusters) = g.detect_clusters()
712            {
713                for (a, b, c) in &clusters {
714                    let label = format!("cluster-{}", &a[..8.min(a.len())]);
715                    if let Err(e) = g.add_related_edge(a, b, 0.8, &label) {
716                        tracing::warn!("cluster add_related_edge failed: {}", e);
717                    }
718                    if let Err(e) = g.add_related_edge(b, c, 0.8, &label) {
719                        tracing::warn!("cluster add_related_edge failed: {}", e);
720                    }
721                    if let Err(e) = g.add_related_edge(a, c, 0.8, &label) {
722                        tracing::warn!("cluster add_related_edge failed: {}", e);
723                    }
724                }
725                if !clusters.is_empty() {
726                    tracing::info!("Detected {} clusters", clusters.len());
727                }
728            }
729        });
730
731        // Consolidation queue cleanup: purge old processed entries
732        let access_db_path2 = ctx.access_db_path.clone();
733        spawn_logged("consolidation_cleanup", async move {
734            if let Ok(q) = crate::consolidation::ConsolidationQueue::open(&access_db_path2)
735                && let Err(e) = q.cleanup_processed()
736            {
737                tracing::warn!("consolidation queue cleanup failed: {}", e);
738            }
739        });
740    }
741
742    CallToolResult::success(
743        serde_json::to_string_pretty(&result_json)
744            .unwrap_or_else(|e| format!("{{\"error\": \"serialization failed: {}\"}}", e)),
745    )
746}
747
748async fn execute_list(db: &mut Database, args: Option<Value>) -> CallToolResult {
749    let params: ListParams =
750        args.and_then(|v| serde_json::from_value(v).ok())
751            .unwrap_or(ListParams {
752                limit: None,
753                scope: None,
754            });
755
756    let limit = params.limit.unwrap_or(10).min(100);
757
758    let filters = ItemFilters::new();
759
760    let scope = params
761        .scope
762        .as_deref()
763        .map(|s| s.parse::<ListScope>())
764        .transpose();
765
766    let scope = match scope {
767        Ok(s) => s.unwrap_or(ListScope::Project),
768        Err(e) => return CallToolResult::error(e),
769    };
770
771    match db.list_items(filters, Some(limit), scope).await {
772        Ok(items) => {
773            if items.is_empty() {
774                CallToolResult::success("No items stored yet.")
775            } else {
776                let formatted: Vec<Value> = items
777                    .iter()
778                    .map(|item| {
779                        let content_preview = truncate(&item.content, 100);
780                        let mut obj = json!({
781                            "id": item.id,
782                            "content": content_preview,
783                            "created": item.created_at.to_rfc3339(),
784                        });
785
786                        if item.is_chunked {
787                            obj["chunked"] = json!(true);
788                        }
789
790                        obj
791                    })
792                    .collect();
793
794                let result = json!({
795                    "count": items.len(),
796                    "items": formatted
797                });
798
799                CallToolResult::success(
800                    serde_json::to_string_pretty(&result).unwrap_or_else(|e| {
801                        format!("{{\"error\": \"serialization failed: {}\"}}", e)
802                    }),
803                )
804            }
805        }
806        Err(e) => sanitized_error("Failed to list items", e),
807    }
808}
809
810async fn execute_forget(
811    db: &mut Database,
812    graph: &GraphStore,
813    ctx: &ServerContext,
814    args: Option<Value>,
815) -> CallToolResult {
816    let params: ForgetParams = match args {
817        Some(v) => match serde_json::from_value(v) {
818            Ok(p) => p,
819            Err(e) => {
820                tracing::debug!("Parameter validation failed: {}", e);
821                return CallToolResult::error("Invalid parameters");
822            }
823        },
824        None => return CallToolResult::error("Missing parameters"),
825    };
826
827    // Access control: verify the item belongs to the current project (or is global)
828    if let Some(ref current_pid) = ctx.project_id {
829        match db.get_item(&params.id).await {
830            Ok(Some(item)) => {
831                if let Some(ref item_pid) = item.project_id
832                    && item_pid != current_pid
833                {
834                    return CallToolResult::error(format!(
835                        "Cannot delete item {} from a different project",
836                        params.id
837                    ));
838                }
839            }
840            Ok(None) => return CallToolResult::error(format!("Item not found: {}", params.id)),
841            Err(e) => {
842                return sanitized_error("Failed to look up item", e);
843            }
844        }
845    }
846
847    match db.delete_item(&params.id).await {
848        Ok(true) => {
849            // Remove from graph
850            if let Err(e) = graph.remove_node(&params.id) {
851                tracing::warn!("remove_node failed: {}", e);
852            }
853
854            let result = json!({
855                "success": true,
856                "message": format!("Deleted item: {}", params.id)
857            });
858            CallToolResult::success(
859                serde_json::to_string_pretty(&result)
860                    .unwrap_or_else(|e| format!("{{\"error\": \"serialization failed: {}\"}}", e)),
861            )
862        }
863        Ok(false) => CallToolResult::error(format!("Item not found: {}", params.id)),
864        Err(e) => sanitized_error("Failed to delete item", e),
865    }
866}
867
868// ========== Utilities ==========
869
870/// Log a detailed internal error and return a sanitized message to the MCP client.
871/// This prevents leaking file paths, database internals, or OS details.
872fn sanitized_error(context: &str, err: impl std::fmt::Display) -> CallToolResult {
873    tracing::error!("{}: {}", context, err);
874    CallToolResult::error(context.to_string())
875}
876
877/// Like `sanitized_error` but returns a String for use inside `map_err` chains.
878fn sanitize_err(context: &str, err: impl std::fmt::Display) -> String {
879    tracing::error!("{}: {}", context, err);
880    context.to_string()
881}
882
883fn truncate(s: &str, max_len: usize) -> String {
884    if s.chars().count() <= max_len {
885        s.to_string()
886    } else if max_len <= 3 {
887        // Not enough room for "..." suffix; just take max_len chars
888        s.chars().take(max_len).collect()
889    } else {
890        let cut = s
891            .char_indices()
892            .nth(max_len - 3)
893            .map(|(i, _)| i)
894            .unwrap_or(s.len());
895        format!("{}...", &s[..cut])
896    }
897}
898
899#[cfg(test)]
900mod tests {
901    use super::*;
902
903    #[test]
904    fn test_truncate_small_max_len() {
905        // Bug #25: truncate should not panic when max_len < 3
906        assert_eq!(truncate("hello", 0), "");
907        assert_eq!(truncate("hello", 1), "h");
908        assert_eq!(truncate("hello", 2), "he");
909        assert_eq!(truncate("hello", 3), "hel");
910        assert_eq!(truncate("hi", 3), "hi"); // shorter than max, no truncation
911        assert_eq!(truncate("hello", 5), "hello");
912        assert_eq!(truncate("hello!", 5), "he...");
913    }
914
915    #[test]
916    fn test_truncate_unicode() {
917        assert_eq!(truncate("héllo wörld", 5), "hé...");
918        assert_eq!(truncate("日本語テスト", 4), "日...");
919    }
920
921    // ========== Integration Tests ==========
922
923    use std::path::PathBuf;
924    use std::sync::Mutex;
925    use tokio::sync::Semaphore;
926
927    /// Create a ServerContext with temp dirs for integration testing.
928    async fn setup_test_context() -> (ServerContext, tempfile::TempDir) {
929        let tmp = tempfile::TempDir::new().unwrap();
930        let db_path = tmp.path().join("data");
931        let access_db_path = tmp.path().join("access.db");
932
933        let embedder = Arc::new(crate::Embedder::new().unwrap());
934        let project_id = Some("test-project-00000001".to_string());
935
936        let ctx = ServerContext {
937            db_path,
938            access_db_path,
939            project_id,
940            embedder,
941            cwd: PathBuf::from("."),
942            consolidation_semaphore: Arc::new(Semaphore::new(1)),
943            recall_count: std::sync::atomic::AtomicU64::new(0),
944            rate_limit: Mutex::new(super::super::server::RateLimitState {
945                window_start_ms: 0,
946                count: 0,
947            }),
948        };
949
950        (ctx, tmp)
951    }
952
953    #[tokio::test]
954    #[ignore] // requires model download
955    async fn test_store_and_recall_roundtrip() {
956        let (ctx, _tmp) = setup_test_context().await;
957
958        // Store an item
959        let store_result = execute_tool(
960            &ctx,
961            "store",
962            Some(json!({ "content": "Rust is a systems programming language" })),
963        )
964        .await;
965        assert!(
966            store_result.is_error.is_none(),
967            "Store should succeed: {:?}",
968            store_result.content
969        );
970
971        // Recall by query
972        let recall_result = execute_tool(
973            &ctx,
974            "recall",
975            Some(json!({ "query": "systems programming language" })),
976        )
977        .await;
978        assert!(recall_result.is_error.is_none(), "Recall should succeed");
979
980        let text = &recall_result.content[0].text;
981        assert!(
982            text.contains("Rust is a systems programming language"),
983            "Recall should return stored content, got: {}",
984            text
985        );
986    }
987
988    #[tokio::test]
989    #[ignore] // requires model download
990    async fn test_store_and_list() {
991        let (ctx, _tmp) = setup_test_context().await;
992
993        // Store 2 items
994        execute_tool(
995            &ctx,
996            "store",
997            Some(json!({ "content": "First item for listing" })),
998        )
999        .await;
1000        execute_tool(
1001            &ctx,
1002            "store",
1003            Some(json!({ "content": "Second item for listing" })),
1004        )
1005        .await;
1006
1007        // List items
1008        let list_result = execute_tool(&ctx, "list", Some(json!({ "scope": "project" }))).await;
1009        assert!(list_result.is_error.is_none(), "List should succeed");
1010
1011        let text = &list_result.content[0].text;
1012        let parsed: Value = serde_json::from_str(text).unwrap();
1013        assert_eq!(parsed["count"], 2, "Should list 2 items");
1014    }
1015
1016    #[tokio::test]
1017    #[ignore] // requires model download
1018    async fn test_store_conflict_detection() {
1019        let (ctx, _tmp) = setup_test_context().await;
1020
1021        // Store first item
1022        execute_tool(
1023            &ctx,
1024            "store",
1025            Some(json!({ "content": "The quick brown fox jumps over the lazy dog" })),
1026        )
1027        .await;
1028
1029        // Store nearly identical item
1030        let result = execute_tool(
1031            &ctx,
1032            "store",
1033            Some(json!({ "content": "The quick brown fox jumps over the lazy dog" })),
1034        )
1035        .await;
1036        assert!(result.is_error.is_none(), "Store should succeed");
1037
1038        let text = &result.content[0].text;
1039        let parsed: Value = serde_json::from_str(text).unwrap();
1040        assert!(
1041            parsed.get("potential_conflicts").is_some(),
1042            "Should detect conflict for near-duplicate content, got: {}",
1043            text
1044        );
1045    }
1046
1047    #[tokio::test]
1048    #[ignore] // requires model download
1049    async fn test_forget_removes_item() {
1050        let (ctx, _tmp) = setup_test_context().await;
1051
1052        // Store an item
1053        let store_result = execute_tool(
1054            &ctx,
1055            "store",
1056            Some(json!({ "content": "Item to be forgotten" })),
1057        )
1058        .await;
1059        let text = &store_result.content[0].text;
1060        let parsed: Value = serde_json::from_str(text).unwrap();
1061        let item_id = parsed["id"].as_str().unwrap().to_string();
1062
1063        // Forget it
1064        let forget_result = execute_tool(&ctx, "forget", Some(json!({ "id": item_id }))).await;
1065        assert!(forget_result.is_error.is_none(), "Forget should succeed");
1066
1067        // List should be empty
1068        let list_result = execute_tool(&ctx, "list", Some(json!({ "scope": "project" }))).await;
1069        let text = &list_result.content[0].text;
1070        assert!(
1071            text.contains("No items stored yet"),
1072            "Should have no items after forget, got: {}",
1073            text
1074        );
1075    }
1076
1077    #[tokio::test]
1078    #[ignore] // requires model download
1079    async fn test_recall_empty_db() {
1080        let (ctx, _tmp) = setup_test_context().await;
1081
1082        let result = execute_tool(&ctx, "recall", Some(json!({ "query": "anything" }))).await;
1083        assert!(
1084            result.is_error.is_none(),
1085            "Recall on empty DB should not error"
1086        );
1087
1088        let text = &result.content[0].text;
1089        assert!(
1090            text.contains("No items found"),
1091            "Should indicate no items found, got: {}",
1092            text
1093        );
1094    }
1095
1096    #[tokio::test]
1097    #[ignore] // requires model download
1098    async fn test_store_rejects_oversized_content() {
1099        let (ctx, _tmp) = setup_test_context().await;
1100
1101        let large_content = "x".repeat(1_100_000); // >1MB
1102        let result = execute_tool(&ctx, "store", Some(json!({ "content": large_content }))).await;
1103        assert!(
1104            result.is_error == Some(true),
1105            "Should reject oversized content"
1106        );
1107
1108        let text = &result.content[0].text;
1109        assert!(
1110            text.contains("too large"),
1111            "Error should mention size, got: {}",
1112            text
1113        );
1114    }
1115
1116    #[tokio::test]
1117    #[ignore] // requires model download
1118    async fn test_recall_rejects_oversized_query() {
1119        let (ctx, _tmp) = setup_test_context().await;
1120
1121        let large_query = "x".repeat(11_000); // >10KB
1122        let result = execute_tool(&ctx, "recall", Some(json!({ "query": large_query }))).await;
1123        assert!(
1124            result.is_error == Some(true),
1125            "Should reject oversized query"
1126        );
1127
1128        let text = &result.content[0].text;
1129        assert!(
1130            text.contains("too large"),
1131            "Error should mention size, got: {}",
1132            text
1133        );
1134    }
1135
1136    #[tokio::test]
1137    #[ignore] // requires model download
1138    async fn test_store_missing_params() {
1139        let (ctx, _tmp) = setup_test_context().await;
1140
1141        // No params at all
1142        let result = execute_tool(&ctx, "store", None).await;
1143        assert!(result.is_error == Some(true), "Should error with no params");
1144
1145        // Empty object (missing required 'content')
1146        let result = execute_tool(&ctx, "store", Some(json!({}))).await;
1147        assert!(
1148            result.is_error == Some(true),
1149            "Should error with missing content"
1150        );
1151    }
1152}