dioxus_fullstack/server_context.rs
1use parking_lot::RwLock;
2use std::any::Any;
3use std::collections::HashMap;
4use std::sync::Arc;
5
6type SendSyncAnyMap = std::collections::HashMap<std::any::TypeId, ContextType>;
7
8/// A shared context for server functions that contains information about the request and middleware state.
9///
10/// You should not construct this directly inside components or server functions. Instead use [`server_context()`] to get the server context from the current request.
11///
12/// # Example
13///
14/// ```rust, no_run
15/// # use dioxus::prelude::*;
16/// #[server]
17/// async fn read_headers() -> Result<(), ServerFnError> {
18/// let server_context = server_context();
19/// let headers: http::HeaderMap = server_context.extract().await?;
20/// println!("{:?}", headers);
21/// Ok(())
22/// }
23/// ```
24#[derive(Clone)]
25pub struct DioxusServerContext {
26 shared_context: std::sync::Arc<RwLock<SendSyncAnyMap>>,
27 response_parts: std::sync::Arc<RwLock<http::response::Parts>>,
28 pub(crate) parts: Arc<RwLock<http::request::Parts>>,
29}
30
31enum ContextType {
32 Factory(Box<dyn Fn() -> Box<dyn Any> + Send + Sync>),
33 Value(Box<dyn Any + Send + Sync>),
34}
35
36impl ContextType {
37 fn downcast<T: Clone + 'static>(&self) -> Option<T> {
38 match self {
39 ContextType::Value(value) => value.downcast_ref::<T>().cloned(),
40 ContextType::Factory(factory) => factory().downcast::<T>().ok().map(|v| *v),
41 }
42 }
43}
44
45#[allow(clippy::derivable_impls)]
46impl Default for DioxusServerContext {
47 fn default() -> Self {
48 Self {
49 shared_context: std::sync::Arc::new(RwLock::new(HashMap::new())),
50 response_parts: std::sync::Arc::new(RwLock::new(
51 http::response::Response::new(()).into_parts().0,
52 )),
53 parts: std::sync::Arc::new(RwLock::new(http::request::Request::new(()).into_parts().0)),
54 }
55 }
56}
57
58mod server_fn_impl {
59 use super::*;
60 use parking_lot::{RwLockReadGuard, RwLockWriteGuard};
61 use std::any::{Any, TypeId};
62
63 impl DioxusServerContext {
64 /// Create a new server context from a request
65 pub fn new(parts: http::request::Parts) -> Self {
66 Self {
67 parts: Arc::new(RwLock::new(parts)),
68 shared_context: Arc::new(RwLock::new(SendSyncAnyMap::new())),
69 response_parts: std::sync::Arc::new(RwLock::new(
70 http::response::Response::new(()).into_parts().0,
71 )),
72 }
73 }
74
75 /// Create a server context from a shared parts
76 #[allow(unused)]
77 pub(crate) fn from_shared_parts(parts: Arc<RwLock<http::request::Parts>>) -> Self {
78 Self {
79 parts,
80 shared_context: Arc::new(RwLock::new(SendSyncAnyMap::new())),
81 response_parts: std::sync::Arc::new(RwLock::new(
82 http::response::Response::new(()).into_parts().0,
83 )),
84 }
85 }
86
87 /// Clone a value from the shared server context. If you are using [`DioxusRouterExt`](crate::prelude::DioxusRouterExt), any values you insert into
88 /// the launch context will also be available in the server context.
89 ///
90 /// Example:
91 /// ```rust, no_run
92 /// use dioxus::prelude::*;
93 ///
94 /// LaunchBuilder::new()
95 /// // You can provide context to your whole app (including server functions) with the `with_context` method on the launch builder
96 /// .with_context(server_only! {
97 /// 1234567890u32
98 /// })
99 /// .launch(app);
100 ///
101 /// #[server]
102 /// async fn read_context() -> Result<u32, ServerFnError> {
103 /// // You can extract values from the server context with the `extract` function
104 /// let FromContext(value) = extract().await?;
105 /// Ok(value)
106 /// }
107 ///
108 /// fn app() -> Element {
109 /// let future = use_resource(read_context);
110 /// rsx! {
111 /// h1 { "{future:?}" }
112 /// }
113 /// }
114 /// ```
115 pub fn get<T: Any + Send + Sync + Clone + 'static>(&self) -> Option<T> {
116 self.shared_context
117 .read()
118 .get(&TypeId::of::<T>())
119 .map(|v| v.downcast::<T>().unwrap())
120 }
121
122 /// Insert a value into the shared server context
123 pub fn insert<T: Any + Send + Sync + 'static>(&self, value: T) {
124 self.insert_any(Box::new(value));
125 }
126
127 /// Insert a boxed `Any` value into the shared server context
128 pub fn insert_any(&self, value: Box<dyn Any + Send + Sync + 'static>) {
129 self.shared_context
130 .write()
131 .insert((*value).type_id(), ContextType::Value(value));
132 }
133
134 /// Insert a factory that creates a non-sync value for the shared server context
135 pub fn insert_factory<F, T>(&self, value: F)
136 where
137 F: Fn() -> T + Send + Sync + 'static,
138 T: 'static,
139 {
140 self.shared_context.write().insert(
141 TypeId::of::<T>(),
142 ContextType::Factory(Box::new(move || Box::new(value()))),
143 );
144 }
145
146 /// Insert a boxed factory that creates a non-sync value for the shared server context
147 pub fn insert_boxed_factory(&self, value: Box<dyn Fn() -> Box<dyn Any> + Send + Sync>) {
148 self.shared_context
149 .write()
150 .insert((*value()).type_id(), ContextType::Factory(value));
151 }
152
153 /// Get the response parts from the server context
154 ///
155 #[doc = include_str!("../docs/request_origin.md")]
156 ///
157 /// # Example
158 ///
159 /// ```rust, no_run
160 /// # use dioxus::prelude::*;
161 /// #[server]
162 /// async fn set_headers() -> Result<(), ServerFnError> {
163 /// let server_context = server_context();
164 /// let response_parts = server_context.response_parts();
165 /// let cookies = response_parts
166 /// .headers
167 /// .get("Cookie")
168 /// .ok_or_else(|| ServerFnError::new("failed to find Cookie header in the response"))?;
169 /// println!("{:?}", cookies);
170 /// Ok(())
171 /// }
172 /// ```
173 pub fn response_parts(&self) -> RwLockReadGuard<'_, http::response::Parts> {
174 self.response_parts.read()
175 }
176
177 /// Get the response parts from the server context
178 ///
179 #[doc = include_str!("../docs/request_origin.md")]
180 ///
181 /// # Example
182 ///
183 /// ```rust, no_run
184 /// # use dioxus::prelude::*;
185 /// #[server]
186 /// async fn set_headers() -> Result<(), ServerFnError> {
187 /// let server_context = server_context();
188 /// server_context.response_parts_mut()
189 /// .headers
190 /// .insert("Cookie", http::HeaderValue::from_static("dioxus=fullstack"));
191 /// Ok(())
192 /// }
193 /// ```
194 pub fn response_parts_mut(&self) -> RwLockWriteGuard<'_, http::response::Parts> {
195 self.response_parts.write()
196 }
197
198 /// Get the request parts
199 ///
200 #[doc = include_str!("../docs/request_origin.md")]
201 ///
202 /// # Example
203 ///
204 /// ```rust, no_run
205 /// # use dioxus::prelude::*;
206 /// #[server]
207 /// async fn read_headers() -> Result<(), ServerFnError> {
208 /// let server_context = server_context();
209 /// let request_parts = server_context.request_parts();
210 /// let id: &i32 = request_parts
211 /// .extensions
212 /// .get()
213 /// .ok_or_else(|| ServerFnError::new("failed to find i32 extension in the request"))?;
214 /// println!("{:?}", id);
215 /// Ok(())
216 /// }
217 /// ```
218 pub fn request_parts(&self) -> parking_lot::RwLockReadGuard<'_, http::request::Parts> {
219 self.parts.read()
220 }
221
222 /// Get the request parts mutably
223 ///
224 #[doc = include_str!("../docs/request_origin.md")]
225 ///
226 /// # Example
227 ///
228 /// ```rust, no_run
229 /// # use dioxus::prelude::*;
230 /// #[server]
231 /// async fn read_headers() -> Result<(), ServerFnError> {
232 /// let server_context = server_context();
233 /// let id: i32 = server_context.request_parts_mut()
234 /// .extensions
235 /// .remove()
236 /// .ok_or_else(|| ServerFnError::new("failed to find i32 extension in the request"))?;
237 /// println!("{:?}", id);
238 /// Ok(())
239 /// }
240 /// ```
241 pub fn request_parts_mut(&self) -> parking_lot::RwLockWriteGuard<'_, http::request::Parts> {
242 self.parts.write()
243 }
244
245 /// Extract part of the request.
246 ///
247 #[doc = include_str!("../docs/request_origin.md")]
248 ///
249 /// # Example
250 ///
251 /// ```rust, no_run
252 /// # use dioxus::prelude::*;
253 /// #[server]
254 /// async fn read_headers() -> Result<(), ServerFnError> {
255 /// let server_context = server_context();
256 /// let headers: http::HeaderMap = server_context.extract().await?;
257 /// println!("{:?}", headers);
258 /// Ok(())
259 /// }
260 /// ```
261 pub async fn extract<M, T: FromServerContext<M>>(&self) -> Result<T, T::Rejection> {
262 T::from_request(self).await
263 }
264 }
265}
266
267#[test]
268fn server_context_as_any_map() {
269 let parts = http::Request::new(()).into_parts().0;
270 let server_context = DioxusServerContext::new(parts);
271 server_context.insert_boxed_factory(Box::new(|| Box::new(1234u32)));
272 assert_eq!(server_context.get::<u32>().unwrap(), 1234u32);
273}
274
275std::thread_local! {
276 pub(crate) static SERVER_CONTEXT: std::cell::RefCell<Box<DioxusServerContext>> = Default::default();
277}
278
279/// Get information about the current server request.
280///
281/// This function will only provide the current server context if it is called from a server function or on the server rendering a request.
282pub fn server_context() -> DioxusServerContext {
283 SERVER_CONTEXT.with(|ctx| *ctx.borrow().clone())
284}
285
286/// Extract some part from the current server request.
287///
288/// This function will only provide the current server context if it is called from a server function or on the server rendering a request.
289pub async fn extract<E: FromServerContext<I>, I>() -> Result<E, E::Rejection> {
290 E::from_request(&server_context()).await
291}
292
293/// Run a function inside of the server context.
294pub fn with_server_context<O>(context: DioxusServerContext, f: impl FnOnce() -> O) -> O {
295 // before polling the future, we need to set the context
296 let prev_context = SERVER_CONTEXT.with(|ctx| ctx.replace(Box::new(context)));
297 // poll the future, which may call server_context()
298 let result = f();
299 // after polling the future, we need to restore the context
300 SERVER_CONTEXT.with(|ctx| ctx.replace(prev_context));
301 result
302}
303
304/// A future that provides the server context to the inner future
305#[pin_project::pin_project]
306pub struct ProvideServerContext<F: std::future::Future> {
307 context: DioxusServerContext,
308 #[pin]
309 f: F,
310}
311
312impl<F: std::future::Future> ProvideServerContext<F> {
313 /// Create a new future that provides the server context to the inner future
314 pub fn new(f: F, context: DioxusServerContext) -> Self {
315 Self { f, context }
316 }
317}
318
319impl<F: std::future::Future> std::future::Future for ProvideServerContext<F> {
320 type Output = F::Output;
321
322 fn poll(
323 self: std::pin::Pin<&mut Self>,
324 cx: &mut std::task::Context<'_>,
325 ) -> std::task::Poll<Self::Output> {
326 let this = self.project();
327 let context = this.context.clone();
328 with_server_context(context, || this.f.poll(cx))
329 }
330}
331
332/// A trait for extracting types from the server context
333#[async_trait::async_trait]
334pub trait FromServerContext<I = ()>: Sized {
335 /// The error type returned when extraction fails. This type must implement `std::error::Error`.
336 type Rejection;
337
338 /// Extract this type from the server context.
339 async fn from_request(req: &DioxusServerContext) -> Result<Self, Self::Rejection>;
340}
341
342/// A type was not found in the server context
343pub struct NotFoundInServerContext<T: 'static>(std::marker::PhantomData<T>);
344
345impl<T: 'static> std::fmt::Debug for NotFoundInServerContext<T> {
346 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
347 let type_name = std::any::type_name::<T>();
348 write!(f, "`{type_name}` not found in server context")
349 }
350}
351
352impl<T: 'static> std::fmt::Display for NotFoundInServerContext<T> {
353 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
354 let type_name = std::any::type_name::<T>();
355 write!(f, "`{type_name}` not found in server context")
356 }
357}
358
359impl<T: 'static> std::error::Error for NotFoundInServerContext<T> {}
360
361/// Extract a value from the server context provided through the launch builder context or [`DioxusServerContext::insert`]
362///
363/// Example:
364/// ```rust, no_run
365/// use dioxus::prelude::*;
366///
367/// dioxus::LaunchBuilder::new()
368/// // You can provide context to your whole app (including server functions) with the `with_context` method on the launch builder
369/// .with_context(server_only! {
370/// 1234567890u32
371/// })
372/// .launch(app);
373///
374/// #[server]
375/// async fn read_context() -> Result<u32, ServerFnError> {
376/// // You can extract values from the server context with the `extract` function
377/// let FromContext(value) = extract().await?;
378/// Ok(value)
379/// }
380///
381/// fn app() -> Element {
382/// let future = use_resource(read_context);
383/// rsx! {
384/// h1 { "{future:?}" }
385/// }
386/// }
387/// ```
388pub struct FromContext<T: std::marker::Send + std::marker::Sync + Clone + 'static>(pub T);
389
390#[async_trait::async_trait]
391impl<T: Send + Sync + Clone + 'static> FromServerContext for FromContext<T> {
392 type Rejection = NotFoundInServerContext<T>;
393
394 async fn from_request(req: &DioxusServerContext) -> Result<Self, Self::Rejection> {
395 Ok(Self(req.get::<T>().ok_or({
396 NotFoundInServerContext::<T>(std::marker::PhantomData::<T>)
397 })?))
398 }
399}
400
401#[cfg(feature = "axum")]
402#[cfg_attr(docsrs, doc(cfg(feature = "axum")))]
403/// An adapter for axum extractors for the server context
404pub struct Axum;
405
406#[cfg(feature = "axum")]
407#[async_trait::async_trait]
408impl<I: axum::extract::FromRequestParts<()>> FromServerContext<Axum> for I {
409 type Rejection = I::Rejection;
410
411 #[allow(clippy::all)]
412 async fn from_request(req: &DioxusServerContext) -> Result<Self, Self::Rejection> {
413 let mut lock = req.request_parts_mut();
414 I::from_request_parts(&mut lock, &()).await
415 }
416}