use std::path::Path;
use tree_sitter::{Node, Parser};
use crate::semantic::adapter::LanguageAdapter;
use crate::semantic::common::{node_text, signature_up_to_body};
use crate::semantic::types::{ByteRange, ExtractedFile, Import, ImportKind, Symbol, SymbolKind};
pub struct RustAdapter;
impl RustAdapter {
fn signature(&self, n: Node, s: &[u8]) -> String {
signature_up_to_body(n, s)
}
fn is_exported(&self, n: Node) -> bool {
for i in 0..n.named_child_count() {
if let Some(c) = n.named_child(i)
&& c.kind() == "visibility_modifier"
{
return true;
}
}
false
}
fn ident_child<'a>(&self, n: Node<'a>, s: &'a [u8]) -> Option<String> {
for i in 0..n.named_child_count() {
let c = n.named_child(i)?;
if matches!(c.kind(), "identifier" | "type_identifier") {
return Some(node_text(c, s).to_string());
}
}
None
}
#[allow(clippy::only_used_in_recursion)]
fn type_leaf_name(&self, n: Node, s: &[u8]) -> Option<String> {
match n.kind() {
"type_identifier" => Some(node_text(n, s).to_string()),
"generic_type" | "scoped_type_identifier" => {
for i in 0..n.named_child_count() {
if let Some(c) = n.named_child(i)
&& let Some(name) = self.type_leaf_name(c, s)
{
return Some(name);
}
}
None
}
_ => None,
}
}
fn handle_function(&self, n: Node, s: &[u8], symbols: &mut Vec<Symbol>) {
let Some(name) = self.ident_child(n, s) else {
return;
};
symbols.push(Symbol {
kind: SymbolKind::Function,
is_exported: self.is_exported(n),
name,
range: ByteRange::from(n),
signature: self.signature(n, s),
parent_class: None,
});
}
fn handle_struct_or_enum(&self, n: Node, s: &[u8], symbols: &mut Vec<Symbol>) {
let Some(name) = self.ident_child(n, s) else {
return;
};
symbols.push(Symbol {
kind: SymbolKind::Class,
is_exported: self.is_exported(n),
name,
range: ByteRange::from(n),
signature: node_text(n, s).lines().next().unwrap_or("").to_string(),
parent_class: None,
});
}
fn handle_trait(&self, n: Node, s: &[u8], symbols: &mut Vec<Symbol>) {
let Some(trait_name) = self.ident_child(n, s) else {
return;
};
symbols.push(Symbol {
kind: SymbolKind::Interface,
is_exported: self.is_exported(n),
name: trait_name.clone(),
range: ByteRange::from(n),
signature: node_text(n, s).lines().next().unwrap_or("").to_string(),
parent_class: None,
});
for i in 0..n.named_child_count() {
let Some(c) = n.named_child(i) else { continue };
if c.kind() != "declaration_list" {
continue;
}
for j in 0..c.named_child_count() {
let Some(m) = c.named_child(j) else { continue };
let mname = match m.kind() {
"function_item" | "function_signature_item" => self.ident_child(m, s),
_ => None,
};
if let Some(mname) = mname {
symbols.push(Symbol {
kind: SymbolKind::Method,
is_exported: true,
name: mname,
range: ByteRange::from(m),
signature: self.signature(m, s),
parent_class: Some(trait_name.clone()),
});
}
}
}
}
fn handle_impl(&self, n: Node, s: &[u8], symbols: &mut Vec<Symbol>) {
let receiving = n
.child_by_field_name("type")
.and_then(|t| self.type_leaf_name(t, s))
.or_else(|| {
let mut last: Option<String> = None;
for i in 0..n.named_child_count() {
if let Some(c) = n.named_child(i)
&& c.kind() == "type_identifier"
{
last = Some(node_text(c, s).to_string());
}
}
last
});
let Some(receiving) = receiving else {
return;
};
for i in 0..n.named_child_count() {
let Some(c) = n.named_child(i) else { continue };
if c.kind() != "declaration_list" {
continue;
}
for j in 0..c.named_child_count() {
let Some(m) = c.named_child(j) else { continue };
if m.kind() != "function_item" {
continue;
}
if let Some(mname) = self.ident_child(m, s) {
symbols.push(Symbol {
kind: SymbolKind::Method,
is_exported: self.is_exported(m),
name: mname,
range: ByteRange::from(m),
signature: self.signature(m, s),
parent_class: Some(receiving.clone()),
});
}
}
}
}
fn handle_type_alias(&self, n: Node, s: &[u8], symbols: &mut Vec<Symbol>) {
let Some(name) = self.ident_child(n, s) else {
return;
};
symbols.push(Symbol {
kind: SymbolKind::TypeAlias,
is_exported: self.is_exported(n),
name,
range: ByteRange::from(n),
signature: node_text(n, s).lines().next().unwrap_or("").to_string(),
parent_class: None,
});
}
fn handle_const_or_static(&self, n: Node, s: &[u8], symbols: &mut Vec<Symbol>) {
let mut name: Option<String> = None;
for i in 0..n.named_child_count() {
if let Some(c) = n.named_child(i)
&& c.kind() == "identifier"
{
name = Some(node_text(c, s).to_string());
break;
}
}
let Some(name) = name else { return };
symbols.push(Symbol {
kind: SymbolKind::Variable,
is_exported: self.is_exported(n),
name,
range: ByteRange::from(n),
signature: node_text(n, s).lines().next().unwrap_or("").to_string(),
parent_class: None,
});
}
fn handle_mod(&self, n: Node, s: &[u8], symbols: &mut Vec<Symbol>, imports: &mut Vec<Import>) {
let Some(name) = self.ident_child(n, s) else {
return;
};
symbols.push(Symbol {
kind: SymbolKind::Class,
is_exported: self.is_exported(n),
name,
range: ByteRange::from(n),
signature: node_text(n, s).lines().next().unwrap_or("").to_string(),
parent_class: None,
});
for i in 0..n.named_child_count() {
let Some(c) = n.named_child(i) else { continue };
if c.kind() != "declaration_list" {
continue;
}
for j in 0..c.named_child_count() {
let Some(item) = c.named_child(j) else {
continue;
};
match item.kind() {
"function_item" => self.handle_function(item, s, symbols),
"struct_item" | "enum_item" | "union_item" => {
self.handle_struct_or_enum(item, s, symbols);
}
"trait_item" => self.handle_trait(item, s, symbols),
"impl_item" => self.handle_impl(item, s, symbols),
"type_item" => self.handle_type_alias(item, s, symbols),
"const_item" | "static_item" => self.handle_const_or_static(item, s, symbols),
"use_declaration" => self.handle_use(item, s, imports),
"mod_item" => self.handle_mod(item, s, symbols, imports),
"macro_definition" => self.handle_macro(item, s, symbols),
"foreign_mod_item" => self.handle_extern_block(item, s, symbols),
_ => {}
}
}
}
}
fn handle_macro(&self, n: Node, s: &[u8], symbols: &mut Vec<Symbol>) {
for i in 0..n.named_child_count() {
if let Some(c) = n.named_child(i)
&& c.kind() == "identifier"
{
let name = node_text(c, s).to_string();
symbols.push(Symbol {
kind: SymbolKind::Function,
is_exported: self.is_exported(n),
name,
range: ByteRange::from(n),
signature: node_text(n, s).lines().next().unwrap_or("").to_string(),
parent_class: None,
});
return;
}
}
}
fn handle_extern_block(&self, n: Node, s: &[u8], symbols: &mut Vec<Symbol>) {
for i in 0..n.named_child_count() {
let Some(c) = n.named_child(i) else { continue };
if c.kind() != "declaration_list" {
continue;
}
for j in 0..c.named_child_count() {
let Some(item) = c.named_child(j) else {
continue;
};
match item.kind() {
"function_signature_item" => {
if let Some(name) = self.ident_child(item, s) {
symbols.push(Symbol {
kind: SymbolKind::Function,
is_exported: true,
name,
range: ByteRange::from(item),
signature: self.signature(item, s),
parent_class: None,
});
}
}
"static_item" => self.handle_const_or_static(item, s, symbols),
"type_item" => self.handle_type_alias(item, s, symbols),
_ => {}
}
}
}
}
fn handle_use(&self, n: Node, s: &[u8], imports: &mut Vec<Import>) {
for i in 0..n.named_child_count() {
let Some(c) = n.named_child(i) else { continue };
match c.kind() {
"scoped_identifier" | "identifier" | "use_list" | "use_as_clause" => {
let path = node_text(c, s).to_string();
imports.push(Import {
names: vec![path.clone()],
source: path,
kind: ImportKind::Qualified,
});
break;
}
_ => {}
}
}
}
}
impl LanguageAdapter for RustAdapter {
fn extensions(&self) -> &[&str] {
&[".rs"]
}
fn extract(&self, file_path: &Path, source: &str) -> Result<ExtractedFile, String> {
let lang: tree_sitter::Language = tree_sitter_rust::LANGUAGE.into();
let mut parser = Parser::new();
parser
.set_language(&lang)
.map_err(|e| format!("Failed to set language: {e}"))?;
let tree = parser.parse(source, None).ok_or("Failed to parse source")?;
let root = tree.root_node();
let bytes = source.as_bytes();
let mut symbols = Vec::new();
let mut imports = Vec::new();
let mut warnings = Vec::new();
if root.has_error() {
warnings.push("tree-sitter reported syntax errors".to_string());
}
for i in 0..root.named_child_count() {
let Some(c) = root.named_child(i) else {
continue;
};
match c.kind() {
"function_item" => self.handle_function(c, bytes, &mut symbols),
"struct_item" | "enum_item" | "union_item" => {
self.handle_struct_or_enum(c, bytes, &mut symbols);
}
"trait_item" => self.handle_trait(c, bytes, &mut symbols),
"impl_item" => self.handle_impl(c, bytes, &mut symbols),
"type_item" => self.handle_type_alias(c, bytes, &mut symbols),
"const_item" | "static_item" => self.handle_const_or_static(c, bytes, &mut symbols),
"use_declaration" => self.handle_use(c, bytes, &mut imports),
"mod_item" => self.handle_mod(c, bytes, &mut symbols, &mut imports),
"macro_definition" => self.handle_macro(c, bytes, &mut symbols),
"foreign_mod_item" => self.handle_extern_block(c, bytes, &mut symbols),
_ => {}
}
}
let exports: Vec<String> = symbols
.iter()
.filter(|s| s.is_exported)
.map(|s| s.name.clone())
.collect();
Ok(ExtractedFile {
file_path: file_path.to_path_buf(),
symbols,
imports,
exports,
warnings,
mtime: std::time::SystemTime::now(),
size: 0,
head_hash: 0,
})
}
fn find_callees_in_range(
&self,
source: &str,
_file_path: &Path,
range: ByteRange,
) -> Result<Vec<String>, String> {
let lang: tree_sitter::Language = tree_sitter_rust::LANGUAGE.into();
let query_str = r#"
(call_expression function: (identifier) @callee)
(call_expression function: (field_expression field: (field_identifier) @callee))
(macro_invocation macro: (identifier) @callee)
"#;
crate::semantic::common::run_callee_query(&lang, query_str, source, range)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn pb(n: &str) -> std::path::PathBuf {
std::path::PathBuf::from(n)
}
#[test]
fn extracts_pub_fn_as_exported_and_private_fn_not() {
let src = "pub fn a() {}\nfn b() {}\n";
let f = RustAdapter.extract(&pb("x.rs"), src).unwrap();
let a = f.symbols.iter().find(|s| s.name == "a").unwrap();
let b = f.symbols.iter().find(|s| s.name == "b").unwrap();
assert!(a.is_exported);
assert!(!b.is_exported);
assert!(matches!(a.kind, SymbolKind::Function));
}
#[test]
fn extracts_struct_enum_as_class() {
let src = "pub struct Foo { name: String }\npub enum Bar { A, B }\n";
let f = RustAdapter.extract(&pb("x.rs"), src).unwrap();
assert!(
f.symbols
.iter()
.any(|s| s.name == "Foo" && matches!(s.kind, SymbolKind::Class))
);
assert!(
f.symbols
.iter()
.any(|s| s.name == "Bar" && matches!(s.kind, SymbolKind::Class))
);
}
#[test]
fn extracts_trait_with_method_signatures() {
let src = "pub trait Greeter {\n fn greet(&self) -> String;\n fn default_greet(&self) -> String { \"hi\".to_string() }\n}\n";
let f = RustAdapter.extract(&pb("x.rs"), src).unwrap();
let trait_sym = f.symbols.iter().find(|s| s.name == "Greeter").unwrap();
assert!(matches!(trait_sym.kind, SymbolKind::Interface));
let g = f.symbols.iter().find(|s| s.name == "greet").unwrap();
assert_eq!(g.parent_class.as_deref(), Some("Greeter"));
let dg = f
.symbols
.iter()
.find(|s| s.name == "default_greet")
.unwrap();
assert_eq!(dg.parent_class.as_deref(), Some("Greeter"));
}
#[test]
fn impl_methods_attach_to_receiving_type() {
let src = "pub struct Foo;\nimpl Greeter for Foo {\n fn greet(&self) -> String { String::new() }\n}\n";
let f = RustAdapter.extract(&pb("x.rs"), src).unwrap();
let g = f.symbols.iter().find(|s| s.name == "greet").unwrap();
assert!(matches!(g.kind, SymbolKind::Method));
assert_eq!(g.parent_class.as_deref(), Some("Foo"));
}
#[test]
fn extracts_type_alias() {
let src = "pub type Id = u64;\n";
let f = RustAdapter.extract(&pb("x.rs"), src).unwrap();
let id = f.symbols.iter().find(|s| s.name == "Id").unwrap();
assert!(matches!(id.kind, SymbolKind::TypeAlias));
assert!(id.is_exported);
}
#[test]
fn extracts_const_and_static_as_variable() {
let src = "pub const MAX: u32 = 42;\nstatic GLOBAL: i32 = 0;\n";
let f = RustAdapter.extract(&pb("x.rs"), src).unwrap();
let m = f.symbols.iter().find(|s| s.name == "MAX").unwrap();
let g = f.symbols.iter().find(|s| s.name == "GLOBAL").unwrap();
assert!(matches!(m.kind, SymbolKind::Variable));
assert!(m.is_exported);
assert!(!g.is_exported);
}
#[test]
fn extracts_use_imports() {
let src = "use std::sync::Arc;\nuse crate::foo::Bar;\n";
let f = RustAdapter.extract(&pb("x.rs"), src).unwrap();
assert!(
f.imports
.iter()
.any(|i| i.source.contains("std::sync::Arc"))
);
assert!(
f.imports
.iter()
.any(|i| i.source.contains("crate::foo::Bar"))
);
}
#[test]
fn find_callees_captures_direct_method_and_macro() {
let src = "pub fn run() { helper(); foo.bar(); println!(\"x\"); }\nfn helper() {}\n";
let f = RustAdapter.extract(&pb("x.rs"), src).unwrap();
let run = f.symbols.iter().find(|s| s.name == "run").unwrap();
let callees = RustAdapter
.find_callees_in_range(src, &pb("x.rs"), run.range)
.unwrap();
assert!(callees.contains(&"helper".to_string()));
assert!(callees.contains(&"bar".to_string()));
assert!(callees.contains(&"println".to_string()));
}
#[test]
fn extracts_inline_module_and_its_items() {
let src = "pub mod inner {\n pub fn deep() -> u32 { 42 }\n pub struct Held;\n}\n";
let f = RustAdapter.extract(&pb("x.rs"), src).unwrap();
let m = f.symbols.iter().find(|s| s.name == "inner").unwrap();
assert!(matches!(m.kind, SymbolKind::Class));
assert!(m.is_exported);
assert!(
f.symbols
.iter()
.any(|s| s.name == "deep" && matches!(s.kind, SymbolKind::Function))
);
assert!(
f.symbols
.iter()
.any(|s| s.name == "Held" && matches!(s.kind, SymbolKind::Class))
);
}
#[test]
fn extracts_extern_block_signatures() {
let src =
"extern \"C\" {\n fn foreign_fn(x: i32) -> i32;\n static FOREIGN_GLOBAL: i32;\n}\n";
let f = RustAdapter.extract(&pb("x.rs"), src).unwrap();
let ff = f.symbols.iter().find(|s| s.name == "foreign_fn").unwrap();
assert!(matches!(ff.kind, SymbolKind::Function));
let g = f
.symbols
.iter()
.find(|s| s.name == "FOREIGN_GLOBAL")
.unwrap();
assert!(matches!(g.kind, SymbolKind::Variable));
}
#[test]
fn extracts_macro_rules() {
let src = "macro_rules! my_mac { ($x:expr) => { $x + 1 }; }\n";
let f = RustAdapter.extract(&pb("x.rs"), src).unwrap();
let m = f.symbols.iter().find(|s| s.name == "my_mac").unwrap();
assert!(matches!(m.kind, SymbolKind::Function));
}
#[test]
fn impl_for_generic_type_uses_base_name() {
let src = "pub struct Bag<T>(T);\nimpl<T: Clone> AsRef<T> for Bag<T> { fn as_ref(&self) -> &T { &self.0 } }\n";
let f = RustAdapter.extract(&pb("x.rs"), src).unwrap();
let m = f.symbols.iter().find(|s| s.name == "as_ref").unwrap();
assert_eq!(m.parent_class.as_deref(), Some("Bag"));
}
}