use std::collections::HashMap;
use std::io;
use bytes::Bytes;
use hyper::Method;
use serde::Deserialize;
use serde_json::json;
use uuid::Uuid;
use hyperlite::{
get_extension, parse_json_body, path_param, path_params, query_params, BoxError, PathParams,
};
mod test_helpers;
use test_helpers::*;
fn expect_io_error_kind(err: BoxError, kind: io::ErrorKind) -> Box<io::Error> {
let io_err = err
.downcast::<io::Error>()
.expect("expected io::Error from extractor error");
assert_eq!(io_err.kind(), kind);
io_err
}
#[derive(Debug, Deserialize, PartialEq)]
struct CreateUser {
name: String,
age: u32,
}
#[derive(Debug, Deserialize, PartialEq)]
struct NestedPayload {
info: InnerPayload,
}
#[derive(Debug, Deserialize, PartialEq)]
struct InnerPayload {
title: String,
#[serde(default)]
tags: Vec<String>,
}
#[derive(Debug, Deserialize, PartialEq)]
struct QueryStruct {
q: String,
limit: u32,
}
#[derive(Debug, Deserialize, PartialEq)]
struct OptionalQuery {
q: Option<String>,
#[serde(default)]
page: Option<u32>,
}
#[tokio::test]
async fn test_parse_json_body_success() {
let request = build_json_request(
Method::POST,
"/users",
json!({ "name": "Alice", "age": 30 }),
);
let payload = parse_json_body::<CreateUser>(request).await.unwrap();
assert_eq!(payload.name, "Alice");
assert_eq!(payload.age, 30);
}
#[tokio::test]
async fn test_parse_json_body_with_charset() {
let mut request =
build_json_request(Method::POST, "/users", json!({ "name": "Bob", "age": 25 }));
request.headers_mut().insert(
hyper::header::CONTENT_TYPE,
hyper::header::HeaderValue::from_static("application/json; charset=utf-8"),
);
let payload = parse_json_body::<CreateUser>(request).await.unwrap();
assert_eq!(payload.name, "Bob");
}
#[tokio::test]
async fn test_parse_json_body_missing_content_type() {
let request = build_request(
Method::POST,
"/users",
json_body(json!({ "name": "Alice" })),
);
let err = parse_json_body::<CreateUser>(request).await.unwrap_err();
let io_err = expect_io_error_kind(err, io::ErrorKind::InvalidInput);
assert!(io_err.to_string().contains("Content-Type"));
}
#[tokio::test]
async fn test_parse_json_body_wrong_content_type() {
let body = json_body(json!({ "name": "Alice" }));
let request = build_request_with_headers(
Method::POST,
"/users",
body,
vec![(
hyper::header::CONTENT_TYPE,
hyper::header::HeaderValue::from_static("text/plain"),
)],
);
let err = parse_json_body::<CreateUser>(request).await.unwrap_err();
let io_err = expect_io_error_kind(err, io::ErrorKind::InvalidInput);
assert!(io_err.to_string().contains("Content-Type"));
}
#[tokio::test]
async fn test_parse_json_body_invalid_json() {
let request = build_request_with_headers(
Method::POST,
"/users",
bytes_body(Bytes::from_static(b"{invalid}")),
vec![(
hyper::header::CONTENT_TYPE,
hyper::header::HeaderValue::from_static("application/json"),
)],
);
let err = parse_json_body::<CreateUser>(request).await.unwrap_err();
let io_err = expect_io_error_kind(err, io::ErrorKind::InvalidData);
assert!(io_err.to_string().contains("Invalid JSON payload"));
}
#[tokio::test]
async fn test_parse_json_body_size_limit() {
let target_size = 1_048_577usize;
let prefix = b"{\"data\":\"";
let suffix = b"\"}";
let filler_len = target_size - prefix.len() - suffix.len();
let mut payload = Vec::with_capacity(target_size);
payload.extend_from_slice(prefix);
payload.extend(std::iter::repeat_n(b'a', filler_len));
payload.extend_from_slice(suffix);
assert_eq!(payload.len(), target_size);
let content_length = hyper::header::HeaderValue::from_str(&payload.len().to_string()).unwrap();
let request = build_request_with_headers(
Method::POST,
"/users",
bytes_body(Bytes::from(payload)),
vec![
(
hyper::header::CONTENT_TYPE,
hyper::header::HeaderValue::from_static("application/json"),
),
(hyper::header::CONTENT_LENGTH, content_length),
],
);
let err = parse_json_body::<CreateUser>(request).await.unwrap_err();
let io_err = expect_io_error_kind(err, io::ErrorKind::InvalidData);
assert!(io_err
.to_string()
.contains("Request body exceeds 1MB limit"));
}
#[tokio::test]
async fn test_parse_json_body_empty() {
let request = build_request_with_headers(
Method::POST,
"/users",
bytes_body(Bytes::new()),
vec![(
hyper::header::CONTENT_TYPE,
hyper::header::HeaderValue::from_static("application/json"),
)],
);
let err = parse_json_body::<CreateUser>(request).await.unwrap_err();
let io_err = expect_io_error_kind(err, io::ErrorKind::InvalidData);
assert!(io_err.to_string().contains("Invalid JSON payload"));
}
#[tokio::test]
async fn test_parse_json_body_nested_objects() {
let request = build_json_request(
Method::POST,
"/nested",
json!({ "info": { "title": "Post", "tags": ["a", "b"] } }),
);
let payload = parse_json_body::<NestedPayload>(request).await.unwrap();
assert_eq!(payload.info.title, "Post");
assert_eq!(payload.info.tags.len(), 2);
}
#[tokio::test]
async fn test_parse_json_body_arrays() {
let request = build_json_request(Method::POST, "/tags", json!(["one", "two", "three"]));
let payload = parse_json_body::<Vec<String>>(request).await.unwrap();
assert_eq!(payload.len(), 3);
}
#[test]
fn test_query_params_success() {
#[derive(Deserialize)]
struct Query {
q: String,
limit: u32,
}
let request = build_request(Method::GET, "/search?q=test&limit=10", empty_body());
let params = query_params::<Query>(&request).unwrap();
assert_eq!(params.q, "test");
assert_eq!(params.limit, 10);
}
#[test]
fn test_query_params_multiple_values() {
#[derive(Deserialize)]
struct Query {
a: String,
b: String,
c: String,
}
let request = build_request(Method::GET, "/path?a=1&b=2&c=3", empty_body());
let params = query_params::<Query>(&request).unwrap();
assert_eq!(params.a, "1");
assert_eq!(params.b, "2");
assert_eq!(params.c, "3");
}
#[test]
fn test_query_params_optional_fields() {
let request = build_request(Method::GET, "/search?page=5", empty_body());
let params = query_params::<OptionalQuery>(&request).unwrap();
assert_eq!(params.page, Some(5));
assert!(params.q.is_none());
}
#[test]
fn test_query_params_missing_required() {
let request = build_request(Method::GET, "/search?page=5", empty_body());
let err = query_params::<QueryStruct>(&request).unwrap_err();
let io_err = expect_io_error_kind(err, io::ErrorKind::InvalidInput);
assert!(io_err.to_string().contains("Invalid query parameters"));
}
#[test]
fn test_query_params_invalid_type() {
let request = build_request(Method::GET, "/search?limit=abc&q=foo", empty_body());
let err = query_params::<QueryStruct>(&request).unwrap_err();
let io_err = expect_io_error_kind(err, io::ErrorKind::InvalidInput);
assert!(io_err.to_string().contains("Invalid query parameters"));
}
#[test]
fn test_query_params_no_query_string() {
let request = build_request(Method::GET, "/search", empty_body());
let err = query_params::<QueryStruct>(&request).unwrap_err();
let io_err = expect_io_error_kind(err, io::ErrorKind::InvalidInput);
assert!(io_err.to_string().contains("No query parameters"));
}
#[test]
fn test_query_params_empty_query() {
#[derive(Deserialize)]
struct QueryDefault {
#[serde(default)]
flag: bool,
}
let request = build_request(Method::GET, "/search?", empty_body());
let params = query_params::<QueryDefault>(&request).unwrap();
assert!(!params.flag);
}
#[test]
fn test_query_params_url_encoded() {
let request = build_request(Method::GET, "/search?q=hello%20world", empty_body());
let params = query_params::<OptionalQuery>(&request).unwrap();
assert_eq!(params.q, Some("hello world".into()));
}
#[test]
fn test_query_params_special_chars() {
let request = build_request(Method::GET, "/search?q=name%40example.com", empty_body());
let params = query_params::<OptionalQuery>(&request).unwrap();
assert_eq!(params.q, Some("name@example.com".into()));
}
#[test]
fn test_path_params_success() {
let mut map = HashMap::new();
map.insert("id".into(), "123".into());
let mut request = build_request(Method::GET, "/users/123", empty_body());
request.extensions_mut().insert(PathParams(map.clone()));
let params = path_params(&request).unwrap();
assert_eq!(params, map);
}
#[test]
fn test_path_params_multiple() {
let mut map = HashMap::new();
map.insert("user_id".into(), "10".into());
map.insert("post_id".into(), "5".into());
let mut request = build_request(Method::GET, "/users/10/posts/5", empty_body());
request.extensions_mut().insert(PathParams(map.clone()));
let params = path_params(&request).unwrap();
assert_eq!(params, map);
}
#[test]
fn test_path_params_empty() {
let request = build_request(Method::GET, "/health", empty_body());
let params = path_params(&request).unwrap();
assert!(params.is_empty());
}
#[test]
fn test_path_params_missing_extension() {
let request = build_request(Method::GET, "/health", empty_body());
let params = path_params(&request).unwrap();
assert!(params.is_empty());
}
#[test]
fn test_path_param_success() {
let mut map = HashMap::new();
map.insert("id".into(), "abc".into());
let mut request = build_request(Method::GET, "/users/abc", empty_body());
request.extensions_mut().insert(PathParams(map));
let value: String = path_param(&request, "id").unwrap();
assert_eq!(value, "abc");
}
#[test]
fn test_path_param_uuid() {
let id = Uuid::new_v4();
let mut map = HashMap::new();
map.insert("id".into(), id.to_string());
let mut request = build_request(Method::GET, "/users", empty_body());
request.extensions_mut().insert(PathParams(map));
let value: Uuid = path_param(&request, "id").unwrap();
assert_eq!(value, id);
}
#[test]
fn test_path_param_integer() {
let mut map = HashMap::new();
map.insert("id".into(), "42".into());
let mut request = build_request(Method::GET, "/users", empty_body());
request.extensions_mut().insert(PathParams(map));
let value: i32 = path_param(&request, "id").unwrap();
assert_eq!(value, 42);
}
#[test]
fn test_path_param_missing() {
let request = build_request(Method::GET, "/users", empty_body());
let err = path_param::<String>(&request, "id").unwrap_err();
let io_err = expect_io_error_kind(err, io::ErrorKind::InvalidInput);
assert!(io_err.to_string().contains("Path parameter"));
}
#[test]
fn test_path_param_invalid_type() {
let mut map = HashMap::new();
map.insert("id".into(), "abc".into());
let mut request = build_request(Method::GET, "/users", empty_body());
request.extensions_mut().insert(PathParams(map));
let err = path_param::<u32>(&request, "id").unwrap_err();
let io_err = expect_io_error_kind(err, io::ErrorKind::InvalidInput);
assert!(io_err.to_string().contains("Invalid path parameter"));
}
#[test]
fn test_path_param_empty_value() {
let mut map = HashMap::new();
map.insert("id".into(), "".into());
let mut request = build_request(Method::GET, "/users", empty_body());
request.extensions_mut().insert(PathParams(map));
let err = path_param::<usize>(&request, "id").unwrap_err();
let io_err = expect_io_error_kind(err, io::ErrorKind::InvalidInput);
assert!(io_err.to_string().contains("Invalid path parameter"));
}
#[test]
fn test_get_extension_success() {
let mut request = build_request(Method::GET, "/", empty_body());
request.extensions_mut().insert(42u32);
let value: u32 = get_extension(&request).unwrap();
assert_eq!(value, 42);
}
#[test]
fn test_get_extension_uuid() {
let mut request = build_request(Method::GET, "/", empty_body());
let id = Uuid::new_v4();
request.extensions_mut().insert(id);
let value: Uuid = get_extension(&request).unwrap();
assert_eq!(value, id);
}
#[test]
fn test_get_extension_string() {
let mut request = build_request(Method::GET, "/", empty_body());
request.extensions_mut().insert(String::from("hello"));
let value: String = get_extension(&request).unwrap();
assert_eq!(value, "hello");
}
#[test]
fn test_get_extension_custom_type() {
#[derive(Clone, PartialEq, Debug)]
struct Custom(u32);
let mut request = build_request(Method::GET, "/", empty_body());
request.extensions_mut().insert(Custom(7));
let value: Custom = get_extension(&request).unwrap();
assert_eq!(value.0, 7);
}
#[test]
fn test_get_extension_missing() {
let request = build_request(Method::GET, "/", empty_body());
let err = get_extension::<String>(&request).unwrap_err();
let io_err = expect_io_error_kind(err, io::ErrorKind::NotFound);
assert!(io_err.to_string().contains("Extension of type"));
}
#[test]
fn test_get_extension_wrong_type() {
let mut request = build_request(Method::GET, "/", empty_body());
request.extensions_mut().insert(String::from("hello"));
let err = get_extension::<u32>(&request).unwrap_err();
let io_err = expect_io_error_kind(err, io::ErrorKind::NotFound);
assert!(io_err.to_string().contains("Extension of type"));
}
#[tokio::test]
async fn test_extract_json_and_query() {
let request = build_json_request(
Method::POST,
"/search?q=test&limit=5",
json!({"name": "Alice", "age": 30}),
);
let params = query_params::<QueryStruct>(&request).unwrap();
assert_eq!(params.limit, 5);
let payload = parse_json_body::<CreateUser>(request).await.unwrap();
assert_eq!(payload.name, "Alice");
}
#[tokio::test]
async fn test_extract_path_and_extension() {
let id = Uuid::new_v4();
let mut request = build_json_request(Method::GET, "/users", json!({"id": id}));
let mut map = HashMap::new();
map.insert("id".into(), id.to_string());
request.extensions_mut().insert(PathParams(map));
request.extensions_mut().insert(id);
let extracted_id: Uuid = path_param(&request, "id").unwrap();
let extension_id: Uuid = get_extension(&request).unwrap();
assert_eq!(extracted_id, id);
assert_eq!(extension_id, id);
}
#[tokio::test]
async fn test_extract_all() {
#[derive(Deserialize, PartialEq, Debug)]
struct BodyPayload {
name: String,
}
let mut request = build_json_request(
Method::POST,
"/combo?q=test&limit=2",
json!({"name": "Alice"}),
);
let mut map = HashMap::new();
map.insert("id".into(), "42".into());
request.extensions_mut().insert(PathParams(map));
request.extensions_mut().insert(String::from("ext"));
let params = query_params::<QueryStruct>(&request).unwrap();
let path_id: i32 = path_param(&request, "id").unwrap();
let extension: String = get_extension(&request).unwrap();
let body = parse_json_body::<BodyPayload>(request).await.unwrap();
assert_eq!(params.limit, 2);
assert_eq!(path_id, 42);
assert_eq!(extension, "ext");
assert_eq!(body.name, "Alice");
}
#[tokio::test]
async fn test_parse_json_body_error_message() {
let request = build_request(Method::POST, "/fail", json_body(json!({"a": 1}))); let err = parse_json_body::<CreateUser>(request).await.unwrap_err();
let io_err = expect_io_error_kind(err, io::ErrorKind::InvalidInput);
assert!(io_err.to_string().contains("Content-Type"));
}
#[test]
fn test_query_params_error_message() {
let request = build_request(Method::GET, "/search?limit=abc", empty_body());
let err = query_params::<QueryStruct>(&request).unwrap_err();
let io_err = expect_io_error_kind(err, io::ErrorKind::InvalidInput);
assert!(io_err.to_string().contains("Invalid query parameters"));
}
#[test]
fn test_path_param_error_message() {
let request = build_request(Method::GET, "/users", empty_body());
let err = path_param::<u32>(&request, "missing").unwrap_err();
let io_err = expect_io_error_kind(err, io::ErrorKind::InvalidInput);
assert!(io_err.to_string().contains("Path parameter"));
}
#[test]
fn test_get_extension_error_message() {
let request = build_request(Method::GET, "/", empty_body());
let err = get_extension::<u64>(&request).unwrap_err();
let io_err = expect_io_error_kind(err, io::ErrorKind::NotFound);
assert!(io_err.to_string().contains("Extension of type"));
}