mod args;
mod endpoints;
mod openapi;
mod types;
use args::AxumApiArgs;
use endpoints::Endpoints;
use proc_macro::TokenStream;
use quote::quote;
use syn::{DeriveInput, parse_macro_input};
use types::Types;
pub fn generate(attr: TokenStream, item: TokenStream) -> TokenStream {
let args = parse_macro_input!(attr as AxumApiArgs);
let input = parse_macro_input!(item as DeriveInput);
let state_type_name = &args.state_type;
let enable_openapi = args.enable_openapi();
let additional_components = args.components();
let server_url = args.server();
let additional_paths = args.paths();
let struct_name = &input.ident;
let vis = &input.vis;
let types = Types { enable_openapi, vis: vis.clone() };
let endpoints =
Endpoints { enable_openapi, struct_name: struct_name.clone(), vis: vis.clone() };
let api_doc = if enable_openapi {
let servers_attr = server_url
.map(|url| {
quote! {
servers(
(url = #url)
),
}
})
.unwrap_or_default();
quote! {
#[derive(::utoipa::OpenApi)]
#[openapi(
#servers_attr
paths(
create_entity,
list_entities,
get_entity_by_id,
update_entity,
patch_entity_by_id,
remove_entity,
#(#additional_paths),*
),
components(
responses(
OperationResponse,
EntitiesResponse,
ListResponse,
GetEntityResponse,
::stately::ApiError,
),
schemas(
Entity,
StateEntry,
OperationResponse,
EntitiesResponse,
EntitiesMap,
ListResponse,
GetEntityResponse,
::stately::Summary,
::stately::EntityId,
#(#additional_components),*
)
),
tags(
(name = "entity", description = "Entity management endpoints"),
)
)]
}
} else {
quote! {}
};
let expanded = quote! {
#[derive(Clone)]
#api_doc
#vis struct #struct_name {
#vis state: ::std::sync::Arc<::tokio::sync::RwLock<#state_type_name>>,
}
impl #struct_name {
#vis fn new(state: #state_type_name) -> Self {
Self {
state: ::std::sync::Arc::new(::tokio::sync::RwLock::new(state)),
}
}
#vis fn new_from_state(state: ::std::sync::Arc<::tokio::sync::RwLock<#state_type_name>>) -> Self {
Self { state }
}
}
impl #struct_name {
pub fn router<S>(state: S) -> ::axum::Router<S>
where
S: Send + Sync + Clone + 'static,
#struct_name: ::axum::extract::FromRef<S>,
{
::axum::Router::new()
.route(
"/",
::axum::routing::get(get_entities)
.put(create_entity)
.layer(::tower_http::compression::CompressionLayer::new())
)
.route(
"/list",
::axum::routing::get(list_all_entities)
.layer(::tower_http::compression::CompressionLayer::new())
)
.route(
"/list/{type}",
::axum::routing::get(list_entities)
.layer(::tower_http::compression::CompressionLayer::new())
)
.route(
"/{id}",
::axum::routing::get(get_entity_by_id)
.post(update_entity)
.patch(patch_entity_by_id)
)
.route("/{entry}/{id}", ::axum::routing::delete(remove_entity))
.with_state(state)
}
pub fn event_middleware<T>(
event_tx: ::tokio::sync::mpsc::Sender<T>
) -> impl Fn(::axum::http::Request<::axum::body::Body>, ::axum::middleware::Next) -> std::pin::Pin<Box<dyn std::future::Future<Output = ::axum::response::Response> + Send>> + Clone
where
T: From<ResponseEvent> + Send + 'static,
{
move |req: ::axum::http::Request<::axum::body::Body>, next: ::axum::middleware::Next| {
let tx = event_tx.clone();
Box::pin(async move {
let response = next.run(req).await;
if let Some(event) = response.extensions().get::<ResponseEvent>() {
let converted: T = event.clone().into();
let _ = tx.send(converted).await;
}
response
})
}
}
}
#types
#endpoints
};
TokenStream::from(expanded)
}