use std::cell::{Cell, RefCell};
use std::future::Future;
use std::pin::Pin;
use std::rc::Rc;
use std::task::{Context, Poll};
use serde::{de::DeserializeOwned, Serialize};
use wasm_bindgen::prelude::*;
use wasm_bindgen_futures::JsFuture;
use web_sys::{AbortSignal, Request, RequestInit, Response};
use crate::server::{Result as ServerResult, ServerError};
#[derive(Clone, Debug)]
#[non_exhaustive]
pub struct FetchRequest {
pub url: String,
pub method: String,
pub body: String,
pub headers: Vec<(String, String)>,
pub abort_signal: Option<AbortSignal>,
pub(crate) replay_safe: bool,
}
impl FetchRequest {
pub fn set_header(&mut self, name: impl Into<String>, value: impl Into<String>) {
let name = name.into();
let value = value.into();
if let Some(slot) = self
.headers
.iter_mut()
.find(|(k, _)| k.eq_ignore_ascii_case(&name))
{
slot.1 = value;
} else {
self.headers.push((name, value));
}
}
pub fn is_replay_safe(&self) -> bool {
self.replay_safe
}
}
#[derive(Clone, Debug)]
pub struct FetchResponse {
pub status: u16,
pub body: String,
}
#[derive(Clone)]
pub struct FetchNext {
index: usize,
middlewares: Rc<Vec<Rc<dyn FetchMiddleware>>>,
}
impl FetchNext {
pub fn run(
self,
request: FetchRequest,
) -> Pin<Box<dyn Future<Output = Result<FetchResponse, ServerError>> + 'static>> {
if self.index >= self.middlewares.len() {
Box::pin(perform_fetch(request))
} else {
let middleware = self.middlewares[self.index].clone();
let next = FetchNext {
index: self.index + 1,
middlewares: self.middlewares.clone(),
};
middleware.call(request, next)
}
}
}
pub type FetchMiddlewareFuture =
Pin<Box<dyn Future<Output = Result<FetchResponse, ServerError>> + 'static>>;
type MiddlewareChain = Rc<Vec<Rc<dyn FetchMiddleware>>>;
pub trait FetchMiddleware: 'static {
fn call(&self, request: FetchRequest, next: FetchNext) -> FetchMiddlewareFuture;
}
impl<F, Fut> FetchMiddleware for F
where
F: Fn(FetchRequest, FetchNext) -> Fut + 'static,
Fut: Future<Output = Result<FetchResponse, ServerError>> + 'static,
{
fn call(&self, request: FetchRequest, next: FetchNext) -> FetchMiddlewareFuture {
Box::pin(self(request, next))
}
}
thread_local! {
static MIDDLEWARES: RefCell<Vec<Rc<dyn FetchMiddleware>>> = const { RefCell::new(Vec::new()) };
static FROZEN: Cell<bool> = const { Cell::new(false) };
static MIDDLEWARE_SNAPSHOT: RefCell<Option<MiddlewareChain>> =
const { RefCell::new(None) };
static ACTIVE_ABORT_SIGNAL: RefCell<Option<AbortSignal>> =
const { RefCell::new(None) };
}
struct AbortSignalScope {
previous: Option<AbortSignal>,
}
impl Drop for AbortSignalScope {
fn drop(&mut self) {
ACTIVE_ABORT_SIGNAL.with(|cell| {
*cell.borrow_mut() = self.previous.take();
});
}
}
fn enter_abort_signal(signal: Option<AbortSignal>) -> AbortSignalScope {
let previous =
ACTIVE_ABORT_SIGNAL.with(|cell| std::mem::replace(&mut *cell.borrow_mut(), signal));
AbortSignalScope { previous }
}
fn current_abort_signal() -> Option<AbortSignal> {
ACTIVE_ABORT_SIGNAL.with(|cell| cell.borrow().clone())
}
pub(crate) struct AbortSignalFuture<F> {
signal: Option<AbortSignal>,
future: Pin<Box<F>>,
}
impl<F> Unpin for AbortSignalFuture<F> {}
impl<F: Future> Future for AbortSignalFuture<F> {
type Output = F::Output;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();
let _scope = enter_abort_signal(this.signal.clone());
this.future.as_mut().poll(cx)
}
}
pub(crate) fn with_abort_signal_future<F: Future>(
signal: Option<AbortSignal>,
future: F,
) -> AbortSignalFuture<F> {
AbortSignalFuture {
signal,
future: Box::pin(future),
}
}
pub fn install_middleware<M: FetchMiddleware>(middleware: M) {
if FROZEN.with(|cell| cell.get()) {
panic!(
"fetch::install_middleware called after the chain was frozen. \
Middleware install must run before the first App::run or \
fetch::call. Move the install into your plugin's `install` \
function so it executes before App::run."
);
}
MIDDLEWARES.with(|cell| cell.borrow_mut().push(Rc::new(middleware)));
}
#[doc(hidden)]
pub fn freeze_middleware_chain() {
let was_frozen = FROZEN.with(|cell| cell.replace(true));
if !was_frozen {
MIDDLEWARE_SNAPSHOT.with(|snap| {
*snap.borrow_mut() = Some(Rc::new(MIDDLEWARES.with(|cell| cell.borrow().clone())));
});
}
}
#[doc(hidden)]
pub fn __reset_middleware_chain_for_test() {
MIDDLEWARES.with(|cell| cell.borrow_mut().clear());
FROZEN.with(|cell| cell.set(false));
MIDDLEWARE_SNAPSHOT.with(|snap| *snap.borrow_mut() = None);
}
pub(crate) fn clear_and_freeze() {
MIDDLEWARES.with(|cell| cell.borrow_mut().clear());
FROZEN.with(|cell| cell.set(true));
MIDDLEWARE_SNAPSHOT.with(|snap| *snap.borrow_mut() = Some(Rc::new(Vec::new())));
}
fn snapshot_chain() -> MiddlewareChain {
if let Some(snap) = MIDDLEWARE_SNAPSHOT.with(|cell| cell.borrow().clone()) {
return snap;
}
MIDDLEWARES.with(|cell| Rc::new(cell.borrow().clone()))
}
pub async fn call<A, R>(url: &str, args: &A) -> ServerResult<R>
where
A: Serialize,
R: DeserializeOwned,
{
call_with_options(url, args, FetchOptions::default()).await
}
#[derive(Clone, Debug, Default)]
#[non_exhaustive]
pub struct FetchOptions {
pub(crate) abort_signal: Option<AbortSignal>,
pub(crate) replay_safe: bool,
}
impl FetchOptions {
pub fn abort_signal(mut self, signal: Option<AbortSignal>) -> Self {
self.abort_signal = signal;
self
}
pub fn replay_safe(mut self, enabled: bool) -> Self {
self.replay_safe = enabled;
self
}
fn with_active_context(mut self) -> Self {
if self.abort_signal.is_none() {
self.abort_signal = current_abort_signal();
}
self
}
}
pub async fn call_with_options<A, R>(url: &str, args: &A, options: FetchOptions) -> ServerResult<R>
where
A: Serialize,
R: DeserializeOwned,
{
freeze_middleware_chain();
let options = options.with_active_context();
let observe = FetchObservation::new(url);
let body = match serde_json::to_string(args) {
Ok(body) => body,
Err(err) => {
observe.failed("serialize");
return Err(ServerError::Network(format!("serialize args: {err}")));
}
};
let request = FetchRequest {
url: url.to_string(),
method: "POST".to_string(),
body,
headers: vec![("content-type".to_string(), "application/json".to_string())],
abort_signal: options.abort_signal,
replay_safe: options.replay_safe,
};
let middlewares = snapshot_chain();
let next = FetchNext {
index: 0,
middlewares,
};
let response = match next.run(request).await {
Ok(response) => response,
Err(err) => {
observe.failed(server_error_kind(&err));
return Err(err);
}
};
if !(200..300).contains(&response.status) {
observe.failed("http_status");
return Err(ServerError::Network(format!("HTTP {}", response.status)));
}
let outer: ServerResult<R> = match serde_json::from_str(&response.body) {
Ok(outer) => outer,
Err(err) => {
observe.failed("parse_response");
return Err(ServerError::Network(format!("parse response: {err}")));
}
};
match &outer {
Ok(_) => observe.completed(response.status),
Err(err) => observe.failed(server_error_kind(err)),
}
outer
}
pub async fn call_replay_safe<A, R>(url: &str, args: &A) -> ServerResult<R>
where
A: Serialize,
R: DeserializeOwned,
{
call_with_options(url, args, FetchOptions::default().replay_safe(true)).await
}
async fn perform_fetch(request: FetchRequest) -> Result<FetchResponse, ServerError> {
let init = RequestInit::new();
init.set_method(&request.method);
init.set_body(&JsValue::from_str(&request.body));
init.set_signal(request.abort_signal.as_ref());
let headers =
web_sys::Headers::new().map_err(|e| ServerError::Network(format!("headers: {e:?}")))?;
for (name, value) in &request.headers {
let _ = headers.set(name, value);
}
init.set_headers(&headers);
let req = Request::new_with_str_and_init(&request.url, &init)
.map_err(|e| ServerError::Network(format!("build request: {e:?}")))?;
let win =
web_sys::window().ok_or_else(|| ServerError::Network("no window available".to_string()))?;
let resp_js = JsFuture::from(win.fetch_with_request(&req))
.await
.map_err(|e| ServerError::Network(format!("fetch failed: {e:?}")))?;
let resp: Response = resp_js
.dyn_into()
.map_err(|_| ServerError::Network("fetch returned non-Response".into()))?;
let status = resp.status();
let text_js = JsFuture::from(
resp.text()
.map_err(|e| ServerError::Network(format!("read body: {e:?}")))?,
)
.await
.map_err(|e| ServerError::Network(format!("read body: {e:?}")))?;
let body = text_js
.as_string()
.ok_or_else(|| ServerError::Network("body was not a string".into()))?;
Ok(FetchResponse { status, body })
}
struct FetchObservation {
route: Option<String>,
start_ms: Option<f64>,
}
impl FetchObservation {
fn new(url: &str) -> Self {
if !crate::plugin::has_server_function_client_hooks() {
return Self {
route: None,
start_ms: None,
};
}
let route = public_url_path(url);
crate::plugin::emit(crate::plugin::ServerFunctionClientStarted {
route: route.clone(),
});
Self {
route: Some(route),
start_ms: Some(js_sys::Date::now()),
}
}
fn completed(&self, status_code: u16) {
let Some(route) = self.route.as_ref() else {
return;
};
crate::plugin::emit(crate::plugin::ServerFunctionClientCompleted {
route: route.clone(),
duration_ms: self.elapsed_ms(),
status_code,
});
}
fn failed(&self, error_kind: &'static str) {
let Some(route) = self.route.as_ref() else {
return;
};
crate::plugin::emit(crate::plugin::ServerFunctionClientFailed {
route: route.clone(),
duration_ms: self.elapsed_ms(),
error_kind,
});
}
fn elapsed_ms(&self) -> f64 {
let Some(start_ms) = self.start_ms else {
return 0.0;
};
let elapsed = js_sys::Date::now() - start_ms;
if elapsed.is_finite() && elapsed >= 0.0 {
elapsed
} else {
0.0
}
}
}
fn server_error_kind(err: &ServerError) -> &'static str {
match err {
ServerError::App(_) => "app",
ServerError::Unauthorized(_) => "unauthorized",
ServerError::Forbidden(_) => "forbidden",
ServerError::BadRequest(_) => "bad_request",
ServerError::Network(_) => "network",
}
}
fn public_url_path(url: &str) -> String {
let without_query = url.split_once('?').map(|(path, _)| path).unwrap_or(url);
without_query
.split_once('#')
.map(|(path, _)| path)
.unwrap_or(without_query)
.to_owned()
}
#[cfg(test)]
mod tests {
use super::*;
fn reset() {
__reset_middleware_chain_for_test();
}
#[test]
fn install_records_middleware_in_order() {
reset();
install_middleware(|req: FetchRequest, next: FetchNext| async move { next.run(req).await });
install_middleware(|req: FetchRequest, next: FetchNext| async move { next.run(req).await });
let snapshot = snapshot_chain();
assert_eq!(snapshot.len(), 2);
}
#[test]
#[should_panic(expected = "called after the chain was frozen")]
fn install_panics_after_freeze() {
reset();
freeze_middleware_chain();
install_middleware(|req: FetchRequest, next: FetchNext| async move { next.run(req).await });
}
#[test]
fn freeze_is_idempotent() {
reset();
freeze_middleware_chain();
freeze_middleware_chain();
assert!(FROZEN.with(|cell| cell.get()));
}
#[test]
fn fetch_request_set_header_replaces_case_insensitively() {
let mut req = FetchRequest {
url: "/x".into(),
method: "POST".into(),
body: "".into(),
headers: vec![("Content-Type".into(), "text/plain".into())],
abort_signal: None,
replay_safe: false,
};
req.set_header("content-type", "application/json");
assert_eq!(req.headers.len(), 1);
assert_eq!(req.headers[0].1, "application/json");
assert_eq!(req.headers[0].0, "Content-Type");
}
#[test]
fn fetch_request_set_header_appends_when_missing() {
let mut req = FetchRequest {
url: "/x".into(),
method: "POST".into(),
body: "".into(),
headers: vec![],
abort_signal: None,
replay_safe: false,
};
req.set_header("authorization", "Bearer t");
assert_eq!(
req.headers,
vec![("authorization".into(), "Bearer t".into())]
);
}
#[test]
fn clear_and_freeze_drops_middlewares_and_locks_chain() {
reset();
install_middleware(|req: FetchRequest, next: FetchNext| async move { next.run(req).await });
assert_eq!(snapshot_chain().len(), 1);
clear_and_freeze();
assert_eq!(snapshot_chain().len(), 0, "middlewares should be dropped");
assert!(FROZEN.with(|cell| cell.get()), "chain must stay frozen");
reset();
}
#[test]
fn fetch_options_default_to_not_replay_safe() {
let options = FetchOptions::default();
assert!(!options.replay_safe);
assert!(options.abort_signal.is_none());
}
#[test]
fn strips_query_and_fragment_from_observed_urls() {
assert_eq!(public_url_path("/api/search?q=secret#frag"), "/api/search");
assert_eq!(public_url_path("/api/save"), "/api/save");
}
}