use proc_macro2::TokenStream;
use quote::quote;
use super::parse::{EntityDef, FieldDef};
use crate::utils::marker;
pub fn generate(entity: &EntityDef) -> TokenStream {
let create = generate_create_dto(entity);
let update = generate_update_dto(entity);
let response = generate_response_dto(entity);
quote! { #create #update #response }
}
fn generate_create_dto(entity: &EntityDef) -> TokenStream {
let fields = entity.create_fields();
if fields.is_empty() {
return TokenStream::new();
}
let vis = &entity.vis;
let name = entity.ident_with("Create", "Request");
let field_defs = fields.iter().map(|f| {
let n = f.name();
let t = f.ty();
let garde = garde_attr(f, 0);
quote! {
#garde
pub #n: #t
}
});
let extra_derives = dto_extra_derives();
let marker = marker::generated();
quote! {
#marker
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
#extra_derives
#vis struct #name { #(#field_defs),* }
}
}
fn dto_extra_derives() -> TokenStream {
let mut derives = TokenStream::new();
if cfg!(feature = "api") {
derives.extend(quote! { #[derive(utoipa::ToSchema)] });
}
if cfg!(feature = "validate") {
derives.extend(quote! { #[derive(validator::Validate)] });
} else if cfg!(feature = "garde") {
derives.extend(quote! { #[derive(garde::Validate)] });
}
derives
}
fn garde_attr(field: &FieldDef, option_depth: usize) -> TokenStream {
let v = &field.validation;
let mut rules: Vec<String> = Vec::new();
match (v.min_length, v.max_length) {
(Some(min), Some(max)) => rules.push(format!("length(min = {min}, max = {max})")),
(Some(min), None) => rules.push(format!("length(min = {min})")),
(None, Some(max)) => rules.push(format!("length(max = {max})")),
(None, None) => {}
}
match (v.minimum, v.maximum) {
(Some(min), Some(max)) => rules.push(format!("range(min = {min}, max = {max})")),
(Some(min), None) => rules.push(format!("range(min = {min})")),
(None, Some(max)) => rules.push(format!("range(max = {max})")),
(None, None) => {}
}
if v.email {
rules.push("email".to_string());
}
if v.url {
rules.push("url".to_string());
}
if let Some(pattern) = &v.pattern {
rules.push(format!("pattern(\"{pattern}\")"));
}
let body = if rules.is_empty() {
"skip".to_string()
} else {
let mut inner = rules.join(", ");
for _ in 0..option_depth {
inner = format!("inner({inner})");
}
inner
};
if !cfg!(feature = "garde") || cfg!(feature = "validate") {
return TokenStream::new();
}
let tokens: TokenStream = body.parse().expect("garde rules are valid tokens");
quote! { #[garde(#tokens)] }
}
fn generate_update_dto(entity: &EntityDef) -> TokenStream {
let fields = entity.update_fields();
if fields.is_empty() {
return TokenStream::new();
}
let vis = &entity.vis;
let name = entity.ident_with("Update", "Request");
let field_defs = fields.iter().map(|f| {
let n = f.name();
let t = f.ty();
if f.is_option() {
let garde = garde_attr(f, 2);
quote! {
#[serde(
default,
skip_serializing_if = "Option::is_none",
with = "::entity_core::serde_helpers::double_option"
)]
#garde
pub #n: Option<#t>
}
} else {
let garde = garde_attr(f, 1);
quote! {
#garde
pub #n: Option<#t>
}
}
});
let version_garde_skip = if cfg!(feature = "garde") && !cfg!(feature = "validate") {
quote! { #[garde(skip)] }
} else {
TokenStream::new()
};
let version_field = entity.version_field().map(|f| {
let vt = f.ty();
quote! {
#version_garde_skip
pub expected_version: #vt,
}
});
let extra_derives = dto_extra_derives();
let marker = marker::generated();
quote! {
#marker
#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
#extra_derives
#vis struct #name {
#(#field_defs,)*
#version_field
}
}
}
fn generate_response_dto(entity: &EntityDef) -> TokenStream {
let fields = entity.response_fields();
if fields.is_empty() {
return TokenStream::new();
}
let vis = &entity.vis;
let name = entity.ident_with("", "Response");
let field_defs = fields.iter().map(|f| {
let n = f.name();
let t = f.ty();
quote! { pub #n: #t }
});
let extra_derives_api = if cfg!(feature = "api") {
quote! { #[derive(utoipa::ToSchema)] }
} else {
TokenStream::new()
};
let marker = marker::generated();
quote! {
#marker
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
#extra_derives_api
#vis struct #name { #(#field_defs),* }
}
}
#[cfg(all(test, feature = "garde", not(feature = "validate")))]
mod garde_tests {
use quote::quote;
use syn::DeriveInput;
use super::*;
fn parse_entity(tokens: proc_macro2::TokenStream) -> EntityDef {
let input: DeriveInput = syn::parse2(tokens).expect("test entity must parse");
EntityDef::from_derive_input(&input).expect("test entity must be valid")
}
fn validated_entity() -> EntityDef {
parse_entity(quote! {
#[entity(table = "users")]
pub struct User {
#[id]
pub id: uuid::Uuid,
#[field(create, update, response)]
#[validate(length(min = 1, max = 64))]
pub name: String,
#[field(create, response)]
#[validate(email)]
pub email: String,
#[field(create, update, response)]
pub bio: Option<String>,
}
})
}
#[test]
fn create_dto_translates_constraints() {
let code = generate(&validated_entity()).to_string();
assert!(code.contains("garde (length (min = 1 , max = 64))"));
assert!(code.contains("garde (email)"));
assert!(code.contains("derive (garde :: Validate)"));
}
#[test]
fn unconstrained_fields_get_skip() {
let code = generate(&validated_entity()).to_string();
assert!(code.contains("garde (skip)"));
}
#[test]
fn update_dto_wraps_rules_in_inner() {
let code = generate(&validated_entity()).to_string();
assert!(code.contains("inner (length (min = 1 , max = 64))"));
}
#[test]
fn garde_derive_emitted_without_cfg_wrapper() {
let code = generate(&validated_entity()).to_string();
assert!(code.contains("derive (garde :: Validate)"));
assert!(!code.contains("cfg_attr"));
}
}
#[cfg(all(test, feature = "garde", feature = "validate"))]
mod garde_precedence_tests {
use syn::DeriveInput;
use super::*;
#[test]
fn validate_wins_over_garde() {
let input: DeriveInput = syn::parse_quote! {
#[entity(table = "users")]
pub struct User {
#[id]
pub id: uuid::Uuid,
#[field(create, response)]
#[validate(email)]
pub email: String,
}
};
let entity = EntityDef::from_derive_input(&input).unwrap();
let code = generate(&entity).to_string();
assert!(code.contains("validator :: Validate"));
assert!(!code.contains("garde"));
}
}