roa_core/
context.rs

1mod storage;
2
3use std::any::Any;
4use std::borrow::Cow;
5use std::net::SocketAddr;
6use std::ops::{Deref, DerefMut};
7use std::sync::Arc;
8
9use http::header::AsHeaderName;
10use http::{Method, StatusCode, Uri, Version};
11pub use storage::Variable;
12use storage::{Storage, Value};
13
14use crate::{status, Executor, Request, Response};
15
16/// A structure to share request, response and other data between middlewares.
17///
18/// ### Example
19///
20/// ```rust
21/// use roa_core::{App, Context, Next, Result};
22/// use tracing::info;
23/// use tokio::fs::File;
24///
25/// let app = App::new().gate(gate).end(end);
26/// async fn gate(ctx: &mut Context, next: Next<'_>) -> Result {
27///     info!("{} {}", ctx.method(), ctx.uri());
28///     next.await
29/// }
30///
31/// async fn end(ctx: &mut Context) -> Result {
32///     ctx.resp.write_reader(File::open("assets/welcome.html").await?);
33///     Ok(())
34/// }
35/// ```
36pub struct Context<S = ()> {
37    /// The request, to read http method, uri, version, headers and body.
38    pub req: Request,
39
40    /// The response, to set http status, version, headers and body.
41    pub resp: Response,
42
43    /// The executor, to spawn futures or blocking works.
44    pub exec: Executor,
45
46    /// Socket addr of last client or proxy.
47    pub remote_addr: SocketAddr,
48
49    storage: Storage,
50    state: S,
51}
52
53impl<S> Context<S> {
54    /// Construct a context from a request, an app and a addr_stream.
55    #[inline]
56    pub(crate) fn new(request: Request, state: S, exec: Executor, remote_addr: SocketAddr) -> Self {
57        Self {
58            req: request,
59            resp: Response::default(),
60            state,
61            exec,
62            storage: Storage::default(),
63            remote_addr,
64        }
65    }
66
67    /// Clone URI.
68    ///
69    /// ### Example
70    /// ```rust
71    /// use roa_core::{App, Context, Result};
72    ///
73    /// let app = App::new().end(get);
74    ///
75    /// async fn get(ctx: &mut Context) -> Result {
76    ///     assert_eq!("/", ctx.uri().to_string());
77    ///     Ok(())
78    /// }
79    /// ```
80    #[inline]
81    pub fn uri(&self) -> &Uri {
82        &self.req.uri
83    }
84
85    /// Clone request::method.
86    ///
87    /// ### Example
88    /// ```rust
89    /// use roa_core::{App, Context, Result};
90    /// use roa_core::http::Method;
91    ///
92    /// let app = App::new().end(get);
93    ///
94    /// async fn get(ctx: &mut Context) -> Result {
95    ///     assert_eq!(Method::GET, ctx.method());
96    ///     Ok(())
97    /// }
98    /// ```
99    #[inline]
100    pub fn method(&self) -> &Method {
101        &self.req.method
102    }
103
104    /// Search for a header value and try to get its string reference.
105    ///
106    /// ### Example
107    /// ```rust
108    /// use roa_core::{App, Context, Result};
109    /// use roa_core::http::header::CONTENT_TYPE;
110    ///
111    /// let app = App::new().end(get);
112    ///
113    /// async fn get(ctx: &mut Context) -> Result {
114    ///     assert_eq!(
115    ///         Some("text/plain"),
116    ///         ctx.get(CONTENT_TYPE),
117    ///     );
118    ///     Ok(())
119    /// }
120    /// ```
121    #[inline]
122    pub fn get(&self, name: impl AsHeaderName) -> Option<&str> {
123        self.req
124            .headers
125            .get(name)
126            .and_then(|value| value.to_str().ok())
127    }
128
129    /// Search for a header value and get its string reference.
130    ///
131    /// Otherwise return a 400 BAD REQUEST.
132    ///
133    /// ### Example
134    /// ```rust
135    /// use roa_core::{App, Context, Result};
136    /// use roa_core::http::header::CONTENT_TYPE;
137    ///
138    /// let app = App::new().end(get);
139    ///
140    /// async fn get(ctx: &mut Context) -> Result {
141    ///     assert_eq!(
142    ///         "text/plain",
143    ///         ctx.must_get(CONTENT_TYPE)?,
144    ///     );
145    ///     Ok(())
146    /// }
147    /// ```
148    #[inline]
149    pub fn must_get(&self, name: impl AsHeaderName) -> crate::Result<&str> {
150        let value = self
151            .req
152            .headers
153            .get(name)
154            .ok_or_else(|| status!(StatusCode::BAD_REQUEST))?;
155        value
156            .to_str()
157            .map_err(|err| status!(StatusCode::BAD_REQUEST, err))
158    }
159    /// Clone response::status.
160    ///
161    /// ### Example
162    /// ```rust
163    /// use roa_core::{App, Context, Result};
164    /// use roa_core::http::StatusCode;
165    ///
166    /// let app = App::new().end(get);
167    ///
168    /// async fn get(ctx: &mut Context) -> Result {
169    ///     assert_eq!(StatusCode::OK, ctx.status());
170    ///     Ok(())
171    /// }
172    /// ```
173    #[inline]
174    pub fn status(&self) -> StatusCode {
175        self.resp.status
176    }
177
178    /// Clone request::version.
179    ///
180    /// ### Example
181    /// ```rust
182    /// use roa_core::{App, Context, Result};
183    /// use roa_core::http::Version;
184    ///
185    /// let app = App::new().end(get);
186    ///
187    /// async fn get(ctx: &mut Context) -> Result {
188    ///     assert_eq!(Version::HTTP_11, ctx.version());
189    ///     Ok(())
190    /// }
191    /// ```
192    #[inline]
193    pub fn version(&self) -> Version {
194        self.req.version
195    }
196
197    /// Store key-value pair in specific scope.
198    ///
199    /// ### Example
200    /// ```rust
201    /// use roa_core::{App, Context, Result, Next};
202    ///
203    /// struct Scope;
204    /// struct AnotherScope;
205    ///
206    /// async fn gate(ctx: &mut Context, next: Next<'_>) -> Result {
207    ///     ctx.store_scoped(Scope, "id", "1".to_string());
208    ///     next.await
209    /// }
210    ///
211    /// async fn end(ctx: &mut Context) -> Result {
212    ///     assert_eq!(1, ctx.load_scoped::<Scope, String>("id").unwrap().parse::<i32>()?);
213    ///     assert!(ctx.load_scoped::<AnotherScope, String>("id").is_none());
214    ///     Ok(())
215    /// }
216    ///
217    /// let app = App::new().gate(gate).end(end);
218    /// ```
219    #[inline]
220    pub fn store_scoped<SC, K, V>(&mut self, scope: SC, key: K, value: V) -> Option<Arc<V>>
221    where
222        SC: Any,
223        K: Into<Cow<'static, str>>,
224        V: Value,
225    {
226        self.storage.insert(scope, key, value)
227    }
228
229    /// Store key-value pair in public scope.
230    ///
231    /// ### Example
232    /// ```rust
233    /// use roa_core::{App, Context, Result, Next};
234    ///
235    /// async fn gate(ctx: &mut Context, next: Next<'_>) -> Result {
236    ///     ctx.store("id", "1".to_string());
237    ///     next.await
238    /// }
239    ///
240    /// async fn end(ctx: &mut Context) -> Result {
241    ///     assert_eq!(1, ctx.load::<String>("id").unwrap().parse::<i32>()?);
242    ///     Ok(())
243    /// }
244    ///
245    /// let app = App::new().gate(gate).end(end);
246    /// ```
247    #[inline]
248    pub fn store<K, V>(&mut self, key: K, value: V) -> Option<Arc<V>>
249    where
250        K: Into<Cow<'static, str>>,
251        V: Value,
252    {
253        self.store_scoped(PublicScope, key, value)
254    }
255
256    /// Search for value by key in specific scope.
257    ///
258    /// ### Example
259    ///
260    /// ```rust
261    /// use roa_core::{App, Context, Result, Next};
262    ///
263    /// struct Scope;
264    ///
265    /// async fn gate(ctx: &mut Context, next: Next<'_>) -> Result {
266    ///     ctx.store_scoped(Scope, "id", "1".to_owned());
267    ///     next.await
268    /// }
269    ///
270    /// async fn end(ctx: &mut Context) -> Result {
271    ///     assert_eq!(1, ctx.load_scoped::<Scope, String>("id").unwrap().parse::<i32>()?);
272    ///     Ok(())
273    /// }
274    ///
275    /// let app = App::new().gate(gate).end(end);
276    /// ```
277    #[inline]
278    pub fn load_scoped<'a, SC, V>(&self, key: &'a str) -> Option<Variable<'a, V>>
279    where
280        SC: Any,
281        V: Value,
282    {
283        self.storage.get::<SC, V>(key)
284    }
285
286    /// Search for value by key in public scope.
287    ///
288    /// ### Example
289    /// ```rust
290    /// use roa_core::{App, Context, Result, Next};
291    ///
292    /// async fn gate(ctx: &mut Context, next: Next<'_>) -> Result {
293    ///     ctx.store("id", "1".to_string());
294    ///     next.await
295    /// }
296    ///
297    /// async fn end(ctx: &mut Context) -> Result {
298    ///     assert_eq!(1, ctx.load::<String>("id").unwrap().parse::<i32>()?);
299    ///     Ok(())
300    /// }
301    ///
302    /// let app = App::new().gate(gate).end(end);
303    /// ```
304    #[inline]
305    pub fn load<'a, V>(&self, key: &'a str) -> Option<Variable<'a, V>>
306    where
307        V: Value,
308    {
309        self.load_scoped::<PublicScope, V>(key)
310    }
311}
312
313/// Public storage scope.
314struct PublicScope;
315
316impl<S> Deref for Context<S> {
317    type Target = S;
318    #[inline]
319    fn deref(&self) -> &Self::Target {
320        &self.state
321    }
322}
323
324impl<S> DerefMut for Context<S> {
325    #[inline]
326    fn deref_mut(&mut self) -> &mut Self::Target {
327        &mut self.state
328    }
329}
330
331impl<S: Clone> Clone for Context<S> {
332    #[inline]
333    fn clone(&self) -> Self {
334        Self {
335            req: Request::default(),
336            resp: Response::new(),
337            state: self.state.clone(),
338            exec: self.exec.clone(),
339            storage: self.storage.clone(),
340            remote_addr: self.remote_addr,
341        }
342    }
343}
344
345#[cfg(all(test, feature = "runtime"))]
346mod tests_with_runtime {
347    use std::error::Error;
348
349    use http::{HeaderValue, StatusCode, Version};
350
351    use crate::{App, Context, Next, Request, Status};
352
353    #[tokio::test]
354    async fn status_and_version() -> Result<(), Box<dyn Error>> {
355        async fn test(ctx: &mut Context) -> Result<(), Status> {
356            assert_eq!(Version::HTTP_11, ctx.version());
357            assert_eq!(StatusCode::OK, ctx.status());
358            Ok(())
359        }
360        let service = App::new().end(test).http_service();
361        service.serve(Request::default()).await;
362        Ok(())
363    }
364
365    #[derive(Clone)]
366    struct State {
367        data: usize,
368    }
369
370    #[tokio::test]
371    async fn state() -> Result<(), Box<dyn Error>> {
372        async fn gate(ctx: &mut Context<State>, next: Next<'_>) -> Result<(), Status> {
373            ctx.data = 1;
374            next.await
375        }
376
377        async fn test(ctx: &mut Context<State>) -> Result<(), Status> {
378            assert_eq!(1, ctx.data);
379            Ok(())
380        }
381        let service = App::state(State { data: 1 })
382            .gate(gate)
383            .end(test)
384            .http_service();
385        service.serve(Request::default()).await;
386        Ok(())
387    }
388
389    #[tokio::test]
390    async fn must_get() -> Result<(), Box<dyn Error>> {
391        use http::header::{CONTENT_TYPE, HOST};
392        async fn test(ctx: &mut Context) -> Result<(), Status> {
393            assert_eq!(Ok("github.com"), ctx.must_get(HOST));
394            ctx.must_get(CONTENT_TYPE)?;
395            unreachable!()
396        }
397        let service = App::new().end(test).http_service();
398        let mut req = Request::default();
399        req.headers
400            .insert(HOST, HeaderValue::from_static("github.com"));
401        let resp = service.serve(req).await;
402        assert_eq!(StatusCode::BAD_REQUEST, resp.status);
403        Ok(())
404    }
405}