1use std::borrow::Cow;
17use std::collections::HashMap;
18
19use crate::utils::has_attr;
20use pyro_spec::{PrimitiveDataType, PyroField, PyroSchema, PyroType};
21use syn::{Attribute, Expr, Fields, Lit, Meta};
22
23pub struct SchemaBuilder {
33 structs: HashMap<String, StructEntry>,
35 foreign_structs: HashMap<String, PyroSchema<'static>>,
37}
38
39struct StructEntry {
40 doc: Option<String>,
41 fields: Vec<FieldEntry>,
42}
43
44struct FieldEntry {
45 name: String,
46 ty: syn::Type,
47 doc: Option<String>,
48}
49
50impl SchemaBuilder {
51 pub fn from_file(file: &syn::File) -> Self {
57 let mut structs = HashMap::new();
58 for item in &file.items {
59 if let syn::Item::Struct(s) = item {
60 if !(has_attr(&s.attrs, "config") || has_attr(&s.attrs, "magma")) {
61 continue;
62 }
63 let name = s.ident.to_string();
64 let doc = extract_doc_string(&s.attrs);
65 let fields = Self::collect_fields(&s.fields);
66 structs.insert(name, StructEntry { doc, fields });
67 }
68 }
69 Self {
70 structs,
71 foreign_structs: HashMap::new(),
72 }
73 }
74
75 pub fn with_foreign_specs(
76 mut self,
77 dep_interfaces: &[pyro_spec::InterfaceSpec<'static>],
78 ) -> Self {
79 for spec in dep_interfaces {
80 for (struct_name, schema) in &spec.structs {
81 self.foreign_structs.insert(struct_name.to_string(), schema.clone());
82 }
83 }
84 self
85 }
86
87 pub fn struct_names(&self) -> Vec<String> {
88 self.structs.keys().cloned().collect()
89 }
90
91 fn collect_fields(fields: &Fields) -> Vec<FieldEntry> {
92 match fields {
93 Fields::Named(named) => named
94 .named
95 .iter()
96 .map(|f| FieldEntry {
97 name: f.ident.as_ref().unwrap().to_string(),
98 ty: f.ty.clone(),
99 doc: extract_doc_string(&f.attrs),
100 })
101 .collect(),
102 Fields::Unnamed(unnamed) => unnamed
103 .unnamed
104 .iter()
105 .enumerate()
106 .map(|(i, f)| FieldEntry {
107 name: i.to_string(),
108 ty: f.ty.clone(),
109 doc: extract_doc_string(&f.attrs),
110 })
111 .collect(),
112 Fields::Unit => vec![],
113 }
114 }
115
116 pub fn schema_for(&self, struct_name: &str) -> Option<PyroSchema<'static>> {
122 let entry = self.structs.get(struct_name)?;
123 let mut visited = Vec::new();
124 let fields = self.resolve_fields_inner(&entry.fields, &mut visited);
125 let mut schema = PyroSchema::new(fields);
126 if let Some(d) = &entry.doc {
127 schema = schema.add_docstring(Cow::Owned(d.clone()));
128 }
129 Some(schema)
130 }
131
132 pub fn resolve_type(&self, ty: &syn::Type) -> PyroType<'static> {
135 self.resolve_type_inner(ty, &mut Vec::new())
136 }
137
138 pub fn is_option(ty: &syn::Type) -> bool {
140 is_option_type(ty)
141 }
142
143 fn resolve_fields_inner(
148 &self,
149 fields: &[FieldEntry],
150 visited: &mut Vec<String>,
151 ) -> Vec<PyroField<'static>> {
152 fields
153 .iter()
154 .map(|f| {
155 let data_type = self.resolve_type_inner(&f.ty, visited);
156 let nullable = is_option_type(&f.ty);
157 let mut field = PyroField::new(Cow::Owned(f.name.clone()), data_type, nullable);
158 if let Some(doc) = &f.doc {
159 field = field.add_docstring(Cow::Owned(doc.clone()));
160 }
161 field
162 })
163 .collect()
164 }
165
166 fn resolve_type_inner(&self, ty: &syn::Type, visited: &mut Vec<String>) -> PyroType<'static> {
170 match ty {
171 syn::Type::Path(type_path) => {
172 let segment = match type_path.path.segments.last() {
173 Some(s) => s,
174 None => return PyroType::Null,
175 };
176 let ident_str = segment.ident.to_string();
177
178 match ident_str.as_str() {
179 "bool" => PyroType::PrimitiveScalar(PrimitiveDataType::Bool),
181 "u8" => PyroType::PrimitiveScalar(PrimitiveDataType::U8),
182 "u16" => PyroType::PrimitiveScalar(PrimitiveDataType::U16),
183 "u32" => PyroType::PrimitiveScalar(PrimitiveDataType::U32),
184 "u64" => PyroType::PrimitiveScalar(PrimitiveDataType::U64),
185 "i8" => PyroType::PrimitiveScalar(PrimitiveDataType::I8),
186 "i16" => PyroType::PrimitiveScalar(PrimitiveDataType::I16),
187 "i32" => PyroType::PrimitiveScalar(PrimitiveDataType::I32),
188 "i64" => PyroType::PrimitiveScalar(PrimitiveDataType::I64),
189 "f16" => PyroType::PrimitiveScalar(PrimitiveDataType::F16),
190 "f32" => PyroType::PrimitiveScalar(PrimitiveDataType::F32),
191 "f64" => PyroType::PrimitiveScalar(PrimitiveDataType::F64),
192
193 "String" | "str" => PyroType::Str,
195
196 "Bytes" => PyroType::PrimitiveList(PrimitiveDataType::U8),
198
199 "Option" => {
201 if let Some(inner) = extract_single_generic_arg(segment) {
202 self.resolve_type_inner(inner, visited)
203 } else {
204 PyroType::Null
205 }
206 }
207
208 "Vec" => {
210 if let Some(inner) = extract_single_generic_arg(segment) {
211 let inner_pyro = self.resolve_type_inner(inner, visited);
212 match &inner_pyro {
213 PyroType::PrimitiveScalar(p) => PyroType::PrimitiveList(*p),
214 _ => PyroType::List(Box::new(inner_pyro), false),
215 }
216 } else {
217 PyroType::Null
218 }
219 }
220
221 "HashMap" | "BTreeMap" => {
223 if let Some((k, v)) = extract_two_generic_args(segment) {
224 PyroType::Map {
225 key: Box::new(self.resolve_type_inner(k, visited)),
226 value: Box::new(self.resolve_type_inner(v, visited)),
227 }
228 } else {
229 PyroType::Null
230 }
231 }
232
233 "Result" => {
235 if let Some((ok, _err)) = extract_two_generic_args(segment) {
236 self.resolve_type_inner(ok, visited)
237 } else {
238 PyroType::Null
239 }
240 }
241
242 "DateTime" => PyroType::Timestamp,
244
245 other => {
247 let mut entry_opt = self.structs.get(other);
248 let mut resolved_name = other;
249 if entry_opt.is_none() && other.ends_with("Ref") {
250 let stripped = &other[..other.len() - 3];
251 if let Some(entry) = self.structs.get(stripped) {
252 entry_opt = Some(entry);
253 resolved_name = stripped;
254 }
255 }
256
257 if let Some(entry) = entry_opt {
258 if visited.contains(&resolved_name.to_string()) {
259 return PyroType::Group(Cow::Owned(vec![]));
261 }
262 visited.push(resolved_name.to_string());
263 let fields = self.resolve_fields_inner(&entry.fields, visited);
264 visited.pop();
265 PyroType::Group(Cow::Owned(fields))
266 } else {
267 let base_name = if other.ends_with("Ref") {
269 &other[..other.len() - 3]
270 } else {
271 other
272 };
273
274 if let Some(schema) = self.foreign_structs.get(base_name) {
275 return PyroType::Group(Cow::Owned(
276 schema.fields.iter().map(|f| f.clone().into_owned()).collect(),
277 ));
278 }
279
280 PyroType::Group(Cow::Owned(vec![]))
282 }
283 }
284 }
285 }
286 syn::Type::Reference(r) => self.resolve_type_inner(&r.elem, visited),
287 syn::Type::Tuple(t) if t.elems.is_empty() => PyroType::Null,
288 _ => PyroType::Null,
289 }
290 }
291}
292
293fn is_option_type(ty: &syn::Type) -> bool {
298 if let syn::Type::Path(type_path) = ty
299 && let Some(seg) = type_path.path.segments.last()
300 {
301 return seg.ident == "Option";
302 }
303 false
304}
305
306fn extract_single_generic_arg(segment: &syn::PathSegment) -> Option<&syn::Type> {
307 if let syn::PathArguments::AngleBracketed(args) = &segment.arguments
308 && let Some(syn::GenericArgument::Type(ty)) = args.args.first()
309 {
310 return Some(ty);
311 }
312 None
313}
314
315fn extract_two_generic_args(segment: &syn::PathSegment) -> Option<(&syn::Type, &syn::Type)> {
316 if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
317 let mut iter = args.args.iter();
318 if let (Some(syn::GenericArgument::Type(a)), Some(syn::GenericArgument::Type(b))) =
319 (iter.next(), iter.next())
320 {
321 return Some((a, b));
322 }
323 }
324 None
325}
326
327fn extract_doc_string(attrs: &[Attribute]) -> Option<String> {
328 let mut lines = Vec::new();
329 for attr in attrs {
330 if attr.path().is_ident("doc")
331 && let Meta::NameValue(nv) = &attr.meta
332 && let Expr::Lit(expr_lit) = &nv.value
333 && let Lit::Str(lit_str) = &expr_lit.lit
334 {
335 lines.push(lit_str.value().trim().to_string());
336 }
337 }
338 if lines.is_empty() {
339 None
340 } else {
341 Some(lines.join("\n"))
342 }
343}
344
345#[cfg(test)]
350mod tests {
351 use super::*;
352 use quote::quote;
353 use syn::parse2;
354
355 fn builder_from_tokens(tokens: proc_macro2::TokenStream) -> SchemaBuilder {
356 let file: syn::File = syn::parse2(tokens).unwrap();
357 SchemaBuilder::from_file(&file)
358 }
359
360 #[test]
363 fn test_resolve_primitives() {
364 let builder = builder_from_tokens(quote! {});
365
366 let ty: syn::Type = parse2(quote!(u32)).unwrap();
367 assert_eq!(
368 builder.resolve_type(&ty),
369 PyroType::PrimitiveScalar(PrimitiveDataType::U32)
370 );
371
372 let ty: syn::Type = parse2(quote!(String)).unwrap();
373 assert_eq!(builder.resolve_type(&ty), PyroType::Str);
374
375 let ty: syn::Type = parse2(quote!(f64)).unwrap();
376 assert_eq!(
377 builder.resolve_type(&ty),
378 PyroType::PrimitiveScalar(PrimitiveDataType::F64)
379 );
380 }
381
382 #[test]
385 fn test_resolve_vec_and_option() {
386 let builder = builder_from_tokens(quote! {});
387
388 let ty: syn::Type = parse2(quote!(Vec<u8>)).unwrap();
389 assert_eq!(
390 builder.resolve_type(&ty),
391 PyroType::PrimitiveList(PrimitiveDataType::U8)
392 );
393
394 let ty: syn::Type = parse2(quote!(Vec<String>)).unwrap();
395 assert_eq!(
396 builder.resolve_type(&ty),
397 PyroType::List(Box::new(PyroType::Str), false)
398 );
399
400 let ty: syn::Type = parse2(quote!(Option<i32>)).unwrap();
401 assert_eq!(
402 builder.resolve_type(&ty),
403 PyroType::PrimitiveScalar(PrimitiveDataType::I32)
404 );
405 assert!(SchemaBuilder::is_option(&ty));
406 }
407
408 #[test]
411 fn test_resolve_nested_struct() {
412 let builder = builder_from_tokens(quote! {
413 #[config]
414 struct Foo {
415 woobie: String,
416 }
417
418 #[config]
419 struct Bar {
420 doobie: Foo,
421 }
422 });
423
424 let ty_foo: syn::Type = parse2(quote!(Foo)).unwrap();
426 assert_eq!(
427 builder.resolve_type(&ty_foo),
428 PyroType::Group(Cow::Owned(vec![PyroField::new(
429 Cow::Borrowed("woobie"),
430 PyroType::Str,
431 false,
432 )]))
433 );
434
435 let schema = builder.schema_for("Bar").unwrap();
437 assert_eq!(schema.fields.len(), 1);
438
439 let doobie = &schema.fields()[0];
440 assert_eq!(doobie.name(), "doobie");
441 match &doobie.data_type {
442 PyroType::Group(inner_fields) => {
443 assert_eq!(inner_fields.len(), 1);
444 assert_eq!(inner_fields[0].name(), "woobie");
445 assert_eq!(inner_fields[0].data_type, PyroType::Str);
446 }
447 other => panic!("expected Group, got {:?}", other),
448 }
449 }
450
451 #[test]
454 fn test_resolve_deeply_nested() {
455 let builder = builder_from_tokens(quote! {
456 #[config]
457 struct A {
458 x: i32,
459 }
460 #[config]
461 struct B {
462 a: A,
463 name: String,
464 }
465 #[config]
466 struct C {
467 b: B,
468 flag: bool,
469 }
470 });
471
472 let schema_c = builder.schema_for("C").unwrap();
473 assert_eq!(schema_c.fields.len(), 2);
474
475 let b_field = &schema_c.fields()[0];
477 assert_eq!(b_field.name(), "b");
478 match &b_field.data_type {
479 PyroType::Group(b_fields) => {
480 assert_eq!(b_fields.len(), 2);
481 assert_eq!(b_fields[0].name(), "a");
482 match &b_fields[0].data_type {
483 PyroType::Group(a_fields) => {
484 assert_eq!(a_fields.len(), 1);
485 assert_eq!(a_fields[0].name(), "x");
486 assert_eq!(
487 a_fields[0].data_type,
488 PyroType::PrimitiveScalar(PrimitiveDataType::I32)
489 );
490 }
491 other => panic!("expected Group for A, got {:?}", other),
492 }
493 assert_eq!(b_fields[1].name(), "name");
494 assert_eq!(b_fields[1].data_type, PyroType::Str);
495 }
496 other => panic!("expected Group for B, got {:?}", other),
497 }
498
499 let flag_field = &schema_c.fields()[1];
501 assert_eq!(flag_field.name(), "flag");
502 assert_eq!(
503 flag_field.data_type,
504 PyroType::PrimitiveScalar(PrimitiveDataType::Bool)
505 );
506 }
507
508 #[test]
511 fn test_resolve_vec_of_struct() {
512 let builder = builder_from_tokens(quote! {
513 #[config]
514 struct Item {
515 value: f32,
516 }
517 #[config]
518 struct Container {
519 items: Vec<Item>,
520 }
521 });
522
523 let schema = builder.schema_for("Container").unwrap();
524 let items_field = &schema.fields()[0];
525 assert_eq!(items_field.name(), "items");
526
527 match &items_field.data_type {
528 PyroType::List(inner, nullable) => {
529 assert!(!nullable);
530 match inner.as_ref() {
531 PyroType::Group(fields) => {
532 assert_eq!(fields.len(), 1);
533 assert_eq!(fields[0].name(), "value");
534 assert_eq!(
535 fields[0].data_type,
536 PyroType::PrimitiveScalar(PrimitiveDataType::F32)
537 );
538 }
539 other => panic!("expected Group inside List, got {:?}", other),
540 }
541 }
542 other => panic!("expected List, got {:?}", other),
543 }
544 }
545
546 #[test]
549 fn test_doc_strings_preserved() {
550 let builder = builder_from_tokens(quote! {
551 #[config]
553 struct Foo {
554 id: u32,
556 name: String,
557 }
558 });
559
560 let schema = builder.schema_for("Foo").unwrap();
561 assert_eq!(schema.documentation.as_deref(), Some("This is Foo"));
562 assert_eq!(schema.fields.len(), 2);
563 assert_eq!(schema.fields()[0].documentation.as_deref(), Some("The id"));
564 assert!(schema.fields()[1].documentation.is_none());
565 }
566
567 #[test]
570 fn test_unknown_struct_empty_group() {
571 let builder = builder_from_tokens(quote! {
572 #[config]
573 struct Wrapper {
574 inner: SomeExternalThing,
575 }
576 });
577
578 let schema = builder.schema_for("Wrapper").unwrap();
579 let inner = &schema.fields()[0];
580 assert_eq!(inner.data_type, PyroType::Group(Cow::Owned(vec![])));
581 }
582
583 #[test]
586 fn test_cycle_guard() {
587 let builder = builder_from_tokens(quote! {
589 #[config]
590 struct A {
591 next: A,
592 }
593 });
594
595 let schema = builder.schema_for("A").unwrap();
596 assert_eq!(schema.fields().len(), 1);
597 let next_field = &schema.fields()[0];
598 assert_eq!(next_field.name(), "next");
599
600 match &next_field.data_type {
603 PyroType::Group(inner_fields) => {
604 assert_eq!(inner_fields.len(), 1);
605 assert_eq!(inner_fields[0].name(), "next");
606 assert_eq!(
608 inner_fields[0].data_type,
609 PyroType::Group(Cow::Owned(vec![]))
610 );
611 }
612 other => panic!("expected Group for A's next field, got {:?}", other),
613 }
614 }
615
616 #[test]
619 fn test_resolve_map_of_struct() {
620 let builder = builder_from_tokens(quote! {
621 #[config]
622 struct Config {
623 key: String,
624 }
625 #[config]
626 struct Registry {
627 entries: HashMap<String, Config>,
628 }
629 });
630
631 let schema = builder.schema_for("Registry").unwrap();
632 let entries = &schema.fields()[0];
633
634 match &entries.data_type {
635 PyroType::Map { key, value } => {
636 assert_eq!(key.as_ref(), &PyroType::Str);
637 match value.as_ref() {
638 PyroType::Group(fields) => {
639 assert_eq!(fields.len(), 1);
640 assert_eq!(fields[0].name(), "key");
641 }
642 other => panic!("expected Group for Config, got {:?}", other),
643 }
644 }
645 other => panic!("expected Map, got {:?}", other),
646 }
647 }
648}