Skip to main content

forge_codegen/
parser.rs

1//! Rust source code parser for extracting FORGE schema definitions.
2//!
3//! Parses Rust source files using `syn` to extract model, enum, and function
4//! definitions without requiring compilation.
5//!
6//! Key design decisions:
7//! - Context arguments are matched against `KNOWN_CONTEXT_TYPES`, a fixed list
8//!   of the 8 Forge context types. User-defined types like `AppContext` are not
9//!   accidentally skipped.
10//! - `Result<T, E>` extraction uses `syn::TypePath` recursion, handling
11//!   arbitrarily nested generics.
12//! - Unparseable inner types become `RustType::Custom(original_string)` instead
13//!   of silently falling back to `String`.
14//! - `NaiveTime` correctly maps to `RustType::LocalTime`.
15
16use std::path::{Path, PathBuf};
17
18use forge_core::schema::{
19    EnumDef, EnumVariant, FieldDef, FunctionArg, FunctionDef, FunctionKind, RustType,
20    SchemaRegistry, TableDef,
21};
22use forge_core::util::to_snake_case;
23use std::collections::BTreeMap;
24
25use quote::ToTokens;
26use syn::{Attribute, Expr, Fields, FnArg, Lit, Meta, Pat, ReturnType};
27
28use crate::Error;
29
30fn collect_rs_files(dir: &Path, out: &mut Vec<PathBuf>) {
31    let entries = match std::fs::read_dir(dir) {
32        Ok(e) => e,
33        Err(_) => return,
34    };
35    for entry in entries.flatten() {
36        let path = entry.path();
37        if path.is_dir() {
38            collect_rs_files(&path, out);
39        } else if path.extension().is_some_and(|ext| ext == "rs") {
40            out.push(path);
41        }
42    }
43}
44
45/// Schema registry plus file-level parse failures from [`parse_project`].
46///
47/// Files that fail `syn::parse_file` are recorded instead of silently dropped,
48/// so a broken handler source doesn't disappear from the generated bindings.
49pub struct ParseOutcome {
50    pub registry: SchemaRegistry,
51    pub parse_failures: Vec<(PathBuf, String)>,
52}
53
54pub fn parse_project(src_dir: &Path) -> Result<ParseOutcome, Error> {
55    let registry = SchemaRegistry::new();
56    let mut parse_failures = Vec::new();
57
58    let mut files = Vec::new();
59    collect_rs_files(src_dir, &mut files);
60    files.sort();
61
62    for path in &files {
63        let content = std::fs::read_to_string(path)?;
64        if let Err(e) = parse_file(&content, &registry) {
65            tracing::warn!(file = ?path, error = %e, "failed to parse file; handlers in this file will be missing from generated bindings");
66            parse_failures.push((path.clone(), e.to_string()));
67        }
68    }
69
70    Ok(ParseOutcome {
71        registry,
72        parse_failures,
73    })
74}
75
76/// Validate every function uses types the emitters support.
77///
78/// Surfaces `usize`/`isize` and unsupported integer widths before they fall
79/// through to silently-lossy `number` on the frontend.
80pub fn validate_registry(registry: &SchemaRegistry) -> Result<(), Vec<String>> {
81    let mut errors = Vec::new();
82    for func in registry.all_functions() {
83        for arg in &func.args {
84            collect_unsupported(&arg.rust_type, &func.name, Some(&arg.name), &mut errors);
85        }
86        collect_unsupported(&func.return_type, &func.name, None, &mut errors);
87    }
88    if errors.is_empty() {
89        Ok(())
90    } else {
91        Err(errors)
92    }
93}
94
95fn collect_unsupported(
96    ty: &RustType,
97    func_name: &str,
98    arg_name: Option<&str>,
99    errors: &mut Vec<String>,
100) {
101    match ty {
102        RustType::Option(inner) | RustType::Vec(inner) => {
103            collect_unsupported(inner, func_name, arg_name, errors);
104        }
105        RustType::Custom(name) => {
106            if let Some(reason) = unsupported_type_reason(name) {
107                let location = match arg_name {
108                    Some(arg) => format!("{}.{}", func_name, arg),
109                    None => format!("{}() return type", func_name),
110                };
111                errors.push(format!("{}: {}", location, reason));
112            }
113        }
114        _ => {}
115    }
116}
117
118fn unsupported_type_reason(name: &str) -> Option<String> {
119    match name {
120        "usize" | "isize" => Some(format!(
121            "`{}` is platform-dependent and not portable across the wire. \
122             Use `i32`, `i64`, or `u32` instead.",
123            name
124        )),
125        "u8" | "u16" | "u32" | "u64" | "u128" | "i8" | "i16" | "i128" => Some(format!(
126            "`{}` is not supported in handler signatures. \
127             Use `i32` or `i64` (signed integers) for portability.",
128            name
129        )),
130        _ => None,
131    }
132}
133
134/// Scan a source directory and return every `(kind, name)` pair that appears in more than one file.
135/// Map key is `"kind:name"`, value is the list of file paths.
136pub fn find_duplicate_handlers(src_dir: &Path) -> Result<BTreeMap<String, Vec<PathBuf>>, Error> {
137    let mut occurrences: BTreeMap<String, Vec<PathBuf>> = BTreeMap::new();
138
139    let mut files = Vec::new();
140    collect_rs_files(src_dir, &mut files);
141    files.sort();
142
143    for path in &files {
144        let content = match std::fs::read_to_string(path) {
145            Ok(c) => c,
146            Err(_) => continue,
147        };
148        let file = match syn::parse_file(&content) {
149            Ok(f) => f,
150            Err(_) => continue,
151        };
152        for item in &file.items {
153            if let syn::Item::Fn(item_fn) = item
154                && let Some(func) = parse_function(item_fn)
155            {
156                let key = format!("{}:{}", func.kind.as_str(), func.name);
157                occurrences.entry(key).or_default().push(path.clone());
158            }
159        }
160    }
161
162    Ok(occurrences
163        .into_iter()
164        .filter(|(_, paths)| paths.len() > 1)
165        .collect())
166}
167
168fn parse_file(content: &str, registry: &SchemaRegistry) -> Result<(), Error> {
169    let file = syn::parse_file(content).map_err(|e| Error::Parse {
170        file: String::new(),
171        message: e.to_string(),
172    })?;
173
174    for item in file.items {
175        match item {
176            syn::Item::Struct(item_struct) => {
177                if has_forge_attr(&item_struct.attrs, "model") {
178                    if let Some(table) = parse_model(&item_struct) {
179                        registry.register_table(table);
180                    }
181                } else if has_serde_derive(&item_struct.attrs)
182                    && let Some(table) = parse_dto_struct(&item_struct)
183                {
184                    registry.register_table(table);
185                }
186            }
187            syn::Item::Enum(item_enum) => {
188                if (has_forge_enum_attr(&item_enum.attrs) || has_serde_derive(&item_enum.attrs))
189                    && let Some(enum_def) = parse_enum(&item_enum)
190                {
191                    registry.register_enum(enum_def);
192                }
193            }
194            syn::Item::Fn(item_fn) => {
195                if let Some(func) = parse_function(&item_fn) {
196                    registry.register_function(func);
197                }
198            }
199            _ => {}
200        }
201    }
202
203    Ok(())
204}
205
206/// Check if attributes contain `#[forge::name]` or `#[name]`.
207fn has_forge_attr(attrs: &[Attribute], name: &str) -> bool {
208    attrs.iter().any(|attr| {
209        let path = attr.path();
210        path.is_ident(name)
211            || matches!(
212                (path.segments.first(), path.segments.get(1), path.segments.get(2)),
213                (Some(first), Some(second), None)
214                    if first.ident == "forge" && second.ident == name
215            )
216    })
217}
218
219/// Check if attributes contain `#[forge_enum]`, `#[enum_type]`, or `#[forge::enum_type]`.
220fn has_forge_enum_attr(attrs: &[Attribute]) -> bool {
221    attrs.iter().any(|attr| {
222        let path = attr.path();
223        path.is_ident("forge_enum")
224            || path.is_ident("enum_type")
225            || matches!(
226                (path.segments.first(), path.segments.get(1), path.segments.get(2)),
227                (Some(first), Some(second), None)
228                    if first.ident == "forge"
229                        && (second.ident == "enum_type" || second.ident == "forge_enum")
230            )
231    })
232}
233
234fn has_serde_derive(attrs: &[Attribute]) -> bool {
235    attrs.iter().any(|attr| {
236        if !attr.path().is_ident("derive") {
237            return false;
238        }
239        let tokens = attr.meta.to_token_stream().to_string();
240        tokens.contains("Serialize") || tokens.contains("Deserialize")
241    })
242}
243
244fn parse_dto_struct(item: &syn::ItemStruct) -> Option<TableDef> {
245    let struct_name = item.ident.to_string();
246    let mut table = TableDef::new(&struct_name, &struct_name);
247    table.is_dto = true;
248    table.doc = get_doc_comment(&item.attrs);
249
250    if let Fields::Named(fields) = &item.fields {
251        for field in &fields.named {
252            if let Some(field_name) = &field.ident {
253                table
254                    .fields
255                    .push(parse_field(field_name.to_string(), &field.ty, &field.attrs));
256            }
257        }
258    }
259
260    Some(table)
261}
262
263fn parse_model(item: &syn::ItemStruct) -> Option<TableDef> {
264    let struct_name = item.ident.to_string();
265    let table_name = get_table_name_from_attrs(&item.attrs).unwrap_or_else(|| {
266        let snake = to_snake_case(&struct_name);
267        pluralize(&snake)
268    });
269
270    let mut table = TableDef::new(&table_name, &struct_name);
271    table.doc = get_doc_comment(&item.attrs);
272
273    if let Fields::Named(fields) = &item.fields {
274        for field in &fields.named {
275            if let Some(field_name) = &field.ident {
276                table
277                    .fields
278                    .push(parse_field(field_name.to_string(), &field.ty, &field.attrs));
279            }
280        }
281    }
282
283    Some(table)
284}
285
286fn parse_field(name: String, ty: &syn::Type, attrs: &[Attribute]) -> FieldDef {
287    let rust_type = type_to_rust_type(ty);
288    let mut field = FieldDef::new(&name, rust_type);
289    field.column_name = to_snake_case(&name);
290    field.doc = get_doc_comment(attrs);
291    field
292}
293
294fn parse_enum(item: &syn::ItemEnum) -> Option<EnumDef> {
295    let enum_name = item.ident.to_string();
296    let mut enum_def = EnumDef::new(&enum_name);
297    enum_def.doc = get_doc_comment(&item.attrs);
298
299    for variant in &item.variants {
300        let variant_name = variant.ident.to_string();
301        let mut enum_variant = EnumVariant::new(&variant_name);
302        enum_variant.doc = get_doc_comment(&variant.attrs);
303
304        if let Some((_, Expr::Lit(lit))) = &variant.discriminant
305            && let Lit::Int(int_lit) = &lit.lit
306            && let Ok(value) = int_lit.base10_parse::<i32>()
307        {
308            enum_variant.int_value = Some(value);
309        }
310
311        enum_def.variants.push(enum_variant);
312    }
313
314    Some(enum_def)
315}
316
317fn parse_function(item: &syn::ItemFn) -> Option<FunctionDef> {
318    let kind = get_function_kind(&item.attrs)?;
319    let func_name = item.sig.ident.to_string();
320
321    let return_type = match &item.sig.output {
322        ReturnType::Default => RustType::Custom("()".to_string()),
323        ReturnType::Type(_, ty) => extract_result_type(ty),
324    };
325
326    let mut func = FunctionDef::new(&func_name, kind, return_type);
327    func.doc = get_doc_comment(&item.attrs);
328    func.is_async = item.sig.asyncness.is_some();
329
330    let mut is_first = true;
331    for arg in &item.sig.inputs {
332        if let FnArg::Typed(pat_type) = arg {
333            if is_first {
334                is_first = false;
335                if is_context_type(&pat_type.ty) {
336                    continue;
337                }
338            }
339
340            if let Pat::Ident(pat_ident) = &*pat_type.pat {
341                let arg_name = pat_ident.ident.to_string();
342                let arg_type = type_to_rust_type(&pat_type.ty);
343                func.args.push(FunctionArg::new(arg_name, arg_type));
344            }
345        }
346    }
347
348    Some(func)
349}
350
351/// Known Forge context types. Only these are skipped as the first parameter.
352/// User-defined types like `AppContext` or `MyContext` are not matched.
353const KNOWN_CONTEXT_TYPES: &[&str] = &[
354    "QueryContext",
355    "MutationContext",
356    "JobContext",
357    "CronContext",
358    "WorkflowContext",
359    "DaemonContext",
360    "WebhookContext",
361    "McpToolContext",
362];
363
364/// Check if a type is a known Forge context type.
365/// Walks `&T`/`&mut T` references and checks the final path segment.
366fn is_context_type(ty: &syn::Type) -> bool {
367    let mut inner = ty;
368    while let syn::Type::Reference(r) = inner {
369        inner = &r.elem;
370    }
371    let syn::Type::Path(type_path) = inner else {
372        return false;
373    };
374    let Some(last) = type_path.path.segments.last() else {
375        return false;
376    };
377    KNOWN_CONTEXT_TYPES.contains(&last.ident.to_string().as_str())
378}
379
380fn get_function_kind(attrs: &[Attribute]) -> Option<FunctionKind> {
381    for attr in attrs {
382        let path = attr.path();
383        let segments: Vec<_> = path.segments.iter().map(|s| s.ident.to_string()).collect();
384
385        let kind_str = match segments.as_slice() {
386            [forge, kind] if forge == "forge" => Some(kind.as_str()),
387            [kind] => Some(kind.as_str()),
388            _ => None,
389        };
390
391        if let Some(kind) = kind_str {
392            match kind {
393                "query" => return Some(FunctionKind::Query),
394                "mutation" => return Some(FunctionKind::Mutation),
395                "job" => return Some(FunctionKind::Job),
396                "cron" => return Some(FunctionKind::Cron),
397                "workflow" => return Some(FunctionKind::Workflow),
398                _ => {}
399            }
400        }
401    }
402    None
403}
404
405/// Extract the inner `T` from `Result<T, E>`.
406fn extract_result_type(ty: &syn::Type) -> RustType {
407    if let syn::Type::Path(type_path) = ty
408        && let Some(seg) = type_path.path.segments.last()
409        && seg.ident == "Result"
410        && let syn::PathArguments::AngleBracketed(args) = &seg.arguments
411        && let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first()
412    {
413        return type_to_rust_type(inner_ty);
414    }
415
416    type_to_rust_type(ty)
417}
418
419fn type_to_rust_type(ty: &syn::Type) -> RustType {
420    match ty {
421        syn::Type::Reference(r) => type_to_rust_type(&r.elem),
422        syn::Type::Path(tp) => path_to_rust_type(tp),
423        _ => RustType::Custom(quote::quote!(#ty).to_string()),
424    }
425}
426
427fn path_to_rust_type(tp: &syn::TypePath) -> RustType {
428    let Some(last) = tp.path.segments.last() else {
429        return RustType::Custom(quote::quote!(#tp).to_string());
430    };
431    let ident = last.ident.to_string();
432
433    match ident.as_str() {
434        "String" | "str" => RustType::String,
435        "i32" => RustType::I32,
436        "i64" => RustType::I64,
437        "f32" => RustType::F32,
438        "f64" => RustType::F64,
439        "bool" => RustType::Bool,
440        "Uuid" => RustType::Uuid,
441        "DateTime" => RustType::Instant,
442        "NaiveDate" => RustType::LocalDate,
443        "NaiveTime" => RustType::LocalTime,
444        "Value" => RustType::Json,
445        "Option" => {
446            let inner = first_generic_arg(last);
447            RustType::Option(Box::new(inner))
448        }
449        "Vec" => {
450            if is_vec_u8(last) {
451                return RustType::Bytes;
452            }
453            let inner = first_generic_arg(last);
454            RustType::Vec(Box::new(inner))
455        }
456        _ => RustType::Custom(ident),
457    }
458}
459
460fn is_vec_u8(seg: &syn::PathSegment) -> bool {
461    if let syn::PathArguments::AngleBracketed(args) = &seg.arguments
462        && let Some(syn::GenericArgument::Type(syn::Type::Path(tp))) = args.args.first()
463        && let Some(s) = tp.path.segments.last()
464    {
465        return s.ident == "u8" && s.arguments.is_empty();
466    }
467    false
468}
469
470fn first_generic_arg(seg: &syn::PathSegment) -> RustType {
471    if let syn::PathArguments::AngleBracketed(args) = &seg.arguments
472        && let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first()
473    {
474        return type_to_rust_type(inner_ty);
475    }
476    RustType::Custom(seg.ident.to_string())
477}
478
479/// Get `#[table(name = "...")]` value from attributes.
480fn get_table_name_from_attrs(attrs: &[Attribute]) -> Option<String> {
481    for attr in attrs {
482        if attr.path().is_ident("table")
483            && let Meta::List(list) = &attr.meta
484        {
485            let tokens = list.tokens.to_string();
486            if let Some(value) = extract_name_value(&tokens) {
487                return Some(value);
488            }
489        }
490    }
491    None
492}
493
494fn get_attribute_string_value(attr: &Attribute) -> Option<String> {
495    if let Meta::NameValue(nv) = &attr.meta
496        && let Expr::Lit(lit) = &nv.value
497        && let Lit::Str(s) = &lit.lit
498    {
499        return Some(s.value());
500    }
501    None
502}
503
504fn get_doc_comment(attrs: &[Attribute]) -> Option<String> {
505    let docs: Vec<String> = attrs
506        .iter()
507        .filter_map(|attr| {
508            if attr.path().is_ident("doc") {
509                get_attribute_string_value(attr)
510            } else {
511                None
512            }
513        })
514        .collect();
515
516    if docs.is_empty() {
517        None
518    } else {
519        Some(
520            docs.into_iter()
521                .map(|s| s.trim().to_string())
522                .collect::<Vec<_>>()
523                .join("\n"),
524        )
525    }
526}
527
528fn extract_name_value(s: &str) -> Option<String> {
529    if let Some((_, value)) = s.split_once('=') {
530        let value = value.trim();
531        if let Some(stripped) = value.strip_prefix('"').and_then(|s| s.strip_suffix('"')) {
532            return Some(stripped.to_string());
533        }
534    }
535    None
536}
537
538fn pluralize(s: &str) -> String {
539    forge_core::util::pluralize(s)
540}
541
542#[cfg(test)]
543#[allow(
544    clippy::unwrap_used,
545    clippy::indexing_slicing,
546    clippy::panic,
547    clippy::todo
548)]
549mod tests {
550    use super::*;
551
552    #[test]
553    fn test_parse_model_source() {
554        let source = r#"
555            #[model]
556            struct User {
557                #[id]
558                id: Uuid,
559                email: String,
560                name: Option<String>,
561                #[indexed]
562                created_at: DateTime<Utc>,
563            }
564        "#;
565
566        let registry = SchemaRegistry::new();
567        parse_file(source, &registry).expect("model source should parse");
568
569        let table = registry
570            .get_table("users")
571            .expect("users table should be registered");
572        assert_eq!(table.struct_name, "User");
573        assert_eq!(table.fields.len(), 4);
574    }
575
576    #[test]
577    fn test_parse_enum_source() {
578        let source = r#"
579            #[forge_enum]
580            enum ProjectStatus {
581                Draft,
582                Active,
583                Completed,
584            }
585        "#;
586
587        let registry = SchemaRegistry::new();
588        parse_file(source, &registry).expect("enum source should parse");
589
590        let enum_def = registry
591            .get_enum("ProjectStatus")
592            .expect("ProjectStatus enum should be registered");
593        assert_eq!(enum_def.variants.len(), 3);
594    }
595
596    #[test]
597    fn test_to_snake_case() {
598        assert_eq!(to_snake_case("UserProfile"), "user_profile");
599        assert_eq!(to_snake_case("ID"), "id");
600        assert_eq!(to_snake_case("createdAt"), "created_at");
601    }
602
603    #[test]
604    fn test_pluralize() {
605        assert_eq!(pluralize("user"), "users");
606        assert_eq!(pluralize("category"), "categories");
607        assert_eq!(pluralize("box"), "boxes");
608        assert_eq!(pluralize("address"), "addresses");
609    }
610
611    #[test]
612    fn test_parse_query_function() {
613        let source = r#"
614            #[query]
615            async fn get_user(ctx: QueryContext, id: Uuid) -> Result<User> {
616                todo!()
617            }
618        "#;
619
620        let registry = SchemaRegistry::new();
621        parse_file(source, &registry).expect("query function should parse");
622
623        let func = registry
624            .get_function("get_user")
625            .expect("get_user function should be registered");
626        assert_eq!(func.name, "get_user");
627        assert_eq!(func.kind, FunctionKind::Query);
628        assert!(func.is_async);
629    }
630
631    #[test]
632    fn test_parse_mutation_function() {
633        let source = r#"
634            #[mutation]
635            async fn create_user(ctx: MutationContext, name: String, email: String) -> Result<User> {
636                todo!()
637            }
638        "#;
639
640        let registry = SchemaRegistry::new();
641        parse_file(source, &registry).expect("mutation function should parse");
642
643        let func = registry
644            .get_function("create_user")
645            .expect("create_user function should be registered");
646        assert_eq!(func.name, "create_user");
647        assert_eq!(func.kind, FunctionKind::Mutation);
648        assert_eq!(func.args.len(), 2);
649        assert_eq!(
650            func.args.first().expect("name arg should exist").name,
651            "name"
652        );
653        assert_eq!(
654            func.args.get(1).expect("email arg should exist").name,
655            "email"
656        );
657    }
658
659    #[test]
660    fn test_context_detection_structural() {
661        let source = r#"
662            #[query]
663            async fn test(ctx: forge::QueryContext, id: Uuid) -> Result<User> {
664                todo!()
665            }
666        "#;
667        let registry = SchemaRegistry::new();
668        parse_file(source, &registry).expect("context query should parse");
669        let func = registry
670            .get_function("test")
671            .expect("test function should be registered");
672        assert_eq!(func.args.len(), 1); // context skipped, only `id` remains
673        assert_eq!(func.args.first().expect("id arg should exist").name, "id");
674    }
675
676    #[test]
677    fn test_context_detection_does_not_match_other_types() {
678        let source = r#"
679            #[query]
680            async fn test(data: ContextManager, id: Uuid) -> Result<User> {
681                todo!()
682            }
683        "#;
684        let registry = SchemaRegistry::new();
685        parse_file(source, &registry).expect("non-context query should parse");
686        let func = registry
687            .get_function("test")
688            .expect("test function should be registered");
689        assert_eq!(func.args.len(), 2); // "ContextManager" is not a known context type
690    }
691
692    #[test]
693    fn test_user_defined_context_not_skipped() {
694        let source = r#"
695            #[query]
696            async fn test(ctx: AppContext, id: Uuid) -> Result<User> {
697                todo!()
698            }
699        "#;
700        let registry = SchemaRegistry::new();
701        parse_file(source, &registry).expect("user context should parse");
702        let func = registry
703            .get_function("test")
704            .expect("test function should be registered");
705        assert_eq!(func.args.len(), 2, "AppContext should not be skipped");
706    }
707
708    #[test]
709    fn test_nested_result_type_extraction() {
710        let source = r#"
711            #[query]
712            async fn nested(ctx: QueryContext) -> Result<Vec<Option<User>>> {
713                todo!()
714            }
715        "#;
716        let registry = SchemaRegistry::new();
717        parse_file(source, &registry).expect("nested result should parse");
718        let func = registry
719            .get_function("nested")
720            .expect("nested function should be registered");
721        match &func.return_type {
722            RustType::Vec(inner) => match inner.as_ref() {
723                RustType::Option(inner2) => match inner2.as_ref() {
724                    RustType::Custom(name) => assert_eq!(name, "User"),
725                    other => panic!("Expected Custom(User), got: {other:?}"),
726                },
727                other => panic!("Expected Option, got: {other:?}"),
728            },
729            other => panic!("Expected Vec, got: {other:?}"),
730        }
731    }
732
733    #[test]
734    fn test_naive_time_maps_to_local_time() {
735        let source = r#"
736            #[derive(Serialize, Deserialize)]
737            struct Schedule {
738                start_time: NaiveTime,
739            }
740        "#;
741        let registry = SchemaRegistry::new();
742        parse_file(source, &registry).expect("schedule DTO should parse");
743        let table = registry
744            .get_table("Schedule")
745            .expect("Schedule table should be registered");
746        assert_eq!(
747            table
748                .fields
749                .first()
750                .expect("start_time field should exist")
751                .rust_type,
752            RustType::LocalTime
753        );
754    }
755
756    #[test]
757    fn end_to_end_realistic_schema_pipeline() {
758        let source = r#"
759            use forge::prelude::*;
760
761            #[model]
762            struct User {
763                id: Uuid,
764                email: String,
765                name: Option<String>,
766                role: UserRole,
767                created_at: DateTime<Utc>,
768            }
769
770            #[model]
771            struct Post {
772                id: Uuid,
773                title: String,
774                body: String,
775                author_id: Uuid,
776                published: bool,
777                view_count: i64,
778                created_at: DateTime<Utc>,
779            }
780
781            #[forge_enum]
782            enum UserRole {
783                Admin,
784                Member,
785                Guest,
786            }
787
788            #[derive(Serialize, Deserialize)]
789            struct CreateUserArgs {
790                email: String,
791                name: Option<String>,
792                role: UserRole,
793            }
794
795            #[query]
796            async fn get_users(ctx: QueryContext) -> Result<Vec<User>> {
797                todo!()
798            }
799
800            #[query]
801            async fn get_user(ctx: QueryContext, id: Uuid) -> Result<User> {
802                todo!()
803            }
804
805            #[mutation]
806            async fn create_user(ctx: MutationContext, args: CreateUserArgs) -> Result<User> {
807                todo!()
808            }
809
810            #[mutation]
811            async fn delete_user(ctx: MutationContext, id: Uuid) -> Result<()> {
812                todo!()
813            }
814
815            #[job]
816            async fn send_welcome_email(ctx: JobContext, user_id: Uuid) -> Result<()> {
817                todo!()
818            }
819
820            #[workflow]
821            async fn onboarding(ctx: WorkflowContext, user_id: Uuid) -> Result<String> {
822                todo!()
823            }
824
825            #[cron]
826            async fn daily_cleanup(ctx: CronContext) -> Result<()> {
827                todo!()
828            }
829        "#;
830
831        let registry = SchemaRegistry::new();
832        parse_file(source, &registry).expect("realistic project should parse");
833
834        let users = registry.get_table("users").expect("users table");
835        assert_eq!(users.fields.len(), 5);
836        let posts = registry.get_table("posts").expect("posts table");
837        assert_eq!(posts.fields.len(), 7);
838
839        let role_enum = registry.get_enum("UserRole").expect("UserRole enum");
840        assert_eq!(role_enum.variants.len(), 3);
841
842        let args = registry
843            .get_table("CreateUserArgs")
844            .expect("CreateUserArgs DTO");
845        assert_eq!(args.fields.len(), 3);
846
847        let all_fns = registry.all_functions();
848        assert_eq!(all_fns.len(), 7);
849
850        let queries: Vec<_> = all_fns
851            .iter()
852            .filter(|f| f.kind == FunctionKind::Query)
853            .collect();
854        assert_eq!(queries.len(), 2);
855
856        let mutations: Vec<_> = all_fns
857            .iter()
858            .filter(|f| f.kind == FunctionKind::Mutation)
859            .collect();
860        assert_eq!(mutations.len(), 2);
861
862        let jobs: Vec<_> = all_fns
863            .iter()
864            .filter(|f| f.kind == FunctionKind::Job)
865            .collect();
866        assert_eq!(jobs.len(), 1);
867
868        let workflows: Vec<_> = all_fns
869            .iter()
870            .filter(|f| f.kind == FunctionKind::Workflow)
871            .collect();
872        assert_eq!(workflows.len(), 1);
873
874        let crons: Vec<_> = all_fns
875            .iter()
876            .filter(|f| f.kind == FunctionKind::Cron)
877            .collect();
878        assert_eq!(crons.len(), 1);
879
880        let get_users = registry.get_function("get_users").expect("get_users");
881        assert!(
882            get_users.args.is_empty(),
883            "get_users has no user args (context stripped)"
884        );
885
886        let create_user = registry.get_function("create_user").expect("create_user");
887        assert_eq!(create_user.args.len(), 1, "create_user has one user arg");
888        assert_eq!(create_user.args.first().expect("arg").name, "args");
889
890        let send_email = registry
891            .get_function("send_welcome_email")
892            .expect("send_welcome_email");
893        assert_eq!(send_email.kind, FunctionKind::Job);
894        assert_eq!(send_email.args.len(), 1);
895    }
896
897    #[test]
898    fn binding_set_from_mixed_schema() {
899        use crate::binding::BindingSet;
900
901        let source = r#"
902            #[query]
903            async fn list_items(ctx: QueryContext) -> Result<Vec<Item>> { todo!() }
904
905            #[mutation]
906            async fn add_item(ctx: MutationContext, name: String) -> Result<Item> { todo!() }
907
908            #[job]
909            async fn process_item(ctx: JobContext, id: Uuid) -> Result<()> { todo!() }
910
911            #[cron]
912            async fn cleanup(ctx: CronContext) -> Result<()> { todo!() }
913
914            #[workflow]
915            async fn item_pipeline(ctx: WorkflowContext, id: Uuid) -> Result<String> { todo!() }
916        "#;
917
918        let registry = SchemaRegistry::new();
919        parse_file(source, &registry).expect("parse");
920
921        let bindings = BindingSet::from_registry(&registry);
922        assert_eq!(bindings.queries.len(), 1);
923        assert_eq!(bindings.mutations.len(), 1);
924        assert_eq!(bindings.jobs.len(), 1);
925        assert_eq!(bindings.workflows.len(), 1);
926        // crons are not client-callable and must not appear in bindings
927    }
928
929    #[test]
930    fn parse_function_with_multiple_args() {
931        let source = r#"
932            #[mutation]
933            async fn update_user(ctx: MutationContext, id: Uuid, name: String, email: Option<String>) -> Result<User> {
934                todo!()
935            }
936        "#;
937
938        let registry = SchemaRegistry::new();
939        parse_file(source, &registry).expect("parse");
940
941        let func = registry.get_function("update_user").expect("update_user");
942        assert_eq!(func.args.len(), 3);
943        assert_eq!(func.args.first().expect("id").name, "id");
944        assert_eq!(func.args.get(1).expect("name").name, "name");
945        assert_eq!(func.args.get(2).expect("email").name, "email");
946    }
947
948    #[test]
949    fn parse_function_with_vec_return() {
950        let source = r#"
951            #[query]
952            async fn list_posts(ctx: QueryContext) -> Result<Vec<Post>> {
953                todo!()
954            }
955        "#;
956
957        let registry = SchemaRegistry::new();
958        parse_file(source, &registry).expect("parse");
959
960        let func = registry.get_function("list_posts").expect("list_posts");
961        match &func.return_type {
962            RustType::Vec(inner) => match inner.as_ref() {
963                RustType::Custom(name) => assert_eq!(name, "Post"),
964                other => panic!("Expected Custom(Post), got: {other:?}"),
965            },
966            other => panic!("Expected Vec, got: {other:?}"),
967        }
968    }
969
970    fn parse_type(s: &str) -> RustType {
971        let ty: syn::Type = syn::parse_str(s).expect("valid type");
972        type_to_rust_type(&ty)
973    }
974
975    #[test]
976    fn type_to_rust_type_primitives() {
977        assert_eq!(parse_type("String"), RustType::String);
978        assert_eq!(parse_type("&str"), RustType::String);
979        assert_eq!(parse_type("i32"), RustType::I32);
980        assert_eq!(parse_type("i64"), RustType::I64);
981        assert_eq!(parse_type("f32"), RustType::F32);
982        assert_eq!(parse_type("f64"), RustType::F64);
983        assert_eq!(parse_type("bool"), RustType::Bool);
984    }
985
986    #[test]
987    fn type_to_rust_type_qualified_paths() {
988        assert_eq!(parse_type("Uuid"), RustType::Uuid);
989        assert_eq!(parse_type("uuid::Uuid"), RustType::Uuid);
990        assert_eq!(parse_type("DateTime<Utc>"), RustType::Instant);
991        assert_eq!(parse_type("chrono::DateTime<Utc>"), RustType::Instant);
992        assert_eq!(
993            parse_type("chrono::DateTime<chrono::Utc>"),
994            RustType::Instant
995        );
996        assert_eq!(parse_type("NaiveDate"), RustType::LocalDate);
997        assert_eq!(parse_type("chrono::NaiveDate"), RustType::LocalDate);
998        assert_eq!(parse_type("NaiveTime"), RustType::LocalTime);
999        assert_eq!(parse_type("chrono::NaiveTime"), RustType::LocalTime);
1000        assert_eq!(parse_type("serde_json::Value"), RustType::Json);
1001        assert_eq!(parse_type("Value"), RustType::Json);
1002    }
1003
1004    #[test]
1005    fn type_to_rust_type_containers() {
1006        assert_eq!(parse_type("Vec<u8>"), RustType::Bytes);
1007        assert_eq!(
1008            parse_type("Vec<String>"),
1009            RustType::Vec(Box::new(RustType::String))
1010        );
1011        assert_eq!(
1012            parse_type("Option<i32>"),
1013            RustType::Option(Box::new(RustType::I32))
1014        );
1015        assert_eq!(
1016            parse_type("Option<Vec<String>>"),
1017            RustType::Option(Box::new(RustType::Vec(Box::new(RustType::String))))
1018        );
1019    }
1020
1021    #[test]
1022    fn type_to_rust_type_std_qualified_vec() {
1023        assert_eq!(
1024            parse_type("std::vec::Vec<i32>"),
1025            RustType::Vec(Box::new(RustType::I32))
1026        );
1027        assert_eq!(
1028            parse_type("std::option::Option<String>"),
1029            RustType::Option(Box::new(RustType::String))
1030        );
1031    }
1032
1033    #[test]
1034    fn type_to_rust_type_custom() {
1035        assert_eq!(
1036            parse_type("MyStruct"),
1037            RustType::Custom("MyStruct".to_string())
1038        );
1039    }
1040}