use crate::app::extract_app_meta;
use crate::server_attrs::{has_server_hidden, has_server_skip, validate_server_attrs};
use proc_macro2::TokenStream as TokenStream2;
use quote::{format_ident, quote};
use server_less_parse::{MethodInfo, ParamInfo, extract_methods, get_impl_name, partition_methods};
use server_less_rpc::{self, AsyncHandling};
use syn::{ItemImpl, Token, parse::Parse};
use crate::context::{has_qualified_special_param, should_inject_special_param};
fn is_qualified_ws_sender(ty: &syn::Type) -> bool {
if let syn::Type::Path(type_path) = ty {
let path = &type_path.path;
let segments: Vec<_> = path.segments.iter().collect();
if segments.len() >= 2 {
for i in 0..segments.len() - 1 {
if segments[i].ident == "server_less" && segments[i + 1].ident == "WsSender" {
return true;
}
}
}
}
false
}
fn is_bare_ws_sender(ty: &syn::Type) -> bool {
if let syn::Type::Path(type_path) = ty
&& type_path.path.segments.len() == 1
{
return type_path.path.segments[0].ident == "WsSender";
}
false
}
fn should_inject_ws_sender(ty: &syn::Type, has_qualified: bool) -> bool {
should_inject_special_param(ty, is_qualified_ws_sender, is_bare_ws_sender, has_qualified)
}
fn has_qualified_ws_sender(methods: &[MethodInfo]) -> bool {
has_qualified_special_param(methods, is_qualified_ws_sender)
}
fn partition_ws_params(
params: &[ParamInfo],
has_qualified_sender: bool,
) -> syn::Result<(Option<&ParamInfo>, Option<&ParamInfo>, Vec<&ParamInfo>)> {
let mut context_param: Option<&ParamInfo> = None;
let mut sender_param: Option<&ParamInfo> = None;
let mut other_params = Vec::new();
for param in params {
if crate::context::should_inject_context(¶m.ty, params) {
if context_param.is_some() {
return Err(syn::Error::new_spanned(
¶m.ty,
"only one Context parameter allowed per method\n\
\n\
Hint: server_less::Context is automatically injected from request metadata.\n\
Remove the duplicate Context parameter.",
));
}
context_param = Some(param);
} else if should_inject_ws_sender(¶m.ty, has_qualified_sender) {
if sender_param.is_some() {
return Err(syn::Error::new_spanned(
¶m.ty,
"only one WsSender parameter allowed per method\n\
\n\
Hint: server_less::WsSender is automatically injected for each WebSocket connection.\n\
Remove the duplicate WsSender parameter.",
));
}
sender_param = Some(param);
} else {
other_params.push(param);
}
}
Ok((context_param, sender_param, other_params))
}
fn build_mount_injections(
params: &[ParamInfo],
has_qualified_sender: bool,
) -> Option<Vec<(usize, proc_macro2::TokenStream)>> {
let mut injections = Vec::new();
for (i, p) in params.iter().enumerate() {
if crate::context::should_inject_context(&p.ty, params) {
injections.push((i, quote! { ::server_less::Context::new() }));
} else if should_inject_ws_sender(&p.ty, has_qualified_sender) {
return None; }
}
Some(injections)
}
#[derive(Default)]
pub(crate) struct WsArgs {
pub path: Option<String>,
}
impl Parse for WsArgs {
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
let mut args = WsArgs::default();
while !input.is_empty() {
let ident: syn::Ident = input.parse()?;
input.parse::<Token![=]>()?;
match ident.to_string().as_str() {
"path" => {
let lit: syn::LitStr = input.parse()?;
args.path = Some(lit.value());
}
other => {
const VALID: &[&str] = &["path"];
let suggestion = crate::did_you_mean(other, VALID)
.map(|s| format!(" — did you mean `{s}`?"))
.unwrap_or_default();
return Err(syn::Error::new(
ident.span(),
format!("unknown argument `{other}`{suggestion}. Valid arguments: path"),
));
}
}
if input.peek(Token![,]) {
input.parse::<Token![,]>()?;
}
}
Ok(args)
}
}
pub(crate) fn expand_ws(args: WsArgs, mut impl_block: ItemImpl) -> syn::Result<TokenStream2> {
crate::reject_generic_impl(&impl_block)?;
let _app_meta = extract_app_meta(&mut impl_block.attrs);
let struct_name = get_impl_name(&impl_block)?;
let (impl_generics, _ty_generics, where_clause) = impl_block.generics.split_for_impl();
let self_ty = &impl_block.self_ty;
let methods = extract_methods(&impl_block)?;
let has_qualified_sender = has_qualified_ws_sender(&methods);
let path = args.path.unwrap_or_else(|| "/ws".to_string());
for m in &methods {
validate_server_attrs(m)?;
}
let partitioned = partition_methods(&methods, has_server_skip);
let visible_leaf: Vec<_> = partitioned
.leaf
.iter()
.copied()
.filter(|m| !has_server_hidden(m))
.collect();
let dispatch_arms_sync: Vec<_> = partitioned
.leaf
.iter()
.map(|m| {
let arm = generate_dispatch_arm_sync(m, has_qualified_sender)?;
let cfg_attrs = &m.cfg_attrs;
Ok(quote! {
#(#cfg_attrs)*
#arm
})
})
.collect::<syn::Result<Vec<_>>>()?;
let dispatch_arms_async: Vec<_> = partitioned
.leaf
.iter()
.map(|m| {
let arm = generate_dispatch_arm_async(m, has_qualified_sender)?;
let cfg_attrs = &m.cfg_attrs;
Ok(quote! {
#(#cfg_attrs)*
#arm
})
})
.collect::<syn::Result<Vec<_>>>()?;
let method_names: Vec<_> = visible_leaf
.iter()
.map(|m| m.wire_name_or(|n| n))
.collect();
let ws_method_doc_entries: Vec<String> = visible_leaf
.iter()
.map(|m| {
let name = m.wire_name_or(|n| n);
match &m.docs {
Some(doc) => format!("- `{name}` — {doc}"),
None => format!("- `{name}`"),
}
})
.collect();
let has_ws_mounts =
!partitioned.static_mounts.is_empty() || !partitioned.slug_mounts.is_empty();
let ws_methods_doc = if ws_method_doc_entries.is_empty() && !has_ws_mounts {
"Get available WebSocket JSON-RPC method names.".to_string()
} else {
let mount_note = if has_ws_mounts {
"\n\nAlso includes methods from mounted sub-services."
} else {
""
};
format!(
"Get available WebSocket JSON-RPC method names.\n\n# Methods\n\n{}{}",
ws_method_doc_entries.join("\n"),
mount_note
)
};
let ws_router_doc = format!(
"Create an axum Router with WebSocket endpoint at `{}`.\n\n\
Exposes {} method{}.",
path,
method_names.len(),
if method_names.len() == 1 { "" } else { "s" }
);
let mount_dispatch_sync: Vec<_> = partitioned
.static_mounts
.iter()
.map(|m| generate_ws_static_mount_dispatch(m, false))
.chain(
partitioned
.slug_mounts
.iter()
.map(|m| generate_ws_slug_mount_dispatch(m, false)),
)
.collect::<syn::Result<Vec<_>>>()?;
let mount_dispatch_async: Vec<_> = partitioned
.static_mounts
.iter()
.map(|m| generate_ws_static_mount_dispatch(m, true))
.chain(
partitioned
.slug_mounts
.iter()
.map(|m| generate_ws_slug_mount_dispatch(m, true)),
)
.collect::<syn::Result<Vec<_>>>()?;
let mount_method_names: Vec<_> = partitioned
.static_mounts
.iter()
.chain(partitioned.slug_mounts.iter())
.map(|m| generate_ws_mount_method_names(m))
.collect::<syn::Result<Vec<_>>>()?;
let uses_injected_params = partitioned.leaf.iter().any(|m| {
partition_ws_params(&m.params, has_qualified_sender)
.map(|(ctx, sender, _)| ctx.is_some() || sender.is_some())
.unwrap_or(false)
});
let (dispatch_sig_sync, dispatch_sig_async, dispatch_call_sync, dispatch_call_async) =
if uses_injected_params {
(
quote! {
fn ws_dispatch(
&self,
__ctx: ::server_less::Context,
__sender: ::server_less::WsSender,
method: &str,
args: ::server_less::serde_json::Value,
) -> ::std::result::Result<::server_less::serde_json::Value, String>
},
quote! {
async fn ws_dispatch_async(
&self,
__ctx: ::server_less::Context,
__sender: ::server_less::WsSender,
method: &str,
args: ::server_less::serde_json::Value,
) -> ::std::result::Result<::server_less::serde_json::Value, String>
},
quote! { self.ws_dispatch(__ctx, __sender, method, params) },
quote! { self.ws_dispatch_async(__ctx, __sender, method, params).await },
)
} else {
(
quote! {
fn ws_dispatch(
&self,
method: &str,
args: ::server_less::serde_json::Value,
) -> ::std::result::Result<::server_less::serde_json::Value, String>
},
quote! {
async fn ws_dispatch_async(
&self,
method: &str,
args: ::server_less::serde_json::Value,
) -> ::std::result::Result<::server_less::serde_json::Value, String>
},
quote! { self.ws_dispatch(method, params) },
quote! { self.ws_dispatch_async(method, params).await },
)
};
let message_handler_call = if uses_injected_params {
quote! { state.ws_handle_message_async(__ctx.clone(), __sender.clone(), &text).await }
} else {
quote! { state.ws_handle_message_async(&text).await }
};
let struct_name_snake = struct_name.to_string().to_lowercase();
let handler_name = format_ident!("__server_less_ws_handler_{}", struct_name_snake);
let connection_fn_name = format_ident!("__server_less_ws_connection_{}", struct_name_snake);
let (handle_sig_sync, handle_sig_async) = if uses_injected_params {
(
quote! {
pub fn ws_handle_message(
&self,
__ctx: ::server_less::Context,
__sender: ::server_less::WsSender,
message: &str,
) -> ::std::result::Result<String, String>
},
quote! {
pub async fn ws_handle_message_async(
&self,
__ctx: ::server_less::Context,
__sender: ::server_less::WsSender,
message: &str,
) -> ::std::result::Result<String, String>
},
)
} else {
(
quote! {
pub fn ws_handle_message(
&self,
message: &str,
) -> ::std::result::Result<String, String>
},
quote! {
pub async fn ws_handle_message_async(
&self,
message: &str,
) -> ::std::result::Result<String, String>
},
)
};
let ctx_creation = if uses_injected_params {
quote! {}
} else {
quote! { let __ctx = ::server_less::Context::new(); }
};
let mount_trait_dispatch_sync: Vec<_> = partitioned
.leaf
.iter()
.filter_map(|m| {
let injections =
build_mount_injections(&m.params, has_qualified_sender)?;
let arm = server_less_rpc::generate_dispatch_arm_with_injections(
m,
None,
AsyncHandling::Error,
&injections,
);
let cfg_attrs = &m.cfg_attrs;
Some(quote! {
#(#cfg_attrs)*
#arm
})
})
.collect();
let mount_trait_dispatch_async: Vec<_> = partitioned
.leaf
.iter()
.filter_map(|m| {
let injections =
build_mount_injections(&m.params, has_qualified_sender)?;
let arm = server_less_rpc::generate_dispatch_arm_with_injections(
m,
None,
AsyncHandling::Await,
&injections,
);
let cfg_attrs = &m.cfg_attrs;
Some(quote! {
#(#cfg_attrs)*
#arm
})
})
.collect();
let ctx_init_code = if uses_injected_params {
quote! {
let mut __ctx = ::server_less::Context::new();
for (name, value) in __context_headers.iter() {
if let Ok(value_str) = value.to_str() {
__ctx.set(name.as_str(), value_str);
}
}
if let Some(request_id) = __context_headers.get("x-request-id")
.and_then(|v| v.to_str().ok())
{
__ctx.set_request_id(request_id);
}
}
} else {
quote! {
let __ctx = ::server_less::Context::new();
}
};
let maybe_impl = if crate::is_protocol_impl_emitter(&impl_block, "ws") {
quote! { #impl_block }
} else {
quote! {}
};
Ok(quote! {
#maybe_impl
impl #impl_generics ::server_less::WsMount for #self_ty #where_clause {
fn ws_mount_methods() -> Vec<String> {
Self::ws_methods()
}
fn ws_mount_dispatch(
&self,
method: &str,
params: ::server_less::serde_json::Value,
) -> ::std::result::Result<::server_less::serde_json::Value, String> {
self.ws_mount_dispatch_inner(method, params)
}
async fn ws_mount_dispatch_async(
&self,
method: &str,
params: ::server_less::serde_json::Value,
) -> ::std::result::Result<::server_less::serde_json::Value, String> {
self.ws_mount_dispatch_async_inner(method, params).await
}
}
impl #impl_generics #self_ty #where_clause {
#[doc = #ws_methods_doc]
pub fn ws_methods() -> Vec<String> {
let mut names: Vec<String> = vec![#(#method_names.to_string()),*];
#(#mount_method_names)*
names
}
#handle_sig_sync {
#ctx_creation
let parsed: ::server_less::serde_json::Value = ::server_less::serde_json::from_str(message)
.map_err(|e| format!("Invalid JSON: {}", e))?;
let method = parsed.get("method")
.and_then(|v| v.as_str())
.ok_or_else(|| "Missing 'method' field".to_string())?;
let params = parsed.get("params")
.cloned()
.unwrap_or(::server_less::serde_json::json!({}));
let id = parsed.get("id").cloned();
let result = #dispatch_call_sync;
Self::__format_ws_response(result, id)
}
#handle_sig_async {
#ctx_creation
let parsed: ::server_less::serde_json::Value = ::server_less::serde_json::from_str(message)
.map_err(|e| format!("Invalid JSON: {}", e))?;
let method = parsed.get("method")
.and_then(|v| v.as_str())
.ok_or_else(|| "Missing 'method' field".to_string())?;
let params = parsed.get("params")
.cloned()
.unwrap_or(::server_less::serde_json::json!({}));
let id = parsed.get("id").cloned();
let result = #dispatch_call_async;
Self::__format_ws_response(result, id)
}
fn __format_ws_response(
result: ::std::result::Result<::server_less::serde_json::Value, String>,
id: Option<::server_less::serde_json::Value>,
) -> ::std::result::Result<String, String> {
let response = match result {
Ok(value) => {
let mut resp = ::server_less::serde_json::json!({
"result": value
});
if let Some(id) = id {
resp.as_object_mut()
.expect("BUG: json!({}) must produce an Object")
.insert("id".to_string(), id);
}
resp
}
Err(err) => {
let mut resp = ::server_less::serde_json::json!({
"error": {
"message": err
}
});
if let Some(id) = id {
resp.as_object_mut()
.expect("BUG: json!({}) must produce an Object")
.insert("id".to_string(), id);
}
resp
}
};
::server_less::serde_json::to_string(&response)
.map_err(|e| format!("Serialization error: {}", e))
}
#dispatch_sig_sync {
match method {
#(#dispatch_arms_sync)*
#(#mount_dispatch_sync)*
_ => Err(format!("Unknown method: {}", method)),
}
}
#dispatch_sig_async {
match method {
#(#dispatch_arms_async)*
#(#mount_dispatch_async)*
_ => Err(format!("Unknown method: {}", method)),
}
}
fn ws_mount_dispatch_inner(
&self,
method: &str,
args: ::server_less::serde_json::Value,
) -> ::std::result::Result<::server_less::serde_json::Value, String> {
match method {
#(#mount_trait_dispatch_sync)*
#(#mount_dispatch_sync)*
_ => Err(format!("Unknown method: {}", method)),
}
}
async fn ws_mount_dispatch_async_inner(
&self,
method: &str,
args: ::server_less::serde_json::Value,
) -> ::std::result::Result<::server_less::serde_json::Value, String> {
match method {
#(#mount_trait_dispatch_async)*
#(#mount_dispatch_async)*
_ => Err(format!("Unknown method: {}", method)),
}
}
#[doc = #ws_router_doc]
pub fn ws_router(self) -> ::server_less::axum::Router
where
Self: Clone + Send + Sync + 'static,
{
let state = ::std::sync::Arc::new(self);
::server_less::axum::Router::new()
.route(#path, ::server_less::axum::routing::get(#handler_name))
.with_state(state)
}
pub fn ws_openapi_paths() -> ::std::vec::Vec<::server_less::OpenApiPath> {
let methods: Vec<&str> = vec![#(#method_names),*];
let methods_desc = methods.join(", ");
vec![
::server_less::OpenApiPath {
path: #path.to_string(),
method: "get".to_string(),
operation: ::server_less::OpenApiOperation {
summary: Some(format!("WebSocket endpoint (methods: {})", methods_desc)),
description: None,
operation_id: Some("websocket".to_string()),
tags: vec!["websocket".to_string()],
deprecated: false,
parameters: vec![],
request_body: None,
responses: {
let mut r = ::server_less::serde_json::Map::new();
r.insert("101".to_string(), ::server_less::serde_json::json!({
"description": "Switching Protocols - WebSocket upgrade successful"
}));
r
},
extra: {
let mut e = ::server_less::serde_json::Map::new();
e.insert("x-websocket-protocol".to_string(), ::server_less::serde_json::json!({
"format": "JSON-RPC style",
"methods": methods,
"request_example": {
"method": "echo",
"params": {"message": "hello"},
"id": 1
},
"response_example": {
"result": "Echo: hello",
"id": 1
}
}));
e
},
},
}
]
}
}
async fn #handler_name(
ws: ::server_less::axum::extract::WebSocketUpgrade,
state_extractor: ::server_less::axum::extract::State<::std::sync::Arc<#self_ty>>,
__context_headers: ::server_less::axum::http::HeaderMap,
) -> impl ::server_less::axum::response::IntoResponse {
let state = state_extractor.0;
#ctx_init_code
ws.on_upgrade(move |socket| async move {
#connection_fn_name(socket, state, __ctx).await
})
}
async fn #connection_fn_name(
socket: ::server_less::axum::extract::ws::WebSocket,
state: ::std::sync::Arc<#self_ty>,
__ctx: ::server_less::Context,
) {
use ::futures::stream::StreamExt;
use ::futures::sink::SinkExt;
let (sender, mut receiver) = socket.split();
let __sender = ::server_less::WsSender::new(sender);
while let Some(msg) = receiver.next().await {
match msg {
Ok(::server_less::axum::extract::ws::Message::Text(text)) => {
let response = #message_handler_call;
let reply = match response {
Ok(json) => json,
Err(err) => ::server_less::serde_json::json!({
"error": {"message": err}
}).to_string(),
};
if __sender.send(reply).await.is_err() {
break;
}
}
Ok(::server_less::axum::extract::ws::Message::Close(_)) => break,
Ok(_) => {} Err(_) => break,
}
}
}
})
}
fn generate_dispatch_arm_sync(
method: &MethodInfo,
has_qualified_sender: bool,
) -> syn::Result<TokenStream2> {
generate_dispatch_arm_with_injected_params(
method,
has_qualified_sender,
AsyncHandling::Error,
)
}
fn generate_dispatch_arm_async(
method: &MethodInfo,
has_qualified_sender: bool,
) -> syn::Result<TokenStream2> {
generate_dispatch_arm_with_injected_params(
method,
has_qualified_sender,
AsyncHandling::Await,
)
}
fn generate_dispatch_arm_with_injected_params(
method: &MethodInfo,
has_qualified_sender: bool,
async_handling: AsyncHandling,
) -> syn::Result<TokenStream2> {
let method_name_str = method.wire_name_or(|n| n);
let (context_param, sender_param, regular_params) =
partition_ws_params(&method.params, has_qualified_sender)?;
if context_param.is_none() && sender_param.is_none() {
return Ok(server_less_rpc::generate_dispatch_arm(
method,
None,
async_handling,
));
}
let requires_async = method.is_async || method.return_info.is_stream;
if requires_async && matches!(async_handling, AsyncHandling::Error) {
let param_extractions = server_less_rpc::generate_param_extractions_for(®ular_params);
return Ok(quote! {
#method_name_str => {
#(#param_extractions)*
return Err("Async methods and streaming methods not supported in sync context".to_string());
}
});
}
let param_extractions = server_less_rpc::generate_param_extractions_for(®ular_params);
let mut arg_exprs = Vec::new();
for param in &method.params {
if crate::context::should_inject_context(¶m.ty, &method.params) {
arg_exprs.push(quote! { __ctx.clone() });
} else if should_inject_ws_sender(¶m.ty, has_qualified_sender) {
arg_exprs.push(quote! { __sender.clone() });
} else {
let name = ¶m.name;
arg_exprs.push(quote! { #name });
}
}
let call = server_less_rpc::generate_method_call_with_args(method, arg_exprs, async_handling);
let response = server_less_rpc::generate_json_response(method);
Ok(quote! {
#method_name_str => {
#(#param_extractions)*
#call
#response
}
})
}
fn generate_ws_mount_method_names(method: &MethodInfo) -> syn::Result<TokenStream2> {
let mount_name = method.wire_name_or(|n| n);
let mount_prefix = format!("{}.", mount_name);
let inner_ty = method.return_info.reference_inner.as_ref().ok_or_else(|| {
syn::Error::new_spanned(
&method.method.sig,
"BUG: mount method must have a reference return type (&T)",
)
})?;
Ok(quote! {
{
let child_methods = <#inner_ty as ::server_less::WsMount>::ws_mount_methods();
for child_name in child_methods {
let prefixed = format!("{}{}", #mount_prefix, child_name);
names.push(prefixed);
}
}
})
}
fn generate_ws_static_mount_dispatch(method: &MethodInfo, is_async: bool) -> syn::Result<TokenStream2> {
let mount_name = method.wire_name_or(|n| n);
let mount_prefix = format!("{}.", mount_name);
let method_name = &method.name;
let inner_ty = method.return_info.reference_inner.as_ref().ok_or_else(|| {
syn::Error::new_spanned(
&method.method.sig,
"BUG: mount method must have a reference return type (&T)",
)
})?;
Ok(if is_async {
quote! {
__method if __method.starts_with(#mount_prefix) => {
let __stripped = &__method[#mount_prefix.len()..];
let __delegate = self.#method_name();
<#inner_ty as ::server_less::WsMount>::ws_mount_dispatch_async(__delegate, __stripped, args).await
}
}
} else {
quote! {
__method if __method.starts_with(#mount_prefix) => {
let __stripped = &__method[#mount_prefix.len()..];
let __delegate = self.#method_name();
<#inner_ty as ::server_less::WsMount>::ws_mount_dispatch(__delegate, __stripped, args)
}
}
})
}
fn generate_ws_slug_mount_dispatch(method: &MethodInfo, is_async: bool) -> syn::Result<TokenStream2> {
let mount_name = method.wire_name_or(|n| n);
let mount_prefix = format!("{}.", mount_name);
let method_name = &method.name;
let inner_ty = method.return_info.reference_inner.as_ref().ok_or_else(|| {
syn::Error::new_spanned(
&method.method.sig,
"BUG: mount method must have a reference return type (&T)",
)
})?;
let slug_extractions: Vec<_> = method
.params
.iter()
.map(server_less_rpc::generate_param_extraction)
.collect();
let slug_names: Vec<_> = method.params.iter().map(|p| &p.name).collect();
Ok(if is_async {
quote! {
__method if __method.starts_with(#mount_prefix) => {
let __stripped = &__method[#mount_prefix.len()..];
#(#slug_extractions)*
let __delegate = self.#method_name(#(#slug_names),*);
<#inner_ty as ::server_less::WsMount>::ws_mount_dispatch_async(__delegate, __stripped, args).await
}
}
} else {
quote! {
__method if __method.starts_with(#mount_prefix) => {
let __stripped = &__method[#mount_prefix.len()..];
#(#slug_extractions)*
let __delegate = self.#method_name(#(#slug_names),*);
<#inner_ty as ::server_less::WsMount>::ws_mount_dispatch(__delegate, __stripped, args)
}
}
})
}