1use std::collections::HashMap;
2
3use ic_http_certification::{HeaderField, Method};
4
5#[derive(Debug)]
7pub enum JsonBodyError {
8 Utf8(std::str::Utf8Error),
10 Json(serde_json::Error),
12}
13
14impl std::fmt::Display for JsonBodyError {
15 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
16 match self {
17 Self::Utf8(e) => write!(f, "body is not valid UTF-8: {e}"),
18 Self::Json(e) => write!(f, "JSON deserialization failed: {e}"),
19 }
20 }
21}
22
23impl std::error::Error for JsonBodyError {
24 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
25 match self {
26 Self::Utf8(e) => Some(e),
27 Self::Json(e) => Some(e),
28 }
29 }
30}
31
32#[derive(Debug)]
34pub enum FormBodyError {
35 Utf8(std::str::Utf8Error),
37 Deserialize(serde_urlencoded::de::Error),
39}
40
41impl std::fmt::Display for FormBodyError {
42 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
43 match self {
44 Self::Utf8(e) => write!(f, "body is not valid UTF-8: {e}"),
45 Self::Deserialize(e) => write!(f, "form deserialization failed: {e}"),
46 }
47 }
48}
49
50impl std::error::Error for FormBodyError {
51 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
52 match self {
53 Self::Utf8(e) => Some(e),
54 Self::Deserialize(e) => Some(e),
55 }
56 }
57}
58
59pub type QueryParams = HashMap<String, String>;
61
62pub struct RouteContext<P, S = ()> {
111 pub params: P,
113 pub search: S,
118 pub query: QueryParams,
120 pub method: Method,
122 pub headers: Vec<HeaderField>,
124 pub body: Vec<u8>,
126 pub url: String,
128 pub wildcard: Option<String>,
134}
135
136impl<P, S> RouteContext<P, S> {
137 pub fn header(&self, name: &str) -> Option<&str> {
148 self.headers
149 .iter()
150 .find(|(k, _)| k.eq_ignore_ascii_case(name))
151 .map(|(_, v)| v.as_str())
152 }
153
154 pub fn body_to_str(&self) -> Result<&str, std::str::Utf8Error> {
168 std::str::from_utf8(&self.body)
169 }
170
171 pub fn json<T: serde::de::DeserializeOwned>(&self) -> Result<T, JsonBodyError> {
192 let text = std::str::from_utf8(&self.body).map_err(JsonBodyError::Utf8)?;
193 serde_json::from_str(text).map_err(JsonBodyError::Json)
194 }
195
196 pub fn form_data(&self) -> HashMap<String, String> {
209 parse_form_body(&self.body)
210 }
211
212 pub fn form<T: serde::de::DeserializeOwned>(&self) -> Result<T, FormBodyError> {
234 let text = std::str::from_utf8(&self.body).map_err(FormBodyError::Utf8)?;
235 serde_urlencoded::from_str(text).map_err(FormBodyError::Deserialize)
236 }
237}
238
239pub fn parse_query(url: &str) -> QueryParams {
246 let query_str = match url.split_once('?') {
247 Some((_, q)) => q,
248 None => return QueryParams::new(),
249 };
250
251 let query_str = query_str.split_once('#').map_or(query_str, |(q, _)| q);
253
254 query_str
255 .split('&')
256 .filter(|s| !s.is_empty())
257 .filter_map(|pair| {
258 let (key, value) = pair.split_once('=')?;
259 Some((url_decode(key).into_owned(), url_decode(value).into_owned()))
260 })
261 .collect()
262}
263
264pub fn url_decode(input: &str) -> std::borrow::Cow<'_, str> {
283 if !input.contains('%') && !input.contains('+') {
284 return std::borrow::Cow::Borrowed(input);
285 }
286
287 let mut bytes = Vec::with_capacity(input.len());
288 let mut chars = input.bytes();
289 while let Some(b) = chars.next() {
290 match b {
291 b'+' => bytes.push(b' '),
292 b'%' => {
293 let hi = chars.next().and_then(hex_val);
294 let lo = chars.next().and_then(hex_val);
295 match (hi, lo) {
296 (Some(h), Some(l)) => bytes.push(h << 4 | l),
297 _ => {
298 bytes.push(b'%');
300 }
301 }
302 }
303 _ => bytes.push(b),
304 }
305 }
306
307 String::from_utf8(bytes)
308 .map(std::borrow::Cow::Owned)
309 .unwrap_or_else(|e| {
310 std::borrow::Cow::Owned(String::from_utf8_lossy(e.as_bytes()).into_owned())
311 })
312}
313
314pub fn deserialize_search_params<S>(query_str: &str) -> S
323where
324 S: serde::de::DeserializeOwned + Default,
325{
326 let qs = query_str.strip_prefix('?').unwrap_or(query_str);
329 serde_urlencoded::from_str(qs).unwrap_or_default()
330}
331
332pub fn parse_form_body(body: &[u8]) -> HashMap<String, String> {
351 let input = String::from_utf8_lossy(body);
352 input
353 .split('&')
354 .filter(|s| !s.is_empty())
355 .filter_map(|pair| {
356 let (key, value) = pair.split_once('=')?;
357 Some((url_decode(key).into_owned(), url_decode(value).into_owned()))
358 })
359 .collect()
360}
361
362fn hex_val(b: u8) -> Option<u8> {
363 match b {
364 b'0'..=b'9' => Some(b - b'0'),
365 b'a'..=b'f' => Some(b - b'a' + 10),
366 b'A'..=b'F' => Some(b - b'A' + 10),
367 _ => None,
368 }
369}
370
371#[cfg(test)]
372mod tests {
373 use super::*;
374
375 #[test]
376 fn parse_query_basic() {
377 let q = parse_query("http://example.com/path?page=3&filter=active");
378 assert_eq!(q.get("page").unwrap(), "3");
379 assert_eq!(q.get("filter").unwrap(), "active");
380 }
381
382 #[test]
383 fn parse_query_empty_url() {
384 let q = parse_query("");
385 assert!(q.is_empty());
386 }
387
388 #[test]
389 fn parse_query_no_query_string() {
390 let q = parse_query("/path/to/resource");
391 assert!(q.is_empty());
392 }
393
394 #[test]
395 fn parse_query_empty_query_string() {
396 let q = parse_query("/path?");
397 assert!(q.is_empty());
398 }
399
400 #[test]
401 fn parse_query_with_fragment() {
402 let q = parse_query("/path?page=1#section");
403 assert_eq!(q.get("page").unwrap(), "1");
404 assert_eq!(q.len(), 1);
405 }
406
407 #[test]
408 fn parse_query_url_encoded_values() {
409 let q = parse_query("/search?q=hello+world&name=foo%20bar");
410 assert_eq!(q.get("q").unwrap(), "hello world");
411 assert_eq!(q.get("name").unwrap(), "foo bar");
412 }
413
414 #[test]
415 fn parse_query_skips_malformed_pairs() {
416 let q = parse_query("/path?good=yes&bad&also=fine");
418 assert_eq!(q.get("good").unwrap(), "yes");
419 assert_eq!(q.get("also").unwrap(), "fine");
420 assert_eq!(q.len(), 2);
421 }
422
423 #[test]
424 fn parse_query_empty_value() {
425 let q = parse_query("/path?key=");
426 assert_eq!(q.get("key").unwrap(), "");
427 }
428
429 #[test]
430 fn parse_query_multiple_equals() {
431 let q = parse_query("/path?expr=a=b");
433 assert_eq!(q.get("expr").unwrap(), "a=b");
434 }
435
436 #[test]
439 fn parse_query_bare_query_string() {
440 let q = parse_query("?page=3&filter=active");
444 assert_eq!(q.get("page").unwrap(), "3");
445 assert_eq!(q.get("filter").unwrap(), "active");
446 }
447
448 #[test]
451 fn parse_query_empty_string_returns_empty_hashmap() {
452 let q = parse_query("");
453 assert!(q.is_empty());
454 }
455
456 #[test]
459 fn deserialize_search_params_valid() {
460 #[derive(serde::Deserialize, Default, Debug)]
461 struct Sp {
462 page: Option<u32>,
463 filter: Option<String>,
464 }
465
466 let sp: Sp = deserialize_search_params("page=3&filter=active");
467 assert_eq!(sp.page, Some(3));
468 assert_eq!(sp.filter.as_deref(), Some("active"));
469 }
470
471 #[test]
472 fn deserialize_search_params_type_mismatch_falls_back() {
473 #[derive(serde::Deserialize, Default, Debug)]
476 struct Sp {
477 page: Option<u32>,
478 filter: Option<String>,
479 }
480
481 let sp: Sp = deserialize_search_params("page=abc&filter=active");
482 assert_eq!(sp.page, None);
484 assert_eq!(sp.filter, None);
485 }
486
487 #[test]
488 fn deserialize_search_params_empty_string() {
489 #[derive(serde::Deserialize, Default, Debug)]
490 struct Sp {
491 page: Option<u32>,
492 }
493
494 let sp: Sp = deserialize_search_params("");
495 assert_eq!(sp.page, None);
496 }
497
498 #[test]
499 fn deserialize_search_params_missing_fields_default_to_none() {
500 #[derive(serde::Deserialize, Default, Debug)]
501 struct Sp {
502 page: Option<u32>,
503 filter: Option<String>,
504 limit: Option<u32>,
505 }
506
507 let sp: Sp = deserialize_search_params("page=5");
508 assert_eq!(sp.page, Some(5));
509 assert_eq!(sp.filter, None);
510 assert_eq!(sp.limit, None);
511 }
512
513 #[test]
514 fn deserialize_search_params_with_leading_question_mark() {
515 #[derive(serde::Deserialize, Default, Debug)]
516 struct Sp {
517 page: Option<u32>,
518 }
519
520 let sp: Sp = deserialize_search_params("?page=7");
521 assert_eq!(sp.page, Some(7));
522 }
523
524 #[test]
525 fn deserialize_search_params_malformed_encoding_does_not_panic() {
526 #[derive(serde::Deserialize, Default, Debug)]
527 struct Sp {
528 q: Option<String>,
529 }
530
531 let sp: Sp = deserialize_search_params("q=%ZZ");
533 let _ = sp.q;
535 }
536
537 #[test]
540 fn url_decode_percent_encoding() {
541 assert_eq!(url_decode("hello%20world"), "hello world");
542 }
543
544 #[test]
545 fn url_decode_plus_as_space() {
546 assert_eq!(url_decode("a+b"), "a b");
547 }
548
549 #[test]
550 fn url_decode_malformed_passthrough() {
551 assert_eq!(url_decode("no%encoding"), "no%coding");
555 }
556
557 #[test]
558 fn url_decode_plain_passthrough() {
559 let result = url_decode("plain");
560 assert_eq!(result, "plain");
561 assert!(matches!(result, std::borrow::Cow::Borrowed(_)));
563 }
564
565 #[test]
568 fn url_decode_invalid_utf8_returns_valid_string() {
569 let result = url_decode("%FF%FE");
572 assert!(!result.is_empty());
574 assert!(result.contains('\u{FFFD}'));
576 }
577
578 #[test]
581 fn url_decode_trailing_percent() {
582 assert_eq!(url_decode("abc%"), "abc%");
584 }
585
586 #[test]
587 fn url_decode_only_percent() {
588 assert_eq!(url_decode("%"), "%");
590 }
591
592 #[test]
593 fn url_decode_percent_one_hex_then_eof() {
594 assert_eq!(url_decode("abc%4"), "abc%");
598 }
599
600 #[test]
601 fn url_decode_null_byte() {
602 let result = url_decode("%00");
604 assert_eq!(result, "\0");
605 assert_eq!(result.len(), 1);
606 }
607
608 #[test]
609 fn url_decode_double_encoded() {
610 assert_eq!(url_decode("%2520"), "%20");
613 }
614
615 #[test]
616 fn url_decode_empty_string() {
617 let result = url_decode("");
618 assert_eq!(result, "");
619 assert!(matches!(result, std::borrow::Cow::Borrowed(_)));
621 }
622
623 #[test]
626 fn parse_form_body_basic_pairs() {
627 let fields = parse_form_body(b"name=Alice&age=30");
628 assert_eq!(fields.get("name").unwrap(), "Alice");
629 assert_eq!(fields.get("age").unwrap(), "30");
630 }
631
632 #[test]
633 fn parse_form_body_plus_decoding() {
634 let fields = parse_form_body(b"q=hello+world");
635 assert_eq!(fields.get("q").unwrap(), "hello world");
636 }
637
638 #[test]
639 fn parse_form_body_empty() {
640 let fields = parse_form_body(b"");
641 assert!(fields.is_empty());
642 }
643
644 #[test]
645 fn parse_form_body_encoded_values() {
646 let fields = parse_form_body(b"key=val%26ue");
647 assert_eq!(fields.get("key").unwrap(), "val&ue");
648 }
649
650 fn test_ctx(headers: Vec<(String, String)>, body: Vec<u8>) -> RouteContext<()> {
653 RouteContext {
654 params: (),
655 search: (),
656 query: QueryParams::new(),
657 method: Method::GET,
658 headers,
659 body,
660 url: String::new(),
661 wildcard: None,
662 }
663 }
664
665 #[test]
668 fn header_case_insensitive() {
669 let ctx = test_ctx(
670 vec![("authorization".to_string(), "Bearer x".to_string())],
671 vec![],
672 );
673 assert_eq!(ctx.header("Authorization"), Some("Bearer x"));
674 assert_eq!(ctx.header("authorization"), Some("Bearer x"));
675 assert_eq!(ctx.header("AUTHORIZATION"), Some("Bearer x"));
676 }
677
678 #[test]
679 fn header_missing() {
680 let ctx = test_ctx(vec![], vec![]);
681 assert_eq!(ctx.header("x-missing"), None);
682 }
683
684 #[test]
685 fn header_first_match_wins() {
686 let ctx = test_ctx(
687 vec![
688 ("x-custom".to_string(), "first".to_string()),
689 ("x-custom".to_string(), "second".to_string()),
690 ],
691 vec![],
692 );
693 assert_eq!(ctx.header("x-custom"), Some("first"));
694 }
695
696 #[test]
699 fn body_to_str_valid_utf8() {
700 let ctx = test_ctx(vec![], b"hello".to_vec());
701 assert_eq!(ctx.body_to_str(), Ok("hello"));
702 }
703
704 #[test]
705 fn body_to_str_invalid_utf8() {
706 let ctx = test_ctx(vec![], vec![0xff, 0xfe]);
707 assert!(ctx.body_to_str().is_err());
708 }
709
710 #[test]
711 fn body_to_str_empty() {
712 let ctx = test_ctx(vec![], vec![]);
713 assert_eq!(ctx.body_to_str(), Ok(""));
714 }
715
716 #[test]
719 fn json_valid() {
720 #[derive(serde::Deserialize, Debug, PartialEq)]
721 struct Item {
722 name: String,
723 }
724 let ctx = test_ctx(vec![], br#"{"name":"test"}"#.to_vec());
725 let result: Result<Item, _> = ctx.json();
726 assert_eq!(
727 result.unwrap(),
728 Item {
729 name: "test".to_string()
730 }
731 );
732 }
733
734 #[test]
735 fn json_invalid_json() {
736 #[derive(serde::Deserialize)]
737 struct Item {
738 #[allow(dead_code)]
739 name: String,
740 }
741 let ctx = test_ctx(vec![], b"{invalid}".to_vec());
742 let result: Result<Item, _> = ctx.json();
743 assert!(matches!(result, Err(JsonBodyError::Json(_))));
744 }
745
746 #[test]
747 fn json_invalid_utf8() {
748 #[derive(serde::Deserialize)]
749 struct Item {
750 #[allow(dead_code)]
751 name: String,
752 }
753 let ctx = test_ctx(vec![], vec![0xff, 0xfe]);
754 let result: Result<Item, _> = ctx.json();
755 assert!(matches!(result, Err(JsonBodyError::Utf8(_))));
756 }
757
758 #[test]
759 fn json_empty_body() {
760 #[derive(serde::Deserialize)]
761 struct Item {
762 #[allow(dead_code)]
763 name: String,
764 }
765 let ctx = test_ctx(vec![], vec![]);
766 let result: Result<Item, _> = ctx.json();
767 assert!(matches!(result, Err(JsonBodyError::Json(_))));
768 }
769
770 #[test]
773 fn form_data_basic() {
774 let ctx = test_ctx(vec![], b"name=Alice&age=30".to_vec());
775 let fields = ctx.form_data();
776 assert_eq!(fields.get("name").unwrap(), "Alice");
777 assert_eq!(fields.get("age").unwrap(), "30");
778 }
779
780 #[test]
781 fn form_data_empty() {
782 let ctx = test_ctx(vec![], vec![]);
783 let fields = ctx.form_data();
784 assert!(fields.is_empty());
785 }
786
787 #[test]
788 fn form_data_url_encoded() {
789 let ctx = test_ctx(vec![], b"greeting=hello+world&path=%2Ffoo%2Fbar".to_vec());
790 let fields = ctx.form_data();
791 assert_eq!(fields.get("greeting").unwrap(), "hello world");
792 assert_eq!(fields.get("path").unwrap(), "/foo/bar");
793 }
794
795 #[test]
798 fn form_valid() {
799 #[derive(serde::Deserialize, Debug, PartialEq)]
800 struct Comment {
801 author: String,
802 body: String,
803 }
804 let ctx = test_ctx(vec![], b"author=Alice&body=hello".to_vec());
805 let result: Result<Comment, _> = ctx.form();
806 assert_eq!(
807 result.unwrap(),
808 Comment {
809 author: "Alice".to_string(),
810 body: "hello".to_string(),
811 }
812 );
813 }
814
815 #[test]
816 fn form_missing_field() {
817 #[derive(serde::Deserialize)]
818 struct Comment {
819 #[allow(dead_code)]
820 author: String,
821 #[allow(dead_code)]
822 body: String,
823 }
824 let ctx = test_ctx(vec![], b"author=Alice".to_vec());
825 let result: Result<Comment, _> = ctx.form();
826 assert!(matches!(result, Err(FormBodyError::Deserialize(_))));
827 }
828
829 #[test]
830 fn form_invalid_utf8() {
831 #[derive(serde::Deserialize)]
832 struct Comment {
833 #[allow(dead_code)]
834 author: String,
835 }
836 let ctx = test_ctx(vec![], vec![0xff, 0xfe]);
837 let result: Result<Comment, _> = ctx.form();
838 assert!(matches!(result, Err(FormBodyError::Utf8(_))));
839 }
840
841 #[test]
842 fn form_empty_body_with_optional_fields() {
843 #[derive(serde::Deserialize, Debug, PartialEq)]
844 struct Opts {
845 name: Option<String>,
846 }
847 let ctx = test_ctx(vec![], vec![]);
848 let result: Result<Opts, _> = ctx.form();
849 assert_eq!(result.unwrap(), Opts { name: None });
850 }
851}