use proc_macro::TokenStream;
use proc_macro2::TokenStream as TokenStream2;
use quote::{format_ident, quote};
use syn::parse::{Parse, ParseStream};
use syn::{
parse_macro_input, Attribute, Expr, Fields, GenericArgument, Ident, ItemStruct, LitInt, LitStr,
PathArguments, Token, Type,
};
#[derive(Default)]
struct ContainerArgs {
rename_all: Option<LitStr>,
}
impl Parse for ContainerArgs {
fn parse(input: ParseStream) -> syn::Result<Self> {
let mut args = ContainerArgs::default();
while !input.is_empty() {
let key: Ident = input.parse()?;
input.parse::<Token![=]>()?;
match key.to_string().as_str() {
"rename_all" => args.rename_all = Some(input.parse()?),
other => {
return Err(syn::Error::new(
key.span(),
format!("unknown api_model option `{other}`"),
));
}
}
if input.is_empty() {
break;
}
input.parse::<Token![,]>()?;
}
Ok(args)
}
}
#[derive(Default)]
struct FieldArgs {
min_length: Option<LitInt>,
max_length: Option<LitInt>,
ge: Option<Expr>,
le: Option<Expr>,
gt: Option<Expr>,
lt: Option<Expr>,
title: Option<LitStr>,
description: Option<LitStr>,
custom: Vec<Expr>,
nested: bool,
default: bool,
}
impl Parse for FieldArgs {
fn parse(input: ParseStream) -> syn::Result<Self> {
let mut args = FieldArgs::default();
while !input.is_empty() {
let key: Ident = input.parse()?;
let name = key.to_string();
match name.as_str() {
"nested" => args.nested = true,
"default" => args.default = true,
_ => {
input.parse::<Token![=]>()?;
match name.as_str() {
"min_length" => args.min_length = Some(input.parse()?),
"max_length" => args.max_length = Some(input.parse()?),
"ge" => args.ge = Some(input.parse()?),
"le" => args.le = Some(input.parse()?),
"gt" => args.gt = Some(input.parse()?),
"lt" => args.lt = Some(input.parse()?),
"title" => args.title = Some(input.parse()?),
"description" => args.description = Some(input.parse()?),
"custom" => args.custom.push(input.parse()?),
other => {
return Err(syn::Error::new(
key.span(),
format!("unknown field constraint `{other}`"),
));
}
}
}
}
if input.is_empty() {
break;
}
input.parse::<Token![,]>()?;
}
Ok(args)
}
}
pub fn expand(attr: TokenStream, item: TokenStream) -> TokenStream {
let container = parse_macro_input!(attr as ContainerArgs);
let item = parse_macro_input!(item as ItemStruct);
match expand_struct(container, item) {
Ok(tokens) => tokens.into(),
Err(error) => error.to_compile_error().into(),
}
}
fn expand_struct(container: ContainerArgs, item: ItemStruct) -> syn::Result<TokenStream2> {
let fields = match &item.fields {
Fields::Named(named) => &named.named,
_ => {
return Err(syn::Error::new_spanned(
&item,
"#[api_model] supports only structs with named fields",
));
}
};
let struct_ident = &item.ident;
let vis = &item.vis;
let generics = &item.generics;
let struct_attrs = &item.attrs;
let mut field_tokens = Vec::new();
let mut extra_fns = Vec::new();
for field in fields {
let field_ident = field.ident.as_ref().expect("named field");
let field_ty = &field.ty;
let mut field_args = FieldArgs::default();
let mut preserved: Vec<&Attribute> = Vec::new();
for attr in &field.attrs {
if attr.path().is_ident("field") {
field_args = attr.parse_args()?;
} else {
preserved.push(attr);
}
}
let mut garde_rules: Vec<TokenStream2> = Vec::new();
let mut schemars_rules: Vec<TokenStream2> = Vec::new();
if field_args.min_length.is_some() || field_args.max_length.is_some() {
let parts = bound_parts(
field_args.min_length.as_ref().map(|l| quote!(#l)),
field_args.max_length.as_ref().map(|l| quote!(#l)),
);
garde_rules.push(quote!(length(#parts)));
schemars_rules.push(quote!(length(#parts)));
}
if field_args.ge.is_some() || field_args.le.is_some() {
let parts = bound_parts(
field_args.ge.as_ref().map(|e| coerce_bound(e, field_ty)),
field_args.le.as_ref().map(|e| coerce_bound(e, field_ty)),
);
garde_rules.push(quote!(range(#parts)));
schemars_rules.push(quote!(range(#parts)));
}
if let Some(bound) = &field_args.gt {
let (check_fn, call) =
exclusive_check(struct_ident, field_ident, "gt", bound, field_ty);
extra_fns.push(check_fn);
garde_rules.push(quote!(custom(#call)));
schemars_rules.push(quote!(extend("exclusiveMinimum" = #bound)));
}
if let Some(bound) = &field_args.lt {
let (check_fn, call) =
exclusive_check(struct_ident, field_ident, "lt", bound, field_ty);
extra_fns.push(check_fn);
garde_rules.push(quote!(custom(#call)));
schemars_rules.push(quote!(extend("exclusiveMaximum" = #bound)));
}
for custom in &field_args.custom {
garde_rules.push(quote!(custom(#custom)));
}
if field_args.nested {
garde_rules.push(quote!(dive));
}
if let Some(title) = &field_args.title {
schemars_rules.push(quote!(title = #title));
}
if let Some(description) = &field_args.description {
schemars_rules.push(quote!(description = #description));
}
let garde_attr = if garde_rules.is_empty() {
quote!(#[garde(skip)])
} else {
quote!(#[garde(#(#garde_rules),*)])
};
let schemars_attr = if schemars_rules.is_empty() {
quote!()
} else {
quote!(#[schemars(#(#schemars_rules),*)])
};
let serde_attr = if field_args.default {
quote!(#[serde(default)])
} else {
quote!()
};
let field_vis = &field.vis;
field_tokens.push(quote! {
#(#preserved)*
#serde_attr
#garde_attr
#schemars_attr
#field_vis #field_ident: #field_ty,
});
}
let rename_attr = container
.rename_all
.map(|rename| quote!(#[serde(rename_all = #rename)]));
Ok(quote! {
#(#struct_attrs)*
#[derive(
::core::fmt::Debug,
::core::clone::Clone,
::tork::__serde::Serialize,
::tork::__serde::Deserialize,
::tork::__garde::Validate,
::tork::__schemars::JsonSchema,
)]
#[serde(crate = "::tork::__serde")]
#rename_attr
#[schemars(crate = "::tork::__schemars")]
#vis struct #struct_ident #generics {
#(#field_tokens)*
}
#(#extra_fns)*
})
}
pub(crate) fn bound_parts(min: Option<TokenStream2>, max: Option<TokenStream2>) -> TokenStream2 {
match (min, max) {
(Some(min), Some(max)) => quote!(min = #min, max = #max),
(Some(min), None) => quote!(min = #min),
(None, Some(max)) => quote!(max = #max),
(None, None) => quote!(),
}
}
pub(crate) fn exclusive_check(
struct_ident: &Ident,
field_ident: &Ident,
kind: &str,
bound: &Expr,
field_ty: &Type,
) -> (TokenStream2, Ident) {
let fn_ident = format_ident!(
"__tork_{}_{}_{}",
to_snake(&struct_ident.to_string()),
field_ident,
kind
);
let (op, word): (TokenStream2, &str) = if kind == "gt" {
(quote!(>), "greater than")
} else {
(quote!(<), "less than")
};
let message = format!("must be {word} {}", quote!(#bound));
let compare_ty = option_inner(field_ty).unwrap_or(field_ty);
let body = if option_inner(field_ty).is_some() {
quote! {
match value {
::core::option::Option::Some(value) => {
if *value #op (#bound as #compare_ty) {
::core::result::Result::Ok(())
} else {
::core::result::Result::Err(::tork::__garde::Error::new(#message))
}
}
::core::option::Option::None => ::core::result::Result::Ok(()),
}
}
} else {
quote! {
if *value #op (#bound as #compare_ty) {
::core::result::Result::Ok(())
} else {
::core::result::Result::Err(::tork::__garde::Error::new(#message))
}
}
};
let check_fn = quote! {
#[doc(hidden)]
fn #fn_ident(
value: &#field_ty,
_ctx: &(),
) -> ::core::result::Result<(), ::tork::__garde::Error> {
#body
}
};
(check_fn, fn_ident)
}
pub(crate) fn coerce_bound(expr: &Expr, field_ty: &Type) -> TokenStream2 {
let inner = option_inner(field_ty).unwrap_or(field_ty);
if is_float_type(inner) {
if let Expr::Lit(syn::ExprLit {
lit: syn::Lit::Int(int),
..
}) = expr
{
if let Ok(float) = format!("{}.0", int.base10_digits()).parse::<TokenStream2>() {
return float;
}
}
}
quote!(#expr)
}
fn is_float_type(ty: &Type) -> bool {
matches!(ty, Type::Path(path) if path.path.is_ident("f32") || path.path.is_ident("f64"))
}
fn option_inner(ty: &Type) -> Option<&Type> {
let Type::Path(type_path) = ty else {
return None;
};
let segment = type_path.path.segments.last()?;
if segment.ident != "Option" {
return None;
}
let PathArguments::AngleBracketed(args) = &segment.arguments else {
return None;
};
args.args.iter().find_map(|arg| match arg {
GenericArgument::Type(inner) => Some(inner),
_ => None,
})
}
pub(crate) fn to_snake(input: &str) -> String {
let mut out = String::with_capacity(input.len());
for (index, ch) in input.chars().enumerate() {
if ch.is_uppercase() {
if index != 0 {
out.push('_');
}
out.extend(ch.to_lowercase());
} else {
out.push(ch);
}
}
out
}
#[cfg(test)]
mod tests {
use super::*;
use syn::parse_quote;
#[test]
fn container_args_and_field_args_parse_known_keys() {
let container: ContainerArgs = parse_quote!(rename_all = "camelCase");
assert_eq!(container.rename_all.unwrap().value(), "camelCase");
let fields: FieldArgs = parse_quote!(
min_length = 1,
max_length = 3,
ge = 1,
le = 9,
gt = 2,
lt = 8,
title = "Title",
description = "Desc",
custom = validate_name,
nested,
default
);
assert!(fields.min_length.is_some());
assert!(fields.max_length.is_some());
assert!(fields.ge.is_some());
assert!(fields.le.is_some());
assert!(fields.gt.is_some());
assert!(fields.lt.is_some());
assert_eq!(fields.title.unwrap().value(), "Title");
assert_eq!(fields.description.unwrap().value(), "Desc");
assert_eq!(fields.custom.len(), 1);
assert!(fields.nested);
assert!(fields.default);
}
#[test]
fn parse_rejects_unknown_container_and_field_options() {
let err = match syn::parse2::<ContainerArgs>(quote!(unknown = "x")) {
Ok(_) => panic!("expected parse error"),
Err(err) => err,
};
assert!(err.to_string().contains("unknown api_model option"));
let err = match syn::parse2::<FieldArgs>(quote!(unknown = 1)) {
Ok(_) => panic!("expected parse error"),
Err(err) => err,
};
assert!(err.to_string().contains("unknown field constraint"));
}
#[test]
fn helper_functions_cover_bounds_and_identifiers() {
assert_eq!(
bound_parts(Some(quote!(1)), Some(quote!(9))).to_string(),
"min = 1 , max = 9"
);
assert_eq!(bound_parts(Some(quote!(1)), None).to_string(), "min = 1");
assert_eq!(bound_parts(None, Some(quote!(9))).to_string(), "max = 9");
assert_eq!(bound_parts(None, None).to_string(), "");
let field_ty: Type = parse_quote!(f64);
assert_eq!(coerce_bound(&parse_quote!(0), &field_ty).to_string(), "0.0");
assert_eq!(
coerce_bound(&parse_quote!(2.5), &field_ty).to_string(),
"2.5"
);
let option_ty: Type = parse_quote!(Option<u32>);
let plain_ty: Type = parse_quote!(String);
assert!(option_inner(&option_ty).is_some());
assert!(option_inner(&plain_ty).is_none());
assert!(is_float_type(&parse_quote!(f32)));
assert!(!is_float_type(&parse_quote!(u32)));
assert_eq!(to_snake("HTTPServer"), "h_t_t_p_server");
}
#[test]
fn exclusive_check_handles_option_and_plain_values() {
let struct_ident = parse_quote!(Model);
let field_ident = parse_quote!(count);
let bound: Expr = parse_quote!(3);
let plain_ty: Type = parse_quote!(u32);
let opt_ty: Type = parse_quote!(Option<u32>);
let (_, plain_fn) = exclusive_check(&struct_ident, &field_ident, "gt", &bound, &plain_ty);
assert!(plain_fn.to_string().contains("model_count_gt"));
let (_, opt_fn) = exclusive_check(&struct_ident, &field_ident, "lt", &bound, &opt_ty);
assert!(opt_fn.to_string().contains("model_count_lt"));
}
#[test]
fn expand_struct_handles_named_fields_and_constraints() {
let container: ContainerArgs = parse_quote!(rename_all = "camelCase");
let item: ItemStruct = parse_quote! {
#[doc = "model"]
pub struct User {
#[field(min_length = 1, max_length = 3, default)]
pub name: String,
#[field(ge = 0, le = 10, gt = 0, lt = 10, title = "Score", description = "Score field")]
pub score: Option<f64>,
#[field(custom = validate_name, nested)]
pub child: Option<Child>,
pub raw: String,
}
};
let tokens = expand_struct(container, item).unwrap().to_string();
assert!(tokens.contains("rename_all"));
assert!(tokens.contains("camelCase"));
assert!(tokens.contains("skip"));
assert!(tokens.contains("length"));
assert!(tokens.contains("range"));
assert!(tokens.contains("exclusiveMinimum"));
assert!(tokens.contains("exclusiveMaximum"));
assert!(tokens.contains("title"));
assert!(tokens.contains("description"));
assert!(tokens.contains("default"));
assert!(tokens.contains("validate_name"));
assert!(tokens.contains("dive"));
assert!(tokens.contains("fn __tork_user_score_gt"));
}
}