use streaming_iterator::StreamingIterator;
use tree_sitter::{Node, Query, QueryCursor, QueryMatch};
use super::{ExtractError, FileMapL1, Implementation, Import, SCHEMA_VER, Symbol, SymbolKind};
use crate::lang::{
LangId, ParseOutcome, parse_with_default_timeout, try_get_combined_l1_query, with_parser,
};
pub fn extract_l1(lang: LangId, source: &[u8]) -> Result<FileMapL1, ExtractError> {
let outcome = with_parser(lang, |p| parse_with_default_timeout(p, source))?;
let tree = match outcome {
ParseOutcome::Ok(t) => t,
ParseOutcome::Failed => return Err(ExtractError::ParseFailure),
ParseOutcome::TimedOut => {
return Err(ExtractError::ParseTimeout(
crate::lang::DEFAULT_PARSE_TIMEOUT,
));
}
};
extract_l1_from_tree(lang, &tree, source)
}
pub(crate) fn extract_l1_from_tree(
lang: LangId,
tree: &tree_sitter::Tree,
source: &[u8],
) -> Result<FileMapL1, ExtractError> {
let root = tree.root_node();
let (had_errors, error_count) = if root.has_error() {
(true, count_error_nodes(root))
} else {
(false, 0)
};
let (symbols, imports, implementations) = run_combined(lang, root, source)?;
Ok(FileMapL1 {
schema_ver: SCHEMA_VER,
language: lang.to_string(),
size_bytes: source.len() as u64,
had_errors,
error_count,
symbols,
imports,
implementations,
})
}
fn count_error_nodes(root: Node) -> u32 {
let mut count: u32 = 0;
let mut cursor = root.walk();
let mut stack: Vec<Node> = vec![root];
while let Some(node) = stack.pop() {
if node.is_error() || node.is_missing() {
count = count.saturating_add(1);
}
for child in node.children(&mut cursor) {
stack.push(child);
}
}
count
}
type CombinedL1 = (Vec<Symbol>, Vec<Import>, Vec<Implementation>);
fn run_combined(
lang: LangId,
root: tree_sitter::Node,
source: &[u8],
) -> Result<CombinedL1, ExtractError> {
let Some(q) = try_get_combined_l1_query(lang)? else {
return Ok((Vec::new(), Vec::new(), Vec::new()));
};
let mut cursor = QueryCursor::new();
let mut iter = cursor.matches(&q, root, source);
let mut symbols = Vec::new();
let mut imports = Vec::new();
let mut implementations = Vec::new();
while let Some(m) = iter.next() {
let Some(first_cap) = m.captures.first() else {
continue;
};
let first_name = capture_name(&q, first_cap.index);
if first_name.starts_with("symbol.") {
if let Some(sym) = build_symbol(&q, m, source) {
symbols.push(sym);
}
} else if first_name.starts_with("import.") {
if let Some(imp) = build_import(&q, m, source) {
imports.push(imp);
}
} else if first_name.starts_with("impl.") {
if let Some(imp) = build_implementation(&q, m, source) {
implementations.push(imp);
}
} else {
debug_assert!(
false,
"unexpected capture prefix in combined L1 query: {first_name}"
);
}
}
Ok((dedupe_symbols(symbols), imports, implementations))
}
fn dedupe_symbols(syms: Vec<Symbol>) -> Vec<Symbol> {
let mut keep: Vec<Symbol> = Vec::with_capacity(syms.len());
let mut index: ahash::AHashMap<(u32, String), usize> =
ahash::AHashMap::with_capacity(syms.len());
for sym in syms {
let key = (sym.start_byte, sym.name.clone());
if let Some(&idx) = index.get(&key) {
let existing = &mut keep[idx];
if sym.kind.specificity() > existing.kind.specificity() {
existing.kind = sym.kind;
if sym.signature.is_some() {
existing.signature = sym.signature;
}
}
for d in sym.decorators {
if !existing.decorators.contains(&d) {
existing.decorators.push(d);
}
}
} else {
let new_idx = keep.len();
keep.push(sym);
index.insert(key, new_idx);
}
}
keep
}
fn capture_name(q: &Query, index: u32) -> &str {
q.capture_names()[index as usize]
}
fn build_symbol(q: &Query, m: &QueryMatch, source: &[u8]) -> Option<Symbol> {
let mut name: Option<String> = None;
let mut kind: Option<SymbolKind> = None;
let mut start_byte = 0u32;
let mut end_byte = 0u32;
let mut start_row = 0u32;
let mut start_col = 0u32;
let mut signature: Option<String> = None;
let mut decorators: Vec<String> = Vec::new();
for cap in m.captures {
let cname = capture_name(q, cap.index);
let node = cap.node;
if cname == "symbol.name" {
name = node.utf8_text(source).ok().map(|s| s.to_string());
} else if cname == "symbol.decorator" {
if let Ok(text) = node.utf8_text(source) {
let trimmed = text.trim();
if !trimmed.is_empty() {
decorators.push(trimmed.to_string());
}
}
} else if let Some(suffix) = cname.strip_prefix("symbol.") {
kind = Some(SymbolKind::from_capture_suffix(suffix));
start_byte = node.start_byte() as u32;
end_byte = node.end_byte() as u32;
let p = node.start_position();
start_row = p.row as u32;
start_col = p.column as u32;
if let Ok(text) = node.utf8_text(source) {
signature = signature_slice(text);
if matches!(kind, Some(SymbolKind::Method))
&& let Some(promoted) = detect_accessor(text)
{
kind = Some(promoted);
}
}
}
}
Some(Symbol {
name: name?,
kind: kind.unwrap_or(SymbolKind::Unknown),
start_byte,
end_byte,
start_row,
start_col,
signature,
decorators,
})
}
fn detect_accessor(slice: &str) -> Option<SymbolKind> {
for tok in slice.split_whitespace().take(8) {
match tok {
"get" => return Some(SymbolKind::Getter),
"set" => return Some(SymbolKind::Setter),
"static" | "public" | "private" | "protected" | "readonly" | "override" | "async" => {
continue;
}
_ => return None,
}
}
None
}
fn signature_slice(text: &str) -> Option<String> {
let bytes = text.as_bytes();
let mut end = bytes.len();
for (i, &b) in bytes.iter().enumerate() {
if b == b'{' || b == b';' {
end = i;
break;
}
}
let collapsed: String = text[..end].split_whitespace().collect::<Vec<_>>().join(" ");
if collapsed.is_empty() {
None
} else {
Some(collapsed)
}
}
fn build_implementation(q: &Query, m: &QueryMatch, source: &[u8]) -> Option<Implementation> {
let mut trait_name: Option<String> = None;
let mut impl_type: Option<String> = None;
let mut range_node: Option<Node> = None;
let mut trait_node: Option<Node> = None;
for cap in m.captures {
let cname = capture_name(q, cap.index);
match cname {
"impl.trait_name" => {
trait_name = cap.node.utf8_text(source).ok().map(|s| s.to_string());
trait_node = Some(cap.node);
}
"impl.implementor" => {
impl_type = cap.node.utf8_text(source).ok().map(|s| s.to_string());
}
"impl.range" => {
range_node = Some(cap.node);
}
_ => {}
}
}
let trait_name = trait_name?;
let impl_type = impl_type.or_else(|| {
let anchor = range_node.or(trait_node)?;
implementor_from_ancestor(anchor, source)
})?;
let pos_node = range_node.or(trait_node)?;
let p = pos_node.start_position();
Some(Implementation {
trait_name,
impl_type,
start_byte: pos_node.start_byte() as u32,
start_row: p.row as u32,
start_col: p.column as u32,
})
}
fn implementor_from_ancestor(node: Node, source: &[u8]) -> Option<String> {
fn field_text<'a>(parent: Node<'a>, field: &str, src: &'a [u8]) -> Option<&'a str> {
let n = parent.child_by_field_name(field)?;
let t = n.utf8_text(src).ok()?;
if t.is_empty() { None } else { Some(t) }
}
let mut current = node;
for _ in 0..8 {
let parent = current.parent()?;
if let Some(text) = field_text(parent, "name", source) {
return Some(text.to_string());
}
if let Some(type_node) = parent.child_by_field_name("type") {
let leaf_text = (type_node.child_count() == 0)
.then(|| type_node.utf8_text(source).ok())
.flatten()
.filter(|t| !t.is_empty());
if let Some(text) = leaf_text {
return Some(text.to_string());
}
}
current = parent;
}
None
}
fn build_import(q: &Query, m: &QueryMatch, source: &[u8]) -> Option<Import> {
let mut range_node = None;
let mut module: Option<String> = None;
for cap in m.captures {
let cname = capture_name(q, cap.index);
match cname {
"import.range" => range_node = Some(cap.node),
"import.module" => {
module = cap.node.utf8_text(source).ok().map(|s| s.to_string());
}
_ => {}
}
}
let node = range_node?;
let raw = node.utf8_text(source).ok()?.to_string();
Some(Import {
module,
raw,
start_byte: node.start_byte() as u32,
end_byte: node.end_byte() as u32,
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn extract_implementation_rust_trait_impl() {
let src = br#"
trait Drawable {
fn draw(&self);
}
struct Beta {
x: i32,
}
impl Drawable for Beta {
fn draw(&self) {}
}
"#;
let map = extract_l1("rust", src).expect("extract");
let impls = &map.implementations;
assert!(
!impls.is_empty(),
"expected at least one Implementation; got none"
);
let found = impls
.iter()
.find(|i| i.trait_name == "Drawable" && i.impl_type == "Beta");
assert!(
found.is_some(),
"expected Implementation {{ trait_name: \"Drawable\", impl_type: \"Beta\" }}; got {impls:?}"
);
}
#[test]
fn extract_basic_rust() {
let src = br#"
pub fn hello() {}
pub struct Foo {
x: i32,
}
use std::collections::HashMap;
const N: u32 = 42;
"#;
let map = extract_l1("rust", src).expect("extract");
let names: Vec<&str> = map.symbols.iter().map(|s| s.name.as_str()).collect();
assert!(names.contains(&"hello"));
assert!(names.contains(&"Foo"));
assert!(names.contains(&"N"));
assert!(!map.imports.is_empty(), "expected at least one import");
assert!(!map.had_errors, "clean source must not flag errors");
assert_eq!(map.error_count, 0);
}
#[test]
fn extract_recovers_from_syntax_errors() {
let src = br#"
pub fn good_one() {}
pub fn broken( {
let x = ;
}
pub fn good_two() {}
"#;
let map = extract_l1("rust", src).expect("extract should not fail on partial parse");
assert!(
map.had_errors,
"had_errors should be true for syntax errors"
);
assert!(
map.error_count > 0,
"error_count should be > 0; got {}",
map.error_count
);
let names: Vec<&str> = map.symbols.iter().map(|s| s.name.as_str()).collect();
assert!(
names.contains(&"good_one") || names.contains(&"good_two"),
"at least one well-formed sibling symbol should be recovered; got {names:?}"
);
}
}