Skip to main content

pyro_macro/
struct_doc.rs

1//! File-level schema builder.
2//!
3//! Parses every struct in a capability source file to build a registry of
4//! `name → PyroSchema`, then resolves field types against that registry so
5//! that nested user-defined structs produce correct `Group(fields)` instead of
6//! opaque empty groups.
7//!
8//! # Usage
9//!
10//! ```ignore
11//! let builder = SchemaBuilder::from_file(&syn_file);
12//! let pyro_type = builder.resolve_type(&syn_ty);
13//! let schema    = builder.schema_for("MyStruct").unwrap();
14//! ```
15
16use std::borrow::Cow;
17use std::collections::HashMap;
18
19use crate::utils::has_attr;
20use pyro_spec::{PrimitiveDataType, PyroField, PyroSchema, PyroType};
21use syn::{Attribute, Expr, Fields, Lit, Meta};
22
23// =============================================================================
24// SchemaBuilder
25// =============================================================================
26
27/// A file-level registry of struct schemas.
28///
29/// Built from a parsed `syn::File`, it maps every struct name to its fields
30/// (name, syn::Type, doc-string). The registry is then used to resolve
31/// `syn::Type` → `PyroType` with full knowledge of sibling structs.
32pub struct SchemaBuilder {
33    /// Struct name → list of (field_name, field_type, field_doc).
34    structs: HashMap<String, StructEntry>,
35}
36
37struct StructEntry {
38    doc: Option<String>,
39    fields: Vec<FieldEntry>,
40}
41
42struct FieldEntry {
43    name: String,
44    ty: syn::Type,
45    doc: Option<String>,
46}
47
48impl SchemaBuilder {
49    // -----------------------------------------------------------------
50    // Construction
51    // -----------------------------------------------------------------
52
53    /// Scan every `Item::Struct` in the file and register it.
54    pub fn from_file(file: &syn::File) -> Self {
55        let mut structs = HashMap::new();
56        for item in &file.items {
57            if let syn::Item::Struct(s) = item {
58                if !(has_attr(&s.attrs, "config") || has_attr(&s.attrs, "magma")) {
59                    continue;
60                }
61                let name = s.ident.to_string();
62                let doc = extract_doc_string(&s.attrs);
63                let fields = Self::collect_fields(&s.fields);
64                structs.insert(name, StructEntry { doc, fields });
65            }
66        }
67        Self { structs }
68    }
69
70    fn collect_fields(fields: &Fields) -> Vec<FieldEntry> {
71        match fields {
72            Fields::Named(named) => named
73                .named
74                .iter()
75                .map(|f| FieldEntry {
76                    name: f.ident.as_ref().unwrap().to_string(),
77                    ty: f.ty.clone(),
78                    doc: extract_doc_string(&f.attrs),
79                })
80                .collect(),
81            Fields::Unnamed(unnamed) => unnamed
82                .unnamed
83                .iter()
84                .enumerate()
85                .map(|(i, f)| FieldEntry {
86                    name: i.to_string(),
87                    ty: f.ty.clone(),
88                    doc: extract_doc_string(&f.attrs),
89                })
90                .collect(),
91            Fields::Unit => vec![],
92        }
93    }
94
95    // -----------------------------------------------------------------
96    // Resolution
97    // -----------------------------------------------------------------
98
99    /// Build a `PyroSchema` for a struct that is in the registry.
100    pub fn schema_for(&self, struct_name: &str) -> Option<PyroSchema<'static>> {
101        let entry = self.structs.get(struct_name)?;
102        let mut visited = Vec::new();
103        let fields = self.resolve_fields_inner(&entry.fields, &mut visited);
104        let mut schema = PyroSchema::new(fields);
105        if let Some(d) = &entry.doc {
106            schema = schema.add_docstring(Cow::Owned(d.clone()));
107        }
108        Some(schema)
109    }
110
111    /// Resolve a `syn::Type` to a `PyroType`, expanding known struct names
112    /// into full `Group(fields)`.
113    pub fn resolve_type(&self, ty: &syn::Type) -> PyroType<'static> {
114        self.resolve_type_inner(ty, &mut Vec::new())
115    }
116
117    /// Check whether a `syn::Type` is `Option<_>`.
118    pub fn is_option(ty: &syn::Type) -> bool {
119        is_option_type(ty)
120    }
121
122    // -----------------------------------------------------------------
123    // Internals
124    // -----------------------------------------------------------------
125
126    fn resolve_fields_inner(
127        &self,
128        fields: &[FieldEntry],
129        visited: &mut Vec<String>,
130    ) -> Vec<PyroField<'static>> {
131        fields
132            .iter()
133            .map(|f| {
134                let data_type = self.resolve_type_inner(&f.ty, visited);
135                let nullable = is_option_type(&f.ty);
136                let mut field = PyroField::new(Cow::Owned(f.name.clone()), data_type, nullable);
137                if let Some(doc) = &f.doc {
138                    field = field.add_docstring(Cow::Owned(doc.clone()));
139                }
140                field
141            })
142            .collect()
143    }
144
145    /// Core resolver. `visited` tracks struct names on the current path to
146    /// break infinite recursion from cyclic types (which shouldn't occur in
147    /// practice but we guard against it).
148    fn resolve_type_inner(&self, ty: &syn::Type, visited: &mut Vec<String>) -> PyroType<'static> {
149        match ty {
150            syn::Type::Path(type_path) => {
151                let segment = match type_path.path.segments.last() {
152                    Some(s) => s,
153                    None => return PyroType::Null,
154                };
155                let ident_str = segment.ident.to_string();
156
157                match ident_str.as_str() {
158                    // --- Primitives ---
159                    "bool" => PyroType::PrimitiveScalar(PrimitiveDataType::Bool),
160                    "u8" => PyroType::PrimitiveScalar(PrimitiveDataType::U8),
161                    "u16" => PyroType::PrimitiveScalar(PrimitiveDataType::U16),
162                    "u32" => PyroType::PrimitiveScalar(PrimitiveDataType::U32),
163                    "u64" => PyroType::PrimitiveScalar(PrimitiveDataType::U64),
164                    "i8" => PyroType::PrimitiveScalar(PrimitiveDataType::I8),
165                    "i16" => PyroType::PrimitiveScalar(PrimitiveDataType::I16),
166                    "i32" => PyroType::PrimitiveScalar(PrimitiveDataType::I32),
167                    "i64" => PyroType::PrimitiveScalar(PrimitiveDataType::I64),
168                    "f16" => PyroType::PrimitiveScalar(PrimitiveDataType::F16),
169                    "f32" => PyroType::PrimitiveScalar(PrimitiveDataType::F32),
170                    "f64" => PyroType::PrimitiveScalar(PrimitiveDataType::F64),
171
172                    // --- Strings ---
173                    "String" | "str" => PyroType::Str,
174
175                    // --- Bytes ---
176                    "Bytes" => PyroType::PrimitiveList(PrimitiveDataType::U8),
177
178                    // --- Option<T> ---
179                    "Option" => {
180                        if let Some(inner) = extract_single_generic_arg(segment) {
181                            self.resolve_type_inner(inner, visited)
182                        } else {
183                            PyroType::Null
184                        }
185                    }
186
187                    // --- Vec<T> ---
188                    "Vec" => {
189                        if let Some(inner) = extract_single_generic_arg(segment) {
190                            let inner_pyro = self.resolve_type_inner(inner, visited);
191                            match &inner_pyro {
192                                PyroType::PrimitiveScalar(p) => PyroType::PrimitiveList(*p),
193                                _ => PyroType::List(Box::new(inner_pyro), false),
194                            }
195                        } else {
196                            PyroType::Null
197                        }
198                    }
199
200                    // --- HashMap / BTreeMap ---
201                    "HashMap" | "BTreeMap" => {
202                        if let Some((k, v)) = extract_two_generic_args(segment) {
203                            PyroType::Map {
204                                key: Box::new(self.resolve_type_inner(k, visited)),
205                                value: Box::new(self.resolve_type_inner(v, visited)),
206                            }
207                        } else {
208                            PyroType::Null
209                        }
210                    }
211
212                    // --- Result<T, E> ---
213                    "Result" => {
214                        if let Some((ok, _err)) = extract_two_generic_args(segment) {
215                            self.resolve_type_inner(ok, visited)
216                        } else {
217                            PyroType::Null
218                        }
219                    }
220
221                    // --- DateTime ---
222                    "DateTime" => PyroType::Timestamp,
223
224                    // --- User-defined struct (look up in registry) ---
225                    other => {
226                        if visited.contains(&other.to_string()) {
227                            // Cycle guard — return empty group
228                            return PyroType::Group(Cow::Owned(vec![]));
229                        }
230                        if let Some(entry) = self.structs.get(other) {
231                            visited.push(other.to_string());
232                            let fields = self.resolve_fields_inner(&entry.fields, visited);
233                            visited.pop();
234                            PyroType::Group(Cow::Owned(fields))
235                        } else {
236                            // Unknown struct — opaque group
237                            PyroType::Group(Cow::Owned(vec![]))
238                        }
239                    }
240                }
241            }
242            syn::Type::Reference(r) => self.resolve_type_inner(&r.elem, visited),
243            syn::Type::Tuple(t) if t.elems.is_empty() => PyroType::Null,
244            _ => PyroType::Null,
245        }
246    }
247}
248
249// =============================================================================
250// Helpers (private)
251// =============================================================================
252
253fn is_option_type(ty: &syn::Type) -> bool {
254    if let syn::Type::Path(type_path) = ty {
255        if let Some(seg) = type_path.path.segments.last() {
256            return seg.ident == "Option";
257        }
258    }
259    false
260}
261
262fn extract_single_generic_arg(segment: &syn::PathSegment) -> Option<&syn::Type> {
263    if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
264        if let Some(syn::GenericArgument::Type(ty)) = args.args.first() {
265            return Some(ty);
266        }
267    }
268    None
269}
270
271fn extract_two_generic_args(segment: &syn::PathSegment) -> Option<(&syn::Type, &syn::Type)> {
272    if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
273        let mut iter = args.args.iter();
274        if let (Some(syn::GenericArgument::Type(a)), Some(syn::GenericArgument::Type(b))) =
275            (iter.next(), iter.next())
276        {
277            return Some((a, b));
278        }
279    }
280    None
281}
282
283fn extract_doc_string(attrs: &[Attribute]) -> Option<String> {
284    let mut lines = Vec::new();
285    for attr in attrs {
286        if attr.path().is_ident("doc") {
287            if let Meta::NameValue(nv) = &attr.meta {
288                if let Expr::Lit(expr_lit) = &nv.value {
289                    if let Lit::Str(lit_str) = &expr_lit.lit {
290                        lines.push(lit_str.value().trim().to_string());
291                    }
292                }
293            }
294        }
295    }
296    if lines.is_empty() {
297        None
298    } else {
299        Some(lines.join("\n"))
300    }
301}
302
303// =============================================================================
304// Tests
305// =============================================================================
306
307#[cfg(test)]
308mod tests {
309    use super::*;
310    use quote::quote;
311    use syn::parse2;
312
313    fn builder_from_tokens(tokens: proc_macro2::TokenStream) -> SchemaBuilder {
314        let file: syn::File = syn::parse2(tokens).unwrap();
315        SchemaBuilder::from_file(&file)
316    }
317
318    // --- Primitive resolution ---
319
320    #[test]
321    fn test_resolve_primitives() {
322        let builder = builder_from_tokens(quote! {});
323
324        let ty: syn::Type = parse2(quote!(u32)).unwrap();
325        assert_eq!(
326            builder.resolve_type(&ty),
327            PyroType::PrimitiveScalar(PrimitiveDataType::U32)
328        );
329
330        let ty: syn::Type = parse2(quote!(String)).unwrap();
331        assert_eq!(builder.resolve_type(&ty), PyroType::Str);
332
333        let ty: syn::Type = parse2(quote!(f64)).unwrap();
334        assert_eq!(
335            builder.resolve_type(&ty),
336            PyroType::PrimitiveScalar(PrimitiveDataType::F64)
337        );
338    }
339
340    // --- Vec / Option ---
341
342    #[test]
343    fn test_resolve_vec_and_option() {
344        let builder = builder_from_tokens(quote! {});
345
346        let ty: syn::Type = parse2(quote!(Vec<u8>)).unwrap();
347        assert_eq!(
348            builder.resolve_type(&ty),
349            PyroType::PrimitiveList(PrimitiveDataType::U8)
350        );
351
352        let ty: syn::Type = parse2(quote!(Vec<String>)).unwrap();
353        assert_eq!(
354            builder.resolve_type(&ty),
355            PyroType::List(Box::new(PyroType::Str), false)
356        );
357
358        let ty: syn::Type = parse2(quote!(Option<i32>)).unwrap();
359        assert_eq!(
360            builder.resolve_type(&ty),
361            PyroType::PrimitiveScalar(PrimitiveDataType::I32)
362        );
363        assert!(SchemaBuilder::is_option(&ty));
364    }
365
366    // --- Nested struct resolution (the whole point) ---
367
368    #[test]
369    fn test_resolve_nested_struct() {
370        let builder = builder_from_tokens(quote! {
371            #[config]
372            struct Foo {
373                woobie: String,
374            }
375
376            #[config]
377            struct Bar {
378                doobie: Foo,
379            }
380        });
381
382        // Foo resolves to Group with one Str field
383        let ty_foo: syn::Type = parse2(quote!(Foo)).unwrap();
384        assert_eq!(
385            builder.resolve_type(&ty_foo),
386            PyroType::Group(Cow::Owned(vec![PyroField::new(
387                Cow::Borrowed("woobie"),
388                PyroType::Str,
389                false,
390            )]))
391        );
392
393        // Bar resolves to Group with one Group field (the Foo)
394        let schema = builder.schema_for("Bar").unwrap();
395        assert_eq!(schema.fields.len(), 1);
396
397        let doobie = &schema.fields()[0];
398        assert_eq!(doobie.name(), "doobie");
399        match &doobie.data_type {
400            PyroType::Group(inner_fields) => {
401                assert_eq!(inner_fields.len(), 1);
402                assert_eq!(inner_fields[0].name(), "woobie");
403                assert_eq!(inner_fields[0].data_type, PyroType::Str);
404            }
405            other => panic!("expected Group, got {:?}", other),
406        }
407    }
408
409    // --- Deeply nested ---
410
411    #[test]
412    fn test_resolve_deeply_nested() {
413        let builder = builder_from_tokens(quote! {
414            #[config]
415            struct A {
416                x: i32,
417            }
418            #[config]
419            struct B {
420                a: A,
421                name: String,
422            }
423            #[config]
424            struct C {
425                b: B,
426                flag: bool,
427            }
428        });
429
430        let schema_c = builder.schema_for("C").unwrap();
431        assert_eq!(schema_c.fields.len(), 2);
432
433        // b field should be Group([ Group([x:I32]), name:Str ])
434        let b_field = &schema_c.fields()[0];
435        assert_eq!(b_field.name(), "b");
436        match &b_field.data_type {
437            PyroType::Group(b_fields) => {
438                assert_eq!(b_fields.len(), 2);
439                assert_eq!(b_fields[0].name(), "a");
440                match &b_fields[0].data_type {
441                    PyroType::Group(a_fields) => {
442                        assert_eq!(a_fields.len(), 1);
443                        assert_eq!(a_fields[0].name(), "x");
444                        assert_eq!(
445                            a_fields[0].data_type,
446                            PyroType::PrimitiveScalar(PrimitiveDataType::I32)
447                        );
448                    }
449                    other => panic!("expected Group for A, got {:?}", other),
450                }
451                assert_eq!(b_fields[1].name(), "name");
452                assert_eq!(b_fields[1].data_type, PyroType::Str);
453            }
454            other => panic!("expected Group for B, got {:?}", other),
455        }
456
457        // flag field
458        let flag_field = &schema_c.fields()[1];
459        assert_eq!(flag_field.name(), "flag");
460        assert_eq!(
461            flag_field.data_type,
462            PyroType::PrimitiveScalar(PrimitiveDataType::Bool)
463        );
464    }
465
466    // --- Vec of struct ---
467
468    #[test]
469    fn test_resolve_vec_of_struct() {
470        let builder = builder_from_tokens(quote! {
471            #[config]
472            struct Item {
473                value: f32,
474            }
475            #[config]
476            struct Container {
477                items: Vec<Item>,
478            }
479        });
480
481        let schema = builder.schema_for("Container").unwrap();
482        let items_field = &schema.fields()[0];
483        assert_eq!(items_field.name(), "items");
484
485        match &items_field.data_type {
486            PyroType::List(inner, nullable) => {
487                assert!(!nullable);
488                match inner.as_ref() {
489                    PyroType::Group(fields) => {
490                        assert_eq!(fields.len(), 1);
491                        assert_eq!(fields[0].name(), "value");
492                        assert_eq!(
493                            fields[0].data_type,
494                            PyroType::PrimitiveScalar(PrimitiveDataType::F32)
495                        );
496                    }
497                    other => panic!("expected Group inside List, got {:?}", other),
498                }
499            }
500            other => panic!("expected List, got {:?}", other),
501        }
502    }
503
504    // --- Doc strings preserved ---
505
506    #[test]
507    fn test_doc_strings_preserved() {
508        let builder = builder_from_tokens(quote! {
509            /// This is Foo
510            #[config]
511            struct Foo {
512                /// The id
513                id: u32,
514                name: String,
515            }
516        });
517
518        let schema = builder.schema_for("Foo").unwrap();
519        assert_eq!(schema.documentation.as_deref(), Some("This is Foo"));
520        assert_eq!(schema.fields.len(), 2);
521        assert_eq!(schema.fields()[0].documentation.as_deref(), Some("The id"));
522        assert!(schema.fields()[1].documentation.is_none());
523    }
524
525    // --- Unknown struct falls back to empty group ---
526
527    #[test]
528    fn test_unknown_struct_empty_group() {
529        let builder = builder_from_tokens(quote! {
530            #[config]
531            struct Wrapper {
532                inner: SomeExternalThing,
533            }
534        });
535
536        let schema = builder.schema_for("Wrapper").unwrap();
537        let inner = &schema.fields()[0];
538        assert_eq!(inner.data_type, PyroType::Group(Cow::Owned(vec![])));
539    }
540
541    // --- Cycle guard ---
542
543    #[test]
544    fn test_cycle_guard() {
545        // Contrived: struct A has field of type A. Should not stack overflow.
546        let builder = builder_from_tokens(quote! {
547            #[config]
548            struct A {
549                next: A,
550            }
551        });
552
553        let schema = builder.schema_for("A").unwrap();
554        assert_eq!(schema.fields().len(), 1);
555        let next_field = &schema.fields()[0];
556        assert_eq!(next_field.name(), "next");
557
558        // The `next` field has type A, which is a Group containing one field also named `next`.
559        // The inner self-reference is cut off as Group([]) to break the cycle.
560        match &next_field.data_type {
561            PyroType::Group(inner_fields) => {
562                assert_eq!(inner_fields.len(), 1);
563                assert_eq!(inner_fields[0].name(), "next");
564                // Cycle broken at this depth — the recursive self-ref is an empty group
565                assert_eq!(
566                    inner_fields[0].data_type,
567                    PyroType::Group(Cow::Owned(vec![]))
568                );
569            }
570            other => panic!("expected Group for A's next field, got {:?}", other),
571        }
572    }
573
574    // --- Map of struct ---
575
576    #[test]
577    fn test_resolve_map_of_struct() {
578        let builder = builder_from_tokens(quote! {
579            #[config]
580            struct Config {
581                key: String,
582            }
583            #[config]
584            struct Registry {
585                entries: HashMap<String, Config>,
586            }
587        });
588
589        let schema = builder.schema_for("Registry").unwrap();
590        let entries = &schema.fields()[0];
591
592        match &entries.data_type {
593            PyroType::Map { key, value } => {
594                assert_eq!(key.as_ref(), &PyroType::Str);
595                match value.as_ref() {
596                    PyroType::Group(fields) => {
597                        assert_eq!(fields.len(), 1);
598                        assert_eq!(fields[0].name(), "key");
599                    }
600                    other => panic!("expected Group for Config, got {:?}", other),
601                }
602            }
603            other => panic!("expected Map, got {:?}", other),
604        }
605    }
606}