use std::cell::RefCell;
use std::collections::{BTreeSet, HashMap};
use std::ops::Bound;
use std::path::Path;
use std::sync::Arc;
use once_cell::sync::Lazy;
use parking_lot::RwLock;
use tree_sitter::{Language as TSLanguage, Node, Parser, Query, QueryCursor, QueryMatch, Tree};
use crate::ast::types::{ClassInfo, FunctionInfo, ImportInfo, ModuleInfo};
use crate::error::{Result, BrrrError};
use crate::lang::{Language, LanguageRegistry};
use crate::util::format_query_error;
type QueryCacheKey = (&'static str, &'static str);
static QUERY_CACHE: Lazy<RwLock<HashMap<QueryCacheKey, Arc<Query>>>> =
Lazy::new(|| RwLock::new(HashMap::new()));
fn get_cached_query(
ts_lang: &TSLanguage,
lang_name: &'static str,
query_kind: &'static str,
query_str: &str,
) -> Result<Arc<Query>> {
let key = (lang_name, query_kind);
{
let cache = QUERY_CACHE.read();
if let Some(query) = cache.get(&key) {
return Ok(Arc::clone(query));
}
}
let query = Query::new(ts_lang, query_str).map_err(|e| {
BrrrError::TreeSitter(format_query_error(lang_name, query_kind, query_str, &e))
})?;
let query_arc = Arc::new(query);
let mut cache = QUERY_CACHE.write();
cache.entry(key).or_insert_with(|| Arc::clone(&query_arc));
Ok(query_arc)
}
#[allow(dead_code)]
pub fn clear_query_cache() {
let mut cache = QUERY_CACHE.write();
cache.clear();
}
#[allow(dead_code)]
pub fn query_cache_stats() -> usize {
QUERY_CACHE.read().len()
}
thread_local! {
static PARSER_CACHE: RefCell<HashMap<&'static str, Parser>> = RefCell::new(HashMap::new());
}
const MAX_CACHED_PARSERS: usize = 16;
struct CachedParser {
parser: Option<Parser>,
lang_name: &'static str,
}
impl CachedParser {
fn take(lang: &dyn Language) -> Result<Self> {
let lang_name = lang.name();
let cached = PARSER_CACHE.with(|cache| cache.borrow_mut().remove(lang_name));
let parser = match cached {
Some(mut p) => {
p.reset();
p
}
None => {
lang.parser()?
}
};
Ok(Self {
parser: Some(parser),
lang_name,
})
}
fn get_mut(&mut self) -> &mut Parser {
self.parser.as_mut().expect("Parser already consumed")
}
}
impl Drop for CachedParser {
fn drop(&mut self) {
if let Some(parser) = self.parser.take() {
PARSER_CACHE.with(|cache| {
let mut cache = cache.borrow_mut();
if cache.len() < MAX_CACHED_PARSERS {
cache.insert(self.lang_name, parser);
}
});
}
}
}
#[allow(dead_code)]
pub fn clear_parser_cache() {
PARSER_CACHE.with(|cache| {
cache.borrow_mut().clear();
});
}
#[allow(dead_code)]
pub fn parser_cache_stats() -> usize {
PARSER_CACHE.with(|cache| cache.borrow().len())
}
struct PositionSet {
positions: BTreeSet<usize>,
tolerance: usize,
}
impl PositionSet {
fn with_tolerance(tolerance: usize) -> Self {
Self {
positions: BTreeSet::new(),
tolerance,
}
}
fn is_duplicate(&self, pos: usize) -> bool {
let lower = pos.saturating_sub(self.tolerance);
let upper = pos.saturating_add(self.tolerance);
self.positions
.range((Bound::Included(lower), Bound::Included(upper)))
.next()
.is_some()
}
fn insert(&mut self, pos: usize) {
self.positions.insert(pos);
}
}
const FUNCTION_NODE_KINDS: &[&str] = &[
"function_definition",
"decorated_definition",
"function_declaration",
"method_definition",
"arrow_function",
"function_expression",
"generator_function_declaration",
"function_signature",
"ambient_declaration",
"function_declaration",
"method_declaration",
"function_item",
"function_signature_item",
"macro_definition",
"function_definition",
"declaration",
"template_declaration",
"preproc_def",
"preproc_function_def",
"type_definition",
"method_declaration",
"constructor_declaration",
];
const CLASS_NODE_KINDS: &[&str] = &[
"class_definition",
"decorated_definition",
"class_declaration",
"abstract_class_declaration",
"class", "interface_declaration",
"enum_declaration",
"type_alias_declaration",
"module",
"type_declaration",
"struct_item",
"union_item",
"impl_item",
"trait_item",
"enum_item",
"const_item",
"static_item",
"type_item",
"mod_item",
"foreign_mod_item",
"extern_crate_declaration",
"struct_specifier",
"enum_specifier",
"union_specifier",
"class_specifier",
"namespace_definition",
"type_definition",
"preproc_ifdef",
"preproc_if",
"class_declaration",
"interface_declaration",
"enum_declaration",
"record_declaration",
"annotation_type_declaration",
];
fn get_function_node_from_match<'tree>(
match_: &QueryMatch<'_, 'tree>,
query: &Query,
) -> Option<Node<'tree>> {
if let Some(idx) = query.capture_index_for_name("function") {
if let Some(capture) = match_.captures.iter().find(|c| c.index == idx) {
return Some(capture.node);
}
}
for capture in match_.captures.iter() {
if FUNCTION_NODE_KINDS.contains(&capture.node.kind()) {
return Some(capture.node);
}
}
match_.captures.first().map(|c| c.node)
}
fn get_class_node_from_match<'tree>(
match_: &QueryMatch<'_, 'tree>,
query: &Query,
) -> Option<Node<'tree>> {
if let Some(idx) = query.capture_index_for_name("class") {
if let Some(capture) = match_.captures.iter().find(|c| c.index == idx) {
return Some(capture.node);
}
}
for capture in match_.captures.iter() {
if CLASS_NODE_KINDS.contains(&capture.node.kind()) {
return Some(capture.node);
}
}
match_.captures.first().map(|c| c.node)
}
pub struct AstExtractor;
impl AstExtractor {
pub fn extract_file(path: &Path) -> Result<ModuleInfo> {
let registry = LanguageRegistry::global();
let lang = registry.detect_language(path).ok_or_else(|| {
BrrrError::UnsupportedLanguage(
path.extension()
.and_then(|e| e.to_str())
.unwrap_or("unknown")
.to_string(),
)
})?;
let source = std::fs::read(path)
.map_err(|e| BrrrError::io_with_path(e, path))?;
if lang.should_skip_file(path, &source) {
return Err(BrrrError::UnsupportedLanguage(format!(
"File content incompatible with {} parser: {}",
lang.name(),
path.display()
)));
}
let mut cached_parser = CachedParser::take(lang)?;
let tree = cached_parser
.get_mut()
.parse(&source, None)
.ok_or_else(|| BrrrError::Parse {
file: path.display().to_string(),
message: "Failed to parse file".to_string(),
})?;
Self::extract_module(&tree, &source, lang, path)
}
fn extract_module(
tree: &Tree,
source: &[u8],
lang: &dyn Language,
path: &Path,
) -> Result<ModuleInfo> {
let functions = Self::extract_functions(tree, source, lang)?;
let classes = Self::extract_classes(tree, source, lang)?;
let imports = lang.extract_imports(tree, source);
let docstring = lang.extract_module_docstring(tree, source);
Ok(ModuleInfo {
path: path.display().to_string(),
language: lang.name().to_string(),
docstring,
functions,
classes,
imports,
call_graph: None, })
}
fn extract_functions(
tree: &Tree,
source: &[u8],
lang: &dyn Language,
) -> Result<Vec<FunctionInfo>> {
let query_str = lang.function_query();
let ts_lang = tree.language();
let query = get_cached_query(&ts_lang, lang.name(), "function", query_str)?;
let mut cursor = QueryCursor::new();
let mut matches = cursor.matches(&query, tree.root_node(), source);
let mut functions = Vec::new();
let mut seen_positions = PositionSet::with_tolerance(2);
use streaming_iterator::StreamingIterator;
while let Some(match_) = matches.next() {
let node = get_function_node_from_match(match_, &query);
if let Some(node) = node {
let start = node.start_byte();
if seen_positions.is_duplicate(start) {
continue;
}
seen_positions.insert(start);
if let Some(func_info) = lang.extract_function(node, source) {
functions.push(func_info);
}
}
}
functions.sort_by_key(|f| f.line_number);
Ok(functions)
}
fn extract_classes(tree: &Tree, source: &[u8], lang: &dyn Language) -> Result<Vec<ClassInfo>> {
let query_str = lang.class_query();
let ts_lang = tree.language();
let query = get_cached_query(&ts_lang, lang.name(), "class", query_str)?;
let mut cursor = QueryCursor::new();
let mut matches = cursor.matches(&query, tree.root_node(), source);
let mut classes = Vec::new();
let mut seen_positions = PositionSet::with_tolerance(2);
use streaming_iterator::StreamingIterator;
while let Some(match_) = matches.next() {
let node = get_class_node_from_match(match_, &query);
if let Some(node) = node {
let start = node.start_byte();
if seen_positions.is_duplicate(start) {
continue;
}
seen_positions.insert(start);
if let Some(class_info) = lang.extract_class(node, source) {
classes.push(class_info);
}
}
}
classes.sort_by_key(|c| c.line_number);
Ok(classes)
}
#[allow(dead_code)]
pub fn extract_from_source(source: &str, language: &str) -> Result<ModuleInfo> {
let registry = LanguageRegistry::global();
let lang = registry
.get_by_name(language)
.ok_or_else(|| BrrrError::UnsupportedLanguage(language.to_string()))?;
let source_bytes = source.as_bytes();
let mut cached_parser = CachedParser::take(lang)?;
let tree = cached_parser
.get_mut()
.parse(source_bytes, None)
.ok_or_else(|| BrrrError::Parse {
file: "<string>".to_string(),
message: "Failed to parse source".to_string(),
})?;
let functions = Self::extract_functions(&tree, source_bytes, lang)?;
let classes = Self::extract_classes(&tree, source_bytes, lang)?;
let imports = lang.extract_imports(&tree, source_bytes);
let docstring = lang.extract_module_docstring(&tree, source_bytes);
Ok(ModuleInfo {
path: "<string>".to_string(),
language: lang.name().to_string(),
docstring,
functions,
classes,
imports,
call_graph: None, })
}
#[allow(dead_code)]
pub fn find_function(path: &Path, function_name: &str) -> Result<FunctionInfo> {
let module_info = Self::extract_file(path)?;
if let Some(func) = module_info
.functions
.iter()
.find(|f| f.name == function_name)
{
return Ok(func.clone());
}
for class in &module_info.classes {
if let Some(method) = class.methods.iter().find(|m| m.name == function_name) {
return Ok(method.clone());
}
}
Err(BrrrError::FunctionNotFound(function_name.to_string()))
}
#[allow(dead_code)]
pub fn find_class(path: &Path, class_name: &str) -> Result<ClassInfo> {
let module_info = Self::extract_file(path)?;
module_info
.classes
.into_iter()
.find(|c| c.name == class_name)
.ok_or_else(|| BrrrError::ClassNotFound(class_name.to_string()))
}
}
pub fn extract_imports(path: &Path) -> Result<Vec<ImportInfo>> {
let registry = LanguageRegistry::global();
let lang = registry.detect_language(path).ok_or_else(|| {
BrrrError::UnsupportedLanguage(
path.extension()
.and_then(|e| e.to_str())
.unwrap_or("unknown")
.to_string(),
)
})?;
let source = std::fs::read(path)
.map_err(|e| BrrrError::io_with_path(e, path))?;
let mut cached_parser = CachedParser::take(lang)?;
let tree = cached_parser
.get_mut()
.parse(&source, None)
.ok_or_else(|| BrrrError::Parse {
file: path.display().to_string(),
message: "Failed to parse file".to_string(),
})?;
Ok(lang.extract_imports(&tree, &source))
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
use tempfile::NamedTempFile;
fn create_temp_file(content: &str, extension: &str) -> NamedTempFile {
let mut file = tempfile::Builder::new()
.suffix(extension)
.tempfile()
.unwrap();
file.write_all(content.as_bytes()).unwrap();
file
}
#[test]
fn test_extract_python_functions() {
let source = r#"
def hello(name: str) -> str:
"""Say hello to someone."""
return f"Hello, {name}!"
async def fetch_data(url: str) -> bytes:
"""Fetch data from URL."""
pass
class MyClass:
def method(self, x: int) -> int:
return x * 2
"#;
let file = create_temp_file(source, ".py");
let result = AstExtractor::extract_file(file.path()).unwrap();
assert_eq!(result.language, "python");
assert!(result.functions.len() >= 2);
let hello = result.functions.iter().find(|f| f.name == "hello");
assert!(hello.is_some());
let hello = hello.unwrap();
assert_eq!(hello.return_type, Some("str".to_string()));
assert!(hello
.docstring
.as_ref()
.map_or(false, |d| d.contains("Say hello")));
assert!(!hello.is_async);
let fetch = result.functions.iter().find(|f| f.name == "fetch_data");
assert!(fetch.is_some());
assert!(fetch.unwrap().is_async);
assert_eq!(result.classes.len(), 1);
assert_eq!(result.classes[0].name, "MyClass");
assert!(!result.classes[0].methods.is_empty());
}
#[test]
fn test_extract_python_classes() {
let source = r#"
class Animal:
"""Base class for animals."""
def __init__(self, name: str):
self.name = name
def speak(self) -> str:
pass
class Dog(Animal):
"""A dog."""
def speak(self) -> str:
return "Woof!"
"#;
let file = create_temp_file(source, ".py");
let result = AstExtractor::extract_file(file.path()).unwrap();
assert_eq!(result.classes.len(), 2);
let animal = result.classes.iter().find(|c| c.name == "Animal").unwrap();
assert!(animal
.docstring
.as_ref()
.map_or(false, |d| d.contains("Base class")));
assert!(animal.methods.len() >= 2);
let dog = result.classes.iter().find(|c| c.name == "Dog").unwrap();
assert!(dog.bases.contains(&"Animal".to_string()));
}
#[test]
fn test_extract_python_imports() {
let source = r#"
import os
import sys as system
from pathlib import Path
from collections import defaultdict as dd
from . import local
"#;
let file = create_temp_file(source, ".py");
let imports = extract_imports(file.path()).unwrap();
assert!(imports.len() >= 4);
let os_import = imports.iter().find(|i| i.module == "os");
assert!(os_import.is_some());
assert!(!os_import.unwrap().is_from);
let sys_import = imports.iter().find(|i| i.module == "sys");
assert!(sys_import.is_some());
assert!(sys_import.unwrap().aliases.contains_key("sys"));
let pathlib_import = imports.iter().find(|i| i.module == "pathlib");
assert!(pathlib_import.is_some());
assert!(pathlib_import.unwrap().is_from);
assert!(pathlib_import.unwrap().names.contains(&"Path".to_string()));
}
#[test]
fn test_extract_typescript_functions() {
let source = r#"
function greet(name: string): string {
return "Hello, " + name;
}
async function fetchData(url: string): Promise<Response> {
return fetch(url);
}
const add = (a: number, b: number): number => a + b;
"#;
let file = create_temp_file(source, ".ts");
let result = AstExtractor::extract_file(file.path()).unwrap();
assert_eq!(result.language, "typescript");
assert!(result.functions.len() >= 2);
let greet = result.functions.iter().find(|f| f.name == "greet");
assert!(greet.is_some());
assert_eq!(greet.unwrap().return_type, Some("string".to_string()));
let fetch_data = result.functions.iter().find(|f| f.name == "fetchData");
assert!(fetch_data.is_some());
assert!(fetch_data.unwrap().is_async);
}
#[test]
fn test_extract_typescript_classes() {
let source = r#"
class Animal {
constructor(public name: string) {}
speak(): void {
console.log(this.name);
}
}
class Dog extends Animal {
bark(): void {
console.log("Woof!");
}
}
"#;
let file = create_temp_file(source, ".ts");
let result = AstExtractor::extract_file(file.path()).unwrap();
assert_eq!(result.classes.len(), 2);
let animal = result.classes.iter().find(|c| c.name == "Animal");
assert!(animal.is_some());
let dog = result.classes.iter().find(|c| c.name == "Dog");
assert!(dog.is_some());
assert!(dog.unwrap().bases.contains(&"Animal".to_string()));
}
#[test]
fn test_extract_from_source() {
let source = r#"
def add(a: int, b: int) -> int:
return a + b
"#;
let result = AstExtractor::extract_from_source(source, "python").unwrap();
assert_eq!(result.language, "python");
assert_eq!(result.functions.len(), 1);
assert_eq!(result.functions[0].name, "add");
}
#[test]
fn test_find_function() {
let source = r#"
def target_function(x: int) -> int:
return x * 2
def other_function():
pass
"#;
let file = create_temp_file(source, ".py");
let func = AstExtractor::find_function(file.path(), "target_function");
assert!(func.is_ok());
assert_eq!(func.unwrap().name, "target_function");
let not_found = AstExtractor::find_function(file.path(), "nonexistent");
assert!(not_found.is_err());
}
#[test]
fn test_find_class() {
let source = r#"
class TargetClass:
pass
class OtherClass:
pass
"#;
let file = create_temp_file(source, ".py");
let class = AstExtractor::find_class(file.path(), "TargetClass");
assert!(class.is_ok());
assert_eq!(class.unwrap().name, "TargetClass");
let not_found = AstExtractor::find_class(file.path(), "NonexistentClass");
assert!(not_found.is_err());
assert!(matches!(not_found, Err(BrrrError::ClassNotFound(_))));
}
#[test]
fn test_unsupported_language() {
let file = create_temp_file("some content", ".xyz");
let result = AstExtractor::extract_file(file.path());
assert!(matches!(result, Err(BrrrError::UnsupportedLanguage(_))));
}
#[test]
fn test_decorated_python_function() {
let source = r#"
@staticmethod
@cache
def cached_function(x: int) -> int:
return x * 2
"#;
let file = create_temp_file(source, ".py");
let result = AstExtractor::extract_file(file.path()).unwrap();
assert_eq!(result.functions.len(), 1);
let func = &result.functions[0];
assert_eq!(func.name, "cached_function");
assert!(!func.decorators.is_empty());
}
#[test]
fn test_decorated_python_class() {
let source = r#"
@dataclass
class Point:
x: float
y: float
"#;
let file = create_temp_file(source, ".py");
let result = AstExtractor::extract_file(file.path()).unwrap();
assert_eq!(result.classes.len(), 1);
let class = &result.classes[0];
assert_eq!(class.name, "Point");
assert!(!class.decorators.is_empty());
}
#[test]
fn test_multiple_decorated_functions_no_duplicates() {
let source = r#"
@decorator1
def func1():
pass
@decorator2
@decorator3
def func2():
pass
@contextmanager
def func3():
yield
def plain_func():
pass
"#;
let file = create_temp_file(source, ".py");
let result = AstExtractor::extract_file(file.path()).unwrap();
assert_eq!(
result.functions.len(),
4,
"Expected 4 functions, got {}: {:?}",
result.functions.len(),
result.functions.iter().map(|f| &f.name).collect::<Vec<_>>()
);
let names: Vec<&str> = result.functions.iter().map(|f| f.name.as_str()).collect();
assert!(names.contains(&"func1"));
assert!(names.contains(&"func2"));
assert!(names.contains(&"func3"));
assert!(names.contains(&"plain_func"));
}
#[test]
fn test_nested_decorated_classes_no_duplicates() {
let source = r#"
@dataclass
class Point:
x: float
y: float
@singleton
@validate
class Config:
value: str
class PlainClass:
pass
"#;
let file = create_temp_file(source, ".py");
let result = AstExtractor::extract_file(file.path()).unwrap();
assert_eq!(
result.classes.len(),
3,
"Expected 3 classes, got {}: {:?}",
result.classes.len(),
result.classes.iter().map(|c| &c.name).collect::<Vec<_>>()
);
let names: Vec<&str> = result.classes.iter().map(|c| c.name.as_str()).collect();
assert!(names.contains(&"Point"));
assert!(names.contains(&"Config"));
assert!(names.contains(&"PlainClass"));
}
#[test]
fn test_overlap_detection_algorithm() {
fn overlaps(start: usize, end: usize, s: usize, e: usize) -> bool {
start < e && s < end
}
assert!(overlaps(10, 20, 15, 25), "Partial overlap should be detected");
assert!(overlaps(15, 25, 10, 20), "Partial overlap should be detected (reversed)");
assert!(overlaps(10, 30, 15, 20), "Containment should be detected");
assert!(overlaps(15, 20, 10, 30), "Containment should be detected (reversed)");
assert!(!overlaps(10, 20, 20, 30), "Adjacent intervals should not overlap");
assert!(!overlaps(10, 20, 25, 30), "Disjoint intervals should not overlap");
assert!(!overlaps(25, 30, 10, 20), "Disjoint intervals should not overlap (reversed)");
assert!(overlaps(10, 20, 10, 20), "Same interval should overlap");
assert!(overlaps(10, 20, 19, 25), "Should overlap when ranges share interior point");
}
#[test]
fn test_position_set_deduplication() {
let mut set = PositionSet::with_tolerance(2);
assert!(!set.is_duplicate(100), "Empty set should have no duplicates");
set.insert(100);
assert!(set.is_duplicate(100), "Exact position should be duplicate");
assert!(set.is_duplicate(99), "Position 99 should be duplicate (within tolerance of 100)");
assert!(set.is_duplicate(101), "Position 101 should be duplicate (within tolerance of 100)");
assert!(set.is_duplicate(98), "Position 98 should be duplicate (within tolerance of 100)");
assert!(set.is_duplicate(102), "Position 102 should be duplicate (within tolerance of 100)");
assert!(!set.is_duplicate(97), "Position 97 should NOT be duplicate (outside tolerance)");
assert!(!set.is_duplicate(103), "Position 103 should NOT be duplicate (outside tolerance)");
assert!(!set.is_duplicate(50), "Position 50 should NOT be duplicate");
assert!(!set.is_duplicate(200), "Position 200 should NOT be duplicate");
set.insert(500);
assert!(set.is_duplicate(500), "Position 500 should now be duplicate");
assert!(set.is_duplicate(498), "Position 498 should be duplicate (within tolerance of 500)");
assert!(!set.is_duplicate(495), "Position 495 should NOT be duplicate");
assert!(set.is_duplicate(100), "Position 100 should still be duplicate");
let mut set2 = PositionSet::with_tolerance(2);
set2.insert(0);
assert!(set2.is_duplicate(0), "Position 0 should be duplicate");
assert!(set2.is_duplicate(1), "Position 1 should be duplicate (within tolerance of 0)");
assert!(set2.is_duplicate(2), "Position 2 should be duplicate (within tolerance of 0)");
assert!(!set2.is_duplicate(3), "Position 3 should NOT be duplicate");
set2.insert(1);
assert!(set2.is_duplicate(0), "Position 0 should be duplicate");
assert!(set2.is_duplicate(3), "Position 3 should be duplicate (within tolerance of 1)");
}
#[test]
fn test_position_set_performance_characteristics() {
let mut set = PositionSet::with_tolerance(2);
for i in 0..1000 {
let pos = i * 100;
assert!(!set.is_duplicate(pos), "Position {} should not be duplicate before insert", pos);
set.insert(pos);
assert!(set.is_duplicate(pos), "Position {} should be duplicate after insert", pos);
}
for i in 0..1000 {
let pos = i * 100;
assert!(set.is_duplicate(pos), "Position {} should be duplicate", pos);
assert!(set.is_duplicate(pos + 1), "Position {} should be duplicate (tolerance)", pos + 1);
if i < 999 {
assert!(!set.is_duplicate(pos + 50), "Position {} should NOT be duplicate (between functions)", pos + 50);
}
}
}
#[test]
fn test_extract_java_methods_with_fallback() {
let source = r#"
public class Calculator {
public int add(int a, int b) {
return a + b;
}
public Calculator() {
// constructor
}
private void helper() {
// helper method
}
}
"#;
let file = create_temp_file(source, ".java");
let result = AstExtractor::extract_file(file.path()).unwrap();
assert_eq!(result.classes.len(), 1, "Should extract Calculator class");
let calc = &result.classes[0];
assert_eq!(calc.name, "Calculator");
assert!(
calc.methods.len() >= 2,
"Should extract at least 2 methods from Calculator, got {}",
calc.methods.len()
);
let method_names: Vec<&str> = calc.methods.iter().map(|m| m.name.as_str()).collect();
assert!(
method_names.contains(&"add"),
"Should find 'add' method, found: {:?}",
method_names
);
}
#[test]
fn test_extract_go_structs_with_fallback() {
let source = r#"
package main
type Person struct {
Name string
Age int
}
type Speaker interface {
Speak() string
}
"#;
let file = create_temp_file(source, ".go");
let result = AstExtractor::extract_file(file.path()).unwrap();
assert!(
result.classes.len() >= 2,
"Should extract Person struct and Speaker interface, got {}",
result.classes.len()
);
let names: Vec<&str> = result.classes.iter().map(|c| c.name.as_str()).collect();
assert!(
names.contains(&"Person"),
"Should find Person struct, found: {:?}",
names
);
assert!(
names.contains(&"Speaker"),
"Should find Speaker interface, found: {:?}",
names
);
}
#[test]
fn test_fallback_node_selection_helper_functions() {
assert!(
FUNCTION_NODE_KINDS.contains(&"function_definition"),
"Should contain Python function_definition"
);
assert!(
FUNCTION_NODE_KINDS.contains(&"method_declaration"),
"Should contain Java method_declaration"
);
assert!(
FUNCTION_NODE_KINDS.contains(&"function_item"),
"Should contain Rust function_item"
);
assert!(
FUNCTION_NODE_KINDS.contains(&"arrow_function"),
"Should contain TypeScript arrow_function"
);
assert!(
CLASS_NODE_KINDS.contains(&"class_definition"),
"Should contain Python class_definition"
);
assert!(
CLASS_NODE_KINDS.contains(&"type_declaration"),
"Should contain Go type_declaration"
);
assert!(
CLASS_NODE_KINDS.contains(&"struct_specifier"),
"Should contain C struct_specifier"
);
assert!(
CLASS_NODE_KINDS.contains(&"class_declaration"),
"Should contain Java/TS class_declaration"
);
}
#[test]
fn test_query_caching() {
let _baseline_size = query_cache_stats();
let source = r#"
def hello():
pass
class World:
pass
"#;
let file = create_temp_file(source, ".py");
let result = AstExtractor::extract_file(file.path()).unwrap();
assert!(!result.functions.is_empty(), "Should extract at least one function");
assert!(!result.classes.is_empty(), "Should extract at least one class");
let cache_size_after_python = query_cache_stats();
assert!(
cache_size_after_python >= 2,
"Cache should have at least 2 entries (function + class), got {}",
cache_size_after_python
);
let source2 = r#"
def another():
return 42
"#;
let file2 = create_temp_file(source2, ".py");
let result2 = AstExtractor::extract_file(file2.path()).unwrap();
assert!(!result2.functions.is_empty());
assert_eq!(
query_cache_stats(),
cache_size_after_python,
"Cache size should remain the same when reusing same language"
);
}
#[test]
fn test_query_cache_reuse() {
let ts_source1 = "function greet(): string { return 'hello'; }";
let ts_file1 = create_temp_file(ts_source1, ".ts");
let _ = AstExtractor::extract_file(ts_file1.path()).unwrap();
let size_after_first = query_cache_stats();
let ts_source2 = "const add = (a: number, b: number) => a + b;";
let ts_file2 = create_temp_file(ts_source2, ".ts");
let ts_result = AstExtractor::extract_file(ts_file2.path()).unwrap();
assert!(!ts_result.functions.is_empty(), "Should extract TypeScript function");
let size_after_second = query_cache_stats();
assert_eq!(
size_after_first, size_after_second,
"Cache size should remain the same when reprocessing same language"
);
}
#[test]
fn test_query_cache_thread_safety() {
use std::thread;
let handles: Vec<_> = (0..4)
.map(|i| {
thread::spawn(move || {
let source = format!(
r#"
def func_{}():
pass
"#,
i
);
let file = create_temp_file(&source, ".py");
let result = AstExtractor::extract_file(file.path());
assert!(result.is_ok(), "Extraction should succeed in thread {}", i);
})
})
.collect();
for handle in handles {
handle.join().expect("Thread should complete successfully");
}
let cache_size = query_cache_stats();
assert!(
cache_size >= 2,
"Cache should have entries after concurrent access, got {}",
cache_size
);
}
#[test]
fn test_parser_caching_basic() {
let source1 = r#"
def hello():
pass
"#;
let source2 = r#"
def world():
return 42
"#;
let file1 = create_temp_file(source1, ".py");
let file2 = create_temp_file(source2, ".py");
let result1 = AstExtractor::extract_file(file1.path());
assert!(result1.is_ok(), "First extraction should succeed");
let result2 = AstExtractor::extract_file(file2.path());
assert!(result2.is_ok(), "Second extraction should succeed (using cached parser)");
assert!(!result1.unwrap().functions.is_empty());
assert!(!result2.unwrap().functions.is_empty());
}
#[test]
fn test_parser_caching_multiple_languages() {
let py_source = "def hello(): pass";
let ts_source = "function hello(): void {}";
let go_source = "package main\nfunc hello() {}";
let py_file = create_temp_file(py_source, ".py");
let ts_file = create_temp_file(ts_source, ".ts");
let go_file = create_temp_file(go_source, ".go");
let py_result = AstExtractor::extract_file(py_file.path());
let ts_result = AstExtractor::extract_file(ts_file.path());
let go_result = AstExtractor::extract_file(go_file.path());
assert!(py_result.is_ok(), "Python extraction should succeed");
assert!(ts_result.is_ok(), "TypeScript extraction should succeed");
assert!(go_result.is_ok(), "Go extraction should succeed");
assert_eq!(py_result.unwrap().language, "python");
assert_eq!(ts_result.unwrap().language, "typescript");
assert_eq!(go_result.unwrap().language, "go");
let cache_size = parser_cache_stats();
assert!(
cache_size >= 3,
"Cache should have at least 3 parsers (one per language), got {}",
cache_size
);
}
#[test]
fn test_parser_caching_extract_from_source() {
let source1 = "def foo(): pass";
let source2 = "def bar(): return 1";
let result1 = AstExtractor::extract_from_source(source1, "python");
let result2 = AstExtractor::extract_from_source(source2, "python");
assert!(result1.is_ok(), "First extract_from_source should succeed");
assert!(result2.is_ok(), "Second extract_from_source should succeed (cached)");
assert_eq!(result1.unwrap().functions[0].name, "foo");
assert_eq!(result2.unwrap().functions[0].name, "bar");
}
#[test]
fn test_parser_cache_clear() {
let source = "def test(): pass";
let file = create_temp_file(source, ".py");
let _ = AstExtractor::extract_file(file.path()).unwrap();
let before_clear = parser_cache_stats();
assert!(before_clear >= 1, "Cache should have at least 1 parser before clear");
clear_parser_cache();
let after_clear = parser_cache_stats();
assert_eq!(after_clear, 0, "Cache should be empty after clear");
let source2 = "def another(): pass";
let file2 = create_temp_file(source2, ".py");
let result = AstExtractor::extract_file(file2.path());
assert!(result.is_ok(), "Extraction should work after cache clear");
let after_extraction = parser_cache_stats();
assert_eq!(after_extraction, 1, "Cache should have 1 parser after extraction");
}
#[test]
fn test_parser_cache_thread_local() {
use std::thread;
clear_parser_cache();
let source = "def main_thread(): pass";
let file = create_temp_file(source, ".py");
let _ = AstExtractor::extract_file(file.path()).unwrap();
let main_thread_cache = parser_cache_stats();
assert!(main_thread_cache >= 1, "Main thread cache should have parser");
let handle = thread::spawn(|| {
let child_cache_before = parser_cache_stats();
let source = "def child_thread(): pass";
let file = create_temp_file(source, ".py");
let _ = AstExtractor::extract_file(file.path()).unwrap();
let child_cache_after = parser_cache_stats();
(child_cache_before, child_cache_after)
});
let (child_before, child_after) = handle.join().unwrap();
assert_eq!(
child_before, 0,
"Child thread should start with empty cache"
);
assert!(
child_after >= 1,
"Child thread should have parser after extraction"
);
let main_thread_cache_after = parser_cache_stats();
assert_eq!(
main_thread_cache, main_thread_cache_after,
"Main thread cache should be unchanged by child thread"
);
}
}