1use crate::dep::DepResolver;
5use crate::error::{Error, Result};
6use crate::response::Json;
7use bytes::Bytes;
8use serde::de::DeserializeOwned;
9use std::future::Future;
10
11pub(crate) type StreamLane =
15 http_body_util::combinators::UnsyncBoxBody<Bytes, Box<dyn std::error::Error + Send + Sync>>;
16
17pub(crate) enum BodyLane {
21 Buffered(Bytes),
22 Stream(Option<StreamLane>),
24}
25
26#[derive(Clone, Copy, Debug)]
30pub struct ClientAddr(pub std::net::SocketAddr);
31
32pub struct RequestCtx {
35 pub(crate) parts: http::request::Parts,
36 pub(crate) body: BodyLane,
37 pub(crate) params: Vec<(String, String)>,
39 pub(crate) deps: DepResolver,
40 pub(crate) is_task: bool,
43}
44
45impl RequestCtx {
46 pub(crate) fn new(parts: http::request::Parts, body: Bytes, deps: DepResolver) -> Self {
50 Self::with_lane(parts, BodyLane::Buffered(body), deps)
51 }
52
53 pub(crate) fn with_lane(
56 parts: http::request::Parts,
57 body: BodyLane,
58 deps: DepResolver,
59 ) -> Self {
60 Self {
61 parts,
62 body,
63 params: Vec::new(),
64 deps,
65 is_task: false,
66 }
67 }
68
69 pub(crate) async fn drain_body(&mut self) -> Result<Bytes> {
73 match &mut self.body {
74 BodyLane::Buffered(bytes) => Ok(bytes.clone()),
75 BodyLane::Stream(slot) => {
76 let stream = slot
80 .take()
81 .ok_or_else(|| Error::internal("request body was already consumed"))?;
82 use http_body_util::BodyExt;
83 let collected = stream.collect().await.map_err(map_stream_error)?;
84 let bytes = collected.to_bytes();
85 self.body = BodyLane::Buffered(bytes.clone());
86 Ok(bytes)
87 }
88 }
89 }
90
91 pub fn method(&self) -> &http::Method {
92 &self.parts.method
93 }
94 pub fn uri(&self) -> &http::Uri {
95 &self.parts.uri
96 }
97 pub fn headers(&self) -> &http::HeaderMap {
98 &self.parts.headers
99 }
100
101 pub fn peer_addr(&self) -> Option<std::net::SocketAddr> {
106 self.parts.extensions.get::<ClientAddr>().map(|c| c.0)
107 }
108}
109
110pub(crate) fn map_stream_error(e: Box<dyn std::error::Error + Send + Sync>) -> Error {
114 if e.downcast_ref::<http_body_util::LengthLimitError>()
115 .is_some()
116 {
117 return Error::payload_too_large();
118 }
119 if e.downcast_ref::<crate::serve::RecvTimeout>().is_some() {
120 return Error::new(
121 http::StatusCode::REQUEST_TIMEOUT,
122 "JC0408",
123 "timed out reading the request body",
124 );
125 }
126 Error::bad_request("request body failed mid-read")
127}
128
129pub trait FromRequest: Sized + Send {
132 fn from_request(ctx: &mut RequestCtx) -> impl Future<Output = Result<Self>> + Send;
133}
134
135pub struct Path<T>(pub T);
141
142#[doc(hidden)]
146pub mod sealed {
147 pub trait Sealed {}
148}
149
150pub trait PathParam: sealed::Sealed + Sized + Send {
155 fn parse_param(name: &str, raw: &str) -> Result<Self>;
156}
157
158macro_rules! impl_path_param {
159 ($($t:ty),* $(,)?) => {$(
160 impl sealed::Sealed for $t {}
161 impl PathParam for $t {
162 fn parse_param(name: &str, raw: &str) -> Result<Self> {
163 raw.parse::<$t>().map_err(|e| {
164 Error::bad_request(format!("invalid path parameter `{name}`: {e}"))
165 })
166 }
167 }
168 )*};
169}
170
171#[macro_export]
186macro_rules! path_param {
187 ($($t:ty),* $(,)?) => {$(
188 impl $crate::extract::sealed::Sealed for $t {}
189 impl $crate::extract::PathParam for $t {
190 fn parse_param(name: &str, raw: &str) -> $crate::Result<Self> {
191 raw.parse::<$t>().map_err(|e| {
192 $crate::Error::bad_request(format!("invalid path parameter `{name}`: {e}"))
193 })
194 }
195 }
196 )*};
197}
198impl_path_param!(
199 i8, i16, i32, i64, i128, isize, u8, u16, u32, u64, u128, usize, f32, f64, bool, char, String,
200);
201
202impl<T: PathParam> FromRequest for Path<T> {
203 async fn from_request(ctx: &mut RequestCtx) -> Result<Self> {
204 if ctx.is_task {
205 return Err(Error::task_context());
206 }
207 let (name, raw) = ctx
211 .params
212 .last()
213 .ok_or_else(|| Error::internal("route has no path parameters"))?;
214 T::parse_param(name, raw).map(Path)
215 }
216}
217
218impl<A: PathParam, B: PathParam> FromRequest for Path<(A, B)> {
219 async fn from_request(ctx: &mut RequestCtx) -> Result<Self> {
220 if ctx.is_task {
221 return Err(Error::task_context());
222 }
223 let [a, b] = take_params::<2>(ctx)?;
224 Ok(Path((
225 A::parse_param(&a.0, &a.1)?,
226 B::parse_param(&b.0, &b.1)?,
227 )))
228 }
229}
230
231impl<A: PathParam, B: PathParam, C: PathParam> FromRequest for Path<(A, B, C)> {
232 async fn from_request(ctx: &mut RequestCtx) -> Result<Self> {
233 if ctx.is_task {
234 return Err(Error::task_context());
235 }
236 let [a, b, c] = take_params::<3>(ctx)?;
237 Ok(Path((
238 A::parse_param(&a.0, &a.1)?,
239 B::parse_param(&b.0, &b.1)?,
240 C::parse_param(&c.0, &c.1)?,
241 )))
242 }
243}
244
245fn take_params<const N: usize>(ctx: &RequestCtx) -> Result<[(String, String); N]> {
248 if ctx.params.len() < N {
249 return Err(Error::internal(format!(
250 "route captures {} path parameter(s) but the handler expects {N}",
251 ctx.params.len()
252 )));
253 }
254 Ok(std::array::from_fn(|i| ctx.params[i].clone()))
255}
256
257pub struct Query<T>(pub T);
259
260impl<T: DeserializeOwned + Send> FromRequest for Query<T> {
261 async fn from_request(ctx: &mut RequestCtx) -> Result<Self> {
262 if ctx.is_task {
263 return Err(Error::task_context());
264 }
265 let q = ctx.parts.uri.query().unwrap_or("");
266 serde_urlencoded::from_str::<T>(q)
267 .map(Query)
268 .map_err(|e| Error::bad_request(format!("invalid query string: {e}")))
269 }
270}
271
272impl<T: DeserializeOwned + Send> FromRequest for Json<T> {
273 async fn from_request(ctx: &mut RequestCtx) -> Result<Self> {
274 if ctx.is_task {
275 return Err(Error::task_context());
276 }
277 let body = ctx.drain_body().await?;
278 serde_json::from_slice::<T>(&body)
279 .map(Json)
280 .map_err(|e| Error::unprocessable(format!("invalid JSON body: {e}")))
281 }
282}
283
284pub struct Headers(pub(crate) http::HeaderMap);
286
287impl Headers {
288 pub fn get(&self, name: &str) -> Option<&str> {
290 self.0.get(name).and_then(|v| v.to_str().ok())
291 }
292}
293
294impl FromRequest for Headers {
295 async fn from_request(ctx: &mut RequestCtx) -> Result<Self> {
296 if ctx.is_task {
297 return Err(Error::task_context());
298 }
299 Ok(Headers(ctx.headers().clone()))
300 }
301}
302
303pub struct RawBody(pub Bytes);
308
309impl FromRequest for RawBody {
310 async fn from_request(ctx: &mut RequestCtx) -> Result<Self> {
311 if ctx.is_task {
312 return Err(Error::task_context());
313 }
314 Ok(RawBody(ctx.drain_body().await?))
315 }
316}
317
318#[cfg(test)]
319mod tests {
320 use super::*;
321 use crate::dep::DepEnv;
322 use std::sync::Arc;
323
324 fn ctx(uri: &str, body: &str) -> RequestCtx {
325 let req = http::Request::builder()
326 .method(http::Method::GET)
327 .uri(uri)
328 .body(())
329 .unwrap();
330 let (parts, ()) = req.into_parts();
331 RequestCtx::new(
332 parts,
333 Bytes::from(body.to_string()),
334 DepResolver::new(Arc::new(DepEnv::default()), Default::default()),
335 )
336 }
337
338 #[tokio::test]
339 async fn peer_addr_is_none_without_a_socket_and_readable_when_set() {
340 let mut c = ctx("/x", "");
341 assert!(c.peer_addr().is_none());
342 let addr: std::net::SocketAddr = "203.0.113.7:5000".parse().unwrap();
343 c.parts.extensions.insert(crate::extract::ClientAddr(addr));
344 assert_eq!(c.peer_addr(), Some(addr));
345 }
346
347 #[tokio::test]
348 async fn path_extracts_typed_param() {
349 let mut c = ctx("/todos/42", "");
350 c.params.push(("id".into(), "42".into()));
351 let Path(id): Path<i64> = Path::<i64>::from_request(&mut c).await.unwrap();
352 assert_eq!(id, 42);
353 }
354
355 #[tokio::test]
356 async fn path_with_wrong_type_is_400() {
357 let mut c = ctx("/todos/abc", "");
358 c.params.push(("id".into(), "abc".into()));
359 let err = Path::<i64>::from_request(&mut c).await.err().unwrap();
360 assert_eq!(err.code(), "JC0400");
361 }
362
363 #[tokio::test]
364 async fn path_missing_param_is_500() {
365 let mut c = ctx("/todos", "");
368 let err = Path::<i64>::from_request(&mut c).await.err().unwrap();
369 assert_eq!(err.code(), "JC0500");
370 }
371
372 #[tokio::test]
373 async fn query_deserializes_struct() {
374 #[derive(serde::Deserialize)]
375 struct Page {
376 limit: u32,
377 offset: u32,
378 }
379 let mut c = ctx("/todos?limit=10&offset=20", "");
380 let Query(p): Query<Page> = Query::from_request(&mut c).await.unwrap();
381 assert_eq!((p.limit, p.offset), (10, 20));
382 }
383
384 #[tokio::test]
385 async fn single_path_param_binds_the_leaf_segment() {
386 use crate::prelude::*;
387 async fn show(Path(id): Path<i64>) -> Result<Json<i64>> {
388 Ok(Json(id))
389 }
390 let t = App::new()
391 .mount(
392 "/ws/{ws}",
393 Module::new("leads").route("/leads/{id}", get(show)),
394 )
395 .into_test();
396 assert_eq!(
397 t.get("/ws/7/leads/42").await.json::<i64>(),
398 42,
399 "leaf param, not mount param"
400 );
401 }
402
403 #[tokio::test]
404 async fn tuples_still_read_root_to_leaf() {
405 use crate::prelude::*;
406 async fn pair(Path((ws, id)): Path<(i64, i64)>) -> Result<Json<(i64, i64)>> {
407 Ok(Json((ws, id)))
408 }
409 let t = App::new()
410 .mount(
411 "/ws/{ws}",
412 Module::new("leads").route("/leads/{id}", get(pair)),
413 )
414 .into_test();
415 assert_eq!(t.get("/ws/7/leads/42").await.json::<(i64, i64)>(), (7, 42));
416 }
417
418 #[tokio::test]
419 async fn path_param_macro_admits_custom_newtypes() {
420 use crate::prelude::*;
421 #[derive(Debug)]
422 struct LeadId(i64);
423 impl std::str::FromStr for LeadId {
424 type Err = std::num::ParseIntError;
425 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
426 Ok(LeadId(s.parse()?))
427 }
428 }
429 crate::path_param!(LeadId);
430 async fn show(Path(id): Path<LeadId>) -> Result<Json<i64>> {
431 Ok(Json(id.0))
432 }
433 let t = App::new().route("/leads/{id}", get(show)).into_test();
434 assert_eq!(t.get("/leads/42").await.json::<i64>(), 42);
435 }
436
437 #[tokio::test]
438 async fn raw_body_yields_exact_bytes_and_coexists_with_headers() {
439 use crate::prelude::*;
440 async fn verify(headers: Headers, body: RawBody) -> Result<Json<(usize, bool)>> {
441 let signed = headers.get("x-signature").is_some();
442 Ok(Json((body.0.len(), signed)))
443 }
444 let t = App::new().route("/hook", post(verify)).into_test();
445 let res = t
446 .post_bytes_with("/hook", b"{\"raw\": 1}", &[("x-signature", "abc")])
447 .await;
448 assert_eq!(res.status().as_u16(), 200);
449 assert_eq!(res.json::<(usize, bool)>(), (10, true));
450 }
451
452 #[tokio::test]
453 async fn raw_body_drains_a_stream_route_transparently() {
454 use crate::prelude::*;
455 async fn len(body: RawBody) -> Result<Json<usize>> {
456 Ok(Json(body.0.len()))
457 }
458 let t = App::new().route("/up", post(len).stream_body()).into_test();
459 let payload = vec![b'x'; 100]; let res = t.post_bytes("/up", &payload).await;
461 assert_eq!(res.json::<usize>(), 100);
462 }
463
464 #[tokio::test]
465 async fn json_body_deserializes_and_bad_json_is_422() {
466 #[derive(serde::Deserialize)]
467 struct NewTodo {
468 title: String,
469 }
470 let mut c = ctx("/todos", r#"{"title":"x"}"#);
471 let Json(t): Json<NewTodo> = Json::from_request(&mut c).await.unwrap();
472 assert_eq!(t.title, "x");
473
474 let mut bad = ctx("/todos", r#"{"title":"#);
475 let err = Json::<NewTodo>::from_request(&mut bad).await.err().unwrap();
476 assert_eq!(err.code(), "JC0422");
477 }
478
479 fn stream_ctx(body: &[u8], limit: Option<usize>) -> RequestCtx {
484 use http_body_util::BodyExt;
485 use http_body_util::combinators::UnsyncBoxBody;
486 let req = http::Request::builder().uri("/up").body(()).unwrap();
487 let (parts, ()) = req.into_parts();
488 let bytes = Bytes::copy_from_slice(body);
489 let lane: StreamLane = match limit {
490 Some(limit) => {
491 let limited = http_body_util::Limited::new(
492 http_body_util::Full::<Bytes>::new(bytes).map_err(
493 |never| -> Box<dyn std::error::Error + Send + Sync> { match never {} },
494 ),
495 limit,
496 );
497 UnsyncBoxBody::new(limited.map_err(Into::into))
498 }
499 None => {
500 let full = http_body_util::Full::<Bytes>::new(bytes);
501 UnsyncBoxBody::new(full.map_err(
502 |never| -> Box<dyn std::error::Error + Send + Sync> { match never {} },
503 ))
504 }
505 };
506 RequestCtx::with_lane(
507 parts,
508 BodyLane::Stream(Some(lane)),
509 DepResolver::new(Arc::new(DepEnv::default()), Default::default()),
510 )
511 }
512
513 #[tokio::test]
514 async fn stream_routes_deliver_the_body_and_enforce_the_limit() {
515 use crate::prelude::*;
516 async fn echo(Json(v): Json<serde_json::Value>) -> Result<Json<serde_json::Value>> {
517 Ok(Json(v))
518 }
519 let t = App::new()
520 .route("/up", post(echo).stream_body().body_limit(64))
521 .into_test();
522 let res = t.post_json("/up", &serde_json::json!({"k": "v"})).await;
524 assert_eq!(res.status().as_u16(), 200);
525 let big = serde_json::json!({"k": "x".repeat(200)});
527 let res = t.post_json("/up", &big).await;
528 assert_eq!(res.status().as_u16(), 413, "body: {}", res.text());
529 }
530
531 #[tokio::test]
532 async fn drain_body_twice_caches_the_stream_bytes() {
533 use bytes::Bytes;
537 let mut c = stream_ctx(br#"{"k":"v"}"#, None);
538 let first = c.drain_body().await.unwrap();
539 assert_eq!(first, Bytes::from_static(br#"{"k":"v"}"#));
540 let second = c.drain_body().await.unwrap();
541 assert_eq!(second, first, "second drain returns the cached bytes");
542 }
543
544 #[tokio::test]
545 async fn stream_lane_over_limit_maps_to_413() {
546 let mut c = stream_ctx(&[b'x'; 200], Some(64));
549 let err = c.drain_body().await.err().unwrap();
550 assert_eq!(err.code(), "JC0413");
551 }
552
553 #[tokio::test]
554 async fn limit_trips_through_the_timed_recv_wrapper_still_map_to_413() {
555 use crate::serve::TimedRecvBody;
562 use http_body_util::BodyExt;
563 use http_body_util::combinators::UnsyncBoxBody;
564 use std::time::Duration;
565
566 let req = http::Request::builder().uri("/up").body(()).unwrap();
567 let (parts, ()) = req.into_parts();
568 let over_limit_body = http_body_util::Full::<Bytes>::new(Bytes::from_static(&[b'x'; 200]))
569 .map_err(|never| -> Box<dyn std::error::Error + Send + Sync> { match never {} });
570 let lane: StreamLane = UnsyncBoxBody::new(TimedRecvBody::new(
571 http_body_util::Limited::new(over_limit_body, 64),
572 Duration::from_secs(5),
573 ));
574 let mut c = RequestCtx::with_lane(
575 parts,
576 BodyLane::Stream(Some(lane)),
577 DepResolver::new(Arc::new(DepEnv::default()), Default::default()),
578 );
579 let err = c.drain_body().await.err().unwrap();
580 assert_eq!(err.code(), "JC0413");
581 assert_eq!(err.status().as_u16(), 413);
582 }
583}