forge_codegen/
parser.rs

1//! Rust source code parser for extracting FORGE schema definitions.
2//!
3//! This module parses Rust source files to extract model, enum, and function
4//! definitions without requiring compilation.
5
6use std::path::Path;
7
8use forge_core::schema::{
9    EnumDef, EnumVariant, FieldDef, FunctionArg, FunctionDef, FunctionKind, RustType,
10    SchemaRegistry, TableDef,
11};
12use quote::ToTokens;
13use syn::{Attribute, Expr, Fields, FnArg, Lit, Meta, Pat, ReturnType};
14use walkdir::WalkDir;
15
16use crate::Error;
17
18/// Parse all Rust source files in a directory and extract schema definitions.
19pub fn parse_project(src_dir: &Path) -> Result<SchemaRegistry, Error> {
20    let registry = SchemaRegistry::new();
21
22    let mut files: Vec<_> = WalkDir::new(src_dir)
23        .into_iter()
24        .filter_map(|e| e.ok())
25        .filter(|e| e.path().extension().map(|ext| ext == "rs").unwrap_or(false))
26        .collect();
27    files.sort_by(|a, b| a.path().cmp(b.path()));
28
29    for entry in files {
30        let content = std::fs::read_to_string(entry.path())?;
31        if let Err(e) = parse_file(&content, &registry) {
32            tracing::debug!(file = ?entry.path(), error = %e, "Failed to parse file");
33        }
34    }
35
36    Ok(registry)
37}
38
39/// Parse a single Rust source file and extract schema definitions.
40fn parse_file(content: &str, registry: &SchemaRegistry) -> Result<(), Error> {
41    let file = syn::parse_file(content).map_err(|e| Error::Template(e.to_string()))?;
42
43    for item in file.items {
44        match item {
45            syn::Item::Struct(item_struct) => {
46                if has_forge_model_attr(&item_struct.attrs) {
47                    if let Some(table) = parse_model(&item_struct) {
48                        registry.register_table(table);
49                    }
50                } else if has_serde_derive(&item_struct.attrs) {
51                    // Parse DTO structs (those with Serialize/Deserialize)
52                    if let Some(table) = parse_dto_struct(&item_struct) {
53                        registry.register_table(table);
54                    }
55                }
56            }
57            syn::Item::Enum(item_enum) => {
58                if has_forge_enum_attr(&item_enum.attrs) {
59                    if let Some(enum_def) = parse_enum(&item_enum) {
60                        registry.register_enum(enum_def);
61                    }
62                } else if has_serde_derive(&item_enum.attrs) {
63                    // Parse enums with Serialize/Deserialize
64                    if let Some(enum_def) = parse_enum(&item_enum) {
65                        registry.register_enum(enum_def);
66                    }
67                }
68            }
69            syn::Item::Fn(item_fn) => {
70                if let Some(func) = parse_function(&item_fn) {
71                    registry.register_function(func);
72                }
73            }
74            _ => {}
75        }
76    }
77
78    Ok(())
79}
80
81/// Check if attributes contain #[forge::model] or #[model].
82fn has_forge_model_attr(attrs: &[Attribute]) -> bool {
83    attrs.iter().any(|attr| {
84        let path = attr.path();
85        path.is_ident("model")
86            || path.segments.len() == 2
87                && path.segments[0].ident == "forge"
88                && path.segments[1].ident == "model"
89    })
90}
91
92/// Check if attributes contain #[forge_enum] or #[forge::enum_type].
93fn has_forge_enum_attr(attrs: &[Attribute]) -> bool {
94    attrs.iter().any(|attr| {
95        let path = attr.path();
96        path.is_ident("forge_enum")
97            || path.is_ident("enum_type")
98            || path.segments.len() == 2
99                && path.segments[0].ident == "forge"
100                && path.segments[1].ident == "enum_type"
101    })
102}
103
104/// Check if attributes contain #[derive(...Serialize...)] or #[derive(...Deserialize...)].
105fn has_serde_derive(attrs: &[Attribute]) -> bool {
106    attrs.iter().any(|attr| {
107        if !attr.path().is_ident("derive") {
108            return false;
109        }
110        let tokens = attr.meta.to_token_stream().to_string();
111        tokens.contains("Serialize") || tokens.contains("Deserialize")
112    })
113}
114
115/// Parse a DTO struct (with Serialize/Deserialize) into a TableDef.
116fn parse_dto_struct(item: &syn::ItemStruct) -> Option<TableDef> {
117    let struct_name = item.ident.to_string();
118
119    // Use struct name as table name (DTOs don't have SQL tables)
120    let mut table = TableDef::new(&struct_name, &struct_name);
121
122    // Mark as DTO (not a database table)
123    table.is_dto = true;
124
125    // Extract documentation
126    table.doc = get_doc_comment(&item.attrs);
127
128    // Extract fields
129    if let Fields::Named(fields) = &item.fields {
130        for field in &fields.named {
131            if let Some(field_name) = &field.ident {
132                let field_def = parse_field(field_name.to_string(), &field.ty, &field.attrs);
133                table.fields.push(field_def);
134            }
135        }
136    }
137
138    Some(table)
139}
140
141/// Parse a struct with #[model] attribute into a TableDef.
142fn parse_model(item: &syn::ItemStruct) -> Option<TableDef> {
143    let struct_name = item.ident.to_string();
144    let table_name = get_table_name_from_attrs(&item.attrs).unwrap_or_else(|| {
145        let snake = to_snake_case(&struct_name);
146        pluralize(&snake)
147    });
148
149    let mut table = TableDef::new(&table_name, &struct_name);
150
151    // Extract documentation
152    table.doc = get_doc_comment(&item.attrs);
153
154    // Extract fields
155    if let Fields::Named(fields) = &item.fields {
156        for field in &fields.named {
157            if let Some(field_name) = &field.ident {
158                let field_def = parse_field(field_name.to_string(), &field.ty, &field.attrs);
159                table.fields.push(field_def);
160            }
161        }
162    }
163
164    Some(table)
165}
166
167/// Parse a field definition.
168fn parse_field(name: String, ty: &syn::Type, attrs: &[Attribute]) -> FieldDef {
169    let rust_type = type_to_rust_type(ty);
170    let mut field = FieldDef::new(&name, rust_type);
171    field.column_name = to_snake_case(&name);
172    field.doc = get_doc_comment(attrs);
173    field
174}
175
176/// Parse an enum with #[forge_enum] attribute into an EnumDef.
177fn parse_enum(item: &syn::ItemEnum) -> Option<EnumDef> {
178    let enum_name = item.ident.to_string();
179    let mut enum_def = EnumDef::new(&enum_name);
180    enum_def.doc = get_doc_comment(&item.attrs);
181
182    for variant in &item.variants {
183        let variant_name = variant.ident.to_string();
184        let mut enum_variant = EnumVariant::new(&variant_name);
185        enum_variant.doc = get_doc_comment(&variant.attrs);
186
187        // Check for explicit value
188        if let Some((_, Expr::Lit(lit))) = &variant.discriminant {
189            if let Lit::Int(int_lit) = &lit.lit {
190                if let Ok(value) = int_lit.base10_parse::<i32>() {
191                    enum_variant.int_value = Some(value);
192                }
193            }
194        }
195
196        enum_def.variants.push(enum_variant);
197    }
198
199    Some(enum_def)
200}
201
202/// Parse a function with #[query] or #[mutation] attribute.
203fn parse_function(item: &syn::ItemFn) -> Option<FunctionDef> {
204    let kind = get_function_kind(&item.attrs)?;
205    let func_name = item.sig.ident.to_string();
206
207    // Get return type
208    let return_type = match &item.sig.output {
209        ReturnType::Default => RustType::Custom("()".to_string()),
210        ReturnType::Type(_, ty) => extract_result_type(ty),
211    };
212
213    let mut func = FunctionDef::new(&func_name, kind, return_type);
214    func.doc = get_doc_comment(&item.attrs);
215    func.is_async = item.sig.asyncness.is_some();
216
217    // Parse arguments (skip first arg which is usually context)
218    let mut skip_first = true;
219    for arg in &item.sig.inputs {
220        if let FnArg::Typed(pat_type) = arg {
221            // Skip context argument (usually first)
222            if skip_first {
223                skip_first = false;
224                // Check if it's a context type
225                let type_str = quote::quote!(#pat_type.ty).to_string();
226                if type_str.contains("Context")
227                    || type_str.contains("QueryContext")
228                    || type_str.contains("MutationContext")
229                {
230                    continue;
231                }
232            }
233
234            // Extract argument name
235            if let Pat::Ident(pat_ident) = &*pat_type.pat {
236                let arg_name = pat_ident.ident.to_string();
237                let arg_type = type_to_rust_type(&pat_type.ty);
238                func.args.push(FunctionArg::new(arg_name, arg_type));
239            }
240        }
241    }
242
243    Some(func)
244}
245
246/// Get the function kind from attributes.
247fn get_function_kind(attrs: &[Attribute]) -> Option<FunctionKind> {
248    for attr in attrs {
249        let path = attr.path();
250        let segments: Vec<_> = path.segments.iter().map(|s| s.ident.to_string()).collect();
251
252        // Check for #[forge::X] or #[X] patterns
253        let kind_str = if segments.len() == 2 && segments[0] == "forge" {
254            Some(segments[1].as_str())
255        } else if segments.len() == 1 {
256            Some(segments[0].as_str())
257        } else {
258            None
259        };
260
261        if let Some(kind) = kind_str {
262            match kind {
263                "query" => return Some(FunctionKind::Query),
264                "mutation" => return Some(FunctionKind::Mutation),
265                "job" => return Some(FunctionKind::Job),
266                "cron" => return Some(FunctionKind::Cron),
267                "workflow" => return Some(FunctionKind::Workflow),
268                _ => {}
269            }
270        }
271    }
272    None
273}
274
275/// Extract the inner type from Result<T, E>.
276fn extract_result_type(ty: &syn::Type) -> RustType {
277    let type_str = quote::quote!(#ty).to_string().replace(' ', "");
278
279    // Check for Result<T, _>
280    if let Some(rest) = type_str.strip_prefix("Result<") {
281        // Find the inner type before the comma or angle bracket
282        let mut depth = 0;
283        let mut end_idx = 0;
284        for (i, c) in rest.chars().enumerate() {
285            match c {
286                '<' => depth += 1,
287                '>' => {
288                    if depth == 0 {
289                        end_idx = i;
290                        break;
291                    }
292                    depth -= 1;
293                }
294                ',' if depth == 0 => {
295                    end_idx = i;
296                    break;
297                }
298                _ => {}
299            }
300        }
301        let inner = &rest[..end_idx];
302        return type_to_rust_type(
303            &syn::parse_str(inner)
304                .unwrap_or_else(|_| syn::parse_str::<syn::Type>("String").unwrap()),
305        );
306    }
307
308    type_to_rust_type(ty)
309}
310
311/// Convert a syn::Type to RustType.
312fn type_to_rust_type(ty: &syn::Type) -> RustType {
313    let type_str = quote::quote!(#ty).to_string().replace(' ', "");
314
315    // Handle common types
316    match type_str.as_str() {
317        "String" | "&str" => RustType::String,
318        "i32" => RustType::I32,
319        "i64" => RustType::I64,
320        "f32" => RustType::F32,
321        "f64" => RustType::F64,
322        "bool" => RustType::Bool,
323        "Uuid" | "uuid::Uuid" => RustType::Uuid,
324        "DateTime<Utc>" | "chrono::DateTime<Utc>" | "chrono::DateTime<chrono::Utc>" => {
325            RustType::DateTime
326        }
327        "NaiveDate" | "chrono::NaiveDate" => RustType::Date,
328        "NaiveTime" | "chrono::NaiveTime" => RustType::Custom("NaiveTime".to_string()),
329        "serde_json::Value" | "Value" => RustType::Json,
330        "Vec<u8>" => RustType::Bytes,
331        _ => {
332            // Handle Option<T>
333            if let Some(inner) = type_str
334                .strip_prefix("Option<")
335                .and_then(|s| s.strip_suffix('>'))
336            {
337                let inner_type = match inner {
338                    "String" => RustType::String,
339                    "i32" => RustType::I32,
340                    "i64" => RustType::I64,
341                    "f64" => RustType::F64,
342                    "bool" => RustType::Bool,
343                    "Uuid" => RustType::Uuid,
344                    _ => RustType::Custom(inner.to_string()),
345                };
346                return RustType::Option(Box::new(inner_type));
347            }
348
349            // Handle Vec<T>
350            if let Some(inner) = type_str
351                .strip_prefix("Vec<")
352                .and_then(|s| s.strip_suffix('>'))
353            {
354                let inner_type = match inner {
355                    "String" => RustType::String,
356                    "i32" => RustType::I32,
357                    "u8" => return RustType::Bytes,
358                    _ => RustType::Custom(inner.to_string()),
359                };
360                return RustType::Vec(Box::new(inner_type));
361            }
362
363            // Default to custom type
364            RustType::Custom(type_str)
365        }
366    }
367}
368
369/// Get #[table(name = "...")] value from attributes.
370fn get_table_name_from_attrs(attrs: &[Attribute]) -> Option<String> {
371    for attr in attrs {
372        if attr.path().is_ident("table") {
373            if let Meta::List(list) = &attr.meta {
374                let tokens = list.tokens.to_string();
375                if let Some(value) = extract_name_value(&tokens) {
376                    return Some(value);
377                }
378            }
379        }
380    }
381    None
382}
383
384/// Get string value from attribute like #[attr = "value"].
385fn get_attribute_string_value(attr: &Attribute) -> Option<String> {
386    if let Meta::NameValue(nv) = &attr.meta {
387        if let Expr::Lit(lit) = &nv.value {
388            if let Lit::Str(s) = &lit.lit {
389                return Some(s.value());
390            }
391        }
392    }
393    None
394}
395
396/// Get documentation comment from attributes.
397fn get_doc_comment(attrs: &[Attribute]) -> Option<String> {
398    let docs: Vec<String> = attrs
399        .iter()
400        .filter_map(|attr| {
401            if attr.path().is_ident("doc") {
402                get_attribute_string_value(attr)
403            } else {
404                None
405            }
406        })
407        .collect();
408
409    if docs.is_empty() {
410        None
411    } else {
412        Some(
413            docs.into_iter()
414                .map(|s| s.trim().to_string())
415                .collect::<Vec<_>>()
416                .join("\n"),
417        )
418    }
419}
420
421/// Extract name value from "name = \"value\"" format.
422fn extract_name_value(s: &str) -> Option<String> {
423    let parts: Vec<&str> = s.splitn(2, '=').collect();
424    if parts.len() == 2 {
425        let value = parts[1].trim();
426        if let Some(stripped) = value.strip_prefix('"').and_then(|s| s.strip_suffix('"')) {
427            return Some(stripped.to_string());
428        }
429    }
430    None
431}
432
433/// Convert a string to snake_case.
434fn to_snake_case(s: &str) -> String {
435    let mut result = String::new();
436    for (i, c) in s.chars().enumerate() {
437        if c.is_uppercase() {
438            if i > 0 {
439                result.push('_');
440            }
441            result.push(c.to_lowercase().next().unwrap());
442        } else {
443            result.push(c);
444        }
445    }
446    result
447}
448
449/// Simple English pluralization.
450fn pluralize(s: &str) -> String {
451    if s.ends_with('s')
452        || s.ends_with("sh")
453        || s.ends_with("ch")
454        || s.ends_with('x')
455        || s.ends_with('z')
456    {
457        format!("{}es", s)
458    } else if let Some(stem) = s.strip_suffix('y') {
459        if !s.ends_with("ay") && !s.ends_with("ey") && !s.ends_with("oy") && !s.ends_with("uy") {
460            format!("{}ies", stem)
461        } else {
462            format!("{}s", s)
463        }
464    } else {
465        format!("{}s", s)
466    }
467}
468
469#[cfg(test)]
470mod tests {
471    use super::*;
472
473    #[test]
474    fn test_parse_model_source() {
475        let source = r#"
476            #[model]
477            struct User {
478                #[id]
479                id: Uuid,
480                email: String,
481                name: Option<String>,
482                #[indexed]
483                created_at: DateTime<Utc>,
484            }
485        "#;
486
487        let registry = SchemaRegistry::new();
488        parse_file(source, &registry).unwrap();
489
490        let table = registry.get_table("users").unwrap();
491        assert_eq!(table.struct_name, "User");
492        assert_eq!(table.fields.len(), 4);
493    }
494
495    #[test]
496    fn test_parse_enum_source() {
497        let source = r#"
498            #[forge_enum]
499            enum ProjectStatus {
500                Draft,
501                Active,
502                Completed,
503            }
504        "#;
505
506        let registry = SchemaRegistry::new();
507        parse_file(source, &registry).unwrap();
508
509        let enum_def = registry.get_enum("ProjectStatus").unwrap();
510        assert_eq!(enum_def.variants.len(), 3);
511    }
512
513    #[test]
514    fn test_to_snake_case() {
515        assert_eq!(to_snake_case("UserProfile"), "user_profile");
516        assert_eq!(to_snake_case("ID"), "i_d");
517        assert_eq!(to_snake_case("createdAt"), "created_at");
518    }
519
520    #[test]
521    fn test_pluralize() {
522        assert_eq!(pluralize("user"), "users");
523        assert_eq!(pluralize("category"), "categories");
524        assert_eq!(pluralize("box"), "boxes");
525        assert_eq!(pluralize("address"), "addresses");
526    }
527
528    #[test]
529    fn test_parse_query_function() {
530        let source = r#"
531            #[query]
532            async fn get_user(ctx: QueryContext, id: Uuid) -> Result<User> {
533                todo!()
534            }
535        "#;
536
537        let registry = SchemaRegistry::new();
538        parse_file(source, &registry).unwrap();
539
540        let func = registry.get_function("get_user").unwrap();
541        assert_eq!(func.name, "get_user");
542        assert_eq!(func.kind, FunctionKind::Query);
543        assert!(func.is_async);
544    }
545
546    #[test]
547    fn test_parse_mutation_function() {
548        let source = r#"
549            #[mutation]
550            async fn create_user(ctx: MutationContext, name: String, email: String) -> Result<User> {
551                todo!()
552            }
553        "#;
554
555        let registry = SchemaRegistry::new();
556        parse_file(source, &registry).unwrap();
557
558        let func = registry.get_function("create_user").unwrap();
559        assert_eq!(func.name, "create_user");
560        assert_eq!(func.kind, FunctionKind::Mutation);
561        assert_eq!(func.args.len(), 2);
562        assert_eq!(func.args[0].name, "name");
563        assert_eq!(func.args[1].name, "email");
564    }
565}