use convert_case::{Case, Casing};
use proc_macro2::TokenStream;
use quote::{format_ident, quote};
use crate::entity::parse::{CommandDef, CommandKindHint, EntityDef};
pub fn generate(entity: &EntityDef) -> TokenStream {
let commands = entity.command_defs();
if commands.is_empty() {
return TokenStream::new();
}
let handlers: Vec<TokenStream> = commands
.iter()
.map(|cmd| generate_handler(entity, cmd))
.collect();
quote! { #(#handlers)* }
}
fn generate_handler(entity: &EntityDef, cmd: &CommandDef) -> TokenStream {
let entity_name = entity.name();
let entity_name_str = entity.name_str();
let api_config = entity.api_config();
let handler_name = handler_function_name(entity, cmd);
let handler_method = cmd.handler_method_name();
let command_struct = cmd.struct_name(&entity_name_str);
let handler_trait = format_ident!("{}CommandHandler", entity_name);
let path = build_path(entity, cmd);
let http_method = http_method_for_command(cmd);
let http_method_ident = format_ident!("{}", http_method);
let tag = api_config.tag_or_default(&entity_name_str);
let security_attr = if cmd.is_public() {
quote! {}
} else if let Some(cmd_security) = cmd.security() {
let security_name = security_scheme_name(cmd_security);
quote! { security(#security_name = []) }
} else if api_config.is_public_command(&cmd.name.to_string()) {
quote! {}
} else if let Some(security) = &api_config.security {
let security_name = security_scheme_name(security);
quote! { security(#security_name = []) }
} else {
quote! {}
};
let (response_type, response_body) = response_type_for_command(entity, cmd);
let deprecated_attr = if api_config.is_deprecated() {
quote! { , deprecated = true }
} else {
quote! {}
};
let utoipa_attr = if security_attr.is_empty() {
quote! {
#[utoipa::path(
#http_method_ident,
path = #path,
tag = #tag,
request_body = #command_struct,
responses(
(status = 200, body = #response_body, description = "Success"),
(status = 400, description = "Validation error"),
(status = 500, description = "Internal server error")
)
#deprecated_attr
)]
}
} else {
quote! {
#[utoipa::path(
#http_method_ident,
path = #path,
tag = #tag,
request_body = #command_struct,
responses(
(status = 200, body = #response_body, description = "Success"),
(status = 400, description = "Validation error"),
(status = 401, description = "Unauthorized"),
(status = 500, description = "Internal server error")
),
#security_attr
#deprecated_attr
)]
}
};
if cmd.requires_id {
generate_handler_with_id(
entity,
cmd,
&handler_name,
&handler_method,
&command_struct,
&handler_trait,
&response_type,
&utoipa_attr
)
} else {
generate_handler_without_id(
entity,
cmd,
&handler_name,
&handler_method,
&command_struct,
&handler_trait,
&response_type,
&utoipa_attr
)
}
}
#[allow(clippy::too_many_arguments)]
fn generate_handler_without_id(
entity: &EntityDef,
cmd: &CommandDef,
handler_name: &syn::Ident,
handler_method: &syn::Ident,
command_struct: &syn::Ident,
handler_trait: &syn::Ident,
response_type: &TokenStream,
utoipa_attr: &TokenStream
) -> TokenStream {
let vis = &entity.vis;
let doc = format!(
"HTTP handler for {} command.\n\n\
Generated by entity-derive.",
cmd.name
);
quote! {
#[doc = #doc]
#utoipa_attr
#vis async fn #handler_name<H>(
axum::extract::Extension(handler): axum::extract::Extension<std::sync::Arc<H>>,
axum::extract::Json(cmd): axum::extract::Json<#command_struct>,
) -> Result<axum::response::Json<#response_type>, axum::http::StatusCode>
where
H: #handler_trait + 'static,
H::Context: Default,
{
let ctx = H::Context::default();
match handler.#handler_method(cmd, &ctx).await {
Ok(result) => Ok(axum::response::Json(result)),
Err(_) => Err(axum::http::StatusCode::INTERNAL_SERVER_ERROR),
}
}
}
}
#[allow(clippy::too_many_arguments)]
fn generate_handler_with_id(
entity: &EntityDef,
cmd: &CommandDef,
handler_name: &syn::Ident,
handler_method: &syn::Ident,
command_struct: &syn::Ident,
handler_trait: &syn::Ident,
response_type: &TokenStream,
utoipa_attr: &TokenStream
) -> TokenStream {
let vis = &entity.vis;
let id_field = entity.id_field();
let id_type = &id_field.ty;
let doc = format!(
"HTTP handler for {} command.\n\n\
Generated by entity-derive.",
cmd.name
);
quote! {
#[doc = #doc]
#utoipa_attr
#vis async fn #handler_name<H>(
axum::extract::Extension(handler): axum::extract::Extension<std::sync::Arc<H>>,
axum::extract::Path(id): axum::extract::Path<#id_type>,
axum::extract::Json(mut cmd): axum::extract::Json<#command_struct>,
) -> Result<axum::response::Json<#response_type>, axum::http::StatusCode>
where
H: #handler_trait + 'static,
H::Context: Default,
{
cmd.id = id;
let ctx = H::Context::default();
match handler.#handler_method(cmd, &ctx).await {
Ok(result) => Ok(axum::response::Json(result)),
Err(_) => Err(axum::http::StatusCode::INTERNAL_SERVER_ERROR),
}
}
}
}
fn handler_function_name(entity: &EntityDef, cmd: &CommandDef) -> syn::Ident {
let entity_snake = entity.name_str().to_case(Case::Snake);
let cmd_snake = cmd.name.to_string().to_case(Case::Snake);
format_ident!("{}_{}", cmd_snake, entity_snake)
}
fn build_path(entity: &EntityDef, cmd: &CommandDef) -> String {
let api_config = entity.api_config();
let prefix = api_config.full_path_prefix();
let entity_path = entity.name_str().to_case(Case::Kebab);
let cmd_path = cmd.name.to_string().to_case(Case::Kebab);
if cmd.requires_id {
format!("{}/{}/{{id}}/{}", prefix, entity_path, cmd_path)
} else {
format!("{}/{}/{}", prefix, entity_path, cmd_path)
}
}
fn http_method_for_command(cmd: &CommandDef) -> &'static str {
match cmd.kind {
CommandKindHint::Create => "post",
CommandKindHint::Update => "put",
CommandKindHint::Delete => "delete",
CommandKindHint::Custom => "post"
}
}
fn security_scheme_name(scheme: &str) -> &'static str {
match scheme {
"bearer" => "bearer_auth",
"api_key" => "api_key",
"admin" => "admin_auth",
"oauth2" => "oauth2",
_ => "bearer_auth"
}
}
fn response_type_for_command(entity: &EntityDef, cmd: &CommandDef) -> (TokenStream, TokenStream) {
let entity_name = entity.name();
if let Some(ref result_type) = cmd.result_type {
(quote! { #result_type }, quote! { #result_type })
} else {
match cmd.kind {
CommandKindHint::Delete => (quote! { () }, quote! { () }),
_ => (quote! { #entity_name }, quote! { #entity_name })
}
}
}
#[cfg(test)]
mod tests {
use proc_macro2::Span;
use syn::Ident;
use super::*;
use crate::entity::parse::{CommandDef, CommandKindHint, CommandSource, EntityDef};
fn create_test_command(name: &str, requires_id: bool, kind: CommandKindHint) -> CommandDef {
CommandDef {
name: Ident::new(name, Span::call_site()),
source: CommandSource::Create,
requires_id,
result_type: None,
kind,
security: None
}
}
fn create_command_with_security(
name: &str,
requires_id: bool,
kind: CommandKindHint,
security: Option<String>
) -> CommandDef {
CommandDef {
name: Ident::new(name, Span::call_site()),
source: CommandSource::Create,
requires_id,
result_type: None,
kind,
security
}
}
fn create_command_with_result(
name: &str,
kind: CommandKindHint,
result_type: syn::Type
) -> CommandDef {
CommandDef {
name: Ident::new(name, Span::call_site()),
source: CommandSource::Create,
requires_id: false,
result_type: Some(result_type),
kind,
security: None
}
}
#[test]
fn http_method_create() {
let cmd = create_test_command("Register", false, CommandKindHint::Create);
assert_eq!(http_method_for_command(&cmd), "post");
}
#[test]
fn http_method_update() {
let cmd = create_test_command("Update", true, CommandKindHint::Update);
assert_eq!(http_method_for_command(&cmd), "put");
}
#[test]
fn http_method_delete() {
let cmd = create_test_command("Delete", true, CommandKindHint::Delete);
assert_eq!(http_method_for_command(&cmd), "delete");
}
#[test]
fn http_method_custom() {
let cmd = create_test_command("Transfer", false, CommandKindHint::Custom);
assert_eq!(http_method_for_command(&cmd), "post");
}
#[test]
fn security_scheme_bearer() {
assert_eq!(security_scheme_name("bearer"), "bearer_auth");
}
#[test]
fn security_scheme_api_key() {
assert_eq!(security_scheme_name("api_key"), "api_key");
}
#[test]
fn security_scheme_admin() {
assert_eq!(security_scheme_name("admin"), "admin_auth");
}
#[test]
fn security_scheme_oauth2() {
assert_eq!(security_scheme_name("oauth2"), "oauth2");
}
#[test]
fn security_scheme_unknown_defaults_to_bearer() {
assert_eq!(security_scheme_name("unknown"), "bearer_auth");
}
#[test]
fn handler_function_name_simple() {
let input: syn::DeriveInput = syn::parse_quote! {
#[entity(table = "users", commands, api(tag = "Users"))]
#[command(Register)]
pub struct User {
#[id]
pub id: uuid::Uuid,
}
};
let entity = EntityDef::from_derive_input(&input).unwrap();
let cmd = create_test_command("Register", false, CommandKindHint::Create);
let name = handler_function_name(&entity, &cmd);
assert_eq!(name.to_string(), "register_user");
}
#[test]
fn handler_function_name_camel_case() {
let input: syn::DeriveInput = syn::parse_quote! {
#[entity(table = "users", commands, api(tag = "Users"))]
#[command(UpdateEmail: email)]
pub struct User {
#[id]
pub id: uuid::Uuid,
}
};
let entity = EntityDef::from_derive_input(&input).unwrap();
let cmd = create_test_command("UpdateEmail", true, CommandKindHint::Update);
let name = handler_function_name(&entity, &cmd);
assert_eq!(name.to_string(), "update_email_user");
}
#[test]
fn build_path_without_id() {
let input: syn::DeriveInput = syn::parse_quote! {
#[entity(table = "users", commands, api(tag = "Users"))]
#[command(Register)]
pub struct User {
#[id]
pub id: uuid::Uuid,
}
};
let entity = EntityDef::from_derive_input(&input).unwrap();
let cmd = create_test_command("Register", false, CommandKindHint::Create);
let path = build_path(&entity, &cmd);
assert_eq!(path, "/user/register");
}
#[test]
fn build_path_with_id() {
let input: syn::DeriveInput = syn::parse_quote! {
#[entity(table = "users", commands, api(tag = "Users"))]
#[command(UpdateEmail: email)]
pub struct User {
#[id]
pub id: uuid::Uuid,
}
};
let entity = EntityDef::from_derive_input(&input).unwrap();
let cmd = create_test_command("UpdateEmail", true, CommandKindHint::Update);
let path = build_path(&entity, &cmd);
assert_eq!(path, "/user/{id}/update-email");
}
#[test]
fn build_path_with_prefix() {
let input: syn::DeriveInput = syn::parse_quote! {
#[entity(table = "users", commands, api(tag = "Users", path_prefix = "/api/v1"))]
#[command(Register)]
pub struct User {
#[id]
pub id: uuid::Uuid,
}
};
let entity = EntityDef::from_derive_input(&input).unwrap();
let cmd = create_test_command("Register", false, CommandKindHint::Create);
let path = build_path(&entity, &cmd);
assert_eq!(path, "/api/v1/user/register");
}
#[test]
fn response_type_delete() {
let input: syn::DeriveInput = syn::parse_quote! {
#[entity(table = "users", commands, api(tag = "Users"))]
#[command(Delete, requires_id, kind = "delete")]
pub struct User {
#[id]
pub id: uuid::Uuid,
}
};
let entity = EntityDef::from_derive_input(&input).unwrap();
let cmd = create_test_command("Delete", true, CommandKindHint::Delete);
let (resp_type, _) = response_type_for_command(&entity, &cmd);
assert_eq!(resp_type.to_string(), "()");
}
#[test]
fn response_type_create() {
let input: syn::DeriveInput = syn::parse_quote! {
#[entity(table = "users", commands, api(tag = "Users"))]
#[command(Register)]
pub struct User {
#[id]
pub id: uuid::Uuid,
}
};
let entity = EntityDef::from_derive_input(&input).unwrap();
let cmd = create_test_command("Register", false, CommandKindHint::Create);
let (resp_type, _) = response_type_for_command(&entity, &cmd);
assert_eq!(resp_type.to_string(), "User");
}
#[test]
fn response_type_custom_result() {
let input: syn::DeriveInput = syn::parse_quote! {
#[entity(table = "users", commands, api(tag = "Users"))]
#[command(Register)]
pub struct User {
#[id]
pub id: uuid::Uuid,
}
};
let entity = EntityDef::from_derive_input(&input).unwrap();
let result_type: syn::Type = syn::parse_quote!(CustomResult);
let cmd = create_command_with_result("Transfer", CommandKindHint::Custom, result_type);
let (resp_type, _) = response_type_for_command(&entity, &cmd);
assert_eq!(resp_type.to_string(), "CustomResult");
}
#[test]
fn generate_empty_for_no_commands() {
let input: syn::DeriveInput = syn::parse_quote! {
#[entity(table = "users", api(tag = "Users"))]
pub struct User {
#[id]
pub id: uuid::Uuid,
}
};
let entity = EntityDef::from_derive_input(&input).unwrap();
let output = generate(&entity);
assert!(output.is_empty());
}
#[test]
fn generate_handler_without_id() {
let input: syn::DeriveInput = syn::parse_quote! {
#[entity(table = "users", commands, api(tag = "Users"))]
#[command(Register)]
pub struct User {
#[id]
pub id: uuid::Uuid,
}
};
let entity = EntityDef::from_derive_input(&input).unwrap();
let cmd = create_test_command("Register", false, CommandKindHint::Create);
let output = generate_handler(&entity, &cmd);
let output_str = output.to_string();
assert!(output_str.contains("register_user"));
assert!(output_str.contains("UserCommandHandler"));
}
#[test]
fn generate_handler_with_id() {
let input: syn::DeriveInput = syn::parse_quote! {
#[entity(table = "users", commands, api(tag = "Users"))]
#[command(UpdateEmail: email)]
pub struct User {
#[id]
pub id: uuid::Uuid,
}
};
let entity = EntityDef::from_derive_input(&input).unwrap();
let cmd = create_test_command("UpdateEmail", true, CommandKindHint::Update);
let output = generate_handler(&entity, &cmd);
let output_str = output.to_string();
assert!(output_str.contains("update_email_user"));
assert!(output_str.contains("Path"));
assert!(output_str.contains("cmd . id = id"));
}
#[test]
fn generate_handler_with_security() {
let input: syn::DeriveInput = syn::parse_quote! {
#[entity(table = "users", commands, api(tag = "Users", security = "bearer"))]
#[command(Register)]
pub struct User {
#[id]
pub id: uuid::Uuid,
}
};
let entity = EntityDef::from_derive_input(&input).unwrap();
let cmd = create_test_command("Register", false, CommandKindHint::Create);
let output = generate_handler(&entity, &cmd);
let output_str = output.to_string();
assert!(output_str.contains("security"));
assert!(output_str.contains("bearer_auth"));
}
#[test]
fn generate_handler_public_command() {
let input: syn::DeriveInput = syn::parse_quote! {
#[entity(table = "users", commands, api(tag = "Users", security = "bearer"))]
#[command(Register)]
pub struct User {
#[id]
pub id: uuid::Uuid,
}
};
let entity = EntityDef::from_derive_input(&input).unwrap();
let cmd = create_command_with_security(
"Register",
false,
CommandKindHint::Create,
Some("none".to_string())
);
let output = generate_handler(&entity, &cmd);
let output_str = output.to_string();
assert!(!output_str.contains("security"));
}
#[test]
fn generate_handler_command_level_security() {
let input: syn::DeriveInput = syn::parse_quote! {
#[entity(table = "users", commands, api(tag = "Users"))]
#[command(AdminDelete, requires_id, security = "admin")]
pub struct User {
#[id]
pub id: uuid::Uuid,
}
};
let entity = EntityDef::from_derive_input(&input).unwrap();
let cmd = create_command_with_security(
"AdminDelete",
true,
CommandKindHint::Delete,
Some("admin".to_string())
);
let output = generate_handler(&entity, &cmd);
let output_str = output.to_string();
assert!(output_str.contains("admin_auth"));
}
#[test]
fn generate_handler_deprecated() {
let input: syn::DeriveInput = syn::parse_quote! {
#[entity(table = "users", commands, api(tag = "Users", deprecated_in = "2.0"))]
#[command(Register)]
pub struct User {
#[id]
pub id: uuid::Uuid,
}
};
let entity = EntityDef::from_derive_input(&input).unwrap();
let cmd = create_test_command("Register", false, CommandKindHint::Create);
let output = generate_handler(&entity, &cmd);
let output_str = output.to_string();
assert!(output_str.contains("deprecated = true"));
}
#[test]
fn generate_all_handlers() {
let input: syn::DeriveInput = syn::parse_quote! {
#[entity(table = "users", commands, api(tag = "Users"))]
#[command(Register)]
#[command(UpdateEmail: email)]
pub struct User {
#[id]
pub id: uuid::Uuid,
}
};
let entity = EntityDef::from_derive_input(&input).unwrap();
let output = generate(&entity);
let output_str = output.to_string();
assert!(output_str.contains("register_user"));
assert!(output_str.contains("update_email_user"));
}
}