use anyhow::{Context, Result};
use streaming_iterator::StreamingIterator;
use tree_sitter::{Parser, Query, QueryCursor};
use crate::models::{Language, SearchResult, Span, SymbolKind};
use crate::parsers::{DependencyExtractor, ImportInfo};
pub fn parse(path: &str, source: &str) -> Result<Vec<SearchResult>> {
let mut parser = Parser::new();
let language = tree_sitter_rust::LANGUAGE;
parser
.set_language(&language.into())
.context("Failed to set Rust language")?;
let tree = parser
.parse(source, None)
.context("Failed to parse Rust source")?;
let root_node = tree.root_node();
let mut symbols = Vec::new();
symbols.extend(extract_functions(source, &root_node)?);
symbols.extend(extract_structs(source, &root_node)?);
symbols.extend(extract_enums(source, &root_node)?);
symbols.extend(extract_traits(source, &root_node)?);
symbols.extend(extract_impls(source, &root_node)?);
symbols.extend(extract_constants(source, &root_node)?);
symbols.extend(extract_statics(source, &root_node)?);
symbols.extend(extract_local_variables(source, &root_node)?);
symbols.extend(extract_modules(source, &root_node)?);
symbols.extend(extract_type_aliases(source, &root_node)?);
symbols.extend(extract_macros(source, &root_node)?);
symbols.extend(extract_attributes(source, &root_node)?);
for symbol in &mut symbols {
symbol.path = path.to_string();
symbol.lang = Language::Rust;
}
Ok(symbols)
}
fn extract_functions(source: &str, root: &tree_sitter::Node) -> Result<Vec<SearchResult>> {
let language = tree_sitter_rust::LANGUAGE;
let query_str = r#"
(function_item
name: (identifier) @name) @function
"#;
let query = Query::new(&language.into(), query_str)
.context("Failed to create function query")?;
extract_symbols(source, root, &query, SymbolKind::Function, None)
}
fn extract_structs(source: &str, root: &tree_sitter::Node) -> Result<Vec<SearchResult>> {
let language = tree_sitter_rust::LANGUAGE;
let query_str = r#"
(struct_item
name: (type_identifier) @name) @struct
"#;
let query = Query::new(&language.into(), query_str)
.context("Failed to create struct query")?;
extract_symbols(source, root, &query, SymbolKind::Struct, None)
}
fn extract_enums(source: &str, root: &tree_sitter::Node) -> Result<Vec<SearchResult>> {
let language = tree_sitter_rust::LANGUAGE;
let query_str = r#"
(enum_item
name: (type_identifier) @name) @enum
"#;
let query = Query::new(&language.into(), query_str)
.context("Failed to create enum query")?;
extract_symbols(source, root, &query, SymbolKind::Enum, None)
}
fn extract_traits(source: &str, root: &tree_sitter::Node) -> Result<Vec<SearchResult>> {
let language = tree_sitter_rust::LANGUAGE;
let query_str = r#"
(trait_item
name: (type_identifier) @name) @trait
"#;
let query = Query::new(&language.into(), query_str)
.context("Failed to create trait query")?;
extract_symbols(source, root, &query, SymbolKind::Trait, None)
}
fn extract_impls(source: &str, root: &tree_sitter::Node) -> Result<Vec<SearchResult>> {
let language = tree_sitter_rust::LANGUAGE;
let query_str = r#"
(impl_item
type: (type_identifier) @impl_name
body: (declaration_list
(function_item
name: (identifier) @method_name))) @impl
"#;
let query = Query::new(&language.into(), query_str)
.context("Failed to create impl query")?;
let mut cursor = QueryCursor::new();
let mut matches = cursor.matches(&query, *root, source.as_bytes());
let mut symbols = Vec::new();
while let Some(match_) = matches.next() {
let mut impl_name = None;
let mut method_name = None;
let mut method_node = None;
for capture in match_.captures {
let capture_name: &str = &query.capture_names()[capture.index as usize];
match capture_name {
"impl_name" => {
impl_name = Some(capture.node.utf8_text(source.as_bytes()).unwrap_or("").to_string());
}
"method_name" => {
method_name = Some(capture.node.utf8_text(source.as_bytes()).unwrap_or("").to_string());
let mut current = capture.node;
while let Some(parent) = current.parent() {
if parent.kind() == "function_item" {
method_node = Some(parent);
break;
}
current = parent;
}
}
_ => {}
}
}
if let (Some(impl_name), Some(method_name), Some(node)) = (impl_name, method_name, method_node) {
let scope = format!("impl {}", impl_name);
let span = node_to_span(&node);
let preview = extract_preview(source, &span);
symbols.push(SearchResult::new(
String::new(), Language::Rust,
SymbolKind::Method,
Some(method_name),
span,
Some(scope),
preview,
));
}
}
Ok(symbols)
}
fn extract_constants(source: &str, root: &tree_sitter::Node) -> Result<Vec<SearchResult>> {
let language = tree_sitter_rust::LANGUAGE;
let query_str = r#"
(const_item
name: (identifier) @name) @const
"#;
let query = Query::new(&language.into(), query_str)
.context("Failed to create const query")?;
extract_symbols(source, root, &query, SymbolKind::Constant, None)
}
fn extract_statics(source: &str, root: &tree_sitter::Node) -> Result<Vec<SearchResult>> {
let language = tree_sitter_rust::LANGUAGE;
let query_str = r#"
(static_item
name: (identifier) @name) @static
"#;
let query = Query::new(&language.into(), query_str)
.context("Failed to create static query")?;
extract_symbols(source, root, &query, SymbolKind::Variable, None)
}
fn extract_local_variables(source: &str, root: &tree_sitter::Node) -> Result<Vec<SearchResult>> {
let language = tree_sitter_rust::LANGUAGE;
let query_str = r#"
(let_declaration
pattern: (identifier) @name) @let
"#;
let query = Query::new(&language.into(), query_str)
.context("Failed to create let declaration query")?;
extract_symbols(source, root, &query, SymbolKind::Variable, None)
}
fn extract_modules(source: &str, root: &tree_sitter::Node) -> Result<Vec<SearchResult>> {
let language = tree_sitter_rust::LANGUAGE;
let query_str = r#"
(mod_item
name: (identifier) @name) @module
"#;
let query = Query::new(&language.into(), query_str)
.context("Failed to create module query")?;
extract_symbols(source, root, &query, SymbolKind::Module, None)
}
fn extract_type_aliases(source: &str, root: &tree_sitter::Node) -> Result<Vec<SearchResult>> {
let language = tree_sitter_rust::LANGUAGE;
let query_str = r#"
(type_item
name: (type_identifier) @name) @type
"#;
let query = Query::new(&language.into(), query_str)
.context("Failed to create type query")?;
extract_symbols(source, root, &query, SymbolKind::Type, None)
}
fn extract_macros(source: &str, root: &tree_sitter::Node) -> Result<Vec<SearchResult>> {
let language = tree_sitter_rust::LANGUAGE;
let query_str = r#"
(macro_definition
name: (identifier) @name) @macro
"#;
let query = Query::new(&language.into(), query_str)
.context("Failed to create macro query")?;
extract_symbols(source, root, &query, SymbolKind::Macro, None)
}
fn extract_attributes(source: &str, root: &tree_sitter::Node) -> Result<Vec<SearchResult>> {
let language = tree_sitter_rust::LANGUAGE;
let mut symbols = Vec::new();
let func_query_str = r#"
(function_item
name: (identifier) @name) @function
"#;
let func_query = Query::new(&language.into(), func_query_str)
.context("Failed to create function query")?;
let mut cursor = QueryCursor::new();
let mut matches = cursor.matches(&func_query, *root, source.as_bytes());
while let Some(match_) = matches.next() {
let mut name = None;
let mut func_node = None;
for capture in match_.captures {
let capture_name: &str = &func_query.capture_names()[capture.index as usize];
match capture_name {
"name" => {
name = Some(capture.node.utf8_text(source.as_bytes()).unwrap_or("").to_string());
}
"function" => {
func_node = Some(capture.node);
}
_ => {}
}
}
if let (Some(name), Some(func_node)) = (name, func_node) {
let mut has_proc_macro_attr = false;
if let Some(parent) = func_node.parent() {
let mut func_index = None;
for i in 0..parent.child_count() {
if let Some(child) = parent.child(i) {
if child.id() == func_node.id() {
func_index = Some(i);
break;
}
}
}
if let Some(func_idx) = func_index {
for i in (0..func_idx).rev() {
if let Some(child) = parent.child(i) {
if child.kind() == "attribute_item" {
let attr_text = child.utf8_text(source.as_bytes()).unwrap_or("");
if attr_text.contains("proc_macro_attribute") {
has_proc_macro_attr = true;
}
} else if !child.kind().contains("comment") && child.kind() != "line_comment" {
break;
}
}
}
}
}
if has_proc_macro_attr {
let span = node_to_span(&func_node);
let preview = extract_preview(source, &span);
symbols.push(SearchResult::new(
String::new(),
Language::Rust,
SymbolKind::Attribute,
Some(name),
span,
None,
preview,
));
}
}
}
let attr_query_str = r#"
(attribute_item
(attribute
(identifier) @attr_name)) @attr
"#;
let attr_query = Query::new(&language.into(), attr_query_str)
.context("Failed to create attribute use query")?;
let mut cursor = QueryCursor::new();
let mut matches = cursor.matches(&attr_query, *root, source.as_bytes());
while let Some(match_) = matches.next() {
let mut attr_name = None;
let mut attr_node = None;
for capture in match_.captures {
let capture_name: &str = &attr_query.capture_names()[capture.index as usize];
match capture_name {
"attr_name" => {
attr_name = Some(capture.node.utf8_text(source.as_bytes()).unwrap_or("").to_string());
}
"attr" => {
attr_node = Some(capture.node);
}
_ => {}
}
}
if let (Some(name), Some(node)) = (attr_name, attr_node) {
let span = node_to_span(&node);
let preview = extract_preview(source, &span);
symbols.push(SearchResult::new(
String::new(),
Language::Rust,
SymbolKind::Attribute,
Some(name),
span,
None,
preview,
));
}
}
Ok(symbols)
}
fn extract_symbols(
source: &str,
root: &tree_sitter::Node,
query: &Query,
kind: SymbolKind,
scope: Option<String>,
) -> Result<Vec<SearchResult>> {
let mut cursor = QueryCursor::new();
let mut matches = cursor.matches(query, *root, source.as_bytes());
let mut symbols = Vec::new();
while let Some(match_) = matches.next() {
let mut name = None;
let mut full_node = None;
for capture in match_.captures {
let capture_name: &str = &query.capture_names()[capture.index as usize];
if capture_name == "name" {
name = Some(capture.node.utf8_text(source.as_bytes()).unwrap_or("").to_string());
} else {
full_node = Some(capture.node);
}
}
if let (Some(name), Some(node)) = (name, full_node) {
let span = node_to_span(&node);
let preview = extract_preview(source, &span);
symbols.push(SearchResult::new(
String::new(), Language::Rust,
kind.clone(),
Some(name),
span,
scope.clone(),
preview,
));
}
}
Ok(symbols)
}
fn node_to_span(node: &tree_sitter::Node) -> Span {
let start = node.start_position();
let end = node.end_position();
Span::new(
start.row + 1, start.column,
end.row + 1,
end.column,
)
}
fn extract_preview(source: &str, span: &Span) -> String {
let lines: Vec<&str> = source.lines().collect();
let start_idx = (span.start_line - 1) as usize; let end_idx = (start_idx + 7).min(lines.len());
lines[start_idx..end_idx].join("\n")
}
pub struct RustDependencyExtractor;
impl DependencyExtractor for RustDependencyExtractor {
fn extract_dependencies(source: &str) -> Result<Vec<ImportInfo>> {
let mut parser = Parser::new();
let language = tree_sitter_rust::LANGUAGE;
parser
.set_language(&language.into())
.context("Failed to set Rust language")?;
let tree = parser
.parse(source, None)
.context("Failed to parse Rust source")?;
let root_node = tree.root_node();
let mut imports = Vec::new();
imports.extend(extract_use_declarations(source, &root_node)?);
imports.extend(extract_mod_items(source, &root_node)?);
imports.extend(extract_extern_crates(source, &root_node)?);
Ok(imports)
}
}
fn extract_use_declarations(source: &str, root: &tree_sitter::Node) -> Result<Vec<ImportInfo>> {
let language = tree_sitter_rust::LANGUAGE;
let query_str = r#"
(use_declaration) @use
"#;
let query = Query::new(&language.into(), query_str)
.context("Failed to create use declaration query")?;
let mut cursor = QueryCursor::new();
let mut matches = cursor.matches(&query, *root, source.as_bytes());
let mut imports = Vec::new();
while let Some(match_) = matches.next() {
for capture in match_.captures {
let node = capture.node;
let text = node.utf8_text(source.as_bytes()).unwrap_or("");
let line_number = node.start_position().row + 1;
let path_info = parse_rust_use_declaration(text);
for (path, symbols) in path_info {
let import_type = classify_rust_import(&path);
imports.push(ImportInfo {
imported_path: path,
import_type,
line_number,
imported_symbols: symbols,
});
}
}
}
Ok(imports)
}
fn extract_mod_items(source: &str, root: &tree_sitter::Node) -> Result<Vec<ImportInfo>> {
let language = tree_sitter_rust::LANGUAGE;
let query_str = r#"
(mod_item
name: (identifier) @name) @mod
"#;
let query = Query::new(&language.into(), query_str)
.context("Failed to create mod item query")?;
let mut cursor = QueryCursor::new();
let mut matches = cursor.matches(&query, *root, source.as_bytes());
let mut imports = Vec::new();
while let Some(match_) = matches.next() {
let mut name = None;
let mut mod_node = None;
for capture in match_.captures {
let capture_name: &str = &query.capture_names()[capture.index as usize];
match capture_name {
"name" => {
name = Some(capture.node.utf8_text(source.as_bytes()).unwrap_or("").to_string());
}
"mod" => {
mod_node = Some(capture.node);
}
_ => {}
}
}
if let (Some(name), Some(node)) = (name, mod_node) {
let has_body = node.child_by_field_name("body").is_some();
if !has_body {
let line_number = node.start_position().row + 1;
imports.push(ImportInfo {
imported_path: name,
import_type: crate::models::ImportType::Internal,
line_number,
imported_symbols: None,
});
}
}
}
Ok(imports)
}
fn extract_extern_crates(source: &str, root: &tree_sitter::Node) -> Result<Vec<ImportInfo>> {
let language = tree_sitter_rust::LANGUAGE;
let query_str = r#"
(extern_crate_declaration
name: (identifier) @name) @extern
"#;
let query = Query::new(&language.into(), query_str)
.context("Failed to create extern crate query")?;
let mut cursor = QueryCursor::new();
let mut matches = cursor.matches(&query, *root, source.as_bytes());
let mut imports = Vec::new();
while let Some(match_) = matches.next() {
let mut name = None;
let mut extern_node = None;
for capture in match_.captures {
let capture_name: &str = &query.capture_names()[capture.index as usize];
match capture_name {
"name" => {
name = Some(capture.node.utf8_text(source.as_bytes()).unwrap_or("").to_string());
}
"extern" => {
extern_node = Some(capture.node);
}
_ => {}
}
}
if let (Some(name), Some(node)) = (name, extern_node) {
let line_number = node.start_position().row + 1;
let import_type = classify_rust_import(&name);
imports.push(ImportInfo {
imported_path: name,
import_type,
line_number,
imported_symbols: None,
});
}
}
Ok(imports)
}
fn classify_rust_import(path: &str) -> crate::models::ImportType {
use crate::models::ImportType;
if path.starts_with("std::") || path.starts_with("core::") || path.starts_with("alloc::") {
ImportType::Stdlib
} else if path.starts_with("crate::") || path.starts_with("super::") || path.starts_with("self::") {
ImportType::Internal
} else {
ImportType::External
}
}
fn parse_rust_use_declaration(text: &str) -> Vec<(String, Option<Vec<String>>)> {
let text = text.trim()
.strip_prefix("pub(crate)").unwrap_or(text)
.trim()
.strip_prefix("pub(super)").unwrap_or(text)
.trim()
.strip_prefix("pub").unwrap_or(text)
.trim()
.strip_prefix("use").unwrap_or(text)
.trim()
.strip_suffix(";").unwrap_or(text)
.trim();
if text.contains('{') {
if let Some(idx) = text.find('{') {
let base_path = text[..idx].trim_end_matches("::").to_string();
if let Some(end) = text.find('}') {
let symbols_str = &text[idx + 1..end];
let symbols: Vec<String> = symbols_str
.split(',')
.map(|s| {
let trimmed = s.trim();
if let Some(as_idx) = trimmed.find(" as ") {
trimmed[..as_idx].trim().to_string()
} else {
trimmed.to_string()
}
})
.filter(|s| !s.is_empty() && s != "*")
.collect();
if !symbols.is_empty() {
return vec![(base_path, Some(symbols))];
}
}
}
}
let path = if let Some(as_idx) = text.find(" as ") {
text[..as_idx].trim().to_string()
} else {
text.to_string()
};
vec![(path, None)]
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_function() {
let source = r#"
fn hello_world() {
println!("Hello, world!");
}
"#;
let symbols = parse("test.rs", source).unwrap();
assert_eq!(symbols.len(), 1);
assert_eq!(symbols[0].symbol.as_deref(), Some("hello_world"));
assert!(matches!(symbols[0].kind, SymbolKind::Function));
}
#[test]
fn test_parse_struct() {
let source = r#"
struct User {
name: String,
age: u32,
}
"#;
let symbols = parse("test.rs", source).unwrap();
assert_eq!(symbols.len(), 1);
assert_eq!(symbols[0].symbol.as_deref(), Some("User"));
assert!(matches!(symbols[0].kind, SymbolKind::Struct));
}
#[test]
fn test_parse_impl() {
let source = r#"
struct User {
name: String,
}
impl User {
fn new(name: String) -> Self {
User { name }
}
fn get_name(&self) -> &str {
&self.name
}
}
"#;
let symbols = parse("test.rs", source).unwrap();
assert!(symbols.len() >= 3);
let method_symbols: Vec<_> = symbols.iter()
.filter(|s| matches!(s.kind, SymbolKind::Method))
.collect();
assert_eq!(method_symbols.len(), 2);
assert!(method_symbols.iter().any(|s| s.symbol.as_deref() == Some("new")));
assert!(method_symbols.iter().any(|s| s.symbol.as_deref() == Some("get_name")));
}
#[test]
fn test_parse_enum() {
let source = r#"
enum Status {
Active,
Inactive,
}
"#;
let symbols = parse("test.rs", source).unwrap();
assert_eq!(symbols.len(), 1);
assert_eq!(symbols[0].symbol.as_deref(), Some("Status"));
assert!(matches!(symbols[0].kind, SymbolKind::Enum));
}
#[test]
fn test_parse_trait() {
let source = r#"
trait Drawable {
fn draw(&self);
}
"#;
let symbols = parse("test.rs", source).unwrap();
assert_eq!(symbols.len(), 1);
assert_eq!(symbols[0].symbol.as_deref(), Some("Drawable"));
assert!(matches!(symbols[0].kind, SymbolKind::Trait));
}
#[test]
fn test_parse_multiple_symbols() {
let source = r#"
const MAX_SIZE: usize = 100;
struct Config {
size: usize,
}
fn create_config() -> Config {
Config { size: MAX_SIZE }
}
"#;
let symbols = parse("test.rs", source).unwrap();
assert_eq!(symbols.len(), 3);
let kinds: Vec<&SymbolKind> = symbols.iter().map(|s| &s.kind).collect();
assert!(kinds.contains(&&SymbolKind::Constant));
assert!(kinds.contains(&&SymbolKind::Struct));
assert!(kinds.contains(&&SymbolKind::Function));
}
#[test]
fn test_local_variables_included() {
let source = r#"
fn calculate(input: i32) -> i32 {
let local_var = input * 2;
let result = local_var + 10;
result
}
struct Calculator;
impl Calculator {
fn compute(&self, value: i32) -> i32 {
let temp = value * 3;
let mut final_value = temp + 5;
final_value += 1;
final_value
}
}
"#;
let symbols = parse("test.rs", source).unwrap();
let variables: Vec<_> = symbols.iter()
.filter(|s| matches!(s.kind, SymbolKind::Variable))
.collect();
assert!(variables.iter().any(|v| v.symbol.as_deref() == Some("local_var")));
assert!(variables.iter().any(|v| v.symbol.as_deref() == Some("result")));
assert!(variables.iter().any(|v| v.symbol.as_deref() == Some("temp")));
assert!(variables.iter().any(|v| v.symbol.as_deref() == Some("final_value")));
}
#[test]
fn test_static_variables() {
let source = r#"
static GLOBAL_COUNTER: std::sync::atomic::AtomicUsize = std::sync::atomic::AtomicUsize::new(0);
static mut MUTABLE_GLOBAL: i32 = 0;
const MAX_SIZE: usize = 100;
fn increment() {
GLOBAL_COUNTER.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
}
"#;
let symbols = parse("test.rs", source).unwrap();
let statics: Vec<_> = symbols.iter()
.filter(|s| matches!(s.kind, SymbolKind::Variable))
.collect();
let constants: Vec<_> = symbols.iter()
.filter(|s| matches!(s.kind, SymbolKind::Constant))
.collect();
assert!(statics.iter().any(|v| v.symbol.as_deref() == Some("GLOBAL_COUNTER")));
assert!(statics.iter().any(|v| v.symbol.as_deref() == Some("MUTABLE_GLOBAL")));
assert!(constants.iter().any(|c| c.symbol.as_deref() == Some("MAX_SIZE")));
}
#[test]
fn test_macros() {
let source = r#"
macro_rules! say_hello {
() => {
println!("Hello!");
};
}
macro_rules! vec_of_strings {
($($x:expr),*) => {
vec![$($x.to_string()),*]
};
}
fn main() {
say_hello!();
}
"#;
let symbols = parse("test.rs", source).unwrap();
let macros: Vec<_> = symbols.iter()
.filter(|s| matches!(s.kind, SymbolKind::Macro))
.collect();
assert!(macros.iter().any(|m| m.symbol.as_deref() == Some("say_hello")));
assert!(macros.iter().any(|m| m.symbol.as_deref() == Some("vec_of_strings")));
assert_eq!(macros.len(), 2);
}
#[test]
fn test_attribute_proc_macros() {
let source = r#"
use proc_macro::TokenStream;
#[proc_macro_attribute]
pub fn test(attr: TokenStream, item: TokenStream) -> TokenStream {
item
}
#[proc_macro_attribute]
pub fn route(attr: TokenStream, item: TokenStream) -> TokenStream {
item
}
// Regular function - should NOT be captured
pub fn helper() {}
"#;
let symbols = parse("test.rs", source).unwrap();
let attributes: Vec<_> = symbols.iter()
.filter(|s| matches!(s.kind, SymbolKind::Attribute))
.collect();
assert!(attributes.iter().any(|a| a.symbol.as_deref() == Some("test")));
assert!(attributes.iter().any(|a| a.symbol.as_deref() == Some("route")));
assert!(!attributes.iter().any(|a| a.symbol.as_deref() == Some("helper")));
assert_eq!(attributes.len(), 4);
}
#[test]
fn test_attribute_uses() {
let source = r#"
#[test]
fn test_something() {
assert_eq!(1, 1);
}
#[test]
#[should_panic]
fn test_panic() {
panic!("expected");
}
#[derive(Debug, Clone)]
struct MyStruct {
field: i32
}
#[cfg(test)]
mod tests {
#[test]
fn nested_test() {}
}
"#;
let symbols = parse("test.rs", source).unwrap();
let attributes: Vec<_> = symbols.iter()
.filter(|s| matches!(s.kind, SymbolKind::Attribute))
.collect();
assert!(attributes.iter().any(|a| a.symbol.as_deref() == Some("test")));
assert!(attributes.iter().any(|a| a.symbol.as_deref() == Some("should_panic")));
assert!(attributes.iter().any(|a| a.symbol.as_deref() == Some("derive")));
assert!(attributes.iter().any(|a| a.symbol.as_deref() == Some("cfg")));
assert_eq!(attributes.len(), 6);
}
#[test]
fn test_extract_dependencies_use_declarations() {
let source = r#"
use std::collections::HashMap;
use crate::models::{Language, SearchResult};
use super::utils;
use anyhow::Result;
"#;
let deps = RustDependencyExtractor::extract_dependencies(source).unwrap();
assert_eq!(deps.len(), 4);
let std_import = deps.iter().find(|d| d.imported_path == "std::collections::HashMap").unwrap();
assert!(matches!(std_import.import_type, crate::models::ImportType::Stdlib));
let crate_import = deps.iter().find(|d| d.imported_path == "crate::models").unwrap();
assert!(matches!(crate_import.import_type, crate::models::ImportType::Internal));
assert!(crate_import.imported_symbols.is_some());
let symbols = crate_import.imported_symbols.as_ref().unwrap();
assert_eq!(symbols.len(), 2);
assert!(symbols.contains(&"Language".to_string()));
assert!(symbols.contains(&"SearchResult".to_string()));
let super_import = deps.iter().find(|d| d.imported_path == "super::utils").unwrap();
assert!(matches!(super_import.import_type, crate::models::ImportType::Internal));
let external_import = deps.iter().find(|d| d.imported_path == "anyhow::Result").unwrap();
assert!(matches!(external_import.import_type, crate::models::ImportType::External));
}
#[test]
fn test_extract_dependencies_mod_declarations() {
let source = r#"
mod parser;
mod utils;
mod inline {
fn test() {}
}
"#;
let deps = RustDependencyExtractor::extract_dependencies(source).unwrap();
assert_eq!(deps.len(), 2);
assert!(deps.iter().any(|d| d.imported_path == "parser"));
assert!(deps.iter().any(|d| d.imported_path == "utils"));
assert!(deps.iter().all(|d| matches!(d.import_type, crate::models::ImportType::Internal)));
}
#[test]
fn test_extract_dependencies_extern_crate() {
let source = r#"
extern crate serde;
extern crate serde_json;
"#;
let deps = RustDependencyExtractor::extract_dependencies(source).unwrap();
assert_eq!(deps.len(), 2);
assert!(deps.iter().any(|d| d.imported_path == "serde"));
assert!(deps.iter().any(|d| d.imported_path == "serde_json"));
assert!(deps.iter().all(|d| matches!(d.import_type, crate::models::ImportType::External)));
}
#[test]
fn test_parse_use_with_aliases() {
let source = r#"
use std::io::Result as IoResult;
use std::collections::{HashMap as Map, HashSet};
"#;
let deps = RustDependencyExtractor::extract_dependencies(source).unwrap();
let io_import = deps.iter().find(|d| d.imported_path == "std::io::Result").unwrap();
assert!(matches!(io_import.import_type, crate::models::ImportType::Stdlib));
let collections_import = deps.iter().find(|d| d.imported_path == "std::collections").unwrap();
let symbols = collections_import.imported_symbols.as_ref().unwrap();
assert_eq!(symbols.len(), 2);
assert!(symbols.contains(&"HashMap".to_string()));
assert!(symbols.contains(&"HashSet".to_string()));
}
#[test]
fn test_classify_rust_imports() {
use crate::models::ImportType;
assert!(matches!(classify_rust_import("std::collections::HashMap"), ImportType::Stdlib));
assert!(matches!(classify_rust_import("core::ptr"), ImportType::Stdlib));
assert!(matches!(classify_rust_import("alloc::vec::Vec"), ImportType::Stdlib));
assert!(matches!(classify_rust_import("crate::models::Language"), ImportType::Internal));
assert!(matches!(classify_rust_import("super::utils"), ImportType::Internal));
assert!(matches!(classify_rust_import("self::helper"), ImportType::Internal));
assert!(matches!(classify_rust_import("serde::Serialize"), ImportType::External));
assert!(matches!(classify_rust_import("anyhow::Result"), ImportType::External));
assert!(matches!(classify_rust_import("tokio::runtime"), ImportType::External));
}
}
fn find_crate_root(start_path: &str) -> Option<String> {
let path = std::path::Path::new(start_path);
let mut current = path.parent()?;
loop {
let cargo_toml = current.join("Cargo.toml");
if cargo_toml.exists() {
return Some(current.to_string_lossy().to_string());
}
if current.ends_with("src") {
if let Some(parent) = current.parent() {
return Some(parent.to_string_lossy().to_string());
}
}
current = match current.parent() {
Some(p) if p.as_os_str().is_empty() => return None,
Some(p) => p,
None => return None,
};
}
}
pub fn resolve_rust_use_to_path(
import_path: &str,
current_file_path: Option<&str>,
_project_root: Option<&str>,
) -> Option<String> {
if !import_path.starts_with("crate::")
&& !import_path.starts_with("super::")
&& !import_path.starts_with("self::") {
if import_path.contains("::") {
return None; }
}
let current_file = current_file_path?;
let current_path = std::path::Path::new(current_file);
let crate_root = find_crate_root(current_file)?;
let crate_root_path = std::path::Path::new(&crate_root);
if import_path.starts_with("crate::") {
let module_path = import_path.strip_prefix("crate::").unwrap();
let parts: Vec<&str> = module_path.split("::").collect();
let src_root = crate_root_path.join("src");
resolve_rust_module_path(&src_root, &parts)
} else if import_path.starts_with("super::") {
let module_path = import_path.strip_prefix("super::").unwrap();
let parts: Vec<&str> = module_path.split("::").collect();
let current_dir = if current_path.file_name().unwrap() == "mod.rs" {
current_path.parent()?.parent()?
} else {
current_path.parent()?
};
resolve_rust_module_path(current_dir, &parts)
} else if import_path.starts_with("self::") {
let module_path = import_path.strip_prefix("self::").unwrap();
let parts: Vec<&str> = module_path.split("::").collect();
let current_dir = if current_path.file_name().unwrap() == "mod.rs" {
current_path.parent()?
} else {
current_path.parent()?
};
resolve_rust_module_path(current_dir, &parts)
} else {
let current_dir = current_path.parent()?;
let module_file = current_dir.join(format!("{}.rs", import_path));
let module_dir = current_dir.join(import_path).join("mod.rs");
if module_file.exists() {
Some(module_file.to_string_lossy().to_string())
} else if module_dir.exists() {
Some(module_dir.to_string_lossy().to_string())
} else {
Some(module_file.to_string_lossy().to_string())
}
}
}
fn resolve_rust_module_path(base_dir: &std::path::Path, parts: &[&str]) -> Option<String> {
if parts.is_empty() {
return None;
}
let mut current_path = base_dir.to_path_buf();
for (i, part) in parts.iter().enumerate() {
if i == parts.len() - 1 {
let file_path = current_path.join(format!("{}.rs", part));
let mod_path = current_path.join(part).join("mod.rs");
log::trace!("Checking Rust module path: {}", file_path.display());
log::trace!("Checking Rust module path: {}", mod_path.display());
if file_path.exists() {
return Some(file_path.to_string_lossy().to_string());
} else if mod_path.exists() {
return Some(mod_path.to_string_lossy().to_string());
} else {
return Some(file_path.to_string_lossy().to_string());
}
} else {
current_path = current_path.join(part);
}
}
None
}
#[cfg(test)]
mod path_resolution_tests {
use super::*;
#[test]
fn test_resolve_crate_import() {
let result = resolve_rust_use_to_path(
"crate::models",
Some("/home/user/project/src/main.rs"),
Some("/home/user/project"),
);
assert!(result.is_some());
let path = result.unwrap();
assert!(path.contains("models.rs") || path.contains("models/mod.rs"));
}
#[test]
fn test_resolve_super_import() {
let result = resolve_rust_use_to_path(
"super::utils",
Some("/home/user/project/src/commands/index.rs"),
Some("/home/user/project"),
);
assert!(result.is_some());
let path = result.unwrap();
assert!(path.contains("src") && path.contains("utils.rs"));
}
#[test]
fn test_resolve_self_import() {
let result = resolve_rust_use_to_path(
"self::helper",
Some("/home/user/project/src/models/mod.rs"),
Some("/home/user/project"),
);
assert!(result.is_some());
let path = result.unwrap();
assert!(path.contains("models") && path.contains("helper.rs"));
}
#[test]
fn test_resolve_mod_declaration() {
let result = resolve_rust_use_to_path(
"parser",
Some("/home/user/project/src/main.rs"),
Some("/home/user/project"),
);
assert!(result.is_some());
let path = result.unwrap();
assert!(path.contains("parser.rs"));
}
#[test]
fn test_resolve_nested_crate_import() {
let result = resolve_rust_use_to_path(
"crate::models::language",
Some("/home/user/project/src/main.rs"),
Some("/home/user/project"),
);
assert!(result.is_some());
let path = result.unwrap();
assert!(path.contains("models") && (path.contains("language.rs") || path.contains("language/mod.rs")));
}
#[test]
fn test_external_import_not_supported() {
let result = resolve_rust_use_to_path(
"anyhow::Result",
Some("/home/user/project/src/main.rs"),
Some("/home/user/project"),
);
assert!(result.is_none());
}
#[test]
fn test_stdlib_import_not_supported() {
let result = resolve_rust_use_to_path(
"std::collections::HashMap",
Some("/home/user/project/src/main.rs"),
Some("/home/user/project"),
);
assert!(result.is_none());
}
#[test]
fn test_resolve_without_current_file() {
let result = resolve_rust_use_to_path(
"crate::models",
None,
Some("/home/user/project"),
);
assert!(result.is_none());
}
}