use crate::parser::language::{
Export, Import, LanguageSupport, ParseResult, Symbol, SymbolKind, Visibility,
};
use tree_sitter::Language as TsLanguage;
pub struct RustLanguage;
impl RustLanguage {
fn is_public(node: &tree_sitter::Node, source: &[u8]) -> bool {
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.kind() == "visibility_modifier" {
let text = child.utf8_text(source).unwrap_or("");
return text.starts_with("pub");
}
}
false
}
fn node_text<'a>(node: &tree_sitter::Node, source: &'a [u8]) -> &'a str {
node.utf8_text(source).unwrap_or("")
}
fn extract_name(node: &tree_sitter::Node, source: &[u8]) -> String {
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.kind() == "identifier" || child.kind() == "type_identifier" {
return Self::node_text(&child, source).to_string();
}
}
String::new()
}
fn extract_fn_signature(node: &tree_sitter::Node, source: &[u8]) -> String {
let full_text = Self::node_text(node, source);
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.kind() == "block" {
let body_start = child.start_byte() - node.start_byte();
return full_text[..body_start].trim().to_string();
}
}
full_text.trim().to_string()
}
fn extract_fn_body(node: &tree_sitter::Node, source: &[u8]) -> String {
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.kind() == "block" {
let text = &source[child.start_byte()..child.end_byte()];
return String::from_utf8_lossy(text).into_owned();
}
}
String::new()
}
fn first_line(node: &tree_sitter::Node, source: &[u8]) -> String {
let text = Self::node_text(node, source);
text.lines().next().unwrap_or("").trim().to_string()
}
fn type_signature(node: &tree_sitter::Node, source: &[u8]) -> String {
let full_text = Self::node_text(node, source);
full_text.lines().next().unwrap_or("").trim().to_string()
}
fn extract_use_import(node: &tree_sitter::Node, source: &[u8]) -> Option<Import> {
let use_text = Self::node_text(node, source);
let inner = use_text
.trim_start_matches("use")
.trim()
.trim_end_matches(';')
.trim();
if inner.is_empty() {
return None;
}
if inner.ends_with("::*") {
let source_path = inner.trim_end_matches("::*").to_string();
return Some(Import {
source: source_path,
names: vec!["*".to_string()],
});
}
if let Some(brace_start) = inner.rfind("::{") {
let source_path = inner[..brace_start].to_string();
let names_str = &inner[brace_start + 3..];
let names_str = names_str.trim_end_matches('}');
let names: Vec<String> = names_str
.split(',')
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty())
.collect();
return Some(Import {
source: source_path,
names,
});
}
if let Some(sep) = inner.rfind("::") {
let source_path = inner[..sep].to_string();
let name = inner[sep + 2..].to_string();
Some(Import {
source: source_path,
names: vec![name],
})
} else {
Some(Import {
source: String::new(),
names: vec![inner.to_string()],
})
}
}
fn extract_impl_methods(node: &tree_sitter::Node, source: &[u8]) -> Vec<Symbol> {
let mut methods = Vec::new();
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.kind() == "declaration_list" {
let mut inner_cursor = child.walk();
for item in child.children(&mut inner_cursor) {
if item.kind() == "function_item" {
let name = Self::extract_name(&item, source);
let visibility = if Self::is_public(&item, source) {
Visibility::Public
} else {
Visibility::Private
};
let signature = Self::extract_fn_signature(&item, source);
let body = Self::extract_fn_body(&item, source);
let start_line = item.start_position().row + 1;
let end_line = item.end_position().row + 1;
methods.push(Symbol {
name,
kind: SymbolKind::Method,
visibility,
signature,
body,
start_line,
end_line,
});
}
}
}
}
methods
}
}
impl LanguageSupport for RustLanguage {
fn ts_language(&self) -> TsLanguage {
tree_sitter_rust::LANGUAGE.into()
}
fn name(&self) -> &str {
"rust"
}
fn extract(&self, source: &str, tree: &tree_sitter::Tree) -> ParseResult {
let source_bytes = source.as_bytes();
let root = tree.root_node();
let mut symbols: Vec<Symbol> = Vec::new();
let mut imports: Vec<Import> = Vec::new();
let mut exports: Vec<Export> = Vec::new();
let mut cursor = root.walk();
for node in root.children(&mut cursor) {
match node.kind() {
"function_item" => {
let name = Self::extract_name(&node, source_bytes);
let is_pub = Self::is_public(&node, source_bytes);
let visibility = if is_pub {
Visibility::Public
} else {
Visibility::Private
};
let signature = Self::extract_fn_signature(&node, source_bytes);
let body = Self::extract_fn_body(&node, source_bytes);
let start_line = node.start_position().row + 1;
let end_line = node.end_position().row + 1;
if is_pub {
exports.push(Export {
name: name.clone(),
kind: SymbolKind::Function,
});
}
symbols.push(Symbol {
name,
kind: SymbolKind::Function,
visibility,
signature,
body,
start_line,
end_line,
});
}
"struct_item" => {
let name = Self::extract_name(&node, source_bytes);
let is_pub = Self::is_public(&node, source_bytes);
let visibility = if is_pub {
Visibility::Public
} else {
Visibility::Private
};
let signature = Self::type_signature(&node, source_bytes);
let body = Self::node_text(&node, source_bytes).to_string();
let start_line = node.start_position().row + 1;
let end_line = node.end_position().row + 1;
if is_pub {
exports.push(Export {
name: name.clone(),
kind: SymbolKind::Struct,
});
}
symbols.push(Symbol {
name,
kind: SymbolKind::Struct,
visibility,
signature,
body,
start_line,
end_line,
});
}
"enum_item" => {
let name = Self::extract_name(&node, source_bytes);
let is_pub = Self::is_public(&node, source_bytes);
let visibility = if is_pub {
Visibility::Public
} else {
Visibility::Private
};
let signature = Self::type_signature(&node, source_bytes);
let body = Self::node_text(&node, source_bytes).to_string();
let start_line = node.start_position().row + 1;
let end_line = node.end_position().row + 1;
if is_pub {
exports.push(Export {
name: name.clone(),
kind: SymbolKind::Enum,
});
}
symbols.push(Symbol {
name,
kind: SymbolKind::Enum,
visibility,
signature,
body,
start_line,
end_line,
});
}
"trait_item" => {
let name = Self::extract_name(&node, source_bytes);
let is_pub = Self::is_public(&node, source_bytes);
let visibility = if is_pub {
Visibility::Public
} else {
Visibility::Private
};
let signature = Self::first_line(&node, source_bytes);
let body = Self::node_text(&node, source_bytes).to_string();
let start_line = node.start_position().row + 1;
let end_line = node.end_position().row + 1;
if is_pub {
exports.push(Export {
name: name.clone(),
kind: SymbolKind::Trait,
});
}
symbols.push(Symbol {
name,
kind: SymbolKind::Trait,
visibility,
signature,
body,
start_line,
end_line,
});
}
"impl_item" => {
let methods = Self::extract_impl_methods(&node, source_bytes);
for method in &methods {
if method.visibility == Visibility::Public {
exports.push(Export {
name: method.name.clone(),
kind: SymbolKind::Method,
});
}
}
symbols.extend(methods);
}
"use_declaration" => {
if let Some(import) = Self::extract_use_import(&node, source_bytes) {
imports.push(import);
}
}
_ => {}
}
}
ParseResult {
symbols,
imports,
exports,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::parser::language::{SymbolKind, Visibility};
fn make_parser() -> tree_sitter::Parser {
let mut parser = tree_sitter::Parser::new();
parser
.set_language(&tree_sitter_rust::LANGUAGE.into())
.expect("failed to set language");
parser
}
#[test]
fn test_extract_public_function() {
let source = r#"
pub fn greet(name: &str) -> String {
format!("Hello, {}!", name)
}
"#;
let mut parser = make_parser();
let tree = parser.parse(source, None).expect("parse failed");
let lang = RustLanguage;
let result = lang.extract(source, &tree);
assert_eq!(result.symbols.len(), 1);
let sym = &result.symbols[0];
assert_eq!(sym.name, "greet");
assert_eq!(sym.kind, SymbolKind::Function);
assert_eq!(sym.visibility, Visibility::Public);
assert!(
sym.signature.contains("pub fn greet"),
"signature: {}",
sym.signature
);
assert!(sym.body.contains("format!"), "body: {}", sym.body);
assert_eq!(result.exports.len(), 1);
assert_eq!(result.exports[0].name, "greet");
}
#[test]
fn test_extract_private_function() {
let source = r#"
fn helper(x: i32) -> i32 {
x * 2
}
"#;
let mut parser = make_parser();
let tree = parser.parse(source, None).expect("parse failed");
let lang = RustLanguage;
let result = lang.extract(source, &tree);
assert_eq!(result.symbols.len(), 1);
let sym = &result.symbols[0];
assert_eq!(sym.name, "helper");
assert_eq!(sym.kind, SymbolKind::Function);
assert_eq!(sym.visibility, Visibility::Private);
assert!(
sym.signature.contains("fn helper"),
"signature: {}",
sym.signature
);
assert!(sym.body.contains("x * 2"), "body: {}", sym.body);
assert!(
result.exports.is_empty(),
"private function should not be exported"
);
}
#[test]
fn test_extract_struct() {
let source = r#"
pub struct Point {
pub x: f64,
pub y: f64,
}
"#;
let mut parser = make_parser();
let tree = parser.parse(source, None).expect("parse failed");
let lang = RustLanguage;
let result = lang.extract(source, &tree);
assert_eq!(result.symbols.len(), 1);
let sym = &result.symbols[0];
assert_eq!(sym.name, "Point");
assert_eq!(sym.kind, SymbolKind::Struct);
assert_eq!(sym.visibility, Visibility::Public);
assert!(
sym.signature.contains("pub struct Point"),
"signature: {}",
sym.signature
);
assert_eq!(result.exports.len(), 1);
assert_eq!(result.exports[0].name, "Point");
assert_eq!(result.exports[0].kind, SymbolKind::Struct);
}
#[test]
fn test_extract_enum() {
let source = r#"
pub enum Direction {
North,
South,
East,
West,
}
"#;
let mut parser = make_parser();
let tree = parser.parse(source, None).expect("parse failed");
let lang = RustLanguage;
let result = lang.extract(source, &tree);
assert_eq!(result.symbols.len(), 1);
let sym = &result.symbols[0];
assert_eq!(sym.name, "Direction");
assert_eq!(sym.kind, SymbolKind::Enum);
assert_eq!(sym.visibility, Visibility::Public);
assert!(
sym.signature.contains("pub enum Direction"),
"signature: {}",
sym.signature
);
assert_eq!(result.exports.len(), 1);
assert_eq!(result.exports[0].name, "Direction");
assert_eq!(result.exports[0].kind, SymbolKind::Enum);
}
#[test]
fn test_extract_trait() {
let source = r#"
pub trait Animal {
fn name(&self) -> &str;
fn sound(&self) -> &str;
}
"#;
let mut parser = make_parser();
let tree = parser.parse(source, None).expect("parse failed");
let lang = RustLanguage;
let result = lang.extract(source, &tree);
assert_eq!(result.symbols.len(), 1);
let sym = &result.symbols[0];
assert_eq!(sym.name, "Animal");
assert_eq!(sym.kind, SymbolKind::Trait);
assert_eq!(sym.visibility, Visibility::Public);
assert!(
sym.signature.contains("pub trait Animal"),
"signature: {}",
sym.signature
);
assert_eq!(result.exports.len(), 1);
assert_eq!(result.exports[0].name, "Animal");
assert_eq!(result.exports[0].kind, SymbolKind::Trait);
}
#[test]
fn test_extract_use_import() {
let source = r#"
use std::collections::HashMap;
use std::io::{Read, Write};
"#;
let mut parser = make_parser();
let tree = parser.parse(source, None).expect("parse failed");
let lang = RustLanguage;
let result = lang.extract(source, &tree);
assert_eq!(result.imports.len(), 2);
let first = &result.imports[0];
assert_eq!(first.source, "std::collections");
assert_eq!(first.names, vec!["HashMap"]);
let second = &result.imports[1];
assert_eq!(second.source, "std::io");
assert!(
second.names.contains(&"Read".to_string()),
"names: {:?}",
second.names
);
assert!(
second.names.contains(&"Write".to_string()),
"names: {:?}",
second.names
);
}
#[test]
fn test_extract_impl_methods() {
let source = r#"
struct Counter {
count: u32,
}
impl Counter {
pub fn increment(&mut self) {
self.count += 1;
}
fn reset(&mut self) {
self.count = 0;
}
}
"#;
let mut parser = make_parser();
let tree = parser.parse(source, None).expect("parse failed");
let lang = RustLanguage;
let result = lang.extract(source, &tree);
let methods: Vec<&Symbol> = result
.symbols
.iter()
.filter(|s| s.kind == SymbolKind::Method)
.collect();
assert_eq!(
methods.len(),
2,
"expected 2 methods, got: {:?}",
methods.iter().map(|m| &m.name).collect::<Vec<_>>()
);
let increment = methods
.iter()
.find(|m| m.name == "increment")
.expect("increment method not found");
assert_eq!(increment.visibility, Visibility::Public);
assert!(
increment.signature.contains("pub fn increment"),
"sig: {}",
increment.signature
);
assert!(
increment.body.contains("self.count += 1"),
"body: {}",
increment.body
);
let reset = methods
.iter()
.find(|m| m.name == "reset")
.expect("reset method not found");
assert_eq!(reset.visibility, Visibility::Private);
assert!(
reset.signature.contains("fn reset"),
"sig: {}",
reset.signature
);
let method_exports: Vec<&Export> = result
.exports
.iter()
.filter(|e| e.kind == SymbolKind::Method)
.collect();
assert_eq!(method_exports.len(), 1);
assert_eq!(method_exports[0].name, "increment");
}
#[test]
fn test_extract_glob_import() {
let source = r#"
use std::collections::*;
"#;
let mut parser = make_parser();
let tree = parser.parse(source, None).expect("parse failed");
let lang = RustLanguage;
let result = lang.extract(source, &tree);
assert_eq!(result.imports.len(), 1);
assert_eq!(result.imports[0].source, "std::collections");
assert!(result.imports[0].names.contains(&"*".to_string()));
}
#[test]
fn test_extract_bare_identifier_import() {
let source = r#"
use HashMap;
"#;
let mut parser = make_parser();
let tree = parser.parse(source, None).expect("parse failed");
let lang = RustLanguage;
let result = lang.extract(source, &tree);
assert_eq!(result.imports.len(), 1);
assert!(result.imports[0].source.is_empty());
assert!(result.imports[0].names.contains(&"HashMap".to_string()));
}
#[test]
fn test_extract_trait_method_no_body() {
let source = r#"
pub trait Serializer {
fn serialize(&self) -> String;
fn deserialize(data: &str) -> Self;
}
"#;
let mut parser = make_parser();
let tree = parser.parse(source, None).expect("parse failed");
let lang = RustLanguage;
let result = lang.extract(source, &tree);
let traits: Vec<_> = result
.symbols
.iter()
.filter(|s| s.kind == SymbolKind::Trait)
.collect();
assert!(!traits.is_empty());
assert!(
traits[0].body.contains("serialize"),
"body: {}",
traits[0].body
);
}
#[test]
fn test_trait_method_no_body() {
let source = r#"pub trait Handler {
fn handle(&self, req: Request) -> Response;
fn name(&self) -> &str;
}
"#;
let mut parser = make_parser();
let tree = parser.parse(source, None).expect("parse failed");
let lang = RustLanguage;
let result = lang.extract(source, &tree);
let traits: Vec<_> = result
.symbols
.iter()
.filter(|s| s.kind == SymbolKind::Trait)
.collect();
assert!(!traits.is_empty());
}
#[test]
fn test_extract_name_fallback() {
let source = "use std::io;\n";
let mut parser = make_parser();
let tree = parser.parse(source, None).expect("parse failed");
let lang = RustLanguage;
let result = lang.extract(source, &tree);
let _ = result;
}
#[test]
fn test_bare_use_import() {
let source = "use serde;\n";
let mut parser = make_parser();
let tree = parser.parse(source, None).expect("parse failed");
let lang = RustLanguage;
let result = lang.extract(source, &tree);
assert!(!result.imports.is_empty());
}
#[test]
fn test_type_alias_parsed() {
let source = "pub type Result<T> = std::result::Result<T, Error>;\n";
let mut parser = make_parser();
let tree = parser.parse(source, None).expect("parse failed");
let lang = RustLanguage;
let result = lang.extract(source, &tree);
let _ = result;
}
#[test]
fn test_private_trait() {
let source = r#"
trait InternalHelper {
fn do_work(&self);
}
"#;
let mut parser = make_parser();
let tree = parser.parse(source, None).expect("parse failed");
let lang = RustLanguage;
let result = lang.extract(source, &tree);
let traits: Vec<_> = result
.symbols
.iter()
.filter(|s| s.kind == SymbolKind::Trait)
.collect();
assert_eq!(traits.len(), 1);
assert_eq!(traits[0].name, "InternalHelper");
assert_eq!(traits[0].visibility, Visibility::Private);
assert!(result.exports.iter().all(|e| e.name != "InternalHelper"));
}
#[test]
fn test_private_enum() {
let source = "enum InternalState {\n Active,\n Inactive,\n}\n";
let mut parser = make_parser();
let tree = parser.parse(source, None).expect("parse failed");
let lang = RustLanguage;
let result = lang.extract(source, &tree);
let enums: Vec<_> = result
.symbols
.iter()
.filter(|s| s.kind == SymbolKind::Enum)
.collect();
assert_eq!(enums.len(), 1);
assert_eq!(enums[0].name, "InternalState");
assert_eq!(enums[0].visibility, Visibility::Private);
assert!(result.exports.is_empty());
}
#[test]
fn test_trait_method_without_body() {
let source = "pub trait Greeter {\n fn greet(&self) -> String;\n}\n";
let mut parser = make_parser();
let tree = parser.parse(source, None).expect("parse failed");
let lang = RustLanguage;
let result = lang.extract(source, &tree);
let traits: Vec<_> = result
.symbols
.iter()
.filter(|s| s.kind == SymbolKind::Trait)
.collect();
assert!(!traits.is_empty());
}
}