1use async_trait::async_trait;
2use bytes::Bytes;
3use http::{Request, Response, StatusCode};
4use http_body::Body;
5use http_body_util::{BodyExt, Full};
6use hyper::body::Incoming;
7use serde::de::DeserializeOwned;
8use std::collections::HashMap;
9
10#[cfg(feature = "validation")]
11use std::collections::BTreeMap;
12#[cfg(feature = "validation")]
13use validator::{Validate, ValidationErrors, ValidationErrorsKind};
14
15use crate::ingress::PathParams;
16
17#[cfg(feature = "multer")]
18pub mod multipart;
19#[cfg(feature = "multer")]
20pub use multipart::Multipart;
21
22pub const DEFAULT_BODY_LIMIT: usize = 1024 * 1024;
23
24#[derive(Debug, thiserror::Error, PartialEq, Eq)]
25pub enum ExtractError {
26 #[error("request body exceeds limit {limit} bytes (actual: {actual})")]
27 BodyTooLarge { limit: usize, actual: usize },
28 #[error("failed to read request body: {0}")]
29 BodyRead(String),
30 #[error("invalid JSON body: {0}")]
31 InvalidJson(String),
32 #[error("invalid query string: {0}")]
33 InvalidQuery(String),
34 #[error("missing path params in request extensions")]
35 MissingPathParams,
36 #[error("invalid path params: {0}")]
37 InvalidPath(String),
38 #[error("failed to encode path params: {0}")]
39 PathEncode(String),
40 #[error("missing header: {0}")]
41 MissingHeader(String),
42 #[error("invalid header value: {0}")]
43 InvalidHeader(String),
44 #[cfg(feature = "validation")]
45 #[error("validation failed")]
46 ValidationFailed(ValidationErrorBody),
47 #[cfg(feature = "multer")]
48 #[error("multipart parsing error: {0}")]
49 MultipartError(String),
50}
51
52#[cfg(feature = "validation")]
53#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize)]
54pub struct ValidationErrorBody {
55 pub error: &'static str,
56 pub message: &'static str,
57 pub fields: BTreeMap<String, Vec<String>>,
58}
59
60impl ExtractError {
61 pub fn status_code(&self) -> StatusCode {
62 #[cfg(feature = "validation")]
63 {
64 if matches!(self, Self::ValidationFailed(_)) {
65 return StatusCode::UNPROCESSABLE_ENTITY;
66 }
67 }
68
69 StatusCode::BAD_REQUEST
70 }
71
72 pub fn into_http_response(&self) -> Response<Full<Bytes>> {
73 #[cfg(feature = "validation")]
74 if let Self::ValidationFailed(body) = self {
75 let payload = serde_json::to_vec(body).unwrap_or_else(|_| {
76 br#"{"error":"validation_failed","message":"request validation failed"}"#.to_vec()
77 });
78 return Response::builder()
79 .status(self.status_code())
80 .header(http::header::CONTENT_TYPE, "application/json")
81 .body(Full::new(Bytes::from(payload)))
82 .expect("validation response builder should be infallible");
83 }
84
85 Response::builder()
86 .status(self.status_code())
87 .body(Full::new(Bytes::from(self.to_string())))
88 .expect("extract error response builder should be infallible")
89 }
90}
91
92#[async_trait]
93pub trait FromRequest<B = Incoming>: Sized
94where
95 B: Body<Data = Bytes> + Send + Unpin + 'static,
96 B::Error: std::fmt::Display + Send + Sync + 'static,
97{
98 async fn from_request(req: &mut Request<B>) -> Result<Self, ExtractError>;
99}
100
101#[derive(Debug, Clone, PartialEq, Eq)]
102pub struct Json<T>(pub T);
103
104impl<T> Json<T> {
105 pub fn into_inner(self) -> T {
106 self.0
107 }
108}
109
110#[derive(Debug, Clone, PartialEq, Eq)]
111pub struct Query<T>(pub T);
112
113impl<T> Query<T> {
114 pub fn into_inner(self) -> T {
115 self.0
116 }
117}
118
119#[derive(Debug, Clone, PartialEq, Eq)]
120pub struct Path<T>(pub T);
121
122impl<T> Path<T> {
123 pub fn into_inner(self) -> T {
124 self.0
125 }
126}
127
128#[derive(Debug, Clone, PartialEq, Eq)]
138pub struct Header(pub String);
139
140impl Header {
141 pub fn from_parts(name: &str, parts: &http::request::Parts) -> Result<Self, ExtractError> {
143 let value = parts
144 .headers
145 .get(name)
146 .ok_or_else(|| ExtractError::MissingHeader(name.to_string()))?;
147 let s = value
148 .to_str()
149 .map_err(|e| ExtractError::InvalidHeader(e.to_string()))?;
150 Ok(Header(s.to_string()))
151 }
152
153 pub fn into_inner(self) -> String {
155 self.0
156 }
157}
158
159#[derive(Debug, Clone, Default)]
172pub struct CookieJar {
173 cookies: HashMap<String, String>,
174}
175
176fn is_valid_cookie_name(name: &str) -> bool {
182 !name.is_empty()
183 && name.bytes().all(|b| matches!(b,
184 b'!' | b'#' | b'$' | b'%' | b'&' | b'\'' | b'*' | b'+' | b'-' | b'.' |
185 b'0'..=b'9' | b'A'..=b'Z' | b'^' | b'_' | b'`' | b'a'..=b'z' | b'|' | b'~'
186 ))
187}
188
189fn unquote_cookie_value(value: &str) -> &str {
191 if value.len() >= 2 && value.starts_with('"') && value.ends_with('"') {
192 &value[1..value.len() - 1]
193 } else {
194 value
195 }
196}
197
198fn percent_decode_cookie(input: &str) -> String {
200 let mut result = Vec::with_capacity(input.len());
201 let bytes = input.as_bytes();
202 let mut i = 0;
203 while i < bytes.len() {
204 if bytes[i] == b'%' && i + 2 < bytes.len() {
205 if let (Some(hi), Some(lo)) = (hex_value(bytes[i + 1]), hex_value(bytes[i + 2])) {
206 result.push(hi * 16 + lo);
207 i += 3;
208 continue;
209 }
210 }
211 result.push(bytes[i]);
212 i += 1;
213 }
214 String::from_utf8_lossy(&result).into_owned()
215}
216
217fn hex_value(b: u8) -> Option<u8> {
218 match b {
219 b'0'..=b'9' => Some(b - b'0'),
220 b'a'..=b'f' => Some(b - b'a' + 10),
221 b'A'..=b'F' => Some(b - b'A' + 10),
222 _ => None,
223 }
224}
225
226impl CookieJar {
227 pub fn from_parts(parts: &http::request::Parts) -> Self {
233 let mut cookies = HashMap::new();
234 if let Some(header) = parts.headers.get(http::header::COOKIE) {
235 if let Ok(value) = header.to_str() {
236 for pair in value.split(';') {
237 let pair = pair.trim();
238 if let Some((key, val)) = pair.split_once('=') {
239 let name = key.trim();
240 if !is_valid_cookie_name(name) {
241 tracing::warn!(
242 cookie_name = name,
243 "skipping cookie with invalid name"
244 );
245 continue;
246 }
247 let val = unquote_cookie_value(val.trim());
248 cookies.insert(name.to_string(), percent_decode_cookie(val));
249 }
250 }
251 }
252 }
253 CookieJar { cookies }
254 }
255
256 pub fn get(&self, name: &str) -> Option<&str> {
258 self.cookies.get(name).map(|s| s.as_str())
259 }
260
261 pub fn contains(&self, name: &str) -> bool {
263 self.cookies.contains_key(name)
264 }
265
266 pub fn iter(&self) -> impl Iterator<Item = (&str, &str)> {
268 self.cookies.iter().map(|(k, v)| (k.as_str(), v.as_str()))
269 }
270}
271
272#[async_trait]
273#[cfg(not(feature = "validation"))]
274impl<T, B> FromRequest<B> for Json<T>
275where
276 T: DeserializeOwned + Send + 'static,
277 B: Body<Data = Bytes> + Send + Unpin + 'static,
278 B::Error: std::fmt::Display + Send + Sync + 'static,
279{
280 async fn from_request(req: &mut Request<B>) -> Result<Self, ExtractError> {
281 let bytes = read_body_limited(req, DEFAULT_BODY_LIMIT).await?;
282 let value = parse_json_bytes(&bytes)?;
283 Ok(Json(value))
284 }
285}
286
287#[async_trait]
288#[cfg(feature = "validation")]
289impl<T, B> FromRequest<B> for Json<T>
290where
291 T: DeserializeOwned + Send + Validate + 'static,
292 B: Body<Data = Bytes> + Send + Unpin + 'static,
293 B::Error: std::fmt::Display + Send + Sync + 'static,
294{
295 async fn from_request(req: &mut Request<B>) -> Result<Self, ExtractError> {
296 let bytes = read_body_limited(req, DEFAULT_BODY_LIMIT).await?;
297 let value = parse_json_bytes::<T>(&bytes)?;
298
299 validate_payload(&value)?;
300 Ok(Json(value))
301 }
302}
303
304#[async_trait]
305impl<T, B> FromRequest<B> for Query<T>
306where
307 T: DeserializeOwned + Send + 'static,
308 B: Body<Data = Bytes> + Send + Unpin + 'static,
309 B::Error: std::fmt::Display + Send + Sync + 'static,
310{
311 async fn from_request(req: &mut Request<B>) -> Result<Self, ExtractError> {
312 let value = parse_query_str(req.uri().query().unwrap_or(""))?;
313 Ok(Query(value))
314 }
315}
316
317#[async_trait]
318impl<T, B> FromRequest<B> for Path<T>
319where
320 T: DeserializeOwned + Send + 'static,
321 B: Body<Data = Bytes> + Send + Unpin + 'static,
322 B::Error: std::fmt::Display + Send + Sync + 'static,
323{
324 async fn from_request(req: &mut Request<B>) -> Result<Self, ExtractError> {
325 let params = req
326 .extensions()
327 .get::<PathParams>()
328 .ok_or(ExtractError::MissingPathParams)?;
329 let value = parse_path_map(params.as_map())?;
330 Ok(Path(value))
331 }
332}
333
334async fn read_body_limited<B>(req: &mut Request<B>, limit: usize) -> Result<Bytes, ExtractError>
335where
336 B: Body<Data = Bytes> + Send + Unpin + 'static,
337 B::Error: std::fmt::Display + Send + Sync + 'static,
338{
339 let body = req
340 .body_mut()
341 .collect()
342 .await
343 .map_err(|error| ExtractError::BodyRead(error.to_string()))?
344 .to_bytes();
345
346 if body.len() > limit {
347 return Err(ExtractError::BodyTooLarge {
348 limit,
349 actual: body.len(),
350 });
351 }
352
353 Ok(body)
354}
355
356fn parse_json_bytes<T>(bytes: &[u8]) -> Result<T, ExtractError>
357where
358 T: DeserializeOwned,
359{
360 serde_json::from_slice(bytes).map_err(|error| ExtractError::InvalidJson(error.to_string()))
361}
362
363fn parse_query_str<T>(query: &str) -> Result<T, ExtractError>
364where
365 T: DeserializeOwned,
366{
367 serde_urlencoded::from_str(query).map_err(|error| ExtractError::InvalidQuery(error.to_string()))
368}
369
370fn parse_path_map<T>(params: &HashMap<String, String>) -> Result<T, ExtractError>
371where
372 T: DeserializeOwned,
373{
374 let encoded = serde_urlencoded::to_string(params)
375 .map_err(|error| ExtractError::PathEncode(error.to_string()))?;
376 serde_urlencoded::from_str(&encoded)
377 .map_err(|error| ExtractError::InvalidPath(error.to_string()))
378}
379
380#[cfg(feature = "validation")]
381fn validate_payload<T>(value: &T) -> Result<(), ExtractError>
382where
383 T: Validate,
384{
385 value
386 .validate()
387 .map_err(|errors| ExtractError::ValidationFailed(validation_error_body(&errors)))
388}
389
390#[cfg(feature = "validation")]
391fn validation_error_body(errors: &ValidationErrors) -> ValidationErrorBody {
392 let mut fields = BTreeMap::new();
393 collect_validation_errors("", errors, &mut fields);
394
395 ValidationErrorBody {
396 error: "validation_failed",
397 message: "request validation failed",
398 fields,
399 }
400}
401
402#[cfg(feature = "validation")]
403fn collect_validation_errors(
404 prefix: &str,
405 errors: &ValidationErrors,
406 fields: &mut BTreeMap<String, Vec<String>>,
407) {
408 for (field, kind) in errors.errors() {
409 let field_path = if prefix.is_empty() {
410 field.to_string()
411 } else {
412 format!("{prefix}.{field}")
413 };
414
415 match kind {
416 ValidationErrorsKind::Field(failures) => {
417 let entry = fields.entry(field_path).or_default();
418 for failure in failures {
419 let detail = if let Some(message) = failure.message.as_ref() {
420 format!("{}: {}", failure.code, message)
421 } else {
422 failure.code.to_string()
423 };
424 entry.push(detail);
425 }
426 }
427 ValidationErrorsKind::Struct(nested) => {
428 collect_validation_errors(&field_path, nested, fields);
429 }
430 ValidationErrorsKind::List(items) => {
431 for (index, nested) in items {
432 let list_path = format!("{field_path}[{index}]");
433 collect_validation_errors(&list_path, nested, fields);
434 }
435 }
436 }
437 }
438}
439
440#[cfg(test)]
441mod tests {
442 use super::*;
443 use serde::Deserialize;
444 #[cfg(feature = "validation")]
445 use validator::{Validate, ValidationErrors};
446
447 #[derive(Debug, Deserialize, PartialEq, Eq)]
448 struct QueryPayload {
449 page: u32,
450 size: u32,
451 }
452
453 #[derive(Debug, Deserialize, PartialEq, Eq)]
454 struct PathPayload {
455 id: u64,
456 slug: String,
457 }
458
459 #[derive(Debug, Deserialize, PartialEq, Eq)]
460 #[cfg_attr(feature = "validation", derive(Validate))]
461 struct JsonPayload {
462 id: u32,
463 name: String,
464 }
465
466 #[cfg(feature = "validation")]
467 #[derive(Debug, Deserialize, Validate)]
468 struct ValidatedPayload {
469 #[validate(length(min = 3, message = "name too short"))]
470 name: String,
471 #[validate(range(min = 1, message = "age must be >= 1"))]
472 age: u8,
473 }
474
475 #[cfg(feature = "validation")]
476 #[derive(Debug, Deserialize, Validate)]
477 #[validate(schema(function = "validate_password_confirmation"))]
478 struct SignupPayload {
479 #[validate(email(message = "email format invalid"))]
480 email: String,
481 password: String,
482 confirm_password: String,
483 }
484
485 #[cfg(feature = "validation")]
486 #[derive(Debug, Deserialize)]
487 struct ManualValidatedPayload {
488 token: String,
489 }
490
491 #[cfg(feature = "validation")]
492 fn validate_password_confirmation(
493 payload: &SignupPayload,
494 ) -> Result<(), validator::ValidationError> {
495 if payload.password != payload.confirm_password {
496 return Err(validator::ValidationError::new("password_mismatch"));
497 }
498 Ok(())
499 }
500
501 #[cfg(feature = "validation")]
502 impl Validate for ManualValidatedPayload {
503 fn validate(&self) -> Result<(), ValidationErrors> {
504 let mut errors = ValidationErrors::new();
505 if !self.token.starts_with("tok_") {
506 let mut error = validator::ValidationError::new("token_prefix");
507 error.message = Some("token must start with tok_".into());
508 errors.add("token", error);
509 }
510
511 if errors.errors().is_empty() {
512 Ok(())
513 } else {
514 Err(errors)
515 }
516 }
517 }
518
519 #[test]
520 fn parse_query_payload() {
521 let payload: QueryPayload = parse_query_str("page=2&size=50").expect("query parse");
522 assert_eq!(payload.page, 2);
523 assert_eq!(payload.size, 50);
524 }
525
526 #[test]
527 fn parse_path_payload() {
528 let mut map = HashMap::new();
529 map.insert("id".to_string(), "42".to_string());
530 map.insert("slug".to_string(), "order-created".to_string());
531 let payload: PathPayload = parse_path_map(&map).expect("path parse");
532 assert_eq!(payload.id, 42);
533 assert_eq!(payload.slug, "order-created");
534 }
535
536 #[test]
537 fn parse_json_payload() {
538 let payload: JsonPayload =
539 parse_json_bytes(br#"{"id":7,"name":"ranvier"}"#).expect("json parse");
540 assert_eq!(payload.id, 7);
541 assert_eq!(payload.name, "ranvier");
542 }
543
544 #[test]
545 fn extract_error_maps_to_bad_request() {
546 let error = ExtractError::InvalidQuery("bad input".to_string());
547 assert_eq!(error.status_code(), StatusCode::BAD_REQUEST);
548 }
549
550 #[tokio::test]
551 async fn json_from_request_with_full_body() {
552 let body = Full::new(Bytes::from_static(br#"{"id":9,"name":"node"}"#));
553 let mut req = Request::builder()
554 .uri("/orders")
555 .body(body)
556 .expect("request build");
557
558 let Json(payload): Json<JsonPayload> = Json::from_request(&mut req).await.expect("extract");
559 assert_eq!(payload.id, 9);
560 assert_eq!(payload.name, "node");
561 }
562
563 #[tokio::test]
564 async fn query_and_path_from_request_extensions() {
565 let body = Full::new(Bytes::new());
566 let mut req = Request::builder()
567 .uri("/orders/42?page=3&size=10")
568 .body(body)
569 .expect("request build");
570
571 let mut params = HashMap::new();
572 params.insert("id".to_string(), "42".to_string());
573 params.insert("slug".to_string(), "created".to_string());
574 req.extensions_mut().insert(PathParams::new(params));
575
576 let Query(query): Query<QueryPayload> = Query::from_request(&mut req).await.expect("query");
577 let Path(path): Path<PathPayload> = Path::from_request(&mut req).await.expect("path");
578
579 assert_eq!(query.page, 3);
580 assert_eq!(query.size, 10);
581 assert_eq!(path.id, 42);
582 assert_eq!(path.slug, "created");
583 }
584
585 #[cfg(feature = "validation")]
586 #[tokio::test]
587 async fn json_validation_rejects_invalid_payload_with_422() {
588 let body = Full::new(Bytes::from_static(br#"{"name":"ab","age":0}"#));
589 let mut req = Request::builder()
590 .uri("/users")
591 .body(body)
592 .expect("request build");
593
594 let error = Json::<ValidatedPayload>::from_request(&mut req)
595 .await
596 .expect_err("payload should fail validation");
597
598 assert_eq!(error.status_code(), StatusCode::UNPROCESSABLE_ENTITY);
599
600 let response = error.into_http_response();
601 assert_eq!(response.status(), StatusCode::UNPROCESSABLE_ENTITY);
602 assert_eq!(
603 response.headers().get(http::header::CONTENT_TYPE),
604 Some(&http::HeaderValue::from_static("application/json"))
605 );
606
607 let body = response.into_body().collect().await.expect("collect body");
608 let json: serde_json::Value =
609 serde_json::from_slice(&body.to_bytes()).expect("validation json body");
610 assert_eq!(json["error"], "validation_failed");
611 assert!(
612 json["fields"]["name"][0]
613 .as_str()
614 .expect("name message")
615 .contains("name too short")
616 );
617 assert!(
618 json["fields"]["age"][0]
619 .as_str()
620 .expect("age message")
621 .contains("age must be >= 1")
622 );
623 }
624
625 #[cfg(feature = "validation")]
626 #[tokio::test]
627 async fn json_validation_supports_schema_level_rules() {
628 let body = Full::new(Bytes::from_static(
629 br#"{"email":"user@example.com","password":"secret123","confirm_password":"different"}"#,
630 ));
631 let mut req = Request::builder()
632 .uri("/signup")
633 .body(body)
634 .expect("request build");
635
636 let error = Json::<SignupPayload>::from_request(&mut req)
637 .await
638 .expect_err("schema validation should fail");
639 assert_eq!(error.status_code(), StatusCode::UNPROCESSABLE_ENTITY);
640
641 let response = error.into_http_response();
642 let body = response.into_body().collect().await.expect("collect body");
643 let json: serde_json::Value =
644 serde_json::from_slice(&body.to_bytes()).expect("validation json body");
645
646 assert_eq!(json["fields"]["__all__"][0], "password_mismatch");
647 }
648
649 #[cfg(feature = "validation")]
650 #[tokio::test]
651 async fn json_validation_accepts_valid_payload() {
652 let body = Full::new(Bytes::from_static(br#"{"name":"valid-name","age":20}"#));
653 let mut req = Request::builder()
654 .uri("/users")
655 .body(body)
656 .expect("request build");
657
658 let Json(payload): Json<ValidatedPayload> = Json::from_request(&mut req)
659 .await
660 .expect("validation should pass");
661 assert_eq!(payload.name, "valid-name");
662 assert_eq!(payload.age, 20);
663 }
664
665 #[cfg(feature = "validation")]
666 #[tokio::test]
667 async fn json_validation_supports_manual_validate_impl_hooks() {
668 let body = Full::new(Bytes::from_static(br#"{"token":"invalid"}"#));
669 let mut req = Request::builder()
670 .uri("/tokens")
671 .body(body)
672 .expect("request build");
673
674 let error = Json::<ManualValidatedPayload>::from_request(&mut req)
675 .await
676 .expect_err("manual validation should fail");
677 assert_eq!(error.status_code(), StatusCode::UNPROCESSABLE_ENTITY);
678
679 let response = error.into_http_response();
680 let body = response.into_body().collect().await.expect("collect body");
681 let json: serde_json::Value =
682 serde_json::from_slice(&body.to_bytes()).expect("validation json body");
683
684 assert_eq!(
685 json["fields"]["token"][0],
686 "token_prefix: token must start with tok_"
687 );
688 }
689
690 fn make_parts_with_cookie(cookie_value: &str) -> http::request::Parts {
693 let (parts, _) = Request::builder()
694 .header(http::header::COOKIE, cookie_value)
695 .body(())
696 .expect("request build")
697 .into_parts();
698 parts
699 }
700
701 #[test]
702 fn cookiejar_parses_standard_cookies() {
703 let parts = make_parts_with_cookie("session=abc123; lang=en");
704 let jar = CookieJar::from_parts(&parts);
705 assert_eq!(jar.get("session"), Some("abc123"));
706 assert_eq!(jar.get("lang"), Some("en"));
707 }
708
709 #[test]
710 fn cookiejar_skips_invalid_names() {
711 let parts = make_parts_with_cookie("good=yes; bad name=no; also,bad=no; ok=fine");
713 let jar = CookieJar::from_parts(&parts);
714 assert_eq!(jar.get("good"), Some("yes"));
715 assert_eq!(jar.get("ok"), Some("fine"));
716 assert!(jar.get("bad name").is_none());
717 assert!(jar.get("also,bad").is_none());
718 }
719
720 #[test]
721 fn cookiejar_unquotes_values() {
722 let parts = make_parts_with_cookie("token=\"quoted_value\"");
723 let jar = CookieJar::from_parts(&parts);
724 assert_eq!(jar.get("token"), Some("quoted_value"));
725 }
726
727 #[test]
728 fn cookiejar_percent_decodes_values() {
729 let parts = make_parts_with_cookie("msg=hello%20world; path=%2Fapi%2Fv1");
730 let jar = CookieJar::from_parts(&parts);
731 assert_eq!(jar.get("msg"), Some("hello world"));
732 assert_eq!(jar.get("path"), Some("/api/v1"));
733 }
734
735 #[test]
736 fn cookiejar_handles_empty_header() {
737 let parts = make_parts_with_cookie("");
738 let jar = CookieJar::from_parts(&parts);
739 assert!(jar.get("anything").is_none());
740 }
741
742 #[test]
743 fn cookiejar_name_validation() {
744 assert!(is_valid_cookie_name("session_id"));
745 assert!(is_valid_cookie_name("__Host-token"));
746 assert!(!is_valid_cookie_name(""));
747 assert!(!is_valid_cookie_name("bad name"));
748 assert!(!is_valid_cookie_name("bad,name"));
749 assert!(!is_valid_cookie_name("bad(name)"));
750 }
751}