Skip to main content

mcp_memory/actions/
code.rs

1//! MCP tool handlers for the tree-sitter code knowledge graph.
2//!
3//! These map parsed code symbols (see [`crate::code`]) onto the existing
4//! entity/relation graph so the regular search/traversal tools work on code,
5//! and expose four code-focused tools: `code_index`, `code_outline`,
6//! `code_search`, `code_get_symbol`.
7//!
8//! Symbols are stored as entities named `relpath::symbol` with type
9//! `code:<kind>`; metadata (file, line range, signature, doc) lives in
10//! observations. Edges: `defines` (file→symbol), `calls`/`references`
11//! (caller→callee, resolved only when the callee name is unambiguous).
12
13use std::collections::{HashMap, HashSet};
14use std::path::Path;
15use std::sync::atomic::{AtomicUsize, Ordering};
16use std::time::{SystemTime, UNIX_EPOCH};
17
18use serde_json::{Value, json};
19
20use crate::code::{self, Def, MAX_SYMBOLS_PER_FILE};
21use crate::errors::{MCSError, Result};
22use crate::kg::GraphHandle;
23use crate::types::{Entity, Relation};
24
25/// Cap on files processed in a single `code_index` call.
26const MAX_INDEX_FILES: usize = 100_000;
27/// Total symbols across all files (prevents OOM on huge repos).
28const MAX_TOTAL_SYMBOLS: usize = 5_000_000;
29/// Batch size for graph writes (keeps each write transaction bounded).
30const WRITE_BATCH: usize = 1_000;
31/// Default / max result rows for `code_search`.
32const DEFAULT_SEARCH_LIMIT: usize = 20;
33const MAX_SEARCH_LIMIT: usize = 500;
34/// Cap on callers/callees returned by `code_get_symbol`.
35const MAX_EDGES_RETURNED: usize = 500;
36
37macro_rules! text_content {
38    ($text:expr) => {
39        json!({ "content": [{ "type": "text", "text": $text }] })
40    };
41}
42
43fn to_json(v: &impl serde::Serialize) -> Result<Value> {
44    let text = serde_json::to_string(v).map_err(MCSError::JsonError)?;
45    Ok(text_content!(text))
46}
47
48/// Repo-relative, forward-slash path used as the `code:file` entity name.
49fn rel_path(p: &Path, base: &Path) -> String {
50    let r = if p.is_absolute() {
51        p.strip_prefix(base).unwrap_or(p)
52    } else {
53        p
54    };
55    r.to_string_lossy().replace('\\', "/")
56}
57
58/// Read a single-valued `key: value` observation off an entity.
59fn obs_val<'a>(entity: &'a Entity, key: &str) -> Option<&'a str> {
60    let prefix = format!("{key}: ");
61    entity
62        .observations
63        .iter()
64        .find_map(|o| o.strip_prefix(&prefix))
65}
66
67/// Strip the `code:` prefix from an entity type for display.
68fn kind_of(entity: &Entity) -> &str {
69    entity.entity_type.strip_prefix("code:").unwrap_or(&entity.entity_type)
70}
71
72fn is_code_entity(entity: &Entity) -> bool {
73    entity.entity_type.starts_with("code:")
74}
75
76/// Compact, location-focused view of a code symbol entity.
77fn symbol_row(entity: &Entity) -> Value {
78    json!({
79        "name": entity.name,
80        "kind": kind_of(entity),
81        "file": obs_val(entity, "file"),
82        "lines": obs_val(entity, "lines"),
83        "lang": obs_val(entity, "lang"),
84        "signature": obs_val(entity, "signature"),
85        "doc": obs_val(entity, "doc"),
86    })
87}
88
89// ---------------------------------------------------------------------------
90// code_index
91// ---------------------------------------------------------------------------
92
93/// Parsed symbols for one file, with qualified names already assigned.
94struct FileWork {
95    rel: String,
96    lang: &'static str,
97    hash: String,
98    /// Whether a `code:file` entity already existed (drives purge-skip).
99    existed: bool,
100    named: Vec<(Def, String)>,
101    refs: Vec<code::Ref>,
102}
103
104/// Outcome of processing a single path during the parallel parse phase.
105enum Outcome {
106    Indexed(Box<FileWork>),
107    Skipped,
108    Failed,
109    Unsupported,
110}
111
112/// Read + hash + (incrementally) parse one file. CPU-bound and independent per
113/// file, so this runs on the parse thread pool. Reads use the graph's
114/// concurrent read pool; no writes happen here.
115fn parse_one(kg: &GraphHandle, path: &Path, base: &Path, force: bool,
116             total_symbols: &AtomicUsize) -> Outcome {
117    let Some(lang) = code::detect(path) else {
118        return Outcome::Unsupported;
119    };
120    let rel = rel_path(path, base);
121    let Ok(bytes) = std::fs::read(path) else {
122        return Outcome::Failed;
123    };
124    let hash = code::hash_bytes(&bytes);
125
126    let existing = kg.get_entity(&rel).ok().flatten();
127    let existed = existing.is_some();
128    // Incremental: skip unchanged files (matching stored hash).
129    if !force
130        && let Some(e) = &existing
131        && obs_val(e, "hash") == Some(hash.as_str())
132    {
133        return Outcome::Skipped;
134    }
135
136    let parsed = code::parse_source(lang, &bytes);
137    let mut seen: HashSet<String> = HashSet::new();
138    let mut named: Vec<(Def, String)> = Vec::with_capacity(parsed.defs.len());
139    for d in parsed.defs.into_iter().take(MAX_SYMBOLS_PER_FILE) {
140        let mut q = format!("{rel}::{}", d.name);
141        if !seen.insert(q.clone()) {
142            q = format!("{q}::L{}", d.line_start);
143            seen.insert(q.clone());
144        }
145        named.push((d, q));
146    }
147
148    // Accumulate towards the total symbol cap.
149    let prev = total_symbols.fetch_add(named.len(), Ordering::Relaxed);
150    if prev + named.len() > MAX_TOTAL_SYMBOLS {
151        // Undo — we overshot. Non-atomic for correctness: the caller's cap check
152        // stops new files from being accepted; any surplus is simply ignored in
153        // the merge phase below.
154        return Outcome::Skipped;
155    }
156
157    Outcome::Indexed(Box::new(FileWork {
158        rel,
159        lang: lang.name(),
160        hash,
161        existed,
162        named,
163        refs: parsed.refs,
164    }))
165}
166
167pub fn handle_code_index(kg: &GraphHandle, args: Option<&Value>) -> Result<Value> {
168    let params = args.ok_or_else(|| MCSError::InvalidParams("Missing parameters".into()))?;
169    let path = params
170        .get("path")
171        .and_then(|v| v.as_str())
172        .ok_or_else(|| MCSError::InvalidParams("Missing 'path' parameter".into()))?;
173    let force = params.get("force").and_then(|v| v.as_bool()).unwrap_or(false);
174
175    let root = Path::new(path);
176    if !root.exists() {
177        return Err(MCSError::InvalidParams(format!("Path not found: {path}")));
178    }
179    let base = std::env::current_dir().unwrap_or_else(|_| Path::new(".").to_path_buf());
180
181    let mut files = code::walk(root, code::MAX_FILE_BYTES);
182    files.truncate(MAX_INDEX_FILES);
183
184    let now = SystemTime::now()
185        .duration_since(UNIX_EPOCH)
186        .map(|d| d.as_secs())
187        .unwrap_or(0);
188
189    // Parse phase (parallel): read + hash + parse each file across the CPU
190    // cores. Files are independent and parsing is the dominant cost; reads use
191    // the concurrent read pool. The single-writer graph mutations stay serial
192    // in the merge phase below.
193    let n = files.len();
194    let n_threads = std::thread::available_parallelism()
195        .map(|t| t.get())
196        .unwrap_or(4)
197        .min(n.max(1));
198    let next = AtomicUsize::new(0);
199    let total_symbols = AtomicUsize::new(0);
200    let buckets: Vec<Vec<Outcome>> = std::thread::scope(|scope| {
201        let handles: Vec<_> = (0..n_threads)
202            .map(|_| {
203                scope.spawn(|| {
204                    let mut local = Vec::new();
205                    loop {
206                        let i = next.fetch_add(1, Ordering::Relaxed);
207                        if i >= n {
208                            break;
209                        }
210                        // Pre-check total symbol cap to avoid unnecessary work.
211                        if total_symbols.load(Ordering::Relaxed) >= MAX_TOTAL_SYMBOLS {
212                            continue;
213                        }
214                        local.push(parse_one(kg, &files[i], &base, force, &total_symbols));
215                    }
216                    local
217                })
218            })
219            .collect();
220        handles.into_iter().map(|h| h.join().unwrap()).collect()
221    });
222
223    // Merge phase (serial): tally outcomes and build the global symbol index
224    // (bare name -> qualified names) used to resolve unambiguous call edges.
225    let mut work: Vec<FileWork> = Vec::new();
226    let mut def_index: HashMap<String, Vec<String>> = HashMap::new();
227    let mut files_indexed = 0usize;
228    let mut files_skipped = 0usize;
229    let mut files_failed = 0usize;
230    for outcome in buckets.into_iter().flatten() {
231        match outcome {
232            Outcome::Indexed(fw) => {
233                for (d, q) in &fw.named {
234                    def_index.entry(d.name.clone()).or_default().push(q.clone());
235                }
236                work.push(*fw);
237                files_indexed += 1;
238            }
239            Outcome::Skipped => files_skipped += 1,
240            Outcome::Failed => files_failed += 1,
241            Outcome::Unsupported => {}
242        }
243    }
244
245    // Write phase (serial, single writer). Streamed in `WRITE_BATCH` chunks so
246    // the transient entity/relation buffers stay bounded regardless of repo
247    // size; the parsed `work` is the only large allocation. Entities are written
248    // in full *before* any relation, since relations resolve their endpoints by
249    // name and would silently drop against a not-yet-written entity.
250
251    // Pass 1: purge changed files and write all entities.
252    let mut ebuf: Vec<Entity> = Vec::with_capacity(WRITE_BATCH);
253    let mut symbols = 0usize;
254    for fw in &work {
255        if fw.existed {
256            kg.code_purge_file(&fw.rel)?;
257        }
258        ebuf.push(Entity {
259            name: fw.rel.clone(),
260            entity_type: "code:file".into(),
261            observations: vec![
262                format!("lang: {}", fw.lang),
263                format!("hash: {}", fw.hash),
264                format!("symbols: {}", fw.named.len()),
265                format!("indexed_at: {now}"),
266            ],
267        });
268        for (d, q) in &fw.named {
269            let mut obs = vec![
270                format!("kind: {}", d.kind),
271                format!("lang: {}", fw.lang),
272                format!("file: {}", fw.rel),
273                format!("lines: {}-{}", d.line_start, d.line_end),
274                format!("signature: {}", d.signature),
275            ];
276            if let Some(doc) = &d.doc {
277                obs.push(format!("doc: {doc}"));
278            }
279            ebuf.push(Entity {
280                name: q.clone(),
281                entity_type: format!("code:{}", d.kind),
282                observations: obs,
283            });
284            symbols += 1;
285        }
286        if ebuf.len() >= WRITE_BATCH {
287            kg.upsert_entities(&ebuf)?;
288            ebuf.clear();
289        }
290    }
291    if !ebuf.is_empty() {
292        kg.upsert_entities(&ebuf)?;
293    }
294
295    // Pass 2: write `defines` edges and unambiguously-resolved call edges.
296    let mut rbuf: Vec<Relation> = Vec::with_capacity(WRITE_BATCH);
297    let mut rel_seen: HashSet<(String, String, &'static str)> = HashSet::new();
298    let mut relation_count = 0usize;
299    for fw in &work {
300        for (_, q) in &fw.named {
301            rbuf.push(Relation {
302                from: fw.rel.clone(),
303                to: q.clone(),
304                relation_type: "defines".into(),
305            });
306            relation_count += 1;
307        }
308        for r in &fw.refs {
309            let Some(targets) = def_index.get(&r.name) else { continue };
310            if targets.len() != 1 {
311                continue; // ambiguous or unresolved — drop (no false edges)
312            }
313            let callee = &targets[0];
314            let caller = enclosing(&fw.named, r.line).unwrap_or(&fw.rel).to_string();
315            if &caller == callee {
316                continue;
317            }
318            let rtype: &'static str = if r.kind == "call" { "calls" } else { "references" };
319            if !rel_seen.insert((caller.clone(), callee.clone(), rtype)) {
320                continue;
321            }
322            rbuf.push(Relation {
323                from: caller,
324                to: callee.clone(),
325                relation_type: rtype.into(),
326            });
327            relation_count += 1;
328        }
329        if rbuf.len() >= WRITE_BATCH {
330            kg.create_relations(&rbuf)?;
331            rbuf.clear();
332        }
333    }
334    if !rbuf.is_empty() {
335        kg.create_relations(&rbuf)?;
336    }
337
338    to_json(&json!({
339        "files_indexed": files_indexed,
340        "files_skipped": files_skipped,
341        "files_failed": files_failed,
342        "symbols": symbols,
343        "relations": relation_count,
344    }))
345}
346
347/// Smallest-span definition whose line range encloses `line`, if any.
348fn enclosing(named: &[(Def, String)], line: usize) -> Option<&str> {
349    named
350        .iter()
351        .filter(|(d, _)| d.line_start <= line && line <= d.line_end)
352        .min_by_key(|(d, _)| d.line_end - d.line_start)
353        .map(|(_, q)| q.as_str())
354}
355
356// ---------------------------------------------------------------------------
357// code_outline
358// ---------------------------------------------------------------------------
359
360pub fn handle_code_outline(kg: &GraphHandle, args: Option<&Value>) -> Result<Value> {
361    let params = args.ok_or_else(|| MCSError::InvalidParams("Missing parameters".into()))?;
362    let file = params
363        .get("file")
364        .and_then(|v| v.as_str())
365        .ok_or_else(|| MCSError::InvalidParams("Missing 'file' parameter".into()))?;
366    let file = file.replace('\\', "/");
367
368    let defines = kg.search_relations(Some(&file), None, Some("defines"), Some(MAX_SYMBOLS_PER_FILE));
369    let names: Vec<String> = defines.into_iter().map(|r| r.to).collect();
370    if names.is_empty() {
371        return to_json(&json!({
372            "file": file,
373            "symbols": [],
374            "note": "no symbols indexed for this file; run code_index first",
375        }));
376    }
377    let mut rows: Vec<Value> = kg
378        .batch_get_entities(&names)
379        .into_iter()
380        .flatten()
381        .map(|e| symbol_row(&e))
382        .collect();
383    // Order by starting line for a readable outline.
384    rows.sort_by_key(|r| {
385        r.get("lines")
386            .and_then(|v| v.as_str())
387            .and_then(|s| s.split('-').next())
388            .and_then(|s| s.parse::<u64>().ok())
389            .unwrap_or(0)
390    });
391
392    to_json(&json!({ "file": file, "symbols": rows }))
393}
394
395// ---------------------------------------------------------------------------
396// code_search
397// ---------------------------------------------------------------------------
398
399pub fn handle_code_search(kg: &GraphHandle, args: Option<&Value>) -> Result<Value> {
400    let params = args.ok_or_else(|| MCSError::InvalidParams("Missing parameters".into()))?;
401    let query = params
402        .get("query")
403        .and_then(|v| v.as_str())
404        .ok_or_else(|| MCSError::InvalidParams("Missing 'query' parameter".into()))?;
405    let kind = params.get("kind").and_then(|v| v.as_str()).filter(|s| !s.is_empty());
406    let lang = params.get("lang").and_then(|v| v.as_str()).filter(|s| !s.is_empty());
407    let limit = params
408        .get("limit")
409        .and_then(|v| v.as_u64())
410        .map(|n| n as usize)
411        .unwrap_or(DEFAULT_SEARCH_LIMIT)
412        .clamp(1, MAX_SEARCH_LIMIT);
413
414    // Over-fetch then narrow to code entities (search has a single-type filter).
415    let raw = kg.search_nodes_filtered(query, None, 0, limit.saturating_mul(5).min(1000));
416    let rows: Vec<Value> = raw
417        .into_iter()
418        .filter(is_code_entity)
419        .filter(|e| e.entity_type != "code:file")
420        .filter(|e| kind.is_none_or(|k| kind_of(e) == k))
421        .filter(|e| lang.is_none_or(|l| obs_val(e, "lang") == Some(l)))
422        .take(limit)
423        .map(|e| symbol_row(&e))
424        .collect();
425
426    to_json(&json!({ "results": rows }))
427}
428
429// ---------------------------------------------------------------------------
430// code_get_symbol
431// ---------------------------------------------------------------------------
432
433pub fn handle_code_get_symbol(kg: &GraphHandle, args: Option<&Value>) -> Result<Value> {
434    let params = args.ok_or_else(|| MCSError::InvalidParams("Missing parameters".into()))?;
435    let name = params
436        .get("name")
437        .and_then(|v| v.as_str())
438        .ok_or_else(|| MCSError::InvalidParams("Missing 'name' parameter".into()))?;
439
440    // Resolve: exact qualified name first, else fuzzy by bare name.
441    let mut matches: Vec<Entity> = Vec::new();
442    if let Ok(Some(e)) = kg.get_entity(name)
443        && is_code_entity(&e)
444    {
445        matches.push(e);
446    }
447    if matches.is_empty() {
448        let suffix = format!("::{name}");
449        matches = kg
450            .search_nodes_filtered(name, None, 0, 200)
451            .into_iter()
452            .filter(is_code_entity)
453            .filter(|e| e.name == name || e.name.ends_with(&suffix))
454            .take(10)
455            .collect();
456    }
457    if matches.is_empty() {
458        return Err(MCSError::InvalidParams(format!(
459            "No code symbol matching '{name}' (run code_index first?)"
460        )));
461    }
462
463    let edge_types = ["calls", "references"];
464    let results: Vec<Value> = matches
465        .iter()
466        .map(|e| {
467            let mut callers: Vec<String> = Vec::new();
468            let mut callees: Vec<String> = Vec::new();
469            for t in edge_types {
470                for r in kg.search_relations(None, Some(&e.name), Some(t), Some(MAX_EDGES_RETURNED)) {
471                    callers.push(r.from);
472                }
473                for r in kg.search_relations(Some(&e.name), None, Some(t), Some(MAX_EDGES_RETURNED)) {
474                    callees.push(r.to);
475                }
476            }
477            callers.truncate(MAX_EDGES_RETURNED);
478            callees.truncate(MAX_EDGES_RETURNED);
479            let mut row = symbol_row(e);
480            row["callers"] = json!(callers);
481            row["callees"] = json!(callees);
482            row
483        })
484        .collect();
485
486    if results.len() == 1 {
487        to_json(&results.into_iter().next().unwrap())
488    } else {
489        to_json(&json!({ "matches": results }))
490    }
491}