Skip to main content

mcp_memory/actions/
code.rs

1//! MCP tool handlers for the tree-sitter code knowledge graph.
2//!
3//! These map parsed code symbols (see [`crate::code`]) onto the existing
4//! entity/relation graph so the regular search/traversal tools work on code,
5//! and expose four code-focused tools: `code_index`, `code_outline`,
6//! `code_search`, `code_get_symbol`.
7//!
8//! Symbols are stored as entities named `relpath::symbol` with type
9//! `code:<kind>`; metadata (file, line range, signature, doc) lives in
10//! observations. Edges: `defines` (file→symbol), `calls`/`references`
11//! (caller→callee, resolved only when the callee name is unambiguous).
12
13use std::collections::{HashMap, HashSet};
14use std::path::Path;
15use std::sync::atomic::{AtomicUsize, Ordering};
16use std::time::{SystemTime, UNIX_EPOCH};
17
18use serde_json::{Value, json};
19
20use crate::code::{self, Def};
21use crate::errors::{MCSError, Result};
22use crate::kg::GraphHandle;
23use crate::types::{Entity, Relation};
24
25/// Cap on files processed in a single `code_index` call.
26const MAX_INDEX_FILES: usize = 50_000;
27/// Cap on symbols recorded per file (guards pathological/generated files).
28const MAX_SYMBOLS_PER_FILE: usize = 5_000;
29/// Batch size for graph writes (keeps each write transaction bounded).
30const WRITE_BATCH: usize = 1_000;
31/// Default / max result rows for `code_search`.
32const DEFAULT_SEARCH_LIMIT: usize = 20;
33const MAX_SEARCH_LIMIT: usize = 200;
34/// Cap on callers/callees returned by `code_get_symbol`.
35const MAX_EDGES_RETURNED: usize = 200;
36
37macro_rules! text_content {
38    ($text:expr) => {
39        json!({ "content": [{ "type": "text", "text": $text }] })
40    };
41}
42
43fn to_json(v: &impl serde::Serialize) -> Result<Value> {
44    let text = serde_json::to_string(v).map_err(MCSError::JsonError)?;
45    Ok(text_content!(text))
46}
47
48/// Repo-relative, forward-slash path used as the `code:file` entity name.
49fn rel_path(p: &Path, base: &Path) -> String {
50    let r = if p.is_absolute() {
51        p.strip_prefix(base).unwrap_or(p)
52    } else {
53        p
54    };
55    r.to_string_lossy().replace('\\', "/")
56}
57
58/// Read a single-valued `key: value` observation off an entity.
59fn obs_val<'a>(entity: &'a Entity, key: &str) -> Option<&'a str> {
60    let prefix = format!("{key}: ");
61    entity
62        .observations
63        .iter()
64        .find_map(|o| o.strip_prefix(&prefix))
65}
66
67/// Strip the `code:` prefix from an entity type for display.
68fn kind_of(entity: &Entity) -> &str {
69    entity.entity_type.strip_prefix("code:").unwrap_or(&entity.entity_type)
70}
71
72fn is_code_entity(entity: &Entity) -> bool {
73    entity.entity_type.starts_with("code:")
74}
75
76/// Compact, location-focused view of a code symbol entity.
77fn symbol_row(entity: &Entity) -> Value {
78    json!({
79        "name": entity.name,
80        "kind": kind_of(entity),
81        "file": obs_val(entity, "file"),
82        "lines": obs_val(entity, "lines"),
83        "lang": obs_val(entity, "lang"),
84        "signature": obs_val(entity, "signature"),
85        "doc": obs_val(entity, "doc"),
86    })
87}
88
89// ---------------------------------------------------------------------------
90// code_index
91// ---------------------------------------------------------------------------
92
93/// Parsed symbols for one file, with qualified names already assigned.
94struct FileWork {
95    rel: String,
96    lang: &'static str,
97    hash: String,
98    /// Whether a `code:file` entity already existed (drives purge-skip).
99    existed: bool,
100    named: Vec<(Def, String)>,
101    refs: Vec<code::Ref>,
102}
103
104/// Outcome of processing a single path during the parallel parse phase.
105enum Outcome {
106    Indexed(Box<FileWork>),
107    Skipped,
108    Failed,
109    Unsupported,
110}
111
112/// Read + hash + (incrementally) parse one file. CPU-bound and independent per
113/// file, so this runs on the parse thread pool. Reads use the graph's
114/// concurrent read pool; no writes happen here.
115fn parse_one(kg: &GraphHandle, path: &Path, base: &Path, force: bool) -> Outcome {
116    let Some(lang) = code::detect(path) else {
117        return Outcome::Unsupported;
118    };
119    let rel = rel_path(path, base);
120    let Ok(bytes) = std::fs::read(path) else {
121        return Outcome::Failed;
122    };
123    let hash = code::hash_bytes(&bytes);
124
125    let existing = kg.get_entity(&rel).ok().flatten();
126    let existed = existing.is_some();
127    // Incremental: skip unchanged files (matching stored hash).
128    if !force
129        && let Some(e) = &existing
130        && obs_val(e, "hash") == Some(hash.as_str())
131    {
132        return Outcome::Skipped;
133    }
134
135    let parsed = code::parse_source(lang, &bytes);
136    let mut seen: HashSet<String> = HashSet::new();
137    let mut named: Vec<(Def, String)> = Vec::with_capacity(parsed.defs.len());
138    for d in parsed.defs.into_iter().take(MAX_SYMBOLS_PER_FILE) {
139        let mut q = format!("{rel}::{}", d.name);
140        if !seen.insert(q.clone()) {
141            q = format!("{q}::L{}", d.line_start);
142            seen.insert(q.clone());
143        }
144        named.push((d, q));
145    }
146
147    Outcome::Indexed(Box::new(FileWork {
148        rel,
149        lang: lang.name(),
150        hash,
151        existed,
152        named,
153        refs: parsed.refs,
154    }))
155}
156
157pub fn handle_code_index(kg: &GraphHandle, args: Option<&Value>) -> Result<Value> {
158    let params = args.ok_or_else(|| MCSError::InvalidParams("Missing parameters".into()))?;
159    let path = params
160        .get("path")
161        .and_then(|v| v.as_str())
162        .ok_or_else(|| MCSError::InvalidParams("Missing 'path' parameter".into()))?;
163    let force = params.get("force").and_then(|v| v.as_bool()).unwrap_or(false);
164
165    let root = Path::new(path);
166    if !root.exists() {
167        return Err(MCSError::InvalidParams(format!("Path not found: {path}")));
168    }
169    let base = std::env::current_dir().unwrap_or_else(|_| Path::new(".").to_path_buf());
170
171    let mut files = code::walk(root, code::MAX_FILE_BYTES);
172    files.truncate(MAX_INDEX_FILES);
173
174    let now = SystemTime::now()
175        .duration_since(UNIX_EPOCH)
176        .map(|d| d.as_secs())
177        .unwrap_or(0);
178
179    // Parse phase (parallel): read + hash + parse each file across the CPU
180    // cores. Files are independent and parsing is the dominant cost; reads use
181    // the concurrent read pool. The single-writer graph mutations stay serial
182    // in the merge phase below.
183    let n = files.len();
184    let n_threads = std::thread::available_parallelism()
185        .map(|t| t.get())
186        .unwrap_or(4)
187        .min(n.max(1));
188    let next = AtomicUsize::new(0);
189    let buckets: Vec<Vec<Outcome>> = std::thread::scope(|scope| {
190        let handles: Vec<_> = (0..n_threads)
191            .map(|_| {
192                scope.spawn(|| {
193                    let mut local = Vec::new();
194                    loop {
195                        let i = next.fetch_add(1, Ordering::Relaxed);
196                        if i >= n {
197                            break;
198                        }
199                        local.push(parse_one(kg, &files[i], &base, force));
200                    }
201                    local
202                })
203            })
204            .collect();
205        handles.into_iter().map(|h| h.join().unwrap()).collect()
206    });
207
208    // Merge phase (serial): tally outcomes and build the global symbol index
209    // (bare name -> qualified names) used to resolve unambiguous call edges.
210    let mut work: Vec<FileWork> = Vec::new();
211    let mut def_index: HashMap<String, Vec<String>> = HashMap::new();
212    let mut files_indexed = 0usize;
213    let mut files_skipped = 0usize;
214    let mut files_failed = 0usize;
215    for outcome in buckets.into_iter().flatten() {
216        match outcome {
217            Outcome::Indexed(fw) => {
218                for (d, q) in &fw.named {
219                    def_index.entry(d.name.clone()).or_default().push(q.clone());
220                }
221                work.push(*fw);
222                files_indexed += 1;
223            }
224            Outcome::Skipped => files_skipped += 1,
225            Outcome::Failed => files_failed += 1,
226            Outcome::Unsupported => {}
227        }
228    }
229
230    // Write phase (serial, single writer). Streamed in `WRITE_BATCH` chunks so
231    // the transient entity/relation buffers stay bounded regardless of repo
232    // size; the parsed `work` is the only large allocation. Entities are written
233    // in full *before* any relation, since relations resolve their endpoints by
234    // name and would silently drop against a not-yet-written entity.
235
236    // Pass 1: purge changed files and write all entities.
237    let mut ebuf: Vec<Entity> = Vec::with_capacity(WRITE_BATCH);
238    let mut symbols = 0usize;
239    for fw in &work {
240        if fw.existed {
241            kg.code_purge_file(&fw.rel)?;
242        }
243        ebuf.push(Entity {
244            name: fw.rel.clone(),
245            entity_type: "code:file".into(),
246            observations: vec![
247                format!("lang: {}", fw.lang),
248                format!("hash: {}", fw.hash),
249                format!("symbols: {}", fw.named.len()),
250                format!("indexed_at: {now}"),
251            ],
252        });
253        for (d, q) in &fw.named {
254            let mut obs = vec![
255                format!("kind: {}", d.kind),
256                format!("lang: {}", fw.lang),
257                format!("file: {}", fw.rel),
258                format!("lines: {}-{}", d.line_start, d.line_end),
259                format!("signature: {}", d.signature),
260            ];
261            if let Some(doc) = &d.doc {
262                obs.push(format!("doc: {doc}"));
263            }
264            ebuf.push(Entity {
265                name: q.clone(),
266                entity_type: format!("code:{}", d.kind),
267                observations: obs,
268            });
269            symbols += 1;
270        }
271        if ebuf.len() >= WRITE_BATCH {
272            kg.upsert_entities(&ebuf)?;
273            ebuf.clear();
274        }
275    }
276    if !ebuf.is_empty() {
277        kg.upsert_entities(&ebuf)?;
278    }
279
280    // Pass 2: write `defines` edges and unambiguously-resolved call edges.
281    let mut rbuf: Vec<Relation> = Vec::with_capacity(WRITE_BATCH);
282    let mut rel_seen: HashSet<(String, String, &'static str)> = HashSet::new();
283    let mut relation_count = 0usize;
284    for fw in &work {
285        for (_, q) in &fw.named {
286            rbuf.push(Relation {
287                from: fw.rel.clone(),
288                to: q.clone(),
289                relation_type: "defines".into(),
290            });
291            relation_count += 1;
292        }
293        for r in &fw.refs {
294            let Some(targets) = def_index.get(&r.name) else { continue };
295            if targets.len() != 1 {
296                continue; // ambiguous or unresolved — drop (no false edges)
297            }
298            let callee = &targets[0];
299            let caller = enclosing(&fw.named, r.line).unwrap_or(&fw.rel).to_string();
300            if &caller == callee {
301                continue;
302            }
303            let rtype: &'static str = if r.kind == "call" { "calls" } else { "references" };
304            if !rel_seen.insert((caller.clone(), callee.clone(), rtype)) {
305                continue;
306            }
307            rbuf.push(Relation {
308                from: caller,
309                to: callee.clone(),
310                relation_type: rtype.into(),
311            });
312            relation_count += 1;
313        }
314        if rbuf.len() >= WRITE_BATCH {
315            kg.create_relations(&rbuf)?;
316            rbuf.clear();
317        }
318    }
319    if !rbuf.is_empty() {
320        kg.create_relations(&rbuf)?;
321    }
322
323    to_json(&json!({
324        "files_indexed": files_indexed,
325        "files_skipped": files_skipped,
326        "files_failed": files_failed,
327        "symbols": symbols,
328        "relations": relation_count,
329    }))
330}
331
332/// Smallest-span definition whose line range encloses `line`, if any.
333fn enclosing(named: &[(Def, String)], line: usize) -> Option<&str> {
334    named
335        .iter()
336        .filter(|(d, _)| d.line_start <= line && line <= d.line_end)
337        .min_by_key(|(d, _)| d.line_end - d.line_start)
338        .map(|(_, q)| q.as_str())
339}
340
341// ---------------------------------------------------------------------------
342// code_outline
343// ---------------------------------------------------------------------------
344
345pub fn handle_code_outline(kg: &GraphHandle, args: Option<&Value>) -> Result<Value> {
346    let params = args.ok_or_else(|| MCSError::InvalidParams("Missing parameters".into()))?;
347    let file = params
348        .get("file")
349        .and_then(|v| v.as_str())
350        .ok_or_else(|| MCSError::InvalidParams("Missing 'file' parameter".into()))?;
351    let file = file.replace('\\', "/");
352
353    let defines = kg.search_relations(Some(&file), None, Some("defines"));
354    let names: Vec<String> = defines.into_iter().map(|r| r.to).collect();
355    if names.is_empty() {
356        return to_json(&json!({
357            "file": file,
358            "symbols": [],
359            "note": "no symbols indexed for this file; run code_index first",
360        }));
361    }
362    let mut rows: Vec<Value> = kg
363        .batch_get_entities(&names)
364        .into_iter()
365        .flatten()
366        .map(|e| symbol_row(&e))
367        .collect();
368    // Order by starting line for a readable outline.
369    rows.sort_by_key(|r| {
370        r.get("lines")
371            .and_then(|v| v.as_str())
372            .and_then(|s| s.split('-').next())
373            .and_then(|s| s.parse::<u64>().ok())
374            .unwrap_or(0)
375    });
376
377    to_json(&json!({ "file": file, "symbols": rows }))
378}
379
380// ---------------------------------------------------------------------------
381// code_search
382// ---------------------------------------------------------------------------
383
384pub fn handle_code_search(kg: &GraphHandle, args: Option<&Value>) -> Result<Value> {
385    let params = args.ok_or_else(|| MCSError::InvalidParams("Missing parameters".into()))?;
386    let query = params
387        .get("query")
388        .and_then(|v| v.as_str())
389        .ok_or_else(|| MCSError::InvalidParams("Missing 'query' parameter".into()))?;
390    let kind = params.get("kind").and_then(|v| v.as_str()).filter(|s| !s.is_empty());
391    let lang = params.get("lang").and_then(|v| v.as_str()).filter(|s| !s.is_empty());
392    let limit = params
393        .get("limit")
394        .and_then(|v| v.as_u64())
395        .map(|n| n as usize)
396        .unwrap_or(DEFAULT_SEARCH_LIMIT)
397        .clamp(1, MAX_SEARCH_LIMIT);
398
399    // Over-fetch then narrow to code entities (search has a single-type filter).
400    let raw = kg.search_nodes_filtered(query, None, 0, limit.saturating_mul(5).min(1000));
401    let rows: Vec<Value> = raw
402        .into_iter()
403        .filter(is_code_entity)
404        .filter(|e| e.entity_type != "code:file")
405        .filter(|e| kind.is_none_or(|k| kind_of(e) == k))
406        .filter(|e| lang.is_none_or(|l| obs_val(e, "lang") == Some(l)))
407        .take(limit)
408        .map(|e| symbol_row(&e))
409        .collect();
410
411    to_json(&json!({ "results": rows }))
412}
413
414// ---------------------------------------------------------------------------
415// code_get_symbol
416// ---------------------------------------------------------------------------
417
418pub fn handle_code_get_symbol(kg: &GraphHandle, args: Option<&Value>) -> Result<Value> {
419    let params = args.ok_or_else(|| MCSError::InvalidParams("Missing parameters".into()))?;
420    let name = params
421        .get("name")
422        .and_then(|v| v.as_str())
423        .ok_or_else(|| MCSError::InvalidParams("Missing 'name' parameter".into()))?;
424
425    // Resolve: exact qualified name first, else fuzzy by bare name.
426    let mut matches: Vec<Entity> = Vec::new();
427    if let Ok(Some(e)) = kg.get_entity(name)
428        && is_code_entity(&e)
429    {
430        matches.push(e);
431    }
432    if matches.is_empty() {
433        let suffix = format!("::{name}");
434        matches = kg
435            .search_nodes_filtered(name, None, 0, 200)
436            .into_iter()
437            .filter(is_code_entity)
438            .filter(|e| e.name == name || e.name.ends_with(&suffix))
439            .take(10)
440            .collect();
441    }
442    if matches.is_empty() {
443        return Err(MCSError::InvalidParams(format!(
444            "No code symbol matching '{name}' (run code_index first?)"
445        )));
446    }
447
448    let edge_types = ["calls", "references"];
449    let results: Vec<Value> = matches
450        .iter()
451        .map(|e| {
452            let mut callers: Vec<String> = Vec::new();
453            let mut callees: Vec<String> = Vec::new();
454            for t in edge_types {
455                for r in kg.search_relations(None, Some(&e.name), Some(t)) {
456                    callers.push(r.from);
457                }
458                for r in kg.search_relations(Some(&e.name), None, Some(t)) {
459                    callees.push(r.to);
460                }
461            }
462            callers.truncate(MAX_EDGES_RETURNED);
463            callees.truncate(MAX_EDGES_RETURNED);
464            let mut row = symbol_row(e);
465            row["callers"] = json!(callers);
466            row["callees"] = json!(callees);
467            row
468        })
469        .collect();
470
471    if results.len() == 1 {
472        to_json(&results.into_iter().next().unwrap())
473    } else {
474        to_json(&json!({ "matches": results }))
475    }
476}