herolib_code/parser/
rust_parser.rs

1//! Rust source code parser using the `syn` crate.
2//!
3//! This module parses Rust source files and extracts information about
4//! enums, structs, and their associated methods.
5
6use std::fs;
7use std::path::Path;
8
9use proc_macro2::Span;
10use quote::ToTokens;
11use syn::{
12    Attribute, Fields, GenericParam, Generics, ImplItem, Item, ItemEnum, ItemImpl, ItemStruct,
13    Type, Visibility as SynVisibility,
14};
15
16use super::error::{ParseError, ParseResult};
17use super::types::{
18    CodeBase, EnumInfo, EnumVariant, FieldInfo, FileInfo, MethodInfo, ParameterInfo, Receiver,
19    StructInfo, Visibility,
20};
21
22/// Parses Rust source files and extracts code structure information.
23pub struct RustParser {
24    /// Whether to include private items.
25    include_private: bool,
26}
27
28impl Default for RustParser {
29    fn default() -> Self {
30        Self::new()
31    }
32}
33
34impl RustParser {
35    /// Creates a new RustParser with default settings.
36    pub fn new() -> Self {
37        Self {
38            include_private: true,
39        }
40    }
41
42    /// Sets whether to include private items in the output.
43    pub fn include_private(mut self, include: bool) -> Self {
44        self.include_private = include;
45        self
46    }
47
48    /// Parses a single Rust source file.
49    ///
50    /// # Arguments
51    ///
52    /// * `path` - Path to the Rust source file.
53    ///
54    /// # Returns
55    ///
56    /// A CodeBase containing all parsed elements from the file.
57    pub fn parse_file<P: AsRef<Path>>(&self, path: P) -> ParseResult<CodeBase> {
58        let path = path.as_ref();
59        let content = fs::read_to_string(path).map_err(|e| ParseError::FileRead {
60            path: path.to_path_buf(),
61            source: e,
62        })?;
63
64        self.parse_source(&content, path.to_string_lossy().to_string())
65    }
66
67    /// Parses Rust source code from a string.
68    ///
69    /// # Arguments
70    ///
71    /// * `source` - The Rust source code as a string.
72    /// * `file_path` - The path to use for error messages and file references.
73    ///
74    /// # Returns
75    ///
76    /// A CodeBase containing all parsed elements.
77    pub fn parse_source(&self, source: &str, file_path: String) -> ParseResult<CodeBase> {
78        let syntax = syn::parse_file(source).map_err(|e| ParseError::SyntaxError {
79            path: file_path.clone().into(),
80            message: e.to_string(),
81        })?;
82
83        let mut codebase = CodeBase::new();
84        let mut structs: Vec<StructInfo> = Vec::new();
85        let mut enums: Vec<EnumInfo> = Vec::new();
86
87        // First pass: collect all structs and enums
88        for item in &syntax.items {
89            match item {
90                Item::Struct(item_struct) => {
91                    if let Some(struct_info) = self.parse_struct(item_struct, &file_path) {
92                        structs.push(struct_info);
93                    }
94                }
95                Item::Enum(item_enum) => {
96                    if let Some(enum_info) = self.parse_enum(item_enum, &file_path) {
97                        enums.push(enum_info);
98                    }
99                }
100                _ => {}
101            }
102        }
103
104        // Second pass: collect impl blocks and attach methods to structs
105        for item in &syntax.items {
106            if let Item::Impl(item_impl) = item {
107                self.process_impl_block(item_impl, &file_path, &mut structs);
108            }
109        }
110
111        // Update file info
112        codebase.files.push(FileInfo {
113            path: file_path,
114            enum_count: enums.len(),
115            struct_count: structs.len(),
116        });
117
118        codebase.enums = enums;
119        codebase.structs = structs;
120
121        Ok(codebase)
122    }
123
124    /// Parses a struct item.
125    fn parse_struct(&self, item: &ItemStruct, file_path: &str) -> Option<StructInfo> {
126        let visibility = self.parse_visibility(&item.vis);
127
128        if !self.include_private && visibility == Visibility::Private {
129            return None;
130        }
131
132        let doc_comment = self.extract_doc_comment(&item.attrs);
133        let (derives, attributes) = self.extract_attributes(&item.attrs);
134        let generics = self.parse_generics(&item.generics);
135        let fields = self.parse_fields(&item.fields);
136
137        Some(StructInfo {
138            name: item.ident.to_string(),
139            doc_comment,
140            file_path: file_path.to_string(),
141            line_number: self.get_line_number(item.ident.span()),
142            visibility,
143            generics,
144            derives,
145            attributes,
146            fields,
147            methods: Vec::new(),
148        })
149    }
150
151    /// Parses an enum item.
152    fn parse_enum(&self, item: &ItemEnum, file_path: &str) -> Option<EnumInfo> {
153        let visibility = self.parse_visibility(&item.vis);
154
155        if !self.include_private && visibility == Visibility::Private {
156            return None;
157        }
158
159        let doc_comment = self.extract_doc_comment(&item.attrs);
160        let (derives, attributes) = self.extract_attributes(&item.attrs);
161        let generics = self.parse_generics(&item.generics);
162
163        let variants: Vec<EnumVariant> = item
164            .variants
165            .iter()
166            .map(|v| {
167                let variant_doc = self.extract_doc_comment(&v.attrs);
168                let fields = self.parse_fields(&v.fields);
169                let discriminant = v
170                    .discriminant
171                    .as_ref()
172                    .map(|(_, expr)| expr.to_token_stream().to_string());
173
174                EnumVariant {
175                    name: v.ident.to_string(),
176                    doc_comment: variant_doc,
177                    fields,
178                    discriminant,
179                }
180            })
181            .collect();
182
183        Some(EnumInfo {
184            name: item.ident.to_string(),
185            doc_comment,
186            file_path: file_path.to_string(),
187            line_number: self.get_line_number(item.ident.span()),
188            visibility,
189            generics,
190            derives,
191            attributes,
192            variants,
193        })
194    }
195
196    /// Processes an impl block and attaches methods to the corresponding struct.
197    fn process_impl_block(
198        &self,
199        item_impl: &ItemImpl,
200        file_path: &str,
201        structs: &mut [StructInfo],
202    ) {
203        // Only process inherent impls (not trait impls)
204        if item_impl.trait_.is_some() {
205            return;
206        }
207
208        // Get the type name being implemented
209        let type_name = match &*item_impl.self_ty {
210            Type::Path(type_path) => type_path
211                .path
212                .segments
213                .last()
214                .map(|seg| seg.ident.to_string()),
215            _ => None,
216        };
217
218        let Some(type_name) = type_name else {
219            return;
220        };
221
222        // Find the corresponding struct
223        let Some(struct_info) = structs.iter_mut().find(|s| s.name == type_name) else {
224            return;
225        };
226
227        // Parse methods from the impl block
228        for item in &item_impl.items {
229            if let ImplItem::Fn(method) = item {
230                let visibility = self.parse_visibility(&method.vis);
231
232                if !self.include_private && visibility == Visibility::Private {
233                    continue;
234                }
235
236                let doc_comment = self.extract_doc_comment(&method.attrs);
237                let generics = self.parse_generics(&method.sig.generics);
238
239                // Parse receiver (self parameter)
240                let receiver = method.sig.receiver().map(|recv| {
241                    if recv.reference.is_some() {
242                        if recv.mutability.is_some() {
243                            Receiver::RefMut
244                        } else {
245                            Receiver::Ref
246                        }
247                    } else {
248                        Receiver::Value
249                    }
250                });
251
252                // Parse parameters (excluding self)
253                let parameters: Vec<ParameterInfo> = method
254                    .sig
255                    .inputs
256                    .iter()
257                    .filter_map(|arg| {
258                        if let syn::FnArg::Typed(pat_type) = arg {
259                            let name = pat_type.pat.to_token_stream().to_string();
260                            let ty = pat_type.ty.to_token_stream().to_string();
261                            Some(ParameterInfo { name, ty })
262                        } else {
263                            None
264                        }
265                    })
266                    .collect();
267
268                // Parse return type
269                let return_type = match &method.sig.output {
270                    syn::ReturnType::Default => None,
271                    syn::ReturnType::Type(_, ty) => Some(ty.to_token_stream().to_string()),
272                };
273
274                let method_info = MethodInfo {
275                    name: method.sig.ident.to_string(),
276                    doc_comment,
277                    file_path: file_path.to_string(),
278                    line_number: self.get_line_number(method.sig.ident.span()),
279                    visibility,
280                    is_async: method.sig.asyncness.is_some(),
281                    is_const: method.sig.constness.is_some(),
282                    is_unsafe: method.sig.unsafety.is_some(),
283                    generics,
284                    parameters,
285                    return_type,
286                    receiver,
287                };
288
289                struct_info.methods.push(method_info);
290            }
291        }
292    }
293
294    /// Parses visibility from syn's Visibility type.
295    fn parse_visibility(&self, vis: &SynVisibility) -> Visibility {
296        match vis {
297            SynVisibility::Public(_) => Visibility::Public,
298            SynVisibility::Restricted(restricted) => {
299                let path = restricted.path.to_token_stream().to_string();
300                if path == "crate" {
301                    Visibility::Crate
302                } else if path == "super" {
303                    Visibility::Super
304                } else {
305                    Visibility::Restricted(path)
306                }
307            }
308            SynVisibility::Inherited => Visibility::Private,
309        }
310    }
311
312    /// Extracts doc comments from attributes.
313    fn extract_doc_comment(&self, attrs: &[Attribute]) -> Option<String> {
314        let doc_lines: Vec<String> = attrs
315            .iter()
316            .filter_map(|attr| {
317                if attr.path().is_ident("doc") {
318                    if let syn::Meta::NameValue(meta) = &attr.meta {
319                        if let syn::Expr::Lit(expr_lit) = &meta.value {
320                            if let syn::Lit::Str(lit_str) = &expr_lit.lit {
321                                return Some(lit_str.value().trim().to_string());
322                            }
323                        }
324                    }
325                }
326                None
327            })
328            .collect();
329
330        if doc_lines.is_empty() {
331            None
332        } else {
333            Some(doc_lines.join("\n"))
334        }
335    }
336
337    /// Extracts derive macros and other attributes.
338    fn extract_attributes(&self, attrs: &[Attribute]) -> (Vec<String>, Vec<String>) {
339        let mut derives = Vec::new();
340        let mut other_attrs = Vec::new();
341
342        for attr in attrs {
343            if attr.path().is_ident("doc") {
344                continue; // Skip doc comments
345            }
346
347            if attr.path().is_ident("derive") {
348                // Extract derive macro names
349                if let Ok(meta) = attr.meta.require_list() {
350                    let tokens = meta.tokens.to_string();
351                    // Parse comma-separated derive names
352                    for derive_name in tokens.split(',') {
353                        let name = derive_name.trim();
354                        if !name.is_empty() {
355                            derives.push(name.to_string());
356                        }
357                    }
358                }
359            } else {
360                // Store other attributes as strings
361                other_attrs.push(attr.to_token_stream().to_string());
362            }
363        }
364
365        (derives, other_attrs)
366    }
367
368    /// Parses generic parameters.
369    fn parse_generics(&self, generics: &Generics) -> Vec<String> {
370        generics
371            .params
372            .iter()
373            .map(|param| match param {
374                GenericParam::Type(type_param) => type_param.ident.to_string(),
375                GenericParam::Lifetime(lifetime) => lifetime.lifetime.to_string(),
376                GenericParam::Const(const_param) => {
377                    format!("const {}", const_param.ident)
378                }
379            })
380            .collect()
381    }
382
383    /// Parses struct/enum fields.
384    fn parse_fields(&self, fields: &Fields) -> Vec<FieldInfo> {
385        match fields {
386            Fields::Named(named) => named
387                .named
388                .iter()
389                .map(|f| {
390                    let doc_comment = self.extract_doc_comment(&f.attrs);
391                    let (_, attributes) = self.extract_attributes(&f.attrs);
392
393                    FieldInfo {
394                        name: f.ident.as_ref().map(|i| i.to_string()),
395                        ty: f.ty.to_token_stream().to_string(),
396                        doc_comment,
397                        visibility: self.parse_visibility(&f.vis),
398                        attributes,
399                    }
400                })
401                .collect(),
402            Fields::Unnamed(unnamed) => unnamed
403                .unnamed
404                .iter()
405                .enumerate()
406                .map(|(idx, f)| {
407                    let doc_comment = self.extract_doc_comment(&f.attrs);
408                    let (_, attributes) = self.extract_attributes(&f.attrs);
409
410                    FieldInfo {
411                        name: Some(format!("{}", idx)),
412                        ty: f.ty.to_token_stream().to_string(),
413                        doc_comment,
414                        visibility: self.parse_visibility(&f.vis),
415                        attributes,
416                    }
417                })
418                .collect(),
419            Fields::Unit => Vec::new(),
420        }
421    }
422
423    /// Gets the line number from a span.
424    fn get_line_number(&self, span: Span) -> usize {
425        span.start().line
426    }
427}
428
429#[cfg(test)]
430mod tests {
431    use super::*;
432
433    #[test]
434    fn test_parse_simple_struct() {
435        let source = r#"
436/// A simple test struct.
437pub struct TestStruct {
438    /// The name field.
439    pub name: String,
440    /// The age field.
441    age: u32,
442}
443"#;
444
445        let parser = RustParser::new();
446        let codebase = parser.parse_source(source, "test.rs".to_string()).unwrap();
447
448        assert_eq!(codebase.structs.len(), 1);
449        let s = &codebase.structs[0];
450        assert_eq!(s.name, "TestStruct");
451        assert!(s
452            .doc_comment
453            .as_ref()
454            .unwrap()
455            .contains("simple test struct"));
456        assert_eq!(s.visibility, Visibility::Public);
457        assert_eq!(s.fields.len(), 2);
458        assert_eq!(s.fields[0].name, Some("name".to_string()));
459        assert_eq!(s.fields[1].name, Some("age".to_string()));
460    }
461
462    #[test]
463    fn test_parse_enum() {
464        let source = r#"
465/// Status enum.
466#[derive(Debug, Clone)]
467pub enum Status {
468    /// Active status.
469    Active,
470    /// Inactive with reason.
471    Inactive(String),
472    /// Custom status.
473    Custom { code: u32, message: String },
474}
475"#;
476
477        let parser = RustParser::new();
478        let codebase = parser.parse_source(source, "test.rs".to_string()).unwrap();
479
480        assert_eq!(codebase.enums.len(), 1);
481        let e = &codebase.enums[0];
482        assert_eq!(e.name, "Status");
483        assert!(e.derives.contains(&"Debug".to_string()));
484        assert!(e.derives.contains(&"Clone".to_string()));
485        assert_eq!(e.variants.len(), 3);
486        assert_eq!(e.variants[0].name, "Active");
487        assert_eq!(e.variants[1].name, "Inactive");
488        assert_eq!(e.variants[2].name, "Custom");
489    }
490
491    #[test]
492    fn test_parse_methods() {
493        let source = r#"
494pub struct Calculator {
495    value: i32,
496}
497
498impl Calculator {
499    /// Creates a new calculator.
500    pub fn new() -> Self {
501        Self { value: 0 }
502    }
503
504    /// Adds a value.
505    pub fn add(&mut self, n: i32) {
506        self.value += n;
507    }
508
509    /// Gets the current value.
510    pub fn value(&self) -> i32 {
511        self.value
512    }
513}
514"#;
515
516        let parser = RustParser::new();
517        let codebase = parser.parse_source(source, "test.rs".to_string()).unwrap();
518
519        assert_eq!(codebase.structs.len(), 1);
520        let s = &codebase.structs[0];
521        assert_eq!(s.methods.len(), 3);
522
523        let new_method = s.methods.iter().find(|m| m.name == "new").unwrap();
524        assert!(new_method.receiver.is_none());
525        assert!(new_method.return_type.is_some());
526
527        let add_method = s.methods.iter().find(|m| m.name == "add").unwrap();
528        assert!(matches!(add_method.receiver, Some(Receiver::RefMut)));
529        assert_eq!(add_method.parameters.len(), 1);
530
531        let value_method = s.methods.iter().find(|m| m.name == "value").unwrap();
532        assert!(matches!(value_method.receiver, Some(Receiver::Ref)));
533    }
534
535    #[test]
536    fn test_parse_generics() {
537        let source = r#"
538pub struct Container<T, U> {
539    first: T,
540    second: U,
541}
542"#;
543
544        let parser = RustParser::new();
545        let codebase = parser.parse_source(source, "test.rs".to_string()).unwrap();
546
547        assert_eq!(codebase.structs.len(), 1);
548        let s = &codebase.structs[0];
549        assert_eq!(s.generics, vec!["T", "U"]);
550    }
551
552    #[test]
553    fn test_exclude_private() {
554        let source = r#"
555pub struct Public {}
556struct Private {}
557"#;
558
559        let parser = RustParser::new().include_private(false);
560        let codebase = parser.parse_source(source, "test.rs".to_string()).unwrap();
561
562        assert_eq!(codebase.structs.len(), 1);
563        assert_eq!(codebase.structs[0].name, "Public");
564    }
565}