forge_codegen/
parser.rs

1//! Rust source code parser for extracting FORGE schema definitions.
2//!
3//! This module parses Rust source files to extract model, enum, and function
4//! definitions without requiring compilation.
5
6use std::path::Path;
7
8use forge_core::schema::{
9    EnumDef, EnumVariant, FieldDef, FunctionArg, FunctionDef, FunctionKind, RustType,
10    SchemaRegistry, TableDef,
11};
12use syn::{Attribute, Expr, Fields, FnArg, Lit, Meta, Pat, ReturnType};
13use walkdir::WalkDir;
14
15use crate::Error;
16
17/// Parse all Rust source files in a directory and extract schema definitions.
18pub fn parse_project(src_dir: &Path) -> Result<SchemaRegistry, Error> {
19    let registry = SchemaRegistry::new();
20
21    for entry in WalkDir::new(src_dir)
22        .into_iter()
23        .filter_map(|e| e.ok())
24        .filter(|e| e.path().extension().map(|ext| ext == "rs").unwrap_or(false))
25    {
26        let content = std::fs::read_to_string(entry.path())?;
27        if let Err(e) = parse_file(&content, &registry) {
28            tracing::debug!(file = ?entry.path(), error = %e, "Failed to parse file");
29        }
30    }
31
32    Ok(registry)
33}
34
35/// Parse a single Rust source file and extract schema definitions.
36fn parse_file(content: &str, registry: &SchemaRegistry) -> Result<(), Error> {
37    let file = syn::parse_file(content).map_err(|e| Error::Template(e.to_string()))?;
38
39    for item in file.items {
40        match item {
41            syn::Item::Struct(item_struct) => {
42                if has_forge_model_attr(&item_struct.attrs) {
43                    if let Some(table) = parse_model(&item_struct) {
44                        registry.register_table(table);
45                    }
46                }
47            }
48            syn::Item::Enum(item_enum) => {
49                if has_forge_enum_attr(&item_enum.attrs) {
50                    if let Some(enum_def) = parse_enum(&item_enum) {
51                        registry.register_enum(enum_def);
52                    }
53                }
54            }
55            syn::Item::Fn(item_fn) => {
56                if let Some(func) = parse_function(&item_fn) {
57                    registry.register_function(func);
58                }
59            }
60            _ => {}
61        }
62    }
63
64    Ok(())
65}
66
67/// Check if attributes contain #[forge::model] or #[model].
68fn has_forge_model_attr(attrs: &[Attribute]) -> bool {
69    attrs.iter().any(|attr| {
70        let path = attr.path();
71        path.is_ident("model")
72            || path.segments.len() == 2
73                && path.segments[0].ident == "forge"
74                && path.segments[1].ident == "model"
75    })
76}
77
78/// Check if attributes contain #[forge_enum] or #[forge::enum_type].
79fn has_forge_enum_attr(attrs: &[Attribute]) -> bool {
80    attrs.iter().any(|attr| {
81        let path = attr.path();
82        path.is_ident("forge_enum")
83            || path.is_ident("enum_type")
84            || path.segments.len() == 2
85                && path.segments[0].ident == "forge"
86                && path.segments[1].ident == "enum_type"
87    })
88}
89
90/// Parse a struct with #[model] attribute into a TableDef.
91fn parse_model(item: &syn::ItemStruct) -> Option<TableDef> {
92    let struct_name = item.ident.to_string();
93    let table_name = get_table_name_from_attrs(&item.attrs).unwrap_or_else(|| {
94        let snake = to_snake_case(&struct_name);
95        pluralize(&snake)
96    });
97
98    let mut table = TableDef::new(&table_name, &struct_name);
99
100    // Extract documentation
101    table.doc = get_doc_comment(&item.attrs);
102
103    // Extract fields
104    if let Fields::Named(fields) = &item.fields {
105        for field in &fields.named {
106            if let Some(field_name) = &field.ident {
107                let field_def = parse_field(field_name.to_string(), &field.ty, &field.attrs);
108                table.fields.push(field_def);
109            }
110        }
111    }
112
113    Some(table)
114}
115
116/// Parse a field definition.
117fn parse_field(name: String, ty: &syn::Type, attrs: &[Attribute]) -> FieldDef {
118    let rust_type = type_to_rust_type(ty);
119    let mut field = FieldDef::new(&name, rust_type);
120    field.column_name = to_snake_case(&name);
121    field.doc = get_doc_comment(attrs);
122
123    // Parse field attributes
124    for attr in attrs {
125        let path = attr.path();
126        if path.is_ident("id") {
127            field
128                .attributes
129                .push(forge_core::schema::FieldAttribute::Id);
130        } else if path.is_ident("indexed") {
131            field
132                .attributes
133                .push(forge_core::schema::FieldAttribute::Indexed);
134        } else if path.is_ident("unique") {
135            field
136                .attributes
137                .push(forge_core::schema::FieldAttribute::Unique);
138        } else if path.is_ident("encrypted") {
139            field
140                .attributes
141                .push(forge_core::schema::FieldAttribute::Encrypted);
142        } else if path.is_ident("updated_at") {
143            field
144                .attributes
145                .push(forge_core::schema::FieldAttribute::UpdatedAt);
146        } else if path.is_ident("default") {
147            if let Some(value) = get_attribute_string_value(attr) {
148                field.default = Some(value);
149            }
150        }
151    }
152
153    field
154}
155
156/// Parse an enum with #[forge_enum] attribute into an EnumDef.
157fn parse_enum(item: &syn::ItemEnum) -> Option<EnumDef> {
158    let enum_name = item.ident.to_string();
159    let mut enum_def = EnumDef::new(&enum_name);
160    enum_def.doc = get_doc_comment(&item.attrs);
161
162    for variant in &item.variants {
163        let variant_name = variant.ident.to_string();
164        let mut enum_variant = EnumVariant::new(&variant_name);
165        enum_variant.doc = get_doc_comment(&variant.attrs);
166
167        // Check for explicit value
168        if let Some((_, Expr::Lit(lit))) = &variant.discriminant {
169            if let Lit::Int(int_lit) = &lit.lit {
170                if let Ok(value) = int_lit.base10_parse::<i32>() {
171                    enum_variant.int_value = Some(value);
172                }
173            }
174        }
175
176        enum_def.variants.push(enum_variant);
177    }
178
179    Some(enum_def)
180}
181
182/// Parse a function with #[query], #[mutation], or #[action] attribute.
183fn parse_function(item: &syn::ItemFn) -> Option<FunctionDef> {
184    let kind = get_function_kind(&item.attrs)?;
185    let func_name = item.sig.ident.to_string();
186
187    // Get return type
188    let return_type = match &item.sig.output {
189        ReturnType::Default => RustType::Custom("()".to_string()),
190        ReturnType::Type(_, ty) => extract_result_type(ty),
191    };
192
193    let mut func = FunctionDef::new(&func_name, kind, return_type);
194    func.doc = get_doc_comment(&item.attrs);
195    func.is_async = item.sig.asyncness.is_some();
196
197    // Parse arguments (skip first arg which is usually context)
198    let mut skip_first = true;
199    for arg in &item.sig.inputs {
200        if let FnArg::Typed(pat_type) = arg {
201            // Skip context argument (usually first)
202            if skip_first {
203                skip_first = false;
204                // Check if it's a context type
205                let type_str = quote::quote!(#pat_type.ty).to_string();
206                if type_str.contains("Context")
207                    || type_str.contains("QueryContext")
208                    || type_str.contains("MutationContext")
209                    || type_str.contains("ActionContext")
210                {
211                    continue;
212                }
213            }
214
215            // Extract argument name
216            if let Pat::Ident(pat_ident) = &*pat_type.pat {
217                let arg_name = pat_ident.ident.to_string();
218                let arg_type = type_to_rust_type(&pat_type.ty);
219                func.args.push(FunctionArg::new(arg_name, arg_type));
220            }
221        }
222    }
223
224    Some(func)
225}
226
227/// Get the function kind from attributes.
228fn get_function_kind(attrs: &[Attribute]) -> Option<FunctionKind> {
229    for attr in attrs {
230        let path = attr.path();
231        let segments: Vec<_> = path.segments.iter().map(|s| s.ident.to_string()).collect();
232
233        // Check for #[forge::X] or #[X] patterns
234        let kind_str = if segments.len() == 2 && segments[0] == "forge" {
235            Some(segments[1].as_str())
236        } else if segments.len() == 1 {
237            Some(segments[0].as_str())
238        } else {
239            None
240        };
241
242        if let Some(kind) = kind_str {
243            match kind {
244                "query" => return Some(FunctionKind::Query),
245                "mutation" => return Some(FunctionKind::Mutation),
246                "action" => return Some(FunctionKind::Action),
247                "job" => return Some(FunctionKind::Job),
248                "cron" => return Some(FunctionKind::Cron),
249                "workflow" => return Some(FunctionKind::Workflow),
250                _ => {}
251            }
252        }
253    }
254    None
255}
256
257/// Extract the inner type from Result<T, E>.
258fn extract_result_type(ty: &syn::Type) -> RustType {
259    let type_str = quote::quote!(#ty).to_string().replace(' ', "");
260
261    // Check for Result<T, _>
262    if let Some(rest) = type_str.strip_prefix("Result<") {
263        // Find the inner type before the comma or angle bracket
264        let mut depth = 0;
265        let mut end_idx = 0;
266        for (i, c) in rest.chars().enumerate() {
267            match c {
268                '<' => depth += 1,
269                '>' => {
270                    if depth == 0 {
271                        end_idx = i;
272                        break;
273                    }
274                    depth -= 1;
275                }
276                ',' if depth == 0 => {
277                    end_idx = i;
278                    break;
279                }
280                _ => {}
281            }
282        }
283        let inner = &rest[..end_idx];
284        return type_to_rust_type(
285            &syn::parse_str(inner)
286                .unwrap_or_else(|_| syn::parse_str::<syn::Type>("String").unwrap()),
287        );
288    }
289
290    type_to_rust_type(ty)
291}
292
293/// Convert a syn::Type to RustType.
294fn type_to_rust_type(ty: &syn::Type) -> RustType {
295    let type_str = quote::quote!(#ty).to_string().replace(' ', "");
296
297    // Handle common types
298    match type_str.as_str() {
299        "String" | "&str" => RustType::String,
300        "i32" => RustType::I32,
301        "i64" => RustType::I64,
302        "f32" => RustType::F32,
303        "f64" => RustType::F64,
304        "bool" => RustType::Bool,
305        "Uuid" | "uuid::Uuid" => RustType::Uuid,
306        "DateTime<Utc>" | "chrono::DateTime<Utc>" | "chrono::DateTime<chrono::Utc>" => {
307            RustType::DateTime
308        }
309        "NaiveDate" | "chrono::NaiveDate" => RustType::Date,
310        "NaiveTime" | "chrono::NaiveTime" => RustType::Custom("NaiveTime".to_string()),
311        "serde_json::Value" | "Value" => RustType::Json,
312        "Vec<u8>" => RustType::Bytes,
313        _ => {
314            // Handle Option<T>
315            if let Some(inner) = type_str
316                .strip_prefix("Option<")
317                .and_then(|s| s.strip_suffix('>'))
318            {
319                let inner_type = match inner {
320                    "String" => RustType::String,
321                    "i32" => RustType::I32,
322                    "i64" => RustType::I64,
323                    "f64" => RustType::F64,
324                    "bool" => RustType::Bool,
325                    "Uuid" => RustType::Uuid,
326                    _ => RustType::String, // Default fallback
327                };
328                return RustType::Option(Box::new(inner_type));
329            }
330
331            // Handle Vec<T>
332            if let Some(inner) = type_str
333                .strip_prefix("Vec<")
334                .and_then(|s| s.strip_suffix('>'))
335            {
336                let inner_type = match inner {
337                    "String" => RustType::String,
338                    "i32" => RustType::I32,
339                    "u8" => return RustType::Bytes,
340                    _ => RustType::String,
341                };
342                return RustType::Vec(Box::new(inner_type));
343            }
344
345            // Default to custom type
346            RustType::Custom(type_str)
347        }
348    }
349}
350
351/// Get #[table(name = "...")] value from attributes.
352fn get_table_name_from_attrs(attrs: &[Attribute]) -> Option<String> {
353    for attr in attrs {
354        if attr.path().is_ident("table") {
355            if let Meta::List(list) = &attr.meta {
356                let tokens = list.tokens.to_string();
357                if let Some(value) = extract_name_value(&tokens) {
358                    return Some(value);
359                }
360            }
361        }
362    }
363    None
364}
365
366/// Get string value from attribute like #[attr = "value"].
367fn get_attribute_string_value(attr: &Attribute) -> Option<String> {
368    if let Meta::NameValue(nv) = &attr.meta {
369        if let Expr::Lit(lit) = &nv.value {
370            if let Lit::Str(s) = &lit.lit {
371                return Some(s.value());
372            }
373        }
374    }
375    None
376}
377
378/// Get documentation comment from attributes.
379fn get_doc_comment(attrs: &[Attribute]) -> Option<String> {
380    let docs: Vec<String> = attrs
381        .iter()
382        .filter_map(|attr| {
383            if attr.path().is_ident("doc") {
384                get_attribute_string_value(attr)
385            } else {
386                None
387            }
388        })
389        .collect();
390
391    if docs.is_empty() {
392        None
393    } else {
394        Some(
395            docs.into_iter()
396                .map(|s| s.trim().to_string())
397                .collect::<Vec<_>>()
398                .join("\n"),
399        )
400    }
401}
402
403/// Extract name value from "name = \"value\"" format.
404fn extract_name_value(s: &str) -> Option<String> {
405    let parts: Vec<&str> = s.splitn(2, '=').collect();
406    if parts.len() == 2 {
407        let value = parts[1].trim();
408        if let Some(stripped) = value.strip_prefix('"').and_then(|s| s.strip_suffix('"')) {
409            return Some(stripped.to_string());
410        }
411    }
412    None
413}
414
415/// Convert a string to snake_case.
416fn to_snake_case(s: &str) -> String {
417    let mut result = String::new();
418    for (i, c) in s.chars().enumerate() {
419        if c.is_uppercase() {
420            if i > 0 {
421                result.push('_');
422            }
423            result.push(c.to_lowercase().next().unwrap());
424        } else {
425            result.push(c);
426        }
427    }
428    result
429}
430
431/// Simple English pluralization.
432fn pluralize(s: &str) -> String {
433    if s.ends_with('s')
434        || s.ends_with("sh")
435        || s.ends_with("ch")
436        || s.ends_with('x')
437        || s.ends_with('z')
438    {
439        format!("{}es", s)
440    } else if let Some(stem) = s.strip_suffix('y') {
441        if !s.ends_with("ay") && !s.ends_with("ey") && !s.ends_with("oy") && !s.ends_with("uy") {
442            format!("{}ies", stem)
443        } else {
444            format!("{}s", s)
445        }
446    } else {
447        format!("{}s", s)
448    }
449}
450
451#[cfg(test)]
452mod tests {
453    use super::*;
454
455    #[test]
456    fn test_parse_model_source() {
457        let source = r#"
458            #[model]
459            struct User {
460                #[id]
461                id: Uuid,
462                email: String,
463                name: Option<String>,
464                #[indexed]
465                created_at: DateTime<Utc>,
466            }
467        "#;
468
469        let registry = SchemaRegistry::new();
470        parse_file(source, &registry).unwrap();
471
472        let table = registry.get_table("users").unwrap();
473        assert_eq!(table.struct_name, "User");
474        assert_eq!(table.fields.len(), 4);
475    }
476
477    #[test]
478    fn test_parse_enum_source() {
479        let source = r#"
480            #[forge_enum]
481            enum ProjectStatus {
482                Draft,
483                Active,
484                Completed,
485            }
486        "#;
487
488        let registry = SchemaRegistry::new();
489        parse_file(source, &registry).unwrap();
490
491        let enum_def = registry.get_enum("ProjectStatus").unwrap();
492        assert_eq!(enum_def.variants.len(), 3);
493    }
494
495    #[test]
496    fn test_to_snake_case() {
497        assert_eq!(to_snake_case("UserProfile"), "user_profile");
498        assert_eq!(to_snake_case("ID"), "i_d");
499        assert_eq!(to_snake_case("createdAt"), "created_at");
500    }
501
502    #[test]
503    fn test_pluralize() {
504        assert_eq!(pluralize("user"), "users");
505        assert_eq!(pluralize("category"), "categories");
506        assert_eq!(pluralize("box"), "boxes");
507        assert_eq!(pluralize("address"), "addresses");
508    }
509
510    #[test]
511    fn test_parse_query_function() {
512        let source = r#"
513            #[query]
514            async fn get_user(ctx: QueryContext, id: Uuid) -> Result<User> {
515                todo!()
516            }
517        "#;
518
519        let registry = SchemaRegistry::new();
520        parse_file(source, &registry).unwrap();
521
522        let func = registry.get_function("get_user").unwrap();
523        assert_eq!(func.name, "get_user");
524        assert_eq!(func.kind, FunctionKind::Query);
525        assert!(func.is_async);
526    }
527
528    #[test]
529    fn test_parse_mutation_function() {
530        let source = r#"
531            #[mutation]
532            async fn create_user(ctx: MutationContext, name: String, email: String) -> Result<User> {
533                todo!()
534            }
535        "#;
536
537        let registry = SchemaRegistry::new();
538        parse_file(source, &registry).unwrap();
539
540        let func = registry.get_function("create_user").unwrap();
541        assert_eq!(func.name, "create_user");
542        assert_eq!(func.kind, FunctionKind::Mutation);
543        assert_eq!(func.args.len(), 2);
544        assert_eq!(func.args[0].name, "name");
545        assert_eq!(func.args[1].name, "email");
546    }
547
548    #[test]
549    fn test_parse_action_function() {
550        let source = r#"
551            #[action]
552            async fn send_notification(ctx: ActionContext, message: String) -> Result<()> {
553                todo!()
554            }
555        "#;
556
557        let registry = SchemaRegistry::new();
558        parse_file(source, &registry).unwrap();
559
560        let func = registry.get_function("send_notification").unwrap();
561        assert_eq!(func.name, "send_notification");
562        assert_eq!(func.kind, FunctionKind::Action);
563    }
564}