rocket_community/
state.rs

1use std::any::type_name;
2use std::fmt;
3use std::ops::Deref;
4
5use ref_cast::RefCast;
6
7use crate::http::Status;
8use crate::outcome::Outcome;
9use crate::request::{self, FromRequest, Request};
10use crate::{Ignite, Phase, Rocket, Sentinel};
11
12/// Request guard to retrieve managed state.
13///
14/// A reference `&State<T>` type is a request guard which retrieves the managed
15/// state managing for some type `T`. A value for the given type must previously
16/// have been registered to be managed by Rocket via [`Rocket::manage()`]. The
17/// type being managed must be thread safe and sendable across thread
18/// boundaries as multiple handlers in multiple threads may be accessing the
19/// value at once. In other words, it must implement [`Send`] + [`Sync`] +
20/// `'static`.
21///
22/// # Example
23///
24/// Imagine you have some configuration struct of the type `MyConfig` that you'd
25/// like to initialize at start-up and later access it in several handlers. The
26/// following example does just this:
27///
28/// ```rust,no_run
29/// # #[macro_use] extern crate rocket_community as rocket;
30/// use rocket::State;
31///
32/// // In a real application, this would likely be more complex.
33/// struct MyConfig {
34///     user_val: String
35/// }
36///
37/// #[get("/")]
38/// fn index(state: &State<MyConfig>) -> String {
39///     format!("The config value is: {}", state.user_val)
40/// }
41///
42/// #[get("/raw")]
43/// fn raw_config_value(state: &State<MyConfig>) -> &str {
44///     &state.user_val
45/// }
46///
47/// #[launch]
48/// fn rocket() -> _ {
49///     rocket::build()
50///         .mount("/", routes![index, raw_config_value])
51///         .manage(MyConfig { user_val: "user input".to_string() })
52/// }
53/// ```
54///
55/// # Within Request Guards
56///
57/// Because `State` is itself a request guard, managed state can be retrieved
58/// from another request guard's implementation using either
59/// [`Request::guard()`] or [`Rocket::state()`]. In the following code example,
60/// the `Item` request guard retrieves `MyConfig` from managed state:
61///
62/// ```rust
63/// extern crate rocket_community as rocket;
64///
65/// use rocket::State;
66/// use rocket::request::{self, Request, FromRequest};
67/// use rocket::outcome::IntoOutcome;
68/// use rocket::http::Status;
69///
70/// # struct MyConfig { user_val: String };
71/// struct Item<'r>(&'r str);
72///
73/// #[rocket::async_trait]
74/// impl<'r> FromRequest<'r> for Item<'r> {
75///     type Error = ();
76///
77///     async fn from_request(request: &'r Request<'_>) -> request::Outcome<Self, ()> {
78///         // Using `State` as a request guard. Use `inner()` to get an `'r`.
79///         let outcome = request.guard::<&State<MyConfig>>().await
80///             .map(|my_config| Item(&my_config.user_val));
81///
82///         // Or alternatively, using `Rocket::state()`:
83///         let outcome = request.rocket().state::<MyConfig>()
84///             .map(|my_config| Item(&my_config.user_val))
85///             .or_forward(Status::InternalServerError);
86///
87///         outcome
88///     }
89/// }
90/// ```
91///
92/// # Testing with `State`
93///
94/// When unit testing your application, you may find it necessary to manually
95/// construct a type of `State` to pass to your functions. To do so, use the
96/// [`State::get()`] static method or the `From<&T>` implementation:
97///
98/// ```rust
99/// # #[macro_use] extern crate rocket_community as rocket;
100/// use rocket::State;
101///
102/// struct MyManagedState(usize);
103///
104/// #[get("/")]
105/// fn handler(state: &State<MyManagedState>) -> String {
106///     state.0.to_string()
107/// }
108///
109/// let mut rocket = rocket::build().manage(MyManagedState(127));
110/// let state = State::get(&rocket).expect("managed `MyManagedState`");
111/// assert_eq!(handler(state), "127");
112///
113/// let managed = MyManagedState(77);
114/// assert_eq!(handler(State::from(&managed)), "77");
115/// ```
116#[repr(transparent)]
117#[derive(RefCast, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
118pub struct State<T: Send + Sync + 'static>(T);
119
120impl<T: Send + Sync + 'static> State<T> {
121    /// Returns the managed state value in `rocket` for the type `T` if it is
122    /// being managed by `rocket`. Otherwise, returns `None`.
123    ///
124    /// # Example
125    ///
126    /// ```rust
127    /// extern crate rocket_community as rocket;
128    ///
129    /// use rocket::State;
130    ///
131    /// #[derive(Debug, PartialEq)]
132    /// struct Managed(usize);
133    ///
134    /// #[derive(Debug, PartialEq)]
135    /// struct Unmanaged(usize);
136    ///
137    /// let rocket = rocket::build().manage(Managed(7));
138    ///
139    /// let state: Option<&State<Managed>> = State::get(&rocket);
140    /// assert_eq!(state.map(|s| s.inner()), Some(&Managed(7)));
141    ///
142    /// let state: Option<&State<Unmanaged>> = State::get(&rocket);
143    /// assert_eq!(state, None);
144    /// ```
145    #[inline(always)]
146    pub fn get<P: Phase>(rocket: &Rocket<P>) -> Option<&State<T>> {
147        rocket.state::<T>().map(State::ref_cast)
148    }
149
150    /// This exists because `State::from()` would otherwise be nothing. But we
151    /// want `State::from(&foo)` to give us `<&State>::from(&foo)`. Here it is.
152    #[doc(hidden)]
153    #[inline(always)]
154    pub fn from(value: &T) -> &State<T> {
155        State::ref_cast(value)
156    }
157
158    /// Borrow the inner value.
159    ///
160    /// Using this method is typically unnecessary as `State` implements
161    /// [`Deref`] with a [`Deref::Target`] of `T`. This means Rocket will
162    /// automatically coerce a `State<T>` to an `&T` as required. This method
163    /// should only be used when a longer lifetime is required.
164    ///
165    /// # Example
166    ///
167    /// ```rust
168    /// extern crate rocket_community as rocket;
169    ///
170    /// use rocket::State;
171    ///
172    /// #[derive(Clone)]
173    /// struct MyConfig {
174    ///     user_val: String
175    /// }
176    ///
177    /// fn handler1<'r>(config: &State<MyConfig>) -> String {
178    ///     let config = config.inner().clone();
179    ///     config.user_val
180    /// }
181    ///
182    /// // Use the `Deref` implementation which coerces implicitly
183    /// fn handler2(config: &State<MyConfig>) -> String {
184    ///     config.user_val.clone()
185    /// }
186    /// ```
187    #[inline(always)]
188    pub fn inner(&self) -> &T {
189        &self.0
190    }
191}
192
193impl<'r, T: Send + Sync + 'static> From<&'r T> for &'r State<T> {
194    #[inline(always)]
195    fn from(reference: &'r T) -> Self {
196        State::ref_cast(reference)
197    }
198}
199
200#[crate::async_trait]
201impl<'r, T: Send + Sync + 'static> FromRequest<'r> for &'r State<T> {
202    type Error = ();
203
204    #[inline(always)]
205    async fn from_request(req: &'r Request<'_>) -> request::Outcome<Self, ()> {
206        match State::get(req.rocket()) {
207            Some(state) => Outcome::Success(state),
208            None => {
209                error!(
210                    type_name = type_name::<T>(),
211                    "retrieving unmanaged state\n\
212                state must be managed via `rocket.manage()`"
213                );
214
215                Outcome::Error((Status::InternalServerError, ()))
216            }
217        }
218    }
219}
220
221impl<T: Send + Sync + 'static> Sentinel for &State<T> {
222    fn abort(rocket: &Rocket<Ignite>) -> bool {
223        if rocket.state::<T>().is_none() {
224            error!(
225                type_name = type_name::<T>(),
226                "unmanaged state detected\n\
227                ensure type is being managed via `rocket.manage()`"
228            );
229
230            return true;
231        }
232
233        false
234    }
235}
236
237impl<T: Send + Sync + fmt::Display + 'static> fmt::Display for State<T> {
238    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
239        self.0.fmt(f)
240    }
241}
242
243impl<T: Send + Sync + 'static> Deref for State<T> {
244    type Target = T;
245
246    #[inline(always)]
247    fn deref(&self) -> &T {
248        &self.0
249    }
250}