use crate::codegen::SchemaRegistry;
use crate::schema::{LiteralValue, Schema, SchemaKind};
use handlebars::Handlebars;
use serde::Serialize;
use std::collections::HashMap;
pub struct RustGenerator {
registry: Handlebars<'static>,
}
impl RustGenerator {
pub fn new() -> Self {
let mut registry = Handlebars::new();
registry
.register_template_string("module", MODULE_TEMPLATE)
.unwrap();
registry
.register_template_string("struct", STRUCT_TEMPLATE)
.unwrap();
registry
.register_template_string("enum", ENUM_TEMPLATE)
.unwrap();
Self { registry }
}
pub fn generate(&self, name: &str, schema: &Schema) -> Result<String, crate::Error> {
let context = SchemaContext::from_schema(name, schema);
match &schema.kind {
SchemaKind::Enum { values } => {
let ctx = EnumContext {
name: name.to_string(),
values: values.clone(),
};
Ok(self.registry.render("enum", &ctx)?)
}
SchemaKind::Object { .. } => Ok(self.registry.render("struct", &context)?),
SchemaKind::Named { schema, .. } => self.generate(name, schema),
_ => {
let rust_type = schema_to_rust_type(schema, &HashMap::new());
Ok(format!("pub type {} = {};\n", name, rust_type))
}
}
}
pub fn generate_module(&self, registry: &SchemaRegistry) -> Result<String, crate::Error> {
let mut rendered: Vec<String> = Vec::new();
for (name, schema) in registry.schemas() {
rendered.push(self.generate(name, schema)?);
}
let mut output = String::new();
output.push_str("// Auto-generated by typebox-rs. DO NOT EDIT.\n\n");
output.push_str("use serde::{Deserialize, Serialize};\n\n");
for code in rendered {
output.push_str(&code);
output.push('\n');
}
Ok(output)
}
}
impl Default for RustGenerator {
fn default() -> Self {
Self::new()
}
}
#[derive(Serialize)]
struct SchemaContext {
name: String,
description: Option<String>,
properties: Vec<PropertyContext>,
type_refs: HashMap<String, String>,
}
#[derive(Serialize)]
struct PropertyContext {
name: String,
rust_name: String,
rust_type: String,
optional: bool,
description: Option<String>,
has_default: bool,
default_value: Option<String>,
}
#[derive(Serialize)]
struct EnumContext {
name: String,
values: Vec<String>,
}
impl SchemaContext {
fn from_schema(name: &str, schema: &Schema) -> Self {
let mut properties = Vec::new();
if let SchemaKind::Object {
properties: props,
required,
..
} = &schema.kind
{
for (prop_name, prop_schema) in props {
let is_optional = !required.contains(prop_name);
properties.push(PropertyContext {
name: prop_name.clone(),
rust_name: format_ident(prop_name),
rust_type: schema_to_rust_type(prop_schema, &HashMap::new()),
optional: is_optional,
description: None,
has_default: false,
default_value: None,
});
}
}
Self {
name: name.to_string(),
description: schema.description.clone(),
properties,
type_refs: HashMap::new(),
}
}
}
fn schema_to_rust_type(schema: &Schema, refs: &HashMap<String, String>) -> String {
match &schema.kind {
SchemaKind::Null => "()".to_string(),
SchemaKind::Bool => "bool".to_string(),
SchemaKind::Int8 { .. } => "i8".to_string(),
SchemaKind::Int16 { .. } => "i16".to_string(),
SchemaKind::Int32 { .. } => "i32".to_string(),
SchemaKind::Int64 { .. } => "i64".to_string(),
SchemaKind::UInt8 { .. } => "u8".to_string(),
SchemaKind::UInt16 { .. } => "u16".to_string(),
SchemaKind::UInt32 { .. } => "u32".to_string(),
SchemaKind::UInt64 { .. } => "u64".to_string(),
SchemaKind::Float32 { .. } => "f32".to_string(),
SchemaKind::Float64 { .. } => "f64".to_string(),
SchemaKind::String { .. } => "String".to_string(),
SchemaKind::Bytes { .. } => "Vec<u8>".to_string(),
SchemaKind::Array { items, .. } => {
format!("Vec<{}>", schema_to_rust_type(items, refs))
}
SchemaKind::Tuple { items } => {
let types: Vec<_> = items.iter().map(|s| schema_to_rust_type(s, refs)).collect();
format!("({})", types.join(", "))
}
SchemaKind::Object { .. } => "serde_json::Value".to_string(),
SchemaKind::Union { any_of } => {
if any_of.len() == 2 {
let is_optional = any_of.iter().any(|s| matches!(&s.kind, SchemaKind::Null));
if is_optional {
let non_null: Vec<_> = any_of
.iter()
.filter(|s| !matches!(&s.kind, SchemaKind::Null))
.collect();
if non_null.len() == 1 {
return format!("Option<{}>", schema_to_rust_type(non_null[0], refs));
}
}
}
let types: Vec<_> = any_of
.iter()
.map(|s| schema_to_rust_type(s, refs))
.collect();
format!("({})", types.join(" | "))
}
SchemaKind::Literal { value } => match value {
LiteralValue::String(s) => format!("&'static str /* \"{}\" */", s),
LiteralValue::Number(n) => format!("{}", n),
LiteralValue::Float(f) => format!("{}", f),
LiteralValue::Boolean(b) => format!("{}", b),
LiteralValue::Null => "()".to_string(),
},
SchemaKind::Enum { .. } => "String".to_string(),
SchemaKind::Ref { reference } => {
let name = reference
.strip_prefix("#/definitions/")
.unwrap_or(reference);
refs.get(name).cloned().unwrap_or_else(|| name.to_string())
}
SchemaKind::Named { name, .. } => name.clone(),
SchemaKind::Function {
parameters,
returns,
} => {
let params: Vec<_> = parameters
.iter()
.map(|s| schema_to_rust_type(s, refs))
.collect();
let ret = schema_to_rust_type(returns, refs);
format!("fn({}) -> {}", params.join(", "), ret)
}
SchemaKind::Void => "()".to_string(),
SchemaKind::Never => "!".to_string(),
SchemaKind::Any => "serde_json::Value".to_string(),
SchemaKind::Unknown => "serde_json::Value".to_string(),
SchemaKind::Undefined => "()".to_string(),
SchemaKind::Recursive { schema } => schema_to_rust_type(schema, refs),
SchemaKind::Intersect { all_of } => {
let types: Vec<_> = all_of
.iter()
.map(|s| schema_to_rust_type(s, refs))
.collect();
types.join(" + ")
}
}
}
fn format_ident(name: &str) -> String {
match name {
"type" => "r#type".to_string(),
"match" => "r#match".to_string(),
"fn" => "r#fn".to_string(),
"mod" => "r#mod".to_string(),
"use" => "r#use".to_string(),
s => s.to_string(),
}
}
const MODULE_TEMPLATE: &str = r#"{{preamble}}
use serde::{Deserialize, Serialize};
{{#each schemas}}
{{{this}}}
{{/each}}
"#;
const STRUCT_TEMPLATE: &str = r#"#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct {{name}} {
{{#each properties}}
{{#if description}}/// {{description}}
{{/if}}{{#if optional}}#[serde(skip_serializing_if = "Option::is_none")]
{{/if}}pub {{rust_name}}: {{#if optional}}Option<{{/if}}{{rust_type}}{{#if optional}}>{{/if}},
{{/each}}
}
"#;
const ENUM_TEMPLATE: &str = r#"#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum {{name}} {
{{#each values}}
#[serde(rename = "{{this}}")]
{{this}},
{{/each}}
}
"#;
#[cfg(test)]
mod tests {
use super::*;
use crate::builder::SchemaBuilder;
#[test]
fn test_generate_struct() {
let gen = RustGenerator::new();
let schema = SchemaBuilder::object()
.field("id", SchemaBuilder::int64())
.field("name", SchemaBuilder::string().build())
.optional_field("email", SchemaBuilder::string().build())
.build();
let output = gen.generate("Person", &schema).unwrap();
assert!(output.contains("pub struct Person"));
assert!(output.contains("pub id: i64"));
assert!(output.contains("pub name: String"));
assert!(output.contains("pub email: Option<String>"));
}
#[test]
fn test_generate_enum() {
let gen = RustGenerator::new();
let schema = SchemaBuilder::enum_values(vec!["Red", "Green", "Blue"]);
let output = gen.generate("Color", &schema).unwrap();
assert!(output.contains("pub enum Color"));
assert!(output.contains("Red"));
assert!(output.contains("Green"));
assert!(output.contains("Blue"));
}
#[test]
fn test_generate_module() {
let gen = RustGenerator::new();
let mut registry = SchemaRegistry::new();
registry.register(
"Person",
SchemaBuilder::object()
.field("id", SchemaBuilder::int64())
.field("name", SchemaBuilder::string().build())
.build(),
);
let output = gen.generate_module(®istry).unwrap();
assert!(output.contains("use serde::{Deserialize, Serialize}"));
assert!(output.contains("pub struct Person"));
}
#[test]
fn test_generate_function_type() {
let gen = RustGenerator::new();
let schema = SchemaBuilder::function(
vec![SchemaBuilder::int64(), SchemaBuilder::string().build()],
SchemaBuilder::void(),
);
let output = gen.generate("Callback", &schema).unwrap();
assert!(output.contains("pub type Callback"));
assert!(output.contains("fn"));
}
}