dioxus_fullstack/
server_context.rs

1use parking_lot::RwLock;
2use std::any::Any;
3use std::collections::HashMap;
4use std::sync::Arc;
5
6type SendSyncAnyMap = std::collections::HashMap<std::any::TypeId, ContextType>;
7
8/// A shared context for server functions that contains information about the request and middleware state.
9///
10/// You should not construct this directly inside components or server functions. Instead use [`server_context()`] to get the server context from the current request.
11///
12/// # Example
13///
14/// ```rust, no_run
15/// # use dioxus::prelude::*;
16/// #[server]
17/// async fn read_headers() -> Result<(), ServerFnError> {
18///     let server_context = server_context();
19///     let headers: http::HeaderMap = server_context.extract().await?;
20///     println!("{:?}", headers);
21///     Ok(())
22/// }
23/// ```
24#[derive(Clone)]
25pub struct DioxusServerContext {
26    shared_context: std::sync::Arc<RwLock<SendSyncAnyMap>>,
27    response_parts: std::sync::Arc<RwLock<http::response::Parts>>,
28    pub(crate) parts: Arc<RwLock<http::request::Parts>>,
29}
30
31enum ContextType {
32    Factory(Box<dyn Fn() -> Box<dyn Any> + Send + Sync>),
33    Value(Box<dyn Any + Send + Sync>),
34}
35
36impl ContextType {
37    fn downcast<T: Clone + 'static>(&self) -> Option<T> {
38        match self {
39            ContextType::Value(value) => value.downcast_ref::<T>().cloned(),
40            ContextType::Factory(factory) => factory().downcast::<T>().ok().map(|v| *v),
41        }
42    }
43}
44
45#[allow(clippy::derivable_impls)]
46impl Default for DioxusServerContext {
47    fn default() -> Self {
48        Self {
49            shared_context: std::sync::Arc::new(RwLock::new(HashMap::new())),
50            response_parts: std::sync::Arc::new(RwLock::new(
51                http::response::Response::new(()).into_parts().0,
52            )),
53            parts: std::sync::Arc::new(RwLock::new(http::request::Request::new(()).into_parts().0)),
54        }
55    }
56}
57
58mod server_fn_impl {
59    use super::*;
60    use parking_lot::{RwLockReadGuard, RwLockWriteGuard};
61    use std::any::{Any, TypeId};
62
63    impl DioxusServerContext {
64        /// Create a new server context from a request
65        pub fn new(parts: http::request::Parts) -> Self {
66            Self {
67                parts: Arc::new(RwLock::new(parts)),
68                shared_context: Arc::new(RwLock::new(SendSyncAnyMap::new())),
69                response_parts: std::sync::Arc::new(RwLock::new(
70                    http::response::Response::new(()).into_parts().0,
71                )),
72            }
73        }
74
75        /// Create a server context from a shared parts
76        #[allow(unused)]
77        pub(crate) fn from_shared_parts(parts: Arc<RwLock<http::request::Parts>>) -> Self {
78            Self {
79                parts,
80                shared_context: Arc::new(RwLock::new(SendSyncAnyMap::new())),
81                response_parts: std::sync::Arc::new(RwLock::new(
82                    http::response::Response::new(()).into_parts().0,
83                )),
84            }
85        }
86
87        /// Clone a value from the shared server context. If you are using [`DioxusRouterExt`](crate::prelude::DioxusRouterExt), any values you insert into
88        /// the launch context will also be available in the server context.
89        ///
90        /// Example:
91        /// ```rust, no_run
92        /// use dioxus::prelude::*;
93        ///
94        /// LaunchBuilder::new()
95        ///     // You can provide context to your whole app (including server functions) with the `with_context` method on the launch builder
96        ///     .with_context(server_only! {
97        ///         1234567890u32
98        ///     })
99        ///     .launch(app);
100        ///
101        /// #[server]
102        /// async fn read_context() -> Result<u32, ServerFnError> {
103        ///     // You can extract values from the server context with the `extract` function
104        ///     let FromContext(value) = extract().await?;
105        ///     Ok(value)
106        /// }
107        ///
108        /// fn app() -> Element {
109        ///     let future = use_resource(read_context);
110        ///     rsx! {
111        ///         h1 { "{future:?}" }
112        ///     }
113        /// }
114        /// ```
115        pub fn get<T: Any + Send + Sync + Clone + 'static>(&self) -> Option<T> {
116            self.shared_context
117                .read()
118                .get(&TypeId::of::<T>())
119                .map(|v| v.downcast::<T>().unwrap())
120        }
121
122        /// Insert a value into the shared server context
123        pub fn insert<T: Any + Send + Sync + 'static>(&self, value: T) {
124            self.insert_any(Box::new(value));
125        }
126
127        /// Insert a boxed `Any` value into the shared server context
128        pub fn insert_any(&self, value: Box<dyn Any + Send + Sync + 'static>) {
129            self.shared_context
130                .write()
131                .insert((*value).type_id(), ContextType::Value(value));
132        }
133
134        /// Insert a factory that creates a non-sync value for the shared server context
135        pub fn insert_factory<F, T>(&self, value: F)
136        where
137            F: Fn() -> T + Send + Sync + 'static,
138            T: 'static,
139        {
140            self.shared_context.write().insert(
141                TypeId::of::<T>(),
142                ContextType::Factory(Box::new(move || Box::new(value()))),
143            );
144        }
145
146        /// Insert a boxed factory that creates a non-sync value for the shared server context
147        pub fn insert_boxed_factory(&self, value: Box<dyn Fn() -> Box<dyn Any> + Send + Sync>) {
148            self.shared_context
149                .write()
150                .insert((*value()).type_id(), ContextType::Factory(value));
151        }
152
153        /// Get the response parts from the server context
154        ///
155        #[doc = include_str!("../docs/request_origin.md")]
156        ///
157        /// # Example
158        ///
159        /// ```rust, no_run
160        /// # use dioxus::prelude::*;
161        /// #[server]
162        /// async fn set_headers() -> Result<(), ServerFnError> {
163        ///     let server_context = server_context();
164        ///     let response_parts = server_context.response_parts();
165        ///     let cookies = response_parts
166        ///         .headers
167        ///         .get("Cookie")
168        ///         .ok_or_else(|| ServerFnError::new("failed to find Cookie header in the response"))?;
169        ///     println!("{:?}", cookies);
170        ///     Ok(())
171        /// }
172        /// ```
173        pub fn response_parts(&self) -> RwLockReadGuard<'_, http::response::Parts> {
174            self.response_parts.read()
175        }
176
177        /// Get the response parts from the server context
178        ///
179        #[doc = include_str!("../docs/request_origin.md")]
180        ///
181        /// # Example
182        ///
183        /// ```rust, no_run
184        /// # use dioxus::prelude::*;
185        /// #[server]
186        /// async fn set_headers() -> Result<(), ServerFnError> {
187        ///     let server_context = server_context();
188        ///     server_context.response_parts_mut()
189        ///         .headers
190        ///         .insert("Cookie", http::HeaderValue::from_static("dioxus=fullstack"));
191        ///     Ok(())
192        /// }
193        /// ```
194        pub fn response_parts_mut(&self) -> RwLockWriteGuard<'_, http::response::Parts> {
195            self.response_parts.write()
196        }
197
198        /// Get the request parts
199        ///
200        #[doc = include_str!("../docs/request_origin.md")]
201        ///
202        /// # Example
203        ///
204        /// ```rust, no_run
205        /// # use dioxus::prelude::*;
206        /// #[server]
207        /// async fn read_headers() -> Result<(), ServerFnError> {
208        ///     let server_context = server_context();
209        ///     let request_parts = server_context.request_parts();
210        ///     let id: &i32 = request_parts
211        ///         .extensions
212        ///         .get()
213        ///         .ok_or_else(|| ServerFnError::new("failed to find i32 extension in the request"))?;
214        ///     println!("{:?}", id);
215        ///     Ok(())
216        /// }
217        /// ```
218        pub fn request_parts(&self) -> parking_lot::RwLockReadGuard<'_, http::request::Parts> {
219            self.parts.read()
220        }
221
222        /// Get the request parts mutably
223        ///
224        #[doc = include_str!("../docs/request_origin.md")]
225        ///
226        /// # Example
227        ///
228        /// ```rust, no_run
229        /// # use dioxus::prelude::*;
230        /// #[server]
231        /// async fn read_headers() -> Result<(), ServerFnError> {
232        ///     let server_context = server_context();
233        ///     let id: i32 = server_context.request_parts_mut()
234        ///         .extensions
235        ///         .remove()
236        ///         .ok_or_else(|| ServerFnError::new("failed to find i32 extension in the request"))?;
237        ///     println!("{:?}", id);
238        ///     Ok(())
239        /// }
240        /// ```
241        pub fn request_parts_mut(&self) -> parking_lot::RwLockWriteGuard<'_, http::request::Parts> {
242            self.parts.write()
243        }
244
245        /// Extract part of the request.
246        ///
247        #[doc = include_str!("../docs/request_origin.md")]
248        ///
249        /// # Example
250        ///
251        /// ```rust, no_run
252        /// # use dioxus::prelude::*;
253        /// #[server]
254        /// async fn read_headers() -> Result<(), ServerFnError> {
255        ///     let server_context = server_context();
256        ///     let headers: http::HeaderMap = server_context.extract().await?;
257        ///     println!("{:?}", headers);
258        ///     Ok(())
259        /// }
260        /// ```
261        pub async fn extract<M, T: FromServerContext<M>>(&self) -> Result<T, T::Rejection> {
262            T::from_request(self).await
263        }
264    }
265}
266
267#[test]
268fn server_context_as_any_map() {
269    let parts = http::Request::new(()).into_parts().0;
270    let server_context = DioxusServerContext::new(parts);
271    server_context.insert_boxed_factory(Box::new(|| Box::new(1234u32)));
272    assert_eq!(server_context.get::<u32>().unwrap(), 1234u32);
273}
274
275std::thread_local! {
276    pub(crate) static SERVER_CONTEXT: std::cell::RefCell<Box<DioxusServerContext>> = Default::default();
277}
278
279/// Get information about the current server request.
280///
281/// This function will only provide the current server context if it is called from a server function or on the server rendering a request.
282pub fn server_context() -> DioxusServerContext {
283    SERVER_CONTEXT.with(|ctx| *ctx.borrow().clone())
284}
285
286/// Extract some part from the current server request.
287///
288/// This function will only provide the current server context if it is called from a server function or on the server rendering a request.
289pub async fn extract<E: FromServerContext<I>, I>() -> Result<E, E::Rejection> {
290    E::from_request(&server_context()).await
291}
292
293/// Run a function inside of the server context.
294pub fn with_server_context<O>(context: DioxusServerContext, f: impl FnOnce() -> O) -> O {
295    // before polling the future, we need to set the context
296    let prev_context = SERVER_CONTEXT.with(|ctx| ctx.replace(Box::new(context)));
297    // poll the future, which may call server_context()
298    let result = f();
299    // after polling the future, we need to restore the context
300    SERVER_CONTEXT.with(|ctx| ctx.replace(prev_context));
301    result
302}
303
304/// A future that provides the server context to the inner future
305#[pin_project::pin_project]
306pub struct ProvideServerContext<F: std::future::Future> {
307    context: DioxusServerContext,
308    #[pin]
309    f: F,
310}
311
312impl<F: std::future::Future> ProvideServerContext<F> {
313    /// Create a new future that provides the server context to the inner future
314    pub fn new(f: F, context: DioxusServerContext) -> Self {
315        Self { f, context }
316    }
317}
318
319impl<F: std::future::Future> std::future::Future for ProvideServerContext<F> {
320    type Output = F::Output;
321
322    fn poll(
323        self: std::pin::Pin<&mut Self>,
324        cx: &mut std::task::Context<'_>,
325    ) -> std::task::Poll<Self::Output> {
326        let this = self.project();
327        let context = this.context.clone();
328        with_server_context(context, || this.f.poll(cx))
329    }
330}
331
332/// A trait for extracting types from the server context
333#[async_trait::async_trait]
334pub trait FromServerContext<I = ()>: Sized {
335    /// The error type returned when extraction fails. This type must implement `std::error::Error`.
336    type Rejection;
337
338    /// Extract this type from the server context.
339    async fn from_request(req: &DioxusServerContext) -> Result<Self, Self::Rejection>;
340}
341
342/// A type was not found in the server context
343pub struct NotFoundInServerContext<T: 'static>(std::marker::PhantomData<T>);
344
345impl<T: 'static> std::fmt::Debug for NotFoundInServerContext<T> {
346    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
347        let type_name = std::any::type_name::<T>();
348        write!(f, "`{type_name}` not found in server context")
349    }
350}
351
352impl<T: 'static> std::fmt::Display for NotFoundInServerContext<T> {
353    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
354        let type_name = std::any::type_name::<T>();
355        write!(f, "`{type_name}` not found in server context")
356    }
357}
358
359impl<T: 'static> std::error::Error for NotFoundInServerContext<T> {}
360
361/// Extract a value from the server context provided through the launch builder context or [`DioxusServerContext::insert`]
362///
363/// Example:
364/// ```rust, no_run
365/// use dioxus::prelude::*;
366///
367/// dioxus::LaunchBuilder::new()
368///     // You can provide context to your whole app (including server functions) with the `with_context` method on the launch builder
369///     .with_context(server_only! {
370///         1234567890u32
371///     })
372///     .launch(app);
373///
374/// #[server]
375/// async fn read_context() -> Result<u32, ServerFnError> {
376///     // You can extract values from the server context with the `extract` function
377///     let FromContext(value) = extract().await?;
378///     Ok(value)
379/// }
380///
381/// fn app() -> Element {
382///     let future = use_resource(read_context);
383///     rsx! {
384///         h1 { "{future:?}" }
385///     }
386/// }
387/// ```
388pub struct FromContext<T: std::marker::Send + std::marker::Sync + Clone + 'static>(pub T);
389
390#[async_trait::async_trait]
391impl<T: Send + Sync + Clone + 'static> FromServerContext for FromContext<T> {
392    type Rejection = NotFoundInServerContext<T>;
393
394    async fn from_request(req: &DioxusServerContext) -> Result<Self, Self::Rejection> {
395        Ok(Self(req.get::<T>().ok_or({
396            NotFoundInServerContext::<T>(std::marker::PhantomData::<T>)
397        })?))
398    }
399}
400
401#[cfg(feature = "axum")]
402#[cfg_attr(docsrs, doc(cfg(feature = "axum")))]
403/// An adapter for axum extractors for the server context
404pub struct Axum;
405
406#[cfg(feature = "axum")]
407#[async_trait::async_trait]
408impl<I: axum::extract::FromRequestParts<()>> FromServerContext<Axum> for I {
409    type Rejection = I::Rejection;
410
411    #[allow(clippy::all)]
412    async fn from_request(req: &DioxusServerContext) -> Result<Self, Self::Rejection> {
413        let mut lock = req.request_parts_mut();
414        I::from_request_parts(&mut lock, &()).await
415    }
416}