#![forbid(unsafe_code)]
#![deny(missing_docs)]
#[doc(hidden)]
pub use const_format;
use proc_macro2::TokenStream;
use quote::TokenStreamExt;
use serde::{de::DeserializeOwned, Deserialize, Serialize};
pub use server_fn_macro_default::server;
use std::{future::Future, pin::Pin, str::FromStr, sync::Arc};
use syn::parse_quote;
use thiserror::Error;
#[doc(hidden)]
pub use xxhash_rust;
pub trait ServerFunctionRegistry<T> {
type Error: std::error::Error;
fn register(
url: &'static str,
server_function: Arc<ServerFnTraitObj<T>>,
encoding: Encoding,
) -> Result<(), Self::Error>;
fn get(url: &str) -> Option<ServerFunction<T>>;
fn get_trait_obj(url: &str) -> Option<Arc<ServerFnTraitObj<T>>>;
fn get_encoding(url: &str) -> Option<Encoding>;
fn paths_registered() -> Vec<&'static str>;
}
#[derive(Clone)]
pub struct ServerFunction<T> {
pub trait_obj: Arc<ServerFnTraitObj<T>>,
pub encoding: Encoding,
}
pub type ServerFnTraitObj<T> = dyn Fn(
T,
&[u8],
) -> Pin<Box<dyn Future<Output = Result<Payload, ServerFnError>>>>
+ Send
+ Sync;
#[derive(Debug)]
pub enum Payload {
Binary(Vec<u8>),
Url(String),
Json(String),
}
#[cfg(any(feature = "ssr", doc))]
pub fn server_fn_by_path<T: 'static, R: ServerFunctionRegistry<T>>(
path: &str,
) -> Option<ServerFunction<T>> {
R::get(path)
}
#[cfg(any(feature = "ssr", doc))]
pub fn server_fn_trait_obj_by_path<T: 'static, R: ServerFunctionRegistry<T>>(
path: &str,
) -> Option<Arc<ServerFnTraitObj<T>>> {
R::get_trait_obj(path)
}
#[cfg(any(feature = "ssr", doc))]
pub fn server_fn_encoding_by_path<T: 'static, R: ServerFunctionRegistry<T>>(
path: &str,
) -> Option<Encoding> {
R::get_encoding(path)
}
#[cfg(any(feature = "ssr", doc))]
pub fn server_fns_by_path<T: 'static, R: ServerFunctionRegistry<T>>(
) -> Vec<&'static str> {
R::paths_registered()
}
#[derive(Debug, Clone, Default, PartialEq)]
pub enum Encoding {
Cbor,
#[default]
Url,
GetJSON,
GetCBOR,
}
impl FromStr for Encoding {
type Err = ();
fn from_str(input: &str) -> Result<Encoding, Self::Err> {
match input {
"URL" => Ok(Encoding::Url),
"Cbor" => Ok(Encoding::Cbor),
"GetCbor" => Ok(Encoding::GetCBOR),
"GetJson" => Ok(Encoding::GetJSON),
_ => Err(()),
}
}
}
impl quote::ToTokens for Encoding {
fn to_tokens(&self, tokens: &mut TokenStream) {
let option: syn::Ident = match *self {
Encoding::Cbor => parse_quote!(Cbor),
Encoding::Url => parse_quote!(Url),
Encoding::GetJSON => parse_quote!(GetJSON),
Encoding::GetCBOR => parse_quote!(GetCBOR),
};
let expansion: syn::Ident = syn::parse_quote! {
Encoding::#option
};
tokens.append(expansion);
}
}
pub trait ServerFn<T: 'static>
where
Self: Serialize + DeserializeOwned + Sized + 'static,
{
type Output: Serialize;
fn prefix() -> &'static str;
fn url() -> &'static str;
fn encoding() -> Encoding;
#[cfg(any(feature = "ssr", doc))]
fn call_fn(
self,
cx: T,
) -> Pin<Box<dyn Future<Output = Result<Self::Output, ServerFnError>>>>;
#[cfg(any(not(feature = "ssr"), doc))]
fn call_fn_client(
self,
cx: T,
) -> Pin<Box<dyn Future<Output = Result<Self::Output, ServerFnError>>>>;
#[cfg(any(feature = "ssr", doc,))]
fn register_in<R: ServerFunctionRegistry<T>>() -> Result<(), ServerFnError>
{
let run_server_fn = Arc::new(|cx: T, data: &[u8]| {
let value = match Self::encoding() {
Encoding::Url | Encoding::GetJSON | Encoding::GetCBOR => {
serde_qs::from_bytes(data).map_err(|e| {
ServerFnError::Deserialization(e.to_string())
})
}
Encoding::Cbor => ciborium::de::from_reader(data)
.map_err(|e| ServerFnError::Deserialization(e.to_string())),
};
Box::pin(async move {
let value: Self = match value {
Ok(v) => v,
Err(e) => return Err(e),
};
let result = match value.call_fn(cx).await {
Ok(r) => r,
Err(e) => return Err(e),
};
let result = match Self::encoding() {
Encoding::Url | Encoding::GetJSON => {
match serde_json::to_string(&result).map_err(|e| {
ServerFnError::Serialization(e.to_string())
}) {
Ok(r) => Payload::Url(r),
Err(e) => return Err(e),
}
}
Encoding::Cbor | Encoding::GetCBOR => {
let mut buffer: Vec<u8> = Vec::new();
match ciborium::ser::into_writer(&result, &mut buffer)
.map_err(|e| {
ServerFnError::Serialization(e.to_string())
}) {
Ok(_) => Payload::Binary(buffer),
Err(e) => return Err(e),
}
}
};
Ok(result)
})
as Pin<Box<dyn Future<Output = Result<Payload, ServerFnError>>>>
});
R::register(Self::url(), run_server_fn, Self::encoding())
.map_err(|e| ServerFnError::Registration(e.to_string()))
}
}
#[derive(Error, Debug, Clone, Serialize, Deserialize)]
pub enum ServerFnError {
#[error("error while trying to register the server function: {0}")]
Registration(String),
#[error("error reaching server to call server function: {0}")]
Request(String),
#[error("error running server function: {0}")]
ServerError(String),
#[error("error deserializing server function results {0}")]
Deserialization(String),
#[error("error serializing server function arguments {0}")]
Serialization(String),
#[error("error deserializing server function arguments {0}")]
Args(String),
#[error("missing argument {0}")]
MissingArg(String),
}
#[cfg(not(feature = "ssr"))]
pub async fn call_server_fn<T, C: 'static>(
url: &str,
args: impl ServerFn<C>,
enc: Encoding,
) -> Result<T, ServerFnError>
where
T: serde::Serialize + serde::de::DeserializeOwned + Sized,
{
use ciborium::ser::into_writer;
use serde_json::Deserializer as JSONDeserializer;
#[cfg(not(target_arch = "wasm32"))]
let url = format!("{}{}", get_server_url(), url);
#[derive(Debug)]
enum Payload {
Binary(Vec<u8>),
Url(String),
}
let args_encoded = match &enc {
Encoding::Url | Encoding::GetJSON | Encoding::GetCBOR => Payload::Url(
serde_qs::to_string(&args)
.map_err(|e| ServerFnError::Serialization(e.to_string()))?,
),
Encoding::Cbor => {
let mut buffer: Vec<u8> = Vec::new();
into_writer(&args, &mut buffer)
.map_err(|e| ServerFnError::Serialization(e.to_string()))?;
Payload::Binary(buffer)
}
};
let content_type_header = match &enc {
Encoding::Url | Encoding::GetJSON | Encoding::GetCBOR => {
"application/x-www-form-urlencoded"
}
Encoding::Cbor => "application/cbor",
};
let accept_header = match &enc {
Encoding::Url | Encoding::GetJSON => {
"application/x-www-form-urlencoded"
}
Encoding::Cbor | Encoding::GetCBOR => "application/cbor",
};
#[cfg(target_arch = "wasm32")]
let resp = match &enc {
Encoding::Url | Encoding::Cbor => match args_encoded {
Payload::Binary(b) => {
let slice_ref: &[u8] = &b;
let js_array = js_sys::Uint8Array::from(slice_ref).buffer();
gloo_net::http::Request::post(url)
.header("Content-Type", content_type_header)
.header("Accept", accept_header)
.body(js_array)
.send()
.await
.map_err(|e| ServerFnError::Request(e.to_string()))?
}
Payload::Url(s) => gloo_net::http::Request::post(url)
.header("Content-Type", content_type_header)
.header("Accept", accept_header)
.body(s)
.send()
.await
.map_err(|e| ServerFnError::Request(e.to_string()))?,
},
Encoding::GetCBOR | Encoding::GetJSON => match args_encoded {
Payload::Binary(_) => panic!(
"Binary data cannot be transferred via GET request in a query \
string. Please try using the CBOR encoding."
),
Payload::Url(s) => {
let full_url = format!("{url}?{s}");
gloo_net::http::Request::get(&full_url)
.header("Content-Type", content_type_header)
.header("Accept", accept_header)
.send()
.await
.map_err(|e| ServerFnError::Request(e.to_string()))?
}
},
};
#[cfg(not(target_arch = "wasm32"))]
let resp = match &enc {
Encoding::Url | Encoding::Cbor => match args_encoded {
Payload::Binary(b) => CLIENT
.post(url)
.header("Content-Type", content_type_header)
.header("Accept", accept_header)
.body(b)
.send()
.await
.map_err(|e| ServerFnError::Request(e.to_string()))?,
Payload::Url(s) => CLIENT
.post(url)
.header("Content-Type", content_type_header)
.header("Accept", accept_header)
.body(s)
.send()
.await
.map_err(|e| ServerFnError::Request(e.to_string()))?,
},
Encoding::GetJSON | Encoding::GetCBOR => match args_encoded {
Payload::Binary(_) => panic!(
"Binary data cannot be transferred via GET request in a query \
string. Please try using the CBOR encoding."
),
Payload::Url(s) => {
let full_url = format!("{url}?{s}");
CLIENT
.get(full_url)
.header("Content-Type", content_type_header)
.header("Accept", accept_header)
.send()
.await
.map_err(|e| ServerFnError::Request(e.to_string()))?
}
},
};
let status = resp.status();
#[cfg(not(target_arch = "wasm32"))]
let status = status.as_u16();
if (500..=599).contains(&status) {
let text = resp.text().await.unwrap_or_default();
#[cfg(target_arch = "wasm32")]
let status_text = resp.status_text();
#[cfg(not(target_arch = "wasm32"))]
let status_text = status.to_string();
return Err(serde_json::from_str(&text)
.unwrap_or(ServerFnError::ServerError(status_text)));
}
if (enc == Encoding::Cbor) || (enc == Encoding::GetCBOR) {
#[cfg(target_arch = "wasm32")]
let binary = resp
.binary()
.await
.map_err(|e| ServerFnError::Deserialization(e.to_string()))?;
#[cfg(target_arch = "wasm32")]
let binary = binary.as_slice();
#[cfg(not(target_arch = "wasm32"))]
let binary = resp
.bytes()
.await
.map_err(|e| ServerFnError::Deserialization(e.to_string()))?;
#[cfg(not(target_arch = "wasm32"))]
let binary = binary.as_ref();
ciborium::de::from_reader(binary)
.map_err(|e| ServerFnError::Deserialization(e.to_string()))
} else {
let text = resp
.text()
.await
.map_err(|e| ServerFnError::Deserialization(e.to_string()))?;
let mut deserializer = JSONDeserializer::from_str(&text);
T::deserialize(&mut deserializer)
.map_err(|e| ServerFnError::Deserialization(e.to_string()))
}
}
#[cfg(any(all(not(feature = "ssr"), not(target_arch = "wasm32")), doc))]
static CLIENT: once_cell::sync::Lazy<reqwest::Client> =
once_cell::sync::Lazy::new(|| reqwest::Client::new());
#[cfg(any(all(not(feature = "ssr"), not(target_arch = "wasm32")), doc))]
static ROOT_URL: once_cell::sync::OnceCell<&'static str> =
once_cell::sync::OnceCell::new();
#[cfg(any(all(not(feature = "ssr"), not(target_arch = "wasm32")), doc))]
pub fn set_server_url(url: &'static str) {
ROOT_URL.set(url).unwrap();
}
#[cfg(all(not(feature = "ssr"), not(target_arch = "wasm32")))]
fn get_server_url() -> &'static str {
ROOT_URL
.get()
.expect("Call set_root_url before calling a server function.")
}