use parking_lot::RwLock;
use std::any::Any;
use std::collections::HashMap;
use std::sync::Arc;
type SendSyncAnyMap = std::collections::HashMap<std::any::TypeId, ContextType>;
#[derive(Clone)]
pub struct DioxusServerContext {
shared_context: std::sync::Arc<RwLock<SendSyncAnyMap>>,
response_parts: std::sync::Arc<RwLock<http::response::Parts>>,
pub(crate) parts: Arc<RwLock<http::request::Parts>>,
}
enum ContextType {
Factory(Box<dyn Fn() -> Box<dyn Any> + Send + Sync>),
Value(Box<dyn Any + Send + Sync>),
}
impl ContextType {
fn downcast<T: Clone + 'static>(&self) -> Option<T> {
match self {
ContextType::Value(value) => value.downcast_ref::<T>().cloned(),
ContextType::Factory(factory) => factory().downcast::<T>().ok().map(|v| *v),
}
}
}
#[allow(clippy::derivable_impls)]
impl Default for DioxusServerContext {
fn default() -> Self {
Self {
shared_context: std::sync::Arc::new(RwLock::new(HashMap::new())),
response_parts: std::sync::Arc::new(RwLock::new(
http::response::Response::new(()).into_parts().0,
)),
parts: std::sync::Arc::new(RwLock::new(http::request::Request::new(()).into_parts().0)),
}
}
}
mod server_fn_impl {
use super::*;
use parking_lot::{RwLockReadGuard, RwLockWriteGuard};
use std::any::{Any, TypeId};
impl DioxusServerContext {
pub fn new(parts: http::request::Parts) -> Self {
Self {
parts: Arc::new(RwLock::new(parts)),
shared_context: Arc::new(RwLock::new(SendSyncAnyMap::new())),
response_parts: std::sync::Arc::new(RwLock::new(
http::response::Response::new(()).into_parts().0,
)),
}
}
#[allow(unused)]
pub(crate) fn from_shared_parts(parts: Arc<RwLock<http::request::Parts>>) -> Self {
Self {
parts,
shared_context: Arc::new(RwLock::new(SendSyncAnyMap::new())),
response_parts: std::sync::Arc::new(RwLock::new(
http::response::Response::new(()).into_parts().0,
)),
}
}
pub fn get<T: Any + Send + Sync + Clone + 'static>(&self) -> Option<T> {
self.shared_context
.read()
.get(&TypeId::of::<T>())
.map(|v| v.downcast::<T>().unwrap())
}
pub fn insert<T: Any + Send + Sync + 'static>(&self, value: T) {
self.insert_any(Box::new(value));
}
pub fn insert_any(&self, value: Box<dyn Any + Send + Sync + 'static>) {
self.shared_context
.write()
.insert((*value).type_id(), ContextType::Value(value));
}
pub fn insert_factory<F, T>(&self, value: F)
where
F: Fn() -> T + Send + Sync + 'static,
T: 'static,
{
self.shared_context.write().insert(
TypeId::of::<T>(),
ContextType::Factory(Box::new(move || Box::new(value()))),
);
}
pub fn insert_boxed_factory(&self, value: Box<dyn Fn() -> Box<dyn Any> + Send + Sync>) {
self.shared_context
.write()
.insert((*value()).type_id(), ContextType::Factory(value));
}
#[doc = include_str!("../docs/request_origin.md")]
pub fn response_parts(&self) -> RwLockReadGuard<'_, http::response::Parts> {
self.response_parts.read()
}
#[doc = include_str!("../docs/request_origin.md")]
pub fn response_parts_mut(&self) -> RwLockWriteGuard<'_, http::response::Parts> {
self.response_parts.write()
}
#[doc = include_str!("../docs/request_origin.md")]
pub fn request_parts(&self) -> parking_lot::RwLockReadGuard<'_, http::request::Parts> {
self.parts.read()
}
#[doc = include_str!("../docs/request_origin.md")]
pub fn request_parts_mut(&self) -> parking_lot::RwLockWriteGuard<'_, http::request::Parts> {
self.parts.write()
}
#[doc = include_str!("../docs/request_origin.md")]
pub async fn extract<M, T: FromServerContext<M>>(&self) -> Result<T, T::Rejection> {
T::from_request(self).await
}
}
}
#[test]
fn server_context_as_any_map() {
let parts = http::Request::new(()).into_parts().0;
let server_context = DioxusServerContext::new(parts);
server_context.insert_boxed_factory(Box::new(|| Box::new(1234u32)));
assert_eq!(server_context.get::<u32>().unwrap(), 1234u32);
}
std::thread_local! {
pub(crate) static SERVER_CONTEXT: std::cell::RefCell<Box<DioxusServerContext>> = Default::default();
}
pub fn server_context() -> DioxusServerContext {
SERVER_CONTEXT.with(|ctx| *ctx.borrow().clone())
}
pub async fn extract<E: FromServerContext<I>, I>() -> Result<E, E::Rejection> {
E::from_request(&server_context()).await
}
pub fn with_server_context<O>(context: DioxusServerContext, f: impl FnOnce() -> O) -> O {
let prev_context = SERVER_CONTEXT.with(|ctx| ctx.replace(Box::new(context)));
let result = f();
SERVER_CONTEXT.with(|ctx| ctx.replace(prev_context));
result
}
#[pin_project::pin_project]
pub struct ProvideServerContext<F: std::future::Future> {
context: DioxusServerContext,
#[pin]
f: F,
}
impl<F: std::future::Future> ProvideServerContext<F> {
pub fn new(f: F, context: DioxusServerContext) -> Self {
Self { f, context }
}
}
impl<F: std::future::Future> std::future::Future for ProvideServerContext<F> {
type Output = F::Output;
fn poll(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Self::Output> {
let this = self.project();
let context = this.context.clone();
with_server_context(context, || this.f.poll(cx))
}
}
#[async_trait::async_trait]
pub trait FromServerContext<I = ()>: Sized {
type Rejection;
async fn from_request(req: &DioxusServerContext) -> Result<Self, Self::Rejection>;
}
pub struct NotFoundInServerContext<T: 'static>(std::marker::PhantomData<T>);
impl<T: 'static> std::fmt::Debug for NotFoundInServerContext<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let type_name = std::any::type_name::<T>();
write!(f, "`{type_name}` not found in server context")
}
}
impl<T: 'static> std::fmt::Display for NotFoundInServerContext<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let type_name = std::any::type_name::<T>();
write!(f, "`{type_name}` not found in server context")
}
}
impl<T: 'static> std::error::Error for NotFoundInServerContext<T> {}
pub struct FromContext<T: std::marker::Send + std::marker::Sync + Clone + 'static>(pub T);
#[async_trait::async_trait]
impl<T: Send + Sync + Clone + 'static> FromServerContext for FromContext<T> {
type Rejection = NotFoundInServerContext<T>;
async fn from_request(req: &DioxusServerContext) -> Result<Self, Self::Rejection> {
Ok(Self(req.get::<T>().ok_or({
NotFoundInServerContext::<T>(std::marker::PhantomData::<T>)
})?))
}
}
#[cfg(feature = "axum")]
#[cfg_attr(docsrs, doc(cfg(feature = "axum")))]
pub struct Axum;
#[cfg(feature = "axum")]
#[async_trait::async_trait]
impl<
I: axum::extract::FromRequestParts<(), Rejection = R>,
R: axum::response::IntoResponse + std::error::Error,
> FromServerContext<Axum> for I
{
type Rejection = R;
#[allow(clippy::all)]
async fn from_request(req: &DioxusServerContext) -> Result<Self, Self::Rejection> {
let mut lock = req.request_parts_mut();
I::from_request_parts(&mut lock, &()).await
}
}