roa_core/context.rs
1mod storage;
2
3use std::any::Any;
4use std::borrow::Cow;
5use std::net::SocketAddr;
6use std::ops::{Deref, DerefMut};
7use std::sync::Arc;
8
9use http::header::AsHeaderName;
10use http::{Method, StatusCode, Uri, Version};
11pub use storage::Variable;
12use storage::{Storage, Value};
13
14use crate::{status, Executor, Request, Response};
15
16/// A structure to share request, response and other data between middlewares.
17///
18/// ### Example
19///
20/// ```rust
21/// use roa_core::{App, Context, Next, Result};
22/// use tracing::info;
23/// use tokio::fs::File;
24///
25/// let app = App::new().gate(gate).end(end);
26/// async fn gate(ctx: &mut Context, next: Next<'_>) -> Result {
27/// info!("{} {}", ctx.method(), ctx.uri());
28/// next.await
29/// }
30///
31/// async fn end(ctx: &mut Context) -> Result {
32/// ctx.resp.write_reader(File::open("assets/welcome.html").await?);
33/// Ok(())
34/// }
35/// ```
36pub struct Context<S = ()> {
37 /// The request, to read http method, uri, version, headers and body.
38 pub req: Request,
39
40 /// The response, to set http status, version, headers and body.
41 pub resp: Response,
42
43 /// The executor, to spawn futures or blocking works.
44 pub exec: Executor,
45
46 /// Socket addr of last client or proxy.
47 pub remote_addr: SocketAddr,
48
49 storage: Storage,
50 state: S,
51}
52
53impl<S> Context<S> {
54 /// Construct a context from a request, an app and a addr_stream.
55 #[inline]
56 pub(crate) fn new(request: Request, state: S, exec: Executor, remote_addr: SocketAddr) -> Self {
57 Self {
58 req: request,
59 resp: Response::default(),
60 state,
61 exec,
62 storage: Storage::default(),
63 remote_addr,
64 }
65 }
66
67 /// Clone URI.
68 ///
69 /// ### Example
70 /// ```rust
71 /// use roa_core::{App, Context, Result};
72 ///
73 /// let app = App::new().end(get);
74 ///
75 /// async fn get(ctx: &mut Context) -> Result {
76 /// assert_eq!("/", ctx.uri().to_string());
77 /// Ok(())
78 /// }
79 /// ```
80 #[inline]
81 pub fn uri(&self) -> &Uri {
82 &self.req.uri
83 }
84
85 /// Clone request::method.
86 ///
87 /// ### Example
88 /// ```rust
89 /// use roa_core::{App, Context, Result};
90 /// use roa_core::http::Method;
91 ///
92 /// let app = App::new().end(get);
93 ///
94 /// async fn get(ctx: &mut Context) -> Result {
95 /// assert_eq!(Method::GET, ctx.method());
96 /// Ok(())
97 /// }
98 /// ```
99 #[inline]
100 pub fn method(&self) -> &Method {
101 &self.req.method
102 }
103
104 /// Search for a header value and try to get its string reference.
105 ///
106 /// ### Example
107 /// ```rust
108 /// use roa_core::{App, Context, Result};
109 /// use roa_core::http::header::CONTENT_TYPE;
110 ///
111 /// let app = App::new().end(get);
112 ///
113 /// async fn get(ctx: &mut Context) -> Result {
114 /// assert_eq!(
115 /// Some("text/plain"),
116 /// ctx.get(CONTENT_TYPE),
117 /// );
118 /// Ok(())
119 /// }
120 /// ```
121 #[inline]
122 pub fn get(&self, name: impl AsHeaderName) -> Option<&str> {
123 self.req
124 .headers
125 .get(name)
126 .and_then(|value| value.to_str().ok())
127 }
128
129 /// Search for a header value and get its string reference.
130 ///
131 /// Otherwise return a 400 BAD REQUEST.
132 ///
133 /// ### Example
134 /// ```rust
135 /// use roa_core::{App, Context, Result};
136 /// use roa_core::http::header::CONTENT_TYPE;
137 ///
138 /// let app = App::new().end(get);
139 ///
140 /// async fn get(ctx: &mut Context) -> Result {
141 /// assert_eq!(
142 /// "text/plain",
143 /// ctx.must_get(CONTENT_TYPE)?,
144 /// );
145 /// Ok(())
146 /// }
147 /// ```
148 #[inline]
149 pub fn must_get(&self, name: impl AsHeaderName) -> crate::Result<&str> {
150 let value = self
151 .req
152 .headers
153 .get(name)
154 .ok_or_else(|| status!(StatusCode::BAD_REQUEST))?;
155 value
156 .to_str()
157 .map_err(|err| status!(StatusCode::BAD_REQUEST, err))
158 }
159 /// Clone response::status.
160 ///
161 /// ### Example
162 /// ```rust
163 /// use roa_core::{App, Context, Result};
164 /// use roa_core::http::StatusCode;
165 ///
166 /// let app = App::new().end(get);
167 ///
168 /// async fn get(ctx: &mut Context) -> Result {
169 /// assert_eq!(StatusCode::OK, ctx.status());
170 /// Ok(())
171 /// }
172 /// ```
173 #[inline]
174 pub fn status(&self) -> StatusCode {
175 self.resp.status
176 }
177
178 /// Clone request::version.
179 ///
180 /// ### Example
181 /// ```rust
182 /// use roa_core::{App, Context, Result};
183 /// use roa_core::http::Version;
184 ///
185 /// let app = App::new().end(get);
186 ///
187 /// async fn get(ctx: &mut Context) -> Result {
188 /// assert_eq!(Version::HTTP_11, ctx.version());
189 /// Ok(())
190 /// }
191 /// ```
192 #[inline]
193 pub fn version(&self) -> Version {
194 self.req.version
195 }
196
197 /// Store key-value pair in specific scope.
198 ///
199 /// ### Example
200 /// ```rust
201 /// use roa_core::{App, Context, Result, Next};
202 ///
203 /// struct Scope;
204 /// struct AnotherScope;
205 ///
206 /// async fn gate(ctx: &mut Context, next: Next<'_>) -> Result {
207 /// ctx.store_scoped(Scope, "id", "1".to_string());
208 /// next.await
209 /// }
210 ///
211 /// async fn end(ctx: &mut Context) -> Result {
212 /// assert_eq!(1, ctx.load_scoped::<Scope, String>("id").unwrap().parse::<i32>()?);
213 /// assert!(ctx.load_scoped::<AnotherScope, String>("id").is_none());
214 /// Ok(())
215 /// }
216 ///
217 /// let app = App::new().gate(gate).end(end);
218 /// ```
219 #[inline]
220 pub fn store_scoped<SC, K, V>(&mut self, scope: SC, key: K, value: V) -> Option<Arc<V>>
221 where
222 SC: Any,
223 K: Into<Cow<'static, str>>,
224 V: Value,
225 {
226 self.storage.insert(scope, key, value)
227 }
228
229 /// Store key-value pair in public scope.
230 ///
231 /// ### Example
232 /// ```rust
233 /// use roa_core::{App, Context, Result, Next};
234 ///
235 /// async fn gate(ctx: &mut Context, next: Next<'_>) -> Result {
236 /// ctx.store("id", "1".to_string());
237 /// next.await
238 /// }
239 ///
240 /// async fn end(ctx: &mut Context) -> Result {
241 /// assert_eq!(1, ctx.load::<String>("id").unwrap().parse::<i32>()?);
242 /// Ok(())
243 /// }
244 ///
245 /// let app = App::new().gate(gate).end(end);
246 /// ```
247 #[inline]
248 pub fn store<K, V>(&mut self, key: K, value: V) -> Option<Arc<V>>
249 where
250 K: Into<Cow<'static, str>>,
251 V: Value,
252 {
253 self.store_scoped(PublicScope, key, value)
254 }
255
256 /// Search for value by key in specific scope.
257 ///
258 /// ### Example
259 ///
260 /// ```rust
261 /// use roa_core::{App, Context, Result, Next};
262 ///
263 /// struct Scope;
264 ///
265 /// async fn gate(ctx: &mut Context, next: Next<'_>) -> Result {
266 /// ctx.store_scoped(Scope, "id", "1".to_owned());
267 /// next.await
268 /// }
269 ///
270 /// async fn end(ctx: &mut Context) -> Result {
271 /// assert_eq!(1, ctx.load_scoped::<Scope, String>("id").unwrap().parse::<i32>()?);
272 /// Ok(())
273 /// }
274 ///
275 /// let app = App::new().gate(gate).end(end);
276 /// ```
277 #[inline]
278 pub fn load_scoped<'a, SC, V>(&self, key: &'a str) -> Option<Variable<'a, V>>
279 where
280 SC: Any,
281 V: Value,
282 {
283 self.storage.get::<SC, V>(key)
284 }
285
286 /// Search for value by key in public scope.
287 ///
288 /// ### Example
289 /// ```rust
290 /// use roa_core::{App, Context, Result, Next};
291 ///
292 /// async fn gate(ctx: &mut Context, next: Next<'_>) -> Result {
293 /// ctx.store("id", "1".to_string());
294 /// next.await
295 /// }
296 ///
297 /// async fn end(ctx: &mut Context) -> Result {
298 /// assert_eq!(1, ctx.load::<String>("id").unwrap().parse::<i32>()?);
299 /// Ok(())
300 /// }
301 ///
302 /// let app = App::new().gate(gate).end(end);
303 /// ```
304 #[inline]
305 pub fn load<'a, V>(&self, key: &'a str) -> Option<Variable<'a, V>>
306 where
307 V: Value,
308 {
309 self.load_scoped::<PublicScope, V>(key)
310 }
311}
312
313/// Public storage scope.
314struct PublicScope;
315
316impl<S> Deref for Context<S> {
317 type Target = S;
318 #[inline]
319 fn deref(&self) -> &Self::Target {
320 &self.state
321 }
322}
323
324impl<S> DerefMut for Context<S> {
325 #[inline]
326 fn deref_mut(&mut self) -> &mut Self::Target {
327 &mut self.state
328 }
329}
330
331impl<S: Clone> Clone for Context<S> {
332 #[inline]
333 fn clone(&self) -> Self {
334 Self {
335 req: Request::default(),
336 resp: Response::new(),
337 state: self.state.clone(),
338 exec: self.exec.clone(),
339 storage: self.storage.clone(),
340 remote_addr: self.remote_addr,
341 }
342 }
343}
344
345#[cfg(all(test, feature = "runtime"))]
346mod tests_with_runtime {
347 use std::error::Error;
348
349 use http::{HeaderValue, StatusCode, Version};
350
351 use crate::{App, Context, Next, Request, Status};
352
353 #[tokio::test]
354 async fn status_and_version() -> Result<(), Box<dyn Error>> {
355 async fn test(ctx: &mut Context) -> Result<(), Status> {
356 assert_eq!(Version::HTTP_11, ctx.version());
357 assert_eq!(StatusCode::OK, ctx.status());
358 Ok(())
359 }
360 let service = App::new().end(test).http_service();
361 service.serve(Request::default()).await;
362 Ok(())
363 }
364
365 #[derive(Clone)]
366 struct State {
367 data: usize,
368 }
369
370 #[tokio::test]
371 async fn state() -> Result<(), Box<dyn Error>> {
372 async fn gate(ctx: &mut Context<State>, next: Next<'_>) -> Result<(), Status> {
373 ctx.data = 1;
374 next.await
375 }
376
377 async fn test(ctx: &mut Context<State>) -> Result<(), Status> {
378 assert_eq!(1, ctx.data);
379 Ok(())
380 }
381 let service = App::state(State { data: 1 })
382 .gate(gate)
383 .end(test)
384 .http_service();
385 service.serve(Request::default()).await;
386 Ok(())
387 }
388
389 #[tokio::test]
390 async fn must_get() -> Result<(), Box<dyn Error>> {
391 use http::header::{CONTENT_TYPE, HOST};
392 async fn test(ctx: &mut Context) -> Result<(), Status> {
393 assert_eq!(Ok("github.com"), ctx.must_get(HOST));
394 ctx.must_get(CONTENT_TYPE)?;
395 unreachable!()
396 }
397 let service = App::new().end(test).http_service();
398 let mut req = Request::default();
399 req.headers
400 .insert(HOST, HeaderValue::from_static("github.com"));
401 let resp = service.serve(req).await;
402 assert_eq!(StatusCode::BAD_REQUEST, resp.status);
403 Ok(())
404 }
405}