Skip to main content

forge_codegen/
parser.rs

1//! Rust source code parser for extracting FORGE schema definitions.
2//!
3//! Parses Rust source files using `syn` to extract model, enum, and function
4//! definitions without requiring compilation.
5//!
6//! Key design decisions:
7//! - Context arguments are detected structurally (type ends with "Context"),
8//!   not by string-searching the entire token stream.
9//! - Unparseable inner types become `RustType::Custom(original_string)` instead
10//!   of silently falling back to `String`.
11//! - `NaiveTime` correctly maps to `RustType::LocalTime`.
12
13use std::path::{Path, PathBuf};
14
15use forge_core::schema::{
16    EnumDef, EnumVariant, FieldDef, FunctionArg, FunctionDef, FunctionKind, RustType,
17    SchemaRegistry, TableDef,
18};
19use forge_core::util::to_snake_case;
20use quote::ToTokens;
21use syn::{Attribute, Expr, Fields, FnArg, Lit, Meta, Pat, ReturnType};
22
23use crate::Error;
24
25fn collect_rs_files(dir: &Path, out: &mut Vec<PathBuf>) {
26    let entries = match std::fs::read_dir(dir) {
27        Ok(e) => e,
28        Err(_) => return,
29    };
30    for entry in entries.flatten() {
31        let path = entry.path();
32        if path.is_dir() {
33            collect_rs_files(&path, out);
34        } else if path.extension().is_some_and(|ext| ext == "rs") {
35            out.push(path);
36        }
37    }
38}
39
40/// Parse all Rust source files in a directory and extract schema definitions.
41pub fn parse_project(src_dir: &Path) -> Result<SchemaRegistry, Error> {
42    let registry = SchemaRegistry::new();
43
44    let mut files = Vec::new();
45    collect_rs_files(src_dir, &mut files);
46    files.sort();
47
48    for path in &files {
49        let content = std::fs::read_to_string(path)?;
50        if let Err(e) = parse_file(&content, &registry) {
51            tracing::debug!(file = ?path, error = %e, "Failed to parse file");
52        }
53    }
54
55    Ok(registry)
56}
57
58/// Parse a single Rust source file and extract schema definitions.
59fn parse_file(content: &str, registry: &SchemaRegistry) -> Result<(), Error> {
60    let file = syn::parse_file(content).map_err(|e| Error::Parse {
61        file: String::new(),
62        message: e.to_string(),
63    })?;
64
65    for item in file.items {
66        match item {
67            syn::Item::Struct(item_struct) => {
68                if has_forge_attr(&item_struct.attrs, "model") {
69                    if let Some(table) = parse_model(&item_struct) {
70                        registry.register_table(table);
71                    }
72                } else if has_serde_derive(&item_struct.attrs)
73                    && let Some(table) = parse_dto_struct(&item_struct)
74                {
75                    registry.register_table(table);
76                }
77            }
78            syn::Item::Enum(item_enum) => {
79                if (has_forge_enum_attr(&item_enum.attrs) || has_serde_derive(&item_enum.attrs))
80                    && let Some(enum_def) = parse_enum(&item_enum)
81                {
82                    registry.register_enum(enum_def);
83                }
84            }
85            syn::Item::Fn(item_fn) => {
86                if let Some(func) = parse_function(&item_fn) {
87                    registry.register_function(func);
88                }
89            }
90            _ => {}
91        }
92    }
93
94    Ok(())
95}
96
97// ---------------------------------------------------------------------------
98// Attribute detection
99// ---------------------------------------------------------------------------
100
101/// Check if attributes contain `#[forge::name]` or `#[name]`.
102fn has_forge_attr(attrs: &[Attribute], name: &str) -> bool {
103    attrs.iter().any(|attr| {
104        let path = attr.path();
105        path.is_ident(name)
106            || matches!(
107                (path.segments.first(), path.segments.get(1), path.segments.get(2)),
108                (Some(first), Some(second), None)
109                    if first.ident == "forge" && second.ident == name
110            )
111    })
112}
113
114/// Check if attributes contain `#[forge_enum]`, `#[enum_type]`, or `#[forge::enum_type]`.
115fn has_forge_enum_attr(attrs: &[Attribute]) -> bool {
116    attrs.iter().any(|attr| {
117        let path = attr.path();
118        path.is_ident("forge_enum")
119            || path.is_ident("enum_type")
120            || matches!(
121                (path.segments.first(), path.segments.get(1), path.segments.get(2)),
122                (Some(first), Some(second), None)
123                    if first.ident == "forge"
124                        && (second.ident == "enum_type" || second.ident == "forge_enum")
125            )
126    })
127}
128
129/// Check if attributes contain `#[derive(...Serialize...)]` or `#[derive(...Deserialize...)]`.
130fn has_serde_derive(attrs: &[Attribute]) -> bool {
131    attrs.iter().any(|attr| {
132        if !attr.path().is_ident("derive") {
133            return false;
134        }
135        let tokens = attr.meta.to_token_stream().to_string();
136        tokens.contains("Serialize") || tokens.contains("Deserialize")
137    })
138}
139
140// ---------------------------------------------------------------------------
141// Struct/model parsing
142// ---------------------------------------------------------------------------
143
144/// Parse a DTO struct (with Serialize/Deserialize) into a TableDef.
145fn parse_dto_struct(item: &syn::ItemStruct) -> Option<TableDef> {
146    let struct_name = item.ident.to_string();
147    let mut table = TableDef::new(&struct_name, &struct_name);
148    table.is_dto = true;
149    table.doc = get_doc_comment(&item.attrs);
150
151    if let Fields::Named(fields) = &item.fields {
152        for field in &fields.named {
153            if let Some(field_name) = &field.ident {
154                table
155                    .fields
156                    .push(parse_field(field_name.to_string(), &field.ty, &field.attrs));
157            }
158        }
159    }
160
161    Some(table)
162}
163
164/// Parse a struct with `#[model]` attribute into a TableDef.
165fn parse_model(item: &syn::ItemStruct) -> Option<TableDef> {
166    let struct_name = item.ident.to_string();
167    let table_name = get_table_name_from_attrs(&item.attrs).unwrap_or_else(|| {
168        let snake = to_snake_case(&struct_name);
169        pluralize(&snake)
170    });
171
172    let mut table = TableDef::new(&table_name, &struct_name);
173    table.doc = get_doc_comment(&item.attrs);
174
175    if let Fields::Named(fields) = &item.fields {
176        for field in &fields.named {
177            if let Some(field_name) = &field.ident {
178                table
179                    .fields
180                    .push(parse_field(field_name.to_string(), &field.ty, &field.attrs));
181            }
182        }
183    }
184
185    Some(table)
186}
187
188fn parse_field(name: String, ty: &syn::Type, attrs: &[Attribute]) -> FieldDef {
189    let rust_type = type_to_rust_type(ty);
190    let mut field = FieldDef::new(&name, rust_type);
191    field.column_name = to_snake_case(&name);
192    field.doc = get_doc_comment(attrs);
193    field
194}
195
196// ---------------------------------------------------------------------------
197// Enum parsing
198// ---------------------------------------------------------------------------
199
200fn parse_enum(item: &syn::ItemEnum) -> Option<EnumDef> {
201    let enum_name = item.ident.to_string();
202    let mut enum_def = EnumDef::new(&enum_name);
203    enum_def.doc = get_doc_comment(&item.attrs);
204
205    for variant in &item.variants {
206        let variant_name = variant.ident.to_string();
207        let mut enum_variant = EnumVariant::new(&variant_name);
208        enum_variant.doc = get_doc_comment(&variant.attrs);
209
210        if let Some((_, Expr::Lit(lit))) = &variant.discriminant
211            && let Lit::Int(int_lit) = &lit.lit
212            && let Ok(value) = int_lit.base10_parse::<i32>()
213        {
214            enum_variant.int_value = Some(value);
215        }
216
217        enum_def.variants.push(enum_variant);
218    }
219
220    Some(enum_def)
221}
222
223// ---------------------------------------------------------------------------
224// Function parsing
225// ---------------------------------------------------------------------------
226
227/// Parse a function with a forge decorator attribute.
228fn parse_function(item: &syn::ItemFn) -> Option<FunctionDef> {
229    let kind = get_function_kind(&item.attrs)?;
230    let func_name = item.sig.ident.to_string();
231
232    let return_type = match &item.sig.output {
233        ReturnType::Default => RustType::Custom("()".to_string()),
234        ReturnType::Type(_, ty) => extract_result_type(ty),
235    };
236
237    let mut func = FunctionDef::new(&func_name, kind, return_type);
238    func.doc = get_doc_comment(&item.attrs);
239    func.is_async = item.sig.asyncness.is_some();
240
241    // Parse arguments, skipping the context parameter.
242    let mut is_first = true;
243    for arg in &item.sig.inputs {
244        if let FnArg::Typed(pat_type) = arg {
245            if is_first {
246                is_first = false;
247                if is_context_type(&pat_type.ty) {
248                    continue;
249                }
250            }
251
252            if let Pat::Ident(pat_ident) = &*pat_type.pat {
253                let arg_name = pat_ident.ident.to_string();
254                let arg_type = type_to_rust_type(&pat_type.ty);
255                func.args.push(FunctionArg::new(arg_name, arg_type));
256            }
257        }
258    }
259
260    Some(func)
261}
262
263/// Check if a type is a Forge context type by examining the base type name.
264///
265/// Handles references (`&Context`, `&mut Context`) and qualified paths
266/// (`forge::QueryContext`). Only matches types whose final segment ends
267/// with "Context" — won't match `ContextManager` or `NoContextHere`.
268fn is_context_type(ty: &syn::Type) -> bool {
269    // Get the type string, stripping whitespace for uniform matching.
270    let type_str = ty.to_token_stream().to_string().replace(' ', "");
271
272    // Strip leading references: &, &mut
273    let base = type_str.trim_start_matches('&').trim_start_matches("mut");
274
275    // Get the final path segment (after any :: qualifiers).
276    let final_segment = base.rsplit("::").next().unwrap_or(base);
277
278    final_segment.ends_with("Context")
279}
280
281fn get_function_kind(attrs: &[Attribute]) -> Option<FunctionKind> {
282    for attr in attrs {
283        let path = attr.path();
284        let segments: Vec<_> = path.segments.iter().map(|s| s.ident.to_string()).collect();
285
286        let kind_str = match segments.as_slice() {
287            [forge, kind] if forge == "forge" => Some(kind.as_str()),
288            [kind] => Some(kind.as_str()),
289            _ => None,
290        };
291
292        if let Some(kind) = kind_str {
293            match kind {
294                "query" => return Some(FunctionKind::Query),
295                "mutation" => return Some(FunctionKind::Mutation),
296                "job" => return Some(FunctionKind::Job),
297                "cron" => return Some(FunctionKind::Cron),
298                "workflow" => return Some(FunctionKind::Workflow),
299                _ => {}
300            }
301        }
302    }
303    None
304}
305
306// ---------------------------------------------------------------------------
307// Type conversion
308// ---------------------------------------------------------------------------
309
310/// Extract the inner type from `Result<T, E>`.
311fn extract_result_type(ty: &syn::Type) -> RustType {
312    let type_str = quote::quote!(#ty).to_string().replace(' ', "");
313
314    if let Some(rest) = type_str.strip_prefix("Result<") {
315        // Find the inner type (T) before the comma or closing bracket.
316        let mut depth = 0;
317        let mut end_idx = 0;
318        for (i, c) in rest.chars().enumerate() {
319            match c {
320                '<' => depth += 1,
321                '>' => {
322                    if depth == 0 {
323                        end_idx = i;
324                        break;
325                    }
326                    depth -= 1;
327                }
328                ',' if depth == 0 => {
329                    end_idx = i;
330                    break;
331                }
332                _ => {}
333            }
334        }
335        let inner = &rest[..end_idx];
336        return match syn::parse_str::<syn::Type>(inner) {
337            Ok(inner_ty) => type_to_rust_type(&inner_ty),
338            Err(_) => {
339                tracing::warn!(
340                    "Could not parse Result inner type '{}', treating as custom type",
341                    inner
342                );
343                RustType::Custom(inner.to_string())
344            }
345        };
346    }
347
348    type_to_rust_type(ty)
349}
350
351/// Convert a `syn::Type` to `RustType`.
352fn type_to_rust_type(ty: &syn::Type) -> RustType {
353    let type_str = quote::quote!(#ty).to_string().replace(' ', "");
354
355    match type_str.as_str() {
356        "String" | "&str" => RustType::String,
357        "i32" => RustType::I32,
358        "i64" => RustType::I64,
359        "f32" => RustType::F32,
360        "f64" => RustType::F64,
361        "bool" => RustType::Bool,
362        "Uuid" | "uuid::Uuid" => RustType::Uuid,
363        "DateTime<Utc>" | "chrono::DateTime<Utc>" | "chrono::DateTime<chrono::Utc>" => {
364            RustType::DateTime
365        }
366        "NaiveDate" | "chrono::NaiveDate" => RustType::Date,
367        "NaiveTime" | "chrono::NaiveTime" => RustType::LocalTime,
368        "serde_json::Value" | "Value" => RustType::Json,
369        "Vec<u8>" => RustType::Bytes,
370        _ => parse_generic_or_custom(&type_str),
371    }
372}
373
374/// Handle generic types (`Option<T>`, `Vec<T>`) and custom types.
375fn parse_generic_or_custom(type_str: &str) -> RustType {
376    // Option<T>
377    if let Some(inner) = type_str
378        .strip_prefix("Option<")
379        .and_then(|s| s.strip_suffix('>'))
380    {
381        let inner_type = parse_inner_type(inner);
382        return RustType::Option(Box::new(inner_type));
383    }
384
385    // Vec<T>
386    if let Some(inner) = type_str
387        .strip_prefix("Vec<")
388        .and_then(|s| s.strip_suffix('>'))
389    {
390        if inner == "u8" {
391            return RustType::Bytes;
392        }
393        let inner_type = parse_inner_type(inner);
394        return RustType::Vec(Box::new(inner_type));
395    }
396
397    // Everything else is a custom type.
398    RustType::Custom(type_str.to_string())
399}
400
401/// Parse an inner type string, falling back to Custom on failure.
402fn parse_inner_type(inner: &str) -> RustType {
403    match syn::parse_str::<syn::Type>(inner) {
404        Ok(inner_ty) => type_to_rust_type(&inner_ty),
405        Err(_) => {
406            tracing::warn!(
407                "Could not parse inner type '{}', treating as custom type",
408                inner
409            );
410            RustType::Custom(inner.to_string())
411        }
412    }
413}
414
415// ---------------------------------------------------------------------------
416// Attribute value helpers
417// ---------------------------------------------------------------------------
418
419/// Get `#[table(name = "...")]` value from attributes.
420fn get_table_name_from_attrs(attrs: &[Attribute]) -> Option<String> {
421    for attr in attrs {
422        if attr.path().is_ident("table")
423            && let Meta::List(list) = &attr.meta
424        {
425            let tokens = list.tokens.to_string();
426            if let Some(value) = extract_name_value(&tokens) {
427                return Some(value);
428            }
429        }
430    }
431    None
432}
433
434/// Get string value from attribute like `#[attr = "value"]`.
435fn get_attribute_string_value(attr: &Attribute) -> Option<String> {
436    if let Meta::NameValue(nv) = &attr.meta
437        && let Expr::Lit(lit) = &nv.value
438        && let Lit::Str(s) = &lit.lit
439    {
440        return Some(s.value());
441    }
442    None
443}
444
445/// Get documentation comment from attributes.
446fn get_doc_comment(attrs: &[Attribute]) -> Option<String> {
447    let docs: Vec<String> = attrs
448        .iter()
449        .filter_map(|attr| {
450            if attr.path().is_ident("doc") {
451                get_attribute_string_value(attr)
452            } else {
453                None
454            }
455        })
456        .collect();
457
458    if docs.is_empty() {
459        None
460    } else {
461        Some(
462            docs.into_iter()
463                .map(|s| s.trim().to_string())
464                .collect::<Vec<_>>()
465                .join("\n"),
466        )
467    }
468}
469
470/// Extract name value from `name = "value"` format.
471fn extract_name_value(s: &str) -> Option<String> {
472    if let Some((_, value)) = s.split_once('=') {
473        let value = value.trim();
474        if let Some(stripped) = value.strip_prefix('"').and_then(|s| s.strip_suffix('"')) {
475            return Some(stripped.to_string());
476        }
477    }
478    None
479}
480
481// ---------------------------------------------------------------------------
482// Pluralization
483// ---------------------------------------------------------------------------
484
485/// Simple English pluralization for table names.
486fn pluralize(s: &str) -> String {
487    if s.ends_with('s')
488        || s.ends_with("sh")
489        || s.ends_with("ch")
490        || s.ends_with('x')
491        || s.ends_with('z')
492    {
493        format!("{}es", s)
494    } else if let Some(stem) = s.strip_suffix('y') {
495        if !s.ends_with("ay") && !s.ends_with("ey") && !s.ends_with("oy") && !s.ends_with("uy") {
496            format!("{}ies", stem)
497        } else {
498            format!("{}s", s)
499        }
500    } else {
501        format!("{}s", s)
502    }
503}
504
505#[cfg(test)]
506mod tests {
507    use super::*;
508
509    #[test]
510    fn test_parse_model_source() {
511        let source = r#"
512            #[model]
513            struct User {
514                #[id]
515                id: Uuid,
516                email: String,
517                name: Option<String>,
518                #[indexed]
519                created_at: DateTime<Utc>,
520            }
521        "#;
522
523        let registry = SchemaRegistry::new();
524        parse_file(source, &registry).expect("model source should parse");
525
526        let table = registry
527            .get_table("users")
528            .expect("users table should be registered");
529        assert_eq!(table.struct_name, "User");
530        assert_eq!(table.fields.len(), 4);
531    }
532
533    #[test]
534    fn test_parse_enum_source() {
535        let source = r#"
536            #[forge_enum]
537            enum ProjectStatus {
538                Draft,
539                Active,
540                Completed,
541            }
542        "#;
543
544        let registry = SchemaRegistry::new();
545        parse_file(source, &registry).expect("enum source should parse");
546
547        let enum_def = registry
548            .get_enum("ProjectStatus")
549            .expect("ProjectStatus enum should be registered");
550        assert_eq!(enum_def.variants.len(), 3);
551    }
552
553    #[test]
554    fn test_to_snake_case() {
555        assert_eq!(to_snake_case("UserProfile"), "user_profile");
556        assert_eq!(to_snake_case("ID"), "i_d");
557        assert_eq!(to_snake_case("createdAt"), "created_at");
558    }
559
560    #[test]
561    fn test_pluralize() {
562        assert_eq!(pluralize("user"), "users");
563        assert_eq!(pluralize("category"), "categories");
564        assert_eq!(pluralize("box"), "boxes");
565        assert_eq!(pluralize("address"), "addresses");
566    }
567
568    #[test]
569    fn test_parse_query_function() {
570        let source = r#"
571            #[query]
572            async fn get_user(ctx: QueryContext, id: Uuid) -> Result<User> {
573                todo!()
574            }
575        "#;
576
577        let registry = SchemaRegistry::new();
578        parse_file(source, &registry).expect("query function should parse");
579
580        let func = registry
581            .get_function("get_user")
582            .expect("get_user function should be registered");
583        assert_eq!(func.name, "get_user");
584        assert_eq!(func.kind, FunctionKind::Query);
585        assert!(func.is_async);
586    }
587
588    #[test]
589    fn test_parse_mutation_function() {
590        let source = r#"
591            #[mutation]
592            async fn create_user(ctx: MutationContext, name: String, email: String) -> Result<User> {
593                todo!()
594            }
595        "#;
596
597        let registry = SchemaRegistry::new();
598        parse_file(source, &registry).expect("mutation function should parse");
599
600        let func = registry
601            .get_function("create_user")
602            .expect("create_user function should be registered");
603        assert_eq!(func.name, "create_user");
604        assert_eq!(func.kind, FunctionKind::Mutation);
605        assert_eq!(func.args.len(), 2);
606        assert_eq!(
607            func.args.first().expect("name arg should exist").name,
608            "name"
609        );
610        assert_eq!(
611            func.args.get(1).expect("email arg should exist").name,
612            "email"
613        );
614    }
615
616    #[test]
617    fn test_context_detection_structural() {
618        // A type ending with "Context" should be detected.
619        let source = r#"
620            #[query]
621            async fn test(ctx: forge::QueryContext, id: Uuid) -> Result<User> {
622                todo!()
623            }
624        "#;
625        let registry = SchemaRegistry::new();
626        parse_file(source, &registry).expect("context query should parse");
627        let func = registry
628            .get_function("test")
629            .expect("test function should be registered");
630        assert_eq!(func.args.len(), 1); // Only `id`, context was skipped.
631        assert_eq!(func.args.first().expect("id arg should exist").name, "id");
632    }
633
634    #[test]
635    fn test_context_detection_does_not_match_other_types() {
636        // A type NOT ending with "Context" should not be skipped.
637        let source = r#"
638            #[query]
639            async fn test(data: ContextManager, id: Uuid) -> Result<User> {
640                todo!()
641            }
642        "#;
643        let registry = SchemaRegistry::new();
644        parse_file(source, &registry).expect("non-context query should parse");
645        let func = registry
646            .get_function("test")
647            .expect("test function should be registered");
648        // "ContextManager" ends with "Manager", not "Context", so both args kept.
649        assert_eq!(func.args.len(), 2);
650    }
651
652    #[test]
653    fn test_naive_time_maps_to_local_time() {
654        let source = r#"
655            #[derive(Serialize, Deserialize)]
656            struct Schedule {
657                start_time: NaiveTime,
658            }
659        "#;
660        let registry = SchemaRegistry::new();
661        parse_file(source, &registry).expect("schedule DTO should parse");
662        let table = registry
663            .get_table("Schedule")
664            .expect("Schedule table should be registered");
665        assert_eq!(
666            table
667                .fields
668                .first()
669                .expect("start_time field should exist")
670                .rust_type,
671            RustType::LocalTime
672        );
673    }
674}