use std::path::Path;
use crate::{
parser::{Language, ParsedFile},
symbol_extraction::find_definitions,
};
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ResolvedSymbol {
pub name: String,
pub start_line: u32,
pub end_line: u32,
pub parent_name: Option<String>,
}
#[derive(Debug, thiserror::Error)]
pub enum SymbolResolveError {
#[error("unsupported file extension: {0}")]
UnsupportedLanguage(String),
#[error("failed to parse source file")]
ParseFailed,
#[error("symbol not found: {0}")]
SymbolNotFound(String),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DefinitionKind {
Type,
Trait,
Class,
Interface,
TypeAlias,
EnumDef,
ConstDecl,
Module,
Function,
Other,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Definition {
pub name: String,
pub kind: DefinitionKind,
pub start_line: u32,
pub end_line: u32,
pub parent_name: Option<String>,
}
pub fn extract_definitions(
source: &[u8],
path: &Path,
) -> Result<Vec<Definition>, SymbolResolveError> {
let language = Language::from_path(path);
language.parser_handle().ok_or_else(|| {
SymbolResolveError::UnsupportedLanguage(
path.extension()
.map(|e| e.to_string_lossy().into_owned())
.unwrap_or_else(|| "<none>".to_string()),
)
})?;
let source_text = std::str::from_utf8(source).map_err(|_| SymbolResolveError::ParseFailed)?;
let parsed = ParsedFile::parse(source_text, language).ok_or(SymbolResolveError::ParseFailed)?;
let mut out = Vec::new();
walk_definitions(&parsed.root_node(), source, None, &mut out);
Ok(out)
}
fn node_text<'a>(node: &tree_sitter::Node, source: &'a [u8]) -> &'a str {
std::str::from_utf8(&source[node.byte_range()]).unwrap_or("")
}
fn push_named_definition(
node: &tree_sitter::Node,
source: &[u8],
dk: DefinitionKind,
parent: Option<&str>,
out: &mut Vec<Definition>,
) {
if let Some(name_node) = node.child_by_field_name("name") {
let name = node_text(&name_node, source).to_string();
if name.is_empty() {
return;
}
out.push(Definition {
name,
kind: dk,
start_line: node.start_position().row as u32 + 1,
end_line: node.end_position().row as u32 + 1,
parent_name: parent.map(String::from),
});
}
}
fn walk_definitions(
node: &tree_sitter::Node,
source: &[u8],
current_parent: Option<&str>,
out: &mut Vec<Definition>,
) {
let kind = node.kind();
match kind {
"function_item" => {
push_named_definition(node, source, DefinitionKind::Function, current_parent, out)
}
"struct_item" => {
push_named_definition(node, source, DefinitionKind::Type, current_parent, out)
}
"enum_item" => {
push_named_definition(node, source, DefinitionKind::EnumDef, current_parent, out)
}
"trait_item" => {
push_named_definition(node, source, DefinitionKind::Trait, current_parent, out)
}
"type_item" => {
push_named_definition(node, source, DefinitionKind::TypeAlias, current_parent, out)
}
"const_item" | "static_item" => {
push_named_definition(node, source, DefinitionKind::ConstDecl, current_parent, out)
}
"mod_item" => {
push_named_definition(node, source, DefinitionKind::Module, current_parent, out)
}
"impl_item" => {
let parent_name = extract_rust_impl_type_name(node, source);
let parent = parent_name.as_deref();
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
walk_definitions(&child, source, parent, out);
}
return;
}
"function_definition" => {
push_named_definition(node, source, DefinitionKind::Function, current_parent, out)
}
"class_definition" => {
let class_name = node
.child_by_field_name("name")
.map(|n| node_text(&n, source).to_string());
if let Some(ref name) = class_name
&& !name.is_empty()
{
out.push(Definition {
name: name.clone(),
kind: DefinitionKind::Class,
start_line: node.start_position().row as u32 + 1,
end_line: node.end_position().row as u32 + 1,
parent_name: current_parent.map(String::from),
});
}
let parent = class_name.as_deref();
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
walk_definitions(&child, source, parent, out);
}
return;
}
"function_declaration" => {
push_named_definition(node, source, DefinitionKind::Function, current_parent, out)
}
"method_declaration" => {
if let Some(name_node) = node.child_by_field_name("name") {
let name = node_text(&name_node, source).to_string();
if !name.is_empty() {
let receiver = extract_go_receiver_type(node, source);
out.push(Definition {
name,
kind: DefinitionKind::Function,
start_line: node.start_position().row as u32 + 1,
end_line: node.end_position().row as u32 + 1,
parent_name: receiver.or_else(|| current_parent.map(String::from)),
});
}
}
}
"type_declaration" => {
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.kind() == "type_spec"
&& let Some(name_node) = child.child_by_field_name("name")
{
let name = node_text(&name_node, source).to_string();
if name.is_empty() {
continue;
}
let dk = match child.child_by_field_name("type").map(|t| t.kind()) {
Some("interface_type") => DefinitionKind::Interface,
Some("struct_type") => DefinitionKind::Type,
_ => DefinitionKind::TypeAlias,
};
out.push(Definition {
name,
kind: dk,
start_line: child.start_position().row as u32 + 1,
end_line: child.end_position().row as u32 + 1,
parent_name: current_parent.map(String::from),
});
}
}
}
"method_definition" => {
push_named_definition(node, source, DefinitionKind::Function, current_parent, out)
}
"class_declaration" => {
let class_name = node
.child_by_field_name("name")
.map(|n| node_text(&n, source).to_string());
if let Some(ref name) = class_name
&& !name.is_empty()
{
out.push(Definition {
name: name.clone(),
kind: DefinitionKind::Class,
start_line: node.start_position().row as u32 + 1,
end_line: node.end_position().row as u32 + 1,
parent_name: current_parent.map(String::from),
});
}
let parent = class_name.as_deref();
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
walk_definitions(&child, source, parent, out);
}
return;
}
"interface_declaration" => {
push_named_definition(node, source, DefinitionKind::Interface, current_parent, out)
}
"type_alias_declaration" => {
push_named_definition(node, source, DefinitionKind::TypeAlias, current_parent, out)
}
"enum_declaration" => {
push_named_definition(node, source, DefinitionKind::EnumDef, current_parent, out)
}
"lexical_declaration" | "variable_declaration" => {
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.kind() == "variable_declarator"
&& let Some(name_node) = child.child_by_field_name("name")
{
let name = node_text(&name_node, source).to_string();
if name.is_empty() {
continue;
}
if let Some(value_node) = child.child_by_field_name("value") {
let vkind = value_node.kind();
let dk = if vkind == "arrow_function"
|| vkind == "function"
|| vkind == "function_expression"
{
DefinitionKind::Function
} else {
DefinitionKind::ConstDecl
};
out.push(Definition {
name,
kind: dk,
start_line: node.start_position().row as u32 + 1,
end_line: node.end_position().row as u32 + 1,
parent_name: current_parent.map(String::from),
});
}
}
}
}
"struct_specifier" | "class_specifier" => {
push_named_definition(node, source, DefinitionKind::Class, current_parent, out)
}
"namespace_definition" => {
push_named_definition(node, source, DefinitionKind::Module, current_parent, out)
}
"enum_specifier" => {
push_named_definition(node, source, DefinitionKind::EnumDef, current_parent, out)
}
"constructor_declaration" => {
push_named_definition(node, source, DefinitionKind::Function, current_parent, out)
}
_ => {}
}
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
walk_definitions(&child, source, current_parent, out);
}
}
fn extract_rust_impl_type_name(node: &tree_sitter::Node, source: &[u8]) -> Option<String> {
let type_node = node.child_by_field_name("type")?;
Some(extract_type_identifier(&type_node, source))
}
fn extract_type_identifier(node: &tree_sitter::Node, source: &[u8]) -> String {
match node.kind() {
"type_identifier" | "identifier" => node_text(node, source).to_string(),
"generic_type" | "scoped_type_identifier" => {
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.kind() == "type_identifier" || child.kind() == "identifier" {
return node_text(&child, source).to_string();
}
}
node_text(node, source).to_string()
}
_ => node_text(node, source).to_string(),
}
}
fn extract_go_receiver_type(node: &tree_sitter::Node, source: &[u8]) -> Option<String> {
let params = node.child_by_field_name("receiver")?;
let mut cursor = params.walk();
for child in params.children(&mut cursor) {
if child.kind() == "parameter_declaration"
&& let Some(type_node) = child.child_by_field_name("type")
{
let text = node_text(&type_node, source);
return Some(text.trim_start_matches('*').to_string());
}
}
None
}
pub fn resolve_symbol_lines(
source: &[u8],
path: &Path,
symbol: &str,
) -> Result<(u32, u32), SymbolResolveError> {
let language = Language::from_path(path);
language.parser_handle().ok_or_else(|| {
SymbolResolveError::UnsupportedLanguage(
path.extension()
.map(|e| e.to_string_lossy().into_owned())
.unwrap_or_else(|| "<none>".to_string()),
)
})?;
let source_text = std::str::from_utf8(source).map_err(|_| SymbolResolveError::ParseFailed)?;
let parsed = ParsedFile::parse(source_text, language).ok_or(SymbolResolveError::ParseFailed)?;
let (parent_filter, target_name) = if let Some(pos) = symbol.rfind("::") {
(Some(&symbol[..pos]), &symbol[pos + 2..])
} else {
(None, symbol)
};
let definitions = find_definitions(&parsed.root_node(), source, target_name);
let matched = if let Some(parent) = parent_filter {
definitions
.iter()
.find(|d| {
d.parent_name
.as_deref()
.map(|p| p == parent)
.unwrap_or(false)
})
.or_else(|| definitions.first())
} else {
definitions.first()
};
match matched {
Some(sym) => Ok((sym.start_line, sym.end_line)),
None => Err(SymbolResolveError::SymbolNotFound(symbol.to_string())),
}
}
pub fn resolve_all_symbols(
source: &[u8],
path: &Path,
symbol: &str,
) -> Result<Vec<ResolvedSymbol>, SymbolResolveError> {
let language = Language::from_path(path);
language.parser_handle().ok_or_else(|| {
SymbolResolveError::UnsupportedLanguage(
path.extension()
.map(|e| e.to_string_lossy().into_owned())
.unwrap_or_else(|| "<none>".to_string()),
)
})?;
let source_text = std::str::from_utf8(source).map_err(|_| SymbolResolveError::ParseFailed)?;
let parsed = ParsedFile::parse(source_text, language).ok_or(SymbolResolveError::ParseFailed)?;
let (parent_filter, target_name) = if let Some(pos) = symbol.rfind("::") {
(Some(&symbol[..pos]), &symbol[pos + 2..])
} else {
(None, symbol)
};
let definitions = find_definitions(&parsed.root_node(), source, target_name);
if let Some(parent) = parent_filter {
let filtered: Vec<_> = definitions
.into_iter()
.filter(|d| {
d.parent_name
.as_deref()
.map(|p| p == parent)
.unwrap_or(false)
})
.collect();
Ok(filtered)
} else {
Ok(definitions)
}
}
pub fn extract_line_range(source: &[u8], start: u32, end: u32) -> Vec<u8> {
let mut line: u32 = 1;
let mut byte_start = 0;
for (i, &b) in source.iter().enumerate() {
if line == start {
byte_start = i;
break;
}
if b == b'\n' {
line += 1;
}
}
if line < start {
return Vec::new();
}
for (i, &b) in source[byte_start..].iter().enumerate() {
if b == b'\n' {
line += 1;
if line > end {
return source[byte_start..byte_start + i + 1].to_vec();
}
}
}
source[byte_start..].to_vec()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn resolve_rust_fn_main() {
let source = br#"
fn helper() -> bool {
true
}
fn main() {
println!("hello");
let x = 1;
}
fn after() {}
"#;
let path = Path::new("test.rs");
let (start, end) = resolve_symbol_lines(source, path, "main").unwrap();
assert_eq!(start, 6);
assert_eq!(end, 9);
}
#[test]
fn resolve_rust_qualified_impl_method() {
let source = br#"
struct Repository {
path: String,
}
impl Repository {
pub fn open(path: &str) -> Self {
Repository {
path: path.to_string(),
}
}
pub fn close(&self) {}
}
impl Default for Repository {
fn default() -> Self {
Repository::open(".")
}
}
"#;
let path = Path::new("repo.rs");
let (start, end) = resolve_symbol_lines(source, path, "Repository::open").unwrap();
assert_eq!(start, 7);
assert_eq!(end, 11);
}
#[test]
fn resolve_rust_struct() {
let source = br#"
pub struct Config {
pub name: String,
pub value: u32,
}
"#;
let path = Path::new("config.rs");
let (start, end) = resolve_symbol_lines(source, path, "Config").unwrap();
assert_eq!(start, 2);
assert_eq!(end, 5);
}
#[test]
fn resolve_python_function() {
let source = br#"
def helper():
pass
def process_data(items):
result = []
for item in items:
result.append(item * 2)
return result
def cleanup():
pass
"#;
let path = Path::new("main.py");
let (start, end) = resolve_symbol_lines(source, path, "process_data").unwrap();
assert_eq!(start, 5);
assert_eq!(end, 9);
}
#[test]
fn resolve_python_class_method() {
let source = br#"
class Repository:
def __init__(self, path):
self.path = path
def open(self):
return True
"#;
let path = Path::new("repo.py");
let (start, end) = resolve_symbol_lines(source, path, "Repository::open").unwrap();
assert_eq!(start, 6);
assert_eq!(end, 7);
}
#[test]
#[cfg(feature = "lang-go")]
fn resolve_go_function() {
let source = br#"package main
func helper() bool {
return true
}
func processData(items []int) []int {
result := make([]int, 0)
for _, item := range items {
result = append(result, item*2)
}
return result
}
"#;
let path = Path::new("main.go");
let (start, end) = resolve_symbol_lines(source, path, "processData").unwrap();
assert_eq!(start, 7);
assert_eq!(end, 13);
}
#[test]
fn resolve_symbol_not_found() {
let source = br#"
fn main() {}
"#;
let path = Path::new("test.rs");
let err = resolve_symbol_lines(source, path, "nonexistent").unwrap_err();
assert!(matches!(err, SymbolResolveError::SymbolNotFound(_)));
}
#[test]
fn resolve_unsupported_extension() {
let source = b"some content";
let path = Path::new("test.xyz");
let err = resolve_symbol_lines(source, path, "main").unwrap_err();
assert!(matches!(err, SymbolResolveError::UnsupportedLanguage(_)));
}
#[test]
fn extract_line_range_basic() {
let source = b"line 1\nline 2\nline 3\nline 4\nline 5\n";
let result = extract_line_range(source, 2, 4);
assert_eq!(result, b"line 2\nline 3\nline 4\n");
}
#[test]
fn extract_line_range_single_line() {
let source = b"line 1\nline 2\nline 3\n";
let result = extract_line_range(source, 2, 2);
assert_eq!(result, b"line 2\n");
}
#[test]
fn resolve_js_function_declaration() {
let source = br#"
function helper() {
return true;
}
function processData(items) {
return items.map(x => x * 2);
}
"#;
let path = Path::new("main.js");
let (start, end) = resolve_symbol_lines(source, path, "processData").unwrap();
assert_eq!(start, 6);
assert_eq!(end, 8);
}
#[test]
fn resolve_js_arrow_function_const() {
let source = br#"
const helper = () => true;
const processData = (items) => {
return items.map(x => x * 2);
};
"#;
let path = Path::new("utils.js");
let (start, end) = resolve_symbol_lines(source, path, "processData").unwrap();
assert_eq!(start, 4);
assert_eq!(end, 6);
}
#[test]
fn resolve_typescript_object_literal_property_arrow_function() {
let source = br#"
export const db = {
query: async (sql: string) => {
return [];
},
insert: async (table: string, data: Record<string, any>) => {
const keys = Object.keys(data);
return keys;
},
};
"#;
let path = Path::new("db.ts");
let (start, end) = resolve_symbol_lines(source, path, "insert").unwrap();
assert!((5..=7).contains(&start), "got start={start}");
assert!(end > start && end <= 10, "got end={end}");
}
#[test]
fn resolve_typescript_function() {
let source = br#"
function helper(): boolean {
return true;
}
function processData(items: number[]): number[] {
return items.map(x => x * 2);
}
"#;
let path = Path::new("main.ts");
let (start, end) = resolve_symbol_lines(source, path, "processData").unwrap();
assert_eq!(start, 6);
assert_eq!(end, 8);
}
#[test]
fn resolve_all_returns_multiple_matches() {
let source = br#"
impl Foo {
fn do_thing(&self) {}
}
impl Bar {
fn do_thing(&self) {}
}
"#;
let path = Path::new("test.rs");
let results = resolve_all_symbols(source, path, "do_thing").unwrap();
assert_eq!(results.len(), 2);
assert_eq!(results[0].parent_name.as_deref(), Some("Foo"));
assert_eq!(results[1].parent_name.as_deref(), Some("Bar"));
}
#[test]
fn extract_definitions_reports_rust_taxonomy_parent_scopes_and_ranges() {
let source = br#"const LIMIT: usize = 10;
pub mod outer {
pub struct Widget {
pub id: u64,
}
pub enum Mode {
Fast,
Slow,
}
pub trait Runner {
fn run(&self);
}
pub type WidgetResult<T> = Result<T, Error>;
impl Widget {
pub fn build(id: u64) -> Self {
Self { id }
}
}
}
"#;
let defs = extract_definitions(source, Path::new("lib.rs")).unwrap();
assert_definition(&defs, "LIMIT", DefinitionKind::ConstDecl, 1, 1, None);
assert_definition(&defs, "outer", DefinitionKind::Module, 2, 23, None);
assert_definition(&defs, "Widget", DefinitionKind::Type, 3, 5, None);
assert_definition(&defs, "Mode", DefinitionKind::EnumDef, 7, 10, None);
assert_definition(&defs, "Runner", DefinitionKind::Trait, 12, 14, None);
assert_definition(
&defs,
"WidgetResult",
DefinitionKind::TypeAlias,
16,
16,
None,
);
assert_definition(
&defs,
"build",
DefinitionKind::Function,
19,
21,
Some("Widget"),
);
}
#[test]
fn extract_definitions_reports_typescript_taxonomy_parent_scopes_and_ranges() {
let source = br#"interface Service {
run(): void;
}
type Handler = (value: string) => void;
enum Status {
Ready,
Done,
}
class Controller {
start(): void {
handle("start");
}
}
export const handle = (value: string): void => {
console.log(value);
};
export const settings = { retry: 2 };
"#;
let defs = extract_definitions(source, Path::new("controller.ts")).unwrap();
assert_definition(&defs, "Service", DefinitionKind::Interface, 1, 3, None);
assert_definition(&defs, "Handler", DefinitionKind::TypeAlias, 5, 5, None);
assert_definition(&defs, "Status", DefinitionKind::EnumDef, 7, 10, None);
assert_definition(&defs, "Controller", DefinitionKind::Class, 12, 16, None);
assert_definition(
&defs,
"start",
DefinitionKind::Function,
13,
15,
Some("Controller"),
);
assert_definition(&defs, "handle", DefinitionKind::Function, 18, 20, None);
assert_definition(&defs, "settings", DefinitionKind::ConstDecl, 22, 22, None);
}
#[test]
fn extract_definitions_rejects_parse_error_trees() {
let err =
extract_definitions(b"fn broken( -> usize { 1 }", Path::new("broken.rs")).unwrap_err();
assert!(matches!(err, SymbolResolveError::ParseFailed));
}
fn assert_definition(
defs: &[Definition],
name: &str,
kind: DefinitionKind,
start_line: u32,
end_line: u32,
parent_name: Option<&str>,
) {
assert!(
defs.iter().any(|def| {
def.name == name
&& def.kind == kind
&& def.start_line == start_line
&& def.end_line == end_line
&& def.parent_name.as_deref() == parent_name
}),
"expected {name:?} {kind:?} lines {start_line}-{end_line} parent {parent_name:?}, got: {defs:?}"
);
}
}