forge_codegen/
parser.rs

1//! Rust source code parser for extracting FORGE schema definitions.
2//!
3//! This module parses Rust source files to extract model, enum, and function
4//! definitions without requiring compilation.
5
6use std::path::Path;
7
8use forge_core::schema::{
9    EnumDef, EnumVariant, FieldDef, FunctionArg, FunctionDef, FunctionKind, RustType,
10    SchemaRegistry, TableDef,
11};
12use quote::ToTokens;
13use syn::{Attribute, Expr, Fields, FnArg, Lit, Meta, Pat, ReturnType};
14use walkdir::WalkDir;
15
16use crate::Error;
17
18/// Parse all Rust source files in a directory and extract schema definitions.
19pub fn parse_project(src_dir: &Path) -> Result<SchemaRegistry, Error> {
20    let registry = SchemaRegistry::new();
21
22    for entry in WalkDir::new(src_dir)
23        .into_iter()
24        .filter_map(|e| e.ok())
25        .filter(|e| e.path().extension().map(|ext| ext == "rs").unwrap_or(false))
26    {
27        let content = std::fs::read_to_string(entry.path())?;
28        if let Err(e) = parse_file(&content, &registry) {
29            tracing::debug!(file = ?entry.path(), error = %e, "Failed to parse file");
30        }
31    }
32
33    Ok(registry)
34}
35
36/// Parse a single Rust source file and extract schema definitions.
37fn parse_file(content: &str, registry: &SchemaRegistry) -> Result<(), Error> {
38    let file = syn::parse_file(content).map_err(|e| Error::Template(e.to_string()))?;
39
40    for item in file.items {
41        match item {
42            syn::Item::Struct(item_struct) => {
43                if has_forge_model_attr(&item_struct.attrs) {
44                    if let Some(table) = parse_model(&item_struct) {
45                        registry.register_table(table);
46                    }
47                } else if has_serde_derive(&item_struct.attrs) {
48                    // Parse DTO structs (those with Serialize/Deserialize)
49                    if let Some(table) = parse_dto_struct(&item_struct) {
50                        registry.register_table(table);
51                    }
52                }
53            }
54            syn::Item::Enum(item_enum) => {
55                if has_forge_enum_attr(&item_enum.attrs) {
56                    if let Some(enum_def) = parse_enum(&item_enum) {
57                        registry.register_enum(enum_def);
58                    }
59                } else if has_serde_derive(&item_enum.attrs) {
60                    // Parse enums with Serialize/Deserialize
61                    if let Some(enum_def) = parse_enum(&item_enum) {
62                        registry.register_enum(enum_def);
63                    }
64                }
65            }
66            syn::Item::Fn(item_fn) => {
67                if let Some(func) = parse_function(&item_fn) {
68                    registry.register_function(func);
69                }
70            }
71            _ => {}
72        }
73    }
74
75    Ok(())
76}
77
78/// Check if attributes contain #[forge::model] or #[model].
79fn has_forge_model_attr(attrs: &[Attribute]) -> bool {
80    attrs.iter().any(|attr| {
81        let path = attr.path();
82        path.is_ident("model")
83            || path.segments.len() == 2
84                && path.segments[0].ident == "forge"
85                && path.segments[1].ident == "model"
86    })
87}
88
89/// Check if attributes contain #[forge_enum] or #[forge::enum_type].
90fn has_forge_enum_attr(attrs: &[Attribute]) -> bool {
91    attrs.iter().any(|attr| {
92        let path = attr.path();
93        path.is_ident("forge_enum")
94            || path.is_ident("enum_type")
95            || path.segments.len() == 2
96                && path.segments[0].ident == "forge"
97                && path.segments[1].ident == "enum_type"
98    })
99}
100
101/// Check if attributes contain #[derive(...Serialize...)] or #[derive(...Deserialize...)].
102fn has_serde_derive(attrs: &[Attribute]) -> bool {
103    attrs.iter().any(|attr| {
104        if !attr.path().is_ident("derive") {
105            return false;
106        }
107        let tokens = attr.meta.to_token_stream().to_string();
108        tokens.contains("Serialize") || tokens.contains("Deserialize")
109    })
110}
111
112/// Parse a DTO struct (with Serialize/Deserialize) into a TableDef.
113fn parse_dto_struct(item: &syn::ItemStruct) -> Option<TableDef> {
114    let struct_name = item.ident.to_string();
115
116    // Use struct name as table name (DTOs don't have SQL tables)
117    let mut table = TableDef::new(&struct_name, &struct_name);
118
119    // Mark as DTO (not a database table)
120    table.is_dto = true;
121
122    // Extract documentation
123    table.doc = get_doc_comment(&item.attrs);
124
125    // Extract fields
126    if let Fields::Named(fields) = &item.fields {
127        for field in &fields.named {
128            if let Some(field_name) = &field.ident {
129                let field_def = parse_field(field_name.to_string(), &field.ty, &field.attrs);
130                table.fields.push(field_def);
131            }
132        }
133    }
134
135    Some(table)
136}
137
138/// Parse a struct with #[model] attribute into a TableDef.
139fn parse_model(item: &syn::ItemStruct) -> Option<TableDef> {
140    let struct_name = item.ident.to_string();
141    let table_name = get_table_name_from_attrs(&item.attrs).unwrap_or_else(|| {
142        let snake = to_snake_case(&struct_name);
143        pluralize(&snake)
144    });
145
146    let mut table = TableDef::new(&table_name, &struct_name);
147
148    // Extract documentation
149    table.doc = get_doc_comment(&item.attrs);
150
151    // Extract fields
152    if let Fields::Named(fields) = &item.fields {
153        for field in &fields.named {
154            if let Some(field_name) = &field.ident {
155                let field_def = parse_field(field_name.to_string(), &field.ty, &field.attrs);
156                table.fields.push(field_def);
157            }
158        }
159    }
160
161    Some(table)
162}
163
164/// Parse a field definition.
165fn parse_field(name: String, ty: &syn::Type, attrs: &[Attribute]) -> FieldDef {
166    let rust_type = type_to_rust_type(ty);
167    let mut field = FieldDef::new(&name, rust_type);
168    field.column_name = to_snake_case(&name);
169    field.doc = get_doc_comment(attrs);
170    field
171}
172
173/// Parse an enum with #[forge_enum] attribute into an EnumDef.
174fn parse_enum(item: &syn::ItemEnum) -> Option<EnumDef> {
175    let enum_name = item.ident.to_string();
176    let mut enum_def = EnumDef::new(&enum_name);
177    enum_def.doc = get_doc_comment(&item.attrs);
178
179    for variant in &item.variants {
180        let variant_name = variant.ident.to_string();
181        let mut enum_variant = EnumVariant::new(&variant_name);
182        enum_variant.doc = get_doc_comment(&variant.attrs);
183
184        // Check for explicit value
185        if let Some((_, Expr::Lit(lit))) = &variant.discriminant {
186            if let Lit::Int(int_lit) = &lit.lit {
187                if let Ok(value) = int_lit.base10_parse::<i32>() {
188                    enum_variant.int_value = Some(value);
189                }
190            }
191        }
192
193        enum_def.variants.push(enum_variant);
194    }
195
196    Some(enum_def)
197}
198
199/// Parse a function with #[query], #[mutation], or #[action] attribute.
200fn parse_function(item: &syn::ItemFn) -> Option<FunctionDef> {
201    let kind = get_function_kind(&item.attrs)?;
202    let func_name = item.sig.ident.to_string();
203
204    // Get return type
205    let return_type = match &item.sig.output {
206        ReturnType::Default => RustType::Custom("()".to_string()),
207        ReturnType::Type(_, ty) => extract_result_type(ty),
208    };
209
210    let mut func = FunctionDef::new(&func_name, kind, return_type);
211    func.doc = get_doc_comment(&item.attrs);
212    func.is_async = item.sig.asyncness.is_some();
213
214    // Parse arguments (skip first arg which is usually context)
215    let mut skip_first = true;
216    for arg in &item.sig.inputs {
217        if let FnArg::Typed(pat_type) = arg {
218            // Skip context argument (usually first)
219            if skip_first {
220                skip_first = false;
221                // Check if it's a context type
222                let type_str = quote::quote!(#pat_type.ty).to_string();
223                if type_str.contains("Context")
224                    || type_str.contains("QueryContext")
225                    || type_str.contains("MutationContext")
226                    || type_str.contains("ActionContext")
227                {
228                    continue;
229                }
230            }
231
232            // Extract argument name
233            if let Pat::Ident(pat_ident) = &*pat_type.pat {
234                let arg_name = pat_ident.ident.to_string();
235                let arg_type = type_to_rust_type(&pat_type.ty);
236                func.args.push(FunctionArg::new(arg_name, arg_type));
237            }
238        }
239    }
240
241    Some(func)
242}
243
244/// Get the function kind from attributes.
245fn get_function_kind(attrs: &[Attribute]) -> Option<FunctionKind> {
246    for attr in attrs {
247        let path = attr.path();
248        let segments: Vec<_> = path.segments.iter().map(|s| s.ident.to_string()).collect();
249
250        // Check for #[forge::X] or #[X] patterns
251        let kind_str = if segments.len() == 2 && segments[0] == "forge" {
252            Some(segments[1].as_str())
253        } else if segments.len() == 1 {
254            Some(segments[0].as_str())
255        } else {
256            None
257        };
258
259        if let Some(kind) = kind_str {
260            match kind {
261                "query" => return Some(FunctionKind::Query),
262                "mutation" => return Some(FunctionKind::Mutation),
263                "action" => return Some(FunctionKind::Action),
264                "job" => return Some(FunctionKind::Job),
265                "cron" => return Some(FunctionKind::Cron),
266                "workflow" => return Some(FunctionKind::Workflow),
267                _ => {}
268            }
269        }
270    }
271    None
272}
273
274/// Extract the inner type from Result<T, E>.
275fn extract_result_type(ty: &syn::Type) -> RustType {
276    let type_str = quote::quote!(#ty).to_string().replace(' ', "");
277
278    // Check for Result<T, _>
279    if let Some(rest) = type_str.strip_prefix("Result<") {
280        // Find the inner type before the comma or angle bracket
281        let mut depth = 0;
282        let mut end_idx = 0;
283        for (i, c) in rest.chars().enumerate() {
284            match c {
285                '<' => depth += 1,
286                '>' => {
287                    if depth == 0 {
288                        end_idx = i;
289                        break;
290                    }
291                    depth -= 1;
292                }
293                ',' if depth == 0 => {
294                    end_idx = i;
295                    break;
296                }
297                _ => {}
298            }
299        }
300        let inner = &rest[..end_idx];
301        return type_to_rust_type(
302            &syn::parse_str(inner)
303                .unwrap_or_else(|_| syn::parse_str::<syn::Type>("String").unwrap()),
304        );
305    }
306
307    type_to_rust_type(ty)
308}
309
310/// Convert a syn::Type to RustType.
311fn type_to_rust_type(ty: &syn::Type) -> RustType {
312    let type_str = quote::quote!(#ty).to_string().replace(' ', "");
313
314    // Handle common types
315    match type_str.as_str() {
316        "String" | "&str" => RustType::String,
317        "i32" => RustType::I32,
318        "i64" => RustType::I64,
319        "f32" => RustType::F32,
320        "f64" => RustType::F64,
321        "bool" => RustType::Bool,
322        "Uuid" | "uuid::Uuid" => RustType::Uuid,
323        "DateTime<Utc>" | "chrono::DateTime<Utc>" | "chrono::DateTime<chrono::Utc>" => {
324            RustType::DateTime
325        }
326        "NaiveDate" | "chrono::NaiveDate" => RustType::Date,
327        "NaiveTime" | "chrono::NaiveTime" => RustType::Custom("NaiveTime".to_string()),
328        "serde_json::Value" | "Value" => RustType::Json,
329        "Vec<u8>" => RustType::Bytes,
330        _ => {
331            // Handle Option<T>
332            if let Some(inner) = type_str
333                .strip_prefix("Option<")
334                .and_then(|s| s.strip_suffix('>'))
335            {
336                let inner_type = match inner {
337                    "String" => RustType::String,
338                    "i32" => RustType::I32,
339                    "i64" => RustType::I64,
340                    "f64" => RustType::F64,
341                    "bool" => RustType::Bool,
342                    "Uuid" => RustType::Uuid,
343                    _ => RustType::String, // Default fallback
344                };
345                return RustType::Option(Box::new(inner_type));
346            }
347
348            // Handle Vec<T>
349            if let Some(inner) = type_str
350                .strip_prefix("Vec<")
351                .and_then(|s| s.strip_suffix('>'))
352            {
353                let inner_type = match inner {
354                    "String" => RustType::String,
355                    "i32" => RustType::I32,
356                    "u8" => return RustType::Bytes,
357                    _ => RustType::String,
358                };
359                return RustType::Vec(Box::new(inner_type));
360            }
361
362            // Default to custom type
363            RustType::Custom(type_str)
364        }
365    }
366}
367
368/// Get #[table(name = "...")] value from attributes.
369fn get_table_name_from_attrs(attrs: &[Attribute]) -> Option<String> {
370    for attr in attrs {
371        if attr.path().is_ident("table") {
372            if let Meta::List(list) = &attr.meta {
373                let tokens = list.tokens.to_string();
374                if let Some(value) = extract_name_value(&tokens) {
375                    return Some(value);
376                }
377            }
378        }
379    }
380    None
381}
382
383/// Get string value from attribute like #[attr = "value"].
384fn get_attribute_string_value(attr: &Attribute) -> Option<String> {
385    if let Meta::NameValue(nv) = &attr.meta {
386        if let Expr::Lit(lit) = &nv.value {
387            if let Lit::Str(s) = &lit.lit {
388                return Some(s.value());
389            }
390        }
391    }
392    None
393}
394
395/// Get documentation comment from attributes.
396fn get_doc_comment(attrs: &[Attribute]) -> Option<String> {
397    let docs: Vec<String> = attrs
398        .iter()
399        .filter_map(|attr| {
400            if attr.path().is_ident("doc") {
401                get_attribute_string_value(attr)
402            } else {
403                None
404            }
405        })
406        .collect();
407
408    if docs.is_empty() {
409        None
410    } else {
411        Some(
412            docs.into_iter()
413                .map(|s| s.trim().to_string())
414                .collect::<Vec<_>>()
415                .join("\n"),
416        )
417    }
418}
419
420/// Extract name value from "name = \"value\"" format.
421fn extract_name_value(s: &str) -> Option<String> {
422    let parts: Vec<&str> = s.splitn(2, '=').collect();
423    if parts.len() == 2 {
424        let value = parts[1].trim();
425        if let Some(stripped) = value.strip_prefix('"').and_then(|s| s.strip_suffix('"')) {
426            return Some(stripped.to_string());
427        }
428    }
429    None
430}
431
432/// Convert a string to snake_case.
433fn to_snake_case(s: &str) -> String {
434    let mut result = String::new();
435    for (i, c) in s.chars().enumerate() {
436        if c.is_uppercase() {
437            if i > 0 {
438                result.push('_');
439            }
440            result.push(c.to_lowercase().next().unwrap());
441        } else {
442            result.push(c);
443        }
444    }
445    result
446}
447
448/// Simple English pluralization.
449fn pluralize(s: &str) -> String {
450    if s.ends_with('s')
451        || s.ends_with("sh")
452        || s.ends_with("ch")
453        || s.ends_with('x')
454        || s.ends_with('z')
455    {
456        format!("{}es", s)
457    } else if let Some(stem) = s.strip_suffix('y') {
458        if !s.ends_with("ay") && !s.ends_with("ey") && !s.ends_with("oy") && !s.ends_with("uy") {
459            format!("{}ies", stem)
460        } else {
461            format!("{}s", s)
462        }
463    } else {
464        format!("{}s", s)
465    }
466}
467
468#[cfg(test)]
469mod tests {
470    use super::*;
471
472    #[test]
473    fn test_parse_model_source() {
474        let source = r#"
475            #[model]
476            struct User {
477                #[id]
478                id: Uuid,
479                email: String,
480                name: Option<String>,
481                #[indexed]
482                created_at: DateTime<Utc>,
483            }
484        "#;
485
486        let registry = SchemaRegistry::new();
487        parse_file(source, &registry).unwrap();
488
489        let table = registry.get_table("users").unwrap();
490        assert_eq!(table.struct_name, "User");
491        assert_eq!(table.fields.len(), 4);
492    }
493
494    #[test]
495    fn test_parse_enum_source() {
496        let source = r#"
497            #[forge_enum]
498            enum ProjectStatus {
499                Draft,
500                Active,
501                Completed,
502            }
503        "#;
504
505        let registry = SchemaRegistry::new();
506        parse_file(source, &registry).unwrap();
507
508        let enum_def = registry.get_enum("ProjectStatus").unwrap();
509        assert_eq!(enum_def.variants.len(), 3);
510    }
511
512    #[test]
513    fn test_to_snake_case() {
514        assert_eq!(to_snake_case("UserProfile"), "user_profile");
515        assert_eq!(to_snake_case("ID"), "i_d");
516        assert_eq!(to_snake_case("createdAt"), "created_at");
517    }
518
519    #[test]
520    fn test_pluralize() {
521        assert_eq!(pluralize("user"), "users");
522        assert_eq!(pluralize("category"), "categories");
523        assert_eq!(pluralize("box"), "boxes");
524        assert_eq!(pluralize("address"), "addresses");
525    }
526
527    #[test]
528    fn test_parse_query_function() {
529        let source = r#"
530            #[query]
531            async fn get_user(ctx: QueryContext, id: Uuid) -> Result<User> {
532                todo!()
533            }
534        "#;
535
536        let registry = SchemaRegistry::new();
537        parse_file(source, &registry).unwrap();
538
539        let func = registry.get_function("get_user").unwrap();
540        assert_eq!(func.name, "get_user");
541        assert_eq!(func.kind, FunctionKind::Query);
542        assert!(func.is_async);
543    }
544
545    #[test]
546    fn test_parse_mutation_function() {
547        let source = r#"
548            #[mutation]
549            async fn create_user(ctx: MutationContext, name: String, email: String) -> Result<User> {
550                todo!()
551            }
552        "#;
553
554        let registry = SchemaRegistry::new();
555        parse_file(source, &registry).unwrap();
556
557        let func = registry.get_function("create_user").unwrap();
558        assert_eq!(func.name, "create_user");
559        assert_eq!(func.kind, FunctionKind::Mutation);
560        assert_eq!(func.args.len(), 2);
561        assert_eq!(func.args[0].name, "name");
562        assert_eq!(func.args[1].name, "email");
563    }
564
565    #[test]
566    fn test_parse_action_function() {
567        let source = r#"
568            #[action]
569            async fn send_notification(ctx: ActionContext, message: String) -> Result<()> {
570                todo!()
571            }
572        "#;
573
574        let registry = SchemaRegistry::new();
575        parse_file(source, &registry).unwrap();
576
577        let func = registry.get_function("send_notification").unwrap();
578        assert_eq!(func.name, "send_notification");
579        assert_eq!(func.kind, FunctionKind::Action);
580    }
581}