use crate::error::{Result, SpliceError};
use crate::graph::CodeGraph;
use crate::ingest::imports::extract_rust_imports;
use crate::ingest::rust::{extract_rust_symbols, RustSymbol, RustSymbolKind, Visibility};
use crate::resolve::references::{Reference, ReferenceContext, ReferenceSet, SymbolDefinition};
use ropey::Rope;
use std::path::{Path, PathBuf};
pub fn find_rust_references(
_graph: &CodeGraph,
file_path: &Path,
symbol_name: &str,
symbol_kind: Option<RustSymbolKind>,
) -> Result<ReferenceSet> {
let source = std::fs::read(file_path)?;
let rope = Rope::from_str(std::str::from_utf8(&source)?);
let symbols = extract_rust_symbols(file_path, &source)?;
let target_symbol = symbols
.iter()
.find(|s| s.name == symbol_name && symbol_kind.is_none_or(|k| s.kind == k))
.ok_or_else(|| SpliceError::symbol_not_found(symbol_name, Some(file_path)))?;
let same_file_refs = find_same_file_references(&source, &rope, target_symbol, file_path)?;
let (cross_file_refs, has_glob_ambiguity) = if target_symbol.visibility != Visibility::Private {
find_cross_file_references(file_path, target_symbol)?
} else {
(Vec::new(), false)
};
let mut all_refs = same_file_refs;
all_refs.extend(cross_file_refs);
all_refs.sort_by_key(|r| std::cmp::Reverse(r.byte_start));
Ok(ReferenceSet {
references: all_refs,
definition: SymbolDefinition {
name: target_symbol.name.clone(),
kind: target_symbol.kind,
file_path: file_path.to_str().unwrap_or("").to_string(),
byte_start: target_symbol.byte_start,
byte_end: target_symbol.byte_end,
is_public: target_symbol.visibility != Visibility::Private,
},
has_glob_ambiguity,
})
}
#[derive(Debug, Clone)]
struct ScopedSymbol {
name: String,
declaration_pos: usize,
}
#[derive(Debug, Clone)]
struct Scope {
start: usize,
end: usize,
symbols: Vec<ScopedSymbol>,
#[allow(dead_code)]
parent: Option<usize>,
}
#[derive(Debug, Clone)]
struct ScopeMap {
scopes: Vec<Scope>,
}
impl ScopeMap {
fn new() -> Self {
Self { scopes: Vec::new() }
}
fn add_scope(&mut self, start: usize, end: usize, parent: Option<usize>) -> usize {
let idx = self.scopes.len();
self.scopes.push(Scope {
start,
end,
symbols: Vec::new(),
parent,
});
idx
}
fn add_symbol(&mut self, scope_idx: usize, name: String, declaration_pos: usize) {
if let Some(scope) = self.scopes.get_mut(scope_idx) {
scope.symbols.push(ScopedSymbol {
name,
declaration_pos,
});
}
}
fn is_shadowed_at(&self, name: &str, byte_offset: usize) -> bool {
for scope in &self.scopes {
if byte_offset >= scope.start && byte_offset < scope.end {
for symbol in &scope.symbols {
if symbol.name == name && byte_offset >= symbol.declaration_pos {
return true;
}
}
}
}
false
}
}
fn build_scope_map(source: &[u8]) -> Result<ScopeMap> {
let mut scope_map = ScopeMap::new();
let mut parser = tree_sitter::Parser::new();
parser
.set_language(&tree_sitter_rust::language())
.map_err(|e| SpliceError::Parse {
file: PathBuf::from("<source>"),
message: format!("Failed to set Rust language: {:?}", e),
})?;
let tree = parser
.parse(source, None)
.ok_or_else(|| SpliceError::Parse {
file: PathBuf::from("<source>"),
message: "Parse failed - no tree returned".to_string(),
})?;
let file_scope = scope_map.add_scope(0, source.len(), None);
build_scopes_recursive(tree.root_node(), source, &mut scope_map, file_scope);
Ok(scope_map)
}
fn build_scopes_recursive(
node: tree_sitter::Node,
source: &[u8],
scope_map: &mut ScopeMap,
current_scope: usize,
) {
match node.kind() {
"function_item" => {
if let Some(body) = node.child_by_field_name("body") {
let func_name = node
.child_by_field_name("name")
.and_then(|n| n.utf8_text(source).ok())
.map(|s| s.to_string());
let scope_idx =
scope_map.add_scope(body.start_byte(), body.end_byte(), Some(current_scope));
let is_nested_function = current_scope > 0;
if is_nested_function {
if let Some(name) = func_name {
scope_map.add_symbol(current_scope, name, node.start_byte());
}
}
if let Some(params) = node.child_by_field_name("parameters") {
for (i, name) in
extract_param_names(params, source, &mut std::collections::HashSet::new())
.into_iter()
.enumerate()
{
scope_map.add_symbol(scope_idx, name, body.start_byte() + i);
}
}
let mut cursor = body.walk();
for child in body.children(&mut cursor) {
build_scopes_recursive(child, source, scope_map, scope_idx);
}
return;
}
}
"closure_expression" => {
let scope_idx =
scope_map.add_scope(node.start_byte(), node.end_byte(), Some(current_scope));
if let Some(params) = node.child_by_field_name("parameters") {
for (i, name) in
extract_param_names(params, source, &mut std::collections::HashSet::new())
.into_iter()
.enumerate()
{
scope_map.add_symbol(scope_idx, name, node.start_byte() + i);
}
}
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
build_scopes_recursive(child, source, scope_map, scope_idx);
}
return;
}
"let_declaration" => {
if let Some(name) = extract_let_binding_name(node, source) {
scope_map.add_symbol(current_scope, name, node.start_byte());
}
}
"match_arm" => {
if let Some(pattern) = node.child_by_field_name("pattern") {
let bindings = extract_pattern_bindings(pattern, source);
for binding in bindings {
scope_map.add_symbol(current_scope, binding, node.start_byte());
}
}
}
"block" => {
let scope_idx =
scope_map.add_scope(node.start_byte(), node.end_byte(), Some(current_scope));
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
build_scopes_recursive(child, source, scope_map, scope_idx);
}
return;
}
_ => {}
}
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
build_scopes_recursive(child, source, scope_map, current_scope);
}
}
fn extract_param_names(
node: tree_sitter::Node,
source: &[u8],
_seen: &mut std::collections::HashSet<String>,
) -> Vec<String> {
let mut names = Vec::new();
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
match child.kind() {
"parameter" => {
if let Some(name_node) = child.child_by_field_name("name") {
if let Ok(name) = name_node.utf8_text(source) {
names.push(name.to_string());
}
} else {
let mut inner_cursor = child.walk();
for inner_child in child.children(&mut inner_cursor) {
if inner_child.kind() == "identifier" {
if let Ok(name) = inner_child.utf8_text(source) {
names.push(name.to_string());
break;
}
}
}
}
}
"," => continue,
_ => {}
}
}
names
}
fn extract_let_binding_name(node: tree_sitter::Node, source: &[u8]) -> Option<String> {
if let Some(pattern) = node.child_by_field_name("pattern") {
if pattern.kind() == "identifier" {
if let Ok(name) = pattern.utf8_text(source) {
return Some(name.to_string());
}
}
let mut cursor = pattern.walk();
for child in pattern.children(&mut cursor) {
if child.kind() == "identifier" {
if let Ok(name) = child.utf8_text(source) {
return Some(name.to_string());
}
}
}
}
None
}
fn extract_pattern_bindings(node: tree_sitter::Node, source: &[u8]) -> Vec<String> {
let mut bindings = Vec::new();
match node.kind() {
"identifier" => {
if let Ok(name) = node.utf8_text(source) {
bindings.push(name.to_string());
}
}
"tuple_pattern" | "struct_pattern" => {
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.kind() == "identifier" {
if let Ok(name) = child.utf8_text(source) {
bindings.push(name.to_string());
}
}
}
}
_ => {
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
bindings.extend(extract_pattern_bindings(child, source));
}
}
}
bindings
}
fn find_same_file_references(
source: &[u8],
rope: &Rope,
target_symbol: &RustSymbol,
file_path: &Path,
) -> Result<Vec<Reference>> {
let mut references = Vec::new();
let scope_map = build_scope_map(source)?;
let mut parser = tree_sitter::Parser::new();
parser
.set_language(&tree_sitter_rust::language())
.map_err(|e| SpliceError::Parse {
file: file_path.to_path_buf(),
message: format!("Failed to set Rust language: {:?}", e),
})?;
let tree = parser
.parse(source, None)
.ok_or_else(|| SpliceError::Parse {
file: file_path.to_path_buf(),
message: "Parse failed - no tree returned".to_string(),
})?;
find_identifiers_recursive(
tree.root_node(),
source,
rope,
target_symbol,
file_path,
&scope_map,
&mut references,
);
references.retain(|r| {
!(r.byte_start >= target_symbol.byte_start && r.byte_end <= target_symbol.byte_end)
});
Ok(references)
}
#[derive(Debug, Clone)]
struct Reexport {
reexporting_module: String,
#[allow(dead_code)]
reexported_name: String,
#[allow(dead_code)]
replaced_module: String,
#[allow(dead_code)]
replaced_name: String,
}
fn build_reexport_map(
workspace_root: &Path,
rust_files: &[PathBuf],
) -> Result<std::collections::HashMap<(String, String), Vec<Reexport>>> {
let mut reexport_map: std::collections::HashMap<(String, String), Vec<Reexport>> =
std::collections::HashMap::new();
for file_path in rust_files {
let source = match std::fs::read(file_path) {
Ok(s) => s,
Err(_) => continue,
};
let imports = match extract_rust_imports(file_path, &source) {
Ok(i) => i,
Err(_) => continue,
};
let module_path = match module_path_from_file(workspace_root, file_path) {
Ok(m) => m,
Err(_) => continue,
};
for import in imports {
if !import.is_reexport {
continue;
}
let imported_module = import.path.join("::");
for name in &import.imported_names {
if name == "*" {
continue;
}
let reexport = Reexport {
reexporting_module: module_path.clone(),
reexported_name: name.clone(),
replaced_module: imported_module.clone(),
replaced_name: name.clone(),
};
let key = (imported_module.clone(), name.clone());
reexport_map.entry(key).or_default().push(reexport);
}
}
}
Ok(reexport_map)
}
fn module_path_from_file(workspace_root: &Path, file_path: &Path) -> Result<String> {
let relative = file_path
.strip_prefix(workspace_root)
.map_err(|_| SpliceError::Other("File not in workspace".to_string()))?;
let path_str = relative
.to_str()
.ok_or_else(|| SpliceError::Other("Invalid UTF-8 in path".to_string()))?;
let module_path = path_str
.trim_end_matches(".rs")
.replace("/", "::")
.replace("\\", "::");
let module_path = module_path.replace("::mod", "");
let module_path = if module_path.starts_with("crate::") {
module_path
} else if module_path.starts_with("lib::") || module_path.starts_with("src::") {
let rest = module_path
.split("::")
.skip(1)
.collect::<Vec<_>>()
.join("::");
format!("crate::{}", rest)
} else {
format!("crate::{}", module_path)
};
Ok(module_path)
}
fn module_reexports_symbol(
module_path: &str,
target_module: &str,
target_symbol: &str,
reexport_map: &std::collections::HashMap<(String, String), Vec<Reexport>>,
) -> bool {
let key = (target_module.to_string(), target_symbol.to_string());
if let Some(reexports) = reexport_map.get(&key) {
for reexport in reexports {
if reexport.reexporting_module == module_path {
return true;
}
}
}
false
}
fn find_cross_file_references(
definition_file: &Path,
target_symbol: &RustSymbol,
) -> Result<(Vec<Reference>, bool)> {
let mut all_references = Vec::new();
let mut has_glob_ambiguity = false;
let workspace_root = find_workspace_root(definition_file)?;
let rust_files = find_all_rust_files(&workspace_root)?;
let reexport_map = match build_reexport_map(&workspace_root, &rust_files) {
Ok(m) => m,
Err(e) => {
eprintln!("Warning: failed to build re-export map: {}", e);
std::collections::HashMap::new()
}
};
let target_module = &target_symbol.module_path;
for file_path in rust_files {
if file_path == definition_file {
continue;
}
let source = match std::fs::read(&file_path) {
Ok(s) => s,
Err(_) => continue, };
let imports = match extract_rust_imports(&file_path, &source) {
Ok(i) => i,
Err(_) => continue, };
let (matches, has_glob) =
import_matches_module(&imports, target_module, &target_symbol.name);
let matches_reexport =
check_reexport_matches(&imports, target_module, &target_symbol.name, &reexport_map);
if has_glob {
has_glob_ambiguity = true;
}
if matches || matches_reexport {
let rope = Rope::from_str(std::str::from_utf8(&source)?);
let refs = find_references_in_file(&source, &rope, target_symbol, &file_path)?;
all_references.extend(refs);
}
}
Ok((all_references, has_glob_ambiguity))
}
fn check_reexport_matches(
imports: &[crate::ingest::imports::ImportFact],
target_module: &str,
target_symbol: &str,
reexport_map: &std::collections::HashMap<(String, String), Vec<Reexport>>,
) -> bool {
for import in imports {
let imported_module = import.path.join("::");
for name in &import.imported_names {
if name == "*" {
if module_reexports_symbol(
&imported_module,
target_module,
target_symbol,
reexport_map,
) {
return true;
}
} else if name == target_symbol {
if module_reexports_symbol(
&imported_module,
target_module,
target_symbol,
reexport_map,
) {
return true;
}
}
}
}
false
}
fn find_workspace_root(start_path: &Path) -> Result<PathBuf> {
let mut current = start_path
.parent()
.ok_or_else(|| SpliceError::Other("Cannot determine workspace root".to_string()))?;
loop {
let cargo_toml = current.join("Cargo.toml");
if cargo_toml.exists() {
return Ok(current.to_path_buf());
}
match current.parent() {
Some(parent) => current = parent,
None => {
return Err(SpliceError::Other(
"Cargo.toml not found in any parent directory".to_string(),
))
}
}
}
}
fn find_all_rust_files(workspace_root: &Path) -> Result<Vec<PathBuf>> {
let mut rust_files = Vec::new();
fn visit_dirs(dir: &Path, rust_files: &mut Vec<PathBuf>) -> Result<()> {
if dir
.file_name()
.map(|n| n.to_str().unwrap_or(""))
.unwrap_or("")
== "target"
{
return Ok(());
}
if dir
.file_name()
.map(|n| n.to_str().unwrap_or(""))
.unwrap_or("")
== ".git"
{
return Ok(());
}
if dir
.file_name()
.and_then(|n| n.to_str())
.map(|s| s.starts_with('.'))
.unwrap_or(false)
{
return Ok(());
}
let entries = match std::fs::read_dir(dir) {
Ok(e) => e,
Err(_) => return Ok(()), };
for entry in entries {
let entry = match entry {
Ok(e) => e,
Err(_) => continue,
};
let path = entry.path();
if path.is_dir() {
visit_dirs(&path, rust_files)?;
} else if path.extension().and_then(|s| s.to_str()) == Some("rs") {
rust_files.push(path);
}
}
Ok(())
}
visit_dirs(workspace_root, &mut rust_files)?;
Ok(rust_files)
}
fn import_matches_module(
imports: &[crate::ingest::imports::ImportFact],
target_module: &str,
target_symbol_name: &str,
) -> (bool, bool) {
let mut matches = false;
let mut has_glob = false;
for import in imports {
if import.is_glob {
has_glob = true;
let import_path = import.path.join("::");
if import_path_matches_target(&import_path, target_module) {
matches = true;
}
} else {
if import
.imported_names
.contains(&target_symbol_name.to_string())
{
let import_path = import.path.join("::");
if import_path_matches_target(&import_path, target_module) {
matches = true;
}
}
}
}
(matches, has_glob)
}
fn import_path_matches_target(import_path: &str, target_module: &str) -> bool {
if import_path == target_module {
return true;
}
if target_module.starts_with(&format!("{}::", import_path)) {
return true;
}
if import_path.starts_with(&format!("{}::", target_module)) {
return true;
}
false
}
fn find_references_in_file(
source: &[u8],
rope: &Rope,
target_symbol: &RustSymbol,
file_path: &Path,
) -> Result<Vec<Reference>> {
let mut references = Vec::new();
let scope_map = build_scope_map(source)?;
let mut parser = tree_sitter::Parser::new();
parser
.set_language(&tree_sitter_rust::language())
.map_err(|e| SpliceError::Parse {
file: file_path.to_path_buf(),
message: format!("Failed to set Rust language: {:?}", e),
})?;
let tree = parser
.parse(source, None)
.ok_or_else(|| SpliceError::Parse {
file: file_path.to_path_buf(),
message: "Parse failed - no tree returned".to_string(),
})?;
find_identifiers_recursive(
tree.root_node(),
source,
rope,
target_symbol,
file_path,
&scope_map,
&mut references,
);
Ok(references)
}
fn find_identifiers_recursive(
node: tree_sitter::Node,
source: &[u8],
rope: &Rope,
target_symbol: &RustSymbol,
file_path: &Path,
scope_map: &ScopeMap,
references: &mut Vec<Reference>,
) {
let kind = node.kind();
match kind {
"identifier" => {
let parent = node.parent();
if let Some(p) = parent {
if p.kind() == "call_expression" {
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
find_identifiers_recursive(
child,
source,
rope,
target_symbol,
file_path,
scope_map,
references,
);
}
return;
}
}
if let Ok(text) = node.utf8_text(source) {
if text == target_symbol.name {
if scope_map.is_shadowed_at(&target_symbol.name, node.start_byte()) {
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
find_identifiers_recursive(
child,
source,
rope,
target_symbol,
file_path,
scope_map,
references,
);
}
return;
}
let context = extract_context(node, source);
let start_char = rope.byte_to_char(node.start_byte());
let line = rope.char_to_line(start_char);
let line_byte = rope.line_to_byte(line);
let col = node.start_byte() - line_byte;
references.push(Reference {
file_path: file_path.to_str().unwrap_or("").to_string(),
byte_start: node.start_byte(),
byte_end: node.end_byte(),
line: line + 1,
column: col,
context,
match_id: None,
});
}
}
}
"scoped_identifier" | "scoped_type_identifier" => {
if let Ok(text) = node.utf8_text(source) {
if text.ends_with(&format!("::{}", target_symbol.name)) {
let context = extract_context(node, source);
let start_char = rope.byte_to_char(node.start_byte());
let line = rope.char_to_line(start_char);
let line_byte = rope.line_to_byte(line);
let col = node.start_byte() - line_byte;
references.push(Reference {
file_path: file_path.to_str().unwrap_or("").to_string(),
byte_start: node.start_byte(),
byte_end: node.end_byte(),
line: line + 1,
column: col,
context,
match_id: None,
});
}
}
}
"call_expression" => {
if let Some(func) = node.child_by_field_name("function") {
let func_kind = func.kind();
if func_kind == "identifier"
|| func_kind == "scoped_identifier"
|| func_kind == "field_expression"
{
if let Ok(text) = func.utf8_text(source) {
let matches = if func_kind == "identifier" {
text == target_symbol.name
} else if func_kind == "field_expression" {
if let Some(field) = func.child_by_field_name("field") {
if let Ok(field_text) = field.utf8_text(source) {
field_text == target_symbol.name
} else {
false
}
} else {
false
}
} else {
text.ends_with(&format!("::{}", target_symbol.name))
};
if matches && target_symbol.kind == RustSymbolKind::Function {
if func_kind == "identifier"
&& scope_map.is_shadowed_at(&target_symbol.name, func.start_byte())
{
} else {
let context = ReferenceContext::FunctionCall {
is_qualified: func_kind == "scoped_identifier"
|| func_kind == "field_expression",
};
let (start, end) = if func_kind == "field_expression" {
if let Some(field) = func.child_by_field_name("field") {
(field.start_byte(), field.end_byte())
} else {
(func.start_byte(), func.end_byte())
}
} else {
(func.start_byte(), func.end_byte())
};
let start_char = rope.byte_to_char(start);
let line = rope.char_to_line(start_char);
let line_byte = rope.line_to_byte(line);
let col = start - line_byte;
references.push(Reference {
file_path: file_path.to_str().unwrap_or("").to_string(),
byte_start: start,
byte_end: end,
line: line + 1,
column: col,
context,
match_id: None,
});
}
}
}
}
}
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
find_identifiers_recursive(
child,
source,
rope,
target_symbol,
file_path,
scope_map,
references,
);
}
}
_ => {
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
find_identifiers_recursive(
child,
source,
rope,
target_symbol,
file_path,
scope_map,
references,
);
}
}
}
}
fn extract_context(node: tree_sitter::Node, _source: &[u8]) -> ReferenceContext {
let parent = match node.parent() {
Some(p) => p,
None => return ReferenceContext::Identifier,
};
let parent_kind = parent.kind();
match parent_kind {
"call_expression" => ReferenceContext::FunctionCall {
is_qualified: node.kind() == "scoped_identifier",
},
"use_declaration" => ReferenceContext::ImportStatement,
"field_expression" => ReferenceContext::FieldAccess,
"type_identifier" | "generic_type" | "type_arguments" => ReferenceContext::TypeReference,
_ => ReferenceContext::Identifier,
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
use tempfile::{NamedTempFile, TempDir};
fn create_test_graph() -> CodeGraph {
let temp_dir = TempDir::new().unwrap();
let graph_path = temp_dir.path().join("test_graph.db");
let graph = CodeGraph::open(&graph_path).unwrap();
std::mem::forget(temp_dir);
graph
}
#[test]
fn test_find_same_file_function_references() {
let source = r#"
fn helper() -> i32 {
42
}
fn main() {
let x = helper();
let y = helper();
println!("{}", helper());
}
"#;
let mut temp_file = NamedTempFile::new().unwrap();
write!(temp_file, "{}", source).unwrap();
let graph = create_test_graph();
let refs = find_rust_references(
&graph,
temp_file.path(),
"helper",
Some(RustSymbolKind::Function),
)
.unwrap();
assert_eq!(refs.references.len(), 3);
}
#[test]
fn test_qualified_path_references() {
let source = r#"
fn helper() -> i32 {
42
}
fn main() {
let x = helper(); // Unqualified
let y = crate::helper(); // Qualified - but this won't resolve in same file
}
"#;
let mut temp_file = NamedTempFile::new().unwrap();
write!(temp_file, "{}", source).unwrap();
let graph = create_test_graph();
let refs = find_rust_references(
&graph,
temp_file.path(),
"helper",
Some(RustSymbolKind::Function),
)
.unwrap();
assert!(!refs.references.is_empty());
}
#[test]
fn test_no_references_to_symbol() {
let source = r#"
fn unused() -> i32 {
42
}
fn main() {
println!("Hello");
}
"#;
let mut temp_file = NamedTempFile::new().unwrap();
write!(temp_file, "{}", source).unwrap();
let graph = create_test_graph();
let refs = find_rust_references(
&graph,
temp_file.path(),
"unused",
Some(RustSymbolKind::Function),
)
.unwrap();
assert_eq!(refs.references.len(), 0);
}
#[test]
fn test_symbol_not_found() {
let source = "fn main() {}";
let mut temp_file = NamedTempFile::new().unwrap();
write!(temp_file, "{}", source).unwrap();
let graph = create_test_graph();
let result = find_rust_references(
&graph,
temp_file.path(),
"nonexistent",
Some(RustSymbolKind::Function),
);
assert!(result.is_err());
}
#[test]
fn test_import_path_matches_target() {
assert!(import_path_matches_target("crate::utils", "crate::utils"));
assert!(import_path_matches_target(
"crate::utils",
"crate::utils::helpers"
));
assert!(import_path_matches_target(
"crate::utils",
"crate::utils::helper"
));
assert!(!import_path_matches_target("crate::utils", "crate::other"));
assert!(!import_path_matches_target(
"crate::utils::helpers",
"crate::other"
));
}
#[test]
fn test_find_workspace_root() {
use std::fs;
use tempfile::TempDir;
let temp_dir = TempDir::new().unwrap();
let workspace = temp_dir.path();
let cargo_toml = workspace.join("Cargo.toml");
fs::write(&cargo_toml, "[package]\nname = \"test\"\n").unwrap();
let src_dir = workspace.join("src");
fs::create_dir_all(&src_dir).unwrap();
let main_rs = src_dir.join("main.rs");
fs::write(&main_rs, "fn main() {}").unwrap();
let found_root = find_workspace_root(&main_rs).unwrap();
assert_eq!(found_root, workspace);
}
#[test]
fn test_shadowing_by_local_function() {
let source = r#"
fn helper() -> i32 {
42
}
fn main() {
let x = helper(); // Should find this (references top-level helper)
fn helper() -> i32 { // Local function shadows the top-level one
99
}
let y = helper(); // Should NOT find this (references local helper)
}
"#;
let mut temp_file = NamedTempFile::new().unwrap();
write!(temp_file, "{}", source).unwrap();
let graph = create_test_graph();
let refs = find_rust_references(
&graph,
temp_file.path(),
"helper",
Some(RustSymbolKind::Function),
)
.unwrap();
assert_eq!(refs.references.len(), 1);
}
#[test]
fn test_shadowing_by_closure_parameter() {
let source = r#"
fn helper() -> i32 {
42
}
fn main() {
let x = helper(); // Should find this
let f = |helper: i32| helper + 1; // 'helper' here is a parameter, not a call
let y = f(10);
let z = helper(); // Should find this too
}
"#;
let mut temp_file = NamedTempFile::new().unwrap();
write!(temp_file, "{}", source).unwrap();
let graph = create_test_graph();
let refs = find_rust_references(
&graph,
temp_file.path(),
"helper",
Some(RustSymbolKind::Function),
)
.unwrap();
assert_eq!(refs.references.len(), 2);
}
#[test]
fn test_nested_scope_shadowing() {
let source = r#"
fn helper() -> i32 {
42
}
fn main() {
let x = helper(); // Should find
{
let y = helper(); // Should find (still in scope)
fn helper() -> i32 { // Shadows only within this block
99
}
let z = helper(); // Should NOT find (shadowed)
}
let w = helper(); // Should find (outside shadowing scope)
}
"#;
let mut temp_file = NamedTempFile::new().unwrap();
write!(temp_file, "{}", source).unwrap();
let graph = create_test_graph();
let refs = find_rust_references(
&graph,
temp_file.path(),
"helper",
Some(RustSymbolKind::Function),
)
.unwrap();
assert_eq!(refs.references.len(), 3);
}
}