use proc_macro::TokenStream;
use quote::quote;
use syn::{
parse::{Parse, ParseStream},
parse_macro_input, Expr, Ident, ItemFn, LitStr, Token,
};
enum BodyMode {
None,
Json,
Query,
Encrypted(Expr),
EncryptedQuery(Expr),
}
struct MechanismArgs {
server: Ident,
method: Ident,
path: LitStr,
state: Option<Expr>,
body_mode: BodyMode,
}
impl Parse for MechanismArgs {
fn parse(input: ParseStream) -> syn::Result<Self> {
let server: Ident = input.parse()?;
input.parse::<Token![,]>()?;
let method: Ident = input.parse()?;
input.parse::<Token![,]>()?;
let path: LitStr = input.parse()?;
let mut state: Option<Expr> = None;
let mut body_mode = BodyMode::None;
while input.peek(Token![,]) {
input.parse::<Token![,]>()?;
if input.is_empty() {
break;
}
let kw: Ident = input.parse()?;
match kw.to_string().as_str() {
"json" => {
if !matches!(body_mode, BodyMode::None) {
return Err(syn::Error::new(
kw.span(),
"#[mechanism]: only one of `json`, `query`, \
`encrypted(…)`, or `encrypted_query(…)` may be \
specified per route",
));
}
body_mode = BodyMode::Json;
}
"query" => {
if !matches!(body_mode, BodyMode::None) {
return Err(syn::Error::new(
kw.span(),
"#[mechanism]: only one of `json`, `query`, \
`encrypted(…)`, or `encrypted_query(…)` may be \
specified per route",
));
}
body_mode = BodyMode::Query;
}
"state" => {
if state.is_some() {
return Err(syn::Error::new(
kw.span(),
"#[mechanism]: `state(…)` may only be specified once",
));
}
let content;
syn::parenthesized!(content in input);
state = Some(content.parse::<Expr>()?);
}
"encrypted" => {
if !matches!(body_mode, BodyMode::None) {
return Err(syn::Error::new(
kw.span(),
"#[mechanism]: only one of `json`, `query`, \
`encrypted(…)`, or `encrypted_query(…)` may be \
specified per route",
));
}
let content;
syn::parenthesized!(content in input);
body_mode = BodyMode::Encrypted(content.parse::<Expr>()?);
}
"encrypted_query" => {
if !matches!(body_mode, BodyMode::None) {
return Err(syn::Error::new(
kw.span(),
"#[mechanism]: only one of `json`, `query`, \
`encrypted(…)`, or `encrypted_query(…)` may be \
specified per route",
));
}
let content;
syn::parenthesized!(content in input);
body_mode = BodyMode::EncryptedQuery(content.parse::<Expr>()?);
}
other => {
return Err(syn::Error::new(
kw.span(),
format!(
"#[mechanism]: unknown keyword `{other}`. \
Valid keywords: json, query, state(<expr>), \
encrypted(<key>), encrypted_query(<key>)"
),
));
}
}
}
Ok(MechanismArgs { server, method, path, state, body_mode })
}
}
fn extract_pat_ty<'a>(
arg: &'a syn::FnArg,
position: &str,
) -> syn::Result<(&'a syn::Pat, &'a syn::Type)> {
match arg {
syn::FnArg::Typed(pt) => Ok((&pt.pat, &pt.ty)),
syn::FnArg::Receiver(r) => Err(syn::Error::new_spanned(
&r.self_token,
format!(
"#[mechanism]: unexpected `self` in the {position} parameter position"
),
)),
}
}
pub fn expand(attr: TokenStream, item: TokenStream) -> TokenStream {
let args = parse_macro_input!(attr as MechanismArgs);
let func = parse_macro_input!(item as ItemFn);
let method_str = args.method.to_string().to_lowercase();
let method_ident = Ident::new(&method_str, args.method.span());
match method_str.as_str() {
"get" | "post" | "put" | "delete" | "patch" | "head" | "options" => {}
other => {
return syn::Error::new(
args.method.span(),
format!(
"#[mechanism]: `{other}` is not a valid HTTP method. \
Expected GET, POST, PUT, DELETE, PATCH, HEAD, or OPTIONS."
),
)
.to_compile_error()
.into();
}
}
let server = &args.server;
let path = &args.path;
let body = &func.block;
let params: Vec<&syn::FnArg> = func.sig.inputs.iter().collect();
macro_rules! bail {
($span:expr, $msg:literal) => {{
return syn::Error::new($span, $msg)
.to_compile_error()
.into();
}};
}
let route_expr = match (&args.state, &args.body_mode) {
(None, BodyMode::None) => {
quote! {
toolkit_zero::socket::server::ServerMechanism::#method_ident(#path)
.onconnect(|| async move #body)
}
}
(None, BodyMode::Json) => {
if params.is_empty() {
bail!(
func.sig.ident.span(),
"#[mechanism]: `json` mode requires one function parameter — `(body: YourType)`"
);
}
let (name, ty) = match extract_pat_ty(params[0], "body") {
Ok(v) => v,
Err(e) => return e.to_compile_error().into(),
};
quote! {
toolkit_zero::socket::server::ServerMechanism::#method_ident(#path)
.json::<#ty>()
.onconnect(|#name: #ty| async move #body)
}
}
(None, BodyMode::Query) => {
if params.is_empty() {
bail!(
func.sig.ident.span(),
"#[mechanism]: `query` mode requires one function parameter — `(params: YourType)`"
);
}
let (name, ty) = match extract_pat_ty(params[0], "query") {
Ok(v) => v,
Err(e) => return e.to_compile_error().into(),
};
quote! {
toolkit_zero::socket::server::ServerMechanism::#method_ident(#path)
.query::<#ty>()
.onconnect(|#name: #ty| async move #body)
}
}
(None, BodyMode::Encrypted(key_expr)) => {
if params.is_empty() {
bail!(
func.sig.ident.span(),
"#[mechanism]: `encrypted(key)` mode requires one function parameter — `(body: YourType)`"
);
}
let (name, ty) = match extract_pat_ty(params[0], "body") {
Ok(v) => v,
Err(e) => return e.to_compile_error().into(),
};
quote! {
toolkit_zero::socket::server::ServerMechanism::#method_ident(#path)
.encryption::<#ty>(#key_expr)
.onconnect(|#name: #ty| async move #body)
}
}
(None, BodyMode::EncryptedQuery(key_expr)) => {
if params.is_empty() {
bail!(
func.sig.ident.span(),
"#[mechanism]: `encrypted_query(key)` mode requires one function parameter — `(params: YourType)`"
);
}
let (name, ty) = match extract_pat_ty(params[0], "params") {
Ok(v) => v,
Err(e) => return e.to_compile_error().into(),
};
quote! {
toolkit_zero::socket::server::ServerMechanism::#method_ident(#path)
.encrypted_query::<#ty>(#key_expr)
.onconnect(|#name: #ty| async move #body)
}
}
(Some(state_expr), BodyMode::None) => {
if params.is_empty() {
bail!(
func.sig.ident.span(),
"#[mechanism]: `state(expr)` mode requires one function parameter — `(state: YourStateType)`"
);
}
let (name, ty) = match extract_pat_ty(params[0], "state") {
Ok(v) => v,
Err(e) => return e.to_compile_error().into(),
};
quote! {
toolkit_zero::socket::server::ServerMechanism::#method_ident(#path)
.state(#state_expr)
.onconnect(|#name: #ty| async move #body)
}
}
(Some(state_expr), BodyMode::Json) => {
if params.len() < 2 {
bail!(
func.sig.ident.span(),
"#[mechanism]: `state(expr), json` mode requires two function parameters — `(state: S, body: T)`"
);
}
let (s_name, s_ty) = match extract_pat_ty(params[0], "state") {
Ok(v) => v,
Err(e) => return e.to_compile_error().into(),
};
let (b_name, b_ty) = match extract_pat_ty(params[1], "body") {
Ok(v) => v,
Err(e) => return e.to_compile_error().into(),
};
quote! {
toolkit_zero::socket::server::ServerMechanism::#method_ident(#path)
.state(#state_expr)
.json::<#b_ty>()
.onconnect(|#s_name: #s_ty, #b_name: #b_ty| async move #body)
}
}
(Some(state_expr), BodyMode::Query) => {
if params.len() < 2 {
bail!(
func.sig.ident.span(),
"#[mechanism]: `state(expr), query` mode requires two function parameters — `(state: S, params: T)`"
);
}
let (s_name, s_ty) = match extract_pat_ty(params[0], "state") {
Ok(v) => v,
Err(e) => return e.to_compile_error().into(),
};
let (q_name, q_ty) = match extract_pat_ty(params[1], "query") {
Ok(v) => v,
Err(e) => return e.to_compile_error().into(),
};
quote! {
toolkit_zero::socket::server::ServerMechanism::#method_ident(#path)
.state(#state_expr)
.query::<#q_ty>()
.onconnect(|#s_name: #s_ty, #q_name: #q_ty| async move #body)
}
}
(Some(state_expr), BodyMode::Encrypted(key_expr)) => {
if params.len() < 2 {
bail!(
func.sig.ident.span(),
"#[mechanism]: `state(expr), encrypted(key)` mode requires two function parameters — `(state: S, body: T)`"
);
}
let (s_name, s_ty) = match extract_pat_ty(params[0], "state") {
Ok(v) => v,
Err(e) => return e.to_compile_error().into(),
};
let (b_name, b_ty) = match extract_pat_ty(params[1], "body") {
Ok(v) => v,
Err(e) => return e.to_compile_error().into(),
};
quote! {
toolkit_zero::socket::server::ServerMechanism::#method_ident(#path)
.state(#state_expr)
.encryption::<#b_ty>(#key_expr)
.onconnect(|#s_name: #s_ty, #b_name: #b_ty| async move #body)
}
}
(Some(state_expr), BodyMode::EncryptedQuery(key_expr)) => {
if params.len() < 2 {
bail!(
func.sig.ident.span(),
"#[mechanism]: `state(expr), encrypted_query(key)` mode requires two function parameters — `(state: S, params: T)`"
);
}
let (s_name, s_ty) = match extract_pat_ty(params[0], "state") {
Ok(v) => v,
Err(e) => return e.to_compile_error().into(),
};
let (q_name, q_ty) = match extract_pat_ty(params[1], "query") {
Ok(v) => v,
Err(e) => return e.to_compile_error().into(),
};
quote! {
toolkit_zero::socket::server::ServerMechanism::#method_ident(#path)
.state(#state_expr)
.encrypted_query::<#q_ty>(#key_expr)
.onconnect(|#s_name: #s_ty, #q_name: #q_ty| async move #body)
}
}
};
quote! {
#server.mechanism(#route_expr);
}
.into()
}