use crate::languages::get_language_info;
use crate::types::{
CallInfo, ClassInfo, FunctionInfo, ImplTraitInfo, ImportInfo, ReferenceInfo, ReferenceType,
SemanticAnalysis,
};
use std::cell::RefCell;
use std::collections::HashMap;
use std::path::{Path, PathBuf};
use std::sync::LazyLock;
use thiserror::Error;
use tracing::instrument;
use tree_sitter::{Node, Parser, Query, QueryCursor, StreamingIterator};
#[derive(Debug, Error)]
#[non_exhaustive]
pub enum ParserError {
#[error("Unsupported language: {0}")]
UnsupportedLanguage(String),
#[error("Failed to parse file: {0}")]
ParseError(String),
#[error("Invalid UTF-8 in file")]
InvalidUtf8,
#[error("Query error: {0}")]
QueryError(String),
}
struct CompiledQueries {
element: Query,
call: Query,
import: Option<Query>,
impl_block: Option<Query>,
reference: Option<Query>,
impl_trait: Option<Query>,
defuse: Option<Query>,
}
#[cfg_attr(coverage_nightly, coverage(off))]
fn build_compiled_queries(
lang_info: &crate::languages::LanguageInfo,
) -> Result<CompiledQueries, ParserError> {
let element = Query::new(&lang_info.language, lang_info.element_query).map_err(|e| {
ParserError::QueryError(format!(
"Failed to compile element query for {}: {}",
lang_info.name, e
))
})?;
let call = Query::new(&lang_info.language, lang_info.call_query).map_err(|e| {
ParserError::QueryError(format!(
"Failed to compile call query for {}: {}",
lang_info.name, e
))
})?;
let import = if let Some(import_query_str) = lang_info.import_query {
Some(
Query::new(&lang_info.language, import_query_str).map_err(|e| {
ParserError::QueryError(format!(
"Failed to compile import query for {}: {}",
lang_info.name, e
))
})?,
)
} else {
None
};
let impl_block = if let Some(impl_query_str) = lang_info.impl_query {
Some(
Query::new(&lang_info.language, impl_query_str).map_err(|e| {
ParserError::QueryError(format!(
"Failed to compile impl query for {}: {}",
lang_info.name, e
))
})?,
)
} else {
None
};
let reference = if let Some(ref_query_str) = lang_info.reference_query {
Some(Query::new(&lang_info.language, ref_query_str).map_err(|e| {
ParserError::QueryError(format!(
"Failed to compile reference query for {}: {}",
lang_info.name, e
))
})?)
} else {
None
};
let impl_trait = if let Some(impl_trait_query_str) = lang_info.impl_trait_query {
Some(
Query::new(&lang_info.language, impl_trait_query_str).map_err(|e| {
ParserError::QueryError(format!(
"Failed to compile impl_trait query for {}: {}",
lang_info.name, e
))
})?,
)
} else {
None
};
let defuse = if let Some(defuse_query_str) = lang_info.defuse_query {
Some(
Query::new(&lang_info.language, defuse_query_str).map_err(|e| {
ParserError::QueryError(format!(
"Failed to compile defuse query for {}: {}",
lang_info.name, e
))
})?,
)
} else {
None
};
Ok(CompiledQueries {
element,
call,
import,
impl_block,
reference,
impl_trait,
defuse,
})
}
#[cfg_attr(coverage_nightly, coverage(off))]
fn init_query_cache() -> HashMap<&'static str, CompiledQueries> {
let mut cache = HashMap::new();
for lang_name in crate::lang::supported_languages() {
if let Some(lang_info) = get_language_info(lang_name) {
match build_compiled_queries(&lang_info) {
Ok(compiled) => {
cache.insert(*lang_name, compiled);
}
Err(e) => {
tracing::error!(
"Failed to compile queries for language {}: {}",
lang_name,
e
);
}
}
}
}
cache
}
static QUERY_CACHE: LazyLock<HashMap<&'static str, CompiledQueries>> =
LazyLock::new(init_query_cache);
fn get_compiled_queries(language: &str) -> Result<&'static CompiledQueries, ParserError> {
QUERY_CACHE
.get(language)
.ok_or_else(|| ParserError::UnsupportedLanguage(language.to_string()))
}
thread_local! {
static PARSER: RefCell<Parser> = RefCell::new(Parser::new());
}
pub struct ElementExtractor;
impl ElementExtractor {
#[instrument(skip_all, fields(language))]
pub fn extract_with_depth(source: &str, language: &str) -> Result<(usize, usize), ParserError> {
let lang_info = get_language_info(language)
.ok_or_else(|| ParserError::UnsupportedLanguage(language.to_string()))?;
let tree = PARSER.with(|p| {
let mut parser = p.borrow_mut();
parser
.set_language(&lang_info.language)
.map_err(|e| ParserError::ParseError(format!("Failed to set language: {e}")))?;
parser
.parse(source, None)
.ok_or_else(|| ParserError::ParseError("Failed to parse".to_string()))
})?;
let compiled = get_compiled_queries(language)?;
let mut cursor = QueryCursor::new();
let mut function_count = 0;
let mut class_count = 0;
let mut matches = cursor.matches(&compiled.element, tree.root_node(), source.as_bytes());
while let Some(mat) = matches.next() {
for capture in mat.captures {
let capture_name = compiled.element.capture_names()[capture.index as usize];
match capture_name {
"function" => function_count += 1,
"class" => class_count += 1,
_ => {}
}
}
}
tracing::debug!(language = %language, functions = function_count, classes = class_count, "parse complete");
Ok((function_count, class_count))
}
}
#[allow(clippy::too_many_lines)] fn extract_imports_from_node(
node: &Node,
source: &str,
prefix: &str,
line: usize,
imports: &mut Vec<ImportInfo>,
) {
match node.kind() {
"identifier" | "self" | "super" | "crate" => {
let name = source[node.start_byte()..node.end_byte()].to_string();
imports.push(ImportInfo {
module: prefix.to_string(),
items: vec![name],
line,
});
}
"scoped_identifier" => {
let item = node
.child_by_field_name("name")
.map(|n| source[n.start_byte()..n.end_byte()].to_string())
.unwrap_or_default();
let module = node.child_by_field_name("path").map_or_else(
|| prefix.to_string(),
|p| {
let path_text = source[p.start_byte()..p.end_byte()].to_string();
if prefix.is_empty() {
path_text
} else {
format!("{prefix}::{path_text}")
}
},
);
if !item.is_empty() {
imports.push(ImportInfo {
module,
items: vec![item],
line,
});
}
}
"scoped_use_list" => {
let new_prefix = node.child_by_field_name("path").map_or_else(
|| prefix.to_string(),
|p| {
let path_text = source[p.start_byte()..p.end_byte()].to_string();
if prefix.is_empty() {
path_text
} else {
format!("{prefix}::{path_text}")
}
},
);
if let Some(list) = node.child_by_field_name("list") {
extract_imports_from_node(&list, source, &new_prefix, line, imports);
}
}
"use_list" => {
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
match child.kind() {
"{" | "}" | "," => {}
_ => extract_imports_from_node(&child, source, prefix, line, imports),
}
}
}
"use_wildcard" => {
let text = source[node.start_byte()..node.end_byte()].to_string();
let module = if let Some(stripped) = text.strip_suffix("::*") {
if prefix.is_empty() {
stripped.to_string()
} else {
format!("{prefix}::{stripped}")
}
} else {
prefix.to_string()
};
imports.push(ImportInfo {
module,
items: vec!["*".to_string()],
line,
});
}
"use_as_clause" => {
let alias = node
.child_by_field_name("alias")
.map(|n| source[n.start_byte()..n.end_byte()].to_string())
.unwrap_or_default();
let module = if let Some(path_node) = node.child_by_field_name("path") {
match path_node.kind() {
"scoped_identifier" => path_node.child_by_field_name("path").map_or_else(
|| prefix.to_string(),
|p| {
let p_text = source[p.start_byte()..p.end_byte()].to_string();
if prefix.is_empty() {
p_text
} else {
format!("{prefix}::{p_text}")
}
},
),
_ => prefix.to_string(),
}
} else {
prefix.to_string()
};
if !alias.is_empty() {
imports.push(ImportInfo {
module,
items: vec![alias],
line,
});
}
}
"import_from_statement" => {
extract_python_import_from(node, source, line, imports);
}
_ => {
let text = source[node.start_byte()..node.end_byte()]
.trim()
.to_string();
if !text.is_empty() {
imports.push(ImportInfo {
module: text,
items: vec![],
line,
});
}
}
}
}
fn extract_import_item_name(child: &Node, source: &str) -> Option<String> {
match child.kind() {
"dotted_name" => {
let name = source[child.start_byte()..child.end_byte()]
.trim()
.to_string();
if name.is_empty() { None } else { Some(name) }
}
"aliased_import" => child.child_by_field_name("name").and_then(|n| {
let name = source[n.start_byte()..n.end_byte()].trim().to_string();
if name.is_empty() { None } else { Some(name) }
}),
_ => None,
}
}
fn collect_import_items(
node: &Node,
source: &str,
is_wildcard: &mut bool,
items: &mut Vec<String>,
) {
if let Some(import_list) = node.child_by_field_name("import_list") {
let mut cursor = import_list.walk();
for child in import_list.named_children(&mut cursor) {
if child.kind() == "wildcard_import" {
*is_wildcard = true;
} else if let Some(name) = extract_import_item_name(&child, source) {
items.push(name);
}
}
return;
}
let mut cursor = node.walk();
let mut first = true;
for child in node.named_children(&mut cursor) {
if first {
first = false;
continue;
}
if child.kind() == "wildcard_import" {
*is_wildcard = true;
} else if let Some(name) = extract_import_item_name(&child, source) {
items.push(name);
}
}
}
fn extract_python_import_from(
node: &Node,
source: &str,
line: usize,
imports: &mut Vec<ImportInfo>,
) {
let module = if let Some(m) = node.child_by_field_name("module_name") {
source[m.start_byte()..m.end_byte()].trim().to_string()
} else if let Some(r) = node.child_by_field_name("relative_import") {
source[r.start_byte()..r.end_byte()].trim().to_string()
} else {
String::new()
};
let mut is_wildcard = false;
let mut items = Vec::new();
collect_import_items(node, source, &mut is_wildcard, &mut items);
if !module.is_empty() {
imports.push(ImportInfo {
module,
items: if is_wildcard {
vec!["*".to_string()]
} else {
items
},
line,
});
}
}
pub struct SemanticExtractor;
impl SemanticExtractor {
#[instrument(skip_all, fields(language))]
pub fn extract(
source: &str,
language: &str,
ast_recursion_limit: Option<usize>,
) -> Result<SemanticAnalysis, ParserError> {
let lang_info = get_language_info(language)
.ok_or_else(|| ParserError::UnsupportedLanguage(language.to_string()))?;
let tree = PARSER.with(|p| {
let mut parser = p.borrow_mut();
parser
.set_language(&lang_info.language)
.map_err(|e| ParserError::ParseError(format!("Failed to set language: {e}")))?;
parser
.parse(source, None)
.ok_or_else(|| ParserError::ParseError("Failed to parse".to_string()))
})?;
let max_depth: Option<u32> = ast_recursion_limit
.filter(|&limit| limit > 0)
.map(|limit| {
u32::try_from(limit).map_err(|_| {
ParserError::ParseError(format!(
"ast_recursion_limit {} exceeds maximum supported value {}",
limit,
u32::MAX
))
})
})
.transpose()?;
let compiled = get_compiled_queries(language)?;
let root = tree.root_node();
let mut functions = Vec::new();
let mut classes = Vec::new();
let mut imports = Vec::new();
let mut references = Vec::new();
let mut call_frequency = HashMap::new();
let mut calls = Vec::new();
Self::extract_elements(
source,
compiled,
root,
max_depth,
&lang_info,
&mut functions,
&mut classes,
);
Self::extract_calls(
source,
compiled,
root,
max_depth,
&mut calls,
&mut call_frequency,
);
Self::extract_imports(source, compiled, root, max_depth, &mut imports);
Self::extract_impl_methods(source, compiled, root, max_depth, &mut classes);
Self::extract_references(source, compiled, root, max_depth, &mut references);
let impl_traits = if language == "rust" {
Self::extract_impl_traits_from_tree(source, compiled, root)
} else {
vec![]
};
tracing::debug!(language = %language, functions = functions.len(), classes = classes.len(), imports = imports.len(), references = references.len(), calls = calls.len(), impl_traits = impl_traits.len(), "extraction complete");
Ok(SemanticAnalysis {
functions,
classes,
imports,
references,
call_frequency,
calls,
impl_traits,
def_use_sites: Vec::new(),
})
}
fn extract_elements(
source: &str,
compiled: &CompiledQueries,
root: Node<'_>,
max_depth: Option<u32>,
lang_info: &crate::languages::LanguageInfo,
functions: &mut Vec<FunctionInfo>,
classes: &mut Vec<ClassInfo>,
) {
let mut cursor = QueryCursor::new();
if let Some(depth) = max_depth {
cursor.set_max_start_depth(Some(depth));
}
let mut matches = cursor.matches(&compiled.element, root, source.as_bytes());
let mut seen_functions = std::collections::HashSet::new();
while let Some(mat) = matches.next() {
let mut func_node: Option<Node> = None;
let mut func_name_text: Option<String> = None;
let mut class_node: Option<Node> = None;
let mut class_name_text: Option<String> = None;
for capture in mat.captures {
let capture_name = compiled.element.capture_names()[capture.index as usize];
let node = capture.node;
match capture_name {
"function" => func_node = Some(node),
"func_name" | "method_name" => {
func_name_text =
Some(source[node.start_byte()..node.end_byte()].to_string());
}
"class" => class_node = Some(node),
"class_name" | "type_name" => {
class_name_text =
Some(source[node.start_byte()..node.end_byte()].to_string());
}
_ => {}
}
}
if let Some(func_node) = func_node {
let parent_is_template = func_node
.parent()
.map(|p| p.kind() == "template_declaration")
.unwrap_or(false);
if func_node.kind() == "function_definition" && parent_is_template {
} else {
let func_def = if func_node.kind() == "template_declaration" {
let mut cursor = func_node.walk();
func_node
.children(&mut cursor)
.find(|n| n.kind() == "function_definition")
.unwrap_or(func_node)
} else {
func_node
};
let name = func_name_text
.or_else(|| {
func_def
.child_by_field_name("name")
.map(|n| source[n.start_byte()..n.end_byte()].to_string())
})
.unwrap_or_default();
let func_key = (name.clone(), func_node.start_position().row);
if !name.is_empty() && seen_functions.insert(func_key) {
let params = func_def
.child_by_field_name("declarator")
.and_then(|d| d.child_by_field_name("parameters"))
.or_else(|| func_def.child_by_field_name("parameters"))
.map(|p| source[p.start_byte()..p.end_byte()].to_string())
.unwrap_or_default();
let return_type = func_def
.child_by_field_name("type")
.or_else(|| func_def.child_by_field_name("return_type"))
.map(|r| source[r.start_byte()..r.end_byte()].to_string());
functions.push(FunctionInfo {
name,
line: func_node.start_position().row + 1,
end_line: func_node.end_position().row + 1,
parameters: if params.is_empty() {
Vec::new()
} else {
vec![params]
},
return_type,
});
}
}
}
if let Some(class_node) = class_node {
let name = class_name_text
.or_else(|| {
class_node
.child_by_field_name("name")
.map(|n| source[n.start_byte()..n.end_byte()].to_string())
})
.unwrap_or_default();
if !name.is_empty() {
let inherits = if let Some(handler) = lang_info.extract_inheritance {
handler(&class_node, source)
} else {
Vec::new()
};
classes.push(ClassInfo {
name,
line: class_node.start_position().row + 1,
end_line: class_node.end_position().row + 1,
methods: Vec::new(),
fields: Vec::new(),
inherits,
});
}
}
}
}
fn enclosing_function_name(mut node: tree_sitter::Node<'_>, source: &str) -> Option<String> {
let mut depth = 0u32;
while let Some(parent) = node.parent() {
depth += 1;
if depth > 64 {
return None;
}
let name_node = match parent.kind() {
"function_item"
| "method_item"
| "function_definition"
| "function_declaration"
| "method_declaration"
| "method_definition" => parent.child_by_field_name("name"),
"subroutine" => {
let mut cursor = parent.walk();
parent
.children(&mut cursor)
.find(|c| c.kind() == "subroutine_statement")
.and_then(|s| s.child_by_field_name("name"))
}
"function" => {
let mut cursor = parent.walk();
parent
.children(&mut cursor)
.find(|c| c.kind() == "function_statement")
.and_then(|s| s.child_by_field_name("name"))
}
_ => {
node = parent;
continue;
}
};
return name_node.map(|n| source[n.start_byte()..n.end_byte()].to_string());
}
None
}
fn extract_calls(
source: &str,
compiled: &CompiledQueries,
root: Node<'_>,
max_depth: Option<u32>,
calls: &mut Vec<CallInfo>,
call_frequency: &mut HashMap<String, usize>,
) {
let mut cursor = QueryCursor::new();
if let Some(depth) = max_depth {
cursor.set_max_start_depth(Some(depth));
}
let mut matches = cursor.matches(&compiled.call, root, source.as_bytes());
while let Some(mat) = matches.next() {
for capture in mat.captures {
let capture_name = compiled.call.capture_names()[capture.index as usize];
if capture_name != "call" {
continue;
}
let node = capture.node;
let call_name = source[node.start_byte()..node.end_byte()].to_string();
*call_frequency.entry(call_name.clone()).or_insert(0) += 1;
let caller = Self::enclosing_function_name(node, source)
.unwrap_or_else(|| "<module>".to_string());
let mut arg_count = None;
let mut arg_node = node;
let mut hop = 0u32;
let mut cap_hit = false;
while let Some(parent) = arg_node.parent() {
hop += 1;
if hop > 16 {
cap_hit = true;
break;
}
if parent.kind() == "call_expression" {
if let Some(args) = parent.child_by_field_name("arguments") {
arg_count = Some(args.named_child_count());
}
break;
}
arg_node = parent;
}
debug_assert!(
!cap_hit,
"extract_calls: parent traversal cap reached (hop > 16)"
);
calls.push(CallInfo {
caller,
callee: call_name,
line: node.start_position().row + 1,
column: node.start_position().column,
arg_count,
});
}
}
}
fn extract_imports(
source: &str,
compiled: &CompiledQueries,
root: Node<'_>,
max_depth: Option<u32>,
imports: &mut Vec<ImportInfo>,
) {
let Some(ref import_query) = compiled.import else {
return;
};
let mut cursor = QueryCursor::new();
if let Some(depth) = max_depth {
cursor.set_max_start_depth(Some(depth));
}
let mut matches = cursor.matches(import_query, root, source.as_bytes());
while let Some(mat) = matches.next() {
for capture in mat.captures {
let capture_name = import_query.capture_names()[capture.index as usize];
if capture_name == "import_path" {
let node = capture.node;
let line = node.start_position().row + 1;
extract_imports_from_node(&node, source, "", line, imports);
}
}
}
}
fn extract_impl_methods(
source: &str,
compiled: &CompiledQueries,
root: Node<'_>,
max_depth: Option<u32>,
classes: &mut [ClassInfo],
) {
let Some(ref impl_query) = compiled.impl_block else {
return;
};
let mut cursor = QueryCursor::new();
if let Some(depth) = max_depth {
cursor.set_max_start_depth(Some(depth));
}
let mut matches = cursor.matches(impl_query, root, source.as_bytes());
while let Some(mat) = matches.next() {
let mut impl_type_name = String::new();
let mut method_name = String::new();
let mut method_line = 0usize;
let mut method_end_line = 0usize;
let mut method_params = String::new();
let mut method_return_type: Option<String> = None;
for capture in mat.captures {
let capture_name = impl_query.capture_names()[capture.index as usize];
let node = capture.node;
match capture_name {
"impl_type" => {
impl_type_name = source[node.start_byte()..node.end_byte()].to_string();
}
"method_name" => {
method_name = source[node.start_byte()..node.end_byte()].to_string();
}
"method_params" => {
method_params = source[node.start_byte()..node.end_byte()].to_string();
}
"method" => {
method_line = node.start_position().row + 1;
method_end_line = node.end_position().row + 1;
method_return_type = node
.child_by_field_name("return_type")
.map(|r| source[r.start_byte()..r.end_byte()].to_string());
}
_ => {}
}
}
if !impl_type_name.is_empty() && !method_name.is_empty() {
let func = FunctionInfo {
name: method_name,
line: method_line,
end_line: method_end_line,
parameters: if method_params.is_empty() {
Vec::new()
} else {
vec![method_params]
},
return_type: method_return_type,
};
if let Some(class) = classes.iter_mut().find(|c| c.name == impl_type_name) {
class.methods.push(func);
}
}
}
}
fn extract_references(
source: &str,
compiled: &CompiledQueries,
root: Node<'_>,
max_depth: Option<u32>,
references: &mut Vec<ReferenceInfo>,
) {
let Some(ref ref_query) = compiled.reference else {
return;
};
let mut cursor = QueryCursor::new();
if let Some(depth) = max_depth {
cursor.set_max_start_depth(Some(depth));
}
let mut seen_refs = std::collections::HashSet::new();
let mut matches = cursor.matches(ref_query, root, source.as_bytes());
while let Some(mat) = matches.next() {
for capture in mat.captures {
let capture_name = ref_query.capture_names()[capture.index as usize];
if capture_name == "type_ref" {
let node = capture.node;
let type_ref = source[node.start_byte()..node.end_byte()].to_string();
if seen_refs.insert(type_ref.clone()) {
references.push(ReferenceInfo {
symbol: type_ref,
reference_type: ReferenceType::Usage,
location: String::new(),
line: node.start_position().row + 1,
});
}
}
}
}
}
fn extract_impl_traits_from_tree(
source: &str,
compiled: &CompiledQueries,
root: Node<'_>,
) -> Vec<ImplTraitInfo> {
let Some(query) = &compiled.impl_trait else {
return vec![];
};
let mut cursor = QueryCursor::new();
let mut matches = cursor.matches(query, root, source.as_bytes());
let mut results = Vec::new();
while let Some(mat) = matches.next() {
let mut trait_name = String::new();
let mut impl_type = String::new();
let mut line = 0usize;
for capture in mat.captures {
let capture_name = query.capture_names()[capture.index as usize];
let node = capture.node;
let text = source[node.start_byte()..node.end_byte()].to_string();
match capture_name {
"trait_name" => {
trait_name = text;
line = node.start_position().row + 1;
}
"impl_type" => {
impl_type = text;
}
_ => {}
}
}
if !trait_name.is_empty() && !impl_type.is_empty() {
results.push(ImplTraitInfo {
trait_name,
impl_type,
path: PathBuf::new(), line,
});
}
}
results
}
fn extract_def_use(
source: &str,
compiled: &CompiledQueries,
root: Node<'_>,
symbol_name: &str,
file_path: &str,
max_depth: Option<u32>,
) -> Vec<crate::types::DefUseSite> {
let Some(ref defuse_query) = compiled.defuse else {
return vec![];
};
let mut cursor = QueryCursor::new();
if let Some(depth) = max_depth {
cursor.set_max_start_depth(Some(depth));
}
let mut matches = cursor.matches(defuse_query, root, source.as_bytes());
let mut sites = Vec::new();
let source_lines: Vec<&str> = source.lines().collect();
let mut write_offsets = std::collections::HashSet::new();
while let Some(mat) = matches.next() {
for capture in mat.captures {
let capture_name = defuse_query.capture_names()[capture.index as usize];
let node = capture.node;
let node_text = node.utf8_text(source.as_bytes()).unwrap_or_default();
if node_text != symbol_name {
continue;
}
let kind = if capture_name.starts_with("write.") {
crate::types::DefUseKind::Write
} else if capture_name.starts_with("read.") {
crate::types::DefUseKind::Read
} else if capture_name.starts_with("writeread.") {
crate::types::DefUseKind::WriteRead
} else {
continue;
};
let byte_offset = node.start_byte();
if kind == crate::types::DefUseKind::Read && write_offsets.contains(&byte_offset) {
continue;
}
if kind != crate::types::DefUseKind::Read {
write_offsets.insert(byte_offset);
}
let line = node.start_position().row + 1;
let snippet = {
let row = node.start_position().row;
let last_line = source_lines.len().saturating_sub(1);
let prev = if row > 0 { row - 1 } else { 0 };
let next = std::cmp::min(row + 1, last_line);
let prev_text = if row == 0 {
""
} else {
source_lines[prev].trim_end()
};
let cur_text = source_lines[row].trim_end();
let next_text = if row >= last_line {
""
} else {
source_lines[next].trim_end()
};
format!("{prev_text}\n{cur_text}\n{next_text}")
};
let enclosing_scope = Self::enclosing_function_name(node, source);
let column = node.start_position().column;
sites.push(crate::types::DefUseSite {
kind,
symbol: node_text.to_string(),
file: file_path.to_string(),
line,
column,
snippet,
enclosing_scope,
});
}
}
sites
}
pub(crate) fn extract_def_use_for_file(
source: &str,
language: &str,
symbol: &str,
file_path: &str,
ast_recursion_limit: Option<usize>,
) -> Vec<crate::types::DefUseSite> {
let Some(lang_info) = crate::languages::get_language_info(language) else {
return vec![];
};
let Ok(compiled) = get_compiled_queries(language) else {
return vec![];
};
if compiled.defuse.is_none() {
return vec![];
}
let tree = match PARSER.with(|p| {
let mut parser = p.borrow_mut();
if parser.set_language(&lang_info.language).is_err() {
return None;
}
parser.parse(source, None)
}) {
Some(t) => t,
None => return vec![],
};
let root = tree.root_node();
let max_depth: Option<u32> = ast_recursion_limit
.filter(|&limit| limit > 0)
.and_then(|limit| u32::try_from(limit).ok());
Self::extract_def_use(source, compiled, root, symbol, file_path, max_depth)
}
}
#[must_use]
pub fn extract_impl_traits(source: &str, path: &Path) -> Vec<ImplTraitInfo> {
let Some(lang_info) = get_language_info("rust") else {
return vec![];
};
let Ok(compiled) = get_compiled_queries("rust") else {
return vec![];
};
let Some(query) = &compiled.impl_trait else {
return vec![];
};
let Some(tree) = PARSER.with(|p| {
let mut parser = p.borrow_mut();
let _ = parser.set_language(&lang_info.language);
parser.parse(source, None)
}) else {
return vec![];
};
let root = tree.root_node();
let mut cursor = QueryCursor::new();
let mut matches = cursor.matches(query, root, source.as_bytes());
let mut results = Vec::new();
while let Some(mat) = matches.next() {
let mut trait_name = String::new();
let mut impl_type = String::new();
let mut line = 0usize;
for capture in mat.captures {
let capture_name = query.capture_names()[capture.index as usize];
let node = capture.node;
let text = source[node.start_byte()..node.end_byte()].to_string();
match capture_name {
"trait_name" => {
trait_name = text;
line = node.start_position().row + 1;
}
"impl_type" => {
impl_type = text;
}
_ => {}
}
}
if !trait_name.is_empty() && !impl_type.is_empty() {
results.push(ImplTraitInfo {
trait_name,
impl_type,
path: path.to_path_buf(),
line,
});
}
}
results
}
pub fn execute_query_impl(
language: &str,
source: &str,
query_str: &str,
) -> Result<Vec<crate::QueryCapture>, ParserError> {
let ts_language = crate::languages::get_ts_language(language)
.ok_or_else(|| ParserError::UnsupportedLanguage(language.to_string()))?;
let mut parser = Parser::new();
parser
.set_language(&ts_language)
.map_err(|e| ParserError::QueryError(e.to_string()))?;
let tree = parser
.parse(source.as_bytes(), None)
.ok_or_else(|| ParserError::QueryError("failed to parse source".to_string()))?;
let query =
Query::new(&ts_language, query_str).map_err(|e| ParserError::QueryError(e.to_string()))?;
let mut cursor = QueryCursor::new();
let source_bytes = source.as_bytes();
let mut captures = Vec::new();
let mut matches = cursor.matches(&query, tree.root_node(), source_bytes);
while let Some(m) = matches.next() {
for cap in m.captures {
let node = cap.node;
let capture_name = query.capture_names()[cap.index as usize].to_string();
let text = node.utf8_text(source_bytes).unwrap_or("").to_string();
captures.push(crate::QueryCapture {
capture_name,
text,
start_line: node.start_position().row,
end_line: node.end_position().row,
start_byte: node.start_byte(),
end_byte: node.end_byte(),
});
}
}
Ok(captures)
}
#[cfg(all(test, feature = "lang-rust"))]
mod tests {
use super::*;
use std::path::Path;
#[test]
fn test_ast_recursion_limit_zero_is_unlimited() {
let source = r#"fn hello() -> u32 { 42 }"#;
let result_none = SemanticExtractor::extract(source, "rust", None);
let result_zero = SemanticExtractor::extract(source, "rust", Some(0));
assert!(result_none.is_ok(), "extract with None failed");
assert!(result_zero.is_ok(), "extract with Some(0) failed");
let analysis_none = result_none.unwrap();
let analysis_zero = result_zero.unwrap();
assert!(
analysis_none.functions.len() >= 1,
"extract with None should find at least one function in the test source"
);
assert_eq!(
analysis_none.functions.len(),
analysis_zero.functions.len(),
"ast_recursion_limit=0 should behave identically to unset (unlimited)"
);
}
#[test]
fn test_rust_use_as_imports() {
let source = "use std::io as stdio;";
let result = SemanticExtractor::extract(source, "rust", None).unwrap();
assert!(
result
.imports
.iter()
.any(|imp| imp.items.iter().any(|i| i == "stdio")),
"expected import alias 'stdio' in {:?}",
result.imports
);
}
#[test]
fn test_rust_use_as_clause_plain_identifier() {
let source = "use io as stdio;";
let result = SemanticExtractor::extract(source, "rust", None).unwrap();
assert!(
result
.imports
.iter()
.any(|imp| imp.items.iter().any(|i| i == "stdio")),
"expected import alias 'stdio' from plain identifier in {:?}",
result.imports
);
}
#[test]
fn test_rust_scoped_use_with_prefix() {
let source = "use std::{io::Read, io::Write};";
let result = SemanticExtractor::extract(source, "rust", None).unwrap();
let items: Vec<String> = result
.imports
.iter()
.filter(|imp| imp.module.starts_with("std::io"))
.flat_map(|imp| imp.items.clone())
.collect();
assert!(
items.contains(&"Read".to_string()) && items.contains(&"Write".to_string()),
"expected 'Read' and 'Write' items under module with std::io, got {:?}",
result.imports
);
}
#[test]
fn test_rust_scoped_use_imports() {
let source = "use std::{fs, io};";
let result = SemanticExtractor::extract(source, "rust", None).unwrap();
let items: Vec<&str> = result
.imports
.iter()
.filter(|imp| imp.module == "std")
.flat_map(|imp| imp.items.iter().map(|s| s.as_str()))
.collect();
assert!(
items.contains(&"fs") && items.contains(&"io"),
"expected 'fs' and 'io' items under module 'std', got {:?}",
items
);
}
#[test]
fn test_rust_wildcard_imports() {
let source = "use std::io::*;";
let result = SemanticExtractor::extract(source, "rust", None).unwrap();
let wildcard = result
.imports
.iter()
.find(|imp| imp.module == "std::io" && imp.items == vec!["*"]);
assert!(
wildcard.is_some(),
"expected wildcard import with module 'std::io', got {:?}",
result.imports
);
}
#[test]
fn test_extract_impl_traits_standalone() {
let source = r#"
struct Foo;
trait Display {}
impl Display for Foo {}
"#;
let results = extract_impl_traits(source, Path::new("test.rs"));
assert_eq!(
results.len(),
1,
"expected one impl trait, got {:?}",
results
);
assert_eq!(results[0].trait_name, "Display");
assert_eq!(results[0].impl_type, "Foo");
}
#[cfg(target_pointer_width = "64")]
#[test]
fn test_ast_recursion_limit_overflow() {
let source = "fn foo() {}";
let big_limit = usize::try_from(u32::MAX).unwrap() + 1;
let result = SemanticExtractor::extract(source, "rust", Some(big_limit));
assert!(
matches!(result, Err(ParserError::ParseError(_))),
"expected ParseError for oversized limit, got {:?}",
result
);
}
#[test]
fn test_ast_recursion_limit_some() {
let source = r#"fn hello() -> u32 { 42 }"#;
let result = SemanticExtractor::extract(source, "rust", Some(5));
assert!(result.is_ok(), "extract with Some(5) failed: {:?}", result);
let analysis = result.unwrap();
assert!(
analysis.functions.len() >= 1,
"expected at least one function with depth limit 5"
);
}
#[test]
fn test_extract_def_use_for_file_finds_write_and_read() {
let source = r#"
fn main() {
let count = 0;
println!("{}", count);
}
"#;
let sites = SemanticExtractor::extract_def_use_for_file(
source,
"rust",
"count",
"src/main.rs",
None,
);
assert!(
!sites.is_empty(),
"expected at least one def-use site for 'count'"
);
let has_write = sites
.iter()
.any(|s| s.kind == crate::types::DefUseKind::Write);
let has_read = sites
.iter()
.any(|s| s.kind == crate::types::DefUseKind::Read);
assert!(has_write, "expected a write site for 'count'");
assert!(has_read, "expected a read site for 'count'");
assert_eq!(sites[0].file, "src/main.rs");
}
#[test]
fn test_extract_def_use_for_file_no_match_returns_empty() {
let source = "fn foo() { let x = 1; }";
let sites = SemanticExtractor::extract_def_use_for_file(
source,
"rust",
"nonexistent_symbol",
"src/lib.rs",
None,
);
assert!(sites.is_empty(), "expected empty for nonexistent symbol");
}
}
#[cfg(all(test, feature = "lang-python"))]
mod tests_python {
use super::*;
#[test]
fn test_python_relative_import() {
let source = "from . import foo\n";
let result = SemanticExtractor::extract(source, "python", None).unwrap();
let relative = result.imports.iter().find(|imp| imp.module.contains("."));
assert!(
relative.is_some(),
"expected relative import in {:?}",
result.imports
);
}
#[test]
fn test_python_aliased_import() {
let source = "from os import path as p\n";
let result = SemanticExtractor::extract(source, "python", None).unwrap();
let path_import = result
.imports
.iter()
.find(|imp| imp.module == "os" && imp.items.iter().any(|i| i == "path"));
assert!(
path_import.is_some(),
"expected import 'path' from module 'os' in {:?}",
result.imports
);
}
}
#[cfg(test)]
mod tests_unsupported {
use super::*;
#[test]
fn test_element_extractor_unsupported_language() {
let result = ElementExtractor::extract_with_depth("x = 1", "cobol");
assert!(
matches!(result, Err(ParserError::UnsupportedLanguage(ref lang)) if lang == "cobol"),
"expected UnsupportedLanguage error, got {:?}",
result
);
}
#[test]
fn test_semantic_extractor_unsupported_language() {
let result = SemanticExtractor::extract("x = 1", "cobol", None);
assert!(
matches!(result, Err(ParserError::UnsupportedLanguage(ref lang)) if lang == "cobol"),
"expected UnsupportedLanguage error, got {:?}",
result
);
}
}