pub mod types;
pub mod languages;
pub mod index;
pub use types::*;
pub use index::ProjectIndex;
use languages::{extract_symbols, ts_language};
pub fn parse(source: &str, lang: Language) -> Option<ParsedFile> {
let ts_lang = ts_language(lang)?;
let mut parser = tree_sitter::Parser::new();
parser.set_language(&ts_lang).ok()?;
let tree = parser.parse(source.as_bytes(), None)?;
let (symbols, imports) = extract_symbols(lang, &tree, source.as_bytes());
Some(ParsedFile {
language: lang,
symbols,
imports,
})
}
pub fn parse_file(source: &str, filename: &str) -> Option<ParsedFile> {
let lang = Language::from_filename(filename)?;
parse(source, lang)
}
pub fn extract_source(symbol: &Symbol, source: &str) -> String {
if symbol.span.end_byte <= source.len() {
source[symbol.span.start_byte..symbol.span.end_byte].to_string()
} else {
String::new()
}
}
pub fn extract_context(symbol: &Symbol, _parsed: &ParsedFile, source: &str) -> SymbolContext {
let sym_source = extract_source(symbol, source);
let source_bytes = source.as_bytes();
let mut references = Vec::new();
let name = &symbol.name;
let mut search_from = 0;
while let Some(pos) = source[search_from..].find(name) {
let abs_pos = search_from + pos;
if abs_pos < symbol.span.start_byte || abs_pos >= symbol.span.end_byte {
let before = if abs_pos > 0 { source_bytes[abs_pos - 1] } else { b' ' };
let after_pos = abs_pos + name.len();
let after = if after_pos < source_bytes.len() { source_bytes[after_pos] } else { b' ' };
if !is_ident_char(before) && !is_ident_char(after) {
let line = source[..abs_pos].matches('\n').count() as u32;
let col = abs_pos - source[..abs_pos].rfind('\n').map(|p| p + 1).unwrap_or(0);
let end_pos = abs_pos + name.len();
let end_line = source[..end_pos].matches('\n').count() as u32;
let end_col = end_pos - source[..end_pos].rfind('\n').map(|p| p + 1).unwrap_or(0);
references.push(Span {
start_byte: abs_pos,
end_byte: end_pos,
start_line: line,
start_col: col as u32,
end_line: end_line,
end_col: end_col as u32,
});
}
}
search_from = abs_pos + 1;
}
SymbolContext {
symbol: symbol.clone(),
source: sym_source,
references_in_file: references,
}
}
pub fn extract_symbols_source(symbols: &[&Symbol], source: &str) -> String {
if symbols.is_empty() {
return String::new();
}
let mut parts: Vec<(usize, String)> = symbols
.iter()
.map(|s| (s.span.start_byte, extract_source(s, source)))
.collect();
parts.sort_by_key(|(pos, _)| *pos);
let mut result = String::new();
for (i, (_, src)) in parts.iter().enumerate() {
if i > 0 {
result.push_str("\n\n// ...\n\n");
}
result.push_str(src);
}
result
}
pub fn diff_symbols(old: &ParsedFile, new: &ParsedFile) -> Vec<SymbolChange> {
let mut changes = Vec::new();
let old_syms = old.all_symbols();
let new_syms = new.all_symbols();
let old_map: std::collections::HashMap<(&str, SymbolKind), &Symbol> = old_syms
.iter()
.map(|s| ((s.name.as_str(), s.kind), *s))
.collect();
let new_map: std::collections::HashMap<(&str, SymbolKind), &Symbol> = new_syms
.iter()
.map(|s| ((s.name.as_str(), s.kind), *s))
.collect();
for (key, old_sym) in &old_map {
if let Some(new_sym) = new_map.get(key) {
if old_sym.signature != new_sym.signature
|| old_sym.span.start_byte != new_sym.span.start_byte
|| old_sym.span.end_byte != new_sym.span.end_byte
{
changes.push(SymbolChange::Modified {
old: (*old_sym).clone(),
new: (*new_sym).clone(),
signature_changed: old_sym.signature != new_sym.signature,
});
}
} else {
changes.push(SymbolChange::Removed((*old_sym).clone()));
}
}
for (key, new_sym) in &new_map {
if !old_map.contains_key(key) {
changes.push(SymbolChange::Added((*new_sym).clone()));
}
}
changes
}
fn is_ident_char(b: u8) -> bool {
b.is_ascii_alphanumeric() || b == b'_'
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_rust() {
let source = r#"
use std::collections::HashMap;
/// A cool struct.
pub struct Foo {
bar: String,
}
impl Foo {
/// Create a new Foo.
pub fn new(bar: String) -> Self {
Self { bar }
}
pub fn bar(&self) -> &str {
&self.bar
}
}
fn helper() -> bool {
true
}
"#;
let parsed = parse(source, Language::Rust).unwrap();
assert!(!parsed.symbols.is_empty());
assert!(!parsed.imports.is_empty());
let foo_structs = parsed.find_symbol("Foo");
assert!(foo_structs.iter().any(|s| s.kind == SymbolKind::Struct));
let foo_impls: Vec<_> = parsed.symbols.iter()
.filter(|s| s.kind == SymbolKind::Impl && s.name == "Foo")
.collect();
assert_eq!(foo_impls.len(), 1);
assert_eq!(foo_impls[0].children.len(), 2);
assert_eq!(foo_impls[0].children[0].name, "new");
assert_eq!(foo_impls[0].children[0].kind, SymbolKind::Method);
let helpers = parsed.find_symbol("helper");
assert_eq!(helpers.len(), 1);
assert_eq!(helpers[0].kind, SymbolKind::Function);
assert_eq!(parsed.imports.len(), 1);
assert!(parsed.imports[0].path.contains("HashMap"));
}
#[test]
fn test_parse_python() {
let source = r#"
import os
from pathlib import Path
MAX_SIZE = 1024
class Parser:
"""A code parser."""
def __init__(self, lang: str):
self.lang = lang
def parse(self, source: str) -> dict:
return {}
def standalone():
pass
"#;
let parsed = parse(source, Language::Python).unwrap();
assert!(!parsed.imports.is_empty());
let classes = parsed.find_symbol("Parser");
assert_eq!(classes.len(), 1);
assert_eq!(classes[0].kind, SymbolKind::Class);
assert_eq!(classes[0].children.len(), 2);
let funcs = parsed.find_symbol("standalone");
assert_eq!(funcs.len(), 1);
let consts = parsed.find_symbol("MAX_SIZE");
assert_eq!(consts.len(), 1);
assert_eq!(consts[0].kind, SymbolKind::Const);
}
#[test]
fn test_parse_file_detection() {
let source = "fn main() {}";
let parsed = parse_file(source, "main.rs").unwrap();
assert_eq!(parsed.language, Language::Rust);
assert!(parse_file(source, "unknown.xyz").is_none());
}
#[test]
fn test_diff_symbols() {
let old_source = r#"
fn foo() -> i32 { 1 }
fn bar() -> i32 { 2 }
"#;
let new_source = r#"
fn foo() -> i64 { 1 }
fn baz() -> i32 { 3 }
"#;
let old = parse(old_source, Language::Rust).unwrap();
let new = parse(new_source, Language::Rust).unwrap();
let changes = diff_symbols(&old, &new);
assert!(changes.iter().any(|c| matches!(c, SymbolChange::Modified { old, .. } if old.name == "foo")));
assert!(changes.iter().any(|c| matches!(c, SymbolChange::Removed(s) if s.name == "bar")));
assert!(changes.iter().any(|c| matches!(c, SymbolChange::Added(s) if s.name == "baz")));
}
#[test]
fn test_extract_context() {
let source = r#"
fn helper() -> bool { true }
fn main() {
let x = helper();
println!("{}", x);
}
"#;
let parsed = parse(source, Language::Rust).unwrap();
let helper = &parsed.find_symbol("helper")[0];
let ctx = extract_context(helper, &parsed, source);
assert!(!ctx.source.is_empty());
assert!(ctx.source.contains("fn helper()"));
assert!(!ctx.references_in_file.is_empty());
}
#[test]
fn test_extract_symbols_source() {
let source = r#"
fn a() { 1 }
fn b() { 2 }
fn c() { 3 }
"#;
let parsed = parse(source, Language::Rust).unwrap();
let a = &parsed.find_symbol("a")[0];
let c = &parsed.find_symbol("c")[0];
let compact = extract_symbols_source(&[a, c], source);
assert!(compact.contains("fn a()"));
assert!(compact.contains("fn c()"));
assert!(compact.contains("// ..."));
assert!(!compact.contains("fn b()"));
}
}