Skip to main content

sediment/mcp/
tools.rs

1//! MCP Tool definitions for Sediment
2//!
3//! 5 tools: store, recall, list, forget, connections
4
5use std::sync::Arc;
6
7use chrono::DateTime;
8use serde::Deserialize;
9use serde_json::{Value, json};
10
11use crate::access::AccessTracker;
12use crate::consolidation::{ConsolidationQueue, spawn_consolidation};
13use crate::db::score_with_decay;
14use crate::graph::GraphStore;
15use crate::item::{Item, ItemFilters};
16use crate::retry::{RetryConfig, with_retry};
17use crate::{Database, ListScope, StoreScope};
18
19use super::protocol::{CallToolResult, Tool};
20use super::server::ServerContext;
21
22/// Spawn a background task with panic logging. If the task panics, the panic
23/// is caught and logged as an error instead of silently disappearing.
24fn spawn_logged(name: &'static str, fut: impl std::future::Future<Output = ()> + Send + 'static) {
25    tokio::spawn(async move {
26        let result = tokio::task::spawn(fut).await;
27        if let Err(e) = result {
28            tracing::error!("Background task '{}' panicked: {:?}", name, e);
29        }
30    });
31}
32
33/// Get all available tools (5 total)
34pub fn get_tools() -> Vec<Tool> {
35    vec![
36        Tool {
37            name: "store".to_string(),
38            description: "Store content for later retrieval. Use for preferences, facts, reference material, docs, or any information worth remembering. Long content is automatically chunked for better search.".to_string(),
39            input_schema: json!({
40                "type": "object",
41                "properties": {
42                    "content": {
43                        "type": "string",
44                        "description": "The content to store"
45                    },
46                    "title": {
47                        "type": "string",
48                        "description": "Optional title (recommended for long content)"
49                    },
50                    "tags": {
51                        "type": "array",
52                        "items": { "type": "string" },
53                        "description": "Tags for categorization"
54                    },
55                    "source": {
56                        "type": "string",
57                        "description": "Source attribution (e.g., URL, file path, 'conversation')"
58                    },
59                    "metadata": {
60                        "type": "object",
61                        "description": "Custom JSON metadata"
62                    },
63                    "expires_at": {
64                        "type": "string",
65                        "description": "ISO datetime when this should expire (optional)"
66                    },
67                    "scope": {
68                        "type": "string",
69                        "enum": ["project", "global"],
70                        "default": "project",
71                        "description": "Where to store: 'project' (current project) or 'global' (all projects)"
72                    },
73                    "replace": {
74                        "type": "string",
75                        "description": "ID of an existing item to replace (stores new item first, then deletes old)"
76                    },
77                    "related": {
78                        "type": "array",
79                        "items": { "type": "string" },
80                        "description": "IDs of related items to link in the knowledge graph"
81                    }
82                },
83                "required": ["content"]
84            }),
85        },
86        Tool {
87            name: "recall".to_string(),
88            description: "Search stored content by semantic similarity. Returns matching items with relevant excerpts for chunked content.".to_string(),
89            input_schema: json!({
90                "type": "object",
91                "properties": {
92                    "query": {
93                        "type": "string",
94                        "description": "What to search for (semantic search)"
95                    },
96                    "limit": {
97                        "type": "number",
98                        "default": 5,
99                        "description": "Maximum number of results"
100                    },
101                    "tags": {
102                        "type": "array",
103                        "items": { "type": "string" },
104                        "description": "Filter by tags (any match)"
105                    },
106                    "min_similarity": {
107                        "type": "number",
108                        "default": 0.3,
109                        "description": "Minimum similarity threshold (0.0-1.0). Lower values return more results."
110                    }
111                },
112                "required": ["query"]
113            }),
114        },
115        Tool {
116            name: "list".to_string(),
117            description: "List stored items with optional filtering.".to_string(),
118            input_schema: json!({
119                "type": "object",
120                "properties": {
121                    "tags": {
122                        "type": "array",
123                        "items": { "type": "string" },
124                        "description": "Filter by tags"
125                    },
126                    "limit": {
127                        "type": "number",
128                        "default": 10,
129                        "description": "Maximum number of results"
130                    },
131                    "scope": {
132                        "type": "string",
133                        "enum": ["project", "global", "all"],
134                        "default": "project",
135                        "description": "Which items to list: 'project', 'global', or 'all'"
136                    }
137                }
138            }),
139        },
140        Tool {
141            name: "forget".to_string(),
142            description: "Delete a stored item by its ID.".to_string(),
143            input_schema: json!({
144                "type": "object",
145                "properties": {
146                    "id": {
147                        "type": "string",
148                        "description": "The item ID to delete"
149                    }
150                },
151                "required": ["id"]
152            }),
153        },
154        Tool {
155            name: "connections".to_string(),
156            description: "Show the relationship graph for a stored item. Returns all connections including related items, superseded items, and frequently co-accessed items.".to_string(),
157            input_schema: json!({
158                "type": "object",
159                "properties": {
160                    "id": {
161                        "type": "string",
162                        "description": "The item ID to show connections for"
163                    }
164                },
165                "required": ["id"]
166            }),
167        },
168    ]
169}
170
171// ========== Parameter Structs ==========
172
173#[derive(Debug, Deserialize)]
174pub struct StoreParams {
175    pub content: String,
176    #[serde(default)]
177    pub title: Option<String>,
178    #[serde(default)]
179    pub tags: Option<Vec<String>>,
180    #[serde(default)]
181    pub source: Option<String>,
182    #[serde(default)]
183    pub metadata: Option<Value>,
184    #[serde(default)]
185    pub expires_at: Option<String>,
186    #[serde(default)]
187    pub scope: Option<String>,
188    #[serde(default)]
189    pub replace: Option<String>,
190    #[serde(default)]
191    pub related: Option<Vec<String>>,
192}
193
194#[derive(Debug, Deserialize)]
195pub struct RecallParams {
196    pub query: String,
197    #[serde(default)]
198    pub limit: Option<usize>,
199    #[serde(default)]
200    pub tags: Option<Vec<String>>,
201    #[serde(default)]
202    pub min_similarity: Option<f32>,
203}
204
205#[derive(Debug, Deserialize)]
206pub struct ListParams {
207    #[serde(default)]
208    pub tags: Option<Vec<String>>,
209    #[serde(default)]
210    pub limit: Option<usize>,
211    #[serde(default)]
212    pub scope: Option<String>,
213}
214
215#[derive(Debug, Deserialize)]
216pub struct ForgetParams {
217    pub id: String,
218}
219
220#[derive(Debug, Deserialize)]
221pub struct ConnectionsParams {
222    pub id: String,
223}
224
225// ========== Recall Configuration ==========
226
227/// Controls which graph and scoring features are enabled during recall.
228/// Used by benchmarks to measure the impact of individual features.
229pub struct RecallConfig {
230    pub enable_graph_backfill: bool,
231    pub enable_graph_expansion: bool,
232    pub enable_co_access: bool,
233    pub enable_decay_scoring: bool,
234    pub enable_background_tasks: bool,
235}
236
237impl Default for RecallConfig {
238    fn default() -> Self {
239        Self {
240            enable_graph_backfill: true,
241            enable_graph_expansion: true,
242            enable_co_access: true,
243            enable_decay_scoring: true,
244            enable_background_tasks: true,
245        }
246    }
247}
248
249/// Result of a recall pipeline execution (for benchmark consumption).
250pub struct RecallResult {
251    pub results: Vec<crate::item::SearchResult>,
252    pub graph_expanded: Vec<Value>,
253    pub suggested: Vec<Value>,
254    /// Raw (pre-decay/trust) similarity scores, keyed by item ID
255    pub raw_similarities: std::collections::HashMap<String, f32>,
256}
257
258// ========== Tool Execution ==========
259
260pub async fn execute_tool(ctx: &ServerContext, name: &str, args: Option<Value>) -> CallToolResult {
261    let config = RetryConfig::default();
262    let args_for_retry = args.clone();
263
264    let result = with_retry(&config, || {
265        let ctx_ref = ctx;
266        let name_ref = name;
267        let args_clone = args_for_retry.clone();
268
269        async move {
270            // Open fresh connection with shared embedder
271            let mut db = Database::open_with_embedder(
272                &ctx_ref.db_path,
273                ctx_ref.project_id.clone(),
274                ctx_ref.embedder.clone(),
275            )
276            .await
277            .map_err(|e| sanitize_err("Failed to open database", e))?;
278
279            // Open access tracker
280            let tracker = AccessTracker::open(&ctx_ref.access_db_path)
281                .map_err(|e| sanitize_err("Failed to open access tracker", e))?;
282
283            // Open graph store (shares access.db)
284            let graph = GraphStore::open(&ctx_ref.access_db_path)
285                .map_err(|e| sanitize_err("Failed to open graph store", e))?;
286
287            let result = match name_ref {
288                "store" => execute_store(&mut db, &tracker, &graph, ctx_ref, args_clone).await,
289                "recall" => execute_recall(&mut db, &tracker, &graph, ctx_ref, args_clone).await,
290                "list" => execute_list(&mut db, args_clone).await,
291                "forget" => execute_forget(&mut db, &graph, ctx_ref, args_clone).await,
292                "connections" => execute_connections(&mut db, &graph, ctx_ref, args_clone).await,
293                _ => return Ok(CallToolResult::error(format!("Unknown tool: {}", name_ref))),
294            };
295
296            if result.is_error.unwrap_or(false)
297                && let Some(content) = result.content.first()
298                && is_retryable_error(&content.text)
299            {
300                return Err(content.text.clone());
301            }
302
303            Ok(result)
304        }
305    })
306    .await;
307
308    match result {
309        Ok(call_result) => call_result,
310        Err(e) => {
311            tracing::error!("Operation failed after retries: {}", e);
312            CallToolResult::error("Operation failed after retries")
313        }
314    }
315}
316
317fn is_retryable_error(error_msg: &str) -> bool {
318    let retryable_patterns = [
319        "connection",
320        "timeout",
321        "temporarily unavailable",
322        "resource busy",
323        "lock",
324        "I/O error",
325        "Failed to open",
326        "Failed to connect",
327    ];
328
329    let lower = error_msg.to_lowercase();
330    retryable_patterns
331        .iter()
332        .any(|p| lower.contains(&p.to_lowercase()))
333}
334
335// ========== Tool Implementations ==========
336
337async fn execute_store(
338    db: &mut Database,
339    tracker: &AccessTracker,
340    graph: &GraphStore,
341    ctx: &ServerContext,
342    args: Option<Value>,
343) -> CallToolResult {
344    let params: StoreParams = match args {
345        Some(v) => match serde_json::from_value(v) {
346            Ok(p) => p,
347            Err(e) => {
348                tracing::debug!("Parameter validation failed: {}", e);
349                return CallToolResult::error("Invalid parameters");
350            }
351        },
352        None => return CallToolResult::error("Missing parameters"),
353    };
354
355    // Reject oversized content to prevent OOM during embedding/chunking.
356    // Intentionally byte-based (not char-based): memory allocation is proportional
357    // to byte length, so this is the correct metric for OOM prevention.
358    const MAX_CONTENT_BYTES: usize = 1_000_000;
359    if params.content.len() > MAX_CONTENT_BYTES {
360        return CallToolResult::error(format!(
361            "Content too large: {} bytes (max {} bytes)",
362            params.content.len(),
363            MAX_CONTENT_BYTES
364        ));
365    }
366
367    // Validate field sizes to prevent abuse
368    const MAX_TITLE_LEN: usize = 1000;
369    const MAX_SOURCE_LEN: usize = 2000;
370    const MAX_TAG_LEN: usize = 200;
371    const MAX_TAG_COUNT: usize = 50;
372    const MAX_METADATA_BYTES: usize = 100_000;
373
374    if let Some(ref title) = params.title
375        && title.len() > MAX_TITLE_LEN
376    {
377        return CallToolResult::error(format!(
378            "Title too large: {} bytes (max {})",
379            title.len(),
380            MAX_TITLE_LEN
381        ));
382    }
383
384    if let Some(ref source) = params.source
385        && source.len() > MAX_SOURCE_LEN
386    {
387        return CallToolResult::error(format!(
388            "Source too large: {} bytes (max {})",
389            source.len(),
390            MAX_SOURCE_LEN
391        ));
392    }
393
394    if let Some(ref tags) = params.tags {
395        if tags.len() > MAX_TAG_COUNT {
396            return CallToolResult::error(format!(
397                "Too many tags: {} (max {})",
398                tags.len(),
399                MAX_TAG_COUNT
400            ));
401        }
402        for tag in tags {
403            if tag.len() > MAX_TAG_LEN {
404                return CallToolResult::error(format!(
405                    "Tag too large: {} bytes (max {})",
406                    tag.len(),
407                    MAX_TAG_LEN
408                ));
409            }
410        }
411    }
412
413    if let Some(ref metadata) = params.metadata {
414        let meta_size = metadata.to_string().len();
415        if meta_size > MAX_METADATA_BYTES {
416            return CallToolResult::error(format!(
417                "Metadata too large: {} bytes (max {})",
418                meta_size, MAX_METADATA_BYTES
419            ));
420        }
421    }
422
423    // Parse scope
424    let scope = params
425        .scope
426        .as_deref()
427        .map(|s| s.parse::<StoreScope>())
428        .transpose();
429
430    let scope = match scope {
431        Ok(s) => s.unwrap_or(StoreScope::Project),
432        Err(e) => return CallToolResult::error(e),
433    };
434
435    // Parse expires_at if provided
436    let expires_at = if let Some(ref exp_str) = params.expires_at {
437        match DateTime::parse_from_rfc3339(exp_str) {
438            Ok(dt) => Some(dt.with_timezone(&chrono::Utc)),
439            Err(e) => return CallToolResult::error(format!("Invalid expires_at: {}", e)),
440        }
441    } else {
442        None
443    };
444
445    // Validate that the item to replace exists and belongs to the current project
446    let replaced_id = if let Some(ref replace_id) = params.replace {
447        match db.get_item(replace_id).await {
448            Ok(Some(item)) => {
449                // Access control: prevent replacing items from other projects
450                if let Some(ref current_pid) = ctx.project_id
451                    && let Some(ref item_pid) = item.project_id
452                    && item_pid != current_pid
453                {
454                    return CallToolResult::error(format!(
455                        "Cannot replace item {} from a different project",
456                        replace_id
457                    ));
458                }
459                Some(replace_id.clone())
460            }
461            Ok(None) => {
462                return CallToolResult::error(format!(
463                    "Cannot replace: item not found: {}",
464                    replace_id
465                ));
466            }
467            Err(e) => {
468                return sanitized_error("Failed to look up item for replace", e);
469            }
470        }
471    } else {
472        None
473    };
474
475    // Build item
476    let mut tags = params.tags.unwrap_or_default();
477    let mut item = Item::new(&params.content).with_tags(tags.clone());
478
479    if let Some(title) = params.title {
480        item = item.with_title(title);
481    }
482
483    if let Some(source) = params.source {
484        item = item.with_source(source);
485    }
486
487    // Build metadata with provenance
488    let mut metadata = params.metadata.unwrap_or(json!({}));
489    if let Some(obj) = metadata.as_object_mut() {
490        let mut provenance = json!({
491            "v": 1,
492            "project_path": ctx.cwd.file_name().map(|n| n.to_string_lossy().into_owned()).unwrap_or_else(|| "unknown".to_string())
493        });
494        if let Some(ref rid) = replaced_id {
495            provenance["supersedes"] = json!(rid);
496        }
497        obj.insert("_provenance".to_string(), provenance);
498    }
499    item = item.with_metadata(metadata);
500
501    if let Some(exp) = expires_at {
502        item = item.with_expires_at(exp);
503    }
504
505    // Set project_id based on scope
506    if scope == StoreScope::Project
507        && let Some(project_id) = db.project_id()
508    {
509        item = item.with_project_id(project_id);
510    }
511
512    // Auto-tag inference (Phase 4a): if no user tags, infer from similar items
513    if tags.is_empty()
514        && let Ok(similar) = db.find_similar_items(&params.content, 0.85, 5).await
515    {
516        let mut tag_counts: std::collections::HashMap<String, usize> =
517            std::collections::HashMap::new();
518        for conflict in &similar {
519            if let Some(similar_item) = db.get_item(&conflict.id).await.ok().flatten() {
520                for tag in &similar_item.tags {
521                    if !tag.starts_with("auto:") {
522                        *tag_counts.entry(tag.clone()).or_insert(0) += 1;
523                    }
524                }
525            }
526        }
527        // If 2+ similar items share a tag, auto-apply it
528        let auto_tags: Vec<String> = tag_counts
529            .into_iter()
530            .filter(|(_, count)| *count >= 2)
531            .map(|(tag, _)| format!("auto:{}", tag))
532            .collect();
533        if !auto_tags.is_empty() {
534            tags = item.tags.clone();
535            tags.extend(auto_tags);
536            item = item.with_tags(tags);
537        }
538    }
539
540    match db.store_item(item).await {
541        Ok(store_result) => {
542            let new_id = store_result.id.clone();
543
544            // Create graph node
545            let now = chrono::Utc::now().timestamp();
546            let project_id = db.project_id().map(|s| s.to_string());
547            if let Err(e) = graph.add_node(&new_id, project_id.as_deref(), now) {
548                tracing::warn!("graph add_node failed: {}", e);
549            }
550
551            // Complete replace: now that the new item is stored, delete the old one.
552            // Intentionally non-atomic (store-before-delete): if the process crashes
553            // between store and delete, both items exist (benign duplication, not data
554            // loss). The consolidation system will detect and merge duplicates.
555            if let Some(ref old_id) = replaced_id {
556                // Record validation on the NEW item (the replacement is a "confirmed" version)
557                let now_ts = chrono::Utc::now().timestamp();
558                if let Err(e) = tracker.record_validation(&new_id, now_ts) {
559                    tracing::warn!("record_validation failed: {}", e);
560                }
561                // Transfer graph edges from old node to new node before removing old node
562                // If this fails, abort the replace to avoid inconsistent state
563                if let Err(e) = graph.transfer_edges(old_id, &new_id) {
564                    tracing::error!("transfer_edges failed, aborting replace: {}", e);
565                    // Compensating action: remove the new item since replace failed
566                    let _ = db.delete_item(&new_id).await;
567                    let _ = graph.remove_node(&new_id);
568                    return CallToolResult::error(
569                        "Replace failed during edge transfer. Original item preserved.",
570                    );
571                }
572                // Create SUPERSEDES edge (non-critical, continue on failure)
573                if let Err(e) = graph.add_supersedes_edge(&new_id, old_id) {
574                    tracing::warn!("add_supersedes_edge failed: {}", e);
575                }
576                // Delete old item from LanceDB
577                if let Err(e) = db.delete_item(old_id).await {
578                    tracing::error!(
579                        "delete_item failed during replace: {}. Old item may remain as duplicate.",
580                        e
581                    );
582                }
583                // Remove old graph node (and its remaining edges)
584                if let Err(e) = graph.remove_node(old_id) {
585                    tracing::warn!("remove_node failed: {}", e);
586                }
587            }
588
589            // Create RELATED edges if specified (only for IDs that exist in LanceDB)
590            if let Some(ref related_ids) = params.related {
591                let rid_refs: Vec<&str> = related_ids.iter().map(|s| s.as_str()).collect();
592                let existing_items = db.get_items_batch(&rid_refs).await.unwrap_or_default();
593                let valid_ids: std::collections::HashSet<&str> =
594                    existing_items.iter().map(|i| i.id.as_str()).collect();
595
596                for rid in related_ids {
597                    if !valid_ids.contains(rid.as_str()) {
598                        tracing::warn!("related ID not found, skipping edge: {}", rid);
599                        continue;
600                    }
601                    // Ensure target node exists in graph before creating edge
602                    let _ = graph.ensure_node_exists(rid, None, now);
603                    if let Err(e) = graph.add_related_edge(&new_id, rid, 1.0, "user_linked") {
604                        tracing::warn!("add_related_edge failed: {}", e);
605                    }
606                }
607            }
608
609            // Enqueue consolidation candidates from conflicts
610            if !store_result.potential_conflicts.is_empty()
611                && let Ok(queue) = ConsolidationQueue::open(&ctx.access_db_path)
612            {
613                for conflict in &store_result.potential_conflicts {
614                    if let Err(e) = queue.enqueue(&new_id, &conflict.id, conflict.similarity as f64)
615                    {
616                        tracing::warn!("enqueue consolidation failed: {}", e);
617                    }
618                }
619            }
620
621            let mut result = json!({
622                "success": true,
623                "id": new_id,
624                "message": format!("Stored in {} scope", scope)
625            });
626
627            if !store_result.potential_conflicts.is_empty() {
628                let conflicts: Vec<Value> = store_result
629                    .potential_conflicts
630                    .iter()
631                    .map(|c| {
632                        json!({
633                            "id": c.id,
634                            "content": c.content,
635                            "similarity": format!("{:.2}", c.similarity)
636                        })
637                    })
638                    .collect();
639                result["potential_conflicts"] = json!(conflicts);
640            }
641
642            CallToolResult::success(
643                serde_json::to_string_pretty(&result)
644                    .unwrap_or_else(|e| format!("{{\"error\": \"serialization failed: {}\"}}", e)),
645            )
646        }
647        Err(e) => sanitized_error("Failed to store item", e),
648    }
649}
650
651/// Core recall pipeline, extracted for benchmarking.
652///
653/// Performs: vector search, optional decay scoring, optional graph backfill,
654/// optional 1-hop graph expansion, and optional co-access suggestions.
655pub async fn recall_pipeline(
656    db: &mut Database,
657    tracker: &AccessTracker,
658    graph: &GraphStore,
659    query: &str,
660    limit: usize,
661    filters: ItemFilters,
662    config: &RecallConfig,
663) -> std::result::Result<RecallResult, String> {
664    let mut results = db
665        .search_items(query, limit, filters)
666        .await
667        .map_err(|e| format!("Search failed: {}", e))?;
668
669    if results.is_empty() {
670        return Ok(RecallResult {
671            results: Vec::new(),
672            graph_expanded: Vec::new(),
673            suggested: Vec::new(),
674            raw_similarities: std::collections::HashMap::new(),
675        });
676    }
677
678    // Lazy graph backfill (uses project_id from SearchResult, no extra queries)
679    if config.enable_graph_backfill {
680        for result in &results {
681            if let Err(e) = graph.ensure_node_exists(
682                &result.id,
683                result.project_id.as_deref(),
684                result.created_at.timestamp(),
685            ) {
686                tracing::warn!("ensure_node_exists failed: {}", e);
687            }
688        }
689    }
690
691    // Decay scoring — preserve raw similarity for transparency
692    let mut raw_similarities: std::collections::HashMap<String, f32> =
693        std::collections::HashMap::new();
694    if config.enable_decay_scoring {
695        let item_ids: Vec<&str> = results.iter().map(|r| r.id.as_str()).collect();
696        let access_records = tracker.get_accesses(&item_ids).unwrap_or_default();
697        let now = chrono::Utc::now().timestamp();
698
699        for result in &mut results {
700            raw_similarities.insert(result.id.clone(), result.similarity);
701
702            let created_at = result.created_at.timestamp();
703            let (access_count, last_accessed) = match access_records.get(&result.id) {
704                Some(rec) => (rec.access_count, Some(rec.last_accessed_at)),
705                None => (0, None),
706            };
707
708            let base_score = score_with_decay(
709                result.similarity,
710                now,
711                created_at,
712                access_count,
713                last_accessed,
714            );
715
716            let validation_count = tracker.get_validation_count(&result.id).unwrap_or(0);
717            let edge_count = graph.get_edge_count(&result.id).unwrap_or(0);
718            let trust_bonus =
719                1.0 + 0.05 * (1.0 + validation_count as f64).ln() as f32 + 0.02 * edge_count as f32;
720
721            result.similarity = (base_score * trust_bonus).min(1.0);
722        }
723
724        results.sort_by(|a, b| {
725            b.similarity
726                .partial_cmp(&a.similarity)
727                .unwrap_or(std::cmp::Ordering::Equal)
728        });
729    }
730
731    // Record access
732    for result in &results {
733        let created_at = result.created_at.timestamp();
734        if let Err(e) = tracker.record_access(&result.id, created_at) {
735            tracing::warn!("record_access failed: {}", e);
736        }
737    }
738
739    // Graph expansion
740    let existing_ids: std::collections::HashSet<String> =
741        results.iter().map(|r| r.id.clone()).collect();
742
743    let mut graph_expanded = Vec::new();
744    if config.enable_graph_expansion {
745        let top_ids: Vec<&str> = results.iter().take(5).map(|r| r.id.as_str()).collect();
746        if let Ok(neighbors) = graph.get_neighbors(&top_ids, 0.5) {
747            // Collect neighbor IDs not already in results, then batch fetch
748            let neighbor_info: Vec<(String, String)> = neighbors
749                .into_iter()
750                .filter(|(id, _, _)| !existing_ids.contains(id))
751                .map(|(id, rel_type, _)| (id, rel_type))
752                .collect();
753
754            let neighbor_ids: Vec<&str> = neighbor_info.iter().map(|(id, _)| id.as_str()).collect();
755            if let Ok(items) = db.get_items_batch(&neighbor_ids).await {
756                let item_map: std::collections::HashMap<&str, &Item> =
757                    items.iter().map(|item| (item.id.as_str(), item)).collect();
758
759                for (neighbor_id, rel_type) in &neighbor_info {
760                    if let Some(item) = item_map.get(neighbor_id.as_str()) {
761                        let sr = crate::item::SearchResult::from_item(item, 0.05);
762                        let mut entry = json!({
763                            "id": sr.id,
764                            "similarity": "graph",
765                            "created": sr.created_at.to_rfc3339(),
766                            "graph_expanded": true,
767                            "rel_type": rel_type,
768                        });
769                        // Only include content for same-project or global items
770                        let same_project = match (db.project_id(), item.project_id.as_deref()) {
771                            (Some(current), Some(item_pid)) => current == item_pid,
772                            (_, None) => true,
773                            _ => false,
774                        };
775                        if same_project {
776                            entry["content"] = json!(sr.content);
777                        } else {
778                            entry["cross_project"] = json!(true);
779                        }
780                        graph_expanded.push(entry);
781                    }
782                }
783            }
784        }
785    }
786
787    // Co-access suggestions (batch fetch)
788    let mut suggested = Vec::new();
789    if config.enable_co_access {
790        let top3_ids: Vec<&str> = results.iter().take(3).map(|r| r.id.as_str()).collect();
791        if let Ok(co_accessed) = graph.get_co_accessed(&top3_ids, 3) {
792            let co_info: Vec<(String, i64)> = co_accessed
793                .into_iter()
794                .filter(|(id, _)| !existing_ids.contains(id))
795                .collect();
796
797            let co_ids: Vec<&str> = co_info.iter().map(|(id, _)| id.as_str()).collect();
798            if let Ok(items) = db.get_items_batch(&co_ids).await {
799                let item_map: std::collections::HashMap<&str, &Item> =
800                    items.iter().map(|item| (item.id.as_str(), item)).collect();
801
802                for (co_id, co_count) in &co_info {
803                    if let Some(item) = item_map.get(co_id.as_str()) {
804                        let same_project = match (db.project_id(), item.project_id.as_deref()) {
805                            (Some(current), Some(item_pid)) => current == item_pid,
806                            (_, None) => true,
807                            _ => false,
808                        };
809                        let mut entry = json!({
810                            "id": item.id,
811                            "reason": format!("frequently recalled with result (co-accessed {} times)", co_count),
812                        });
813                        if same_project {
814                            entry["content"] = json!(truncate(&item.content, 100));
815                        } else {
816                            entry["cross_project"] = json!(true);
817                        }
818                        suggested.push(entry);
819                    }
820                }
821            }
822        }
823    }
824
825    Ok(RecallResult {
826        results,
827        graph_expanded,
828        suggested,
829        raw_similarities,
830    })
831}
832
833async fn execute_recall(
834    db: &mut Database,
835    tracker: &AccessTracker,
836    graph: &GraphStore,
837    ctx: &ServerContext,
838    args: Option<Value>,
839) -> CallToolResult {
840    let params: RecallParams = match args {
841        Some(v) => match serde_json::from_value(v) {
842            Ok(p) => p,
843            Err(e) => {
844                tracing::debug!("Parameter validation failed: {}", e);
845                return CallToolResult::error("Invalid parameters");
846            }
847        },
848        None => return CallToolResult::error("Missing parameters"),
849    };
850
851    // Reject oversized queries to prevent OOM during tokenization
852    const MAX_QUERY_BYTES: usize = 100_000;
853    if params.query.len() > MAX_QUERY_BYTES {
854        return CallToolResult::error(format!(
855            "Query too large: {} bytes (max {} bytes)",
856            params.query.len(),
857            MAX_QUERY_BYTES
858        ));
859    }
860
861    let limit = params.limit.unwrap_or(5).min(100);
862    let min_similarity = params.min_similarity.unwrap_or(0.3);
863
864    let mut filters = ItemFilters::new().with_min_similarity(min_similarity);
865
866    if let Some(tags) = params.tags {
867        filters = filters.with_tags(tags);
868    }
869
870    let config = RecallConfig::default();
871
872    let recall_result =
873        match recall_pipeline(db, tracker, graph, &params.query, limit, filters, &config).await {
874            Ok(r) => r,
875            Err(e) => {
876                tracing::error!("Recall failed: {}", e);
877                return CallToolResult::error("Search failed");
878            }
879        };
880
881    if recall_result.results.is_empty() {
882        return CallToolResult::success("No items found matching your query.");
883    }
884
885    let results = &recall_result.results;
886
887    let formatted: Vec<Value> = results
888        .iter()
889        .map(|r| {
890            let mut obj = json!({
891                "id": r.id,
892                "content": r.content,
893                "similarity": format!("{:.2}", r.similarity),
894                "created": r.created_at.to_rfc3339(),
895            });
896
897            // Include raw (pre-decay) similarity when decay scoring was applied
898            if let Some(&raw_sim) = recall_result.raw_similarities.get(&r.id)
899                && (raw_sim - r.similarity).abs() > 0.001
900            {
901                obj["raw_similarity"] = json!(format!("{:.2}", raw_sim));
902            }
903
904            if let Some(ref excerpt) = r.relevant_excerpt {
905                obj["relevant_excerpt"] = json!(excerpt);
906            }
907            if !r.tags.is_empty() {
908                obj["tags"] = json!(r.tags);
909            }
910            if let Some(ref source) = r.source {
911                obj["source"] = json!(source);
912            }
913
914            // Cross-project flag (Phase 3c) — uses cached project_id/metadata from SearchResult
915            if let Some(ref current_pid) = ctx.project_id
916                && let Some(ref item_pid) = r.project_id
917                && item_pid != current_pid
918            {
919                obj["cross_project"] = json!(true);
920                if let Some(ref meta) = r.metadata
921                    && let Some(prov) = meta.get("_provenance")
922                    && let Some(pp) = prov.get("project_path")
923                {
924                    obj["project_path"] = pp.clone();
925                }
926            }
927
928            // Related IDs from graph (Phase 1d)
929            if let Ok(neighbors) = graph.get_neighbors(&[r.id.as_str()], 0.5) {
930                let related: Vec<String> = neighbors.iter().map(|(id, _, _)| id.clone()).collect();
931                if !related.is_empty() {
932                    obj["related_ids"] = json!(related);
933                }
934            }
935
936            obj
937        })
938        .collect();
939
940    let mut result_json = json!({
941        "count": results.len(),
942        "results": formatted
943    });
944
945    if !recall_result.graph_expanded.is_empty() {
946        result_json["graph_expanded"] = json!(recall_result.graph_expanded);
947    }
948
949    if !recall_result.suggested.is_empty() {
950        result_json["suggested"] = json!(recall_result.suggested);
951    }
952
953    // Fire-and-forget: background consolidation (Phase 2b)
954    spawn_consolidation(
955        Arc::new(ctx.db_path.clone()),
956        Arc::new(ctx.access_db_path.clone()),
957        ctx.project_id.clone(),
958        ctx.embedder.clone(),
959        ctx.consolidation_semaphore.clone(),
960    );
961
962    // Fire-and-forget: co-access recording (Phase 3a)
963    let result_ids: Vec<String> = results.iter().map(|r| r.id.clone()).collect();
964    let access_db_path = ctx.access_db_path.clone();
965    spawn_logged("co_access", async move {
966        if let Ok(g) = GraphStore::open(&access_db_path) {
967            if let Err(e) = g.record_co_access(&result_ids) {
968                tracing::warn!("record_co_access failed: {}", e);
969            }
970        } else {
971            tracing::warn!("co_access: failed to open graph store");
972        }
973    });
974
975    // Periodic maintenance: every 10th recall
976    let run_count = ctx
977        .recall_count
978        .fetch_add(1, std::sync::atomic::Ordering::AcqRel);
979    if run_count % 10 == 9 {
980        // Clustering
981        let access_db_path = ctx.access_db_path.clone();
982        spawn_logged("clustering", async move {
983            if let Ok(g) = GraphStore::open(&access_db_path)
984                && let Ok(clusters) = g.detect_clusters()
985            {
986                for (a, b, c) in &clusters {
987                    let label = format!("cluster-{}", &a[..8.min(a.len())]);
988                    if let Err(e) = g.add_related_edge(a, b, 0.8, &label) {
989                        tracing::warn!("cluster add_related_edge failed: {}", e);
990                    }
991                    if let Err(e) = g.add_related_edge(b, c, 0.8, &label) {
992                        tracing::warn!("cluster add_related_edge failed: {}", e);
993                    }
994                    if let Err(e) = g.add_related_edge(a, c, 0.8, &label) {
995                        tracing::warn!("cluster add_related_edge failed: {}", e);
996                    }
997                }
998                if !clusters.is_empty() {
999                    tracing::info!("Detected {} clusters", clusters.len());
1000                }
1001            }
1002        });
1003
1004        // Expired item cleanup
1005        let db_path = ctx.db_path.clone();
1006        let project_id = ctx.project_id.clone();
1007        let embedder = ctx.embedder.clone();
1008        spawn_logged("cleanup_expired", async move {
1009            match Database::open_with_embedder(&db_path, project_id, embedder).await {
1010                Ok(db) => {
1011                    if let Err(e) = db.cleanup_expired().await {
1012                        tracing::warn!("cleanup_expired failed: {}", e);
1013                    }
1014                }
1015                Err(e) => tracing::warn!("cleanup_expired: failed to open db: {}", e),
1016            }
1017        });
1018
1019        // Consolidation queue cleanup: purge old processed entries
1020        let access_db_path2 = ctx.access_db_path.clone();
1021        spawn_logged("consolidation_cleanup", async move {
1022            if let Ok(q) = crate::consolidation::ConsolidationQueue::open(&access_db_path2)
1023                && let Err(e) = q.cleanup_processed()
1024            {
1025                tracing::warn!("consolidation queue cleanup failed: {}", e);
1026            }
1027        });
1028    }
1029
1030    CallToolResult::success(
1031        serde_json::to_string_pretty(&result_json)
1032            .unwrap_or_else(|e| format!("{{\"error\": \"serialization failed: {}\"}}", e)),
1033    )
1034}
1035
1036async fn execute_list(db: &mut Database, args: Option<Value>) -> CallToolResult {
1037    let params: ListParams =
1038        args.and_then(|v| serde_json::from_value(v).ok())
1039            .unwrap_or(ListParams {
1040                tags: None,
1041                limit: None,
1042                scope: None,
1043            });
1044
1045    let limit = params.limit.unwrap_or(10).min(100);
1046
1047    let mut filters = ItemFilters::new();
1048
1049    if let Some(tags) = params.tags {
1050        filters = filters.with_tags(tags);
1051    }
1052
1053    let scope = params
1054        .scope
1055        .as_deref()
1056        .map(|s| s.parse::<ListScope>())
1057        .transpose();
1058
1059    let scope = match scope {
1060        Ok(s) => s.unwrap_or(ListScope::Project),
1061        Err(e) => return CallToolResult::error(e),
1062    };
1063
1064    match db.list_items(filters, Some(limit), scope).await {
1065        Ok(items) => {
1066            if items.is_empty() {
1067                CallToolResult::success("No items stored yet.")
1068            } else {
1069                let formatted: Vec<Value> = items
1070                    .iter()
1071                    .map(|item| {
1072                        let content_preview = truncate(&item.content, 100);
1073                        let mut obj = json!({
1074                            "id": item.id,
1075                            "content": content_preview,
1076                            "created": item.created_at.to_rfc3339(),
1077                        });
1078
1079                        if let Some(ref title) = item.title {
1080                            obj["title"] = json!(title);
1081                        }
1082                        if !item.tags.is_empty() {
1083                            obj["tags"] = json!(item.tags);
1084                        }
1085                        if item.is_chunked {
1086                            obj["chunked"] = json!(true);
1087                        }
1088
1089                        obj
1090                    })
1091                    .collect();
1092
1093                let result = json!({
1094                    "count": items.len(),
1095                    "items": formatted
1096                });
1097
1098                CallToolResult::success(
1099                    serde_json::to_string_pretty(&result).unwrap_or_else(|e| {
1100                        format!("{{\"error\": \"serialization failed: {}\"}}", e)
1101                    }),
1102                )
1103            }
1104        }
1105        Err(e) => sanitized_error("Failed to list items", e),
1106    }
1107}
1108
1109async fn execute_forget(
1110    db: &mut Database,
1111    graph: &GraphStore,
1112    ctx: &ServerContext,
1113    args: Option<Value>,
1114) -> CallToolResult {
1115    let params: ForgetParams = match args {
1116        Some(v) => match serde_json::from_value(v) {
1117            Ok(p) => p,
1118            Err(e) => {
1119                tracing::debug!("Parameter validation failed: {}", e);
1120                return CallToolResult::error("Invalid parameters");
1121            }
1122        },
1123        None => return CallToolResult::error("Missing parameters"),
1124    };
1125
1126    // Access control: verify the item belongs to the current project (or is global)
1127    if let Some(ref current_pid) = ctx.project_id {
1128        match db.get_item(&params.id).await {
1129            Ok(Some(item)) => {
1130                if let Some(ref item_pid) = item.project_id
1131                    && item_pid != current_pid
1132                {
1133                    return CallToolResult::error(format!(
1134                        "Cannot delete item {} from a different project",
1135                        params.id
1136                    ));
1137                }
1138            }
1139            Ok(None) => return CallToolResult::error(format!("Item not found: {}", params.id)),
1140            Err(e) => {
1141                return sanitized_error("Failed to look up item", e);
1142            }
1143        }
1144    }
1145
1146    match db.delete_item(&params.id).await {
1147        Ok(true) => {
1148            // Remove from graph
1149            if let Err(e) = graph.remove_node(&params.id) {
1150                tracing::warn!("remove_node failed: {}", e);
1151            }
1152
1153            let result = json!({
1154                "success": true,
1155                "message": format!("Deleted item: {}", params.id)
1156            });
1157            CallToolResult::success(
1158                serde_json::to_string_pretty(&result)
1159                    .unwrap_or_else(|e| format!("{{\"error\": \"serialization failed: {}\"}}", e)),
1160            )
1161        }
1162        Ok(false) => CallToolResult::error(format!("Item not found: {}", params.id)),
1163        Err(e) => sanitized_error("Failed to delete item", e),
1164    }
1165}
1166
1167async fn execute_connections(
1168    db: &mut Database,
1169    graph: &GraphStore,
1170    ctx: &ServerContext,
1171    args: Option<Value>,
1172) -> CallToolResult {
1173    let params: ConnectionsParams = match args {
1174        Some(v) => match serde_json::from_value(v) {
1175            Ok(p) => p,
1176            Err(e) => {
1177                tracing::debug!("Parameter validation failed: {}", e);
1178                return CallToolResult::error("Invalid parameters");
1179            }
1180        },
1181        None => return CallToolResult::error("Missing parameters"),
1182    };
1183
1184    // Verify item exists and belongs to the current project
1185    match db.get_item(&params.id).await {
1186        Ok(Some(item)) => {
1187            if let Some(ref current_pid) = ctx.project_id
1188                && let Some(ref item_pid) = item.project_id
1189                && item_pid != current_pid
1190            {
1191                return CallToolResult::error(format!(
1192                    "Cannot view connections for item {} from a different project",
1193                    params.id
1194                ));
1195            }
1196        }
1197        Ok(None) => return CallToolResult::error(format!("Item not found: {}", params.id)),
1198        Err(e) => return sanitized_error("Failed to get item", e),
1199    }
1200
1201    match graph.get_full_connections(&params.id) {
1202        Ok(connections) => {
1203            // Batch fetch all connected items
1204            let target_ids: Vec<&str> = connections.iter().map(|c| c.target_id.as_str()).collect();
1205            let items = db.get_items_batch(&target_ids).await.unwrap_or_default();
1206            let item_map: std::collections::HashMap<&str, &Item> =
1207                items.iter().map(|item| (item.id.as_str(), item)).collect();
1208
1209            let mut conn_json: Vec<Value> = Vec::new();
1210
1211            for conn in &connections {
1212                let mut obj = json!({
1213                    "id": conn.target_id,
1214                    "type": conn.rel_type,
1215                    "strength": conn.strength,
1216                });
1217
1218                if let Some(count) = conn.count {
1219                    obj["count"] = json!(count);
1220                }
1221
1222                // Add content preview from batch, but only if the connected item
1223                // belongs to the same project (or is global). This prevents
1224                // cross-project content leakage via graph edges.
1225                if let Some(item) = item_map.get(conn.target_id.as_str()) {
1226                    let same_project = match (&ctx.project_id, &item.project_id) {
1227                        (Some(current), Some(item_pid)) => current == item_pid,
1228                        (_, None) => true, // Global items are visible to all
1229                        _ => false,
1230                    };
1231                    if same_project {
1232                        obj["content_preview"] = json!(truncate(&item.content, 80));
1233                    } else {
1234                        obj["cross_project"] = json!(true);
1235                    }
1236                }
1237
1238                conn_json.push(obj);
1239            }
1240
1241            let result = json!({
1242                "item_id": params.id,
1243                "connections": conn_json
1244            });
1245
1246            CallToolResult::success(
1247                serde_json::to_string_pretty(&result)
1248                    .unwrap_or_else(|e| format!("{{\"error\": \"serialization failed: {}\"}}", e)),
1249            )
1250        }
1251        Err(e) => sanitized_error("Failed to get connections", e),
1252    }
1253}
1254
1255// ========== Utilities ==========
1256
1257/// Log a detailed internal error and return a sanitized message to the MCP client.
1258/// This prevents leaking file paths, database internals, or OS details.
1259fn sanitized_error(context: &str, err: impl std::fmt::Display) -> CallToolResult {
1260    tracing::error!("{}: {}", context, err);
1261    CallToolResult::error(context.to_string())
1262}
1263
1264/// Like `sanitized_error` but returns a String for use inside `map_err` chains.
1265fn sanitize_err(context: &str, err: impl std::fmt::Display) -> String {
1266    tracing::error!("{}: {}", context, err);
1267    context.to_string()
1268}
1269
1270fn truncate(s: &str, max_len: usize) -> String {
1271    if s.chars().count() <= max_len {
1272        s.to_string()
1273    } else if max_len <= 3 {
1274        // Not enough room for "..." suffix; just take max_len chars
1275        s.chars().take(max_len).collect()
1276    } else {
1277        let cut = s
1278            .char_indices()
1279            .nth(max_len - 3)
1280            .map(|(i, _)| i)
1281            .unwrap_or(s.len());
1282        format!("{}...", &s[..cut])
1283    }
1284}
1285
1286#[cfg(test)]
1287mod tests {
1288    use super::*;
1289
1290    #[test]
1291    fn test_truncate_small_max_len() {
1292        // Bug #25: truncate should not panic when max_len < 3
1293        assert_eq!(truncate("hello", 0), "");
1294        assert_eq!(truncate("hello", 1), "h");
1295        assert_eq!(truncate("hello", 2), "he");
1296        assert_eq!(truncate("hello", 3), "hel");
1297        assert_eq!(truncate("hi", 3), "hi"); // shorter than max, no truncation
1298        assert_eq!(truncate("hello", 5), "hello");
1299        assert_eq!(truncate("hello!", 5), "he...");
1300    }
1301
1302    #[test]
1303    fn test_truncate_unicode() {
1304        assert_eq!(truncate("héllo wörld", 5), "hé...");
1305        assert_eq!(truncate("日本語テスト", 4), "日...");
1306    }
1307}