use similarity_core::language_parser::{
GenericFunctionDef, GenericTypeDef, Language, LanguageParser,
};
use similarity_core::tree::TreeNode;
use std::error::Error;
use std::rc::Rc;
use tree_sitter::{Node, Parser};
pub struct RustParser {
parser: Parser,
}
impl RustParser {
pub fn new() -> Result<Self, Box<dyn Error>> {
let mut parser = Parser::new();
parser.set_language(&tree_sitter_rust::LANGUAGE.into())?;
Ok(RustParser { parser })
}
fn extract_functions_from_node<'a>(
&self,
node: Node<'a>,
source: &'a str,
functions: &mut Vec<GenericFunctionDef>,
skip_test: bool,
) {
match node.kind() {
"function_item" => {
if skip_test && self.is_test_function(node, source) {
return;
}
if let Some(func_def) = self.extract_function_definition(node, source) {
functions.push(func_def);
}
}
"impl_item" => {
for child in node.children(&mut node.walk()) {
if child.kind() == "declaration_list" {
for method in child.children(&mut child.walk()) {
if method.kind() == "function_item" {
if skip_test && self.is_test_function(method, source) {
continue;
}
if let Some(func_def) =
self.extract_function_definition(method, source)
{
functions.push(func_def);
}
}
}
}
}
}
_ => {
for child in node.children(&mut node.walk()) {
self.extract_functions_from_node(child, source, functions, skip_test);
}
}
}
}
fn is_test_function(&self, node: Node, source: &str) -> bool {
if let Some(prev_sibling) = node.prev_sibling() {
if prev_sibling.kind() == "attribute_item" {
let attr_text = &source[prev_sibling.byte_range().start..prev_sibling.byte_range().end];
if attr_text.contains("test") {
return true;
}
}
}
for child in node.children(&mut node.walk()) {
if child.kind() == "identifier" {
let name = &source[child.byte_range().start..child.byte_range().end];
if name.starts_with("test_") {
return true;
}
break;
}
}
false
}
pub fn is_test_function_by_name(&self, name: &str, _source: &str) -> bool {
name.starts_with("test_")
}
pub fn extract_functions_with_skip_test(
&mut self,
source: &str,
_filename: &str,
skip_test: bool,
) -> Result<Vec<GenericFunctionDef>, Box<dyn Error>> {
let tree = self.parser.parse(source, None).ok_or("Failed to parse source")?;
let root_node = tree.root_node();
let mut functions = Vec::new();
self.extract_functions_from_node(root_node, source, &mut functions, skip_test);
Ok(functions)
}
fn extract_function_definition(&self, node: Node, source: &str) -> Option<GenericFunctionDef> {
let mut name = String::new();
let mut is_async = false;
let mut is_method = false;
let mut class_name: Option<String> = None;
let mut parameters = Vec::new();
let mut body_start_line = 0;
let mut body_end_line = 0;
for child in node.children(&mut node.walk()) {
if child.kind() == "async" {
is_async = true;
}
}
if let Some(parent) = node.parent() {
if parent.kind() == "declaration_list" {
if let Some(impl_node) = parent.parent() {
if impl_node.kind() == "impl_item" {
is_method = true;
for child in impl_node.children(&mut impl_node.walk()) {
if child.kind() == "type_identifier" {
class_name = Some(
source[child.byte_range().start..child.byte_range().end]
.to_string(),
);
break;
}
}
}
}
}
}
for child in node.children(&mut node.walk()) {
match child.kind() {
"identifier" => {
if name.is_empty() {
name = source[child.byte_range().start..child.byte_range().end].to_string();
}
}
"parameters" => {
for param in child.children(&mut child.walk()) {
if param.kind() == "parameter" || param.kind() == "self_parameter" {
if let Some(pattern) = param.child_by_field_name("pattern") {
parameters.push(
source[pattern.byte_range().start..pattern.byte_range().end]
.to_string(),
);
} else if param.kind() == "self_parameter" {
parameters.push("self".to_string());
}
}
}
}
"block" => {
let block_text = &source[child.byte_range().start..child.byte_range().end];
if let Some(open_pos) = block_text.find('{') {
if let Some(close_pos) = block_text.rfind('}') {
let inner_content = &block_text[open_pos + 1..close_pos].trim();
let _lines_before_block =
source[..child.byte_range().start].lines().count();
let lines_before_content =
source[..child.byte_range().start + open_pos + 1].lines().count();
body_start_line = (lines_before_content + 1) as u32;
let content_lines = inner_content.lines().count();
body_end_line =
body_start_line + content_lines.saturating_sub(1) as u32;
}
}
if body_start_line == 0 {
body_start_line = (child.start_position().row + 1) as u32;
body_end_line = (child.end_position().row + 1) as u32;
}
}
_ => {}
}
}
if !name.is_empty() {
Some(GenericFunctionDef {
name,
start_line: (node.start_position().row + 1) as u32,
end_line: (node.end_position().row + 1) as u32,
body_start_line,
body_end_line,
is_async,
is_generator: false, is_method,
class_name,
decorators: Vec::new(), parameters,
})
} else {
None
}
}
#[allow(clippy::only_used_in_recursion)]
fn convert_node_to_tree(&self, node: Node, source: &str) -> Rc<TreeNode> {
let label = node.kind().to_string();
let value = match node.kind() {
"identifier" | "string_literal" | "char_literal" | "integer_literal"
| "float_literal" | "true" | "false" | "+" | "-" | "*" | "/" | "%" | "==" | "!="
| "<" | ">" | "<=" | ">=" | "&&" | "||" | "!" | "&" | "|" | "^" | "<<" | ">>"
| "type_identifier" | "field_identifier" => {
source[node.byte_range().start..node.byte_range().end].to_string()
}
_ => String::new(),
};
let mut tree_node = TreeNode::new(label, value, 0);
for child in node.children(&mut node.walk()) {
if !child.is_extra() {
tree_node.add_child(self.convert_node_to_tree(child, source));
}
}
Rc::new(tree_node)
}
fn extract_types_from_node<'a>(
&self,
node: Node<'a>,
source: &'a str,
types: &mut Vec<GenericTypeDef>,
) {
match node.kind() {
"struct_item" => {
if let Some(type_def) = self.extract_struct_definition(node, source) {
types.push(type_def);
}
}
"enum_item" => {
if let Some(type_def) = self.extract_enum_definition(node, source) {
types.push(type_def);
}
}
"type_alias" => {
if let Some(type_def) = self.extract_type_alias(node, source) {
types.push(type_def);
}
}
_ => {
for child in node.children(&mut node.walk()) {
self.extract_types_from_node(child, source, types);
}
}
}
}
fn extract_struct_definition(&self, node: Node, source: &str) -> Option<GenericTypeDef> {
let mut name = String::new();
let mut fields = Vec::new();
for child in node.children(&mut node.walk()) {
match child.kind() {
"type_identifier" => {
if name.is_empty() {
name = source[child.byte_range().start..child.byte_range().end].to_string();
}
}
"field_declaration_list" => {
for field in child.children(&mut child.walk()) {
if field.kind() == "field_declaration" {
if let Some(field_name) = field.child_by_field_name("name") {
let field_name_str = source
[field_name.byte_range().start..field_name.byte_range().end]
.to_string();
fields.push(field_name_str);
}
}
}
}
_ => {}
}
}
if !name.is_empty() {
Some(GenericTypeDef {
name,
kind: "struct".to_string(),
start_line: (node.start_position().row + 1) as u32,
end_line: (node.end_position().row + 1) as u32,
fields,
})
} else {
None
}
}
fn extract_enum_definition(&self, node: Node, source: &str) -> Option<GenericTypeDef> {
let mut name = String::new();
let mut variants = Vec::new();
for child in node.children(&mut node.walk()) {
match child.kind() {
"type_identifier" => {
if name.is_empty() {
name = source[child.byte_range().start..child.byte_range().end].to_string();
}
}
"enum_variant_list" => {
for variant in child.children(&mut child.walk()) {
if variant.kind() == "enum_variant" {
if let Some(variant_name) = variant.child_by_field_name("name") {
let variant_name_str = source[variant_name.byte_range().start
..variant_name.byte_range().end]
.to_string();
variants.push(variant_name_str);
}
}
}
}
_ => {}
}
}
if !name.is_empty() {
Some(GenericTypeDef {
name,
kind: "enum".to_string(),
start_line: (node.start_position().row + 1) as u32,
end_line: (node.end_position().row + 1) as u32,
fields: variants,
})
} else {
None
}
}
fn extract_type_alias(&self, node: Node, source: &str) -> Option<GenericTypeDef> {
let mut name = String::new();
for child in node.children(&mut node.walk()) {
if child.kind() == "type_identifier" && name.is_empty() {
name = source[child.byte_range().start..child.byte_range().end].to_string();
break;
}
}
if !name.is_empty() {
Some(GenericTypeDef {
name,
kind: "type_alias".to_string(),
start_line: (node.start_position().row + 1) as u32,
end_line: (node.end_position().row + 1) as u32,
fields: Vec::new(),
})
} else {
None
}
}
}
fn find_first_function(node: Node) -> Option<Node> {
if node.kind() == "function_item" {
return Some(node);
}
for child in node.children(&mut node.walk()) {
if let Some(func) = find_first_function(child) {
return Some(func);
}
}
None
}
impl LanguageParser for RustParser {
fn parse(&mut self, source: &str, filename: &str) -> Result<Rc<TreeNode>, Box<dyn Error>> {
let wrapped_source = if source.trim_start() != source || !source.starts_with("fn ") {
format!("fn __dummy() {{ {} }}", source)
} else {
source.to_string()
};
let tree = self
.parser
.parse(&wrapped_source, None)
.ok_or_else(|| format!("Failed to parse {}", filename))?;
let root_node = tree.root_node();
if wrapped_source != source {
if let Some(func_node) = find_first_function(root_node) {
for child in func_node.children(&mut func_node.walk()) {
if child.kind() == "block" {
let mut block_children = Vec::new();
for block_child in child.children(&mut child.walk()) {
if block_child.kind() != "{" && block_child.kind() != "}" {
block_children
.push(self.convert_node_to_tree(block_child, &wrapped_source));
}
}
let mut root = TreeNode::new("block_content".to_string(), String::new(), 0);
for child in block_children {
root.add_child(child);
}
return Ok(Rc::new(root));
}
}
}
}
Ok(self.convert_node_to_tree(root_node, &wrapped_source))
}
fn extract_functions(
&mut self,
source: &str,
_filename: &str,
) -> Result<Vec<GenericFunctionDef>, Box<dyn Error>> {
let tree = self.parser.parse(source, None).ok_or("Failed to parse source")?;
let root_node = tree.root_node();
let mut functions = Vec::new();
self.extract_functions_from_node(root_node, source, &mut functions, false);
Ok(functions)
}
fn extract_types(
&mut self,
source: &str,
_filename: &str,
) -> Result<Vec<GenericTypeDef>, Box<dyn Error>> {
let tree = self.parser.parse(source, None).ok_or("Failed to parse source")?;
let root_node = tree.root_node();
let mut types = Vec::new();
self.extract_types_from_node(root_node, source, &mut types);
Ok(types)
}
fn language(&self) -> Language {
Language::Rust
}
}
impl Default for RustParser {
fn default() -> Self {
Self::new().expect("Failed to create Rust parser")
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rust_functions() {
let mut parser = RustParser::new().unwrap();
let source = r#"
fn main() {
println!("Hello, world!");
}
async fn fetch_data(url: &str) -> Result<String, Error> {
let response = reqwest::get(url).await?;
response.text().await
}
impl MyStruct {
fn new() -> Self {
MyStruct { value: 0 }
}
fn get_value(&self) -> i32 {
self.value
}
}
"#;
let functions = parser.extract_functions(source, "test.rs").unwrap();
assert_eq!(functions.len(), 4);
assert_eq!(functions[0].name, "main");
assert!(!functions[0].is_async);
assert!(!functions[0].is_method);
assert_eq!(functions[1].name, "fetch_data");
assert!(!functions[1].is_method);
assert_eq!(functions[2].name, "new");
assert!(functions[2].is_method);
assert_eq!(functions[2].class_name, Some("MyStruct".to_string()));
assert_eq!(functions[3].name, "get_value");
assert!(functions[3].is_method);
assert_eq!(functions[3].parameters, vec!["self"]);
}
#[test]
fn test_rust_types() {
let mut parser = RustParser::new().unwrap();
let source = r#"
struct Point {
x: f64,
y: f64,
}
enum Color {
Red,
Green,
Blue,
RGB(u8, u8, u8),
}
type Distance = f64;
"#;
let types = parser.extract_types(source, "test.rs").unwrap();
assert!(types.len() >= 2);
assert_eq!(types[0].name, "Point");
assert_eq!(types[0].kind, "struct");
assert_eq!(types[0].fields, vec!["x", "y"]);
assert_eq!(types[1].name, "Color");
assert_eq!(types[1].kind, "enum");
assert_eq!(types[1].fields, vec!["Red", "Green", "Blue", "RGB"]);
}
}