Skip to main content

cersei_tools/tool_primitives/
code_intel.rs

1//! Code intelligence via tree-sitter: extract imports, symbols, and build dependency graphs.
2//!
3//! Supports: Rust, TypeScript/JavaScript, Python, Go.
4//! Used to intelligently select which files to read for codebase analysis.
5
6use std::collections::{HashMap, HashSet};
7use std::path::{Path, PathBuf};
8use tree_sitter::{Parser, Query, QueryCursor};
9
10/// A file's extracted metadata.
11#[derive(Debug, Clone, Default)]
12pub struct FileIntel {
13    pub path: PathBuf,
14    pub language: Language,
15    pub imports: Vec<String>,
16    pub symbols: Vec<Symbol>,
17}
18
19/// A symbol extracted from source code.
20#[derive(Debug, Clone)]
21pub struct Symbol {
22    pub name: String,
23    pub kind: SymbolKind,
24    pub line: usize,
25}
26
27#[derive(Debug, Clone, Copy, PartialEq, Eq)]
28pub enum SymbolKind {
29    Function,
30    Struct,
31    Class,
32    Interface,
33    Enum,
34    Module,
35    Type,
36    Constant,
37}
38
39impl SymbolKind {
40    pub fn label(&self) -> &'static str {
41        match self {
42            Self::Function => "fn",
43            Self::Struct => "struct",
44            Self::Class => "class",
45            Self::Interface => "interface",
46            Self::Enum => "enum",
47            Self::Module => "mod",
48            Self::Type => "type",
49            Self::Constant => "const",
50        }
51    }
52}
53
54#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
55pub enum Language {
56    Rust,
57    TypeScript,
58    JavaScript,
59    Python,
60    Go,
61    #[default]
62    Unknown,
63}
64
65impl Language {
66    pub fn from_extension(ext: &str) -> Self {
67        match ext {
68            "rs" => Self::Rust,
69            "ts" | "tsx" => Self::TypeScript,
70            "js" | "jsx" | "mjs" | "cjs" => Self::JavaScript,
71            "py" | "pyi" => Self::Python,
72            "go" => Self::Go,
73            _ => Self::Unknown,
74        }
75    }
76}
77
78/// Extract imports and symbols from a source file.
79pub fn analyze_file(path: &Path, source: &str) -> Option<FileIntel> {
80    let ext = path.extension()?.to_str()?;
81    let lang = Language::from_extension(ext);
82    if lang == Language::Unknown {
83        return None;
84    }
85
86    let mut parser = Parser::new();
87    let ts_lang = match lang {
88        Language::Rust => tree_sitter_rust::LANGUAGE.into(),
89        Language::TypeScript | Language::JavaScript => {
90            tree_sitter_typescript::LANGUAGE_TYPESCRIPT.into()
91        }
92        Language::Python => tree_sitter_python::LANGUAGE.into(),
93        Language::Go => tree_sitter_go::LANGUAGE.into(),
94        Language::Unknown => return None,
95    };
96    parser.set_language(&ts_lang).ok()?;
97    let tree = parser.parse(source, None)?;
98    let root = tree.root_node();
99    let bytes = source.as_bytes();
100
101    let mut imports = Vec::new();
102    let mut symbols = Vec::new();
103
104    // Walk AST and extract imports + symbols
105    let mut stack = vec![root];
106    while let Some(node) = stack.pop() {
107        let kind = node.kind();
108
109        match lang {
110            Language::Rust => match kind {
111                "use_declaration" => {
112                    if let Ok(text) = node.utf8_text(bytes) {
113                        imports.push(text.trim().to_string());
114                    }
115                }
116                "function_item" => {
117                    if let Some(name) = node.child_by_field_name("name") {
118                        if let Ok(n) = name.utf8_text(bytes) {
119                            symbols.push(Symbol {
120                                name: n.to_string(),
121                                kind: SymbolKind::Function,
122                                line: node.start_position().row + 1,
123                            });
124                        }
125                    }
126                }
127                "struct_item" => {
128                    if let Some(name) = node.child_by_field_name("name") {
129                        if let Ok(n) = name.utf8_text(bytes) {
130                            symbols.push(Symbol {
131                                name: n.to_string(),
132                                kind: SymbolKind::Struct,
133                                line: node.start_position().row + 1,
134                            });
135                        }
136                    }
137                }
138                "enum_item" => {
139                    if let Some(name) = node.child_by_field_name("name") {
140                        if let Ok(n) = name.utf8_text(bytes) {
141                            symbols.push(Symbol {
142                                name: n.to_string(),
143                                kind: SymbolKind::Enum,
144                                line: node.start_position().row + 1,
145                            });
146                        }
147                    }
148                }
149                "mod_item" => {
150                    if let Some(name) = node.child_by_field_name("name") {
151                        if let Ok(n) = name.utf8_text(bytes) {
152                            symbols.push(Symbol {
153                                name: n.to_string(),
154                                kind: SymbolKind::Module,
155                                line: node.start_position().row + 1,
156                            });
157                        }
158                    }
159                }
160                "trait_item" => {
161                    if let Some(name) = node.child_by_field_name("name") {
162                        if let Ok(n) = name.utf8_text(bytes) {
163                            symbols.push(Symbol {
164                                name: n.to_string(),
165                                kind: SymbolKind::Interface,
166                                line: node.start_position().row + 1,
167                            });
168                        }
169                    }
170                }
171                "type_item" => {
172                    if let Some(name) = node.child_by_field_name("name") {
173                        if let Ok(n) = name.utf8_text(bytes) {
174                            symbols.push(Symbol {
175                                name: n.to_string(),
176                                kind: SymbolKind::Type,
177                                line: node.start_position().row + 1,
178                            });
179                        }
180                    }
181                }
182                _ => {}
183            },
184            Language::TypeScript | Language::JavaScript => match kind {
185                "import_statement" => {
186                    if let Some(source_node) = node.child_by_field_name("source") {
187                        if let Ok(text) = source_node.utf8_text(bytes) {
188                            imports.push(text.trim_matches(|c| c == '"' || c == '\'').to_string());
189                        }
190                    }
191                }
192                "function_declaration" => {
193                    if let Some(name) = node.child_by_field_name("name") {
194                        if let Ok(n) = name.utf8_text(bytes) {
195                            symbols.push(Symbol {
196                                name: n.to_string(),
197                                kind: SymbolKind::Function,
198                                line: node.start_position().row + 1,
199                            });
200                        }
201                    }
202                }
203                "class_declaration" => {
204                    if let Some(name) = node.child_by_field_name("name") {
205                        if let Ok(n) = name.utf8_text(bytes) {
206                            symbols.push(Symbol {
207                                name: n.to_string(),
208                                kind: SymbolKind::Class,
209                                line: node.start_position().row + 1,
210                            });
211                        }
212                    }
213                }
214                "interface_declaration" => {
215                    if let Some(name) = node.child_by_field_name("name") {
216                        if let Ok(n) = name.utf8_text(bytes) {
217                            symbols.push(Symbol {
218                                name: n.to_string(),
219                                kind: SymbolKind::Interface,
220                                line: node.start_position().row + 1,
221                            });
222                        }
223                    }
224                }
225                "type_alias_declaration" => {
226                    if let Some(name) = node.child_by_field_name("name") {
227                        if let Ok(n) = name.utf8_text(bytes) {
228                            symbols.push(Symbol {
229                                name: n.to_string(),
230                                kind: SymbolKind::Type,
231                                line: node.start_position().row + 1,
232                            });
233                        }
234                    }
235                }
236                "enum_declaration" => {
237                    if let Some(name) = node.child_by_field_name("name") {
238                        if let Ok(n) = name.utf8_text(bytes) {
239                            symbols.push(Symbol {
240                                name: n.to_string(),
241                                kind: SymbolKind::Enum,
242                                line: node.start_position().row + 1,
243                            });
244                        }
245                    }
246                }
247                "export_statement" => {
248                    // Also extract exported declarations
249                    if let Some(decl) = node.child_by_field_name("declaration") {
250                        stack.push(decl);
251                    }
252                }
253                _ => {}
254            },
255            Language::Python => match kind {
256                "import_statement" | "import_from_statement" => {
257                    if let Ok(text) = node.utf8_text(bytes) {
258                        imports.push(text.trim().to_string());
259                    }
260                }
261                "function_definition" => {
262                    if let Some(name) = node.child_by_field_name("name") {
263                        if let Ok(n) = name.utf8_text(bytes) {
264                            symbols.push(Symbol {
265                                name: n.to_string(),
266                                kind: SymbolKind::Function,
267                                line: node.start_position().row + 1,
268                            });
269                        }
270                    }
271                }
272                "class_definition" => {
273                    if let Some(name) = node.child_by_field_name("name") {
274                        if let Ok(n) = name.utf8_text(bytes) {
275                            symbols.push(Symbol {
276                                name: n.to_string(),
277                                kind: SymbolKind::Class,
278                                line: node.start_position().row + 1,
279                            });
280                        }
281                    }
282                }
283                _ => {}
284            },
285            Language::Go => match kind {
286                "import_declaration" => {
287                    if let Ok(text) = node.utf8_text(bytes) {
288                        imports.push(text.trim().to_string());
289                    }
290                }
291                "function_declaration" => {
292                    if let Some(name) = node.child_by_field_name("name") {
293                        if let Ok(n) = name.utf8_text(bytes) {
294                            symbols.push(Symbol {
295                                name: n.to_string(),
296                                kind: SymbolKind::Function,
297                                line: node.start_position().row + 1,
298                            });
299                        }
300                    }
301                }
302                "method_declaration" => {
303                    if let Some(name) = node.child_by_field_name("name") {
304                        if let Ok(n) = name.utf8_text(bytes) {
305                            symbols.push(Symbol {
306                                name: n.to_string(),
307                                kind: SymbolKind::Function,
308                                line: node.start_position().row + 1,
309                            });
310                        }
311                    }
312                }
313                "type_declaration" => {
314                    // Type declarations contain type_spec children
315                }
316                "type_spec" => {
317                    if let Some(name) = node.child_by_field_name("name") {
318                        if let Ok(n) = name.utf8_text(bytes) {
319                            let sk = if node
320                                .child_by_field_name("type")
321                                .map(|t| t.kind() == "struct_type")
322                                .unwrap_or(false)
323                            {
324                                SymbolKind::Struct
325                            } else if node
326                                .child_by_field_name("type")
327                                .map(|t| t.kind() == "interface_type")
328                                .unwrap_or(false)
329                            {
330                                SymbolKind::Interface
331                            } else {
332                                SymbolKind::Type
333                            };
334                            symbols.push(Symbol {
335                                name: n.to_string(),
336                                kind: sk,
337                                line: node.start_position().row + 1,
338                            });
339                        }
340                    }
341                }
342                _ => {}
343            },
344            Language::Unknown => {}
345        }
346
347        // Push children for traversal (only top-level for performance)
348        if node.child_count() > 0 && is_container_node(kind) {
349            for i in 0..node.child_count() {
350                if let Some(child) = node.child(i) {
351                    stack.push(child);
352                }
353            }
354        }
355    }
356
357    Some(FileIntel {
358        path: path.to_path_buf(),
359        language: lang,
360        imports,
361        symbols,
362    })
363}
364
365/// Only descend into container nodes (not function bodies, etc.)
366fn is_container_node(kind: &str) -> bool {
367    matches!(
368        kind,
369        "source_file"
370            | "program"
371            | "module"
372            | "declaration_list"
373            | "block"
374            | "statement_block"
375            | "export_statement"
376            | "type_declaration"
377            | "impl_item" // Rust impl blocks contain methods
378    )
379}
380
381/// Scan a project directory and build a dependency-ordered list of important files.
382/// Returns files sorted by importance: entry points first, then most-imported files.
383pub fn scan_project(root: &Path, max_files: usize) -> Vec<FileIntel> {
384    let files = discover_source_files(root, 200);
385    if files.is_empty() {
386        return vec![];
387    }
388
389    let mut intels: Vec<FileIntel> = Vec::new();
390    let mut import_counts: HashMap<String, usize> = HashMap::new();
391
392    for file_path in &files {
393        if let Ok(source) = std::fs::read_to_string(file_path) {
394            // Limit parsing to first 500 lines for performance
395            let truncated: String = source.lines().take(500).collect::<Vec<_>>().join("\n");
396            if let Some(intel) = analyze_file(file_path, &truncated) {
397                // Count how often each file is imported
398                for imp in &intel.imports {
399                    *import_counts.entry(imp.clone()).or_insert(0) += 1;
400                }
401                intels.push(intel);
402            }
403        }
404    }
405
406    // Score files by importance
407    let mut scored: Vec<(usize, &FileIntel)> = intels
408        .iter()
409        .map(|intel| {
410            let mut score = 0usize;
411            let path_str = intel.path.display().to_string();
412
413            // Entry points get highest score
414            let filename = intel
415                .path
416                .file_name()
417                .and_then(|f| f.to_str())
418                .unwrap_or("");
419            if matches!(
420                filename,
421                "main.rs"
422                    | "lib.rs"
423                    | "mod.rs"
424                    | "index.ts"
425                    | "index.tsx"
426                    | "App.tsx"
427                    | "App.ts"
428                    | "main.ts"
429                    | "main.tsx"
430                    | "main.py"
431                    | "__init__.py"
432                    | "main.go"
433                    | "app.go"
434            ) {
435                score += 100;
436            }
437
438            // Config files
439            if matches!(
440                filename,
441                "package.json"
442                    | "Cargo.toml"
443                    | "tsconfig.json"
444                    | "pyproject.toml"
445                    | "go.mod"
446                    | "vite.config.ts"
447            ) {
448                score += 80;
449            }
450
451            // Store/state files (key architectural files)
452            if path_str.contains("store")
453                || path_str.contains("state")
454                || path_str.contains("context")
455                || path_str.contains("reducer")
456            {
457                score += 60;
458            }
459
460            // Type definition files
461            if path_str.contains("types")
462                || path_str.contains("interfaces")
463                || filename.ends_with(".d.ts")
464            {
465                score += 40;
466            }
467
468            // Files that are imported by many others
469            for imp in &intel.imports {
470                if let Some(count) = import_counts.get(imp) {
471                    score += count * 5;
472                }
473            }
474
475            // Files with many symbols are more important
476            score += intel.symbols.len() * 3;
477
478            score
479        })
480        .enumerate()
481        .map(|(i, score)| (score, &intels[i]))
482        .collect();
483
484    scored.sort_by(|a, b| b.0.cmp(&a.0));
485
486    scored
487        .into_iter()
488        .take(max_files)
489        .map(|(_, intel)| intel.clone())
490        .collect()
491}
492
493/// Discover source files in a project (respects .gitignore via git ls-files).
494fn discover_source_files(root: &Path, max: usize) -> Vec<PathBuf> {
495    use std::process::Command;
496
497    // Try git ls-files first
498    let output = Command::new("git")
499        .args(["ls-files", "--cached", "--others", "--exclude-standard"])
500        .current_dir(root)
501        .output()
502        .ok();
503
504    let files: Vec<PathBuf> = if let Some(out) = output {
505        if out.status.success() {
506            String::from_utf8_lossy(&out.stdout)
507                .lines()
508                .filter(|l| {
509                    let ext = l.rsplit('.').next().unwrap_or("");
510                    matches!(
511                        ext,
512                        "rs" | "ts" | "tsx" | "js" | "jsx" | "py" | "go" | "mjs" | "cjs" | "mts"
513                    )
514                })
515                .take(max)
516                .map(|l| root.join(l))
517                .collect()
518        } else {
519            vec![]
520        }
521    } else {
522        vec![]
523    };
524
525    if files.is_empty() {
526        // Fallback: walkdir
527        walkdir_source_files(root, max)
528    } else {
529        files
530    }
531}
532
533fn walkdir_source_files(root: &Path, max: usize) -> Vec<PathBuf> {
534    let excluded = [
535        "node_modules",
536        "target",
537        ".git",
538        "__pycache__",
539        "venv",
540        ".venv",
541        "dist",
542        "build",
543    ];
544    let mut files = Vec::new();
545
546    fn walk(dir: &Path, excluded: &[&str], files: &mut Vec<PathBuf>, max: usize) {
547        if files.len() >= max {
548            return;
549        }
550        let entries = match std::fs::read_dir(dir) {
551            Ok(e) => e,
552            Err(_) => return,
553        };
554        for entry in entries.flatten() {
555            if files.len() >= max {
556                return;
557            }
558            let name = entry.file_name().to_string_lossy().to_string();
559            if name.starts_with('.') || excluded.contains(&name.as_str()) {
560                continue;
561            }
562            let path = entry.path();
563            if path.is_file() {
564                let ext = path.extension().and_then(|e| e.to_str()).unwrap_or("");
565                if matches!(ext, "rs" | "ts" | "tsx" | "js" | "jsx" | "py" | "go") {
566                    files.push(path);
567                }
568            } else if path.is_dir() {
569                walk(&path, excluded, files, max);
570            }
571        }
572    }
573
574    walk(root, &excluded, &mut files, max);
575    files
576}
577
578/// Format a project scan as a concise summary for injection into the system prompt.
579pub fn format_project_intel(intels: &[FileIntel]) -> String {
580    let mut out = String::new();
581
582    for intel in intels {
583        let rel_path = intel
584            .path
585            .file_name()
586            .and_then(|f| f.to_str())
587            .unwrap_or("?");
588
589        // Format: path (lang) — symbols: fn foo, struct Bar; imports: ...
590        let symbols_str: Vec<String> = intel
591            .symbols
592            .iter()
593            .take(8)
594            .map(|s| format!("{} {}", s.kind.label(), s.name))
595            .collect();
596
597        let imports_str: Vec<String> = intel.imports.iter().take(5).cloned().collect();
598
599        out.push_str(&format!("• {} — ", intel.path.display()));
600        if !symbols_str.is_empty() {
601            out.push_str(&symbols_str.join(", "));
602        }
603        if !imports_str.is_empty() {
604            if !symbols_str.is_empty() {
605                out.push_str(" | imports: ");
606            }
607            out.push_str(&imports_str.join(", "));
608        }
609        out.push('\n');
610    }
611
612    out
613}
614
615#[cfg(test)]
616mod tests {
617    use super::*;
618
619    #[test]
620    fn test_analyze_rust_file() {
621        let source = r#"
622use std::collections::HashMap;
623use serde::Serialize;
624
625pub struct Config {
626    pub name: String,
627}
628
629pub fn load_config() -> Config {
630    Config { name: "test".into() }
631}
632
633enum Mode { Fast, Slow }
634"#;
635        let intel = analyze_file(Path::new("test.rs"), source).unwrap();
636        assert_eq!(intel.language, Language::Rust);
637        assert!(intel.imports.len() >= 2);
638        assert!(intel
639            .symbols
640            .iter()
641            .any(|s| s.name == "Config" && s.kind == SymbolKind::Struct));
642        assert!(intel
643            .symbols
644            .iter()
645            .any(|s| s.name == "load_config" && s.kind == SymbolKind::Function));
646    }
647
648    #[test]
649    fn test_analyze_typescript_file() {
650        let source = r#"
651import { useState } from "react";
652import { create } from "zustand";
653
654interface AppState {
655    count: number;
656}
657
658function increment() {}
659
660class App {}
661
662export type Config = { name: string };
663"#;
664        let intel = analyze_file(Path::new("test.ts"), source).unwrap();
665        assert_eq!(intel.language, Language::TypeScript);
666        assert!(intel.imports.iter().any(|i| i.contains("react")));
667        assert!(intel
668            .symbols
669            .iter()
670            .any(|s| s.name == "AppState" && s.kind == SymbolKind::Interface));
671        assert!(intel
672            .symbols
673            .iter()
674            .any(|s| s.name == "increment" && s.kind == SymbolKind::Function));
675    }
676
677    #[test]
678    fn test_analyze_python_file() {
679        let source = r#"
680import os
681from pathlib import Path
682
683class MyModel:
684    pass
685
686def train():
687    pass
688"#;
689        let intel = analyze_file(Path::new("test.py"), source).unwrap();
690        assert_eq!(intel.language, Language::Python);
691        assert!(intel.imports.len() >= 2);
692        assert!(intel
693            .symbols
694            .iter()
695            .any(|s| s.name == "MyModel" && s.kind == SymbolKind::Class));
696        assert!(intel
697            .symbols
698            .iter()
699            .any(|s| s.name == "train" && s.kind == SymbolKind::Function));
700    }
701
702    #[test]
703    fn test_language_detection() {
704        assert_eq!(Language::from_extension("rs"), Language::Rust);
705        assert_eq!(Language::from_extension("tsx"), Language::TypeScript);
706        assert_eq!(Language::from_extension("py"), Language::Python);
707        assert_eq!(Language::from_extension("go"), Language::Go);
708        assert_eq!(Language::from_extension("md"), Language::Unknown);
709    }
710}