use std::fs;
use std::path::Path;
use proc_macro2::Span;
use quote::ToTokens;
use syn::{
Attribute, Fields, GenericParam, Generics, ImplItem, Item, ItemEnum, ItemImpl, ItemStruct,
Type, Visibility as SynVisibility,
};
use super::error::{ParseError, ParseResult};
use super::types::{
CodeBase, EnumInfo, EnumVariant, FieldInfo, FileInfo, MethodInfo, ParameterInfo, Receiver,
StructInfo, Visibility,
};
pub struct RustParser {
include_private: bool,
}
impl Default for RustParser {
fn default() -> Self {
Self::new()
}
}
impl RustParser {
pub fn new() -> Self {
Self {
include_private: true,
}
}
pub fn include_private(mut self, include: bool) -> Self {
self.include_private = include;
self
}
pub fn parse_file<P: AsRef<Path>>(&self, path: P) -> ParseResult<CodeBase> {
let path = path.as_ref();
let content = fs::read_to_string(path).map_err(|e| ParseError::FileRead {
path: path.to_path_buf(),
source: e,
})?;
self.parse_source(&content, path.to_string_lossy().to_string())
}
pub fn parse_source(&self, source: &str, file_path: String) -> ParseResult<CodeBase> {
let syntax = syn::parse_file(source).map_err(|e| ParseError::SyntaxError {
path: file_path.clone().into(),
message: e.to_string(),
})?;
let mut codebase = CodeBase::new();
let mut structs: Vec<StructInfo> = Vec::new();
let mut enums: Vec<EnumInfo> = Vec::new();
for item in &syntax.items {
match item {
Item::Struct(item_struct) => {
if let Some(struct_info) = self.parse_struct(item_struct, &file_path) {
structs.push(struct_info);
}
}
Item::Enum(item_enum) => {
if let Some(enum_info) = self.parse_enum(item_enum, &file_path) {
enums.push(enum_info);
}
}
_ => {}
}
}
for item in &syntax.items {
if let Item::Impl(item_impl) = item {
self.process_impl_block(item_impl, &file_path, &mut structs);
}
}
codebase.files.push(FileInfo {
path: file_path,
enum_count: enums.len(),
struct_count: structs.len(),
});
codebase.enums = enums;
codebase.structs = structs;
Ok(codebase)
}
fn parse_struct(&self, item: &ItemStruct, file_path: &str) -> Option<StructInfo> {
let visibility = self.parse_visibility(&item.vis);
if !self.include_private && visibility == Visibility::Private {
return None;
}
let doc_comment = self.extract_doc_comment(&item.attrs);
let (derives, attributes) = self.extract_attributes(&item.attrs);
let generics = self.parse_generics(&item.generics);
let fields = self.parse_fields(&item.fields);
Some(StructInfo {
name: item.ident.to_string(),
doc_comment,
file_path: file_path.to_string(),
line_number: self.get_line_number(item.ident.span()),
visibility,
generics,
derives,
attributes,
fields,
methods: Vec::new(),
})
}
fn parse_enum(&self, item: &ItemEnum, file_path: &str) -> Option<EnumInfo> {
let visibility = self.parse_visibility(&item.vis);
if !self.include_private && visibility == Visibility::Private {
return None;
}
let doc_comment = self.extract_doc_comment(&item.attrs);
let (derives, attributes) = self.extract_attributes(&item.attrs);
let generics = self.parse_generics(&item.generics);
let variants: Vec<EnumVariant> = item
.variants
.iter()
.map(|v| {
let variant_doc = self.extract_doc_comment(&v.attrs);
let fields = self.parse_fields(&v.fields);
let discriminant = v
.discriminant
.as_ref()
.map(|(_, expr)| expr.to_token_stream().to_string());
EnumVariant {
name: v.ident.to_string(),
doc_comment: variant_doc,
fields,
discriminant,
}
})
.collect();
Some(EnumInfo {
name: item.ident.to_string(),
doc_comment,
file_path: file_path.to_string(),
line_number: self.get_line_number(item.ident.span()),
visibility,
generics,
derives,
attributes,
variants,
})
}
fn process_impl_block(
&self,
item_impl: &ItemImpl,
file_path: &str,
structs: &mut [StructInfo],
) {
if item_impl.trait_.is_some() {
return;
}
let type_name = match &*item_impl.self_ty {
Type::Path(type_path) => type_path
.path
.segments
.last()
.map(|seg| seg.ident.to_string()),
_ => None,
};
let Some(type_name) = type_name else {
return;
};
let Some(struct_info) = structs.iter_mut().find(|s| s.name == type_name) else {
return;
};
for item in &item_impl.items {
if let ImplItem::Fn(method) = item {
let visibility = self.parse_visibility(&method.vis);
if !self.include_private && visibility == Visibility::Private {
continue;
}
let doc_comment = self.extract_doc_comment(&method.attrs);
let generics = self.parse_generics(&method.sig.generics);
let receiver = method.sig.receiver().map(|recv| {
if recv.reference.is_some() {
if recv.mutability.is_some() {
Receiver::RefMut
} else {
Receiver::Ref
}
} else {
Receiver::Value
}
});
let parameters: Vec<ParameterInfo> = method
.sig
.inputs
.iter()
.filter_map(|arg| {
if let syn::FnArg::Typed(pat_type) = arg {
let name = pat_type.pat.to_token_stream().to_string();
let ty = pat_type.ty.to_token_stream().to_string();
Some(ParameterInfo { name, ty })
} else {
None
}
})
.collect();
let return_type = match &method.sig.output {
syn::ReturnType::Default => None,
syn::ReturnType::Type(_, ty) => Some(ty.to_token_stream().to_string()),
};
let method_info = MethodInfo {
name: method.sig.ident.to_string(),
doc_comment,
file_path: file_path.to_string(),
line_number: self.get_line_number(method.sig.ident.span()),
visibility,
is_async: method.sig.asyncness.is_some(),
is_const: method.sig.constness.is_some(),
is_unsafe: method.sig.unsafety.is_some(),
generics,
parameters,
return_type,
receiver,
};
struct_info.methods.push(method_info);
}
}
}
fn parse_visibility(&self, vis: &SynVisibility) -> Visibility {
match vis {
SynVisibility::Public(_) => Visibility::Public,
SynVisibility::Restricted(restricted) => {
let path = restricted.path.to_token_stream().to_string();
if path == "crate" {
Visibility::Crate
} else if path == "super" {
Visibility::Super
} else {
Visibility::Restricted(path)
}
}
SynVisibility::Inherited => Visibility::Private,
}
}
fn extract_doc_comment(&self, attrs: &[Attribute]) -> Option<String> {
let doc_lines: Vec<String> = attrs
.iter()
.filter_map(|attr| {
if attr.path().is_ident("doc") {
if let syn::Meta::NameValue(meta) = &attr.meta {
if let syn::Expr::Lit(expr_lit) = &meta.value {
if let syn::Lit::Str(lit_str) = &expr_lit.lit {
return Some(lit_str.value().trim().to_string());
}
}
}
}
None
})
.collect();
if doc_lines.is_empty() {
None
} else {
Some(doc_lines.join("\n"))
}
}
fn extract_attributes(&self, attrs: &[Attribute]) -> (Vec<String>, Vec<String>) {
let mut derives = Vec::new();
let mut other_attrs = Vec::new();
for attr in attrs {
if attr.path().is_ident("doc") {
continue; }
if attr.path().is_ident("derive") {
if let Ok(meta) = attr.meta.require_list() {
let tokens = meta.tokens.to_string();
for derive_name in tokens.split(',') {
let name = derive_name.trim();
if !name.is_empty() {
derives.push(name.to_string());
}
}
}
} else {
other_attrs.push(attr.to_token_stream().to_string());
}
}
(derives, other_attrs)
}
fn parse_generics(&self, generics: &Generics) -> Vec<String> {
generics
.params
.iter()
.map(|param| match param {
GenericParam::Type(type_param) => type_param.ident.to_string(),
GenericParam::Lifetime(lifetime) => lifetime.lifetime.to_string(),
GenericParam::Const(const_param) => {
format!("const {}", const_param.ident)
}
})
.collect()
}
fn parse_fields(&self, fields: &Fields) -> Vec<FieldInfo> {
match fields {
Fields::Named(named) => named
.named
.iter()
.map(|f| {
let doc_comment = self.extract_doc_comment(&f.attrs);
let (_, attributes) = self.extract_attributes(&f.attrs);
FieldInfo {
name: f.ident.as_ref().map(|i| i.to_string()),
ty: f.ty.to_token_stream().to_string(),
doc_comment,
visibility: self.parse_visibility(&f.vis),
attributes,
}
})
.collect(),
Fields::Unnamed(unnamed) => unnamed
.unnamed
.iter()
.enumerate()
.map(|(idx, f)| {
let doc_comment = self.extract_doc_comment(&f.attrs);
let (_, attributes) = self.extract_attributes(&f.attrs);
FieldInfo {
name: Some(format!("{}", idx)),
ty: f.ty.to_token_stream().to_string(),
doc_comment,
visibility: self.parse_visibility(&f.vis),
attributes,
}
})
.collect(),
Fields::Unit => Vec::new(),
}
}
fn get_line_number(&self, span: Span) -> usize {
span.start().line
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_simple_struct() {
let source = r#"
/// A simple test struct.
pub struct TestStruct {
/// The name field.
pub name: String,
/// The age field.
age: u32,
}
"#;
let parser = RustParser::new();
let codebase = parser.parse_source(source, "test.rs".to_string()).unwrap();
assert_eq!(codebase.structs.len(), 1);
let s = &codebase.structs[0];
assert_eq!(s.name, "TestStruct");
assert!(s
.doc_comment
.as_ref()
.unwrap()
.contains("simple test struct"));
assert_eq!(s.visibility, Visibility::Public);
assert_eq!(s.fields.len(), 2);
assert_eq!(s.fields[0].name, Some("name".to_string()));
assert_eq!(s.fields[1].name, Some("age".to_string()));
}
#[test]
fn test_parse_enum() {
let source = r#"
/// Status enum.
#[derive(Debug, Clone)]
pub enum Status {
/// Active status.
Active,
/// Inactive with reason.
Inactive(String),
/// Custom status.
Custom { code: u32, message: String },
}
"#;
let parser = RustParser::new();
let codebase = parser.parse_source(source, "test.rs".to_string()).unwrap();
assert_eq!(codebase.enums.len(), 1);
let e = &codebase.enums[0];
assert_eq!(e.name, "Status");
assert!(e.derives.contains(&"Debug".to_string()));
assert!(e.derives.contains(&"Clone".to_string()));
assert_eq!(e.variants.len(), 3);
assert_eq!(e.variants[0].name, "Active");
assert_eq!(e.variants[1].name, "Inactive");
assert_eq!(e.variants[2].name, "Custom");
}
#[test]
fn test_parse_methods() {
let source = r#"
pub struct Calculator {
value: i32,
}
impl Calculator {
/// Creates a new calculator.
pub fn new() -> Self {
Self { value: 0 }
}
/// Adds a value.
pub fn add(&mut self, n: i32) {
self.value += n;
}
/// Gets the current value.
pub fn value(&self) -> i32 {
self.value
}
}
"#;
let parser = RustParser::new();
let codebase = parser.parse_source(source, "test.rs".to_string()).unwrap();
assert_eq!(codebase.structs.len(), 1);
let s = &codebase.structs[0];
assert_eq!(s.methods.len(), 3);
let new_method = s.methods.iter().find(|m| m.name == "new").unwrap();
assert!(new_method.receiver.is_none());
assert!(new_method.return_type.is_some());
let add_method = s.methods.iter().find(|m| m.name == "add").unwrap();
assert!(matches!(add_method.receiver, Some(Receiver::RefMut)));
assert_eq!(add_method.parameters.len(), 1);
let value_method = s.methods.iter().find(|m| m.name == "value").unwrap();
assert!(matches!(value_method.receiver, Some(Receiver::Ref)));
}
#[test]
fn test_parse_generics() {
let source = r#"
pub struct Container<T, U> {
first: T,
second: U,
}
"#;
let parser = RustParser::new();
let codebase = parser.parse_source(source, "test.rs".to_string()).unwrap();
assert_eq!(codebase.structs.len(), 1);
let s = &codebase.structs[0];
assert_eq!(s.generics, vec!["T", "U"]);
}
#[test]
fn test_exclude_private() {
let source = r#"
pub struct Public {}
struct Private {}
"#;
let parser = RustParser::new().include_private(false);
let codebase = parser.parse_source(source, "test.rs".to_string()).unwrap();
assert_eq!(codebase.structs.len(), 1);
assert_eq!(codebase.structs[0].name, "Public");
}
}