use std::collections::HashMap;
use ic_http_certification::{HeaderField, Method};
#[derive(Debug)]
pub enum JsonBodyError {
Utf8(std::str::Utf8Error),
Json(serde_json::Error),
}
impl std::fmt::Display for JsonBodyError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Utf8(e) => write!(f, "body is not valid UTF-8: {e}"),
Self::Json(e) => write!(f, "JSON deserialization failed: {e}"),
}
}
}
impl std::error::Error for JsonBodyError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
Self::Utf8(e) => Some(e),
Self::Json(e) => Some(e),
}
}
}
#[derive(Debug)]
pub enum FormBodyError {
Utf8(std::str::Utf8Error),
Deserialize(serde_urlencoded::de::Error),
}
impl std::fmt::Display for FormBodyError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Utf8(e) => write!(f, "body is not valid UTF-8: {e}"),
Self::Deserialize(e) => write!(f, "form deserialization failed: {e}"),
}
}
}
impl std::error::Error for FormBodyError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
Self::Utf8(e) => Some(e),
Self::Deserialize(e) => Some(e),
}
}
}
pub type QueryParams = HashMap<String, String>;
pub struct RouteContext<P, S = ()> {
pub params: P,
pub search: S,
pub query: QueryParams,
pub method: Method,
pub headers: Vec<HeaderField>,
pub body: Vec<u8>,
pub url: String,
pub wildcard: Option<String>,
}
impl<P, S> RouteContext<P, S> {
pub fn header(&self, name: &str) -> Option<&str> {
self.headers
.iter()
.find(|(k, _)| k.eq_ignore_ascii_case(name))
.map(|(_, v)| v.as_str())
}
pub fn body_to_str(&self) -> Result<&str, std::str::Utf8Error> {
std::str::from_utf8(&self.body)
}
pub fn json<T: serde::de::DeserializeOwned>(&self) -> Result<T, JsonBodyError> {
let text = std::str::from_utf8(&self.body).map_err(JsonBodyError::Utf8)?;
serde_json::from_str(text).map_err(JsonBodyError::Json)
}
pub fn form_data(&self) -> HashMap<String, String> {
parse_form_body(&self.body)
}
pub fn form<T: serde::de::DeserializeOwned>(&self) -> Result<T, FormBodyError> {
let text = std::str::from_utf8(&self.body).map_err(FormBodyError::Utf8)?;
serde_urlencoded::from_str(text).map_err(FormBodyError::Deserialize)
}
}
pub fn parse_query(url: &str) -> QueryParams {
let query_str = match url.split_once('?') {
Some((_, q)) => q,
None => return QueryParams::new(),
};
let query_str = query_str.split_once('#').map_or(query_str, |(q, _)| q);
query_str
.split('&')
.filter(|s| !s.is_empty())
.filter_map(|pair| {
let (key, value) = pair.split_once('=')?;
Some((url_decode(key).into_owned(), url_decode(value).into_owned()))
})
.collect()
}
pub fn url_decode(input: &str) -> std::borrow::Cow<'_, str> {
if !input.contains('%') && !input.contains('+') {
return std::borrow::Cow::Borrowed(input);
}
let mut bytes = Vec::with_capacity(input.len());
let mut chars = input.bytes();
while let Some(b) = chars.next() {
match b {
b'+' => bytes.push(b' '),
b'%' => {
let hi = chars.next().and_then(hex_val);
let lo = chars.next().and_then(hex_val);
match (hi, lo) {
(Some(h), Some(l)) => bytes.push(h << 4 | l),
_ => {
bytes.push(b'%');
}
}
}
_ => bytes.push(b),
}
}
String::from_utf8(bytes)
.map(std::borrow::Cow::Owned)
.unwrap_or_else(|e| {
std::borrow::Cow::Owned(String::from_utf8_lossy(e.as_bytes()).into_owned())
})
}
pub fn deserialize_search_params<S>(query_str: &str) -> S
where
S: serde::de::DeserializeOwned + Default,
{
let qs = query_str.strip_prefix('?').unwrap_or(query_str);
serde_urlencoded::from_str(qs).unwrap_or_default()
}
pub fn parse_form_body(body: &[u8]) -> HashMap<String, String> {
let input = String::from_utf8_lossy(body);
input
.split('&')
.filter(|s| !s.is_empty())
.filter_map(|pair| {
let (key, value) = pair.split_once('=')?;
Some((url_decode(key).into_owned(), url_decode(value).into_owned()))
})
.collect()
}
fn hex_val(b: u8) -> Option<u8> {
match b {
b'0'..=b'9' => Some(b - b'0'),
b'a'..=b'f' => Some(b - b'a' + 10),
b'A'..=b'F' => Some(b - b'A' + 10),
_ => None,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_query_basic() {
let q = parse_query("http://example.com/path?page=3&filter=active");
assert_eq!(q.get("page").unwrap(), "3");
assert_eq!(q.get("filter").unwrap(), "active");
}
#[test]
fn parse_query_empty_url() {
let q = parse_query("");
assert!(q.is_empty());
}
#[test]
fn parse_query_no_query_string() {
let q = parse_query("/path/to/resource");
assert!(q.is_empty());
}
#[test]
fn parse_query_empty_query_string() {
let q = parse_query("/path?");
assert!(q.is_empty());
}
#[test]
fn parse_query_with_fragment() {
let q = parse_query("/path?page=1#section");
assert_eq!(q.get("page").unwrap(), "1");
assert_eq!(q.len(), 1);
}
#[test]
fn parse_query_url_encoded_values() {
let q = parse_query("/search?q=hello+world&name=foo%20bar");
assert_eq!(q.get("q").unwrap(), "hello world");
assert_eq!(q.get("name").unwrap(), "foo bar");
}
#[test]
fn parse_query_skips_malformed_pairs() {
let q = parse_query("/path?good=yes&bad&also=fine");
assert_eq!(q.get("good").unwrap(), "yes");
assert_eq!(q.get("also").unwrap(), "fine");
assert_eq!(q.len(), 2);
}
#[test]
fn parse_query_empty_value() {
let q = parse_query("/path?key=");
assert_eq!(q.get("key").unwrap(), "");
}
#[test]
fn parse_query_multiple_equals() {
let q = parse_query("/path?expr=a=b");
assert_eq!(q.get("expr").unwrap(), "a=b");
}
#[test]
fn parse_query_bare_query_string() {
let q = parse_query("?page=3&filter=active");
assert_eq!(q.get("page").unwrap(), "3");
assert_eq!(q.get("filter").unwrap(), "active");
}
#[test]
fn parse_query_empty_string_returns_empty_hashmap() {
let q = parse_query("");
assert!(q.is_empty());
}
#[test]
fn deserialize_search_params_valid() {
#[derive(serde::Deserialize, Default, Debug)]
struct Sp {
page: Option<u32>,
filter: Option<String>,
}
let sp: Sp = deserialize_search_params("page=3&filter=active");
assert_eq!(sp.page, Some(3));
assert_eq!(sp.filter.as_deref(), Some("active"));
}
#[test]
fn deserialize_search_params_type_mismatch_falls_back() {
#[derive(serde::Deserialize, Default, Debug)]
struct Sp {
page: Option<u32>,
filter: Option<String>,
}
let sp: Sp = deserialize_search_params("page=abc&filter=active");
assert_eq!(sp.page, None);
assert_eq!(sp.filter, None);
}
#[test]
fn deserialize_search_params_empty_string() {
#[derive(serde::Deserialize, Default, Debug)]
struct Sp {
page: Option<u32>,
}
let sp: Sp = deserialize_search_params("");
assert_eq!(sp.page, None);
}
#[test]
fn deserialize_search_params_missing_fields_default_to_none() {
#[derive(serde::Deserialize, Default, Debug)]
struct Sp {
page: Option<u32>,
filter: Option<String>,
limit: Option<u32>,
}
let sp: Sp = deserialize_search_params("page=5");
assert_eq!(sp.page, Some(5));
assert_eq!(sp.filter, None);
assert_eq!(sp.limit, None);
}
#[test]
fn deserialize_search_params_with_leading_question_mark() {
#[derive(serde::Deserialize, Default, Debug)]
struct Sp {
page: Option<u32>,
}
let sp: Sp = deserialize_search_params("?page=7");
assert_eq!(sp.page, Some(7));
}
#[test]
fn deserialize_search_params_malformed_encoding_does_not_panic() {
#[derive(serde::Deserialize, Default, Debug)]
struct Sp {
q: Option<String>,
}
let sp: Sp = deserialize_search_params("q=%ZZ");
let _ = sp.q;
}
#[test]
fn url_decode_percent_encoding() {
assert_eq!(url_decode("hello%20world"), "hello world");
}
#[test]
fn url_decode_plus_as_space() {
assert_eq!(url_decode("a+b"), "a b");
}
#[test]
fn url_decode_malformed_passthrough() {
assert_eq!(url_decode("no%encoding"), "no%coding");
}
#[test]
fn url_decode_plain_passthrough() {
let result = url_decode("plain");
assert_eq!(result, "plain");
assert!(matches!(result, std::borrow::Cow::Borrowed(_)));
}
#[test]
fn url_decode_invalid_utf8_returns_valid_string() {
let result = url_decode("%FF%FE");
assert!(!result.is_empty());
assert!(result.contains('\u{FFFD}'));
}
#[test]
fn url_decode_trailing_percent() {
assert_eq!(url_decode("abc%"), "abc%");
}
#[test]
fn url_decode_only_percent() {
assert_eq!(url_decode("%"), "%");
}
#[test]
fn url_decode_percent_one_hex_then_eof() {
assert_eq!(url_decode("abc%4"), "abc%");
}
#[test]
fn url_decode_null_byte() {
let result = url_decode("%00");
assert_eq!(result, "\0");
assert_eq!(result.len(), 1);
}
#[test]
fn url_decode_double_encoded() {
assert_eq!(url_decode("%2520"), "%20");
}
#[test]
fn url_decode_empty_string() {
let result = url_decode("");
assert_eq!(result, "");
assert!(matches!(result, std::borrow::Cow::Borrowed(_)));
}
#[test]
fn parse_form_body_basic_pairs() {
let fields = parse_form_body(b"name=Alice&age=30");
assert_eq!(fields.get("name").unwrap(), "Alice");
assert_eq!(fields.get("age").unwrap(), "30");
}
#[test]
fn parse_form_body_plus_decoding() {
let fields = parse_form_body(b"q=hello+world");
assert_eq!(fields.get("q").unwrap(), "hello world");
}
#[test]
fn parse_form_body_empty() {
let fields = parse_form_body(b"");
assert!(fields.is_empty());
}
#[test]
fn parse_form_body_encoded_values() {
let fields = parse_form_body(b"key=val%26ue");
assert_eq!(fields.get("key").unwrap(), "val&ue");
}
fn test_ctx(headers: Vec<(String, String)>, body: Vec<u8>) -> RouteContext<()> {
RouteContext {
params: (),
search: (),
query: QueryParams::new(),
method: Method::GET,
headers,
body,
url: String::new(),
wildcard: None,
}
}
#[test]
fn header_case_insensitive() {
let ctx = test_ctx(
vec![("authorization".to_string(), "Bearer x".to_string())],
vec![],
);
assert_eq!(ctx.header("Authorization"), Some("Bearer x"));
assert_eq!(ctx.header("authorization"), Some("Bearer x"));
assert_eq!(ctx.header("AUTHORIZATION"), Some("Bearer x"));
}
#[test]
fn header_missing() {
let ctx = test_ctx(vec![], vec![]);
assert_eq!(ctx.header("x-missing"), None);
}
#[test]
fn header_first_match_wins() {
let ctx = test_ctx(
vec![
("x-custom".to_string(), "first".to_string()),
("x-custom".to_string(), "second".to_string()),
],
vec![],
);
assert_eq!(ctx.header("x-custom"), Some("first"));
}
#[test]
fn body_to_str_valid_utf8() {
let ctx = test_ctx(vec![], b"hello".to_vec());
assert_eq!(ctx.body_to_str(), Ok("hello"));
}
#[test]
fn body_to_str_invalid_utf8() {
let ctx = test_ctx(vec![], vec![0xff, 0xfe]);
assert!(ctx.body_to_str().is_err());
}
#[test]
fn body_to_str_empty() {
let ctx = test_ctx(vec![], vec![]);
assert_eq!(ctx.body_to_str(), Ok(""));
}
#[test]
fn json_valid() {
#[derive(serde::Deserialize, Debug, PartialEq)]
struct Item {
name: String,
}
let ctx = test_ctx(vec![], br#"{"name":"test"}"#.to_vec());
let result: Result<Item, _> = ctx.json();
assert_eq!(
result.unwrap(),
Item {
name: "test".to_string()
}
);
}
#[test]
fn json_invalid_json() {
#[derive(serde::Deserialize)]
struct Item {
#[allow(dead_code)]
name: String,
}
let ctx = test_ctx(vec![], b"{invalid}".to_vec());
let result: Result<Item, _> = ctx.json();
assert!(matches!(result, Err(JsonBodyError::Json(_))));
}
#[test]
fn json_invalid_utf8() {
#[derive(serde::Deserialize)]
struct Item {
#[allow(dead_code)]
name: String,
}
let ctx = test_ctx(vec![], vec![0xff, 0xfe]);
let result: Result<Item, _> = ctx.json();
assert!(matches!(result, Err(JsonBodyError::Utf8(_))));
}
#[test]
fn json_empty_body() {
#[derive(serde::Deserialize)]
struct Item {
#[allow(dead_code)]
name: String,
}
let ctx = test_ctx(vec![], vec![]);
let result: Result<Item, _> = ctx.json();
assert!(matches!(result, Err(JsonBodyError::Json(_))));
}
#[test]
fn form_data_basic() {
let ctx = test_ctx(vec![], b"name=Alice&age=30".to_vec());
let fields = ctx.form_data();
assert_eq!(fields.get("name").unwrap(), "Alice");
assert_eq!(fields.get("age").unwrap(), "30");
}
#[test]
fn form_data_empty() {
let ctx = test_ctx(vec![], vec![]);
let fields = ctx.form_data();
assert!(fields.is_empty());
}
#[test]
fn form_data_url_encoded() {
let ctx = test_ctx(vec![], b"greeting=hello+world&path=%2Ffoo%2Fbar".to_vec());
let fields = ctx.form_data();
assert_eq!(fields.get("greeting").unwrap(), "hello world");
assert_eq!(fields.get("path").unwrap(), "/foo/bar");
}
#[test]
fn form_valid() {
#[derive(serde::Deserialize, Debug, PartialEq)]
struct Comment {
author: String,
body: String,
}
let ctx = test_ctx(vec![], b"author=Alice&body=hello".to_vec());
let result: Result<Comment, _> = ctx.form();
assert_eq!(
result.unwrap(),
Comment {
author: "Alice".to_string(),
body: "hello".to_string(),
}
);
}
#[test]
fn form_missing_field() {
#[derive(serde::Deserialize)]
struct Comment {
#[allow(dead_code)]
author: String,
#[allow(dead_code)]
body: String,
}
let ctx = test_ctx(vec![], b"author=Alice".to_vec());
let result: Result<Comment, _> = ctx.form();
assert!(matches!(result, Err(FormBodyError::Deserialize(_))));
}
#[test]
fn form_invalid_utf8() {
#[derive(serde::Deserialize)]
struct Comment {
#[allow(dead_code)]
author: String,
}
let ctx = test_ctx(vec![], vec![0xff, 0xfe]);
let result: Result<Comment, _> = ctx.form();
assert!(matches!(result, Err(FormBodyError::Utf8(_))));
}
#[test]
fn form_empty_body_with_optional_fields() {
#[derive(serde::Deserialize, Debug, PartialEq)]
struct Opts {
name: Option<String>,
}
let ctx = test_ctx(vec![], vec![]);
let result: Result<Opts, _> = ctx.form();
assert_eq!(result.unwrap(), Opts { name: None });
}
}