Skip to main content

forge_codegen/
parser.rs

1//! Rust source code parser for extracting FORGE schema definitions.
2//!
3//! This module parses Rust source files to extract model, enum, and function
4//! definitions without requiring compilation.
5
6use std::path::{Path, PathBuf};
7
8use forge_core::schema::{
9    EnumDef, EnumVariant, FieldDef, FunctionArg, FunctionDef, FunctionKind, RustType,
10    SchemaRegistry, TableDef,
11};
12use forge_core::util::to_snake_case;
13use quote::ToTokens;
14use syn::{Attribute, Expr, Fields, FnArg, Lit, Meta, Pat, ReturnType};
15
16use crate::Error;
17
18fn collect_rs_files(dir: &Path, out: &mut Vec<PathBuf>) {
19    let entries = match std::fs::read_dir(dir) {
20        Ok(e) => e,
21        Err(_) => return,
22    };
23    for entry in entries.flatten() {
24        let path = entry.path();
25        if path.is_dir() {
26            collect_rs_files(&path, out);
27        } else if path.extension().is_some_and(|ext| ext == "rs") {
28            out.push(path);
29        }
30    }
31}
32
33/// Parse all Rust source files in a directory and extract schema definitions.
34pub fn parse_project(src_dir: &Path) -> Result<SchemaRegistry, Error> {
35    let registry = SchemaRegistry::new();
36
37    let mut files = Vec::new();
38    collect_rs_files(src_dir, &mut files);
39    files.sort();
40
41    for path in &files {
42        let content = std::fs::read_to_string(path)?;
43        if let Err(e) = parse_file(&content, &registry) {
44            tracing::debug!(file = ?path, error = %e, "Failed to parse file");
45        }
46    }
47
48    Ok(registry)
49}
50
51/// Parse a single Rust source file and extract schema definitions.
52fn parse_file(content: &str, registry: &SchemaRegistry) -> Result<(), Error> {
53    let file = syn::parse_file(content).map_err(|e| Error::Template(e.to_string()))?;
54
55    for item in file.items {
56        match item {
57            syn::Item::Struct(item_struct) => {
58                if has_forge_model_attr(&item_struct.attrs) {
59                    if let Some(table) = parse_model(&item_struct) {
60                        registry.register_table(table);
61                    }
62                } else if has_serde_derive(&item_struct.attrs) {
63                    // Parse DTO structs (those with Serialize/Deserialize)
64                    if let Some(table) = parse_dto_struct(&item_struct) {
65                        registry.register_table(table);
66                    }
67                }
68            }
69            syn::Item::Enum(item_enum) => {
70                if has_forge_enum_attr(&item_enum.attrs) {
71                    if let Some(enum_def) = parse_enum(&item_enum) {
72                        registry.register_enum(enum_def);
73                    }
74                } else if has_serde_derive(&item_enum.attrs) {
75                    // Parse enums with Serialize/Deserialize
76                    if let Some(enum_def) = parse_enum(&item_enum) {
77                        registry.register_enum(enum_def);
78                    }
79                }
80            }
81            syn::Item::Fn(item_fn) => {
82                if let Some(func) = parse_function(&item_fn) {
83                    registry.register_function(func);
84                }
85            }
86            _ => {}
87        }
88    }
89
90    Ok(())
91}
92
93/// Check if attributes contain #[forge::model] or #[model].
94fn has_forge_model_attr(attrs: &[Attribute]) -> bool {
95    attrs.iter().any(|attr| {
96        let path = attr.path();
97        path.is_ident("model")
98            || path.segments.len() == 2
99                && path.segments[0].ident == "forge"
100                && path.segments[1].ident == "model"
101    })
102}
103
104/// Check if attributes contain #[forge_enum] or #[forge::enum_type].
105fn has_forge_enum_attr(attrs: &[Attribute]) -> bool {
106    attrs.iter().any(|attr| {
107        let path = attr.path();
108        path.is_ident("forge_enum")
109            || path.is_ident("enum_type")
110            || path.segments.len() == 2
111                && path.segments[0].ident == "forge"
112                && (path.segments[1].ident == "enum_type" || path.segments[1].ident == "forge_enum")
113    })
114}
115
116/// Check if attributes contain #[derive(...Serialize...)] or #[derive(...Deserialize...)].
117fn has_serde_derive(attrs: &[Attribute]) -> bool {
118    attrs.iter().any(|attr| {
119        if !attr.path().is_ident("derive") {
120            return false;
121        }
122        let tokens = attr.meta.to_token_stream().to_string();
123        tokens.contains("Serialize") || tokens.contains("Deserialize")
124    })
125}
126
127/// Parse a DTO struct (with Serialize/Deserialize) into a TableDef.
128fn parse_dto_struct(item: &syn::ItemStruct) -> Option<TableDef> {
129    let struct_name = item.ident.to_string();
130
131    // Use struct name as table name (DTOs don't have SQL tables)
132    let mut table = TableDef::new(&struct_name, &struct_name);
133
134    // Mark as DTO (not a database table)
135    table.is_dto = true;
136
137    // Extract documentation
138    table.doc = get_doc_comment(&item.attrs);
139
140    // Extract fields
141    if let Fields::Named(fields) = &item.fields {
142        for field in &fields.named {
143            if let Some(field_name) = &field.ident {
144                let field_def = parse_field(field_name.to_string(), &field.ty, &field.attrs);
145                table.fields.push(field_def);
146            }
147        }
148    }
149
150    Some(table)
151}
152
153/// Parse a struct with #[model] attribute into a TableDef.
154fn parse_model(item: &syn::ItemStruct) -> Option<TableDef> {
155    let struct_name = item.ident.to_string();
156    let table_name = get_table_name_from_attrs(&item.attrs).unwrap_or_else(|| {
157        let snake = to_snake_case(&struct_name);
158        pluralize(&snake)
159    });
160
161    let mut table = TableDef::new(&table_name, &struct_name);
162
163    // Extract documentation
164    table.doc = get_doc_comment(&item.attrs);
165
166    // Extract fields
167    if let Fields::Named(fields) = &item.fields {
168        for field in &fields.named {
169            if let Some(field_name) = &field.ident {
170                let field_def = parse_field(field_name.to_string(), &field.ty, &field.attrs);
171                table.fields.push(field_def);
172            }
173        }
174    }
175
176    Some(table)
177}
178
179/// Parse a field definition.
180fn parse_field(name: String, ty: &syn::Type, attrs: &[Attribute]) -> FieldDef {
181    let rust_type = type_to_rust_type(ty);
182    let mut field = FieldDef::new(&name, rust_type);
183    field.column_name = to_snake_case(&name);
184    field.doc = get_doc_comment(attrs);
185    field
186}
187
188/// Parse an enum with #[forge_enum] attribute into an EnumDef.
189fn parse_enum(item: &syn::ItemEnum) -> Option<EnumDef> {
190    let enum_name = item.ident.to_string();
191    let mut enum_def = EnumDef::new(&enum_name);
192    enum_def.doc = get_doc_comment(&item.attrs);
193
194    for variant in &item.variants {
195        let variant_name = variant.ident.to_string();
196        let mut enum_variant = EnumVariant::new(&variant_name);
197        enum_variant.doc = get_doc_comment(&variant.attrs);
198
199        // Check for explicit value
200        if let Some((_, Expr::Lit(lit))) = &variant.discriminant
201            && let Lit::Int(int_lit) = &lit.lit
202            && let Ok(value) = int_lit.base10_parse::<i32>()
203        {
204            enum_variant.int_value = Some(value);
205        }
206
207        enum_def.variants.push(enum_variant);
208    }
209
210    Some(enum_def)
211}
212
213/// Parse a function with #[query] or #[mutation] attribute.
214fn parse_function(item: &syn::ItemFn) -> Option<FunctionDef> {
215    let kind = get_function_kind(&item.attrs)?;
216    let func_name = item.sig.ident.to_string();
217
218    // Get return type
219    let return_type = match &item.sig.output {
220        ReturnType::Default => RustType::Custom("()".to_string()),
221        ReturnType::Type(_, ty) => extract_result_type(ty),
222    };
223
224    let mut func = FunctionDef::new(&func_name, kind, return_type);
225    func.doc = get_doc_comment(&item.attrs);
226    func.is_async = item.sig.asyncness.is_some();
227
228    // Parse arguments (skip first arg which is usually context)
229    let mut skip_first = true;
230    for arg in &item.sig.inputs {
231        if let FnArg::Typed(pat_type) = arg {
232            // Skip context argument (usually first)
233            if skip_first {
234                skip_first = false;
235                // Check if it's a context type
236                let type_str = quote::quote!(#pat_type.ty).to_string();
237                if type_str.contains("Context")
238                    || type_str.contains("QueryContext")
239                    || type_str.contains("MutationContext")
240                {
241                    continue;
242                }
243            }
244
245            // Extract argument name
246            if let Pat::Ident(pat_ident) = &*pat_type.pat {
247                let arg_name = pat_ident.ident.to_string();
248                let arg_type = type_to_rust_type(&pat_type.ty);
249                func.args.push(FunctionArg::new(arg_name, arg_type));
250            }
251        }
252    }
253
254    Some(func)
255}
256
257/// Get the function kind from attributes.
258fn get_function_kind(attrs: &[Attribute]) -> Option<FunctionKind> {
259    for attr in attrs {
260        let path = attr.path();
261        let segments: Vec<_> = path.segments.iter().map(|s| s.ident.to_string()).collect();
262
263        // Check for #[forge::X] or #[X] patterns
264        let kind_str = if segments.len() == 2 && segments[0] == "forge" {
265            Some(segments[1].as_str())
266        } else if segments.len() == 1 {
267            Some(segments[0].as_str())
268        } else {
269            None
270        };
271
272        if let Some(kind) = kind_str {
273            match kind {
274                "query" => return Some(FunctionKind::Query),
275                "mutation" => return Some(FunctionKind::Mutation),
276                "job" => return Some(FunctionKind::Job),
277                "cron" => return Some(FunctionKind::Cron),
278                "workflow" => return Some(FunctionKind::Workflow),
279                _ => {}
280            }
281        }
282    }
283    None
284}
285
286/// Extract the inner type from Result<T, E>.
287fn extract_result_type(ty: &syn::Type) -> RustType {
288    let type_str = quote::quote!(#ty).to_string().replace(' ', "");
289
290    // Check for Result<T, _>
291    if let Some(rest) = type_str.strip_prefix("Result<") {
292        // Find the inner type before the comma or angle bracket
293        let mut depth = 0;
294        let mut end_idx = 0;
295        for (i, c) in rest.chars().enumerate() {
296            match c {
297                '<' => depth += 1,
298                '>' => {
299                    if depth == 0 {
300                        end_idx = i;
301                        break;
302                    }
303                    depth -= 1;
304                }
305                ',' if depth == 0 => {
306                    end_idx = i;
307                    break;
308                }
309                _ => {}
310            }
311        }
312        let inner = &rest[..end_idx];
313        return type_to_rust_type(
314            &syn::parse_str(inner)
315                .unwrap_or_else(|_| syn::parse_str::<syn::Type>("String").unwrap()),
316        );
317    }
318
319    type_to_rust_type(ty)
320}
321
322/// Convert a syn::Type to RustType.
323fn type_to_rust_type(ty: &syn::Type) -> RustType {
324    let type_str = quote::quote!(#ty).to_string().replace(' ', "");
325
326    // Handle common types
327    match type_str.as_str() {
328        "String" | "&str" => RustType::String,
329        "i32" => RustType::I32,
330        "i64" => RustType::I64,
331        "f32" => RustType::F32,
332        "f64" => RustType::F64,
333        "bool" => RustType::Bool,
334        "Uuid" | "uuid::Uuid" => RustType::Uuid,
335        "DateTime<Utc>" | "chrono::DateTime<Utc>" | "chrono::DateTime<chrono::Utc>" => {
336            RustType::DateTime
337        }
338        "NaiveDate" | "chrono::NaiveDate" => RustType::Date,
339        "NaiveTime" | "chrono::NaiveTime" => RustType::Custom("NaiveTime".to_string()),
340        "serde_json::Value" | "Value" => RustType::Json,
341        "Vec<u8>" => RustType::Bytes,
342        _ => {
343            // Handle Option<T> by recursively resolving the inner type
344            if let Some(inner) = type_str
345                .strip_prefix("Option<")
346                .and_then(|s| s.strip_suffix('>'))
347            {
348                let inner_ty: syn::Type =
349                    syn::parse_str(inner).unwrap_or_else(|_| syn::parse_str("String").unwrap());
350                return RustType::Option(Box::new(type_to_rust_type(&inner_ty)));
351            }
352
353            // Handle Vec<T> by recursively resolving the inner type
354            if let Some(inner) = type_str
355                .strip_prefix("Vec<")
356                .and_then(|s| s.strip_suffix('>'))
357            {
358                if inner == "u8" {
359                    return RustType::Bytes;
360                }
361                let inner_ty: syn::Type =
362                    syn::parse_str(inner).unwrap_or_else(|_| syn::parse_str("String").unwrap());
363                return RustType::Vec(Box::new(type_to_rust_type(&inner_ty)));
364            }
365
366            // Default to custom type
367            RustType::Custom(type_str)
368        }
369    }
370}
371
372/// Get #[table(name = "...")] value from attributes.
373fn get_table_name_from_attrs(attrs: &[Attribute]) -> Option<String> {
374    for attr in attrs {
375        if attr.path().is_ident("table")
376            && let Meta::List(list) = &attr.meta
377        {
378            let tokens = list.tokens.to_string();
379            if let Some(value) = extract_name_value(&tokens) {
380                return Some(value);
381            }
382        }
383    }
384    None
385}
386
387/// Get string value from attribute like #[attr = "value"].
388fn get_attribute_string_value(attr: &Attribute) -> Option<String> {
389    if let Meta::NameValue(nv) = &attr.meta
390        && let Expr::Lit(lit) = &nv.value
391        && let Lit::Str(s) = &lit.lit
392    {
393        return Some(s.value());
394    }
395    None
396}
397
398/// Get documentation comment from attributes.
399fn get_doc_comment(attrs: &[Attribute]) -> Option<String> {
400    let docs: Vec<String> = attrs
401        .iter()
402        .filter_map(|attr| {
403            if attr.path().is_ident("doc") {
404                get_attribute_string_value(attr)
405            } else {
406                None
407            }
408        })
409        .collect();
410
411    if docs.is_empty() {
412        None
413    } else {
414        Some(
415            docs.into_iter()
416                .map(|s| s.trim().to_string())
417                .collect::<Vec<_>>()
418                .join("\n"),
419        )
420    }
421}
422
423/// Extract name value from "name = \"value\"" format.
424fn extract_name_value(s: &str) -> Option<String> {
425    let parts: Vec<&str> = s.splitn(2, '=').collect();
426    if parts.len() == 2 {
427        let value = parts[1].trim();
428        if let Some(stripped) = value.strip_prefix('"').and_then(|s| s.strip_suffix('"')) {
429            return Some(stripped.to_string());
430        }
431    }
432    None
433}
434
435/// Simple English pluralization.
436fn pluralize(s: &str) -> String {
437    if s.ends_with('s')
438        || s.ends_with("sh")
439        || s.ends_with("ch")
440        || s.ends_with('x')
441        || s.ends_with('z')
442    {
443        format!("{}es", s)
444    } else if let Some(stem) = s.strip_suffix('y') {
445        if !s.ends_with("ay") && !s.ends_with("ey") && !s.ends_with("oy") && !s.ends_with("uy") {
446            format!("{}ies", stem)
447        } else {
448            format!("{}s", s)
449        }
450    } else {
451        format!("{}s", s)
452    }
453}
454
455#[cfg(test)]
456mod tests {
457    use super::*;
458
459    #[test]
460    fn test_parse_model_source() {
461        let source = r#"
462            #[model]
463            struct User {
464                #[id]
465                id: Uuid,
466                email: String,
467                name: Option<String>,
468                #[indexed]
469                created_at: DateTime<Utc>,
470            }
471        "#;
472
473        let registry = SchemaRegistry::new();
474        parse_file(source, &registry).unwrap();
475
476        let table = registry.get_table("users").unwrap();
477        assert_eq!(table.struct_name, "User");
478        assert_eq!(table.fields.len(), 4);
479    }
480
481    #[test]
482    fn test_parse_enum_source() {
483        let source = r#"
484            #[forge_enum]
485            enum ProjectStatus {
486                Draft,
487                Active,
488                Completed,
489            }
490        "#;
491
492        let registry = SchemaRegistry::new();
493        parse_file(source, &registry).unwrap();
494
495        let enum_def = registry.get_enum("ProjectStatus").unwrap();
496        assert_eq!(enum_def.variants.len(), 3);
497    }
498
499    #[test]
500    fn test_to_snake_case() {
501        assert_eq!(to_snake_case("UserProfile"), "user_profile");
502        assert_eq!(to_snake_case("ID"), "i_d");
503        assert_eq!(to_snake_case("createdAt"), "created_at");
504    }
505
506    #[test]
507    fn test_pluralize() {
508        assert_eq!(pluralize("user"), "users");
509        assert_eq!(pluralize("category"), "categories");
510        assert_eq!(pluralize("box"), "boxes");
511        assert_eq!(pluralize("address"), "addresses");
512    }
513
514    #[test]
515    fn test_parse_query_function() {
516        let source = r#"
517            #[query]
518            async fn get_user(ctx: QueryContext, id: Uuid) -> Result<User> {
519                todo!()
520            }
521        "#;
522
523        let registry = SchemaRegistry::new();
524        parse_file(source, &registry).unwrap();
525
526        let func = registry.get_function("get_user").unwrap();
527        assert_eq!(func.name, "get_user");
528        assert_eq!(func.kind, FunctionKind::Query);
529        assert!(func.is_async);
530    }
531
532    #[test]
533    fn test_parse_mutation_function() {
534        let source = r#"
535            #[mutation]
536            async fn create_user(ctx: MutationContext, name: String, email: String) -> Result<User> {
537                todo!()
538            }
539        "#;
540
541        let registry = SchemaRegistry::new();
542        parse_file(source, &registry).unwrap();
543
544        let func = registry.get_function("create_user").unwrap();
545        assert_eq!(func.name, "create_user");
546        assert_eq!(func.kind, FunctionKind::Mutation);
547        assert_eq!(func.args.len(), 2);
548        assert_eq!(func.args[0].name, "name");
549        assert_eq!(func.args[1].name, "email");
550    }
551}