use proc_macro::TokenStream;
use proc_macro_crate::{crate_name, FoundCrate};
use quote::quote;
use std::collections::HashSet;
use syn::{
parse_macro_input, Attribute, Data, DeriveInput, Expr, Fields, FnArg, GenericArgument, ItemFn,
Lit, LitStr, Meta, PathArguments, ReturnType, Type,
};
mod api_error;
mod derive_schema;
fn get_rustapi_path() -> proc_macro2::TokenStream {
let rustapi_rs_found = crate_name("rustapi-rs").or_else(|_| crate_name("rustapi_rs"));
if let Ok(found) = rustapi_rs_found {
match found {
FoundCrate::Itself => quote! { ::rustapi_rs },
FoundCrate::Name(name) => {
let normalized = name.replace('-', "_");
let ident = syn::Ident::new(&normalized, proc_macro2::Span::call_site());
quote! { ::#ident }
}
}
} else {
quote! { ::rustapi_rs }
}
}
#[proc_macro_derive(Schema, attributes(schema))]
pub fn derive_schema(input: TokenStream) -> TokenStream {
derive_schema::expand_derive_schema(parse_macro_input!(input as DeriveInput)).into()
}
#[proc_macro_attribute]
pub fn schema(_attr: TokenStream, item: TokenStream) -> TokenStream {
let input = parse_macro_input!(item as syn::Item);
let rustapi_path = get_rustapi_path();
let (ident, generics) = match &input {
syn::Item::Struct(s) => (&s.ident, &s.generics),
syn::Item::Enum(e) => (&e.ident, &e.generics),
_ => {
return syn::Error::new_spanned(
&input,
"#[rustapi_rs::schema] can only be used on structs or enums",
)
.to_compile_error()
.into();
}
};
if !generics.params.is_empty() {
return syn::Error::new_spanned(
generics,
"#[rustapi_rs::schema] does not support generic types",
)
.to_compile_error()
.into();
}
let registrar_ident = syn::Ident::new(
&format!("__RUSTAPI_AUTO_SCHEMA_{}", ident),
proc_macro2::Span::call_site(),
);
let expanded = quote! {
#input
#[allow(non_upper_case_globals)]
#[#rustapi_path::__private::linkme::distributed_slice(#rustapi_path::__private::AUTO_SCHEMAS)]
#[linkme(crate = #rustapi_path::__private::linkme)]
static #registrar_ident: fn(&mut #rustapi_path::__private::openapi::OpenApiSpec) =
|spec: &mut #rustapi_path::__private::openapi::OpenApiSpec| {
spec.register_in_place::<#ident>();
};
};
debug_output("schema", &expanded);
expanded.into()
}
fn extract_schema_types(ty: &Type, out: &mut Vec<Type>, allow_leaf: bool) {
match ty {
Type::Reference(r) => extract_schema_types(&r.elem, out, allow_leaf),
Type::Path(tp) => {
let Some(seg) = tp.path.segments.last() else {
return;
};
let ident = seg.ident.to_string();
let unwrap_first_generic = |out: &mut Vec<Type>| {
if let PathArguments::AngleBracketed(args) = &seg.arguments {
if let Some(GenericArgument::Type(inner)) = args.args.first() {
extract_schema_types(inner, out, true);
}
}
};
match ident.as_str() {
"Json" | "ValidatedJson" | "Created" => {
unwrap_first_generic(out);
}
"WithStatus" => {
if let PathArguments::AngleBracketed(args) = &seg.arguments {
if let Some(GenericArgument::Type(inner)) = args.args.first() {
extract_schema_types(inner, out, true);
}
}
}
"Option" | "Result" => {
if let PathArguments::AngleBracketed(args) = &seg.arguments {
if let Some(GenericArgument::Type(inner)) = args.args.first() {
extract_schema_types(inner, out, allow_leaf);
}
}
}
_ => {
if allow_leaf {
out.push(ty.clone());
}
}
}
}
_ => {}
}
}
fn collect_handler_schema_types(input: &ItemFn) -> Vec<Type> {
let mut found: Vec<Type> = Vec::new();
for arg in &input.sig.inputs {
if let FnArg::Typed(pat_ty) = arg {
extract_schema_types(&pat_ty.ty, &mut found, false);
}
}
if let ReturnType::Type(_, ty) = &input.sig.output {
extract_schema_types(ty, &mut found, false);
}
let mut seen = HashSet::<String>::new();
found
.into_iter()
.filter(|t| seen.insert(quote!(#t).to_string()))
.collect()
}
fn collect_path_params(input: &ItemFn) -> Vec<(String, String)> {
let mut params = Vec::new();
for arg in &input.sig.inputs {
if let FnArg::Typed(pat_ty) = arg {
if let Type::Path(tp) = &*pat_ty.ty {
if let Some(seg) = tp.path.segments.last() {
if seg.ident == "Path" {
if let PathArguments::AngleBracketed(args) = &seg.arguments {
if let Some(GenericArgument::Type(inner_ty)) = args.args.first() {
if let Some(schema_type) = map_type_to_schema(inner_ty) {
if let Some(name) = extract_param_name(&pat_ty.pat) {
params.push((name, schema_type));
}
}
}
}
}
}
}
}
}
params
}
fn extract_param_name(pat: &syn::Pat) -> Option<String> {
match pat {
syn::Pat::Ident(ident) => Some(ident.ident.to_string()),
syn::Pat::TupleStruct(ts) => {
if let Some(first) = ts.elems.first() {
extract_param_name(first)
} else {
None
}
}
_ => None, }
}
fn map_type_to_schema(ty: &Type) -> Option<String> {
match ty {
Type::Path(tp) => {
if let Some(seg) = tp.path.segments.last() {
let ident = seg.ident.to_string();
match ident.as_str() {
"Uuid" => Some("uuid".to_string()),
"String" | "str" => Some("string".to_string()),
"bool" => Some("boolean".to_string()),
"i8" | "i16" | "i32" | "i64" | "isize" | "u8" | "u16" | "u32" | "u64"
| "usize" => Some("integer".to_string()),
"f32" | "f64" => Some("number".to_string()),
_ => None,
}
} else {
None
}
}
_ => None,
}
}
fn is_debug_enabled() -> bool {
std::env::var("RUSTAPI_DEBUG")
.map(|v| v == "1" || v.eq_ignore_ascii_case("true"))
.unwrap_or(false)
}
fn debug_output(name: &str, tokens: &proc_macro2::TokenStream) {
if is_debug_enabled() {
eprintln!("\n=== RUSTAPI_DEBUG: {} ===", name);
eprintln!("{}", tokens);
eprintln!("=== END {} ===\n", name);
}
}
fn validate_path_syntax(path: &str, span: proc_macro2::Span) -> Result<(), syn::Error> {
if !path.starts_with('/') {
return Err(syn::Error::new(
span,
format!("route path must start with '/', got: \"{}\"", path),
));
}
if path.contains("//") {
return Err(syn::Error::new(
span,
format!(
"route path contains empty segment (double slash): \"{}\"",
path
),
));
}
let mut brace_depth = 0;
let mut param_start = None;
for (i, ch) in path.char_indices() {
match ch {
'{' => {
if brace_depth > 0 {
return Err(syn::Error::new(
span,
format!(
"nested braces are not allowed in route path at position {}: \"{}\"",
i, path
),
));
}
brace_depth += 1;
param_start = Some(i);
}
'}' => {
if brace_depth == 0 {
return Err(syn::Error::new(
span,
format!(
"unmatched closing brace '}}' at position {} in route path: \"{}\"",
i, path
),
));
}
brace_depth -= 1;
if let Some(start) = param_start {
let param_name = &path[start + 1..i];
if param_name.is_empty() {
return Err(syn::Error::new(
span,
format!(
"empty parameter name '{{}}' at position {} in route path: \"{}\"",
start, path
),
));
}
if !param_name.chars().all(|c| c.is_alphanumeric() || c == '_') {
return Err(syn::Error::new(
span,
format!(
"invalid parameter name '{{{}}}' at position {} - parameter names must contain only alphanumeric characters and underscores: \"{}\"",
param_name, start, path
),
));
}
if param_name
.chars()
.next()
.map(|c| c.is_ascii_digit())
.unwrap_or(false)
{
return Err(syn::Error::new(
span,
format!(
"parameter name '{{{}}}' cannot start with a digit at position {}: \"{}\"",
param_name, start, path
),
));
}
}
param_start = None;
}
_ if brace_depth == 0 => {
if !ch.is_alphanumeric() && !"-_./*".contains(ch) {
return Err(syn::Error::new(
span,
format!(
"invalid character '{}' at position {} in route path: \"{}\"",
ch, i, path
),
));
}
}
_ => {}
}
}
if brace_depth > 0 {
return Err(syn::Error::new(
span,
format!(
"unclosed brace '{{' in route path (missing closing '}}'): \"{}\"",
path
),
));
}
Ok(())
}
#[proc_macro_attribute]
pub fn main(_attr: TokenStream, item: TokenStream) -> TokenStream {
let input = parse_macro_input!(item as ItemFn);
let attrs = &input.attrs;
let vis = &input.vis;
let sig = &input.sig;
let block = &input.block;
let expanded = quote! {
#(#attrs)*
#[::tokio::main]
#vis #sig {
#block
}
};
debug_output("main", &expanded);
TokenStream::from(expanded)
}
fn is_body_consuming_type(ty: &Type) -> bool {
match ty {
Type::Path(tp) => {
if let Some(seg) = tp.path.segments.last() {
matches!(
seg.ident.to_string().as_str(),
"Json" | "Body" | "ValidatedJson" | "AsyncValidatedJson" | "Multipart"
)
} else {
false
}
}
_ => false,
}
}
fn validate_extractor_order(input: &ItemFn) -> Result<(), syn::Error> {
let params: Vec<_> = input
.sig
.inputs
.iter()
.filter_map(|arg| {
if let FnArg::Typed(pat_ty) = arg {
Some(pat_ty)
} else {
None
}
})
.collect();
if params.is_empty() {
return Ok(());
}
let body_indices: Vec<usize> = params
.iter()
.enumerate()
.filter(|(_, p)| is_body_consuming_type(&p.ty))
.map(|(i, _)| i)
.collect();
if body_indices.is_empty() {
return Ok(());
}
let last_non_body = params
.iter()
.enumerate()
.filter(|(_, p)| !is_body_consuming_type(&p.ty))
.map(|(i, _)| i)
.max();
if let Some(last_non_body_idx) = last_non_body {
let first_body_idx = body_indices[0];
if first_body_idx < last_non_body_idx {
let offending_param = ¶ms[first_body_idx];
let ty_name = quote!(#offending_param).to_string();
return Err(syn::Error::new_spanned(
&offending_param.ty,
format!(
"Body-consuming extractor must be the LAST parameter.\n\
\n\
Found `{}` before non-body extractor(s).\n\
\n\
Body extractors (Json, Body, ValidatedJson, AsyncValidatedJson, Multipart) \
consume the request body, which can only be read once. Place them after all \
non-body extractors (State, Path, Query, Headers, etc.).\n\
\n\
Example:\n\
\x20 async fn handler(\n\
\x20 State(db): State<AppState>, // non-body: OK first\n\
\x20 Path(id): Path<i64>, // non-body: OK second\n\
\x20 Json(body): Json<CreateUser>, // body: MUST be last\n\
\x20 ) -> Result<Json<User>> {{ ... }}",
ty_name,
),
));
}
}
if body_indices.len() > 1 {
let second_body_param = ¶ms[body_indices[1]];
return Err(syn::Error::new_spanned(
&second_body_param.ty,
"Multiple body-consuming extractors detected.\n\
\n\
Only ONE body-consuming extractor (Json, Body, ValidatedJson, AsyncValidatedJson, \
Multipart) is allowed per handler, because the request body can only be consumed once.\n\
\n\
Remove the extra body extractor or combine the data into a single type.",
));
}
Ok(())
}
fn generate_route_handler(method: &str, attr: TokenStream, item: TokenStream) -> TokenStream {
let path = parse_macro_input!(attr as LitStr);
let input = parse_macro_input!(item as ItemFn);
let rustapi_path = get_rustapi_path();
let fn_name = &input.sig.ident;
let fn_vis = &input.vis;
let fn_attrs = &input.attrs;
let fn_async = &input.sig.asyncness;
let fn_inputs = &input.sig.inputs;
let fn_output = &input.sig.output;
let fn_block = &input.block;
let fn_generics = &input.sig.generics;
let schema_types = collect_handler_schema_types(&input);
let path_value = path.value();
if let Err(err) = validate_path_syntax(&path_value, path.span()) {
return err.to_compile_error().into();
}
if let Err(err) = validate_extractor_order(&input) {
return err.to_compile_error().into();
}
let route_fn_name = syn::Ident::new(&format!("{}_route", fn_name), fn_name.span());
let auto_route_name = syn::Ident::new(&format!("__AUTO_ROUTE_{}", fn_name), fn_name.span());
let schema_reg_fn_name =
syn::Ident::new(&format!("__{}_register_schemas", fn_name), fn_name.span());
let auto_schema_name = syn::Ident::new(&format!("__AUTO_SCHEMA_{}", fn_name), fn_name.span());
let route_helper = match method {
"GET" => quote!(#rustapi_path::get_route),
"POST" => quote!(#rustapi_path::post_route),
"PUT" => quote!(#rustapi_path::put_route),
"PATCH" => quote!(#rustapi_path::patch_route),
"DELETE" => quote!(#rustapi_path::delete_route),
_ => quote!(#rustapi_path::get_route),
};
let auto_params = collect_path_params(&input);
let mut chained_calls = quote!();
for (name, schema) in auto_params {
chained_calls = quote! { #chained_calls .param(#name, #schema) };
}
for attr in fn_attrs {
if let Some(ident) = attr.path().segments.last().map(|s| &s.ident) {
let ident_str = ident.to_string();
if ident_str == "tag" {
if let Ok(lit) = attr.parse_args::<LitStr>() {
let val = lit.value();
chained_calls = quote! { #chained_calls .tag(#val) };
}
} else if ident_str == "summary" {
if let Ok(lit) = attr.parse_args::<LitStr>() {
let val = lit.value();
chained_calls = quote! { #chained_calls .summary(#val) };
}
} else if ident_str == "description" {
if let Ok(lit) = attr.parse_args::<LitStr>() {
let val = lit.value();
chained_calls = quote! { #chained_calls .description(#val) };
}
} else if ident_str == "param" {
if let Ok(param_args) = attr.parse_args_with(
syn::punctuated::Punctuated::<Meta, syn::Token![,]>::parse_terminated,
) {
let mut param_name: Option<String> = None;
let mut param_schema: Option<String> = None;
for meta in param_args {
match &meta {
Meta::Path(path) => {
if param_name.is_none() {
if let Some(ident) = path.get_ident() {
param_name = Some(ident.to_string());
}
}
}
Meta::NameValue(nv) => {
let key = nv.path.get_ident().map(|i| i.to_string());
if let Some(key) = key {
if key == "schema" || key == "type" {
if let Expr::Lit(lit) = &nv.value {
if let Lit::Str(s) = &lit.lit {
param_schema = Some(s.value());
}
}
} else if param_name.is_none() {
param_name = Some(key);
if let Expr::Lit(lit) = &nv.value {
if let Lit::Str(s) = &lit.lit {
param_schema = Some(s.value());
}
}
}
}
}
_ => {}
}
}
if let (Some(pname), Some(pschema)) = (param_name, param_schema) {
chained_calls = quote! { #chained_calls .param(#pname, #pschema) };
}
}
} else if ident_str == "errors" {
if let Ok(error_args) = attr.parse_args_with(
syn::punctuated::Punctuated::<Meta, syn::Token![,]>::parse_terminated,
) {
for meta in error_args {
if let Meta::NameValue(nv) = &meta {
let status_str = nv.path.get_ident().map(|i| i.to_string());
if let Some(status_key) = status_str {
if let Expr::Lit(lit) = &nv.value {
if let Lit::Str(s) = &lit.lit {
let desc = s.value();
chained_calls = quote! {
#chained_calls .error_response(#status_key, #desc)
};
}
}
}
} else if let Meta::List(list) = &meta {
let _ = list;
}
}
}
if let Ok(ts) = attr.parse_args::<proc_macro2::TokenStream>() {
let tokens: Vec<proc_macro2::TokenTree> = ts.into_iter().collect();
let mut i = 0;
while i < tokens.len() {
if let proc_macro2::TokenTree::Literal(lit) = &tokens[i] {
let lit_str = lit.to_string();
if let Ok(status_code) = lit_str.parse::<u16>() {
if i + 2 < tokens.len() {
if let proc_macro2::TokenTree::Punct(p) = &tokens[i + 1] {
if p.as_char() == '=' {
if let proc_macro2::TokenTree::Literal(desc_lit) =
&tokens[i + 2]
{
let desc_str = desc_lit.to_string();
let desc = desc_str.trim_matches('"').to_string();
chained_calls = quote! {
#chained_calls .error_response(#status_code, #desc)
};
i += 3;
if i < tokens.len() {
if let proc_macro2::TokenTree::Punct(p) =
&tokens[i]
{
if p.as_char() == ',' {
i += 1;
}
}
}
continue;
}
}
}
}
}
}
i += 1;
}
}
}
}
}
let expanded = quote! {
#(#fn_attrs)*
#fn_vis #fn_async fn #fn_name #fn_generics (#fn_inputs) #fn_output #fn_block
#[doc(hidden)]
#fn_vis fn #route_fn_name() -> #rustapi_path::Route {
#route_helper(#path_value, #fn_name)
#chained_calls
}
#[doc(hidden)]
#[allow(non_upper_case_globals)]
#[#rustapi_path::__private::linkme::distributed_slice(#rustapi_path::__private::AUTO_ROUTES)]
#[linkme(crate = #rustapi_path::__private::linkme)]
static #auto_route_name: fn() -> #rustapi_path::Route = #route_fn_name;
#[doc(hidden)]
#[allow(non_snake_case)]
fn #schema_reg_fn_name(spec: &mut #rustapi_path::__private::openapi::OpenApiSpec) {
#( spec.register_in_place::<#schema_types>(); )*
}
#[doc(hidden)]
#[allow(non_upper_case_globals)]
#[#rustapi_path::__private::linkme::distributed_slice(#rustapi_path::__private::AUTO_SCHEMAS)]
#[linkme(crate = #rustapi_path::__private::linkme)]
static #auto_schema_name: fn(&mut #rustapi_path::__private::openapi::OpenApiSpec) = #schema_reg_fn_name;
};
debug_output(&format!("{} {}", method, path_value), &expanded);
TokenStream::from(expanded)
}
#[proc_macro_attribute]
pub fn get(attr: TokenStream, item: TokenStream) -> TokenStream {
generate_route_handler("GET", attr, item)
}
#[proc_macro_attribute]
pub fn post(attr: TokenStream, item: TokenStream) -> TokenStream {
generate_route_handler("POST", attr, item)
}
#[proc_macro_attribute]
pub fn put(attr: TokenStream, item: TokenStream) -> TokenStream {
generate_route_handler("PUT", attr, item)
}
#[proc_macro_attribute]
pub fn patch(attr: TokenStream, item: TokenStream) -> TokenStream {
generate_route_handler("PATCH", attr, item)
}
#[proc_macro_attribute]
pub fn delete(attr: TokenStream, item: TokenStream) -> TokenStream {
generate_route_handler("DELETE", attr, item)
}
#[proc_macro_attribute]
pub fn tag(attr: TokenStream, item: TokenStream) -> TokenStream {
let tag = parse_macro_input!(attr as LitStr);
let input = parse_macro_input!(item as ItemFn);
let attrs = &input.attrs;
let vis = &input.vis;
let sig = &input.sig;
let block = &input.block;
let tag_value = tag.value();
let expanded = quote! {
#[doc = concat!("**Tag:** ", #tag_value)]
#(#attrs)*
#vis #sig #block
};
TokenStream::from(expanded)
}
#[proc_macro_attribute]
pub fn summary(attr: TokenStream, item: TokenStream) -> TokenStream {
let summary = parse_macro_input!(attr as LitStr);
let input = parse_macro_input!(item as ItemFn);
let attrs = &input.attrs;
let vis = &input.vis;
let sig = &input.sig;
let block = &input.block;
let summary_value = summary.value();
let expanded = quote! {
#[doc = #summary_value]
#(#attrs)*
#vis #sig #block
};
TokenStream::from(expanded)
}
#[proc_macro_attribute]
pub fn description(attr: TokenStream, item: TokenStream) -> TokenStream {
let desc = parse_macro_input!(attr as LitStr);
let input = parse_macro_input!(item as ItemFn);
let attrs = &input.attrs;
let vis = &input.vis;
let sig = &input.sig;
let block = &input.block;
let desc_value = desc.value();
let expanded = quote! {
#[doc = ""]
#[doc = #desc_value]
#(#attrs)*
#vis #sig #block
};
TokenStream::from(expanded)
}
#[proc_macro_attribute]
pub fn param(_attr: TokenStream, item: TokenStream) -> TokenStream {
item
}
#[proc_macro_attribute]
pub fn errors(_attr: TokenStream, item: TokenStream) -> TokenStream {
item
}
#[derive(Debug)]
struct ValidationRuleInfo {
rule_type: String,
params: Vec<(String, String)>,
message: Option<String>,
groups: Vec<String>,
}
fn parse_validate_attrs(attrs: &[Attribute]) -> Vec<ValidationRuleInfo> {
let mut rules = Vec::new();
for attr in attrs {
if !attr.path().is_ident("validate") {
continue;
}
if let Ok(meta) = attr.parse_args::<Meta>() {
if let Some(rule) = parse_validate_meta(&meta) {
rules.push(rule);
}
} else if let Ok(nested) = attr
.parse_args_with(syn::punctuated::Punctuated::<Meta, syn::Token![,]>::parse_terminated)
{
for meta in nested {
if let Some(rule) = parse_validate_meta(&meta) {
rules.push(rule);
}
}
}
}
rules
}
fn parse_validate_meta(meta: &Meta) -> Option<ValidationRuleInfo> {
match meta {
Meta::Path(path) => {
let ident = path.get_ident()?.to_string();
Some(ValidationRuleInfo {
rule_type: ident,
params: Vec::new(),
message: None,
groups: Vec::new(),
})
}
Meta::List(list) => {
let rule_type = list.path.get_ident()?.to_string();
let mut params = Vec::new();
let mut message = None;
let mut groups = Vec::new();
if let Ok(nested) = list.parse_args_with(
syn::punctuated::Punctuated::<Meta, syn::Token![,]>::parse_terminated,
) {
for nested_meta in nested {
if let Meta::NameValue(nv) = &nested_meta {
let key = nv.path.get_ident()?.to_string();
if key == "groups" {
let vec = expr_to_string_vec(&nv.value);
groups.extend(vec);
} else if let Some(value) = expr_to_string(&nv.value) {
if key == "message" {
message = Some(value);
} else if key == "group" {
groups.push(value);
} else {
params.push((key, value));
}
}
} else if let Meta::Path(path) = &nested_meta {
if let Some(ident) = path.get_ident() {
params.push((ident.to_string(), "true".to_string()));
}
}
}
}
Some(ValidationRuleInfo {
rule_type,
params,
message,
groups,
})
}
Meta::NameValue(nv) => {
let rule_type = nv.path.get_ident()?.to_string();
let value = expr_to_string(&nv.value)?;
Some(ValidationRuleInfo {
rule_type: rule_type.clone(),
params: vec![(rule_type.clone(), value)],
message: None,
groups: Vec::new(),
})
}
}
}
fn expr_to_string(expr: &Expr) -> Option<String> {
match expr {
Expr::Lit(lit) => match &lit.lit {
Lit::Str(s) => Some(s.value()),
Lit::Int(i) => Some(i.base10_digits().to_string()),
Lit::Float(f) => Some(f.base10_digits().to_string()),
Lit::Bool(b) => Some(b.value.to_string()),
_ => None,
},
_ => None,
}
}
fn expr_to_string_vec(expr: &Expr) -> Vec<String> {
match expr {
Expr::Array(arr) => {
let mut result = Vec::new();
for elem in &arr.elems {
if let Some(s) = expr_to_string(elem) {
result.push(s);
}
}
result
}
_ => {
if let Some(s) = expr_to_string(expr) {
vec![s]
} else {
Vec::new()
}
}
}
}
fn get_validate_path() -> proc_macro2::TokenStream {
let rustapi_rs_found = crate_name("rustapi-rs").or_else(|_| crate_name("rustapi_rs"));
if let Ok(found) = rustapi_rs_found {
match found {
FoundCrate::Itself => {
quote! { ::rustapi_rs::__private::validate }
}
FoundCrate::Name(name) => {
let normalized = name.replace('-', "_");
let ident = syn::Ident::new(&normalized, proc_macro2::Span::call_site());
quote! { ::#ident::__private::validate }
}
}
} else if let Ok(found) =
crate_name("rustapi-validate").or_else(|_| crate_name("rustapi_validate"))
{
match found {
FoundCrate::Itself => quote! { crate },
FoundCrate::Name(name) => {
let normalized = name.replace('-', "_");
let ident = syn::Ident::new(&normalized, proc_macro2::Span::call_site());
quote! { ::#ident }
}
}
} else {
quote! { ::rustapi_validate }
}
}
fn get_core_path() -> proc_macro2::TokenStream {
let rustapi_rs_found = crate_name("rustapi-rs").or_else(|_| crate_name("rustapi_rs"));
if let Ok(found) = rustapi_rs_found {
match found {
FoundCrate::Itself => quote! { ::rustapi_rs::__private::core },
FoundCrate::Name(name) => {
let normalized = name.replace('-', "_");
let ident = syn::Ident::new(&normalized, proc_macro2::Span::call_site());
quote! { ::#ident::__private::core }
}
}
} else if let Ok(found) = crate_name("rustapi-core").or_else(|_| crate_name("rustapi_core")) {
match found {
FoundCrate::Itself => quote! { crate },
FoundCrate::Name(name) => {
let normalized = name.replace('-', "_");
let ident = syn::Ident::new(&normalized, proc_macro2::Span::call_site());
quote! { ::#ident }
}
}
} else {
quote! { ::rustapi_core }
}
}
fn get_async_trait_path() -> proc_macro2::TokenStream {
let rustapi_rs_found = crate_name("rustapi-rs").or_else(|_| crate_name("rustapi_rs"));
if let Ok(found) = rustapi_rs_found {
match found {
FoundCrate::Itself => {
quote! { ::rustapi_rs::__private::async_trait }
}
FoundCrate::Name(name) => {
let normalized = name.replace('-', "_");
let ident = syn::Ident::new(&normalized, proc_macro2::Span::call_site());
quote! { ::#ident::__private::async_trait }
}
}
} else if let Ok(found) = crate_name("async-trait").or_else(|_| crate_name("async_trait")) {
match found {
FoundCrate::Itself => quote! { crate },
FoundCrate::Name(name) => {
let normalized = name.replace('-', "_");
let ident = syn::Ident::new(&normalized, proc_macro2::Span::call_site());
quote! { ::#ident }
}
}
} else {
quote! { ::async_trait }
}
}
fn generate_rule_validation(
field_name: &str,
_field_type: &Type,
rule: &ValidationRuleInfo,
validate_path: &proc_macro2::TokenStream,
) -> proc_macro2::TokenStream {
let field_ident = syn::Ident::new(field_name, proc_macro2::Span::call_site());
let field_name_str = field_name;
let group_check = if rule.groups.is_empty() {
quote! { true }
} else {
let group_names = rule.groups.iter().map(|g| g.as_str());
quote! {
{
let rule_groups = [#(#validate_path::v2::ValidationGroup::from(#group_names)),*];
rule_groups.iter().any(|g| g.matches(&group))
}
}
};
let validation_logic = match rule.rule_type.as_str() {
"email" => {
let message = rule
.message
.as_ref()
.map(|m| quote! { .with_message(#m) })
.unwrap_or_default();
quote! {
{
let rule = #validate_path::v2::EmailRule::new() #message;
if let Err(e) = #validate_path::v2::ValidationRule::validate(&rule, &self.#field_ident) {
errors.add(#field_name_str, e);
}
}
}
}
"length" => {
let min = rule
.params
.iter()
.find(|(k, _)| k == "min")
.and_then(|(_, v)| v.parse::<usize>().ok());
let max = rule
.params
.iter()
.find(|(k, _)| k == "max")
.and_then(|(_, v)| v.parse::<usize>().ok());
let message = rule
.message
.as_ref()
.map(|m| quote! { .with_message(#m) })
.unwrap_or_default();
let rule_creation = match (min, max) {
(Some(min), Some(max)) => {
quote! { #validate_path::v2::LengthRule::new(#min, #max) }
}
(Some(min), None) => quote! { #validate_path::v2::LengthRule::min(#min) },
(None, Some(max)) => quote! { #validate_path::v2::LengthRule::max(#max) },
(None, None) => quote! { #validate_path::v2::LengthRule::new(0, usize::MAX) },
};
quote! {
{
let rule = #rule_creation #message;
if let Err(e) = #validate_path::v2::ValidationRule::validate(&rule, &self.#field_ident) {
errors.add(#field_name_str, e);
}
}
}
}
"range" => {
let min = rule
.params
.iter()
.find(|(k, _)| k == "min")
.map(|(_, v)| v.clone());
let max = rule
.params
.iter()
.find(|(k, _)| k == "max")
.map(|(_, v)| v.clone());
let message = rule
.message
.as_ref()
.map(|m| quote! { .with_message(#m) })
.unwrap_or_default();
let rule_creation = match (min, max) {
(Some(min), Some(max)) => {
let min_lit: proc_macro2::TokenStream = min.parse().unwrap();
let max_lit: proc_macro2::TokenStream = max.parse().unwrap();
quote! { #validate_path::v2::RangeRule::new(#min_lit, #max_lit) }
}
(Some(min), None) => {
let min_lit: proc_macro2::TokenStream = min.parse().unwrap();
quote! { #validate_path::v2::RangeRule::min(#min_lit) }
}
(None, Some(max)) => {
let max_lit: proc_macro2::TokenStream = max.parse().unwrap();
quote! { #validate_path::v2::RangeRule::max(#max_lit) }
}
(None, None) => {
return quote! {};
}
};
quote! {
{
let rule = #rule_creation #message;
if let Err(e) = #validate_path::v2::ValidationRule::validate(&rule, &self.#field_ident) {
errors.add(#field_name_str, e);
}
}
}
}
"regex" => {
let pattern = rule
.params
.iter()
.find(|(k, _)| k == "regex" || k == "pattern")
.map(|(_, v)| v.clone())
.unwrap_or_default();
let message = rule
.message
.as_ref()
.map(|m| quote! { .with_message(#m) })
.unwrap_or_default();
quote! {
{
let rule = #validate_path::v2::RegexRule::new(#pattern) #message;
if let Err(e) = #validate_path::v2::ValidationRule::validate(&rule, &self.#field_ident) {
errors.add(#field_name_str, e);
}
}
}
}
"url" => {
let message = rule
.message
.as_ref()
.map(|m| quote! { .with_message(#m) })
.unwrap_or_default();
quote! {
{
let rule = #validate_path::v2::UrlRule::new() #message;
if let Err(e) = #validate_path::v2::ValidationRule::validate(&rule, &self.#field_ident) {
errors.add(#field_name_str, e);
}
}
}
}
"required" => {
let message = rule
.message
.as_ref()
.map(|m| quote! { .with_message(#m) })
.unwrap_or_default();
quote! {
{
let rule = #validate_path::v2::RequiredRule::new() #message;
if let Err(e) = #validate_path::v2::ValidationRule::validate(&rule, &self.#field_ident) {
errors.add(#field_name_str, e);
}
}
}
}
"credit_card" => {
let message = rule
.message
.as_ref()
.map(|m| quote! { .with_message(#m) })
.unwrap_or_default();
quote! {
{
let rule = #validate_path::v2::CreditCardRule::new() #message;
if let Err(e) = #validate_path::v2::ValidationRule::validate(&rule, &self.#field_ident) {
errors.add(#field_name_str, e);
}
}
}
}
"ip" => {
let v4 = rule.params.iter().any(|(k, _)| k == "v4");
let v6 = rule.params.iter().any(|(k, _)| k == "v6");
let rule_creation = if v4 && !v6 {
quote! { #validate_path::v2::IpRule::v4() }
} else if !v4 && v6 {
quote! { #validate_path::v2::IpRule::v6() }
} else {
quote! { #validate_path::v2::IpRule::new() }
};
let message = rule
.message
.as_ref()
.map(|m| quote! { .with_message(#m) })
.unwrap_or_default();
quote! {
{
let rule = #rule_creation #message;
if let Err(e) = #validate_path::v2::ValidationRule::validate(&rule, &self.#field_ident) {
errors.add(#field_name_str, e);
}
}
}
}
"phone" => {
let message = rule
.message
.as_ref()
.map(|m| quote! { .with_message(#m) })
.unwrap_or_default();
quote! {
{
let rule = #validate_path::v2::PhoneRule::new() #message;
if let Err(e) = #validate_path::v2::ValidationRule::validate(&rule, &self.#field_ident) {
errors.add(#field_name_str, e);
}
}
}
}
"contains" => {
let needle = rule
.params
.iter()
.find(|(k, _)| k == "needle")
.map(|(_, v)| v.clone())
.unwrap_or_default();
let message = rule
.message
.as_ref()
.map(|m| quote! { .with_message(#m) })
.unwrap_or_default();
quote! {
{
let rule = #validate_path::v2::ContainsRule::new(#needle) #message;
if let Err(e) = #validate_path::v2::ValidationRule::validate(&rule, &self.#field_ident) {
errors.add(#field_name_str, e);
}
}
}
}
_ => {
quote! {}
}
};
quote! {
if #group_check {
#validation_logic
}
}
}
fn generate_async_rule_validation(
field_name: &str,
rule: &ValidationRuleInfo,
validate_path: &proc_macro2::TokenStream,
) -> proc_macro2::TokenStream {
let field_ident = syn::Ident::new(field_name, proc_macro2::Span::call_site());
let field_name_str = field_name;
let group_check = if rule.groups.is_empty() {
quote! { true }
} else {
let group_names = rule.groups.iter().map(|g| g.as_str());
quote! {
{
let rule_groups = [#(#validate_path::v2::ValidationGroup::from(#group_names)),*];
rule_groups.iter().any(|g| g.matches(&group))
}
}
};
let validation_logic = match rule.rule_type.as_str() {
"async_unique" => {
let table = rule
.params
.iter()
.find(|(k, _)| k == "table")
.map(|(_, v)| v.clone())
.unwrap_or_default();
let column = rule
.params
.iter()
.find(|(k, _)| k == "column")
.map(|(_, v)| v.clone())
.unwrap_or_default();
let message = rule
.message
.as_ref()
.map(|m| quote! { .with_message(#m) })
.unwrap_or_default();
quote! {
{
let rule = #validate_path::v2::AsyncUniqueRule::new(#table, #column) #message;
if let Err(e) = #validate_path::v2::AsyncValidationRule::validate_async(&rule, &self.#field_ident, ctx).await {
errors.add(#field_name_str, e);
}
}
}
}
"async_exists" => {
let table = rule
.params
.iter()
.find(|(k, _)| k == "table")
.map(|(_, v)| v.clone())
.unwrap_or_default();
let column = rule
.params
.iter()
.find(|(k, _)| k == "column")
.map(|(_, v)| v.clone())
.unwrap_or_default();
let message = rule
.message
.as_ref()
.map(|m| quote! { .with_message(#m) })
.unwrap_or_default();
quote! {
{
let rule = #validate_path::v2::AsyncExistsRule::new(#table, #column) #message;
if let Err(e) = #validate_path::v2::AsyncValidationRule::validate_async(&rule, &self.#field_ident, ctx).await {
errors.add(#field_name_str, e);
}
}
}
}
"async_api" => {
let endpoint = rule
.params
.iter()
.find(|(k, _)| k == "endpoint")
.map(|(_, v)| v.clone())
.unwrap_or_default();
let message = rule
.message
.as_ref()
.map(|m| quote! { .with_message(#m) })
.unwrap_or_default();
quote! {
{
let rule = #validate_path::v2::AsyncApiRule::new(#endpoint) #message;
if let Err(e) = #validate_path::v2::AsyncValidationRule::validate_async(&rule, &self.#field_ident, ctx).await {
errors.add(#field_name_str, e);
}
}
}
}
"custom_async" => {
let function_path = rule
.params
.iter()
.find(|(k, _)| k == "custom_async" || k == "function")
.map(|(_, v)| v.clone())
.unwrap_or_default();
if function_path.is_empty() {
quote! {}
} else {
let func: syn::Path = syn::parse_str(&function_path).unwrap();
let message_handling = if let Some(msg) = &rule.message {
quote! {
let e = #validate_path::v2::RuleError::new("custom_async", #msg);
errors.add(#field_name_str, e);
}
} else {
quote! {
errors.add(#field_name_str, e);
}
};
quote! {
{
if let Err(e) = #func(&self.#field_ident, ctx).await {
#message_handling
}
}
}
}
}
_ => {
quote! {}
}
};
quote! {
if #group_check {
#validation_logic
}
}
}
fn is_async_rule(rule: &ValidationRuleInfo) -> bool {
matches!(
rule.rule_type.as_str(),
"async_unique" | "async_exists" | "async_api" | "custom_async"
)
}
#[proc_macro_derive(Validate, attributes(validate))]
pub fn derive_validate(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
let name = &input.ident;
let generics = &input.generics;
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
let fields = match &input.data {
Data::Struct(data) => match &data.fields {
Fields::Named(fields) => &fields.named,
_ => {
return syn::Error::new_spanned(
&input,
"Validate can only be derived for structs with named fields",
)
.to_compile_error()
.into();
}
},
_ => {
return syn::Error::new_spanned(&input, "Validate can only be derived for structs")
.to_compile_error()
.into();
}
};
let validate_path = get_validate_path();
let core_path = get_core_path();
let async_trait_path = get_async_trait_path();
let mut sync_validations = Vec::new();
let mut async_validations = Vec::new();
let mut has_async_rules = false;
for field in fields {
let field_name = field.ident.as_ref().unwrap().to_string();
let field_type = &field.ty;
let rules = parse_validate_attrs(&field.attrs);
for rule in &rules {
if is_async_rule(rule) {
has_async_rules = true;
let validation = generate_async_rule_validation(&field_name, rule, &validate_path);
async_validations.push(validation);
} else {
let validation =
generate_rule_validation(&field_name, field_type, rule, &validate_path);
sync_validations.push(validation);
}
}
}
let validate_impl = quote! {
impl #impl_generics #validate_path::v2::Validate for #name #ty_generics #where_clause {
fn validate_with_group(&self, group: #validate_path::v2::ValidationGroup) -> Result<(), #validate_path::v2::ValidationErrors> {
let mut errors = #validate_path::v2::ValidationErrors::new();
#(#sync_validations)*
errors.into_result()
}
}
};
let async_validate_impl = if has_async_rules {
quote! {
#[#async_trait_path::async_trait]
impl #impl_generics #validate_path::v2::AsyncValidate for #name #ty_generics #where_clause {
async fn validate_async_with_group(&self, ctx: &#validate_path::v2::ValidationContext, group: #validate_path::v2::ValidationGroup) -> Result<(), #validate_path::v2::ValidationErrors> {
let mut errors = #validate_path::v2::ValidationErrors::new();
#(#async_validations)*
errors.into_result()
}
}
}
} else {
quote! {
#[#async_trait_path::async_trait]
impl #impl_generics #validate_path::v2::AsyncValidate for #name #ty_generics #where_clause {
async fn validate_async_with_group(&self, _ctx: &#validate_path::v2::ValidationContext, _group: #validate_path::v2::ValidationGroup) -> Result<(), #validate_path::v2::ValidationErrors> {
Ok(())
}
}
}
};
let validatable_impl = quote! {
impl #impl_generics #core_path::validation::Validatable for #name #ty_generics #where_clause {
fn do_validate(&self) -> Result<(), #core_path::ApiError> {
match #validate_path::v2::Validate::validate(self) {
Ok(_) => Ok(()),
Err(e) => Err(#core_path::validation::convert_v2_errors(e)),
}
}
}
};
let expanded = quote! {
#validate_impl
#async_validate_impl
#validatable_impl
};
debug_output("Validate derive", &expanded);
TokenStream::from(expanded)
}
#[proc_macro_derive(ApiError, attributes(error))]
pub fn derive_api_error(input: TokenStream) -> TokenStream {
api_error::expand_derive_api_error(input)
}
#[proc_macro_derive(TypedPath, attributes(typed_path))]
pub fn derive_typed_path(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
let name = &input.ident;
let generics = &input.generics;
let rustapi_path = get_rustapi_path();
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
let mut path_str = None;
for attr in &input.attrs {
if attr.path().is_ident("typed_path") {
if let Ok(lit) = attr.parse_args::<LitStr>() {
path_str = Some(lit.value());
}
}
}
let path = match path_str {
Some(p) => p,
None => {
return syn::Error::new_spanned(
&input,
"#[derive(TypedPath)] requires a #[typed_path(\"...\")] attribute",
)
.to_compile_error()
.into();
}
};
if let Err(err) = validate_path_syntax(&path, proc_macro2::Span::call_site()) {
return err.to_compile_error().into();
}
let mut format_string = String::new();
let mut format_args = Vec::new();
let mut chars = path.chars().peekable();
while let Some(ch) = chars.next() {
if ch == '{' {
let mut param_name = String::new();
while let Some(&c) = chars.peek() {
if c == '}' {
chars.next(); break;
}
param_name.push(chars.next().unwrap());
}
if param_name.is_empty() {
return syn::Error::new_spanned(
&input,
"Empty path parameter not allowed in typed_path",
)
.to_compile_error()
.into();
}
format_string.push_str("{}");
let ident = syn::Ident::new(¶m_name, proc_macro2::Span::call_site());
format_args.push(quote! { self.#ident });
} else {
format_string.push(ch);
}
}
let expanded = quote! {
impl #impl_generics #rustapi_path::prelude::TypedPath for #name #ty_generics #where_clause {
const PATH: &'static str = #path;
fn to_uri(&self) -> String {
format!(#format_string, #(#format_args),*)
}
}
};
debug_output("TypedPath derive", &expanded);
TokenStream::from(expanded)
}