1use std::borrow::Cow;
17use std::collections::HashMap;
18
19use crate::utils::has_attr;
20use pyro_spec::{PrimitiveDataType, PyroField, PyroSchema, PyroType};
21use syn::{Attribute, Expr, Fields, Lit, Meta};
22
23pub struct SchemaBuilder {
33 structs: HashMap<String, StructEntry>,
35}
36
37struct StructEntry {
38 doc: Option<String>,
39 fields: Vec<FieldEntry>,
40}
41
42struct FieldEntry {
43 name: String,
44 ty: syn::Type,
45 doc: Option<String>,
46}
47
48impl SchemaBuilder {
49 pub fn from_file(file: &syn::File) -> Self {
55 let mut structs = HashMap::new();
56 for item in &file.items {
57 if let syn::Item::Struct(s) = item {
58 if !(has_attr(&s.attrs, "config") || has_attr(&s.attrs, "magma")) {
59 continue;
60 }
61 let name = s.ident.to_string();
62 let doc = extract_doc_string(&s.attrs);
63 let fields = Self::collect_fields(&s.fields);
64 structs.insert(name, StructEntry { doc, fields });
65 }
66 }
67 Self { structs }
68 }
69
70 fn collect_fields(fields: &Fields) -> Vec<FieldEntry> {
71 match fields {
72 Fields::Named(named) => named
73 .named
74 .iter()
75 .map(|f| FieldEntry {
76 name: f.ident.as_ref().unwrap().to_string(),
77 ty: f.ty.clone(),
78 doc: extract_doc_string(&f.attrs),
79 })
80 .collect(),
81 Fields::Unnamed(unnamed) => unnamed
82 .unnamed
83 .iter()
84 .enumerate()
85 .map(|(i, f)| FieldEntry {
86 name: i.to_string(),
87 ty: f.ty.clone(),
88 doc: extract_doc_string(&f.attrs),
89 })
90 .collect(),
91 Fields::Unit => vec![],
92 }
93 }
94
95 pub fn schema_for(&self, struct_name: &str) -> Option<PyroSchema<'static>> {
101 let entry = self.structs.get(struct_name)?;
102 let mut visited = Vec::new();
103 let fields = self.resolve_fields_inner(&entry.fields, &mut visited);
104 let mut schema = PyroSchema::new(fields);
105 if let Some(d) = &entry.doc {
106 schema = schema.add_docstring(Cow::Owned(d.clone()));
107 }
108 Some(schema)
109 }
110
111 pub fn resolve_type(&self, ty: &syn::Type) -> PyroType<'static> {
114 self.resolve_type_inner(ty, &mut Vec::new())
115 }
116
117 pub fn is_option(ty: &syn::Type) -> bool {
119 is_option_type(ty)
120 }
121
122 fn resolve_fields_inner(
127 &self,
128 fields: &[FieldEntry],
129 visited: &mut Vec<String>,
130 ) -> Vec<PyroField<'static>> {
131 fields
132 .iter()
133 .map(|f| {
134 let data_type = self.resolve_type_inner(&f.ty, visited);
135 let nullable = is_option_type(&f.ty);
136 let mut field = PyroField::new(Cow::Owned(f.name.clone()), data_type, nullable);
137 if let Some(doc) = &f.doc {
138 field = field.add_docstring(Cow::Owned(doc.clone()));
139 }
140 field
141 })
142 .collect()
143 }
144
145 fn resolve_type_inner(&self, ty: &syn::Type, visited: &mut Vec<String>) -> PyroType<'static> {
149 match ty {
150 syn::Type::Path(type_path) => {
151 let segment = match type_path.path.segments.last() {
152 Some(s) => s,
153 None => return PyroType::Null,
154 };
155 let ident_str = segment.ident.to_string();
156
157 match ident_str.as_str() {
158 "bool" => PyroType::PrimitiveScalar(PrimitiveDataType::Bool),
160 "u8" => PyroType::PrimitiveScalar(PrimitiveDataType::U8),
161 "u16" => PyroType::PrimitiveScalar(PrimitiveDataType::U16),
162 "u32" => PyroType::PrimitiveScalar(PrimitiveDataType::U32),
163 "u64" => PyroType::PrimitiveScalar(PrimitiveDataType::U64),
164 "i8" => PyroType::PrimitiveScalar(PrimitiveDataType::I8),
165 "i16" => PyroType::PrimitiveScalar(PrimitiveDataType::I16),
166 "i32" => PyroType::PrimitiveScalar(PrimitiveDataType::I32),
167 "i64" => PyroType::PrimitiveScalar(PrimitiveDataType::I64),
168 "f16" => PyroType::PrimitiveScalar(PrimitiveDataType::F16),
169 "f32" => PyroType::PrimitiveScalar(PrimitiveDataType::F32),
170 "f64" => PyroType::PrimitiveScalar(PrimitiveDataType::F64),
171
172 "String" | "str" => PyroType::Str,
174
175 "Bytes" => PyroType::PrimitiveList(PrimitiveDataType::U8),
177
178 "Option" => {
180 if let Some(inner) = extract_single_generic_arg(segment) {
181 self.resolve_type_inner(inner, visited)
182 } else {
183 PyroType::Null
184 }
185 }
186
187 "Vec" => {
189 if let Some(inner) = extract_single_generic_arg(segment) {
190 let inner_pyro = self.resolve_type_inner(inner, visited);
191 match &inner_pyro {
192 PyroType::PrimitiveScalar(p) => PyroType::PrimitiveList(*p),
193 _ => PyroType::List(Box::new(inner_pyro), false),
194 }
195 } else {
196 PyroType::Null
197 }
198 }
199
200 "HashMap" | "BTreeMap" => {
202 if let Some((k, v)) = extract_two_generic_args(segment) {
203 PyroType::Map {
204 key: Box::new(self.resolve_type_inner(k, visited)),
205 value: Box::new(self.resolve_type_inner(v, visited)),
206 }
207 } else {
208 PyroType::Null
209 }
210 }
211
212 "Result" => {
214 if let Some((ok, _err)) = extract_two_generic_args(segment) {
215 self.resolve_type_inner(ok, visited)
216 } else {
217 PyroType::Null
218 }
219 }
220
221 "DateTime" => PyroType::Timestamp,
223
224 other => {
226 if visited.contains(&other.to_string()) {
227 return PyroType::Group(Cow::Owned(vec![]));
229 }
230 if let Some(entry) = self.structs.get(other) {
231 visited.push(other.to_string());
232 let fields = self.resolve_fields_inner(&entry.fields, visited);
233 visited.pop();
234 PyroType::Group(Cow::Owned(fields))
235 } else {
236 PyroType::Group(Cow::Owned(vec![]))
238 }
239 }
240 }
241 }
242 syn::Type::Reference(r) => self.resolve_type_inner(&r.elem, visited),
243 syn::Type::Tuple(t) if t.elems.is_empty() => PyroType::Null,
244 _ => PyroType::Null,
245 }
246 }
247}
248
249fn is_option_type(ty: &syn::Type) -> bool {
254 if let syn::Type::Path(type_path) = ty {
255 if let Some(seg) = type_path.path.segments.last() {
256 return seg.ident == "Option";
257 }
258 }
259 false
260}
261
262fn extract_single_generic_arg(segment: &syn::PathSegment) -> Option<&syn::Type> {
263 if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
264 if let Some(syn::GenericArgument::Type(ty)) = args.args.first() {
265 return Some(ty);
266 }
267 }
268 None
269}
270
271fn extract_two_generic_args(segment: &syn::PathSegment) -> Option<(&syn::Type, &syn::Type)> {
272 if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
273 let mut iter = args.args.iter();
274 if let (Some(syn::GenericArgument::Type(a)), Some(syn::GenericArgument::Type(b))) =
275 (iter.next(), iter.next())
276 {
277 return Some((a, b));
278 }
279 }
280 None
281}
282
283fn extract_doc_string(attrs: &[Attribute]) -> Option<String> {
284 let mut lines = Vec::new();
285 for attr in attrs {
286 if attr.path().is_ident("doc") {
287 if let Meta::NameValue(nv) = &attr.meta {
288 if let Expr::Lit(expr_lit) = &nv.value {
289 if let Lit::Str(lit_str) = &expr_lit.lit {
290 lines.push(lit_str.value().trim().to_string());
291 }
292 }
293 }
294 }
295 }
296 if lines.is_empty() {
297 None
298 } else {
299 Some(lines.join("\n"))
300 }
301}
302
303#[cfg(test)]
308mod tests {
309 use super::*;
310 use quote::quote;
311 use syn::parse2;
312
313 fn builder_from_tokens(tokens: proc_macro2::TokenStream) -> SchemaBuilder {
314 let file: syn::File = syn::parse2(tokens).unwrap();
315 SchemaBuilder::from_file(&file)
316 }
317
318 #[test]
321 fn test_resolve_primitives() {
322 let builder = builder_from_tokens(quote! {});
323
324 let ty: syn::Type = parse2(quote!(u32)).unwrap();
325 assert_eq!(
326 builder.resolve_type(&ty),
327 PyroType::PrimitiveScalar(PrimitiveDataType::U32)
328 );
329
330 let ty: syn::Type = parse2(quote!(String)).unwrap();
331 assert_eq!(builder.resolve_type(&ty), PyroType::Str);
332
333 let ty: syn::Type = parse2(quote!(f64)).unwrap();
334 assert_eq!(
335 builder.resolve_type(&ty),
336 PyroType::PrimitiveScalar(PrimitiveDataType::F64)
337 );
338 }
339
340 #[test]
343 fn test_resolve_vec_and_option() {
344 let builder = builder_from_tokens(quote! {});
345
346 let ty: syn::Type = parse2(quote!(Vec<u8>)).unwrap();
347 assert_eq!(
348 builder.resolve_type(&ty),
349 PyroType::PrimitiveList(PrimitiveDataType::U8)
350 );
351
352 let ty: syn::Type = parse2(quote!(Vec<String>)).unwrap();
353 assert_eq!(
354 builder.resolve_type(&ty),
355 PyroType::List(Box::new(PyroType::Str), false)
356 );
357
358 let ty: syn::Type = parse2(quote!(Option<i32>)).unwrap();
359 assert_eq!(
360 builder.resolve_type(&ty),
361 PyroType::PrimitiveScalar(PrimitiveDataType::I32)
362 );
363 assert!(SchemaBuilder::is_option(&ty));
364 }
365
366 #[test]
369 fn test_resolve_nested_struct() {
370 let builder = builder_from_tokens(quote! {
371 #[config]
372 struct Foo {
373 woobie: String,
374 }
375
376 #[config]
377 struct Bar {
378 doobie: Foo,
379 }
380 });
381
382 let ty_foo: syn::Type = parse2(quote!(Foo)).unwrap();
384 assert_eq!(
385 builder.resolve_type(&ty_foo),
386 PyroType::Group(Cow::Owned(vec![PyroField::new(
387 Cow::Borrowed("woobie"),
388 PyroType::Str,
389 false,
390 )]))
391 );
392
393 let schema = builder.schema_for("Bar").unwrap();
395 assert_eq!(schema.fields.len(), 1);
396
397 let doobie = &schema.fields()[0];
398 assert_eq!(doobie.name(), "doobie");
399 match &doobie.data_type {
400 PyroType::Group(inner_fields) => {
401 assert_eq!(inner_fields.len(), 1);
402 assert_eq!(inner_fields[0].name(), "woobie");
403 assert_eq!(inner_fields[0].data_type, PyroType::Str);
404 }
405 other => panic!("expected Group, got {:?}", other),
406 }
407 }
408
409 #[test]
412 fn test_resolve_deeply_nested() {
413 let builder = builder_from_tokens(quote! {
414 #[config]
415 struct A {
416 x: i32,
417 }
418 #[config]
419 struct B {
420 a: A,
421 name: String,
422 }
423 #[config]
424 struct C {
425 b: B,
426 flag: bool,
427 }
428 });
429
430 let schema_c = builder.schema_for("C").unwrap();
431 assert_eq!(schema_c.fields.len(), 2);
432
433 let b_field = &schema_c.fields()[0];
435 assert_eq!(b_field.name(), "b");
436 match &b_field.data_type {
437 PyroType::Group(b_fields) => {
438 assert_eq!(b_fields.len(), 2);
439 assert_eq!(b_fields[0].name(), "a");
440 match &b_fields[0].data_type {
441 PyroType::Group(a_fields) => {
442 assert_eq!(a_fields.len(), 1);
443 assert_eq!(a_fields[0].name(), "x");
444 assert_eq!(
445 a_fields[0].data_type,
446 PyroType::PrimitiveScalar(PrimitiveDataType::I32)
447 );
448 }
449 other => panic!("expected Group for A, got {:?}", other),
450 }
451 assert_eq!(b_fields[1].name(), "name");
452 assert_eq!(b_fields[1].data_type, PyroType::Str);
453 }
454 other => panic!("expected Group for B, got {:?}", other),
455 }
456
457 let flag_field = &schema_c.fields()[1];
459 assert_eq!(flag_field.name(), "flag");
460 assert_eq!(
461 flag_field.data_type,
462 PyroType::PrimitiveScalar(PrimitiveDataType::Bool)
463 );
464 }
465
466 #[test]
469 fn test_resolve_vec_of_struct() {
470 let builder = builder_from_tokens(quote! {
471 #[config]
472 struct Item {
473 value: f32,
474 }
475 #[config]
476 struct Container {
477 items: Vec<Item>,
478 }
479 });
480
481 let schema = builder.schema_for("Container").unwrap();
482 let items_field = &schema.fields()[0];
483 assert_eq!(items_field.name(), "items");
484
485 match &items_field.data_type {
486 PyroType::List(inner, nullable) => {
487 assert!(!nullable);
488 match inner.as_ref() {
489 PyroType::Group(fields) => {
490 assert_eq!(fields.len(), 1);
491 assert_eq!(fields[0].name(), "value");
492 assert_eq!(
493 fields[0].data_type,
494 PyroType::PrimitiveScalar(PrimitiveDataType::F32)
495 );
496 }
497 other => panic!("expected Group inside List, got {:?}", other),
498 }
499 }
500 other => panic!("expected List, got {:?}", other),
501 }
502 }
503
504 #[test]
507 fn test_doc_strings_preserved() {
508 let builder = builder_from_tokens(quote! {
509 #[config]
511 struct Foo {
512 id: u32,
514 name: String,
515 }
516 });
517
518 let schema = builder.schema_for("Foo").unwrap();
519 assert_eq!(schema.documentation.as_deref(), Some("This is Foo"));
520 assert_eq!(schema.fields.len(), 2);
521 assert_eq!(schema.fields()[0].documentation.as_deref(), Some("The id"));
522 assert!(schema.fields()[1].documentation.is_none());
523 }
524
525 #[test]
528 fn test_unknown_struct_empty_group() {
529 let builder = builder_from_tokens(quote! {
530 #[config]
531 struct Wrapper {
532 inner: SomeExternalThing,
533 }
534 });
535
536 let schema = builder.schema_for("Wrapper").unwrap();
537 let inner = &schema.fields()[0];
538 assert_eq!(inner.data_type, PyroType::Group(Cow::Owned(vec![])));
539 }
540
541 #[test]
544 fn test_cycle_guard() {
545 let builder = builder_from_tokens(quote! {
547 #[config]
548 struct A {
549 next: A,
550 }
551 });
552
553 let schema = builder.schema_for("A").unwrap();
554 assert_eq!(schema.fields().len(), 1);
555 let next_field = &schema.fields()[0];
556 assert_eq!(next_field.name(), "next");
557
558 match &next_field.data_type {
561 PyroType::Group(inner_fields) => {
562 assert_eq!(inner_fields.len(), 1);
563 assert_eq!(inner_fields[0].name(), "next");
564 assert_eq!(
566 inner_fields[0].data_type,
567 PyroType::Group(Cow::Owned(vec![]))
568 );
569 }
570 other => panic!("expected Group for A's next field, got {:?}", other),
571 }
572 }
573
574 #[test]
577 fn test_resolve_map_of_struct() {
578 let builder = builder_from_tokens(quote! {
579 #[config]
580 struct Config {
581 key: String,
582 }
583 #[config]
584 struct Registry {
585 entries: HashMap<String, Config>,
586 }
587 });
588
589 let schema = builder.schema_for("Registry").unwrap();
590 let entries = &schema.fields()[0];
591
592 match &entries.data_type {
593 PyroType::Map { key, value } => {
594 assert_eq!(key.as_ref(), &PyroType::Str);
595 match value.as_ref() {
596 PyroType::Group(fields) => {
597 assert_eq!(fields.len(), 1);
598 assert_eq!(fields[0].name(), "key");
599 }
600 other => panic!("expected Group for Config, got {:?}", other),
601 }
602 }
603 other => panic!("expected Map, got {:?}", other),
604 }
605 }
606}