use super::trait_def::Driver;
use crate::core::ir::{
Blueprint, Class, Documentation, Element, Function, Method, Parameter, Property, Signature,
Visibility,
};
use crate::drivers::LanguageTerminology;
use anyhow::{Context, Result, anyhow};
use std::path::PathBuf;
use tree_sitter::{Node, Parser};
pub struct RustDriver;
impl Driver for RustDriver {
fn parse(&self, content: &str) -> Result<Blueprint> {
parse_rust_code(content)
}
fn generate(&self, _blueprint: &Blueprint) -> Result<String> {
Err(anyhow!(
"RustDriver::generate: Not implemented yet (Rust is a source, not a target)"
))
}
fn terminology(&self) -> LanguageTerminology {
LanguageTerminology::rust()
}
}
fn parse_rust_code(source: &str) -> Result<Blueprint> {
let mut parser = Parser::new();
let language = tree_sitter_rust::language();
parser
.set_language(language)
.context("Failed to set Rust language")?;
let tree = parser
.parse(source, None)
.ok_or_else(|| anyhow!("Failed to parse Rust code"))?;
let mut elements = Vec::new();
let mut dependencies = Vec::new();
let mut cursor = tree.walk();
for child in tree.root_node().children(&mut cursor) {
match child.kind() {
"use_declaration" => {
if let Ok(deps) = extract_imports(&child, source) {
dependencies.extend(deps);
}
}
"struct_item" => {
if let Ok(class) = extract_struct(&child, source) {
elements.push(Element::Class(class));
}
}
"function_item" => {
if let Ok(function) = extract_function(&child, source) {
elements.push(Element::Function(function));
}
}
"impl_item" => {
if let Ok(methods) = extract_impl_methods(&child, source)
&& let Some(class_name) = get_impl_struct_name(&child, source)
&& let Some(Element::Class(class)) = elements
.iter_mut()
.find(|e| matches!(e, Element::Class(c) if c.name == class_name))
{
class.methods.extend(methods);
}
}
_ => {}
}
}
dependencies.sort();
dependencies.dedup();
Ok(Blueprint {
source_path: PathBuf::from("unknown.rs"),
language: "rust".to_string(),
elements,
dependencies,
})
}
fn extract_imports(node: &Node, source: &str) -> Result<Vec<String>, String> {
let mut imports = Vec::new();
let text = node
.utf8_text(source.as_bytes())
.map(|s| s.to_string())
.unwrap_or_default();
if let Some(use_path) = text
.strip_prefix("use ")
.map(|s| s.trim_end_matches(';').trim())
{
if let Some(root) = use_path.split("::").next() {
if root != "crate" && root != "self" {
imports.push(root.to_string());
} else if root == "crate" {
if let Some(submodule) = use_path.split("::").nth(1)
&& !submodule.is_empty()
{
imports.push(submodule.to_string());
}
}
}
}
Ok(imports)
}
fn extract_struct(node: &Node, source: &str) -> Result<Class> {
let mut name = String::new();
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.kind() == "type_identifier" {
name = child
.utf8_text(source.as_bytes())
.map(|s| s.to_string())
.unwrap_or_default();
break;
}
}
if name.is_empty() {
return Err(anyhow!("Could not find struct name"));
}
let visibility = extract_visibility(node, source);
Ok(Class {
name,
visibility,
methods: Vec::new(),
properties: extract_struct_fields(node, source),
documentation: extract_doc_comment(node, source),
})
}
fn extract_function(node: &Node, source: &str) -> Result<Function> {
let name = get_node_text(node, source, "name")?;
let visibility = extract_visibility(node, source);
let signature = extract_function_signature(node, source)?;
Ok(Function {
name,
visibility,
signature,
documentation: extract_doc_comment(node, source),
})
}
fn extract_impl_methods(node: &Node, source: &str) -> Result<Vec<Method>> {
let mut methods = Vec::new();
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.kind() == "function_item"
&& let Ok(sig) = extract_function_signature(&child, source)
{
let name = get_node_text(&child, source, "name").unwrap_or_default();
let visibility = extract_visibility(&child, source);
methods.push(Method {
name,
visibility,
is_static: false, signature: sig,
documentation: extract_doc_comment(&child, source),
});
}
}
Ok(methods)
}
fn get_impl_struct_name(node: &Node, source: &str) -> Option<String> {
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.kind() == "type_identifier" {
return Some(get_node_text(&child, source, "").unwrap_or_default());
}
}
None
}
fn extract_struct_fields(node: &Node, source: &str) -> Vec<Property> {
let mut properties = Vec::new();
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.kind() == "field_declaration"
&& let Ok(prop) = extract_field(&child, source)
{
properties.push(prop);
}
}
properties
}
fn extract_field(node: &Node, source: &str) -> Result<Property> {
let name = get_node_text(node, source, "name")?;
let visibility = extract_visibility(node, source);
let type_annotation = get_node_text(node, source, "type").ok();
Ok(Property {
name,
visibility,
type_annotation,
documentation: extract_doc_comment(node, source),
})
}
fn extract_function_signature(node: &Node, source: &str) -> Result<Signature> {
let mut parameters = Vec::new();
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.kind() == "parameters" {
parameters = extract_parameters(&child, source);
break;
}
}
let return_type = extract_return_type(node, source);
Ok(Signature {
parameters,
return_type,
})
}
fn extract_parameters(node: &Node, source: &str) -> Vec<Parameter> {
let mut parameters = Vec::new();
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.kind() == "parameter"
&& let Ok(param) = extract_parameter(&child, source)
{
parameters.push(param);
}
}
parameters
}
fn extract_parameter(node: &Node, source: &str) -> Result<Parameter> {
let name = get_node_text(node, source, "pattern")?;
let type_annotation = get_node_text(node, source, "type").ok();
Ok(Parameter {
name,
type_annotation,
default_value: None,
})
}
fn extract_return_type(node: &Node, source: &str) -> Option<String> {
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.kind() == "->" {
if let Some(next) = child.next_sibling() {
return Some(
next.utf8_text(source.as_bytes())
.unwrap_or("")
.trim()
.to_string(),
);
}
}
}
None
}
fn extract_visibility(node: &Node, source: &str) -> Visibility {
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.kind() == "pub" {
if let Ok(text) = child.utf8_text(source.as_bytes()) {
if text.contains("(crate)") {
return Visibility::Internal;
} else if text.contains("(super)") {
return Visibility::Protected;
}
}
return Visibility::Public;
}
}
Visibility::Private
}
fn extract_doc_comment(_node: &Node, _source: &str) -> Documentation {
Documentation {
summary: None,
description: None,
examples: Vec::new(),
}
}
fn get_node_text(node: &Node, source: &str, target_kind: &str) -> Result<String> {
if target_kind.is_empty() {
return Ok(node
.utf8_text(source.as_bytes())
.map(|s| s.trim().to_string())
.unwrap_or_default());
}
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.kind() == target_kind {
return Ok(child
.utf8_text(source.as_bytes())
.map(|s| s.trim().to_string())
.unwrap_or_default());
}
}
Err(anyhow!("Could not find {} in node", target_kind))
}