use proc_macro::TokenStream;
use proc_macro2::TokenStream as TokenStream2;
use quote::quote;
use syn::{Attribute, Block, Expr, ExprLit, Lit, LitInt, LitStr};
fn is_multi_value_header(name: &str) -> bool {
matches!(name, "set-cookie")
}
pub(crate) fn passthrough(_args: TokenStream, input: TokenStream) -> TokenStream {
input
}
#[derive(Default)]
pub(crate) struct ResponseShapers {
pub http_code: Option<LitInt>,
pub headers: Vec<(LitStr, LitStr)>,
pub redirect: Option<RedirectSpec>,
}
pub(crate) struct RedirectSpec {
pub url: LitStr,
pub code: Option<LitInt>,
pub attr: Attribute,
}
impl ResponseShapers {
pub fn is_empty(&self) -> bool {
self.http_code.is_none() && self.headers.is_empty() && self.redirect.is_none()
}
}
pub(crate) fn take_response_shapers(
attrs: &mut Vec<Attribute>,
body: &Block,
) -> syn::Result<ResponseShapers> {
let mut out = ResponseShapers::default();
while let Some(idx) = attrs.iter().position(|a| a.path().is_ident("http_code")) {
if out.http_code.is_some() {
return Err(syn::Error::new_spanned(
&attrs[idx],
"`#[http_code]` is allowed at most once per handler",
));
}
let attr = attrs.remove(idx);
let lit = attr.parse_args::<LitInt>()?;
let n: u16 = lit.base10_parse().map_err(|e| {
syn::Error::new_spanned(&lit, format!("`#[http_code]` expects a u16: {e}"))
})?;
if !(100..=999).contains(&n) {
return Err(syn::Error::new_spanned(
&lit,
"`#[http_code]` expects a status in 100..=999",
));
}
out.http_code = Some(lit);
}
while let Some(idx) = attrs
.iter()
.position(|a| a.path().is_ident("response_header"))
{
let attr = attrs.remove(idx);
let (name, value) = parse_header_args(&attr)?;
validate_header_name(&name)?;
validate_header_value(&value)?;
out.headers.push((name, value));
}
while let Some(idx) = attrs.iter().position(|a| a.path().is_ident("redirect")) {
if out.redirect.is_some() {
return Err(syn::Error::new_spanned(
&attrs[idx],
"`#[redirect]` is allowed at most once per handler",
));
}
let attr = attrs.remove(idx);
let spec = parse_redirect_args(&attr)?;
out.redirect = Some(spec);
}
if let (Some(_), Some(r)) = (&out.http_code, &out.redirect) {
return Err(syn::Error::new_spanned(
&r.attr,
"`#[redirect]` and `#[http_code]` are mutually exclusive — \
`#[redirect]` sets the status itself",
));
}
if out.redirect.is_some()
&& let Some((name_lit, _)) = out
.headers
.iter()
.find(|(n, _)| n.value().eq_ignore_ascii_case("location"))
{
return Err(syn::Error::new_spanned(
name_lit,
"`#[response_header(\"location\", …)]` cannot be combined with \
`#[redirect]` — the redirect URL already sets the Location header",
));
}
if let Some(spec) = &out.redirect
&& !body.stmts.is_empty()
{
let url = spec.url.value();
return Err(syn::Error::new_spanned(
body,
format!(
"`#[redirect({url:?})]` handlers must have an empty body — \
the method is not called, only the redirect URL is sent. \
Move side-effect work into a service the user is redirected to, \
or drop the body to opt in."
),
));
}
Ok(out)
}
fn parse_header_args(attr: &Attribute) -> syn::Result<(LitStr, LitStr)> {
use syn::Token;
use syn::punctuated::Punctuated;
let list: Punctuated<LitStr, Token![,]> = attr.parse_args_with(Punctuated::parse_terminated)?;
let mut iter = list.into_iter();
let name = iter.next().ok_or_else(|| {
syn::Error::new_spanned(
attr,
"`#[response_header]` expects two string literals: `name, value`",
)
})?;
let value = iter.next().ok_or_else(|| {
syn::Error::new_spanned(
attr,
"`#[response_header]` expects two string literals: `name, value`",
)
})?;
if iter.next().is_some() {
return Err(syn::Error::new_spanned(
attr,
"`#[response_header]` accepts exactly two arguments: `name, value`",
));
}
Ok((name, value))
}
fn validate_header_name(lit: &LitStr) -> syn::Result<()> {
let s = lit.value();
if s.is_empty() {
return Err(syn::Error::new_spanned(
lit,
"`#[response_header]` header name cannot be empty",
));
}
for c in s.bytes() {
let ok = matches!(c,
b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_'
);
if !ok {
return Err(syn::Error::new_spanned(
lit,
"`#[response_header]` header name must be lowercase ASCII \
(a-z, 0-9, `-`, `_`)",
));
}
}
Ok(())
}
fn validate_redirect_url(lit: &LitStr) -> syn::Result<()> {
for b in lit.value().bytes() {
if !(0x21..=0x7e).contains(&b) {
return Err(syn::Error::new_spanned(
lit,
format!(
"`#[redirect]` URL contains a non-printable-ASCII byte \
0x{b:02x}; percent-encode it or use ASCII (RFC 3986)"
),
));
}
}
Ok(())
}
fn validate_header_value(lit: &LitStr) -> syn::Result<()> {
let s = lit.value();
for c in s.bytes() {
let ok = c == b'\t' || (0x20..=0x7e).contains(&c);
if !ok {
return Err(syn::Error::new_spanned(
lit,
"`#[response_header]` header value must be printable ASCII \
(no CR/LF, no control bytes)",
));
}
}
Ok(())
}
fn parse_redirect_args(attr: &Attribute) -> syn::Result<RedirectSpec> {
use syn::Token;
use syn::punctuated::Punctuated;
let list: Punctuated<Expr, Token![,]> = attr.parse_args_with(Punctuated::parse_terminated)?;
let mut iter = list.into_iter();
let url_expr = iter.next().ok_or_else(|| {
syn::Error::new_spanned(
attr,
"`#[redirect]` expects a URL literal: `#[redirect(\"…\")]` \
or `#[redirect(\"…\", 301)]`",
)
})?;
let url = match url_expr {
Expr::Lit(ExprLit {
lit: Lit::Str(s), ..
}) => s,
other => {
return Err(syn::Error::new_spanned(
other,
"`#[redirect]` URL must be a string literal",
));
}
};
validate_redirect_url(&url)?;
let code = match iter.next() {
None => None,
Some(expr) => {
let lit = match expr {
Expr::Lit(ExprLit {
lit: Lit::Int(i), ..
}) => i,
other => {
return Err(syn::Error::new_spanned(
other,
"`#[redirect]` status code must be an integer literal",
));
}
};
let n: u16 = lit.base10_parse().map_err(|e| {
syn::Error::new_spanned(&lit, format!("`#[redirect]` status not a u16: {e}"))
})?;
if !(300..=399).contains(&n) {
return Err(syn::Error::new_spanned(
&lit,
"`#[redirect]` status must be in 300..=399",
));
}
Some(lit)
}
};
if iter.next().is_some() {
return Err(syn::Error::new_spanned(
attr,
"`#[redirect]` accepts at most two arguments: `url[, status]`",
));
}
Ok(RedirectSpec {
url,
code,
attr: attr.clone(),
})
}
pub(crate) fn apply_response_shapers(
shapers: &ResponseShapers,
call_expr: TokenStream2,
wrapper_args: &[syn::Ident],
returns_result: bool,
) -> TokenStream2 {
if let Some(redirect) = &shapers.redirect {
let url = &redirect.url;
let status_lit = match &redirect.code {
Some(lit) => quote! { #lit },
None => quote! { 307u16 },
};
let header_writes = headers_tokens(&shapers.headers);
return quote! {
{
let _ = (#(&#wrapper_args,)*);
let mut __response: ::poem::Response =
::poem::Response::builder()
.status(
::poem::http::StatusCode::from_u16(#status_lit)
.expect("redirect status validated at compile time"),
)
.header(::poem::http::header::LOCATION, #url)
.finish();
#header_writes
::poem::Result::<::poem::Response>::Ok(__response)
}
};
}
let status_apply = match &shapers.http_code {
Some(lit) => quote! {
__response.set_status(
::poem::http::StatusCode::from_u16(#lit)
.expect("status validated at compile time"),
);
},
None => quote! {},
};
let header_writes = headers_tokens(&shapers.headers);
let unwrap_ok = if returns_result {
quote! {
let __ok = match __out {
::core::result::Result::Ok(v) => v,
::core::result::Result::Err(e) => {
return ::core::result::Result::Err(::core::convert::From::from(e));
}
};
}
} else {
quote! { let __ok = __out; }
};
quote! {
{
let __out = #call_expr;
#unwrap_ok
let mut __response: ::poem::Response =
::poem::IntoResponse::into_response(__ok);
#status_apply
#header_writes
::poem::Result::<::poem::Response>::Ok(__response)
}
}
}
fn headers_tokens(headers: &[(LitStr, LitStr)]) -> TokenStream2 {
if headers.is_empty() {
return quote! {};
}
let writes = headers.iter().map(|(name, value)| {
let method = if is_multi_value_header(&name.value()) {
quote! { append }
} else {
quote! { insert }
};
quote! {
__response.headers_mut().#method(
::poem::http::HeaderName::from_static(#name),
::poem::http::HeaderValue::from_static(#value),
);
}
});
quote! { #(#writes)* }
}