1use crate::context::RequestContext;
7use crate::error::{HttpError, ValidationError, ValidationErrors};
8use crate::multipart;
9use crate::request::{Body, Request, RequestBodyStreamError};
10use crate::response::IntoResponse;
11use serde::de::{
12 self, DeserializeOwned, Deserializer, IntoDeserializer, MapAccess, SeqAccess, Visitor,
13};
14use std::fmt;
15use std::future::Future;
16use std::ops::{Deref, DerefMut};
17use std::task::Context;
18
19async fn collect_body_limited(
20 ctx: &RequestContext,
21 body: Body,
22 limit: usize,
23) -> Result<Vec<u8>, RequestBodyStreamError> {
24 match body {
25 Body::Empty => Ok(Vec::new()),
26 Body::Bytes(b) => {
27 if b.len() > limit {
28 Err(RequestBodyStreamError::TooLarge {
29 received: b.len(),
30 max: limit,
31 })
32 } else {
33 Ok(b)
34 }
35 }
36 Body::Stream {
37 stream,
38 content_length,
39 } => {
40 let mut stream = stream.into_inner().unwrap_or_else(|e| e.into_inner());
41 if let Some(n) = content_length {
42 if n > limit {
43 return Err(RequestBodyStreamError::TooLarge {
44 received: n,
45 max: limit,
46 });
47 }
48 }
49
50 let mut out = Vec::with_capacity(content_length.unwrap_or(0).min(limit));
52 let mut seen = 0usize;
53 loop {
54 let next =
55 std::future::poll_fn(|cx: &mut Context<'_>| stream.as_mut().poll_next(cx))
56 .await;
57 let Some(chunk) = next else {
58 break;
59 };
60 let chunk = chunk?;
61 seen = seen.saturating_add(chunk.len());
62 if seen > limit {
63 return Err(RequestBodyStreamError::TooLarge {
64 received: seen,
65 max: limit,
66 });
67 }
68 out.extend_from_slice(&chunk);
69 let _ = ctx.checkpoint();
70 }
71 Ok(out)
72 }
73 }
74}
75
76async fn parse_multipart_limited(
77 ctx: &RequestContext,
78 body: Body,
79 limit: usize,
80 parser: &multipart::MultipartParser,
81) -> Result<Vec<multipart::Part>, MultipartExtractError> {
82 fn map_parser_error(err: multipart::MultipartError) -> MultipartExtractError {
83 match err {
84 multipart::MultipartError::FileTooLarge { size, max }
85 | multipart::MultipartError::TotalTooLarge { size, max } => {
86 MultipartExtractError::PayloadTooLarge { size, limit: max }
87 }
88 multipart::MultipartError::Io { detail } => {
89 MultipartExtractError::ReadError { message: detail }
90 }
91 other => MultipartExtractError::BadRequest {
92 message: other.to_string(),
93 },
94 }
95 }
96
97 fn map_stream_error(err: RequestBodyStreamError) -> MultipartExtractError {
98 match err {
99 RequestBodyStreamError::TooLarge { received, max } => {
100 MultipartExtractError::PayloadTooLarge {
101 size: received,
102 limit: max,
103 }
104 }
105 RequestBodyStreamError::ConnectionClosed => MultipartExtractError::BadRequest {
106 message: RequestBodyStreamError::ConnectionClosed.to_string(),
107 },
108 RequestBodyStreamError::Io(message) => MultipartExtractError::ReadError { message },
109 }
110 }
111
112 match body {
113 Body::Empty => parser.parse(&[]).map_err(map_parser_error),
114 Body::Bytes(bytes) => {
115 if bytes.len() > limit {
116 return Err(MultipartExtractError::PayloadTooLarge {
117 size: bytes.len(),
118 limit,
119 });
120 }
121 parser.parse(&bytes).map_err(map_parser_error)
122 }
123 Body::Stream {
124 stream,
125 content_length,
126 } => {
127 let mut stream = stream.into_inner().unwrap_or_else(|e| e.into_inner());
128
129 if let Some(n) = content_length {
130 if n > limit {
131 return Err(MultipartExtractError::PayloadTooLarge { size: n, limit });
132 }
133 }
134
135 let mut state = multipart::MultipartStreamState::default();
136 let mut buffer = Vec::new();
137 let mut parts = Vec::new();
138 let mut seen = 0usize;
139
140 loop {
141 let next =
142 std::future::poll_fn(|cx: &mut Context<'_>| stream.as_mut().poll_next(cx))
143 .await;
144 let Some(chunk) = next else {
145 break;
146 };
147 let chunk = chunk.map_err(map_stream_error)?;
148
149 seen = seen.saturating_add(chunk.len());
150 if seen > limit {
151 return Err(MultipartExtractError::PayloadTooLarge { size: seen, limit });
152 }
153
154 buffer.extend_from_slice(&chunk);
155 let mut newly_parsed = parser
156 .parse_incremental(&mut buffer, &mut state, false)
157 .map_err(map_parser_error)?;
158 parts.append(&mut newly_parsed);
159 let _ = ctx.checkpoint();
160 }
161
162 let mut tail = parser
163 .parse_incremental(&mut buffer, &mut state, true)
164 .map_err(map_parser_error)?;
165 parts.append(&mut tail);
166
167 if !state.is_done() {
168 return Err(MultipartExtractError::BadRequest {
169 message: RequestBodyStreamError::ConnectionClosed.to_string(),
170 });
171 }
172
173 Ok(parts)
174 }
175 }
176}
177
178pub trait FromRequest: Sized {
208 type Error: IntoResponse;
210
211 fn from_request(
218 ctx: &RequestContext,
219 req: &mut Request,
220 ) -> impl Future<Output = Result<Self, Self::Error>> + Send;
221}
222
223#[derive(Debug)]
229pub enum MultipartExtractError {
230 UnsupportedMediaType { actual: Option<String> },
232 BadRequest { message: String },
234 PayloadTooLarge { size: usize, limit: usize },
236 ReadError { message: String },
238}
239
240impl fmt::Display for MultipartExtractError {
241 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
242 match self {
243 Self::UnsupportedMediaType { actual } => {
244 if let Some(ct) = actual {
245 write!(f, "Expected Content-Type: multipart/form-data, got: {ct}")
246 } else {
247 write!(
248 f,
249 "Missing Content-Type header, expected multipart/form-data"
250 )
251 }
252 }
253 Self::BadRequest { message } => write!(f, "{message}"),
254 Self::PayloadTooLarge { size, limit } => write!(
255 f,
256 "Request body too large: {size} bytes exceeds {limit} byte limit"
257 ),
258 Self::ReadError { message } => write!(f, "Failed to read request body: {message}"),
259 }
260 }
261}
262
263impl std::error::Error for MultipartExtractError {}
264
265impl IntoResponse for MultipartExtractError {
266 fn into_response(self) -> crate::response::Response {
267 use crate::response::{Response, ResponseBody, StatusCode};
268
269 let (status, detail) = match self {
270 Self::UnsupportedMediaType { actual: _ } => {
271 (StatusCode::UNSUPPORTED_MEDIA_TYPE, self.to_string())
272 }
273 Self::BadRequest { message } => (StatusCode::BAD_REQUEST, message),
274 Self::PayloadTooLarge { .. } => (StatusCode::PAYLOAD_TOO_LARGE, self.to_string()),
275 Self::ReadError { .. } => (StatusCode::BAD_REQUEST, self.to_string()),
276 };
277
278 let body = serde_json::json!({ "detail": detail });
279 Response::with_status(status)
280 .header("content-type", b"application/json".to_vec())
281 .body(ResponseBody::Bytes(body.to_string().into_bytes()))
282 }
283}
284
285impl FromRequest for multipart::MultipartForm {
286 type Error = MultipartExtractError;
287
288 async fn from_request(ctx: &RequestContext, req: &mut Request) -> Result<Self, Self::Error> {
289 let _ = ctx.checkpoint();
290
291 let content_type = req
292 .headers()
293 .get("content-type")
294 .and_then(|v| std::str::from_utf8(v).ok());
295 let Some(ct) = content_type else {
296 return Err(MultipartExtractError::UnsupportedMediaType { actual: None });
297 };
298
299 let ct = ct.trim();
300 let main = ct.split(';').next().unwrap_or("").trim();
301 if !main.eq_ignore_ascii_case("multipart/form-data") {
302 return Err(MultipartExtractError::UnsupportedMediaType {
303 actual: Some(ct.to_string()),
304 });
305 }
306
307 let boundary =
308 multipart::parse_boundary(ct).map_err(|e| MultipartExtractError::BadRequest {
309 message: e.to_string(),
310 })?;
311
312 let multipart_config = multipart::MultipartConfig::default();
313 let limit = multipart_config.get_max_total_size();
314 let spool_threshold = multipart_config.get_spool_threshold();
315 let parser = multipart::MultipartParser::new(&boundary, multipart_config);
316 let parts = parse_multipart_limited(ctx, req.take_body(), limit, &parser).await?;
317
318 Ok(multipart::MultipartForm::from_parts_with_spool_threshold(
319 parts,
320 spool_threshold,
321 ))
322 }
323}
324
325#[cfg(test)]
326mod multipart_extractor_tests {
327 use super::*;
328 use crate::request::Method;
329
330 fn test_context() -> RequestContext {
331 let cx = asupersync::Cx::for_testing();
332 RequestContext::new(cx, 12345)
333 }
334
335 #[test]
336 fn multipart_extract_success() {
337 let ctx = test_context();
338 let boundary = "----boundary";
339 let body = concat!(
340 "------boundary\r\n",
341 "Content-Disposition: form-data; name=\"field1\"\r\n",
342 "\r\n",
343 "value1\r\n",
344 "------boundary\r\n",
345 "Content-Disposition: form-data; name=\"file\"; filename=\"test.txt\"\r\n",
346 "Content-Type: text/plain\r\n",
347 "\r\n",
348 "Hello\r\n",
349 "------boundary--\r\n"
350 );
351
352 let mut req = Request::new(Method::Post, "/upload");
353 req.headers_mut().insert(
354 "content-type",
355 format!("multipart/form-data; boundary={boundary}").into_bytes(),
356 );
357 req.set_body(Body::Bytes(body.as_bytes().to_vec()));
358
359 let form =
360 futures_executor::block_on(multipart::MultipartForm::from_request(&ctx, &mut req))
361 .expect("multipart parse");
362 assert_eq!(form.get_field("field1"), Some("value1"));
363 let file = form.get_file("file").expect("file");
364 assert_eq!(file.filename, "test.txt");
365 assert_eq!(file.content_type, "text/plain");
366 assert_eq!(file.bytes().expect("read upload bytes"), b"Hello".to_vec());
367 }
368
369 #[test]
370 fn multipart_extract_wrong_content_type() {
371 let ctx = test_context();
372 let mut req = Request::new(Method::Post, "/upload");
373 req.headers_mut()
374 .insert("content-type", b"application/json".to_vec());
375 req.set_body(Body::Bytes(b"{}".to_vec()));
376
377 let err =
378 futures_executor::block_on(multipart::MultipartForm::from_request(&ctx, &mut req))
379 .unwrap_err();
380 assert!(matches!(
381 err,
382 MultipartExtractError::UnsupportedMediaType { actual: Some(_) }
383 ));
384 }
385
386 #[test]
387 fn multipart_extract_missing_boundary_is_bad_request() {
388 let ctx = test_context();
389 let mut req = Request::new(Method::Post, "/upload");
390 req.headers_mut()
391 .insert("content-type", b"multipart/form-data".to_vec());
392 req.set_body(Body::Bytes(b"".to_vec()));
393
394 let err =
395 futures_executor::block_on(multipart::MultipartForm::from_request(&ctx, &mut req))
396 .unwrap_err();
397 assert!(matches!(err, MultipartExtractError::BadRequest { .. }));
398 }
399
400 #[test]
401 fn multipart_extract_streaming_body() {
402 use asupersync::stream;
403
404 let ctx = test_context();
405 let boundary = "----boundary";
406 let body = concat!(
407 "------boundary\r\n",
408 "Content-Disposition: form-data; name=\"field1\"\r\n",
409 "\r\n",
410 "value1\r\n",
411 "------boundary\r\n",
412 "Content-Disposition: form-data; name=\"file\"; filename=\"test.txt\"\r\n",
413 "Content-Type: text/plain\r\n",
414 "\r\n",
415 "Hello stream\r\n",
416 "------boundary--\r\n"
417 )
418 .as_bytes()
419 .to_vec();
420
421 let chunks: Vec<Result<Vec<u8>, RequestBodyStreamError>> =
422 body.chunks(7).map(|chunk| Ok(chunk.to_vec())).collect();
423 let stream = stream::iter(chunks);
424
425 let mut req = Request::new(Method::Post, "/upload");
426 req.headers_mut().insert(
427 "content-type",
428 format!("multipart/form-data; boundary={boundary}").into_bytes(),
429 );
430 req.set_body(Body::streaming(stream));
431
432 let form =
433 futures_executor::block_on(multipart::MultipartForm::from_request(&ctx, &mut req))
434 .expect("multipart parse");
435 assert_eq!(form.get_field("field1"), Some("value1"));
436 let file = form.get_file("file").expect("file");
437 assert_eq!(
438 file.bytes().expect("read upload bytes"),
439 b"Hello stream".to_vec()
440 );
441 }
442
443 #[test]
444 fn multipart_extract_file_too_large_maps_to_payload_too_large() {
445 let ctx = test_context();
446 let boundary = "----boundary";
447 let oversized = vec![b'a'; multipart::DEFAULT_MAX_FILE_SIZE + 1];
448
449 let mut body = Vec::new();
450 body.extend_from_slice(format!("--{boundary}\r\n").as_bytes());
451 body.extend_from_slice(
452 b"Content-Disposition: form-data; name=\"file\"; filename=\"big.bin\"\r\n",
453 );
454 body.extend_from_slice(b"Content-Type: application/octet-stream\r\n\r\n");
455 body.extend_from_slice(&oversized);
456 body.extend_from_slice(b"\r\n");
457 body.extend_from_slice(format!("--{boundary}--\r\n").as_bytes());
458
459 let mut req = Request::new(Method::Post, "/upload");
460 req.headers_mut().insert(
461 "content-type",
462 format!("multipart/form-data; boundary={boundary}").into_bytes(),
463 );
464 req.set_body(Body::Bytes(body));
465
466 let err =
467 futures_executor::block_on(multipart::MultipartForm::from_request(&ctx, &mut req))
468 .unwrap_err();
469 assert!(matches!(err, MultipartExtractError::PayloadTooLarge { .. }));
470 if let MultipartExtractError::PayloadTooLarge { size, limit } = err {
471 assert!(size > limit);
472 assert_eq!(limit, multipart::DEFAULT_MAX_FILE_SIZE);
473 }
474 }
475}
476
477impl<T: FromRequest> FromRequest for Option<T> {
479 type Error = std::convert::Infallible;
480
481 async fn from_request(ctx: &RequestContext, req: &mut Request) -> Result<Self, Self::Error> {
482 Ok(T::from_request(ctx, req).await.ok())
483 }
484}
485
486impl FromRequest for RequestContext {
488 type Error = std::convert::Infallible;
489
490 async fn from_request(ctx: &RequestContext, _req: &mut Request) -> Result<Self, Self::Error> {
491 Ok(ctx.clone())
492 }
493}
494
495pub const DEFAULT_JSON_LIMIT: usize = 1024 * 1024;
501
502#[derive(Debug, Clone)]
504pub struct JsonConfig {
505 limit: usize,
507 content_type: Option<String>,
510}
511
512impl Default for JsonConfig {
513 fn default() -> Self {
514 Self {
515 limit: DEFAULT_JSON_LIMIT,
516 content_type: None,
517 }
518 }
519}
520
521impl JsonConfig {
522 #[must_use]
524 pub fn new() -> Self {
525 Self::default()
526 }
527
528 #[must_use]
530 pub fn limit(mut self, limit: usize) -> Self {
531 self.limit = limit;
532 self
533 }
534
535 #[must_use]
537 pub fn content_type(mut self, content_type: impl Into<String>) -> Self {
538 self.content_type = Some(content_type.into());
539 self
540 }
541
542 #[must_use]
544 pub fn get_limit(&self) -> usize {
545 self.limit
546 }
547}
548
549#[derive(Debug, Clone, Copy, Default)]
576pub struct Json<T>(pub T);
577
578impl<T> Json<T> {
579 pub fn into_inner(self) -> T {
581 self.0
582 }
583}
584
585impl<T> Deref for Json<T> {
586 type Target = T;
587
588 fn deref(&self) -> &Self::Target {
589 &self.0
590 }
591}
592
593impl<T> DerefMut for Json<T> {
594 fn deref_mut(&mut self) -> &mut Self::Target {
595 &mut self.0
596 }
597}
598
599#[derive(Debug)]
601pub enum JsonExtractError {
602 UnsupportedMediaType {
604 actual: Option<String>,
606 },
607 PayloadTooLarge {
609 size: usize,
611 limit: usize,
613 },
614 ReadError {
616 message: String,
618 },
619 DeserializeError {
621 message: String,
623 line: Option<usize>,
625 column: Option<usize>,
627 },
628}
629
630impl std::fmt::Display for JsonExtractError {
631 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
632 match self {
633 Self::UnsupportedMediaType { actual } => {
634 if let Some(ct) = actual {
635 write!(f, "Expected Content-Type: application/json, got: {ct}")
636 } else {
637 write!(f, "Missing Content-Type header, expected application/json")
638 }
639 }
640 Self::PayloadTooLarge { size, limit } => {
641 write!(
642 f,
643 "Request body too large: {size} bytes exceeds {limit} byte limit"
644 )
645 }
646 Self::ReadError { message } => write!(f, "Failed to read request body: {message}"),
647 Self::DeserializeError {
648 message,
649 line,
650 column,
651 } => {
652 if let (Some(l), Some(c)) = (line, column) {
653 write!(f, "JSON parse error at line {l}, column {c}: {message}")
654 } else {
655 write!(f, "JSON parse error: {message}")
656 }
657 }
658 }
659 }
660}
661
662impl std::error::Error for JsonExtractError {}
663
664impl IntoResponse for JsonExtractError {
665 fn into_response(self) -> crate::response::Response {
666 match self {
667 Self::UnsupportedMediaType { actual } => {
668 let detail = if let Some(ct) = actual {
669 format!("Expected Content-Type: application/json, got: {ct}")
670 } else {
671 "Missing Content-Type header, expected application/json".to_string()
672 };
673 HttpError::unsupported_media_type()
674 .with_detail(detail)
675 .into_response()
676 }
677 Self::PayloadTooLarge { size, limit } => HttpError::payload_too_large()
678 .with_detail(format!(
679 "Request body too large: {size} bytes exceeds {limit} byte limit"
680 ))
681 .into_response(),
682 Self::ReadError { message } => HttpError::bad_request()
683 .with_detail(format!("Failed to read request body: {message}"))
684 .into_response(),
685 Self::DeserializeError {
686 message,
687 line,
688 column,
689 } => {
690 let msg = if let (Some(l), Some(c)) = (line, column) {
692 format!("JSON parse error at line {l}, column {c}: {message}")
693 } else {
694 format!("JSON parse error: {message}")
695 };
696 ValidationErrors::single(ValidationError::json_invalid(
697 crate::error::loc::body(),
698 msg,
699 ))
700 .into_response()
701 }
702 }
703 }
704}
705
706impl<T: DeserializeOwned> FromRequest for Json<T> {
707 type Error = JsonExtractError;
708
709 async fn from_request(ctx: &RequestContext, req: &mut Request) -> Result<Self, Self::Error> {
710 let _ = ctx.checkpoint();
712
713 let content_type = req
715 .headers()
716 .get("content-type")
717 .and_then(|v| std::str::from_utf8(v).ok());
718
719 let is_json = content_type.is_some_and(|ct| {
720 let ct_lower = ct.to_ascii_lowercase();
721 let base_type = ct_lower.split(';').next().unwrap_or("").trim();
724 base_type == "application/json"
725 || (base_type.starts_with("application/") && base_type.ends_with("+json"))
726 });
727
728 if !is_json {
729 return Err(JsonExtractError::UnsupportedMediaType {
730 actual: content_type.map(String::from),
731 });
732 }
733
734 let body = req.take_body();
736 let limit = DEFAULT_JSON_LIMIT;
737 let bytes = collect_body_limited(ctx, body, limit)
738 .await
739 .map_err(|e| match e {
740 RequestBodyStreamError::TooLarge { received, .. } => {
741 JsonExtractError::PayloadTooLarge {
742 size: received,
743 limit,
744 }
745 }
746 other => JsonExtractError::ReadError {
747 message: other.to_string(),
748 },
749 })?;
750
751 let _ = ctx.checkpoint();
753
754 serde_json::from_slice(&bytes)
756 .map(Json)
757 .map_err(|e| JsonExtractError::DeserializeError {
758 message: e.to_string(),
759 line: Some(e.line()),
760 column: Some(e.column()),
761 })
762 }
763}
764
765#[cfg(test)]
766mod tests {
767 use super::*;
768 use crate::request::Method;
769
770 fn test_context() -> RequestContext {
772 let cx = asupersync::Cx::for_testing();
773 RequestContext::new(cx, 12345)
774 }
775
776 fn json_request(body: &str) -> Request {
778 let mut req = Request::new(Method::Post, "/test");
779 req.headers_mut()
780 .insert("content-type", b"application/json".to_vec());
781 req.set_body(Body::Bytes(body.as_bytes().to_vec()));
782 req
783 }
784
785 #[test]
786 fn json_config_defaults() {
787 let config = JsonConfig::default();
788 assert_eq!(config.get_limit(), DEFAULT_JSON_LIMIT);
789 }
790
791 #[test]
792 fn json_config_custom() {
793 let config = JsonConfig::new().limit(1024);
794 assert_eq!(config.get_limit(), 1024);
795 }
796
797 #[test]
798 fn json_deref() {
799 let json = Json(42i32);
800 assert_eq!(*json, 42);
801 }
802
803 #[test]
804 fn json_into_inner() {
805 let json = Json("hello".to_string());
806 assert_eq!(json.into_inner(), "hello");
807 }
808
809 #[test]
810 fn json_extract_success() {
811 use serde::Deserialize;
812
813 #[derive(Deserialize, Debug, PartialEq)]
814 struct TestPayload {
815 name: String,
816 value: i32,
817 }
818
819 let ctx = test_context();
820 let mut req = json_request(r#"{"name": "test", "value": 42}"#);
821
822 let result = futures_executor::block_on(Json::<TestPayload>::from_request(&ctx, &mut req));
823 let Json(payload) = result.unwrap();
824 assert_eq!(payload.name, "test");
825 assert_eq!(payload.value, 42);
826 }
827
828 #[test]
829 fn json_extract_wrong_content_type() {
830 use serde::Deserialize;
831
832 #[derive(Deserialize)]
833 struct TestPayload {
834 #[allow(dead_code)]
835 name: String,
836 }
837
838 let ctx = test_context();
839 let mut req = Request::new(Method::Post, "/test");
840 req.headers_mut()
841 .insert("content-type", b"text/plain".to_vec());
842 req.set_body(Body::Bytes(b"{}".to_vec()));
843
844 let result = futures_executor::block_on(Json::<TestPayload>::from_request(&ctx, &mut req));
845 assert!(matches!(
846 result,
847 Err(JsonExtractError::UnsupportedMediaType { actual: Some(ct) })
848 if ct == "text/plain"
849 ));
850 }
851
852 #[test]
853 fn json_extract_missing_content_type() {
854 use serde::Deserialize;
855
856 #[derive(Deserialize)]
857 struct TestPayload {
858 #[allow(dead_code)]
859 name: String,
860 }
861
862 let ctx = test_context();
863 let mut req = Request::new(Method::Post, "/test");
864 req.set_body(Body::Bytes(b"{}".to_vec()));
865
866 let result = futures_executor::block_on(Json::<TestPayload>::from_request(&ctx, &mut req));
867 assert!(matches!(
868 result,
869 Err(JsonExtractError::UnsupportedMediaType { actual: None })
870 ));
871 }
872
873 #[test]
874 fn json_extract_invalid_json() {
875 use serde::Deserialize;
876
877 #[derive(Deserialize)]
878 struct TestPayload {
879 #[allow(dead_code)]
880 name: String,
881 }
882
883 let ctx = test_context();
884 let mut req = json_request(r#"{"name": invalid}"#);
885
886 let result = futures_executor::block_on(Json::<TestPayload>::from_request(&ctx, &mut req));
887 assert!(matches!(
888 result,
889 Err(JsonExtractError::DeserializeError { .. })
890 ));
891 }
892
893 #[test]
894 fn json_extract_application_json_charset() {
895 use serde::Deserialize;
896
897 #[derive(Deserialize, PartialEq, Debug)]
898 struct TestPayload {
899 value: i32,
900 }
901
902 let ctx = test_context();
903 let mut req = Request::new(Method::Post, "/test");
904 req.headers_mut()
905 .insert("content-type", b"application/json; charset=utf-8".to_vec());
906 req.set_body(Body::Bytes(b"{\"value\": 123}".to_vec()));
907
908 let result = futures_executor::block_on(Json::<TestPayload>::from_request(&ctx, &mut req));
909 let Json(payload) = result.unwrap();
910 assert_eq!(payload.value, 123);
911 }
912
913 #[test]
914 fn json_extract_vendor_json() {
915 use serde::Deserialize;
916
917 #[derive(Deserialize, PartialEq, Debug)]
918 struct TestPayload {
919 value: i32,
920 }
921
922 let ctx = test_context();
923 let mut req = Request::new(Method::Post, "/test");
924 req.headers_mut()
926 .insert("content-type", b"application/vnd.api+json".to_vec());
927 req.set_body(Body::Bytes(b"{\"value\": 456}".to_vec()));
928
929 let result = futures_executor::block_on(Json::<TestPayload>::from_request(&ctx, &mut req));
930 let Json(payload) = result.unwrap();
931 assert_eq!(payload.value, 456);
932 }
933
934 #[test]
935 fn json_error_display() {
936 let err = JsonExtractError::UnsupportedMediaType {
937 actual: Some("text/html".to_string()),
938 };
939 assert!(err.to_string().contains("text/html"));
940
941 let err = JsonExtractError::PayloadTooLarge {
942 size: 2000,
943 limit: 1000,
944 };
945 assert!(err.to_string().contains("2000"));
946 assert!(err.to_string().contains("1000"));
947
948 let err = JsonExtractError::DeserializeError {
949 message: "unexpected token".to_string(),
950 line: Some(1),
951 column: Some(10),
952 };
953 assert!(err.to_string().contains("line 1"));
954 assert!(err.to_string().contains("column 10"));
955 }
956}
957
958#[derive(Debug, Clone, Default)]
973pub struct PathParams(pub Vec<(String, String)>);
974
975impl PathParams {
976 #[must_use]
978 pub fn new() -> Self {
979 Self(Vec::new())
980 }
981
982 #[must_use]
984 pub fn from_pairs(pairs: Vec<(String, String)>) -> Self {
985 Self(pairs)
986 }
987
988 #[must_use]
990 pub fn get(&self, name: &str) -> Option<&str> {
991 self.0
992 .iter()
993 .find(|(n, _)| n == name)
994 .map(|(_, v)| v.as_str())
995 }
996
997 #[must_use]
999 pub fn as_slice(&self) -> &[(String, String)] {
1000 &self.0
1001 }
1002
1003 #[must_use]
1005 pub fn is_empty(&self) -> bool {
1006 self.0.is_empty()
1007 }
1008
1009 #[must_use]
1011 pub fn len(&self) -> usize {
1012 self.0.len()
1013 }
1014}
1015
1016#[derive(Debug, Clone, Copy, Default)]
1066pub struct Path<T>(pub T);
1067
1068impl<T> Path<T> {
1069 pub fn into_inner(self) -> T {
1071 self.0
1072 }
1073}
1074
1075impl<T> Deref for Path<T> {
1076 type Target = T;
1077
1078 fn deref(&self) -> &Self::Target {
1079 &self.0
1080 }
1081}
1082
1083impl<T> DerefMut for Path<T> {
1084 fn deref_mut(&mut self) -> &mut Self::Target {
1085 &mut self.0
1086 }
1087}
1088
1089#[derive(Debug)]
1091pub enum PathExtractError {
1092 MissingPathParams,
1095 MissingParam {
1097 name: String,
1099 },
1100 InvalidValue {
1102 name: String,
1104 value: String,
1106 expected: &'static str,
1108 message: String,
1110 },
1111 DeserializeError {
1113 message: String,
1115 },
1116}
1117
1118impl fmt::Display for PathExtractError {
1119 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1120 match self {
1121 Self::MissingPathParams => {
1122 write!(f, "Path parameters not available in request")
1123 }
1124 Self::MissingParam { name } => {
1125 write!(f, "Missing path parameter: {name}")
1126 }
1127 Self::InvalidValue {
1128 name,
1129 value,
1130 expected,
1131 message,
1132 } => {
1133 write!(
1134 f,
1135 "Invalid value for path parameter '{name}': expected {expected}, got '{value}': {message}"
1136 )
1137 }
1138 Self::DeserializeError { message } => {
1139 write!(f, "Path deserialization error: {message}")
1140 }
1141 }
1142 }
1143}
1144
1145impl std::error::Error for PathExtractError {}
1146
1147impl IntoResponse for PathExtractError {
1148 fn into_response(self) -> crate::response::Response {
1149 match self {
1150 Self::MissingPathParams => {
1151 HttpError::internal()
1153 .with_detail("Path parameters not available")
1154 .into_response()
1155 }
1156 Self::MissingParam { name } => ValidationErrors::single(
1157 ValidationError::missing(crate::error::loc::path(&name))
1158 .with_msg("Path parameter is required"),
1159 )
1160 .into_response(),
1161 Self::InvalidValue {
1162 name,
1163 value,
1164 expected,
1165 message,
1166 } => ValidationErrors::single(
1167 ValidationError::type_error(crate::error::loc::path(&name), &expected)
1168 .with_msg(format!("Expected {expected}: {message}"))
1169 .with_input(serde_json::Value::String(value)),
1170 )
1171 .into_response(),
1172 Self::DeserializeError { message } => ValidationErrors::single(
1173 ValidationError::new(
1174 crate::error::error_types::VALUE_ERROR,
1175 vec![crate::error::LocItem::field("path")],
1176 )
1177 .with_msg(message),
1178 )
1179 .into_response(),
1180 }
1181 }
1182}
1183
1184impl<T: DeserializeOwned> FromRequest for Path<T> {
1185 type Error = PathExtractError;
1186
1187 async fn from_request(_ctx: &RequestContext, req: &mut Request) -> Result<Self, Self::Error> {
1188 let params = req
1190 .get_extension::<PathParams>()
1191 .ok_or(PathExtractError::MissingPathParams)?
1192 .clone();
1193
1194 let value = T::deserialize(PathDeserializer::new(¶ms))?;
1196
1197 Ok(Path(value))
1198 }
1199}
1200
1201struct PathDeserializer<'de> {
1212 params: &'de PathParams,
1213}
1214
1215impl<'de> PathDeserializer<'de> {
1216 fn new(params: &'de PathParams) -> Self {
1217 Self { params }
1218 }
1219}
1220
1221impl<'de> Deserializer<'de> for PathDeserializer<'de> {
1222 type Error = PathExtractError;
1223
1224 fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
1225 where
1226 V: Visitor<'de>,
1227 {
1228 self.deserialize_map(visitor)
1230 }
1231
1232 fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value, Self::Error>
1233 where
1234 V: Visitor<'de>,
1235 {
1236 let value = self.get_single_value()?;
1237 let b = value
1238 .parse::<bool>()
1239 .map_err(|_| PathExtractError::InvalidValue {
1240 name: self.get_first_name(),
1241 value: value.to_string(),
1242 expected: "boolean",
1243 message: "expected 'true' or 'false'".to_string(),
1244 })?;
1245 visitor.visit_bool(b)
1246 }
1247
1248 fn deserialize_i8<V>(self, visitor: V) -> Result<V::Value, Self::Error>
1249 where
1250 V: Visitor<'de>,
1251 {
1252 let value = self.get_single_value()?;
1253 let n = value
1254 .parse::<i8>()
1255 .map_err(|e| PathExtractError::InvalidValue {
1256 name: self.get_first_name(),
1257 value: value.to_string(),
1258 expected: "i8",
1259 message: e.to_string(),
1260 })?;
1261 visitor.visit_i8(n)
1262 }
1263
1264 fn deserialize_i16<V>(self, visitor: V) -> Result<V::Value, Self::Error>
1265 where
1266 V: Visitor<'de>,
1267 {
1268 let value = self.get_single_value()?;
1269 let n = value
1270 .parse::<i16>()
1271 .map_err(|e| PathExtractError::InvalidValue {
1272 name: self.get_first_name(),
1273 value: value.to_string(),
1274 expected: "i16",
1275 message: e.to_string(),
1276 })?;
1277 visitor.visit_i16(n)
1278 }
1279
1280 fn deserialize_i32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
1281 where
1282 V: Visitor<'de>,
1283 {
1284 let value = self.get_single_value()?;
1285 let n = value
1286 .parse::<i32>()
1287 .map_err(|e| PathExtractError::InvalidValue {
1288 name: self.get_first_name(),
1289 value: value.to_string(),
1290 expected: "i32",
1291 message: e.to_string(),
1292 })?;
1293 visitor.visit_i32(n)
1294 }
1295
1296 fn deserialize_i64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
1297 where
1298 V: Visitor<'de>,
1299 {
1300 let value = self.get_single_value()?;
1301 let n = value
1302 .parse::<i64>()
1303 .map_err(|e| PathExtractError::InvalidValue {
1304 name: self.get_first_name(),
1305 value: value.to_string(),
1306 expected: "i64",
1307 message: e.to_string(),
1308 })?;
1309 visitor.visit_i64(n)
1310 }
1311
1312 fn deserialize_i128<V>(self, visitor: V) -> Result<V::Value, Self::Error>
1313 where
1314 V: Visitor<'de>,
1315 {
1316 let value = self.get_single_value()?;
1317 let n = value
1318 .parse::<i128>()
1319 .map_err(|e| PathExtractError::InvalidValue {
1320 name: self.get_first_name(),
1321 value: value.to_string(),
1322 expected: "i128",
1323 message: e.to_string(),
1324 })?;
1325 visitor.visit_i128(n)
1326 }
1327
1328 fn deserialize_u8<V>(self, visitor: V) -> Result<V::Value, Self::Error>
1329 where
1330 V: Visitor<'de>,
1331 {
1332 let value = self.get_single_value()?;
1333 let n = value
1334 .parse::<u8>()
1335 .map_err(|e| PathExtractError::InvalidValue {
1336 name: self.get_first_name(),
1337 value: value.to_string(),
1338 expected: "u8",
1339 message: e.to_string(),
1340 })?;
1341 visitor.visit_u8(n)
1342 }
1343
1344 fn deserialize_u16<V>(self, visitor: V) -> Result<V::Value, Self::Error>
1345 where
1346 V: Visitor<'de>,
1347 {
1348 let value = self.get_single_value()?;
1349 let n = value
1350 .parse::<u16>()
1351 .map_err(|e| PathExtractError::InvalidValue {
1352 name: self.get_first_name(),
1353 value: value.to_string(),
1354 expected: "u16",
1355 message: e.to_string(),
1356 })?;
1357 visitor.visit_u16(n)
1358 }
1359
1360 fn deserialize_u32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
1361 where
1362 V: Visitor<'de>,
1363 {
1364 let value = self.get_single_value()?;
1365 let n = value
1366 .parse::<u32>()
1367 .map_err(|e| PathExtractError::InvalidValue {
1368 name: self.get_first_name(),
1369 value: value.to_string(),
1370 expected: "u32",
1371 message: e.to_string(),
1372 })?;
1373 visitor.visit_u32(n)
1374 }
1375
1376 fn deserialize_u64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
1377 where
1378 V: Visitor<'de>,
1379 {
1380 let value = self.get_single_value()?;
1381 let n = value
1382 .parse::<u64>()
1383 .map_err(|e| PathExtractError::InvalidValue {
1384 name: self.get_first_name(),
1385 value: value.to_string(),
1386 expected: "u64",
1387 message: e.to_string(),
1388 })?;
1389 visitor.visit_u64(n)
1390 }
1391
1392 fn deserialize_u128<V>(self, visitor: V) -> Result<V::Value, Self::Error>
1393 where
1394 V: Visitor<'de>,
1395 {
1396 let value = self.get_single_value()?;
1397 let n = value
1398 .parse::<u128>()
1399 .map_err(|e| PathExtractError::InvalidValue {
1400 name: self.get_first_name(),
1401 value: value.to_string(),
1402 expected: "u128",
1403 message: e.to_string(),
1404 })?;
1405 visitor.visit_u128(n)
1406 }
1407
1408 fn deserialize_f32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
1409 where
1410 V: Visitor<'de>,
1411 {
1412 let value = self.get_single_value()?;
1413 let n = value
1414 .parse::<f32>()
1415 .map_err(|e| PathExtractError::InvalidValue {
1416 name: self.get_first_name(),
1417 value: value.to_string(),
1418 expected: "f32",
1419 message: e.to_string(),
1420 })?;
1421 visitor.visit_f32(n)
1422 }
1423
1424 fn deserialize_f64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
1425 where
1426 V: Visitor<'de>,
1427 {
1428 let value = self.get_single_value()?;
1429 let n = value
1430 .parse::<f64>()
1431 .map_err(|e| PathExtractError::InvalidValue {
1432 name: self.get_first_name(),
1433 value: value.to_string(),
1434 expected: "f64",
1435 message: e.to_string(),
1436 })?;
1437 visitor.visit_f64(n)
1438 }
1439
1440 fn deserialize_char<V>(self, visitor: V) -> Result<V::Value, Self::Error>
1441 where
1442 V: Visitor<'de>,
1443 {
1444 let value = self.get_single_value()?;
1445 let mut chars = value.chars();
1446 let c = chars.next().ok_or_else(|| PathExtractError::InvalidValue {
1447 name: self.get_first_name(),
1448 value: value.to_string(),
1449 expected: "char",
1450 message: "empty string".to_string(),
1451 })?;
1452 if chars.next().is_some() {
1453 return Err(PathExtractError::InvalidValue {
1454 name: self.get_first_name(),
1455 value: value.to_string(),
1456 expected: "char",
1457 message: "expected single character".to_string(),
1458 });
1459 }
1460 visitor.visit_char(c)
1461 }
1462
1463 fn deserialize_str<V>(self, visitor: V) -> Result<V::Value, Self::Error>
1464 where
1465 V: Visitor<'de>,
1466 {
1467 let value = self.get_single_value()?;
1468 visitor.visit_str(value)
1469 }
1470
1471 fn deserialize_string<V>(self, visitor: V) -> Result<V::Value, Self::Error>
1472 where
1473 V: Visitor<'de>,
1474 {
1475 let value = self.get_single_value()?;
1476 visitor.visit_string(value.to_string())
1477 }
1478
1479 fn deserialize_bytes<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
1480 where
1481 V: Visitor<'de>,
1482 {
1483 Err(PathExtractError::DeserializeError {
1484 message: "bytes deserialization not supported for path parameters".to_string(),
1485 })
1486 }
1487
1488 fn deserialize_byte_buf<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
1489 where
1490 V: Visitor<'de>,
1491 {
1492 Err(PathExtractError::DeserializeError {
1493 message: "byte_buf deserialization not supported for path parameters".to_string(),
1494 })
1495 }
1496
1497 fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Self::Error>
1498 where
1499 V: Visitor<'de>,
1500 {
1501 visitor.visit_some(self)
1503 }
1504
1505 fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value, Self::Error>
1506 where
1507 V: Visitor<'de>,
1508 {
1509 visitor.visit_unit()
1510 }
1511
1512 fn deserialize_unit_struct<V>(
1513 self,
1514 _name: &'static str,
1515 visitor: V,
1516 ) -> Result<V::Value, Self::Error>
1517 where
1518 V: Visitor<'de>,
1519 {
1520 visitor.visit_unit()
1521 }
1522
1523 fn deserialize_newtype_struct<V>(
1524 self,
1525 _name: &'static str,
1526 visitor: V,
1527 ) -> Result<V::Value, Self::Error>
1528 where
1529 V: Visitor<'de>,
1530 {
1531 visitor.visit_newtype_struct(self)
1532 }
1533
1534 fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value, Self::Error>
1535 where
1536 V: Visitor<'de>,
1537 {
1538 visitor.visit_seq(PathSeqAccess::new(self.params))
1539 }
1540
1541 fn deserialize_tuple<V>(self, _len: usize, visitor: V) -> Result<V::Value, Self::Error>
1542 where
1543 V: Visitor<'de>,
1544 {
1545 visitor.visit_seq(PathSeqAccess::new(self.params))
1546 }
1547
1548 fn deserialize_tuple_struct<V>(
1549 self,
1550 _name: &'static str,
1551 _len: usize,
1552 visitor: V,
1553 ) -> Result<V::Value, Self::Error>
1554 where
1555 V: Visitor<'de>,
1556 {
1557 visitor.visit_seq(PathSeqAccess::new(self.params))
1558 }
1559
1560 fn deserialize_map<V>(self, visitor: V) -> Result<V::Value, Self::Error>
1561 where
1562 V: Visitor<'de>,
1563 {
1564 visitor.visit_map(PathMapAccess::new(self.params))
1565 }
1566
1567 fn deserialize_struct<V>(
1568 self,
1569 _name: &'static str,
1570 _fields: &'static [&'static str],
1571 visitor: V,
1572 ) -> Result<V::Value, Self::Error>
1573 where
1574 V: Visitor<'de>,
1575 {
1576 visitor.visit_map(PathMapAccess::new(self.params))
1577 }
1578
1579 fn deserialize_enum<V>(
1580 self,
1581 _name: &'static str,
1582 _variants: &'static [&'static str],
1583 visitor: V,
1584 ) -> Result<V::Value, Self::Error>
1585 where
1586 V: Visitor<'de>,
1587 {
1588 let value = self.get_single_value()?;
1589 visitor.visit_enum(value.into_deserializer())
1590 }
1591
1592 fn deserialize_identifier<V>(self, visitor: V) -> Result<V::Value, Self::Error>
1593 where
1594 V: Visitor<'de>,
1595 {
1596 self.deserialize_str(visitor)
1597 }
1598
1599 fn deserialize_ignored_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
1600 where
1601 V: Visitor<'de>,
1602 {
1603 visitor.visit_unit()
1604 }
1605}
1606
1607impl PathDeserializer<'_> {
1608 fn get_single_value(&self) -> Result<&str, PathExtractError> {
1609 self.params
1610 .0
1611 .first()
1612 .map(|(_, v)| v.as_str())
1613 .ok_or_else(|| PathExtractError::DeserializeError {
1614 message: "no path parameters available".to_string(),
1615 })
1616 }
1617
1618 fn get_first_name(&self) -> String {
1619 self.params
1620 .0
1621 .first()
1622 .map_or_else(|| "unknown".to_string(), |(n, _)| n.clone())
1623 }
1624}
1625
1626impl de::Error for PathExtractError {
1627 fn custom<T: fmt::Display>(msg: T) -> Self {
1628 PathExtractError::DeserializeError {
1629 message: msg.to_string(),
1630 }
1631 }
1632}
1633
1634struct PathSeqAccess<'de> {
1636 params: &'de PathParams,
1637 index: usize,
1638}
1639
1640impl<'de> PathSeqAccess<'de> {
1641 fn new(params: &'de PathParams) -> Self {
1642 Self { params, index: 0 }
1643 }
1644}
1645
1646impl<'de> SeqAccess<'de> for PathSeqAccess<'de> {
1647 type Error = PathExtractError;
1648
1649 fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>, Self::Error>
1650 where
1651 T: de::DeserializeSeed<'de>,
1652 {
1653 if self.index >= self.params.0.len() {
1654 return Ok(None);
1655 }
1656
1657 let (name, value) = &self.params.0[self.index];
1658 self.index += 1;
1659
1660 seed.deserialize(PathValueDeserializer::new(name, value))
1661 .map(Some)
1662 }
1663}
1664
1665struct PathMapAccess<'de> {
1667 params: &'de PathParams,
1668 index: usize,
1669}
1670
1671impl<'de> PathMapAccess<'de> {
1672 fn new(params: &'de PathParams) -> Self {
1673 Self { params, index: 0 }
1674 }
1675}
1676
1677impl<'de> MapAccess<'de> for PathMapAccess<'de> {
1678 type Error = PathExtractError;
1679
1680 fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>, Self::Error>
1681 where
1682 K: de::DeserializeSeed<'de>,
1683 {
1684 if self.index >= self.params.0.len() {
1685 return Ok(None);
1686 }
1687
1688 let (name, _) = &self.params.0[self.index];
1689 seed.deserialize(name.as_str().into_deserializer())
1690 .map(Some)
1691 }
1692
1693 fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value, Self::Error>
1694 where
1695 V: de::DeserializeSeed<'de>,
1696 {
1697 let (name, value) = &self.params.0[self.index];
1698 self.index += 1;
1699
1700 seed.deserialize(PathValueDeserializer::new(name, value))
1701 }
1702}
1703
1704struct PathValueDeserializer<'de> {
1706 name: &'de str,
1707 value: &'de str,
1708}
1709
1710impl<'de> PathValueDeserializer<'de> {
1711 fn new(name: &'de str, value: &'de str) -> Self {
1712 Self { name, value }
1713 }
1714}
1715
1716impl<'de> Deserializer<'de> for PathValueDeserializer<'de> {
1717 type Error = PathExtractError;
1718
1719 fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
1720 where
1721 V: Visitor<'de>,
1722 {
1723 visitor.visit_str(self.value)
1725 }
1726
1727 fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value, Self::Error>
1728 where
1729 V: Visitor<'de>,
1730 {
1731 let b = self
1732 .value
1733 .parse::<bool>()
1734 .map_err(|_| PathExtractError::InvalidValue {
1735 name: self.name.to_string(),
1736 value: self.value.to_string(),
1737 expected: "boolean",
1738 message: "expected 'true' or 'false'".to_string(),
1739 })?;
1740 visitor.visit_bool(b)
1741 }
1742
1743 fn deserialize_i8<V>(self, visitor: V) -> Result<V::Value, Self::Error>
1744 where
1745 V: Visitor<'de>,
1746 {
1747 let n = self
1748 .value
1749 .parse::<i8>()
1750 .map_err(|e| PathExtractError::InvalidValue {
1751 name: self.name.to_string(),
1752 value: self.value.to_string(),
1753 expected: "i8",
1754 message: e.to_string(),
1755 })?;
1756 visitor.visit_i8(n)
1757 }
1758
1759 fn deserialize_i16<V>(self, visitor: V) -> Result<V::Value, Self::Error>
1760 where
1761 V: Visitor<'de>,
1762 {
1763 let n = self
1764 .value
1765 .parse::<i16>()
1766 .map_err(|e| PathExtractError::InvalidValue {
1767 name: self.name.to_string(),
1768 value: self.value.to_string(),
1769 expected: "i16",
1770 message: e.to_string(),
1771 })?;
1772 visitor.visit_i16(n)
1773 }
1774
1775 fn deserialize_i32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
1776 where
1777 V: Visitor<'de>,
1778 {
1779 let n = self
1780 .value
1781 .parse::<i32>()
1782 .map_err(|e| PathExtractError::InvalidValue {
1783 name: self.name.to_string(),
1784 value: self.value.to_string(),
1785 expected: "i32",
1786 message: e.to_string(),
1787 })?;
1788 visitor.visit_i32(n)
1789 }
1790
1791 fn deserialize_i64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
1792 where
1793 V: Visitor<'de>,
1794 {
1795 let n = self
1796 .value
1797 .parse::<i64>()
1798 .map_err(|e| PathExtractError::InvalidValue {
1799 name: self.name.to_string(),
1800 value: self.value.to_string(),
1801 expected: "i64",
1802 message: e.to_string(),
1803 })?;
1804 visitor.visit_i64(n)
1805 }
1806
1807 fn deserialize_i128<V>(self, visitor: V) -> Result<V::Value, Self::Error>
1808 where
1809 V: Visitor<'de>,
1810 {
1811 let n = self
1812 .value
1813 .parse::<i128>()
1814 .map_err(|e| PathExtractError::InvalidValue {
1815 name: self.name.to_string(),
1816 value: self.value.to_string(),
1817 expected: "i128",
1818 message: e.to_string(),
1819 })?;
1820 visitor.visit_i128(n)
1821 }
1822
1823 fn deserialize_u8<V>(self, visitor: V) -> Result<V::Value, Self::Error>
1824 where
1825 V: Visitor<'de>,
1826 {
1827 let n = self
1828 .value
1829 .parse::<u8>()
1830 .map_err(|e| PathExtractError::InvalidValue {
1831 name: self.name.to_string(),
1832 value: self.value.to_string(),
1833 expected: "u8",
1834 message: e.to_string(),
1835 })?;
1836 visitor.visit_u8(n)
1837 }
1838
1839 fn deserialize_u16<V>(self, visitor: V) -> Result<V::Value, Self::Error>
1840 where
1841 V: Visitor<'de>,
1842 {
1843 let n = self
1844 .value
1845 .parse::<u16>()
1846 .map_err(|e| PathExtractError::InvalidValue {
1847 name: self.name.to_string(),
1848 value: self.value.to_string(),
1849 expected: "u16",
1850 message: e.to_string(),
1851 })?;
1852 visitor.visit_u16(n)
1853 }
1854
1855 fn deserialize_u32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
1856 where
1857 V: Visitor<'de>,
1858 {
1859 let n = self
1860 .value
1861 .parse::<u32>()
1862 .map_err(|e| PathExtractError::InvalidValue {
1863 name: self.name.to_string(),
1864 value: self.value.to_string(),
1865 expected: "u32",
1866 message: e.to_string(),
1867 })?;
1868 visitor.visit_u32(n)
1869 }
1870
1871 fn deserialize_u64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
1872 where
1873 V: Visitor<'de>,
1874 {
1875 let n = self
1876 .value
1877 .parse::<u64>()
1878 .map_err(|e| PathExtractError::InvalidValue {
1879 name: self.name.to_string(),
1880 value: self.value.to_string(),
1881 expected: "u64",
1882 message: e.to_string(),
1883 })?;
1884 visitor.visit_u64(n)
1885 }
1886
1887 fn deserialize_u128<V>(self, visitor: V) -> Result<V::Value, Self::Error>
1888 where
1889 V: Visitor<'de>,
1890 {
1891 let n = self
1892 .value
1893 .parse::<u128>()
1894 .map_err(|e| PathExtractError::InvalidValue {
1895 name: self.name.to_string(),
1896 value: self.value.to_string(),
1897 expected: "u128",
1898 message: e.to_string(),
1899 })?;
1900 visitor.visit_u128(n)
1901 }
1902
1903 fn deserialize_f32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
1904 where
1905 V: Visitor<'de>,
1906 {
1907 let n = self
1908 .value
1909 .parse::<f32>()
1910 .map_err(|e| PathExtractError::InvalidValue {
1911 name: self.name.to_string(),
1912 value: self.value.to_string(),
1913 expected: "f32",
1914 message: e.to_string(),
1915 })?;
1916 visitor.visit_f32(n)
1917 }
1918
1919 fn deserialize_f64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
1920 where
1921 V: Visitor<'de>,
1922 {
1923 let n = self
1924 .value
1925 .parse::<f64>()
1926 .map_err(|e| PathExtractError::InvalidValue {
1927 name: self.name.to_string(),
1928 value: self.value.to_string(),
1929 expected: "f64",
1930 message: e.to_string(),
1931 })?;
1932 visitor.visit_f64(n)
1933 }
1934
1935 fn deserialize_char<V>(self, visitor: V) -> Result<V::Value, Self::Error>
1936 where
1937 V: Visitor<'de>,
1938 {
1939 let mut chars = self.value.chars();
1940 let c = chars.next().ok_or_else(|| PathExtractError::InvalidValue {
1941 name: self.name.to_string(),
1942 value: self.value.to_string(),
1943 expected: "char",
1944 message: "empty string".to_string(),
1945 })?;
1946 if chars.next().is_some() {
1947 return Err(PathExtractError::InvalidValue {
1948 name: self.name.to_string(),
1949 value: self.value.to_string(),
1950 expected: "char",
1951 message: "expected single character".to_string(),
1952 });
1953 }
1954 visitor.visit_char(c)
1955 }
1956
1957 fn deserialize_str<V>(self, visitor: V) -> Result<V::Value, Self::Error>
1958 where
1959 V: Visitor<'de>,
1960 {
1961 visitor.visit_str(self.value)
1962 }
1963
1964 fn deserialize_string<V>(self, visitor: V) -> Result<V::Value, Self::Error>
1965 where
1966 V: Visitor<'de>,
1967 {
1968 visitor.visit_string(self.value.to_string())
1969 }
1970
1971 fn deserialize_bytes<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
1972 where
1973 V: Visitor<'de>,
1974 {
1975 Err(PathExtractError::DeserializeError {
1976 message: "bytes deserialization not supported for path parameters".to_string(),
1977 })
1978 }
1979
1980 fn deserialize_byte_buf<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
1981 where
1982 V: Visitor<'de>,
1983 {
1984 Err(PathExtractError::DeserializeError {
1985 message: "byte_buf deserialization not supported for path parameters".to_string(),
1986 })
1987 }
1988
1989 fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Self::Error>
1990 where
1991 V: Visitor<'de>,
1992 {
1993 visitor.visit_some(self)
1994 }
1995
1996 fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value, Self::Error>
1997 where
1998 V: Visitor<'de>,
1999 {
2000 visitor.visit_unit()
2001 }
2002
2003 fn deserialize_unit_struct<V>(
2004 self,
2005 _name: &'static str,
2006 visitor: V,
2007 ) -> Result<V::Value, Self::Error>
2008 where
2009 V: Visitor<'de>,
2010 {
2011 visitor.visit_unit()
2012 }
2013
2014 fn deserialize_newtype_struct<V>(
2015 self,
2016 _name: &'static str,
2017 visitor: V,
2018 ) -> Result<V::Value, Self::Error>
2019 where
2020 V: Visitor<'de>,
2021 {
2022 visitor.visit_newtype_struct(self)
2023 }
2024
2025 fn deserialize_seq<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
2026 where
2027 V: Visitor<'de>,
2028 {
2029 Err(PathExtractError::DeserializeError {
2030 message: "sequence deserialization not supported for single path parameter".to_string(),
2031 })
2032 }
2033
2034 fn deserialize_tuple<V>(self, _len: usize, _visitor: V) -> Result<V::Value, Self::Error>
2035 where
2036 V: Visitor<'de>,
2037 {
2038 Err(PathExtractError::DeserializeError {
2039 message: "tuple deserialization not supported for single path parameter".to_string(),
2040 })
2041 }
2042
2043 fn deserialize_tuple_struct<V>(
2044 self,
2045 _name: &'static str,
2046 _len: usize,
2047 _visitor: V,
2048 ) -> Result<V::Value, Self::Error>
2049 where
2050 V: Visitor<'de>,
2051 {
2052 Err(PathExtractError::DeserializeError {
2053 message: "tuple struct deserialization not supported for single path parameter"
2054 .to_string(),
2055 })
2056 }
2057
2058 fn deserialize_map<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
2059 where
2060 V: Visitor<'de>,
2061 {
2062 Err(PathExtractError::DeserializeError {
2063 message: "map deserialization not supported for single path parameter".to_string(),
2064 })
2065 }
2066
2067 fn deserialize_struct<V>(
2068 self,
2069 _name: &'static str,
2070 _fields: &'static [&'static str],
2071 _visitor: V,
2072 ) -> Result<V::Value, Self::Error>
2073 where
2074 V: Visitor<'de>,
2075 {
2076 Err(PathExtractError::DeserializeError {
2077 message: "struct deserialization not supported for single path parameter".to_string(),
2078 })
2079 }
2080
2081 fn deserialize_enum<V>(
2082 self,
2083 _name: &'static str,
2084 _variants: &'static [&'static str],
2085 visitor: V,
2086 ) -> Result<V::Value, Self::Error>
2087 where
2088 V: Visitor<'de>,
2089 {
2090 visitor.visit_enum(self.value.into_deserializer())
2091 }
2092
2093 fn deserialize_identifier<V>(self, visitor: V) -> Result<V::Value, Self::Error>
2094 where
2095 V: Visitor<'de>,
2096 {
2097 visitor.visit_str(self.value)
2098 }
2099
2100 fn deserialize_ignored_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
2101 where
2102 V: Visitor<'de>,
2103 {
2104 visitor.visit_unit()
2105 }
2106}
2107
2108#[derive(Debug, Clone, Copy, Default)]
2155pub struct Query<T>(pub T);
2156
2157impl<T> Query<T> {
2158 pub fn new(value: T) -> Self {
2160 Self(value)
2161 }
2162
2163 pub fn into_inner(self) -> T {
2165 self.0
2166 }
2167}
2168
2169impl<T> Deref for Query<T> {
2170 type Target = T;
2171
2172 fn deref(&self) -> &Self::Target {
2173 &self.0
2174 }
2175}
2176
2177impl<T> DerefMut for Query<T> {
2178 fn deref_mut(&mut self) -> &mut Self::Target {
2179 &mut self.0
2180 }
2181}
2182
2183#[derive(Debug)]
2185pub enum QueryExtractError {
2186 MissingParam { name: String },
2188 InvalidValue {
2190 name: String,
2191 value: String,
2192 expected: &'static str,
2193 message: String,
2194 },
2195 DeserializeError { message: String },
2197}
2198
2199impl fmt::Display for QueryExtractError {
2200 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
2201 match self {
2202 Self::MissingParam { name } => {
2203 write!(f, "Missing required query parameter: {}", name)
2204 }
2205 Self::InvalidValue {
2206 name,
2207 value,
2208 expected,
2209 message,
2210 } => {
2211 write!(
2212 f,
2213 "Invalid value '{}' for query parameter '{}' (expected {}): {}",
2214 value, name, expected, message
2215 )
2216 }
2217 Self::DeserializeError { message } => {
2218 write!(f, "Query deserialization error: {}", message)
2219 }
2220 }
2221 }
2222}
2223
2224impl std::error::Error for QueryExtractError {}
2225
2226impl de::Error for QueryExtractError {
2227 fn custom<T: fmt::Display>(msg: T) -> Self {
2228 Self::DeserializeError {
2229 message: msg.to_string(),
2230 }
2231 }
2232}
2233
2234impl IntoResponse for QueryExtractError {
2235 fn into_response(self) -> crate::response::Response {
2236 match self {
2237 Self::MissingParam { name } => ValidationErrors::single(
2238 ValidationError::missing(crate::error::loc::query(&name))
2239 .with_msg("Query parameter is required"),
2240 )
2241 .into_response(),
2242 Self::InvalidValue {
2243 name,
2244 value,
2245 expected,
2246 message,
2247 } => ValidationErrors::single(
2248 ValidationError::type_error(crate::error::loc::query(&name), &expected)
2249 .with_msg(format!("Expected {expected}: {message}"))
2250 .with_input(serde_json::Value::String(value)),
2251 )
2252 .into_response(),
2253 Self::DeserializeError { message } => ValidationErrors::single(
2254 ValidationError::new(
2255 crate::error::error_types::VALUE_ERROR,
2256 vec![crate::error::LocItem::field("query")],
2257 )
2258 .with_msg(message),
2259 )
2260 .into_response(),
2261 }
2262 }
2263}
2264
2265#[derive(Debug, Clone, Default)]
2270pub struct QueryParams {
2271 params: Vec<(String, String)>,
2273}
2274
2275impl QueryParams {
2276 pub fn new() -> Self {
2278 Self { params: Vec::new() }
2279 }
2280
2281 pub fn from_pairs(pairs: Vec<(String, String)>) -> Self {
2283 Self { params: pairs }
2284 }
2285
2286 pub fn parse(query: &str) -> Self {
2288 let pairs: Vec<(String, String)> = query
2289 .split('&')
2290 .filter(|s| !s.is_empty())
2291 .map(|pair| {
2292 if let Some(eq_pos) = pair.find('=') {
2293 let key = &pair[..eq_pos];
2294 let value = &pair[eq_pos + 1..];
2295 (
2296 percent_decode(key).into_owned(),
2297 percent_decode(value).into_owned(),
2298 )
2299 } else {
2300 (percent_decode(pair).into_owned(), String::new())
2302 }
2303 })
2304 .collect();
2305 Self { params: pairs }
2306 }
2307
2308 pub fn get(&self, key: &str) -> Option<&str> {
2310 self.params
2311 .iter()
2312 .find(|(k, _)| k == key)
2313 .map(|(_, v)| v.as_str())
2314 }
2315
2316 pub fn get_all(&self, key: &str) -> Vec<&str> {
2318 self.params
2319 .iter()
2320 .filter(|(k, _)| k == key)
2321 .map(|(_, v)| v.as_str())
2322 .collect()
2323 }
2324
2325 pub fn contains(&self, key: &str) -> bool {
2327 self.params.iter().any(|(k, _)| k == key)
2328 }
2329
2330 pub fn pairs(&self) -> &[(String, String)] {
2332 &self.params
2333 }
2334
2335 pub fn keys(&self) -> impl Iterator<Item = &str> {
2337 let mut seen = std::collections::HashSet::new();
2338 self.params.iter().filter_map(move |(k, _)| {
2339 if seen.insert(k.as_str()) {
2340 Some(k.as_str())
2341 } else {
2342 None
2343 }
2344 })
2345 }
2346
2347 pub fn len(&self) -> usize {
2349 self.params.len()
2350 }
2351
2352 pub fn is_empty(&self) -> bool {
2354 self.params.is_empty()
2355 }
2356}
2357
2358fn percent_decode(s: &str) -> std::borrow::Cow<'_, str> {
2363 use std::borrow::Cow;
2364
2365 if !s.contains('%') && !s.contains('+') {
2367 return Cow::Borrowed(s);
2368 }
2369
2370 let mut result = Vec::with_capacity(s.len());
2371 let bytes = s.as_bytes();
2372 let mut i = 0;
2373
2374 while i < bytes.len() {
2375 match bytes[i] {
2376 b'%' if i + 2 < bytes.len() => {
2377 if let (Some(hi), Some(lo)) = (hex_digit(bytes[i + 1]), hex_digit(bytes[i + 2])) {
2379 result.push(hi << 4 | lo);
2380 i += 3;
2381 } else {
2382 result.push(b'%');
2384 i += 1;
2385 }
2386 }
2387 b'+' => {
2388 result.push(b' ');
2390 i += 1;
2391 }
2392 b => {
2393 result.push(b);
2394 i += 1;
2395 }
2396 }
2397 }
2398
2399 Cow::Owned(String::from_utf8_lossy(&result).into_owned())
2400}
2401
2402fn hex_digit(b: u8) -> Option<u8> {
2404 match b {
2405 b'0'..=b'9' => Some(b - b'0'),
2406 b'a'..=b'f' => Some(b - b'a' + 10),
2407 b'A'..=b'F' => Some(b - b'A' + 10),
2408 _ => None,
2409 }
2410}
2411
2412impl<T: DeserializeOwned> FromRequest for Query<T> {
2413 type Error = QueryExtractError;
2414
2415 async fn from_request(_ctx: &RequestContext, req: &mut Request) -> Result<Self, Self::Error> {
2416 let params = match req.get_extension::<QueryParams>() {
2418 Some(p) => p.clone(),
2419 None => {
2420 let query_str = req.query().unwrap_or("");
2422 QueryParams::parse(query_str)
2423 }
2424 };
2425
2426 let value = T::deserialize(QueryDeserializer::new(¶ms))?;
2428
2429 Ok(Query(value))
2430 }
2431}
2432
2433pub const DEFAULT_PAGE: u64 = 1;
2439pub const DEFAULT_PER_PAGE: u64 = 20;
2441pub const MAX_PER_PAGE: u64 = 100;
2443
2444#[derive(Debug, Clone, Copy, PartialEq, Eq)]
2446pub struct PaginationConfig {
2447 default_page: u64,
2448 default_per_page: u64,
2449 max_per_page: u64,
2450}
2451
2452impl Default for PaginationConfig {
2453 fn default() -> Self {
2454 Self {
2455 default_page: DEFAULT_PAGE,
2456 default_per_page: DEFAULT_PER_PAGE,
2457 max_per_page: MAX_PER_PAGE,
2458 }
2459 }
2460}
2461
2462impl PaginationConfig {
2463 #[must_use]
2465 pub fn new() -> Self {
2466 Self::default()
2467 }
2468
2469 #[must_use]
2471 pub fn default_page(mut self, page: u64) -> Self {
2472 self.default_page = page;
2473 self
2474 }
2475
2476 #[must_use]
2478 pub fn default_per_page(mut self, per_page: u64) -> Self {
2479 self.default_per_page = per_page;
2480 self
2481 }
2482
2483 #[must_use]
2485 pub fn max_per_page(mut self, max: u64) -> Self {
2486 self.max_per_page = max;
2487 self
2488 }
2489}
2490
2491#[derive(serde::Deserialize)]
2492struct PaginationParams {
2493 page: Option<u64>,
2494 per_page: Option<u64>,
2495}
2496
2497#[derive(Debug, Clone, Copy, PartialEq, Eq)]
2504pub struct Pagination {
2505 page: u64,
2506 per_page: u64,
2507}
2508
2509impl Pagination {
2510 #[must_use]
2511 pub fn page(&self) -> u64 {
2512 self.page
2513 }
2514
2515 #[must_use]
2516 pub fn per_page(&self) -> u64 {
2517 self.per_page
2518 }
2519
2520 #[must_use]
2522 pub fn limit(&self) -> u64 {
2523 self.per_page
2524 }
2525
2526 #[must_use]
2528 pub fn offset(&self) -> u64 {
2529 self.page.saturating_sub(1).saturating_mul(self.per_page)
2530 }
2531}
2532
2533impl FromRequest for Pagination {
2534 type Error = QueryExtractError;
2535
2536 async fn from_request(ctx: &RequestContext, req: &mut Request) -> Result<Self, Self::Error> {
2537 let config = req
2539 .get_extension::<PaginationConfig>()
2540 .copied()
2541 .unwrap_or_default();
2542
2543 let Query(params) = Query::<PaginationParams>::from_request(ctx, req).await?;
2544
2545 let page = params.page.unwrap_or(config.default_page);
2546 if page == 0 {
2547 return Err(QueryExtractError::InvalidValue {
2548 name: "page".to_string(),
2549 value: "0".to_string(),
2550 expected: "u64",
2551 message: "must be >= 1".to_string(),
2552 });
2553 }
2554
2555 let per_page = params.per_page.unwrap_or(config.default_per_page);
2556 if per_page == 0 {
2557 return Err(QueryExtractError::InvalidValue {
2558 name: "per_page".to_string(),
2559 value: "0".to_string(),
2560 expected: "u64",
2561 message: "must be >= 1".to_string(),
2562 });
2563 }
2564 if per_page > config.max_per_page {
2565 return Err(QueryExtractError::InvalidValue {
2566 name: "per_page".to_string(),
2567 value: per_page.to_string(),
2568 expected: "u64",
2569 message: format!("must be <= {}", config.max_per_page),
2570 });
2571 }
2572
2573 Ok(Self { page, per_page })
2574 }
2575}
2576
2577#[derive(Debug, Clone, serde::Serialize)]
2579pub struct Page<T> {
2580 pub items: Vec<T>,
2581 pub total: u64,
2582 pub page: u64,
2583 pub per_page: u64,
2584 pub total_pages: u64,
2585}
2586
2587impl<T> Page<T> {
2588 #[must_use]
2589 pub fn new(items: Vec<T>, total: u64, page: u64, per_page: u64) -> Self {
2590 let total_pages = if per_page == 0 {
2591 0
2592 } else {
2593 total.div_ceil(per_page)
2594 };
2595 Self {
2596 items,
2597 total,
2598 page,
2599 per_page,
2600 total_pages,
2601 }
2602 }
2603}
2604
2605struct QueryDeserializer<'de> {
2617 params: &'de QueryParams,
2618}
2619
2620impl<'de> QueryDeserializer<'de> {
2621 fn new(params: &'de QueryParams) -> Self {
2622 Self { params }
2623 }
2624}
2625
2626impl<'de> Deserializer<'de> for QueryDeserializer<'de> {
2627 type Error = QueryExtractError;
2628
2629 fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
2630 where
2631 V: Visitor<'de>,
2632 {
2633 self.deserialize_map(visitor)
2635 }
2636
2637 fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value, Self::Error>
2638 where
2639 V: Visitor<'de>,
2640 {
2641 let value = self
2642 .params
2643 .pairs()
2644 .first()
2645 .map(|(_, v)| v.as_str())
2646 .ok_or_else(|| QueryExtractError::MissingParam {
2647 name: "value".to_string(),
2648 })?;
2649
2650 let b = parse_bool(value).map_err(|msg| QueryExtractError::InvalidValue {
2651 name: "value".to_string(),
2652 value: value.to_string(),
2653 expected: "bool",
2654 message: msg,
2655 })?;
2656 visitor.visit_bool(b)
2657 }
2658
2659 fn deserialize_i8<V>(self, visitor: V) -> Result<V::Value, Self::Error>
2660 where
2661 V: Visitor<'de>,
2662 {
2663 let value = self.get_single_value()?;
2664 let n = value
2665 .parse::<i8>()
2666 .map_err(|e| QueryExtractError::InvalidValue {
2667 name: "value".to_string(),
2668 value: value.to_string(),
2669 expected: "i8",
2670 message: e.to_string(),
2671 })?;
2672 visitor.visit_i8(n)
2673 }
2674
2675 fn deserialize_i16<V>(self, visitor: V) -> Result<V::Value, Self::Error>
2676 where
2677 V: Visitor<'de>,
2678 {
2679 let value = self.get_single_value()?;
2680 let n = value
2681 .parse::<i16>()
2682 .map_err(|e| QueryExtractError::InvalidValue {
2683 name: "value".to_string(),
2684 value: value.to_string(),
2685 expected: "i16",
2686 message: e.to_string(),
2687 })?;
2688 visitor.visit_i16(n)
2689 }
2690
2691 fn deserialize_i32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
2692 where
2693 V: Visitor<'de>,
2694 {
2695 let value = self.get_single_value()?;
2696 let n = value
2697 .parse::<i32>()
2698 .map_err(|e| QueryExtractError::InvalidValue {
2699 name: "value".to_string(),
2700 value: value.to_string(),
2701 expected: "i32",
2702 message: e.to_string(),
2703 })?;
2704 visitor.visit_i32(n)
2705 }
2706
2707 fn deserialize_i64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
2708 where
2709 V: Visitor<'de>,
2710 {
2711 let value = self.get_single_value()?;
2712 let n = value
2713 .parse::<i64>()
2714 .map_err(|e| QueryExtractError::InvalidValue {
2715 name: "value".to_string(),
2716 value: value.to_string(),
2717 expected: "i64",
2718 message: e.to_string(),
2719 })?;
2720 visitor.visit_i64(n)
2721 }
2722
2723 fn deserialize_u8<V>(self, visitor: V) -> Result<V::Value, Self::Error>
2724 where
2725 V: Visitor<'de>,
2726 {
2727 let value = self.get_single_value()?;
2728 let n = value
2729 .parse::<u8>()
2730 .map_err(|e| QueryExtractError::InvalidValue {
2731 name: "value".to_string(),
2732 value: value.to_string(),
2733 expected: "u8",
2734 message: e.to_string(),
2735 })?;
2736 visitor.visit_u8(n)
2737 }
2738
2739 fn deserialize_u16<V>(self, visitor: V) -> Result<V::Value, Self::Error>
2740 where
2741 V: Visitor<'de>,
2742 {
2743 let value = self.get_single_value()?;
2744 let n = value
2745 .parse::<u16>()
2746 .map_err(|e| QueryExtractError::InvalidValue {
2747 name: "value".to_string(),
2748 value: value.to_string(),
2749 expected: "u16",
2750 message: e.to_string(),
2751 })?;
2752 visitor.visit_u16(n)
2753 }
2754
2755 fn deserialize_u32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
2756 where
2757 V: Visitor<'de>,
2758 {
2759 let value = self.get_single_value()?;
2760 let n = value
2761 .parse::<u32>()
2762 .map_err(|e| QueryExtractError::InvalidValue {
2763 name: "value".to_string(),
2764 value: value.to_string(),
2765 expected: "u32",
2766 message: e.to_string(),
2767 })?;
2768 visitor.visit_u32(n)
2769 }
2770
2771 fn deserialize_u64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
2772 where
2773 V: Visitor<'de>,
2774 {
2775 let value = self.get_single_value()?;
2776 let n = value
2777 .parse::<u64>()
2778 .map_err(|e| QueryExtractError::InvalidValue {
2779 name: "value".to_string(),
2780 value: value.to_string(),
2781 expected: "u64",
2782 message: e.to_string(),
2783 })?;
2784 visitor.visit_u64(n)
2785 }
2786
2787 fn deserialize_f32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
2788 where
2789 V: Visitor<'de>,
2790 {
2791 let value = self.get_single_value()?;
2792 let n = value
2793 .parse::<f32>()
2794 .map_err(|e| QueryExtractError::InvalidValue {
2795 name: "value".to_string(),
2796 value: value.to_string(),
2797 expected: "f32",
2798 message: e.to_string(),
2799 })?;
2800 visitor.visit_f32(n)
2801 }
2802
2803 fn deserialize_f64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
2804 where
2805 V: Visitor<'de>,
2806 {
2807 let value = self.get_single_value()?;
2808 let n = value
2809 .parse::<f64>()
2810 .map_err(|e| QueryExtractError::InvalidValue {
2811 name: "value".to_string(),
2812 value: value.to_string(),
2813 expected: "f64",
2814 message: e.to_string(),
2815 })?;
2816 visitor.visit_f64(n)
2817 }
2818
2819 fn deserialize_char<V>(self, visitor: V) -> Result<V::Value, Self::Error>
2820 where
2821 V: Visitor<'de>,
2822 {
2823 let value = self.get_single_value()?;
2824 let mut chars = value.chars();
2825 match (chars.next(), chars.next()) {
2826 (Some(c), None) => visitor.visit_char(c),
2827 _ => Err(QueryExtractError::InvalidValue {
2828 name: "value".to_string(),
2829 value: value.to_string(),
2830 expected: "char",
2831 message: "expected single character".to_string(),
2832 }),
2833 }
2834 }
2835
2836 fn deserialize_str<V>(self, visitor: V) -> Result<V::Value, Self::Error>
2837 where
2838 V: Visitor<'de>,
2839 {
2840 let value = self.get_single_value()?;
2841 visitor.visit_str(value)
2842 }
2843
2844 fn deserialize_string<V>(self, visitor: V) -> Result<V::Value, Self::Error>
2845 where
2846 V: Visitor<'de>,
2847 {
2848 let value = self.get_single_value()?;
2849 visitor.visit_string(value.to_owned())
2850 }
2851
2852 fn deserialize_bytes<V>(self, visitor: V) -> Result<V::Value, Self::Error>
2853 where
2854 V: Visitor<'de>,
2855 {
2856 let value = self.get_single_value()?;
2857 visitor.visit_bytes(value.as_bytes())
2858 }
2859
2860 fn deserialize_byte_buf<V>(self, visitor: V) -> Result<V::Value, Self::Error>
2861 where
2862 V: Visitor<'de>,
2863 {
2864 let value = self.get_single_value()?;
2865 visitor.visit_byte_buf(value.as_bytes().to_vec())
2866 }
2867
2868 fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Self::Error>
2869 where
2870 V: Visitor<'de>,
2871 {
2872 if self.params.is_empty() {
2874 visitor.visit_none()
2875 } else {
2876 visitor.visit_some(self)
2877 }
2878 }
2879
2880 fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value, Self::Error>
2881 where
2882 V: Visitor<'de>,
2883 {
2884 visitor.visit_unit()
2885 }
2886
2887 fn deserialize_unit_struct<V>(
2888 self,
2889 _name: &'static str,
2890 visitor: V,
2891 ) -> Result<V::Value, Self::Error>
2892 where
2893 V: Visitor<'de>,
2894 {
2895 visitor.visit_unit()
2896 }
2897
2898 fn deserialize_newtype_struct<V>(
2899 self,
2900 _name: &'static str,
2901 visitor: V,
2902 ) -> Result<V::Value, Self::Error>
2903 where
2904 V: Visitor<'de>,
2905 {
2906 visitor.visit_newtype_struct(self)
2907 }
2908
2909 fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value, Self::Error>
2910 where
2911 V: Visitor<'de>,
2912 {
2913 let values: Vec<&str> = self
2915 .params
2916 .pairs()
2917 .iter()
2918 .map(|(_, v)| v.as_str())
2919 .collect();
2920 visitor.visit_seq(QuerySeqAccess::new(values))
2921 }
2922
2923 fn deserialize_tuple<V>(self, _len: usize, visitor: V) -> Result<V::Value, Self::Error>
2924 where
2925 V: Visitor<'de>,
2926 {
2927 let values: Vec<&str> = self
2929 .params
2930 .pairs()
2931 .iter()
2932 .map(|(_, v)| v.as_str())
2933 .collect();
2934 visitor.visit_seq(QuerySeqAccess::new(values))
2935 }
2936
2937 fn deserialize_tuple_struct<V>(
2938 self,
2939 _name: &'static str,
2940 _len: usize,
2941 visitor: V,
2942 ) -> Result<V::Value, Self::Error>
2943 where
2944 V: Visitor<'de>,
2945 {
2946 let values: Vec<&str> = self
2947 .params
2948 .pairs()
2949 .iter()
2950 .map(|(_, v)| v.as_str())
2951 .collect();
2952 visitor.visit_seq(QuerySeqAccess::new(values))
2953 }
2954
2955 fn deserialize_map<V>(self, visitor: V) -> Result<V::Value, Self::Error>
2956 where
2957 V: Visitor<'de>,
2958 {
2959 visitor.visit_map(QueryMapAccess::new(self.params))
2960 }
2961
2962 fn deserialize_struct<V>(
2963 self,
2964 _name: &'static str,
2965 _fields: &'static [&'static str],
2966 visitor: V,
2967 ) -> Result<V::Value, Self::Error>
2968 where
2969 V: Visitor<'de>,
2970 {
2971 self.deserialize_map(visitor)
2972 }
2973
2974 fn deserialize_enum<V>(
2975 self,
2976 _name: &'static str,
2977 _variants: &'static [&'static str],
2978 visitor: V,
2979 ) -> Result<V::Value, Self::Error>
2980 where
2981 V: Visitor<'de>,
2982 {
2983 let value = self.get_single_value()?;
2985 visitor.visit_enum(value.into_deserializer())
2986 }
2987
2988 fn deserialize_identifier<V>(self, visitor: V) -> Result<V::Value, Self::Error>
2989 where
2990 V: Visitor<'de>,
2991 {
2992 let value = self.get_single_value()?;
2993 visitor.visit_str(value)
2994 }
2995
2996 fn deserialize_ignored_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
2997 where
2998 V: Visitor<'de>,
2999 {
3000 visitor.visit_unit()
3001 }
3002}
3003
3004impl<'de> QueryDeserializer<'de> {
3005 fn get_single_value(&self) -> Result<&'de str, QueryExtractError> {
3006 self.params
3007 .pairs()
3008 .first()
3009 .map(|(_, v)| v.as_str())
3010 .ok_or_else(|| QueryExtractError::MissingParam {
3011 name: "value".to_string(),
3012 })
3013 }
3014}
3015
3016fn parse_bool(s: &str) -> Result<bool, String> {
3018 match s.to_lowercase().as_str() {
3019 "true" | "1" | "yes" | "on" => Ok(true),
3020 "false" | "0" | "no" | "off" | "" => Ok(false),
3021 _ => Err(format!("cannot parse '{}' as boolean", s)),
3022 }
3023}
3024
3025struct QuerySeqAccess<'de> {
3027 values: Vec<&'de str>,
3028 index: usize,
3029}
3030
3031impl<'de> QuerySeqAccess<'de> {
3032 fn new(values: Vec<&'de str>) -> Self {
3033 Self { values, index: 0 }
3034 }
3035}
3036
3037impl<'de> SeqAccess<'de> for QuerySeqAccess<'de> {
3038 type Error = QueryExtractError;
3039
3040 fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>, Self::Error>
3041 where
3042 T: de::DeserializeSeed<'de>,
3043 {
3044 if self.index >= self.values.len() {
3045 return Ok(None);
3046 }
3047
3048 let value = self.values[self.index];
3049 self.index += 1;
3050
3051 seed.deserialize(QueryValueDeserializer::new(value, None))
3052 .map(Some)
3053 }
3054
3055 fn size_hint(&self) -> Option<usize> {
3056 Some(self.values.len() - self.index)
3057 }
3058}
3059
3060struct QueryMapAccess<'de> {
3062 params: &'de QueryParams,
3063 keys: Vec<&'de str>,
3064 index: usize,
3065}
3066
3067impl<'de> QueryMapAccess<'de> {
3068 fn new(params: &'de QueryParams) -> Self {
3069 let keys: Vec<&str> = params.keys().collect();
3070 Self {
3071 params,
3072 keys,
3073 index: 0,
3074 }
3075 }
3076}
3077
3078impl<'de> MapAccess<'de> for QueryMapAccess<'de> {
3079 type Error = QueryExtractError;
3080
3081 fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>, Self::Error>
3082 where
3083 K: de::DeserializeSeed<'de>,
3084 {
3085 if self.index >= self.keys.len() {
3086 return Ok(None);
3087 }
3088
3089 let key = self.keys[self.index];
3090 seed.deserialize(key.into_deserializer()).map(Some)
3091 }
3092
3093 fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value, Self::Error>
3094 where
3095 V: de::DeserializeSeed<'de>,
3096 {
3097 let key = self.keys[self.index];
3098 self.index += 1;
3099
3100 let values = self.params.get_all(key);
3102
3103 seed.deserialize(QueryFieldDeserializer::new(key, values))
3104 }
3105}
3106
3107struct QueryValueDeserializer<'de> {
3109 value: &'de str,
3110 name: Option<&'de str>,
3111}
3112
3113impl<'de> QueryValueDeserializer<'de> {
3114 fn new(value: &'de str, name: Option<&'de str>) -> Self {
3115 Self { value, name }
3116 }
3117
3118 fn field_name(&self) -> String {
3119 self.name.unwrap_or("value").to_string()
3120 }
3121}
3122
3123impl<'de> Deserializer<'de> for QueryValueDeserializer<'de> {
3124 type Error = QueryExtractError;
3125
3126 fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3127 where
3128 V: Visitor<'de>,
3129 {
3130 visitor.visit_str(self.value)
3131 }
3132
3133 fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3134 where
3135 V: Visitor<'de>,
3136 {
3137 let b = parse_bool(self.value).map_err(|msg| QueryExtractError::InvalidValue {
3138 name: self.field_name(),
3139 value: self.value.to_string(),
3140 expected: "bool",
3141 message: msg,
3142 })?;
3143 visitor.visit_bool(b)
3144 }
3145
3146 fn deserialize_i8<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3147 where
3148 V: Visitor<'de>,
3149 {
3150 let n = self
3151 .value
3152 .parse::<i8>()
3153 .map_err(|e| QueryExtractError::InvalidValue {
3154 name: self.field_name(),
3155 value: self.value.to_string(),
3156 expected: "i8",
3157 message: e.to_string(),
3158 })?;
3159 visitor.visit_i8(n)
3160 }
3161
3162 fn deserialize_i16<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3163 where
3164 V: Visitor<'de>,
3165 {
3166 let n = self
3167 .value
3168 .parse::<i16>()
3169 .map_err(|e| QueryExtractError::InvalidValue {
3170 name: self.field_name(),
3171 value: self.value.to_string(),
3172 expected: "i16",
3173 message: e.to_string(),
3174 })?;
3175 visitor.visit_i16(n)
3176 }
3177
3178 fn deserialize_i32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3179 where
3180 V: Visitor<'de>,
3181 {
3182 let n = self
3183 .value
3184 .parse::<i32>()
3185 .map_err(|e| QueryExtractError::InvalidValue {
3186 name: self.field_name(),
3187 value: self.value.to_string(),
3188 expected: "i32",
3189 message: e.to_string(),
3190 })?;
3191 visitor.visit_i32(n)
3192 }
3193
3194 fn deserialize_i64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3195 where
3196 V: Visitor<'de>,
3197 {
3198 let n = self
3199 .value
3200 .parse::<i64>()
3201 .map_err(|e| QueryExtractError::InvalidValue {
3202 name: self.field_name(),
3203 value: self.value.to_string(),
3204 expected: "i64",
3205 message: e.to_string(),
3206 })?;
3207 visitor.visit_i64(n)
3208 }
3209
3210 fn deserialize_u8<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3211 where
3212 V: Visitor<'de>,
3213 {
3214 let n = self
3215 .value
3216 .parse::<u8>()
3217 .map_err(|e| QueryExtractError::InvalidValue {
3218 name: self.field_name(),
3219 value: self.value.to_string(),
3220 expected: "u8",
3221 message: e.to_string(),
3222 })?;
3223 visitor.visit_u8(n)
3224 }
3225
3226 fn deserialize_u16<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3227 where
3228 V: Visitor<'de>,
3229 {
3230 let n = self
3231 .value
3232 .parse::<u16>()
3233 .map_err(|e| QueryExtractError::InvalidValue {
3234 name: self.field_name(),
3235 value: self.value.to_string(),
3236 expected: "u16",
3237 message: e.to_string(),
3238 })?;
3239 visitor.visit_u16(n)
3240 }
3241
3242 fn deserialize_u32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3243 where
3244 V: Visitor<'de>,
3245 {
3246 let n = self
3247 .value
3248 .parse::<u32>()
3249 .map_err(|e| QueryExtractError::InvalidValue {
3250 name: self.field_name(),
3251 value: self.value.to_string(),
3252 expected: "u32",
3253 message: e.to_string(),
3254 })?;
3255 visitor.visit_u32(n)
3256 }
3257
3258 fn deserialize_u64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3259 where
3260 V: Visitor<'de>,
3261 {
3262 let n = self
3263 .value
3264 .parse::<u64>()
3265 .map_err(|e| QueryExtractError::InvalidValue {
3266 name: self.field_name(),
3267 value: self.value.to_string(),
3268 expected: "u64",
3269 message: e.to_string(),
3270 })?;
3271 visitor.visit_u64(n)
3272 }
3273
3274 fn deserialize_f32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3275 where
3276 V: Visitor<'de>,
3277 {
3278 let n = self
3279 .value
3280 .parse::<f32>()
3281 .map_err(|e| QueryExtractError::InvalidValue {
3282 name: self.field_name(),
3283 value: self.value.to_string(),
3284 expected: "f32",
3285 message: e.to_string(),
3286 })?;
3287 visitor.visit_f32(n)
3288 }
3289
3290 fn deserialize_f64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3291 where
3292 V: Visitor<'de>,
3293 {
3294 let n = self
3295 .value
3296 .parse::<f64>()
3297 .map_err(|e| QueryExtractError::InvalidValue {
3298 name: self.field_name(),
3299 value: self.value.to_string(),
3300 expected: "f64",
3301 message: e.to_string(),
3302 })?;
3303 visitor.visit_f64(n)
3304 }
3305
3306 fn deserialize_char<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3307 where
3308 V: Visitor<'de>,
3309 {
3310 let mut chars = self.value.chars();
3311 match (chars.next(), chars.next()) {
3312 (Some(c), None) => visitor.visit_char(c),
3313 _ => Err(QueryExtractError::InvalidValue {
3314 name: self.field_name(),
3315 value: self.value.to_string(),
3316 expected: "char",
3317 message: "expected single character".to_string(),
3318 }),
3319 }
3320 }
3321
3322 fn deserialize_str<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3323 where
3324 V: Visitor<'de>,
3325 {
3326 visitor.visit_str(self.value)
3327 }
3328
3329 fn deserialize_string<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3330 where
3331 V: Visitor<'de>,
3332 {
3333 visitor.visit_string(self.value.to_owned())
3334 }
3335
3336 fn deserialize_bytes<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3337 where
3338 V: Visitor<'de>,
3339 {
3340 visitor.visit_bytes(self.value.as_bytes())
3341 }
3342
3343 fn deserialize_byte_buf<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3344 where
3345 V: Visitor<'de>,
3346 {
3347 visitor.visit_byte_buf(self.value.as_bytes().to_vec())
3348 }
3349
3350 fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3351 where
3352 V: Visitor<'de>,
3353 {
3354 if self.value.is_empty() {
3355 visitor.visit_none()
3356 } else {
3357 visitor.visit_some(self)
3358 }
3359 }
3360
3361 fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3362 where
3363 V: Visitor<'de>,
3364 {
3365 visitor.visit_unit()
3366 }
3367
3368 fn deserialize_unit_struct<V>(
3369 self,
3370 _name: &'static str,
3371 visitor: V,
3372 ) -> Result<V::Value, Self::Error>
3373 where
3374 V: Visitor<'de>,
3375 {
3376 visitor.visit_unit()
3377 }
3378
3379 fn deserialize_newtype_struct<V>(
3380 self,
3381 _name: &'static str,
3382 visitor: V,
3383 ) -> Result<V::Value, Self::Error>
3384 where
3385 V: Visitor<'de>,
3386 {
3387 visitor.visit_newtype_struct(self)
3388 }
3389
3390 fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3391 where
3392 V: Visitor<'de>,
3393 {
3394 visitor.visit_seq(QuerySeqAccess::new(vec![self.value]))
3396 }
3397
3398 fn deserialize_tuple<V>(self, _len: usize, visitor: V) -> Result<V::Value, Self::Error>
3399 where
3400 V: Visitor<'de>,
3401 {
3402 visitor.visit_seq(QuerySeqAccess::new(vec![self.value]))
3403 }
3404
3405 fn deserialize_tuple_struct<V>(
3406 self,
3407 _name: &'static str,
3408 _len: usize,
3409 visitor: V,
3410 ) -> Result<V::Value, Self::Error>
3411 where
3412 V: Visitor<'de>,
3413 {
3414 visitor.visit_seq(QuerySeqAccess::new(vec![self.value]))
3415 }
3416
3417 fn deserialize_map<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
3418 where
3419 V: Visitor<'de>,
3420 {
3421 Err(QueryExtractError::DeserializeError {
3423 message: "cannot deserialize single value as map".to_string(),
3424 })
3425 }
3426
3427 fn deserialize_struct<V>(
3428 self,
3429 _name: &'static str,
3430 _fields: &'static [&'static str],
3431 visitor: V,
3432 ) -> Result<V::Value, Self::Error>
3433 where
3434 V: Visitor<'de>,
3435 {
3436 self.deserialize_map(visitor)
3437 }
3438
3439 fn deserialize_enum<V>(
3440 self,
3441 _name: &'static str,
3442 _variants: &'static [&'static str],
3443 visitor: V,
3444 ) -> Result<V::Value, Self::Error>
3445 where
3446 V: Visitor<'de>,
3447 {
3448 visitor.visit_enum(self.value.into_deserializer())
3449 }
3450
3451 fn deserialize_identifier<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3452 where
3453 V: Visitor<'de>,
3454 {
3455 visitor.visit_str(self.value)
3456 }
3457
3458 fn deserialize_ignored_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3459 where
3460 V: Visitor<'de>,
3461 {
3462 visitor.visit_unit()
3463 }
3464}
3465
3466struct QueryFieldDeserializer<'de> {
3470 name: &'de str,
3471 values: Vec<&'de str>,
3472}
3473
3474impl<'de> QueryFieldDeserializer<'de> {
3475 fn new(name: &'de str, values: Vec<&'de str>) -> Self {
3476 Self { name, values }
3477 }
3478}
3479
3480impl<'de> Deserializer<'de> for QueryFieldDeserializer<'de> {
3481 type Error = QueryExtractError;
3482
3483 fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3484 where
3485 V: Visitor<'de>,
3486 {
3487 if let Some(value) = self.values.first() {
3489 visitor.visit_str(value)
3490 } else {
3491 visitor.visit_none()
3492 }
3493 }
3494
3495 fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3496 where
3497 V: Visitor<'de>,
3498 {
3499 let value = self
3500 .values
3501 .first()
3502 .ok_or_else(|| QueryExtractError::MissingParam {
3503 name: self.name.to_string(),
3504 })?;
3505 let b = parse_bool(value).map_err(|msg| QueryExtractError::InvalidValue {
3506 name: self.name.to_string(),
3507 value: (*value).to_string(),
3508 expected: "bool",
3509 message: msg,
3510 })?;
3511 visitor.visit_bool(b)
3512 }
3513
3514 fn deserialize_i8<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3515 where
3516 V: Visitor<'de>,
3517 {
3518 let value = self
3519 .values
3520 .first()
3521 .ok_or_else(|| QueryExtractError::MissingParam {
3522 name: self.name.to_string(),
3523 })?;
3524 let n = value
3525 .parse::<i8>()
3526 .map_err(|e| QueryExtractError::InvalidValue {
3527 name: self.name.to_string(),
3528 value: (*value).to_string(),
3529 expected: "i8",
3530 message: e.to_string(),
3531 })?;
3532 visitor.visit_i8(n)
3533 }
3534
3535 fn deserialize_i16<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3536 where
3537 V: Visitor<'de>,
3538 {
3539 let value = self
3540 .values
3541 .first()
3542 .ok_or_else(|| QueryExtractError::MissingParam {
3543 name: self.name.to_string(),
3544 })?;
3545 let n = value
3546 .parse::<i16>()
3547 .map_err(|e| QueryExtractError::InvalidValue {
3548 name: self.name.to_string(),
3549 value: (*value).to_string(),
3550 expected: "i16",
3551 message: e.to_string(),
3552 })?;
3553 visitor.visit_i16(n)
3554 }
3555
3556 fn deserialize_i32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3557 where
3558 V: Visitor<'de>,
3559 {
3560 let value = self
3561 .values
3562 .first()
3563 .ok_or_else(|| QueryExtractError::MissingParam {
3564 name: self.name.to_string(),
3565 })?;
3566 let n = value
3567 .parse::<i32>()
3568 .map_err(|e| QueryExtractError::InvalidValue {
3569 name: self.name.to_string(),
3570 value: (*value).to_string(),
3571 expected: "i32",
3572 message: e.to_string(),
3573 })?;
3574 visitor.visit_i32(n)
3575 }
3576
3577 fn deserialize_i64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3578 where
3579 V: Visitor<'de>,
3580 {
3581 let value = self
3582 .values
3583 .first()
3584 .ok_or_else(|| QueryExtractError::MissingParam {
3585 name: self.name.to_string(),
3586 })?;
3587 let n = value
3588 .parse::<i64>()
3589 .map_err(|e| QueryExtractError::InvalidValue {
3590 name: self.name.to_string(),
3591 value: (*value).to_string(),
3592 expected: "i64",
3593 message: e.to_string(),
3594 })?;
3595 visitor.visit_i64(n)
3596 }
3597
3598 fn deserialize_u8<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3599 where
3600 V: Visitor<'de>,
3601 {
3602 let value = self
3603 .values
3604 .first()
3605 .ok_or_else(|| QueryExtractError::MissingParam {
3606 name: self.name.to_string(),
3607 })?;
3608 let n = value
3609 .parse::<u8>()
3610 .map_err(|e| QueryExtractError::InvalidValue {
3611 name: self.name.to_string(),
3612 value: (*value).to_string(),
3613 expected: "u8",
3614 message: e.to_string(),
3615 })?;
3616 visitor.visit_u8(n)
3617 }
3618
3619 fn deserialize_u16<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3620 where
3621 V: Visitor<'de>,
3622 {
3623 let value = self
3624 .values
3625 .first()
3626 .ok_or_else(|| QueryExtractError::MissingParam {
3627 name: self.name.to_string(),
3628 })?;
3629 let n = value
3630 .parse::<u16>()
3631 .map_err(|e| QueryExtractError::InvalidValue {
3632 name: self.name.to_string(),
3633 value: (*value).to_string(),
3634 expected: "u16",
3635 message: e.to_string(),
3636 })?;
3637 visitor.visit_u16(n)
3638 }
3639
3640 fn deserialize_u32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3641 where
3642 V: Visitor<'de>,
3643 {
3644 let value = self
3645 .values
3646 .first()
3647 .ok_or_else(|| QueryExtractError::MissingParam {
3648 name: self.name.to_string(),
3649 })?;
3650 let n = value
3651 .parse::<u32>()
3652 .map_err(|e| QueryExtractError::InvalidValue {
3653 name: self.name.to_string(),
3654 value: (*value).to_string(),
3655 expected: "u32",
3656 message: e.to_string(),
3657 })?;
3658 visitor.visit_u32(n)
3659 }
3660
3661 fn deserialize_u64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3662 where
3663 V: Visitor<'de>,
3664 {
3665 let value = self
3666 .values
3667 .first()
3668 .ok_or_else(|| QueryExtractError::MissingParam {
3669 name: self.name.to_string(),
3670 })?;
3671 let n = value
3672 .parse::<u64>()
3673 .map_err(|e| QueryExtractError::InvalidValue {
3674 name: self.name.to_string(),
3675 value: (*value).to_string(),
3676 expected: "u64",
3677 message: e.to_string(),
3678 })?;
3679 visitor.visit_u64(n)
3680 }
3681
3682 fn deserialize_f32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3683 where
3684 V: Visitor<'de>,
3685 {
3686 let value = self
3687 .values
3688 .first()
3689 .ok_or_else(|| QueryExtractError::MissingParam {
3690 name: self.name.to_string(),
3691 })?;
3692 let n = value
3693 .parse::<f32>()
3694 .map_err(|e| QueryExtractError::InvalidValue {
3695 name: self.name.to_string(),
3696 value: (*value).to_string(),
3697 expected: "f32",
3698 message: e.to_string(),
3699 })?;
3700 visitor.visit_f32(n)
3701 }
3702
3703 fn deserialize_f64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3704 where
3705 V: Visitor<'de>,
3706 {
3707 let value = self
3708 .values
3709 .first()
3710 .ok_or_else(|| QueryExtractError::MissingParam {
3711 name: self.name.to_string(),
3712 })?;
3713 let n = value
3714 .parse::<f64>()
3715 .map_err(|e| QueryExtractError::InvalidValue {
3716 name: self.name.to_string(),
3717 value: (*value).to_string(),
3718 expected: "f64",
3719 message: e.to_string(),
3720 })?;
3721 visitor.visit_f64(n)
3722 }
3723
3724 fn deserialize_char<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3725 where
3726 V: Visitor<'de>,
3727 {
3728 let value = self
3729 .values
3730 .first()
3731 .ok_or_else(|| QueryExtractError::MissingParam {
3732 name: self.name.to_string(),
3733 })?;
3734 let mut chars = value.chars();
3735 match (chars.next(), chars.next()) {
3736 (Some(c), None) => visitor.visit_char(c),
3737 _ => Err(QueryExtractError::InvalidValue {
3738 name: self.name.to_string(),
3739 value: (*value).to_string(),
3740 expected: "char",
3741 message: "expected single character".to_string(),
3742 }),
3743 }
3744 }
3745
3746 fn deserialize_str<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3747 where
3748 V: Visitor<'de>,
3749 {
3750 let value = self
3751 .values
3752 .first()
3753 .ok_or_else(|| QueryExtractError::MissingParam {
3754 name: self.name.to_string(),
3755 })?;
3756 visitor.visit_str(value)
3757 }
3758
3759 fn deserialize_string<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3760 where
3761 V: Visitor<'de>,
3762 {
3763 let value = self
3764 .values
3765 .first()
3766 .ok_or_else(|| QueryExtractError::MissingParam {
3767 name: self.name.to_string(),
3768 })?;
3769 visitor.visit_string((*value).to_owned())
3770 }
3771
3772 fn deserialize_bytes<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3773 where
3774 V: Visitor<'de>,
3775 {
3776 let value = self
3777 .values
3778 .first()
3779 .ok_or_else(|| QueryExtractError::MissingParam {
3780 name: self.name.to_string(),
3781 })?;
3782 visitor.visit_bytes(value.as_bytes())
3783 }
3784
3785 fn deserialize_byte_buf<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3786 where
3787 V: Visitor<'de>,
3788 {
3789 let value = self
3790 .values
3791 .first()
3792 .ok_or_else(|| QueryExtractError::MissingParam {
3793 name: self.name.to_string(),
3794 })?;
3795 visitor.visit_byte_buf(value.as_bytes().to_vec())
3796 }
3797
3798 fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3799 where
3800 V: Visitor<'de>,
3801 {
3802 if self.values.is_empty() {
3803 visitor.visit_none()
3804 } else {
3805 visitor.visit_some(self)
3806 }
3807 }
3808
3809 fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3810 where
3811 V: Visitor<'de>,
3812 {
3813 visitor.visit_unit()
3814 }
3815
3816 fn deserialize_unit_struct<V>(
3817 self,
3818 _name: &'static str,
3819 visitor: V,
3820 ) -> Result<V::Value, Self::Error>
3821 where
3822 V: Visitor<'de>,
3823 {
3824 visitor.visit_unit()
3825 }
3826
3827 fn deserialize_newtype_struct<V>(
3828 self,
3829 _name: &'static str,
3830 visitor: V,
3831 ) -> Result<V::Value, Self::Error>
3832 where
3833 V: Visitor<'de>,
3834 {
3835 visitor.visit_newtype_struct(self)
3836 }
3837
3838 fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3839 where
3840 V: Visitor<'de>,
3841 {
3842 visitor.visit_seq(QuerySeqAccess::new(self.values))
3844 }
3845
3846 fn deserialize_tuple<V>(self, _len: usize, visitor: V) -> Result<V::Value, Self::Error>
3847 where
3848 V: Visitor<'de>,
3849 {
3850 visitor.visit_seq(QuerySeqAccess::new(self.values))
3851 }
3852
3853 fn deserialize_tuple_struct<V>(
3854 self,
3855 _name: &'static str,
3856 _len: usize,
3857 visitor: V,
3858 ) -> Result<V::Value, Self::Error>
3859 where
3860 V: Visitor<'de>,
3861 {
3862 visitor.visit_seq(QuerySeqAccess::new(self.values))
3863 }
3864
3865 fn deserialize_map<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
3866 where
3867 V: Visitor<'de>,
3868 {
3869 Err(QueryExtractError::DeserializeError {
3870 message: "cannot deserialize query field as map".to_string(),
3871 })
3872 }
3873
3874 fn deserialize_struct<V>(
3875 self,
3876 _name: &'static str,
3877 _fields: &'static [&'static str],
3878 visitor: V,
3879 ) -> Result<V::Value, Self::Error>
3880 where
3881 V: Visitor<'de>,
3882 {
3883 self.deserialize_map(visitor)
3884 }
3885
3886 fn deserialize_enum<V>(
3887 self,
3888 _name: &'static str,
3889 _variants: &'static [&'static str],
3890 visitor: V,
3891 ) -> Result<V::Value, Self::Error>
3892 where
3893 V: Visitor<'de>,
3894 {
3895 let value = self
3896 .values
3897 .first()
3898 .ok_or_else(|| QueryExtractError::MissingParam {
3899 name: self.name.to_string(),
3900 })?;
3901 visitor.visit_enum((*value).into_deserializer())
3902 }
3903
3904 fn deserialize_identifier<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3905 where
3906 V: Visitor<'de>,
3907 {
3908 let value = self
3909 .values
3910 .first()
3911 .ok_or_else(|| QueryExtractError::MissingParam {
3912 name: self.name.to_string(),
3913 })?;
3914 visitor.visit_str(value)
3915 }
3916
3917 fn deserialize_ignored_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3918 where
3919 V: Visitor<'de>,
3920 {
3921 visitor.visit_unit()
3922 }
3923}
3924
3925#[derive(Debug, Default, Clone)]
3956pub struct AppState {
3957 inner: std::sync::Arc<
3958 std::collections::HashMap<
3959 std::any::TypeId,
3960 std::sync::Arc<dyn std::any::Any + Send + Sync>,
3961 >,
3962 >,
3963}
3964
3965impl AppState {
3966 #[must_use]
3968 pub fn new() -> Self {
3969 Self {
3970 inner: std::sync::Arc::new(std::collections::HashMap::new()),
3971 }
3972 }
3973
3974 #[must_use]
3979 pub fn with<T: Send + Sync + 'static>(self, value: T) -> Self {
3980 let mut map = match std::sync::Arc::try_unwrap(self.inner) {
3981 Ok(map) => map,
3982 Err(arc) => (*arc).clone(),
3983 };
3984 map.insert(std::any::TypeId::of::<T>(), std::sync::Arc::new(value));
3985 Self {
3986 inner: std::sync::Arc::new(map),
3987 }
3988 }
3989
3990 #[must_use]
3992 pub fn get<T: Send + Sync + 'static>(&self) -> Option<&T> {
3993 self.inner
3994 .get(&std::any::TypeId::of::<T>())
3995 .and_then(|arc| arc.downcast_ref::<T>())
3996 }
3997
3998 #[must_use]
4000 pub fn contains<T: Send + Sync + 'static>(&self) -> bool {
4001 self.inner.contains_key(&std::any::TypeId::of::<T>())
4002 }
4003
4004 #[must_use]
4006 pub fn len(&self) -> usize {
4007 self.inner.len()
4008 }
4009
4010 #[must_use]
4012 pub fn is_empty(&self) -> bool {
4013 self.inner.is_empty()
4014 }
4015}
4016
4017#[derive(Debug, Clone)]
4046pub struct State<T>(pub T);
4047
4048impl<T> State<T> {
4049 pub fn into_inner(self) -> T {
4051 self.0
4052 }
4053}
4054
4055impl<T> Deref for State<T> {
4056 type Target = T;
4057
4058 fn deref(&self) -> &Self::Target {
4059 &self.0
4060 }
4061}
4062
4063impl<T> DerefMut for State<T> {
4064 fn deref_mut(&mut self) -> &mut Self::Target {
4065 &mut self.0
4066 }
4067}
4068
4069#[derive(Debug)]
4071pub enum StateExtractError {
4072 MissingAppState,
4076 MissingStateType {
4080 type_name: &'static str,
4082 },
4083}
4084
4085impl std::fmt::Display for StateExtractError {
4086 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
4087 match self {
4088 Self::MissingAppState => {
4089 write!(f, "Application state not configured in request")
4090 }
4091 Self::MissingStateType { type_name } => {
4092 write!(f, "State type not found: {type_name}")
4093 }
4094 }
4095 }
4096}
4097
4098impl std::error::Error for StateExtractError {}
4099
4100impl IntoResponse for StateExtractError {
4101 fn into_response(self) -> crate::response::Response {
4102 HttpError::internal()
4104 .with_detail(self.to_string())
4105 .into_response()
4106 }
4107}
4108
4109impl<T> FromRequest for State<T>
4110where
4111 T: Clone + Send + Sync + 'static,
4112{
4113 type Error = StateExtractError;
4114
4115 async fn from_request(_ctx: &RequestContext, req: &mut Request) -> Result<Self, Self::Error> {
4116 let app_state = req
4118 .get_extension::<AppState>()
4119 .ok_or(StateExtractError::MissingAppState)?;
4120
4121 let value = app_state
4123 .get::<T>()
4124 .ok_or(StateExtractError::MissingStateType {
4125 type_name: std::any::type_name::<T>(),
4126 })?;
4127
4128 Ok(State(value.clone()))
4129 }
4130}
4131
4132#[cfg(test)]
4133mod state_tests {
4134 use super::*;
4135 use crate::request::Method;
4136
4137 fn test_context() -> RequestContext {
4138 let cx = asupersync::Cx::for_testing();
4139 RequestContext::new(cx, 12345)
4140 }
4141
4142 #[derive(Clone, Debug, PartialEq)]
4143 struct DatabasePool {
4144 connection_string: String,
4145 }
4146
4147 #[derive(Clone, Debug, PartialEq)]
4148 struct AppConfig {
4149 debug: bool,
4150 port: u16,
4151 }
4152
4153 #[test]
4154 fn app_state_new_is_empty() {
4155 let state = AppState::new();
4156 assert!(state.is_empty());
4157 assert_eq!(state.len(), 0);
4158 }
4159
4160 #[test]
4161 fn app_state_with_single_type() {
4162 let db = DatabasePool {
4163 connection_string: "postgres://localhost".into(),
4164 };
4165 let state = AppState::new().with(db.clone());
4166
4167 assert!(!state.is_empty());
4168 assert_eq!(state.len(), 1);
4169 assert!(state.contains::<DatabasePool>());
4170 assert_eq!(state.get::<DatabasePool>(), Some(&db));
4171 }
4172
4173 #[test]
4174 fn app_state_with_multiple_types() {
4175 let db = DatabasePool {
4176 connection_string: "postgres://localhost".into(),
4177 };
4178 let config = AppConfig {
4179 debug: true,
4180 port: 8080,
4181 };
4182
4183 let state = AppState::new().with(db.clone()).with(config.clone());
4184
4185 assert_eq!(state.len(), 2);
4186 assert_eq!(state.get::<DatabasePool>(), Some(&db));
4187 assert_eq!(state.get::<AppConfig>(), Some(&config));
4188 }
4189
4190 #[test]
4191 fn app_state_get_missing_type() {
4192 let state = AppState::new().with(42i32);
4193 assert!(state.get::<String>().is_none());
4194 assert!(!state.contains::<String>());
4195 }
4196
4197 #[test]
4198 fn state_deref() {
4199 let state = State(42i32);
4200 assert_eq!(*state, 42);
4201 }
4202
4203 #[test]
4204 fn state_into_inner() {
4205 let state = State("hello".to_string());
4206 assert_eq!(state.into_inner(), "hello");
4207 }
4208
4209 #[test]
4210 fn state_extract_success() {
4211 let ctx = test_context();
4212 let db = DatabasePool {
4213 connection_string: "postgres://localhost".into(),
4214 };
4215 let app_state = AppState::new().with(db.clone());
4216
4217 let mut req = Request::new(Method::Get, "/test");
4218 req.insert_extension(app_state);
4219
4220 let result =
4221 futures_executor::block_on(State::<DatabasePool>::from_request(&ctx, &mut req));
4222 let State(extracted) = result.unwrap();
4223 assert_eq!(extracted, db);
4224 }
4225
4226 #[test]
4227 fn state_extract_multiple_types() {
4228 let ctx = test_context();
4229 let db = DatabasePool {
4230 connection_string: "postgres://localhost".into(),
4231 };
4232 let config = AppConfig {
4233 debug: true,
4234 port: 8080,
4235 };
4236 let app_state = AppState::new().with(db.clone()).with(config.clone());
4237
4238 let mut req = Request::new(Method::Get, "/test");
4239 req.insert_extension(app_state);
4240
4241 let result =
4243 futures_executor::block_on(State::<DatabasePool>::from_request(&ctx, &mut req));
4244 let State(extracted_db) = result.unwrap();
4245 assert_eq!(extracted_db, db);
4246
4247 let result = futures_executor::block_on(State::<AppConfig>::from_request(&ctx, &mut req));
4249 let State(extracted_config) = result.unwrap();
4250 assert_eq!(extracted_config, config);
4251 }
4252
4253 #[test]
4254 fn state_extract_missing_app_state() {
4255 let ctx = test_context();
4256 let mut req = Request::new(Method::Get, "/test");
4257 let result =
4260 futures_executor::block_on(State::<DatabasePool>::from_request(&ctx, &mut req));
4261 assert!(matches!(result, Err(StateExtractError::MissingAppState)));
4262 }
4263
4264 #[test]
4265 fn state_extract_missing_type() {
4266 let ctx = test_context();
4267 let app_state = AppState::new().with(42i32);
4268
4269 let mut req = Request::new(Method::Get, "/test");
4270 req.insert_extension(app_state);
4271
4272 let result =
4273 futures_executor::block_on(State::<DatabasePool>::from_request(&ctx, &mut req));
4274 assert!(matches!(
4275 result,
4276 Err(StateExtractError::MissingStateType { .. })
4277 ));
4278 }
4279
4280 #[test]
4281 fn state_error_display() {
4282 let err = StateExtractError::MissingAppState;
4283 assert!(err.to_string().contains("not configured"));
4284
4285 let err = StateExtractError::MissingStateType {
4286 type_name: "DatabasePool",
4287 };
4288 assert!(err.to_string().contains("DatabasePool"));
4289 }
4290
4291 #[test]
4292 fn app_state_clone() {
4293 let db = DatabasePool {
4294 connection_string: "postgres://localhost".into(),
4295 };
4296 let state1 = AppState::new().with(db.clone());
4297 let state2 = state1.clone();
4298
4299 assert_eq!(state2.get::<DatabasePool>(), Some(&db));
4300 }
4301
4302 #[test]
4303 fn state_with_arc() {
4304 use std::sync::Arc;
4305
4306 let ctx = test_context();
4307 let db = Arc::new(DatabasePool {
4308 connection_string: "postgres://localhost".into(),
4309 });
4310 let app_state = AppState::new().with(db.clone());
4311
4312 let mut req = Request::new(Method::Get, "/test");
4313 req.insert_extension(app_state);
4314
4315 let result =
4316 futures_executor::block_on(State::<Arc<DatabasePool>>::from_request(&ctx, &mut req));
4317 let State(extracted) = result.unwrap();
4318 assert_eq!(extracted.connection_string, "postgres://localhost");
4319 }
4320}
4321
4322#[derive(Debug, Clone)]
4354pub struct Header<T> {
4355 pub value: T,
4357 pub name: String,
4359}
4360
4361impl<T> Header<T> {
4362 #[must_use]
4364 pub fn new(name: impl Into<String>, value: T) -> Self {
4365 Self {
4366 value,
4367 name: name.into(),
4368 }
4369 }
4370
4371 #[must_use]
4373 pub fn into_inner(self) -> T {
4374 self.value
4375 }
4376}
4377
4378impl<T> Deref for Header<T> {
4379 type Target = T;
4380
4381 fn deref(&self) -> &Self::Target {
4382 &self.value
4383 }
4384}
4385
4386impl<T> DerefMut for Header<T> {
4387 fn deref_mut(&mut self) -> &mut Self::Target {
4388 &mut self.value
4389 }
4390}
4391
4392#[must_use]
4399pub fn snake_to_header_case(name: &str) -> String {
4400 name.split('_')
4401 .map(|word| {
4402 let mut chars = word.chars();
4403 match chars.next() {
4404 None => String::new(),
4405 Some(first) => {
4406 let mut result = first.to_uppercase().to_string();
4407 result.extend(chars);
4408 result
4409 }
4410 }
4411 })
4412 .collect::<Vec<_>>()
4413 .join("-")
4414}
4415
4416#[derive(Debug)]
4418pub enum HeaderExtractError {
4419 MissingHeader {
4421 name: String,
4423 },
4424 InvalidUtf8 {
4426 name: String,
4428 },
4429 ParseError {
4431 name: String,
4433 value: String,
4435 expected: &'static str,
4437 message: String,
4439 },
4440}
4441
4442impl std::fmt::Display for HeaderExtractError {
4443 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
4444 match self {
4445 Self::MissingHeader { name } => {
4446 write!(f, "Missing required header: {name}")
4447 }
4448 Self::InvalidUtf8 { name } => {
4449 write!(f, "Header '{name}' contains invalid UTF-8")
4450 }
4451 Self::ParseError {
4452 name,
4453 value,
4454 expected,
4455 message,
4456 } => {
4457 write!(
4458 f,
4459 "Failed to parse header '{name}' value '{value}' as {expected}: {message}"
4460 )
4461 }
4462 }
4463 }
4464}
4465
4466impl std::error::Error for HeaderExtractError {}
4467
4468impl IntoResponse for HeaderExtractError {
4469 fn into_response(self) -> crate::response::Response {
4470 let error = match &self {
4472 HeaderExtractError::MissingHeader { name } => {
4473 ValidationError::missing(crate::error::loc::header(name))
4474 .with_msg(format!("Missing required header: {name}"))
4475 }
4476 HeaderExtractError::InvalidUtf8 { name } => {
4477 ValidationError::type_error(crate::error::loc::header(name), "string")
4478 .with_msg(format!("Header '{name}' contains invalid UTF-8"))
4479 }
4480 HeaderExtractError::ParseError {
4481 name,
4482 value,
4483 expected,
4484 message,
4485 } => ValidationError::type_error(crate::error::loc::header(name), expected)
4486 .with_msg(format!("Failed to parse as {expected}: {message}"))
4487 .with_input(serde_json::Value::String(value.clone())),
4488 };
4489 ValidationErrors::single(error).into_response()
4490 }
4491}
4492
4493pub trait FromHeaderValue: Sized {
4495 fn from_header_value(value: &str) -> Result<Self, String>;
4497
4498 fn type_name() -> &'static str;
4500}
4501
4502impl FromHeaderValue for String {
4503 fn from_header_value(value: &str) -> Result<Self, String> {
4504 Ok(value.to_string())
4505 }
4506
4507 fn type_name() -> &'static str {
4508 "String"
4509 }
4510}
4511
4512impl FromHeaderValue for i32 {
4513 fn from_header_value(value: &str) -> Result<Self, String> {
4514 value.parse().map_err(|e| format!("{e}"))
4515 }
4516
4517 fn type_name() -> &'static str {
4518 "i32"
4519 }
4520}
4521
4522impl FromHeaderValue for i64 {
4523 fn from_header_value(value: &str) -> Result<Self, String> {
4524 value.parse().map_err(|e| format!("{e}"))
4525 }
4526
4527 fn type_name() -> &'static str {
4528 "i64"
4529 }
4530}
4531
4532impl FromHeaderValue for u32 {
4533 fn from_header_value(value: &str) -> Result<Self, String> {
4534 value.parse().map_err(|e| format!("{e}"))
4535 }
4536
4537 fn type_name() -> &'static str {
4538 "u32"
4539 }
4540}
4541
4542impl FromHeaderValue for u64 {
4543 fn from_header_value(value: &str) -> Result<Self, String> {
4544 value.parse().map_err(|e| format!("{e}"))
4545 }
4546
4547 fn type_name() -> &'static str {
4548 "u64"
4549 }
4550}
4551
4552impl FromHeaderValue for bool {
4553 fn from_header_value(value: &str) -> Result<Self, String> {
4554 match value.to_ascii_lowercase().as_str() {
4555 "true" | "1" | "yes" | "on" => Ok(true),
4556 "false" | "0" | "no" | "off" => Ok(false),
4557 _ => Err(format!("invalid boolean: {value}")),
4558 }
4559 }
4560
4561 fn type_name() -> &'static str {
4562 "bool"
4563 }
4564}
4565
4566#[derive(Debug, Clone)]
4588pub struct NamedHeader<T, N> {
4589 pub value: T,
4591 _marker: std::marker::PhantomData<N>,
4592}
4593
4594pub trait HeaderName {
4596 const NAME: &'static str;
4598}
4599
4600impl<T, N> NamedHeader<T, N> {
4601 #[must_use]
4603 pub fn new(value: T) -> Self {
4604 Self {
4605 value,
4606 _marker: std::marker::PhantomData,
4607 }
4608 }
4609
4610 #[must_use]
4612 pub fn into_inner(self) -> T {
4613 self.value
4614 }
4615}
4616
4617impl<T, N> Deref for NamedHeader<T, N> {
4618 type Target = T;
4619
4620 fn deref(&self) -> &Self::Target {
4621 &self.value
4622 }
4623}
4624
4625impl<T, N> DerefMut for NamedHeader<T, N> {
4626 fn deref_mut(&mut self) -> &mut Self::Target {
4627 &mut self.value
4628 }
4629}
4630
4631impl<T, N> FromRequest for NamedHeader<T, N>
4632where
4633 T: FromHeaderValue + Send + Sync + 'static,
4634 N: HeaderName + Send + Sync + 'static,
4635{
4636 type Error = HeaderExtractError;
4637
4638 async fn from_request(_ctx: &RequestContext, req: &mut Request) -> Result<Self, Self::Error> {
4639 let header_name = N::NAME;
4640
4641 let value_bytes =
4642 req.headers()
4643 .get(header_name)
4644 .ok_or_else(|| HeaderExtractError::MissingHeader {
4645 name: header_name.to_string(),
4646 })?;
4647
4648 let value_str =
4649 std::str::from_utf8(value_bytes).map_err(|_| HeaderExtractError::InvalidUtf8 {
4650 name: header_name.to_string(),
4651 })?;
4652
4653 let value =
4654 T::from_header_value(value_str).map_err(|message| HeaderExtractError::ParseError {
4655 name: header_name.to_string(),
4656 value: value_str.to_string(),
4657 expected: T::type_name(),
4658 message,
4659 })?;
4660
4661 Ok(NamedHeader::new(value))
4662 }
4663}
4664
4665pub struct Authorization;
4668impl HeaderName for Authorization {
4669 const NAME: &'static str = "authorization";
4670}
4671
4672pub struct ContentType;
4674impl HeaderName for ContentType {
4675 const NAME: &'static str = "content-type";
4676}
4677
4678pub struct Accept;
4680impl HeaderName for Accept {
4681 const NAME: &'static str = "accept";
4682}
4683
4684pub struct XRequestId;
4686impl HeaderName for XRequestId {
4687 const NAME: &'static str = "x-request-id";
4688}
4689
4690pub struct UserAgent;
4692impl HeaderName for UserAgent {
4693 const NAME: &'static str = "user-agent";
4694}
4695
4696pub struct Host;
4698impl HeaderName for Host {
4699 const NAME: &'static str = "host";
4700}
4701
4702#[derive(Debug, Clone)]
4745pub struct OAuth2PasswordBearer {
4746 pub token: String,
4748}
4749
4750impl OAuth2PasswordBearer {
4751 #[must_use]
4753 pub fn new(token: impl Into<String>) -> Self {
4754 Self {
4755 token: token.into(),
4756 }
4757 }
4758
4759 #[must_use]
4761 pub fn token(&self) -> &str {
4762 &self.token
4763 }
4764
4765 #[must_use]
4767 pub fn into_token(self) -> String {
4768 self.token
4769 }
4770}
4771
4772impl Deref for OAuth2PasswordBearer {
4773 type Target = str;
4774
4775 fn deref(&self) -> &Self::Target {
4776 &self.token
4777 }
4778}
4779
4780#[derive(Debug, Clone)]
4784pub struct OAuth2PasswordBearerConfig {
4785 pub token_url: String,
4787 pub refresh_url: Option<String>,
4789 pub scopes: std::collections::HashMap<String, String>,
4791 pub scheme_name: Option<String>,
4793 pub description: Option<String>,
4795 pub auto_error: bool,
4798}
4799
4800impl Default for OAuth2PasswordBearerConfig {
4801 fn default() -> Self {
4802 Self {
4803 token_url: "/token".to_string(),
4804 refresh_url: None,
4805 scopes: std::collections::HashMap::new(),
4806 scheme_name: None,
4807 description: None,
4808 auto_error: true,
4809 }
4810 }
4811}
4812
4813impl OAuth2PasswordBearerConfig {
4814 #[must_use]
4816 pub fn new(token_url: impl Into<String>) -> Self {
4817 Self {
4818 token_url: token_url.into(),
4819 ..Default::default()
4820 }
4821 }
4822
4823 #[must_use]
4825 pub fn with_refresh_url(mut self, url: impl Into<String>) -> Self {
4826 self.refresh_url = Some(url.into());
4827 self
4828 }
4829
4830 #[must_use]
4832 pub fn with_scope(mut self, scope: impl Into<String>, description: impl Into<String>) -> Self {
4833 self.scopes.insert(scope.into(), description.into());
4834 self
4835 }
4836
4837 #[must_use]
4839 pub fn with_scheme_name(mut self, name: impl Into<String>) -> Self {
4840 self.scheme_name = Some(name.into());
4841 self
4842 }
4843
4844 #[must_use]
4846 pub fn with_description(mut self, desc: impl Into<String>) -> Self {
4847 self.description = Some(desc.into());
4848 self
4849 }
4850
4851 #[must_use]
4853 pub fn with_auto_error(mut self, auto_error: bool) -> Self {
4854 self.auto_error = auto_error;
4855 self
4856 }
4857}
4858
4859#[derive(Debug, Clone)]
4861pub struct OAuth2BearerError {
4862 pub kind: OAuth2BearerErrorKind,
4864}
4865
4866#[derive(Debug, Clone, Copy, PartialEq, Eq)]
4868pub enum OAuth2BearerErrorKind {
4869 MissingHeader,
4871 InvalidScheme,
4873 EmptyToken,
4875}
4876
4877impl OAuth2BearerError {
4878 #[must_use]
4880 pub fn missing_header() -> Self {
4881 Self {
4882 kind: OAuth2BearerErrorKind::MissingHeader,
4883 }
4884 }
4885
4886 #[must_use]
4888 pub fn invalid_scheme() -> Self {
4889 Self {
4890 kind: OAuth2BearerErrorKind::InvalidScheme,
4891 }
4892 }
4893
4894 #[must_use]
4896 pub fn empty_token() -> Self {
4897 Self {
4898 kind: OAuth2BearerErrorKind::EmptyToken,
4899 }
4900 }
4901}
4902
4903impl fmt::Display for OAuth2BearerError {
4904 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
4905 match self.kind {
4906 OAuth2BearerErrorKind::MissingHeader => {
4907 write!(f, "Missing Authorization header")
4908 }
4909 OAuth2BearerErrorKind::InvalidScheme => {
4910 write!(f, "Authorization header must use Bearer scheme")
4911 }
4912 OAuth2BearerErrorKind::EmptyToken => {
4913 write!(f, "Bearer token is empty")
4914 }
4915 }
4916 }
4917}
4918
4919impl IntoResponse for OAuth2BearerError {
4920 fn into_response(self) -> crate::response::Response {
4921 use crate::response::{Response, ResponseBody, StatusCode};
4922
4923 let message = match self.kind {
4924 OAuth2BearerErrorKind::MissingHeader => "Not authenticated",
4925 OAuth2BearerErrorKind::InvalidScheme => "Invalid authentication credentials",
4926 OAuth2BearerErrorKind::EmptyToken => "Invalid authentication credentials",
4927 };
4928
4929 let body = serde_json::json!({
4930 "detail": message
4931 });
4932
4933 Response::with_status(StatusCode::UNAUTHORIZED)
4934 .header("www-authenticate", b"Bearer".to_vec())
4935 .header("content-type", b"application/json".to_vec())
4936 .body(ResponseBody::Bytes(body.to_string().into_bytes()))
4937 }
4938}
4939
4940impl FromRequest for OAuth2PasswordBearer {
4941 type Error = OAuth2BearerError;
4942
4943 async fn from_request(_ctx: &RequestContext, req: &mut Request) -> Result<Self, Self::Error> {
4944 let auth_header = req
4946 .headers()
4947 .get("authorization")
4948 .ok_or_else(OAuth2BearerError::missing_header)?;
4949
4950 let auth_str =
4951 std::str::from_utf8(auth_header).map_err(|_| OAuth2BearerError::invalid_scheme())?;
4952
4953 let mut parts = auth_str.split_whitespace();
4954 let scheme = parts.next().ok_or_else(OAuth2BearerError::invalid_scheme)?;
4955 if !scheme.eq_ignore_ascii_case("bearer") {
4956 return Err(OAuth2BearerError::invalid_scheme());
4957 }
4958
4959 let token = parts.next().unwrap_or("");
4960 if token.is_empty() {
4961 return Err(OAuth2BearerError::empty_token());
4962 }
4963
4964 const MAX_TOKEN_LEN: usize = 8192;
4966 if token.len() > MAX_TOKEN_LEN {
4967 return Err(OAuth2BearerError::empty_token());
4968 }
4969
4970 Ok(OAuth2PasswordBearer::new(token))
4971 }
4972}
4973
4974#[derive(Debug, Clone)]
5015pub struct BasicAuth {
5016 pub username: String,
5018 pub password: String,
5020}
5021
5022impl BasicAuth {
5023 #[must_use]
5025 pub fn new(username: impl Into<String>, password: impl Into<String>) -> Self {
5026 Self {
5027 username: username.into(),
5028 password: password.into(),
5029 }
5030 }
5031
5032 #[must_use]
5034 pub fn username(&self) -> &str {
5035 &self.username
5036 }
5037
5038 #[must_use]
5040 pub fn password(&self) -> &str {
5041 &self.password
5042 }
5043
5044 fn decode_credentials(encoded: &str) -> Option<(String, String)> {
5046 let decoded_bytes = base64_decode(encoded)?;
5048 let decoded = std::str::from_utf8(&decoded_bytes).ok()?;
5049
5050 let colon_pos = decoded.find(':')?;
5052 let username = decoded[..colon_pos].to_string();
5053 let password = decoded[colon_pos + 1..].to_string();
5054
5055 Some((username, password))
5056 }
5057}
5058
5059#[allow(clippy::cast_sign_loss)] fn base64_decode(input: &str) -> Option<Vec<u8>> {
5062 const DECODE_TABLE: [i8; 256] = {
5063 let mut table = [-1i8; 256];
5064 let alphabet = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
5065 let mut i = 0;
5066 while i < 64 {
5067 table[alphabet[i] as usize] = i as i8;
5068 i += 1;
5069 }
5070 table
5071 };
5072
5073 let input = input.trim_end_matches('=');
5074 let bytes = input.as_bytes();
5075 let output_len = bytes.len() * 3 / 4;
5076 let mut output = Vec::with_capacity(output_len);
5077
5078 let mut buffer = 0u32;
5079 let mut bits_collected = 0;
5080
5081 for &byte in bytes {
5082 let value = DECODE_TABLE[byte as usize];
5083 if value < 0 {
5084 return None; }
5086 buffer = (buffer << 6) | (value as u32);
5087 bits_collected += 6;
5088
5089 if bits_collected >= 8 {
5090 bits_collected -= 8;
5091 output.push((buffer >> bits_collected) as u8);
5092 }
5093 }
5094
5095 Some(output)
5096}
5097
5098#[derive(Debug, Clone)]
5100pub struct BasicAuthError {
5101 pub kind: BasicAuthErrorKind,
5103}
5104
5105#[derive(Debug, Clone, Copy, PartialEq, Eq)]
5107pub enum BasicAuthErrorKind {
5108 MissingHeader,
5110 InvalidScheme,
5112 InvalidEncoding,
5114 InvalidFormat,
5116}
5117
5118impl BasicAuthError {
5119 #[must_use]
5121 pub fn missing_header() -> Self {
5122 Self {
5123 kind: BasicAuthErrorKind::MissingHeader,
5124 }
5125 }
5126
5127 #[must_use]
5129 pub fn invalid_scheme() -> Self {
5130 Self {
5131 kind: BasicAuthErrorKind::InvalidScheme,
5132 }
5133 }
5134
5135 #[must_use]
5137 pub fn invalid_encoding() -> Self {
5138 Self {
5139 kind: BasicAuthErrorKind::InvalidEncoding,
5140 }
5141 }
5142
5143 #[must_use]
5145 pub fn invalid_format() -> Self {
5146 Self {
5147 kind: BasicAuthErrorKind::InvalidFormat,
5148 }
5149 }
5150}
5151
5152impl fmt::Display for BasicAuthError {
5153 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
5154 match self.kind {
5155 BasicAuthErrorKind::MissingHeader => {
5156 write!(f, "Missing Authorization header")
5157 }
5158 BasicAuthErrorKind::InvalidScheme => {
5159 write!(f, "Authorization header must use Basic scheme")
5160 }
5161 BasicAuthErrorKind::InvalidEncoding => {
5162 write!(f, "Invalid base64 encoding in Authorization header")
5163 }
5164 BasicAuthErrorKind::InvalidFormat => {
5165 write!(f, "Invalid format in Basic auth credentials")
5166 }
5167 }
5168 }
5169}
5170
5171impl IntoResponse for BasicAuthError {
5172 fn into_response(self) -> crate::response::Response {
5173 use crate::response::{Response, ResponseBody, StatusCode};
5174
5175 let message = match self.kind {
5176 BasicAuthErrorKind::MissingHeader => "Not authenticated",
5177 BasicAuthErrorKind::InvalidScheme => "Invalid authentication credentials",
5178 BasicAuthErrorKind::InvalidEncoding => "Invalid authentication credentials",
5179 BasicAuthErrorKind::InvalidFormat => "Invalid authentication credentials",
5180 };
5181
5182 let body = serde_json::json!({
5183 "detail": message
5184 });
5185
5186 Response::with_status(StatusCode::UNAUTHORIZED)
5187 .header("www-authenticate", b"Basic realm=\"api\"".to_vec())
5188 .header("content-type", b"application/json".to_vec())
5189 .body(ResponseBody::Bytes(body.to_string().into_bytes()))
5190 }
5191}
5192
5193impl FromRequest for BasicAuth {
5194 type Error = BasicAuthError;
5195
5196 async fn from_request(_ctx: &RequestContext, req: &mut Request) -> Result<Self, Self::Error> {
5197 let auth_header = req
5199 .headers()
5200 .get("authorization")
5201 .ok_or_else(BasicAuthError::missing_header)?;
5202
5203 let auth_str =
5205 std::str::from_utf8(auth_header).map_err(|_| BasicAuthError::invalid_encoding())?;
5206
5207 let mut parts = auth_str.split_whitespace();
5208 let scheme = parts.next().ok_or_else(BasicAuthError::invalid_scheme)?;
5209 if !scheme.eq_ignore_ascii_case("basic") {
5210 return Err(BasicAuthError::invalid_scheme());
5211 }
5212
5213 let encoded = parts.next().unwrap_or("");
5214 if encoded.is_empty() {
5215 return Err(BasicAuthError::invalid_format());
5216 }
5217
5218 const MAX_ENCODED_LEN: usize = 8192;
5220 if encoded.len() > MAX_ENCODED_LEN {
5221 return Err(BasicAuthError::invalid_format());
5222 }
5223
5224 let (username, password) = BasicAuth::decode_credentials(encoded.trim())
5226 .ok_or_else(BasicAuthError::invalid_format)?;
5227
5228 Ok(BasicAuth::new(username, password))
5229 }
5230}
5231
5232#[derive(Debug, Clone)]
5241pub struct BearerToken {
5242 token: String,
5243}
5244
5245impl BearerToken {
5246 #[must_use]
5248 pub fn new(token: impl Into<String>) -> Self {
5249 Self {
5250 token: token.into(),
5251 }
5252 }
5253
5254 #[must_use]
5256 pub fn token(&self) -> &str {
5257 &self.token
5258 }
5259}
5260
5261#[derive(Debug, Clone)]
5263pub struct BearerTokenError {
5264 pub kind: BearerTokenErrorKind,
5266}
5267
5268#[derive(Debug, Clone, Copy, PartialEq, Eq)]
5270pub enum BearerTokenErrorKind {
5271 MissingHeader,
5273 InvalidUtf8,
5275 InvalidScheme,
5277 EmptyToken,
5279}
5280
5281impl BearerTokenError {
5282 #[must_use]
5283 pub fn missing_header() -> Self {
5284 Self {
5285 kind: BearerTokenErrorKind::MissingHeader,
5286 }
5287 }
5288
5289 #[must_use]
5290 pub fn invalid_utf8() -> Self {
5291 Self {
5292 kind: BearerTokenErrorKind::InvalidUtf8,
5293 }
5294 }
5295
5296 #[must_use]
5297 pub fn invalid_scheme() -> Self {
5298 Self {
5299 kind: BearerTokenErrorKind::InvalidScheme,
5300 }
5301 }
5302
5303 #[must_use]
5304 pub fn empty_token() -> Self {
5305 Self {
5306 kind: BearerTokenErrorKind::EmptyToken,
5307 }
5308 }
5309}
5310
5311impl fmt::Display for BearerTokenError {
5312 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
5313 match self.kind {
5314 BearerTokenErrorKind::MissingHeader => write!(f, "Missing Authorization header"),
5315 BearerTokenErrorKind::InvalidUtf8 => write!(f, "Invalid Authorization header encoding"),
5316 BearerTokenErrorKind::InvalidScheme => {
5317 write!(f, "Authorization header must use Bearer scheme")
5318 }
5319 BearerTokenErrorKind::EmptyToken => write!(f, "Bearer token is empty"),
5320 }
5321 }
5322}
5323
5324impl IntoResponse for BearerTokenError {
5325 fn into_response(self) -> crate::response::Response {
5326 use crate::response::{Response, ResponseBody, StatusCode};
5327
5328 let detail = match self.kind {
5330 BearerTokenErrorKind::MissingHeader => "Not authenticated",
5331 BearerTokenErrorKind::InvalidUtf8 => "Invalid authentication credentials",
5332 BearerTokenErrorKind::InvalidScheme => "Invalid authentication credentials",
5333 BearerTokenErrorKind::EmptyToken => "Invalid authentication credentials",
5334 };
5335
5336 let body = serde_json::json!({ "detail": detail });
5337 Response::with_status(StatusCode::UNAUTHORIZED)
5338 .header("www-authenticate", b"Bearer".to_vec())
5339 .header("content-type", b"application/json".to_vec())
5340 .body(ResponseBody::Bytes(body.to_string().into_bytes()))
5341 }
5342}
5343
5344impl FromRequest for BearerToken {
5345 type Error = BearerTokenError;
5346
5347 async fn from_request(_ctx: &RequestContext, req: &mut Request) -> Result<Self, Self::Error> {
5348 let auth_header = req
5349 .headers()
5350 .get("authorization")
5351 .ok_or_else(BearerTokenError::missing_header)?;
5352
5353 let auth_str =
5354 std::str::from_utf8(auth_header).map_err(|_| BearerTokenError::invalid_utf8())?;
5355
5356 let mut parts = auth_str.split_whitespace();
5357 let scheme = parts.next().ok_or_else(BearerTokenError::invalid_scheme)?;
5358 if !scheme.eq_ignore_ascii_case("bearer") {
5359 return Err(BearerTokenError::invalid_scheme());
5360 }
5361
5362 let token = parts.next().unwrap_or("").trim();
5363 if token.is_empty() {
5364 return Err(BearerTokenError::empty_token());
5365 }
5366
5367 const MAX_TOKEN_LEN: usize = 8192;
5369 if token.len() > MAX_TOKEN_LEN {
5370 return Err(BearerTokenError::empty_token());
5371 }
5372
5373 Ok(BearerToken::new(token.to_string()))
5374 }
5375}
5376
5377#[derive(Debug, Clone, Copy, PartialEq, Eq)]
5383pub enum ApiKeyLocation {
5384 Header,
5386 Query,
5388 Cookie,
5390}
5391
5392#[derive(Debug, Clone)]
5422pub struct ApiKey {
5423 pub key: String,
5425}
5426
5427impl ApiKey {
5428 #[must_use]
5430 pub fn new(key: impl Into<String>) -> Self {
5431 Self { key: key.into() }
5432 }
5433
5434 #[must_use]
5436 pub fn key(&self) -> &str {
5437 &self.key
5438 }
5439
5440 #[must_use]
5442 pub fn into_key(self) -> String {
5443 self.key
5444 }
5445}
5446
5447impl Deref for ApiKey {
5448 type Target = str;
5449
5450 fn deref(&self) -> &Self::Target {
5451 &self.key
5452 }
5453}
5454
5455#[derive(Debug, Clone)]
5457pub struct ApiKeyConfig {
5458 pub name: String,
5460 pub location: ApiKeyLocation,
5462 pub description: Option<String>,
5464}
5465
5466impl Default for ApiKeyConfig {
5467 fn default() -> Self {
5468 Self {
5469 name: "X-API-Key".to_string(),
5470 location: ApiKeyLocation::Header,
5471 description: None,
5472 }
5473 }
5474}
5475
5476impl ApiKeyConfig {
5477 #[must_use]
5479 pub fn header(name: impl Into<String>) -> Self {
5480 Self {
5481 name: name.into(),
5482 location: ApiKeyLocation::Header,
5483 description: None,
5484 }
5485 }
5486
5487 #[must_use]
5489 pub fn query(name: impl Into<String>) -> Self {
5490 Self {
5491 name: name.into(),
5492 location: ApiKeyLocation::Query,
5493 description: None,
5494 }
5495 }
5496
5497 #[must_use]
5499 pub fn cookie(name: impl Into<String>) -> Self {
5500 Self {
5501 name: name.into(),
5502 location: ApiKeyLocation::Cookie,
5503 description: None,
5504 }
5505 }
5506
5507 #[must_use]
5509 pub fn with_description(mut self, desc: impl Into<String>) -> Self {
5510 self.description = Some(desc.into());
5511 self
5512 }
5513}
5514
5515#[derive(Debug, Clone)]
5517pub struct ApiKeyError {
5518 pub kind: ApiKeyErrorKind,
5520 pub location: ApiKeyLocation,
5522 pub name: String,
5524}
5525
5526#[derive(Debug, Clone, Copy, PartialEq, Eq)]
5528pub enum ApiKeyErrorKind {
5529 Missing,
5531 Empty,
5533}
5534
5535impl ApiKeyError {
5536 #[must_use]
5538 pub fn missing(location: ApiKeyLocation, name: impl Into<String>) -> Self {
5539 Self {
5540 kind: ApiKeyErrorKind::Missing,
5541 location,
5542 name: name.into(),
5543 }
5544 }
5545
5546 #[must_use]
5548 pub fn empty(location: ApiKeyLocation, name: impl Into<String>) -> Self {
5549 Self {
5550 kind: ApiKeyErrorKind::Empty,
5551 location,
5552 name: name.into(),
5553 }
5554 }
5555}
5556
5557impl fmt::Display for ApiKeyError {
5558 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
5559 let location_name = match self.location {
5560 ApiKeyLocation::Header => "header",
5561 ApiKeyLocation::Query => "query parameter",
5562 ApiKeyLocation::Cookie => "cookie",
5563 };
5564 match self.kind {
5565 ApiKeyErrorKind::Missing => {
5566 write!(f, "Missing API key in {} '{}'", location_name, self.name)
5567 }
5568 ApiKeyErrorKind::Empty => {
5569 write!(f, "Empty API key in {} '{}'", location_name, self.name)
5570 }
5571 }
5572 }
5573}
5574
5575impl IntoResponse for ApiKeyError {
5576 fn into_response(self) -> crate::response::Response {
5577 use crate::response::{Response, ResponseBody, StatusCode};
5578
5579 let body = serde_json::json!({
5580 "detail": "Not authenticated"
5581 });
5582
5583 Response::with_status(StatusCode::UNAUTHORIZED)
5584 .header("content-type", b"application/json".to_vec())
5585 .body(ResponseBody::Bytes(body.to_string().into_bytes()))
5586 }
5587}
5588
5589impl FromRequest for ApiKey {
5590 type Error = ApiKeyError;
5591
5592 async fn from_request(_ctx: &RequestContext, req: &mut Request) -> Result<Self, Self::Error> {
5593 let name = "X-API-Key";
5595 let location = ApiKeyLocation::Header;
5596
5597 let key = req
5598 .headers()
5599 .get("x-api-key")
5600 .and_then(|v| std::str::from_utf8(v).ok())
5601 .map(|s| s.trim().to_string())
5602 .ok_or_else(|| ApiKeyError::missing(location, name))?;
5603
5604 if key.is_empty() {
5605 return Err(ApiKeyError::empty(location, name));
5606 }
5607
5608 Ok(ApiKey::new(key))
5609 }
5610}
5611
5612pub trait CookieName {
5635 const NAME: &'static str;
5637}
5638
5639#[derive(Debug, Clone)]
5659pub struct Cookie<T, N> {
5660 pub value: T,
5662 _marker: std::marker::PhantomData<N>,
5663}
5664
5665impl<T, N> Cookie<T, N> {
5666 #[must_use]
5668 pub fn new(value: T) -> Self {
5669 Self {
5670 value,
5671 _marker: std::marker::PhantomData,
5672 }
5673 }
5674
5675 #[must_use]
5677 pub fn into_inner(self) -> T {
5678 self.value
5679 }
5680}
5681
5682impl<T, N> Deref for Cookie<T, N> {
5683 type Target = T;
5684
5685 fn deref(&self) -> &Self::Target {
5686 &self.value
5687 }
5688}
5689
5690impl<T, N> DerefMut for Cookie<T, N> {
5691 fn deref_mut(&mut self) -> &mut Self::Target {
5692 &mut self.value
5693 }
5694}
5695
5696#[derive(Debug, Clone)]
5698pub struct CookieExtractError {
5699 pub name: String,
5701 pub kind: CookieExtractErrorKind,
5703}
5704
5705#[derive(Debug, Clone, Copy, PartialEq, Eq)]
5707pub enum CookieExtractErrorKind {
5708 NoCookieHeader,
5710 NotFound,
5712 Empty,
5714 ParseError,
5716}
5717
5718impl CookieExtractError {
5719 #[must_use]
5721 pub fn no_header(name: impl Into<String>) -> Self {
5722 Self {
5723 name: name.into(),
5724 kind: CookieExtractErrorKind::NoCookieHeader,
5725 }
5726 }
5727
5728 #[must_use]
5730 pub fn not_found(name: impl Into<String>) -> Self {
5731 Self {
5732 name: name.into(),
5733 kind: CookieExtractErrorKind::NotFound,
5734 }
5735 }
5736
5737 #[must_use]
5739 pub fn empty(name: impl Into<String>) -> Self {
5740 Self {
5741 name: name.into(),
5742 kind: CookieExtractErrorKind::Empty,
5743 }
5744 }
5745
5746 #[must_use]
5748 pub fn parse_error(name: impl Into<String>) -> Self {
5749 Self {
5750 name: name.into(),
5751 kind: CookieExtractErrorKind::ParseError,
5752 }
5753 }
5754}
5755
5756impl fmt::Display for CookieExtractError {
5757 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
5758 match self.kind {
5759 CookieExtractErrorKind::NoCookieHeader => {
5760 write!(f, "No Cookie header in request")
5761 }
5762 CookieExtractErrorKind::NotFound => {
5763 write!(f, "Cookie '{}' not found", self.name)
5764 }
5765 CookieExtractErrorKind::Empty => {
5766 write!(f, "Cookie '{}' is empty", self.name)
5767 }
5768 CookieExtractErrorKind::ParseError => {
5769 write!(f, "Failed to parse cookie '{}'", self.name)
5770 }
5771 }
5772 }
5773}
5774
5775impl IntoResponse for CookieExtractError {
5776 fn into_response(self) -> crate::response::Response {
5777 use crate::response::{Response, ResponseBody, StatusCode};
5778
5779 let body = serde_json::json!({
5780 "detail": [{
5781 "type": "missing",
5782 "loc": ["cookie", &self.name],
5783 "msg": format!("Cookie '{}' is required", self.name),
5784 }]
5785 });
5786
5787 Response::with_status(StatusCode::UNPROCESSABLE_ENTITY)
5788 .header("content-type", b"application/json".to_vec())
5789 .body(ResponseBody::Bytes(body.to_string().into_bytes()))
5790 }
5791}
5792
5793fn parse_cookies(header: &str) -> impl Iterator<Item = (&str, &str)> {
5795 header.split(';').filter_map(|cookie| {
5796 let cookie = cookie.trim();
5797 let eq_pos = cookie.find('=')?;
5798 let name = cookie[..eq_pos].trim();
5799 let value = cookie[eq_pos + 1..].trim();
5800 Some((name, value))
5801 })
5802}
5803
5804impl<T, N> FromRequest for Cookie<T, N>
5805where
5806 T: FromHeaderValue + Send + Sync + 'static,
5807 N: CookieName + Send + Sync + 'static,
5808{
5809 type Error = CookieExtractError;
5810
5811 async fn from_request(_ctx: &RequestContext, req: &mut Request) -> Result<Self, Self::Error> {
5812 let cookie_name = N::NAME;
5813
5814 let cookie_header = req
5816 .headers()
5817 .get("cookie")
5818 .ok_or_else(|| CookieExtractError::no_header(cookie_name))?;
5819
5820 let cookie_str = std::str::from_utf8(cookie_header)
5821 .map_err(|_| CookieExtractError::not_found(cookie_name))?;
5822
5823 for (name, value) in parse_cookies(cookie_str) {
5825 if name == cookie_name {
5826 if value.is_empty() {
5827 return Err(CookieExtractError::empty(cookie_name));
5828 }
5829 let parsed = T::from_header_value(value)
5830 .map_err(|_| CookieExtractError::parse_error(cookie_name))?;
5831 return Ok(Cookie::new(parsed));
5832 }
5833 }
5834
5835 Err(CookieExtractError::not_found(cookie_name))
5836 }
5837}
5838
5839pub struct SessionId;
5842impl CookieName for SessionId {
5843 const NAME: &'static str = "session_id";
5844}
5845
5846pub struct CsrfToken;
5848impl CookieName for CsrfToken {
5849 const NAME: &'static str = "csrf_token";
5850}
5851
5852pub struct CsrfTokenCookie;
5854impl CookieName for CsrfTokenCookie {
5855 const NAME: &'static str = "csrf_token";
5856}
5857
5858#[derive(Debug, Clone)]
5897pub struct Form<T>(pub T);
5898
5899impl<T> Form<T> {
5900 #[must_use]
5902 pub fn new(value: T) -> Self {
5903 Self(value)
5904 }
5905
5906 #[must_use]
5908 pub fn into_inner(self) -> T {
5909 self.0
5910 }
5911}
5912
5913impl<T> Deref for Form<T> {
5914 type Target = T;
5915
5916 fn deref(&self) -> &Self::Target {
5917 &self.0
5918 }
5919}
5920
5921impl<T> DerefMut for Form<T> {
5922 fn deref_mut(&mut self) -> &mut Self::Target {
5923 &mut self.0
5924 }
5925}
5926
5927#[derive(Debug)]
5929pub struct FormExtractError {
5930 pub kind: FormExtractErrorKind,
5932}
5933
5934#[derive(Debug)]
5936pub enum FormExtractErrorKind {
5937 WrongContentType {
5939 actual: Option<String>,
5941 },
5942 ReadError(String),
5944 PayloadTooLarge { size: usize, limit: usize },
5946 DeserializeError(String),
5948}
5949
5950impl FormExtractError {
5951 #[must_use]
5953 pub fn wrong_content_type(actual: Option<String>) -> Self {
5954 Self {
5955 kind: FormExtractErrorKind::WrongContentType { actual },
5956 }
5957 }
5958
5959 #[must_use]
5961 pub fn read_error(msg: impl Into<String>) -> Self {
5962 Self {
5963 kind: FormExtractErrorKind::ReadError(msg.into()),
5964 }
5965 }
5966
5967 #[must_use]
5969 pub fn payload_too_large(size: usize, limit: usize) -> Self {
5970 Self {
5971 kind: FormExtractErrorKind::PayloadTooLarge { size, limit },
5972 }
5973 }
5974
5975 #[must_use]
5977 pub fn deserialize_error(msg: impl Into<String>) -> Self {
5978 Self {
5979 kind: FormExtractErrorKind::DeserializeError(msg.into()),
5980 }
5981 }
5982}
5983
5984impl fmt::Display for FormExtractError {
5985 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
5986 match &self.kind {
5987 FormExtractErrorKind::WrongContentType { actual } => {
5988 if let Some(ct) = actual {
5989 write!(
5990 f,
5991 "Expected content-type 'application/x-www-form-urlencoded', got '{}'",
5992 ct
5993 )
5994 } else {
5995 write!(
5996 f,
5997 "Expected content-type 'application/x-www-form-urlencoded', none provided"
5998 )
5999 }
6000 }
6001 FormExtractErrorKind::ReadError(msg) => {
6002 write!(f, "Failed to read form body: {}", msg)
6003 }
6004 FormExtractErrorKind::PayloadTooLarge { size, limit } => {
6005 write!(
6006 f,
6007 "Request body too large: {size} bytes exceeds {limit} byte limit"
6008 )
6009 }
6010 FormExtractErrorKind::DeserializeError(msg) => {
6011 write!(f, "Failed to deserialize form data: {}", msg)
6012 }
6013 }
6014 }
6015}
6016
6017impl IntoResponse for FormExtractError {
6018 fn into_response(self) -> crate::response::Response {
6019 use crate::response::{Response, ResponseBody, StatusCode};
6020
6021 let (status, detail) = match &self.kind {
6022 FormExtractErrorKind::WrongContentType { .. } => {
6023 (StatusCode::UNSUPPORTED_MEDIA_TYPE, self.to_string())
6024 }
6025 FormExtractErrorKind::ReadError(_) => (StatusCode::BAD_REQUEST, self.to_string()),
6026 FormExtractErrorKind::PayloadTooLarge { .. } => {
6027 (StatusCode::PAYLOAD_TOO_LARGE, self.to_string())
6028 }
6029 FormExtractErrorKind::DeserializeError(msg) => {
6030 (StatusCode::UNPROCESSABLE_ENTITY, msg.clone())
6031 }
6032 };
6033
6034 let body = serde_json::json!({
6035 "detail": detail
6036 });
6037
6038 Response::with_status(status)
6039 .header("content-type", b"application/json".to_vec())
6040 .body(ResponseBody::Bytes(body.to_string().into_bytes()))
6041 }
6042}
6043
6044fn parse_urlencoded(data: &str) -> impl Iterator<Item = (String, String)> + '_ {
6046 data.split('&').filter_map(|pair| {
6047 let mut parts = pair.splitn(2, '=');
6048 let key = parts.next()?;
6049 let value = parts.next().unwrap_or("");
6050
6051 let key = url_decode(key);
6053 let value = url_decode(value);
6054
6055 Some((key, value))
6056 })
6057}
6058
6059fn url_decode(input: &str) -> String {
6065 let mut bytes = Vec::with_capacity(input.len());
6066 let mut chars = input.as_bytes().iter().copied();
6067
6068 while let Some(b) = chars.next() {
6069 if b == b'%' {
6070 let hi = chars.next();
6072 let lo = chars.next();
6073 if let (Some(h), Some(l)) = (hi, lo) {
6074 let hex_str = [h, l];
6075 if let Ok(decoded) =
6076 u8::from_str_radix(std::str::from_utf8(&hex_str).unwrap_or(""), 16)
6077 {
6078 bytes.push(decoded);
6079 continue;
6080 }
6081 bytes.push(b'%');
6083 bytes.push(h);
6084 bytes.push(l);
6085 } else {
6086 bytes.push(b'%');
6088 if let Some(h) = hi {
6089 bytes.push(h);
6090 }
6091 }
6092 } else if b == b'+' {
6093 bytes.push(b' ');
6095 } else {
6096 bytes.push(b);
6097 }
6098 }
6099
6100 String::from_utf8(bytes).unwrap_or_else(|e| String::from_utf8_lossy(e.as_bytes()).into_owned())
6101}
6102
6103impl<T> FromRequest for Form<T>
6104where
6105 T: serde::de::DeserializeOwned + Send + Sync + 'static,
6106{
6107 type Error = FormExtractError;
6108
6109 async fn from_request(ctx: &RequestContext, req: &mut Request) -> Result<Self, Self::Error> {
6110 let content_type = req
6112 .headers()
6113 .get("content-type")
6114 .and_then(|v| std::str::from_utf8(v).ok());
6115
6116 let is_form = content_type
6117 .map(|ct| {
6118 ct.starts_with("application/x-www-form-urlencoded")
6119 || ct.starts_with("application/x-www-form-urlencoded;")
6120 })
6121 .unwrap_or(false);
6122
6123 if !is_form {
6124 return Err(FormExtractError::wrong_content_type(
6125 content_type.map(String::from),
6126 ));
6127 }
6128
6129 let limit = DEFAULT_JSON_LIMIT;
6131 let body = collect_body_limited(ctx, req.take_body(), limit)
6132 .await
6133 .map_err(|e| match e {
6134 RequestBodyStreamError::TooLarge { received, .. } => {
6135 FormExtractError::payload_too_large(received, limit)
6136 }
6137 other => FormExtractError::read_error(other.to_string()),
6138 })?;
6139 let body_str = std::str::from_utf8(&body)
6140 .map_err(|e| FormExtractError::read_error(format!("Invalid UTF-8: {}", e)))?;
6141
6142 let pairs: Vec<(String, String)> = parse_urlencoded(body_str).collect();
6144
6145 let mut map = serde_json::Map::new();
6148 for (key, value) in pairs {
6149 if key.ends_with("[]") {
6151 let base_key = &key[..key.len() - 2];
6152 let entry = map
6153 .entry(base_key.to_string())
6154 .or_insert_with(|| serde_json::Value::Array(Vec::new()));
6155 if let serde_json::Value::Array(arr) = entry {
6156 arr.push(serde_json::Value::String(value));
6157 }
6158 } else {
6159 map.insert(key, serde_json::Value::String(value));
6160 }
6161 }
6162
6163 let json_value = serde_json::Value::Object(map);
6164 let result: T = serde_json::from_value(json_value)
6165 .map_err(|e| FormExtractError::deserialize_error(e.to_string()))?;
6166
6167 Ok(Form(result))
6168 }
6169}
6170
6171#[derive(Debug, Clone)]
6175pub struct HeaderValues<T, N> {
6176 pub values: Vec<T>,
6178 _marker: std::marker::PhantomData<N>,
6179}
6180
6181impl<T, N> HeaderValues<T, N> {
6182 #[must_use]
6184 pub fn new(values: Vec<T>) -> Self {
6185 Self {
6186 values,
6187 _marker: std::marker::PhantomData,
6188 }
6189 }
6190
6191 #[must_use]
6193 pub fn is_empty(&self) -> bool {
6194 self.values.is_empty()
6195 }
6196
6197 #[must_use]
6199 pub fn len(&self) -> usize {
6200 self.values.len()
6201 }
6202}
6203
6204impl<T, N> Deref for HeaderValues<T, N> {
6205 type Target = Vec<T>;
6206
6207 fn deref(&self) -> &Self::Target {
6208 &self.values
6209 }
6210}
6211
6212#[derive(Debug, Clone, Copy)]
6247pub struct Valid<T>(pub T);
6248
6249impl<T> Valid<T> {
6250 pub fn into_inner(self) -> T {
6252 self.0
6253 }
6254}
6255
6256impl<T> Deref for Valid<T> {
6257 type Target = T;
6258
6259 fn deref(&self) -> &Self::Target {
6260 &self.0
6261 }
6262}
6263
6264impl<T> DerefMut for Valid<T> {
6265 fn deref_mut(&mut self) -> &mut Self::Target {
6266 &mut self.0
6267 }
6268}
6269
6270#[derive(Debug)]
6272pub enum ValidExtractError<E> {
6273 Extract(E),
6275 Validation(Box<ValidationErrors>),
6277}
6278
6279impl<E: std::fmt::Display> std::fmt::Display for ValidExtractError<E> {
6280 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
6281 match self {
6282 Self::Extract(e) => write!(f, "Extraction failed: {e}"),
6283 Self::Validation(e) => write!(f, "{e}"),
6284 }
6285 }
6286}
6287
6288impl<E: std::error::Error + 'static> std::error::Error for ValidExtractError<E> {
6289 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
6290 match self {
6291 Self::Extract(e) => Some(e),
6292 Self::Validation(e) => Some(&**e),
6293 }
6294 }
6295}
6296
6297impl<E: IntoResponse> IntoResponse for ValidExtractError<E> {
6298 fn into_response(self) -> crate::response::Response {
6299 match self {
6300 Self::Extract(e) => e.into_response(),
6301 Self::Validation(e) => (*e).into_response(),
6302 }
6303 }
6304}
6305
6306pub use crate::validation::Validate;
6311
6312impl<T> FromRequest for Valid<T>
6313where
6314 T: FromRequest,
6315 T::Error: IntoResponse,
6316 <T as Deref>::Target: Validate,
6317 T: Deref,
6318{
6319 type Error = ValidExtractError<T::Error>;
6320
6321 async fn from_request(ctx: &RequestContext, req: &mut Request) -> Result<Self, Self::Error> {
6322 let inner = T::from_request(ctx, req)
6324 .await
6325 .map_err(ValidExtractError::Extract)?;
6326
6327 inner.validate().map_err(ValidExtractError::Validation)?;
6329
6330 Ok(Valid(inner))
6331 }
6332}
6333
6334#[cfg(test)]
6335mod header_tests {
6336 use super::*;
6337 use crate::request::Method;
6338
6339 fn test_context() -> RequestContext {
6340 let cx = asupersync::Cx::for_testing();
6341 RequestContext::new(cx, 12345)
6342 }
6343
6344 #[test]
6345 fn snake_to_header_case_simple() {
6346 assert_eq!(snake_to_header_case("authorization"), "Authorization");
6347 assert_eq!(snake_to_header_case("content_type"), "Content-Type");
6348 assert_eq!(snake_to_header_case("x_request_id"), "X-Request-Id");
6349 assert_eq!(snake_to_header_case("accept"), "Accept");
6350 }
6351
6352 #[test]
6353 fn snake_to_header_case_edge_cases() {
6354 assert_eq!(snake_to_header_case(""), "");
6355 assert_eq!(snake_to_header_case("a"), "A");
6356 assert_eq!(snake_to_header_case("a_b_c"), "A-B-C");
6357 }
6358
6359 #[test]
6360 fn header_deref() {
6361 let header = Header::new("test", "value".to_string());
6362 assert_eq!(*header, "value");
6363 }
6364
6365 #[test]
6366 fn header_into_inner() {
6367 let header = Header::new("test", 42i32);
6368 assert_eq!(header.into_inner(), 42);
6369 }
6370
6371 #[test]
6372 fn from_header_value_string() {
6373 let result = String::from_header_value("test value");
6374 assert_eq!(result.unwrap(), "test value");
6375 }
6376
6377 #[test]
6378 fn from_header_value_i32() {
6379 assert_eq!(i32::from_header_value("42").unwrap(), 42);
6380 assert_eq!(i32::from_header_value("-1").unwrap(), -1);
6381 assert!(i32::from_header_value("abc").is_err());
6382 }
6383
6384 #[test]
6385 fn from_header_value_bool() {
6386 assert!(bool::from_header_value("true").unwrap());
6387 assert!(bool::from_header_value("1").unwrap());
6388 assert!(bool::from_header_value("yes").unwrap());
6389 assert!(!bool::from_header_value("false").unwrap());
6390 assert!(!bool::from_header_value("0").unwrap());
6391 assert!(!bool::from_header_value("no").unwrap());
6392 assert!(bool::from_header_value("maybe").is_err());
6393 }
6394
6395 #[test]
6396 fn named_header_extract_success() {
6397 let ctx = test_context();
6398 let mut req = Request::new(Method::Get, "/test");
6399 req.headers_mut()
6400 .insert("authorization", b"Bearer token123".to_vec());
6401
6402 let result = futures_executor::block_on(
6403 NamedHeader::<String, Authorization>::from_request(&ctx, &mut req),
6404 );
6405 let header = result.unwrap();
6406 assert_eq!(header.value, "Bearer token123");
6407 }
6408
6409 #[test]
6410 fn named_header_extract_i32() {
6411 let ctx = test_context();
6412 let mut req = Request::new(Method::Get, "/test");
6413 req.headers_mut().insert("x-request-id", b"12345".to_vec());
6414
6415 let result = futures_executor::block_on(NamedHeader::<i32, XRequestId>::from_request(
6416 &ctx, &mut req,
6417 ));
6418 let header = result.unwrap();
6419 assert_eq!(header.value, 12345);
6420 }
6421
6422 #[test]
6423 fn named_header_missing() {
6424 let ctx = test_context();
6425 let mut req = Request::new(Method::Get, "/test");
6426 let result = futures_executor::block_on(
6429 NamedHeader::<String, Authorization>::from_request(&ctx, &mut req),
6430 );
6431 assert!(matches!(
6432 result,
6433 Err(HeaderExtractError::MissingHeader { .. })
6434 ));
6435 }
6436
6437 #[test]
6438 fn named_header_parse_error() {
6439 let ctx = test_context();
6440 let mut req = Request::new(Method::Get, "/test");
6441 req.headers_mut()
6442 .insert("x-request-id", b"not-a-number".to_vec());
6443
6444 let result = futures_executor::block_on(NamedHeader::<i32, XRequestId>::from_request(
6445 &ctx, &mut req,
6446 ));
6447 assert!(matches!(result, Err(HeaderExtractError::ParseError { .. })));
6448 }
6449
6450 #[test]
6451 fn header_error_display() {
6452 let err = HeaderExtractError::MissingHeader {
6453 name: "Authorization".to_string(),
6454 };
6455 assert!(err.to_string().contains("Authorization"));
6456
6457 let err = HeaderExtractError::ParseError {
6458 name: "X-Count".to_string(),
6459 value: "abc".to_string(),
6460 expected: "i32",
6461 message: "invalid digit".to_string(),
6462 };
6463 assert!(err.to_string().contains("X-Count"));
6464 assert!(err.to_string().contains("abc"));
6465 }
6466
6467 #[test]
6468 fn optional_header_some() {
6469 let ctx = test_context();
6470 let mut req = Request::new(Method::Get, "/test");
6471 req.headers_mut()
6472 .insert("authorization", b"Bearer token".to_vec());
6473
6474 let result = futures_executor::block_on(
6475 Option::<NamedHeader<String, Authorization>>::from_request(&ctx, &mut req),
6476 );
6477 let opt = result.unwrap();
6478 assert!(opt.is_some());
6479 assert_eq!(opt.unwrap().value, "Bearer token");
6480 }
6481
6482 #[test]
6483 fn optional_header_none() {
6484 let ctx = test_context();
6485 let mut req = Request::new(Method::Get, "/test");
6486 let result = futures_executor::block_on(
6489 Option::<NamedHeader<String, Authorization>>::from_request(&ctx, &mut req),
6490 );
6491 let opt = result.unwrap();
6492 assert!(opt.is_none());
6493 }
6494}
6495
6496#[cfg(test)]
6497mod oauth2_tests {
6498 use super::*;
6499 use crate::request::Method;
6500 use crate::response::IntoResponse;
6501
6502 fn test_context() -> RequestContext {
6503 let cx = asupersync::Cx::for_testing();
6504 RequestContext::new(cx, 12345)
6505 }
6506
6507 #[test]
6508 fn oauth2_extract_valid_bearer_token() {
6509 let ctx = test_context();
6510 let mut req = Request::new(Method::Get, "/api/protected");
6511 req.headers_mut()
6512 .insert("authorization", b"Bearer mytoken123".to_vec());
6513
6514 let result = futures_executor::block_on(OAuth2PasswordBearer::from_request(&ctx, &mut req));
6515 let bearer = result.unwrap();
6516 assert_eq!(bearer.token(), "mytoken123");
6517 assert_eq!(&*bearer, "mytoken123"); }
6519
6520 #[test]
6521 fn oauth2_extract_bearer_lowercase() {
6522 let ctx = test_context();
6523 let mut req = Request::new(Method::Get, "/api/protected");
6524 req.headers_mut()
6525 .insert("authorization", b"bearer lowercase_token".to_vec());
6526
6527 let result = futures_executor::block_on(OAuth2PasswordBearer::from_request(&ctx, &mut req));
6528 let bearer = result.unwrap();
6529 assert_eq!(bearer.token(), "lowercase_token");
6530 }
6531
6532 #[test]
6533 fn oauth2_missing_header() {
6534 let ctx = test_context();
6535 let mut req = Request::new(Method::Get, "/api/protected");
6536 let result = futures_executor::block_on(OAuth2PasswordBearer::from_request(&ctx, &mut req));
6539 let err = result.unwrap_err();
6540 assert_eq!(err.kind, OAuth2BearerErrorKind::MissingHeader);
6541 }
6542
6543 #[test]
6544 fn oauth2_wrong_scheme() {
6545 let ctx = test_context();
6546 let mut req = Request::new(Method::Get, "/api/protected");
6547 req.headers_mut()
6548 .insert("authorization", b"Basic dXNlcjpwYXNz".to_vec());
6549
6550 let result = futures_executor::block_on(OAuth2PasswordBearer::from_request(&ctx, &mut req));
6551 let err = result.unwrap_err();
6552 assert_eq!(err.kind, OAuth2BearerErrorKind::InvalidScheme);
6553 }
6554
6555 #[test]
6556 fn oauth2_empty_token() {
6557 let ctx = test_context();
6558 let mut req = Request::new(Method::Get, "/api/protected");
6559 req.headers_mut()
6560 .insert("authorization", b"Bearer ".to_vec());
6561
6562 let result = futures_executor::block_on(OAuth2PasswordBearer::from_request(&ctx, &mut req));
6563 let err = result.unwrap_err();
6564 assert_eq!(err.kind, OAuth2BearerErrorKind::EmptyToken);
6565 }
6566
6567 #[test]
6568 fn oauth2_whitespace_only_token() {
6569 let ctx = test_context();
6570 let mut req = Request::new(Method::Get, "/api/protected");
6571 req.headers_mut()
6572 .insert("authorization", b"Bearer ".to_vec());
6573
6574 let result = futures_executor::block_on(OAuth2PasswordBearer::from_request(&ctx, &mut req));
6575 let err = result.unwrap_err();
6576 assert_eq!(err.kind, OAuth2BearerErrorKind::EmptyToken);
6577 }
6578
6579 #[test]
6580 fn oauth2_token_with_spaces_trimmed() {
6581 let ctx = test_context();
6582 let mut req = Request::new(Method::Get, "/api/protected");
6583 req.headers_mut()
6584 .insert("authorization", b"Bearer spaced_token ".to_vec());
6585
6586 let result = futures_executor::block_on(OAuth2PasswordBearer::from_request(&ctx, &mut req));
6587 let bearer = result.unwrap();
6588 assert_eq!(bearer.token(), "spaced_token");
6589 }
6590
6591 #[test]
6592 fn oauth2_optional_extraction_some() {
6593 let ctx = test_context();
6594 let mut req = Request::new(Method::Get, "/api/maybe-protected");
6595 req.headers_mut()
6596 .insert("authorization", b"Bearer optional_token".to_vec());
6597
6598 let result = futures_executor::block_on(Option::<OAuth2PasswordBearer>::from_request(
6599 &ctx, &mut req,
6600 ));
6601 let opt = result.unwrap();
6602 assert!(opt.is_some());
6603 assert_eq!(opt.unwrap().token(), "optional_token");
6604 }
6605
6606 #[test]
6607 fn oauth2_optional_extraction_none() {
6608 let ctx = test_context();
6609 let mut req = Request::new(Method::Get, "/api/maybe-protected");
6610 let result = futures_executor::block_on(Option::<OAuth2PasswordBearer>::from_request(
6613 &ctx, &mut req,
6614 ));
6615 let opt = result.unwrap();
6616 assert!(opt.is_none());
6617 }
6618
6619 #[test]
6620 fn oauth2_error_response_401() {
6621 let err = OAuth2BearerError::missing_header();
6622 let response = err.into_response();
6623 assert_eq!(response.status().as_u16(), 401);
6624 }
6625
6626 #[test]
6627 fn oauth2_error_response_has_www_authenticate() {
6628 let err = OAuth2BearerError::missing_header();
6629 let response = err.into_response();
6630
6631 let www_auth = response
6632 .headers()
6633 .iter()
6634 .find(|(name, _)| name == "www-authenticate")
6635 .map(|(_, value)| String::from_utf8_lossy(value).to_string());
6636
6637 assert_eq!(www_auth, Some("Bearer".to_string()));
6638 }
6639
6640 #[test]
6641 fn oauth2_error_display() {
6642 assert!(
6643 OAuth2BearerError::missing_header()
6644 .to_string()
6645 .contains("Missing")
6646 );
6647 assert!(
6648 OAuth2BearerError::invalid_scheme()
6649 .to_string()
6650 .contains("Bearer")
6651 );
6652 assert!(
6653 OAuth2BearerError::empty_token()
6654 .to_string()
6655 .contains("empty")
6656 );
6657 }
6658
6659 #[test]
6660 fn oauth2_config_builder() {
6661 let config = OAuth2PasswordBearerConfig::new("/auth/token")
6662 .with_refresh_url("/auth/refresh")
6663 .with_scope("read", "Read access")
6664 .with_scope("write", "Write access")
6665 .with_scheme_name("MyOAuth2")
6666 .with_description("Custom OAuth2 scheme")
6667 .with_auto_error(false);
6668
6669 assert_eq!(config.token_url, "/auth/token");
6670 assert_eq!(config.refresh_url, Some("/auth/refresh".to_string()));
6671 assert_eq!(config.scopes.len(), 2);
6672 assert_eq!(config.scopes.get("read"), Some(&"Read access".to_string()));
6673 assert_eq!(config.scheme_name, Some("MyOAuth2".to_string()));
6674 assert!(!config.auto_error);
6675 }
6676
6677 #[test]
6678 fn oauth2_password_bearer_accessors() {
6679 let bearer = OAuth2PasswordBearer::new("test_token");
6680 assert_eq!(bearer.token(), "test_token");
6681 assert_eq!(bearer.into_token(), "test_token");
6682 }
6683}
6684
6685#[cfg(test)]
6686mod path_tests {
6687 use super::*;
6688 use crate::request::Method;
6689 use serde::Deserialize;
6690
6691 fn test_context() -> RequestContext {
6693 let cx = asupersync::Cx::for_testing();
6694 RequestContext::new(cx, 12345)
6695 }
6696
6697 fn request_with_params(params: Vec<(&str, &str)>) -> Request {
6699 let mut req = Request::new(Method::Get, "/test");
6700 let path_params = PathParams::from_pairs(
6701 params
6702 .into_iter()
6703 .map(|(k, v)| (k.to_string(), v.to_string()))
6704 .collect(),
6705 );
6706 req.insert_extension(path_params);
6707 req
6708 }
6709
6710 #[test]
6711 fn path_params_get() {
6712 let params = PathParams::from_pairs(vec![("id".to_string(), "42".to_string())]);
6713 assert_eq!(params.get("id"), Some("42"));
6714 assert_eq!(params.get("unknown"), None);
6715 }
6716
6717 #[test]
6718 fn path_params_len() {
6719 let params = PathParams::new();
6720 assert!(params.is_empty());
6721 assert_eq!(params.len(), 0);
6722
6723 let params = PathParams::from_pairs(vec![
6724 ("a".to_string(), "1".to_string()),
6725 ("b".to_string(), "2".to_string()),
6726 ]);
6727 assert!(!params.is_empty());
6728 assert_eq!(params.len(), 2);
6729 }
6730
6731 #[test]
6732 fn path_extract_single_i64() {
6733 let ctx = test_context();
6734 let mut req = request_with_params(vec![("id", "42")]);
6735
6736 let result = futures_executor::block_on(Path::<i64>::from_request(&ctx, &mut req));
6737 let Path(id) = result.unwrap();
6738 assert_eq!(id, 42);
6739 }
6740
6741 #[test]
6742 fn path_extract_single_string() {
6743 let ctx = test_context();
6744 let mut req = request_with_params(vec![("name", "alice")]);
6745
6746 let result = futures_executor::block_on(Path::<String>::from_request(&ctx, &mut req));
6747 let Path(name) = result.unwrap();
6748 assert_eq!(name, "alice");
6749 }
6750
6751 #[test]
6752 fn path_extract_single_u32() {
6753 let ctx = test_context();
6754 let mut req = request_with_params(vec![("count", "100")]);
6755
6756 let result = futures_executor::block_on(Path::<u32>::from_request(&ctx, &mut req));
6757 let Path(count) = result.unwrap();
6758 assert_eq!(count, 100);
6759 }
6760
6761 #[test]
6762 fn path_extract_tuple() {
6763 let ctx = test_context();
6764 let mut req = request_with_params(vec![("user_id", "42"), ("post_id", "99")]);
6765
6766 let result = futures_executor::block_on(Path::<(i64, i64)>::from_request(&ctx, &mut req));
6767 let Path((user_id, post_id)) = result.unwrap();
6768 assert_eq!(user_id, 42);
6769 assert_eq!(post_id, 99);
6770 }
6771
6772 #[test]
6773 fn path_extract_tuple_mixed_types() {
6774 let ctx = test_context();
6775 let mut req = request_with_params(vec![("name", "alice"), ("id", "123")]);
6776
6777 let result =
6778 futures_executor::block_on(Path::<(String, i64)>::from_request(&ctx, &mut req));
6779 let Path((name, id)) = result.unwrap();
6780 assert_eq!(name, "alice");
6781 assert_eq!(id, 123);
6782 }
6783
6784 #[test]
6785 fn path_extract_struct() {
6786 #[derive(Deserialize, Debug, PartialEq)]
6787 struct UserPath {
6788 user_id: i64,
6789 post_id: i64,
6790 }
6791
6792 let ctx = test_context();
6793 let mut req = request_with_params(vec![("user_id", "42"), ("post_id", "99")]);
6794
6795 let result = futures_executor::block_on(Path::<UserPath>::from_request(&ctx, &mut req));
6796 let Path(path) = result.unwrap();
6797 assert_eq!(path.user_id, 42);
6798 assert_eq!(path.post_id, 99);
6799 }
6800
6801 #[test]
6802 fn path_extract_missing_params() {
6803 let ctx = test_context();
6804 let mut req = Request::new(Method::Get, "/test");
6805 let result = futures_executor::block_on(Path::<i64>::from_request(&ctx, &mut req));
6808 assert!(matches!(result, Err(PathExtractError::MissingPathParams)));
6809 }
6810
6811 #[test]
6812 fn path_extract_invalid_type() {
6813 let ctx = test_context();
6814 let mut req = request_with_params(vec![("id", "not_a_number")]);
6815
6816 let result = futures_executor::block_on(Path::<i64>::from_request(&ctx, &mut req));
6817 assert!(matches!(
6818 result,
6819 Err(PathExtractError::InvalidValue { name, .. }) if name == "id"
6820 ));
6821 }
6822
6823 #[test]
6824 fn path_extract_negative_for_unsigned() {
6825 let ctx = test_context();
6826 let mut req = request_with_params(vec![("count", "-5")]);
6827
6828 let result = futures_executor::block_on(Path::<u32>::from_request(&ctx, &mut req));
6829 assert!(matches!(result, Err(PathExtractError::InvalidValue { .. })));
6830 }
6831
6832 #[test]
6833 fn path_extract_f64() {
6834 let ctx = test_context();
6835 let mut req = request_with_params(vec![("price", "19.99")]);
6836
6837 let result = futures_executor::block_on(Path::<f64>::from_request(&ctx, &mut req));
6838 let Path(price) = result.unwrap();
6839 assert!((price - 19.99).abs() < 0.001);
6840 }
6841
6842 #[test]
6843 fn path_deref() {
6844 let path = Path(42i64);
6845 assert_eq!(*path, 42);
6846 }
6847
6848 #[test]
6849 fn path_into_inner() {
6850 let path = Path("hello".to_string());
6851 assert_eq!(path.into_inner(), "hello");
6852 }
6853
6854 #[test]
6855 fn path_error_display() {
6856 let err = PathExtractError::MissingPathParams;
6857 assert!(err.to_string().contains("not available"));
6858
6859 let err = PathExtractError::MissingParam {
6860 name: "user_id".to_string(),
6861 };
6862 assert!(err.to_string().contains("user_id"));
6863
6864 let err = PathExtractError::InvalidValue {
6865 name: "id".to_string(),
6866 value: "abc".to_string(),
6867 expected: "i64",
6868 message: "invalid digit".to_string(),
6869 };
6870 assert!(err.to_string().contains("id"));
6871 assert!(err.to_string().contains("abc"));
6872 assert!(err.to_string().contains("i64"));
6873 }
6874
6875 #[test]
6876 fn path_extract_bool() {
6877 let ctx = test_context();
6878 let mut req = request_with_params(vec![("active", "true")]);
6879
6880 let result = futures_executor::block_on(Path::<bool>::from_request(&ctx, &mut req));
6881 let Path(active) = result.unwrap();
6882 assert!(active);
6883 }
6884
6885 #[test]
6886 fn path_extract_char() {
6887 let ctx = test_context();
6888 let mut req = request_with_params(vec![("letter", "A")]);
6889
6890 let result = futures_executor::block_on(Path::<char>::from_request(&ctx, &mut req));
6891 let Path(letter) = result.unwrap();
6892 assert_eq!(letter, 'A');
6893 }
6894}
6895
6896#[cfg(test)]
6897mod query_tests {
6898 use super::*;
6899 use crate::request::Method;
6900 use serde::Deserialize;
6901
6902 fn test_context() -> RequestContext {
6904 let cx = asupersync::Cx::for_testing();
6905 RequestContext::new(cx, 12345)
6906 }
6907
6908 fn request_with_query(query: &str) -> Request {
6910 let mut req = Request::new(Method::Get, "/test");
6911 req.set_query(Some(query.to_string()));
6912 req
6913 }
6914
6915 #[test]
6916 fn query_params_parse() {
6917 let params = QueryParams::parse("a=1&b=2&c=3");
6918 assert_eq!(params.get("a"), Some("1"));
6919 assert_eq!(params.get("b"), Some("2"));
6920 assert_eq!(params.get("c"), Some("3"));
6921 assert_eq!(params.get("d"), None);
6922 }
6923
6924 #[test]
6925 fn query_params_multi_value() {
6926 let params = QueryParams::parse("tag=rust&tag=web&tag=api");
6927 assert_eq!(params.get("tag"), Some("rust")); assert_eq!(params.get_all("tag"), vec!["rust", "web", "api"]);
6929 }
6930
6931 #[test]
6932 fn query_params_percent_decode() {
6933 let params = QueryParams::parse("msg=hello%20world&name=caf%C3%A9");
6934 assert_eq!(params.get("msg"), Some("hello world"));
6935 assert_eq!(params.get("name"), Some("café"));
6936 }
6937
6938 #[test]
6939 fn query_params_plus_as_space() {
6940 let params = QueryParams::parse("msg=hello+world");
6941 assert_eq!(params.get("msg"), Some("hello world"));
6942 }
6943
6944 #[test]
6945 fn query_params_empty_value() {
6946 let params = QueryParams::parse("flag&name=alice");
6947 assert!(params.contains("flag"));
6948 assert_eq!(params.get("flag"), Some(""));
6949 assert_eq!(params.get("name"), Some("alice"));
6950 }
6951
6952 #[test]
6953 fn query_extract_struct() {
6954 #[derive(Deserialize, Debug, PartialEq)]
6955 struct SearchParams {
6956 q: String,
6957 page: i32,
6958 }
6959
6960 let ctx = test_context();
6961 let mut req = request_with_query("q=rust&page=5");
6962
6963 let result =
6964 futures_executor::block_on(Query::<SearchParams>::from_request(&ctx, &mut req));
6965 let Query(params) = result.unwrap();
6966 assert_eq!(params.q, "rust");
6967 assert_eq!(params.page, 5);
6968 }
6969
6970 #[test]
6971 fn query_extract_optional_field() {
6972 #[derive(Deserialize, Debug)]
6973 struct Params {
6974 required: String,
6975 optional: Option<i32>,
6976 }
6977
6978 let ctx = test_context();
6979
6980 let mut req = request_with_query("required=hello&optional=42");
6982 let result = futures_executor::block_on(Query::<Params>::from_request(&ctx, &mut req));
6983 let Query(params) = result.unwrap();
6984 assert_eq!(params.required, "hello");
6985 assert_eq!(params.optional, Some(42));
6986
6987 let mut req = request_with_query("required=hello");
6989 let result = futures_executor::block_on(Query::<Params>::from_request(&ctx, &mut req));
6990 let Query(params) = result.unwrap();
6991 assert_eq!(params.required, "hello");
6992 assert_eq!(params.optional, None);
6993 }
6994
6995 #[test]
6996 fn query_extract_multi_value() {
6997 #[derive(Deserialize, Debug)]
6998 struct Params {
6999 tags: Vec<String>,
7000 }
7001
7002 let ctx = test_context();
7003 let mut req = request_with_query("tags=rust&tags=web&tags=api");
7004
7005 let result = futures_executor::block_on(Query::<Params>::from_request(&ctx, &mut req));
7006 let Query(params) = result.unwrap();
7007 assert_eq!(params.tags, vec!["rust", "web", "api"]);
7008 }
7009
7010 #[test]
7011 fn query_extract_default_value() {
7012 #[derive(Deserialize, Debug)]
7013 struct Params {
7014 name: String,
7015 #[serde(default)]
7016 limit: i32,
7017 }
7018
7019 let ctx = test_context();
7020 let mut req = request_with_query("name=test");
7021
7022 let result = futures_executor::block_on(Query::<Params>::from_request(&ctx, &mut req));
7023 let Query(params) = result.unwrap();
7024 assert_eq!(params.name, "test");
7025 assert_eq!(params.limit, 0); }
7027
7028 #[test]
7029 fn query_extract_bool() {
7030 #[derive(Deserialize, Debug)]
7031 struct Params {
7032 active: bool,
7033 archived: bool,
7034 }
7035
7036 let ctx = test_context();
7037 let mut req = request_with_query("active=true&archived=false");
7038
7039 let result = futures_executor::block_on(Query::<Params>::from_request(&ctx, &mut req));
7040 let Query(params) = result.unwrap();
7041 assert!(params.active);
7042 assert!(!params.archived);
7043 }
7044
7045 #[test]
7046 fn query_extract_bool_variants() {
7047 #[derive(Deserialize, Debug)]
7048 struct Params {
7049 a: bool,
7050 b: bool,
7051 c: bool,
7052 }
7053
7054 let ctx = test_context();
7055 let mut req = request_with_query("a=1&b=yes&c=on");
7056
7057 let result = futures_executor::block_on(Query::<Params>::from_request(&ctx, &mut req));
7058 let Query(params) = result.unwrap();
7059 assert!(params.a);
7060 assert!(params.b);
7061 assert!(params.c);
7062 }
7063
7064 #[test]
7065 fn query_extract_missing_required_fails() {
7066 #[derive(Deserialize, Debug)]
7067 #[allow(dead_code)]
7068 struct Params {
7069 required: String,
7070 }
7071
7072 let ctx = test_context();
7073 let mut req = request_with_query("other=value");
7074
7075 let result = futures_executor::block_on(Query::<Params>::from_request(&ctx, &mut req));
7076 assert!(result.is_err());
7077 }
7078
7079 #[test]
7080 fn query_extract_invalid_type_fails() {
7081 #[derive(Deserialize, Debug)]
7082 #[allow(dead_code)]
7083 struct Params {
7084 count: i32,
7085 }
7086
7087 let ctx = test_context();
7088 let mut req = request_with_query("count=not_a_number");
7089
7090 let result = futures_executor::block_on(Query::<Params>::from_request(&ctx, &mut req));
7091 assert!(result.is_err());
7092 }
7093
7094 #[test]
7095 fn query_extract_empty_query() {
7096 #[derive(Deserialize, Debug, Default)]
7097 struct Params {
7098 #[serde(default)]
7099 name: String,
7100 }
7101
7102 let ctx = test_context();
7103 let mut req = request_with_query("");
7104
7105 let result = futures_executor::block_on(Query::<Params>::from_request(&ctx, &mut req));
7106 let Query(params) = result.unwrap();
7107 assert_eq!(params.name, "");
7108 }
7109
7110 #[test]
7111 fn query_extract_float() {
7112 #[derive(Deserialize, Debug)]
7113 struct Params {
7114 price: f64,
7115 }
7116
7117 let ctx = test_context();
7118 let mut req = request_with_query("price=29.99");
7119
7120 let result = futures_executor::block_on(Query::<Params>::from_request(&ctx, &mut req));
7121 let Query(params) = result.unwrap();
7122 assert!((params.price - 29.99).abs() < 0.001);
7123 }
7124
7125 #[test]
7126 fn query_deref() {
7127 #[derive(Deserialize, Debug)]
7128 struct Params {
7129 name: String,
7130 }
7131
7132 let query = Query(Params {
7133 name: "test".to_string(),
7134 });
7135 assert_eq!(query.name, "test");
7136 }
7137
7138 #[test]
7139 fn query_into_inner() {
7140 #[derive(Deserialize, Debug, PartialEq)]
7141 struct Params {
7142 value: i32,
7143 }
7144
7145 let query = Query(Params { value: 42 });
7146 assert_eq!(query.into_inner(), Params { value: 42 });
7147 }
7148
7149 #[test]
7150 fn query_error_display() {
7151 let err = QueryExtractError::MissingParam {
7152 name: "user_id".to_string(),
7153 };
7154 assert!(err.to_string().contains("user_id"));
7155
7156 let err = QueryExtractError::InvalidValue {
7157 name: "count".to_string(),
7158 value: "abc".to_string(),
7159 expected: "i32",
7160 message: "invalid digit".to_string(),
7161 };
7162 assert!(err.to_string().contains("count"));
7163 assert!(err.to_string().contains("abc"));
7164 assert!(err.to_string().contains("i32"));
7165 }
7166
7167 #[test]
7168 fn query_params_keys() {
7169 let params = QueryParams::parse("a=1&b=2&a=3&c=4");
7170 let keys: Vec<&str> = params.keys().collect();
7171 assert_eq!(keys, vec!["a", "b", "c"]); }
7173
7174 #[test]
7175 fn query_params_len() {
7176 let params = QueryParams::parse("a=1&b=2&c=3");
7177 assert_eq!(params.len(), 3);
7178 assert!(!params.is_empty());
7179
7180 let empty = QueryParams::new();
7181 assert_eq!(empty.len(), 0);
7182 assert!(empty.is_empty());
7183 }
7184}
7185
7186#[cfg(test)]
7191mod optional_tests {
7192 use super::*;
7193 use crate::request::Method;
7194
7195 fn test_context() -> RequestContext {
7196 let cx = asupersync::Cx::for_testing();
7197 RequestContext::new(cx, 99999)
7198 }
7199
7200 #[test]
7203 fn optional_json_present_valid() {
7204 use serde::Deserialize;
7205
7206 #[derive(Deserialize, PartialEq, Debug)]
7207 struct Data {
7208 value: i32,
7209 }
7210
7211 let ctx = test_context();
7212 let mut req = Request::new(Method::Post, "/test");
7213 req.headers_mut()
7214 .insert("content-type", b"application/json".to_vec());
7215 req.set_body(Body::Bytes(b"{\"value\": 42}".to_vec()));
7216
7217 let result = futures_executor::block_on(Option::<Json<Data>>::from_request(&ctx, &mut req));
7218 let Some(Json(data)) = result.unwrap() else {
7219 panic!("Expected Some");
7220 };
7221 assert_eq!(data.value, 42);
7222 }
7223
7224 #[test]
7225 fn optional_json_invalid_content_type_returns_none() {
7226 use serde::Deserialize;
7227
7228 #[derive(Deserialize)]
7229 #[allow(dead_code)]
7230 struct Data {
7231 value: i32,
7232 }
7233
7234 let ctx = test_context();
7235 let mut req = Request::new(Method::Post, "/test");
7236 req.headers_mut()
7237 .insert("content-type", b"text/plain".to_vec());
7238 req.set_body(Body::Bytes(b"{\"value\": 42}".to_vec()));
7239
7240 let result = futures_executor::block_on(Option::<Json<Data>>::from_request(&ctx, &mut req));
7241 assert!(result.unwrap().is_none());
7242 }
7243
7244 #[test]
7245 fn optional_json_missing_body_returns_none() {
7246 use serde::Deserialize;
7247
7248 #[derive(Deserialize)]
7249 #[allow(dead_code)]
7250 struct Data {
7251 value: i32,
7252 }
7253
7254 let ctx = test_context();
7255 let mut req = Request::new(Method::Post, "/test");
7256 req.headers_mut()
7257 .insert("content-type", b"application/json".to_vec());
7258 let result = futures_executor::block_on(Option::<Json<Data>>::from_request(&ctx, &mut req));
7261 assert!(result.unwrap().is_none());
7263 }
7264
7265 #[test]
7266 fn optional_json_malformed_returns_none() {
7267 use serde::Deserialize;
7268
7269 #[derive(Deserialize)]
7270 #[allow(dead_code)]
7271 struct Data {
7272 value: i32,
7273 }
7274
7275 let ctx = test_context();
7276 let mut req = Request::new(Method::Post, "/test");
7277 req.headers_mut()
7278 .insert("content-type", b"application/json".to_vec());
7279 req.set_body(Body::Bytes(b"{ not valid json }".to_vec()));
7280
7281 let result = futures_executor::block_on(Option::<Json<Data>>::from_request(&ctx, &mut req));
7282 assert!(result.unwrap().is_none());
7283 }
7284
7285 #[test]
7288 fn optional_path_present_valid() {
7289 let ctx = test_context();
7290 let mut req = Request::new(Method::Get, "/users/42");
7291 req.insert_extension(PathParams::from_pairs(vec![(
7292 "id".to_string(),
7293 "42".to_string(),
7294 )]));
7295
7296 let result = futures_executor::block_on(Option::<Path<i64>>::from_request(&ctx, &mut req));
7297 let Some(Path(id)) = result.unwrap() else {
7298 panic!("Expected Some");
7299 };
7300 assert_eq!(id, 42);
7301 }
7302
7303 #[test]
7304 fn optional_path_missing_params_returns_none() {
7305 let ctx = test_context();
7306 let mut req = Request::new(Method::Get, "/users/42");
7307 let result = futures_executor::block_on(Option::<Path<i64>>::from_request(&ctx, &mut req));
7310 assert!(result.unwrap().is_none());
7311 }
7312
7313 #[test]
7314 fn optional_path_invalid_type_returns_none() {
7315 let ctx = test_context();
7316 let mut req = Request::new(Method::Get, "/users/abc");
7317 req.insert_extension(PathParams::from_pairs(vec![(
7318 "id".to_string(),
7319 "abc".to_string(),
7320 )]));
7321
7322 let result = futures_executor::block_on(Option::<Path<i64>>::from_request(&ctx, &mut req));
7323 assert!(result.unwrap().is_none());
7324 }
7325
7326 #[test]
7329 fn optional_query_present_valid() {
7330 use serde::Deserialize;
7331
7332 #[derive(Deserialize, PartialEq, Debug)]
7333 struct Params {
7334 page: i32,
7335 }
7336
7337 let ctx = test_context();
7338 let mut req = Request::new(Method::Get, "/items");
7339 req.set_query(Some("page=5".to_string()));
7340
7341 let result =
7342 futures_executor::block_on(Option::<Query<Params>>::from_request(&ctx, &mut req));
7343 let Some(Query(params)) = result.unwrap() else {
7344 panic!("Expected Some");
7345 };
7346 assert_eq!(params.page, 5);
7347 }
7348
7349 #[test]
7350 fn optional_query_missing_returns_none() {
7351 use serde::Deserialize;
7352
7353 #[derive(Deserialize)]
7354 #[allow(dead_code)]
7355 struct Params {
7356 required: String,
7357 }
7358
7359 let ctx = test_context();
7360 let mut req = Request::new(Method::Get, "/items");
7361 let result =
7364 futures_executor::block_on(Option::<Query<Params>>::from_request(&ctx, &mut req));
7365 assert!(result.unwrap().is_none());
7366 }
7367
7368 #[test]
7369 fn optional_query_invalid_type_returns_none() {
7370 use serde::Deserialize;
7371
7372 #[derive(Deserialize)]
7373 #[allow(dead_code)]
7374 struct Params {
7375 page: i32,
7376 }
7377
7378 let ctx = test_context();
7379 let mut req = Request::new(Method::Get, "/items");
7380 req.set_query(Some("page=abc".to_string()));
7381
7382 let result =
7383 futures_executor::block_on(Option::<Query<Params>>::from_request(&ctx, &mut req));
7384 assert!(result.unwrap().is_none());
7385 }
7386
7387 #[test]
7390 fn optional_state_present() {
7391 let ctx = test_context();
7392 let mut req = Request::new(Method::Get, "/");
7393 let app_state = AppState::new().with(42i32);
7394 req.insert_extension(app_state);
7395
7396 let result = futures_executor::block_on(Option::<State<i32>>::from_request(&ctx, &mut req));
7397 let Some(State(val)) = result.unwrap() else {
7398 panic!("Expected Some");
7399 };
7400 assert_eq!(val, 42);
7401 }
7402
7403 #[test]
7404 fn optional_state_missing_returns_none() {
7405 let ctx = test_context();
7406 let mut req = Request::new(Method::Get, "/");
7407 let result = futures_executor::block_on(Option::<State<i32>>::from_request(&ctx, &mut req));
7410 assert!(result.unwrap().is_none());
7411 }
7412
7413 #[test]
7414 fn optional_state_wrong_type_returns_none() {
7415 let ctx = test_context();
7416 let mut req = Request::new(Method::Get, "/");
7417 let app_state = AppState::new().with("string".to_string()); req.insert_extension(app_state);
7419
7420 let result = futures_executor::block_on(Option::<State<i32>>::from_request(&ctx, &mut req));
7421 assert!(result.unwrap().is_none());
7422 }
7423}
7424
7425#[cfg(test)]
7430mod combination_tests {
7431 use super::*;
7432 use crate::request::Method;
7433
7434 fn test_context() -> RequestContext {
7435 let cx = asupersync::Cx::for_testing();
7436 RequestContext::new(cx, 88888)
7437 }
7438
7439 #[test]
7440 fn path_and_query_together() {
7441 use serde::Deserialize;
7442
7443 #[derive(Deserialize, PartialEq, Debug)]
7444 struct QueryParams {
7445 limit: i32,
7446 }
7447
7448 let ctx = test_context();
7449 let mut req = Request::new(Method::Get, "/users/42");
7450 req.insert_extension(PathParams::from_pairs(vec![(
7451 "id".to_string(),
7452 "42".to_string(),
7453 )]));
7454 req.set_query(Some("limit=10".to_string()));
7455
7456 let path_result = futures_executor::block_on(Path::<i64>::from_request(&ctx, &mut req));
7458 let Path(user_id) = path_result.unwrap();
7459 assert_eq!(user_id, 42);
7460
7461 let query_result =
7463 futures_executor::block_on(Query::<QueryParams>::from_request(&ctx, &mut req));
7464 let Query(params) = query_result.unwrap();
7465 assert_eq!(params.limit, 10);
7466 }
7467
7468 #[test]
7469 fn json_body_and_path() {
7470 use serde::Deserialize;
7471
7472 #[derive(Deserialize, PartialEq, Debug)]
7473 struct CreateItem {
7474 name: String,
7475 }
7476
7477 let ctx = test_context();
7478 let mut req = Request::new(Method::Post, "/categories/5/items");
7479 req.headers_mut()
7480 .insert("content-type", b"application/json".to_vec());
7481 req.set_body(Body::Bytes(b"{\"name\": \"Widget\"}".to_vec()));
7482 req.insert_extension(PathParams::from_pairs(vec![(
7483 "cat_id".to_string(),
7484 "5".to_string(),
7485 )]));
7486
7487 let path_result = futures_executor::block_on(Path::<i64>::from_request(&ctx, &mut req));
7489 let Path(cat_id) = path_result.unwrap();
7490 assert_eq!(cat_id, 5);
7491
7492 let json_result =
7494 futures_executor::block_on(Json::<CreateItem>::from_request(&ctx, &mut req));
7495 let Json(item) = json_result.unwrap();
7496 assert_eq!(item.name, "Widget");
7497 }
7498
7499 #[test]
7500 fn state_and_query() {
7501 use serde::Deserialize;
7502
7503 #[derive(Deserialize, PartialEq, Debug)]
7504 struct SearchParams {
7505 q: String,
7506 }
7507
7508 #[derive(Clone, PartialEq, Debug)]
7509 struct Config {
7510 max_results: i32,
7511 }
7512
7513 let ctx = test_context();
7514 let mut req = Request::new(Method::Get, "/search");
7515 req.set_query(Some("q=hello".to_string()));
7516 let app_state = AppState::new().with(Config { max_results: 100 });
7517 req.insert_extension(app_state);
7518
7519 let state_result =
7521 futures_executor::block_on(State::<Config>::from_request(&ctx, &mut req));
7522 let State(config) = state_result.unwrap();
7523 assert_eq!(config.max_results, 100);
7524
7525 let query_result =
7527 futures_executor::block_on(Query::<SearchParams>::from_request(&ctx, &mut req));
7528 let Query(params) = query_result.unwrap();
7529 assert_eq!(params.q, "hello");
7530 }
7531
7532 #[test]
7533 fn multiple_path_params_with_struct() {
7534 use serde::Deserialize;
7535
7536 #[derive(Deserialize, PartialEq, Debug)]
7537 struct CommentPath {
7538 post_id: i64,
7539 comment_id: i64,
7540 }
7541
7542 let ctx = test_context();
7543 let mut req = Request::new(Method::Get, "/posts/123/comments/456");
7544 req.insert_extension(PathParams::from_pairs(vec![
7545 ("post_id".to_string(), "123".to_string()),
7546 ("comment_id".to_string(), "456".to_string()),
7547 ]));
7548
7549 let result = futures_executor::block_on(Path::<CommentPath>::from_request(&ctx, &mut req));
7550 let Path(path) = result.unwrap();
7551 assert_eq!(path.post_id, 123);
7552 assert_eq!(path.comment_id, 456);
7553 }
7554
7555 #[test]
7556 fn optional_mixed_with_required() {
7557 use serde::Deserialize;
7558
7559 #[derive(Deserialize, PartialEq, Debug)]
7560 struct OptionalParams {
7561 page: Option<i32>,
7562 }
7563
7564 let ctx = test_context();
7565 let mut req = Request::new(Method::Get, "/users/42");
7566 req.insert_extension(PathParams::from_pairs(vec![(
7567 "id".to_string(),
7568 "42".to_string(),
7569 )]));
7570
7571 let path_result = futures_executor::block_on(Path::<i64>::from_request(&ctx, &mut req));
7573 let Path(id) = path_result.unwrap();
7574 assert_eq!(id, 42);
7575
7576 let query_result =
7578 futures_executor::block_on(Query::<OptionalParams>::from_request(&ctx, &mut req));
7579 let Query(params) = query_result.unwrap();
7580 assert_eq!(params.page, None);
7581 }
7582
7583 #[test]
7584 fn request_context_extraction() {
7585 let ctx = test_context();
7586 let mut req = Request::new(Method::Get, "/");
7587
7588 let result = futures_executor::block_on(RequestContext::from_request(&ctx, &mut req));
7589 let extracted_ctx = result.unwrap();
7590 assert_eq!(extracted_ctx.request_id(), ctx.request_id());
7591 }
7592
7593 #[test]
7594 fn triple_extraction_path_query_state() {
7595 use serde::Deserialize;
7596
7597 #[derive(Deserialize, PartialEq, Debug)]
7598 struct QueryFilter {
7599 status: String,
7600 }
7601
7602 #[derive(Clone)]
7603 struct DbPool {
7604 connection_count: i32,
7605 }
7606
7607 let ctx = test_context();
7608 let mut req = Request::new(Method::Get, "/projects/99/tasks");
7609 req.insert_extension(PathParams::from_pairs(vec![(
7610 "project_id".to_string(),
7611 "99".to_string(),
7612 )]));
7613 req.set_query(Some("status=active".to_string()));
7614 let app_state = AppState::new().with(DbPool {
7615 connection_count: 10,
7616 });
7617 req.insert_extension(app_state);
7618
7619 let Path(project_id): Path<i32> =
7621 futures_executor::block_on(Path::<i32>::from_request(&ctx, &mut req)).unwrap();
7622 assert_eq!(project_id, 99);
7623
7624 let Query(filter): Query<QueryFilter> =
7626 futures_executor::block_on(Query::<QueryFilter>::from_request(&ctx, &mut req)).unwrap();
7627 assert_eq!(filter.status, "active");
7628
7629 let State(pool): State<DbPool> =
7631 futures_executor::block_on(State::<DbPool>::from_request(&ctx, &mut req)).unwrap();
7632 assert_eq!(pool.connection_count, 10);
7633 }
7634}
7635
7636#[cfg(test)]
7641mod edge_case_tests {
7642 use super::*;
7643 use crate::request::Method;
7644
7645 fn test_context() -> RequestContext {
7646 let cx = asupersync::Cx::for_testing();
7647 RequestContext::new(cx, 77777)
7648 }
7649
7650 #[test]
7653 fn json_with_unicode() {
7654 use serde::Deserialize;
7655
7656 #[derive(Deserialize, PartialEq, Debug)]
7657 struct Data {
7658 name: String,
7659 emoji: String,
7660 }
7661
7662 let ctx = test_context();
7663 let mut req = Request::new(Method::Post, "/test");
7664 req.headers_mut()
7665 .insert("content-type", b"application/json".to_vec());
7666 req.set_body(Body::Bytes(
7667 r#"{"name": "日本語", "emoji": "🎉🚀"}"#.as_bytes().to_vec(),
7668 ));
7669
7670 let result = futures_executor::block_on(Json::<Data>::from_request(&ctx, &mut req));
7671 let Json(data) = result.unwrap();
7672 assert_eq!(data.name, "日本語");
7673 assert_eq!(data.emoji, "🎉🚀");
7674 }
7675
7676 #[test]
7677 fn query_with_unicode_percent_encoded() {
7678 use serde::Deserialize;
7679
7680 #[derive(Deserialize, PartialEq, Debug)]
7681 struct Search {
7682 q: String,
7683 }
7684
7685 let ctx = test_context();
7686 let mut req = Request::new(Method::Get, "/search");
7687 req.set_query(Some(
7689 "q=%E3%81%93%E3%82%93%E3%81%AB%E3%81%A1%E3%81%AF".to_string(),
7690 ));
7691
7692 let result = futures_executor::block_on(Query::<Search>::from_request(&ctx, &mut req));
7693 let Query(search) = result.unwrap();
7694 assert_eq!(search.q, "こんにちは");
7695 }
7696
7697 #[test]
7698 fn path_with_unicode() {
7699 let ctx = test_context();
7700 let mut req = Request::new(Method::Get, "/users/用户123");
7701 req.insert_extension(PathParams::from_pairs(vec![(
7702 "name".to_string(),
7703 "用户123".to_string(),
7704 )]));
7705
7706 let result = futures_executor::block_on(Path::<String>::from_request(&ctx, &mut req));
7707 let Path(name) = result.unwrap();
7708 assert_eq!(name, "用户123");
7709 }
7710
7711 #[test]
7714 fn path_max_i64() {
7715 let ctx = test_context();
7716 let mut req = Request::new(Method::Get, "/items/9223372036854775807");
7717 req.insert_extension(PathParams::from_pairs(vec![(
7718 "id".to_string(),
7719 "9223372036854775807".to_string(),
7720 )]));
7721
7722 let result = futures_executor::block_on(Path::<i64>::from_request(&ctx, &mut req));
7723 let Path(id) = result.unwrap();
7724 assert_eq!(id, i64::MAX);
7725 }
7726
7727 #[test]
7728 fn path_min_i64() {
7729 let ctx = test_context();
7730 let mut req = Request::new(Method::Get, "/items/-9223372036854775808");
7731 req.insert_extension(PathParams::from_pairs(vec![(
7732 "id".to_string(),
7733 "-9223372036854775808".to_string(),
7734 )]));
7735
7736 let result = futures_executor::block_on(Path::<i64>::from_request(&ctx, &mut req));
7737 let Path(id) = result.unwrap();
7738 assert_eq!(id, i64::MIN);
7739 }
7740
7741 #[test]
7742 fn path_overflow_i64_fails() {
7743 let ctx = test_context();
7744 let mut req = Request::new(Method::Get, "/items/9223372036854775808");
7745 req.insert_extension(PathParams::from_pairs(vec![(
7746 "id".to_string(),
7747 "9223372036854775808".to_string(), )]));
7749
7750 let result = futures_executor::block_on(Path::<i64>::from_request(&ctx, &mut req));
7751 assert!(result.is_err());
7752 }
7753
7754 #[test]
7755 fn query_with_empty_value() {
7756 use serde::Deserialize;
7757
7758 #[derive(Deserialize, PartialEq, Debug)]
7759 struct Params {
7760 key: String,
7761 }
7762
7763 let ctx = test_context();
7764 let mut req = Request::new(Method::Get, "/test");
7765 req.set_query(Some("key=".to_string()));
7766
7767 let result = futures_executor::block_on(Query::<Params>::from_request(&ctx, &mut req));
7768 let Query(params) = result.unwrap();
7769 assert_eq!(params.key, "");
7770 }
7771
7772 #[test]
7773 fn query_with_only_key_no_equals() {
7774 use serde::Deserialize;
7775
7776 #[derive(Deserialize, PartialEq, Debug)]
7777 struct Params {
7778 flag: Option<String>,
7779 }
7780
7781 let ctx = test_context();
7782 let mut req = Request::new(Method::Get, "/test");
7783 req.set_query(Some("flag".to_string()));
7784
7785 let result = futures_executor::block_on(Query::<Params>::from_request(&ctx, &mut req));
7786 let Query(params) = result.unwrap();
7787 assert_eq!(params.flag, Some(String::new()));
7789 }
7790
7791 #[test]
7792 fn json_empty_object() {
7793 use serde::Deserialize;
7794
7795 #[derive(Deserialize, PartialEq, Debug)]
7796 struct Empty {}
7797
7798 let ctx = test_context();
7799 let mut req = Request::new(Method::Post, "/test");
7800 req.headers_mut()
7801 .insert("content-type", b"application/json".to_vec());
7802 req.set_body(Body::Bytes(b"{}".to_vec()));
7803
7804 let result = futures_executor::block_on(Json::<Empty>::from_request(&ctx, &mut req));
7805 assert!(result.is_ok());
7806 }
7807
7808 #[test]
7809 fn json_with_null_field() {
7810 use serde::Deserialize;
7811
7812 #[derive(Deserialize, PartialEq, Debug)]
7813 struct Data {
7814 value: Option<i32>,
7815 }
7816
7817 let ctx = test_context();
7818 let mut req = Request::new(Method::Post, "/test");
7819 req.headers_mut()
7820 .insert("content-type", b"application/json".to_vec());
7821 req.set_body(Body::Bytes(b"{\"value\": null}".to_vec()));
7822
7823 let result = futures_executor::block_on(Json::<Data>::from_request(&ctx, &mut req));
7824 let Json(data) = result.unwrap();
7825 assert_eq!(data.value, None);
7826 }
7827
7828 #[test]
7829 fn json_with_nested_objects() {
7830 use serde::Deserialize;
7831
7832 #[derive(Deserialize, PartialEq, Debug)]
7833 struct Address {
7834 city: String,
7835 zip: String,
7836 }
7837
7838 #[derive(Deserialize, PartialEq, Debug)]
7839 struct User {
7840 name: String,
7841 address: Address,
7842 }
7843
7844 let ctx = test_context();
7845 let mut req = Request::new(Method::Post, "/test");
7846 req.headers_mut()
7847 .insert("content-type", b"application/json".to_vec());
7848 req.set_body(Body::Bytes(
7849 b"{\"name\": \"Alice\", \"address\": {\"city\": \"NYC\", \"zip\": \"10001\"}}".to_vec(),
7850 ));
7851
7852 let result = futures_executor::block_on(Json::<User>::from_request(&ctx, &mut req));
7853 let Json(user) = result.unwrap();
7854 assert_eq!(user.name, "Alice");
7855 assert_eq!(user.address.city, "NYC");
7856 assert_eq!(user.address.zip, "10001");
7857 }
7858
7859 #[test]
7860 fn json_with_array() {
7861 use serde::Deserialize;
7862
7863 #[derive(Deserialize, PartialEq, Debug)]
7864 struct Data {
7865 items: Vec<i32>,
7866 }
7867
7868 let ctx = test_context();
7869 let mut req = Request::new(Method::Post, "/test");
7870 req.headers_mut()
7871 .insert("content-type", b"application/json".to_vec());
7872 req.set_body(Body::Bytes(b"{\"items\": [1, 2, 3, 4, 5]}".to_vec()));
7873
7874 let result = futures_executor::block_on(Json::<Data>::from_request(&ctx, &mut req));
7875 let Json(data) = result.unwrap();
7876 assert_eq!(data.items, vec![1, 2, 3, 4, 5]);
7877 }
7878
7879 #[test]
7880 fn path_with_special_chars() {
7881 let ctx = test_context();
7882 let mut req = Request::new(Method::Get, "/files/my-file_v2.txt");
7883 req.insert_extension(PathParams::from_pairs(vec![(
7884 "filename".to_string(),
7885 "my-file_v2.txt".to_string(),
7886 )]));
7887
7888 let result = futures_executor::block_on(Path::<String>::from_request(&ctx, &mut req));
7889 let Path(filename) = result.unwrap();
7890 assert_eq!(filename, "my-file_v2.txt");
7891 }
7892
7893 #[test]
7894 fn query_with_special_chars_encoded() {
7895 use serde::Deserialize;
7896
7897 #[derive(Deserialize, PartialEq, Debug)]
7898 struct Params {
7899 value: String,
7900 }
7901
7902 let ctx = test_context();
7903 let mut req = Request::new(Method::Get, "/test");
7904 req.set_query(Some("value=hello%20world%20%26%20more".to_string()));
7906
7907 let result = futures_executor::block_on(Query::<Params>::from_request(&ctx, &mut req));
7908 let Query(params) = result.unwrap();
7909 assert_eq!(params.value, "hello world & more");
7910 }
7911
7912 #[test]
7913 fn query_multiple_values_same_key() {
7914 use serde::Deserialize;
7915
7916 #[derive(Deserialize, PartialEq, Debug)]
7917 struct Params {
7918 tags: Vec<String>,
7919 }
7920
7921 let ctx = test_context();
7922 let mut req = Request::new(Method::Get, "/test");
7923 req.set_query(Some("tags=a&tags=b&tags=c".to_string()));
7924
7925 let result = futures_executor::block_on(Query::<Params>::from_request(&ctx, &mut req));
7926 let Query(params) = result.unwrap();
7927 assert_eq!(params.tags, vec!["a", "b", "c"]);
7928 }
7929
7930 #[test]
7931 fn path_empty_string() {
7932 let ctx = test_context();
7933 let mut req = Request::new(Method::Get, "/items//details");
7934 req.insert_extension(PathParams::from_pairs(vec![(
7935 "id".to_string(),
7936 String::new(),
7937 )]));
7938
7939 let result = futures_executor::block_on(Path::<String>::from_request(&ctx, &mut req));
7940 let Path(id) = result.unwrap();
7941 assert_eq!(id, "");
7942 }
7943
7944 #[test]
7945 fn json_with_escaped_quotes() {
7946 use serde::Deserialize;
7947
7948 #[derive(Deserialize, PartialEq, Debug)]
7949 struct Data {
7950 message: String,
7951 }
7952
7953 let ctx = test_context();
7954 let mut req = Request::new(Method::Post, "/test");
7955 req.headers_mut()
7956 .insert("content-type", b"application/json".to_vec());
7957 req.set_body(Body::Bytes(
7958 b"{\"message\": \"He said \\\"hello\\\"\"}".to_vec(),
7959 ));
7960
7961 let result = futures_executor::block_on(Json::<Data>::from_request(&ctx, &mut req));
7962 let Json(data) = result.unwrap();
7963 assert_eq!(data.message, "He said \"hello\"");
7964 }
7965
7966 #[test]
7967 fn query_with_plus_as_space() {
7968 use serde::Deserialize;
7969
7970 #[derive(Deserialize, PartialEq, Debug)]
7971 struct Params {
7972 q: String,
7973 }
7974
7975 let ctx = test_context();
7976 let mut req = Request::new(Method::Get, "/search");
7977 req.set_query(Some("q=hello+world".to_string()));
7978
7979 let result = futures_executor::block_on(Query::<Params>::from_request(&ctx, &mut req));
7980 let Query(params) = result.unwrap();
7981 assert_eq!(params.q, "hello world");
7982 }
7983}
7984
7985#[cfg(test)]
7990mod security_tests {
7991 use super::*;
7992 use crate::request::Method;
7993
7994 fn test_context() -> RequestContext {
7995 let cx = asupersync::Cx::for_testing();
7996 RequestContext::new(cx, 66666)
7997 }
7998
7999 #[test]
8000 fn json_payload_size_limit() {
8001 use serde::Deserialize;
8002
8003 #[derive(Deserialize)]
8004 #[allow(dead_code)]
8005 struct Data {
8006 content: String,
8007 }
8008
8009 let ctx = test_context();
8010 let mut req = Request::new(Method::Post, "/test");
8011 req.headers_mut()
8012 .insert("content-type", b"application/json".to_vec());
8013
8014 let large_content = "x".repeat(DEFAULT_JSON_LIMIT + 100);
8016 let body = format!("{{\"content\": \"{large_content}\"}}");
8017 req.set_body(Body::Bytes(body.into_bytes()));
8018
8019 let result = futures_executor::block_on(Json::<Data>::from_request(&ctx, &mut req));
8020 assert!(matches!(
8021 result,
8022 Err(JsonExtractError::PayloadTooLarge { .. })
8023 ));
8024 }
8025
8026 #[test]
8027 fn json_deeply_nested_object() {
8028 use serde::Deserialize;
8029
8030 #[derive(Deserialize)]
8032 struct Level1 {
8033 #[allow(dead_code)]
8034 l2: Level2,
8035 }
8036 #[derive(Deserialize)]
8037 struct Level2 {
8038 #[allow(dead_code)]
8039 l3: Level3,
8040 }
8041 #[derive(Deserialize)]
8042 struct Level3 {
8043 #[allow(dead_code)]
8044 l4: Level4,
8045 }
8046 #[derive(Deserialize)]
8047 struct Level4 {
8048 #[allow(dead_code)]
8049 value: i32,
8050 }
8051
8052 let ctx = test_context();
8053 let mut req = Request::new(Method::Post, "/test");
8054 req.headers_mut()
8055 .insert("content-type", b"application/json".to_vec());
8056 req.set_body(Body::Bytes(
8057 b"{\"l2\":{\"l3\":{\"l4\":{\"value\":42}}}}".to_vec(),
8058 ));
8059
8060 let result = futures_executor::block_on(Json::<Level1>::from_request(&ctx, &mut req));
8061 assert!(result.is_ok());
8062 }
8063
8064 #[test]
8065 fn query_injection_attempt_escaped() {
8066 use serde::Deserialize;
8067
8068 #[derive(Deserialize, PartialEq, Debug)]
8069 struct Params {
8070 name: String,
8071 }
8072
8073 let ctx = test_context();
8074 let mut req = Request::new(Method::Get, "/test");
8075 req.set_query(Some(
8077 "name=Robert%27%3B%20DROP%20TABLE%20users%3B--".to_string(),
8078 ));
8079
8080 let result = futures_executor::block_on(Query::<Params>::from_request(&ctx, &mut req));
8081 let Query(params) = result.unwrap();
8082 assert_eq!(params.name, "Robert'; DROP TABLE users;--");
8084 }
8085
8086 #[test]
8087 fn path_traversal_attempt() {
8088 let ctx = test_context();
8089 let mut req = Request::new(Method::Get, "/files/../../../etc/passwd");
8090 req.insert_extension(PathParams::from_pairs(vec![(
8091 "path".to_string(),
8092 "../../../etc/passwd".to_string(),
8093 )]));
8094
8095 let result = futures_executor::block_on(Path::<String>::from_request(&ctx, &mut req));
8096 let Path(path) = result.unwrap();
8097 assert_eq!(path, "../../../etc/passwd");
8099 }
8100
8101 #[test]
8102 fn json_with_script_tag_xss() {
8103 use serde::Deserialize;
8104
8105 #[derive(Deserialize, PartialEq, Debug)]
8106 struct Data {
8107 comment: String,
8108 }
8109
8110 let ctx = test_context();
8111 let mut req = Request::new(Method::Post, "/test");
8112 req.headers_mut()
8113 .insert("content-type", b"application/json".to_vec());
8114 req.set_body(Body::Bytes(
8115 b"{\"comment\": \"<script>alert('xss')</script>\"}".to_vec(),
8116 ));
8117
8118 let result = futures_executor::block_on(Json::<Data>::from_request(&ctx, &mut req));
8119 let Json(data) = result.unwrap();
8120 assert_eq!(data.comment, "<script>alert('xss')</script>");
8122 }
8123
8124 #[test]
8125 fn json_content_type_case_insensitive() {
8126 use serde::Deserialize;
8127
8128 #[derive(Deserialize, PartialEq, Debug)]
8129 struct Data {
8130 value: i32,
8131 }
8132
8133 for content_type in &[
8135 "APPLICATION/JSON",
8136 "Application/Json",
8137 "application/JSON",
8138 "APPLICATION/json",
8139 ] {
8140 let ctx = test_context();
8141 let mut req = Request::new(Method::Post, "/test");
8142 req.headers_mut()
8143 .insert("content-type", content_type.as_bytes().to_vec());
8144 req.set_body(Body::Bytes(b"{\"value\": 42}".to_vec()));
8145
8146 let result = futures_executor::block_on(Json::<Data>::from_request(&ctx, &mut req));
8147 assert!(result.is_ok(), "Failed for content-type: {}", content_type);
8148 }
8149 }
8150
8151 #[test]
8152 fn json_wrong_content_type_variants() {
8153 use serde::Deserialize;
8154
8155 #[derive(Deserialize)]
8156 #[allow(dead_code)]
8157 struct Data {
8158 value: i32,
8159 }
8160
8161 for content_type in &[
8163 "text/json",
8164 "text/plain",
8165 "application/xml",
8166 "application/x-json",
8167 ] {
8168 let ctx = test_context();
8169 let mut req = Request::new(Method::Post, "/test");
8170 req.headers_mut()
8171 .insert("content-type", content_type.as_bytes().to_vec());
8172 req.set_body(Body::Bytes(b"{\"value\": 42}".to_vec()));
8173
8174 let result = futures_executor::block_on(Json::<Data>::from_request(&ctx, &mut req));
8175 assert!(
8176 matches!(result, Err(JsonExtractError::UnsupportedMediaType { .. })),
8177 "Should reject content-type: {}",
8178 content_type
8179 );
8180 }
8181 }
8182
8183 #[test]
8184 fn json_content_type_rejects_near_miss_types() {
8185 use serde::Deserialize;
8186
8187 #[derive(Deserialize)]
8188 #[allow(dead_code)]
8189 struct Data {
8190 value: i32,
8191 }
8192
8193 for content_type in &[
8195 "application/jsonl",
8196 "application/json-seq",
8197 "application/json-patch",
8198 "application/jsonlines",
8199 ] {
8200 let ctx = test_context();
8201 let mut req = Request::new(Method::Post, "/test");
8202 req.headers_mut()
8203 .insert("content-type", content_type.as_bytes().to_vec());
8204 req.set_body(Body::Bytes(b"{\"value\": 42}".to_vec()));
8205
8206 let result = futures_executor::block_on(Json::<Data>::from_request(&ctx, &mut req));
8207 assert!(
8208 matches!(result, Err(JsonExtractError::UnsupportedMediaType { .. })),
8209 "Should reject content-type: {}",
8210 content_type
8211 );
8212 }
8213
8214 let ctx = test_context();
8216 let mut req = Request::new(Method::Post, "/test");
8217 req.headers_mut().insert(
8218 "content-type",
8219 b"application/json; charset=utf-8".to_vec(),
8220 );
8221 req.set_body(Body::Bytes(b"{\"value\": 42}".to_vec()));
8222
8223 let result = futures_executor::block_on(Json::<Data>::from_request(&ctx, &mut req));
8224 assert!(
8225 result.is_ok(),
8226 "application/json with charset parameter should be accepted"
8227 );
8228
8229 let ctx = test_context();
8231 let mut req = Request::new(Method::Post, "/test");
8232 req.headers_mut().insert(
8233 "content-type",
8234 b"application/vnd.api+json".to_vec(),
8235 );
8236 req.set_body(Body::Bytes(b"{\"value\": 42}".to_vec()));
8237
8238 let result = futures_executor::block_on(Json::<Data>::from_request(&ctx, &mut req));
8239 assert!(
8240 result.is_ok(),
8241 "application/vnd.api+json should be accepted"
8242 );
8243 }
8244
8245 #[test]
8246 fn form_url_decode_handles_multibyte_utf8() {
8247 let decoded = url_decode("caf%C3%A9");
8249 assert_eq!(decoded, "café");
8250
8251 let decoded = url_decode("%E6%97%A5%E6%9C%AC");
8253 assert_eq!(decoded, "日本");
8254
8255 let decoded = url_decode("hello+w%C3%B6rld");
8257 assert_eq!(decoded, "hello wörld");
8258
8259 let decoded = url_decode("hello+world");
8261 assert_eq!(decoded, "hello world");
8262 }
8263
8264 #[test]
8265 fn query_null_byte_handling() {
8266 use serde::Deserialize;
8267
8268 #[derive(Deserialize, PartialEq, Debug)]
8269 struct Params {
8270 name: String,
8271 }
8272
8273 let ctx = test_context();
8274 let mut req = Request::new(Method::Get, "/test");
8275 req.set_query(Some("name=test%00value".to_string()));
8277
8278 let result = futures_executor::block_on(Query::<Params>::from_request(&ctx, &mut req));
8279 let Query(params) = result.unwrap();
8280 assert_eq!(params.name, "test\0value");
8282 }
8283
8284 #[test]
8285 fn path_with_null_byte() {
8286 let ctx = test_context();
8287 let mut req = Request::new(Method::Get, "/files/test");
8288 req.insert_extension(PathParams::from_pairs(vec![(
8289 "filename".to_string(),
8290 "test\0.txt".to_string(),
8291 )]));
8292
8293 let result = futures_executor::block_on(Path::<String>::from_request(&ctx, &mut req));
8294 let Path(filename) = result.unwrap();
8295 assert_eq!(filename, "test\0.txt");
8296 }
8297
8298 #[test]
8299 fn json_number_precision() {
8300 use serde::Deserialize;
8301
8302 #[derive(Deserialize, PartialEq, Debug)]
8303 struct Data {
8304 big_int: i64,
8305 float_val: f64,
8306 }
8307
8308 let ctx = test_context();
8309 let mut req = Request::new(Method::Post, "/test");
8310 req.headers_mut()
8311 .insert("content-type", b"application/json".to_vec());
8312 req.set_body(Body::Bytes(
8314 b"{\"big_int\": 9007199254740993, \"float_val\": 3.141592653589793}".to_vec(),
8315 ));
8316
8317 let result = futures_executor::block_on(Json::<Data>::from_request(&ctx, &mut req));
8318 let Json(data) = result.unwrap();
8319 assert_eq!(data.big_int, 9007199254740993_i64);
8320 assert!((data.float_val - std::f64::consts::PI).abs() < 0.0000001);
8321 }
8322}
8323
8324#[cfg(test)]
8325mod valid_tests {
8326 use super::*;
8327 use crate::error::ValidationErrors;
8328 use crate::request::Method;
8329 use crate::validation::Validate;
8330
8331 fn test_context() -> RequestContext {
8332 let cx = asupersync::Cx::for_testing();
8333 RequestContext::new(cx, 12345)
8334 }
8335
8336 impl Validate for String {
8338 fn validate(&self) -> Result<(), Box<ValidationErrors>> {
8339 if self.is_empty() {
8340 let mut errors = ValidationErrors::new();
8341 errors.push(crate::error::ValidationError::new(
8342 crate::error::error_types::STRING_TOO_SHORT,
8343 crate::error::loc::body(),
8344 ));
8345 Err(Box::new(errors))
8346 } else if self.len() > 100 {
8347 let mut errors = ValidationErrors::new();
8348 errors.push(crate::error::ValidationError::new(
8349 crate::error::error_types::STRING_TOO_LONG,
8350 crate::error::loc::body(),
8351 ));
8352 Err(Box::new(errors))
8353 } else {
8354 Ok(())
8355 }
8356 }
8357 }
8358
8359 struct MockExtractor(String);
8361
8362 impl Deref for MockExtractor {
8363 type Target = String;
8364
8365 fn deref(&self) -> &Self::Target {
8366 &self.0
8367 }
8368 }
8369
8370 impl FromRequest for MockExtractor {
8371 type Error = HttpError;
8372
8373 async fn from_request(
8374 _ctx: &RequestContext,
8375 req: &mut Request,
8376 ) -> Result<Self, Self::Error> {
8377 let body = req.take_body();
8378 let bytes = body.into_bytes();
8379 let s = String::from_utf8(bytes).map_err(|_| HttpError::bad_request())?;
8380 Ok(MockExtractor(s))
8381 }
8382 }
8383
8384 #[test]
8385 fn valid_deref() {
8386 let valid = Valid(42i32);
8387 assert_eq!(*valid, 42);
8388 }
8389
8390 #[test]
8391 fn valid_into_inner() {
8392 let valid = Valid("hello".to_string());
8393 assert_eq!(valid.into_inner(), "hello");
8394 }
8395
8396 #[test]
8397 fn valid_extract_and_validate_success() {
8398 let ctx = test_context();
8399 let mut req = Request::new(Method::Post, "/test");
8400 req.set_body(Body::Bytes(b"valid string".to_vec()));
8401
8402 let result =
8403 futures_executor::block_on(Valid::<MockExtractor>::from_request(&ctx, &mut req));
8404 assert!(result.is_ok());
8405 let Valid(MockExtractor(inner)) = result.unwrap();
8406 assert_eq!(inner, "valid string");
8407 }
8408
8409 #[test]
8410 fn valid_extract_validation_fails_empty() {
8411 let ctx = test_context();
8412 let mut req = Request::new(Method::Post, "/test");
8413 req.set_body(Body::Bytes(b"".to_vec()));
8414
8415 let result =
8416 futures_executor::block_on(Valid::<MockExtractor>::from_request(&ctx, &mut req));
8417 assert!(matches!(result, Err(ValidExtractError::Validation(_))));
8418 }
8419
8420 #[test]
8421 fn valid_extract_validation_fails_too_long() {
8422 let ctx = test_context();
8423 let mut req = Request::new(Method::Post, "/test");
8424 let long_string = "a".repeat(101);
8426 req.set_body(Body::Bytes(long_string.into_bytes()));
8427
8428 let result =
8429 futures_executor::block_on(Valid::<MockExtractor>::from_request(&ctx, &mut req));
8430 assert!(matches!(result, Err(ValidExtractError::Validation(_))));
8431 }
8432
8433 #[test]
8434 fn valid_extract_error_display() {
8435 let extract_err: ValidExtractError<HttpError> =
8436 ValidExtractError::Extract(HttpError::bad_request());
8437 let display = format!("{}", extract_err);
8438 assert!(display.contains("Extraction failed"));
8439
8440 let validation_err: ValidExtractError<HttpError> =
8441 ValidExtractError::Validation(Box::new(ValidationErrors::new()));
8442 let display = format!("{}", validation_err);
8443 assert!(display.contains("validation error"));
8444 }
8445}