use anyhow::{Context, Result};
use proc_macro2::{Ident, Span, TokenStream};
use quote::quote;
use std::fs;
use std::path::{Path, PathBuf};
use crate::schema::{ResolvedSchema, SchemaRegistry};
pub struct CodeGenerator {
registry: SchemaRegistry,
output_dir: PathBuf,
}
impl CodeGenerator {
pub fn new(registry: SchemaRegistry, output_dir: impl AsRef<Path>) -> Self {
Self {
registry,
output_dir: output_dir.as_ref().to_path_buf(),
}
}
pub fn generate_all(&self) -> Result<()> {
fs::create_dir_all(&self.output_dir).context(format!(
"Failed to create output directory: {:?}",
self.output_dir
))?;
println!("\nGenerating code...");
let entities_code = self.generate_entity_structs()?;
self.write_module("entities.rs", entities_code)?;
println!(" ✓ entities.rs");
let enum_code = self.generate_ftm_entity_enum()?;
self.write_module("ftm_entity.rs", enum_code)?;
println!(" ✓ ftm_entity.rs");
let traits_code = self.generate_traits()?;
self.write_module("traits.rs", traits_code)?;
println!(" ✓ traits.rs");
let trait_impls_code = self.generate_trait_implementations()?;
self.write_module("trait_impls.rs", trait_impls_code)?;
println!(" ✓ trait_impls.rs");
let mod_code = self.generate_mod_file();
self.write_module("mod.rs", mod_code)?;
println!(" ✓ mod.rs");
Ok(())
}
fn generate_entity_structs(&self) -> Result<TokenStream> {
let mut structs = Vec::new();
for schema_name in self.registry.schema_names() {
let resolved = self.registry.resolve_inheritance(&schema_name)?;
if resolved.is_abstract() {
continue;
}
let struct_code = self.generate_entity_struct(&resolved)?;
structs.push(struct_code);
}
Ok(quote! {
#![allow(missing_docs)]
use serde::{Deserialize, Serialize};
#[cfg(feature = "builder")] use bon::Builder;
fn deserialize_f64_vec<'de, D>(deserializer: D) -> Result<Vec<f64>, D::Error>
where
D: serde::Deserializer<'de>,
{
Vec::<serde_json::Value>::deserialize(deserializer)?
.into_iter()
.map(|v| match v {
serde_json::Value::Number(n) => {
n.as_f64().ok_or_else(|| serde::de::Error::custom("number out of f64 range"))
}
serde_json::Value::String(s) => {
s.parse::<f64>().map_err(serde::de::Error::custom)
}
other => Err(serde::de::Error::custom(
format!("expected number or numeric string, got {other}")
)),
})
.collect()
}
fn deserialize_opt_f64_vec<'de, D>(deserializer: D) -> Result<Option<Vec<f64>>, D::Error>
where
D: serde::Deserializer<'de>,
{
deserialize_f64_vec(deserializer).map(Some)
}
#(#structs)*
})
}
fn generate_entity_struct(&self, schema: &ResolvedSchema) -> Result<TokenStream> {
let struct_name = Ident::new(&schema.name, Span::call_site());
let label = schema.label().unwrap_or(&schema.name);
let doc_comment = format!("FTM Schema: {}", label);
let schema_name_str = &schema.name;
let mut fields = Vec::new();
fields.push(quote! {
pub id: String
});
let schema_lit = proc_macro2::Literal::string(schema_name_str);
fields.push(quote! {
#[cfg_attr(feature = "builder", builder(default = #schema_lit.to_string()))]
pub schema: String
});
let mut property_names: Vec<_> = schema.all_properties.keys().collect();
property_names.sort();
for prop_name in &property_names {
let property = &schema.all_properties[*prop_name];
let field_name = self.property_to_field_name(prop_name);
let prop_type = property.type_.as_deref().unwrap_or("string");
let is_required = schema.all_required.contains(*prop_name);
let field_type = self.map_property_type(prop_type, is_required);
let field_doc = if let Some(label) = &property.label {
format!("Property: {}", label)
} else {
format!("Property: {}", prop_name)
};
let serde_attr = match (prop_type, is_required) {
("number", true) => {
quote! { #[serde(deserialize_with = "deserialize_f64_vec", default)] }
}
("number", false) => {
quote! { #[serde(skip_serializing_if = "Option::is_none", deserialize_with = "deserialize_opt_f64_vec", default)] }
}
(_, true) => quote! { #[serde(default)] },
(_, false) => quote! { #[serde(skip_serializing_if = "Option::is_none")] },
};
fields.push(quote! {
#[doc = #field_doc]
#serde_attr
pub #field_name: #field_type
});
}
let mut field_inits = vec![
quote! { id: id.into() },
quote! { schema: #schema_name_str.to_string() },
];
for prop_name in &property_names {
let property = &schema.all_properties[*prop_name];
let field_name = self.property_to_field_name(prop_name);
let prop_type = property.type_.as_deref().unwrap_or("string");
let is_required = schema.all_required.contains(*prop_name);
let init_value = if is_required {
match prop_type {
"json" => quote! { serde_json::Value::Object(serde_json::Map::new()) },
_ => quote! { Vec::new() },
}
} else {
quote! { None }
};
field_inits.push(quote! { #field_name: #init_value });
}
Ok(quote! {
#[doc = #doc_comment]
#[derive(Debug, Clone, Serialize, Deserialize)]
#[cfg_attr(feature = "builder", derive(Builder))]
#[serde(rename_all = "camelCase")]
pub struct #struct_name {
#(#fields),*
}
impl #struct_name {
#[deprecated(note = "Use the builder() method instead to ensure required fields are set")]
pub fn new(id: impl Into<String>) -> Self {
Self {
#(#field_inits),*
}
}
pub fn schema_name() -> &'static str {
#schema_name_str
}
pub fn to_ftm_json(&self) -> Result<String, serde_json::Error> {
let mut value = serde_json::to_value(self)?;
if let Some(obj) = value.as_object_mut() {
let id = obj.remove("id");
let schema = obj.remove("schema");
let properties = serde_json::Value::Object(std::mem::take(obj));
if let Some(id) = id { obj.insert("id".into(), id); }
if let Some(schema) = schema { obj.insert("schema".into(), schema); }
obj.insert("properties".into(), properties);
}
serde_json::to_string(&value)
}
}
})
}
fn generate_ftm_entity_enum(&self) -> Result<TokenStream> {
let mut variants = Vec::new();
let mut match_schema_arms = Vec::new();
let mut match_id_arms = Vec::new();
let mut dispatch_arms = Vec::new();
let mut from_impls = Vec::new();
for schema_name in self.registry.schema_names() {
let resolved = self.registry.resolve_inheritance(&schema_name)?;
if resolved.is_abstract() {
continue;
}
let variant_name = Ident::new(&schema_name, Span::call_site());
let type_name = Ident::new(&schema_name, Span::call_site());
variants.push(quote! {
#variant_name(#type_name)
});
match_schema_arms.push(quote! {
FtmEntity::#variant_name(_) => #schema_name
});
match_id_arms.push(quote! {
FtmEntity::#variant_name(entity) => &entity.id
});
dispatch_arms.push(quote! {
#schema_name => Ok(FtmEntity::#variant_name(serde_json::from_value(value)?))
});
from_impls.push(quote! {
impl From<#type_name> for FtmEntity {
fn from(entity: #type_name) -> Self {
FtmEntity::#variant_name(entity)
}
}
});
}
Ok(quote! {
#![allow(missing_docs)]
use super::entities::*;
use serde::{Deserialize, Serialize};
use serde_json::Value;
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
#[allow(clippy::large_enum_variant)]
pub enum FtmEntity {
#(#variants),*
}
impl FtmEntity {
pub fn schema(&self) -> &str {
match self {
#(#match_schema_arms),*
}
}
pub fn id(&self) -> &str {
match self {
#(#match_id_arms),*
}
}
pub fn from_ftm_json(json_str: &str) -> Result<Self, serde_json::Error> {
let mut value: Value = serde_json::from_str(json_str)?;
if let Some(obj) = value.as_object_mut()
&& let Some(properties) = obj.remove("properties")
&& let Some(props_obj) = properties.as_object()
{
for (key, val) in props_obj {
obj.insert(key.clone(), val.clone());
}
}
let schema = value
.get("schema")
.and_then(|v| v.as_str())
.unwrap_or("");
match schema {
#(#dispatch_arms,)*
_ => Err(serde::de::Error::custom(
format!("unknown FTM schema: {schema:?}")
)),
}
}
pub fn to_ftm_json(&self) -> Result<String, serde_json::Error> {
let mut value = serde_json::to_value(self)?;
if let Some(obj) = value.as_object_mut() {
let id = obj.remove("id");
let schema = obj.remove("schema");
let properties = serde_json::Value::Object(std::mem::take(obj));
if let Some(id) = id { obj.insert("id".into(), id); }
if let Some(schema) = schema { obj.insert("schema".into(), schema); }
obj.insert("properties".into(), properties);
}
serde_json::to_string(&value)
}
}
impl TryFrom<String> for FtmEntity {
type Error = serde_json::Error;
fn try_from(s: String) -> Result<Self, Self::Error> {
Self::from_ftm_json(&s)
}
}
impl TryFrom<&str> for FtmEntity {
type Error = serde_json::Error;
fn try_from(s: &str) -> Result<Self, Self::Error> {
Self::from_ftm_json(s)
}
}
#(#from_impls)*
})
}
fn generate_mod_file(&self) -> TokenStream {
quote! {
#![allow(missing_docs)]
pub mod entities;
pub mod ftm_entity;
pub mod trait_impls;
pub mod traits;
pub use entities::*;
pub use ftm_entity::FtmEntity;
pub use traits::*;
}
}
fn generate_traits(&self) -> Result<TokenStream> {
let mut traits = Vec::new();
for schema_name in self.registry.schema_names() {
let schema = self
.registry
.get(&schema_name)
.context(format!("Schema not found: {}", schema_name))?;
if !schema.abstract_.unwrap_or(false) {
continue;
}
let trait_code = self.generate_trait(&schema_name, schema)?;
traits.push(trait_code);
}
Ok(quote! {
#![allow(missing_docs)]
#(#traits)*
})
}
fn generate_trait(
&self,
schema_name: &str,
schema: &crate::schema::FtmSchema,
) -> Result<TokenStream> {
let trait_name = Ident::new(schema_name, Span::call_site());
let doc_comment = format!(
"Trait for FTM schema: {}",
schema.label.as_deref().unwrap_or(schema_name)
);
let parent_traits: Vec<TokenStream> = if let Some(extends) = &schema.extends {
extends
.iter()
.map(|parent| {
let parent_ident = Ident::new(parent, Span::call_site());
quote! { #parent_ident }
})
.collect()
} else {
vec![]
};
let trait_bounds = if parent_traits.is_empty() {
quote! {}
} else {
quote! { : #(#parent_traits)+* }
};
let mut methods = Vec::new();
methods.push(quote! {
fn id(&self) -> &str;
});
methods.push(quote! {
fn schema(&self) -> &str;
});
let mut property_names: Vec<_> = schema.properties.keys().collect();
property_names.sort();
for prop_name in property_names {
let property = &schema.properties[prop_name];
let method_name = self.property_to_field_name(prop_name);
let prop_type = property.type_.as_deref().unwrap_or("string");
let return_type = match prop_type {
"number" => quote! { Option<&[f64]> },
"json" => quote! { Option<&serde_json::Value> },
_ => quote! { Option<&[String]> },
};
let method_doc = if let Some(label) = &property.label {
format!("Get {} property", label)
} else {
format!("Get {} property", prop_name)
};
methods.push(quote! {
#[doc = #method_doc]
fn #method_name(&self) -> #return_type;
});
}
Ok(quote! {
#[doc = #doc_comment]
pub trait #trait_name #trait_bounds {
#(#methods)*
}
})
}
fn generate_trait_implementations(&self) -> Result<TokenStream> {
let mut impls = Vec::new();
for schema_name in self.registry.schema_names() {
let resolved = self.registry.resolve_inheritance(&schema_name)?;
if resolved.is_abstract() {
continue;
}
let impl_code = self.generate_trait_impls_for_entity(&resolved)?;
impls.extend(impl_code);
}
Ok(quote! {
#![allow(missing_docs)]
use super::entities::*;
use super::traits::*;
#(#impls)*
})
}
fn generate_trait_impls_for_entity(&self, schema: &ResolvedSchema) -> Result<Vec<TokenStream>> {
let mut impls = Vec::new();
let struct_name = Ident::new(&schema.name, Span::call_site());
let parent_schemas = self.get_all_parent_schemas(&schema.name)?;
for parent_name in parent_schemas {
let parent_schema = self
.registry
.get(&parent_name)
.context(format!("Parent schema not found: {}", parent_name))?;
if !parent_schema.abstract_.unwrap_or(false) {
continue;
}
let trait_name = Ident::new(&parent_name, Span::call_site());
let mut methods = Vec::new();
methods.push(quote! {
fn id(&self) -> &str {
&self.id
}
});
methods.push(quote! {
fn schema(&self) -> &str {
&self.schema
}
});
let mut property_names: Vec<_> = parent_schema.properties.keys().collect();
property_names.sort();
for prop_name in property_names {
let property = &parent_schema.properties[prop_name];
let method_name = self.property_to_field_name(prop_name);
let field_name = self.property_to_field_name(prop_name);
let prop_type = property.type_.as_deref().unwrap_or("string");
let is_required = schema.all_required.contains(prop_name);
let method_impl = if is_required {
match prop_type {
"number" => quote! {
fn #method_name(&self) -> Option<&[f64]> {
Some(&self.#field_name)
}
},
"json" => quote! {
fn #method_name(&self) -> Option<&serde_json::Value> {
Some(&self.#field_name)
}
},
_ => quote! {
fn #method_name(&self) -> Option<&[String]> {
Some(&self.#field_name)
}
},
}
} else {
match prop_type {
"number" => quote! {
fn #method_name(&self) -> Option<&[f64]> {
self.#field_name.as_deref()
}
},
"json" => quote! {
fn #method_name(&self) -> Option<&serde_json::Value> {
self.#field_name.as_ref()
}
},
_ => quote! {
fn #method_name(&self) -> Option<&[String]> {
self.#field_name.as_deref()
}
},
}
};
methods.push(method_impl);
}
impls.push(quote! {
impl #trait_name for #struct_name {
#(#methods)*
}
});
}
Ok(impls)
}
fn get_all_parent_schemas(&self, schema_name: &str) -> Result<Vec<String>> {
let mut parents_set = std::collections::HashSet::new();
let mut visited = std::collections::HashSet::new();
self.collect_parents_recursive(schema_name, &mut parents_set, &mut visited)?;
let mut parents: Vec<String> = parents_set.into_iter().collect();
parents.sort(); Ok(parents)
}
fn collect_parents_recursive(
&self,
schema_name: &str,
parents: &mut std::collections::HashSet<String>,
visited: &mut std::collections::HashSet<String>,
) -> Result<()> {
if visited.contains(schema_name) {
return Ok(());
}
visited.insert(schema_name.to_string());
let schema = self
.registry
.get(schema_name)
.context(format!("Schema not found: {}", schema_name))?;
if let Some(extends) = &schema.extends {
for parent_name in extends {
parents.insert(parent_name.clone());
self.collect_parents_recursive(parent_name, parents, visited)?;
}
}
Ok(())
}
fn map_property_type(&self, ftm_type: &str, is_required: bool) -> TokenStream {
if is_required {
match ftm_type {
"number" => quote! { Vec<f64> },
"date" => quote! { Vec<String> },
"json" => quote! { serde_json::Value },
_ => quote! { Vec<String> },
}
} else {
match ftm_type {
"number" => quote! { Option<Vec<f64>> },
"date" => quote! { Option<Vec<String>> },
"json" => quote! { Option<serde_json::Value> },
_ => quote! { Option<Vec<String>> },
}
}
}
fn property_to_field_name(&self, prop_name: &str) -> Ident {
let snake_case = self.to_snake_case(prop_name);
let field_name = match snake_case.as_str() {
"type" => "type_".to_string(),
"match" => "match_".to_string(),
"ref" => "ref_".to_string(),
_ => snake_case,
};
Ident::new(&field_name, Span::call_site())
}
fn to_snake_case(&self, s: &str) -> String {
if s.to_uppercase() == s && s.len() <= 3 {
return s.to_lowercase();
}
let mut result = String::new();
let mut prev_is_upper = false;
for (i, ch) in s.chars().enumerate() {
if ch.is_uppercase() {
if i > 0 && !prev_is_upper {
result.push('_');
}
result.push(ch.to_lowercase().next().unwrap());
prev_is_upper = true;
} else {
result.push(ch);
prev_is_upper = false;
}
}
result
}
fn write_module(&self, filename: &str, tokens: TokenStream) -> Result<()> {
let path = self.output_dir.join(filename);
let content = match syn::parse2(tokens.clone()) {
Ok(syntax_tree) => prettyplease::unparse(&syntax_tree),
Err(_) => {
let raw = tokens.to_string();
fs::write(&path, &raw).context(format!("Failed to write file: {:?}", path))?;
let _result = std::process::Command::new("rustfmt").arg(&path).output();
return fs::read_to_string(&path)
.context("Failed to read formatted file")
.map(|_| ());
}
};
fs::write(&path, content).context(format!("Failed to write file: {:?}", path))?;
let _result = std::process::Command::new("rustfmt").arg(&path).output();
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{generated::Person, schema::SchemaRegistry};
use std::io::Write;
use tempfile::TempDir;
fn create_test_schema(dir: &std::path::Path, name: &str, yaml: &str) {
let path = dir.join(format!("{}.yml", name));
let mut file = fs::File::create(path).unwrap();
file.write_all(yaml.as_bytes()).unwrap();
}
#[test]
fn test_code_generation() {
let temp_dir = TempDir::new().unwrap();
create_test_schema(
temp_dir.path(),
"Thing",
r#"
label: Thing
abstract: true
properties:
name:
label: Name
type: name
"#,
);
create_test_schema(
temp_dir.path(),
"Person",
r#"
label: Person
extends:
- Thing
properties:
firstName:
label: First Name
type: name
"#,
);
let registry = SchemaRegistry::load_from_cache(temp_dir.path()).unwrap();
let output_dir = temp_dir.path().join("generated");
let codegen = CodeGenerator::new(registry, &output_dir);
let result = codegen.generate_all();
assert!(result.is_ok(), "Code generation failed: {:?}", result);
assert!(output_dir.join("mod.rs").exists());
assert!(output_dir.join("entities.rs").exists());
assert!(output_dir.join("ftm_entity.rs").exists());
assert!(output_dir.join("traits.rs").exists());
assert!(output_dir.join("trait_impls.rs").exists());
}
#[test]
fn test_snake_case_conversion() {
let temp_dir = TempDir::new().unwrap();
create_test_schema(
temp_dir.path(),
"Thing",
r#"
label: Thing
properties: {}
"#,
);
let registry = SchemaRegistry::load_from_cache(temp_dir.path()).unwrap();
let codegen = CodeGenerator::new(registry, "/tmp/test");
assert_eq!(codegen.to_snake_case("firstName"), "first_name");
assert_eq!(codegen.to_snake_case("birthDate"), "birth_date");
assert_eq!(codegen.to_snake_case("name"), "name");
assert_eq!(codegen.to_snake_case("ID"), "id");
assert_eq!(codegen.to_snake_case("API"), "api");
}
#[test]
fn test_trait_generation() {
let temp_dir = TempDir::new().unwrap();
create_test_schema(
temp_dir.path(),
"Thing",
r#"
label: Thing
abstract: true
properties:
name:
label: Name
type: name
description:
label: Description
type: text
"#,
);
create_test_schema(
temp_dir.path(),
"LegalEntity",
r#"
label: Legal Entity
abstract: true
extends:
- Thing
properties:
country:
label: Country
type: country
"#,
);
create_test_schema(
temp_dir.path(),
"Person",
r#"
label: Person
extends:
- LegalEntity
properties:
firstName:
label: First Name
type: name
"#,
);
create_test_schema(
temp_dir.path(),
"Company",
r#"
label: Company
extends:
- LegalEntity
properties:
registrationNumber:
label: Registration Number
type: identifier
"#,
);
let registry = SchemaRegistry::load_from_cache(temp_dir.path()).unwrap();
let output_dir = temp_dir.path().join("generated");
let codegen = CodeGenerator::new(registry, &output_dir);
let result = codegen.generate_all();
assert!(result.is_ok(), "Code generation failed: {:?}", result);
let traits_content = fs::read_to_string(output_dir.join("traits.rs")).unwrap();
assert!(traits_content.contains("pub trait Thing"));
assert!(traits_content.contains("pub trait LegalEntity"));
assert!(traits_content.contains("fn name(&self)"));
assert!(traits_content.contains("fn country(&self)"));
let trait_impls_content = fs::read_to_string(output_dir.join("trait_impls.rs")).unwrap();
assert!(trait_impls_content.contains("impl Thing for Person"));
assert!(trait_impls_content.contains("impl LegalEntity for Person"));
assert!(trait_impls_content.contains("impl Thing for Company"));
assert!(trait_impls_content.contains("impl LegalEntity for Company"));
let entities_content = fs::read_to_string(output_dir.join("entities.rs")).unwrap();
assert!(entities_content.contains("pub struct Person"));
assert!(entities_content.contains("pub struct Company"));
assert!(entities_content.contains("pub name: Option<Vec<String>>")); assert!(entities_content.contains("pub country: Option<Vec<String>>")); }
#[test]
fn test_builder() {
let _person = Person::builder()
.name(vec!["Huh".to_string()])
.height(vec![123.45]);
}
#[test]
fn test_to_ftm_json() {
let person = Person::builder()
.name(vec!["Hello Sir".into()])
.id("123".into())
.build();
let v: serde_json::Value = serde_json::from_str(&person.to_ftm_json().unwrap()).unwrap();
let v = v.as_object().unwrap();
let keys: Vec<_> = Vec::from_iter(v.keys());
assert_eq!(keys, vec!["id", "properties", "schema"]);
}
}