Skip to main content

pyro_macro/
struct_doc.rs

1//! File-level schema builder.
2//!
3//! Parses every struct in a capability source file to build a registry of
4//! `name → PyroSchema`, then resolves field types against that registry so
5//! that nested user-defined structs produce correct `Group(fields)` instead of
6//! opaque empty groups.
7//!
8//! # Usage
9//!
10//! ```ignore
11//! let builder = SchemaBuilder::from_file(&syn_file);
12//! let pyro_type = builder.resolve_type(&syn_ty);
13//! let schema    = builder.schema_for("MyStruct").unwrap();
14//! ```
15
16use std::borrow::Cow;
17use std::collections::HashMap;
18
19use crate::utils::has_attr;
20use pyro_spec::{PrimitiveDataType, PyroField, PyroSchema, PyroType};
21use syn::{Attribute, Expr, Fields, Lit, Meta};
22
23// =============================================================================
24// SchemaBuilder
25// =============================================================================
26
27/// A file-level registry of struct schemas.
28///
29/// Built from a parsed `syn::File`, it maps every struct name to its fields
30/// (name, syn::Type, doc-string). The registry is then used to resolve
31/// `syn::Type` → `PyroType` with full knowledge of sibling structs.
32pub struct SchemaBuilder {
33    /// Struct name → list of (field_name, field_type, field_doc).
34    structs: HashMap<String, StructEntry>,
35    /// Foreign struct name → PyroSchema.
36    foreign_structs: HashMap<String, PyroSchema<'static>>,
37}
38
39struct StructEntry {
40    doc: Option<String>,
41    fields: Vec<FieldEntry>,
42}
43
44struct FieldEntry {
45    name: String,
46    ty: syn::Type,
47    doc: Option<String>,
48}
49
50impl SchemaBuilder {
51    // -----------------------------------------------------------------
52    // Construction
53    // -----------------------------------------------------------------
54
55    /// Scan every `Item::Struct` in the file and register it.
56    pub fn from_file(file: &syn::File) -> Self {
57        let mut structs = HashMap::new();
58        for item in &file.items {
59            if let syn::Item::Struct(s) = item {
60                if !(has_attr(&s.attrs, "config") || has_attr(&s.attrs, "magma")) {
61                    continue;
62                }
63                let name = s.ident.to_string();
64                let doc = extract_doc_string(&s.attrs);
65                let fields = Self::collect_fields(&s.fields);
66                structs.insert(name, StructEntry { doc, fields });
67            }
68        }
69        Self {
70            structs,
71            foreign_structs: HashMap::new(),
72        }
73    }
74
75    pub fn with_foreign_specs(
76        mut self,
77        dep_interfaces: &[pyro_spec::InterfaceSpec<'static>],
78    ) -> Self {
79        for spec in dep_interfaces {
80            for (struct_name, schema) in &spec.structs {
81                self.foreign_structs.insert(struct_name.to_string(), schema.clone());
82            }
83        }
84        self
85    }
86
87    pub fn struct_names(&self) -> Vec<String> {
88        self.structs.keys().cloned().collect()
89    }
90
91    fn collect_fields(fields: &Fields) -> Vec<FieldEntry> {
92        match fields {
93            Fields::Named(named) => named
94                .named
95                .iter()
96                .map(|f| FieldEntry {
97                    name: f.ident.as_ref().unwrap().to_string(),
98                    ty: f.ty.clone(),
99                    doc: extract_doc_string(&f.attrs),
100                })
101                .collect(),
102            Fields::Unnamed(unnamed) => unnamed
103                .unnamed
104                .iter()
105                .enumerate()
106                .map(|(i, f)| FieldEntry {
107                    name: i.to_string(),
108                    ty: f.ty.clone(),
109                    doc: extract_doc_string(&f.attrs),
110                })
111                .collect(),
112            Fields::Unit => vec![],
113        }
114    }
115
116    // -----------------------------------------------------------------
117    // Resolution
118    // -----------------------------------------------------------------
119
120    /// Build a `PyroSchema` for a struct that is in the registry.
121    pub fn schema_for(&self, struct_name: &str) -> Option<PyroSchema<'static>> {
122        let entry = self.structs.get(struct_name)?;
123        let mut visited = Vec::new();
124        let fields = self.resolve_fields_inner(&entry.fields, &mut visited);
125        let mut schema = PyroSchema::new(fields);
126        if let Some(d) = &entry.doc {
127            schema = schema.add_docstring(Cow::Owned(d.clone()));
128        }
129        Some(schema)
130    }
131
132    /// Resolve a `syn::Type` to a `PyroType`, expanding known struct names
133    /// into full `Group(fields)`.
134    pub fn resolve_type(&self, ty: &syn::Type) -> PyroType<'static> {
135        self.resolve_type_inner(ty, &mut Vec::new())
136    }
137
138    /// Check whether a `syn::Type` is `Option<_>`.
139    pub fn is_option(ty: &syn::Type) -> bool {
140        is_option_type(ty)
141    }
142
143    // -----------------------------------------------------------------
144    // Internals
145    // -----------------------------------------------------------------
146
147    fn resolve_fields_inner(
148        &self,
149        fields: &[FieldEntry],
150        visited: &mut Vec<String>,
151    ) -> Vec<PyroField<'static>> {
152        fields
153            .iter()
154            .map(|f| {
155                let data_type = self.resolve_type_inner(&f.ty, visited);
156                let nullable = is_option_type(&f.ty);
157                let mut field = PyroField::new(Cow::Owned(f.name.clone()), data_type, nullable);
158                if let Some(doc) = &f.doc {
159                    field = field.add_docstring(Cow::Owned(doc.clone()));
160                }
161                field
162            })
163            .collect()
164    }
165
166    /// Core resolver. `visited` tracks struct names on the current path to
167    /// break infinite recursion from cyclic types (which shouldn't occur in
168    /// practice but we guard against it).
169    fn resolve_type_inner(&self, ty: &syn::Type, visited: &mut Vec<String>) -> PyroType<'static> {
170        match ty {
171            syn::Type::Path(type_path) => {
172                let segment = match type_path.path.segments.last() {
173                    Some(s) => s,
174                    None => return PyroType::Null,
175                };
176                let ident_str = segment.ident.to_string();
177
178                match ident_str.as_str() {
179                    // --- Primitives ---
180                    "bool" => PyroType::PrimitiveScalar(PrimitiveDataType::Bool),
181                    "u8" => PyroType::PrimitiveScalar(PrimitiveDataType::U8),
182                    "u16" => PyroType::PrimitiveScalar(PrimitiveDataType::U16),
183                    "u32" => PyroType::PrimitiveScalar(PrimitiveDataType::U32),
184                    "u64" => PyroType::PrimitiveScalar(PrimitiveDataType::U64),
185                    "i8" => PyroType::PrimitiveScalar(PrimitiveDataType::I8),
186                    "i16" => PyroType::PrimitiveScalar(PrimitiveDataType::I16),
187                    "i32" => PyroType::PrimitiveScalar(PrimitiveDataType::I32),
188                    "i64" => PyroType::PrimitiveScalar(PrimitiveDataType::I64),
189                    "f16" => PyroType::PrimitiveScalar(PrimitiveDataType::F16),
190                    "f32" => PyroType::PrimitiveScalar(PrimitiveDataType::F32),
191                    "f64" => PyroType::PrimitiveScalar(PrimitiveDataType::F64),
192
193                    // --- Strings ---
194                    "String" | "str" => PyroType::Str,
195
196                    // --- Bytes ---
197                    "Bytes" => PyroType::PrimitiveList(PrimitiveDataType::U8),
198
199                    // --- Option<T> ---
200                    "Option" => {
201                        if let Some(inner) = extract_single_generic_arg(segment) {
202                            self.resolve_type_inner(inner, visited)
203                        } else {
204                            PyroType::Null
205                        }
206                    }
207
208                    // --- Vec<T> ---
209                    "Vec" => {
210                        if let Some(inner) = extract_single_generic_arg(segment) {
211                            let inner_pyro = self.resolve_type_inner(inner, visited);
212                            match &inner_pyro {
213                                PyroType::PrimitiveScalar(p) => PyroType::PrimitiveList(*p),
214                                _ => PyroType::List(Box::new(inner_pyro), false),
215                            }
216                        } else {
217                            PyroType::Null
218                        }
219                    }
220
221                    // --- HashMap / BTreeMap ---
222                    "HashMap" | "BTreeMap" => {
223                        if let Some((k, v)) = extract_two_generic_args(segment) {
224                            PyroType::Map {
225                                key: Box::new(self.resolve_type_inner(k, visited)),
226                                value: Box::new(self.resolve_type_inner(v, visited)),
227                            }
228                        } else {
229                            PyroType::Null
230                        }
231                    }
232
233                    // --- Result<T, E> ---
234                    "Result" => {
235                        if let Some((ok, _err)) = extract_two_generic_args(segment) {
236                            self.resolve_type_inner(ok, visited)
237                        } else {
238                            PyroType::Null
239                        }
240                    }
241
242                    // --- DateTime ---
243                    "DateTime" => PyroType::Timestamp,
244
245                    // --- User-defined struct (look up in registry) ---
246                    other => {
247                        let mut entry_opt = self.structs.get(other);
248                        let mut resolved_name = other;
249                        if entry_opt.is_none() && other.ends_with("Ref") {
250                            let stripped = &other[..other.len() - 3];
251                            if let Some(entry) = self.structs.get(stripped) {
252                                entry_opt = Some(entry);
253                                resolved_name = stripped;
254                            }
255                        }
256
257                        if let Some(entry) = entry_opt {
258                            if visited.contains(&resolved_name.to_string()) {
259                                // Cycle guard — return empty group
260                                return PyroType::Group(Cow::Owned(vec![]));
261                            }
262                            visited.push(resolved_name.to_string());
263                            let fields = self.resolve_fields_inner(&entry.fields, visited);
264                            visited.pop();
265                            PyroType::Group(Cow::Owned(fields))
266                        } else {
267                            // Try foreign structs
268                            let base_name = if other.ends_with("Ref") {
269                                &other[..other.len() - 3]
270                            } else {
271                                other
272                            };
273
274                            if let Some(schema) = self.foreign_structs.get(base_name) {
275                                return PyroType::Group(Cow::Owned(
276                                    schema.fields.iter().map(|f| f.clone().into_owned()).collect(),
277                                ));
278                            }
279
280                            // Unknown struct — opaque group
281                            PyroType::Group(Cow::Owned(vec![]))
282                        }
283                    }
284                }
285            }
286            syn::Type::Reference(r) => self.resolve_type_inner(&r.elem, visited),
287            syn::Type::Tuple(t) if t.elems.is_empty() => PyroType::Null,
288            _ => PyroType::Null,
289        }
290    }
291}
292
293// =============================================================================
294// Helpers (private)
295// =============================================================================
296
297fn is_option_type(ty: &syn::Type) -> bool {
298    if let syn::Type::Path(type_path) = ty
299        && let Some(seg) = type_path.path.segments.last()
300    {
301        return seg.ident == "Option";
302    }
303    false
304}
305
306fn extract_single_generic_arg(segment: &syn::PathSegment) -> Option<&syn::Type> {
307    if let syn::PathArguments::AngleBracketed(args) = &segment.arguments
308        && let Some(syn::GenericArgument::Type(ty)) = args.args.first()
309    {
310        return Some(ty);
311    }
312    None
313}
314
315fn extract_two_generic_args(segment: &syn::PathSegment) -> Option<(&syn::Type, &syn::Type)> {
316    if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
317        let mut iter = args.args.iter();
318        if let (Some(syn::GenericArgument::Type(a)), Some(syn::GenericArgument::Type(b))) =
319            (iter.next(), iter.next())
320        {
321            return Some((a, b));
322        }
323    }
324    None
325}
326
327fn extract_doc_string(attrs: &[Attribute]) -> Option<String> {
328    let mut lines = Vec::new();
329    for attr in attrs {
330        if attr.path().is_ident("doc")
331            && let Meta::NameValue(nv) = &attr.meta
332            && let Expr::Lit(expr_lit) = &nv.value
333            && let Lit::Str(lit_str) = &expr_lit.lit
334        {
335            lines.push(lit_str.value().trim().to_string());
336        }
337    }
338    if lines.is_empty() {
339        None
340    } else {
341        Some(lines.join("\n"))
342    }
343}
344
345// =============================================================================
346// Tests
347// =============================================================================
348
349#[cfg(test)]
350mod tests {
351    use super::*;
352    use quote::quote;
353    use syn::parse2;
354
355    fn builder_from_tokens(tokens: proc_macro2::TokenStream) -> SchemaBuilder {
356        let file: syn::File = syn::parse2(tokens).unwrap();
357        SchemaBuilder::from_file(&file)
358    }
359
360    // --- Primitive resolution ---
361
362    #[test]
363    fn test_resolve_primitives() {
364        let builder = builder_from_tokens(quote! {});
365
366        let ty: syn::Type = parse2(quote!(u32)).unwrap();
367        assert_eq!(
368            builder.resolve_type(&ty),
369            PyroType::PrimitiveScalar(PrimitiveDataType::U32)
370        );
371
372        let ty: syn::Type = parse2(quote!(String)).unwrap();
373        assert_eq!(builder.resolve_type(&ty), PyroType::Str);
374
375        let ty: syn::Type = parse2(quote!(f64)).unwrap();
376        assert_eq!(
377            builder.resolve_type(&ty),
378            PyroType::PrimitiveScalar(PrimitiveDataType::F64)
379        );
380    }
381
382    // --- Vec / Option ---
383
384    #[test]
385    fn test_resolve_vec_and_option() {
386        let builder = builder_from_tokens(quote! {});
387
388        let ty: syn::Type = parse2(quote!(Vec<u8>)).unwrap();
389        assert_eq!(
390            builder.resolve_type(&ty),
391            PyroType::PrimitiveList(PrimitiveDataType::U8)
392        );
393
394        let ty: syn::Type = parse2(quote!(Vec<String>)).unwrap();
395        assert_eq!(
396            builder.resolve_type(&ty),
397            PyroType::List(Box::new(PyroType::Str), false)
398        );
399
400        let ty: syn::Type = parse2(quote!(Option<i32>)).unwrap();
401        assert_eq!(
402            builder.resolve_type(&ty),
403            PyroType::PrimitiveScalar(PrimitiveDataType::I32)
404        );
405        assert!(SchemaBuilder::is_option(&ty));
406    }
407
408    // --- Nested struct resolution (the whole point) ---
409
410    #[test]
411    fn test_resolve_nested_struct() {
412        let builder = builder_from_tokens(quote! {
413            #[config]
414            struct Foo {
415                woobie: String,
416            }
417
418            #[config]
419            struct Bar {
420                doobie: Foo,
421            }
422        });
423
424        // Foo resolves to Group with one Str field
425        let ty_foo: syn::Type = parse2(quote!(Foo)).unwrap();
426        assert_eq!(
427            builder.resolve_type(&ty_foo),
428            PyroType::Group(Cow::Owned(vec![PyroField::new(
429                Cow::Borrowed("woobie"),
430                PyroType::Str,
431                false,
432            )]))
433        );
434
435        // Bar resolves to Group with one Group field (the Foo)
436        let schema = builder.schema_for("Bar").unwrap();
437        assert_eq!(schema.fields.len(), 1);
438
439        let doobie = &schema.fields()[0];
440        assert_eq!(doobie.name(), "doobie");
441        match &doobie.data_type {
442            PyroType::Group(inner_fields) => {
443                assert_eq!(inner_fields.len(), 1);
444                assert_eq!(inner_fields[0].name(), "woobie");
445                assert_eq!(inner_fields[0].data_type, PyroType::Str);
446            }
447            other => panic!("expected Group, got {:?}", other),
448        }
449    }
450
451    // --- Deeply nested ---
452
453    #[test]
454    fn test_resolve_deeply_nested() {
455        let builder = builder_from_tokens(quote! {
456            #[config]
457            struct A {
458                x: i32,
459            }
460            #[config]
461            struct B {
462                a: A,
463                name: String,
464            }
465            #[config]
466            struct C {
467                b: B,
468                flag: bool,
469            }
470        });
471
472        let schema_c = builder.schema_for("C").unwrap();
473        assert_eq!(schema_c.fields.len(), 2);
474
475        // b field should be Group([ Group([x:I32]), name:Str ])
476        let b_field = &schema_c.fields()[0];
477        assert_eq!(b_field.name(), "b");
478        match &b_field.data_type {
479            PyroType::Group(b_fields) => {
480                assert_eq!(b_fields.len(), 2);
481                assert_eq!(b_fields[0].name(), "a");
482                match &b_fields[0].data_type {
483                    PyroType::Group(a_fields) => {
484                        assert_eq!(a_fields.len(), 1);
485                        assert_eq!(a_fields[0].name(), "x");
486                        assert_eq!(
487                            a_fields[0].data_type,
488                            PyroType::PrimitiveScalar(PrimitiveDataType::I32)
489                        );
490                    }
491                    other => panic!("expected Group for A, got {:?}", other),
492                }
493                assert_eq!(b_fields[1].name(), "name");
494                assert_eq!(b_fields[1].data_type, PyroType::Str);
495            }
496            other => panic!("expected Group for B, got {:?}", other),
497        }
498
499        // flag field
500        let flag_field = &schema_c.fields()[1];
501        assert_eq!(flag_field.name(), "flag");
502        assert_eq!(
503            flag_field.data_type,
504            PyroType::PrimitiveScalar(PrimitiveDataType::Bool)
505        );
506    }
507
508    // --- Vec of struct ---
509
510    #[test]
511    fn test_resolve_vec_of_struct() {
512        let builder = builder_from_tokens(quote! {
513            #[config]
514            struct Item {
515                value: f32,
516            }
517            #[config]
518            struct Container {
519                items: Vec<Item>,
520            }
521        });
522
523        let schema = builder.schema_for("Container").unwrap();
524        let items_field = &schema.fields()[0];
525        assert_eq!(items_field.name(), "items");
526
527        match &items_field.data_type {
528            PyroType::List(inner, nullable) => {
529                assert!(!nullable);
530                match inner.as_ref() {
531                    PyroType::Group(fields) => {
532                        assert_eq!(fields.len(), 1);
533                        assert_eq!(fields[0].name(), "value");
534                        assert_eq!(
535                            fields[0].data_type,
536                            PyroType::PrimitiveScalar(PrimitiveDataType::F32)
537                        );
538                    }
539                    other => panic!("expected Group inside List, got {:?}", other),
540                }
541            }
542            other => panic!("expected List, got {:?}", other),
543        }
544    }
545
546    // --- Doc strings preserved ---
547
548    #[test]
549    fn test_doc_strings_preserved() {
550        let builder = builder_from_tokens(quote! {
551            /// This is Foo
552            #[config]
553            struct Foo {
554                /// The id
555                id: u32,
556                name: String,
557            }
558        });
559
560        let schema = builder.schema_for("Foo").unwrap();
561        assert_eq!(schema.documentation.as_deref(), Some("This is Foo"));
562        assert_eq!(schema.fields.len(), 2);
563        assert_eq!(schema.fields()[0].documentation.as_deref(), Some("The id"));
564        assert!(schema.fields()[1].documentation.is_none());
565    }
566
567    // --- Unknown struct falls back to empty group ---
568
569    #[test]
570    fn test_unknown_struct_empty_group() {
571        let builder = builder_from_tokens(quote! {
572            #[config]
573            struct Wrapper {
574                inner: SomeExternalThing,
575            }
576        });
577
578        let schema = builder.schema_for("Wrapper").unwrap();
579        let inner = &schema.fields()[0];
580        assert_eq!(inner.data_type, PyroType::Group(Cow::Owned(vec![])));
581    }
582
583    // --- Cycle guard ---
584
585    #[test]
586    fn test_cycle_guard() {
587        // Contrived: struct A has field of type A. Should not stack overflow.
588        let builder = builder_from_tokens(quote! {
589            #[config]
590            struct A {
591                next: A,
592            }
593        });
594
595        let schema = builder.schema_for("A").unwrap();
596        assert_eq!(schema.fields().len(), 1);
597        let next_field = &schema.fields()[0];
598        assert_eq!(next_field.name(), "next");
599
600        // The `next` field has type A, which is a Group containing one field also named `next`.
601        // The inner self-reference is cut off as Group([]) to break the cycle.
602        match &next_field.data_type {
603            PyroType::Group(inner_fields) => {
604                assert_eq!(inner_fields.len(), 1);
605                assert_eq!(inner_fields[0].name(), "next");
606                // Cycle broken at this depth — the recursive self-ref is an empty group
607                assert_eq!(
608                    inner_fields[0].data_type,
609                    PyroType::Group(Cow::Owned(vec![]))
610                );
611            }
612            other => panic!("expected Group for A's next field, got {:?}", other),
613        }
614    }
615
616    // --- Map of struct ---
617
618    #[test]
619    fn test_resolve_map_of_struct() {
620        let builder = builder_from_tokens(quote! {
621            #[config]
622            struct Config {
623                key: String,
624            }
625            #[config]
626            struct Registry {
627                entries: HashMap<String, Config>,
628            }
629        });
630
631        let schema = builder.schema_for("Registry").unwrap();
632        let entries = &schema.fields()[0];
633
634        match &entries.data_type {
635            PyroType::Map { key, value } => {
636                assert_eq!(key.as_ref(), &PyroType::Str);
637                match value.as_ref() {
638                    PyroType::Group(fields) => {
639                        assert_eq!(fields.len(), 1);
640                        assert_eq!(fields[0].name(), "key");
641                    }
642                    other => panic!("expected Group for Config, got {:?}", other),
643                }
644            }
645            other => panic!("expected Map, got {:?}", other),
646        }
647    }
648}