use std::path::Path;
use tree_sitter::Node;
use super::error::{AstQueryError, Result};
use super::types::{Context, ContextItem, ContextKind, ContextualMatch, ContextualMatchLocation};
use crate::graph::unified::build::StagingGraph;
use crate::graph::unified::concurrent::CodeGraph;
use crate::plugin::PluginManager;
pub struct ContextExtractor {
plugin_manager: PluginManager,
}
impl ContextExtractor {
#[must_use]
pub fn new() -> Self {
Self::with_plugin_manager(PluginManager::new())
}
#[must_use]
pub fn with_plugin_manager(plugin_manager: PluginManager) -> Self {
Self { plugin_manager }
}
#[allow(clippy::too_many_lines)]
pub fn extract_from_file(&self, path: &Path) -> Result<Vec<ContextualMatch>> {
let plugin = self.plugin_manager.plugin_for_path(path).ok_or_else(|| {
AstQueryError::ContextExtraction(format!(
"No plugin found for path: {}",
path.display()
))
})?;
let raw_content = std::fs::read(path)?;
let lang_name = plugin.metadata().id;
let (prepared_content, tree) = plugin
.prepare_ast(&raw_content)
.map_err(|e| AstQueryError::ContextExtraction(format!("Failed to parse AST: {e:?}")))?;
let parse_content = prepared_content.as_ref();
let builder = plugin.graph_builder().ok_or_else(|| {
AstQueryError::ContextExtraction(format!("No graph builder registered for {lang_name}"))
})?;
let mut staging = StagingGraph::new();
builder
.build_graph(&tree, parse_content, path, &mut staging)
.map_err(|e| {
AstQueryError::ContextExtraction(format!(
"Failed to build graph for {}: {e}",
path.display()
))
})?;
staging.attach_body_hashes(&raw_content);
let mut graph = CodeGraph::new();
let file_id = graph
.files_mut()
.register_with_language(path, Some(builder.language()))
.map_err(|e| {
AstQueryError::ContextExtraction(format!(
"Failed to register file {}: {e}",
path.display()
))
})?;
staging.apply_file_id(file_id);
let string_remap = staging.commit_strings(graph.strings_mut()).map_err(|e| {
AstQueryError::ContextExtraction(format!(
"Failed to commit strings for {}: {e}",
path.display()
))
})?;
staging.apply_string_remap(&string_remap).map_err(|e| {
AstQueryError::ContextExtraction(format!(
"Failed to remap strings for {}: {e}",
path.display()
))
})?;
let _node_id_map = staging.commit_nodes(graph.nodes_mut()).map_err(|e| {
AstQueryError::ContextExtraction(format!(
"Failed to commit nodes for {}: {e}",
path.display()
))
})?;
let content_str = String::from_utf8_lossy(&raw_content);
let root_node = tree.root_node();
let mut contextual_matches = Vec::new();
for (_, entry) in graph.nodes().iter() {
if entry.is_unified_loser() {
continue;
}
if ContextKind::from_node_kind(entry.kind).is_none() {
continue;
}
if entry.start_line == 0 {
continue;
}
let start_line = entry.start_line;
let start_column = entry.start_column;
let mut node = Self::find_defining_node(root_node, start_line, start_column, lang_name);
if node.is_none()
&& Self::looks_like_byte_span(
entry.start_line,
entry.end_line,
entry.start_column,
entry.end_column,
&content_str,
)
{
node = Self::find_defining_node_by_bytes(
root_node,
entry.start_column as usize,
entry.end_column as usize,
lang_name,
);
}
if let Some(node) = node {
let semantic_context = Self::build_context(&node, &content_str, lang_name);
let match_name = semantic_context.immediate.name.clone();
let location = ContextualMatchLocation::new(
path.to_path_buf(),
entry.start_line,
entry.start_column,
entry.end_line,
entry.end_column,
);
contextual_matches.push(ContextualMatch::new(
match_name,
location,
semantic_context,
lang_name.to_string(),
));
}
}
Ok(contextual_matches)
}
fn find_defining_node<'a>(
root: Node<'a>,
line: u32,
column: u32,
lang_name: &str,
) -> Option<Node<'a>> {
let mut cursor = root.walk();
Self::find_defining_node_recursive(root, line, column, lang_name, &mut cursor)
}
fn find_defining_node_by_bytes<'a>(
root: Node<'a>,
start: usize,
end: usize,
lang_name: &str,
) -> Option<Node<'a>> {
let target = root.descendant_for_byte_range(start, end)?;
let mut current = Some(target);
while let Some(node) = current {
if Self::is_named_scope(&node, lang_name) {
return Some(node);
}
current = node.parent();
}
None
}
fn looks_like_byte_span(
start_line: u32,
end_line: u32,
start_column: u32,
end_column: u32,
source: &str,
) -> bool {
if start_line != 1 || end_line != 1 {
return false;
}
let first_line_len = source.lines().next().map_or(0, str::len);
let start = start_column as usize;
let end = end_column as usize;
start > first_line_len || end > first_line_len
}
fn find_defining_node_recursive<'a>(
node: Node<'a>,
line: u32,
column: u32,
lang_name: &str,
cursor: &mut tree_sitter::TreeCursor<'a>,
) -> Option<Node<'a>> {
let start_pos = node.start_position();
let end_pos = node.end_position();
let node_start_line = start_pos
.row
.try_into()
.unwrap_or(u32::MAX)
.saturating_add(1);
let node_end_line = end_pos.row.try_into().unwrap_or(u32::MAX).saturating_add(1);
let line_in_range = line >= node_start_line && line <= node_end_line;
let start_col: u32 = start_pos.column.try_into().unwrap_or(u32::MAX);
let end_col: u32 = end_pos.column.try_into().unwrap_or(u32::MAX);
let col_in_range = if line == node_start_line && line == node_end_line {
column >= start_col && column <= end_col
} else if line == node_start_line {
column >= start_col
} else if line == node_end_line {
column <= end_col
} else {
true
};
if !line_in_range || !col_in_range {
return None;
}
let children: Vec<Node<'a>> = node.children(cursor).collect();
for child in children {
let child_end = child.end_position();
let child_end_line: u32 = child_end
.row
.try_into()
.unwrap_or(u32::MAX)
.saturating_add(1);
if child_end_line >= line
&& let Some(found) =
Self::find_defining_node_recursive(child, line, column, lang_name, cursor)
{
return Some(found);
}
}
if Self::is_named_scope(&node, lang_name) {
return Some(node);
}
None
}
fn build_context(node: &Node, source: &str, lang_name: &str) -> Context {
let source_bytes = source.as_bytes();
let immediate = Self::node_to_context_item(node, source_bytes, lang_name);
let mut parent = None;
let mut ancestors = Vec::new();
let mut current = node.parent();
while let Some(node) = current {
if Self::is_named_scope(&node, lang_name) {
let item = Self::node_to_context_item(&node, source_bytes, lang_name);
if parent.is_none() {
parent = Some(item);
} else {
ancestors.push(item);
}
}
current = node.parent();
}
Context::new(immediate, parent, ancestors)
}
fn node_to_context_item(node: &Node, source_bytes: &[u8], lang_name: &str) -> ContextItem {
let name = Self::extract_name(node, source_bytes, lang_name)
.unwrap_or_else(|| "<anonymous>".to_string());
let kind = Self::node_to_context_kind(node, lang_name);
let start_pos = node.start_position();
let end_pos = node.end_position();
let start_line = start_pos
.row
.try_into()
.unwrap_or(u32::MAX)
.saturating_add(1);
let end_line = end_pos.row.try_into().unwrap_or(u32::MAX).saturating_add(1);
ContextItem::new(
name,
kind,
start_line,
end_line,
node.start_byte(),
node.end_byte(),
)
}
fn is_named_scope(node: &Node, lang_name: &str) -> bool {
let kind = node.kind();
if matches!(kind, "source_file" | "program" | "module") {
return false;
}
match lang_name {
"rust" => matches!(
kind,
"function_item"
| "impl_item"
| "trait_item"
| "struct_item"
| "enum_item"
| "mod_item"
),
"javascript" | "typescript" => matches!(
kind,
"function_declaration"
| "method_definition"
| "class_declaration"
| "lexical_declaration"
),
"python" => matches!(kind, "function_definition" | "class_definition"),
"go" => matches!(
kind,
"function_declaration" | "method_declaration" | "type_declaration"
),
_ => false,
}
}
fn identifier_kinds(lang_name: &str) -> &'static [&'static str] {
match lang_name {
"rust" => &["identifier", "type_identifier"],
"javascript" | "typescript" => &["identifier", "property_identifier"],
"python" | "go" => &["identifier"],
_ => &[],
}
}
fn extract_name(node: &Node, source_bytes: &[u8], lang_name: &str) -> Option<String> {
let kinds = Self::identifier_kinds(lang_name);
if kinds.is_empty() {
return None;
}
let mut cursor = node.walk();
node.children(&mut cursor)
.find(|child| kinds.contains(&child.kind()))
.and_then(|child| child.utf8_text(source_bytes).ok())
.map(std::string::ToString::to_string)
}
fn node_to_context_kind(node: &Node, lang_name: &str) -> ContextKind {
if lang_name == "rust" && node.kind() == "function_item" {
let mut current = node.parent();
while let Some(parent) = current {
if matches!(parent.kind(), "impl_item" | "trait_item") {
return ContextKind::Method;
}
current = parent.parent();
}
}
Self::node_kind_to_context_kind(node.kind(), lang_name)
}
fn node_kind_to_context_kind(node_kind: &str, lang_name: &str) -> ContextKind {
match lang_name {
"rust" => match node_kind {
"impl_item" => ContextKind::Impl,
"trait_item" => ContextKind::Trait,
"struct_item" => ContextKind::Struct,
"enum_item" => ContextKind::Enum,
"mod_item" => ContextKind::Module,
"const_item" => ContextKind::Constant,
"static_item" => ContextKind::Variable,
"type_item" => ContextKind::TypeAlias,
_ => ContextKind::Function,
},
"javascript" | "typescript" => match node_kind {
"method_definition" => ContextKind::Method,
"class_declaration" => ContextKind::Class,
"lexical_declaration" | "variable_declaration" => ContextKind::Variable,
_ => ContextKind::Function,
},
"python" => match node_kind {
"class_definition" => ContextKind::Class,
_ => ContextKind::Function,
},
"go" => match node_kind {
"method_declaration" => ContextKind::Method,
"type_declaration" => ContextKind::Struct,
_ => ContextKind::Function,
},
_ => ContextKind::Function,
}
}
pub fn extract_from_directory(&self, root: &Path) -> Result<Vec<ContextualMatch>> {
let mut all_matches = Vec::new();
for entry in walkdir::WalkDir::new(root)
.follow_links(false)
.into_iter()
.filter_map(std::result::Result::ok)
{
let path = entry.path();
if path.is_file() {
if let Ok(mut matches) = self.extract_from_file(path) {
all_matches.append(&mut matches);
}
}
}
Ok(all_matches)
}
}
impl Default for ContextExtractor {
fn default() -> Self {
Self::new()
}
}
#[cfg(all(test, feature = "context-tests"))]
mod tests {
use super::*;
use std::fs;
use tempfile::TempDir;
fn create_test_plugin_manager() -> crate::plugin::PluginManager {
crate::test_support::plugin_factory::with_builtin_plugins()
}
#[test]
#[ignore = "Plugins not available in unit tests (dev-dependencies). Move to integration tests if needed."]
fn test_extract_rust_function_context() {
let dir = TempDir::new().unwrap();
let file_path = dir.path().join("test.rs");
fs::write(
&file_path,
r#"
fn top_level() {
println!("hello");
}
struct MyStruct {
value: i32,
}
impl MyStruct {
fn method(&self) -> i32 {
self.value
}
}
"#,
)
.unwrap();
let manager = create_test_plugin_manager();
let extractor = ContextExtractor::with_plugin_manager(manager);
let matches = extractor.extract_from_file(&file_path).unwrap();
assert!(
matches.len() >= 2,
"Expected at least 2 matches, found {}",
matches.len()
);
let top_level = matches.iter().find(|m| m.name == "top_level");
assert!(top_level.is_some(), "Should find top_level function");
if let Some(m) = top_level {
assert_eq!(m.context.depth(), 1, "top_level should be at depth 1");
assert_eq!(m.context.path(), "top_level");
}
let method = matches.iter().find(|m| m.name == "method");
if let Some(m) = method {
assert!(m.context.depth() >= 1, "method should have depth >= 1");
assert!(m.context.parent.is_some(), "method should have a parent");
}
}
#[test]
#[ignore = "Plugins not available in unit tests (dev-dependencies). Move to integration tests if needed."]
fn test_extract_nested_context() {
let dir = TempDir::new().unwrap();
let file_path = dir.path().join("test.rs");
fs::write(
&file_path,
r"
mod outer {
struct Inner {
value: i32,
}
impl Inner {
fn deeply_nested(&self) {
// nested function
}
}
}
",
)
.unwrap();
let manager = create_test_plugin_manager();
let extractor = ContextExtractor::with_plugin_manager(manager);
let matches = extractor.extract_from_file(&file_path).unwrap();
let method = matches.iter().find(|m| m.name == "deeply_nested");
if let Some(m) = method {
assert!(m.context.parent.is_some(), "Should have parent");
assert!(m.context.depth() >= 1, "Should have depth >= 1");
}
}
#[test]
#[ignore = "JavaScript plugin not registered in test helper"]
fn test_extract_javascript_class() {
let dir = TempDir::new().unwrap();
let file_path = dir.path().join("test.js");
fs::write(
&file_path,
r#"
function topLevel() {
console.log("hello");
}
class MyClass {
constructor(name) {
this.name = name;
}
greet() {
console.log("Hello " + this.name);
}
}
"#,
)
.unwrap();
let manager = create_test_plugin_manager();
let extractor = ContextExtractor::with_plugin_manager(manager);
let matches = extractor.extract_from_file(&file_path).unwrap();
assert!(matches.len() >= 2, "Should find at least 2 matches");
let top_fn = matches.iter().find(|m| m.name == "topLevel");
if let Some(m) = top_fn {
assert_eq!(m.context.depth(), 1, "topLevel should be at depth 1");
}
let class = matches.iter().find(|m| m.name == "MyClass");
assert!(class.is_some(), "Should find MyClass");
}
#[test]
#[ignore = "Python plugin not registered in test helper"]
fn test_extract_python_context() {
let dir = TempDir::new().unwrap();
let file_path = dir.path().join("test.py");
fs::write(
&file_path,
r#"
def top_level():
print("hello")
class MyClass:
def method(self):
return 42
"#,
)
.unwrap();
let manager = create_test_plugin_manager();
let extractor = ContextExtractor::with_plugin_manager(manager);
let matches = extractor.extract_from_file(&file_path).unwrap();
assert!(matches.len() >= 2, "Should find at least 2 matches");
let top_fn = matches.iter().find(|m| m.name == "top_level");
if let Some(m) = top_fn {
assert_eq!(m.context.depth(), 1);
}
let method = matches.iter().find(|m| m.name == "method");
if let Some(m) = method {
assert!(m.context.depth() >= 2);
assert!(m.context.parent.is_some());
}
}
#[test]
#[ignore = "Plugins not available in unit tests (dev-dependencies). Move to integration tests if needed."]
fn test_empty_file() {
let dir = TempDir::new().unwrap();
let file_path = dir.path().join("empty.rs");
fs::write(&file_path, "").unwrap();
let manager = create_test_plugin_manager();
let extractor = ContextExtractor::with_plugin_manager(manager);
let matches = extractor.extract_from_file(&file_path).unwrap();
assert_eq!(matches.len(), 0);
}
#[test]
#[ignore = "Plugins not available in unit tests (dev-dependencies). Move to integration tests if needed."]
fn test_position_matching_single_line_function() {
let dir = TempDir::new().unwrap();
let file_path = dir.path().join("test.rs");
fs::write(
&file_path,
r#"
fn single_line() { println!("hello"); }
"#,
)
.unwrap();
let manager = create_test_plugin_manager();
let extractor = ContextExtractor::with_plugin_manager(manager);
let matches = extractor.extract_from_file(&file_path).unwrap();
let func = matches.iter().find(|m| m.name == "single_line");
assert!(func.is_some(), "Should find single-line function");
if let Some(m) = func {
assert_eq!(m.context.depth(), 1);
assert_eq!(m.context.path(), "single_line");
}
}
#[test]
#[ignore = "Plugins not available in unit tests (dev-dependencies). Move to integration tests if needed."]
fn test_position_matching_multiline_function() {
let dir = TempDir::new().unwrap();
let file_path = dir.path().join("test.rs");
fs::write(
&file_path,
r"
fn multiline() {
let x = 1;
let y = 2;
x + y
}
",
)
.unwrap();
let manager = create_test_plugin_manager();
let extractor = ContextExtractor::with_plugin_manager(manager);
let matches = extractor.extract_from_file(&file_path).unwrap();
let func = matches.iter().find(|m| m.name == "multiline");
assert!(func.is_some(), "Should find multi-line function");
if let Some(m) = func {
assert_eq!(m.context.depth(), 1);
assert!(m.end_line > m.start_line + 1);
}
}
#[test]
#[ignore = "Plugins not available in unit tests (dev-dependencies). Move to integration tests if needed."]
fn test_position_matching_nested_structures() {
let dir = TempDir::new().unwrap();
let file_path = dir.path().join("test.rs");
fs::write(
&file_path,
r"
mod outer {
struct Inner {
field: i32,
}
impl Inner {
fn method(&self) -> i32 {
self.field
}
}
}
",
)
.unwrap();
let manager = create_test_plugin_manager();
let extractor = ContextExtractor::with_plugin_manager(manager);
let matches = extractor.extract_from_file(&file_path).unwrap();
let method = matches.iter().find(|m| m.name == "method");
assert!(method.is_some(), "Should find nested method");
if let Some(m) = method {
assert_eq!(m.context.depth(), 3, "Method should have depth 3");
assert!(m.context.parent.is_some(), "Method should have parent");
if let Some(parent) = &m.context.parent {
assert_eq!(parent.name, "Inner", "Method parent should be Inner impl");
}
}
}
#[test]
#[ignore = "Plugins not available in unit tests (dev-dependencies). Move to integration tests if needed."]
fn test_position_matching_with_comments() {
let dir = TempDir::new().unwrap();
let file_path = dir.path().join("test.rs");
fs::write(
&file_path,
r#"
// This is a comment
/// Documentation comment
fn documented_function() {
// Internal comment
println!("test");
}
"#,
)
.unwrap();
let manager = create_test_plugin_manager();
let extractor = ContextExtractor::with_plugin_manager(manager);
let matches = extractor.extract_from_file(&file_path).unwrap();
let func = matches.iter().find(|m| m.name == "documented_function");
assert!(func.is_some(), "Should find function with comments");
if let Some(m) = func {
assert_eq!(m.context.depth(), 1);
}
}
#[test]
#[ignore = "Plugins not available in unit tests (dev-dependencies). Move to integration tests if needed."]
fn test_position_matching_edge_positions() {
let dir = TempDir::new().unwrap();
let file_path = dir.path().join("test.rs");
fs::write(
&file_path,
r"
struct Container {
value: i32,
}
impl Container {
fn new(val: i32) -> Self {
Self { value: val }
}
}
",
)
.unwrap();
let manager = create_test_plugin_manager();
let extractor = ContextExtractor::with_plugin_manager(manager);
let matches = extractor.extract_from_file(&file_path).unwrap();
let container = matches.iter().find(|m| m.name == "Container");
let new_method = matches.iter().find(|m| m.name == "new");
assert!(container.is_some(), "Should find Container struct");
assert!(new_method.is_some(), "Should find new method");
if let Some(m) = new_method {
assert_eq!(m.context.depth(), 2, "Method should have depth 2");
if let Some(parent) = &m.context.parent {
assert_eq!(
parent.name, "Container",
"Method parent should be Container impl"
);
}
}
}
}