1use crate::context::RequestContext;
7use crate::error::{HttpError, ValidationError, ValidationErrors};
8use crate::request::{Body, Request};
9use crate::response::{IntoResponse, Response, ResponseBody};
10use serde::de::{
11 self, DeserializeOwned, Deserializer, IntoDeserializer, MapAccess, SeqAccess, Visitor,
12};
13use std::fmt;
14use std::future::Future;
15use std::ops::{Deref, DerefMut};
16
17pub trait FromRequest: Sized {
47 type Error: IntoResponse;
49
50 fn from_request(
57 ctx: &RequestContext,
58 req: &mut Request,
59 ) -> impl Future<Output = Result<Self, Self::Error>> + Send;
60}
61
62impl<T: FromRequest> FromRequest for Option<T> {
64 type Error = std::convert::Infallible;
65
66 async fn from_request(ctx: &RequestContext, req: &mut Request) -> Result<Self, Self::Error> {
67 Ok(T::from_request(ctx, req).await.ok())
68 }
69}
70
71impl FromRequest for RequestContext {
73 type Error = std::convert::Infallible;
74
75 async fn from_request(ctx: &RequestContext, _req: &mut Request) -> Result<Self, Self::Error> {
76 Ok(ctx.clone())
77 }
78}
79
80pub const DEFAULT_JSON_LIMIT: usize = 1024 * 1024;
86
87#[derive(Debug, Clone)]
99pub struct JsonConfig {
100 limit: usize,
102 content_type: Option<String>,
105}
106
107impl Default for JsonConfig {
108 fn default() -> Self {
109 Self {
110 limit: DEFAULT_JSON_LIMIT,
111 content_type: None,
112 }
113 }
114}
115
116impl JsonConfig {
117 #[must_use]
119 pub fn new() -> Self {
120 Self::default()
121 }
122
123 #[must_use]
125 pub fn limit(mut self, limit: usize) -> Self {
126 self.limit = limit;
127 self
128 }
129
130 #[must_use]
132 pub fn content_type(mut self, content_type: impl Into<String>) -> Self {
133 self.content_type = Some(content_type.into());
134 self
135 }
136
137 #[must_use]
139 pub fn get_limit(&self) -> usize {
140 self.limit
141 }
142}
143
144#[derive(Debug, Clone, Copy, Default)]
171pub struct Json<T>(pub T);
172
173impl<T> Json<T> {
174 pub fn into_inner(self) -> T {
176 self.0
177 }
178}
179
180impl<T> Deref for Json<T> {
181 type Target = T;
182
183 fn deref(&self) -> &Self::Target {
184 &self.0
185 }
186}
187
188impl<T> DerefMut for Json<T> {
189 fn deref_mut(&mut self) -> &mut Self::Target {
190 &mut self.0
191 }
192}
193
194impl<T: serde::Serialize> IntoResponse for Json<T> {
195 fn into_response(self) -> Response {
196 match serde_json::to_vec(&self.0) {
197 Ok(bytes) => Response::ok()
198 .header("content-type", b"application/json".to_vec())
199 .body(ResponseBody::Bytes(bytes)),
200 Err(e) => {
201 crate::error::ResponseValidationError::serialization_failed(e.to_string())
204 .into_response()
205 }
206 }
207 }
208}
209
210#[derive(Debug)]
212pub enum JsonExtractError {
213 UnsupportedMediaType {
215 actual: Option<String>,
217 },
218 PayloadTooLarge {
220 size: usize,
222 limit: usize,
224 },
225 DeserializeError {
227 message: String,
229 line: Option<usize>,
231 column: Option<usize>,
233 },
234 StreamingNotSupported,
236}
237
238impl std::fmt::Display for JsonExtractError {
239 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
240 match self {
241 Self::UnsupportedMediaType { actual } => {
242 if let Some(ct) = actual {
243 write!(f, "Expected Content-Type: application/json, got: {ct}")
244 } else {
245 write!(f, "Missing Content-Type header, expected application/json")
246 }
247 }
248 Self::PayloadTooLarge { size, limit } => {
249 write!(
250 f,
251 "Request body too large: {size} bytes exceeds {limit} byte limit"
252 )
253 }
254 Self::DeserializeError {
255 message,
256 line,
257 column,
258 } => {
259 if let (Some(l), Some(c)) = (line, column) {
260 write!(f, "JSON parse error at line {l}, column {c}: {message}")
261 } else {
262 write!(f, "JSON parse error: {message}")
263 }
264 }
265 Self::StreamingNotSupported => {
266 write!(
267 f,
268 "Streaming request bodies are not supported for JSON extraction"
269 )
270 }
271 }
272 }
273}
274
275impl std::error::Error for JsonExtractError {}
276
277impl IntoResponse for JsonExtractError {
278 fn into_response(self) -> crate::response::Response {
279 match self {
280 Self::UnsupportedMediaType { actual } => {
281 let detail = if let Some(ct) = actual {
282 format!("Expected Content-Type: application/json, got: {ct}")
283 } else {
284 "Missing Content-Type header, expected application/json".to_string()
285 };
286 HttpError::unsupported_media_type()
287 .with_detail(detail)
288 .into_response()
289 }
290 Self::PayloadTooLarge { size, limit } => HttpError::payload_too_large()
291 .with_detail(format!(
292 "Request body too large: {size} bytes exceeds {limit} byte limit"
293 ))
294 .into_response(),
295 Self::DeserializeError {
296 message,
297 line,
298 column,
299 } => {
300 let msg = if let (Some(l), Some(c)) = (line, column) {
302 format!("JSON parse error at line {l}, column {c}: {message}")
303 } else {
304 format!("JSON parse error: {message}")
305 };
306 ValidationErrors::single(ValidationError::json_invalid(
307 crate::error::loc::body(),
308 msg,
309 ))
310 .into_response()
311 }
312 Self::StreamingNotSupported => HttpError::bad_request()
313 .with_detail("Streaming request bodies are not supported for JSON extraction")
314 .into_response(),
315 }
316 }
317}
318
319impl<T: DeserializeOwned> FromRequest for Json<T> {
320 type Error = JsonExtractError;
321
322 async fn from_request(ctx: &RequestContext, req: &mut Request) -> Result<Self, Self::Error> {
323 let content_type = req
325 .headers()
326 .get("content-type")
327 .and_then(|v| std::str::from_utf8(v).ok());
328
329 let is_json = content_type.is_some_and(|ct| {
330 ct.get(..16)
332 .is_some_and(|prefix| prefix.eq_ignore_ascii_case("application/json"))
333 || (ct
334 .get(..12)
335 .is_some_and(|p| p.eq_ignore_ascii_case("application/"))
336 && ct
337 .as_bytes()
338 .windows(5)
339 .any(|w| w.eq_ignore_ascii_case(b"+json")))
340 });
341
342 if !is_json {
343 return Err(JsonExtractError::UnsupportedMediaType {
344 actual: content_type.map(String::from),
345 });
346 }
347
348 let _ = ctx.checkpoint();
350
351 let body = req.take_body();
353 let bytes = match body {
354 Body::Empty => Vec::new(),
355 Body::Bytes(b) => b,
356 Body::Stream(_) => {
357 return Err(JsonExtractError::StreamingNotSupported);
359 }
360 };
361
362 let limit = ctx.max_body_size();
366 if bytes.len() > limit {
367 return Err(JsonExtractError::PayloadTooLarge {
368 size: bytes.len(),
369 limit,
370 });
371 }
372
373 let _ = ctx.checkpoint();
375
376 let value =
380 serde_json::from_slice(&bytes).map_err(|e| JsonExtractError::DeserializeError {
381 message: e.to_string(),
382 line: Some(e.line()),
383 column: Some(e.column()),
384 })?;
385
386 let _ = ctx.checkpoint();
388
389 Ok(Json(value))
390 }
391}
392
393#[cfg(test)]
394mod tests {
395 use super::*;
396 use crate::request::Method;
397
398 fn test_context() -> RequestContext {
400 let cx = asupersync::Cx::for_testing();
401 RequestContext::new(cx, 12345)
402 }
403
404 fn json_request(body: &str) -> Request {
406 let mut req = Request::new(Method::Post, "/test");
407 req.headers_mut()
408 .insert("content-type", b"application/json".to_vec());
409 req.set_body(Body::Bytes(body.as_bytes().to_vec()));
410 req
411 }
412
413 #[test]
414 fn json_config_defaults() {
415 let config = JsonConfig::default();
416 assert_eq!(config.get_limit(), DEFAULT_JSON_LIMIT);
417 }
418
419 #[test]
420 fn json_config_custom() {
421 let config = JsonConfig::new().limit(1024);
422 assert_eq!(config.get_limit(), 1024);
423 }
424
425 #[test]
426 fn json_deref() {
427 let json = Json(42i32);
428 assert_eq!(*json, 42);
429 }
430
431 #[test]
432 fn json_into_inner() {
433 let json = Json("hello".to_string());
434 assert_eq!(json.into_inner(), "hello");
435 }
436
437 #[test]
438 fn json_extract_success() {
439 use serde::Deserialize;
440
441 #[derive(Deserialize, Debug, PartialEq)]
442 struct TestPayload {
443 name: String,
444 value: i32,
445 }
446
447 let ctx = test_context();
448 let mut req = json_request(r#"{"name": "test", "value": 42}"#);
449
450 let result = futures_executor::block_on(Json::<TestPayload>::from_request(&ctx, &mut req));
451 let Json(payload) = result.unwrap();
452 assert_eq!(payload.name, "test");
453 assert_eq!(payload.value, 42);
454 }
455
456 #[test]
457 fn json_extract_wrong_content_type() {
458 use serde::Deserialize;
459
460 #[derive(Deserialize)]
461 struct TestPayload {
462 #[allow(dead_code)]
463 name: String,
464 }
465
466 let ctx = test_context();
467 let mut req = Request::new(Method::Post, "/test");
468 req.headers_mut()
469 .insert("content-type", b"text/plain".to_vec());
470 req.set_body(Body::Bytes(b"{}".to_vec()));
471
472 let result = futures_executor::block_on(Json::<TestPayload>::from_request(&ctx, &mut req));
473 assert!(matches!(
474 result,
475 Err(JsonExtractError::UnsupportedMediaType { actual: Some(ct) })
476 if ct == "text/plain"
477 ));
478 }
479
480 #[test]
481 fn json_extract_missing_content_type() {
482 use serde::Deserialize;
483
484 #[derive(Deserialize)]
485 struct TestPayload {
486 #[allow(dead_code)]
487 name: String,
488 }
489
490 let ctx = test_context();
491 let mut req = Request::new(Method::Post, "/test");
492 req.set_body(Body::Bytes(b"{}".to_vec()));
493
494 let result = futures_executor::block_on(Json::<TestPayload>::from_request(&ctx, &mut req));
495 assert!(matches!(
496 result,
497 Err(JsonExtractError::UnsupportedMediaType { actual: None })
498 ));
499 }
500
501 #[test]
502 fn json_extract_invalid_json() {
503 use serde::Deserialize;
504
505 #[derive(Deserialize)]
506 struct TestPayload {
507 #[allow(dead_code)]
508 name: String,
509 }
510
511 let ctx = test_context();
512 let mut req = json_request(r#"{"name": invalid}"#);
513
514 let result = futures_executor::block_on(Json::<TestPayload>::from_request(&ctx, &mut req));
515 assert!(matches!(
516 result,
517 Err(JsonExtractError::DeserializeError { .. })
518 ));
519 }
520
521 #[test]
522 fn json_extract_application_json_charset() {
523 use serde::Deserialize;
524
525 #[derive(Deserialize, PartialEq, Debug)]
526 struct TestPayload {
527 value: i32,
528 }
529
530 let ctx = test_context();
531 let mut req = Request::new(Method::Post, "/test");
532 req.headers_mut()
533 .insert("content-type", b"application/json; charset=utf-8".to_vec());
534 req.set_body(Body::Bytes(b"{\"value\": 123}".to_vec()));
535
536 let result = futures_executor::block_on(Json::<TestPayload>::from_request(&ctx, &mut req));
537 let Json(payload) = result.unwrap();
538 assert_eq!(payload.value, 123);
539 }
540
541 #[test]
542 fn json_extract_vendor_json() {
543 use serde::Deserialize;
544
545 #[derive(Deserialize, PartialEq, Debug)]
546 struct TestPayload {
547 value: i32,
548 }
549
550 let ctx = test_context();
551 let mut req = Request::new(Method::Post, "/test");
552 req.headers_mut()
554 .insert("content-type", b"application/vnd.api+json".to_vec());
555 req.set_body(Body::Bytes(b"{\"value\": 456}".to_vec()));
556
557 let result = futures_executor::block_on(Json::<TestPayload>::from_request(&ctx, &mut req));
558 let Json(payload) = result.unwrap();
559 assert_eq!(payload.value, 456);
560 }
561
562 #[test]
563 fn json_error_display() {
564 let err = JsonExtractError::UnsupportedMediaType {
565 actual: Some("text/html".to_string()),
566 };
567 assert!(err.to_string().contains("text/html"));
568
569 let err = JsonExtractError::PayloadTooLarge {
570 size: 2000,
571 limit: 1000,
572 };
573 assert!(err.to_string().contains("2000"));
574 assert!(err.to_string().contains("1000"));
575
576 let err = JsonExtractError::DeserializeError {
577 message: "unexpected token".to_string(),
578 line: Some(1),
579 column: Some(10),
580 };
581 assert!(err.to_string().contains("line 1"));
582 assert!(err.to_string().contains("column 10"));
583 }
584}
585
586pub const DEFAULT_FORM_LIMIT: usize = 1024 * 1024;
592
593#[derive(Debug, Clone)]
595pub struct FormConfig {
596 limit: usize,
597}
598
599impl Default for FormConfig {
600 fn default() -> Self {
601 Self {
602 limit: DEFAULT_FORM_LIMIT,
603 }
604 }
605}
606
607impl FormConfig {
608 #[must_use]
609 pub fn new() -> Self {
610 Self::default()
611 }
612
613 #[must_use]
614 pub fn limit(mut self, limit: usize) -> Self {
615 self.limit = limit;
616 self
617 }
618
619 #[must_use]
620 pub fn get_limit(&self) -> usize {
621 self.limit
622 }
623}
624
625#[derive(Debug, Clone, Copy, Default)]
627pub struct Form<T>(pub T);
628
629impl<T> Form<T> {
630 pub fn new(value: T) -> Self {
631 Self(value)
632 }
633
634 pub fn into_inner(self) -> T {
635 self.0
636 }
637}
638
639impl<T> Deref for Form<T> {
640 type Target = T;
641 fn deref(&self) -> &Self::Target {
642 &self.0
643 }
644}
645
646impl<T> DerefMut for Form<T> {
647 fn deref_mut(&mut self) -> &mut Self::Target {
648 &mut self.0
649 }
650}
651
652#[derive(Debug)]
654pub enum FormExtractError {
655 UnsupportedMediaType { actual: Option<String> },
656 PayloadTooLarge { size: usize, limit: usize },
657 DeserializeError { message: String },
658 StreamingNotSupported,
659 InvalidUtf8,
660}
661
662impl std::fmt::Display for FormExtractError {
663 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
664 match self {
665 Self::UnsupportedMediaType { actual } => {
666 if let Some(ct) = actual {
667 write!(f, "Expected application/x-www-form-urlencoded, got: {ct}")
668 } else {
669 write!(f, "Missing Content-Type header")
670 }
671 }
672 Self::PayloadTooLarge { size, limit } => {
673 write!(f, "Body too large: {size} > {limit}")
674 }
675 Self::DeserializeError { message } => write!(f, "Form error: {message}"),
676 Self::StreamingNotSupported => write!(f, "Streaming not supported"),
677 Self::InvalidUtf8 => write!(f, "Invalid UTF-8"),
678 }
679 }
680}
681
682impl std::error::Error for FormExtractError {}
683
684impl IntoResponse for FormExtractError {
685 fn into_response(self) -> Response {
686 match &self {
687 FormExtractError::UnsupportedMediaType { .. } => {
688 HttpError::unsupported_media_type().into_response()
689 }
690 FormExtractError::PayloadTooLarge { size, limit } => HttpError::payload_too_large()
691 .with_detail(format!("Body {size} > {limit}"))
692 .into_response(),
693 FormExtractError::DeserializeError { message } => {
694 use crate::error::error_types;
695 ValidationErrors::single(
696 ValidationError::new(
697 error_types::VALUE_ERROR,
698 vec![crate::error::LocItem::field("body")],
699 )
700 .with_msg(message.clone()),
701 )
702 .into_response()
703 }
704 FormExtractError::StreamingNotSupported => HttpError::bad_request().into_response(),
705 FormExtractError::InvalidUtf8 => HttpError::bad_request()
706 .with_detail("Invalid UTF-8")
707 .into_response(),
708 }
709 }
710}
711
712impl<T: DeserializeOwned> FromRequest for Form<T> {
713 type Error = FormExtractError;
714
715 async fn from_request(ctx: &RequestContext, req: &mut Request) -> Result<Self, Self::Error> {
716 let ct = req
717 .headers()
718 .get("content-type")
719 .and_then(|v| std::str::from_utf8(v).ok());
720 let is_form = ct.is_some_and(|c| {
721 c.to_ascii_lowercase()
722 .starts_with("application/x-www-form-urlencoded")
723 });
724 if !is_form {
725 return Err(FormExtractError::UnsupportedMediaType {
726 actual: ct.map(String::from),
727 });
728 }
729 let _ = ctx.checkpoint();
730 let body = req.take_body();
731 let bytes = match body {
732 Body::Empty => Vec::new(),
733 Body::Bytes(b) => b,
734 Body::Stream(_) => return Err(FormExtractError::StreamingNotSupported),
735 };
736 let limit = ctx.max_body_size();
737 if bytes.len() > limit {
738 return Err(FormExtractError::PayloadTooLarge {
739 size: bytes.len(),
740 limit,
741 });
742 }
743 let _ = ctx.checkpoint();
744 let body_str = std::str::from_utf8(&bytes).map_err(|_| FormExtractError::InvalidUtf8)?;
745 let params = QueryParams::parse(body_str);
746 let value = T::deserialize(QueryDeserializer::new(¶ms)).map_err(|e| {
747 FormExtractError::DeserializeError {
748 message: e.to_string(),
749 }
750 })?;
751 let _ = ctx.checkpoint();
752 Ok(Form(value))
753 }
754}
755
756#[cfg(test)]
757mod form_tests {
758 use super::*;
759 use crate::request::Method;
760
761 fn test_context() -> RequestContext {
762 let cx = asupersync::Cx::for_testing();
763 RequestContext::new(cx, 12345)
764 }
765
766 fn form_request(body: &str) -> Request {
767 let mut req = Request::new(Method::Post, "/test");
768 req.headers_mut().insert(
769 "content-type",
770 b"application/x-www-form-urlencoded".to_vec(),
771 );
772 req.set_body(Body::Bytes(body.as_bytes().to_vec()));
773 req
774 }
775
776 #[test]
777 fn form_extract_success() {
778 use serde::Deserialize;
779 #[derive(Deserialize, Debug, PartialEq)]
780 struct Login {
781 username: String,
782 password: String,
783 }
784 let ctx = test_context();
785 let mut req = form_request("username=alice&password=secret");
786 let result = futures_executor::block_on(Form::<Login>::from_request(&ctx, &mut req));
787 let Form(form) = result.unwrap();
788 assert_eq!(form.username, "alice");
789 assert_eq!(form.password, "secret");
790 }
791
792 #[test]
793 fn form_wrong_content_type() {
794 use serde::Deserialize;
795 #[derive(Deserialize)]
796 struct T {
797 #[allow(dead_code)]
798 x: String,
799 }
800 let ctx = test_context();
801 let mut req = Request::new(Method::Post, "/test");
802 req.headers_mut()
803 .insert("content-type", b"application/json".to_vec());
804 req.set_body(Body::Bytes(b"x=1".to_vec()));
805 let result = futures_executor::block_on(Form::<T>::from_request(&ctx, &mut req));
806 assert!(matches!(
807 result,
808 Err(FormExtractError::UnsupportedMediaType { .. })
809 ));
810 }
811}
812
813pub const DEFAULT_RAW_BODY_LIMIT: usize = 2 * 1024 * 1024;
819
820#[derive(Debug, Clone)]
822pub struct RawBodyConfig {
823 limit: usize,
825}
826
827impl Default for RawBodyConfig {
828 fn default() -> Self {
829 Self {
830 limit: DEFAULT_RAW_BODY_LIMIT,
831 }
832 }
833}
834
835impl RawBodyConfig {
836 #[must_use]
838 pub fn new() -> Self {
839 Self::default()
840 }
841
842 #[must_use]
844 pub fn limit(mut self, size: usize) -> Self {
845 self.limit = size;
846 self
847 }
848
849 #[must_use]
851 pub fn get_limit(&self) -> usize {
852 self.limit
853 }
854}
855
856#[derive(Debug)]
858pub enum RawBodyError {
859 PayloadTooLarge { size: usize, limit: usize },
861 StreamingNotSupported,
863 InvalidUtf8,
865}
866
867impl std::fmt::Display for RawBodyError {
868 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
869 match self {
870 Self::PayloadTooLarge { size, limit } => {
871 write!(
872 f,
873 "Payload too large: {size} bytes exceeds limit of {limit}"
874 )
875 }
876 Self::StreamingNotSupported => {
877 write!(f, "Streaming body not supported for raw extraction")
878 }
879 Self::InvalidUtf8 => write!(f, "Body is not valid UTF-8"),
880 }
881 }
882}
883
884impl std::error::Error for RawBodyError {}
885
886impl IntoResponse for RawBodyError {
887 fn into_response(self) -> Response {
888 match &self {
889 RawBodyError::PayloadTooLarge { size, limit } => HttpError::payload_too_large()
890 .with_detail(format!("Body {size} bytes > {limit} limit"))
891 .into_response(),
892 RawBodyError::StreamingNotSupported => HttpError::bad_request()
893 .with_detail("Streaming body not supported")
894 .into_response(),
895 RawBodyError::InvalidUtf8 => HttpError::bad_request()
896 .with_detail("Body is not valid UTF-8")
897 .into_response(),
898 }
899 }
900}
901
902#[derive(Debug, Clone)]
917pub struct Bytes(pub Vec<u8>);
918
919impl Bytes {
920 #[must_use]
922 pub fn new(data: Vec<u8>) -> Self {
923 Self(data)
924 }
925
926 #[must_use]
928 pub fn len(&self) -> usize {
929 self.0.len()
930 }
931
932 #[must_use]
934 pub fn is_empty(&self) -> bool {
935 self.0.is_empty()
936 }
937
938 #[must_use]
940 pub fn as_slice(&self) -> &[u8] {
941 &self.0
942 }
943
944 #[must_use]
946 pub fn into_inner(self) -> Vec<u8> {
947 self.0
948 }
949}
950
951impl AsRef<[u8]> for Bytes {
952 fn as_ref(&self) -> &[u8] {
953 &self.0
954 }
955}
956
957impl std::ops::Deref for Bytes {
958 type Target = [u8];
959
960 fn deref(&self) -> &Self::Target {
961 &self.0
962 }
963}
964
965impl From<Vec<u8>> for Bytes {
966 fn from(data: Vec<u8>) -> Self {
967 Self(data)
968 }
969}
970
971impl From<Bytes> for Vec<u8> {
972 fn from(bytes: Bytes) -> Self {
973 bytes.0
974 }
975}
976
977impl FromRequest for Bytes {
978 type Error = RawBodyError;
979
980 async fn from_request(ctx: &RequestContext, req: &mut Request) -> Result<Self, Self::Error> {
981 let _ = ctx.checkpoint();
982
983 let body = req.take_body();
984 let bytes = match body {
985 Body::Empty => Vec::new(),
986 Body::Bytes(b) => b,
987 Body::Stream(_) => return Err(RawBodyError::StreamingNotSupported),
988 };
989
990 let limit = req
992 .get_extension::<RawBodyConfig>()
993 .map(|c| c.limit)
994 .unwrap_or(DEFAULT_RAW_BODY_LIMIT);
995
996 if bytes.len() > limit {
997 return Err(RawBodyError::PayloadTooLarge {
998 size: bytes.len(),
999 limit,
1000 });
1001 }
1002
1003 let _ = ctx.checkpoint();
1004 Ok(Bytes(bytes))
1005 }
1006}
1007
1008#[derive(Debug, Clone)]
1023pub struct StringBody(pub String);
1024
1025impl StringBody {
1026 #[must_use]
1028 pub fn new(data: String) -> Self {
1029 Self(data)
1030 }
1031
1032 #[must_use]
1034 pub fn len(&self) -> usize {
1035 self.0.len()
1036 }
1037
1038 #[must_use]
1040 pub fn is_empty(&self) -> bool {
1041 self.0.is_empty()
1042 }
1043
1044 #[must_use]
1046 pub fn as_str(&self) -> &str {
1047 &self.0
1048 }
1049
1050 #[must_use]
1052 pub fn into_inner(self) -> String {
1053 self.0
1054 }
1055}
1056
1057impl AsRef<str> for StringBody {
1058 fn as_ref(&self) -> &str {
1059 &self.0
1060 }
1061}
1062
1063impl std::ops::Deref for StringBody {
1064 type Target = str;
1065
1066 fn deref(&self) -> &Self::Target {
1067 &self.0
1068 }
1069}
1070
1071impl From<String> for StringBody {
1072 fn from(data: String) -> Self {
1073 Self(data)
1074 }
1075}
1076
1077impl From<StringBody> for String {
1078 fn from(text: StringBody) -> Self {
1079 text.0
1080 }
1081}
1082
1083impl FromRequest for StringBody {
1084 type Error = RawBodyError;
1085
1086 async fn from_request(ctx: &RequestContext, req: &mut Request) -> Result<Self, Self::Error> {
1087 let bytes = Bytes::from_request(ctx, req).await?;
1088
1089 let text = String::from_utf8(bytes.into_inner()).map_err(|_| RawBodyError::InvalidUtf8)?;
1090
1091 Ok(StringBody(text))
1092 }
1093}
1094
1095#[cfg(test)]
1096mod raw_body_tests {
1097 use super::*;
1098 use crate::request::Method;
1099
1100 fn test_context() -> RequestContext {
1101 RequestContext::new(asupersync::Cx::for_testing(), 1)
1102 }
1103
1104 #[test]
1105 fn test_bytes_extract_success() {
1106 let ctx = test_context();
1107 let mut req = Request::new(Method::Post, "/upload");
1108 req.set_body(Body::Bytes(b"hello world".to_vec()));
1109
1110 let result = futures_executor::block_on(Bytes::from_request(&ctx, &mut req));
1111 let bytes = result.unwrap();
1112 assert_eq!(bytes.as_slice(), b"hello world");
1113 assert_eq!(bytes.len(), 11);
1114 }
1115
1116 #[test]
1117 fn test_bytes_extract_empty() {
1118 let ctx = test_context();
1119 let mut req = Request::new(Method::Post, "/upload");
1120 req.set_body(Body::Empty);
1121
1122 let result = futures_executor::block_on(Bytes::from_request(&ctx, &mut req));
1123 let bytes = result.unwrap();
1124 assert!(bytes.is_empty());
1125 }
1126
1127 #[test]
1128 fn test_bytes_size_limit() {
1129 let ctx = test_context();
1130 let mut req = Request::new(Method::Post, "/upload");
1131 let large_body = vec![0u8; DEFAULT_RAW_BODY_LIMIT + 1];
1132 req.set_body(Body::Bytes(large_body));
1133
1134 let result = futures_executor::block_on(Bytes::from_request(&ctx, &mut req));
1135 assert!(matches!(result, Err(RawBodyError::PayloadTooLarge { .. })));
1136 }
1137
1138 #[test]
1139 fn test_bytes_custom_limit() {
1140 let ctx = test_context();
1141 let mut req = Request::new(Method::Post, "/upload");
1142 req.insert_extension(RawBodyConfig::new().limit(100));
1143 req.set_body(Body::Bytes(vec![0u8; 150]));
1144
1145 let result = futures_executor::block_on(Bytes::from_request(&ctx, &mut req));
1146 assert!(matches!(
1147 result,
1148 Err(RawBodyError::PayloadTooLarge {
1149 size: 150,
1150 limit: 100
1151 })
1152 ));
1153 }
1154
1155 #[test]
1156 fn test_bytes_deref() {
1157 let bytes = Bytes::new(b"test".to_vec());
1158 assert_eq!(&*bytes, b"test");
1159 }
1160
1161 #[test]
1162 fn test_bytes_from_vec() {
1163 let bytes: Bytes = vec![1, 2, 3].into();
1164 assert_eq!(bytes.as_slice(), &[1, 2, 3]);
1165 }
1166
1167 #[test]
1168 fn test_string_body_extract_success() {
1169 let ctx = test_context();
1170 let mut req = Request::new(Method::Post, "/text");
1171 req.set_body(Body::Bytes(b"hello world".to_vec()));
1172
1173 let result = futures_executor::block_on(StringBody::from_request(&ctx, &mut req));
1174 let text = result.unwrap();
1175 assert_eq!(text.as_str(), "hello world");
1176 assert_eq!(text.len(), 11);
1177 }
1178
1179 #[test]
1180 fn test_string_body_extract_empty() {
1181 let ctx = test_context();
1182 let mut req = Request::new(Method::Post, "/text");
1183 req.set_body(Body::Empty);
1184
1185 let result = futures_executor::block_on(StringBody::from_request(&ctx, &mut req));
1186 let text = result.unwrap();
1187 assert!(text.is_empty());
1188 }
1189
1190 #[test]
1191 fn test_string_body_invalid_utf8() {
1192 let ctx = test_context();
1193 let mut req = Request::new(Method::Post, "/text");
1194 req.set_body(Body::Bytes(vec![0xff, 0xfe, 0x00, 0x01]));
1196
1197 let result = futures_executor::block_on(StringBody::from_request(&ctx, &mut req));
1198 assert!(matches!(result, Err(RawBodyError::InvalidUtf8)));
1199 }
1200
1201 #[test]
1202 fn test_string_body_deref() {
1203 let text = StringBody::new("hello".to_string());
1204 assert_eq!(&*text, "hello");
1205 }
1206
1207 #[test]
1208 fn test_string_body_from_string() {
1209 let text: StringBody = "test".to_string().into();
1210 assert_eq!(text.as_str(), "test");
1211 }
1212
1213 #[test]
1214 fn test_string_body_unicode() {
1215 let ctx = test_context();
1216 let mut req = Request::new(Method::Post, "/text");
1217 req.set_body(Body::Bytes("こんにちは世界 🌍".as_bytes().to_vec()));
1218
1219 let result = futures_executor::block_on(StringBody::from_request(&ctx, &mut req));
1220 let text = result.unwrap();
1221 assert_eq!(text.as_str(), "こんにちは世界 🌍");
1222 }
1223}
1224
1225pub const DEFAULT_MULTIPART_FILE_SIZE: usize = 10 * 1024 * 1024;
1231
1232pub const DEFAULT_MULTIPART_TOTAL_SIZE: usize = 50 * 1024 * 1024;
1234
1235pub const DEFAULT_MULTIPART_MAX_FIELDS: usize = 100;
1237
1238#[derive(Debug, Clone)]
1251pub struct MultipartConfig {
1252 max_file_size: usize,
1253 max_total_size: usize,
1254 max_fields: usize,
1255}
1256
1257impl Default for MultipartConfig {
1258 fn default() -> Self {
1259 Self {
1260 max_file_size: DEFAULT_MULTIPART_FILE_SIZE,
1261 max_total_size: DEFAULT_MULTIPART_TOTAL_SIZE,
1262 max_fields: DEFAULT_MULTIPART_MAX_FIELDS,
1263 }
1264 }
1265}
1266
1267impl MultipartConfig {
1268 #[must_use]
1270 pub fn new() -> Self {
1271 Self::default()
1272 }
1273
1274 #[must_use]
1276 pub fn max_file_size(mut self, size: usize) -> Self {
1277 self.max_file_size = size;
1278 self
1279 }
1280
1281 #[must_use]
1283 pub fn max_total_size(mut self, size: usize) -> Self {
1284 self.max_total_size = size;
1285 self
1286 }
1287
1288 #[must_use]
1290 pub fn max_fields(mut self, count: usize) -> Self {
1291 self.max_fields = count;
1292 self
1293 }
1294
1295 #[must_use]
1297 pub fn get_max_file_size(&self) -> usize {
1298 self.max_file_size
1299 }
1300
1301 #[must_use]
1303 pub fn get_max_total_size(&self) -> usize {
1304 self.max_total_size
1305 }
1306
1307 #[must_use]
1309 pub fn get_max_fields(&self) -> usize {
1310 self.max_fields
1311 }
1312}
1313
1314#[derive(Debug, Clone)]
1326pub struct UploadedFile {
1327 field_name: String,
1329 filename: String,
1331 content_type: String,
1333 data: Vec<u8>,
1335}
1336
1337impl UploadedFile {
1338 #[must_use]
1340 pub fn new(field_name: String, filename: String, content_type: String, data: Vec<u8>) -> Self {
1341 Self {
1342 field_name,
1343 filename,
1344 content_type,
1345 data,
1346 }
1347 }
1348
1349 #[must_use]
1351 pub fn field_name(&self) -> &str {
1352 &self.field_name
1353 }
1354
1355 #[must_use]
1357 pub fn filename(&self) -> &str {
1358 &self.filename
1359 }
1360
1361 #[must_use]
1363 pub fn content_type(&self) -> &str {
1364 &self.content_type
1365 }
1366
1367 #[must_use]
1369 pub fn data(&self) -> &[u8] {
1370 &self.data
1371 }
1372
1373 #[must_use]
1375 pub fn into_data(self) -> Vec<u8> {
1376 self.data
1377 }
1378
1379 #[must_use]
1381 pub fn size(&self) -> usize {
1382 self.data.len()
1383 }
1384
1385 #[must_use]
1387 pub fn extension(&self) -> Option<&str> {
1388 self.filename
1389 .rsplit('.')
1390 .next()
1391 .filter(|ext| !ext.is_empty() && *ext != self.filename)
1392 }
1393
1394 #[must_use]
1398 pub fn text(&self) -> Option<&str> {
1399 std::str::from_utf8(&self.data).ok()
1400 }
1401}
1402
1403#[derive(Debug)]
1405pub enum MultipartExtractError {
1406 UnsupportedMediaType { actual: Option<String> },
1408 MissingBoundary,
1410 FileTooLarge { size: usize, limit: usize },
1412 TotalTooLarge { size: usize, limit: usize },
1414 TooManyFields { count: usize, limit: usize },
1416 InvalidFormat { detail: String },
1418 StreamingNotSupported,
1420 FileNotFound { field_name: String },
1422}
1423
1424impl std::fmt::Display for MultipartExtractError {
1425 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1426 match self {
1427 Self::UnsupportedMediaType { actual } => {
1428 if let Some(ct) = actual {
1429 write!(f, "Expected multipart/form-data, got: {ct}")
1430 } else {
1431 write!(f, "Expected multipart/form-data, got empty Content-Type")
1432 }
1433 }
1434 Self::MissingBoundary => write!(f, "Missing boundary in multipart Content-Type"),
1435 Self::FileTooLarge { size, limit } => {
1436 write!(f, "File too large: {size} bytes exceeds limit of {limit}")
1437 }
1438 Self::TotalTooLarge { size, limit } => {
1439 write!(
1440 f,
1441 "Total upload too large: {size} bytes exceeds limit of {limit}"
1442 )
1443 }
1444 Self::TooManyFields { count, limit } => {
1445 write!(f, "Too many fields: {count} exceeds limit of {limit}")
1446 }
1447 Self::InvalidFormat { detail } => {
1448 write!(f, "Invalid multipart format: {detail}")
1449 }
1450 Self::StreamingNotSupported => {
1451 write!(f, "Streaming body not supported for multipart extraction")
1452 }
1453 Self::FileNotFound { field_name } => {
1454 write!(f, "No file found with field name '{field_name}'")
1455 }
1456 }
1457 }
1458}
1459
1460impl std::error::Error for MultipartExtractError {}
1461
1462impl IntoResponse for MultipartExtractError {
1463 fn into_response(self) -> Response {
1464 match &self {
1465 MultipartExtractError::UnsupportedMediaType { .. } => {
1466 HttpError::unsupported_media_type().into_response()
1467 }
1468 MultipartExtractError::MissingBoundary => HttpError::bad_request()
1469 .with_detail("Missing boundary in multipart Content-Type")
1470 .into_response(),
1471 MultipartExtractError::FileTooLarge { size, limit } => HttpError::payload_too_large()
1472 .with_detail(format!("File {size} bytes > {limit} limit"))
1473 .into_response(),
1474 MultipartExtractError::TotalTooLarge { size, limit } => HttpError::payload_too_large()
1475 .with_detail(format!("Total {size} bytes > {limit} limit"))
1476 .into_response(),
1477 MultipartExtractError::TooManyFields { count, limit } => HttpError::bad_request()
1478 .with_detail(format!("Too many fields: {count} > {limit}"))
1479 .into_response(),
1480 MultipartExtractError::InvalidFormat { detail } => HttpError::bad_request()
1481 .with_detail(format!("Invalid multipart: {detail}"))
1482 .into_response(),
1483 MultipartExtractError::StreamingNotSupported => HttpError::bad_request()
1484 .with_detail("Streaming body not supported")
1485 .into_response(),
1486 MultipartExtractError::FileNotFound { field_name } => {
1487 use crate::error::error_types;
1488 ValidationErrors::single(
1489 ValidationError::new(
1490 error_types::VALUE_ERROR,
1491 vec![crate::error::LocItem::field(field_name)],
1492 )
1493 .with_msg(format!("Required file '{field_name}' not found")),
1494 )
1495 .into_response()
1496 }
1497 }
1498 }
1499}
1500
1501#[derive(Debug, Clone)]
1517pub struct Multipart {
1518 parts: Vec<MultipartPart>,
1519}
1520
1521#[derive(Debug, Clone)]
1523pub struct MultipartPart {
1524 pub name: String,
1526 pub filename: Option<String>,
1528 pub content_type: Option<String>,
1530 pub data: Vec<u8>,
1532}
1533
1534impl Multipart {
1535 #[must_use]
1537 pub fn from_parts(parts: Vec<MultipartPart>) -> Self {
1538 Self { parts }
1539 }
1540
1541 #[must_use]
1543 pub fn parts(&self) -> &[MultipartPart] {
1544 &self.parts
1545 }
1546
1547 #[must_use]
1549 pub fn get_field(&self, name: &str) -> Option<&str> {
1550 self.parts
1551 .iter()
1552 .find(|p| p.name == name && p.filename.is_none())
1553 .and_then(|p| std::str::from_utf8(&p.data).ok())
1554 }
1555
1556 #[must_use]
1558 pub fn get_file(&self, name: &str) -> Option<UploadedFile> {
1559 self.parts
1560 .iter()
1561 .find(|p| p.name == name && p.filename.is_some())
1562 .map(|p| {
1563 UploadedFile::new(
1564 p.name.clone(),
1565 p.filename.clone().unwrap_or_default(),
1566 p.content_type
1567 .clone()
1568 .unwrap_or_else(|| "application/octet-stream".to_string()),
1569 p.data.clone(),
1570 )
1571 })
1572 }
1573
1574 #[must_use]
1576 pub fn files(&self) -> Vec<UploadedFile> {
1577 self.parts
1578 .iter()
1579 .filter(|p| p.filename.is_some())
1580 .map(|p| {
1581 UploadedFile::new(
1582 p.name.clone(),
1583 p.filename.clone().unwrap_or_default(),
1584 p.content_type
1585 .clone()
1586 .unwrap_or_else(|| "application/octet-stream".to_string()),
1587 p.data.clone(),
1588 )
1589 })
1590 .collect()
1591 }
1592
1593 #[must_use]
1595 pub fn get_files(&self, name: &str) -> Vec<UploadedFile> {
1596 self.parts
1597 .iter()
1598 .filter(|p| p.name == name && p.filename.is_some())
1599 .map(|p| {
1600 UploadedFile::new(
1601 p.name.clone(),
1602 p.filename.clone().unwrap_or_default(),
1603 p.content_type
1604 .clone()
1605 .unwrap_or_else(|| "application/octet-stream".to_string()),
1606 p.data.clone(),
1607 )
1608 })
1609 .collect()
1610 }
1611
1612 #[must_use]
1614 pub fn fields(&self) -> Vec<(&str, &str)> {
1615 self.parts
1616 .iter()
1617 .filter(|p| p.filename.is_none())
1618 .filter_map(|p| Some((p.name.as_str(), std::str::from_utf8(&p.data).ok()?)))
1619 .collect()
1620 }
1621
1622 #[must_use]
1624 pub fn has_field(&self, name: &str) -> bool {
1625 self.parts.iter().any(|p| p.name == name)
1626 }
1627
1628 #[must_use]
1630 pub fn len(&self) -> usize {
1631 self.parts.len()
1632 }
1633
1634 #[must_use]
1636 pub fn is_empty(&self) -> bool {
1637 self.parts.is_empty()
1638 }
1639}
1640
1641impl FromRequest for Multipart {
1642 type Error = MultipartExtractError;
1643
1644 async fn from_request(ctx: &RequestContext, req: &mut Request) -> Result<Self, Self::Error> {
1645 let content_type = req
1647 .headers()
1648 .get("content-type")
1649 .and_then(|v| std::str::from_utf8(v).ok())
1650 .map(String::from);
1651
1652 let ct = content_type
1653 .as_deref()
1654 .ok_or(MultipartExtractError::UnsupportedMediaType { actual: None })?;
1655
1656 if !ct.to_ascii_lowercase().starts_with("multipart/form-data") {
1657 return Err(MultipartExtractError::UnsupportedMediaType {
1658 actual: Some(ct.to_string()),
1659 });
1660 }
1661
1662 let boundary = parse_multipart_boundary(ct)?;
1664
1665 let _ = ctx.checkpoint();
1666
1667 let body = req.take_body();
1669 let bytes = match body {
1670 Body::Empty => Vec::new(),
1671 Body::Bytes(b) => b,
1672 Body::Stream(_) => return Err(MultipartExtractError::StreamingNotSupported),
1673 };
1674
1675 let config = req
1677 .get_extension::<MultipartConfig>()
1678 .cloned()
1679 .unwrap_or_default();
1680
1681 let _ = ctx.checkpoint();
1682
1683 let parts = parse_multipart_body(&bytes, &boundary, &config)?;
1685
1686 let _ = ctx.checkpoint();
1687
1688 Ok(Multipart::from_parts(parts))
1689 }
1690}
1691
1692#[derive(Debug, Clone)]
1708pub struct File(pub UploadedFile);
1709
1710impl File {
1711 #[must_use]
1713 pub fn into_inner(self) -> UploadedFile {
1714 self.0
1715 }
1716
1717 #[must_use]
1719 pub fn inner(&self) -> &UploadedFile {
1720 &self.0
1721 }
1722}
1723
1724impl std::ops::Deref for File {
1725 type Target = UploadedFile;
1726
1727 fn deref(&self) -> &Self::Target {
1728 &self.0
1729 }
1730}
1731
1732#[derive(Debug, Clone)]
1734pub struct FileConfig {
1735 field_name: String,
1736}
1737
1738impl Default for FileConfig {
1739 fn default() -> Self {
1740 Self {
1741 field_name: "file".to_string(),
1742 }
1743 }
1744}
1745
1746impl FileConfig {
1747 #[must_use]
1749 pub fn new(field_name: impl Into<String>) -> Self {
1750 Self {
1751 field_name: field_name.into(),
1752 }
1753 }
1754
1755 #[must_use]
1757 pub fn field_name(&self) -> &str {
1758 &self.field_name
1759 }
1760}
1761
1762impl FromRequest for File {
1763 type Error = MultipartExtractError;
1764
1765 async fn from_request(ctx: &RequestContext, req: &mut Request) -> Result<Self, Self::Error> {
1766 let field_name = req
1767 .get_extension::<FileConfig>()
1768 .map(|c| c.field_name.clone())
1769 .unwrap_or_else(|| "file".to_string());
1770
1771 let multipart = Multipart::from_request(ctx, req).await?;
1772
1773 let file = multipart
1774 .get_file(&field_name)
1775 .ok_or(MultipartExtractError::FileNotFound { field_name })?;
1776
1777 Ok(File(file))
1778 }
1779}
1780
1781const MAX_MULTIPART_BOUNDARY_LEN: usize = 70;
1786
1787fn parse_multipart_boundary(content_type: &str) -> Result<String, MultipartExtractError> {
1789 for part in content_type.split(';') {
1790 let part = part.trim();
1791 if let Some(boundary) = part
1792 .strip_prefix("boundary=")
1793 .or_else(|| part.strip_prefix("BOUNDARY="))
1794 {
1795 let boundary = boundary.trim_matches('"').trim_matches('\'');
1796 if boundary.is_empty() {
1797 return Err(MultipartExtractError::MissingBoundary);
1798 }
1799 if boundary.len() > MAX_MULTIPART_BOUNDARY_LEN {
1801 return Err(MultipartExtractError::InvalidFormat {
1802 detail: format!(
1803 "boundary too long: {} chars (max {})",
1804 boundary.len(),
1805 MAX_MULTIPART_BOUNDARY_LEN
1806 ),
1807 });
1808 }
1809 return Ok(boundary.to_string());
1810 }
1811 }
1812 Err(MultipartExtractError::MissingBoundary)
1813}
1814
1815fn parse_multipart_body(
1817 body: &[u8],
1818 boundary: &str,
1819 config: &MultipartConfig,
1820) -> Result<Vec<MultipartPart>, MultipartExtractError> {
1821 let boundary_bytes = format!("--{boundary}").into_bytes();
1822 let mut parts = Vec::new();
1823 let mut total_size = 0usize;
1824 let mut pos = 0;
1825
1826 pos = find_bytes(body, &boundary_bytes, pos).ok_or_else(|| {
1828 MultipartExtractError::InvalidFormat {
1829 detail: "no boundary found".to_string(),
1830 }
1831 })?;
1832
1833 loop {
1834 if parts.len() >= config.max_fields {
1836 return Err(MultipartExtractError::TooManyFields {
1837 count: parts.len() + 1,
1838 limit: config.max_fields,
1839 });
1840 }
1841
1842 let boundary_end = pos + boundary_bytes.len();
1844 if boundary_end + 2 <= body.len() && body[boundary_end..boundary_end + 2] == *b"--" {
1845 break;
1846 }
1847
1848 pos = boundary_end;
1850 if pos + 2 > body.len() {
1851 return Err(MultipartExtractError::InvalidFormat {
1852 detail: "unexpected end after boundary".to_string(),
1853 });
1854 }
1855 if body[pos..pos + 2] != *b"\r\n" {
1856 return Err(MultipartExtractError::InvalidFormat {
1857 detail: "expected CRLF after boundary".to_string(),
1858 });
1859 }
1860 pos += 2;
1861
1862 let mut name = None;
1864 let mut filename = None;
1865 let mut content_type = None;
1866
1867 loop {
1868 let line_end =
1869 find_crlf(body, pos).ok_or_else(|| MultipartExtractError::InvalidFormat {
1870 detail: "unterminated headers".to_string(),
1871 })?;
1872
1873 let line = &body[pos..line_end];
1874 if line.is_empty() {
1875 pos = line_end + 2;
1876 break;
1877 }
1878
1879 if let Ok(line_str) = std::str::from_utf8(line) {
1880 if let Some((header_name, header_value)) = line_str.split_once(':') {
1881 let header_name = header_name.trim().to_ascii_lowercase();
1882 let header_value = header_value.trim();
1883
1884 if header_name == "content-disposition" {
1885 (name, filename) = parse_content_disposition_header(header_value);
1886 } else if header_name == "content-type" {
1887 content_type = Some(header_value.to_string());
1888 }
1889 }
1890 }
1891
1892 pos = line_end + 2;
1893 }
1894
1895 let name = name.ok_or_else(|| MultipartExtractError::InvalidFormat {
1896 detail: "missing Content-Disposition name".to_string(),
1897 })?;
1898
1899 let data_end = find_bytes(body, &boundary_bytes, pos).ok_or_else(|| {
1901 MultipartExtractError::InvalidFormat {
1902 detail: "missing closing boundary".to_string(),
1903 }
1904 })?;
1905
1906 let data = if data_end >= 2 && body[data_end - 2..data_end] == *b"\r\n" {
1908 &body[pos..data_end - 2]
1909 } else {
1910 &body[pos..data_end]
1911 };
1912
1913 if filename.is_some() && data.len() > config.max_file_size {
1915 return Err(MultipartExtractError::FileTooLarge {
1916 size: data.len(),
1917 limit: config.max_file_size,
1918 });
1919 }
1920
1921 total_size += data.len();
1922 if total_size > config.max_total_size {
1923 return Err(MultipartExtractError::TotalTooLarge {
1924 size: total_size,
1925 limit: config.max_total_size,
1926 });
1927 }
1928
1929 parts.push(MultipartPart {
1930 name,
1931 filename,
1932 content_type,
1933 data: data.to_vec(),
1934 });
1935
1936 pos = data_end;
1937 }
1938
1939 Ok(parts)
1940}
1941
1942fn find_bytes(data: &[u8], needle: &[u8], start: usize) -> Option<usize> {
1947 if needle.is_empty() {
1948 return Some(start);
1949 }
1950 if start >= data.len() {
1951 return None;
1952 }
1953 let first_byte = needle[0];
1956 let search_slice = &data[start..];
1957
1958 if needle.len() == 1 {
1959 return memchr::memchr(first_byte, search_slice).map(|pos| pos + start);
1961 }
1962
1963 let mut search_offset = 0;
1965 while let Some(pos) = memchr::memchr(first_byte, &search_slice[search_offset..]) {
1966 let abs_pos = start + search_offset + pos;
1967 if abs_pos + needle.len() > data.len() {
1968 return None;
1969 }
1970 if data[abs_pos..].starts_with(needle) {
1971 return Some(abs_pos);
1972 }
1973 search_offset += pos + 1;
1974 }
1975 None
1976}
1977
1978fn find_crlf(data: &[u8], start: usize) -> Option<usize> {
1982 if start >= data.len().saturating_sub(1) {
1983 return None;
1984 }
1985 let search_slice = &data[start..];
1986
1987 let mut search_offset = 0;
1989 while let Some(pos) = memchr::memchr(b'\r', &search_slice[search_offset..]) {
1990 let abs_pos = start + search_offset + pos;
1991 if abs_pos + 1 < data.len() && data[abs_pos + 1] == b'\n' {
1992 return Some(abs_pos);
1993 }
1994 search_offset += pos + 1;
1995 if search_offset >= search_slice.len().saturating_sub(1) {
1996 break;
1997 }
1998 }
1999 None
2000}
2001
2002fn parse_content_disposition_header(value: &str) -> (Option<String>, Option<String>) {
2004 let mut name = None;
2005 let mut filename = None;
2006
2007 for part in value.split(';') {
2008 let part = part.trim();
2009 if let Some(n) = part
2010 .strip_prefix("name=")
2011 .or_else(|| part.strip_prefix("NAME="))
2012 {
2013 name = Some(unquote_param(n));
2014 } else if let Some(f) = part
2015 .strip_prefix("filename=")
2016 .or_else(|| part.strip_prefix("FILENAME="))
2017 {
2018 filename = Some(unquote_param(f));
2019 }
2020 }
2021
2022 (name, filename)
2023}
2024
2025fn unquote_param(s: &str) -> String {
2030 let s = s.trim();
2031 if s.len() >= 2
2033 && ((s.starts_with('"') && s.ends_with('"')) || (s.starts_with('\'') && s.ends_with('\'')))
2034 {
2035 s[1..s.len() - 1].to_string()
2036 } else {
2037 s.to_string()
2038 }
2039}
2040
2041#[cfg(test)]
2042mod multipart_tests {
2043 use super::*;
2044 use crate::RequestContext;
2045 use crate::request::Method;
2046 use asupersync::Cx;
2047
2048 fn test_context() -> RequestContext {
2049 RequestContext::new(Cx::for_testing(), 1)
2050 }
2051
2052 #[test]
2053 fn test_parse_boundary() {
2054 let ct = "multipart/form-data; boundary=----WebKit";
2055 let boundary = parse_multipart_boundary(ct).unwrap();
2056 assert_eq!(boundary, "----WebKit");
2057 }
2058
2059 #[test]
2060 fn test_parse_boundary_quoted() {
2061 let ct = r#"multipart/form-data; boundary="simple""#;
2062 let boundary = parse_multipart_boundary(ct).unwrap();
2063 assert_eq!(boundary, "simple");
2064 }
2065
2066 #[test]
2067 fn test_parse_boundary_missing() {
2068 let ct = "multipart/form-data";
2069 let result = parse_multipart_boundary(ct);
2070 assert!(matches!(
2071 result,
2072 Err(MultipartExtractError::MissingBoundary)
2073 ));
2074 }
2075
2076 #[test]
2077 fn test_parse_boundary_too_long() {
2078 let long_boundary = "x".repeat(100);
2080 let ct = format!("multipart/form-data; boundary={long_boundary}");
2081 let result = parse_multipart_boundary(&ct);
2082 assert!(
2083 matches!(result, Err(MultipartExtractError::InvalidFormat { .. })),
2084 "Expected InvalidFormat for boundary > 70 chars"
2085 );
2086 }
2087
2088 #[test]
2089 fn test_parse_boundary_max_length() {
2090 let boundary = "x".repeat(70);
2092 let ct = format!("multipart/form-data; boundary={boundary}");
2093 let result = parse_multipart_boundary(&ct);
2094 assert!(result.is_ok(), "70-char boundary should be accepted");
2095 assert_eq!(result.unwrap(), boundary);
2096 }
2097
2098 #[test]
2099 fn test_parse_simple_form() {
2100 let boundary = "----boundary";
2101 let body = concat!(
2102 "------boundary\r\n",
2103 "Content-Disposition: form-data; name=\"field1\"\r\n",
2104 "\r\n",
2105 "value1\r\n",
2106 "------boundary\r\n",
2107 "Content-Disposition: form-data; name=\"field2\"\r\n",
2108 "\r\n",
2109 "value2\r\n",
2110 "------boundary--\r\n"
2111 );
2112
2113 let config = MultipartConfig::default();
2114 let parts = parse_multipart_body(body.as_bytes(), boundary, &config).unwrap();
2115
2116 assert_eq!(parts.len(), 2);
2117 assert_eq!(parts[0].name, "field1");
2118 assert_eq!(std::str::from_utf8(&parts[0].data).unwrap(), "value1");
2119 assert_eq!(parts[1].name, "field2");
2120 assert_eq!(std::str::from_utf8(&parts[1].data).unwrap(), "value2");
2121 }
2122
2123 #[test]
2124 fn test_parse_file_upload() {
2125 let boundary = "----boundary";
2126 let body = concat!(
2127 "------boundary\r\n",
2128 "Content-Disposition: form-data; name=\"file\"; filename=\"test.txt\"\r\n",
2129 "Content-Type: text/plain\r\n",
2130 "\r\n",
2131 "Hello!\r\n",
2132 "------boundary--\r\n"
2133 );
2134
2135 let config = MultipartConfig::default();
2136 let parts = parse_multipart_body(body.as_bytes(), boundary, &config).unwrap();
2137
2138 assert_eq!(parts.len(), 1);
2139 assert_eq!(parts[0].name, "file");
2140 assert_eq!(parts[0].filename, Some("test.txt".to_string()));
2141 assert_eq!(parts[0].content_type, Some("text/plain".to_string()));
2142 assert_eq!(std::str::from_utf8(&parts[0].data).unwrap(), "Hello!");
2143 }
2144
2145 #[test]
2146 fn test_multipart_extractor() {
2147 let boundary = "----boundary";
2148 let body = concat!(
2149 "------boundary\r\n",
2150 "Content-Disposition: form-data; name=\"name\"\r\n",
2151 "\r\n",
2152 "John\r\n",
2153 "------boundary\r\n",
2154 "Content-Disposition: form-data; name=\"avatar\"; filename=\"pic.jpg\"\r\n",
2155 "Content-Type: image/jpeg\r\n",
2156 "\r\n",
2157 "JPEG\r\n",
2158 "------boundary--\r\n"
2159 );
2160
2161 let config = MultipartConfig::default();
2162 let parts = parse_multipart_body(body.as_bytes(), boundary, &config).unwrap();
2163 let form = Multipart::from_parts(parts);
2164
2165 assert_eq!(form.get_field("name"), Some("John"));
2166 let file = form.get_file("avatar").unwrap();
2167 assert_eq!(file.filename(), "pic.jpg");
2168 assert_eq!(file.content_type(), "image/jpeg");
2169 }
2170
2171 #[test]
2172 fn test_file_size_limit() {
2173 let boundary = "----boundary";
2174 let large = "x".repeat(1000);
2175 let body = format!(
2176 "------boundary\r\n\
2177 Content-Disposition: form-data; name=\"file\"; filename=\"big.txt\"\r\n\
2178 \r\n\
2179 {}\r\n\
2180 ------boundary--\r\n",
2181 large
2182 );
2183
2184 let config = MultipartConfig::default().max_file_size(100);
2185 let result = parse_multipart_body(body.as_bytes(), boundary, &config);
2186
2187 assert!(matches!(
2188 result,
2189 Err(MultipartExtractError::FileTooLarge { .. })
2190 ));
2191 }
2192
2193 #[test]
2194 fn test_total_size_limit() {
2195 let boundary = "----boundary";
2196 let data = "x".repeat(500);
2197 let body = format!(
2198 "------boundary\r\n\
2199 Content-Disposition: form-data; name=\"f1\"; filename=\"a.txt\"\r\n\
2200 \r\n\
2201 {}\r\n\
2202 ------boundary\r\n\
2203 Content-Disposition: form-data; name=\"f2\"; filename=\"b.txt\"\r\n\
2204 \r\n\
2205 {}\r\n\
2206 ------boundary--\r\n",
2207 data, data
2208 );
2209
2210 let config = MultipartConfig::default()
2211 .max_file_size(1000)
2212 .max_total_size(800);
2213 let result = parse_multipart_body(body.as_bytes(), boundary, &config);
2214
2215 assert!(matches!(
2216 result,
2217 Err(MultipartExtractError::TotalTooLarge { .. })
2218 ));
2219 }
2220
2221 #[test]
2222 fn test_field_count_limit() {
2223 let boundary = "----boundary";
2224 let mut body = String::new();
2225 for i in 0..5 {
2226 body.push_str(&format!(
2227 "------boundary\r\n\
2228 Content-Disposition: form-data; name=\"f{}\"\r\n\
2229 \r\n\
2230 v{}\r\n",
2231 i, i
2232 ));
2233 }
2234 body.push_str("------boundary--\r\n");
2235
2236 let config = MultipartConfig::default().max_fields(3);
2237 let result = parse_multipart_body(body.as_bytes(), boundary, &config);
2238
2239 assert!(matches!(
2240 result,
2241 Err(MultipartExtractError::TooManyFields { .. })
2242 ));
2243 }
2244
2245 #[test]
2246 fn test_uploaded_file_extension() {
2247 let file = UploadedFile::new(
2248 "doc".to_string(),
2249 "report.pdf".to_string(),
2250 "application/pdf".to_string(),
2251 vec![],
2252 );
2253 assert_eq!(file.extension(), Some("pdf"));
2254
2255 let no_ext = UploadedFile::new(
2256 "doc".to_string(),
2257 "README".to_string(),
2258 "text/plain".to_string(),
2259 vec![],
2260 );
2261 assert_eq!(no_ext.extension(), None);
2262 }
2263
2264 #[test]
2265 fn test_multipart_from_request_wrong_content_type() {
2266 let ctx = test_context();
2267 let mut req = Request::new(Method::Post, "/upload");
2268 req.headers_mut()
2269 .insert("content-type", b"application/json".to_vec());
2270 req.set_body(Body::Bytes(b"{}".to_vec()));
2271
2272 let result = futures_executor::block_on(Multipart::from_request(&ctx, &mut req));
2273 assert!(matches!(
2274 result,
2275 Err(MultipartExtractError::UnsupportedMediaType { .. })
2276 ));
2277 }
2278
2279 #[test]
2280 fn test_file_extractor() {
2281 let boundary = "----boundary";
2282 let body = concat!(
2283 "------boundary\r\n",
2284 "Content-Disposition: form-data; name=\"file\"; filename=\"doc.pdf\"\r\n",
2285 "Content-Type: application/pdf\r\n",
2286 "\r\n",
2287 "PDF content\r\n",
2288 "------boundary--\r\n"
2289 );
2290
2291 let config = MultipartConfig::default();
2292 let parts = parse_multipart_body(body.as_bytes(), boundary, &config).unwrap();
2293 let form = Multipart::from_parts(parts);
2294
2295 let file = form.get_file("file").unwrap();
2296 assert_eq!(file.filename(), "doc.pdf");
2297 assert_eq!(file.content_type(), "application/pdf");
2298 assert_eq!(file.text(), Some("PDF content"));
2299 }
2300
2301 #[test]
2302 fn test_multiple_files() {
2303 let boundary = "----boundary";
2304 let body = concat!(
2305 "------boundary\r\n",
2306 "Content-Disposition: form-data; name=\"files\"; filename=\"a.txt\"\r\n",
2307 "\r\n",
2308 "file a\r\n",
2309 "------boundary\r\n",
2310 "Content-Disposition: form-data; name=\"files\"; filename=\"b.txt\"\r\n",
2311 "\r\n",
2312 "file b\r\n",
2313 "------boundary--\r\n"
2314 );
2315
2316 let config = MultipartConfig::default();
2317 let parts = parse_multipart_body(body.as_bytes(), boundary, &config).unwrap();
2318 let form = Multipart::from_parts(parts);
2319
2320 let files = form.get_files("files");
2321 assert_eq!(files.len(), 2);
2322 assert_eq!(files[0].filename(), "a.txt");
2323 assert_eq!(files[1].filename(), "b.txt");
2324 }
2325
2326 #[test]
2331 fn test_unquote_param_normal_quoted() {
2332 assert_eq!(unquote_param("\"hello\""), "hello");
2333 assert_eq!(unquote_param("'hello'"), "hello");
2334 }
2335
2336 #[test]
2337 fn test_unquote_param_empty_quotes() {
2338 assert_eq!(unquote_param("\"\""), "");
2340 assert_eq!(unquote_param("''"), "");
2341 }
2342
2343 #[test]
2344 fn test_unquote_param_single_char_no_panic() {
2345 assert_eq!(unquote_param("\""), "\"");
2347 assert_eq!(unquote_param("'"), "'");
2348 }
2349
2350 #[test]
2351 fn test_unquote_param_unquoted() {
2352 assert_eq!(unquote_param("hello"), "hello");
2353 assert_eq!(unquote_param(""), "");
2354 }
2355
2356 #[test]
2357 fn test_unquote_param_mismatched_quotes() {
2358 assert_eq!(unquote_param("\"hello'"), "\"hello'");
2360 assert_eq!(unquote_param("'hello\""), "'hello\"");
2361 }
2362
2363 #[test]
2364 fn test_unquote_param_whitespace() {
2365 assert_eq!(unquote_param(" \"hello\" "), "hello");
2367 assert_eq!(unquote_param(" 'hello' "), "hello");
2368 }
2369}
2370
2371#[derive(Debug, Clone, Default)]
2386pub struct PathParams(pub Vec<(String, String)>);
2387
2388impl PathParams {
2389 #[must_use]
2391 pub fn new() -> Self {
2392 Self(Vec::new())
2393 }
2394
2395 #[must_use]
2397 pub fn from_pairs(pairs: Vec<(String, String)>) -> Self {
2398 Self(pairs)
2399 }
2400
2401 #[must_use]
2403 pub fn get(&self, name: &str) -> Option<&str> {
2404 self.0
2405 .iter()
2406 .find(|(n, _)| n == name)
2407 .map(|(_, v)| v.as_str())
2408 }
2409
2410 #[must_use]
2412 pub fn as_slice(&self) -> &[(String, String)] {
2413 &self.0
2414 }
2415
2416 #[must_use]
2418 pub fn is_empty(&self) -> bool {
2419 self.0.is_empty()
2420 }
2421
2422 #[must_use]
2424 pub fn len(&self) -> usize {
2425 self.0.len()
2426 }
2427}
2428
2429#[derive(Debug, Clone, Copy, Default)]
2479pub struct Path<T>(pub T);
2480
2481impl<T> Path<T> {
2482 pub fn into_inner(self) -> T {
2484 self.0
2485 }
2486}
2487
2488impl<T> Deref for Path<T> {
2489 type Target = T;
2490
2491 fn deref(&self) -> &Self::Target {
2492 &self.0
2493 }
2494}
2495
2496impl<T> DerefMut for Path<T> {
2497 fn deref_mut(&mut self) -> &mut Self::Target {
2498 &mut self.0
2499 }
2500}
2501
2502#[derive(Debug)]
2504pub enum PathExtractError {
2505 MissingPathParams,
2508 MissingParam {
2510 name: String,
2512 },
2513 InvalidValue {
2515 name: String,
2517 value: String,
2519 expected: &'static str,
2521 message: String,
2523 },
2524 DeserializeError {
2526 message: String,
2528 },
2529}
2530
2531impl fmt::Display for PathExtractError {
2532 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
2533 match self {
2534 Self::MissingPathParams => {
2535 write!(f, "Path parameters not available in request")
2536 }
2537 Self::MissingParam { name } => {
2538 write!(f, "Missing path parameter: {name}")
2539 }
2540 Self::InvalidValue {
2541 name,
2542 value,
2543 expected,
2544 message,
2545 } => {
2546 write!(
2547 f,
2548 "Invalid value for path parameter '{name}': expected {expected}, got '{value}': {message}"
2549 )
2550 }
2551 Self::DeserializeError { message } => {
2552 write!(f, "Path deserialization error: {message}")
2553 }
2554 }
2555 }
2556}
2557
2558impl std::error::Error for PathExtractError {}
2559
2560impl IntoResponse for PathExtractError {
2561 fn into_response(self) -> crate::response::Response {
2562 match self {
2563 Self::MissingPathParams => {
2564 HttpError::internal()
2566 .with_detail("Path parameters not available")
2567 .into_response()
2568 }
2569 Self::MissingParam { name } => ValidationErrors::single(
2570 ValidationError::missing(crate::error::loc::path(&name))
2571 .with_msg("Path parameter is required"),
2572 )
2573 .into_response(),
2574 Self::InvalidValue {
2575 name,
2576 value,
2577 expected,
2578 message,
2579 } => ValidationErrors::single(
2580 ValidationError::type_error(crate::error::loc::path(&name), &expected)
2581 .with_msg(format!("Expected {expected}: {message}"))
2582 .with_input(serde_json::Value::String(value)),
2583 )
2584 .into_response(),
2585 Self::DeserializeError { message } => ValidationErrors::single(
2586 ValidationError::new(
2587 crate::error::error_types::VALUE_ERROR,
2588 vec![crate::error::LocItem::field("path")],
2589 )
2590 .with_msg(message),
2591 )
2592 .into_response(),
2593 }
2594 }
2595}
2596
2597impl<T: DeserializeOwned> FromRequest for Path<T> {
2598 type Error = PathExtractError;
2599
2600 async fn from_request(_ctx: &RequestContext, req: &mut Request) -> Result<Self, Self::Error> {
2601 let params = req
2603 .get_extension::<PathParams>()
2604 .ok_or(PathExtractError::MissingPathParams)?
2605 .clone();
2606
2607 let value = T::deserialize(PathDeserializer::new(¶ms))?;
2609
2610 Ok(Path(value))
2611 }
2612}
2613
2614struct PathDeserializer<'de> {
2625 params: &'de PathParams,
2626}
2627
2628impl<'de> PathDeserializer<'de> {
2629 fn new(params: &'de PathParams) -> Self {
2630 Self { params }
2631 }
2632}
2633
2634impl<'de> Deserializer<'de> for PathDeserializer<'de> {
2635 type Error = PathExtractError;
2636
2637 fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
2638 where
2639 V: Visitor<'de>,
2640 {
2641 self.deserialize_map(visitor)
2643 }
2644
2645 fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value, Self::Error>
2646 where
2647 V: Visitor<'de>,
2648 {
2649 let value = self.get_single_value()?;
2650 let b = value
2651 .parse::<bool>()
2652 .map_err(|_| PathExtractError::InvalidValue {
2653 name: self.get_first_name(),
2654 value: value.to_string(),
2655 expected: "boolean",
2656 message: "expected 'true' or 'false'".to_string(),
2657 })?;
2658 visitor.visit_bool(b)
2659 }
2660
2661 fn deserialize_i8<V>(self, visitor: V) -> Result<V::Value, Self::Error>
2662 where
2663 V: Visitor<'de>,
2664 {
2665 let value = self.get_single_value()?;
2666 let n = value
2667 .parse::<i8>()
2668 .map_err(|e| PathExtractError::InvalidValue {
2669 name: self.get_first_name(),
2670 value: value.to_string(),
2671 expected: "i8",
2672 message: e.to_string(),
2673 })?;
2674 visitor.visit_i8(n)
2675 }
2676
2677 fn deserialize_i16<V>(self, visitor: V) -> Result<V::Value, Self::Error>
2678 where
2679 V: Visitor<'de>,
2680 {
2681 let value = self.get_single_value()?;
2682 let n = value
2683 .parse::<i16>()
2684 .map_err(|e| PathExtractError::InvalidValue {
2685 name: self.get_first_name(),
2686 value: value.to_string(),
2687 expected: "i16",
2688 message: e.to_string(),
2689 })?;
2690 visitor.visit_i16(n)
2691 }
2692
2693 fn deserialize_i32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
2694 where
2695 V: Visitor<'de>,
2696 {
2697 let value = self.get_single_value()?;
2698 let n = value
2699 .parse::<i32>()
2700 .map_err(|e| PathExtractError::InvalidValue {
2701 name: self.get_first_name(),
2702 value: value.to_string(),
2703 expected: "i32",
2704 message: e.to_string(),
2705 })?;
2706 visitor.visit_i32(n)
2707 }
2708
2709 fn deserialize_i64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
2710 where
2711 V: Visitor<'de>,
2712 {
2713 let value = self.get_single_value()?;
2714 let n = value
2715 .parse::<i64>()
2716 .map_err(|e| PathExtractError::InvalidValue {
2717 name: self.get_first_name(),
2718 value: value.to_string(),
2719 expected: "i64",
2720 message: e.to_string(),
2721 })?;
2722 visitor.visit_i64(n)
2723 }
2724
2725 fn deserialize_i128<V>(self, visitor: V) -> Result<V::Value, Self::Error>
2726 where
2727 V: Visitor<'de>,
2728 {
2729 let value = self.get_single_value()?;
2730 let n = value
2731 .parse::<i128>()
2732 .map_err(|e| PathExtractError::InvalidValue {
2733 name: self.get_first_name(),
2734 value: value.to_string(),
2735 expected: "i128",
2736 message: e.to_string(),
2737 })?;
2738 visitor.visit_i128(n)
2739 }
2740
2741 fn deserialize_u8<V>(self, visitor: V) -> Result<V::Value, Self::Error>
2742 where
2743 V: Visitor<'de>,
2744 {
2745 let value = self.get_single_value()?;
2746 let n = value
2747 .parse::<u8>()
2748 .map_err(|e| PathExtractError::InvalidValue {
2749 name: self.get_first_name(),
2750 value: value.to_string(),
2751 expected: "u8",
2752 message: e.to_string(),
2753 })?;
2754 visitor.visit_u8(n)
2755 }
2756
2757 fn deserialize_u16<V>(self, visitor: V) -> Result<V::Value, Self::Error>
2758 where
2759 V: Visitor<'de>,
2760 {
2761 let value = self.get_single_value()?;
2762 let n = value
2763 .parse::<u16>()
2764 .map_err(|e| PathExtractError::InvalidValue {
2765 name: self.get_first_name(),
2766 value: value.to_string(),
2767 expected: "u16",
2768 message: e.to_string(),
2769 })?;
2770 visitor.visit_u16(n)
2771 }
2772
2773 fn deserialize_u32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
2774 where
2775 V: Visitor<'de>,
2776 {
2777 let value = self.get_single_value()?;
2778 let n = value
2779 .parse::<u32>()
2780 .map_err(|e| PathExtractError::InvalidValue {
2781 name: self.get_first_name(),
2782 value: value.to_string(),
2783 expected: "u32",
2784 message: e.to_string(),
2785 })?;
2786 visitor.visit_u32(n)
2787 }
2788
2789 fn deserialize_u64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
2790 where
2791 V: Visitor<'de>,
2792 {
2793 let value = self.get_single_value()?;
2794 let n = value
2795 .parse::<u64>()
2796 .map_err(|e| PathExtractError::InvalidValue {
2797 name: self.get_first_name(),
2798 value: value.to_string(),
2799 expected: "u64",
2800 message: e.to_string(),
2801 })?;
2802 visitor.visit_u64(n)
2803 }
2804
2805 fn deserialize_u128<V>(self, visitor: V) -> Result<V::Value, Self::Error>
2806 where
2807 V: Visitor<'de>,
2808 {
2809 let value = self.get_single_value()?;
2810 let n = value
2811 .parse::<u128>()
2812 .map_err(|e| PathExtractError::InvalidValue {
2813 name: self.get_first_name(),
2814 value: value.to_string(),
2815 expected: "u128",
2816 message: e.to_string(),
2817 })?;
2818 visitor.visit_u128(n)
2819 }
2820
2821 fn deserialize_f32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
2822 where
2823 V: Visitor<'de>,
2824 {
2825 let value = self.get_single_value()?;
2826 let n = value
2827 .parse::<f32>()
2828 .map_err(|e| PathExtractError::InvalidValue {
2829 name: self.get_first_name(),
2830 value: value.to_string(),
2831 expected: "f32",
2832 message: e.to_string(),
2833 })?;
2834 visitor.visit_f32(n)
2835 }
2836
2837 fn deserialize_f64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
2838 where
2839 V: Visitor<'de>,
2840 {
2841 let value = self.get_single_value()?;
2842 let n = value
2843 .parse::<f64>()
2844 .map_err(|e| PathExtractError::InvalidValue {
2845 name: self.get_first_name(),
2846 value: value.to_string(),
2847 expected: "f64",
2848 message: e.to_string(),
2849 })?;
2850 visitor.visit_f64(n)
2851 }
2852
2853 fn deserialize_char<V>(self, visitor: V) -> Result<V::Value, Self::Error>
2854 where
2855 V: Visitor<'de>,
2856 {
2857 let value = self.get_single_value()?;
2858 let mut chars = value.chars();
2859 let c = chars.next().ok_or_else(|| PathExtractError::InvalidValue {
2860 name: self.get_first_name(),
2861 value: value.to_string(),
2862 expected: "char",
2863 message: "empty string".to_string(),
2864 })?;
2865 if chars.next().is_some() {
2866 return Err(PathExtractError::InvalidValue {
2867 name: self.get_first_name(),
2868 value: value.to_string(),
2869 expected: "char",
2870 message: "expected single character".to_string(),
2871 });
2872 }
2873 visitor.visit_char(c)
2874 }
2875
2876 fn deserialize_str<V>(self, visitor: V) -> Result<V::Value, Self::Error>
2877 where
2878 V: Visitor<'de>,
2879 {
2880 let value = self.get_single_value()?;
2881 visitor.visit_str(value)
2882 }
2883
2884 fn deserialize_string<V>(self, visitor: V) -> Result<V::Value, Self::Error>
2885 where
2886 V: Visitor<'de>,
2887 {
2888 let value = self.get_single_value()?;
2889 visitor.visit_string(value.to_string())
2890 }
2891
2892 fn deserialize_bytes<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
2893 where
2894 V: Visitor<'de>,
2895 {
2896 Err(PathExtractError::DeserializeError {
2897 message: "bytes deserialization not supported for path parameters".to_string(),
2898 })
2899 }
2900
2901 fn deserialize_byte_buf<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
2902 where
2903 V: Visitor<'de>,
2904 {
2905 Err(PathExtractError::DeserializeError {
2906 message: "byte_buf deserialization not supported for path parameters".to_string(),
2907 })
2908 }
2909
2910 fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Self::Error>
2911 where
2912 V: Visitor<'de>,
2913 {
2914 visitor.visit_some(self)
2916 }
2917
2918 fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value, Self::Error>
2919 where
2920 V: Visitor<'de>,
2921 {
2922 visitor.visit_unit()
2923 }
2924
2925 fn deserialize_unit_struct<V>(
2926 self,
2927 _name: &'static str,
2928 visitor: V,
2929 ) -> Result<V::Value, Self::Error>
2930 where
2931 V: Visitor<'de>,
2932 {
2933 visitor.visit_unit()
2934 }
2935
2936 fn deserialize_newtype_struct<V>(
2937 self,
2938 _name: &'static str,
2939 visitor: V,
2940 ) -> Result<V::Value, Self::Error>
2941 where
2942 V: Visitor<'de>,
2943 {
2944 visitor.visit_newtype_struct(self)
2945 }
2946
2947 fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value, Self::Error>
2948 where
2949 V: Visitor<'de>,
2950 {
2951 visitor.visit_seq(PathSeqAccess::new(self.params))
2952 }
2953
2954 fn deserialize_tuple<V>(self, _len: usize, visitor: V) -> Result<V::Value, Self::Error>
2955 where
2956 V: Visitor<'de>,
2957 {
2958 visitor.visit_seq(PathSeqAccess::new(self.params))
2959 }
2960
2961 fn deserialize_tuple_struct<V>(
2962 self,
2963 _name: &'static str,
2964 _len: usize,
2965 visitor: V,
2966 ) -> Result<V::Value, Self::Error>
2967 where
2968 V: Visitor<'de>,
2969 {
2970 visitor.visit_seq(PathSeqAccess::new(self.params))
2971 }
2972
2973 fn deserialize_map<V>(self, visitor: V) -> Result<V::Value, Self::Error>
2974 where
2975 V: Visitor<'de>,
2976 {
2977 visitor.visit_map(PathMapAccess::new(self.params))
2978 }
2979
2980 fn deserialize_struct<V>(
2981 self,
2982 _name: &'static str,
2983 _fields: &'static [&'static str],
2984 visitor: V,
2985 ) -> Result<V::Value, Self::Error>
2986 where
2987 V: Visitor<'de>,
2988 {
2989 visitor.visit_map(PathMapAccess::new(self.params))
2990 }
2991
2992 fn deserialize_enum<V>(
2993 self,
2994 _name: &'static str,
2995 _variants: &'static [&'static str],
2996 visitor: V,
2997 ) -> Result<V::Value, Self::Error>
2998 where
2999 V: Visitor<'de>,
3000 {
3001 let value = self.get_single_value()?;
3002 visitor.visit_enum(value.into_deserializer())
3003 }
3004
3005 fn deserialize_identifier<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3006 where
3007 V: Visitor<'de>,
3008 {
3009 self.deserialize_str(visitor)
3010 }
3011
3012 fn deserialize_ignored_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3013 where
3014 V: Visitor<'de>,
3015 {
3016 visitor.visit_unit()
3017 }
3018}
3019
3020impl PathDeserializer<'_> {
3021 fn get_single_value(&self) -> Result<&str, PathExtractError> {
3022 self.params
3023 .0
3024 .first()
3025 .map(|(_, v)| v.as_str())
3026 .ok_or_else(|| PathExtractError::DeserializeError {
3027 message: "no path parameters available".to_string(),
3028 })
3029 }
3030
3031 fn get_first_name(&self) -> String {
3032 self.params
3033 .0
3034 .first()
3035 .map_or_else(|| "unknown".to_string(), |(n, _)| n.clone())
3036 }
3037}
3038
3039impl de::Error for PathExtractError {
3040 fn custom<T: fmt::Display>(msg: T) -> Self {
3041 PathExtractError::DeserializeError {
3042 message: msg.to_string(),
3043 }
3044 }
3045}
3046
3047struct PathSeqAccess<'de> {
3049 params: &'de PathParams,
3050 index: usize,
3051}
3052
3053impl<'de> PathSeqAccess<'de> {
3054 fn new(params: &'de PathParams) -> Self {
3055 Self { params, index: 0 }
3056 }
3057}
3058
3059impl<'de> SeqAccess<'de> for PathSeqAccess<'de> {
3060 type Error = PathExtractError;
3061
3062 fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>, Self::Error>
3063 where
3064 T: de::DeserializeSeed<'de>,
3065 {
3066 if self.index >= self.params.0.len() {
3067 return Ok(None);
3068 }
3069
3070 let (name, value) = &self.params.0[self.index];
3071 self.index += 1;
3072
3073 seed.deserialize(PathValueDeserializer::new(name, value))
3074 .map(Some)
3075 }
3076}
3077
3078struct PathMapAccess<'de> {
3080 params: &'de PathParams,
3081 index: usize,
3082}
3083
3084impl<'de> PathMapAccess<'de> {
3085 fn new(params: &'de PathParams) -> Self {
3086 Self { params, index: 0 }
3087 }
3088}
3089
3090impl<'de> MapAccess<'de> for PathMapAccess<'de> {
3091 type Error = PathExtractError;
3092
3093 fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>, Self::Error>
3094 where
3095 K: de::DeserializeSeed<'de>,
3096 {
3097 if self.index >= self.params.0.len() {
3098 return Ok(None);
3099 }
3100
3101 let (name, _) = &self.params.0[self.index];
3102 seed.deserialize(name.as_str().into_deserializer())
3103 .map(Some)
3104 }
3105
3106 fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value, Self::Error>
3107 where
3108 V: de::DeserializeSeed<'de>,
3109 {
3110 let (name, value) = &self.params.0[self.index];
3111 self.index += 1;
3112
3113 seed.deserialize(PathValueDeserializer::new(name, value))
3114 }
3115}
3116
3117struct PathValueDeserializer<'de> {
3119 name: &'de str,
3120 value: &'de str,
3121}
3122
3123impl<'de> PathValueDeserializer<'de> {
3124 fn new(name: &'de str, value: &'de str) -> Self {
3125 Self { name, value }
3126 }
3127}
3128
3129impl<'de> Deserializer<'de> for PathValueDeserializer<'de> {
3130 type Error = PathExtractError;
3131
3132 fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3133 where
3134 V: Visitor<'de>,
3135 {
3136 visitor.visit_str(self.value)
3138 }
3139
3140 fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3141 where
3142 V: Visitor<'de>,
3143 {
3144 let b = self
3145 .value
3146 .parse::<bool>()
3147 .map_err(|_| PathExtractError::InvalidValue {
3148 name: self.name.to_string(),
3149 value: self.value.to_string(),
3150 expected: "boolean",
3151 message: "expected 'true' or 'false'".to_string(),
3152 })?;
3153 visitor.visit_bool(b)
3154 }
3155
3156 fn deserialize_i8<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3157 where
3158 V: Visitor<'de>,
3159 {
3160 let n = self
3161 .value
3162 .parse::<i8>()
3163 .map_err(|e| PathExtractError::InvalidValue {
3164 name: self.name.to_string(),
3165 value: self.value.to_string(),
3166 expected: "i8",
3167 message: e.to_string(),
3168 })?;
3169 visitor.visit_i8(n)
3170 }
3171
3172 fn deserialize_i16<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3173 where
3174 V: Visitor<'de>,
3175 {
3176 let n = self
3177 .value
3178 .parse::<i16>()
3179 .map_err(|e| PathExtractError::InvalidValue {
3180 name: self.name.to_string(),
3181 value: self.value.to_string(),
3182 expected: "i16",
3183 message: e.to_string(),
3184 })?;
3185 visitor.visit_i16(n)
3186 }
3187
3188 fn deserialize_i32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3189 where
3190 V: Visitor<'de>,
3191 {
3192 let n = self
3193 .value
3194 .parse::<i32>()
3195 .map_err(|e| PathExtractError::InvalidValue {
3196 name: self.name.to_string(),
3197 value: self.value.to_string(),
3198 expected: "i32",
3199 message: e.to_string(),
3200 })?;
3201 visitor.visit_i32(n)
3202 }
3203
3204 fn deserialize_i64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3205 where
3206 V: Visitor<'de>,
3207 {
3208 let n = self
3209 .value
3210 .parse::<i64>()
3211 .map_err(|e| PathExtractError::InvalidValue {
3212 name: self.name.to_string(),
3213 value: self.value.to_string(),
3214 expected: "i64",
3215 message: e.to_string(),
3216 })?;
3217 visitor.visit_i64(n)
3218 }
3219
3220 fn deserialize_i128<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3221 where
3222 V: Visitor<'de>,
3223 {
3224 let n = self
3225 .value
3226 .parse::<i128>()
3227 .map_err(|e| PathExtractError::InvalidValue {
3228 name: self.name.to_string(),
3229 value: self.value.to_string(),
3230 expected: "i128",
3231 message: e.to_string(),
3232 })?;
3233 visitor.visit_i128(n)
3234 }
3235
3236 fn deserialize_u8<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3237 where
3238 V: Visitor<'de>,
3239 {
3240 let n = self
3241 .value
3242 .parse::<u8>()
3243 .map_err(|e| PathExtractError::InvalidValue {
3244 name: self.name.to_string(),
3245 value: self.value.to_string(),
3246 expected: "u8",
3247 message: e.to_string(),
3248 })?;
3249 visitor.visit_u8(n)
3250 }
3251
3252 fn deserialize_u16<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3253 where
3254 V: Visitor<'de>,
3255 {
3256 let n = self
3257 .value
3258 .parse::<u16>()
3259 .map_err(|e| PathExtractError::InvalidValue {
3260 name: self.name.to_string(),
3261 value: self.value.to_string(),
3262 expected: "u16",
3263 message: e.to_string(),
3264 })?;
3265 visitor.visit_u16(n)
3266 }
3267
3268 fn deserialize_u32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3269 where
3270 V: Visitor<'de>,
3271 {
3272 let n = self
3273 .value
3274 .parse::<u32>()
3275 .map_err(|e| PathExtractError::InvalidValue {
3276 name: self.name.to_string(),
3277 value: self.value.to_string(),
3278 expected: "u32",
3279 message: e.to_string(),
3280 })?;
3281 visitor.visit_u32(n)
3282 }
3283
3284 fn deserialize_u64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3285 where
3286 V: Visitor<'de>,
3287 {
3288 let n = self
3289 .value
3290 .parse::<u64>()
3291 .map_err(|e| PathExtractError::InvalidValue {
3292 name: self.name.to_string(),
3293 value: self.value.to_string(),
3294 expected: "u64",
3295 message: e.to_string(),
3296 })?;
3297 visitor.visit_u64(n)
3298 }
3299
3300 fn deserialize_u128<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3301 where
3302 V: Visitor<'de>,
3303 {
3304 let n = self
3305 .value
3306 .parse::<u128>()
3307 .map_err(|e| PathExtractError::InvalidValue {
3308 name: self.name.to_string(),
3309 value: self.value.to_string(),
3310 expected: "u128",
3311 message: e.to_string(),
3312 })?;
3313 visitor.visit_u128(n)
3314 }
3315
3316 fn deserialize_f32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3317 where
3318 V: Visitor<'de>,
3319 {
3320 let n = self
3321 .value
3322 .parse::<f32>()
3323 .map_err(|e| PathExtractError::InvalidValue {
3324 name: self.name.to_string(),
3325 value: self.value.to_string(),
3326 expected: "f32",
3327 message: e.to_string(),
3328 })?;
3329 visitor.visit_f32(n)
3330 }
3331
3332 fn deserialize_f64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3333 where
3334 V: Visitor<'de>,
3335 {
3336 let n = self
3337 .value
3338 .parse::<f64>()
3339 .map_err(|e| PathExtractError::InvalidValue {
3340 name: self.name.to_string(),
3341 value: self.value.to_string(),
3342 expected: "f64",
3343 message: e.to_string(),
3344 })?;
3345 visitor.visit_f64(n)
3346 }
3347
3348 fn deserialize_char<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3349 where
3350 V: Visitor<'de>,
3351 {
3352 let mut chars = self.value.chars();
3353 let c = chars.next().ok_or_else(|| PathExtractError::InvalidValue {
3354 name: self.name.to_string(),
3355 value: self.value.to_string(),
3356 expected: "char",
3357 message: "empty string".to_string(),
3358 })?;
3359 if chars.next().is_some() {
3360 return Err(PathExtractError::InvalidValue {
3361 name: self.name.to_string(),
3362 value: self.value.to_string(),
3363 expected: "char",
3364 message: "expected single character".to_string(),
3365 });
3366 }
3367 visitor.visit_char(c)
3368 }
3369
3370 fn deserialize_str<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3371 where
3372 V: Visitor<'de>,
3373 {
3374 visitor.visit_str(self.value)
3375 }
3376
3377 fn deserialize_string<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3378 where
3379 V: Visitor<'de>,
3380 {
3381 visitor.visit_string(self.value.to_string())
3382 }
3383
3384 fn deserialize_bytes<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
3385 where
3386 V: Visitor<'de>,
3387 {
3388 Err(PathExtractError::DeserializeError {
3389 message: "bytes deserialization not supported for path parameters".to_string(),
3390 })
3391 }
3392
3393 fn deserialize_byte_buf<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
3394 where
3395 V: Visitor<'de>,
3396 {
3397 Err(PathExtractError::DeserializeError {
3398 message: "byte_buf deserialization not supported for path parameters".to_string(),
3399 })
3400 }
3401
3402 fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3403 where
3404 V: Visitor<'de>,
3405 {
3406 visitor.visit_some(self)
3407 }
3408
3409 fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3410 where
3411 V: Visitor<'de>,
3412 {
3413 visitor.visit_unit()
3414 }
3415
3416 fn deserialize_unit_struct<V>(
3417 self,
3418 _name: &'static str,
3419 visitor: V,
3420 ) -> Result<V::Value, Self::Error>
3421 where
3422 V: Visitor<'de>,
3423 {
3424 visitor.visit_unit()
3425 }
3426
3427 fn deserialize_newtype_struct<V>(
3428 self,
3429 _name: &'static str,
3430 visitor: V,
3431 ) -> Result<V::Value, Self::Error>
3432 where
3433 V: Visitor<'de>,
3434 {
3435 visitor.visit_newtype_struct(self)
3436 }
3437
3438 fn deserialize_seq<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
3439 where
3440 V: Visitor<'de>,
3441 {
3442 Err(PathExtractError::DeserializeError {
3443 message: "sequence deserialization not supported for single path parameter".to_string(),
3444 })
3445 }
3446
3447 fn deserialize_tuple<V>(self, _len: usize, _visitor: V) -> Result<V::Value, Self::Error>
3448 where
3449 V: Visitor<'de>,
3450 {
3451 Err(PathExtractError::DeserializeError {
3452 message: "tuple deserialization not supported for single path parameter".to_string(),
3453 })
3454 }
3455
3456 fn deserialize_tuple_struct<V>(
3457 self,
3458 _name: &'static str,
3459 _len: usize,
3460 _visitor: V,
3461 ) -> Result<V::Value, Self::Error>
3462 where
3463 V: Visitor<'de>,
3464 {
3465 Err(PathExtractError::DeserializeError {
3466 message: "tuple struct deserialization not supported for single path parameter"
3467 .to_string(),
3468 })
3469 }
3470
3471 fn deserialize_map<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
3472 where
3473 V: Visitor<'de>,
3474 {
3475 Err(PathExtractError::DeserializeError {
3476 message: "map deserialization not supported for single path parameter".to_string(),
3477 })
3478 }
3479
3480 fn deserialize_struct<V>(
3481 self,
3482 _name: &'static str,
3483 _fields: &'static [&'static str],
3484 _visitor: V,
3485 ) -> Result<V::Value, Self::Error>
3486 where
3487 V: Visitor<'de>,
3488 {
3489 Err(PathExtractError::DeserializeError {
3490 message: "struct deserialization not supported for single path parameter".to_string(),
3491 })
3492 }
3493
3494 fn deserialize_enum<V>(
3495 self,
3496 _name: &'static str,
3497 _variants: &'static [&'static str],
3498 visitor: V,
3499 ) -> Result<V::Value, Self::Error>
3500 where
3501 V: Visitor<'de>,
3502 {
3503 visitor.visit_enum(self.value.into_deserializer())
3504 }
3505
3506 fn deserialize_identifier<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3507 where
3508 V: Visitor<'de>,
3509 {
3510 visitor.visit_str(self.value)
3511 }
3512
3513 fn deserialize_ignored_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3514 where
3515 V: Visitor<'de>,
3516 {
3517 visitor.visit_unit()
3518 }
3519}
3520
3521#[derive(Debug, Clone, Copy, Default)]
3568pub struct Query<T>(pub T);
3569
3570impl<T> Query<T> {
3571 pub fn new(value: T) -> Self {
3573 Self(value)
3574 }
3575
3576 pub fn into_inner(self) -> T {
3578 self.0
3579 }
3580}
3581
3582impl<T> Deref for Query<T> {
3583 type Target = T;
3584
3585 fn deref(&self) -> &Self::Target {
3586 &self.0
3587 }
3588}
3589
3590impl<T> DerefMut for Query<T> {
3591 fn deref_mut(&mut self) -> &mut Self::Target {
3592 &mut self.0
3593 }
3594}
3595
3596#[derive(Debug)]
3598pub enum QueryExtractError {
3599 MissingParam { name: String },
3601 InvalidValue {
3603 name: String,
3604 value: String,
3605 expected: &'static str,
3606 message: String,
3607 },
3608 DeserializeError { message: String },
3610}
3611
3612impl fmt::Display for QueryExtractError {
3613 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
3614 match self {
3615 Self::MissingParam { name } => {
3616 write!(f, "Missing required query parameter: {}", name)
3617 }
3618 Self::InvalidValue {
3619 name,
3620 value,
3621 expected,
3622 message,
3623 } => {
3624 write!(
3625 f,
3626 "Invalid value '{}' for query parameter '{}' (expected {}): {}",
3627 value, name, expected, message
3628 )
3629 }
3630 Self::DeserializeError { message } => {
3631 write!(f, "Query deserialization error: {}", message)
3632 }
3633 }
3634 }
3635}
3636
3637impl std::error::Error for QueryExtractError {}
3638
3639impl de::Error for QueryExtractError {
3640 fn custom<T: fmt::Display>(msg: T) -> Self {
3641 Self::DeserializeError {
3642 message: msg.to_string(),
3643 }
3644 }
3645}
3646
3647impl IntoResponse for QueryExtractError {
3648 fn into_response(self) -> crate::response::Response {
3649 match self {
3650 Self::MissingParam { name } => ValidationErrors::single(
3651 ValidationError::missing(crate::error::loc::query(&name))
3652 .with_msg("Query parameter is required"),
3653 )
3654 .into_response(),
3655 Self::InvalidValue {
3656 name,
3657 value,
3658 expected,
3659 message,
3660 } => ValidationErrors::single(
3661 ValidationError::type_error(crate::error::loc::query(&name), &expected)
3662 .with_msg(format!("Expected {expected}: {message}"))
3663 .with_input(serde_json::Value::String(value)),
3664 )
3665 .into_response(),
3666 Self::DeserializeError { message } => ValidationErrors::single(
3667 ValidationError::new(
3668 crate::error::error_types::VALUE_ERROR,
3669 vec![crate::error::LocItem::field("query")],
3670 )
3671 .with_msg(message),
3672 )
3673 .into_response(),
3674 }
3675 }
3676}
3677
3678#[derive(Debug, Clone, Default)]
3683pub struct QueryParams {
3684 params: Vec<(String, String)>,
3686}
3687
3688impl QueryParams {
3689 pub fn new() -> Self {
3691 Self { params: Vec::new() }
3692 }
3693
3694 pub fn from_pairs(pairs: Vec<(String, String)>) -> Self {
3696 Self { params: pairs }
3697 }
3698
3699 pub fn parse(query: &str) -> Self {
3701 let pairs: Vec<(String, String)> = query
3702 .split('&')
3703 .filter(|s| !s.is_empty())
3704 .map(|pair| {
3705 if let Some(eq_pos) = pair.find('=') {
3706 let key = &pair[..eq_pos];
3707 let value = &pair[eq_pos + 1..];
3708 (
3709 percent_decode(key).into_owned(),
3710 percent_decode(value).into_owned(),
3711 )
3712 } else {
3713 (percent_decode(pair).into_owned(), String::new())
3715 }
3716 })
3717 .collect();
3718 Self { params: pairs }
3719 }
3720
3721 pub fn get(&self, key: &str) -> Option<&str> {
3723 self.params
3724 .iter()
3725 .find(|(k, _)| k == key)
3726 .map(|(_, v)| v.as_str())
3727 }
3728
3729 pub fn get_all(&self, key: &str) -> Vec<&str> {
3731 self.params
3732 .iter()
3733 .filter(|(k, _)| k == key)
3734 .map(|(_, v)| v.as_str())
3735 .collect()
3736 }
3737
3738 pub fn contains(&self, key: &str) -> bool {
3740 self.params.iter().any(|(k, _)| k == key)
3741 }
3742
3743 pub fn pairs(&self) -> &[(String, String)] {
3745 &self.params
3746 }
3747
3748 pub fn keys(&self) -> impl Iterator<Item = &str> {
3750 let mut seen = std::collections::HashSet::new();
3751 self.params.iter().filter_map(move |(k, _)| {
3752 if seen.insert(k.as_str()) {
3753 Some(k.as_str())
3754 } else {
3755 None
3756 }
3757 })
3758 }
3759
3760 pub fn len(&self) -> usize {
3762 self.params.len()
3763 }
3764
3765 pub fn is_empty(&self) -> bool {
3767 self.params.is_empty()
3768 }
3769}
3770
3771fn percent_decode(s: &str) -> std::borrow::Cow<'_, str> {
3776 use std::borrow::Cow;
3777
3778 if !s.contains('%') && !s.contains('+') {
3780 return Cow::Borrowed(s);
3781 }
3782
3783 let mut result = Vec::with_capacity(s.len());
3784 let bytes = s.as_bytes();
3785 let mut i = 0;
3786
3787 while i < bytes.len() {
3788 match bytes[i] {
3789 b'%' if i + 2 < bytes.len() => {
3790 if let (Some(hi), Some(lo)) = (hex_digit(bytes[i + 1]), hex_digit(bytes[i + 2])) {
3792 result.push(hi << 4 | lo);
3793 i += 3;
3794 } else {
3795 result.push(b'%');
3797 i += 1;
3798 }
3799 }
3800 b'+' => {
3801 result.push(b' ');
3803 i += 1;
3804 }
3805 b => {
3806 result.push(b);
3807 i += 1;
3808 }
3809 }
3810 }
3811
3812 Cow::Owned(String::from_utf8_lossy(&result).into_owned())
3813}
3814
3815fn hex_digit(b: u8) -> Option<u8> {
3817 match b {
3818 b'0'..=b'9' => Some(b - b'0'),
3819 b'a'..=b'f' => Some(b - b'a' + 10),
3820 b'A'..=b'F' => Some(b - b'A' + 10),
3821 _ => None,
3822 }
3823}
3824
3825impl<T: DeserializeOwned> FromRequest for Query<T> {
3826 type Error = QueryExtractError;
3827
3828 async fn from_request(_ctx: &RequestContext, req: &mut Request) -> Result<Self, Self::Error> {
3829 let params = match req.get_extension::<QueryParams>() {
3831 Some(p) => p.clone(),
3832 None => {
3833 let query_str = req.query().unwrap_or("");
3835 QueryParams::parse(query_str)
3836 }
3837 };
3838
3839 let value = T::deserialize(QueryDeserializer::new(¶ms))?;
3841
3842 Ok(Query(value))
3843 }
3844}
3845
3846struct QueryDeserializer<'de> {
3858 params: &'de QueryParams,
3859}
3860
3861impl<'de> QueryDeserializer<'de> {
3862 fn new(params: &'de QueryParams) -> Self {
3863 Self { params }
3864 }
3865}
3866
3867impl<'de> Deserializer<'de> for QueryDeserializer<'de> {
3868 type Error = QueryExtractError;
3869
3870 fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3871 where
3872 V: Visitor<'de>,
3873 {
3874 self.deserialize_map(visitor)
3876 }
3877
3878 fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3879 where
3880 V: Visitor<'de>,
3881 {
3882 let value = self
3883 .params
3884 .pairs()
3885 .first()
3886 .map(|(_, v)| v.as_str())
3887 .ok_or_else(|| QueryExtractError::MissingParam {
3888 name: "value".to_string(),
3889 })?;
3890
3891 let b = parse_bool(value).map_err(|msg| QueryExtractError::InvalidValue {
3892 name: "value".to_string(),
3893 value: value.to_string(),
3894 expected: "bool",
3895 message: msg,
3896 })?;
3897 visitor.visit_bool(b)
3898 }
3899
3900 fn deserialize_i8<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3901 where
3902 V: Visitor<'de>,
3903 {
3904 let value = self.get_single_value()?;
3905 let n = value
3906 .parse::<i8>()
3907 .map_err(|e| QueryExtractError::InvalidValue {
3908 name: "value".to_string(),
3909 value: value.to_string(),
3910 expected: "i8",
3911 message: e.to_string(),
3912 })?;
3913 visitor.visit_i8(n)
3914 }
3915
3916 fn deserialize_i16<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3917 where
3918 V: Visitor<'de>,
3919 {
3920 let value = self.get_single_value()?;
3921 let n = value
3922 .parse::<i16>()
3923 .map_err(|e| QueryExtractError::InvalidValue {
3924 name: "value".to_string(),
3925 value: value.to_string(),
3926 expected: "i16",
3927 message: e.to_string(),
3928 })?;
3929 visitor.visit_i16(n)
3930 }
3931
3932 fn deserialize_i32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3933 where
3934 V: Visitor<'de>,
3935 {
3936 let value = self.get_single_value()?;
3937 let n = value
3938 .parse::<i32>()
3939 .map_err(|e| QueryExtractError::InvalidValue {
3940 name: "value".to_string(),
3941 value: value.to_string(),
3942 expected: "i32",
3943 message: e.to_string(),
3944 })?;
3945 visitor.visit_i32(n)
3946 }
3947
3948 fn deserialize_i64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3949 where
3950 V: Visitor<'de>,
3951 {
3952 let value = self.get_single_value()?;
3953 let n = value
3954 .parse::<i64>()
3955 .map_err(|e| QueryExtractError::InvalidValue {
3956 name: "value".to_string(),
3957 value: value.to_string(),
3958 expected: "i64",
3959 message: e.to_string(),
3960 })?;
3961 visitor.visit_i64(n)
3962 }
3963
3964 fn deserialize_u8<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3965 where
3966 V: Visitor<'de>,
3967 {
3968 let value = self.get_single_value()?;
3969 let n = value
3970 .parse::<u8>()
3971 .map_err(|e| QueryExtractError::InvalidValue {
3972 name: "value".to_string(),
3973 value: value.to_string(),
3974 expected: "u8",
3975 message: e.to_string(),
3976 })?;
3977 visitor.visit_u8(n)
3978 }
3979
3980 fn deserialize_u16<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3981 where
3982 V: Visitor<'de>,
3983 {
3984 let value = self.get_single_value()?;
3985 let n = value
3986 .parse::<u16>()
3987 .map_err(|e| QueryExtractError::InvalidValue {
3988 name: "value".to_string(),
3989 value: value.to_string(),
3990 expected: "u16",
3991 message: e.to_string(),
3992 })?;
3993 visitor.visit_u16(n)
3994 }
3995
3996 fn deserialize_u32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
3997 where
3998 V: Visitor<'de>,
3999 {
4000 let value = self.get_single_value()?;
4001 let n = value
4002 .parse::<u32>()
4003 .map_err(|e| QueryExtractError::InvalidValue {
4004 name: "value".to_string(),
4005 value: value.to_string(),
4006 expected: "u32",
4007 message: e.to_string(),
4008 })?;
4009 visitor.visit_u32(n)
4010 }
4011
4012 fn deserialize_u64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
4013 where
4014 V: Visitor<'de>,
4015 {
4016 let value = self.get_single_value()?;
4017 let n = value
4018 .parse::<u64>()
4019 .map_err(|e| QueryExtractError::InvalidValue {
4020 name: "value".to_string(),
4021 value: value.to_string(),
4022 expected: "u64",
4023 message: e.to_string(),
4024 })?;
4025 visitor.visit_u64(n)
4026 }
4027
4028 fn deserialize_f32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
4029 where
4030 V: Visitor<'de>,
4031 {
4032 let value = self.get_single_value()?;
4033 let n = value
4034 .parse::<f32>()
4035 .map_err(|e| QueryExtractError::InvalidValue {
4036 name: "value".to_string(),
4037 value: value.to_string(),
4038 expected: "f32",
4039 message: e.to_string(),
4040 })?;
4041 visitor.visit_f32(n)
4042 }
4043
4044 fn deserialize_f64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
4045 where
4046 V: Visitor<'de>,
4047 {
4048 let value = self.get_single_value()?;
4049 let n = value
4050 .parse::<f64>()
4051 .map_err(|e| QueryExtractError::InvalidValue {
4052 name: "value".to_string(),
4053 value: value.to_string(),
4054 expected: "f64",
4055 message: e.to_string(),
4056 })?;
4057 visitor.visit_f64(n)
4058 }
4059
4060 fn deserialize_char<V>(self, visitor: V) -> Result<V::Value, Self::Error>
4061 where
4062 V: Visitor<'de>,
4063 {
4064 let value = self.get_single_value()?;
4065 let mut chars = value.chars();
4066 match (chars.next(), chars.next()) {
4067 (Some(c), None) => visitor.visit_char(c),
4068 _ => Err(QueryExtractError::InvalidValue {
4069 name: "value".to_string(),
4070 value: value.to_string(),
4071 expected: "char",
4072 message: "expected single character".to_string(),
4073 }),
4074 }
4075 }
4076
4077 fn deserialize_str<V>(self, visitor: V) -> Result<V::Value, Self::Error>
4078 where
4079 V: Visitor<'de>,
4080 {
4081 let value = self.get_single_value()?;
4082 visitor.visit_str(value)
4083 }
4084
4085 fn deserialize_string<V>(self, visitor: V) -> Result<V::Value, Self::Error>
4086 where
4087 V: Visitor<'de>,
4088 {
4089 let value = self.get_single_value()?;
4090 visitor.visit_string(value.to_owned())
4091 }
4092
4093 fn deserialize_bytes<V>(self, visitor: V) -> Result<V::Value, Self::Error>
4094 where
4095 V: Visitor<'de>,
4096 {
4097 let value = self.get_single_value()?;
4098 visitor.visit_bytes(value.as_bytes())
4099 }
4100
4101 fn deserialize_byte_buf<V>(self, visitor: V) -> Result<V::Value, Self::Error>
4102 where
4103 V: Visitor<'de>,
4104 {
4105 let value = self.get_single_value()?;
4106 visitor.visit_byte_buf(value.as_bytes().to_vec())
4107 }
4108
4109 fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Self::Error>
4110 where
4111 V: Visitor<'de>,
4112 {
4113 if self.params.is_empty() {
4115 visitor.visit_none()
4116 } else {
4117 visitor.visit_some(self)
4118 }
4119 }
4120
4121 fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value, Self::Error>
4122 where
4123 V: Visitor<'de>,
4124 {
4125 visitor.visit_unit()
4126 }
4127
4128 fn deserialize_unit_struct<V>(
4129 self,
4130 _name: &'static str,
4131 visitor: V,
4132 ) -> Result<V::Value, Self::Error>
4133 where
4134 V: Visitor<'de>,
4135 {
4136 visitor.visit_unit()
4137 }
4138
4139 fn deserialize_newtype_struct<V>(
4140 self,
4141 _name: &'static str,
4142 visitor: V,
4143 ) -> Result<V::Value, Self::Error>
4144 where
4145 V: Visitor<'de>,
4146 {
4147 visitor.visit_newtype_struct(self)
4148 }
4149
4150 fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value, Self::Error>
4151 where
4152 V: Visitor<'de>,
4153 {
4154 let values: Vec<&str> = self
4156 .params
4157 .pairs()
4158 .iter()
4159 .map(|(_, v)| v.as_str())
4160 .collect();
4161 visitor.visit_seq(QuerySeqAccess::new(values))
4162 }
4163
4164 fn deserialize_tuple<V>(self, _len: usize, visitor: V) -> Result<V::Value, Self::Error>
4165 where
4166 V: Visitor<'de>,
4167 {
4168 let values: Vec<&str> = self
4170 .params
4171 .pairs()
4172 .iter()
4173 .map(|(_, v)| v.as_str())
4174 .collect();
4175 visitor.visit_seq(QuerySeqAccess::new(values))
4176 }
4177
4178 fn deserialize_tuple_struct<V>(
4179 self,
4180 _name: &'static str,
4181 _len: usize,
4182 visitor: V,
4183 ) -> Result<V::Value, Self::Error>
4184 where
4185 V: Visitor<'de>,
4186 {
4187 let values: Vec<&str> = self
4188 .params
4189 .pairs()
4190 .iter()
4191 .map(|(_, v)| v.as_str())
4192 .collect();
4193 visitor.visit_seq(QuerySeqAccess::new(values))
4194 }
4195
4196 fn deserialize_map<V>(self, visitor: V) -> Result<V::Value, Self::Error>
4197 where
4198 V: Visitor<'de>,
4199 {
4200 visitor.visit_map(QueryMapAccess::new(self.params))
4201 }
4202
4203 fn deserialize_struct<V>(
4204 self,
4205 _name: &'static str,
4206 _fields: &'static [&'static str],
4207 visitor: V,
4208 ) -> Result<V::Value, Self::Error>
4209 where
4210 V: Visitor<'de>,
4211 {
4212 self.deserialize_map(visitor)
4213 }
4214
4215 fn deserialize_enum<V>(
4216 self,
4217 _name: &'static str,
4218 _variants: &'static [&'static str],
4219 visitor: V,
4220 ) -> Result<V::Value, Self::Error>
4221 where
4222 V: Visitor<'de>,
4223 {
4224 let value = self.get_single_value()?;
4226 visitor.visit_enum(value.into_deserializer())
4227 }
4228
4229 fn deserialize_identifier<V>(self, visitor: V) -> Result<V::Value, Self::Error>
4230 where
4231 V: Visitor<'de>,
4232 {
4233 let value = self.get_single_value()?;
4234 visitor.visit_str(value)
4235 }
4236
4237 fn deserialize_ignored_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
4238 where
4239 V: Visitor<'de>,
4240 {
4241 visitor.visit_unit()
4242 }
4243}
4244
4245impl<'de> QueryDeserializer<'de> {
4246 fn get_single_value(&self) -> Result<&'de str, QueryExtractError> {
4247 self.params
4248 .pairs()
4249 .first()
4250 .map(|(_, v)| v.as_str())
4251 .ok_or_else(|| QueryExtractError::MissingParam {
4252 name: "value".to_string(),
4253 })
4254 }
4255}
4256
4257fn parse_bool(s: &str) -> Result<bool, String> {
4259 match s.to_lowercase().as_str() {
4260 "true" | "1" | "yes" | "on" => Ok(true),
4261 "false" | "0" | "no" | "off" | "" => Ok(false),
4262 _ => Err(format!("cannot parse '{}' as boolean", s)),
4263 }
4264}
4265
4266struct QuerySeqAccess<'de> {
4268 values: Vec<&'de str>,
4269 index: usize,
4270}
4271
4272impl<'de> QuerySeqAccess<'de> {
4273 fn new(values: Vec<&'de str>) -> Self {
4274 Self { values, index: 0 }
4275 }
4276}
4277
4278impl<'de> SeqAccess<'de> for QuerySeqAccess<'de> {
4279 type Error = QueryExtractError;
4280
4281 fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>, Self::Error>
4282 where
4283 T: de::DeserializeSeed<'de>,
4284 {
4285 if self.index >= self.values.len() {
4286 return Ok(None);
4287 }
4288
4289 let value = self.values[self.index];
4290 self.index += 1;
4291
4292 seed.deserialize(QueryValueDeserializer::new(value, None))
4293 .map(Some)
4294 }
4295
4296 fn size_hint(&self) -> Option<usize> {
4297 Some(self.values.len() - self.index)
4298 }
4299}
4300
4301struct QueryMapAccess<'de> {
4303 params: &'de QueryParams,
4304 keys: Vec<&'de str>,
4305 index: usize,
4306}
4307
4308impl<'de> QueryMapAccess<'de> {
4309 fn new(params: &'de QueryParams) -> Self {
4310 let keys: Vec<&str> = params.keys().collect();
4311 Self {
4312 params,
4313 keys,
4314 index: 0,
4315 }
4316 }
4317}
4318
4319impl<'de> MapAccess<'de> for QueryMapAccess<'de> {
4320 type Error = QueryExtractError;
4321
4322 fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>, Self::Error>
4323 where
4324 K: de::DeserializeSeed<'de>,
4325 {
4326 if self.index >= self.keys.len() {
4327 return Ok(None);
4328 }
4329
4330 let key = self.keys[self.index];
4331 seed.deserialize(key.into_deserializer()).map(Some)
4332 }
4333
4334 fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value, Self::Error>
4335 where
4336 V: de::DeserializeSeed<'de>,
4337 {
4338 let key = self.keys[self.index];
4339 self.index += 1;
4340
4341 let values = self.params.get_all(key);
4343
4344 seed.deserialize(QueryFieldDeserializer::new(key, values))
4345 }
4346}
4347
4348struct QueryValueDeserializer<'de> {
4350 value: &'de str,
4351 name: Option<&'de str>,
4352}
4353
4354impl<'de> QueryValueDeserializer<'de> {
4355 fn new(value: &'de str, name: Option<&'de str>) -> Self {
4356 Self { value, name }
4357 }
4358
4359 fn field_name(&self) -> String {
4360 self.name.unwrap_or("value").to_string()
4361 }
4362}
4363
4364impl<'de> Deserializer<'de> for QueryValueDeserializer<'de> {
4365 type Error = QueryExtractError;
4366
4367 fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
4368 where
4369 V: Visitor<'de>,
4370 {
4371 visitor.visit_str(self.value)
4372 }
4373
4374 fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value, Self::Error>
4375 where
4376 V: Visitor<'de>,
4377 {
4378 let b = parse_bool(self.value).map_err(|msg| QueryExtractError::InvalidValue {
4379 name: self.field_name(),
4380 value: self.value.to_string(),
4381 expected: "bool",
4382 message: msg,
4383 })?;
4384 visitor.visit_bool(b)
4385 }
4386
4387 fn deserialize_i8<V>(self, visitor: V) -> Result<V::Value, Self::Error>
4388 where
4389 V: Visitor<'de>,
4390 {
4391 let n = self
4392 .value
4393 .parse::<i8>()
4394 .map_err(|e| QueryExtractError::InvalidValue {
4395 name: self.field_name(),
4396 value: self.value.to_string(),
4397 expected: "i8",
4398 message: e.to_string(),
4399 })?;
4400 visitor.visit_i8(n)
4401 }
4402
4403 fn deserialize_i16<V>(self, visitor: V) -> Result<V::Value, Self::Error>
4404 where
4405 V: Visitor<'de>,
4406 {
4407 let n = self
4408 .value
4409 .parse::<i16>()
4410 .map_err(|e| QueryExtractError::InvalidValue {
4411 name: self.field_name(),
4412 value: self.value.to_string(),
4413 expected: "i16",
4414 message: e.to_string(),
4415 })?;
4416 visitor.visit_i16(n)
4417 }
4418
4419 fn deserialize_i32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
4420 where
4421 V: Visitor<'de>,
4422 {
4423 let n = self
4424 .value
4425 .parse::<i32>()
4426 .map_err(|e| QueryExtractError::InvalidValue {
4427 name: self.field_name(),
4428 value: self.value.to_string(),
4429 expected: "i32",
4430 message: e.to_string(),
4431 })?;
4432 visitor.visit_i32(n)
4433 }
4434
4435 fn deserialize_i64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
4436 where
4437 V: Visitor<'de>,
4438 {
4439 let n = self
4440 .value
4441 .parse::<i64>()
4442 .map_err(|e| QueryExtractError::InvalidValue {
4443 name: self.field_name(),
4444 value: self.value.to_string(),
4445 expected: "i64",
4446 message: e.to_string(),
4447 })?;
4448 visitor.visit_i64(n)
4449 }
4450
4451 fn deserialize_u8<V>(self, visitor: V) -> Result<V::Value, Self::Error>
4452 where
4453 V: Visitor<'de>,
4454 {
4455 let n = self
4456 .value
4457 .parse::<u8>()
4458 .map_err(|e| QueryExtractError::InvalidValue {
4459 name: self.field_name(),
4460 value: self.value.to_string(),
4461 expected: "u8",
4462 message: e.to_string(),
4463 })?;
4464 visitor.visit_u8(n)
4465 }
4466
4467 fn deserialize_u16<V>(self, visitor: V) -> Result<V::Value, Self::Error>
4468 where
4469 V: Visitor<'de>,
4470 {
4471 let n = self
4472 .value
4473 .parse::<u16>()
4474 .map_err(|e| QueryExtractError::InvalidValue {
4475 name: self.field_name(),
4476 value: self.value.to_string(),
4477 expected: "u16",
4478 message: e.to_string(),
4479 })?;
4480 visitor.visit_u16(n)
4481 }
4482
4483 fn deserialize_u32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
4484 where
4485 V: Visitor<'de>,
4486 {
4487 let n = self
4488 .value
4489 .parse::<u32>()
4490 .map_err(|e| QueryExtractError::InvalidValue {
4491 name: self.field_name(),
4492 value: self.value.to_string(),
4493 expected: "u32",
4494 message: e.to_string(),
4495 })?;
4496 visitor.visit_u32(n)
4497 }
4498
4499 fn deserialize_u64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
4500 where
4501 V: Visitor<'de>,
4502 {
4503 let n = self
4504 .value
4505 .parse::<u64>()
4506 .map_err(|e| QueryExtractError::InvalidValue {
4507 name: self.field_name(),
4508 value: self.value.to_string(),
4509 expected: "u64",
4510 message: e.to_string(),
4511 })?;
4512 visitor.visit_u64(n)
4513 }
4514
4515 fn deserialize_f32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
4516 where
4517 V: Visitor<'de>,
4518 {
4519 let n = self
4520 .value
4521 .parse::<f32>()
4522 .map_err(|e| QueryExtractError::InvalidValue {
4523 name: self.field_name(),
4524 value: self.value.to_string(),
4525 expected: "f32",
4526 message: e.to_string(),
4527 })?;
4528 visitor.visit_f32(n)
4529 }
4530
4531 fn deserialize_f64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
4532 where
4533 V: Visitor<'de>,
4534 {
4535 let n = self
4536 .value
4537 .parse::<f64>()
4538 .map_err(|e| QueryExtractError::InvalidValue {
4539 name: self.field_name(),
4540 value: self.value.to_string(),
4541 expected: "f64",
4542 message: e.to_string(),
4543 })?;
4544 visitor.visit_f64(n)
4545 }
4546
4547 fn deserialize_char<V>(self, visitor: V) -> Result<V::Value, Self::Error>
4548 where
4549 V: Visitor<'de>,
4550 {
4551 let mut chars = self.value.chars();
4552 match (chars.next(), chars.next()) {
4553 (Some(c), None) => visitor.visit_char(c),
4554 _ => Err(QueryExtractError::InvalidValue {
4555 name: self.field_name(),
4556 value: self.value.to_string(),
4557 expected: "char",
4558 message: "expected single character".to_string(),
4559 }),
4560 }
4561 }
4562
4563 fn deserialize_str<V>(self, visitor: V) -> Result<V::Value, Self::Error>
4564 where
4565 V: Visitor<'de>,
4566 {
4567 visitor.visit_str(self.value)
4568 }
4569
4570 fn deserialize_string<V>(self, visitor: V) -> Result<V::Value, Self::Error>
4571 where
4572 V: Visitor<'de>,
4573 {
4574 visitor.visit_string(self.value.to_owned())
4575 }
4576
4577 fn deserialize_bytes<V>(self, visitor: V) -> Result<V::Value, Self::Error>
4578 where
4579 V: Visitor<'de>,
4580 {
4581 visitor.visit_bytes(self.value.as_bytes())
4582 }
4583
4584 fn deserialize_byte_buf<V>(self, visitor: V) -> Result<V::Value, Self::Error>
4585 where
4586 V: Visitor<'de>,
4587 {
4588 visitor.visit_byte_buf(self.value.as_bytes().to_vec())
4589 }
4590
4591 fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Self::Error>
4592 where
4593 V: Visitor<'de>,
4594 {
4595 if self.value.is_empty() {
4596 visitor.visit_none()
4597 } else {
4598 visitor.visit_some(self)
4599 }
4600 }
4601
4602 fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value, Self::Error>
4603 where
4604 V: Visitor<'de>,
4605 {
4606 visitor.visit_unit()
4607 }
4608
4609 fn deserialize_unit_struct<V>(
4610 self,
4611 _name: &'static str,
4612 visitor: V,
4613 ) -> Result<V::Value, Self::Error>
4614 where
4615 V: Visitor<'de>,
4616 {
4617 visitor.visit_unit()
4618 }
4619
4620 fn deserialize_newtype_struct<V>(
4621 self,
4622 _name: &'static str,
4623 visitor: V,
4624 ) -> Result<V::Value, Self::Error>
4625 where
4626 V: Visitor<'de>,
4627 {
4628 visitor.visit_newtype_struct(self)
4629 }
4630
4631 fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value, Self::Error>
4632 where
4633 V: Visitor<'de>,
4634 {
4635 visitor.visit_seq(QuerySeqAccess::new(vec![self.value]))
4637 }
4638
4639 fn deserialize_tuple<V>(self, _len: usize, visitor: V) -> Result<V::Value, Self::Error>
4640 where
4641 V: Visitor<'de>,
4642 {
4643 visitor.visit_seq(QuerySeqAccess::new(vec![self.value]))
4644 }
4645
4646 fn deserialize_tuple_struct<V>(
4647 self,
4648 _name: &'static str,
4649 _len: usize,
4650 visitor: V,
4651 ) -> Result<V::Value, Self::Error>
4652 where
4653 V: Visitor<'de>,
4654 {
4655 visitor.visit_seq(QuerySeqAccess::new(vec![self.value]))
4656 }
4657
4658 fn deserialize_map<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
4659 where
4660 V: Visitor<'de>,
4661 {
4662 Err(QueryExtractError::DeserializeError {
4664 message: "cannot deserialize single value as map".to_string(),
4665 })
4666 }
4667
4668 fn deserialize_struct<V>(
4669 self,
4670 _name: &'static str,
4671 _fields: &'static [&'static str],
4672 visitor: V,
4673 ) -> Result<V::Value, Self::Error>
4674 where
4675 V: Visitor<'de>,
4676 {
4677 self.deserialize_map(visitor)
4678 }
4679
4680 fn deserialize_enum<V>(
4681 self,
4682 _name: &'static str,
4683 _variants: &'static [&'static str],
4684 visitor: V,
4685 ) -> Result<V::Value, Self::Error>
4686 where
4687 V: Visitor<'de>,
4688 {
4689 visitor.visit_enum(self.value.into_deserializer())
4690 }
4691
4692 fn deserialize_identifier<V>(self, visitor: V) -> Result<V::Value, Self::Error>
4693 where
4694 V: Visitor<'de>,
4695 {
4696 visitor.visit_str(self.value)
4697 }
4698
4699 fn deserialize_ignored_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
4700 where
4701 V: Visitor<'de>,
4702 {
4703 visitor.visit_unit()
4704 }
4705}
4706
4707struct QueryFieldDeserializer<'de> {
4711 name: &'de str,
4712 values: Vec<&'de str>,
4713}
4714
4715impl<'de> QueryFieldDeserializer<'de> {
4716 fn new(name: &'de str, values: Vec<&'de str>) -> Self {
4717 Self { name, values }
4718 }
4719}
4720
4721impl<'de> Deserializer<'de> for QueryFieldDeserializer<'de> {
4722 type Error = QueryExtractError;
4723
4724 fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
4725 where
4726 V: Visitor<'de>,
4727 {
4728 if let Some(value) = self.values.first() {
4730 visitor.visit_str(value)
4731 } else {
4732 visitor.visit_none()
4733 }
4734 }
4735
4736 fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value, Self::Error>
4737 where
4738 V: Visitor<'de>,
4739 {
4740 let value = self
4741 .values
4742 .first()
4743 .ok_or_else(|| QueryExtractError::MissingParam {
4744 name: self.name.to_string(),
4745 })?;
4746 let b = parse_bool(value).map_err(|msg| QueryExtractError::InvalidValue {
4747 name: self.name.to_string(),
4748 value: (*value).to_string(),
4749 expected: "bool",
4750 message: msg,
4751 })?;
4752 visitor.visit_bool(b)
4753 }
4754
4755 fn deserialize_i8<V>(self, visitor: V) -> Result<V::Value, Self::Error>
4756 where
4757 V: Visitor<'de>,
4758 {
4759 let value = self
4760 .values
4761 .first()
4762 .ok_or_else(|| QueryExtractError::MissingParam {
4763 name: self.name.to_string(),
4764 })?;
4765 let n = value
4766 .parse::<i8>()
4767 .map_err(|e| QueryExtractError::InvalidValue {
4768 name: self.name.to_string(),
4769 value: (*value).to_string(),
4770 expected: "i8",
4771 message: e.to_string(),
4772 })?;
4773 visitor.visit_i8(n)
4774 }
4775
4776 fn deserialize_i16<V>(self, visitor: V) -> Result<V::Value, Self::Error>
4777 where
4778 V: Visitor<'de>,
4779 {
4780 let value = self
4781 .values
4782 .first()
4783 .ok_or_else(|| QueryExtractError::MissingParam {
4784 name: self.name.to_string(),
4785 })?;
4786 let n = value
4787 .parse::<i16>()
4788 .map_err(|e| QueryExtractError::InvalidValue {
4789 name: self.name.to_string(),
4790 value: (*value).to_string(),
4791 expected: "i16",
4792 message: e.to_string(),
4793 })?;
4794 visitor.visit_i16(n)
4795 }
4796
4797 fn deserialize_i32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
4798 where
4799 V: Visitor<'de>,
4800 {
4801 let value = self
4802 .values
4803 .first()
4804 .ok_or_else(|| QueryExtractError::MissingParam {
4805 name: self.name.to_string(),
4806 })?;
4807 let n = value
4808 .parse::<i32>()
4809 .map_err(|e| QueryExtractError::InvalidValue {
4810 name: self.name.to_string(),
4811 value: (*value).to_string(),
4812 expected: "i32",
4813 message: e.to_string(),
4814 })?;
4815 visitor.visit_i32(n)
4816 }
4817
4818 fn deserialize_i64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
4819 where
4820 V: Visitor<'de>,
4821 {
4822 let value = self
4823 .values
4824 .first()
4825 .ok_or_else(|| QueryExtractError::MissingParam {
4826 name: self.name.to_string(),
4827 })?;
4828 let n = value
4829 .parse::<i64>()
4830 .map_err(|e| QueryExtractError::InvalidValue {
4831 name: self.name.to_string(),
4832 value: (*value).to_string(),
4833 expected: "i64",
4834 message: e.to_string(),
4835 })?;
4836 visitor.visit_i64(n)
4837 }
4838
4839 fn deserialize_u8<V>(self, visitor: V) -> Result<V::Value, Self::Error>
4840 where
4841 V: Visitor<'de>,
4842 {
4843 let value = self
4844 .values
4845 .first()
4846 .ok_or_else(|| QueryExtractError::MissingParam {
4847 name: self.name.to_string(),
4848 })?;
4849 let n = value
4850 .parse::<u8>()
4851 .map_err(|e| QueryExtractError::InvalidValue {
4852 name: self.name.to_string(),
4853 value: (*value).to_string(),
4854 expected: "u8",
4855 message: e.to_string(),
4856 })?;
4857 visitor.visit_u8(n)
4858 }
4859
4860 fn deserialize_u16<V>(self, visitor: V) -> Result<V::Value, Self::Error>
4861 where
4862 V: Visitor<'de>,
4863 {
4864 let value = self
4865 .values
4866 .first()
4867 .ok_or_else(|| QueryExtractError::MissingParam {
4868 name: self.name.to_string(),
4869 })?;
4870 let n = value
4871 .parse::<u16>()
4872 .map_err(|e| QueryExtractError::InvalidValue {
4873 name: self.name.to_string(),
4874 value: (*value).to_string(),
4875 expected: "u16",
4876 message: e.to_string(),
4877 })?;
4878 visitor.visit_u16(n)
4879 }
4880
4881 fn deserialize_u32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
4882 where
4883 V: Visitor<'de>,
4884 {
4885 let value = self
4886 .values
4887 .first()
4888 .ok_or_else(|| QueryExtractError::MissingParam {
4889 name: self.name.to_string(),
4890 })?;
4891 let n = value
4892 .parse::<u32>()
4893 .map_err(|e| QueryExtractError::InvalidValue {
4894 name: self.name.to_string(),
4895 value: (*value).to_string(),
4896 expected: "u32",
4897 message: e.to_string(),
4898 })?;
4899 visitor.visit_u32(n)
4900 }
4901
4902 fn deserialize_u64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
4903 where
4904 V: Visitor<'de>,
4905 {
4906 let value = self
4907 .values
4908 .first()
4909 .ok_or_else(|| QueryExtractError::MissingParam {
4910 name: self.name.to_string(),
4911 })?;
4912 let n = value
4913 .parse::<u64>()
4914 .map_err(|e| QueryExtractError::InvalidValue {
4915 name: self.name.to_string(),
4916 value: (*value).to_string(),
4917 expected: "u64",
4918 message: e.to_string(),
4919 })?;
4920 visitor.visit_u64(n)
4921 }
4922
4923 fn deserialize_f32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
4924 where
4925 V: Visitor<'de>,
4926 {
4927 let value = self
4928 .values
4929 .first()
4930 .ok_or_else(|| QueryExtractError::MissingParam {
4931 name: self.name.to_string(),
4932 })?;
4933 let n = value
4934 .parse::<f32>()
4935 .map_err(|e| QueryExtractError::InvalidValue {
4936 name: self.name.to_string(),
4937 value: (*value).to_string(),
4938 expected: "f32",
4939 message: e.to_string(),
4940 })?;
4941 visitor.visit_f32(n)
4942 }
4943
4944 fn deserialize_f64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
4945 where
4946 V: Visitor<'de>,
4947 {
4948 let value = self
4949 .values
4950 .first()
4951 .ok_or_else(|| QueryExtractError::MissingParam {
4952 name: self.name.to_string(),
4953 })?;
4954 let n = value
4955 .parse::<f64>()
4956 .map_err(|e| QueryExtractError::InvalidValue {
4957 name: self.name.to_string(),
4958 value: (*value).to_string(),
4959 expected: "f64",
4960 message: e.to_string(),
4961 })?;
4962 visitor.visit_f64(n)
4963 }
4964
4965 fn deserialize_char<V>(self, visitor: V) -> Result<V::Value, Self::Error>
4966 where
4967 V: Visitor<'de>,
4968 {
4969 let value = self
4970 .values
4971 .first()
4972 .ok_or_else(|| QueryExtractError::MissingParam {
4973 name: self.name.to_string(),
4974 })?;
4975 let mut chars = value.chars();
4976 match (chars.next(), chars.next()) {
4977 (Some(c), None) => visitor.visit_char(c),
4978 _ => Err(QueryExtractError::InvalidValue {
4979 name: self.name.to_string(),
4980 value: (*value).to_string(),
4981 expected: "char",
4982 message: "expected single character".to_string(),
4983 }),
4984 }
4985 }
4986
4987 fn deserialize_str<V>(self, visitor: V) -> Result<V::Value, Self::Error>
4988 where
4989 V: Visitor<'de>,
4990 {
4991 let value = self
4992 .values
4993 .first()
4994 .ok_or_else(|| QueryExtractError::MissingParam {
4995 name: self.name.to_string(),
4996 })?;
4997 visitor.visit_str(value)
4998 }
4999
5000 fn deserialize_string<V>(self, visitor: V) -> Result<V::Value, Self::Error>
5001 where
5002 V: Visitor<'de>,
5003 {
5004 let value = self
5005 .values
5006 .first()
5007 .ok_or_else(|| QueryExtractError::MissingParam {
5008 name: self.name.to_string(),
5009 })?;
5010 visitor.visit_string((*value).to_owned())
5011 }
5012
5013 fn deserialize_bytes<V>(self, visitor: V) -> Result<V::Value, Self::Error>
5014 where
5015 V: Visitor<'de>,
5016 {
5017 let value = self
5018 .values
5019 .first()
5020 .ok_or_else(|| QueryExtractError::MissingParam {
5021 name: self.name.to_string(),
5022 })?;
5023 visitor.visit_bytes(value.as_bytes())
5024 }
5025
5026 fn deserialize_byte_buf<V>(self, visitor: V) -> Result<V::Value, Self::Error>
5027 where
5028 V: Visitor<'de>,
5029 {
5030 let value = self
5031 .values
5032 .first()
5033 .ok_or_else(|| QueryExtractError::MissingParam {
5034 name: self.name.to_string(),
5035 })?;
5036 visitor.visit_byte_buf(value.as_bytes().to_vec())
5037 }
5038
5039 fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Self::Error>
5040 where
5041 V: Visitor<'de>,
5042 {
5043 if self.values.is_empty() {
5044 visitor.visit_none()
5045 } else {
5046 visitor.visit_some(self)
5047 }
5048 }
5049
5050 fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value, Self::Error>
5051 where
5052 V: Visitor<'de>,
5053 {
5054 visitor.visit_unit()
5055 }
5056
5057 fn deserialize_unit_struct<V>(
5058 self,
5059 _name: &'static str,
5060 visitor: V,
5061 ) -> Result<V::Value, Self::Error>
5062 where
5063 V: Visitor<'de>,
5064 {
5065 visitor.visit_unit()
5066 }
5067
5068 fn deserialize_newtype_struct<V>(
5069 self,
5070 _name: &'static str,
5071 visitor: V,
5072 ) -> Result<V::Value, Self::Error>
5073 where
5074 V: Visitor<'de>,
5075 {
5076 visitor.visit_newtype_struct(self)
5077 }
5078
5079 fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value, Self::Error>
5080 where
5081 V: Visitor<'de>,
5082 {
5083 visitor.visit_seq(QuerySeqAccess::new(self.values))
5085 }
5086
5087 fn deserialize_tuple<V>(self, _len: usize, visitor: V) -> Result<V::Value, Self::Error>
5088 where
5089 V: Visitor<'de>,
5090 {
5091 visitor.visit_seq(QuerySeqAccess::new(self.values))
5092 }
5093
5094 fn deserialize_tuple_struct<V>(
5095 self,
5096 _name: &'static str,
5097 _len: usize,
5098 visitor: V,
5099 ) -> Result<V::Value, Self::Error>
5100 where
5101 V: Visitor<'de>,
5102 {
5103 visitor.visit_seq(QuerySeqAccess::new(self.values))
5104 }
5105
5106 fn deserialize_map<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
5107 where
5108 V: Visitor<'de>,
5109 {
5110 Err(QueryExtractError::DeserializeError {
5111 message: "cannot deserialize query field as map".to_string(),
5112 })
5113 }
5114
5115 fn deserialize_struct<V>(
5116 self,
5117 _name: &'static str,
5118 _fields: &'static [&'static str],
5119 visitor: V,
5120 ) -> Result<V::Value, Self::Error>
5121 where
5122 V: Visitor<'de>,
5123 {
5124 self.deserialize_map(visitor)
5125 }
5126
5127 fn deserialize_enum<V>(
5128 self,
5129 _name: &'static str,
5130 _variants: &'static [&'static str],
5131 visitor: V,
5132 ) -> Result<V::Value, Self::Error>
5133 where
5134 V: Visitor<'de>,
5135 {
5136 let value = self
5137 .values
5138 .first()
5139 .ok_or_else(|| QueryExtractError::MissingParam {
5140 name: self.name.to_string(),
5141 })?;
5142 visitor.visit_enum((*value).into_deserializer())
5143 }
5144
5145 fn deserialize_identifier<V>(self, visitor: V) -> Result<V::Value, Self::Error>
5146 where
5147 V: Visitor<'de>,
5148 {
5149 let value = self
5150 .values
5151 .first()
5152 .ok_or_else(|| QueryExtractError::MissingParam {
5153 name: self.name.to_string(),
5154 })?;
5155 visitor.visit_str(value)
5156 }
5157
5158 fn deserialize_ignored_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
5159 where
5160 V: Visitor<'de>,
5161 {
5162 visitor.visit_unit()
5163 }
5164}
5165
5166#[derive(Debug, Default, Clone)]
5197pub struct AppState {
5198 inner: std::sync::Arc<
5199 std::collections::HashMap<
5200 std::any::TypeId,
5201 std::sync::Arc<dyn std::any::Any + Send + Sync>,
5202 >,
5203 >,
5204}
5205
5206impl AppState {
5207 #[must_use]
5209 pub fn new() -> Self {
5210 Self {
5211 inner: std::sync::Arc::new(std::collections::HashMap::new()),
5212 }
5213 }
5214
5215 #[must_use]
5220 pub fn with<T: Send + Sync + 'static>(self, value: T) -> Self {
5221 let mut map = match std::sync::Arc::try_unwrap(self.inner) {
5222 Ok(map) => map,
5223 Err(arc) => (*arc).clone(),
5224 };
5225 map.insert(std::any::TypeId::of::<T>(), std::sync::Arc::new(value));
5226 Self {
5227 inner: std::sync::Arc::new(map),
5228 }
5229 }
5230
5231 #[must_use]
5233 pub fn get<T: Send + Sync + 'static>(&self) -> Option<&T> {
5234 self.inner
5235 .get(&std::any::TypeId::of::<T>())
5236 .and_then(|arc| arc.downcast_ref::<T>())
5237 }
5238
5239 #[must_use]
5241 pub fn contains<T: Send + Sync + 'static>(&self) -> bool {
5242 self.inner.contains_key(&std::any::TypeId::of::<T>())
5243 }
5244
5245 #[must_use]
5247 pub fn len(&self) -> usize {
5248 self.inner.len()
5249 }
5250
5251 #[must_use]
5253 pub fn is_empty(&self) -> bool {
5254 self.inner.is_empty()
5255 }
5256}
5257
5258#[derive(Debug, Clone)]
5287pub struct State<T>(pub T);
5288
5289impl<T> State<T> {
5290 pub fn into_inner(self) -> T {
5292 self.0
5293 }
5294}
5295
5296impl<T> Deref for State<T> {
5297 type Target = T;
5298
5299 fn deref(&self) -> &Self::Target {
5300 &self.0
5301 }
5302}
5303
5304impl<T> DerefMut for State<T> {
5305 fn deref_mut(&mut self) -> &mut Self::Target {
5306 &mut self.0
5307 }
5308}
5309
5310#[derive(Debug)]
5312pub enum StateExtractError {
5313 MissingAppState,
5317 MissingStateType {
5321 type_name: &'static str,
5323 },
5324}
5325
5326impl std::fmt::Display for StateExtractError {
5327 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
5328 match self {
5329 Self::MissingAppState => {
5330 write!(f, "Application state not configured in request")
5331 }
5332 Self::MissingStateType { type_name } => {
5333 write!(f, "State type not found: {type_name}")
5334 }
5335 }
5336 }
5337}
5338
5339impl std::error::Error for StateExtractError {}
5340
5341impl IntoResponse for StateExtractError {
5342 fn into_response(self) -> crate::response::Response {
5343 HttpError::internal()
5345 .with_detail(self.to_string())
5346 .into_response()
5347 }
5348}
5349
5350impl<T> FromRequest for State<T>
5351where
5352 T: Clone + Send + Sync + 'static,
5353{
5354 type Error = StateExtractError;
5355
5356 async fn from_request(_ctx: &RequestContext, req: &mut Request) -> Result<Self, Self::Error> {
5357 let app_state = req
5359 .get_extension::<AppState>()
5360 .ok_or(StateExtractError::MissingAppState)?;
5361
5362 let value = app_state
5364 .get::<T>()
5365 .ok_or(StateExtractError::MissingStateType {
5366 type_name: std::any::type_name::<T>(),
5367 })?;
5368
5369 Ok(State(value.clone()))
5370 }
5371}
5372
5373#[cfg(test)]
5374mod state_tests {
5375 use super::*;
5376 use crate::request::Method;
5377
5378 fn test_context() -> RequestContext {
5379 let cx = asupersync::Cx::for_testing();
5380 RequestContext::new(cx, 12345)
5381 }
5382
5383 #[derive(Clone, Debug, PartialEq)]
5384 struct DatabasePool {
5385 connection_string: String,
5386 }
5387
5388 #[derive(Clone, Debug, PartialEq)]
5389 struct AppConfig {
5390 debug: bool,
5391 port: u16,
5392 }
5393
5394 #[test]
5395 fn app_state_new_is_empty() {
5396 let state = AppState::new();
5397 assert!(state.is_empty());
5398 assert_eq!(state.len(), 0);
5399 }
5400
5401 #[test]
5402 fn app_state_with_single_type() {
5403 let db = DatabasePool {
5404 connection_string: "postgres://localhost".into(),
5405 };
5406 let state = AppState::new().with(db.clone());
5407
5408 assert!(!state.is_empty());
5409 assert_eq!(state.len(), 1);
5410 assert!(state.contains::<DatabasePool>());
5411 assert_eq!(state.get::<DatabasePool>(), Some(&db));
5412 }
5413
5414 #[test]
5415 fn app_state_with_multiple_types() {
5416 let db = DatabasePool {
5417 connection_string: "postgres://localhost".into(),
5418 };
5419 let config = AppConfig {
5420 debug: true,
5421 port: 8080,
5422 };
5423
5424 let state = AppState::new().with(db.clone()).with(config.clone());
5425
5426 assert_eq!(state.len(), 2);
5427 assert_eq!(state.get::<DatabasePool>(), Some(&db));
5428 assert_eq!(state.get::<AppConfig>(), Some(&config));
5429 }
5430
5431 #[test]
5432 fn app_state_get_missing_type() {
5433 let state = AppState::new().with(42i32);
5434 assert!(state.get::<String>().is_none());
5435 assert!(!state.contains::<String>());
5436 }
5437
5438 #[test]
5439 fn state_deref() {
5440 let state = State(42i32);
5441 assert_eq!(*state, 42);
5442 }
5443
5444 #[test]
5445 fn state_into_inner() {
5446 let state = State("hello".to_string());
5447 assert_eq!(state.into_inner(), "hello");
5448 }
5449
5450 #[test]
5451 fn state_extract_success() {
5452 let ctx = test_context();
5453 let db = DatabasePool {
5454 connection_string: "postgres://localhost".into(),
5455 };
5456 let app_state = AppState::new().with(db.clone());
5457
5458 let mut req = Request::new(Method::Get, "/test");
5459 req.insert_extension(app_state);
5460
5461 let result =
5462 futures_executor::block_on(State::<DatabasePool>::from_request(&ctx, &mut req));
5463 let State(extracted) = result.unwrap();
5464 assert_eq!(extracted, db);
5465 }
5466
5467 #[test]
5468 fn state_extract_multiple_types() {
5469 let ctx = test_context();
5470 let db = DatabasePool {
5471 connection_string: "postgres://localhost".into(),
5472 };
5473 let config = AppConfig {
5474 debug: true,
5475 port: 8080,
5476 };
5477 let app_state = AppState::new().with(db.clone()).with(config.clone());
5478
5479 let mut req = Request::new(Method::Get, "/test");
5480 req.insert_extension(app_state);
5481
5482 let result =
5484 futures_executor::block_on(State::<DatabasePool>::from_request(&ctx, &mut req));
5485 let State(extracted_db) = result.unwrap();
5486 assert_eq!(extracted_db, db);
5487
5488 let result = futures_executor::block_on(State::<AppConfig>::from_request(&ctx, &mut req));
5490 let State(extracted_config) = result.unwrap();
5491 assert_eq!(extracted_config, config);
5492 }
5493
5494 #[test]
5495 fn state_extract_missing_app_state() {
5496 let ctx = test_context();
5497 let mut req = Request::new(Method::Get, "/test");
5498 let result =
5501 futures_executor::block_on(State::<DatabasePool>::from_request(&ctx, &mut req));
5502 assert!(matches!(result, Err(StateExtractError::MissingAppState)));
5503 }
5504
5505 #[test]
5506 fn state_extract_missing_type() {
5507 let ctx = test_context();
5508 let app_state = AppState::new().with(42i32);
5509
5510 let mut req = Request::new(Method::Get, "/test");
5511 req.insert_extension(app_state);
5512
5513 let result =
5514 futures_executor::block_on(State::<DatabasePool>::from_request(&ctx, &mut req));
5515 assert!(matches!(
5516 result,
5517 Err(StateExtractError::MissingStateType { .. })
5518 ));
5519 }
5520
5521 #[test]
5522 fn state_error_display() {
5523 let err = StateExtractError::MissingAppState;
5524 assert!(err.to_string().contains("not configured"));
5525
5526 let err = StateExtractError::MissingStateType {
5527 type_name: "DatabasePool",
5528 };
5529 assert!(err.to_string().contains("DatabasePool"));
5530 }
5531
5532 #[test]
5533 fn app_state_clone() {
5534 let db = DatabasePool {
5535 connection_string: "postgres://localhost".into(),
5536 };
5537 let state1 = AppState::new().with(db.clone());
5538 let state2 = state1.clone();
5539
5540 assert_eq!(state2.get::<DatabasePool>(), Some(&db));
5541 }
5542
5543 #[test]
5544 fn state_with_arc() {
5545 use std::sync::Arc;
5546
5547 let ctx = test_context();
5548 let db = Arc::new(DatabasePool {
5549 connection_string: "postgres://localhost".into(),
5550 });
5551 let app_state = AppState::new().with(db.clone());
5552
5553 let mut req = Request::new(Method::Get, "/test");
5554 req.insert_extension(app_state);
5555
5556 let result =
5557 futures_executor::block_on(State::<Arc<DatabasePool>>::from_request(&ctx, &mut req));
5558 let State(extracted) = result.unwrap();
5559 assert_eq!(extracted.connection_string, "postgres://localhost");
5560 }
5561
5562 #[test]
5567 fn atomic_counter_fetch_add_concurrent() {
5568 use std::sync::Arc;
5570 use std::sync::atomic::{AtomicUsize, Ordering};
5571 use std::thread;
5572
5573 const NUM_THREADS: usize = 100;
5574 const INCREMENTS_PER_THREAD: usize = 1000;
5575
5576 let counter = Arc::new(AtomicUsize::new(0));
5577 let app_state = AppState::new().with(counter.clone());
5578
5579 let handles: Vec<_> = (0..NUM_THREADS)
5580 .map(|_| {
5581 let state = app_state.clone();
5582 thread::spawn(move || {
5583 let counter = state.get::<Arc<AtomicUsize>>().expect("Counter not found");
5585 for _ in 0..INCREMENTS_PER_THREAD {
5586 counter.fetch_add(1, Ordering::SeqCst);
5587 }
5588 })
5589 })
5590 .collect();
5591
5592 for handle in handles {
5594 handle.join().expect("Thread panicked");
5595 }
5596
5597 let final_value = counter.load(Ordering::SeqCst);
5599 let expected = NUM_THREADS * INCREMENTS_PER_THREAD;
5600 assert_eq!(
5601 final_value, expected,
5602 "Lost increments: expected {expected}, got {final_value}"
5603 );
5604 }
5605
5606 #[test]
5607 fn atomic_compare_and_swap_concurrent() {
5608 use std::sync::Arc;
5610 use std::sync::atomic::{AtomicUsize, Ordering};
5611 use std::thread;
5612
5613 const NUM_THREADS: usize = 50;
5614 const CAS_ATTEMPTS_PER_THREAD: usize = 100;
5615
5616 let counter = Arc::new(AtomicUsize::new(0));
5618 let success_count = Arc::new(AtomicUsize::new(0));
5619
5620 let handles: Vec<_> = (0..NUM_THREADS)
5621 .map(|_| {
5622 let counter = counter.clone();
5623 let success_count = success_count.clone();
5624 thread::spawn(move || {
5625 for _ in 0..CAS_ATTEMPTS_PER_THREAD {
5626 let mut current = counter.load(Ordering::SeqCst);
5628 loop {
5629 match counter.compare_exchange(
5630 current,
5631 current + 1,
5632 Ordering::SeqCst,
5633 Ordering::SeqCst,
5634 ) {
5635 Ok(_) => {
5636 success_count.fetch_add(1, Ordering::SeqCst);
5637 break;
5638 }
5639 Err(actual) => {
5640 current = actual;
5642 }
5643 }
5644 }
5645 }
5646 })
5647 })
5648 .collect();
5649
5650 for handle in handles {
5651 handle.join().expect("Thread panicked");
5652 }
5653
5654 let final_counter = counter.load(Ordering::SeqCst);
5655 let total_successes = success_count.load(Ordering::SeqCst);
5656 let expected = NUM_THREADS * CAS_ATTEMPTS_PER_THREAD;
5657
5658 assert_eq!(total_successes, expected);
5660 assert_eq!(final_counter, expected);
5662 }
5663
5664 #[test]
5665 fn atomic_state_concurrent_reads() {
5666 use std::sync::Arc;
5668 use std::sync::atomic::{AtomicUsize, Ordering};
5669 use std::thread;
5670
5671 const NUM_READERS: usize = 100;
5672 const READS_PER_THREAD: usize = 1000;
5673 const INITIAL_VALUE: usize = 42;
5674
5675 let counter = Arc::new(AtomicUsize::new(INITIAL_VALUE));
5676 let app_state = AppState::new().with(counter.clone());
5677
5678 let handles: Vec<_> = (0..NUM_READERS)
5679 .map(|_| {
5680 let state = app_state.clone();
5681 thread::spawn(move || {
5682 let counter = state.get::<Arc<AtomicUsize>>().expect("Counter not found");
5683 for _ in 0..READS_PER_THREAD {
5684 let value = counter.load(Ordering::SeqCst);
5685 assert_eq!(value, INITIAL_VALUE, "Value corrupted during read");
5686 }
5687 })
5688 })
5689 .collect();
5690
5691 for handle in handles {
5692 handle.join().expect("Thread panicked");
5693 }
5694
5695 assert_eq!(counter.load(Ordering::SeqCst), INITIAL_VALUE);
5697 }
5698
5699 #[test]
5700 fn atomic_rate_limiter_pattern() {
5701 use std::sync::Arc;
5703 use std::sync::atomic::{AtomicUsize, Ordering};
5704 use std::thread;
5705
5706 const MAX_REQUESTS: usize = 100;
5707 const NUM_CLIENTS: usize = 50;
5708 const REQUESTS_PER_CLIENT: usize = 5;
5709
5710 #[derive(Clone)]
5712 struct RateLimiter {
5713 current_count: Arc<AtomicUsize>,
5714 max_count: usize,
5715 }
5716
5717 impl RateLimiter {
5718 fn new(max: usize) -> Self {
5719 Self {
5720 current_count: Arc::new(AtomicUsize::new(0)),
5721 max_count: max,
5722 }
5723 }
5724
5725 fn try_acquire(&self) -> bool {
5726 let mut current = self.current_count.load(Ordering::SeqCst);
5727 loop {
5728 if current >= self.max_count {
5729 return false;
5730 }
5731 match self.current_count.compare_exchange(
5732 current,
5733 current + 1,
5734 Ordering::SeqCst,
5735 Ordering::SeqCst,
5736 ) {
5737 Ok(_) => return true,
5738 Err(actual) => current = actual,
5739 }
5740 }
5741 }
5742
5743 fn count(&self) -> usize {
5744 self.current_count.load(Ordering::SeqCst)
5745 }
5746 }
5747
5748 let limiter = RateLimiter::new(MAX_REQUESTS);
5749 let app_state = AppState::new().with(limiter.clone());
5750 let allowed_count = Arc::new(AtomicUsize::new(0));
5751 let denied_count = Arc::new(AtomicUsize::new(0));
5752
5753 let handles: Vec<_> = (0..NUM_CLIENTS)
5754 .map(|_| {
5755 let state = app_state.clone();
5756 let allowed = allowed_count.clone();
5757 let denied = denied_count.clone();
5758 thread::spawn(move || {
5759 let limiter = state.get::<RateLimiter>().expect("Limiter not found");
5760 for _ in 0..REQUESTS_PER_CLIENT {
5761 if limiter.try_acquire() {
5762 allowed.fetch_add(1, Ordering::SeqCst);
5763 } else {
5764 denied.fetch_add(1, Ordering::SeqCst);
5765 }
5766 }
5767 })
5768 })
5769 .collect();
5770
5771 for handle in handles {
5772 handle.join().expect("Thread panicked");
5773 }
5774
5775 let total_allowed = allowed_count.load(Ordering::SeqCst);
5776 let total_denied = denied_count.load(Ordering::SeqCst);
5777 let total_requests = NUM_CLIENTS * REQUESTS_PER_CLIENT;
5778
5779 assert_eq!(total_allowed, MAX_REQUESTS, "Allowed count mismatch");
5781 assert_eq!(
5783 total_denied,
5784 total_requests - MAX_REQUESTS,
5785 "Denied count mismatch"
5786 );
5787 assert!(limiter.count() <= MAX_REQUESTS, "Rate limiter exceeded max");
5789 }
5790
5791 #[test]
5792 fn atomic_concurrent_queue_pattern() {
5793 use std::sync::Arc;
5795 use std::sync::atomic::{AtomicUsize, Ordering};
5796 use std::thread;
5797
5798 const NUM_PRODUCERS: usize = 10;
5799 const ITEMS_PER_PRODUCER: usize = 100;
5800
5801 let _head = Arc::new(AtomicUsize::new(0));
5803 let tail = Arc::new(AtomicUsize::new(0));
5804 let produced_ids = Arc::new(parking_lot::Mutex::new(Vec::new()));
5805
5806 let handles: Vec<_> = (0..NUM_PRODUCERS)
5807 .map(|_| {
5808 let tail = tail.clone();
5809 let ids = produced_ids.clone();
5810 thread::spawn(move || {
5811 for _ in 0..ITEMS_PER_PRODUCER {
5812 let slot = tail.fetch_add(1, Ordering::SeqCst);
5814 ids.lock().push(slot);
5815 }
5816 })
5817 })
5818 .collect();
5819
5820 for handle in handles {
5821 handle.join().expect("Thread panicked");
5822 }
5823
5824 let produced = produced_ids.lock();
5825 let expected_count = NUM_PRODUCERS * ITEMS_PER_PRODUCER;
5826
5827 assert_eq!(produced.len(), expected_count);
5829
5830 let mut sorted: Vec<_> = produced.iter().copied().collect();
5832 sorted.sort_unstable();
5833 for (i, &id) in sorted.iter().enumerate() {
5834 assert_eq!(id, i, "Slot {i} missing or duplicated");
5835 }
5836
5837 assert_eq!(tail.load(Ordering::SeqCst), expected_count);
5839 }
5840
5841 #[test]
5846 fn concurrent_reads_basic_consistency() {
5847 use std::sync::Arc;
5849 use std::thread;
5850
5851 const NUM_READERS: usize = 100;
5852 const READS_PER_THREAD: usize = 100;
5853
5854 let config_value = "production";
5856 let port_value = 8080u16;
5857 let enabled_value = true;
5858
5859 let app_state = AppState::new()
5860 .with(config_value.to_string())
5861 .with(port_value)
5862 .with(enabled_value);
5863
5864 let error_count = Arc::new(std::sync::atomic::AtomicUsize::new(0));
5865
5866 let handles: Vec<_> = (0..NUM_READERS)
5867 .map(|_| {
5868 let state = app_state.clone();
5869 let errors = error_count.clone();
5870 thread::spawn(move || {
5871 for _ in 0..READS_PER_THREAD {
5872 let config = state.get::<String>();
5874 let port = state.get::<u16>();
5875 let enabled = state.get::<bool>();
5876
5877 if config.map(String::as_str) != Some(config_value) {
5879 errors.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
5880 }
5881 if port != Some(&port_value) {
5882 errors.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
5883 }
5884 if enabled != Some(&enabled_value) {
5885 errors.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
5886 }
5887 }
5888 })
5889 })
5890 .collect();
5891
5892 for handle in handles {
5893 handle.join().expect("Thread panicked");
5894 }
5895
5896 assert_eq!(
5897 error_count.load(std::sync::atomic::Ordering::SeqCst),
5898 0,
5899 "Some reads returned inconsistent values"
5900 );
5901 }
5902
5903 #[test]
5904 fn concurrent_reads_varying_payload_sizes() {
5905 use std::sync::Arc;
5907 use std::thread;
5908
5909 const NUM_READERS: usize = 50;
5910 const READS_PER_THREAD: usize = 50;
5911
5912 let small: i32 = 42;
5914 let medium: String = "a".repeat(1000);
5916 let large: Vec<u8> = vec![0u8; 10_000];
5918
5919 let app_state = AppState::new()
5920 .with(small)
5921 .with(medium.clone())
5922 .with(large.clone());
5923
5924 let error_count = Arc::new(std::sync::atomic::AtomicUsize::new(0));
5925
5926 let handles: Vec<_> = (0..NUM_READERS)
5927 .map(|_| {
5928 let state = app_state.clone();
5929 let _expected_medium = medium.clone();
5930 let expected_large = large.clone();
5931 let errors = error_count.clone();
5932 thread::spawn(move || {
5933 for _ in 0..READS_PER_THREAD {
5934 if state.get::<i32>() != Some(&42) {
5936 errors.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
5937 }
5938 if state.get::<String>().is_some_and(|s| s.len() != 1000) {
5939 errors.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
5940 }
5941 if state.get::<Vec<u8>>() != Some(&expected_large) {
5942 errors.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
5943 }
5944 }
5945 })
5946 })
5947 .collect();
5948
5949 for handle in handles {
5950 handle.join().expect("Thread panicked");
5951 }
5952
5953 assert_eq!(
5954 error_count.load(std::sync::atomic::Ordering::SeqCst),
5955 0,
5956 "Payload size affected read consistency"
5957 );
5958 }
5959
5960 #[test]
5961 fn concurrent_reads_nested_structures() {
5962 use std::sync::Arc;
5964 use std::thread;
5965
5966 const NUM_READERS: usize = 50;
5967
5968 #[derive(Clone, Debug, PartialEq)]
5969 struct OuterConfig {
5970 inner: InnerConfig,
5971 name: String,
5972 }
5973
5974 #[derive(Clone, Debug, PartialEq)]
5975 struct InnerConfig {
5976 values: Vec<i32>,
5977 enabled: bool,
5978 }
5979
5980 let nested = OuterConfig {
5981 inner: InnerConfig {
5982 values: vec![1, 2, 3, 4, 5],
5983 enabled: true,
5984 },
5985 name: "nested_test".to_string(),
5986 };
5987
5988 let app_state = AppState::new().with(nested.clone());
5989 let error_count = Arc::new(std::sync::atomic::AtomicUsize::new(0));
5990
5991 let handles: Vec<_> = (0..NUM_READERS)
5992 .map(|_| {
5993 let state = app_state.clone();
5994 let expected = nested.clone();
5995 let errors = error_count.clone();
5996 thread::spawn(move || {
5997 for _ in 0..100 {
5998 let read = state.get::<OuterConfig>();
5999 if read != Some(&expected) {
6000 errors.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
6001 }
6002 if let Some(outer) = read {
6004 if outer.inner.values.len() != 5 {
6005 errors.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
6006 }
6007 }
6008 }
6009 })
6010 })
6011 .collect();
6012
6013 for handle in handles {
6014 handle.join().expect("Thread panicked");
6015 }
6016
6017 assert_eq!(
6018 error_count.load(std::sync::atomic::Ordering::SeqCst),
6019 0,
6020 "Nested structure reads inconsistent"
6021 );
6022 }
6023
6024 #[test]
6025 #[allow(clippy::cast_possible_wrap)]
6026 fn concurrent_reads_with_arc_rwlock_pattern() {
6027 use parking_lot::RwLock;
6029 use std::sync::Arc;
6030 use std::thread;
6031
6032 const NUM_READERS: usize = 80;
6033 const NUM_WRITERS: usize = 20;
6034 const OPS_PER_THREAD: usize = 100;
6035
6036 #[derive(Clone)]
6037 struct MutableState {
6038 data: Arc<RwLock<Vec<i32>>>,
6039 }
6040
6041 impl MutableState {
6042 fn new() -> Self {
6043 Self {
6044 data: Arc::new(RwLock::new(Vec::new())),
6045 }
6046 }
6047
6048 fn push(&self, value: i32) {
6049 self.data.write().push(value);
6050 }
6051
6052 fn len(&self) -> usize {
6053 self.data.read().len()
6054 }
6055 }
6056
6057 let mutable_state = MutableState::new();
6058 let app_state = AppState::new().with(mutable_state.clone());
6059 let read_count = Arc::new(std::sync::atomic::AtomicUsize::new(0));
6060 let write_count = Arc::new(std::sync::atomic::AtomicUsize::new(0));
6061
6062 let reader_handles: Vec<_> = (0..NUM_READERS)
6064 .map(|_| {
6065 let state = app_state.clone();
6066 let reads = read_count.clone();
6067 thread::spawn(move || {
6068 for _ in 0..OPS_PER_THREAD {
6069 let ms = state.get::<MutableState>().expect("State not found");
6070 let _ = ms.len(); reads.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
6072 }
6073 })
6074 })
6075 .collect();
6076
6077 let writer_handles: Vec<_> = (0..NUM_WRITERS)
6079 .map(|i| {
6080 let state = app_state.clone();
6081 let writes = write_count.clone();
6082 thread::spawn(move || {
6083 for j in 0..OPS_PER_THREAD {
6084 let ms = state.get::<MutableState>().expect("State not found");
6085 ms.push((i * OPS_PER_THREAD + j) as i32);
6086 writes.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
6087 }
6088 })
6089 })
6090 .collect();
6091
6092 for handle in reader_handles {
6094 handle.join().expect("Reader thread panicked");
6095 }
6096 for handle in writer_handles {
6097 handle.join().expect("Writer thread panicked");
6098 }
6099
6100 let total_reads = read_count.load(std::sync::atomic::Ordering::SeqCst);
6102 let total_writes = write_count.load(std::sync::atomic::Ordering::SeqCst);
6103
6104 assert_eq!(total_reads, NUM_READERS * OPS_PER_THREAD);
6105 assert_eq!(total_writes, NUM_WRITERS * OPS_PER_THREAD);
6106
6107 assert_eq!(mutable_state.len(), NUM_WRITERS * OPS_PER_THREAD);
6109 }
6110}
6111
6112#[derive(Debug, Clone)]
6131pub struct RequestRef {
6132 method: crate::request::Method,
6133 path: String,
6134 query: Option<String>,
6135 headers: Vec<(String, Vec<u8>)>,
6136}
6137
6138impl RequestRef {
6139 #[must_use]
6141 pub fn method(&self) -> crate::request::Method {
6142 self.method
6143 }
6144
6145 #[must_use]
6147 pub fn path(&self) -> &str {
6148 &self.path
6149 }
6150
6151 #[must_use]
6153 pub fn query(&self) -> Option<&str> {
6154 self.query.as_deref()
6155 }
6156
6157 #[must_use]
6159 pub fn header(&self, name: &str) -> Option<&[u8]> {
6160 let name_lower = name.to_ascii_lowercase();
6161 self.headers
6162 .iter()
6163 .find(|(n, _)| n.to_ascii_lowercase() == name_lower)
6164 .map(|(_, v)| v.as_slice())
6165 }
6166
6167 pub fn headers(&self) -> impl Iterator<Item = (&str, &[u8])> {
6169 self.headers.iter().map(|(n, v)| (n.as_str(), v.as_slice()))
6170 }
6171}
6172
6173impl FromRequest for RequestRef {
6174 type Error = std::convert::Infallible;
6175
6176 async fn from_request(_ctx: &RequestContext, req: &mut Request) -> Result<Self, Self::Error> {
6177 Ok(RequestRef {
6178 method: req.method(),
6179 path: req.path().to_string(),
6180 query: req.query().map(String::from),
6181 headers: req
6182 .headers()
6183 .iter()
6184 .map(|(name, value)| (name.to_string(), value.to_vec()))
6185 .collect(),
6186 })
6187 }
6188}
6189
6190#[derive(Debug, Clone, Default)]
6209pub struct ResponseMutations {
6210 pub headers: Vec<(String, Vec<u8>)>,
6212 pub cookies: Vec<Cookie>,
6214 pub delete_cookies: Vec<String>,
6216}
6217
6218impl ResponseMutations {
6219 #[must_use]
6221 pub fn new() -> Self {
6222 Self::default()
6223 }
6224
6225 pub fn add_header(&mut self, name: impl Into<String>, value: impl Into<Vec<u8>>) {
6227 self.headers.push((name.into(), value.into()));
6228 }
6229
6230 pub fn add_cookie(&mut self, cookie: Cookie) {
6232 self.cookies.push(cookie);
6233 }
6234
6235 pub fn remove_cookie(&mut self, name: impl Into<String>) {
6237 self.delete_cookies.push(name.into());
6238 }
6239
6240 #[must_use]
6242 pub fn apply(self, mut response: crate::response::Response) -> crate::response::Response {
6243 for (name, value) in self.headers {
6245 response = response.header(name, value);
6246 }
6247
6248 for cookie in self.cookies {
6250 response = response.header("Set-Cookie", cookie.to_header_value().into_bytes());
6251 }
6252
6253 for name in self.delete_cookies {
6255 let sanitized_name = sanitize_cookie_token(&name);
6257 let delete_cookie = format!("{}=; Max-Age=0; Path=/", sanitized_name);
6258 response = response.header("Set-Cookie", delete_cookie.into_bytes());
6259 }
6260
6261 response
6262 }
6263}
6264
6265fn sanitize_cookie_token(name: &str) -> String {
6274 name.chars()
6275 .filter(|&c| {
6276 c.is_ascii()
6279 && !c.is_ascii_control()
6280 && c != ' '
6281 && c != '"'
6282 && c != '('
6283 && c != ')'
6284 && c != ','
6285 && c != '/'
6286 && c != ':'
6287 && c != ';'
6288 && c != '<'
6289 && c != '='
6290 && c != '>'
6291 && c != '?'
6292 && c != '@'
6293 && c != '['
6294 && c != '\\'
6295 && c != ']'
6296 && c != '{'
6297 && c != '}'
6298 })
6299 .collect()
6300}
6301
6302fn sanitize_cookie_value(value: &str) -> String {
6307 value
6308 .chars()
6309 .filter(|&c| {
6310 c.is_ascii()
6311 && !c.is_ascii_control()
6312 && c != ' '
6313 && c != '"'
6314 && c != ','
6315 && c != ';'
6316 && c != '\\'
6317 })
6318 .collect()
6319}
6320
6321fn sanitize_cookie_attr(attr: &str) -> String {
6325 attr.chars()
6326 .filter(|&c| c != ';' && c != '\r' && c != '\n' && c != '\0')
6327 .collect()
6328}
6329
6330#[derive(Debug, Clone)]
6332pub struct Cookie {
6333 pub name: String,
6335 pub value: String,
6337 pub max_age: Option<i64>,
6339 pub path: Option<String>,
6341 pub domain: Option<String>,
6343 pub secure: bool,
6345 pub http_only: bool,
6347 pub same_site: Option<SameSite>,
6349}
6350
6351impl Cookie {
6352 #[must_use]
6354 pub fn new(name: impl Into<String>, value: impl Into<String>) -> Self {
6355 Self {
6356 name: name.into(),
6357 value: value.into(),
6358 max_age: None,
6359 path: None,
6360 domain: None,
6361 secure: false,
6362 http_only: false,
6363 same_site: None,
6364 }
6365 }
6366
6367 #[must_use]
6369 pub fn max_age(mut self, seconds: i64) -> Self {
6370 self.max_age = Some(seconds);
6371 self
6372 }
6373
6374 #[must_use]
6376 pub fn path(mut self, path: impl Into<String>) -> Self {
6377 self.path = Some(path.into());
6378 self
6379 }
6380
6381 #[must_use]
6383 pub fn domain(mut self, domain: impl Into<String>) -> Self {
6384 self.domain = Some(domain.into());
6385 self
6386 }
6387
6388 #[must_use]
6390 pub fn secure(mut self, secure: bool) -> Self {
6391 self.secure = secure;
6392 self
6393 }
6394
6395 #[must_use]
6397 pub fn http_only(mut self, http_only: bool) -> Self {
6398 self.http_only = http_only;
6399 self
6400 }
6401
6402 #[must_use]
6404 pub fn same_site(mut self, same_site: SameSite) -> Self {
6405 self.same_site = Some(same_site);
6406 self
6407 }
6408
6409 #[must_use]
6417 pub fn to_header_value(&self) -> String {
6418 let sanitized_name = sanitize_cookie_token(&self.name);
6421
6422 let sanitized_value = sanitize_cookie_value(&self.value);
6425
6426 let mut parts = vec![format!("{}={}", sanitized_name, sanitized_value)];
6427
6428 if let Some(max_age) = self.max_age {
6429 parts.push(format!("Max-Age={}", max_age));
6430 }
6431 if let Some(ref path) = self.path {
6432 let sanitized_path = sanitize_cookie_attr(path);
6434 parts.push(format!("Path={}", sanitized_path));
6435 }
6436 if let Some(ref domain) = self.domain {
6437 let sanitized_domain = sanitize_cookie_attr(domain);
6439 parts.push(format!("Domain={}", sanitized_domain));
6440 }
6441 if self.secure {
6442 parts.push("Secure".to_string());
6443 }
6444 if self.http_only {
6445 parts.push("HttpOnly".to_string());
6446 }
6447 if let Some(ref same_site) = self.same_site {
6448 parts.push(format!("SameSite={}", same_site.as_str()));
6449 }
6450
6451 parts.join("; ")
6452 }
6453
6454 #[must_use]
6484 pub fn session(name: impl Into<String>, value: impl Into<String>, production: bool) -> Self {
6485 Self::new(name, value)
6486 .http_only(true)
6487 .secure(production)
6488 .same_site(SameSite::Lax)
6489 .path("/")
6490 }
6491
6492 #[must_use]
6517 pub fn auth(name: impl Into<String>, value: impl Into<String>, production: bool) -> Self {
6518 Self::new(name, value)
6519 .http_only(true)
6520 .secure(production)
6521 .same_site(SameSite::Strict)
6522 .path("/")
6523 }
6524
6525 #[must_use]
6550 pub fn csrf(name: impl Into<String>, value: impl Into<String>, production: bool) -> Self {
6551 Self::new(name, value)
6552 .http_only(false)
6553 .secure(production)
6554 .same_site(SameSite::Strict)
6555 .path("/")
6556 }
6557
6558 #[must_use]
6584 pub fn host_prefixed(name: impl Into<String>, value: impl Into<String>) -> Self {
6585 let prefixed_name = format!("__Host-{}", name.into());
6586 Self::new(prefixed_name, value).secure(true).path("/")
6587 }
6588
6589 #[must_use]
6612 pub fn secure_prefixed(name: impl Into<String>, value: impl Into<String>) -> Self {
6613 let prefixed_name = format!("__Secure-{}", name.into());
6614 Self::new(prefixed_name, value).secure(true)
6615 }
6616
6617 pub fn validate_prefix(&self) -> Result<(), CookiePrefixError> {
6640 if self.name.starts_with("__Host-") {
6641 if !self.secure {
6642 return Err(CookiePrefixError::HostRequiresSecure);
6643 }
6644 if self.domain.is_some() {
6645 return Err(CookiePrefixError::HostCannotHaveDomain);
6646 }
6647 if self.path.as_deref() != Some("/") {
6648 return Err(CookiePrefixError::HostRequiresRootPath);
6649 }
6650 } else if self.name.starts_with("__Secure-") && !self.secure {
6651 return Err(CookiePrefixError::SecureRequiresSecure);
6652 }
6653 Ok(())
6654 }
6655
6656 #[must_use]
6658 pub fn has_security_prefix(&self) -> bool {
6659 self.name.starts_with("__Host-") || self.name.starts_with("__Secure-")
6660 }
6661
6662 #[must_use]
6664 pub fn prefix(&self) -> Option<CookiePrefix> {
6665 if self.name.starts_with("__Host-") {
6666 Some(CookiePrefix::Host)
6667 } else if self.name.starts_with("__Secure-") {
6668 Some(CookiePrefix::Secure)
6669 } else {
6670 None
6671 }
6672 }
6673}
6674
6675#[derive(Debug, Clone, Copy, PartialEq, Eq)]
6681pub enum CookiePrefix {
6682 Host,
6686 Secure,
6690}
6691
6692impl CookiePrefix {
6693 #[must_use]
6695 pub const fn as_str(&self) -> &'static str {
6696 match self {
6697 Self::Host => "__Host-",
6698 Self::Secure => "__Secure-",
6699 }
6700 }
6701}
6702
6703#[derive(Debug, Clone, PartialEq, Eq)]
6705pub enum CookiePrefixError {
6706 HostRequiresSecure,
6708 HostCannotHaveDomain,
6710 HostRequiresRootPath,
6712 SecureRequiresSecure,
6714}
6715
6716impl std::fmt::Display for CookiePrefixError {
6717 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
6718 match self {
6719 Self::HostRequiresSecure => {
6720 write!(f, "__Host- prefix requires Secure flag to be true")
6721 }
6722 Self::HostCannotHaveDomain => {
6723 write!(f, "__Host- prefix cannot have a Domain attribute")
6724 }
6725 Self::HostRequiresRootPath => {
6726 write!(f, "__Host- prefix requires Path=\"/\"")
6727 }
6728 Self::SecureRequiresSecure => {
6729 write!(f, "__Secure- prefix requires Secure flag to be true")
6730 }
6731 }
6732 }
6733}
6734
6735impl std::error::Error for CookiePrefixError {}
6736
6737#[derive(Debug, Clone, Copy, PartialEq, Eq)]
6739pub enum SameSite {
6740 Strict,
6742 Lax,
6744 None,
6746}
6747
6748impl SameSite {
6749 #[must_use]
6751 pub const fn as_str(&self) -> &'static str {
6752 match self {
6753 Self::Strict => "Strict",
6754 Self::Lax => "Lax",
6755 Self::None => "None",
6756 }
6757 }
6758}
6759
6760#[derive(Debug, Clone, Default)]
6782pub struct RequestCookies {
6783 cookies: std::collections::HashMap<String, String>,
6784}
6785
6786impl RequestCookies {
6787 #[must_use]
6789 pub fn new() -> Self {
6790 Self::default()
6791 }
6792
6793 #[must_use]
6795 pub fn from_header(header_value: &str) -> Self {
6796 let mut cookies = std::collections::HashMap::new();
6797
6798 for pair in header_value.split(';') {
6800 let pair = pair.trim();
6801 if let Some((name, value)) = pair.split_once('=') {
6802 let name = name.trim().to_string();
6803 let value = value.trim().to_string();
6804 if !name.is_empty() {
6805 cookies.insert(name, value);
6806 }
6807 }
6808 }
6809
6810 Self { cookies }
6811 }
6812
6813 #[must_use]
6815 pub fn get(&self, name: &str) -> Option<&str> {
6816 self.cookies.get(name).map(String::as_str)
6817 }
6818
6819 #[must_use]
6821 pub fn contains(&self, name: &str) -> bool {
6822 self.cookies.contains_key(name)
6823 }
6824
6825 #[must_use]
6827 pub fn len(&self) -> usize {
6828 self.cookies.len()
6829 }
6830
6831 #[must_use]
6833 pub fn is_empty(&self) -> bool {
6834 self.cookies.is_empty()
6835 }
6836
6837 pub fn iter(&self) -> impl Iterator<Item = (&str, &str)> {
6839 self.cookies.iter().map(|(k, v)| (k.as_str(), v.as_str()))
6840 }
6841
6842 pub fn names(&self) -> impl Iterator<Item = &str> {
6844 self.cookies.keys().map(String::as_str)
6845 }
6846}
6847
6848impl FromRequest for RequestCookies {
6849 type Error = std::convert::Infallible;
6850
6851 async fn from_request(_ctx: &RequestContext, req: &mut Request) -> Result<Self, Self::Error> {
6852 let cookies = req
6853 .headers()
6854 .get("cookie")
6855 .and_then(|v| std::str::from_utf8(v).ok())
6856 .map(Self::from_header)
6857 .unwrap_or_default();
6858
6859 Ok(cookies)
6860 }
6861}
6862
6863#[derive(Debug, Clone)]
6892pub struct RequestCookie<T> {
6893 value: String,
6894 _marker: std::marker::PhantomData<T>,
6895}
6896
6897impl<T> RequestCookie<T> {
6898 #[must_use]
6900 pub fn new(value: impl Into<String>) -> Self {
6901 Self {
6902 value: value.into(),
6903 _marker: std::marker::PhantomData,
6904 }
6905 }
6906
6907 #[must_use]
6909 pub fn value(&self) -> &str {
6910 &self.value
6911 }
6912
6913 #[must_use]
6915 pub fn into_value(self) -> String {
6916 self.value
6917 }
6918}
6919
6920impl<T> Deref for RequestCookie<T> {
6921 type Target = str;
6922
6923 fn deref(&self) -> &Self::Target {
6924 &self.value
6925 }
6926}
6927
6928impl<T> AsRef<str> for RequestCookie<T> {
6929 fn as_ref(&self) -> &str {
6930 &self.value
6931 }
6932}
6933
6934pub trait CookieName {
6947 const NAME: &'static str;
6949}
6950
6951#[derive(Debug)]
6953pub enum CookieExtractError {
6954 NotFound {
6956 name: &'static str,
6958 },
6959 InvalidValue {
6961 name: &'static str,
6963 value: String,
6965 expected: &'static str,
6967 },
6968}
6969
6970impl fmt::Display for CookieExtractError {
6971 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
6972 match self {
6973 Self::NotFound { name } => {
6974 write!(f, "Cookie '{}' not found", name)
6975 }
6976 Self::InvalidValue {
6977 name,
6978 value,
6979 expected,
6980 } => {
6981 write!(
6982 f,
6983 "Invalid cookie '{}' value '{}': expected {}",
6984 name, value, expected
6985 )
6986 }
6987 }
6988 }
6989}
6990
6991impl std::error::Error for CookieExtractError {}
6992
6993impl IntoResponse for CookieExtractError {
6994 fn into_response(self) -> crate::response::Response {
6995 match self {
6996 Self::NotFound { name } => ValidationErrors::single(
6997 ValidationError::missing(crate::error::loc::cookie(name))
6998 .with_msg("Cookie is required"),
6999 )
7000 .into_response(),
7001 Self::InvalidValue {
7002 name,
7003 value,
7004 expected,
7005 } => ValidationErrors::single(
7006 ValidationError::type_error(crate::error::loc::cookie(name), expected)
7007 .with_msg(format!("Expected {expected}"))
7008 .with_input(serde_json::Value::String(value)),
7009 )
7010 .into_response(),
7011 }
7012 }
7013}
7014
7015impl<T: CookieName> FromRequest for RequestCookie<T> {
7016 type Error = CookieExtractError;
7017
7018 async fn from_request(_ctx: &RequestContext, req: &mut Request) -> Result<Self, Self::Error> {
7019 let cookies = req
7020 .headers()
7021 .get("cookie")
7022 .and_then(|v| std::str::from_utf8(v).ok())
7023 .map(RequestCookies::from_header)
7024 .unwrap_or_default();
7025
7026 cookies
7027 .get(T::NAME)
7028 .map(|v| RequestCookie::new(v))
7029 .ok_or(CookieExtractError::NotFound { name: T::NAME })
7030 }
7031}
7032
7033pub struct SessionIdCookie;
7037impl CookieName for SessionIdCookie {
7038 const NAME: &'static str = "session_id";
7039}
7040
7041pub struct CsrfTokenCookie;
7043impl CookieName for CsrfTokenCookie {
7044 const NAME: &'static str = "csrf_token";
7045}
7046
7047pub struct ResponseMut<'a> {
7069 mutations: &'a mut ResponseMutations,
7070}
7071
7072impl<'a> ResponseMut<'a> {
7073 pub fn header(&mut self, name: impl Into<String>, value: impl Into<Vec<u8>>) {
7075 self.mutations.add_header(name, value);
7076 }
7077
7078 pub fn cookie(&mut self, name: impl Into<String>, value: impl Into<String>) {
7080 self.mutations.add_cookie(Cookie::new(name, value));
7081 }
7082
7083 pub fn set_cookie(&mut self, cookie: Cookie) {
7085 self.mutations.add_cookie(cookie);
7086 }
7087
7088 pub fn delete_cookie(&mut self, name: impl Into<String>) {
7090 self.mutations.remove_cookie(name);
7091 }
7092}
7093
7094impl<'a> std::fmt::Debug for ResponseMut<'a> {
7095 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
7096 f.debug_struct("ResponseMut")
7097 .field("mutations", &self.mutations)
7098 .finish()
7099 }
7100}
7101
7102impl FromRequest for ResponseMutations {
7107 type Error = std::convert::Infallible;
7108
7109 async fn from_request(_ctx: &RequestContext, req: &mut Request) -> Result<Self, Self::Error> {
7110 if let Some(mutations) = req.get_extension::<ResponseMutations>() {
7112 Ok(mutations.clone())
7113 } else {
7114 let mutations = ResponseMutations::new();
7115 req.insert_extension(mutations.clone());
7116 Ok(mutations)
7117 }
7118 }
7119}
7120
7121use std::sync::Arc;
7126
7127fn format_panic_message(panic_info: &Box<dyn std::any::Any + Send>) -> String {
7132 if let Some(s) = panic_info.downcast_ref::<&str>() {
7133 (*s).to_string()
7134 } else if let Some(s) = panic_info.downcast_ref::<String>() {
7135 s.clone()
7136 } else {
7137 "unknown panic".to_string()
7138 }
7139}
7140
7141pub type BackgroundTask =
7143 Box<dyn FnOnce() -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>> + Send>;
7144
7145#[derive(Default, Clone)]
7149pub struct BackgroundTasksInner {
7150 inner: Arc<parking_lot::Mutex<Vec<BackgroundTask>>>,
7151}
7152
7153impl BackgroundTasksInner {
7154 #[must_use]
7156 pub fn new() -> Self {
7157 Self {
7158 inner: Arc::new(parking_lot::Mutex::new(Vec::new())),
7159 }
7160 }
7161
7162 pub fn push(&self, task: BackgroundTask) {
7164 self.inner.lock().push(task);
7165 }
7166
7167 pub fn take(&self) -> Vec<BackgroundTask> {
7169 std::mem::take(&mut *self.inner.lock())
7170 }
7171
7172 #[must_use]
7174 pub fn len(&self) -> usize {
7175 self.inner.lock().len()
7176 }
7177
7178 #[must_use]
7180 pub fn is_empty(&self) -> bool {
7181 self.inner.lock().is_empty()
7182 }
7183}
7184
7185impl std::fmt::Debug for BackgroundTasksInner {
7186 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
7187 f.debug_struct("BackgroundTasksInner")
7188 .field("task_count", &self.len())
7189 .finish()
7190 }
7191}
7192
7193#[derive(Clone)]
7222pub struct BackgroundTasks {
7223 inner: BackgroundTasksInner,
7224}
7225
7226impl Default for BackgroundTasks {
7227 fn default() -> Self {
7228 Self::new()
7229 }
7230}
7231
7232impl BackgroundTasks {
7233 #[must_use]
7235 pub fn new() -> Self {
7236 Self {
7237 inner: BackgroundTasksInner::new(),
7238 }
7239 }
7240
7241 #[must_use]
7243 pub(crate) fn from_inner(inner: BackgroundTasksInner) -> Self {
7244 Self { inner }
7245 }
7246
7247 pub fn add_task<F, Fut>(&mut self, task: F)
7251 where
7252 F: FnOnce() -> Fut + Send + 'static,
7253 Fut: std::future::Future<Output = ()> + Send + 'static,
7254 {
7255 self.inner.push(Box::new(move || Box::pin(task())));
7256 }
7257
7258 pub fn add_sync_task<F>(&mut self, task: F)
7262 where
7263 F: FnOnce() + Send + 'static,
7264 {
7265 self.inner.push(Box::new(move || {
7266 Box::pin(async move {
7267 task();
7268 })
7269 }));
7270 }
7271
7272 pub fn take_tasks(&mut self) -> Vec<BackgroundTask> {
7274 self.inner.take()
7275 }
7276
7277 #[must_use]
7279 pub fn is_empty(&self) -> bool {
7280 self.inner.is_empty()
7281 }
7282
7283 #[must_use]
7285 pub fn len(&self) -> usize {
7286 self.inner.len()
7287 }
7288
7289 pub async fn execute_all(mut self) {
7313 for task in self.take_tasks() {
7314 let future = task();
7315 future.await;
7316 }
7317 }
7318
7319 pub async fn execute_with_context(mut self, ctx: &RequestContext) {
7352 let tasks = self.take_tasks();
7353 let task_count = tasks.len();
7354 let mut executed_count = 0;
7355
7356 for (index, task) in tasks.into_iter().enumerate() {
7357 if ctx.is_cancelled() {
7359 let remaining = task_count - index;
7360 if remaining > 0 {
7361 ctx.trace(&format!(
7362 "BackgroundTasks: Cancellation requested, skipping {} remaining task(s)",
7363 remaining
7364 ));
7365 }
7366 break;
7367 }
7368
7369 let future = task();
7371 future.await;
7372 executed_count += 1;
7373 }
7374
7375 if task_count > 0 {
7377 ctx.trace(&format!(
7378 "BackgroundTasks: Executed {}/{} tasks",
7379 executed_count, task_count
7380 ));
7381 }
7382 }
7383
7384 pub async fn execute_with_panic_isolation(mut self) {
7410 let tasks = self.take_tasks();
7411 let task_count = tasks.len();
7412 let mut success_count = 0;
7413 let mut panic_count = 0;
7414
7415 for (index, task) in tasks.into_iter().enumerate() {
7416 let future_result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| task()));
7418
7419 match future_result {
7420 Ok(future) => {
7421 future.await;
7424 success_count += 1;
7425 }
7426 Err(panic_info) => {
7427 panic_count += 1;
7429 let panic_msg = format_panic_message(&panic_info);
7430 eprintln!(
7431 "[BackgroundTasks] Task {}/{} panicked: {}",
7432 index + 1,
7433 task_count,
7434 panic_msg
7435 );
7436 }
7437 }
7438 }
7439
7440 if panic_count > 0 {
7441 eprintln!(
7442 "[BackgroundTasks] Completed with {}/{} successful, {} panicked",
7443 success_count, task_count, panic_count
7444 );
7445 }
7446 }
7447
7448 pub fn into_inner(self) -> BackgroundTasksInner {
7450 self.inner
7451 }
7452}
7453
7454impl std::fmt::Debug for BackgroundTasks {
7455 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
7456 f.debug_struct("BackgroundTasks")
7457 .field("task_count", &self.len())
7458 .finish()
7459 }
7460}
7461
7462impl FromRequest for BackgroundTasks {
7463 type Error = std::convert::Infallible;
7464
7465 async fn from_request(_ctx: &RequestContext, req: &mut Request) -> Result<Self, Self::Error> {
7466 if let Some(inner) = req.get_extension::<BackgroundTasksInner>() {
7468 Ok(BackgroundTasks::from_inner(inner.clone()))
7469 } else {
7470 let inner = BackgroundTasksInner::new();
7471 req.insert_extension(inner.clone());
7472 Ok(BackgroundTasks::from_inner(inner))
7473 }
7474 }
7475}
7476
7477#[cfg(test)]
7478mod special_extractor_tests {
7479 use super::*;
7480 use crate::request::Method;
7481
7482 fn test_context() -> RequestContext {
7483 let cx = asupersync::Cx::for_testing();
7484 RequestContext::new(cx, 12345)
7485 }
7486
7487 #[test]
7488 fn request_ref_extracts_metadata() {
7489 let ctx = test_context();
7490 let mut req = Request::new(Method::Get, "/users/42");
7491 req.set_query(Some("page=1".to_string()));
7492 req.headers_mut()
7493 .insert("content-type", b"application/json".to_vec());
7494
7495 let result = futures_executor::block_on(RequestRef::from_request(&ctx, &mut req));
7496 let req_ref = result.unwrap();
7497
7498 assert_eq!(req_ref.method(), Method::Get);
7499 assert_eq!(req_ref.path(), "/users/42");
7500 assert_eq!(req_ref.query(), Some("page=1"));
7501 assert_eq!(
7502 req_ref.header("content-type"),
7503 Some(b"application/json".as_slice())
7504 );
7505 }
7506
7507 #[test]
7508 fn request_ref_header_case_insensitive() {
7509 let ctx = test_context();
7510 let mut req = Request::new(Method::Get, "/");
7511 req.headers_mut()
7512 .insert("X-Custom-Header", b"value".to_vec());
7513
7514 let result = futures_executor::block_on(RequestRef::from_request(&ctx, &mut req));
7515 let req_ref = result.unwrap();
7516
7517 assert_eq!(req_ref.header("x-custom-header"), Some(b"value".as_slice()));
7518 assert_eq!(req_ref.header("X-CUSTOM-HEADER"), Some(b"value".as_slice()));
7519 }
7520
7521 #[test]
7522 fn cookie_to_header_value_simple() {
7523 let cookie = Cookie::new("session", "abc123");
7524 assert_eq!(cookie.to_header_value(), "session=abc123");
7525 }
7526
7527 #[test]
7528 fn cookie_to_header_value_with_attributes() {
7529 let cookie = Cookie::new("session", "abc123")
7530 .max_age(3600)
7531 .path("/")
7532 .secure(true)
7533 .http_only(true)
7534 .same_site(SameSite::Strict);
7535
7536 let header = cookie.to_header_value();
7537 assert!(header.contains("session=abc123"));
7538 assert!(header.contains("Max-Age=3600"));
7539 assert!(header.contains("Path=/"));
7540 assert!(header.contains("Secure"));
7541 assert!(header.contains("HttpOnly"));
7542 assert!(header.contains("SameSite=Strict"));
7543 }
7544
7545 #[test]
7546 fn response_mutations_apply_headers() {
7547 let mut mutations = ResponseMutations::new();
7548 mutations.add_header("X-Custom", "value");
7549 mutations.add_header("X-Another", "other");
7550
7551 let response = crate::response::Response::ok();
7552 let response = mutations.apply(response);
7553
7554 let headers = response.headers();
7555 assert!(
7556 headers
7557 .iter()
7558 .any(|(n, v)| n == "X-Custom" && v == b"value")
7559 );
7560 assert!(
7561 headers
7562 .iter()
7563 .any(|(n, v)| n == "X-Another" && v == b"other")
7564 );
7565 }
7566
7567 #[test]
7568 fn response_mutations_apply_cookies() {
7569 let mut mutations = ResponseMutations::new();
7570 mutations.add_cookie(Cookie::new("session", "abc").http_only(true));
7571
7572 let response = crate::response::Response::ok();
7573 let response = mutations.apply(response);
7574
7575 let headers = response.headers();
7576 let set_cookie = headers
7577 .iter()
7578 .find(|(n, _)| n == "Set-Cookie")
7579 .map(|(_, v)| String::from_utf8_lossy(v).to_string());
7580 assert!(set_cookie.is_some());
7581 assert!(set_cookie.unwrap().contains("session=abc"));
7582 }
7583
7584 #[test]
7585 fn response_mutations_delete_cookie() {
7586 let mut mutations = ResponseMutations::new();
7587 mutations.remove_cookie("session");
7588
7589 let response = crate::response::Response::ok();
7590 let response = mutations.apply(response);
7591
7592 let headers = response.headers();
7593 let set_cookie = headers
7594 .iter()
7595 .find(|(n, _)| n == "Set-Cookie")
7596 .map(|(_, v)| String::from_utf8_lossy(v).to_string());
7597 assert!(set_cookie.is_some());
7598 let cookie_header = set_cookie.unwrap();
7599 assert!(cookie_header.contains("session="));
7600 assert!(cookie_header.contains("Max-Age=0"));
7601 }
7602
7603 #[test]
7604 fn response_mutations_extract() {
7605 let ctx = test_context();
7606 let mut req = Request::new(Method::Get, "/");
7607
7608 let result = futures_executor::block_on(ResponseMutations::from_request(&ctx, &mut req));
7609 let mutations = result.unwrap();
7610 assert!(mutations.headers.is_empty());
7611 assert!(mutations.cookies.is_empty());
7612 }
7613
7614 #[test]
7619 fn cookie_sanitizes_semicolon_injection_in_value() {
7620 let cookie = Cookie::new("session", "abc; Domain=.evil.com");
7622 let header = cookie.to_header_value();
7623 assert_eq!(header, "session=abcDomain=.evil.com");
7625 assert!(!header.contains("; Domain"));
7626 }
7627
7628 #[test]
7629 fn cookie_sanitizes_semicolon_injection_in_name() {
7630 let cookie = Cookie::new("session; HttpOnly", "value");
7632 let header = cookie.to_header_value();
7633 assert!(!header.starts_with("session; "));
7635 assert!(header.starts_with("sessionHttpOnly="));
7636 }
7637
7638 #[test]
7639 fn cookie_sanitizes_path_injection() {
7640 let cookie = Cookie::new("session", "abc").path("/; HttpOnly; Domain=.evil.com");
7642 let header = cookie.to_header_value();
7643 assert!(!header.contains("; Domain"));
7646 assert!(!header.contains("; HttpOnly"));
7647 assert!(header.contains("Path=/"));
7649 }
7650
7651 #[test]
7652 fn cookie_sanitizes_domain_injection() {
7653 let cookie = Cookie::new("session", "abc").domain(".example.com; HttpOnly=false");
7655 let header = cookie.to_header_value();
7656 assert!(!header.contains("; HttpOnly=false"));
7658 assert!(header.contains("Domain=.example.com HttpOnly=false"));
7660 }
7661
7662 #[test]
7663 fn cookie_sanitizes_control_characters() {
7664 let cookie = Cookie::new("session", "abc\r\nSet-Cookie: evil=value");
7666 let header = cookie.to_header_value();
7667 assert!(!header.contains("\r"));
7669 assert!(!header.contains("\n"));
7670 assert!(!header.contains(" ")); assert!(header.contains("session=abcSet-Cookie:evil=value"));
7673 }
7674
7675 #[test]
7676 fn delete_cookie_sanitizes_name() {
7677 let mut mutations = ResponseMutations::new();
7679 mutations.remove_cookie("session; Domain=.evil.com");
7680
7681 let response = crate::response::Response::ok();
7682 let response = mutations.apply(response);
7683
7684 let headers = response.headers();
7685 let set_cookie = headers
7686 .iter()
7687 .find(|(n, _)| n == "Set-Cookie")
7688 .map(|(_, v)| String::from_utf8_lossy(v).to_string());
7689 assert!(set_cookie.is_some());
7690 let cookie_header = set_cookie.unwrap();
7691 assert!(!cookie_header.contains("; Domain"));
7693 }
7694
7695 #[test]
7700 fn session_cookie_production() {
7701 let cookie = Cookie::session("session_id", "abc123", true);
7702 assert_eq!(cookie.name, "session_id");
7703 assert_eq!(cookie.value, "abc123");
7704 assert!(cookie.http_only);
7705 assert!(cookie.secure);
7706 assert_eq!(cookie.same_site, Some(SameSite::Lax));
7707 assert_eq!(cookie.path, Some("/".to_string()));
7708 }
7709
7710 #[test]
7711 fn session_cookie_development() {
7712 let cookie = Cookie::session("session_id", "abc123", false);
7713 assert!(cookie.http_only);
7714 assert!(!cookie.secure); assert_eq!(cookie.same_site, Some(SameSite::Lax));
7716 }
7717
7718 #[test]
7719 fn auth_cookie_production() {
7720 let cookie = Cookie::auth("auth_token", "jwt_token", true);
7721 assert_eq!(cookie.name, "auth_token");
7722 assert!(cookie.http_only);
7723 assert!(cookie.secure);
7724 assert_eq!(cookie.same_site, Some(SameSite::Strict)); assert_eq!(cookie.path, Some("/".to_string()));
7726 }
7727
7728 #[test]
7729 fn csrf_cookie_is_readable_by_js() {
7730 let cookie = Cookie::csrf("csrf_token", "random_value", true);
7731 assert_eq!(cookie.name, "csrf_token");
7732 assert!(!cookie.http_only); assert!(cookie.secure);
7734 assert_eq!(cookie.same_site, Some(SameSite::Strict));
7735 }
7736
7737 #[test]
7738 fn host_prefixed_cookie() {
7739 let cookie = Cookie::host_prefixed("session", "abc123");
7740 assert_eq!(cookie.name, "__Host-session");
7741 assert!(cookie.secure);
7742 assert_eq!(cookie.path, Some("/".to_string()));
7743 assert!(cookie.domain.is_none());
7744 assert!(cookie.validate_prefix().is_ok());
7745 }
7746
7747 #[test]
7748 fn host_prefixed_cookie_validation_fails_without_secure() {
7749 let cookie = Cookie::new("__Host-session", "abc123")
7750 .path("/")
7751 .secure(false);
7752 assert_eq!(
7753 cookie.validate_prefix(),
7754 Err(CookiePrefixError::HostRequiresSecure)
7755 );
7756 }
7757
7758 #[test]
7759 fn host_prefixed_cookie_validation_fails_with_domain() {
7760 let cookie = Cookie::new("__Host-session", "abc123")
7761 .path("/")
7762 .secure(true)
7763 .domain("example.com");
7764 assert_eq!(
7765 cookie.validate_prefix(),
7766 Err(CookiePrefixError::HostCannotHaveDomain)
7767 );
7768 }
7769
7770 #[test]
7771 fn host_prefixed_cookie_validation_fails_without_root_path() {
7772 let cookie = Cookie::new("__Host-session", "abc123")
7773 .path("/api")
7774 .secure(true);
7775 assert_eq!(
7776 cookie.validate_prefix(),
7777 Err(CookiePrefixError::HostRequiresRootPath)
7778 );
7779 }
7780
7781 #[test]
7782 fn secure_prefixed_cookie() {
7783 let cookie = Cookie::secure_prefixed("token", "abc123");
7784 assert_eq!(cookie.name, "__Secure-token");
7785 assert!(cookie.secure);
7786 let cookie = cookie.domain("example.com").path("/api");
7788 assert!(cookie.validate_prefix().is_ok());
7789 }
7790
7791 #[test]
7792 fn secure_prefixed_cookie_validation_fails_without_secure() {
7793 let cookie = Cookie::new("__Secure-token", "abc123").secure(false);
7794 assert_eq!(
7795 cookie.validate_prefix(),
7796 Err(CookiePrefixError::SecureRequiresSecure)
7797 );
7798 }
7799
7800 #[test]
7801 fn cookie_prefix_detection() {
7802 let host_cookie = Cookie::host_prefixed("session", "abc");
7803 assert!(host_cookie.has_security_prefix());
7804 assert_eq!(host_cookie.prefix(), Some(CookiePrefix::Host));
7805
7806 let secure_cookie = Cookie::secure_prefixed("token", "abc");
7807 assert!(secure_cookie.has_security_prefix());
7808 assert_eq!(secure_cookie.prefix(), Some(CookiePrefix::Secure));
7809
7810 let normal_cookie = Cookie::new("regular", "abc");
7811 assert!(!normal_cookie.has_security_prefix());
7812 assert_eq!(normal_cookie.prefix(), None);
7813 }
7814
7815 #[test]
7816 fn cookie_prefix_as_str() {
7817 assert_eq!(CookiePrefix::Host.as_str(), "__Host-");
7818 assert_eq!(CookiePrefix::Secure.as_str(), "__Secure-");
7819 }
7820
7821 #[test]
7822 fn cookie_prefix_error_display() {
7823 assert_eq!(
7824 CookiePrefixError::HostRequiresSecure.to_string(),
7825 "__Host- prefix requires Secure flag to be true"
7826 );
7827 assert_eq!(
7828 CookiePrefixError::HostCannotHaveDomain.to_string(),
7829 "__Host- prefix cannot have a Domain attribute"
7830 );
7831 assert_eq!(
7832 CookiePrefixError::HostRequiresRootPath.to_string(),
7833 "__Host- prefix requires Path=\"/\""
7834 );
7835 assert_eq!(
7836 CookiePrefixError::SecureRequiresSecure.to_string(),
7837 "__Secure- prefix requires Secure flag to be true"
7838 );
7839 }
7840
7841 #[test]
7842 fn session_cookie_header_format() {
7843 let cookie = Cookie::session("sid", "abc", true);
7844 let header = cookie.to_header_value();
7845 assert!(header.contains("sid=abc"));
7846 assert!(header.contains("HttpOnly"));
7847 assert!(header.contains("Secure"));
7848 assert!(header.contains("SameSite=Lax"));
7849 assert!(header.contains("Path=/"));
7850 }
7851
7852 #[test]
7853 fn host_prefixed_cookie_header_format() {
7854 let cookie = Cookie::host_prefixed("session", "abc")
7855 .http_only(true)
7856 .same_site(SameSite::Strict);
7857 let header = cookie.to_header_value();
7858 assert!(header.contains("__Host-session=abc"));
7859 assert!(header.contains("Secure"));
7860 assert!(header.contains("Path=/"));
7861 assert!(header.contains("HttpOnly"));
7862 assert!(header.contains("SameSite=Strict"));
7863 }
7864
7865 #[test]
7870 fn request_cookies_parses_single_cookie() {
7871 let cookies = RequestCookies::from_header("session_id=abc123");
7872 assert_eq!(cookies.len(), 1);
7873 assert_eq!(cookies.get("session_id"), Some("abc123"));
7874 }
7875
7876 #[test]
7877 fn request_cookies_parses_multiple_cookies() {
7878 let cookies = RequestCookies::from_header("session_id=abc123; user=bob; theme=dark");
7879 assert_eq!(cookies.len(), 3);
7880 assert_eq!(cookies.get("session_id"), Some("abc123"));
7881 assert_eq!(cookies.get("user"), Some("bob"));
7882 assert_eq!(cookies.get("theme"), Some("dark"));
7883 }
7884
7885 #[test]
7886 fn request_cookies_handles_whitespace() {
7887 let cookies = RequestCookies::from_header(" session_id = abc123 ; user=bob ");
7888 assert_eq!(cookies.get("session_id"), Some("abc123"));
7889 assert_eq!(cookies.get("user"), Some("bob"));
7890 }
7891
7892 #[test]
7893 fn request_cookies_handles_empty_header() {
7894 let cookies = RequestCookies::from_header("");
7895 assert!(cookies.is_empty());
7896 }
7897
7898 #[test]
7899 fn request_cookies_handles_malformed_pairs() {
7900 let cookies = RequestCookies::from_header("valid=value; malformed; another=good");
7902 assert_eq!(cookies.len(), 2);
7903 assert_eq!(cookies.get("valid"), Some("value"));
7904 assert_eq!(cookies.get("another"), Some("good"));
7905 assert!(!cookies.contains("malformed"));
7906 }
7907
7908 #[test]
7909 fn request_cookies_contains_check() {
7910 let cookies = RequestCookies::from_header("session=abc");
7911 assert!(cookies.contains("session"));
7912 assert!(!cookies.contains("missing"));
7913 }
7914
7915 #[test]
7916 fn request_cookies_iter() {
7917 let cookies = RequestCookies::from_header("a=1; b=2");
7918 let pairs: Vec<_> = cookies.iter().collect();
7919 assert_eq!(pairs.len(), 2);
7920 assert!(pairs.contains(&("a", "1")));
7921 assert!(pairs.contains(&("b", "2")));
7922 }
7923
7924 #[test]
7925 fn request_cookies_from_request() {
7926 let ctx = test_context();
7927 let mut req = Request::new(Method::Get, "/");
7928 req.headers_mut()
7929 .insert("cookie", b"session=xyz; user=alice".to_vec());
7930
7931 let result = futures_executor::block_on(RequestCookies::from_request(&ctx, &mut req));
7932 let cookies = result.unwrap();
7933 assert_eq!(cookies.get("session"), Some("xyz"));
7934 assert_eq!(cookies.get("user"), Some("alice"));
7935 }
7936
7937 #[test]
7938 fn request_cookies_from_request_no_cookie_header() {
7939 let ctx = test_context();
7940 let mut req = Request::new(Method::Get, "/");
7941
7942 let result = futures_executor::block_on(RequestCookies::from_request(&ctx, &mut req));
7943 let cookies = result.unwrap();
7944 assert!(cookies.is_empty());
7945 }
7946
7947 #[test]
7948 fn request_cookie_extractor_found() {
7949 #[derive(Debug)]
7950 struct TestCookie;
7951 impl CookieName for TestCookie {
7952 const NAME: &'static str = "test_cookie";
7953 }
7954
7955 let ctx = test_context();
7956 let mut req = Request::new(Method::Get, "/");
7957 req.headers_mut()
7958 .insert("cookie", b"test_cookie=hello_world".to_vec());
7959
7960 let result =
7961 futures_executor::block_on(RequestCookie::<TestCookie>::from_request(&ctx, &mut req));
7962 let cookie = result.unwrap();
7963 assert_eq!(cookie.value(), "hello_world");
7964 }
7965
7966 #[test]
7967 fn request_cookie_extractor_not_found() {
7968 #[derive(Debug)]
7969 struct MissingCookie;
7970 impl CookieName for MissingCookie {
7971 const NAME: &'static str = "missing";
7972 }
7973
7974 let ctx = test_context();
7975 let mut req = Request::new(Method::Get, "/");
7976 req.headers_mut().insert("cookie", b"other=value".to_vec());
7977
7978 let result = futures_executor::block_on(RequestCookie::<MissingCookie>::from_request(
7979 &ctx, &mut req,
7980 ));
7981 assert!(result.is_err());
7982 let err = result.unwrap_err();
7983 assert!(matches!(
7984 err,
7985 CookieExtractError::NotFound { name: "missing" }
7986 ));
7987 }
7988
7989 #[test]
7990 fn request_cookie_deref() {
7991 #[derive(Debug)]
7992 struct TestCookie;
7993 impl CookieName for TestCookie {
7994 const NAME: &'static str = "test";
7995 }
7996
7997 let cookie = RequestCookie::<TestCookie>::new("test_value");
7998 assert_eq!(&*cookie, "test_value");
8000 assert_eq!(cookie.as_ref(), "test_value");
8002 }
8003
8004 #[test]
8005 fn session_id_cookie_marker() {
8006 let ctx = test_context();
8007 let mut req = Request::new(Method::Get, "/");
8008 req.headers_mut()
8009 .insert("cookie", b"session_id=sess123".to_vec());
8010
8011 let result = futures_executor::block_on(RequestCookie::<SessionIdCookie>::from_request(
8012 &ctx, &mut req,
8013 ));
8014 let cookie = result.unwrap();
8015 assert_eq!(cookie.value(), "sess123");
8016 }
8017
8018 #[test]
8019 fn csrf_token_cookie_marker() {
8020 let ctx = test_context();
8021 let mut req = Request::new(Method::Get, "/");
8022 req.headers_mut()
8023 .insert("cookie", b"csrf_token=csrf_abc".to_vec());
8024
8025 let result = futures_executor::block_on(RequestCookie::<CsrfTokenCookie>::from_request(
8026 &ctx, &mut req,
8027 ));
8028 let cookie = result.unwrap();
8029 assert_eq!(cookie.value(), "csrf_abc");
8030 }
8031}
8032
8033#[cfg(test)]
8034mod background_tasks_tests {
8035 use super::*;
8036 use crate::request::Method;
8037 use std::sync::Arc;
8038 use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
8039
8040 fn test_context() -> RequestContext {
8041 let cx = asupersync::Cx::for_testing();
8042 RequestContext::new(cx, 12345)
8043 }
8044
8045 #[test]
8046 fn background_tasks_inner_new_is_empty() {
8047 let inner = BackgroundTasksInner::new();
8048 assert!(inner.take().is_empty());
8049 }
8050
8051 #[test]
8052 fn background_tasks_inner_push_and_take() {
8053 let inner = BackgroundTasksInner::new();
8054
8055 let executed = Arc::new(AtomicBool::new(false));
8056 let executed_clone = executed.clone();
8057
8058 inner.push(Box::new(move || {
8060 Box::pin(async move {
8061 executed_clone.store(true, Ordering::SeqCst);
8062 })
8063 }));
8064
8065 let tasks = inner.take();
8066 assert_eq!(tasks.len(), 1);
8067
8068 let task_fn = tasks.into_iter().next().unwrap();
8070 let future = task_fn();
8071 futures_executor::block_on(future);
8072 assert!(executed.load(Ordering::SeqCst));
8073 }
8074
8075 #[test]
8076 fn background_tasks_inner_take_empties_queue() {
8077 let inner = BackgroundTasksInner::new();
8078
8079 inner.push(Box::new(|| Box::pin(async {})));
8080 inner.push(Box::new(|| Box::pin(async {})));
8081
8082 let tasks = inner.take();
8083 assert_eq!(tasks.len(), 2);
8084
8085 let tasks = inner.take();
8087 assert!(tasks.is_empty());
8088 }
8089
8090 #[test]
8091 fn background_tasks_add_async_task() {
8092 let mut tasks = BackgroundTasks::new();
8093
8094 let counter = Arc::new(AtomicU32::new(0));
8095 let counter_clone = counter.clone();
8096
8097 tasks.add_task(move || async move {
8098 counter_clone.fetch_add(1, Ordering::SeqCst);
8099 });
8100
8101 let queued = tasks.take_tasks();
8102 assert_eq!(queued.len(), 1);
8103
8104 let task_fn = queued.into_iter().next().unwrap();
8106 futures_executor::block_on(task_fn());
8107 assert_eq!(counter.load(Ordering::SeqCst), 1);
8108 }
8109
8110 #[test]
8111 fn background_tasks_add_sync_task() {
8112 let mut tasks = BackgroundTasks::new();
8113
8114 let counter = Arc::new(AtomicU32::new(0));
8115 let counter_clone = counter.clone();
8116
8117 tasks.add_sync_task(move || {
8118 counter_clone.fetch_add(10, Ordering::SeqCst);
8119 });
8120
8121 let queued = tasks.take_tasks();
8122 assert_eq!(queued.len(), 1);
8123
8124 let task_fn = queued.into_iter().next().unwrap();
8125 futures_executor::block_on(task_fn());
8126 assert_eq!(counter.load(Ordering::SeqCst), 10);
8127 }
8128
8129 #[test]
8130 fn background_tasks_multiple_tasks_execute_in_order() {
8131 let mut tasks = BackgroundTasks::new();
8132
8133 let order = Arc::new(parking_lot::Mutex::new(Vec::new()));
8134 let order1 = order.clone();
8135 let order2 = order.clone();
8136 let order3 = order.clone();
8137
8138 tasks.add_task(move || async move {
8139 order1.lock().push(1);
8140 });
8141 tasks.add_task(move || async move {
8142 order2.lock().push(2);
8143 });
8144 tasks.add_task(move || async move {
8145 order3.lock().push(3);
8146 });
8147
8148 let queued = tasks.take_tasks();
8149 assert_eq!(queued.len(), 3);
8150
8151 for task_fn in queued {
8153 futures_executor::block_on(task_fn());
8154 }
8155
8156 assert_eq!(*order.lock(), vec![1, 2, 3]);
8157 }
8158
8159 #[test]
8160 fn background_tasks_from_request_creates_new() {
8161 let ctx = test_context();
8162 let mut req = Request::new(Method::Get, "/");
8163
8164 let tasks = futures_executor::block_on(BackgroundTasks::from_request(&ctx, &mut req))
8165 .expect("extraction should succeed");
8166
8167 let inner_tasks = tasks.into_inner().take();
8169 assert!(inner_tasks.is_empty());
8170 }
8171
8172 #[test]
8173 fn background_tasks_from_request_shares_inner() {
8174 let ctx = test_context();
8175 let mut req = Request::new(Method::Get, "/");
8176
8177 let mut tasks1 = futures_executor::block_on(BackgroundTasks::from_request(&ctx, &mut req))
8179 .expect("extraction should succeed");
8180
8181 let counter = Arc::new(AtomicU32::new(0));
8182 let counter_clone = counter.clone();
8183 tasks1.add_task(move || async move {
8184 counter_clone.store(42, Ordering::SeqCst);
8185 });
8186
8187 let tasks2 = futures_executor::block_on(BackgroundTasks::from_request(&ctx, &mut req))
8189 .expect("extraction should succeed");
8190
8191 let queued = tasks2.into_inner().take();
8193 assert_eq!(queued.len(), 1);
8194
8195 let task_fn = queued.into_iter().next().unwrap();
8197 futures_executor::block_on(task_fn());
8198 assert_eq!(counter.load(Ordering::SeqCst), 42);
8199 }
8200
8201 #[test]
8202 fn background_tasks_debug_shows_task_count() {
8203 let mut tasks = BackgroundTasks::new();
8204
8205 let debug_empty = format!("{:?}", tasks);
8206 assert!(debug_empty.contains("task_count"));
8207 assert!(debug_empty.contains("BackgroundTasks"));
8208
8209 tasks.add_task(|| async {});
8210 tasks.add_task(|| async {});
8211
8212 let debug_with_tasks = format!("{:?}", tasks);
8213 assert!(debug_with_tasks.contains("task_count"));
8214 }
8215
8216 #[test]
8217 fn background_tasks_inner_thread_safe() {
8218 let inner = BackgroundTasksInner::new();
8219 let inner_clone = inner.clone();
8220
8221 let counter = Arc::new(AtomicU32::new(0));
8222 let counter1 = counter.clone();
8223 let counter2 = counter.clone();
8224
8225 inner.push(Box::new(move || {
8227 Box::pin(async move {
8228 counter1.fetch_add(1, Ordering::SeqCst);
8229 })
8230 }));
8231 inner_clone.push(Box::new(move || {
8232 Box::pin(async move {
8233 counter2.fetch_add(10, Ordering::SeqCst);
8234 })
8235 }));
8236
8237 let tasks = inner.take();
8239 assert_eq!(tasks.len(), 2);
8240
8241 for task_fn in tasks {
8242 futures_executor::block_on(task_fn());
8243 }
8244
8245 assert_eq!(counter.load(Ordering::SeqCst), 11);
8246 }
8247
8248 #[test]
8249 fn background_tasks_into_inner_conversion() {
8250 let mut tasks = BackgroundTasks::new();
8251
8252 tasks.add_task(|| async {});
8253 tasks.add_task(|| async {});
8254
8255 let inner = tasks.into_inner();
8256 let queued = inner.take();
8257 assert_eq!(queued.len(), 2);
8258 }
8259
8260 #[test]
8261 fn background_tasks_is_empty_and_len() {
8262 let mut tasks = BackgroundTasks::new();
8263
8264 assert!(tasks.is_empty());
8265 assert_eq!(tasks.len(), 0);
8266
8267 tasks.add_task(|| async {});
8268 assert!(!tasks.is_empty());
8269 assert_eq!(tasks.len(), 1);
8270
8271 tasks.add_task(|| async {});
8272 assert_eq!(tasks.len(), 2);
8273 }
8274
8275 #[test]
8276 fn background_tasks_inner_len_and_is_empty() {
8277 let inner = BackgroundTasksInner::new();
8278
8279 assert!(inner.is_empty());
8280 assert_eq!(inner.len(), 0);
8281
8282 inner.push(Box::new(|| Box::pin(async {})));
8283 assert!(!inner.is_empty());
8284 assert_eq!(inner.len(), 1);
8285
8286 inner.push(Box::new(|| Box::pin(async {})));
8287 assert_eq!(inner.len(), 2);
8288
8289 let _ = inner.take();
8291 assert!(inner.is_empty());
8292 assert_eq!(inner.len(), 0);
8293 }
8294
8295 #[test]
8296 fn background_tasks_execute_all_runs_all_tasks() {
8297 let mut tasks = BackgroundTasks::new();
8298
8299 let counter = Arc::new(AtomicU32::new(0));
8300 let c1 = counter.clone();
8301 let c2 = counter.clone();
8302 let c3 = counter.clone();
8303
8304 tasks.add_task(move || async move {
8305 c1.fetch_add(1, Ordering::SeqCst);
8306 });
8307 tasks.add_task(move || async move {
8308 c2.fetch_add(10, Ordering::SeqCst);
8309 });
8310 tasks.add_task(move || async move {
8311 c3.fetch_add(100, Ordering::SeqCst);
8312 });
8313
8314 futures_executor::block_on(tasks.execute_all());
8315 assert_eq!(counter.load(Ordering::SeqCst), 111);
8316 }
8317
8318 #[test]
8319 fn background_tasks_execute_with_context_respects_cancellation() {
8320 let ctx = test_context();
8321 let mut tasks = BackgroundTasks::new();
8322
8323 let counter = Arc::new(AtomicU32::new(0));
8324 let c1 = counter.clone();
8325 let c2 = counter.clone();
8326
8327 tasks.add_task(move || async move {
8328 c1.fetch_add(1, Ordering::SeqCst);
8329 });
8330 tasks.add_task(move || async move {
8331 c2.fetch_add(10, Ordering::SeqCst);
8332 });
8333
8334 futures_executor::block_on(tasks.execute_with_context(&ctx));
8336 assert_eq!(counter.load(Ordering::SeqCst), 11);
8337 }
8338
8339 #[test]
8340 fn background_tasks_execute_with_panic_isolation_handles_closure_panic() {
8341 let mut tasks = BackgroundTasks::new();
8342
8343 let counter = Arc::new(AtomicU32::new(0));
8344 let c1 = counter.clone();
8345 let c2 = counter.clone();
8346
8347 tasks.add_task(move || async move {
8349 c1.fetch_add(1, Ordering::SeqCst);
8350 });
8351
8352 tasks.inner.push(Box::new(|| {
8354 panic!("intentional test panic");
8355 }));
8356
8357 tasks.add_task(move || async move {
8359 c2.fetch_add(100, Ordering::SeqCst);
8360 });
8361
8362 futures_executor::block_on(tasks.execute_with_panic_isolation());
8364
8365 assert_eq!(counter.load(Ordering::SeqCst), 101);
8368 }
8369
8370 #[test]
8371 fn format_panic_message_extracts_str() {
8372 let panic_info: Box<dyn std::any::Any + Send> = Box::new("test panic message");
8373 let msg = super::format_panic_message(&panic_info);
8374 assert_eq!(msg, "test panic message");
8375 }
8376
8377 #[test]
8378 fn format_panic_message_extracts_string() {
8379 let panic_info: Box<dyn std::any::Any + Send> = Box::new(String::from("string panic"));
8380 let msg = super::format_panic_message(&panic_info);
8381 assert_eq!(msg, "string panic");
8382 }
8383
8384 #[test]
8385 fn format_panic_message_handles_unknown() {
8386 let panic_info: Box<dyn std::any::Any + Send> = Box::new(42i32);
8387 let msg = super::format_panic_message(&panic_info);
8388 assert_eq!(msg, "unknown panic");
8389 }
8390
8391 #[test]
8398 fn background_tasks_timing_single_task_after_response() {
8399 let ctx = test_context();
8401 let mut req = Request::new(Method::Get, "/");
8402
8403 let counter = Arc::new(AtomicU32::new(0));
8405 let counter_clone = counter.clone();
8406
8407 let mut tasks = futures_executor::block_on(BackgroundTasks::from_request(&ctx, &mut req))
8408 .expect("extraction should succeed");
8409
8410 tasks.add_task(move || async move {
8411 counter_clone.store(42, Ordering::SeqCst);
8412 });
8413
8414 assert_eq!(
8417 counter.load(Ordering::SeqCst),
8418 0,
8419 "task should not run before take_tasks"
8420 );
8421
8422 let taken_tasks = req
8425 .get_extension::<BackgroundTasksInner>()
8426 .map(|inner| BackgroundTasks::from_inner(inner.clone()))
8427 .expect("tasks should be in extension");
8428
8429 assert_eq!(
8431 counter.load(Ordering::SeqCst),
8432 0,
8433 "task should not run before execute_all"
8434 );
8435
8436 futures_executor::block_on(taken_tasks.execute_all());
8438
8439 assert_eq!(
8441 counter.load(Ordering::SeqCst),
8442 42,
8443 "task should have executed"
8444 );
8445 }
8446
8447 #[test]
8448 fn background_tasks_timing_multiple_tasks_in_order() {
8449 let ctx = test_context();
8451 let mut req = Request::new(Method::Get, "/");
8452
8453 let execution_order = Arc::new(parking_lot::Mutex::new(Vec::new()));
8454 let order1 = execution_order.clone();
8455 let order2 = execution_order.clone();
8456 let order3 = execution_order.clone();
8457
8458 let mut tasks = futures_executor::block_on(BackgroundTasks::from_request(&ctx, &mut req))
8459 .expect("extraction should succeed");
8460
8461 tasks.add_task(move || async move {
8463 order1.lock().push(1);
8464 });
8465 tasks.add_task(move || async move {
8466 order2.lock().push(2);
8467 });
8468 tasks.add_task(move || async move {
8469 order3.lock().push(3);
8470 });
8471
8472 assert!(
8474 execution_order.lock().is_empty(),
8475 "no tasks should run during response building"
8476 );
8477
8478 let taken_tasks = req
8480 .get_extension::<BackgroundTasksInner>()
8481 .map(|inner| BackgroundTasks::from_inner(inner.clone()))
8482 .expect("tasks should be in extension");
8483
8484 futures_executor::block_on(taken_tasks.execute_all());
8485
8486 assert_eq!(
8488 *execution_order.lock(),
8489 vec![1, 2, 3],
8490 "tasks should execute in queue order"
8491 );
8492 }
8493
8494 #[test]
8495 fn background_tasks_timing_tasks_can_spawn_more_tasks() {
8496 let mut tasks = BackgroundTasks::new();
8501
8502 let counter = Arc::new(AtomicU32::new(0));
8503 let c1 = counter.clone();
8504 let c2 = counter.clone();
8505 let c3 = counter.clone();
8506
8507 tasks.add_task(move || async move {
8509 c1.fetch_add(1, Ordering::SeqCst);
8510 });
8511 tasks.add_task(move || async move {
8512 c2.fetch_add(10, Ordering::SeqCst);
8513 });
8514
8515 futures_executor::block_on(tasks.execute_all());
8517 assert_eq!(
8518 counter.load(Ordering::SeqCst),
8519 11,
8520 "first batch should complete"
8521 );
8522
8523 let mut more_tasks = BackgroundTasks::new();
8525 more_tasks.add_task(move || async move {
8526 c3.fetch_add(100, Ordering::SeqCst);
8527 });
8528
8529 futures_executor::block_on(more_tasks.execute_all());
8531 assert_eq!(
8532 counter.load(Ordering::SeqCst),
8533 111,
8534 "spawned tasks should also run"
8535 );
8536 }
8537
8538 #[test]
8539 fn background_tasks_timing_independent_requests() {
8540 let ctx1 = test_context();
8543 let ctx2 = test_context();
8544
8545 let mut req1 = Request::new(Method::Get, "/request1");
8546 let mut req2 = Request::new(Method::Get, "/request2");
8547
8548 let counter1 = Arc::new(AtomicU32::new(0));
8549 let counter2 = Arc::new(AtomicU32::new(0));
8550 let c1 = counter1.clone();
8551 let c2 = counter2.clone();
8552
8553 let mut tasks1 =
8555 futures_executor::block_on(BackgroundTasks::from_request(&ctx1, &mut req1))
8556 .expect("extraction should succeed");
8557 tasks1.add_task(move || async move {
8558 c1.store(100, Ordering::SeqCst);
8559 });
8560
8561 let mut tasks2 =
8563 futures_executor::block_on(BackgroundTasks::from_request(&ctx2, &mut req2))
8564 .expect("extraction should succeed");
8565 tasks2.add_task(move || async move {
8566 c2.store(200, Ordering::SeqCst);
8567 });
8568
8569 let taken1 = req1
8571 .get_extension::<BackgroundTasksInner>()
8572 .map(|inner| BackgroundTasks::from_inner(inner.clone()))
8573 .expect("tasks should be in extension");
8574 futures_executor::block_on(taken1.execute_all());
8575
8576 assert_eq!(
8578 counter1.load(Ordering::SeqCst),
8579 100,
8580 "request 1 task should run"
8581 );
8582 assert_eq!(
8583 counter2.load(Ordering::SeqCst),
8584 0,
8585 "request 2 task should not run yet"
8586 );
8587
8588 let taken2 = req2
8590 .get_extension::<BackgroundTasksInner>()
8591 .map(|inner| BackgroundTasks::from_inner(inner.clone()))
8592 .expect("tasks should be in extension");
8593 futures_executor::block_on(taken2.execute_all());
8594
8595 assert_eq!(
8597 counter1.load(Ordering::SeqCst),
8598 100,
8599 "request 1 task unchanged"
8600 );
8601 assert_eq!(
8602 counter2.load(Ordering::SeqCst),
8603 200,
8604 "request 2 task should run"
8605 );
8606 }
8607
8608 #[test]
8609 fn background_tasks_timing_nonblocking_next_request() {
8610 let ctx = test_context();
8614 let mut req1 = Request::new(Method::Get, "/first");
8615 let mut req2 = Request::new(Method::Get, "/second");
8616
8617 let req1_done = Arc::new(AtomicBool::new(false));
8618 let req2_done = Arc::new(AtomicBool::new(false));
8619 let r1 = req1_done.clone();
8620 let r2 = req2_done.clone();
8621
8622 let mut tasks1 = futures_executor::block_on(BackgroundTasks::from_request(&ctx, &mut req1))
8624 .expect("extraction should succeed");
8625 tasks1.add_task(move || async move {
8626 r1.store(true, Ordering::SeqCst);
8628 });
8629
8630 let mut tasks2 = futures_executor::block_on(BackgroundTasks::from_request(&ctx, &mut req2))
8632 .expect("extraction should succeed");
8633 tasks2.add_task(move || async move {
8634 r2.store(true, Ordering::SeqCst);
8635 });
8636
8637 assert!(
8639 !req1_done.load(Ordering::SeqCst),
8640 "req1 tasks not yet executed"
8641 );
8642 assert!(
8643 !req2_done.load(Ordering::SeqCst),
8644 "req2 tasks not yet executed"
8645 );
8646
8647 let taken1 = req1
8652 .get_extension::<BackgroundTasksInner>()
8653 .map(|inner| BackgroundTasks::from_inner(inner.clone()))
8654 .expect("tasks should be in extension");
8655 let taken2 = req2
8656 .get_extension::<BackgroundTasksInner>()
8657 .map(|inner| BackgroundTasks::from_inner(inner.clone()))
8658 .expect("tasks should be in extension");
8659
8660 futures_executor::block_on(taken1.execute_all());
8661 futures_executor::block_on(taken2.execute_all());
8662
8663 assert!(
8664 req1_done.load(Ordering::SeqCst),
8665 "req1 tasks should be done"
8666 );
8667 assert!(
8668 req2_done.load(Ordering::SeqCst),
8669 "req2 tasks should be done"
8670 );
8671 }
8672}
8673
8674#[derive(Debug, Clone)]
8706pub struct Header<T> {
8707 pub value: T,
8709 pub name: String,
8711}
8712
8713impl<T> Header<T> {
8714 #[must_use]
8716 pub fn new(name: impl Into<String>, value: T) -> Self {
8717 Self {
8718 value,
8719 name: name.into(),
8720 }
8721 }
8722
8723 #[must_use]
8725 pub fn into_inner(self) -> T {
8726 self.value
8727 }
8728}
8729
8730impl<T> Deref for Header<T> {
8731 type Target = T;
8732
8733 fn deref(&self) -> &Self::Target {
8734 &self.value
8735 }
8736}
8737
8738impl<T> DerefMut for Header<T> {
8739 fn deref_mut(&mut self) -> &mut Self::Target {
8740 &mut self.value
8741 }
8742}
8743
8744#[must_use]
8751pub fn snake_to_header_case(name: &str) -> String {
8752 name.split('_')
8753 .map(|word| {
8754 let mut chars = word.chars();
8755 match chars.next() {
8756 None => String::new(),
8757 Some(first) => {
8758 let mut result = first.to_uppercase().to_string();
8759 result.extend(chars);
8760 result
8761 }
8762 }
8763 })
8764 .collect::<Vec<_>>()
8765 .join("-")
8766}
8767
8768#[derive(Debug)]
8770pub enum HeaderExtractError {
8771 MissingHeader {
8773 name: String,
8775 },
8776 InvalidUtf8 {
8778 name: String,
8780 },
8781 ParseError {
8783 name: String,
8785 value: String,
8787 expected: &'static str,
8789 message: String,
8791 },
8792}
8793
8794impl std::fmt::Display for HeaderExtractError {
8795 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
8796 match self {
8797 Self::MissingHeader { name } => {
8798 write!(f, "Missing required header: {name}")
8799 }
8800 Self::InvalidUtf8 { name } => {
8801 write!(f, "Header '{name}' contains invalid UTF-8")
8802 }
8803 Self::ParseError {
8804 name,
8805 value,
8806 expected,
8807 message,
8808 } => {
8809 write!(
8810 f,
8811 "Failed to parse header '{name}' value '{value}' as {expected}: {message}"
8812 )
8813 }
8814 }
8815 }
8816}
8817
8818impl std::error::Error for HeaderExtractError {}
8819
8820impl IntoResponse for HeaderExtractError {
8821 fn into_response(self) -> crate::response::Response {
8822 let error = match &self {
8824 HeaderExtractError::MissingHeader { name } => {
8825 ValidationError::missing(crate::error::loc::header(name))
8826 .with_msg(format!("Missing required header: {name}"))
8827 }
8828 HeaderExtractError::InvalidUtf8 { name } => {
8829 ValidationError::type_error(crate::error::loc::header(name), "string")
8830 .with_msg(format!("Header '{name}' contains invalid UTF-8"))
8831 }
8832 HeaderExtractError::ParseError {
8833 name,
8834 value,
8835 expected,
8836 message,
8837 } => ValidationError::type_error(crate::error::loc::header(name), expected)
8838 .with_msg(format!("Failed to parse as {expected}: {message}"))
8839 .with_input(serde_json::Value::String(value.clone())),
8840 };
8841 ValidationErrors::single(error).into_response()
8842 }
8843}
8844
8845pub trait FromHeaderValue: Sized {
8847 fn from_header_value(value: &str) -> Result<Self, String>;
8849
8850 fn type_name() -> &'static str;
8852}
8853
8854impl FromHeaderValue for String {
8855 fn from_header_value(value: &str) -> Result<Self, String> {
8856 Ok(value.to_string())
8857 }
8858
8859 fn type_name() -> &'static str {
8860 "String"
8861 }
8862}
8863
8864impl FromHeaderValue for i32 {
8865 fn from_header_value(value: &str) -> Result<Self, String> {
8866 value.parse().map_err(|e| format!("{e}"))
8867 }
8868
8869 fn type_name() -> &'static str {
8870 "i32"
8871 }
8872}
8873
8874impl FromHeaderValue for i64 {
8875 fn from_header_value(value: &str) -> Result<Self, String> {
8876 value.parse().map_err(|e| format!("{e}"))
8877 }
8878
8879 fn type_name() -> &'static str {
8880 "i64"
8881 }
8882}
8883
8884impl FromHeaderValue for u32 {
8885 fn from_header_value(value: &str) -> Result<Self, String> {
8886 value.parse().map_err(|e| format!("{e}"))
8887 }
8888
8889 fn type_name() -> &'static str {
8890 "u32"
8891 }
8892}
8893
8894impl FromHeaderValue for u64 {
8895 fn from_header_value(value: &str) -> Result<Self, String> {
8896 value.parse().map_err(|e| format!("{e}"))
8897 }
8898
8899 fn type_name() -> &'static str {
8900 "u64"
8901 }
8902}
8903
8904impl FromHeaderValue for bool {
8905 fn from_header_value(value: &str) -> Result<Self, String> {
8906 match value.to_ascii_lowercase().as_str() {
8907 "true" | "1" | "yes" | "on" => Ok(true),
8908 "false" | "0" | "no" | "off" => Ok(false),
8909 _ => Err(format!("invalid boolean: {value}")),
8910 }
8911 }
8912
8913 fn type_name() -> &'static str {
8914 "bool"
8915 }
8916}
8917
8918#[derive(Debug, Clone)]
8940pub struct NamedHeader<T, N> {
8941 pub value: T,
8943 _marker: std::marker::PhantomData<N>,
8944}
8945
8946pub trait HeaderName {
8948 const NAME: &'static str;
8950}
8951
8952impl<T, N> NamedHeader<T, N> {
8953 #[must_use]
8955 pub fn new(value: T) -> Self {
8956 Self {
8957 value,
8958 _marker: std::marker::PhantomData,
8959 }
8960 }
8961
8962 #[must_use]
8964 pub fn into_inner(self) -> T {
8965 self.value
8966 }
8967}
8968
8969impl<T, N> Deref for NamedHeader<T, N> {
8970 type Target = T;
8971
8972 fn deref(&self) -> &Self::Target {
8973 &self.value
8974 }
8975}
8976
8977impl<T, N> DerefMut for NamedHeader<T, N> {
8978 fn deref_mut(&mut self) -> &mut Self::Target {
8979 &mut self.value
8980 }
8981}
8982
8983impl<T, N> FromRequest for NamedHeader<T, N>
8984where
8985 T: FromHeaderValue + Send + Sync + 'static,
8986 N: HeaderName + Send + Sync + 'static,
8987{
8988 type Error = HeaderExtractError;
8989
8990 async fn from_request(_ctx: &RequestContext, req: &mut Request) -> Result<Self, Self::Error> {
8991 let header_name = N::NAME;
8992
8993 let value_bytes =
8994 req.headers()
8995 .get(header_name)
8996 .ok_or_else(|| HeaderExtractError::MissingHeader {
8997 name: header_name.to_string(),
8998 })?;
8999
9000 let value_str =
9001 std::str::from_utf8(value_bytes).map_err(|_| HeaderExtractError::InvalidUtf8 {
9002 name: header_name.to_string(),
9003 })?;
9004
9005 let value =
9006 T::from_header_value(value_str).map_err(|message| HeaderExtractError::ParseError {
9007 name: header_name.to_string(),
9008 value: value_str.to_string(),
9009 expected: T::type_name(),
9010 message,
9011 })?;
9012
9013 Ok(NamedHeader::new(value))
9014 }
9015}
9016
9017pub struct Authorization;
9020impl HeaderName for Authorization {
9021 const NAME: &'static str = "authorization";
9022}
9023
9024pub struct ContentType;
9026impl HeaderName for ContentType {
9027 const NAME: &'static str = "content-type";
9028}
9029
9030pub struct Accept;
9032impl HeaderName for Accept {
9033 const NAME: &'static str = "accept";
9034}
9035
9036pub struct XRequestId;
9038impl HeaderName for XRequestId {
9039 const NAME: &'static str = "x-request-id";
9040}
9041
9042pub struct UserAgent;
9044impl HeaderName for UserAgent {
9045 const NAME: &'static str = "user-agent";
9046}
9047
9048pub struct Host;
9050impl HeaderName for Host {
9051 const NAME: &'static str = "host";
9052}
9053
9054#[derive(Debug, Clone)]
9097pub struct OAuth2PasswordBearer {
9098 pub token: String,
9100}
9101
9102impl OAuth2PasswordBearer {
9103 #[must_use]
9105 pub fn new(token: impl Into<String>) -> Self {
9106 Self {
9107 token: token.into(),
9108 }
9109 }
9110
9111 #[must_use]
9113 pub fn token(&self) -> &str {
9114 &self.token
9115 }
9116
9117 #[must_use]
9119 pub fn into_token(self) -> String {
9120 self.token
9121 }
9122}
9123
9124impl Deref for OAuth2PasswordBearer {
9125 type Target = str;
9126
9127 fn deref(&self) -> &Self::Target {
9128 &self.token
9129 }
9130}
9131
9132#[derive(Debug, Clone)]
9136pub struct OAuth2PasswordBearerConfig {
9137 pub token_url: String,
9139 pub refresh_url: Option<String>,
9141 pub scopes: std::collections::HashMap<String, String>,
9143 pub scheme_name: Option<String>,
9145 pub description: Option<String>,
9147 pub auto_error: bool,
9150}
9151
9152impl Default for OAuth2PasswordBearerConfig {
9153 fn default() -> Self {
9154 Self {
9155 token_url: "/token".to_string(),
9156 refresh_url: None,
9157 scopes: std::collections::HashMap::new(),
9158 scheme_name: None,
9159 description: None,
9160 auto_error: true,
9161 }
9162 }
9163}
9164
9165impl OAuth2PasswordBearerConfig {
9166 #[must_use]
9168 pub fn new(token_url: impl Into<String>) -> Self {
9169 Self {
9170 token_url: token_url.into(),
9171 ..Default::default()
9172 }
9173 }
9174
9175 #[must_use]
9177 pub fn with_refresh_url(mut self, url: impl Into<String>) -> Self {
9178 self.refresh_url = Some(url.into());
9179 self
9180 }
9181
9182 #[must_use]
9184 pub fn with_scope(mut self, scope: impl Into<String>, description: impl Into<String>) -> Self {
9185 self.scopes.insert(scope.into(), description.into());
9186 self
9187 }
9188
9189 #[must_use]
9191 pub fn with_scheme_name(mut self, name: impl Into<String>) -> Self {
9192 self.scheme_name = Some(name.into());
9193 self
9194 }
9195
9196 #[must_use]
9198 pub fn with_description(mut self, desc: impl Into<String>) -> Self {
9199 self.description = Some(desc.into());
9200 self
9201 }
9202
9203 #[must_use]
9205 pub fn with_auto_error(mut self, auto_error: bool) -> Self {
9206 self.auto_error = auto_error;
9207 self
9208 }
9209}
9210
9211#[derive(Debug, Clone)]
9213pub struct OAuth2BearerError {
9214 pub kind: OAuth2BearerErrorKind,
9216}
9217
9218#[derive(Debug, Clone, Copy, PartialEq, Eq)]
9220pub enum OAuth2BearerErrorKind {
9221 MissingHeader,
9223 InvalidScheme,
9225 EmptyToken,
9227}
9228
9229impl OAuth2BearerError {
9230 #[must_use]
9232 pub fn missing_header() -> Self {
9233 Self {
9234 kind: OAuth2BearerErrorKind::MissingHeader,
9235 }
9236 }
9237
9238 #[must_use]
9240 pub fn invalid_scheme() -> Self {
9241 Self {
9242 kind: OAuth2BearerErrorKind::InvalidScheme,
9243 }
9244 }
9245
9246 #[must_use]
9248 pub fn empty_token() -> Self {
9249 Self {
9250 kind: OAuth2BearerErrorKind::EmptyToken,
9251 }
9252 }
9253}
9254
9255impl fmt::Display for OAuth2BearerError {
9256 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
9257 match self.kind {
9258 OAuth2BearerErrorKind::MissingHeader => {
9259 write!(f, "Missing Authorization header")
9260 }
9261 OAuth2BearerErrorKind::InvalidScheme => {
9262 write!(f, "Authorization header must use Bearer scheme")
9263 }
9264 OAuth2BearerErrorKind::EmptyToken => {
9265 write!(f, "Bearer token is empty")
9266 }
9267 }
9268 }
9269}
9270
9271impl IntoResponse for OAuth2BearerError {
9272 fn into_response(self) -> crate::response::Response {
9273 use crate::response::{Response, ResponseBody, StatusCode};
9274
9275 let message = match self.kind {
9276 OAuth2BearerErrorKind::MissingHeader => "Not authenticated",
9277 OAuth2BearerErrorKind::InvalidScheme => "Invalid authentication credentials",
9278 OAuth2BearerErrorKind::EmptyToken => "Invalid authentication credentials",
9279 };
9280
9281 let body = serde_json::json!({
9282 "detail": message
9283 });
9284
9285 Response::with_status(StatusCode::UNAUTHORIZED)
9286 .header("www-authenticate", b"Bearer".to_vec())
9287 .header("content-type", b"application/json".to_vec())
9288 .body(ResponseBody::Bytes(body.to_string().into_bytes()))
9289 }
9290}
9291
9292impl FromRequest for OAuth2PasswordBearer {
9293 type Error = OAuth2BearerError;
9294
9295 async fn from_request(_ctx: &RequestContext, req: &mut Request) -> Result<Self, Self::Error> {
9296 let auth_header = req
9298 .headers()
9299 .get("authorization")
9300 .ok_or_else(OAuth2BearerError::missing_header)?;
9301
9302 let auth_str =
9304 std::str::from_utf8(auth_header).map_err(|_| OAuth2BearerError::invalid_scheme())?;
9305
9306 const BEARER_PREFIX: &str = "Bearer ";
9308 const BEARER_PREFIX_LOWER: &str = "bearer ";
9309
9310 let token = if auth_str.starts_with(BEARER_PREFIX) {
9311 &auth_str[BEARER_PREFIX.len()..]
9312 } else if auth_str.starts_with(BEARER_PREFIX_LOWER) {
9313 &auth_str[BEARER_PREFIX_LOWER.len()..]
9314 } else {
9315 return Err(OAuth2BearerError::invalid_scheme());
9316 };
9317
9318 let token = token.trim();
9320 if token.is_empty() {
9321 return Err(OAuth2BearerError::empty_token());
9322 }
9323
9324 Ok(OAuth2PasswordBearer::new(token))
9325 }
9326}
9327
9328#[derive(Debug, Clone, PartialEq, Eq)]
9334pub enum OAuth2PasswordFormError {
9335 UnsupportedMediaType {
9337 actual: Option<String>,
9339 },
9340 PayloadTooLarge {
9342 size: usize,
9344 limit: usize,
9346 },
9347 MissingUsername,
9349 MissingPassword,
9351 InvalidGrantType {
9353 actual: String,
9355 },
9356 MissingGrantType,
9358 InvalidUtf8,
9360 StreamingNotSupported,
9362}
9363
9364impl std::fmt::Display for OAuth2PasswordFormError {
9365 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
9366 match self {
9367 Self::UnsupportedMediaType { actual } => {
9368 if let Some(ct) = actual {
9369 write!(f, "Expected application/x-www-form-urlencoded, got: {ct}")
9370 } else {
9371 write!(f, "Missing Content-Type header")
9372 }
9373 }
9374 Self::PayloadTooLarge { size, limit } => {
9375 write!(f, "Body too large: {size} > {limit}")
9376 }
9377 Self::MissingUsername => write!(f, "Missing required field: username"),
9378 Self::MissingPassword => write!(f, "Missing required field: password"),
9379 Self::InvalidGrantType { actual } => {
9380 write!(f, "grant_type must be \"password\", got: \"{actual}\"")
9381 }
9382 Self::MissingGrantType => write!(f, "Missing required field: grant_type"),
9383 Self::InvalidUtf8 => write!(f, "Invalid UTF-8 in form body"),
9384 Self::StreamingNotSupported => write!(f, "Streaming bodies not supported"),
9385 }
9386 }
9387}
9388
9389impl std::error::Error for OAuth2PasswordFormError {}
9390
9391impl IntoResponse for OAuth2PasswordFormError {
9392 fn into_response(self) -> Response {
9393 match &self {
9394 OAuth2PasswordFormError::UnsupportedMediaType { .. } => {
9395 HttpError::unsupported_media_type().into_response()
9396 }
9397 OAuth2PasswordFormError::PayloadTooLarge { size, limit } => {
9398 HttpError::payload_too_large()
9399 .with_detail(format!("Body {size} > {limit}"))
9400 .into_response()
9401 }
9402 OAuth2PasswordFormError::MissingUsername => ValidationErrors::single(
9403 ValidationError::new(
9404 crate::error::error_types::MISSING,
9405 vec![
9406 crate::error::LocItem::field("body"),
9407 crate::error::LocItem::field("username"),
9408 ],
9409 )
9410 .with_msg("Field required".to_string()),
9411 )
9412 .into_response(),
9413 OAuth2PasswordFormError::MissingPassword => ValidationErrors::single(
9414 ValidationError::new(
9415 crate::error::error_types::MISSING,
9416 vec![
9417 crate::error::LocItem::field("body"),
9418 crate::error::LocItem::field("password"),
9419 ],
9420 )
9421 .with_msg("Field required".to_string()),
9422 )
9423 .into_response(),
9424 OAuth2PasswordFormError::InvalidGrantType { actual } => ValidationErrors::single(
9425 ValidationError::new(
9426 crate::error::error_types::VALUE_ERROR,
9427 vec![
9428 crate::error::LocItem::field("body"),
9429 crate::error::LocItem::field("grant_type"),
9430 ],
9431 )
9432 .with_msg(format!(
9433 "grant_type must be \"password\", got: \"{actual}\""
9434 )),
9435 )
9436 .into_response(),
9437 OAuth2PasswordFormError::MissingGrantType => ValidationErrors::single(
9438 ValidationError::new(
9439 crate::error::error_types::MISSING,
9440 vec![
9441 crate::error::LocItem::field("body"),
9442 crate::error::LocItem::field("grant_type"),
9443 ],
9444 )
9445 .with_msg("Field required".to_string()),
9446 )
9447 .into_response(),
9448 OAuth2PasswordFormError::InvalidUtf8 => HttpError::bad_request()
9449 .with_detail("Invalid UTF-8")
9450 .into_response(),
9451 OAuth2PasswordFormError::StreamingNotSupported => {
9452 HttpError::bad_request().into_response()
9453 }
9454 }
9455 }
9456}
9457
9458#[derive(Debug, Clone)]
9491pub struct OAuth2PasswordRequestForm {
9492 pub grant_type: Option<String>,
9494 pub username: String,
9496 pub password: String,
9498 pub scope: String,
9500 pub client_id: Option<String>,
9502 pub client_secret: Option<String>,
9504}
9505
9506impl OAuth2PasswordRequestForm {
9507 #[must_use]
9509 pub fn scopes(&self) -> Vec<String> {
9510 if self.scope.is_empty() {
9511 Vec::new()
9512 } else {
9513 self.scope.split(' ').map(String::from).collect()
9514 }
9515 }
9516}
9517
9518fn extract_form_body(
9520 ctx: &RequestContext,
9521 req: &mut Request,
9522) -> Result<QueryParams, OAuth2PasswordFormError> {
9523 let ct = req
9525 .headers()
9526 .get("content-type")
9527 .and_then(|v| std::str::from_utf8(v).ok());
9528 let is_form = ct.is_some_and(|c| {
9529 c.to_ascii_lowercase()
9530 .starts_with("application/x-www-form-urlencoded")
9531 });
9532 if !is_form {
9533 return Err(OAuth2PasswordFormError::UnsupportedMediaType {
9534 actual: ct.map(String::from),
9535 });
9536 }
9537
9538 let body = req.take_body();
9540 let bytes = match body {
9541 Body::Empty => Vec::new(),
9542 Body::Bytes(b) => b,
9543 Body::Stream(_) => return Err(OAuth2PasswordFormError::StreamingNotSupported),
9544 };
9545
9546 let limit = ctx.max_body_size();
9548 if bytes.len() > limit {
9549 return Err(OAuth2PasswordFormError::PayloadTooLarge {
9550 size: bytes.len(),
9551 limit,
9552 });
9553 }
9554
9555 let body_str = std::str::from_utf8(&bytes).map_err(|_| OAuth2PasswordFormError::InvalidUtf8)?;
9557 Ok(QueryParams::parse(body_str))
9558}
9559
9560impl FromRequest for OAuth2PasswordRequestForm {
9561 type Error = OAuth2PasswordFormError;
9562
9563 async fn from_request(ctx: &RequestContext, req: &mut Request) -> Result<Self, Self::Error> {
9564 let _ = ctx.checkpoint();
9565 let params = extract_form_body(ctx, req)?;
9566 let _ = ctx.checkpoint();
9567
9568 let username = params
9569 .get("username")
9570 .ok_or(OAuth2PasswordFormError::MissingUsername)?
9571 .to_string();
9572
9573 let password = params
9574 .get("password")
9575 .ok_or(OAuth2PasswordFormError::MissingPassword)?
9576 .to_string();
9577
9578 let grant_type = params.get("grant_type").map(String::from);
9579 let scope = params.get("scope").map(String::from).unwrap_or_default();
9580 let client_id = params.get("client_id").map(String::from);
9581 let client_secret = params.get("client_secret").map(String::from);
9582
9583 let _ = ctx.checkpoint();
9584 Ok(OAuth2PasswordRequestForm {
9585 grant_type,
9586 username,
9587 password,
9588 scope,
9589 client_id,
9590 client_secret,
9591 })
9592 }
9593}
9594
9595#[derive(Debug, Clone)]
9614pub struct OAuth2PasswordRequestFormStrict {
9615 pub form: OAuth2PasswordRequestForm,
9617}
9618
9619impl OAuth2PasswordRequestFormStrict {
9620 #[must_use]
9622 pub fn inner(&self) -> &OAuth2PasswordRequestForm {
9623 &self.form
9624 }
9625
9626 #[must_use]
9628 pub fn into_inner(self) -> OAuth2PasswordRequestForm {
9629 self.form
9630 }
9631}
9632
9633impl std::ops::Deref for OAuth2PasswordRequestFormStrict {
9634 type Target = OAuth2PasswordRequestForm;
9635
9636 fn deref(&self) -> &Self::Target {
9637 &self.form
9638 }
9639}
9640
9641impl FromRequest for OAuth2PasswordRequestFormStrict {
9642 type Error = OAuth2PasswordFormError;
9643
9644 async fn from_request(ctx: &RequestContext, req: &mut Request) -> Result<Self, Self::Error> {
9645 let _ = ctx.checkpoint();
9646 let params = extract_form_body(ctx, req)?;
9647 let _ = ctx.checkpoint();
9648
9649 let grant_type_value = params
9650 .get("grant_type")
9651 .ok_or(OAuth2PasswordFormError::MissingGrantType)?;
9652
9653 if grant_type_value != "password" {
9654 return Err(OAuth2PasswordFormError::InvalidGrantType {
9655 actual: grant_type_value.to_string(),
9656 });
9657 }
9658
9659 let username = params
9660 .get("username")
9661 .ok_or(OAuth2PasswordFormError::MissingUsername)?
9662 .to_string();
9663
9664 let password = params
9665 .get("password")
9666 .ok_or(OAuth2PasswordFormError::MissingPassword)?
9667 .to_string();
9668
9669 let scope = params.get("scope").map(String::from).unwrap_or_default();
9670 let client_id = params.get("client_id").map(String::from);
9671 let client_secret = params.get("client_secret").map(String::from);
9672
9673 let _ = ctx.checkpoint();
9674 Ok(OAuth2PasswordRequestFormStrict {
9675 form: OAuth2PasswordRequestForm {
9676 grant_type: Some("password".to_string()),
9677 username,
9678 password,
9679 scope,
9680 client_id,
9681 client_secret,
9682 },
9683 })
9684 }
9685}
9686
9687#[cfg(test)]
9688mod oauth2_password_form_tests {
9689 use super::*;
9690 use crate::request::Method;
9691
9692 fn test_context() -> RequestContext {
9693 let cx = asupersync::Cx::for_testing();
9694 RequestContext::new(cx, 12345)
9695 }
9696
9697 fn form_request(body: &str) -> Request {
9698 let mut req = Request::new(Method::Post, "/token");
9699 req.headers_mut().insert(
9700 "content-type",
9701 b"application/x-www-form-urlencoded".to_vec(),
9702 );
9703 req.set_body(Body::Bytes(body.as_bytes().to_vec()));
9704 req
9705 }
9706
9707 #[test]
9710 fn basic_form_extraction() {
9711 let ctx = test_context();
9712 let mut req = form_request("username=alice&password=secret123");
9713 let form =
9714 futures_executor::block_on(OAuth2PasswordRequestForm::from_request(&ctx, &mut req))
9715 .unwrap();
9716
9717 assert_eq!(form.username, "alice");
9718 assert_eq!(form.password, "secret123");
9719 assert!(form.grant_type.is_none());
9720 assert_eq!(form.scope, "");
9721 assert!(form.client_id.is_none());
9722 assert!(form.client_secret.is_none());
9723 }
9724
9725 #[test]
9726 fn full_form_extraction() {
9727 let ctx = test_context();
9728 let body = "grant_type=password&username=bob&password=s3cret&scope=read+write&client_id=myapp&client_secret=appsecret";
9729 let mut req = form_request(body);
9730 let form =
9731 futures_executor::block_on(OAuth2PasswordRequestForm::from_request(&ctx, &mut req))
9732 .unwrap();
9733
9734 assert_eq!(form.grant_type.as_deref(), Some("password"));
9735 assert_eq!(form.username, "bob");
9736 assert_eq!(form.password, "s3cret");
9737 assert_eq!(form.scope, "read write");
9738 assert_eq!(form.client_id.as_deref(), Some("myapp"));
9739 assert_eq!(form.client_secret.as_deref(), Some("appsecret"));
9740 }
9741
9742 #[test]
9743 fn scopes_parsing() {
9744 let ctx = test_context();
9745 let mut req = form_request("username=u&password=p&scope=read+write+admin");
9746 let form =
9747 futures_executor::block_on(OAuth2PasswordRequestForm::from_request(&ctx, &mut req))
9748 .unwrap();
9749
9750 let scopes = form.scopes();
9751 assert_eq!(scopes, vec!["read", "write", "admin"]);
9752 }
9753
9754 #[test]
9755 fn empty_scope_returns_empty_vec() {
9756 let ctx = test_context();
9757 let mut req = form_request("username=u&password=p");
9758 let form =
9759 futures_executor::block_on(OAuth2PasswordRequestForm::from_request(&ctx, &mut req))
9760 .unwrap();
9761
9762 assert!(form.scopes().is_empty());
9763 }
9764
9765 #[test]
9766 fn missing_username_error() {
9767 let ctx = test_context();
9768 let mut req = form_request("password=secret");
9769 let err =
9770 futures_executor::block_on(OAuth2PasswordRequestForm::from_request(&ctx, &mut req))
9771 .unwrap_err();
9772
9773 assert_eq!(err, OAuth2PasswordFormError::MissingUsername);
9774 assert!(err.to_string().contains("username"));
9775 }
9776
9777 #[test]
9778 fn missing_password_error() {
9779 let ctx = test_context();
9780 let mut req = form_request("username=alice");
9781 let err =
9782 futures_executor::block_on(OAuth2PasswordRequestForm::from_request(&ctx, &mut req))
9783 .unwrap_err();
9784
9785 assert_eq!(err, OAuth2PasswordFormError::MissingPassword);
9786 assert!(err.to_string().contains("password"));
9787 }
9788
9789 #[test]
9790 fn wrong_content_type_error() {
9791 let ctx = test_context();
9792 let mut req = Request::new(Method::Post, "/token");
9793 req.headers_mut()
9794 .insert("content-type", b"application/json".to_vec());
9795 req.set_body(Body::Bytes(b"username=a&password=b".to_vec()));
9796
9797 let err =
9798 futures_executor::block_on(OAuth2PasswordRequestForm::from_request(&ctx, &mut req))
9799 .unwrap_err();
9800
9801 match err {
9802 OAuth2PasswordFormError::UnsupportedMediaType { actual } => {
9803 assert_eq!(actual.as_deref(), Some("application/json"));
9804 }
9805 other => panic!("Expected UnsupportedMediaType, got: {other:?}"),
9806 }
9807 }
9808
9809 #[test]
9810 fn missing_content_type_error() {
9811 let ctx = test_context();
9812 let mut req = Request::new(Method::Post, "/token");
9813 req.set_body(Body::Bytes(b"username=a&password=b".to_vec()));
9814
9815 let err =
9816 futures_executor::block_on(OAuth2PasswordRequestForm::from_request(&ctx, &mut req))
9817 .unwrap_err();
9818
9819 match err {
9820 OAuth2PasswordFormError::UnsupportedMediaType { actual } => {
9821 assert!(actual.is_none());
9822 }
9823 other => panic!("Expected UnsupportedMediaType, got: {other:?}"),
9824 }
9825 }
9826
9827 #[test]
9828 fn url_encoded_values() {
9829 let ctx = test_context();
9830 let mut req = form_request("username=user%40example.com&password=p%26ss%3Dword");
9831 let form =
9832 futures_executor::block_on(OAuth2PasswordRequestForm::from_request(&ctx, &mut req))
9833 .unwrap();
9834
9835 assert_eq!(form.username, "user@example.com");
9836 assert_eq!(form.password, "p&ss=word");
9837 }
9838
9839 #[test]
9840 fn plus_decoded_as_space_in_scope() {
9841 let ctx = test_context();
9842 let mut req = form_request("username=u&password=p&scope=read+write+admin");
9843 let form =
9844 futures_executor::block_on(OAuth2PasswordRequestForm::from_request(&ctx, &mut req))
9845 .unwrap();
9846
9847 assert_eq!(form.scope, "read write admin");
9849 }
9850
9851 #[test]
9852 fn empty_body_returns_missing_username() {
9853 let ctx = test_context();
9854 let mut req = form_request("");
9855 let err =
9856 futures_executor::block_on(OAuth2PasswordRequestForm::from_request(&ctx, &mut req))
9857 .unwrap_err();
9858 assert_eq!(err, OAuth2PasswordFormError::MissingUsername);
9859 }
9860
9861 #[test]
9862 fn streaming_not_supported_error_type() {
9863 let err = OAuth2PasswordFormError::StreamingNotSupported;
9865 assert!(err.to_string().contains("Streaming"));
9866
9867 let resp = err.into_response();
9869 assert_eq!(resp.status().as_u16(), 400);
9870 }
9871
9872 #[test]
9873 fn error_display_messages() {
9874 assert!(
9875 OAuth2PasswordFormError::MissingUsername
9876 .to_string()
9877 .contains("username")
9878 );
9879 assert!(
9880 OAuth2PasswordFormError::MissingPassword
9881 .to_string()
9882 .contains("password")
9883 );
9884 assert!(
9885 OAuth2PasswordFormError::InvalidUtf8
9886 .to_string()
9887 .contains("UTF-8")
9888 );
9889 assert!(
9890 OAuth2PasswordFormError::StreamingNotSupported
9891 .to_string()
9892 .contains("Streaming")
9893 );
9894
9895 let err = OAuth2PasswordFormError::InvalidGrantType {
9896 actual: "code".to_string(),
9897 };
9898 assert!(err.to_string().contains("code"));
9899
9900 let err = OAuth2PasswordFormError::PayloadTooLarge {
9901 size: 100,
9902 limit: 50,
9903 };
9904 assert!(err.to_string().contains("100"));
9905 }
9906
9907 #[test]
9908 fn error_into_response_status_codes() {
9909 let resp = OAuth2PasswordFormError::MissingUsername.into_response();
9910 assert_eq!(resp.status().as_u16(), 422);
9911
9912 let resp = OAuth2PasswordFormError::MissingPassword.into_response();
9913 assert_eq!(resp.status().as_u16(), 422);
9914
9915 let resp = OAuth2PasswordFormError::UnsupportedMediaType { actual: None }.into_response();
9916 assert_eq!(resp.status().as_u16(), 415);
9917
9918 let resp = OAuth2PasswordFormError::PayloadTooLarge {
9919 size: 100,
9920 limit: 50,
9921 }
9922 .into_response();
9923 assert_eq!(resp.status().as_u16(), 413);
9924
9925 let resp = OAuth2PasswordFormError::InvalidGrantType {
9926 actual: "code".to_string(),
9927 }
9928 .into_response();
9929 assert_eq!(resp.status().as_u16(), 422);
9930 }
9931
9932 #[test]
9935 fn strict_accepts_password_grant_type() {
9936 let ctx = test_context();
9937 let mut req = form_request("grant_type=password&username=alice&password=secret");
9938 let form = futures_executor::block_on(OAuth2PasswordRequestFormStrict::from_request(
9939 &ctx, &mut req,
9940 ))
9941 .unwrap();
9942
9943 assert_eq!(form.form.grant_type.as_deref(), Some("password"));
9944 assert_eq!(form.username, "alice");
9945 assert_eq!(form.password, "secret");
9946 }
9947
9948 #[test]
9949 fn strict_rejects_missing_grant_type() {
9950 let ctx = test_context();
9951 let mut req = form_request("username=alice&password=secret");
9952 let err = futures_executor::block_on(OAuth2PasswordRequestFormStrict::from_request(
9953 &ctx, &mut req,
9954 ))
9955 .unwrap_err();
9956
9957 assert_eq!(err, OAuth2PasswordFormError::MissingGrantType);
9958 }
9959
9960 #[test]
9961 fn strict_rejects_wrong_grant_type() {
9962 let ctx = test_context();
9963 let mut req = form_request("grant_type=authorization_code&username=alice&password=secret");
9964 let err = futures_executor::block_on(OAuth2PasswordRequestFormStrict::from_request(
9965 &ctx, &mut req,
9966 ))
9967 .unwrap_err();
9968
9969 match err {
9970 OAuth2PasswordFormError::InvalidGrantType { actual } => {
9971 assert_eq!(actual, "authorization_code");
9972 }
9973 other => panic!("Expected InvalidGrantType, got: {other:?}"),
9974 }
9975 }
9976
9977 #[test]
9978 fn strict_with_all_fields() {
9979 let ctx = test_context();
9980 let body = "grant_type=password&username=bob&password=pw&scope=read+write&client_id=app&client_secret=sec";
9981 let mut req = form_request(body);
9982 let form = futures_executor::block_on(OAuth2PasswordRequestFormStrict::from_request(
9983 &ctx, &mut req,
9984 ))
9985 .unwrap();
9986
9987 assert_eq!(form.username, "bob");
9988 assert_eq!(form.password, "pw");
9989 assert_eq!(form.scope, "read write");
9990 assert_eq!(form.client_id.as_deref(), Some("app"));
9991 assert_eq!(form.client_secret.as_deref(), Some("sec"));
9992 assert_eq!(form.scopes(), vec!["read", "write"]);
9993 }
9994
9995 #[test]
9996 fn strict_deref_to_inner() {
9997 let ctx = test_context();
9998 let mut req = form_request("grant_type=password&username=alice&password=pw");
9999 let strict = futures_executor::block_on(OAuth2PasswordRequestFormStrict::from_request(
10000 &ctx, &mut req,
10001 ))
10002 .unwrap();
10003
10004 let _: &str = &strict.username;
10006 let _: &str = &strict.password;
10007 assert_eq!(strict.inner().username, "alice");
10008 }
10009
10010 #[test]
10011 fn strict_into_inner() {
10012 let ctx = test_context();
10013 let mut req = form_request("grant_type=password&username=alice&password=pw");
10014 let strict = futures_executor::block_on(OAuth2PasswordRequestFormStrict::from_request(
10015 &ctx, &mut req,
10016 ))
10017 .unwrap();
10018
10019 let form = strict.into_inner();
10020 assert_eq!(form.username, "alice");
10021 assert_eq!(form.grant_type.as_deref(), Some("password"));
10022 }
10023
10024 #[test]
10025 fn strict_missing_username_after_grant_type() {
10026 let ctx = test_context();
10027 let mut req = form_request("grant_type=password&password=secret");
10028 let err = futures_executor::block_on(OAuth2PasswordRequestFormStrict::from_request(
10029 &ctx, &mut req,
10030 ))
10031 .unwrap_err();
10032 assert_eq!(err, OAuth2PasswordFormError::MissingUsername);
10033 }
10034
10035 #[test]
10036 fn strict_missing_password_after_grant_type() {
10037 let ctx = test_context();
10038 let mut req = form_request("grant_type=password&username=alice");
10039 let err = futures_executor::block_on(OAuth2PasswordRequestFormStrict::from_request(
10040 &ctx, &mut req,
10041 ))
10042 .unwrap_err();
10043 assert_eq!(err, OAuth2PasswordFormError::MissingPassword);
10044 }
10045}
10046
10047#[derive(Debug, Clone)]
10071pub struct OAuth2AuthorizationCodeBearerConfig {
10072 pub authorization_url: String,
10074 pub token_url: String,
10076 pub refresh_url: Option<String>,
10078 pub scopes: std::collections::HashMap<String, String>,
10080 pub scheme_name: Option<String>,
10082 pub description: Option<String>,
10084 pub auto_error: bool,
10087}
10088
10089impl OAuth2AuthorizationCodeBearerConfig {
10090 #[must_use]
10092 pub fn new(authorization_url: impl Into<String>, token_url: impl Into<String>) -> Self {
10093 Self {
10094 authorization_url: authorization_url.into(),
10095 token_url: token_url.into(),
10096 refresh_url: None,
10097 scopes: std::collections::HashMap::new(),
10098 scheme_name: None,
10099 description: None,
10100 auto_error: true,
10101 }
10102 }
10103
10104 #[must_use]
10106 pub fn with_refresh_url(mut self, url: impl Into<String>) -> Self {
10107 self.refresh_url = Some(url.into());
10108 self
10109 }
10110
10111 #[must_use]
10113 pub fn with_scope(mut self, scope: impl Into<String>, description: impl Into<String>) -> Self {
10114 self.scopes.insert(scope.into(), description.into());
10115 self
10116 }
10117
10118 #[must_use]
10120 pub fn with_scheme_name(mut self, name: impl Into<String>) -> Self {
10121 self.scheme_name = Some(name.into());
10122 self
10123 }
10124
10125 #[must_use]
10127 pub fn with_description(mut self, desc: impl Into<String>) -> Self {
10128 self.description = Some(desc.into());
10129 self
10130 }
10131
10132 #[must_use]
10134 pub fn with_auto_error(mut self, auto_error: bool) -> Self {
10135 self.auto_error = auto_error;
10136 self
10137 }
10138}
10139
10140#[derive(Debug, Clone)]
10158pub struct OAuth2AuthorizationCodeBearer {
10159 pub token: String,
10161}
10162
10163impl OAuth2AuthorizationCodeBearer {
10164 #[must_use]
10166 pub fn new(token: impl Into<String>) -> Self {
10167 Self {
10168 token: token.into(),
10169 }
10170 }
10171
10172 #[must_use]
10174 pub fn token(&self) -> &str {
10175 &self.token
10176 }
10177
10178 #[must_use]
10180 pub fn into_token(self) -> String {
10181 self.token
10182 }
10183}
10184
10185impl Deref for OAuth2AuthorizationCodeBearer {
10186 type Target = str;
10187
10188 fn deref(&self) -> &Self::Target {
10189 &self.token
10190 }
10191}
10192
10193impl FromRequest for OAuth2AuthorizationCodeBearer {
10194 type Error = OAuth2BearerError;
10195
10196 async fn from_request(_ctx: &RequestContext, req: &mut Request) -> Result<Self, Self::Error> {
10197 let auth_header = req
10200 .headers()
10201 .get("authorization")
10202 .ok_or_else(OAuth2BearerError::missing_header)?;
10203
10204 let auth_str =
10205 std::str::from_utf8(auth_header).map_err(|_| OAuth2BearerError::invalid_scheme())?;
10206
10207 const BEARER_PREFIX: &str = "Bearer ";
10208 const BEARER_PREFIX_LOWER: &str = "bearer ";
10209
10210 let token = if auth_str.starts_with(BEARER_PREFIX) {
10211 &auth_str[BEARER_PREFIX.len()..]
10212 } else if auth_str.starts_with(BEARER_PREFIX_LOWER) {
10213 &auth_str[BEARER_PREFIX_LOWER.len()..]
10214 } else {
10215 return Err(OAuth2BearerError::invalid_scheme());
10216 };
10217
10218 let token = token.trim();
10219 if token.is_empty() {
10220 return Err(OAuth2BearerError::empty_token());
10221 }
10222
10223 Ok(OAuth2AuthorizationCodeBearer::new(token))
10224 }
10225}
10226
10227#[cfg(test)]
10228mod oauth2_authcode_bearer_tests {
10229 use super::*;
10230 use crate::request::Method;
10231
10232 fn test_context() -> RequestContext {
10233 let cx = asupersync::Cx::for_testing();
10234 RequestContext::new(cx, 12345)
10235 }
10236
10237 #[test]
10240 fn config_new() {
10241 let config = OAuth2AuthorizationCodeBearerConfig::new(
10242 "https://auth.example.com/authorize",
10243 "https://auth.example.com/token",
10244 );
10245 assert_eq!(
10246 config.authorization_url,
10247 "https://auth.example.com/authorize"
10248 );
10249 assert_eq!(config.token_url, "https://auth.example.com/token");
10250 assert!(config.refresh_url.is_none());
10251 assert!(config.scopes.is_empty());
10252 assert!(config.scheme_name.is_none());
10253 assert!(config.description.is_none());
10254 assert!(config.auto_error);
10255 }
10256
10257 #[test]
10258 fn config_builder() {
10259 let config = OAuth2AuthorizationCodeBearerConfig::new(
10260 "https://auth.example.com/authorize",
10261 "https://auth.example.com/token",
10262 )
10263 .with_refresh_url("https://auth.example.com/refresh")
10264 .with_scope("read", "Read access")
10265 .with_scope("write", "Write access")
10266 .with_scheme_name("MyAuth")
10267 .with_description("Authorization code flow")
10268 .with_auto_error(false);
10269
10270 assert_eq!(
10271 config.refresh_url.as_deref(),
10272 Some("https://auth.example.com/refresh")
10273 );
10274 assert_eq!(config.scopes.len(), 2);
10275 assert_eq!(config.scopes.get("read").unwrap(), "Read access");
10276 assert_eq!(config.scopes.get("write").unwrap(), "Write access");
10277 assert_eq!(config.scheme_name.as_deref(), Some("MyAuth"));
10278 assert_eq!(
10279 config.description.as_deref(),
10280 Some("Authorization code flow")
10281 );
10282 assert!(!config.auto_error);
10283 }
10284
10285 #[test]
10288 fn extracts_bearer_token() {
10289 let ctx = test_context();
10290 let mut req = Request::new(Method::Get, "/protected");
10291 req.headers_mut()
10292 .insert("authorization", b"Bearer my-access-token".to_vec());
10293
10294 let result =
10295 futures_executor::block_on(OAuth2AuthorizationCodeBearer::from_request(&ctx, &mut req))
10296 .unwrap();
10297
10298 assert_eq!(result.token(), "my-access-token");
10299 assert_eq!(result.into_token(), "my-access-token");
10300 }
10301
10302 #[test]
10303 fn extracts_lowercase_bearer() {
10304 let ctx = test_context();
10305 let mut req = Request::new(Method::Get, "/protected");
10306 req.headers_mut()
10307 .insert("authorization", b"bearer my-token".to_vec());
10308
10309 let result =
10310 futures_executor::block_on(OAuth2AuthorizationCodeBearer::from_request(&ctx, &mut req))
10311 .unwrap();
10312
10313 assert_eq!(result.token(), "my-token");
10314 }
10315
10316 #[test]
10317 fn missing_header_returns_401() {
10318 let ctx = test_context();
10319 let mut req = Request::new(Method::Get, "/protected");
10320
10321 let err =
10322 futures_executor::block_on(OAuth2AuthorizationCodeBearer::from_request(&ctx, &mut req))
10323 .unwrap_err();
10324
10325 assert_eq!(err.kind, OAuth2BearerErrorKind::MissingHeader);
10326
10327 let resp = err.into_response();
10328 assert_eq!(resp.status().as_u16(), 401);
10329 }
10330
10331 #[test]
10332 fn invalid_scheme_returns_401() {
10333 let ctx = test_context();
10334 let mut req = Request::new(Method::Get, "/protected");
10335 req.headers_mut()
10336 .insert("authorization", b"Basic dXNlcjpwYXNz".to_vec());
10337
10338 let err =
10339 futures_executor::block_on(OAuth2AuthorizationCodeBearer::from_request(&ctx, &mut req))
10340 .unwrap_err();
10341
10342 assert_eq!(err.kind, OAuth2BearerErrorKind::InvalidScheme);
10343 }
10344
10345 #[test]
10346 fn empty_token_returns_error() {
10347 let ctx = test_context();
10348 let mut req = Request::new(Method::Get, "/protected");
10349 req.headers_mut()
10350 .insert("authorization", b"Bearer ".to_vec());
10351
10352 let err =
10353 futures_executor::block_on(OAuth2AuthorizationCodeBearer::from_request(&ctx, &mut req))
10354 .unwrap_err();
10355
10356 assert_eq!(err.kind, OAuth2BearerErrorKind::EmptyToken);
10357 }
10358
10359 #[test]
10360 fn whitespace_only_token_returns_error() {
10361 let ctx = test_context();
10362 let mut req = Request::new(Method::Get, "/protected");
10363 req.headers_mut()
10364 .insert("authorization", b"Bearer ".to_vec());
10365
10366 let err =
10367 futures_executor::block_on(OAuth2AuthorizationCodeBearer::from_request(&ctx, &mut req))
10368 .unwrap_err();
10369
10370 assert_eq!(err.kind, OAuth2BearerErrorKind::EmptyToken);
10371 }
10372
10373 #[test]
10374 fn token_trimmed() {
10375 let ctx = test_context();
10376 let mut req = Request::new(Method::Get, "/protected");
10377 req.headers_mut()
10378 .insert("authorization", b"Bearer my-token ".to_vec());
10379
10380 let result =
10381 futures_executor::block_on(OAuth2AuthorizationCodeBearer::from_request(&ctx, &mut req))
10382 .unwrap();
10383
10384 assert_eq!(result.token(), "my-token");
10385 }
10386
10387 #[test]
10388 fn deref_to_str() {
10389 let bearer = OAuth2AuthorizationCodeBearer::new("abc123");
10390 let s: &str = &bearer;
10391 assert_eq!(s, "abc123");
10392 }
10393
10394 #[test]
10395 fn new_constructor() {
10396 let bearer = OAuth2AuthorizationCodeBearer::new("token-value");
10397 assert_eq!(bearer.token, "token-value");
10398 assert_eq!(bearer.token(), "token-value");
10399 }
10400
10401 #[test]
10402 fn www_authenticate_header_on_error() {
10403 let err = OAuth2BearerError::missing_header();
10404 let resp = err.into_response();
10405
10406 let has_www_auth = resp
10407 .headers()
10408 .iter()
10409 .any(|(n, v)| n.eq_ignore_ascii_case("www-authenticate") && v == b"Bearer");
10410 assert!(
10411 has_www_auth,
10412 "Response should have WWW-Authenticate: Bearer header"
10413 );
10414 }
10415}
10416
10417#[derive(Debug, Clone)]
10450pub struct SecurityScopes {
10451 scopes: Vec<String>,
10453 scope_str: String,
10455}
10456
10457impl SecurityScopes {
10458 #[must_use]
10460 pub fn new() -> Self {
10461 Self {
10462 scopes: Vec::new(),
10463 scope_str: String::new(),
10464 }
10465 }
10466
10467 #[must_use]
10471 pub fn from_scopes(scopes: impl IntoIterator<Item = impl Into<String>>) -> Self {
10472 let mut seen = std::collections::HashSet::new();
10473 let deduped: Vec<String> = scopes
10474 .into_iter()
10475 .map(Into::into)
10476 .filter(|s| seen.insert(s.clone()))
10477 .collect();
10478 let scope_str = deduped.join(" ");
10479 Self {
10480 scopes: deduped,
10481 scope_str,
10482 }
10483 }
10484
10485 #[must_use]
10490 pub fn from_scope_str(scope_str: &str) -> Self {
10491 let parts = scope_str
10492 .split(' ')
10493 .filter(|s| !s.is_empty())
10494 .map(String::from);
10495 Self::from_scopes(parts)
10496 }
10497
10498 #[must_use]
10500 pub fn scopes(&self) -> &[String] {
10501 &self.scopes
10502 }
10503
10504 #[must_use]
10506 pub fn scope_str(&self) -> &str {
10507 &self.scope_str
10508 }
10509
10510 #[must_use]
10512 pub fn contains(&self, scope: &str) -> bool {
10513 self.scopes.iter().any(|s| s == scope)
10514 }
10515
10516 #[must_use]
10518 pub fn is_empty(&self) -> bool {
10519 self.scopes.is_empty()
10520 }
10521
10522 #[must_use]
10524 pub fn len(&self) -> usize {
10525 self.scopes.len()
10526 }
10527
10528 pub fn merge(&mut self, other: &SecurityScopes) {
10532 let existing: std::collections::HashSet<String> = self.scopes.iter().cloned().collect();
10533 for scope in &other.scopes {
10534 if !existing.contains(scope) {
10535 self.scopes.push(scope.clone());
10536 }
10537 }
10538 self.scope_str = self.scopes.join(" ");
10539 }
10540
10541 #[must_use]
10543 pub fn merged(&self, other: &SecurityScopes) -> Self {
10544 let mut result = self.clone();
10545 result.merge(other);
10546 result
10547 }
10548}
10549
10550impl Default for SecurityScopes {
10551 fn default() -> Self {
10552 Self::new()
10553 }
10554}
10555
10556impl std::fmt::Display for SecurityScopes {
10557 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
10558 f.write_str(&self.scope_str)
10559 }
10560}
10561
10562#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10567pub struct SecurityScopesError;
10568
10569impl std::fmt::Display for SecurityScopesError {
10570 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
10571 write!(f, "No security scopes configured for this route")
10572 }
10573}
10574
10575impl std::error::Error for SecurityScopesError {}
10576
10577impl IntoResponse for SecurityScopesError {
10578 fn into_response(self) -> Response {
10579 HttpError::new(crate::response::StatusCode::INTERNAL_SERVER_ERROR)
10580 .with_detail("Security scopes not configured")
10581 .into_response()
10582 }
10583}
10584
10585impl FromRequest for SecurityScopes {
10586 type Error = SecurityScopesError;
10587
10588 async fn from_request(_ctx: &RequestContext, req: &mut Request) -> Result<Self, Self::Error> {
10589 if let Some(scopes) = req.get_extension::<SecurityScopes>() {
10592 Ok(scopes.clone())
10593 } else {
10594 Ok(SecurityScopes::new())
10597 }
10598 }
10599}
10600
10601#[cfg(test)]
10602mod security_scopes_tests {
10603 use super::*;
10604 use crate::request::Method;
10605
10606 fn test_context() -> RequestContext {
10607 let cx = asupersync::Cx::for_testing();
10608 RequestContext::new(cx, 12345)
10609 }
10610
10611 #[test]
10614 fn new_is_empty() {
10615 let scopes = SecurityScopes::new();
10616 assert!(scopes.is_empty());
10617 assert_eq!(scopes.len(), 0);
10618 assert_eq!(scopes.scope_str(), "");
10619 assert!(scopes.scopes().is_empty());
10620 }
10621
10622 #[test]
10623 fn default_is_empty() {
10624 let scopes = SecurityScopes::default();
10625 assert!(scopes.is_empty());
10626 }
10627
10628 #[test]
10629 fn from_scopes_preserves_order() {
10630 let scopes = SecurityScopes::from_scopes(["read", "write", "admin"]);
10631 assert_eq!(scopes.scopes(), &["read", "write", "admin"]);
10632 assert_eq!(scopes.scope_str(), "read write admin");
10633 assert_eq!(scopes.len(), 3);
10634 }
10635
10636 #[test]
10637 fn from_scopes_deduplicates() {
10638 let scopes = SecurityScopes::from_scopes(["read", "write", "read", "admin", "write"]);
10639 assert_eq!(scopes.scopes(), &["read", "write", "admin"]);
10640 assert_eq!(scopes.scope_str(), "read write admin");
10641 }
10642
10643 #[test]
10644 fn from_scope_str() {
10645 let scopes = SecurityScopes::from_scope_str("read write admin");
10646 assert_eq!(scopes.scopes(), &["read", "write", "admin"]);
10647 assert_eq!(scopes.scope_str(), "read write admin");
10648 }
10649
10650 #[test]
10651 fn from_scope_str_deduplicates() {
10652 let scopes = SecurityScopes::from_scope_str("read write read admin");
10653 assert_eq!(scopes.scopes(), &["read", "write", "admin"]);
10654 }
10655
10656 #[test]
10657 fn from_scope_str_ignores_empty_segments() {
10658 let scopes = SecurityScopes::from_scope_str("read write admin");
10659 assert_eq!(scopes.scopes(), &["read", "write", "admin"]);
10660 }
10661
10662 #[test]
10663 fn from_empty_scope_str() {
10664 let scopes = SecurityScopes::from_scope_str("");
10665 assert!(scopes.is_empty());
10666 assert_eq!(scopes.scope_str(), "");
10667 }
10668
10669 #[test]
10672 fn contains_scope() {
10673 let scopes = SecurityScopes::from_scopes(["read", "write"]);
10674 assert!(scopes.contains("read"));
10675 assert!(scopes.contains("write"));
10676 assert!(!scopes.contains("admin"));
10677 }
10678
10679 #[test]
10680 fn display_format() {
10681 let scopes = SecurityScopes::from_scopes(["read", "write"]);
10682 assert_eq!(format!("{scopes}"), "read write");
10683
10684 let empty = SecurityScopes::new();
10685 assert_eq!(format!("{empty}"), "");
10686 }
10687
10688 #[test]
10691 fn merge_appends_new_scopes() {
10692 let mut base = SecurityScopes::from_scopes(["read"]);
10693 let other = SecurityScopes::from_scopes(["write", "admin"]);
10694 base.merge(&other);
10695
10696 assert_eq!(base.scopes(), &["read", "write", "admin"]);
10697 assert_eq!(base.scope_str(), "read write admin");
10698 }
10699
10700 #[test]
10701 fn merge_deduplicates() {
10702 let mut base = SecurityScopes::from_scopes(["read", "write"]);
10703 let other = SecurityScopes::from_scopes(["write", "admin"]);
10704 base.merge(&other);
10705
10706 assert_eq!(base.scopes(), &["read", "write", "admin"]);
10707 }
10708
10709 #[test]
10710 fn merge_empty_into_nonempty() {
10711 let mut base = SecurityScopes::from_scopes(["read"]);
10712 let other = SecurityScopes::new();
10713 base.merge(&other);
10714
10715 assert_eq!(base.scopes(), &["read"]);
10716 }
10717
10718 #[test]
10719 fn merge_nonempty_into_empty() {
10720 let mut base = SecurityScopes::new();
10721 let other = SecurityScopes::from_scopes(["read", "write"]);
10722 base.merge(&other);
10723
10724 assert_eq!(base.scopes(), &["read", "write"]);
10725 }
10726
10727 #[test]
10728 fn merged_returns_new_instance() {
10729 let base = SecurityScopes::from_scopes(["read"]);
10730 let other = SecurityScopes::from_scopes(["write"]);
10731 let combined = base.merged(&other);
10732
10733 assert_eq!(combined.scopes(), &["read", "write"]);
10734 assert_eq!(base.scopes(), &["read"]);
10736 }
10737
10738 #[test]
10741 fn extract_with_extension() {
10742 let ctx = test_context();
10743 let mut req = Request::new(Method::Get, "/protected");
10744 req.insert_extension(SecurityScopes::from_scopes(["admin", "users:read"]));
10745
10746 let scopes =
10747 futures_executor::block_on(SecurityScopes::from_request(&ctx, &mut req)).unwrap();
10748 assert_eq!(scopes.scopes(), &["admin", "users:read"]);
10749 assert_eq!(scopes.scope_str(), "admin users:read");
10750 }
10751
10752 #[test]
10753 fn extract_without_extension_returns_empty() {
10754 let ctx = test_context();
10755 let mut req = Request::new(Method::Get, "/public");
10756
10757 let scopes =
10758 futures_executor::block_on(SecurityScopes::from_request(&ctx, &mut req)).unwrap();
10759 assert!(scopes.is_empty());
10760 }
10761
10762 #[test]
10765 fn error_display() {
10766 let err = SecurityScopesError;
10767 assert!(err.to_string().contains("security scopes"));
10768 }
10769
10770 #[test]
10771 fn error_into_response_is_500() {
10772 let resp = SecurityScopesError.into_response();
10773 assert_eq!(resp.status().as_u16(), 500);
10774 }
10775}
10776
10777#[derive(Debug, Clone, PartialEq, Eq)]
10833pub struct BearerToken {
10834 token: String,
10836}
10837
10838impl BearerToken {
10839 #[must_use]
10841 pub fn new(token: impl Into<String>) -> Self {
10842 Self {
10843 token: token.into(),
10844 }
10845 }
10846
10847 #[must_use]
10849 pub fn token(&self) -> &str {
10850 &self.token
10851 }
10852
10853 #[must_use]
10855 pub fn into_token(self) -> String {
10856 self.token
10857 }
10858}
10859
10860impl Deref for BearerToken {
10861 type Target = str;
10862
10863 fn deref(&self) -> &Self::Target {
10864 &self.token
10865 }
10866}
10867
10868impl AsRef<str> for BearerToken {
10869 fn as_ref(&self) -> &str {
10870 &self.token
10871 }
10872}
10873
10874#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10876pub enum BearerTokenError {
10877 MissingHeader,
10879 InvalidScheme,
10881 EmptyToken,
10883}
10884
10885impl BearerTokenError {
10886 #[must_use]
10888 pub fn missing_header() -> Self {
10889 Self::MissingHeader
10890 }
10891
10892 #[must_use]
10894 pub fn invalid_scheme() -> Self {
10895 Self::InvalidScheme
10896 }
10897
10898 #[must_use]
10900 pub fn empty_token() -> Self {
10901 Self::EmptyToken
10902 }
10903
10904 #[must_use]
10906 pub fn detail(&self) -> &'static str {
10907 match self {
10908 Self::MissingHeader => "Not authenticated",
10909 Self::InvalidScheme => "Invalid authentication credentials",
10910 Self::EmptyToken => "Invalid authentication credentials",
10911 }
10912 }
10913}
10914
10915impl fmt::Display for BearerTokenError {
10916 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
10917 match self {
10918 Self::MissingHeader => write!(f, "Missing Authorization header"),
10919 Self::InvalidScheme => write!(f, "Authorization header must use Bearer scheme"),
10920 Self::EmptyToken => write!(f, "Bearer token is empty"),
10921 }
10922 }
10923}
10924
10925impl std::error::Error for BearerTokenError {}
10926
10927impl IntoResponse for BearerTokenError {
10928 fn into_response(self) -> crate::response::Response {
10929 use crate::response::{Response, ResponseBody, StatusCode};
10930
10931 let body = serde_json::json!({
10932 "detail": self.detail()
10933 });
10934
10935 Response::with_status(StatusCode::UNAUTHORIZED)
10936 .header("www-authenticate", b"Bearer".to_vec())
10937 .header("content-type", b"application/json".to_vec())
10938 .body(ResponseBody::Bytes(body.to_string().into_bytes()))
10939 }
10940}
10941
10942impl FromRequest for BearerToken {
10943 type Error = BearerTokenError;
10944
10945 async fn from_request(_ctx: &RequestContext, req: &mut Request) -> Result<Self, Self::Error> {
10946 let auth_header = req
10948 .headers()
10949 .get("authorization")
10950 .ok_or(BearerTokenError::MissingHeader)?;
10951
10952 let auth_str =
10954 std::str::from_utf8(auth_header).map_err(|_| BearerTokenError::InvalidScheme)?;
10955
10956 const BEARER_PREFIX: &str = "Bearer ";
10958 const BEARER_PREFIX_LOWER: &str = "bearer ";
10959
10960 let token = if auth_str.starts_with(BEARER_PREFIX) {
10961 &auth_str[BEARER_PREFIX.len()..]
10962 } else if auth_str.starts_with(BEARER_PREFIX_LOWER) {
10963 &auth_str[BEARER_PREFIX_LOWER.len()..]
10964 } else {
10965 return Err(BearerTokenError::InvalidScheme);
10966 };
10967
10968 let token = token.trim();
10970 if token.is_empty() {
10971 return Err(BearerTokenError::EmptyToken);
10972 }
10973
10974 Ok(BearerToken::new(token))
10975 }
10976}
10977
10978pub const DEFAULT_API_KEY_HEADER: &str = "x-api-key";
10984
10985#[derive(Debug, Clone)]
10987pub struct ApiKeyHeaderConfig {
10988 header_name: String,
10990}
10991
10992impl Default for ApiKeyHeaderConfig {
10993 fn default() -> Self {
10994 Self {
10995 header_name: DEFAULT_API_KEY_HEADER.to_string(),
10996 }
10997 }
10998}
10999
11000impl ApiKeyHeaderConfig {
11001 #[must_use]
11003 pub fn new() -> Self {
11004 Self::default()
11005 }
11006
11007 #[must_use]
11009 pub fn header_name(mut self, name: impl Into<String>) -> Self {
11010 self.header_name = name.into();
11011 self
11012 }
11013
11014 #[must_use]
11016 pub fn get_header_name(&self) -> &str {
11017 &self.header_name
11018 }
11019}
11020
11021#[derive(Debug, Clone, PartialEq, Eq)]
11061pub struct ApiKeyHeader {
11062 key: String,
11064 header_name: String,
11066}
11067
11068impl ApiKeyHeader {
11069 #[must_use]
11071 pub fn new(key: impl Into<String>) -> Self {
11072 Self {
11073 key: key.into(),
11074 header_name: DEFAULT_API_KEY_HEADER.to_string(),
11075 }
11076 }
11077
11078 #[must_use]
11080 pub fn with_header_name(key: impl Into<String>, header_name: impl Into<String>) -> Self {
11081 Self {
11082 key: key.into(),
11083 header_name: header_name.into(),
11084 }
11085 }
11086
11087 #[must_use]
11089 pub fn key(&self) -> &str {
11090 &self.key
11091 }
11092
11093 #[must_use]
11095 pub fn header_name(&self) -> &str {
11096 &self.header_name
11097 }
11098
11099 #[must_use]
11101 pub fn into_key(self) -> String {
11102 self.key
11103 }
11104}
11105
11106impl Deref for ApiKeyHeader {
11107 type Target = str;
11108
11109 fn deref(&self) -> &Self::Target {
11110 &self.key
11111 }
11112}
11113
11114impl AsRef<str> for ApiKeyHeader {
11115 fn as_ref(&self) -> &str {
11116 &self.key
11117 }
11118}
11119
11120impl SecureCompare for ApiKeyHeader {
11122 fn secure_eq(&self, other: &str) -> bool {
11123 constant_time_str_eq(&self.key, other)
11124 }
11125
11126 fn secure_eq_bytes(&self, other: &[u8]) -> bool {
11127 constant_time_eq(self.key.as_bytes(), other)
11128 }
11129}
11130
11131#[derive(Debug, Clone)]
11133pub enum ApiKeyHeaderError {
11134 MissingHeader {
11136 header_name: String,
11138 },
11139 EmptyKey {
11141 header_name: String,
11143 },
11144 InvalidUtf8 {
11146 header_name: String,
11148 },
11149}
11150
11151impl ApiKeyHeaderError {
11152 #[must_use]
11154 pub fn missing_header(header_name: impl Into<String>) -> Self {
11155 Self::MissingHeader {
11156 header_name: header_name.into(),
11157 }
11158 }
11159
11160 #[must_use]
11162 pub fn empty_key(header_name: impl Into<String>) -> Self {
11163 Self::EmptyKey {
11164 header_name: header_name.into(),
11165 }
11166 }
11167
11168 #[must_use]
11170 pub fn invalid_utf8(header_name: impl Into<String>) -> Self {
11171 Self::InvalidUtf8 {
11172 header_name: header_name.into(),
11173 }
11174 }
11175
11176 #[must_use]
11178 pub fn detail(&self) -> String {
11179 match self {
11180 Self::MissingHeader { header_name } => {
11181 format!("Missing required header: {header_name}")
11182 }
11183 Self::EmptyKey { header_name } => {
11184 format!("Empty API key in header: {header_name}")
11185 }
11186 Self::InvalidUtf8 { header_name } => {
11187 format!("Invalid API key encoding in header: {header_name}")
11188 }
11189 }
11190 }
11191}
11192
11193impl fmt::Display for ApiKeyHeaderError {
11194 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
11195 match self {
11196 Self::MissingHeader { header_name } => {
11197 write!(f, "Missing API key header: {header_name}")
11198 }
11199 Self::EmptyKey { header_name } => {
11200 write!(f, "Empty API key in header: {header_name}")
11201 }
11202 Self::InvalidUtf8 { header_name } => {
11203 write!(f, "Invalid UTF-8 in header: {header_name}")
11204 }
11205 }
11206 }
11207}
11208
11209impl std::error::Error for ApiKeyHeaderError {}
11210
11211impl IntoResponse for ApiKeyHeaderError {
11212 fn into_response(self) -> crate::response::Response {
11213 use crate::response::{Response, ResponseBody, StatusCode};
11214
11215 let body = serde_json::json!({
11216 "detail": self.detail()
11217 });
11218
11219 Response::with_status(StatusCode::UNAUTHORIZED)
11220 .header("content-type", b"application/json".to_vec())
11221 .body(ResponseBody::Bytes(body.to_string().into_bytes()))
11222 }
11223}
11224
11225impl FromRequest for ApiKeyHeader {
11226 type Error = ApiKeyHeaderError;
11227
11228 async fn from_request(_ctx: &RequestContext, req: &mut Request) -> Result<Self, Self::Error> {
11229 let header_name = req.get_extension::<ApiKeyHeaderConfig>().map_or_else(
11231 || DEFAULT_API_KEY_HEADER.to_string(),
11232 |c| c.get_header_name().to_string(),
11233 );
11234
11235 let key_header = req
11237 .headers()
11238 .get(&header_name)
11239 .ok_or_else(|| ApiKeyHeaderError::missing_header(&header_name))?;
11240
11241 let key_str = std::str::from_utf8(key_header)
11243 .map_err(|_| ApiKeyHeaderError::invalid_utf8(&header_name))?;
11244
11245 let key = key_str.trim();
11247 if key.is_empty() {
11248 return Err(ApiKeyHeaderError::empty_key(&header_name));
11249 }
11250
11251 Ok(ApiKeyHeader::with_header_name(key, header_name))
11252 }
11253}
11254
11255pub const DEFAULT_API_KEY_QUERY_PARAM: &str = "api_key";
11261
11262#[derive(Debug, Clone)]
11264pub struct ApiKeyQueryConfig {
11265 param_name: String,
11267}
11268
11269impl Default for ApiKeyQueryConfig {
11270 fn default() -> Self {
11271 Self {
11272 param_name: DEFAULT_API_KEY_QUERY_PARAM.to_string(),
11273 }
11274 }
11275}
11276
11277impl ApiKeyQueryConfig {
11278 #[must_use]
11280 pub fn new() -> Self {
11281 Self::default()
11282 }
11283
11284 #[must_use]
11286 pub fn param_name(mut self, name: impl Into<String>) -> Self {
11287 self.param_name = name.into();
11288 self
11289 }
11290
11291 #[must_use]
11293 pub fn get_param_name(&self) -> &str {
11294 &self.param_name
11295 }
11296}
11297
11298#[derive(Debug, Clone, PartialEq, Eq)]
11352pub struct ApiKeyQuery {
11353 key: String,
11355 param_name: String,
11357}
11358
11359impl ApiKeyQuery {
11360 #[must_use]
11362 pub fn new(key: impl Into<String>) -> Self {
11363 Self {
11364 key: key.into(),
11365 param_name: DEFAULT_API_KEY_QUERY_PARAM.to_string(),
11366 }
11367 }
11368
11369 #[must_use]
11371 pub fn with_param_name(key: impl Into<String>, param_name: impl Into<String>) -> Self {
11372 Self {
11373 key: key.into(),
11374 param_name: param_name.into(),
11375 }
11376 }
11377
11378 #[must_use]
11380 pub fn key(&self) -> &str {
11381 &self.key
11382 }
11383
11384 #[must_use]
11386 pub fn param_name(&self) -> &str {
11387 &self.param_name
11388 }
11389
11390 #[must_use]
11392 pub fn into_key(self) -> String {
11393 self.key
11394 }
11395}
11396
11397impl Deref for ApiKeyQuery {
11398 type Target = str;
11399
11400 fn deref(&self) -> &Self::Target {
11401 &self.key
11402 }
11403}
11404
11405impl AsRef<str> for ApiKeyQuery {
11406 fn as_ref(&self) -> &str {
11407 &self.key
11408 }
11409}
11410
11411impl SecureCompare for ApiKeyQuery {
11413 fn secure_eq(&self, other: &str) -> bool {
11414 constant_time_str_eq(&self.key, other)
11415 }
11416
11417 fn secure_eq_bytes(&self, other: &[u8]) -> bool {
11418 constant_time_eq(self.key.as_bytes(), other)
11419 }
11420}
11421
11422#[derive(Debug, Clone)]
11424pub enum ApiKeyQueryError {
11425 MissingParam {
11427 param_name: String,
11429 },
11430 EmptyKey {
11432 param_name: String,
11434 },
11435}
11436
11437impl ApiKeyQueryError {
11438 #[must_use]
11440 pub fn missing_param(param_name: impl Into<String>) -> Self {
11441 Self::MissingParam {
11442 param_name: param_name.into(),
11443 }
11444 }
11445
11446 #[must_use]
11448 pub fn empty_key(param_name: impl Into<String>) -> Self {
11449 Self::EmptyKey {
11450 param_name: param_name.into(),
11451 }
11452 }
11453
11454 #[must_use]
11456 pub fn detail(&self) -> String {
11457 match self {
11458 Self::MissingParam { param_name } => {
11459 format!("API key required. Include '{param_name}' query parameter.")
11460 }
11461 Self::EmptyKey { param_name } => {
11462 format!("API key cannot be empty. Provide a value for '{param_name}'.")
11463 }
11464 }
11465 }
11466}
11467
11468impl fmt::Display for ApiKeyQueryError {
11469 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
11470 match self {
11471 Self::MissingParam { param_name } => {
11472 write!(f, "Missing API key query parameter: {param_name}")
11473 }
11474 Self::EmptyKey { param_name } => {
11475 write!(f, "Empty API key in query parameter: {param_name}")
11476 }
11477 }
11478 }
11479}
11480
11481impl std::error::Error for ApiKeyQueryError {}
11482
11483impl IntoResponse for ApiKeyQueryError {
11484 fn into_response(self) -> crate::response::Response {
11485 use crate::response::{Response, ResponseBody, StatusCode};
11486
11487 let body = serde_json::json!({
11488 "detail": self.detail()
11489 });
11490
11491 Response::with_status(StatusCode::UNAUTHORIZED)
11492 .header("content-type", b"application/json".to_vec())
11493 .body(ResponseBody::Bytes(body.to_string().into_bytes()))
11494 }
11495}
11496
11497impl FromRequest for ApiKeyQuery {
11498 type Error = ApiKeyQueryError;
11499
11500 async fn from_request(_ctx: &RequestContext, req: &mut Request) -> Result<Self, Self::Error> {
11501 let param_name = req.get_extension::<ApiKeyQueryConfig>().map_or_else(
11503 || DEFAULT_API_KEY_QUERY_PARAM.to_string(),
11504 |c| c.get_param_name().to_string(),
11505 );
11506
11507 let query_params = req.query().map(QueryParams::parse).unwrap_or_default();
11509
11510 let key_value = query_params
11512 .get(¶m_name)
11513 .ok_or_else(|| ApiKeyQueryError::missing_param(¶m_name))?;
11514
11515 let key = key_value.trim();
11517 if key.is_empty() {
11518 return Err(ApiKeyQueryError::empty_key(¶m_name));
11519 }
11520
11521 Ok(ApiKeyQuery::with_param_name(key, param_name))
11522 }
11523}
11524
11525pub const DEFAULT_API_KEY_COOKIE: &str = "api_key";
11531
11532#[derive(Debug, Clone)]
11534pub struct ApiKeyCookieConfig {
11535 cookie_name: String,
11537}
11538
11539impl Default for ApiKeyCookieConfig {
11540 fn default() -> Self {
11541 Self {
11542 cookie_name: DEFAULT_API_KEY_COOKIE.to_string(),
11543 }
11544 }
11545}
11546
11547impl ApiKeyCookieConfig {
11548 #[must_use]
11550 pub fn new() -> Self {
11551 Self::default()
11552 }
11553
11554 #[must_use]
11556 pub fn cookie_name(mut self, name: impl Into<String>) -> Self {
11557 self.cookie_name = name.into();
11558 self
11559 }
11560
11561 #[must_use]
11563 pub fn get_cookie_name(&self) -> &str {
11564 &self.cookie_name
11565 }
11566}
11567
11568#[derive(Debug, Clone, PartialEq, Eq)]
11620pub struct ApiKeyCookie {
11621 key: String,
11623 cookie_name: String,
11625}
11626
11627impl ApiKeyCookie {
11628 #[must_use]
11630 pub fn new(key: impl Into<String>) -> Self {
11631 Self {
11632 key: key.into(),
11633 cookie_name: DEFAULT_API_KEY_COOKIE.to_string(),
11634 }
11635 }
11636
11637 #[must_use]
11639 pub fn with_cookie_name(key: impl Into<String>, cookie_name: impl Into<String>) -> Self {
11640 Self {
11641 key: key.into(),
11642 cookie_name: cookie_name.into(),
11643 }
11644 }
11645
11646 #[must_use]
11648 pub fn key(&self) -> &str {
11649 &self.key
11650 }
11651
11652 #[must_use]
11654 pub fn cookie_name(&self) -> &str {
11655 &self.cookie_name
11656 }
11657
11658 #[must_use]
11660 pub fn into_key(self) -> String {
11661 self.key
11662 }
11663}
11664
11665impl Deref for ApiKeyCookie {
11666 type Target = str;
11667
11668 fn deref(&self) -> &Self::Target {
11669 &self.key
11670 }
11671}
11672
11673impl AsRef<str> for ApiKeyCookie {
11674 fn as_ref(&self) -> &str {
11675 &self.key
11676 }
11677}
11678
11679impl SecureCompare for ApiKeyCookie {
11681 fn secure_eq(&self, other: &str) -> bool {
11682 constant_time_str_eq(&self.key, other)
11683 }
11684
11685 fn secure_eq_bytes(&self, other: &[u8]) -> bool {
11686 constant_time_eq(self.key.as_bytes(), other)
11687 }
11688}
11689
11690#[derive(Debug, Clone)]
11692pub enum ApiKeyCookieError {
11693 MissingCookie {
11695 cookie_name: String,
11697 },
11698 EmptyKey {
11700 cookie_name: String,
11702 },
11703}
11704
11705impl ApiKeyCookieError {
11706 #[must_use]
11708 pub fn missing_cookie(cookie_name: impl Into<String>) -> Self {
11709 Self::MissingCookie {
11710 cookie_name: cookie_name.into(),
11711 }
11712 }
11713
11714 #[must_use]
11716 pub fn empty_key(cookie_name: impl Into<String>) -> Self {
11717 Self::EmptyKey {
11718 cookie_name: cookie_name.into(),
11719 }
11720 }
11721
11722 #[must_use]
11724 pub fn detail(&self) -> String {
11725 match self {
11726 Self::MissingCookie { cookie_name } => {
11727 format!("API key required. Include '{cookie_name}' cookie.")
11728 }
11729 Self::EmptyKey { cookie_name } => {
11730 format!("API key cannot be empty. Provide a value for '{cookie_name}' cookie.")
11731 }
11732 }
11733 }
11734}
11735
11736impl fmt::Display for ApiKeyCookieError {
11737 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
11738 match self {
11739 Self::MissingCookie { cookie_name } => {
11740 write!(f, "Missing API key cookie: {cookie_name}")
11741 }
11742 Self::EmptyKey { cookie_name } => {
11743 write!(f, "Empty API key in cookie: {cookie_name}")
11744 }
11745 }
11746 }
11747}
11748
11749impl std::error::Error for ApiKeyCookieError {}
11750
11751impl IntoResponse for ApiKeyCookieError {
11752 fn into_response(self) -> crate::response::Response {
11753 use crate::response::{Response, ResponseBody, StatusCode};
11754
11755 let body = serde_json::json!({
11756 "detail": self.detail()
11757 });
11758
11759 Response::with_status(StatusCode::UNAUTHORIZED)
11760 .header("content-type", b"application/json".to_vec())
11761 .body(ResponseBody::Bytes(body.to_string().into_bytes()))
11762 }
11763}
11764
11765impl FromRequest for ApiKeyCookie {
11766 type Error = ApiKeyCookieError;
11767
11768 async fn from_request(_ctx: &RequestContext, req: &mut Request) -> Result<Self, Self::Error> {
11769 let cookie_name = req.get_extension::<ApiKeyCookieConfig>().map_or_else(
11771 || DEFAULT_API_KEY_COOKIE.to_string(),
11772 |c| c.get_cookie_name().to_string(),
11773 );
11774
11775 let cookies = req
11777 .headers()
11778 .get("cookie")
11779 .and_then(|v| std::str::from_utf8(v).ok())
11780 .map(RequestCookies::from_header)
11781 .unwrap_or_default();
11782
11783 let key_value = cookies
11785 .get(&cookie_name)
11786 .ok_or_else(|| ApiKeyCookieError::missing_cookie(&cookie_name))?;
11787
11788 let key = key_value.trim();
11790 if key.is_empty() {
11791 return Err(ApiKeyCookieError::empty_key(&cookie_name));
11792 }
11793
11794 Ok(ApiKeyCookie::with_cookie_name(key, cookie_name))
11795 }
11796}
11797
11798#[derive(Debug, Clone, PartialEq, Eq)]
11845pub struct BasicAuth {
11846 username: String,
11848 password: String,
11850}
11851
11852impl BasicAuth {
11853 #[must_use]
11855 pub fn new(username: impl Into<String>, password: impl Into<String>) -> Self {
11856 Self {
11857 username: username.into(),
11858 password: password.into(),
11859 }
11860 }
11861
11862 #[must_use]
11864 pub fn username(&self) -> &str {
11865 &self.username
11866 }
11867
11868 #[must_use]
11870 pub fn password(&self) -> &str {
11871 &self.password
11872 }
11873
11874 #[must_use]
11876 pub fn into_credentials(self) -> (String, String) {
11877 (self.username, self.password)
11878 }
11879}
11880
11881#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11883pub enum BasicAuthError {
11884 MissingHeader,
11886 InvalidScheme,
11888 InvalidBase64,
11890 MissingColon,
11892 InvalidUtf8,
11894}
11895
11896impl BasicAuthError {
11897 #[must_use]
11899 pub fn missing_header() -> Self {
11900 Self::MissingHeader
11901 }
11902
11903 #[must_use]
11905 pub fn invalid_scheme() -> Self {
11906 Self::InvalidScheme
11907 }
11908
11909 #[must_use]
11911 pub fn invalid_base64() -> Self {
11912 Self::InvalidBase64
11913 }
11914
11915 #[must_use]
11917 pub fn missing_colon() -> Self {
11918 Self::MissingColon
11919 }
11920
11921 #[must_use]
11923 pub fn invalid_utf8() -> Self {
11924 Self::InvalidUtf8
11925 }
11926
11927 #[must_use]
11929 pub fn detail(&self) -> &'static str {
11930 match self {
11931 Self::MissingHeader => "Not authenticated",
11932 Self::InvalidScheme => "Invalid authentication credentials",
11933 Self::InvalidBase64 => "Invalid authentication credentials",
11934 Self::MissingColon => "Invalid authentication credentials",
11935 Self::InvalidUtf8 => "Invalid authentication credentials",
11936 }
11937 }
11938}
11939
11940impl fmt::Display for BasicAuthError {
11941 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
11942 match self {
11943 Self::MissingHeader => write!(f, "Missing Authorization header"),
11944 Self::InvalidScheme => write!(f, "Authorization header must use Basic scheme"),
11945 Self::InvalidBase64 => write!(f, "Invalid base64 encoding in credentials"),
11946 Self::MissingColon => write!(f, "Credentials must contain username:password"),
11947 Self::InvalidUtf8 => write!(f, "Credentials contain invalid UTF-8"),
11948 }
11949 }
11950}
11951
11952impl std::error::Error for BasicAuthError {}
11953
11954impl IntoResponse for BasicAuthError {
11955 fn into_response(self) -> crate::response::Response {
11956 use crate::response::{Response, ResponseBody, StatusCode};
11957
11958 let body = serde_json::json!({
11959 "detail": self.detail()
11960 });
11961
11962 Response::with_status(StatusCode::UNAUTHORIZED)
11963 .header("www-authenticate", b"Basic".to_vec())
11964 .header("content-type", b"application/json".to_vec())
11965 .body(ResponseBody::Bytes(body.to_string().into_bytes()))
11966 }
11967}
11968
11969fn decode_base64(input: &str) -> Result<Vec<u8>, BasicAuthError> {
11974 const INVALID: u8 = 0xFF;
11975 const DECODE_TABLE: [u8; 256] = {
11976 let mut table = [INVALID; 256];
11977 let mut i = 0u8;
11978 while i < 26 {
11980 table[(b'A' + i) as usize] = i;
11981 i += 1;
11982 }
11983 i = 0;
11985 while i < 26 {
11986 table[(b'a' + i) as usize] = 26 + i;
11987 i += 1;
11988 }
11989 i = 0;
11991 while i < 10 {
11992 table[(b'0' + i) as usize] = 52 + i;
11993 i += 1;
11994 }
11995 table[b'+' as usize] = 62;
11997 table[b'/' as usize] = 63;
11998 table
11999 };
12000
12001 let input = input.trim_end_matches('=').trim();
12003 if input.is_empty() {
12004 return Ok(Vec::new());
12005 }
12006
12007 let mut output = Vec::with_capacity((input.len() * 3) / 4);
12008 let mut buffer: u32 = 0;
12009 let mut bits_collected: u8 = 0;
12010
12011 for byte in input.bytes() {
12012 let value = DECODE_TABLE[byte as usize];
12013 if value == INVALID {
12014 return Err(BasicAuthError::InvalidBase64);
12015 }
12016
12017 buffer = (buffer << 6) | u32::from(value);
12018 bits_collected += 6;
12019
12020 if bits_collected >= 8 {
12021 bits_collected -= 8;
12022 output.push((buffer >> bits_collected) as u8);
12023 buffer &= (1 << bits_collected) - 1;
12024 }
12025 }
12026
12027 Ok(output)
12028}
12029
12030impl FromRequest for BasicAuth {
12031 type Error = BasicAuthError;
12032
12033 async fn from_request(_ctx: &RequestContext, req: &mut Request) -> Result<Self, Self::Error> {
12034 let auth_header = req
12036 .headers()
12037 .get("authorization")
12038 .ok_or(BasicAuthError::MissingHeader)?;
12039
12040 let auth_str = std::str::from_utf8(auth_header).map_err(|_| BasicAuthError::InvalidUtf8)?;
12042
12043 const BASIC_PREFIX: &str = "Basic ";
12045 const BASIC_PREFIX_LOWER: &str = "basic ";
12046
12047 let encoded = if auth_str.starts_with(BASIC_PREFIX) {
12048 &auth_str[BASIC_PREFIX.len()..]
12049 } else if auth_str.starts_with(BASIC_PREFIX_LOWER) {
12050 &auth_str[BASIC_PREFIX_LOWER.len()..]
12051 } else {
12052 return Err(BasicAuthError::InvalidScheme);
12053 };
12054
12055 let decoded_bytes = decode_base64(encoded.trim())?;
12057
12058 let decoded = String::from_utf8(decoded_bytes).map_err(|_| BasicAuthError::InvalidUtf8)?;
12060
12061 let colon_pos = decoded.find(':').ok_or(BasicAuthError::MissingColon)?;
12063 let (username, password_with_colon) = decoded.split_at(colon_pos);
12064 let password = &password_with_colon[1..]; Ok(BasicAuth::new(username, password))
12067 }
12068}
12069
12070#[derive(Debug, Clone)]
12085pub struct DigestAuth {
12086 credentials: String,
12088}
12089
12090impl DigestAuth {
12091 #[must_use]
12093 pub fn new(credentials: impl Into<String>) -> Self {
12094 Self {
12095 credentials: credentials.into(),
12096 }
12097 }
12098
12099 #[must_use]
12101 pub fn credentials(&self) -> &str {
12102 &self.credentials
12103 }
12104
12105 pub fn param(&self, key: &str) -> Option<&str> {
12110 let search = format!("{key}=");
12111 let mut search_start = 0;
12112
12113 while let Some(pos) = self.credentials[search_start..].find(&search) {
12115 let abs_pos = search_start + pos;
12116
12117 let at_boundary = if abs_pos == 0 {
12119 true
12120 } else {
12121 let prev_char = self.credentials[..abs_pos].chars().next_back();
12123 matches!(prev_char, Some(',' | ' ' | '\t'))
12124 };
12125
12126 if at_boundary {
12127 let after_eq = &self.credentials[abs_pos + search.len()..];
12129 if after_eq.starts_with('"') {
12130 let inner = &after_eq[1..];
12132 let end = inner.find('"')?;
12133 return Some(&inner[..end]);
12134 }
12135 let end = after_eq.find(',').unwrap_or(after_eq.len());
12137 return Some(after_eq[..end].trim());
12138 }
12139
12140 search_start = abs_pos + 1;
12142 }
12143
12144 None
12145 }
12146}
12147
12148#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12150pub enum DigestAuthError {
12151 MissingHeader,
12153 InvalidScheme,
12155 InvalidUtf8,
12157}
12158
12159impl DigestAuthError {
12160 #[must_use]
12162 pub fn detail(&self) -> &'static str {
12163 match self {
12164 Self::MissingHeader => "Not authenticated",
12165 Self::InvalidScheme => "Invalid authentication credentials",
12166 Self::InvalidUtf8 => "Invalid authentication credentials",
12167 }
12168 }
12169}
12170
12171impl fmt::Display for DigestAuthError {
12172 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
12173 match self {
12174 Self::MissingHeader => write!(f, "Missing Authorization header"),
12175 Self::InvalidScheme => write!(f, "Authorization header must use Digest scheme"),
12176 Self::InvalidUtf8 => write!(f, "Authorization header contains invalid UTF-8"),
12177 }
12178 }
12179}
12180
12181impl std::error::Error for DigestAuthError {}
12182
12183impl IntoResponse for DigestAuthError {
12184 fn into_response(self) -> crate::response::Response {
12185 use crate::response::{Response, ResponseBody, StatusCode};
12186
12187 let body = serde_json::json!({
12188 "detail": self.detail()
12189 });
12190
12191 Response::with_status(StatusCode::UNAUTHORIZED)
12192 .header("www-authenticate", b"Digest".to_vec())
12193 .header("content-type", b"application/json".to_vec())
12194 .body(ResponseBody::Bytes(body.to_string().into_bytes()))
12195 }
12196}
12197
12198impl FromRequest for DigestAuth {
12199 type Error = DigestAuthError;
12200
12201 async fn from_request(_ctx: &RequestContext, req: &mut Request) -> Result<Self, Self::Error> {
12202 let auth_header = req
12203 .headers()
12204 .get("authorization")
12205 .ok_or(DigestAuthError::MissingHeader)?;
12206
12207 let auth_str =
12208 std::str::from_utf8(auth_header).map_err(|_| DigestAuthError::InvalidUtf8)?;
12209
12210 const DIGEST_PREFIX: &str = "Digest ";
12211 const DIGEST_PREFIX_LOWER: &str = "digest ";
12212
12213 let credentials = if auth_str.starts_with(DIGEST_PREFIX) {
12214 &auth_str[DIGEST_PREFIX.len()..]
12215 } else if auth_str.starts_with(DIGEST_PREFIX_LOWER) {
12216 &auth_str[DIGEST_PREFIX_LOWER.len()..]
12217 } else {
12218 return Err(DigestAuthError::InvalidScheme);
12219 };
12220
12221 Ok(DigestAuth::new(credentials.trim()))
12222 }
12223}
12224
12225#[must_use]
12285pub fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
12286 if a.len() != b.len() {
12289 return false;
12290 }
12291
12292 let diff = a
12295 .iter()
12296 .zip(b.iter())
12297 .fold(0u8, |acc, (x, y)| acc | (x ^ y));
12298
12299 diff == 0
12301}
12302
12303#[must_use]
12323#[inline]
12324pub fn constant_time_str_eq(a: &str, b: &str) -> bool {
12325 constant_time_eq(a.as_bytes(), b.as_bytes())
12326}
12327
12328pub trait SecureCompare {
12344 fn secure_eq(&self, other: &str) -> bool;
12349
12350 fn secure_eq_bytes(&self, other: &[u8]) -> bool;
12352}
12353
12354impl SecureCompare for BearerToken {
12355 #[inline]
12356 fn secure_eq(&self, other: &str) -> bool {
12357 constant_time_str_eq(self.token(), other)
12358 }
12359
12360 #[inline]
12361 fn secure_eq_bytes(&self, other: &[u8]) -> bool {
12362 constant_time_eq(self.token().as_bytes(), other)
12363 }
12364}
12365
12366impl SecureCompare for str {
12367 #[inline]
12368 fn secure_eq(&self, other: &str) -> bool {
12369 constant_time_str_eq(self, other)
12370 }
12371
12372 #[inline]
12373 fn secure_eq_bytes(&self, other: &[u8]) -> bool {
12374 constant_time_eq(self.as_bytes(), other)
12375 }
12376}
12377
12378impl SecureCompare for String {
12379 #[inline]
12380 fn secure_eq(&self, other: &str) -> bool {
12381 constant_time_str_eq(self, other)
12382 }
12383
12384 #[inline]
12385 fn secure_eq_bytes(&self, other: &[u8]) -> bool {
12386 constant_time_eq(self.as_bytes(), other)
12387 }
12388}
12389
12390impl SecureCompare for [u8] {
12391 #[inline]
12392 fn secure_eq(&self, other: &str) -> bool {
12393 constant_time_eq(self, other.as_bytes())
12394 }
12395
12396 #[inline]
12397 fn secure_eq_bytes(&self, other: &[u8]) -> bool {
12398 constant_time_eq(self, other)
12399 }
12400}
12401
12402impl<const N: usize> SecureCompare for [u8; N] {
12403 #[inline]
12404 fn secure_eq(&self, other: &str) -> bool {
12405 constant_time_eq(self, other.as_bytes())
12406 }
12407
12408 #[inline]
12409 fn secure_eq_bytes(&self, other: &[u8]) -> bool {
12410 constant_time_eq(self, other)
12411 }
12412}
12413
12414impl SecureCompare for Vec<u8> {
12415 #[inline]
12416 fn secure_eq(&self, other: &str) -> bool {
12417 constant_time_eq(self, other.as_bytes())
12418 }
12419
12420 #[inline]
12421 fn secure_eq_bytes(&self, other: &[u8]) -> bool {
12422 constant_time_eq(self, other)
12423 }
12424}
12425
12426pub const DEFAULT_PAGE: u64 = 1;
12432pub const DEFAULT_PER_PAGE: u64 = 20;
12434pub const MAX_PER_PAGE: u64 = 100;
12436
12437#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12489pub struct Pagination {
12490 page: u64,
12492 per_page: u64,
12494 offset: Option<u64>,
12496}
12497
12498impl Default for Pagination {
12499 fn default() -> Self {
12500 Self {
12501 page: DEFAULT_PAGE,
12502 per_page: DEFAULT_PER_PAGE,
12503 offset: None,
12504 }
12505 }
12506}
12507
12508impl Pagination {
12509 #[must_use]
12511 pub fn new(page: u64, per_page: u64) -> Self {
12512 Self {
12513 page: page.max(1),
12514 per_page: per_page.clamp(1, MAX_PER_PAGE),
12515 offset: None,
12516 }
12517 }
12518
12519 #[must_use]
12521 pub fn from_offset(offset: u64, limit: u64) -> Self {
12522 Self {
12523 page: (offset / limit.max(1)) + 1,
12524 per_page: limit.clamp(1, MAX_PER_PAGE),
12525 offset: Some(offset),
12526 }
12527 }
12528
12529 #[must_use]
12531 pub fn page(&self) -> u64 {
12532 self.page
12533 }
12534
12535 #[must_use]
12537 pub fn per_page(&self) -> u64 {
12538 self.per_page
12539 }
12540
12541 #[must_use]
12543 pub fn limit(&self) -> u64 {
12544 self.per_page
12545 }
12546
12547 #[must_use]
12552 pub fn offset(&self) -> u64 {
12553 self.offset
12554 .unwrap_or_else(|| (self.page.saturating_sub(1)) * self.per_page)
12555 }
12556
12557 #[must_use]
12559 pub fn total_pages(&self, total_items: u64) -> u64 {
12560 if self.per_page == 0 {
12561 return 0;
12562 }
12563 total_items.div_ceil(self.per_page)
12564 }
12565
12566 #[must_use]
12568 pub fn has_next(&self, total_items: u64) -> bool {
12569 self.page < self.total_pages(total_items)
12570 }
12571
12572 #[must_use]
12574 pub fn has_prev(&self) -> bool {
12575 self.page > 1
12576 }
12577
12578 #[must_use]
12586 pub fn paginate<T>(self, items: Vec<T>, total: u64, base_url: &str) -> Page<T> {
12587 Page::new(items, total, self, base_url.to_string())
12588 }
12589}
12590
12591impl FromRequest for Pagination {
12592 type Error = std::convert::Infallible;
12593
12594 async fn from_request(_ctx: &RequestContext, req: &mut Request) -> Result<Self, Self::Error> {
12595 let query = req
12596 .get_extension::<QueryParams>()
12597 .cloned()
12598 .unwrap_or_default();
12599
12600 let page = query
12602 .get("page")
12603 .and_then(|v: &str| v.parse::<u64>().ok())
12604 .unwrap_or(DEFAULT_PAGE)
12605 .max(1);
12606
12607 let per_page = query
12609 .get("per_page")
12610 .or_else(|| query.get("limit"))
12611 .and_then(|v: &str| v.parse::<u64>().ok())
12612 .unwrap_or(DEFAULT_PER_PAGE)
12613 .clamp(1, MAX_PER_PAGE);
12614
12615 let offset = query
12617 .get("offset")
12618 .and_then(|v: &str| v.parse::<u64>().ok());
12619
12620 Ok(Pagination {
12621 page,
12622 per_page,
12623 offset,
12624 })
12625 }
12626}
12627
12628#[derive(Debug, Clone, Copy)]
12632pub struct PaginationConfig {
12633 pub default_per_page: u64,
12635 pub max_per_page: u64,
12637 pub default_page: u64,
12639}
12640
12641impl Default for PaginationConfig {
12642 fn default() -> Self {
12643 Self {
12644 default_per_page: DEFAULT_PER_PAGE,
12645 max_per_page: MAX_PER_PAGE,
12646 default_page: DEFAULT_PAGE,
12647 }
12648 }
12649}
12650
12651impl PaginationConfig {
12652 #[must_use]
12654 pub fn new() -> Self {
12655 Self::default()
12656 }
12657
12658 #[must_use]
12660 pub fn default_per_page(mut self, value: u64) -> Self {
12661 self.default_per_page = value;
12662 self
12663 }
12664
12665 #[must_use]
12667 pub fn max_per_page(mut self, value: u64) -> Self {
12668 self.max_per_page = value;
12669 self
12670 }
12671
12672 #[must_use]
12674 pub fn default_page(mut self, value: u64) -> Self {
12675 self.default_page = value;
12676 self
12677 }
12678}
12679
12680#[derive(Debug, Clone)]
12722pub struct Page<T> {
12723 pub items: Vec<T>,
12725 pub total: u64,
12727 pub page: u64,
12729 pub per_page: u64,
12731 pub pages: u64,
12733 base_url: String,
12735}
12736
12737impl<T> Page<T> {
12738 #[must_use]
12740 pub fn new(items: Vec<T>, total: u64, pagination: Pagination, base_url: String) -> Self {
12741 let pages = pagination.total_pages(total);
12742 Self {
12743 items,
12744 total,
12745 page: pagination.page(),
12746 per_page: pagination.per_page(),
12747 pages,
12748 base_url,
12749 }
12750 }
12751
12752 #[must_use]
12754 pub fn with_values(
12755 items: Vec<T>,
12756 total: u64,
12757 page: u64,
12758 per_page: u64,
12759 base_url: impl Into<String>,
12760 ) -> Self {
12761 let pages = if per_page > 0 {
12762 total.div_ceil(per_page)
12763 } else {
12764 0
12765 };
12766 Self {
12767 items,
12768 total,
12769 page,
12770 per_page,
12771 pages,
12772 base_url: base_url.into(),
12773 }
12774 }
12775
12776 #[must_use]
12778 pub fn len(&self) -> usize {
12779 self.items.len()
12780 }
12781
12782 #[must_use]
12784 pub fn is_empty(&self) -> bool {
12785 self.items.is_empty()
12786 }
12787
12788 #[must_use]
12790 pub fn has_next(&self) -> bool {
12791 self.page < self.pages
12792 }
12793
12794 #[must_use]
12796 pub fn has_prev(&self) -> bool {
12797 self.page > 1
12798 }
12799
12800 #[must_use]
12808 pub fn link_header(&self) -> String {
12809 let mut links = Vec::with_capacity(4);
12810
12811 links.push(format!(
12813 "<{}?page=1&per_page={}>; rel=\"first\"",
12814 self.base_url, self.per_page
12815 ));
12816
12817 if self.has_prev() {
12819 links.push(format!(
12820 "<{}?page={}&per_page={}>; rel=\"prev\"",
12821 self.base_url,
12822 self.page - 1,
12823 self.per_page
12824 ));
12825 }
12826
12827 if self.has_next() {
12829 links.push(format!(
12830 "<{}?page={}&per_page={}>; rel=\"next\"",
12831 self.base_url,
12832 self.page + 1,
12833 self.per_page
12834 ));
12835 }
12836
12837 links.push(format!(
12839 "<{}?page={}&per_page={}>; rel=\"last\"",
12840 self.base_url, self.pages, self.per_page
12841 ));
12842
12843 links.join(", ")
12844 }
12845
12846 pub fn map<U, F>(self, f: F) -> Page<U>
12848 where
12849 F: FnMut(T) -> U,
12850 {
12851 Page {
12852 items: self.items.into_iter().map(f).collect(),
12853 total: self.total,
12854 page: self.page,
12855 per_page: self.per_page,
12856 pages: self.pages,
12857 base_url: self.base_url,
12858 }
12859 }
12860}
12861
12862#[derive(serde::Serialize)]
12864struct PageJson<'a, T: serde::Serialize> {
12865 items: &'a Vec<T>,
12866 total: u64,
12867 page: u64,
12868 per_page: u64,
12869 pages: u64,
12870}
12871
12872impl<T: serde::Serialize> IntoResponse for Page<T> {
12873 fn into_response(self) -> crate::response::Response {
12874 let json_body = PageJson {
12875 items: &self.items,
12876 total: self.total,
12877 page: self.page,
12878 per_page: self.per_page,
12879 pages: self.pages,
12880 };
12881
12882 let Ok(body_bytes) = serde_json::to_vec(&json_body) else {
12884 return crate::response::Response::with_status(
12886 crate::response::StatusCode::INTERNAL_SERVER_ERROR,
12887 )
12888 .header("content-type", b"application/json".to_vec())
12889 .body(crate::response::ResponseBody::Bytes(
12890 b"{\"error\":\"Serialization failed\"}".to_vec(),
12891 ));
12892 };
12893
12894 let link_header = self.link_header();
12896
12897 crate::response::Response::ok()
12898 .header("content-type", b"application/json".to_vec())
12899 .header("link", link_header.into_bytes())
12900 .header("x-total-count", self.total.to_string().into_bytes())
12901 .header("x-page", self.page.to_string().into_bytes())
12902 .header("x-per-page", self.per_page.to_string().into_bytes())
12903 .header("x-total-pages", self.pages.to_string().into_bytes())
12904 .body(crate::response::ResponseBody::Bytes(body_bytes))
12905 }
12906}
12907
12908#[derive(Debug, Clone)]
12912pub struct HeaderValues<T, N> {
12913 pub values: Vec<T>,
12915 _marker: std::marker::PhantomData<N>,
12916}
12917
12918impl<T, N> HeaderValues<T, N> {
12919 #[must_use]
12921 pub fn new(values: Vec<T>) -> Self {
12922 Self {
12923 values,
12924 _marker: std::marker::PhantomData,
12925 }
12926 }
12927
12928 #[must_use]
12930 pub fn is_empty(&self) -> bool {
12931 self.values.is_empty()
12932 }
12933
12934 #[must_use]
12936 pub fn len(&self) -> usize {
12937 self.values.len()
12938 }
12939}
12940
12941impl<T, N> Deref for HeaderValues<T, N> {
12942 type Target = Vec<T>;
12943
12944 fn deref(&self) -> &Self::Target {
12945 &self.values
12946 }
12947}
12948
12949#[derive(Debug, Clone, PartialEq)]
12958pub struct MediaType {
12959 pub typ: String,
12961 pub subtype: String,
12963 pub params: Vec<(String, String)>,
12965}
12966
12967impl MediaType {
12968 pub fn parse(s: &str) -> Option<Self> {
12978 let s = s.trim();
12979 let (type_part, params_part) = match s.find(';') {
12980 Some(pos) => (&s[..pos], Some(&s[pos + 1..])),
12981 None => (s, None),
12982 };
12983
12984 let (typ, subtype) = type_part.split_once('/')?;
12985 let typ = typ.trim().to_ascii_lowercase();
12986 let subtype = subtype.trim().to_ascii_lowercase();
12987
12988 if typ.is_empty() || subtype.is_empty() {
12989 return None;
12990 }
12991
12992 let mut params = Vec::new();
12993 if let Some(params_str) = params_part {
12994 for param in params_str.split(';') {
12995 let param = param.trim();
12996 if param.is_empty() {
12997 continue;
12998 }
12999 if let Some((key, value)) = param.split_once('=') {
13000 let key = key.trim().to_ascii_lowercase();
13001 let value = value.trim().trim_matches('"').to_string();
13002 if key != "q" {
13004 params.push((key, value));
13005 }
13006 }
13007 }
13008 }
13009
13010 Some(Self {
13011 typ,
13012 subtype,
13013 params,
13014 })
13015 }
13016
13017 #[must_use]
13019 pub fn new(typ: impl Into<String>, subtype: impl Into<String>) -> Self {
13020 Self {
13021 typ: typ.into().to_ascii_lowercase(),
13022 subtype: subtype.into().to_ascii_lowercase(),
13023 params: Vec::new(),
13024 }
13025 }
13026
13027 #[must_use]
13043 pub fn matches(&self, other: &MediaType) -> bool {
13044 let type_matches = other.typ == "*" || self.typ == other.typ;
13045 let subtype_matches = other.subtype == "*" || self.subtype == other.subtype;
13046 type_matches && subtype_matches
13047 }
13048
13049 #[must_use]
13051 pub fn essence(&self) -> String {
13052 format!("{}/{}", self.typ, self.subtype)
13053 }
13054
13055 #[must_use]
13057 pub fn param(&self, name: &str) -> Option<&str> {
13058 let name_lower = name.to_ascii_lowercase();
13059 self.params
13060 .iter()
13061 .find(|(k, _)| k == &name_lower)
13062 .map(|(_, v)| v.as_str())
13063 }
13064}
13065
13066impl fmt::Display for MediaType {
13067 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
13068 write!(f, "{}/{}", self.typ, self.subtype)?;
13069 for (key, value) in &self.params {
13070 write!(f, "; {}={}", key, value)?;
13071 }
13072 Ok(())
13073 }
13074}
13075
13076#[derive(Debug, Clone)]
13078pub struct AcceptItem {
13079 pub media_type: MediaType,
13081 pub quality: f32,
13083}
13084
13085impl AcceptItem {
13086 pub fn parse(s: &str) -> Option<Self> {
13096 let s = s.trim();
13097 let mut quality = 1.0f32;
13098
13099 let media_str = if let Some(q_pos) = s.to_ascii_lowercase().find(";q=") {
13101 let after_q = &s[q_pos + 3..];
13102 let q_end = after_q.find(';').unwrap_or(after_q.len());
13103 let q_str = &after_q[..q_end];
13104 if let Ok(q) = q_str.trim().parse::<f32>() {
13105 quality = q.clamp(0.0, 1.0);
13106 }
13107 let before = &s[..q_pos];
13109 let after = if q_end < after_q.len() {
13110 &after_q[q_end..]
13111 } else {
13112 ""
13113 };
13114 format!("{}{}", before, after)
13115 } else {
13116 s.to_string()
13117 };
13118
13119 let media_type = MediaType::parse(&media_str)?;
13120 Some(Self {
13121 media_type,
13122 quality,
13123 })
13124 }
13125}
13126
13127impl PartialEq for AcceptItem {
13128 fn eq(&self, other: &Self) -> bool {
13129 self.media_type == other.media_type && (self.quality - other.quality).abs() < f32::EPSILON
13130 }
13131}
13132
13133#[derive(Debug, Clone)]
13155pub struct AcceptHeader {
13156 pub items: Vec<AcceptItem>,
13158}
13159
13160impl AcceptHeader {
13161 #[must_use]
13172 pub fn parse(s: &str) -> Self {
13173 let mut items: Vec<AcceptItem> = s.split(',').filter_map(AcceptItem::parse).collect();
13174
13175 items.sort_by(|a, b| {
13177 let q_cmp = b
13179 .quality
13180 .partial_cmp(&a.quality)
13181 .unwrap_or(std::cmp::Ordering::Equal);
13182 if q_cmp != std::cmp::Ordering::Equal {
13183 return q_cmp;
13184 }
13185 let a_wildcards =
13187 u8::from(a.media_type.typ == "*") + u8::from(a.media_type.subtype == "*");
13188 let b_wildcards =
13189 u8::from(b.media_type.typ == "*") + u8::from(b.media_type.subtype == "*");
13190 a_wildcards.cmp(&b_wildcards)
13191 });
13192
13193 Self { items }
13194 }
13195
13196 #[must_use]
13198 pub fn any() -> Self {
13199 Self {
13200 items: vec![AcceptItem {
13201 media_type: MediaType::new("*", "*"),
13202 quality: 1.0,
13203 }],
13204 }
13205 }
13206
13207 #[must_use]
13209 pub fn accepts(&self, media_type: &str) -> bool {
13210 if self.items.is_empty() {
13211 return true; }
13213
13214 let Some(mt) = MediaType::parse(media_type) else {
13215 return false;
13216 };
13217
13218 self.items
13219 .iter()
13220 .any(|item| item.quality > 0.0 && mt.matches(&item.media_type))
13221 }
13222
13223 #[must_use]
13228 pub fn prefers(&self, media_type: &str) -> bool {
13229 let Some(mt) = MediaType::parse(media_type) else {
13230 return false;
13231 };
13232
13233 self.items
13234 .first()
13235 .map(|item| mt.matches(&item.media_type))
13236 .unwrap_or(true)
13237 }
13238
13239 #[must_use]
13243 pub fn quality_of(&self, media_type: &str) -> f32 {
13244 if self.items.is_empty() {
13245 return 1.0; }
13247
13248 let Some(mt) = MediaType::parse(media_type) else {
13249 return 0.0;
13250 };
13251
13252 self.items
13253 .iter()
13254 .find(|item| mt.matches(&item.media_type))
13255 .map(|item| item.quality)
13256 .unwrap_or(0.0)
13257 }
13258
13259 #[must_use]
13273 pub fn negotiate<'a>(&self, available: &[&'a str]) -> Option<&'a str> {
13274 if self.items.is_empty() {
13275 return available.first().copied();
13276 }
13277
13278 let mut scored: Vec<(&str, f32, usize)> = available
13280 .iter()
13281 .enumerate()
13282 .filter_map(|(idx, &media_type)| {
13283 let q = self.quality_of(media_type);
13284 if q > 0.0 {
13285 Some((media_type, q, idx))
13286 } else {
13287 None
13288 }
13289 })
13290 .collect();
13291
13292 scored.sort_by(|a, b| {
13294 b.1.partial_cmp(&a.1)
13295 .unwrap_or(std::cmp::Ordering::Equal)
13296 .then_with(|| a.2.cmp(&b.2))
13297 });
13298
13299 scored.first().map(|(mt, _, _)| *mt)
13300 }
13301
13302 #[must_use]
13304 pub fn is_empty(&self) -> bool {
13305 self.items.is_empty()
13306 }
13307}
13308
13309impl Default for AcceptHeader {
13310 fn default() -> Self {
13311 Self::any()
13312 }
13313}
13314
13315impl FromRequest for AcceptHeader {
13316 type Error = std::convert::Infallible;
13317
13318 async fn from_request(_ctx: &RequestContext, req: &mut Request) -> Result<Self, Self::Error> {
13319 let header = req
13320 .headers()
13321 .get("accept")
13322 .and_then(|v| std::str::from_utf8(v).ok())
13323 .map(Self::parse)
13324 .unwrap_or_else(Self::any);
13325 Ok(header)
13326 }
13327}
13328
13329#[derive(Debug, Clone)]
13331pub struct AcceptEncodingItem {
13332 pub encoding: String,
13334 pub quality: f32,
13336}
13337
13338impl AcceptEncodingItem {
13339 pub fn parse(s: &str) -> Option<Self> {
13341 let s = s.trim();
13342 let mut quality = 1.0f32;
13343
13344 let (encoding, _) = if let Some(q_pos) = s.to_ascii_lowercase().find(";q=") {
13345 let q_str = &s[q_pos + 3..];
13346 if let Ok(q) = q_str.trim().parse::<f32>() {
13347 quality = q.clamp(0.0, 1.0);
13348 }
13349 (s[..q_pos].trim().to_ascii_lowercase(), quality)
13350 } else {
13351 (s.to_ascii_lowercase(), quality)
13352 };
13353
13354 if encoding.is_empty() {
13355 return None;
13356 }
13357
13358 Some(Self { encoding, quality })
13359 }
13360}
13361
13362#[derive(Debug, Clone, Default)]
13378pub struct AcceptEncodingHeader {
13379 pub items: Vec<AcceptEncodingItem>,
13381}
13382
13383impl AcceptEncodingHeader {
13384 #[must_use]
13386 pub fn parse(s: &str) -> Self {
13387 let mut items: Vec<AcceptEncodingItem> =
13388 s.split(',').filter_map(AcceptEncodingItem::parse).collect();
13389
13390 items.sort_by(|a, b| {
13391 b.quality
13392 .partial_cmp(&a.quality)
13393 .unwrap_or(std::cmp::Ordering::Equal)
13394 });
13395
13396 Self { items }
13397 }
13398
13399 #[must_use]
13401 pub fn accepts(&self, encoding: &str) -> bool {
13402 let encoding = encoding.to_ascii_lowercase();
13403 self.items
13404 .iter()
13405 .any(|item| item.quality > 0.0 && (item.encoding == encoding || item.encoding == "*"))
13406 }
13407
13408 #[must_use]
13410 pub fn negotiate<'a>(&self, available: &[&'a str]) -> Option<&'a str> {
13411 if self.items.is_empty() {
13412 return available.first().copied();
13413 }
13414
13415 let mut best: Option<(&str, f32)> = None;
13416
13417 for &encoding in available {
13418 let enc_lower = encoding.to_ascii_lowercase();
13419 for item in &self.items {
13420 if item.quality > 0.0 && (item.encoding == enc_lower || item.encoding == "*") {
13421 match best {
13422 None => best = Some((encoding, item.quality)),
13423 Some((_, q)) if item.quality > q => best = Some((encoding, item.quality)),
13424 _ => {}
13425 }
13426 break;
13427 }
13428 }
13429 }
13430
13431 best.map(|(e, _)| e)
13432 }
13433}
13434
13435impl FromRequest for AcceptEncodingHeader {
13436 type Error = std::convert::Infallible;
13437
13438 async fn from_request(_ctx: &RequestContext, req: &mut Request) -> Result<Self, Self::Error> {
13439 let header = req
13440 .headers()
13441 .get("accept-encoding")
13442 .and_then(|v| std::str::from_utf8(v).ok())
13443 .map(Self::parse)
13444 .unwrap_or_default();
13445 Ok(header)
13446 }
13447}
13448
13449#[derive(Debug, Clone)]
13451pub struct AcceptLanguageItem {
13452 pub language: String,
13454 pub quality: f32,
13456}
13457
13458impl AcceptLanguageItem {
13459 pub fn parse(s: &str) -> Option<Self> {
13461 let s = s.trim();
13462 let mut quality = 1.0f32;
13463
13464 let (language, _) = if let Some(q_pos) = s.to_ascii_lowercase().find(";q=") {
13465 let q_str = &s[q_pos + 3..];
13466 if let Ok(q) = q_str.trim().parse::<f32>() {
13467 quality = q.clamp(0.0, 1.0);
13468 }
13469 (s[..q_pos].trim().to_string(), quality)
13470 } else {
13471 (s.to_string(), quality)
13472 };
13473
13474 if language.is_empty() {
13475 return None;
13476 }
13477
13478 Some(Self { language, quality })
13479 }
13480}
13481
13482#[derive(Debug, Clone, Default)]
13495pub struct AcceptLanguageHeader {
13496 pub items: Vec<AcceptLanguageItem>,
13498}
13499
13500impl AcceptLanguageHeader {
13501 #[must_use]
13503 pub fn parse(s: &str) -> Self {
13504 let mut items: Vec<AcceptLanguageItem> =
13505 s.split(',').filter_map(AcceptLanguageItem::parse).collect();
13506
13507 items.sort_by(|a, b| {
13508 b.quality
13509 .partial_cmp(&a.quality)
13510 .unwrap_or(std::cmp::Ordering::Equal)
13511 });
13512
13513 Self { items }
13514 }
13515
13516 #[must_use]
13518 pub fn accepts(&self, language: &str) -> bool {
13519 let lang_lower = language.to_ascii_lowercase();
13520 self.items.iter().any(|item| {
13521 if item.quality <= 0.0 {
13522 return false;
13523 }
13524 let item_lower = item.language.to_ascii_lowercase();
13525 item_lower == lang_lower
13527 || item_lower == "*"
13528 || lang_lower.starts_with(&format!("{}-", item_lower))
13529 || item_lower.starts_with(&format!("{}-", lang_lower))
13530 })
13531 }
13532
13533 #[must_use]
13535 pub fn negotiate<'a>(&self, available: &[&'a str]) -> Option<&'a str> {
13536 if self.items.is_empty() {
13537 return available.first().copied();
13538 }
13539
13540 let mut best: Option<(&str, f32, bool)> = None; for &lang in available {
13543 let lang_lower = lang.to_ascii_lowercase();
13544 for item in &self.items {
13545 if item.quality <= 0.0 {
13546 continue;
13547 }
13548 let item_lower = item.language.to_ascii_lowercase();
13549
13550 let (matches, exact) = if item_lower == lang_lower {
13551 (true, true)
13552 } else if item_lower == "*"
13553 || lang_lower.starts_with(&format!("{}-", item_lower))
13554 || item_lower.starts_with(&format!("{}-", lang_lower))
13555 {
13556 (true, false)
13557 } else {
13558 (false, false)
13559 };
13560
13561 if matches {
13562 match best {
13563 None => best = Some((lang, item.quality, exact)),
13564 Some((_, q, e))
13565 if item.quality > q
13566 || ((item.quality - q).abs() < f32::EPSILON && exact && !e) =>
13567 {
13568 best = Some((lang, item.quality, exact));
13569 }
13570 _ => {}
13571 }
13572 break;
13573 }
13574 }
13575 }
13576
13577 best.map(|(l, _, _)| l)
13578 }
13579}
13580
13581impl FromRequest for AcceptLanguageHeader {
13582 type Error = std::convert::Infallible;
13583
13584 async fn from_request(_ctx: &RequestContext, req: &mut Request) -> Result<Self, Self::Error> {
13585 let header = req
13586 .headers()
13587 .get("accept-language")
13588 .and_then(|v| std::str::from_utf8(v).ok())
13589 .map(Self::parse)
13590 .unwrap_or_default();
13591 Ok(header)
13592 }
13593}
13594
13595#[derive(Debug, Clone)]
13597pub struct NotAcceptableError {
13598 pub requested: Vec<String>,
13600 pub available: Vec<String>,
13602}
13603
13604impl NotAcceptableError {
13605 #[must_use]
13607 pub fn new(requested: Vec<String>, available: Vec<String>) -> Self {
13608 Self {
13609 requested,
13610 available,
13611 }
13612 }
13613}
13614
13615impl fmt::Display for NotAcceptableError {
13616 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
13617 write!(
13618 f,
13619 "Not Acceptable: requested [{}], available [{}]",
13620 self.requested.join(", "),
13621 self.available.join(", ")
13622 )
13623 }
13624}
13625
13626impl std::error::Error for NotAcceptableError {}
13627
13628impl IntoResponse for NotAcceptableError {
13629 fn into_response(self) -> Response {
13630 Response::with_status(crate::response::StatusCode::NOT_ACCEPTABLE)
13631 .header("content-type", b"application/json".to_vec())
13632 .body(ResponseBody::Bytes(
13633 serde_json::json!({
13634 "error": "Not Acceptable",
13635 "message": self.to_string(),
13636 "requested": self.requested,
13637 "available": self.available,
13638 })
13639 .to_string()
13640 .into_bytes(),
13641 ))
13642 }
13643}
13644
13645#[derive(Debug, Clone, Default)]
13647pub struct VaryBuilder {
13648 headers: Vec<String>,
13649}
13650
13651impl VaryBuilder {
13652 #[must_use]
13654 pub fn new() -> Self {
13655 Self::default()
13656 }
13657
13658 #[must_use]
13660 pub fn accept(mut self) -> Self {
13661 if !self.headers.contains(&"Accept".to_string()) {
13662 self.headers.push("Accept".to_string());
13663 }
13664 self
13665 }
13666
13667 #[must_use]
13669 pub fn accept_encoding(mut self) -> Self {
13670 if !self.headers.contains(&"Accept-Encoding".to_string()) {
13671 self.headers.push("Accept-Encoding".to_string());
13672 }
13673 self
13674 }
13675
13676 #[must_use]
13678 pub fn accept_language(mut self) -> Self {
13679 if !self.headers.contains(&"Accept-Language".to_string()) {
13680 self.headers.push("Accept-Language".to_string());
13681 }
13682 self
13683 }
13684
13685 #[must_use]
13687 pub fn header(mut self, name: impl Into<String>) -> Self {
13688 let name = name.into();
13689 if !self.headers.contains(&name) {
13690 self.headers.push(name);
13691 }
13692 self
13693 }
13694
13695 #[must_use]
13697 pub fn build(&self) -> String {
13698 self.headers.join(", ")
13699 }
13700
13701 #[must_use]
13703 pub fn is_empty(&self) -> bool {
13704 self.headers.is_empty()
13705 }
13706}
13707
13708#[cfg(test)]
13709mod content_negotiation_tests {
13710 use super::*;
13711
13712 #[test]
13713 fn media_type_parse_simple() {
13714 let mt = MediaType::parse("text/html").unwrap();
13715 assert_eq!(mt.typ, "text");
13716 assert_eq!(mt.subtype, "html");
13717 assert!(mt.params.is_empty());
13718 }
13719
13720 #[test]
13721 fn media_type_parse_with_params() {
13722 let mt = MediaType::parse("text/html; charset=utf-8").unwrap();
13723 assert_eq!(mt.typ, "text");
13724 assert_eq!(mt.subtype, "html");
13725 assert_eq!(mt.param("charset"), Some("utf-8"));
13726 }
13727
13728 #[test]
13729 fn media_type_parse_case_insensitive() {
13730 let mt = MediaType::parse("TEXT/HTML").unwrap();
13731 assert_eq!(mt.typ, "text");
13732 assert_eq!(mt.subtype, "html");
13733 }
13734
13735 #[test]
13736 fn media_type_matches_wildcard() {
13737 let html = MediaType::new("text", "html");
13738 let any_text = MediaType::new("text", "*");
13739 let any = MediaType::new("*", "*");
13740
13741 assert!(html.matches(&any_text));
13742 assert!(html.matches(&any));
13743 assert!(html.matches(&html));
13744 }
13745
13746 #[test]
13747 fn accept_item_parse_with_quality() {
13748 let item = AcceptItem::parse("text/html;q=0.9").unwrap();
13749 assert_eq!(item.media_type.typ, "text");
13750 assert_eq!(item.media_type.subtype, "html");
13751 assert!((item.quality - 0.9).abs() < f32::EPSILON);
13752 }
13753
13754 #[test]
13755 fn accept_item_parse_default_quality() {
13756 let item = AcceptItem::parse("application/json").unwrap();
13757 assert!((item.quality - 1.0).abs() < f32::EPSILON);
13758 }
13759
13760 #[test]
13761 fn accept_header_parse_multiple() {
13762 let accept = AcceptHeader::parse("text/html, application/json;q=0.9, */*;q=0.1");
13763 assert_eq!(accept.items.len(), 3);
13764 assert_eq!(accept.items[0].media_type.subtype, "html");
13765 assert_eq!(accept.items[1].media_type.subtype, "json");
13766 assert_eq!(accept.items[2].media_type.subtype, "*");
13767 }
13768
13769 #[test]
13770 fn accept_header_prefers() {
13771 let accept = AcceptHeader::parse("text/html, application/json;q=0.9");
13772 assert!(accept.prefers("text/html"));
13773 assert!(!accept.prefers("application/json"));
13774 }
13775
13776 #[test]
13777 fn accept_header_accepts() {
13778 let accept = AcceptHeader::parse("text/html, application/json;q=0.9");
13779 assert!(accept.accepts("text/html"));
13780 assert!(accept.accepts("application/json"));
13781 assert!(!accept.accepts("image/png"));
13782 }
13783
13784 #[test]
13785 fn accept_header_negotiate() {
13786 let accept = AcceptHeader::parse("text/html, application/json;q=0.9");
13787 let available = ["application/json", "text/html", "text/plain"];
13788 assert_eq!(accept.negotiate(&available), Some("text/html"));
13789 }
13790
13791 #[test]
13792 fn accept_header_negotiate_returns_best_available() {
13793 let accept = AcceptHeader::parse("application/xml, application/json;q=0.9");
13794 let available = ["application/json", "text/plain"];
13795 assert_eq!(accept.negotiate(&available), Some("application/json"));
13796 }
13797
13798 #[test]
13799 fn accept_header_quality_of() {
13800 let accept = AcceptHeader::parse("text/html, application/json;q=0.9, */*;q=0.1");
13801 assert!((accept.quality_of("text/html") - 1.0).abs() < f32::EPSILON);
13802 assert!((accept.quality_of("application/json") - 0.9).abs() < f32::EPSILON);
13803 assert!((accept.quality_of("image/png") - 0.1).abs() < f32::EPSILON);
13804 }
13805
13806 #[test]
13807 #[allow(clippy::float_cmp)]
13808 fn accept_header_empty_accepts_all() {
13809 let accept = AcceptHeader::parse("");
13810 assert!(accept.accepts("anything/here"));
13811 assert_eq!(accept.quality_of("text/html"), 1.0);
13812 }
13813
13814 #[test]
13815 fn accept_encoding_parse() {
13816 let enc = AcceptEncodingHeader::parse("gzip, deflate, br;q=0.8");
13817 assert_eq!(enc.items.len(), 3);
13818 assert!(enc.accepts("gzip"));
13819 assert!(enc.accepts("br"));
13820 }
13821
13822 #[test]
13823 fn accept_encoding_negotiate() {
13824 let enc = AcceptEncodingHeader::parse("gzip;q=0.9, br");
13825 let available = ["gzip", "br", "identity"];
13826 assert_eq!(enc.negotiate(&available), Some("br"));
13827 }
13828
13829 #[test]
13830 fn accept_language_parse() {
13831 let lang = AcceptLanguageHeader::parse("en-US, en;q=0.9, fr;q=0.8");
13832 assert_eq!(lang.items.len(), 3);
13833 assert!(lang.accepts("en-US"));
13834 assert!(lang.accepts("en"));
13835 assert!(lang.accepts("fr"));
13836 }
13837
13838 #[test]
13839 fn accept_language_negotiate() {
13840 let lang = AcceptLanguageHeader::parse("fr, en;q=0.9");
13841 let available = ["en", "de", "fr"];
13842 assert_eq!(lang.negotiate(&available), Some("fr"));
13843 }
13844
13845 #[test]
13846 fn accept_language_prefix_match() {
13847 let lang = AcceptLanguageHeader::parse("en");
13848 assert!(lang.accepts("en-US"));
13849 assert!(lang.accepts("en-GB"));
13850 }
13851
13852 #[test]
13853 fn vary_builder() {
13854 let vary = VaryBuilder::new().accept().accept_encoding().build();
13855 assert_eq!(vary, "Accept, Accept-Encoding");
13856 }
13857
13858 #[test]
13859 fn vary_builder_no_duplicates() {
13860 let vary = VaryBuilder::new().accept().accept().build();
13861 assert_eq!(vary, "Accept");
13862 }
13863
13864 #[test]
13865 fn not_acceptable_error_response() {
13866 let err = NotAcceptableError::new(
13867 vec!["image/png".to_string()],
13868 vec!["application/json".to_string(), "text/html".to_string()],
13869 );
13870 let response = err.into_response();
13871 assert_eq!(
13872 response.status(),
13873 crate::response::StatusCode::NOT_ACCEPTABLE
13874 );
13875 }
13876}
13877
13878#[cfg(test)]
13879mod header_tests {
13880 use super::*;
13881 use crate::request::Method;
13882
13883 fn test_context() -> RequestContext {
13884 let cx = asupersync::Cx::for_testing();
13885 RequestContext::new(cx, 12345)
13886 }
13887
13888 #[test]
13889 fn snake_to_header_case_simple() {
13890 assert_eq!(snake_to_header_case("authorization"), "Authorization");
13891 assert_eq!(snake_to_header_case("content_type"), "Content-Type");
13892 assert_eq!(snake_to_header_case("x_request_id"), "X-Request-Id");
13893 assert_eq!(snake_to_header_case("accept"), "Accept");
13894 }
13895
13896 #[test]
13897 fn snake_to_header_case_edge_cases() {
13898 assert_eq!(snake_to_header_case(""), "");
13899 assert_eq!(snake_to_header_case("a"), "A");
13900 assert_eq!(snake_to_header_case("a_b_c"), "A-B-C");
13901 }
13902
13903 #[test]
13904 fn header_deref() {
13905 let header = Header::new("test", "value".to_string());
13906 assert_eq!(*header, "value");
13907 }
13908
13909 #[test]
13910 fn header_into_inner() {
13911 let header = Header::new("test", 42i32);
13912 assert_eq!(header.into_inner(), 42);
13913 }
13914
13915 #[test]
13916 fn from_header_value_string() {
13917 let result = String::from_header_value("test value");
13918 assert_eq!(result.unwrap(), "test value");
13919 }
13920
13921 #[test]
13922 fn from_header_value_i32() {
13923 assert_eq!(i32::from_header_value("42").unwrap(), 42);
13924 assert_eq!(i32::from_header_value("-1").unwrap(), -1);
13925 assert!(i32::from_header_value("abc").is_err());
13926 }
13927
13928 #[test]
13929 fn from_header_value_bool() {
13930 assert!(bool::from_header_value("true").unwrap());
13931 assert!(bool::from_header_value("1").unwrap());
13932 assert!(bool::from_header_value("yes").unwrap());
13933 assert!(!bool::from_header_value("false").unwrap());
13934 assert!(!bool::from_header_value("0").unwrap());
13935 assert!(!bool::from_header_value("no").unwrap());
13936 assert!(bool::from_header_value("maybe").is_err());
13937 }
13938
13939 #[test]
13940 fn named_header_extract_success() {
13941 let ctx = test_context();
13942 let mut req = Request::new(Method::Get, "/test");
13943 req.headers_mut()
13944 .insert("authorization", b"Bearer token123".to_vec());
13945
13946 let result = futures_executor::block_on(
13947 NamedHeader::<String, Authorization>::from_request(&ctx, &mut req),
13948 );
13949 let header = result.unwrap();
13950 assert_eq!(header.value, "Bearer token123");
13951 }
13952
13953 #[test]
13954 fn named_header_extract_i32() {
13955 let ctx = test_context();
13956 let mut req = Request::new(Method::Get, "/test");
13957 req.headers_mut().insert("x-request-id", b"12345".to_vec());
13958
13959 let result = futures_executor::block_on(NamedHeader::<i32, XRequestId>::from_request(
13960 &ctx, &mut req,
13961 ));
13962 let header = result.unwrap();
13963 assert_eq!(header.value, 12345);
13964 }
13965
13966 #[test]
13967 fn named_header_missing() {
13968 let ctx = test_context();
13969 let mut req = Request::new(Method::Get, "/test");
13970 let result = futures_executor::block_on(
13973 NamedHeader::<String, Authorization>::from_request(&ctx, &mut req),
13974 );
13975 assert!(matches!(
13976 result,
13977 Err(HeaderExtractError::MissingHeader { .. })
13978 ));
13979 }
13980
13981 #[test]
13982 fn named_header_parse_error() {
13983 let ctx = test_context();
13984 let mut req = Request::new(Method::Get, "/test");
13985 req.headers_mut()
13986 .insert("x-request-id", b"not-a-number".to_vec());
13987
13988 let result = futures_executor::block_on(NamedHeader::<i32, XRequestId>::from_request(
13989 &ctx, &mut req,
13990 ));
13991 assert!(matches!(result, Err(HeaderExtractError::ParseError { .. })));
13992 }
13993
13994 #[test]
13995 fn header_error_display() {
13996 let err = HeaderExtractError::MissingHeader {
13997 name: "Authorization".to_string(),
13998 };
13999 assert!(err.to_string().contains("Authorization"));
14000
14001 let err = HeaderExtractError::ParseError {
14002 name: "X-Count".to_string(),
14003 value: "abc".to_string(),
14004 expected: "i32",
14005 message: "invalid digit".to_string(),
14006 };
14007 assert!(err.to_string().contains("X-Count"));
14008 assert!(err.to_string().contains("abc"));
14009 }
14010
14011 #[test]
14012 fn optional_header_some() {
14013 let ctx = test_context();
14014 let mut req = Request::new(Method::Get, "/test");
14015 req.headers_mut()
14016 .insert("authorization", b"Bearer token".to_vec());
14017
14018 let result = futures_executor::block_on(
14019 Option::<NamedHeader<String, Authorization>>::from_request(&ctx, &mut req),
14020 );
14021 let opt = result.unwrap();
14022 assert!(opt.is_some());
14023 assert_eq!(opt.unwrap().value, "Bearer token");
14024 }
14025
14026 #[test]
14027 fn optional_header_none() {
14028 let ctx = test_context();
14029 let mut req = Request::new(Method::Get, "/test");
14030 let result = futures_executor::block_on(
14033 Option::<NamedHeader<String, Authorization>>::from_request(&ctx, &mut req),
14034 );
14035 let opt = result.unwrap();
14036 assert!(opt.is_none());
14037 }
14038}
14039
14040#[cfg(test)]
14041mod oauth2_tests {
14042 use super::*;
14043 use crate::request::Method;
14044 use crate::response::IntoResponse;
14045
14046 fn test_context() -> RequestContext {
14047 let cx = asupersync::Cx::for_testing();
14048 RequestContext::new(cx, 12345)
14049 }
14050
14051 #[test]
14052 fn oauth2_extract_valid_bearer_token() {
14053 let ctx = test_context();
14054 let mut req = Request::new(Method::Get, "/api/protected");
14055 req.headers_mut()
14056 .insert("authorization", b"Bearer mytoken123".to_vec());
14057
14058 let result = futures_executor::block_on(OAuth2PasswordBearer::from_request(&ctx, &mut req));
14059 let bearer = result.unwrap();
14060 assert_eq!(bearer.token(), "mytoken123");
14061 assert_eq!(&*bearer, "mytoken123"); }
14063
14064 #[test]
14065 fn oauth2_extract_bearer_lowercase() {
14066 let ctx = test_context();
14067 let mut req = Request::new(Method::Get, "/api/protected");
14068 req.headers_mut()
14069 .insert("authorization", b"bearer lowercase_token".to_vec());
14070
14071 let result = futures_executor::block_on(OAuth2PasswordBearer::from_request(&ctx, &mut req));
14072 let bearer = result.unwrap();
14073 assert_eq!(bearer.token(), "lowercase_token");
14074 }
14075
14076 #[test]
14077 fn oauth2_missing_header() {
14078 let ctx = test_context();
14079 let mut req = Request::new(Method::Get, "/api/protected");
14080 let result = futures_executor::block_on(OAuth2PasswordBearer::from_request(&ctx, &mut req));
14083 let err = result.unwrap_err();
14084 assert_eq!(err.kind, OAuth2BearerErrorKind::MissingHeader);
14085 }
14086
14087 #[test]
14088 fn oauth2_wrong_scheme() {
14089 let ctx = test_context();
14090 let mut req = Request::new(Method::Get, "/api/protected");
14091 req.headers_mut()
14092 .insert("authorization", b"Basic dXNlcjpwYXNz".to_vec());
14093
14094 let result = futures_executor::block_on(OAuth2PasswordBearer::from_request(&ctx, &mut req));
14095 let err = result.unwrap_err();
14096 assert_eq!(err.kind, OAuth2BearerErrorKind::InvalidScheme);
14097 }
14098
14099 #[test]
14100 fn oauth2_empty_token() {
14101 let ctx = test_context();
14102 let mut req = Request::new(Method::Get, "/api/protected");
14103 req.headers_mut()
14104 .insert("authorization", b"Bearer ".to_vec());
14105
14106 let result = futures_executor::block_on(OAuth2PasswordBearer::from_request(&ctx, &mut req));
14107 let err = result.unwrap_err();
14108 assert_eq!(err.kind, OAuth2BearerErrorKind::EmptyToken);
14109 }
14110
14111 #[test]
14112 fn oauth2_whitespace_only_token() {
14113 let ctx = test_context();
14114 let mut req = Request::new(Method::Get, "/api/protected");
14115 req.headers_mut()
14116 .insert("authorization", b"Bearer ".to_vec());
14117
14118 let result = futures_executor::block_on(OAuth2PasswordBearer::from_request(&ctx, &mut req));
14119 let err = result.unwrap_err();
14120 assert_eq!(err.kind, OAuth2BearerErrorKind::EmptyToken);
14121 }
14122
14123 #[test]
14124 fn oauth2_token_with_spaces_trimmed() {
14125 let ctx = test_context();
14126 let mut req = Request::new(Method::Get, "/api/protected");
14127 req.headers_mut()
14128 .insert("authorization", b"Bearer spaced_token ".to_vec());
14129
14130 let result = futures_executor::block_on(OAuth2PasswordBearer::from_request(&ctx, &mut req));
14131 let bearer = result.unwrap();
14132 assert_eq!(bearer.token(), "spaced_token");
14133 }
14134
14135 #[test]
14136 fn oauth2_optional_extraction_some() {
14137 let ctx = test_context();
14138 let mut req = Request::new(Method::Get, "/api/maybe-protected");
14139 req.headers_mut()
14140 .insert("authorization", b"Bearer optional_token".to_vec());
14141
14142 let result = futures_executor::block_on(Option::<OAuth2PasswordBearer>::from_request(
14143 &ctx, &mut req,
14144 ));
14145 let opt = result.unwrap();
14146 assert!(opt.is_some());
14147 assert_eq!(opt.unwrap().token(), "optional_token");
14148 }
14149
14150 #[test]
14151 fn oauth2_optional_extraction_none() {
14152 let ctx = test_context();
14153 let mut req = Request::new(Method::Get, "/api/maybe-protected");
14154 let result = futures_executor::block_on(Option::<OAuth2PasswordBearer>::from_request(
14157 &ctx, &mut req,
14158 ));
14159 let opt = result.unwrap();
14160 assert!(opt.is_none());
14161 }
14162
14163 #[test]
14164 fn oauth2_error_response_401() {
14165 let err = OAuth2BearerError::missing_header();
14166 let response = err.into_response();
14167 assert_eq!(response.status().as_u16(), 401);
14168 }
14169
14170 #[test]
14171 fn oauth2_error_response_has_www_authenticate() {
14172 let err = OAuth2BearerError::missing_header();
14173 let response = err.into_response();
14174
14175 let www_auth = response
14176 .headers()
14177 .iter()
14178 .find(|(name, _)| name == "www-authenticate")
14179 .map(|(_, value)| String::from_utf8_lossy(value).to_string());
14180
14181 assert_eq!(www_auth, Some("Bearer".to_string()));
14182 }
14183
14184 #[test]
14185 fn oauth2_error_display() {
14186 assert!(
14187 OAuth2BearerError::missing_header()
14188 .to_string()
14189 .contains("Missing")
14190 );
14191 assert!(
14192 OAuth2BearerError::invalid_scheme()
14193 .to_string()
14194 .contains("Bearer")
14195 );
14196 assert!(
14197 OAuth2BearerError::empty_token()
14198 .to_string()
14199 .contains("empty")
14200 );
14201 }
14202
14203 #[test]
14204 fn oauth2_config_builder() {
14205 let config = OAuth2PasswordBearerConfig::new("/auth/token")
14206 .with_refresh_url("/auth/refresh")
14207 .with_scope("read", "Read access")
14208 .with_scope("write", "Write access")
14209 .with_scheme_name("MyOAuth2")
14210 .with_description("Custom OAuth2 scheme")
14211 .with_auto_error(false);
14212
14213 assert_eq!(config.token_url, "/auth/token");
14214 assert_eq!(config.refresh_url, Some("/auth/refresh".to_string()));
14215 assert_eq!(config.scopes.len(), 2);
14216 assert_eq!(config.scopes.get("read"), Some(&"Read access".to_string()));
14217 assert_eq!(config.scheme_name, Some("MyOAuth2".to_string()));
14218 assert!(!config.auto_error);
14219 }
14220
14221 #[test]
14222 fn oauth2_password_bearer_accessors() {
14223 let bearer = OAuth2PasswordBearer::new("test_token");
14224 assert_eq!(bearer.token(), "test_token");
14225 assert_eq!(bearer.into_token(), "test_token");
14226 }
14227
14228 #[test]
14233 fn oauth2_error_response_json_body_format() {
14234 let err = OAuth2BearerError::missing_header();
14235 let response = err.into_response();
14236
14237 let body = match response.body_ref() {
14239 crate::response::ResponseBody::Bytes(b) => String::from_utf8_lossy(b).to_string(),
14240 _ => panic!("Expected Bytes body"),
14241 };
14242
14243 let json: serde_json::Value =
14244 serde_json::from_str(&body).expect("Body should be valid JSON");
14245 assert!(
14246 json.get("detail").is_some(),
14247 "Response should have 'detail' field"
14248 );
14249 assert_eq!(json["detail"], "Not authenticated");
14250 }
14251
14252 #[test]
14253 fn oauth2_error_invalid_scheme_json_body() {
14254 let err = OAuth2BearerError::invalid_scheme();
14255 let response = err.into_response();
14256
14257 let body = match response.body_ref() {
14258 crate::response::ResponseBody::Bytes(b) => String::from_utf8_lossy(b).to_string(),
14259 _ => panic!("Expected Bytes body"),
14260 };
14261
14262 let json: serde_json::Value =
14263 serde_json::from_str(&body).expect("Body should be valid JSON");
14264 assert_eq!(json["detail"], "Invalid authentication credentials");
14265 }
14266
14267 #[test]
14268 fn oauth2_error_empty_token_json_body() {
14269 let err = OAuth2BearerError::empty_token();
14270 let response = err.into_response();
14271
14272 let body = match response.body_ref() {
14273 crate::response::ResponseBody::Bytes(b) => String::from_utf8_lossy(b).to_string(),
14274 _ => panic!("Expected Bytes body"),
14275 };
14276
14277 let json: serde_json::Value =
14278 serde_json::from_str(&body).expect("Body should be valid JSON");
14279 assert_eq!(json["detail"], "Invalid authentication credentials");
14280 }
14281
14282 #[test]
14283 fn oauth2_error_response_content_type_json() {
14284 let err = OAuth2BearerError::missing_header();
14285 let response = err.into_response();
14286
14287 let content_type = response
14288 .headers()
14289 .iter()
14290 .find(|(name, _)| name == "content-type")
14291 .map(|(_, value)| String::from_utf8_lossy(value).to_string());
14292
14293 assert_eq!(content_type, Some("application/json".to_string()));
14294 }
14295
14296 #[test]
14297 fn oauth2_extract_token_with_special_characters() {
14298 let ctx = test_context();
14299 let mut req = Request::new(Method::Get, "/api/protected");
14300 req.headers_mut()
14302 .insert("authorization", b"Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.dozjgNryP4J3jVmNHl0w5N_XgL0n3I9PlFUP0THsR8U".to_vec());
14303
14304 let result = futures_executor::block_on(OAuth2PasswordBearer::from_request(&ctx, &mut req));
14305 let bearer = result.unwrap();
14306 assert!(bearer.token().contains("eyJ"));
14307 assert!(bearer.token().contains("."));
14308 }
14309
14310 #[test]
14311 fn oauth2_extract_token_with_unicode() {
14312 let ctx = test_context();
14313 let mut req = Request::new(Method::Get, "/api/protected");
14314 req.headers_mut().insert(
14316 "authorization",
14317 "Bearer tökën_with_ünïcödë".as_bytes().to_vec(),
14318 );
14319
14320 let result = futures_executor::block_on(OAuth2PasswordBearer::from_request(&ctx, &mut req));
14321 let bearer = result.unwrap();
14322 assert_eq!(bearer.token(), "tökën_with_ünïcödë");
14323 }
14324
14325 #[test]
14326 fn oauth2_invalid_utf8_in_token() {
14327 let ctx = test_context();
14328 let mut req = Request::new(Method::Get, "/api/protected");
14329 req.headers_mut().insert(
14331 "authorization",
14332 vec![66, 101, 97, 114, 101, 114, 32, 0xFF, 0xFE],
14333 );
14334
14335 let result = futures_executor::block_on(OAuth2PasswordBearer::from_request(&ctx, &mut req));
14336 assert!(result.is_err());
14338 assert_eq!(
14339 result.unwrap_err().kind,
14340 OAuth2BearerErrorKind::InvalidScheme
14341 );
14342 }
14343
14344 #[test]
14345 fn oauth2_only_bearer_prefix_no_space() {
14346 let ctx = test_context();
14347 let mut req = Request::new(Method::Get, "/api/protected");
14348 req.headers_mut()
14350 .insert("authorization", b"Bearertoken".to_vec());
14351
14352 let result = futures_executor::block_on(OAuth2PasswordBearer::from_request(&ctx, &mut req));
14353 let err = result.unwrap_err();
14354 assert_eq!(err.kind, OAuth2BearerErrorKind::InvalidScheme);
14355 }
14356
14357 #[test]
14358 fn oauth2_mixed_case_bearer() {
14359 let ctx = test_context();
14360 let mut req = Request::new(Method::Get, "/api/protected");
14361 req.headers_mut()
14363 .insert("authorization", b"BEARER uppercase_token".to_vec());
14364
14365 let result = futures_executor::block_on(OAuth2PasswordBearer::from_request(&ctx, &mut req));
14366 let err = result.unwrap_err();
14368 assert_eq!(err.kind, OAuth2BearerErrorKind::InvalidScheme);
14369 }
14370
14371 #[test]
14372 fn oauth2_extract_very_long_token() {
14373 let ctx = test_context();
14374 let mut req = Request::new(Method::Get, "/api/protected");
14375 let long_token = "x".repeat(4096);
14377 req.headers_mut()
14378 .insert("authorization", format!("Bearer {long_token}").into_bytes());
14379
14380 let result = futures_executor::block_on(OAuth2PasswordBearer::from_request(&ctx, &mut req));
14381 let bearer = result.unwrap();
14382 assert_eq!(bearer.token().len(), 4096);
14383 }
14384
14385 #[test]
14386 fn oauth2_config_default_values() {
14387 let config = OAuth2PasswordBearerConfig::default();
14388
14389 assert_eq!(config.token_url, "/token");
14390 assert!(config.refresh_url.is_none());
14391 assert!(config.scopes.is_empty());
14392 assert!(config.scheme_name.is_none());
14393 assert!(config.description.is_none());
14394 assert!(config.auto_error); }
14396
14397 #[test]
14398 fn oauth2_error_kind_equality() {
14399 assert_eq!(
14401 OAuth2BearerErrorKind::MissingHeader,
14402 OAuth2BearerErrorKind::MissingHeader
14403 );
14404 assert_eq!(
14405 OAuth2BearerErrorKind::InvalidScheme,
14406 OAuth2BearerErrorKind::InvalidScheme
14407 );
14408 assert_eq!(
14409 OAuth2BearerErrorKind::EmptyToken,
14410 OAuth2BearerErrorKind::EmptyToken
14411 );
14412 assert_ne!(
14413 OAuth2BearerErrorKind::MissingHeader,
14414 OAuth2BearerErrorKind::InvalidScheme
14415 );
14416 }
14417
14418 #[test]
14419 fn oauth2_error_debug_format() {
14420 let err = OAuth2BearerError::missing_header();
14422 let debug_str = format!("{:?}", err);
14423 assert!(debug_str.contains("MissingHeader"));
14424 }
14425
14426 #[test]
14427 fn oauth2_bearer_clone() {
14428 let bearer = OAuth2PasswordBearer::new("cloneable_token");
14429 let cloned = bearer.clone();
14430 assert_eq!(bearer.token(), cloned.token());
14431 }
14432
14433 #[test]
14434 fn oauth2_config_clone() {
14435 let config =
14436 OAuth2PasswordBearerConfig::new("/auth/token").with_scope("admin", "Admin access");
14437 let cloned = config.clone();
14438 assert_eq!(config.token_url, cloned.token_url);
14439 assert_eq!(config.scopes.len(), cloned.scopes.len());
14440 }
14441
14442 #[test]
14443 fn oauth2_all_error_responses_are_401() {
14444 let errors = [
14446 OAuth2BearerError::missing_header(),
14447 OAuth2BearerError::invalid_scheme(),
14448 OAuth2BearerError::empty_token(),
14449 ];
14450
14451 for err in errors {
14452 let response = err.into_response();
14453 assert_eq!(
14454 response.status().as_u16(),
14455 401,
14456 "All OAuth2 errors should be 401"
14457 );
14458 }
14459 }
14460
14461 #[test]
14462 fn oauth2_all_error_responses_have_www_authenticate() {
14463 let errors = [
14465 OAuth2BearerError::missing_header(),
14466 OAuth2BearerError::invalid_scheme(),
14467 OAuth2BearerError::empty_token(),
14468 ];
14469
14470 for err in errors {
14471 let response = err.into_response();
14472 let has_www_auth = response
14473 .headers()
14474 .iter()
14475 .any(|(name, value)| name == "www-authenticate" && value == b"Bearer");
14476 assert!(
14477 has_www_auth,
14478 "All OAuth2 errors should have WWW-Authenticate: Bearer"
14479 );
14480 }
14481 }
14482}
14483
14484#[cfg(test)]
14485mod bearer_token_tests {
14486 use super::*;
14487 use crate::request::Method;
14488 use crate::response::IntoResponse;
14489
14490 fn test_context() -> RequestContext {
14491 let cx = asupersync::Cx::for_testing();
14492 RequestContext::new(cx, 12345)
14493 }
14494
14495 #[test]
14496 fn bearer_token_extract_valid_token() {
14497 let ctx = test_context();
14498 let mut req = Request::new(Method::Get, "/api/protected");
14499 req.headers_mut()
14500 .insert("authorization", b"Bearer mytoken123".to_vec());
14501
14502 let result = futures_executor::block_on(BearerToken::from_request(&ctx, &mut req));
14503 let token = result.unwrap();
14504 assert_eq!(token.token(), "mytoken123");
14505 assert_eq!(&*token, "mytoken123"); assert_eq!(token.as_ref(), "mytoken123"); }
14508
14509 #[test]
14510 fn bearer_token_extract_lowercase_bearer() {
14511 let ctx = test_context();
14512 let mut req = Request::new(Method::Get, "/api/protected");
14513 req.headers_mut()
14514 .insert("authorization", b"bearer lowercase_token".to_vec());
14515
14516 let result = futures_executor::block_on(BearerToken::from_request(&ctx, &mut req));
14517 let token = result.unwrap();
14518 assert_eq!(token.token(), "lowercase_token");
14519 }
14520
14521 #[test]
14522 fn bearer_token_missing_header() {
14523 let ctx = test_context();
14524 let mut req = Request::new(Method::Get, "/api/protected");
14525 let result = futures_executor::block_on(BearerToken::from_request(&ctx, &mut req));
14528 let err = result.unwrap_err();
14529 assert_eq!(err, BearerTokenError::MissingHeader);
14530 }
14531
14532 #[test]
14533 fn bearer_token_wrong_scheme() {
14534 let ctx = test_context();
14535 let mut req = Request::new(Method::Get, "/api/protected");
14536 req.headers_mut()
14537 .insert("authorization", b"Basic dXNlcjpwYXNz".to_vec());
14538
14539 let result = futures_executor::block_on(BearerToken::from_request(&ctx, &mut req));
14540 let err = result.unwrap_err();
14541 assert_eq!(err, BearerTokenError::InvalidScheme);
14542 }
14543
14544 #[test]
14545 fn bearer_token_empty_token() {
14546 let ctx = test_context();
14547 let mut req = Request::new(Method::Get, "/api/protected");
14548 req.headers_mut()
14549 .insert("authorization", b"Bearer ".to_vec());
14550
14551 let result = futures_executor::block_on(BearerToken::from_request(&ctx, &mut req));
14552 let err = result.unwrap_err();
14553 assert_eq!(err, BearerTokenError::EmptyToken);
14554 }
14555
14556 #[test]
14557 fn bearer_token_whitespace_only_token() {
14558 let ctx = test_context();
14559 let mut req = Request::new(Method::Get, "/api/protected");
14560 req.headers_mut()
14561 .insert("authorization", b"Bearer ".to_vec());
14562
14563 let result = futures_executor::block_on(BearerToken::from_request(&ctx, &mut req));
14564 let err = result.unwrap_err();
14565 assert_eq!(err, BearerTokenError::EmptyToken);
14566 }
14567
14568 #[test]
14569 fn bearer_token_with_spaces_trimmed() {
14570 let ctx = test_context();
14571 let mut req = Request::new(Method::Get, "/api/protected");
14572 req.headers_mut()
14573 .insert("authorization", b"Bearer spaced_token ".to_vec());
14574
14575 let result = futures_executor::block_on(BearerToken::from_request(&ctx, &mut req));
14576 let token = result.unwrap();
14577 assert_eq!(token.token(), "spaced_token");
14578 }
14579
14580 #[test]
14581 fn bearer_token_optional_some() {
14582 let ctx = test_context();
14583 let mut req = Request::new(Method::Get, "/api/protected");
14584 req.headers_mut()
14585 .insert("authorization", b"Bearer optional_token".to_vec());
14586
14587 let result =
14588 futures_executor::block_on(Option::<BearerToken>::from_request(&ctx, &mut req));
14589 let maybe_token = result.unwrap();
14590 assert!(maybe_token.is_some());
14591 assert_eq!(maybe_token.unwrap().token(), "optional_token");
14592 }
14593
14594 #[test]
14595 fn bearer_token_optional_none() {
14596 let ctx = test_context();
14597 let mut req = Request::new(Method::Get, "/api/protected");
14598 let result =
14601 futures_executor::block_on(Option::<BearerToken>::from_request(&ctx, &mut req));
14602 let maybe_token = result.unwrap();
14603 assert!(maybe_token.is_none());
14604 }
14605
14606 #[test]
14607 fn bearer_token_error_response_401() {
14608 let err = BearerTokenError::missing_header();
14609 let response = err.into_response();
14610 assert_eq!(response.status().as_u16(), 401);
14611 }
14612
14613 #[test]
14614 fn bearer_token_error_has_www_authenticate() {
14615 let err = BearerTokenError::missing_header();
14616 let response = err.into_response();
14617
14618 let has_www_auth = response
14619 .headers()
14620 .iter()
14621 .any(|(name, value)| name == "www-authenticate" && value == b"Bearer");
14622 assert!(has_www_auth);
14623 }
14624
14625 #[test]
14626 fn bearer_token_error_display() {
14627 assert_eq!(
14628 BearerTokenError::missing_header().to_string(),
14629 "Missing Authorization header"
14630 );
14631 assert_eq!(
14632 BearerTokenError::invalid_scheme().to_string(),
14633 "Authorization header must use Bearer scheme"
14634 );
14635 assert_eq!(
14636 BearerTokenError::empty_token().to_string(),
14637 "Bearer token is empty"
14638 );
14639 }
14640
14641 #[test]
14642 fn bearer_token_error_detail() {
14643 assert_eq!(
14644 BearerTokenError::MissingHeader.detail(),
14645 "Not authenticated"
14646 );
14647 assert_eq!(
14648 BearerTokenError::InvalidScheme.detail(),
14649 "Invalid authentication credentials"
14650 );
14651 assert_eq!(
14652 BearerTokenError::EmptyToken.detail(),
14653 "Invalid authentication credentials"
14654 );
14655 }
14656
14657 #[test]
14658 fn bearer_token_new_and_accessors() {
14659 let token = BearerToken::new("test_token");
14660 assert_eq!(token.token(), "test_token");
14661 assert_eq!(token.clone().into_token(), "test_token");
14662 }
14663
14664 #[test]
14665 fn bearer_token_error_response_json_body() {
14666 let err = BearerTokenError::missing_header();
14667 let response = err.into_response();
14668
14669 let body_str = match response.body_ref() {
14670 crate::response::ResponseBody::Bytes(b) => String::from_utf8_lossy(b).to_string(),
14671 _ => panic!("Expected Bytes body"),
14672 };
14673 let body: serde_json::Value = serde_json::from_str(&body_str).unwrap();
14674
14675 assert_eq!(body["detail"], "Not authenticated");
14676 }
14677
14678 #[test]
14679 fn bearer_token_error_content_type_json() {
14680 let err = BearerTokenError::missing_header();
14681 let response = err.into_response();
14682
14683 let has_json_content_type = response
14684 .headers()
14685 .iter()
14686 .any(|(name, value)| name == "content-type" && value == b"application/json");
14687 assert!(has_json_content_type);
14688 }
14689
14690 #[test]
14691 fn bearer_token_special_characters() {
14692 let ctx = test_context();
14693 let mut req = Request::new(Method::Get, "/api/protected");
14694 let special_token = "abc123!@#$%^&*()_+-=[]{}|;':\",./<>?";
14695 req.headers_mut().insert(
14696 "authorization",
14697 format!("Bearer {}", special_token).into_bytes(),
14698 );
14699
14700 let result = futures_executor::block_on(BearerToken::from_request(&ctx, &mut req));
14701 let token = result.unwrap();
14702 assert_eq!(token.token(), special_token);
14703 }
14704
14705 #[test]
14706 fn bearer_token_very_long_token() {
14707 let ctx = test_context();
14708 let mut req = Request::new(Method::Get, "/api/protected");
14709 let long_token = "a".repeat(10000);
14710 req.headers_mut().insert(
14711 "authorization",
14712 format!("Bearer {}", long_token).into_bytes(),
14713 );
14714
14715 let result = futures_executor::block_on(BearerToken::from_request(&ctx, &mut req));
14716 let token = result.unwrap();
14717 assert_eq!(token.token(), long_token);
14718 }
14719
14720 #[test]
14721 fn bearer_token_invalid_utf8() {
14722 let ctx = test_context();
14723 let mut req = Request::new(Method::Get, "/api/protected");
14724 req.headers_mut().insert(
14726 "authorization",
14727 vec![0x42, 0x65, 0x61, 0x72, 0x65, 0x72, 0x20, 0xFF, 0xFE],
14728 );
14729
14730 let result = futures_executor::block_on(BearerToken::from_request(&ctx, &mut req));
14731 let err = result.unwrap_err();
14732 assert_eq!(err, BearerTokenError::InvalidScheme);
14733 }
14734
14735 #[test]
14736 fn bearer_token_only_bearer_no_space() {
14737 let ctx = test_context();
14738 let mut req = Request::new(Method::Get, "/api/protected");
14739 req.headers_mut()
14741 .insert("authorization", b"Bearer".to_vec());
14742
14743 let result = futures_executor::block_on(BearerToken::from_request(&ctx, &mut req));
14744 let err = result.unwrap_err();
14745 assert_eq!(err, BearerTokenError::InvalidScheme);
14746 }
14747
14748 #[test]
14749 fn bearer_token_mixed_case_bearer() {
14750 let ctx = test_context();
14751 let mut req = Request::new(Method::Get, "/api/protected");
14752 req.headers_mut()
14754 .insert("authorization", b"BEARER token".to_vec());
14755
14756 let result = futures_executor::block_on(BearerToken::from_request(&ctx, &mut req));
14757 let err = result.unwrap_err();
14758 assert_eq!(err, BearerTokenError::InvalidScheme);
14759 }
14760
14761 #[test]
14762 fn bearer_token_all_errors_are_401() {
14763 let errors = vec![
14764 BearerTokenError::missing_header(),
14765 BearerTokenError::invalid_scheme(),
14766 BearerTokenError::empty_token(),
14767 ];
14768
14769 for err in errors {
14770 let response = err.into_response();
14771 assert_eq!(
14772 response.status().as_u16(),
14773 401,
14774 "All BearerToken errors should be 401"
14775 );
14776 }
14777 }
14778
14779 #[test]
14780 fn bearer_token_all_errors_have_www_authenticate() {
14781 let errors = vec![
14782 BearerTokenError::missing_header(),
14783 BearerTokenError::invalid_scheme(),
14784 BearerTokenError::empty_token(),
14785 ];
14786
14787 for err in errors {
14788 let response = err.into_response();
14789 let has_www_auth = response
14790 .headers()
14791 .iter()
14792 .any(|(name, value)| name == "www-authenticate" && value == b"Bearer");
14793 assert!(
14794 has_www_auth,
14795 "All BearerToken errors should have WWW-Authenticate: Bearer"
14796 );
14797 }
14798 }
14799
14800 #[test]
14801 fn bearer_token_equality() {
14802 let token1 = BearerToken::new("same_token");
14803 let token2 = BearerToken::new("same_token");
14804 let token3 = BearerToken::new("different_token");
14805
14806 assert_eq!(token1, token2);
14807 assert_ne!(token1, token3);
14808 }
14809
14810 #[test]
14811 fn bearer_token_error_equality() {
14812 assert_eq!(
14813 BearerTokenError::MissingHeader,
14814 BearerTokenError::MissingHeader
14815 );
14816 assert_eq!(
14817 BearerTokenError::InvalidScheme,
14818 BearerTokenError::InvalidScheme
14819 );
14820 assert_eq!(BearerTokenError::EmptyToken, BearerTokenError::EmptyToken);
14821 assert_ne!(
14822 BearerTokenError::MissingHeader,
14823 BearerTokenError::InvalidScheme
14824 );
14825 }
14826
14827 #[test]
14828 fn bearer_token_debug() {
14829 let token = BearerToken::new("debug_token");
14830 let debug_str = format!("{:?}", token);
14831 assert!(debug_str.contains("debug_token"));
14832 }
14833
14834 #[test]
14835 fn bearer_token_clone() {
14836 let token = BearerToken::new("cloneable");
14837 let cloned = token.clone();
14838 assert_eq!(token, cloned);
14839 }
14840}
14841
14842#[cfg(test)]
14843mod api_key_header_tests {
14844 use super::*;
14845 use crate::request::Method;
14846 use crate::response::IntoResponse;
14847
14848 fn test_context() -> RequestContext {
14849 let cx = asupersync::Cx::for_testing();
14850 RequestContext::new(cx, 54321)
14851 }
14852
14853 #[test]
14854 fn api_key_header_extraction_default() {
14855 let ctx = test_context();
14856 let mut req = Request::new(Method::Get, "/api/protected");
14857 req.headers_mut()
14858 .insert("x-api-key", b"test_api_key_123".to_vec());
14859
14860 let result = futures_executor::block_on(ApiKeyHeader::from_request(&ctx, &mut req));
14861 let api_key = result.unwrap();
14862 assert_eq!(api_key.key(), "test_api_key_123");
14863 assert_eq!(api_key.header_name(), "x-api-key");
14864 }
14865
14866 #[test]
14867 fn api_key_header_missing() {
14868 let ctx = test_context();
14869 let mut req = Request::new(Method::Get, "/api/protected");
14870 let result = futures_executor::block_on(ApiKeyHeader::from_request(&ctx, &mut req));
14873 assert!(result.is_err());
14874 let err = result.unwrap_err();
14875 assert!(matches!(err, ApiKeyHeaderError::MissingHeader { .. }));
14876 }
14877
14878 #[test]
14879 fn api_key_header_empty() {
14880 let ctx = test_context();
14881 let mut req = Request::new(Method::Get, "/api/protected");
14882 req.headers_mut().insert("x-api-key", b"".to_vec());
14883
14884 let result = futures_executor::block_on(ApiKeyHeader::from_request(&ctx, &mut req));
14885 assert!(result.is_err());
14886 let err = result.unwrap_err();
14887 assert!(matches!(err, ApiKeyHeaderError::EmptyKey { .. }));
14888 }
14889
14890 #[test]
14891 fn api_key_header_whitespace_only() {
14892 let ctx = test_context();
14893 let mut req = Request::new(Method::Get, "/api/protected");
14894 req.headers_mut().insert("x-api-key", b" ".to_vec());
14895
14896 let result = futures_executor::block_on(ApiKeyHeader::from_request(&ctx, &mut req));
14897 assert!(result.is_err());
14898 let err = result.unwrap_err();
14899 assert!(matches!(err, ApiKeyHeaderError::EmptyKey { .. }));
14900 }
14901
14902 #[test]
14903 fn api_key_header_trims_whitespace() {
14904 let ctx = test_context();
14905 let mut req = Request::new(Method::Get, "/api/protected");
14906 req.headers_mut()
14907 .insert("x-api-key", b" my_key_123 ".to_vec());
14908
14909 let result = futures_executor::block_on(ApiKeyHeader::from_request(&ctx, &mut req));
14910 let api_key = result.unwrap();
14911 assert_eq!(api_key.key(), "my_key_123");
14912 }
14913
14914 #[test]
14915 fn api_key_header_custom_header_name() {
14916 let ctx = test_context();
14917 let mut req = Request::new(Method::Get, "/api/protected");
14918 req.headers_mut()
14919 .insert("authorization", b"custom_key".to_vec());
14920 req.insert_extension(ApiKeyHeaderConfig::new().header_name("authorization"));
14921
14922 let result = futures_executor::block_on(ApiKeyHeader::from_request(&ctx, &mut req));
14923 let api_key = result.unwrap();
14924 assert_eq!(api_key.key(), "custom_key");
14925 assert_eq!(api_key.header_name(), "authorization");
14926 }
14927
14928 #[test]
14929 fn api_key_header_invalid_utf8() {
14930 let ctx = test_context();
14931 let mut req = Request::new(Method::Get, "/api/protected");
14932 req.headers_mut()
14934 .insert("x-api-key", vec![0xFF, 0xFE, 0x00, 0x01]);
14935
14936 let result = futures_executor::block_on(ApiKeyHeader::from_request(&ctx, &mut req));
14937 assert!(result.is_err());
14938 let err = result.unwrap_err();
14939 assert!(matches!(err, ApiKeyHeaderError::InvalidUtf8 { .. }));
14940 }
14941
14942 #[test]
14943 fn api_key_header_error_response_401() {
14944 let err = ApiKeyHeaderError::missing_header("x-api-key");
14945 let response = err.into_response();
14946 assert_eq!(response.status().as_u16(), 401);
14947 }
14948
14949 #[test]
14950 fn api_key_header_error_response_json() {
14951 let err = ApiKeyHeaderError::missing_header("x-api-key");
14952 let response = err.into_response();
14953
14954 let has_json_content_type = response
14955 .headers()
14956 .iter()
14957 .any(|(name, value)| name == "content-type" && value == b"application/json");
14958 assert!(has_json_content_type);
14959 }
14960
14961 #[test]
14962 fn api_key_header_secure_compare() {
14963 let api_key = ApiKeyHeader::new("secret_key_123");
14964
14965 assert!(api_key.secure_eq("secret_key_123"));
14967 assert!(!api_key.secure_eq("secret_key_124"));
14968 assert!(!api_key.secure_eq("wrong"));
14969
14970 assert!(api_key.secure_eq_bytes(b"secret_key_123"));
14972 assert!(!api_key.secure_eq_bytes(b"secret_key_124"));
14973 }
14974
14975 #[test]
14976 fn api_key_header_deref_and_as_ref() {
14977 let api_key = ApiKeyHeader::new("deref_test");
14978
14979 let s: &str = &api_key;
14981 assert_eq!(s, "deref_test");
14982
14983 let s: &str = api_key.as_ref();
14985 assert_eq!(s, "deref_test");
14986 }
14987
14988 #[test]
14989 fn api_key_header_config_defaults() {
14990 let config = ApiKeyHeaderConfig::default();
14991 assert_eq!(config.get_header_name(), DEFAULT_API_KEY_HEADER);
14992 }
14993
14994 #[test]
14995 fn api_key_header_error_display() {
14996 let err = ApiKeyHeaderError::missing_header("x-api-key");
14997 assert!(err.to_string().contains("x-api-key"));
14998
14999 let err = ApiKeyHeaderError::empty_key("x-api-key");
15000 assert!(err.to_string().contains("Empty"));
15001
15002 let err = ApiKeyHeaderError::invalid_utf8("x-api-key");
15003 assert!(err.to_string().contains("Invalid UTF-8"));
15004 }
15005
15006 #[test]
15007 fn api_key_header_equality() {
15008 let key1 = ApiKeyHeader::new("same_key");
15009 let key2 = ApiKeyHeader::new("same_key");
15010 let key3 = ApiKeyHeader::new("different_key");
15011
15012 assert_eq!(key1, key2);
15013 assert_ne!(key1, key3);
15014 }
15015}
15016
15017#[cfg(test)]
15018mod api_key_query_tests {
15019 use super::*;
15020 use crate::request::Method;
15021 use crate::response::IntoResponse;
15022
15023 fn test_context() -> RequestContext {
15024 let cx = asupersync::Cx::for_testing();
15025 RequestContext::new(cx, 99999)
15026 }
15027
15028 #[test]
15029 fn api_key_query_basic_extraction() {
15030 let ctx = test_context();
15031 let mut req = Request::new(Method::Get, "/api/webhook");
15032 req.set_query(Some("api_key=test_key_123".to_string()));
15033
15034 let result = futures_executor::block_on(ApiKeyQuery::from_request(&ctx, &mut req));
15035 let api_key = result.unwrap();
15036 assert_eq!(api_key.key(), "test_key_123");
15037 assert_eq!(api_key.param_name(), "api_key");
15038 }
15039
15040 #[test]
15041 fn api_key_query_missing() {
15042 let ctx = test_context();
15043 let mut req = Request::new(Method::Get, "/api/webhook");
15044 let result = futures_executor::block_on(ApiKeyQuery::from_request(&ctx, &mut req));
15047 assert!(result.is_err());
15048 let err = result.unwrap_err();
15049 assert!(matches!(err, ApiKeyQueryError::MissingParam { .. }));
15050 }
15051
15052 #[test]
15053 fn api_key_query_empty_query_string() {
15054 let ctx = test_context();
15055 let mut req = Request::new(Method::Get, "/api/webhook");
15056 req.set_query(Some(String::new()));
15057
15058 let result = futures_executor::block_on(ApiKeyQuery::from_request(&ctx, &mut req));
15059 assert!(result.is_err());
15060 let err = result.unwrap_err();
15061 assert!(matches!(err, ApiKeyQueryError::MissingParam { .. }));
15062 }
15063
15064 #[test]
15065 fn api_key_query_param_missing_but_others_present() {
15066 let ctx = test_context();
15067 let mut req = Request::new(Method::Get, "/api/webhook");
15068 req.set_query(Some("other_param=value".to_string()));
15069
15070 let result = futures_executor::block_on(ApiKeyQuery::from_request(&ctx, &mut req));
15071 assert!(result.is_err());
15072 let err = result.unwrap_err();
15073 assert!(matches!(err, ApiKeyQueryError::MissingParam { .. }));
15074 }
15075
15076 #[test]
15077 fn api_key_query_empty_value() {
15078 let ctx = test_context();
15079 let mut req = Request::new(Method::Get, "/api/webhook");
15080 req.set_query(Some("api_key=".to_string()));
15081
15082 let result = futures_executor::block_on(ApiKeyQuery::from_request(&ctx, &mut req));
15083 assert!(result.is_err());
15084 let err = result.unwrap_err();
15085 assert!(matches!(err, ApiKeyQueryError::EmptyKey { .. }));
15086 }
15087
15088 #[test]
15089 fn api_key_query_whitespace_only() {
15090 let ctx = test_context();
15091 let mut req = Request::new(Method::Get, "/api/webhook");
15092 req.set_query(Some("api_key= ".to_string()));
15093
15094 let result = futures_executor::block_on(ApiKeyQuery::from_request(&ctx, &mut req));
15095 assert!(result.is_err());
15096 let err = result.unwrap_err();
15097 assert!(matches!(err, ApiKeyQueryError::EmptyKey { .. }));
15098 }
15099
15100 #[test]
15101 fn api_key_query_trims_whitespace() {
15102 let ctx = test_context();
15103 let mut req = Request::new(Method::Get, "/api/webhook");
15104 req.set_query(Some("api_key= my_key_123 ".to_string()));
15105
15106 let result = futures_executor::block_on(ApiKeyQuery::from_request(&ctx, &mut req));
15107 let api_key = result.unwrap();
15108 assert_eq!(api_key.key(), "my_key_123");
15109 }
15110
15111 #[test]
15112 fn api_key_query_custom_param_name() {
15113 let ctx = test_context();
15114 let mut req = Request::new(Method::Get, "/api/webhook");
15115 req.set_query(Some("token=custom_key".to_string()));
15116 req.insert_extension(ApiKeyQueryConfig::new().param_name("token"));
15117
15118 let result = futures_executor::block_on(ApiKeyQuery::from_request(&ctx, &mut req));
15119 let api_key = result.unwrap();
15120 assert_eq!(api_key.key(), "custom_key");
15121 assert_eq!(api_key.param_name(), "token");
15122 }
15123
15124 #[test]
15125 fn api_key_query_with_other_params() {
15126 let ctx = test_context();
15127 let mut req = Request::new(Method::Get, "/api/webhook");
15128 req.set_query(Some(
15129 "callback=https://example.com&api_key=webhook_key&format=json".to_string(),
15130 ));
15131
15132 let result = futures_executor::block_on(ApiKeyQuery::from_request(&ctx, &mut req));
15133 let api_key = result.unwrap();
15134 assert_eq!(api_key.key(), "webhook_key");
15135 }
15136
15137 #[test]
15138 fn api_key_query_url_encoded_value() {
15139 let ctx = test_context();
15140 let mut req = Request::new(Method::Get, "/api/webhook");
15141 req.set_query(Some("api_key=key%2Bwith%20spaces".to_string()));
15143
15144 let result = futures_executor::block_on(ApiKeyQuery::from_request(&ctx, &mut req));
15145 let api_key = result.unwrap();
15146 assert_eq!(api_key.key(), "key+with spaces");
15147 }
15148
15149 #[test]
15150 fn api_key_query_error_response_401() {
15151 let err = ApiKeyQueryError::missing_param("api_key");
15152 let response = err.into_response();
15153 assert_eq!(response.status().as_u16(), 401);
15154 }
15155
15156 #[test]
15157 fn api_key_query_error_response_json() {
15158 let err = ApiKeyQueryError::missing_param("api_key");
15159 let response = err.into_response();
15160
15161 let has_json_content_type = response
15162 .headers()
15163 .iter()
15164 .any(|(n, v)| n == "content-type" && v.starts_with(b"application/json"));
15165 assert!(has_json_content_type);
15166 }
15167
15168 #[test]
15169 fn api_key_query_secure_compare() {
15170 let api_key = ApiKeyQuery::new("secret_key_123");
15171
15172 assert!(api_key.secure_eq("secret_key_123"));
15174 assert!(!api_key.secure_eq("secret_key_124"));
15175 assert!(!api_key.secure_eq("wrong"));
15176
15177 assert!(api_key.secure_eq_bytes(b"secret_key_123"));
15179 assert!(!api_key.secure_eq_bytes(b"secret_key_124"));
15180 }
15181
15182 #[test]
15183 fn api_key_query_deref_and_as_ref() {
15184 let api_key = ApiKeyQuery::new("deref_test");
15185
15186 let s: &str = &api_key;
15188 assert_eq!(s, "deref_test");
15189
15190 let s: &str = api_key.as_ref();
15192 assert_eq!(s, "deref_test");
15193 }
15194
15195 #[test]
15196 fn api_key_query_config_defaults() {
15197 let config = ApiKeyQueryConfig::default();
15198 assert_eq!(config.get_param_name(), DEFAULT_API_KEY_QUERY_PARAM);
15199 }
15200
15201 #[test]
15202 fn api_key_query_error_display() {
15203 let err = ApiKeyQueryError::missing_param("api_key");
15204 assert!(err.to_string().contains("api_key"));
15205
15206 let err = ApiKeyQueryError::empty_key("api_key");
15207 assert!(err.to_string().contains("Empty"));
15208 }
15209
15210 #[test]
15211 fn api_key_query_equality() {
15212 let key1 = ApiKeyQuery::new("same_key");
15213 let key2 = ApiKeyQuery::new("same_key");
15214 let key3 = ApiKeyQuery::new("different_key");
15215
15216 assert_eq!(key1, key2);
15217 assert_ne!(key1, key3);
15218 }
15219}
15220
15221#[cfg(test)]
15222mod api_key_cookie_tests {
15223 use super::*;
15224 use crate::request::Method;
15225 use crate::response::IntoResponse;
15226
15227 fn test_context() -> RequestContext {
15228 let cx = asupersync::Cx::for_testing();
15229 RequestContext::new(cx, 77777)
15230 }
15231
15232 #[test]
15233 fn api_key_cookie_basic_extraction() {
15234 let ctx = test_context();
15235 let mut req = Request::new(Method::Get, "/api/protected");
15236 req.headers_mut()
15237 .insert("cookie", b"api_key=test_key_123".to_vec());
15238
15239 let result = futures_executor::block_on(ApiKeyCookie::from_request(&ctx, &mut req));
15240 let api_key = result.unwrap();
15241 assert_eq!(api_key.key(), "test_key_123");
15242 assert_eq!(api_key.cookie_name(), "api_key");
15243 }
15244
15245 #[test]
15246 fn api_key_cookie_missing_header() {
15247 let ctx = test_context();
15248 let mut req = Request::new(Method::Get, "/api/protected");
15249 let result = futures_executor::block_on(ApiKeyCookie::from_request(&ctx, &mut req));
15252 assert!(result.is_err());
15253 let err = result.unwrap_err();
15254 assert!(matches!(err, ApiKeyCookieError::MissingCookie { .. }));
15255 }
15256
15257 #[test]
15258 fn api_key_cookie_other_cookies_present() {
15259 let ctx = test_context();
15260 let mut req = Request::new(Method::Get, "/api/protected");
15261 req.headers_mut()
15262 .insert("cookie", b"session_id=abc123; theme=dark".to_vec());
15263 let result = futures_executor::block_on(ApiKeyCookie::from_request(&ctx, &mut req));
15266 assert!(result.is_err());
15267 let err = result.unwrap_err();
15268 assert!(matches!(err, ApiKeyCookieError::MissingCookie { .. }));
15269 }
15270
15271 #[test]
15272 fn api_key_cookie_empty_value() {
15273 let ctx = test_context();
15274 let mut req = Request::new(Method::Get, "/api/protected");
15275 req.headers_mut().insert("cookie", b"api_key=".to_vec());
15276
15277 let result = futures_executor::block_on(ApiKeyCookie::from_request(&ctx, &mut req));
15278 assert!(result.is_err());
15279 let err = result.unwrap_err();
15280 assert!(matches!(err, ApiKeyCookieError::EmptyKey { .. }));
15281 }
15282
15283 #[test]
15284 fn api_key_cookie_whitespace_only() {
15285 let ctx = test_context();
15286 let mut req = Request::new(Method::Get, "/api/protected");
15287 req.headers_mut().insert("cookie", b"api_key= ".to_vec());
15288
15289 let result = futures_executor::block_on(ApiKeyCookie::from_request(&ctx, &mut req));
15290 assert!(result.is_err());
15291 let err = result.unwrap_err();
15292 assert!(matches!(err, ApiKeyCookieError::EmptyKey { .. }));
15293 }
15294
15295 #[test]
15296 fn api_key_cookie_trims_whitespace() {
15297 let ctx = test_context();
15298 let mut req = Request::new(Method::Get, "/api/protected");
15299 req.headers_mut()
15300 .insert("cookie", b"api_key= my_key_123 ".to_vec());
15301
15302 let result = futures_executor::block_on(ApiKeyCookie::from_request(&ctx, &mut req));
15303 let api_key = result.unwrap();
15304 assert_eq!(api_key.key(), "my_key_123");
15305 }
15306
15307 #[test]
15308 fn api_key_cookie_custom_name() {
15309 let ctx = test_context();
15310 let mut req = Request::new(Method::Get, "/api/protected");
15311 req.headers_mut()
15312 .insert("cookie", b"auth_token=custom_key".to_vec());
15313 req.insert_extension(ApiKeyCookieConfig::new().cookie_name("auth_token"));
15314
15315 let result = futures_executor::block_on(ApiKeyCookie::from_request(&ctx, &mut req));
15316 let api_key = result.unwrap();
15317 assert_eq!(api_key.key(), "custom_key");
15318 assert_eq!(api_key.cookie_name(), "auth_token");
15319 }
15320
15321 #[test]
15322 fn api_key_cookie_with_multiple_cookies() {
15323 let ctx = test_context();
15324 let mut req = Request::new(Method::Get, "/api/protected");
15325 req.headers_mut().insert(
15326 "cookie",
15327 b"session_id=sess123; api_key=my_api_key; theme=dark".to_vec(),
15328 );
15329
15330 let result = futures_executor::block_on(ApiKeyCookie::from_request(&ctx, &mut req));
15331 let api_key = result.unwrap();
15332 assert_eq!(api_key.key(), "my_api_key");
15333 }
15334
15335 #[test]
15336 fn api_key_cookie_error_response_401() {
15337 let err = ApiKeyCookieError::missing_cookie("api_key");
15338 let response = err.into_response();
15339 assert_eq!(response.status().as_u16(), 401);
15340 }
15341
15342 #[test]
15343 fn api_key_cookie_error_response_json() {
15344 let err = ApiKeyCookieError::missing_cookie("api_key");
15345 let response = err.into_response();
15346
15347 let has_json_content_type = response
15348 .headers()
15349 .iter()
15350 .any(|(n, v)| n == "content-type" && v.starts_with(b"application/json"));
15351 assert!(has_json_content_type);
15352 }
15353
15354 #[test]
15355 fn api_key_cookie_secure_compare() {
15356 let api_key = ApiKeyCookie::new("secret_key_123");
15357
15358 assert!(api_key.secure_eq("secret_key_123"));
15360 assert!(!api_key.secure_eq("secret_key_124"));
15361 assert!(!api_key.secure_eq("wrong"));
15362
15363 assert!(api_key.secure_eq_bytes(b"secret_key_123"));
15365 assert!(!api_key.secure_eq_bytes(b"secret_key_124"));
15366 }
15367
15368 #[test]
15369 fn api_key_cookie_deref_and_as_ref() {
15370 let api_key = ApiKeyCookie::new("deref_test");
15371
15372 let s: &str = &api_key;
15374 assert_eq!(s, "deref_test");
15375
15376 let s: &str = api_key.as_ref();
15378 assert_eq!(s, "deref_test");
15379 }
15380
15381 #[test]
15382 fn api_key_cookie_config_defaults() {
15383 let config = ApiKeyCookieConfig::default();
15384 assert_eq!(config.get_cookie_name(), DEFAULT_API_KEY_COOKIE);
15385 }
15386
15387 #[test]
15388 fn api_key_cookie_error_display() {
15389 let err = ApiKeyCookieError::missing_cookie("api_key");
15390 assert!(err.to_string().contains("api_key"));
15391
15392 let err = ApiKeyCookieError::empty_key("api_key");
15393 assert!(err.to_string().contains("Empty"));
15394 }
15395
15396 #[test]
15397 fn api_key_cookie_equality() {
15398 let key1 = ApiKeyCookie::new("same_key");
15399 let key2 = ApiKeyCookie::new("same_key");
15400 let key3 = ApiKeyCookie::new("different_key");
15401
15402 assert_eq!(key1, key2);
15403 assert_ne!(key1, key3);
15404 }
15405}
15406
15407#[cfg(test)]
15408mod basic_auth_tests {
15409 use super::*;
15410 use crate::request::Method;
15411 use crate::response::IntoResponse;
15412
15413 fn test_context() -> RequestContext {
15414 let cx = asupersync::Cx::for_testing();
15415 RequestContext::new(cx, 12345)
15416 }
15417
15418 fn encode_basic_auth(username: &str, password: &str) -> String {
15420 const ALPHABET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
15422 let input = format!("{username}:{password}");
15423 let bytes = input.as_bytes();
15424 let mut output = String::new();
15425
15426 for chunk in bytes.chunks(3) {
15427 let mut n: u32 = 0;
15428 for (i, &byte) in chunk.iter().enumerate() {
15429 n |= u32::from(byte) << (16 - 8 * i);
15430 }
15431
15432 let chars = match chunk.len() {
15433 3 => 4,
15434 2 => 3,
15435 1 => 2,
15436 _ => unreachable!(),
15437 };
15438
15439 for i in 0..chars {
15440 let idx = ((n >> (18 - 6 * i)) & 0x3F) as usize;
15441 output.push(ALPHABET[idx] as char);
15442 }
15443
15444 for _ in chars..4 {
15446 output.push('=');
15447 }
15448 }
15449
15450 output
15451 }
15452
15453 #[test]
15454 fn basic_auth_extract_valid_credentials() {
15455 let ctx = test_context();
15456 let mut req = Request::new(Method::Get, "/api/protected");
15457 let encoded = encode_basic_auth("alice", "secret123");
15458 req.headers_mut()
15459 .insert("authorization", format!("Basic {encoded}").into_bytes());
15460
15461 let result = futures_executor::block_on(BasicAuth::from_request(&ctx, &mut req));
15462 let auth = result.unwrap();
15463 assert_eq!(auth.username(), "alice");
15464 assert_eq!(auth.password(), "secret123");
15465 }
15466
15467 #[test]
15468 fn basic_auth_extract_lowercase_basic() {
15469 let ctx = test_context();
15470 let mut req = Request::new(Method::Get, "/api/protected");
15471 let encoded = encode_basic_auth("bob", "pass");
15472 req.headers_mut()
15473 .insert("authorization", format!("basic {encoded}").into_bytes());
15474
15475 let result = futures_executor::block_on(BasicAuth::from_request(&ctx, &mut req));
15476 let auth = result.unwrap();
15477 assert_eq!(auth.username(), "bob");
15478 assert_eq!(auth.password(), "pass");
15479 }
15480
15481 #[test]
15482 fn basic_auth_missing_header() {
15483 let ctx = test_context();
15484 let mut req = Request::new(Method::Get, "/api/protected");
15485 let result = futures_executor::block_on(BasicAuth::from_request(&ctx, &mut req));
15488 let err = result.unwrap_err();
15489 assert_eq!(err, BasicAuthError::MissingHeader);
15490 }
15491
15492 #[test]
15493 fn basic_auth_wrong_scheme() {
15494 let ctx = test_context();
15495 let mut req = Request::new(Method::Get, "/api/protected");
15496 req.headers_mut()
15497 .insert("authorization", b"Bearer sometoken".to_vec());
15498
15499 let result = futures_executor::block_on(BasicAuth::from_request(&ctx, &mut req));
15500 let err = result.unwrap_err();
15501 assert_eq!(err, BasicAuthError::InvalidScheme);
15502 }
15503
15504 #[test]
15505 fn basic_auth_invalid_base64() {
15506 let ctx = test_context();
15507 let mut req = Request::new(Method::Get, "/api/protected");
15508 req.headers_mut()
15509 .insert("authorization", b"Basic !!!invalid!!!".to_vec());
15510
15511 let result = futures_executor::block_on(BasicAuth::from_request(&ctx, &mut req));
15512 let err = result.unwrap_err();
15513 assert_eq!(err, BasicAuthError::InvalidBase64);
15514 }
15515
15516 #[test]
15517 fn basic_auth_missing_colon() {
15518 let ctx = test_context();
15519 let mut req = Request::new(Method::Get, "/api/protected");
15520 req.headers_mut()
15522 .insert("authorization", b"Basic bm9jb2xvbg==".to_vec());
15523
15524 let result = futures_executor::block_on(BasicAuth::from_request(&ctx, &mut req));
15525 let err = result.unwrap_err();
15526 assert_eq!(err, BasicAuthError::MissingColon);
15527 }
15528
15529 #[test]
15530 fn basic_auth_empty_username() {
15531 let ctx = test_context();
15532 let mut req = Request::new(Method::Get, "/api/protected");
15533 let encoded = encode_basic_auth("", "password");
15534 req.headers_mut()
15535 .insert("authorization", format!("Basic {encoded}").into_bytes());
15536
15537 let result = futures_executor::block_on(BasicAuth::from_request(&ctx, &mut req));
15538 let auth = result.unwrap();
15539 assert_eq!(auth.username(), "");
15540 assert_eq!(auth.password(), "password");
15541 }
15542
15543 #[test]
15544 fn basic_auth_empty_password() {
15545 let ctx = test_context();
15546 let mut req = Request::new(Method::Get, "/api/protected");
15547 let encoded = encode_basic_auth("user", "");
15548 req.headers_mut()
15549 .insert("authorization", format!("Basic {encoded}").into_bytes());
15550
15551 let result = futures_executor::block_on(BasicAuth::from_request(&ctx, &mut req));
15552 let auth = result.unwrap();
15553 assert_eq!(auth.username(), "user");
15554 assert_eq!(auth.password(), "");
15555 }
15556
15557 #[test]
15558 fn basic_auth_password_with_colons() {
15559 let ctx = test_context();
15560 let mut req = Request::new(Method::Get, "/api/protected");
15561 let encoded = encode_basic_auth("user", "pass:word:with:colons");
15563 req.headers_mut()
15564 .insert("authorization", format!("Basic {encoded}").into_bytes());
15565
15566 let result = futures_executor::block_on(BasicAuth::from_request(&ctx, &mut req));
15567 let auth = result.unwrap();
15568 assert_eq!(auth.username(), "user");
15569 assert_eq!(auth.password(), "pass:word:with:colons");
15570 }
15571
15572 #[test]
15573 fn basic_auth_optional_some() {
15574 let ctx = test_context();
15575 let mut req = Request::new(Method::Get, "/api/protected");
15576 let encoded = encode_basic_auth("optional", "user");
15577 req.headers_mut()
15578 .insert("authorization", format!("Basic {encoded}").into_bytes());
15579
15580 let result = futures_executor::block_on(Option::<BasicAuth>::from_request(&ctx, &mut req));
15581 let maybe_auth = result.unwrap();
15582 assert!(maybe_auth.is_some());
15583 assert_eq!(maybe_auth.unwrap().username(), "optional");
15584 }
15585
15586 #[test]
15587 fn basic_auth_optional_none() {
15588 let ctx = test_context();
15589 let mut req = Request::new(Method::Get, "/api/protected");
15590 let result = futures_executor::block_on(Option::<BasicAuth>::from_request(&ctx, &mut req));
15593 let maybe_auth = result.unwrap();
15594 assert!(maybe_auth.is_none());
15595 }
15596
15597 #[test]
15598 fn basic_auth_error_response_401() {
15599 let err = BasicAuthError::missing_header();
15600 let response = err.into_response();
15601 assert_eq!(response.status().as_u16(), 401);
15602 }
15603
15604 #[test]
15605 fn basic_auth_error_has_www_authenticate() {
15606 let err = BasicAuthError::missing_header();
15607 let response = err.into_response();
15608
15609 let has_www_auth = response
15610 .headers()
15611 .iter()
15612 .any(|(name, value)| name == "www-authenticate" && value == b"Basic");
15613 assert!(has_www_auth);
15614 }
15615
15616 #[test]
15617 fn basic_auth_error_display() {
15618 assert_eq!(
15619 BasicAuthError::missing_header().to_string(),
15620 "Missing Authorization header"
15621 );
15622 assert_eq!(
15623 BasicAuthError::invalid_scheme().to_string(),
15624 "Authorization header must use Basic scheme"
15625 );
15626 assert_eq!(
15627 BasicAuthError::invalid_base64().to_string(),
15628 "Invalid base64 encoding in credentials"
15629 );
15630 assert_eq!(
15631 BasicAuthError::missing_colon().to_string(),
15632 "Credentials must contain username:password"
15633 );
15634 assert_eq!(
15635 BasicAuthError::invalid_utf8().to_string(),
15636 "Credentials contain invalid UTF-8"
15637 );
15638 }
15639
15640 #[test]
15641 fn basic_auth_error_detail() {
15642 assert_eq!(BasicAuthError::MissingHeader.detail(), "Not authenticated");
15643 assert_eq!(
15644 BasicAuthError::InvalidScheme.detail(),
15645 "Invalid authentication credentials"
15646 );
15647 assert_eq!(
15648 BasicAuthError::InvalidBase64.detail(),
15649 "Invalid authentication credentials"
15650 );
15651 assert_eq!(
15652 BasicAuthError::MissingColon.detail(),
15653 "Invalid authentication credentials"
15654 );
15655 assert_eq!(
15656 BasicAuthError::InvalidUtf8.detail(),
15657 "Invalid authentication credentials"
15658 );
15659 }
15660
15661 #[test]
15662 fn basic_auth_new_and_accessors() {
15663 let auth = BasicAuth::new("testuser", "testpass");
15664 assert_eq!(auth.username(), "testuser");
15665 assert_eq!(auth.password(), "testpass");
15666 let (user, pass) = auth.into_credentials();
15667 assert_eq!(user, "testuser");
15668 assert_eq!(pass, "testpass");
15669 }
15670
15671 #[test]
15672 fn basic_auth_error_response_json_body() {
15673 let err = BasicAuthError::missing_header();
15674 let response = err.into_response();
15675
15676 let body_str = match response.body_ref() {
15677 crate::response::ResponseBody::Bytes(b) => String::from_utf8_lossy(b).to_string(),
15678 _ => panic!("Expected Bytes body"),
15679 };
15680 let body: serde_json::Value = serde_json::from_str(&body_str).unwrap();
15681
15682 assert_eq!(body["detail"], "Not authenticated");
15683 }
15684
15685 #[test]
15686 fn basic_auth_error_content_type_json() {
15687 let err = BasicAuthError::missing_header();
15688 let response = err.into_response();
15689
15690 let has_json_content_type = response
15691 .headers()
15692 .iter()
15693 .any(|(name, value)| name == "content-type" && value == b"application/json");
15694 assert!(has_json_content_type);
15695 }
15696
15697 #[test]
15698 fn basic_auth_all_errors_return_401() {
15699 let errors = [
15700 BasicAuthError::missing_header(),
15701 BasicAuthError::invalid_scheme(),
15702 BasicAuthError::invalid_base64(),
15703 BasicAuthError::missing_colon(),
15704 BasicAuthError::invalid_utf8(),
15705 ];
15706
15707 for err in errors {
15708 let response = err.into_response();
15709 assert_eq!(
15710 response.status().as_u16(),
15711 401,
15712 "All BasicAuth errors should be 401"
15713 );
15714 }
15715 }
15716
15717 #[test]
15718 fn basic_auth_all_errors_have_www_authenticate() {
15719 let errors = [
15720 BasicAuthError::missing_header(),
15721 BasicAuthError::invalid_scheme(),
15722 BasicAuthError::invalid_base64(),
15723 BasicAuthError::missing_colon(),
15724 BasicAuthError::invalid_utf8(),
15725 ];
15726
15727 for err in errors {
15728 let response = err.into_response();
15729 let has_www_auth = response
15730 .headers()
15731 .iter()
15732 .any(|(name, value)| name == "www-authenticate" && value == b"Basic");
15733 assert!(
15734 has_www_auth,
15735 "All BasicAuth errors should have WWW-Authenticate: Basic"
15736 );
15737 }
15738 }
15739
15740 #[test]
15741 fn basic_auth_eq_and_clone() {
15742 let auth1 = BasicAuth::new("user", "pass");
15743 let auth2 = BasicAuth::new("user", "pass");
15744 let auth3 = BasicAuth::new("other", "pass");
15745
15746 assert_eq!(auth1, auth2);
15747 assert_ne!(auth1, auth3);
15748
15749 let cloned = auth1.clone();
15750 assert_eq!(auth1, cloned);
15751 }
15752
15753 #[test]
15754 fn basic_auth_error_eq() {
15755 assert_eq!(BasicAuthError::MissingHeader, BasicAuthError::MissingHeader);
15756 assert_eq!(BasicAuthError::InvalidScheme, BasicAuthError::InvalidScheme);
15757 assert_eq!(BasicAuthError::InvalidBase64, BasicAuthError::InvalidBase64);
15758 assert_eq!(BasicAuthError::MissingColon, BasicAuthError::MissingColon);
15759 assert_eq!(BasicAuthError::InvalidUtf8, BasicAuthError::InvalidUtf8);
15760 assert_ne!(BasicAuthError::MissingHeader, BasicAuthError::InvalidScheme);
15761 }
15762
15763 #[test]
15764 fn basic_auth_debug() {
15765 let auth = BasicAuth::new("debug_user", "debug_pass");
15766 let debug_str = format!("{auth:?}");
15767 assert!(debug_str.contains("debug_user"));
15768 assert!(debug_str.contains("debug_pass"));
15769 }
15770
15771 #[test]
15773 fn decode_base64_valid() {
15774 let result = decode_base64("dXNlcjpwYXNz").unwrap();
15776 assert_eq!(String::from_utf8(result).unwrap(), "user:pass");
15777 }
15778
15779 #[test]
15780 fn decode_base64_with_padding() {
15781 let result = decode_base64("YQ==").unwrap();
15783 assert_eq!(String::from_utf8(result).unwrap(), "a");
15784
15785 let result = decode_base64("YWI=").unwrap();
15787 assert_eq!(String::from_utf8(result).unwrap(), "ab");
15788 }
15789
15790 #[test]
15791 fn decode_base64_without_padding() {
15792 let result = decode_base64("YQ").unwrap();
15794 assert_eq!(String::from_utf8(result).unwrap(), "a");
15795
15796 let result = decode_base64("YWI").unwrap();
15797 assert_eq!(String::from_utf8(result).unwrap(), "ab");
15798 }
15799
15800 #[test]
15801 fn decode_base64_empty() {
15802 let result = decode_base64("").unwrap();
15803 assert!(result.is_empty());
15804 }
15805
15806 #[test]
15807 fn decode_base64_invalid_char() {
15808 let result = decode_base64("abc!def");
15809 assert!(result.is_err());
15810 }
15811
15812 #[test]
15813 fn decode_base64_complex_password() {
15814 let encoded = encode_basic_auth("admin", "p@$$w0rd!123");
15817 let result = decode_base64(&encoded).unwrap();
15819 assert_eq!(String::from_utf8(result).unwrap(), "admin:p@$$w0rd!123");
15820 }
15821}
15822
15823#[cfg(test)]
15824mod secure_compare_tests {
15825 use super::*;
15826
15827 #[test]
15832 fn constant_time_eq_equal_slices() {
15833 assert!(constant_time_eq(b"secret", b"secret"));
15834 assert!(constant_time_eq(b"", b""));
15835 assert!(constant_time_eq(b"a", b"a"));
15836 assert!(constant_time_eq(
15837 b"a_very_long_secret_token_12345",
15838 b"a_very_long_secret_token_12345"
15839 ));
15840 }
15841
15842 #[test]
15843 fn constant_time_eq_different_slices() {
15844 assert!(!constant_time_eq(b"secret", b"secreT"));
15845 assert!(!constant_time_eq(b"aaaaaa", b"aaaaab"));
15846 assert!(!constant_time_eq(b"a", b"b"));
15847 }
15848
15849 #[test]
15850 fn constant_time_eq_different_lengths() {
15851 assert!(!constant_time_eq(b"short", b"longer"));
15852 assert!(!constant_time_eq(b"", b"a"));
15853 assert!(!constant_time_eq(b"abc", b"ab"));
15854 }
15855
15856 #[test]
15857 fn constant_time_eq_binary_data() {
15858 let a = [0u8, 1, 2, 3, 255, 254, 253];
15859 let b = [0u8, 1, 2, 3, 255, 254, 253];
15860 let c = [0u8, 1, 2, 3, 255, 254, 252];
15861
15862 assert!(constant_time_eq(&a, &b));
15863 assert!(!constant_time_eq(&a, &c));
15864 }
15865
15866 #[test]
15867 fn constant_time_eq_all_zeros() {
15868 let a = [0u8; 32];
15869 let b = [0u8; 32];
15870 let c = {
15871 let mut arr = [0u8; 32];
15872 arr[31] = 1;
15873 arr
15874 };
15875
15876 assert!(constant_time_eq(&a, &b));
15877 assert!(!constant_time_eq(&a, &c));
15878 }
15879
15880 #[test]
15881 fn constant_time_eq_all_ones() {
15882 let a = [0xFFu8; 16];
15883 let b = [0xFFu8; 16];
15884 let c = {
15885 let mut arr = [0xFFu8; 16];
15886 arr[0] = 0xFE;
15887 arr
15888 };
15889
15890 assert!(constant_time_eq(&a, &b));
15891 assert!(!constant_time_eq(&a, &c));
15892 }
15893
15894 #[test]
15899 fn constant_time_str_eq_equal() {
15900 assert!(constant_time_str_eq("password123", "password123"));
15901 assert!(constant_time_str_eq("", ""));
15902 assert!(constant_time_str_eq("🔐", "🔐")); }
15904
15905 #[test]
15906 fn constant_time_str_eq_different() {
15907 assert!(!constant_time_str_eq("password123", "password124"));
15908 assert!(!constant_time_str_eq("case", "CASE"));
15909 assert!(!constant_time_str_eq("🔐", "🔑"));
15910 }
15911
15912 #[test]
15913 fn constant_time_str_eq_unicode() {
15914 assert!(constant_time_str_eq("日本語", "日本語"));
15916 assert!(!constant_time_str_eq("日本語", "日本话"));
15917 assert!(!constant_time_str_eq("café", "cafe"));
15918 }
15919
15920 #[test]
15925 fn bearer_token_secure_eq() {
15926 let token = BearerToken::new("my_secret_token");
15927
15928 assert!(token.secure_eq("my_secret_token"));
15929 assert!(!token.secure_eq("my_secret_Token")); assert!(!token.secure_eq("wrong_token"));
15931 }
15932
15933 #[test]
15934 fn bearer_token_secure_eq_bytes() {
15935 let token = BearerToken::new("api_key_123");
15936
15937 assert!(token.secure_eq_bytes(b"api_key_123"));
15938 assert!(!token.secure_eq_bytes(b"api_key_124"));
15939 }
15940
15941 #[test]
15946 fn str_secure_eq() {
15947 let secret: &str = "hunter2";
15948
15949 assert!(secret.secure_eq("hunter2"));
15950 assert!(!secret.secure_eq("hunter3"));
15951 }
15952
15953 #[test]
15954 fn string_secure_eq() {
15955 let secret = String::from("password");
15956
15957 assert!(secret.secure_eq("password"));
15958 assert!(!secret.secure_eq("passwor"));
15959 }
15960
15961 #[test]
15962 fn string_secure_eq_bytes() {
15963 let secret = String::from("binary_safe");
15964
15965 assert!(secret.secure_eq_bytes(b"binary_safe"));
15966 assert!(!secret.secure_eq_bytes(b"binary_Safe"));
15967 }
15968
15969 #[test]
15974 fn byte_slice_secure_eq() {
15975 let hmac: &[u8] = &[0xDE, 0xAD, 0xBE, 0xEF];
15976
15977 assert!(hmac.secure_eq_bytes(&[0xDE, 0xAD, 0xBE, 0xEF]));
15978 assert!(!hmac.secure_eq_bytes(&[0xDE, 0xAD, 0xBE, 0xEE]));
15979 }
15980
15981 #[test]
15982 fn byte_array_secure_eq() {
15983 let key: [u8; 4] = [1, 2, 3, 4];
15984
15985 assert!(key.secure_eq_bytes(&[1, 2, 3, 4]));
15986 assert!(!key.secure_eq_bytes(&[1, 2, 3, 5]));
15987 }
15988
15989 #[test]
15990 fn vec_secure_eq() {
15991 let token: Vec<u8> = vec![0x41, 0x42, 0x43];
15992
15993 assert!(token.secure_eq("ABC"));
15994 assert!(!token.secure_eq("ABD"));
15995 }
15996
15997 #[test]
16002 fn secure_compare_empty_values() {
16003 assert!(constant_time_eq(b"", b""));
16004 assert!(constant_time_str_eq("", ""));
16005 assert!(!constant_time_eq(b"", b"x"));
16006 assert!(!constant_time_str_eq("", "x"));
16007 }
16008
16009 #[test]
16010 fn secure_compare_single_bit_difference() {
16011 let a = [0b1111_1111u8];
16013 let b = [0b1111_1110u8];
16014
16015 assert!(!constant_time_eq(&a, &b));
16016 }
16017
16018 #[test]
16019 fn secure_compare_first_byte_differs() {
16020 assert!(!constant_time_eq(b"Xsecret", b"Ysecret"));
16022 }
16023
16024 #[test]
16025 fn secure_compare_last_byte_differs() {
16026 assert!(!constant_time_eq(b"secretX", b"secretY"));
16028 }
16029
16030 #[test]
16031 fn secure_compare_middle_byte_differs() {
16032 assert!(!constant_time_eq(b"secXet", b"secYet"));
16034 }
16035
16036 #[test]
16038 fn bearer_token_integration_with_secure_compare() {
16039 let token = BearerToken::new("real_api_token_xyz789");
16040
16041 let stored_token = "real_api_token_xyz789";
16043 let is_valid = token.secure_eq(stored_token);
16044 assert!(is_valid);
16045
16046 let wrong_token = "fake_api_token_abc123";
16048 let is_invalid = !token.secure_eq(wrong_token);
16049 assert!(is_invalid);
16050 }
16051
16052 #[test]
16053 fn deref_with_secure_compare() {
16054 let token = BearerToken::new("my_token");
16056 let token_str: &str = &token; assert!(token_str.secure_eq("my_token"));
16060 }
16061
16062 #[test]
16071 fn algorithm_processes_all_bytes() {
16072 let a = b"Xsecret_token";
16078 let b = b"Ysecret_token";
16079 assert!(!constant_time_eq(a, b));
16080
16081 let c = b"secret_tokenX";
16083 let d = b"secret_tokenY";
16084 assert!(!constant_time_eq(c, d));
16085
16086 }
16089}
16090
16091#[cfg(test)]
16092mod pagination_tests {
16093 use super::*;
16094 use crate::request::Method;
16095 use crate::response::IntoResponse;
16096
16097 fn test_context() -> RequestContext {
16099 let cx = asupersync::Cx::for_testing();
16100 RequestContext::new(cx, 12345)
16101 }
16102
16103 #[test]
16108 fn pagination_default_values() {
16109 let p = Pagination::default();
16110 assert_eq!(p.page(), DEFAULT_PAGE);
16111 assert_eq!(p.per_page(), DEFAULT_PER_PAGE);
16112 assert_eq!(p.limit(), DEFAULT_PER_PAGE);
16113 assert_eq!(p.offset(), 0);
16114 }
16115
16116 #[test]
16117 fn pagination_new() {
16118 let p = Pagination::new(3, 50);
16119 assert_eq!(p.page(), 3);
16120 assert_eq!(p.per_page(), 50);
16121 assert_eq!(p.offset(), 100); }
16123
16124 #[test]
16125 fn pagination_new_clamps_per_page() {
16126 let p = Pagination::new(1, 0);
16128 assert_eq!(p.per_page(), 1);
16129
16130 let p = Pagination::new(1, 1000);
16132 assert_eq!(p.per_page(), MAX_PER_PAGE);
16133 }
16134
16135 #[test]
16136 fn pagination_new_clamps_page() {
16137 let p = Pagination::new(0, 20);
16139 assert_eq!(p.page(), 1);
16140 }
16141
16142 #[test]
16143 fn pagination_from_offset() {
16144 let p = Pagination::from_offset(40, 20);
16145 assert_eq!(p.offset(), 40); assert_eq!(p.per_page(), 20);
16147 assert_eq!(p.page(), 3); }
16149
16150 #[test]
16151 fn pagination_total_pages() {
16152 let p = Pagination::new(1, 10);
16153 assert_eq!(p.total_pages(0), 0);
16154 assert_eq!(p.total_pages(10), 1);
16155 assert_eq!(p.total_pages(11), 2);
16156 assert_eq!(p.total_pages(100), 10);
16157 }
16158
16159 #[test]
16160 fn pagination_has_next_prev() {
16161 let p = Pagination::new(1, 10);
16162 assert!(!p.has_prev());
16163 assert!(p.has_next(100));
16164
16165 let p = Pagination::new(5, 10);
16166 assert!(p.has_prev());
16167 assert!(p.has_next(100));
16168
16169 let p = Pagination::new(10, 10);
16170 assert!(p.has_prev());
16171 assert!(!p.has_next(100));
16172 }
16173
16174 #[test]
16179 fn pagination_extractor_default_params() {
16180 let ctx = test_context();
16181 let mut req = Request::new(Method::Get, "/items");
16182
16183 let p = futures_executor::block_on(Pagination::from_request(&ctx, &mut req)).unwrap();
16184 assert_eq!(p.page(), DEFAULT_PAGE);
16185 assert_eq!(p.per_page(), DEFAULT_PER_PAGE);
16186 }
16187
16188 #[test]
16189 fn pagination_extractor_page_param() {
16190 let ctx = test_context();
16191 let mut req = Request::new(Method::Get, "/items?page=5");
16192 req.insert_extension(QueryParams::parse("page=5"));
16193
16194 let p = futures_executor::block_on(Pagination::from_request(&ctx, &mut req)).unwrap();
16195 assert_eq!(p.page(), 5);
16196 assert_eq!(p.per_page(), DEFAULT_PER_PAGE);
16197 }
16198
16199 #[test]
16200 fn pagination_extractor_per_page_param() {
16201 let ctx = test_context();
16202 let mut req = Request::new(Method::Get, "/items?per_page=50");
16203 req.insert_extension(QueryParams::parse("per_page=50"));
16204
16205 let p = futures_executor::block_on(Pagination::from_request(&ctx, &mut req)).unwrap();
16206 assert_eq!(p.page(), DEFAULT_PAGE);
16207 assert_eq!(p.per_page(), 50);
16208 }
16209
16210 #[test]
16211 fn pagination_extractor_limit_alias() {
16212 let ctx = test_context();
16213 let mut req = Request::new(Method::Get, "/items?limit=25");
16214 req.insert_extension(QueryParams::parse("limit=25"));
16215
16216 let p = futures_executor::block_on(Pagination::from_request(&ctx, &mut req)).unwrap();
16217 assert_eq!(p.per_page(), 25);
16218 }
16219
16220 #[test]
16221 fn pagination_extractor_offset_param() {
16222 let ctx = test_context();
16223 let mut req = Request::new(Method::Get, "/items?offset=40&limit=10");
16224 req.insert_extension(QueryParams::parse("offset=40&limit=10"));
16225
16226 let p = futures_executor::block_on(Pagination::from_request(&ctx, &mut req)).unwrap();
16227 assert_eq!(p.offset(), 40);
16228 assert_eq!(p.per_page(), 10);
16229 }
16230
16231 #[test]
16232 fn pagination_extractor_clamps_max_per_page() {
16233 let ctx = test_context();
16234 let mut req = Request::new(Method::Get, "/items?per_page=1000");
16235 req.insert_extension(QueryParams::parse("per_page=1000"));
16236
16237 let p = futures_executor::block_on(Pagination::from_request(&ctx, &mut req)).unwrap();
16238 assert_eq!(p.per_page(), MAX_PER_PAGE);
16239 }
16240
16241 #[test]
16242 fn pagination_extractor_invalid_page_uses_default() {
16243 let ctx = test_context();
16244 let mut req = Request::new(Method::Get, "/items?page=abc");
16245 req.insert_extension(QueryParams::parse("page=abc"));
16246
16247 let p = futures_executor::block_on(Pagination::from_request(&ctx, &mut req)).unwrap();
16248 assert_eq!(p.page(), DEFAULT_PAGE);
16249 }
16250
16251 #[test]
16256 fn page_new() {
16257 let items = vec!["a", "b", "c"];
16258 let pagination = Pagination::new(2, 10);
16259 let page = Page::new(items.clone(), 100, pagination, "/items".to_string());
16260
16261 assert_eq!(page.items, items);
16262 assert_eq!(page.total, 100);
16263 assert_eq!(page.page, 2);
16264 assert_eq!(page.per_page, 10);
16265 assert_eq!(page.pages, 10);
16266 }
16267
16268 #[test]
16269 fn page_with_values() {
16270 let items = vec![1, 2, 3];
16271 let page = Page::with_values(items.clone(), 50, 3, 10, "/users");
16272
16273 assert_eq!(page.items, items);
16274 assert_eq!(page.total, 50);
16275 assert_eq!(page.page, 3);
16276 assert_eq!(page.per_page, 10);
16277 assert_eq!(page.pages, 5);
16278 }
16279
16280 #[test]
16281 fn page_len_is_empty() {
16282 let page: Page<i32> = Page::with_values(vec![], 0, 1, 10, "/items");
16283 assert!(page.is_empty());
16284 assert_eq!(page.len(), 0);
16285
16286 let page = Page::with_values(vec![1, 2, 3], 100, 1, 10, "/items");
16287 assert!(!page.is_empty());
16288 assert_eq!(page.len(), 3);
16289 }
16290
16291 #[test]
16292 fn page_has_next_prev() {
16293 let page = Page::with_values(vec![1, 2, 3], 100, 1, 10, "/items");
16295 assert!(!page.has_prev());
16296 assert!(page.has_next());
16297
16298 let page = Page::with_values(vec![1, 2, 3], 100, 5, 10, "/items");
16300 assert!(page.has_prev());
16301 assert!(page.has_next());
16302
16303 let page = Page::with_values(vec![1, 2, 3], 100, 10, 10, "/items");
16305 assert!(page.has_prev());
16306 assert!(!page.has_next());
16307 }
16308
16309 #[test]
16310 fn page_map() {
16311 let page = Page::with_values(vec![1, 2, 3], 100, 1, 10, "/items");
16312 let mapped = page.map(|n| n * 2);
16313
16314 assert_eq!(mapped.items, vec![2, 4, 6]);
16315 assert_eq!(mapped.total, 100);
16316 assert_eq!(mapped.page, 1);
16317 }
16318
16319 #[test]
16324 fn page_link_header_first_page() {
16325 let page = Page::with_values(vec![1, 2, 3], 100, 1, 10, "/items");
16326 let link = page.link_header();
16327
16328 assert!(link.contains("rel=\"first\""));
16329 assert!(link.contains("rel=\"last\""));
16330 assert!(link.contains("rel=\"next\""));
16331 assert!(!link.contains("rel=\"prev\"")); assert!(link.contains("page=1"));
16333 assert!(link.contains("page=2")); assert!(link.contains("page=10")); }
16336
16337 #[test]
16338 fn page_link_header_middle_page() {
16339 let page = Page::with_values(vec![1, 2, 3], 100, 5, 10, "/items");
16340 let link = page.link_header();
16341
16342 assert!(link.contains("rel=\"first\""));
16343 assert!(link.contains("rel=\"last\""));
16344 assert!(link.contains("rel=\"next\""));
16345 assert!(link.contains("rel=\"prev\""));
16346 assert!(link.contains("page=4")); assert!(link.contains("page=6")); }
16349
16350 #[test]
16351 fn page_link_header_last_page() {
16352 let page = Page::with_values(vec![1, 2, 3], 100, 10, 10, "/items");
16353 let link = page.link_header();
16354
16355 assert!(link.contains("rel=\"first\""));
16356 assert!(link.contains("rel=\"last\""));
16357 assert!(!link.contains("rel=\"next\"")); assert!(link.contains("rel=\"prev\""));
16359 assert!(link.contains("page=9")); }
16361
16362 #[test]
16363 fn page_link_header_single_page() {
16364 let page = Page::with_values(vec![1, 2, 3], 3, 1, 10, "/items");
16365 let link = page.link_header();
16366
16367 assert!(link.contains("rel=\"first\""));
16368 assert!(link.contains("rel=\"last\""));
16369 assert!(!link.contains("rel=\"next\"")); assert!(!link.contains("rel=\"prev\""));
16371 }
16372
16373 #[test]
16378 fn page_into_response_status_ok() {
16379 let page = Page::with_values(vec![1, 2, 3], 100, 1, 10, "/items");
16380 let response = page.into_response();
16381
16382 assert_eq!(response.status().as_u16(), 200);
16383 }
16384
16385 #[test]
16386 fn page_into_response_content_type() {
16387 let page = Page::with_values(vec![1, 2, 3], 100, 1, 10, "/items");
16388 let response = page.into_response();
16389
16390 let content_type = response
16391 .headers()
16392 .iter()
16393 .find(|(name, _)| name == "content-type");
16394 assert!(content_type.is_some());
16395 assert_eq!(content_type.unwrap().1, b"application/json");
16396 }
16397
16398 #[test]
16399 fn page_into_response_has_link_header() {
16400 let page = Page::with_values(vec![1, 2, 3], 100, 1, 10, "/items");
16401 let response = page.into_response();
16402
16403 let link_header = response.headers().iter().find(|(name, _)| name == "link");
16404 assert!(link_header.is_some());
16405
16406 let link_value = String::from_utf8_lossy(&link_header.unwrap().1);
16407 assert!(link_value.contains("rel=\"first\""));
16408 }
16409
16410 #[test]
16411 fn page_into_response_has_pagination_headers() {
16412 let page = Page::with_values(vec![1, 2, 3], 100, 2, 10, "/items");
16413 let response = page.into_response();
16414
16415 let get_header = |name: &str| {
16416 response
16417 .headers()
16418 .iter()
16419 .find(|(n, _)| n == name)
16420 .map(|(_, v)| String::from_utf8_lossy(v).to_string())
16421 };
16422
16423 assert_eq!(get_header("x-total-count"), Some("100".to_string()));
16424 assert_eq!(get_header("x-page"), Some("2".to_string()));
16425 assert_eq!(get_header("x-per-page"), Some("10".to_string()));
16426 assert_eq!(get_header("x-total-pages"), Some("10".to_string()));
16427 }
16428
16429 #[test]
16430 fn page_into_response_json_body() {
16431 let page = Page::with_values(vec!["a", "b", "c"], 100, 2, 10, "/items");
16432 let response = page.into_response();
16433
16434 let body_str = match response.body_ref() {
16435 crate::response::ResponseBody::Bytes(b) => String::from_utf8_lossy(b).to_string(),
16436 _ => panic!("Expected bytes body"),
16437 };
16438
16439 let json: serde_json::Value = serde_json::from_str(&body_str).unwrap();
16441 assert_eq!(json["items"], serde_json::json!(["a", "b", "c"]));
16442 assert_eq!(json["total"], 100);
16443 assert_eq!(json["page"], 2);
16444 assert_eq!(json["per_page"], 10);
16445 assert_eq!(json["pages"], 10);
16446 }
16447
16448 #[test]
16453 fn pagination_config_default() {
16454 let config = PaginationConfig::default();
16455 assert_eq!(config.default_per_page, DEFAULT_PER_PAGE);
16456 assert_eq!(config.max_per_page, MAX_PER_PAGE);
16457 assert_eq!(config.default_page, DEFAULT_PAGE);
16458 }
16459
16460 #[test]
16461 fn pagination_config_builder() {
16462 let config = PaginationConfig::new()
16463 .default_per_page(50)
16464 .max_per_page(200)
16465 .default_page(1);
16466
16467 assert_eq!(config.default_per_page, 50);
16468 assert_eq!(config.max_per_page, 200);
16469 assert_eq!(config.default_page, 1);
16470 }
16471
16472 #[test]
16477 fn pagination_paginate_helper() {
16478 let pagination = Pagination::new(2, 10);
16479 let items = vec!["item1", "item2", "item3"];
16480
16481 let page = pagination.paginate(items.clone(), 100, "/api/items");
16482
16483 assert_eq!(page.items, items);
16484 assert_eq!(page.total, 100);
16485 assert_eq!(page.page, 2);
16486 assert_eq!(page.per_page, 10);
16487 assert_eq!(page.pages, 10);
16488 }
16489
16490 #[test]
16491 fn pagination_equality() {
16492 let p1 = Pagination::new(2, 10);
16493 let p2 = Pagination::new(2, 10);
16494 let p3 = Pagination::new(3, 10);
16495
16496 assert_eq!(p1, p2);
16497 assert_ne!(p1, p3);
16498 }
16499
16500 #[test]
16501 fn pagination_copy_clone() {
16502 let p1 = Pagination::new(2, 10);
16503 let p2 = p1; let p3 = p1; assert_eq!(p1, p2);
16507 assert_eq!(p1, p3);
16508 }
16509}
16510
16511#[cfg(test)]
16512mod path_tests {
16513 use super::*;
16514 use crate::request::Method;
16515 use serde::Deserialize;
16516
16517 fn test_context() -> RequestContext {
16519 let cx = asupersync::Cx::for_testing();
16520 RequestContext::new(cx, 12345)
16521 }
16522
16523 fn request_with_params(params: Vec<(&str, &str)>) -> Request {
16525 let mut req = Request::new(Method::Get, "/test");
16526 let path_params = PathParams::from_pairs(
16527 params
16528 .into_iter()
16529 .map(|(k, v)| (k.to_string(), v.to_string()))
16530 .collect(),
16531 );
16532 req.insert_extension(path_params);
16533 req
16534 }
16535
16536 #[test]
16537 fn path_params_get() {
16538 let params = PathParams::from_pairs(vec![("id".to_string(), "42".to_string())]);
16539 assert_eq!(params.get("id"), Some("42"));
16540 assert_eq!(params.get("unknown"), None);
16541 }
16542
16543 #[test]
16544 fn path_params_len() {
16545 let params = PathParams::new();
16546 assert!(params.is_empty());
16547 assert_eq!(params.len(), 0);
16548
16549 let params = PathParams::from_pairs(vec![
16550 ("a".to_string(), "1".to_string()),
16551 ("b".to_string(), "2".to_string()),
16552 ]);
16553 assert!(!params.is_empty());
16554 assert_eq!(params.len(), 2);
16555 }
16556
16557 #[test]
16558 fn path_extract_single_i64() {
16559 let ctx = test_context();
16560 let mut req = request_with_params(vec![("id", "42")]);
16561
16562 let result = futures_executor::block_on(Path::<i64>::from_request(&ctx, &mut req));
16563 let Path(id) = result.unwrap();
16564 assert_eq!(id, 42);
16565 }
16566
16567 #[test]
16568 fn path_extract_single_string() {
16569 let ctx = test_context();
16570 let mut req = request_with_params(vec![("name", "alice")]);
16571
16572 let result = futures_executor::block_on(Path::<String>::from_request(&ctx, &mut req));
16573 let Path(name) = result.unwrap();
16574 assert_eq!(name, "alice");
16575 }
16576
16577 #[test]
16578 fn path_extract_single_u32() {
16579 let ctx = test_context();
16580 let mut req = request_with_params(vec![("count", "100")]);
16581
16582 let result = futures_executor::block_on(Path::<u32>::from_request(&ctx, &mut req));
16583 let Path(count) = result.unwrap();
16584 assert_eq!(count, 100);
16585 }
16586
16587 #[test]
16588 fn path_extract_tuple() {
16589 let ctx = test_context();
16590 let mut req = request_with_params(vec![("user_id", "42"), ("post_id", "99")]);
16591
16592 let result = futures_executor::block_on(Path::<(i64, i64)>::from_request(&ctx, &mut req));
16593 let Path((user_id, post_id)) = result.unwrap();
16594 assert_eq!(user_id, 42);
16595 assert_eq!(post_id, 99);
16596 }
16597
16598 #[test]
16599 fn path_extract_tuple_mixed_types() {
16600 let ctx = test_context();
16601 let mut req = request_with_params(vec![("name", "alice"), ("id", "123")]);
16602
16603 let result =
16604 futures_executor::block_on(Path::<(String, i64)>::from_request(&ctx, &mut req));
16605 let Path((name, id)) = result.unwrap();
16606 assert_eq!(name, "alice");
16607 assert_eq!(id, 123);
16608 }
16609
16610 #[test]
16611 fn path_extract_struct() {
16612 #[derive(Deserialize, Debug, PartialEq)]
16613 struct UserPath {
16614 user_id: i64,
16615 post_id: i64,
16616 }
16617
16618 let ctx = test_context();
16619 let mut req = request_with_params(vec![("user_id", "42"), ("post_id", "99")]);
16620
16621 let result = futures_executor::block_on(Path::<UserPath>::from_request(&ctx, &mut req));
16622 let Path(path) = result.unwrap();
16623 assert_eq!(path.user_id, 42);
16624 assert_eq!(path.post_id, 99);
16625 }
16626
16627 #[test]
16628 fn path_extract_missing_params() {
16629 let ctx = test_context();
16630 let mut req = Request::new(Method::Get, "/test");
16631 let result = futures_executor::block_on(Path::<i64>::from_request(&ctx, &mut req));
16634 assert!(matches!(result, Err(PathExtractError::MissingPathParams)));
16635 }
16636
16637 #[test]
16638 fn path_extract_invalid_type() {
16639 let ctx = test_context();
16640 let mut req = request_with_params(vec![("id", "not_a_number")]);
16641
16642 let result = futures_executor::block_on(Path::<i64>::from_request(&ctx, &mut req));
16643 assert!(matches!(
16644 result,
16645 Err(PathExtractError::InvalidValue { name, .. }) if name == "id"
16646 ));
16647 }
16648
16649 #[test]
16650 fn path_extract_negative_for_unsigned() {
16651 let ctx = test_context();
16652 let mut req = request_with_params(vec![("count", "-5")]);
16653
16654 let result = futures_executor::block_on(Path::<u32>::from_request(&ctx, &mut req));
16655 assert!(matches!(result, Err(PathExtractError::InvalidValue { .. })));
16656 }
16657
16658 #[test]
16659 fn path_extract_f64() {
16660 let ctx = test_context();
16661 let mut req = request_with_params(vec![("price", "19.99")]);
16662
16663 let result = futures_executor::block_on(Path::<f64>::from_request(&ctx, &mut req));
16664 let Path(price) = result.unwrap();
16665 assert!((price - 19.99).abs() < 0.001);
16666 }
16667
16668 #[test]
16669 fn path_deref() {
16670 let path = Path(42i64);
16671 assert_eq!(*path, 42);
16672 }
16673
16674 #[test]
16675 fn path_into_inner() {
16676 let path = Path("hello".to_string());
16677 assert_eq!(path.into_inner(), "hello");
16678 }
16679
16680 #[test]
16681 fn path_error_display() {
16682 let err = PathExtractError::MissingPathParams;
16683 assert!(err.to_string().contains("not available"));
16684
16685 let err = PathExtractError::MissingParam {
16686 name: "user_id".to_string(),
16687 };
16688 assert!(err.to_string().contains("user_id"));
16689
16690 let err = PathExtractError::InvalidValue {
16691 name: "id".to_string(),
16692 value: "abc".to_string(),
16693 expected: "i64",
16694 message: "invalid digit".to_string(),
16695 };
16696 assert!(err.to_string().contains("id"));
16697 assert!(err.to_string().contains("abc"));
16698 assert!(err.to_string().contains("i64"));
16699 }
16700
16701 #[test]
16702 fn path_extract_bool() {
16703 let ctx = test_context();
16704 let mut req = request_with_params(vec![("active", "true")]);
16705
16706 let result = futures_executor::block_on(Path::<bool>::from_request(&ctx, &mut req));
16707 let Path(active) = result.unwrap();
16708 assert!(active);
16709 }
16710
16711 #[test]
16712 fn path_extract_char() {
16713 let ctx = test_context();
16714 let mut req = request_with_params(vec![("letter", "A")]);
16715
16716 let result = futures_executor::block_on(Path::<char>::from_request(&ctx, &mut req));
16717 let Path(letter) = result.unwrap();
16718 assert_eq!(letter, 'A');
16719 }
16720}
16721
16722#[cfg(test)]
16723mod query_tests {
16724 use super::*;
16725 use crate::request::Method;
16726 use serde::Deserialize;
16727
16728 fn test_context() -> RequestContext {
16730 let cx = asupersync::Cx::for_testing();
16731 RequestContext::new(cx, 12345)
16732 }
16733
16734 fn request_with_query(query: &str) -> Request {
16736 let mut req = Request::new(Method::Get, "/test");
16737 req.set_query(Some(query.to_string()));
16738 req
16739 }
16740
16741 #[test]
16742 fn query_params_parse() {
16743 let params = QueryParams::parse("a=1&b=2&c=3");
16744 assert_eq!(params.get("a"), Some("1"));
16745 assert_eq!(params.get("b"), Some("2"));
16746 assert_eq!(params.get("c"), Some("3"));
16747 assert_eq!(params.get("d"), None);
16748 }
16749
16750 #[test]
16751 fn query_params_multi_value() {
16752 let params = QueryParams::parse("tag=rust&tag=web&tag=api");
16753 assert_eq!(params.get("tag"), Some("rust")); assert_eq!(params.get_all("tag"), vec!["rust", "web", "api"]);
16755 }
16756
16757 #[test]
16758 fn query_params_percent_decode() {
16759 let params = QueryParams::parse("msg=hello%20world&name=caf%C3%A9");
16760 assert_eq!(params.get("msg"), Some("hello world"));
16761 assert_eq!(params.get("name"), Some("café"));
16762 }
16763
16764 #[test]
16765 fn query_params_plus_as_space() {
16766 let params = QueryParams::parse("msg=hello+world");
16767 assert_eq!(params.get("msg"), Some("hello world"));
16768 }
16769
16770 #[test]
16771 fn query_params_empty_value() {
16772 let params = QueryParams::parse("flag&name=alice");
16773 assert!(params.contains("flag"));
16774 assert_eq!(params.get("flag"), Some(""));
16775 assert_eq!(params.get("name"), Some("alice"));
16776 }
16777
16778 #[test]
16779 fn query_extract_struct() {
16780 #[derive(Deserialize, Debug, PartialEq)]
16781 struct SearchParams {
16782 q: String,
16783 page: i32,
16784 }
16785
16786 let ctx = test_context();
16787 let mut req = request_with_query("q=rust&page=5");
16788
16789 let result =
16790 futures_executor::block_on(Query::<SearchParams>::from_request(&ctx, &mut req));
16791 let Query(params) = result.unwrap();
16792 assert_eq!(params.q, "rust");
16793 assert_eq!(params.page, 5);
16794 }
16795
16796 #[test]
16797 fn query_extract_optional_field() {
16798 #[derive(Deserialize, Debug)]
16799 struct Params {
16800 required: String,
16801 optional: Option<i32>,
16802 }
16803
16804 let ctx = test_context();
16805
16806 let mut req = request_with_query("required=hello&optional=42");
16808 let result = futures_executor::block_on(Query::<Params>::from_request(&ctx, &mut req));
16809 let Query(params) = result.unwrap();
16810 assert_eq!(params.required, "hello");
16811 assert_eq!(params.optional, Some(42));
16812
16813 let mut req = request_with_query("required=hello");
16815 let result = futures_executor::block_on(Query::<Params>::from_request(&ctx, &mut req));
16816 let Query(params) = result.unwrap();
16817 assert_eq!(params.required, "hello");
16818 assert_eq!(params.optional, None);
16819 }
16820
16821 #[test]
16822 fn query_extract_multi_value() {
16823 #[derive(Deserialize, Debug)]
16824 struct Params {
16825 tags: Vec<String>,
16826 }
16827
16828 let ctx = test_context();
16829 let mut req = request_with_query("tags=rust&tags=web&tags=api");
16830
16831 let result = futures_executor::block_on(Query::<Params>::from_request(&ctx, &mut req));
16832 let Query(params) = result.unwrap();
16833 assert_eq!(params.tags, vec!["rust", "web", "api"]);
16834 }
16835
16836 #[test]
16837 fn query_extract_default_value() {
16838 #[derive(Deserialize, Debug)]
16839 struct Params {
16840 name: String,
16841 #[serde(default)]
16842 limit: i32,
16843 }
16844
16845 let ctx = test_context();
16846 let mut req = request_with_query("name=test");
16847
16848 let result = futures_executor::block_on(Query::<Params>::from_request(&ctx, &mut req));
16849 let Query(params) = result.unwrap();
16850 assert_eq!(params.name, "test");
16851 assert_eq!(params.limit, 0); }
16853
16854 #[test]
16855 fn query_extract_bool() {
16856 #[derive(Deserialize, Debug)]
16857 struct Params {
16858 active: bool,
16859 archived: bool,
16860 }
16861
16862 let ctx = test_context();
16863 let mut req = request_with_query("active=true&archived=false");
16864
16865 let result = futures_executor::block_on(Query::<Params>::from_request(&ctx, &mut req));
16866 let Query(params) = result.unwrap();
16867 assert!(params.active);
16868 assert!(!params.archived);
16869 }
16870
16871 #[test]
16872 fn query_extract_bool_variants() {
16873 #[derive(Deserialize, Debug)]
16874 struct Params {
16875 a: bool,
16876 b: bool,
16877 c: bool,
16878 }
16879
16880 let ctx = test_context();
16881 let mut req = request_with_query("a=1&b=yes&c=on");
16882
16883 let result = futures_executor::block_on(Query::<Params>::from_request(&ctx, &mut req));
16884 let Query(params) = result.unwrap();
16885 assert!(params.a);
16886 assert!(params.b);
16887 assert!(params.c);
16888 }
16889
16890 #[test]
16891 fn query_extract_missing_required_fails() {
16892 #[derive(Deserialize, Debug)]
16893 #[allow(dead_code)]
16894 struct Params {
16895 required: String,
16896 }
16897
16898 let ctx = test_context();
16899 let mut req = request_with_query("other=value");
16900
16901 let result = futures_executor::block_on(Query::<Params>::from_request(&ctx, &mut req));
16902 assert!(result.is_err());
16903 }
16904
16905 #[test]
16906 fn query_extract_invalid_type_fails() {
16907 #[derive(Deserialize, Debug)]
16908 #[allow(dead_code)]
16909 struct Params {
16910 count: i32,
16911 }
16912
16913 let ctx = test_context();
16914 let mut req = request_with_query("count=not_a_number");
16915
16916 let result = futures_executor::block_on(Query::<Params>::from_request(&ctx, &mut req));
16917 assert!(result.is_err());
16918 }
16919
16920 #[test]
16921 fn query_extract_empty_query() {
16922 #[derive(Deserialize, Debug, Default)]
16923 struct Params {
16924 #[serde(default)]
16925 name: String,
16926 }
16927
16928 let ctx = test_context();
16929 let mut req = request_with_query("");
16930
16931 let result = futures_executor::block_on(Query::<Params>::from_request(&ctx, &mut req));
16932 let Query(params) = result.unwrap();
16933 assert_eq!(params.name, "");
16934 }
16935
16936 #[test]
16937 fn query_extract_float() {
16938 #[derive(Deserialize, Debug)]
16939 struct Params {
16940 price: f64,
16941 }
16942
16943 let ctx = test_context();
16944 let mut req = request_with_query("price=29.99");
16945
16946 let result = futures_executor::block_on(Query::<Params>::from_request(&ctx, &mut req));
16947 let Query(params) = result.unwrap();
16948 assert!((params.price - 29.99).abs() < 0.001);
16949 }
16950
16951 #[test]
16952 fn query_deref() {
16953 #[derive(Deserialize, Debug)]
16954 struct Params {
16955 name: String,
16956 }
16957
16958 let query = Query(Params {
16959 name: "test".to_string(),
16960 });
16961 assert_eq!(query.name, "test");
16962 }
16963
16964 #[test]
16965 fn query_into_inner() {
16966 #[derive(Deserialize, Debug, PartialEq)]
16967 struct Params {
16968 value: i32,
16969 }
16970
16971 let query = Query(Params { value: 42 });
16972 assert_eq!(query.into_inner(), Params { value: 42 });
16973 }
16974
16975 #[test]
16976 fn query_error_display() {
16977 let err = QueryExtractError::MissingParam {
16978 name: "user_id".to_string(),
16979 };
16980 assert!(err.to_string().contains("user_id"));
16981
16982 let err = QueryExtractError::InvalidValue {
16983 name: "count".to_string(),
16984 value: "abc".to_string(),
16985 expected: "i32",
16986 message: "invalid digit".to_string(),
16987 };
16988 assert!(err.to_string().contains("count"));
16989 assert!(err.to_string().contains("abc"));
16990 assert!(err.to_string().contains("i32"));
16991 }
16992
16993 #[test]
16994 fn query_params_keys() {
16995 let params = QueryParams::parse("a=1&b=2&a=3&c=4");
16996 let keys: Vec<&str> = params.keys().collect();
16997 assert_eq!(keys, vec!["a", "b", "c"]); }
16999
17000 #[test]
17001 fn query_params_len() {
17002 let params = QueryParams::parse("a=1&b=2&c=3");
17003 assert_eq!(params.len(), 3);
17004 assert!(!params.is_empty());
17005
17006 let empty = QueryParams::new();
17007 assert_eq!(empty.len(), 0);
17008 assert!(empty.is_empty());
17009 }
17010}
17011
17012#[cfg(test)]
17017mod optional_tests {
17018 use super::*;
17019 use crate::request::Method;
17020
17021 fn test_context() -> RequestContext {
17022 let cx = asupersync::Cx::for_testing();
17023 RequestContext::new(cx, 99999)
17024 }
17025
17026 #[test]
17029 fn optional_json_present_valid() {
17030 use serde::Deserialize;
17031
17032 #[derive(Deserialize, PartialEq, Debug)]
17033 struct Data {
17034 value: i32,
17035 }
17036
17037 let ctx = test_context();
17038 let mut req = Request::new(Method::Post, "/test");
17039 req.headers_mut()
17040 .insert("content-type", b"application/json".to_vec());
17041 req.set_body(Body::Bytes(b"{\"value\": 42}".to_vec()));
17042
17043 let result = futures_executor::block_on(Option::<Json<Data>>::from_request(&ctx, &mut req));
17044 let Some(Json(data)) = result.unwrap() else {
17045 panic!("Expected Some");
17046 };
17047 assert_eq!(data.value, 42);
17048 }
17049
17050 #[test]
17051 fn optional_json_invalid_content_type_returns_none() {
17052 use serde::Deserialize;
17053
17054 #[derive(Deserialize)]
17055 #[allow(dead_code)]
17056 struct Data {
17057 value: i32,
17058 }
17059
17060 let ctx = test_context();
17061 let mut req = Request::new(Method::Post, "/test");
17062 req.headers_mut()
17063 .insert("content-type", b"text/plain".to_vec());
17064 req.set_body(Body::Bytes(b"{\"value\": 42}".to_vec()));
17065
17066 let result = futures_executor::block_on(Option::<Json<Data>>::from_request(&ctx, &mut req));
17067 assert!(result.unwrap().is_none());
17068 }
17069
17070 #[test]
17071 fn optional_json_missing_body_returns_none() {
17072 use serde::Deserialize;
17073
17074 #[derive(Deserialize)]
17075 #[allow(dead_code)]
17076 struct Data {
17077 value: i32,
17078 }
17079
17080 let ctx = test_context();
17081 let mut req = Request::new(Method::Post, "/test");
17082 req.headers_mut()
17083 .insert("content-type", b"application/json".to_vec());
17084 let result = futures_executor::block_on(Option::<Json<Data>>::from_request(&ctx, &mut req));
17087 assert!(result.unwrap().is_none());
17089 }
17090
17091 #[test]
17092 fn optional_json_malformed_returns_none() {
17093 use serde::Deserialize;
17094
17095 #[derive(Deserialize)]
17096 #[allow(dead_code)]
17097 struct Data {
17098 value: i32,
17099 }
17100
17101 let ctx = test_context();
17102 let mut req = Request::new(Method::Post, "/test");
17103 req.headers_mut()
17104 .insert("content-type", b"application/json".to_vec());
17105 req.set_body(Body::Bytes(b"{ not valid json }".to_vec()));
17106
17107 let result = futures_executor::block_on(Option::<Json<Data>>::from_request(&ctx, &mut req));
17108 assert!(result.unwrap().is_none());
17109 }
17110
17111 #[test]
17114 fn optional_path_present_valid() {
17115 let ctx = test_context();
17116 let mut req = Request::new(Method::Get, "/users/42");
17117 req.insert_extension(PathParams::from_pairs(vec![(
17118 "id".to_string(),
17119 "42".to_string(),
17120 )]));
17121
17122 let result = futures_executor::block_on(Option::<Path<i64>>::from_request(&ctx, &mut req));
17123 let Some(Path(id)) = result.unwrap() else {
17124 panic!("Expected Some");
17125 };
17126 assert_eq!(id, 42);
17127 }
17128
17129 #[test]
17130 fn optional_path_missing_params_returns_none() {
17131 let ctx = test_context();
17132 let mut req = Request::new(Method::Get, "/users/42");
17133 let result = futures_executor::block_on(Option::<Path<i64>>::from_request(&ctx, &mut req));
17136 assert!(result.unwrap().is_none());
17137 }
17138
17139 #[test]
17140 fn optional_path_invalid_type_returns_none() {
17141 let ctx = test_context();
17142 let mut req = Request::new(Method::Get, "/users/abc");
17143 req.insert_extension(PathParams::from_pairs(vec![(
17144 "id".to_string(),
17145 "abc".to_string(),
17146 )]));
17147
17148 let result = futures_executor::block_on(Option::<Path<i64>>::from_request(&ctx, &mut req));
17149 assert!(result.unwrap().is_none());
17150 }
17151
17152 #[test]
17155 fn optional_query_present_valid() {
17156 use serde::Deserialize;
17157
17158 #[derive(Deserialize, PartialEq, Debug)]
17159 struct Params {
17160 page: i32,
17161 }
17162
17163 let ctx = test_context();
17164 let mut req = Request::new(Method::Get, "/items");
17165 req.set_query(Some("page=5".to_string()));
17166
17167 let result =
17168 futures_executor::block_on(Option::<Query<Params>>::from_request(&ctx, &mut req));
17169 let Some(Query(params)) = result.unwrap() else {
17170 panic!("Expected Some");
17171 };
17172 assert_eq!(params.page, 5);
17173 }
17174
17175 #[test]
17176 fn optional_query_missing_returns_none() {
17177 use serde::Deserialize;
17178
17179 #[derive(Deserialize)]
17180 #[allow(dead_code)]
17181 struct Params {
17182 required: String,
17183 }
17184
17185 let ctx = test_context();
17186 let mut req = Request::new(Method::Get, "/items");
17187 let result =
17190 futures_executor::block_on(Option::<Query<Params>>::from_request(&ctx, &mut req));
17191 assert!(result.unwrap().is_none());
17192 }
17193
17194 #[test]
17195 fn optional_query_invalid_type_returns_none() {
17196 use serde::Deserialize;
17197
17198 #[derive(Deserialize)]
17199 #[allow(dead_code)]
17200 struct Params {
17201 page: i32,
17202 }
17203
17204 let ctx = test_context();
17205 let mut req = Request::new(Method::Get, "/items");
17206 req.set_query(Some("page=abc".to_string()));
17207
17208 let result =
17209 futures_executor::block_on(Option::<Query<Params>>::from_request(&ctx, &mut req));
17210 assert!(result.unwrap().is_none());
17211 }
17212
17213 #[test]
17216 fn optional_state_present() {
17217 let ctx = test_context();
17218 let mut req = Request::new(Method::Get, "/");
17219 let app_state = AppState::new().with(42i32);
17220 req.insert_extension(app_state);
17221
17222 let result = futures_executor::block_on(Option::<State<i32>>::from_request(&ctx, &mut req));
17223 let Some(State(val)) = result.unwrap() else {
17224 panic!("Expected Some");
17225 };
17226 assert_eq!(val, 42);
17227 }
17228
17229 #[test]
17230 fn optional_state_missing_returns_none() {
17231 let ctx = test_context();
17232 let mut req = Request::new(Method::Get, "/");
17233 let result = futures_executor::block_on(Option::<State<i32>>::from_request(&ctx, &mut req));
17236 assert!(result.unwrap().is_none());
17237 }
17238
17239 #[test]
17240 fn optional_state_wrong_type_returns_none() {
17241 let ctx = test_context();
17242 let mut req = Request::new(Method::Get, "/");
17243 let app_state = AppState::new().with("string".to_string()); req.insert_extension(app_state);
17245
17246 let result = futures_executor::block_on(Option::<State<i32>>::from_request(&ctx, &mut req));
17247 assert!(result.unwrap().is_none());
17248 }
17249}
17250
17251#[cfg(test)]
17256mod combination_tests {
17257 use super::*;
17258 use crate::request::Method;
17259
17260 fn test_context() -> RequestContext {
17261 let cx = asupersync::Cx::for_testing();
17262 RequestContext::new(cx, 88888)
17263 }
17264
17265 #[test]
17266 fn path_and_query_together() {
17267 use serde::Deserialize;
17268
17269 #[derive(Deserialize, PartialEq, Debug)]
17270 struct QueryParams {
17271 limit: i32,
17272 }
17273
17274 let ctx = test_context();
17275 let mut req = Request::new(Method::Get, "/users/42");
17276 req.insert_extension(PathParams::from_pairs(vec![(
17277 "id".to_string(),
17278 "42".to_string(),
17279 )]));
17280 req.set_query(Some("limit=10".to_string()));
17281
17282 let path_result = futures_executor::block_on(Path::<i64>::from_request(&ctx, &mut req));
17284 let Path(user_id) = path_result.unwrap();
17285 assert_eq!(user_id, 42);
17286
17287 let query_result =
17289 futures_executor::block_on(Query::<QueryParams>::from_request(&ctx, &mut req));
17290 let Query(params) = query_result.unwrap();
17291 assert_eq!(params.limit, 10);
17292 }
17293
17294 #[test]
17295 fn json_body_and_path() {
17296 use serde::Deserialize;
17297
17298 #[derive(Deserialize, PartialEq, Debug)]
17299 struct CreateItem {
17300 name: String,
17301 }
17302
17303 let ctx = test_context();
17304 let mut req = Request::new(Method::Post, "/categories/5/items");
17305 req.headers_mut()
17306 .insert("content-type", b"application/json".to_vec());
17307 req.set_body(Body::Bytes(b"{\"name\": \"Widget\"}".to_vec()));
17308 req.insert_extension(PathParams::from_pairs(vec![(
17309 "cat_id".to_string(),
17310 "5".to_string(),
17311 )]));
17312
17313 let path_result = futures_executor::block_on(Path::<i64>::from_request(&ctx, &mut req));
17315 let Path(cat_id) = path_result.unwrap();
17316 assert_eq!(cat_id, 5);
17317
17318 let json_result =
17320 futures_executor::block_on(Json::<CreateItem>::from_request(&ctx, &mut req));
17321 let Json(item) = json_result.unwrap();
17322 assert_eq!(item.name, "Widget");
17323 }
17324
17325 #[test]
17326 fn state_and_query() {
17327 use serde::Deserialize;
17328
17329 #[derive(Deserialize, PartialEq, Debug)]
17330 struct SearchParams {
17331 q: String,
17332 }
17333
17334 #[derive(Clone, PartialEq, Debug)]
17335 struct Config {
17336 max_results: i32,
17337 }
17338
17339 let ctx = test_context();
17340 let mut req = Request::new(Method::Get, "/search");
17341 req.set_query(Some("q=hello".to_string()));
17342 let app_state = AppState::new().with(Config { max_results: 100 });
17343 req.insert_extension(app_state);
17344
17345 let state_result =
17347 futures_executor::block_on(State::<Config>::from_request(&ctx, &mut req));
17348 let State(config) = state_result.unwrap();
17349 assert_eq!(config.max_results, 100);
17350
17351 let query_result =
17353 futures_executor::block_on(Query::<SearchParams>::from_request(&ctx, &mut req));
17354 let Query(params) = query_result.unwrap();
17355 assert_eq!(params.q, "hello");
17356 }
17357
17358 #[test]
17359 fn multiple_path_params_with_struct() {
17360 use serde::Deserialize;
17361
17362 #[derive(Deserialize, PartialEq, Debug)]
17363 struct CommentPath {
17364 post_id: i64,
17365 comment_id: i64,
17366 }
17367
17368 let ctx = test_context();
17369 let mut req = Request::new(Method::Get, "/posts/123/comments/456");
17370 req.insert_extension(PathParams::from_pairs(vec![
17371 ("post_id".to_string(), "123".to_string()),
17372 ("comment_id".to_string(), "456".to_string()),
17373 ]));
17374
17375 let result = futures_executor::block_on(Path::<CommentPath>::from_request(&ctx, &mut req));
17376 let Path(path) = result.unwrap();
17377 assert_eq!(path.post_id, 123);
17378 assert_eq!(path.comment_id, 456);
17379 }
17380
17381 #[test]
17382 fn optional_mixed_with_required() {
17383 use serde::Deserialize;
17384
17385 #[derive(Deserialize, PartialEq, Debug)]
17386 struct OptionalParams {
17387 page: Option<i32>,
17388 }
17389
17390 let ctx = test_context();
17391 let mut req = Request::new(Method::Get, "/users/42");
17392 req.insert_extension(PathParams::from_pairs(vec![(
17393 "id".to_string(),
17394 "42".to_string(),
17395 )]));
17396
17397 let path_result = futures_executor::block_on(Path::<i64>::from_request(&ctx, &mut req));
17399 let Path(id) = path_result.unwrap();
17400 assert_eq!(id, 42);
17401
17402 let query_result =
17404 futures_executor::block_on(Query::<OptionalParams>::from_request(&ctx, &mut req));
17405 let Query(params) = query_result.unwrap();
17406 assert_eq!(params.page, None);
17407 }
17408
17409 #[test]
17410 fn request_context_extraction() {
17411 let ctx = test_context();
17412 let mut req = Request::new(Method::Get, "/");
17413
17414 let result = futures_executor::block_on(RequestContext::from_request(&ctx, &mut req));
17415 let extracted_ctx = result.unwrap();
17416 assert_eq!(extracted_ctx.request_id(), ctx.request_id());
17417 }
17418
17419 #[test]
17420 fn triple_extraction_path_query_state() {
17421 use serde::Deserialize;
17422
17423 #[derive(Deserialize, PartialEq, Debug)]
17424 struct QueryFilter {
17425 status: String,
17426 }
17427
17428 #[derive(Clone)]
17429 struct DbPool {
17430 connection_count: i32,
17431 }
17432
17433 let ctx = test_context();
17434 let mut req = Request::new(Method::Get, "/projects/99/tasks");
17435 req.insert_extension(PathParams::from_pairs(vec![(
17436 "project_id".to_string(),
17437 "99".to_string(),
17438 )]));
17439 req.set_query(Some("status=active".to_string()));
17440 let app_state = AppState::new().with(DbPool {
17441 connection_count: 10,
17442 });
17443 req.insert_extension(app_state);
17444
17445 let Path(project_id): Path<i32> =
17447 futures_executor::block_on(Path::<i32>::from_request(&ctx, &mut req)).unwrap();
17448 assert_eq!(project_id, 99);
17449
17450 let Query(filter): Query<QueryFilter> =
17452 futures_executor::block_on(Query::<QueryFilter>::from_request(&ctx, &mut req)).unwrap();
17453 assert_eq!(filter.status, "active");
17454
17455 let State(pool): State<DbPool> =
17457 futures_executor::block_on(State::<DbPool>::from_request(&ctx, &mut req)).unwrap();
17458 assert_eq!(pool.connection_count, 10);
17459 }
17460}
17461
17462#[cfg(test)]
17467mod edge_case_tests {
17468 use super::*;
17469 use crate::request::Method;
17470
17471 fn test_context() -> RequestContext {
17472 let cx = asupersync::Cx::for_testing();
17473 RequestContext::new(cx, 77777)
17474 }
17475
17476 #[test]
17479 fn json_with_unicode() {
17480 use serde::Deserialize;
17481
17482 #[derive(Deserialize, PartialEq, Debug)]
17483 struct Data {
17484 name: String,
17485 emoji: String,
17486 }
17487
17488 let ctx = test_context();
17489 let mut req = Request::new(Method::Post, "/test");
17490 req.headers_mut()
17491 .insert("content-type", b"application/json".to_vec());
17492 req.set_body(Body::Bytes(
17493 r#"{"name": "日本語", "emoji": "🎉🚀"}"#.as_bytes().to_vec(),
17494 ));
17495
17496 let result = futures_executor::block_on(Json::<Data>::from_request(&ctx, &mut req));
17497 let Json(data) = result.unwrap();
17498 assert_eq!(data.name, "日本語");
17499 assert_eq!(data.emoji, "🎉🚀");
17500 }
17501
17502 #[test]
17503 fn query_with_unicode_percent_encoded() {
17504 use serde::Deserialize;
17505
17506 #[derive(Deserialize, PartialEq, Debug)]
17507 struct Search {
17508 q: String,
17509 }
17510
17511 let ctx = test_context();
17512 let mut req = Request::new(Method::Get, "/search");
17513 req.set_query(Some(
17515 "q=%E3%81%93%E3%82%93%E3%81%AB%E3%81%A1%E3%81%AF".to_string(),
17516 ));
17517
17518 let result = futures_executor::block_on(Query::<Search>::from_request(&ctx, &mut req));
17519 let Query(search) = result.unwrap();
17520 assert_eq!(search.q, "こんにちは");
17521 }
17522
17523 #[test]
17524 fn path_with_unicode() {
17525 let ctx = test_context();
17526 let mut req = Request::new(Method::Get, "/users/用户123");
17527 req.insert_extension(PathParams::from_pairs(vec![(
17528 "name".to_string(),
17529 "用户123".to_string(),
17530 )]));
17531
17532 let result = futures_executor::block_on(Path::<String>::from_request(&ctx, &mut req));
17533 let Path(name) = result.unwrap();
17534 assert_eq!(name, "用户123");
17535 }
17536
17537 #[test]
17540 fn path_max_i64() {
17541 let ctx = test_context();
17542 let mut req = Request::new(Method::Get, "/items/9223372036854775807");
17543 req.insert_extension(PathParams::from_pairs(vec![(
17544 "id".to_string(),
17545 "9223372036854775807".to_string(),
17546 )]));
17547
17548 let result = futures_executor::block_on(Path::<i64>::from_request(&ctx, &mut req));
17549 let Path(id) = result.unwrap();
17550 assert_eq!(id, i64::MAX);
17551 }
17552
17553 #[test]
17554 fn path_min_i64() {
17555 let ctx = test_context();
17556 let mut req = Request::new(Method::Get, "/items/-9223372036854775808");
17557 req.insert_extension(PathParams::from_pairs(vec![(
17558 "id".to_string(),
17559 "-9223372036854775808".to_string(),
17560 )]));
17561
17562 let result = futures_executor::block_on(Path::<i64>::from_request(&ctx, &mut req));
17563 let Path(id) = result.unwrap();
17564 assert_eq!(id, i64::MIN);
17565 }
17566
17567 #[test]
17568 fn path_overflow_i64_fails() {
17569 let ctx = test_context();
17570 let mut req = Request::new(Method::Get, "/items/9223372036854775808");
17571 req.insert_extension(PathParams::from_pairs(vec![(
17572 "id".to_string(),
17573 "9223372036854775808".to_string(), )]));
17575
17576 let result = futures_executor::block_on(Path::<i64>::from_request(&ctx, &mut req));
17577 assert!(result.is_err());
17578 }
17579
17580 #[test]
17581 fn query_with_empty_value() {
17582 use serde::Deserialize;
17583
17584 #[derive(Deserialize, PartialEq, Debug)]
17585 struct Params {
17586 key: String,
17587 }
17588
17589 let ctx = test_context();
17590 let mut req = Request::new(Method::Get, "/test");
17591 req.set_query(Some("key=".to_string()));
17592
17593 let result = futures_executor::block_on(Query::<Params>::from_request(&ctx, &mut req));
17594 let Query(params) = result.unwrap();
17595 assert_eq!(params.key, "");
17596 }
17597
17598 #[test]
17599 fn query_with_only_key_no_equals() {
17600 use serde::Deserialize;
17601
17602 #[derive(Deserialize, PartialEq, Debug)]
17603 struct Params {
17604 flag: Option<String>,
17605 }
17606
17607 let ctx = test_context();
17608 let mut req = Request::new(Method::Get, "/test");
17609 req.set_query(Some("flag".to_string()));
17610
17611 let result = futures_executor::block_on(Query::<Params>::from_request(&ctx, &mut req));
17612 let Query(params) = result.unwrap();
17613 assert_eq!(params.flag, Some(String::new()));
17615 }
17616
17617 #[test]
17618 fn json_empty_object() {
17619 use serde::Deserialize;
17620
17621 #[derive(Deserialize, PartialEq, Debug)]
17622 struct Empty {}
17623
17624 let ctx = test_context();
17625 let mut req = Request::new(Method::Post, "/test");
17626 req.headers_mut()
17627 .insert("content-type", b"application/json".to_vec());
17628 req.set_body(Body::Bytes(b"{}".to_vec()));
17629
17630 let result = futures_executor::block_on(Json::<Empty>::from_request(&ctx, &mut req));
17631 assert!(result.is_ok());
17632 }
17633
17634 #[test]
17635 fn json_with_null_field() {
17636 use serde::Deserialize;
17637
17638 #[derive(Deserialize, PartialEq, Debug)]
17639 struct Data {
17640 value: Option<i32>,
17641 }
17642
17643 let ctx = test_context();
17644 let mut req = Request::new(Method::Post, "/test");
17645 req.headers_mut()
17646 .insert("content-type", b"application/json".to_vec());
17647 req.set_body(Body::Bytes(b"{\"value\": null}".to_vec()));
17648
17649 let result = futures_executor::block_on(Json::<Data>::from_request(&ctx, &mut req));
17650 let Json(data) = result.unwrap();
17651 assert_eq!(data.value, None);
17652 }
17653
17654 #[test]
17655 fn json_with_nested_objects() {
17656 use serde::Deserialize;
17657
17658 #[derive(Deserialize, PartialEq, Debug)]
17659 struct Address {
17660 city: String,
17661 zip: String,
17662 }
17663
17664 #[derive(Deserialize, PartialEq, Debug)]
17665 struct User {
17666 name: String,
17667 address: Address,
17668 }
17669
17670 let ctx = test_context();
17671 let mut req = Request::new(Method::Post, "/test");
17672 req.headers_mut()
17673 .insert("content-type", b"application/json".to_vec());
17674 req.set_body(Body::Bytes(
17675 b"{\"name\": \"Alice\", \"address\": {\"city\": \"NYC\", \"zip\": \"10001\"}}".to_vec(),
17676 ));
17677
17678 let result = futures_executor::block_on(Json::<User>::from_request(&ctx, &mut req));
17679 let Json(user) = result.unwrap();
17680 assert_eq!(user.name, "Alice");
17681 assert_eq!(user.address.city, "NYC");
17682 assert_eq!(user.address.zip, "10001");
17683 }
17684
17685 #[test]
17686 fn json_with_array() {
17687 use serde::Deserialize;
17688
17689 #[derive(Deserialize, PartialEq, Debug)]
17690 struct Data {
17691 items: Vec<i32>,
17692 }
17693
17694 let ctx = test_context();
17695 let mut req = Request::new(Method::Post, "/test");
17696 req.headers_mut()
17697 .insert("content-type", b"application/json".to_vec());
17698 req.set_body(Body::Bytes(b"{\"items\": [1, 2, 3, 4, 5]}".to_vec()));
17699
17700 let result = futures_executor::block_on(Json::<Data>::from_request(&ctx, &mut req));
17701 let Json(data) = result.unwrap();
17702 assert_eq!(data.items, vec![1, 2, 3, 4, 5]);
17703 }
17704
17705 #[test]
17706 fn path_with_special_chars() {
17707 let ctx = test_context();
17708 let mut req = Request::new(Method::Get, "/files/my-file_v2.txt");
17709 req.insert_extension(PathParams::from_pairs(vec![(
17710 "filename".to_string(),
17711 "my-file_v2.txt".to_string(),
17712 )]));
17713
17714 let result = futures_executor::block_on(Path::<String>::from_request(&ctx, &mut req));
17715 let Path(filename) = result.unwrap();
17716 assert_eq!(filename, "my-file_v2.txt");
17717 }
17718
17719 #[test]
17720 fn query_with_special_chars_encoded() {
17721 use serde::Deserialize;
17722
17723 #[derive(Deserialize, PartialEq, Debug)]
17724 struct Params {
17725 value: String,
17726 }
17727
17728 let ctx = test_context();
17729 let mut req = Request::new(Method::Get, "/test");
17730 req.set_query(Some("value=hello%20world%20%26%20more".to_string()));
17732
17733 let result = futures_executor::block_on(Query::<Params>::from_request(&ctx, &mut req));
17734 let Query(params) = result.unwrap();
17735 assert_eq!(params.value, "hello world & more");
17736 }
17737
17738 #[test]
17739 fn query_multiple_values_same_key() {
17740 use serde::Deserialize;
17741
17742 #[derive(Deserialize, PartialEq, Debug)]
17743 struct Params {
17744 tags: Vec<String>,
17745 }
17746
17747 let ctx = test_context();
17748 let mut req = Request::new(Method::Get, "/test");
17749 req.set_query(Some("tags=a&tags=b&tags=c".to_string()));
17750
17751 let result = futures_executor::block_on(Query::<Params>::from_request(&ctx, &mut req));
17752 let Query(params) = result.unwrap();
17753 assert_eq!(params.tags, vec!["a", "b", "c"]);
17754 }
17755
17756 #[test]
17757 fn path_empty_string() {
17758 let ctx = test_context();
17759 let mut req = Request::new(Method::Get, "/items//details");
17760 req.insert_extension(PathParams::from_pairs(vec![(
17761 "id".to_string(),
17762 String::new(),
17763 )]));
17764
17765 let result = futures_executor::block_on(Path::<String>::from_request(&ctx, &mut req));
17766 let Path(id) = result.unwrap();
17767 assert_eq!(id, "");
17768 }
17769
17770 #[test]
17771 fn json_with_escaped_quotes() {
17772 use serde::Deserialize;
17773
17774 #[derive(Deserialize, PartialEq, Debug)]
17775 struct Data {
17776 message: String,
17777 }
17778
17779 let ctx = test_context();
17780 let mut req = Request::new(Method::Post, "/test");
17781 req.headers_mut()
17782 .insert("content-type", b"application/json".to_vec());
17783 req.set_body(Body::Bytes(
17784 b"{\"message\": \"He said \\\"hello\\\"\"}".to_vec(),
17785 ));
17786
17787 let result = futures_executor::block_on(Json::<Data>::from_request(&ctx, &mut req));
17788 let Json(data) = result.unwrap();
17789 assert_eq!(data.message, "He said \"hello\"");
17790 }
17791
17792 #[test]
17793 fn query_with_plus_as_space() {
17794 use serde::Deserialize;
17795
17796 #[derive(Deserialize, PartialEq, Debug)]
17797 struct Params {
17798 q: String,
17799 }
17800
17801 let ctx = test_context();
17802 let mut req = Request::new(Method::Get, "/search");
17803 req.set_query(Some("q=hello+world".to_string()));
17804
17805 let result = futures_executor::block_on(Query::<Params>::from_request(&ctx, &mut req));
17806 let Query(params) = result.unwrap();
17807 assert_eq!(params.q, "hello world");
17808 }
17809}
17810
17811#[cfg(test)]
17816mod security_tests {
17817 use super::*;
17818 use crate::request::Method;
17819
17820 fn test_context() -> RequestContext {
17821 let cx = asupersync::Cx::for_testing();
17822 RequestContext::new(cx, 66666)
17823 }
17824
17825 #[test]
17826 fn json_payload_size_limit() {
17827 use serde::Deserialize;
17828
17829 #[derive(Deserialize)]
17830 #[allow(dead_code)]
17831 struct Data {
17832 content: String,
17833 }
17834
17835 let ctx = test_context();
17836 let mut req = Request::new(Method::Post, "/test");
17837 req.headers_mut()
17838 .insert("content-type", b"application/json".to_vec());
17839
17840 let large_content = "x".repeat(DEFAULT_JSON_LIMIT + 100);
17842 let body = format!("{{\"content\": \"{large_content}\"}}");
17843 req.set_body(Body::Bytes(body.into_bytes()));
17844
17845 let result = futures_executor::block_on(Json::<Data>::from_request(&ctx, &mut req));
17846 assert!(matches!(
17847 result,
17848 Err(JsonExtractError::PayloadTooLarge { .. })
17849 ));
17850 }
17851
17852 #[test]
17853 fn json_deeply_nested_object() {
17854 use serde::Deserialize;
17855
17856 #[derive(Deserialize)]
17858 struct Level1 {
17859 #[allow(dead_code)]
17860 l2: Level2,
17861 }
17862 #[derive(Deserialize)]
17863 struct Level2 {
17864 #[allow(dead_code)]
17865 l3: Level3,
17866 }
17867 #[derive(Deserialize)]
17868 struct Level3 {
17869 #[allow(dead_code)]
17870 l4: Level4,
17871 }
17872 #[derive(Deserialize)]
17873 struct Level4 {
17874 #[allow(dead_code)]
17875 value: i32,
17876 }
17877
17878 let ctx = test_context();
17879 let mut req = Request::new(Method::Post, "/test");
17880 req.headers_mut()
17881 .insert("content-type", b"application/json".to_vec());
17882 req.set_body(Body::Bytes(
17883 b"{\"l2\":{\"l3\":{\"l4\":{\"value\":42}}}}".to_vec(),
17884 ));
17885
17886 let result = futures_executor::block_on(Json::<Level1>::from_request(&ctx, &mut req));
17887 assert!(result.is_ok());
17888 }
17889
17890 #[test]
17891 fn query_injection_attempt_escaped() {
17892 use serde::Deserialize;
17893
17894 #[derive(Deserialize, PartialEq, Debug)]
17895 struct Params {
17896 name: String,
17897 }
17898
17899 let ctx = test_context();
17900 let mut req = Request::new(Method::Get, "/test");
17901 req.set_query(Some(
17903 "name=Robert%27%3B%20DROP%20TABLE%20users%3B--".to_string(),
17904 ));
17905
17906 let result = futures_executor::block_on(Query::<Params>::from_request(&ctx, &mut req));
17907 let Query(params) = result.unwrap();
17908 assert_eq!(params.name, "Robert'; DROP TABLE users;--");
17910 }
17911
17912 #[test]
17913 fn path_traversal_attempt() {
17914 let ctx = test_context();
17915 let mut req = Request::new(Method::Get, "/files/../../../etc/passwd");
17916 req.insert_extension(PathParams::from_pairs(vec![(
17917 "path".to_string(),
17918 "../../../etc/passwd".to_string(),
17919 )]));
17920
17921 let result = futures_executor::block_on(Path::<String>::from_request(&ctx, &mut req));
17922 let Path(path) = result.unwrap();
17923 assert_eq!(path, "../../../etc/passwd");
17925 }
17926
17927 #[test]
17928 fn json_with_script_tag_xss() {
17929 use serde::Deserialize;
17930
17931 #[derive(Deserialize, PartialEq, Debug)]
17932 struct Data {
17933 comment: String,
17934 }
17935
17936 let ctx = test_context();
17937 let mut req = Request::new(Method::Post, "/test");
17938 req.headers_mut()
17939 .insert("content-type", b"application/json".to_vec());
17940 req.set_body(Body::Bytes(
17941 b"{\"comment\": \"<script>alert('xss')</script>\"}".to_vec(),
17942 ));
17943
17944 let result = futures_executor::block_on(Json::<Data>::from_request(&ctx, &mut req));
17945 let Json(data) = result.unwrap();
17946 assert_eq!(data.comment, "<script>alert('xss')</script>");
17948 }
17949
17950 #[test]
17951 fn json_content_type_case_insensitive() {
17952 use serde::Deserialize;
17953
17954 #[derive(Deserialize, PartialEq, Debug)]
17955 struct Data {
17956 value: i32,
17957 }
17958
17959 for content_type in &[
17961 "APPLICATION/JSON",
17962 "Application/Json",
17963 "application/JSON",
17964 "APPLICATION/json",
17965 ] {
17966 let ctx = test_context();
17967 let mut req = Request::new(Method::Post, "/test");
17968 req.headers_mut()
17969 .insert("content-type", content_type.as_bytes().to_vec());
17970 req.set_body(Body::Bytes(b"{\"value\": 42}".to_vec()));
17971
17972 let result = futures_executor::block_on(Json::<Data>::from_request(&ctx, &mut req));
17973 assert!(result.is_ok(), "Failed for content-type: {}", content_type);
17974 }
17975 }
17976
17977 #[test]
17978 fn json_wrong_content_type_variants() {
17979 use serde::Deserialize;
17980
17981 #[derive(Deserialize)]
17982 #[allow(dead_code)]
17983 struct Data {
17984 value: i32,
17985 }
17986
17987 for content_type in &[
17989 "text/json",
17990 "text/plain",
17991 "application/xml",
17992 "application/x-json",
17993 ] {
17994 let ctx = test_context();
17995 let mut req = Request::new(Method::Post, "/test");
17996 req.headers_mut()
17997 .insert("content-type", content_type.as_bytes().to_vec());
17998 req.set_body(Body::Bytes(b"{\"value\": 42}".to_vec()));
17999
18000 let result = futures_executor::block_on(Json::<Data>::from_request(&ctx, &mut req));
18001 assert!(
18002 matches!(result, Err(JsonExtractError::UnsupportedMediaType { .. })),
18003 "Should reject content-type: {}",
18004 content_type
18005 );
18006 }
18007 }
18008
18009 #[test]
18010 fn query_null_byte_handling() {
18011 use serde::Deserialize;
18012
18013 #[derive(Deserialize, PartialEq, Debug)]
18014 struct Params {
18015 name: String,
18016 }
18017
18018 let ctx = test_context();
18019 let mut req = Request::new(Method::Get, "/test");
18020 req.set_query(Some("name=test%00value".to_string()));
18022
18023 let result = futures_executor::block_on(Query::<Params>::from_request(&ctx, &mut req));
18024 let Query(params) = result.unwrap();
18025 assert_eq!(params.name, "test\0value");
18027 }
18028
18029 #[test]
18030 fn path_with_null_byte() {
18031 let ctx = test_context();
18032 let mut req = Request::new(Method::Get, "/files/test");
18033 req.insert_extension(PathParams::from_pairs(vec![(
18034 "filename".to_string(),
18035 "test\0.txt".to_string(),
18036 )]));
18037
18038 let result = futures_executor::block_on(Path::<String>::from_request(&ctx, &mut req));
18039 let Path(filename) = result.unwrap();
18040 assert_eq!(filename, "test\0.txt");
18041 }
18042
18043 #[test]
18044 fn json_number_precision() {
18045 use serde::Deserialize;
18046
18047 #[derive(Deserialize, PartialEq, Debug)]
18048 struct Data {
18049 big_int: i64,
18050 float_val: f64,
18051 }
18052
18053 let ctx = test_context();
18054 let mut req = Request::new(Method::Post, "/test");
18055 req.headers_mut()
18056 .insert("content-type", b"application/json".to_vec());
18057 req.set_body(Body::Bytes(
18059 b"{\"big_int\": 9007199254740993, \"float_val\": 3.141592653589793}".to_vec(),
18060 ));
18061
18062 let result = futures_executor::block_on(Json::<Data>::from_request(&ctx, &mut req));
18063 let Json(data) = result.unwrap();
18064 assert_eq!(data.big_int, 9007199254740993_i64);
18065 assert!((data.float_val - std::f64::consts::PI).abs() < 0.0000001);
18066 }
18067
18068 #[test]
18073 fn json_into_response_serializes_struct() {
18074 use serde::Serialize;
18075
18076 #[derive(Serialize)]
18077 struct User {
18078 name: String,
18079 age: u32,
18080 }
18081
18082 let user = User {
18083 name: "Alice".to_string(),
18084 age: 30,
18085 };
18086 let json = Json(user);
18087 let response = json.into_response();
18088
18089 assert_eq!(response.status().as_u16(), 200);
18090
18091 let content_type = response
18093 .headers()
18094 .iter()
18095 .find(|(name, _)| name == "content-type")
18096 .map(|(_, value)| String::from_utf8_lossy(value).to_string());
18097 assert_eq!(content_type, Some("application/json".to_string()));
18098
18099 if let ResponseBody::Bytes(bytes) = response.body_ref() {
18101 let parsed: serde_json::Value = serde_json::from_slice(bytes).unwrap();
18102 assert_eq!(parsed["name"], "Alice");
18103 assert_eq!(parsed["age"], 30);
18104 } else {
18105 panic!("Expected Bytes body");
18106 }
18107 }
18108
18109 #[test]
18110 fn json_into_response_serializes_primitive() {
18111 let json = Json(42i32);
18112 let response = json.into_response();
18113
18114 assert_eq!(response.status().as_u16(), 200);
18115
18116 if let ResponseBody::Bytes(bytes) = response.body_ref() {
18117 let parsed: i32 = serde_json::from_slice(bytes).unwrap();
18118 assert_eq!(parsed, 42);
18119 } else {
18120 panic!("Expected Bytes body");
18121 }
18122 }
18123
18124 #[test]
18125 fn json_into_response_serializes_array() {
18126 let json = Json(vec!["a", "b", "c"]);
18127 let response = json.into_response();
18128
18129 assert_eq!(response.status().as_u16(), 200);
18130
18131 if let ResponseBody::Bytes(bytes) = response.body_ref() {
18132 let parsed: Vec<String> = serde_json::from_slice(bytes).unwrap();
18133 assert_eq!(parsed, vec!["a", "b", "c"]);
18134 } else {
18135 panic!("Expected Bytes body");
18136 }
18137 }
18138
18139 #[test]
18140 fn json_into_response_serializes_hashmap() {
18141 use std::collections::HashMap;
18142
18143 let mut map = HashMap::new();
18144 map.insert("key1", "value1");
18145 map.insert("key2", "value2");
18146
18147 let json = Json(map);
18148 let response = json.into_response();
18149
18150 assert_eq!(response.status().as_u16(), 200);
18151
18152 if let ResponseBody::Bytes(bytes) = response.body_ref() {
18153 let parsed: HashMap<String, String> = serde_json::from_slice(bytes).unwrap();
18154 assert_eq!(parsed.get("key1"), Some(&"value1".to_string()));
18155 assert_eq!(parsed.get("key2"), Some(&"value2".to_string()));
18156 } else {
18157 panic!("Expected Bytes body");
18158 }
18159 }
18160
18161 #[test]
18162 fn json_into_response_handles_null() {
18163 let json = Json(Option::<String>::None);
18164 let response = json.into_response();
18165
18166 assert_eq!(response.status().as_u16(), 200);
18167
18168 if let ResponseBody::Bytes(bytes) = response.body_ref() {
18169 let content = String::from_utf8_lossy(bytes);
18170 assert_eq!(content, "null");
18171 } else {
18172 panic!("Expected Bytes body");
18173 }
18174 }
18175}
18176
18177#[cfg(test)]
18182mod body_size_limit_tests {
18183 use super::*;
18184 use crate::request::{Body, Method};
18185 use crate::response::{ResponseBody, StatusCode};
18186
18187 fn test_context() -> RequestContext {
18188 let cx = asupersync::Cx::for_testing();
18189 RequestContext::new(cx, 1)
18190 }
18191
18192 fn test_context_with_limit(limit: usize) -> RequestContext {
18193 let cx = asupersync::Cx::for_testing();
18194 RequestContext::with_body_limit(cx, 1, limit)
18195 }
18196
18197 #[test]
18200 fn default_constants_match_expected_values() {
18201 assert_eq!(DEFAULT_JSON_LIMIT, 1024 * 1024); assert_eq!(DEFAULT_FORM_LIMIT, 1024 * 1024); assert_eq!(DEFAULT_RAW_BODY_LIMIT, 2 * 1024 * 1024); assert_eq!(crate::DEFAULT_MAX_BODY_SIZE, 1024 * 1024); }
18206
18207 #[test]
18210 fn json_body_under_limit_accepted() {
18211 use serde::Deserialize;
18212
18213 #[derive(Deserialize, Debug)]
18214 struct Msg {
18215 text: String,
18216 }
18217
18218 let ctx = test_context();
18219 let mut req = Request::new(Method::Post, "/api");
18220 req.headers_mut()
18221 .insert("content-type", b"application/json".to_vec());
18222 req.set_body(Body::Bytes(b"{\"text\":\"hello\"}".to_vec()));
18223
18224 let result = futures_executor::block_on(Json::<Msg>::from_request(&ctx, &mut req));
18225 assert!(result.is_ok());
18226 assert_eq!(result.unwrap().0.text, "hello");
18227 }
18228
18229 #[test]
18230 fn json_body_over_default_limit_rejected() {
18231 use serde::Deserialize;
18232
18233 #[derive(Deserialize)]
18234 #[allow(dead_code)]
18235 struct Data {
18236 content: String,
18237 }
18238
18239 let ctx = test_context();
18240 let mut req = Request::new(Method::Post, "/api");
18241 req.headers_mut()
18242 .insert("content-type", b"application/json".to_vec());
18243
18244 let large = "x".repeat(crate::DEFAULT_MAX_BODY_SIZE + 1);
18246 let body = format!("{{\"content\":\"{}\"}}", large);
18247 req.set_body(Body::Bytes(body.into_bytes()));
18248
18249 let result = futures_executor::block_on(Json::<Data>::from_request(&ctx, &mut req));
18250 assert!(matches!(
18251 result,
18252 Err(JsonExtractError::PayloadTooLarge { .. })
18253 ));
18254 }
18255
18256 #[test]
18257 fn json_body_exactly_at_limit_accepted() {
18258 use serde::Deserialize;
18259
18260 #[derive(Deserialize)]
18261 #[allow(dead_code)]
18262 struct Data {
18263 content: String,
18264 }
18265
18266 let ctx = test_context_with_limit(100);
18268 let mut req = Request::new(Method::Post, "/api");
18269 req.headers_mut()
18270 .insert("content-type", b"application/json".to_vec());
18271
18272 let prefix = b"{\"content\":\"";
18275 let suffix = b"\"}";
18276 let content_len = 100 - prefix.len() - suffix.len();
18277 let content: String = "a".repeat(content_len);
18278 let body = format!("{{\"content\":\"{}\"}}", content);
18279 assert_eq!(body.len(), 100);
18280
18281 req.set_body(Body::Bytes(body.into_bytes()));
18282 let result = futures_executor::block_on(Json::<Data>::from_request(&ctx, &mut req));
18283 assert!(result.is_ok(), "Body exactly at limit should be accepted");
18284 }
18285
18286 #[test]
18287 fn json_body_one_byte_over_limit_rejected() {
18288 use serde::Deserialize;
18289
18290 #[derive(Deserialize, Debug)]
18291 #[allow(dead_code)]
18292 struct Data {
18293 content: String,
18294 }
18295
18296 let ctx = test_context_with_limit(100);
18297 let mut req = Request::new(Method::Post, "/api");
18298 req.headers_mut()
18299 .insert("content-type", b"application/json".to_vec());
18300
18301 let prefix = b"{\"content\":\"";
18303 let suffix = b"\"}";
18304 let content_len = 101 - prefix.len() - suffix.len();
18305 let content: String = "a".repeat(content_len);
18306 let body = format!("{{\"content\":\"{}\"}}", content);
18307 assert_eq!(body.len(), 101);
18308
18309 req.set_body(Body::Bytes(body.into_bytes()));
18310 let result = futures_executor::block_on(Json::<Data>::from_request(&ctx, &mut req));
18311 match result {
18312 Err(JsonExtractError::PayloadTooLarge { size, limit }) => {
18313 assert_eq!(size, 101);
18314 assert_eq!(limit, 100);
18315 }
18316 other => panic!("Expected PayloadTooLarge, got {:?}", other),
18317 }
18318 }
18319
18320 #[test]
18321 fn json_custom_body_limit_via_context() {
18322 use serde::Deserialize;
18323
18324 #[derive(Deserialize)]
18325 #[allow(dead_code)]
18326 struct Data {
18327 val: String,
18328 }
18329
18330 let ctx = test_context_with_limit(50);
18332 let mut req = Request::new(Method::Post, "/api");
18333 req.headers_mut()
18334 .insert("content-type", b"application/json".to_vec());
18335
18336 let padding = "x".repeat(60);
18338 let body = format!("{{\"val\":\"{}\"}}", padding);
18339 assert!(
18340 body.len() > 50,
18341 "Body is {} bytes, expected > 50",
18342 body.len()
18343 );
18344 req.set_body(Body::Bytes(body.into_bytes()));
18345
18346 let result = futures_executor::block_on(Json::<Data>::from_request(&ctx, &mut req));
18347 assert!(matches!(
18348 result,
18349 Err(JsonExtractError::PayloadTooLarge { .. })
18350 ));
18351 }
18352
18353 #[test]
18354 fn json_large_custom_limit_accepts_big_body() {
18355 use serde::Deserialize;
18356
18357 #[derive(Deserialize, Debug)]
18358 #[allow(dead_code)]
18359 struct Data {
18360 content: String,
18361 }
18362
18363 let ctx = test_context_with_limit(10 * 1024 * 1024);
18365 let mut req = Request::new(Method::Post, "/api");
18366 req.headers_mut()
18367 .insert("content-type", b"application/json".to_vec());
18368
18369 let large = "x".repeat(2 * 1024 * 1024);
18371 let body = format!("{{\"content\":\"{}\"}}", large);
18372 req.set_body(Body::Bytes(body.into_bytes()));
18373
18374 let result = futures_executor::block_on(Json::<Data>::from_request(&ctx, &mut req));
18375 assert!(result.is_ok(), "Body under custom limit should be accepted");
18376 }
18377
18378 #[test]
18379 fn json_empty_body_accepted() {
18380 use serde::Deserialize;
18381
18382 #[derive(Deserialize, Debug)]
18383 #[allow(dead_code)]
18384 struct Data {
18385 #[serde(default)]
18386 val: i32,
18387 }
18388
18389 let ctx = test_context_with_limit(10);
18390 let mut req = Request::new(Method::Post, "/api");
18391 req.headers_mut()
18392 .insert("content-type", b"application/json".to_vec());
18393 req.set_body(Body::Empty);
18394
18395 let result = futures_executor::block_on(Json::<Data>::from_request(&ctx, &mut req));
18398 match result {
18400 Err(JsonExtractError::DeserializeError { .. }) => {} Err(JsonExtractError::PayloadTooLarge { .. }) => {
18402 panic!("Empty body should not trigger size limit")
18403 }
18404 Ok(_) => {} other => panic!("Unexpected result: {:?}", other),
18406 }
18407 }
18408
18409 #[test]
18410 fn json_payload_too_large_error_response_is_413() {
18411 let err = JsonExtractError::PayloadTooLarge {
18412 size: 2_000_000,
18413 limit: 1_000_000,
18414 };
18415 let response = err.into_response();
18416 assert_eq!(response.status(), StatusCode::PAYLOAD_TOO_LARGE);
18417 assert_eq!(response.status().as_u16(), 413);
18418 }
18419
18420 #[test]
18423 fn form_body_under_limit_accepted() {
18424 use serde::Deserialize;
18425
18426 #[derive(Deserialize, Debug)]
18427 struct Login {
18428 user: String,
18429 }
18430
18431 let ctx = test_context();
18432 let mut req = Request::new(Method::Post, "/login");
18433 req.headers_mut().insert(
18434 "content-type",
18435 b"application/x-www-form-urlencoded".to_vec(),
18436 );
18437 req.set_body(Body::Bytes(b"user=alice".to_vec()));
18438
18439 let result = futures_executor::block_on(Form::<Login>::from_request(&ctx, &mut req));
18440 assert!(result.is_ok());
18441 assert_eq!(result.unwrap().0.user, "alice");
18442 }
18443
18444 #[test]
18445 fn form_body_over_limit_rejected() {
18446 use serde::Deserialize;
18447
18448 #[derive(Deserialize)]
18449 #[allow(dead_code)]
18450 struct Data {
18451 field: String,
18452 }
18453
18454 let ctx = test_context(); let mut req = Request::new(Method::Post, "/submit");
18456 req.headers_mut().insert(
18457 "content-type",
18458 b"application/x-www-form-urlencoded".to_vec(),
18459 );
18460
18461 let large_value = "x".repeat(crate::DEFAULT_MAX_BODY_SIZE + 1);
18463 let body = format!("field={}", large_value);
18464 req.set_body(Body::Bytes(body.into_bytes()));
18465
18466 let result = futures_executor::block_on(Form::<Data>::from_request(&ctx, &mut req));
18467 assert!(matches!(
18468 result,
18469 Err(FormExtractError::PayloadTooLarge { .. })
18470 ));
18471 }
18472
18473 #[test]
18474 fn form_custom_limit_via_context() {
18475 use serde::Deserialize;
18476
18477 #[derive(Deserialize, Debug)]
18478 #[allow(dead_code)]
18479 struct Data {
18480 field: String,
18481 }
18482
18483 let ctx = test_context_with_limit(20);
18484 let mut req = Request::new(Method::Post, "/submit");
18485 req.headers_mut().insert(
18486 "content-type",
18487 b"application/x-www-form-urlencoded".to_vec(),
18488 );
18489
18490 let body = "field=abcdefghijklmnopqrstuv";
18492 assert!(body.len() > 20);
18493 req.set_body(Body::Bytes(body.as_bytes().to_vec()));
18494
18495 let result = futures_executor::block_on(Form::<Data>::from_request(&ctx, &mut req));
18496 match result {
18497 Err(FormExtractError::PayloadTooLarge { size, limit }) => {
18498 assert_eq!(limit, 20);
18499 assert!(size > 20);
18500 }
18501 other => panic!("Expected PayloadTooLarge, got {:?}", other),
18502 }
18503 }
18504
18505 #[test]
18506 fn form_payload_too_large_error_response_is_413() {
18507 let err = FormExtractError::PayloadTooLarge {
18508 size: 2_000_000,
18509 limit: 1_000_000,
18510 };
18511 let response = err.into_response();
18512 assert_eq!(response.status(), StatusCode::PAYLOAD_TOO_LARGE);
18513 }
18514
18515 #[test]
18518 fn bytes_body_under_limit_accepted() {
18519 let ctx = test_context();
18520 let mut req = Request::new(Method::Post, "/upload");
18521 req.set_body(Body::Bytes(b"small payload".to_vec()));
18522
18523 let result = futures_executor::block_on(Bytes::from_request(&ctx, &mut req));
18524 assert!(result.is_ok());
18525 assert_eq!(result.unwrap().as_slice(), b"small payload");
18526 }
18527
18528 #[test]
18529 fn bytes_body_over_default_limit_rejected() {
18530 let ctx = test_context();
18531 let mut req = Request::new(Method::Post, "/upload");
18532 let large_body = vec![0u8; DEFAULT_RAW_BODY_LIMIT + 1];
18533 req.set_body(Body::Bytes(large_body));
18534
18535 let result = futures_executor::block_on(Bytes::from_request(&ctx, &mut req));
18536 match result {
18537 Err(RawBodyError::PayloadTooLarge { size, limit }) => {
18538 assert_eq!(size, DEFAULT_RAW_BODY_LIMIT + 1);
18539 assert_eq!(limit, DEFAULT_RAW_BODY_LIMIT);
18540 }
18541 other => panic!("Expected PayloadTooLarge, got {:?}", other),
18542 }
18543 }
18544
18545 #[test]
18546 fn bytes_custom_limit_via_extension() {
18547 let ctx = test_context();
18548 let mut req = Request::new(Method::Post, "/upload");
18549 req.insert_extension(RawBodyConfig::new().limit(50));
18550 req.set_body(Body::Bytes(vec![0u8; 80]));
18551
18552 let result = futures_executor::block_on(Bytes::from_request(&ctx, &mut req));
18553 match result {
18554 Err(RawBodyError::PayloadTooLarge { size, limit }) => {
18555 assert_eq!(size, 80);
18556 assert_eq!(limit, 50);
18557 }
18558 other => panic!("Expected PayloadTooLarge, got {:?}", other),
18559 }
18560 }
18561
18562 #[test]
18563 fn bytes_custom_limit_accepts_body_under() {
18564 let ctx = test_context();
18565 let mut req = Request::new(Method::Post, "/upload");
18566 req.insert_extension(RawBodyConfig::new().limit(200));
18567 req.set_body(Body::Bytes(vec![0u8; 100]));
18568
18569 let result = futures_executor::block_on(Bytes::from_request(&ctx, &mut req));
18570 assert!(result.is_ok());
18571 assert_eq!(result.unwrap().len(), 100);
18572 }
18573
18574 #[test]
18575 fn bytes_empty_body_always_accepted() {
18576 let ctx = test_context();
18577 let mut req = Request::new(Method::Post, "/upload");
18578 req.insert_extension(RawBodyConfig::new().limit(0));
18579 req.set_body(Body::Empty);
18580
18581 let result = futures_executor::block_on(Bytes::from_request(&ctx, &mut req));
18582 assert!(result.is_ok());
18583 assert!(result.unwrap().is_empty());
18584 }
18585
18586 #[test]
18587 fn bytes_payload_too_large_error_response_is_413() {
18588 let err = RawBodyError::PayloadTooLarge {
18589 size: 5_000_000,
18590 limit: 2_000_000,
18591 };
18592 let response = err.into_response();
18593 assert_eq!(response.status(), StatusCode::PAYLOAD_TOO_LARGE);
18594 }
18595
18596 #[test]
18599 fn string_body_over_limit_rejected() {
18600 let ctx = test_context();
18601 let mut req = Request::new(Method::Post, "/text");
18602 req.insert_extension(RawBodyConfig::new().limit(10));
18603 req.set_body(Body::Bytes(b"this is longer than ten bytes".to_vec()));
18604
18605 let result = futures_executor::block_on(StringBody::from_request(&ctx, &mut req));
18606 assert!(matches!(result, Err(RawBodyError::PayloadTooLarge { .. })));
18607 }
18608
18609 #[test]
18610 fn string_body_under_limit_accepted() {
18611 let ctx = test_context();
18612 let mut req = Request::new(Method::Post, "/text");
18613 req.insert_extension(RawBodyConfig::new().limit(100));
18614 req.set_body(Body::Bytes(b"short".to_vec()));
18615
18616 let result = futures_executor::block_on(StringBody::from_request(&ctx, &mut req));
18617 assert!(result.is_ok());
18618 assert_eq!(result.unwrap().as_str(), "short");
18619 }
18620
18621 #[test]
18624 fn raw_body_config_default() {
18625 let config = RawBodyConfig::default();
18626 assert_eq!(config.get_limit(), DEFAULT_RAW_BODY_LIMIT);
18627 assert_eq!(config.get_limit(), 2 * 1024 * 1024);
18628 }
18629
18630 #[test]
18631 fn raw_body_config_builder() {
18632 let config = RawBodyConfig::new().limit(500);
18633 assert_eq!(config.get_limit(), 500);
18634 }
18635
18636 #[test]
18639 fn json_config_default() {
18640 let config = JsonConfig::default();
18641 assert_eq!(config.get_limit(), DEFAULT_JSON_LIMIT);
18642 }
18643
18644 #[test]
18645 fn json_config_builder() {
18646 let config = JsonConfig::new().limit(2048);
18647 assert_eq!(config.get_limit(), 2048);
18648 }
18649
18650 #[test]
18653 fn form_config_default() {
18654 let config = FormConfig::default();
18655 assert_eq!(config.get_limit(), DEFAULT_FORM_LIMIT);
18656 }
18657
18658 #[test]
18659 fn form_config_builder() {
18660 let config = FormConfig::new().limit(4096);
18661 assert_eq!(config.get_limit(), 4096);
18662 }
18663
18664 #[test]
18667 fn json_uses_context_body_limit_not_json_config() {
18668 use serde::Deserialize;
18670
18671 #[derive(Deserialize)]
18672 #[allow(dead_code)]
18673 struct Data {
18674 val: String,
18675 }
18676
18677 let ctx = test_context_with_limit(50);
18679 let mut req = Request::new(Method::Post, "/api");
18680 req.headers_mut()
18681 .insert("content-type", b"application/json".to_vec());
18682
18683 let body = "{\"val\":\"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa\"}";
18685 assert!(body.len() > 50);
18686 req.set_body(Body::Bytes(body.as_bytes().to_vec()));
18687
18688 let result = futures_executor::block_on(Json::<Data>::from_request(&ctx, &mut req));
18689 assert!(
18690 matches!(result, Err(JsonExtractError::PayloadTooLarge { .. })),
18691 "Json should use context body limit (50), not default"
18692 );
18693 }
18694
18695 #[test]
18696 fn form_uses_context_body_limit() {
18697 use serde::Deserialize;
18698
18699 #[derive(Deserialize)]
18700 #[allow(dead_code)]
18701 struct Data {
18702 val: String,
18703 }
18704
18705 let ctx = test_context_with_limit(30);
18707 let mut req = Request::new(Method::Post, "/form");
18708 req.headers_mut().insert(
18709 "content-type",
18710 b"application/x-www-form-urlencoded".to_vec(),
18711 );
18712
18713 let body = "val=aaaaaaaaaaaaaaaaaaaaaaaaaaaaaa";
18715 assert!(body.len() > 30);
18716 req.set_body(Body::Bytes(body.as_bytes().to_vec()));
18717
18718 let result = futures_executor::block_on(Form::<Data>::from_request(&ctx, &mut req));
18719 assert!(
18720 matches!(result, Err(FormExtractError::PayloadTooLarge { .. })),
18721 "Form should use context body limit"
18722 );
18723 }
18724
18725 #[test]
18726 fn bytes_uses_extension_config_not_context_limit() {
18727 let ctx = test_context_with_limit(10); let mut req = Request::new(Method::Post, "/upload");
18730 req.insert_extension(RawBodyConfig::new().limit(1000));
18732 req.set_body(Body::Bytes(vec![0u8; 500]));
18733
18734 let result = futures_executor::block_on(Bytes::from_request(&ctx, &mut req));
18735 assert!(
18737 result.is_ok(),
18738 "Bytes should use RawBodyConfig from extension, not context limit"
18739 );
18740 }
18741
18742 #[test]
18743 fn bytes_without_config_uses_default_raw_limit() {
18744 let ctx = test_context_with_limit(10); let mut req = Request::new(Method::Post, "/upload");
18747 req.set_body(Body::Bytes(vec![0u8; 500]));
18749
18750 let result = futures_executor::block_on(Bytes::from_request(&ctx, &mut req));
18751 assert!(result.is_ok());
18753 }
18754
18755 #[test]
18758 fn json_error_response_body_contains_size_info() {
18759 let err = JsonExtractError::PayloadTooLarge {
18760 size: 2_500_000,
18761 limit: 1_048_576,
18762 };
18763 let response = err.into_response();
18764 if let ResponseBody::Bytes(bytes) = response.body_ref() {
18765 let body = std::str::from_utf8(bytes).unwrap();
18766 assert!(
18767 body.contains("2500000"),
18768 "Error should contain actual size, got: {}",
18769 body
18770 );
18771 assert!(
18772 body.contains("1048576"),
18773 "Error should contain limit, got: {}",
18774 body
18775 );
18776 } else {
18777 panic!("Expected Bytes body");
18778 }
18779 }
18780
18781 #[test]
18782 fn form_error_response_body_contains_size_info() {
18783 let err = FormExtractError::PayloadTooLarge {
18784 size: 3_000_000,
18785 limit: 1_048_576,
18786 };
18787 let response = err.into_response();
18788 if let ResponseBody::Bytes(bytes) = response.body_ref() {
18789 let body = std::str::from_utf8(bytes).unwrap();
18790 assert!(
18791 body.contains("3000000"),
18792 "Error should contain actual size, got: {}",
18793 body
18794 );
18795 assert!(
18796 body.contains("1048576"),
18797 "Error should contain limit, got: {}",
18798 body
18799 );
18800 } else {
18801 panic!("Expected Bytes body");
18802 }
18803 }
18804
18805 #[test]
18806 fn raw_body_error_response_contains_size_info() {
18807 let err = RawBodyError::PayloadTooLarge {
18808 size: 5_000_000,
18809 limit: 2_097_152,
18810 };
18811 let response = err.into_response();
18812 if let ResponseBody::Bytes(bytes) = response.body_ref() {
18813 let body = std::str::from_utf8(bytes).unwrap();
18814 assert!(
18815 body.contains("5000000"),
18816 "Error should contain actual size, got: {}",
18817 body
18818 );
18819 assert!(
18820 body.contains("2097152"),
18821 "Error should contain limit, got: {}",
18822 body
18823 );
18824 } else {
18825 panic!("Expected Bytes body");
18826 }
18827 }
18828
18829 #[test]
18832 fn json_streaming_body_rejected() {
18833 use serde::Deserialize;
18834
18835 #[derive(Deserialize)]
18836 #[allow(dead_code)]
18837 struct Data {
18838 val: i32,
18839 }
18840
18841 let ctx = test_context();
18842 let mut req = Request::new(Method::Post, "/api");
18843 req.headers_mut()
18844 .insert("content-type", b"application/json".to_vec());
18845
18846 let stream = asupersync::stream::iter(
18847 vec![Ok(b"chunk".to_vec())]
18848 .into_iter()
18849 .map(|r: Result<Vec<u8>, crate::request::RequestBodyStreamError>| r),
18850 );
18851 req.set_body(Body::streaming(stream));
18852
18853 let result = futures_executor::block_on(Json::<Data>::from_request(&ctx, &mut req));
18854 assert!(matches!(
18855 result,
18856 Err(JsonExtractError::StreamingNotSupported)
18857 ));
18858 }
18859
18860 #[test]
18861 fn form_streaming_body_rejected() {
18862 use serde::Deserialize;
18863
18864 #[derive(Deserialize)]
18865 #[allow(dead_code)]
18866 struct Data {
18867 field: String,
18868 }
18869
18870 let ctx = test_context();
18871 let mut req = Request::new(Method::Post, "/form");
18872 req.headers_mut().insert(
18873 "content-type",
18874 b"application/x-www-form-urlencoded".to_vec(),
18875 );
18876
18877 let stream = asupersync::stream::iter(
18878 vec![Ok(b"chunk".to_vec())]
18879 .into_iter()
18880 .map(|r: Result<Vec<u8>, crate::request::RequestBodyStreamError>| r),
18881 );
18882 req.set_body(Body::streaming(stream));
18883
18884 let result = futures_executor::block_on(Form::<Data>::from_request(&ctx, &mut req));
18885 assert!(matches!(
18886 result,
18887 Err(FormExtractError::StreamingNotSupported)
18888 ));
18889 }
18890
18891 #[test]
18892 fn bytes_streaming_body_rejected() {
18893 let ctx = test_context();
18894 let mut req = Request::new(Method::Post, "/upload");
18895
18896 let stream = asupersync::stream::iter(
18897 vec![Ok(b"chunk".to_vec())]
18898 .into_iter()
18899 .map(|r: Result<Vec<u8>, crate::request::RequestBodyStreamError>| r),
18900 );
18901 req.set_body(Body::streaming(stream));
18902
18903 let result = futures_executor::block_on(Bytes::from_request(&ctx, &mut req));
18904 assert!(matches!(result, Err(RawBodyError::StreamingNotSupported)));
18905 }
18906
18907 #[test]
18912 fn digest_auth_extraction() {
18913 let ctx = test_context();
18914 let mut req = Request::new(Method::Get, "/protected");
18915 req.headers_mut().insert(
18916 "authorization",
18917 b"Digest username=\"alice\", realm=\"test\", nonce=\"abc123\"".to_vec(),
18918 );
18919 let result = futures_executor::block_on(DigestAuth::from_request(&ctx, &mut req));
18920 let auth = result.unwrap();
18921 assert!(auth.credentials().contains("username=\"alice\""));
18922 }
18923
18924 #[test]
18925 fn digest_auth_param_extraction() {
18926 let auth = DigestAuth::new("username=\"alice\", realm=\"test\", nonce=\"abc123\"");
18927 assert_eq!(auth.param("username"), Some("alice"));
18928 assert_eq!(auth.param("realm"), Some("test"));
18929 assert_eq!(auth.param("nonce"), Some("abc123"));
18930 assert_eq!(auth.param("nonexistent"), None);
18931 }
18932
18933 #[test]
18934 fn digest_auth_param_no_substring_match() {
18935 let auth = DigestAuth::new("username=\"alice\", realm=\"test\", nonce=\"abc123\"");
18937
18938 assert_eq!(auth.param("name"), None);
18940
18941 assert_eq!(auth.param("realm"), Some("test"));
18943
18944 assert_eq!(auth.param("e"), None); assert_eq!(auth.param("c"), None); }
18948
18949 #[test]
18950 fn digest_auth_param_unquoted_values() {
18951 let auth = DigestAuth::new("qop=auth, nc=00000001, cnonce=\"xyz\"");
18953 assert_eq!(auth.param("qop"), Some("auth"));
18954 assert_eq!(auth.param("nc"), Some("00000001"));
18955 assert_eq!(auth.param("cnonce"), Some("xyz"));
18956
18957 assert_eq!(auth.param("c"), None);
18959 }
18960
18961 #[test]
18962 fn digest_auth_param_at_start() {
18963 let auth = DigestAuth::new("realm=\"test\", username=\"bob\"");
18965 assert_eq!(auth.param("realm"), Some("test"));
18966 assert_eq!(auth.param("username"), Some("bob"));
18967 }
18968
18969 #[test]
18970 fn digest_auth_missing_header() {
18971 let ctx = test_context();
18972 let mut req = Request::new(Method::Get, "/protected");
18973 let result = futures_executor::block_on(DigestAuth::from_request(&ctx, &mut req));
18974 assert!(matches!(result, Err(DigestAuthError::MissingHeader)));
18975 }
18976
18977 #[test]
18978 fn digest_auth_wrong_scheme() {
18979 let ctx = test_context();
18980 let mut req = Request::new(Method::Get, "/protected");
18981 req.headers_mut()
18982 .insert("authorization", b"Bearer token123".to_vec());
18983 let result = futures_executor::block_on(DigestAuth::from_request(&ctx, &mut req));
18984 assert!(matches!(result, Err(DigestAuthError::InvalidScheme)));
18985 }
18986
18987 #[test]
18988 fn digest_auth_error_response_401() {
18989 let resp = DigestAuthError::MissingHeader.into_response();
18990 assert_eq!(resp.status().as_u16(), 401);
18991 let has_www_auth = resp
18992 .headers()
18993 .iter()
18994 .any(|(n, v)| n == "www-authenticate" && v == b"Digest");
18995 assert!(has_www_auth);
18996 }
18997
18998 #[test]
18999 fn digest_auth_case_insensitive() {
19000 let ctx = test_context();
19001 let mut req = Request::new(Method::Get, "/protected");
19002 req.headers_mut()
19003 .insert("authorization", b"digest username=\"bob\"".to_vec());
19004 let result = futures_executor::block_on(DigestAuth::from_request(&ctx, &mut req));
19005 assert!(result.is_ok());
19006 }
19007}