1use anyhow::{Context, Result};
6use proc_macro2::{Ident, Span, TokenStream};
7use quote::quote;
8use std::fs;
9use std::path::{Path, PathBuf};
10
11use crate::schema::{ResolvedSchema, SchemaRegistry};
12
13pub struct CodeGenerator {
15 registry: SchemaRegistry,
16 output_dir: PathBuf,
17}
18
19impl CodeGenerator {
20 pub fn new(registry: SchemaRegistry, output_dir: impl AsRef<Path>) -> Self {
22 Self {
23 registry,
24 output_dir: output_dir.as_ref().to_path_buf(),
25 }
26 }
27
28 pub fn generate_all(&self) -> Result<()> {
30 fs::create_dir_all(&self.output_dir).context(format!(
32 "Failed to create output directory: {:?}",
33 self.output_dir
34 ))?;
35
36 println!("\nGenerating code...");
37
38 let entities_code = self.generate_entity_structs()?;
40 self.write_module("entities.rs", entities_code)?;
41 println!(" ✓ entities.rs");
42
43 let enum_code = self.generate_ftm_entity_enum()?;
45 self.write_module("ftm_entity.rs", enum_code)?;
46 println!(" ✓ ftm_entity.rs");
47
48 let traits_code = self.generate_traits()?;
50 self.write_module("traits.rs", traits_code)?;
51 println!(" ✓ traits.rs");
52
53 let trait_impls_code = self.generate_trait_implementations()?;
55 self.write_module("trait_impls.rs", trait_impls_code)?;
56 println!(" ✓ trait_impls.rs");
57
58 let mod_code = self.generate_mod_file();
60 self.write_module("mod.rs", mod_code)?;
61 println!(" ✓ mod.rs");
62
63 Ok(())
64 }
65
66 fn generate_entity_structs(&self) -> Result<TokenStream> {
68 let mut structs = Vec::new();
69
70 for schema_name in self.registry.schema_names() {
71 let resolved = self.registry.resolve_inheritance(&schema_name)?;
72
73 if resolved.is_abstract() {
75 continue;
76 }
77
78 let struct_code = self.generate_entity_struct(&resolved)?;
79 structs.push(struct_code);
80 }
81
82 Ok(quote! {
83 #![allow(missing_docs)]
85
86 use serde::{Deserialize, Serialize};
87
88 #[cfg(feature = "builder")] use bon::Builder;
89
90 fn deserialize_f64_vec<'de, D>(deserializer: D) -> Result<Vec<f64>, D::Error>
93 where
94 D: serde::Deserializer<'de>,
95 {
96 Vec::<serde_json::Value>::deserialize(deserializer)?
97 .into_iter()
98 .map(|v| match v {
99 serde_json::Value::Number(n) => {
100 n.as_f64().ok_or_else(|| serde::de::Error::custom("number out of f64 range"))
101 }
102 serde_json::Value::String(s) => {
103 s.parse::<f64>().map_err(serde::de::Error::custom)
104 }
105 other => Err(serde::de::Error::custom(
106 format!("expected number or numeric string, got {other}")
107 )),
108 })
109 .collect()
110 }
111
112 fn deserialize_opt_f64_vec<'de, D>(deserializer: D) -> Result<Option<Vec<f64>>, D::Error>
116 where
117 D: serde::Deserializer<'de>,
118 {
119 deserialize_f64_vec(deserializer).map(Some)
120 }
121
122 #(#structs)*
123 })
124 }
125
126 fn generate_entity_struct(&self, schema: &ResolvedSchema) -> Result<TokenStream> {
128 let struct_name = Ident::new(&schema.name, Span::call_site());
129 let label = schema.label().unwrap_or(&schema.name);
130 let doc_comment = format!("FTM Schema: {}", label);
131 let schema_name_str = &schema.name;
132
133 let mut fields = Vec::new();
135
136 fields.push(quote! {
138 pub id: String
139 });
140
141 let schema_lit = proc_macro2::Literal::string(schema_name_str);
145 fields.push(quote! {
146 #[cfg_attr(feature = "builder", builder(default = #schema_lit.to_string()))]
147 pub schema: String
148 });
149
150 let mut property_names: Vec<_> = schema.all_properties.keys().collect();
152 property_names.sort();
153
154 for prop_name in &property_names {
155 let property = &schema.all_properties[*prop_name];
156 let field_name = self.property_to_field_name(prop_name);
157
158 let prop_type = property.type_.as_deref().unwrap_or("string");
160
161 let is_required = schema.all_required.contains(*prop_name);
162 let field_type = self.map_property_type(prop_type, is_required);
163
164 let field_doc = if let Some(label) = &property.label {
165 format!("Property: {}", label)
166 } else {
167 format!("Property: {}", prop_name)
168 };
169
170 let serde_attr = match (prop_type, is_required) {
173 ("number", true) => {
175 quote! { #[serde(deserialize_with = "deserialize_f64_vec", default)] }
176 }
177 ("number", false) => {
179 quote! { #[serde(skip_serializing_if = "Option::is_none", deserialize_with = "deserialize_opt_f64_vec", default)] }
180 }
181 (_, true) => quote! { #[serde(default)] },
185 (_, false) => quote! { #[serde(skip_serializing_if = "Option::is_none")] },
186 };
187
188 fields.push(quote! {
189 #[doc = #field_doc]
190 #serde_attr
191 pub #field_name: #field_type
192 });
193 }
194
195 let mut field_inits = vec![
197 quote! { id: id.into() },
198 quote! { schema: #schema_name_str.to_string() },
199 ];
200
201 for prop_name in &property_names {
203 let property = &schema.all_properties[*prop_name];
204 let field_name = self.property_to_field_name(prop_name);
205
206 let prop_type = property.type_.as_deref().unwrap_or("string");
207
208 let is_required = schema.all_required.contains(*prop_name);
209
210 let init_value = if is_required {
211 match prop_type {
213 "json" => quote! { serde_json::Value::Object(serde_json::Map::new()) },
214 _ => quote! { Vec::new() },
215 }
216 } else {
217 quote! { None }
219 };
220
221 field_inits.push(quote! { #field_name: #init_value });
222 }
223
224 Ok(quote! {
225 #[doc = #doc_comment]
226 #[derive(Debug, Clone, Serialize, Deserialize)]
227 #[cfg_attr(feature = "builder", derive(Builder))]
228 #[serde(rename_all = "camelCase")]
229 pub struct #struct_name {
230 #(#fields),*
231 }
232
233 impl #struct_name {
234 #[deprecated(note = "Use the builder() method instead to ensure required fields are set")]
236 pub fn new(id: impl Into<String>) -> Self {
237 Self {
238 #(#field_inits),*
239 }
240 }
241
242 pub fn schema_name() -> &'static str {
244 #schema_name_str
245 }
246 }
247 })
248 }
249
250 fn generate_ftm_entity_enum(&self) -> Result<TokenStream> {
252 let mut variants = Vec::new();
253 let mut match_schema_arms = Vec::new();
254 let mut match_id_arms = Vec::new();
255 let mut dispatch_arms = Vec::new();
256 let mut from_impls = Vec::new();
257
258 for schema_name in self.registry.schema_names() {
259 let resolved = self.registry.resolve_inheritance(&schema_name)?;
260
261 if resolved.is_abstract() {
263 continue;
264 }
265
266 let variant_name = Ident::new(&schema_name, Span::call_site());
267 let type_name = Ident::new(&schema_name, Span::call_site());
268
269 variants.push(quote! {
270 #variant_name(#type_name)
271 });
272
273 match_schema_arms.push(quote! {
274 FtmEntity::#variant_name(_) => #schema_name
275 });
276
277 match_id_arms.push(quote! {
278 FtmEntity::#variant_name(entity) => &entity.id
279 });
280
281 dispatch_arms.push(quote! {
282 #schema_name => Ok(FtmEntity::#variant_name(serde_json::from_value(value)?))
283 });
284
285 from_impls.push(quote! {
286 impl From<#type_name> for FtmEntity {
287 fn from(entity: #type_name) -> Self {
288 FtmEntity::#variant_name(entity)
289 }
290 }
291 });
292 }
293
294 Ok(quote! {
295 #![allow(missing_docs)]
297
298 use super::entities::*;
299 use serde::{Deserialize, Serialize};
300 use serde_json::Value;
301
302 #[derive(Debug, Clone, Serialize, Deserialize)]
304 #[serde(untagged)]
305 #[allow(clippy::large_enum_variant)]
306 pub enum FtmEntity {
307 #(#variants),*
308 }
309
310 impl FtmEntity {
311 pub fn schema(&self) -> &str {
313 match self {
314 #(#match_schema_arms),*
315 }
316 }
317
318 pub fn id(&self) -> &str {
320 match self {
321 #(#match_id_arms),*
322 }
323 }
324
325 pub fn from_ftm_json(json_str: &str) -> Result<Self, serde_json::Error> {
343 let mut value: Value = serde_json::from_str(json_str)?;
344
345 if let Some(obj) = value.as_object_mut()
346 && let Some(properties) = obj.remove("properties")
347 && let Some(props_obj) = properties.as_object()
348 {
349 for (key, val) in props_obj {
350 obj.insert(key.clone(), val.clone());
351 }
352 }
353
354 let schema = value
355 .get("schema")
356 .and_then(|v| v.as_str())
357 .unwrap_or("");
358
359 match schema {
360 #(#dispatch_arms,)*
361 _ => Err(serde::de::Error::custom(
362 format!("unknown FTM schema: {schema:?}")
363 )),
364 }
365 }
366 }
367
368 impl TryFrom<String> for FtmEntity {
369 type Error = serde_json::Error;
370
371 fn try_from(s: String) -> Result<Self, Self::Error> {
372 Self::from_ftm_json(&s)
373 }
374 }
375
376 impl TryFrom<&str> for FtmEntity {
377 type Error = serde_json::Error;
378
379 fn try_from(s: &str) -> Result<Self, Self::Error> {
380 Self::from_ftm_json(s)
381 }
382 }
383
384 #(#from_impls)*
385 })
386 }
387
388 fn generate_mod_file(&self) -> TokenStream {
390 quote! {
391 #![allow(missing_docs)]
393
394 pub mod entities;
395 pub mod ftm_entity;
396 pub mod trait_impls;
397 pub mod traits;
398
399 pub use entities::*;
400 pub use ftm_entity::FtmEntity;
401 pub use traits::*;
402 }
403 }
404
405 fn generate_traits(&self) -> Result<TokenStream> {
407 let mut traits = Vec::new();
408
409 for schema_name in self.registry.schema_names() {
410 let schema = self
411 .registry
412 .get(&schema_name)
413 .context(format!("Schema not found: {}", schema_name))?;
414
415 if !schema.abstract_.unwrap_or(false) {
417 continue;
418 }
419
420 let trait_code = self.generate_trait(&schema_name, schema)?;
421 traits.push(trait_code);
422 }
423
424 Ok(quote! {
425 #![allow(missing_docs)]
427
428 #(#traits)*
434 })
435 }
436
437 fn generate_trait(
439 &self,
440 schema_name: &str,
441 schema: &crate::schema::FtmSchema,
442 ) -> Result<TokenStream> {
443 let trait_name = Ident::new(schema_name, Span::call_site());
444 let doc_comment = format!(
445 "Trait for FTM schema: {}",
446 schema.label.as_deref().unwrap_or(schema_name)
447 );
448
449 let parent_traits: Vec<TokenStream> = if let Some(extends) = &schema.extends {
451 extends
452 .iter()
453 .map(|parent| {
454 let parent_ident = Ident::new(parent, Span::call_site());
455 quote! { #parent_ident }
456 })
457 .collect()
458 } else {
459 vec![]
460 };
461
462 let trait_bounds = if parent_traits.is_empty() {
463 quote! {}
464 } else {
465 quote! { : #(#parent_traits)+* }
466 };
467
468 let mut methods = Vec::new();
470
471 methods.push(quote! {
473 fn id(&self) -> &str;
475 });
476
477 methods.push(quote! {
478 fn schema(&self) -> &str;
480 });
481
482 let mut property_names: Vec<_> = schema.properties.keys().collect();
484 property_names.sort();
485
486 for prop_name in property_names {
487 let property = &schema.properties[prop_name];
488 let method_name = self.property_to_field_name(prop_name);
489
490 let prop_type = property.type_.as_deref().unwrap_or("string");
491
492 let return_type = match prop_type {
493 "number" => quote! { Option<&[f64]> },
494 "json" => quote! { Option<&serde_json::Value> },
495 _ => quote! { Option<&[String]> },
496 };
497
498 let method_doc = if let Some(label) = &property.label {
499 format!("Get {} property", label)
500 } else {
501 format!("Get {} property", prop_name)
502 };
503
504 methods.push(quote! {
505 #[doc = #method_doc]
506 fn #method_name(&self) -> #return_type;
507 });
508 }
509
510 Ok(quote! {
511 #[doc = #doc_comment]
512 pub trait #trait_name #trait_bounds {
513 #(#methods)*
514 }
515 })
516 }
517
518 fn generate_trait_implementations(&self) -> Result<TokenStream> {
520 let mut impls = Vec::new();
521
522 for schema_name in self.registry.schema_names() {
523 let resolved = self.registry.resolve_inheritance(&schema_name)?;
524
525 if resolved.is_abstract() {
527 continue;
528 }
529
530 let impl_code = self.generate_trait_impls_for_entity(&resolved)?;
531 impls.extend(impl_code);
532 }
533
534 Ok(quote! {
535 #![allow(missing_docs)]
537
538 use super::entities::*;
539 use super::traits::*;
540
541 #(#impls)*
542 })
543 }
544
545 fn generate_trait_impls_for_entity(&self, schema: &ResolvedSchema) -> Result<Vec<TokenStream>> {
547 let mut impls = Vec::new();
548 let struct_name = Ident::new(&schema.name, Span::call_site());
549
550 let parent_schemas = self.get_all_parent_schemas(&schema.name)?;
552
553 for parent_name in parent_schemas {
555 let parent_schema = self
556 .registry
557 .get(&parent_name)
558 .context(format!("Parent schema not found: {}", parent_name))?;
559
560 if !parent_schema.abstract_.unwrap_or(false) {
562 continue;
563 }
564
565 let trait_name = Ident::new(&parent_name, Span::call_site());
566 let mut methods = Vec::new();
567
568 methods.push(quote! {
570 fn id(&self) -> &str {
571 &self.id
572 }
573 });
574
575 methods.push(quote! {
576 fn schema(&self) -> &str {
577 &self.schema
578 }
579 });
580
581 let mut property_names: Vec<_> = parent_schema.properties.keys().collect();
583 property_names.sort();
584
585 for prop_name in property_names {
586 let property = &parent_schema.properties[prop_name];
587 let method_name = self.property_to_field_name(prop_name);
588 let field_name = self.property_to_field_name(prop_name);
589
590 let prop_type = property.type_.as_deref().unwrap_or("string");
591
592 let is_required = schema.all_required.contains(prop_name);
594
595 let method_impl = if is_required {
596 match prop_type {
598 "number" => quote! {
599 fn #method_name(&self) -> Option<&[f64]> {
600 Some(&self.#field_name)
601 }
602 },
603 "json" => quote! {
604 fn #method_name(&self) -> Option<&serde_json::Value> {
605 Some(&self.#field_name)
606 }
607 },
608 _ => quote! {
609 fn #method_name(&self) -> Option<&[String]> {
610 Some(&self.#field_name)
611 }
612 },
613 }
614 } else {
615 match prop_type {
617 "number" => quote! {
618 fn #method_name(&self) -> Option<&[f64]> {
619 self.#field_name.as_deref()
620 }
621 },
622 "json" => quote! {
623 fn #method_name(&self) -> Option<&serde_json::Value> {
624 self.#field_name.as_ref()
625 }
626 },
627 _ => quote! {
628 fn #method_name(&self) -> Option<&[String]> {
629 self.#field_name.as_deref()
630 }
631 },
632 }
633 };
634
635 methods.push(method_impl);
636 }
637
638 impls.push(quote! {
639 impl #trait_name for #struct_name {
640 #(#methods)*
641 }
642 });
643 }
644
645 Ok(impls)
646 }
647
648 fn get_all_parent_schemas(&self, schema_name: &str) -> Result<Vec<String>> {
650 let mut parents_set = std::collections::HashSet::new();
651 let mut visited = std::collections::HashSet::new();
652 self.collect_parents_recursive(schema_name, &mut parents_set, &mut visited)?;
653
654 let mut parents: Vec<String> = parents_set.into_iter().collect();
656 parents.sort(); Ok(parents)
658 }
659
660 fn collect_parents_recursive(
662 &self,
663 schema_name: &str,
664 parents: &mut std::collections::HashSet<String>,
665 visited: &mut std::collections::HashSet<String>,
666 ) -> Result<()> {
667 if visited.contains(schema_name) {
668 return Ok(());
669 }
670 visited.insert(schema_name.to_string());
671
672 let schema = self
673 .registry
674 .get(schema_name)
675 .context(format!("Schema not found: {}", schema_name))?;
676
677 if let Some(extends) = &schema.extends {
678 for parent_name in extends {
679 parents.insert(parent_name.clone());
680 self.collect_parents_recursive(parent_name, parents, visited)?;
681 }
682 }
683
684 Ok(())
685 }
686
687 fn map_property_type(&self, ftm_type: &str, is_required: bool) -> TokenStream {
689 if is_required {
690 match ftm_type {
692 "number" => quote! { Vec<f64> },
693 "date" => quote! { Vec<String> },
694 "json" => quote! { serde_json::Value },
695 _ => quote! { Vec<String> },
696 }
697 } else {
698 match ftm_type {
700 "number" => quote! { Option<Vec<f64>> },
701 "date" => quote! { Option<Vec<String>> },
702 "json" => quote! { Option<serde_json::Value> },
703 _ => quote! { Option<Vec<String>> },
704 }
705 }
706 }
707
708 fn property_to_field_name(&self, prop_name: &str) -> Ident {
710 let snake_case = self.to_snake_case(prop_name);
712
713 let field_name = match snake_case.as_str() {
715 "type" => "type_".to_string(),
716 "match" => "match_".to_string(),
717 "ref" => "ref_".to_string(),
718 _ => snake_case,
719 };
720
721 Ident::new(&field_name, Span::call_site())
722 }
723
724 fn to_snake_case(&self, s: &str) -> String {
726 if s.to_uppercase() == s && s.len() <= 3 {
728 return s.to_lowercase();
730 }
731
732 let mut result = String::new();
733 let mut prev_is_upper = false;
734
735 for (i, ch) in s.chars().enumerate() {
736 if ch.is_uppercase() {
737 if i > 0 && !prev_is_upper {
738 result.push('_');
739 }
740 result.push(ch.to_lowercase().next().unwrap());
741 prev_is_upper = true;
742 } else {
743 result.push(ch);
744 prev_is_upper = false;
745 }
746 }
747
748 result
749 }
750
751 fn write_module(&self, filename: &str, tokens: TokenStream) -> Result<()> {
753 let path = self.output_dir.join(filename);
754
755 let content = match syn::parse2(tokens.clone()) {
759 Ok(syntax_tree) => prettyplease::unparse(&syntax_tree),
760 Err(_) => {
761 let raw = tokens.to_string();
763 fs::write(&path, &raw).context(format!("Failed to write file: {:?}", path))?;
764
765 let _result = std::process::Command::new("rustfmt").arg(&path).output();
767
768 return fs::read_to_string(&path)
770 .context("Failed to read formatted file")
771 .map(|_| ());
772 }
773 };
774
775 fs::write(&path, content).context(format!("Failed to write file: {:?}", path))?;
776
777 let _result = std::process::Command::new("rustfmt").arg(&path).output();
779
780 Ok(())
781 }
782}
783
784#[cfg(test)]
785mod tests {
786 use super::*;
787 use crate::{generated::Person, schema::SchemaRegistry};
788 use std::io::Write;
789 use tempfile::TempDir;
790
791 fn create_test_schema(dir: &std::path::Path, name: &str, yaml: &str) {
792 let path = dir.join(format!("{}.yml", name));
793 let mut file = fs::File::create(path).unwrap();
794 file.write_all(yaml.as_bytes()).unwrap();
795 }
796
797 #[test]
798 fn test_code_generation() {
799 let temp_dir = TempDir::new().unwrap();
800
801 create_test_schema(
802 temp_dir.path(),
803 "Thing",
804 r#"
805label: Thing
806abstract: true
807properties:
808 name:
809 label: Name
810 type: name
811"#,
812 );
813
814 create_test_schema(
815 temp_dir.path(),
816 "Person",
817 r#"
818label: Person
819extends:
820 - Thing
821properties:
822 firstName:
823 label: First Name
824 type: name
825"#,
826 );
827
828 let registry = SchemaRegistry::load_from_cache(temp_dir.path()).unwrap();
829 let output_dir = temp_dir.path().join("generated");
830 let codegen = CodeGenerator::new(registry, &output_dir);
831
832 let result = codegen.generate_all();
833 assert!(result.is_ok(), "Code generation failed: {:?}", result);
834
835 assert!(output_dir.join("mod.rs").exists());
837 assert!(output_dir.join("entities.rs").exists());
838 assert!(output_dir.join("ftm_entity.rs").exists());
839 assert!(output_dir.join("traits.rs").exists());
840 assert!(output_dir.join("trait_impls.rs").exists());
841 }
842
843 #[test]
844 fn test_snake_case_conversion() {
845 let temp_dir = TempDir::new().unwrap();
846
847 create_test_schema(
848 temp_dir.path(),
849 "Thing",
850 r#"
851label: Thing
852properties: {}
853"#,
854 );
855
856 let registry = SchemaRegistry::load_from_cache(temp_dir.path()).unwrap();
857 let codegen = CodeGenerator::new(registry, "/tmp/test");
858
859 assert_eq!(codegen.to_snake_case("firstName"), "first_name");
860 assert_eq!(codegen.to_snake_case("birthDate"), "birth_date");
861 assert_eq!(codegen.to_snake_case("name"), "name");
862 assert_eq!(codegen.to_snake_case("ID"), "id");
863 assert_eq!(codegen.to_snake_case("API"), "api");
864 }
865
866 #[test]
867 fn test_trait_generation() {
868 let temp_dir = TempDir::new().unwrap();
869
870 create_test_schema(
872 temp_dir.path(),
873 "Thing",
874 r#"
875label: Thing
876abstract: true
877properties:
878 name:
879 label: Name
880 type: name
881 description:
882 label: Description
883 type: text
884"#,
885 );
886
887 create_test_schema(
889 temp_dir.path(),
890 "LegalEntity",
891 r#"
892label: Legal Entity
893abstract: true
894extends:
895 - Thing
896properties:
897 country:
898 label: Country
899 type: country
900"#,
901 );
902
903 create_test_schema(
905 temp_dir.path(),
906 "Person",
907 r#"
908label: Person
909extends:
910 - LegalEntity
911properties:
912 firstName:
913 label: First Name
914 type: name
915"#,
916 );
917
918 create_test_schema(
919 temp_dir.path(),
920 "Company",
921 r#"
922label: Company
923extends:
924 - LegalEntity
925properties:
926 registrationNumber:
927 label: Registration Number
928 type: identifier
929"#,
930 );
931
932 let registry = SchemaRegistry::load_from_cache(temp_dir.path()).unwrap();
933 let output_dir = temp_dir.path().join("generated");
934 let codegen = CodeGenerator::new(registry, &output_dir);
935
936 let result = codegen.generate_all();
937 assert!(result.is_ok(), "Code generation failed: {:?}", result);
938
939 let traits_content = fs::read_to_string(output_dir.join("traits.rs")).unwrap();
941 assert!(traits_content.contains("pub trait Thing"));
942 assert!(traits_content.contains("pub trait LegalEntity"));
943 assert!(traits_content.contains("fn name(&self)"));
944 assert!(traits_content.contains("fn country(&self)"));
945
946 let trait_impls_content = fs::read_to_string(output_dir.join("trait_impls.rs")).unwrap();
948 assert!(trait_impls_content.contains("impl Thing for Person"));
949 assert!(trait_impls_content.contains("impl LegalEntity for Person"));
950 assert!(trait_impls_content.contains("impl Thing for Company"));
951 assert!(trait_impls_content.contains("impl LegalEntity for Company"));
952
953 let entities_content = fs::read_to_string(output_dir.join("entities.rs")).unwrap();
955 assert!(entities_content.contains("pub struct Person"));
956 assert!(entities_content.contains("pub struct Company"));
957 assert!(entities_content.contains("pub name: Option<Vec<String>>")); assert!(entities_content.contains("pub country: Option<Vec<String>>")); }
960
961 #[test]
962 fn test_builder() {
963 let _person = Person::builder()
964 .name(vec!["Huh".to_string()])
965 .height(vec![123.45]);
966 }
967}