use proc_macro2::TokenStream as TokenStream2;
use quote::{format_ident, quote};
use syn::{Ident, Item, ItemFn, Path, Result, Type};
use crate::route::DataKind;
pub(crate) fn generate_pipeline_wrapper(
apigate_path: &TokenStream2,
f: &ItemFn,
hooks: &[Path],
data: &DataKind,
map_fn: Option<&Path>,
path_type: Option<&Type>,
generated_items: &mut Vec<Item>,
) -> Result<TokenStream2> {
let has_hooks = !hooks.is_empty();
let has_body = matches!(
data,
DataKind::Json(_) | DataKind::Query(_) | DataKind::Form(_)
);
let has_path = path_type.is_some();
if !has_hooks && !has_body && map_fn.is_none() && !has_path {
return Ok(quote!(None));
}
let pipeline_ident = format_ident!("__apigate_pipeline_{}", f.sig.ident);
let path_phase = build_path_phase(path_type);
let hook_phase = if has_hooks {
let calls = hooks
.iter()
.map(|h| quote!(#h(&mut ctx, &mut scope).await?;));
quote!(#(#calls)*)
} else {
quote!()
};
let body_phase = build_body_phase(apigate_path, data, map_fn, &f.sig.ident)?;
let item: Item = syn::parse_quote! {
#[doc(hidden)]
fn #pipeline_ident<'a>(
mut ctx: #apigate_path::PartsCtx<'a>,
mut scope: #apigate_path::RequestScope,
) -> #apigate_path::PipelineFuture<'a> {
::std::boxed::Box::pin(async move {
#path_phase
#hook_phase
#body_phase
})
}
};
generated_items.push(item);
Ok(quote!(Some(#pipeline_ident as #apigate_path::PipelineFn)))
}
fn build_path_phase(path_type: Option<&Type>) -> TokenStream2 {
match path_type {
Some(ty) => quote! {
let __apigate_path_value: #ty = ctx.extract_path::<#ty>().await?;
scope.insert(__apigate_path_value);
},
None => quote!(),
}
}
fn build_body_phase(
apigate_path: &TokenStream2,
data: &DataKind,
map_fn: Option<&Path>,
route_fn_ident: &Ident,
) -> Result<TokenStream2> {
match (map_fn, data) {
(None, DataKind::None | DataKind::Multipart) => {
let take = take_body_expr(apigate_path);
Ok(quote!(#take))
}
(Some(_), DataKind::Multipart) => Err(syn::Error::new_spanned(
route_fn_ident,
"`map` is not supported with `multipart`",
)),
(Some(_), DataKind::None) => Err(syn::Error::new_spanned(
route_fn_ident,
"`map` requires one of `query = T`, `json = T`, or `form = T`",
)),
(map_fn, DataKind::Json(ty)) => Ok(json_phase(apigate_path, ty, map_fn)),
(map_fn, DataKind::Query(ty)) => Ok(query_phase(apigate_path, ty, map_fn)),
(map_fn, DataKind::Form(ty)) => Ok(form_phase(apigate_path, ty, map_fn)),
}
}
fn take_body_expr(apigate_path: &TokenStream2) -> TokenStream2 {
quote! {
scope.take_body()
.ok_or_else(|| #apigate_path::ApigateError::internal("request body already consumed"))
}
}
fn take_body_let(apigate_path: &TokenStream2) -> TokenStream2 {
let take = take_body_expr(apigate_path);
quote!(let body = #take?;)
}
fn read_body_bytes(apigate_path: &TokenStream2) -> TokenStream2 {
let take = take_body_let(apigate_path);
quote! {
#take
let limit = scope.body_limit();
let bytes = #apigate_path::__private::axum::body::to_bytes(body, limit)
.await
.map_err(|_| #apigate_path::ApigateError::payload_too_large("request body is too large"))?;
}
}
fn json_phase(apigate_path: &TokenStream2, ty: &Type, map_fn: Option<&Path>) -> TokenStream2 {
let read = read_body_bytes(apigate_path);
match map_fn {
Some(map_fn) => quote! {
#read
let input: #ty = #apigate_path::__private::serde_json::from_slice(&bytes)
.map_err(|_| #apigate_path::ApigateError::bad_request("invalid json body"))?;
let output = #map_fn(input, &mut ctx, &mut scope).await?;
let new_body = #apigate_path::__private::serde_json::to_vec(&output)
.map_err(|_| #apigate_path::ApigateError::internal("failed to serialize mapped json"))?;
ctx.headers_mut().insert(
#apigate_path::__private::http::header::CONTENT_TYPE,
#apigate_path::__private::http::HeaderValue::from_static("application/json"),
);
ctx.headers_mut().remove(#apigate_path::__private::http::header::CONTENT_LENGTH);
Ok(#apigate_path::__private::axum::body::Body::from(new_body))
},
None => quote! {
#read
let _: #ty = #apigate_path::__private::serde_json::from_slice(&bytes)
.map_err(|_| #apigate_path::ApigateError::bad_request("invalid json body"))?;
Ok(#apigate_path::__private::axum::body::Body::from(bytes))
},
}
}
fn query_phase(apigate_path: &TokenStream2, ty: &Type, map_fn: Option<&Path>) -> TokenStream2 {
let take = take_body_expr(apigate_path);
match map_fn {
Some(map_fn) => quote! {
let input: #ty = #apigate_path::__private::axum::extract::Query::<#ty>::try_from_uri(ctx.uri())
.map_err(|_| #apigate_path::ApigateError::bad_request("invalid query"))?
.0;
let output = #map_fn(input, &mut ctx, &mut scope).await?;
let encoded = #apigate_path::__private::serde_urlencoded::to_string(&output)
.map_err(|_| #apigate_path::ApigateError::internal("failed to serialize mapped query"))?;
let path = ctx.uri().path().to_string();
let mut path_and_query = path;
if !encoded.is_empty() {
path_and_query.push('?');
path_and_query.push_str(&encoded);
}
*ctx.uri_mut() = #apigate_path::__private::http::Uri::builder()
.path_and_query(path_and_query)
.build()
.map_err(|_| #apigate_path::ApigateError::internal("failed to rebuild uri"))?;
#take
},
None => quote! {
let _: #ty = #apigate_path::__private::axum::extract::Query::<#ty>::try_from_uri(ctx.uri())
.map_err(|_| #apigate_path::ApigateError::bad_request("invalid query"))?
.0;
#take
},
}
}
fn form_phase(apigate_path: &TokenStream2, ty: &Type, map_fn: Option<&Path>) -> TokenStream2 {
let take = take_body_expr(apigate_path);
let get_branch = form_get_branch(apigate_path, ty, map_fn, &take);
let post_branch = form_post_branch(apigate_path, ty, map_fn);
quote! {
let method = ctx.method().clone();
if method == #apigate_path::__private::http::Method::GET
|| method == #apigate_path::__private::http::Method::HEAD
{
#get_branch
} else {
let content_type = ctx
.headers()
.get(#apigate_path::__private::http::header::CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.unwrap_or_default();
if !content_type.starts_with("application/x-www-form-urlencoded") {
return Err(#apigate_path::ApigateError::unsupported_media_type(
"expected application/x-www-form-urlencoded",
));
}
#post_branch
}
}
}
fn form_get_branch(
apigate_path: &TokenStream2,
ty: &Type,
map_fn: Option<&Path>,
take: &TokenStream2,
) -> TokenStream2 {
match map_fn {
Some(map_fn) => quote! {
let raw = ctx.uri().query().unwrap_or_default();
let input: #ty = #apigate_path::__private::serde_urlencoded::from_str(raw)
.map_err(|_| #apigate_path::ApigateError::bad_request("invalid form query"))?;
let output = #map_fn(input, &mut ctx, &mut scope).await?;
let encoded = #apigate_path::__private::serde_urlencoded::to_string(&output)
.map_err(|_| #apigate_path::ApigateError::internal("failed to serialize mapped form"))?;
let path = ctx.uri().path().to_string();
let mut path_and_query = path;
if !encoded.is_empty() {
path_and_query.push('?');
path_and_query.push_str(&encoded);
}
*ctx.uri_mut() = #apigate_path::__private::http::Uri::builder()
.path_and_query(path_and_query)
.build()
.map_err(|_| #apigate_path::ApigateError::internal("failed to rebuild uri"))?;
#take
},
None => quote! {
let raw = ctx.uri().query().unwrap_or_default();
let _: #ty = #apigate_path::__private::serde_urlencoded::from_str(raw)
.map_err(|_| #apigate_path::ApigateError::bad_request("invalid form query"))?;
#take
},
}
}
fn form_post_branch(apigate_path: &TokenStream2, ty: &Type, map_fn: Option<&Path>) -> TokenStream2 {
let read = read_body_bytes(apigate_path);
match map_fn {
Some(map_fn) => quote! {
#read
let input: #ty = #apigate_path::__private::serde_urlencoded::from_bytes(&bytes)
.map_err(|_| #apigate_path::ApigateError::bad_request("invalid form body"))?;
let output = #map_fn(input, &mut ctx, &mut scope).await?;
let encoded = #apigate_path::__private::serde_urlencoded::to_string(&output)
.map_err(|_| #apigate_path::ApigateError::internal("failed to serialize mapped form"))?;
ctx.headers_mut().insert(
#apigate_path::__private::http::header::CONTENT_TYPE,
#apigate_path::__private::http::HeaderValue::from_static("application/x-www-form-urlencoded"),
);
ctx.headers_mut().remove(#apigate_path::__private::http::header::CONTENT_LENGTH);
Ok(#apigate_path::__private::axum::body::Body::from(encoded))
},
None => quote! {
#read
let _: #ty = #apigate_path::__private::serde_urlencoded::from_bytes(&bytes)
.map_err(|_| #apigate_path::ApigateError::bad_request("invalid form body"))?;
Ok(#apigate_path::__private::axum::body::Body::from(bytes))
},
}
}