use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use bytes::Bytes;
use http::{Method, Response, StatusCode};
use http_body_util::Full;
use crate::app_state::AppState;
pub type Body = Full<Bytes>;
pub type HandlerFuture = Pin<Box<dyn Future<Output = Response<Body>> + Send>>;
pub type HandlerFn = Arc<
dyn Fn(AppState, http::request::Parts, Vec<(String, String)>) -> HandlerFuture + Send + Sync,
>;
pub struct Router {
trees: HashMap<Method, matchit::Router<HandlerFn>>,
state: AppState,
}
impl Router {
#[must_use]
pub fn new(state: AppState) -> Self {
Self {
trees: HashMap::new(),
state,
}
}
pub fn route(
mut self,
method: Method,
path: &str,
handler: HandlerFn,
) -> Result<Self, matchit::InsertError> {
self.trees
.entry(method)
.or_default()
.insert(path, handler)?;
Ok(self)
}
pub fn get(self, path: &str, handler: HandlerFn) -> Result<Self, matchit::InsertError> {
self.route(Method::GET, path, handler)
}
pub fn post(self, path: &str, handler: HandlerFn) -> Result<Self, matchit::InsertError> {
self.route(Method::POST, path, handler)
}
pub fn put(self, path: &str, handler: HandlerFn) -> Result<Self, matchit::InsertError> {
self.route(Method::PUT, path, handler)
}
pub fn delete(self, path: &str, handler: HandlerFn) -> Result<Self, matchit::InsertError> {
self.route(Method::DELETE, path, handler)
}
#[must_use]
pub fn into_shared(self) -> SharedRouter {
SharedRouter {
inner: Arc::new(self),
}
}
}
#[derive(Clone)]
pub struct SharedRouter {
inner: Arc<Router>,
}
impl SharedRouter {
pub async fn handle(&self, req: http::Request<hyper::body::Incoming>) -> Response<Body> {
let (mut parts, body) = req.into_parts();
if matches!(
parts.method,
Method::POST | Method::PUT | Method::PATCH | Method::DELETE
) {
use http_body_util::BodyExt;
let body_bytes = body
.collect()
.await
.map(|c| c.to_bytes())
.unwrap_or_default();
parts.extensions.insert(RequestBody(body_bytes));
}
self.handle_parts(parts).await
}
pub async fn handle_parts(&self, parts: http::request::Parts) -> Response<Body> {
let method = parts.method.clone();
let path = parts.uri.path().to_owned();
if let Some(tree) = self.inner.trees.get(&method)
&& let Ok(matched) = tree.at(&path)
{
let params: Vec<(String, String)> = matched
.params
.iter()
.map(|(k, v)| (k.to_owned(), v.to_owned()))
.collect();
let handler = matched.value.clone();
return handler(self.inner.state.clone(), parts, params).await;
}
for (m, tree) in &self.inner.trees {
if *m != method && tree.at(&path).is_ok() {
return Response::builder()
.status(StatusCode::METHOD_NOT_ALLOWED)
.header("content-type", "text/plain; charset=utf-8")
.body(Body::from("Method Not Allowed"))
.unwrap_or_default();
}
}
Response::builder()
.status(StatusCode::NOT_FOUND)
.header("content-type", "text/plain; charset=utf-8")
.body(Body::from("Not Found"))
.unwrap_or_default()
}
}
#[must_use]
pub fn handler_fn<F, Fut>(f: F) -> HandlerFn
where
F: Fn(AppState, http::request::Parts, Vec<(String, String)>) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Response<Body>> + Send + 'static,
{
Arc::new(move |state, parts, params| Box::pin(f(state, parts, params)))
}
#[derive(Clone, Debug)]
pub struct RequestBody(pub Bytes);
#[must_use]
pub fn path_param(params: &[(String, String)], name: &str) -> Option<String> {
params
.iter()
.find(|(k, _)| k == name)
.map(|(_, v)| v.clone())
}
pub fn parse_query<T: serde::de::DeserializeOwned>(uri: &http::Uri) -> Result<T, String> {
let query = uri.query().unwrap_or_default();
let pairs: Vec<(String, String)> = query
.split('&')
.filter(|s| !s.is_empty())
.filter_map(|pair| {
let (k, v) = pair.split_once('=')?;
Some((k.to_owned(), urlencoding_decode(v)))
})
.collect();
let map: serde_json::Map<String, serde_json::Value> = pairs
.into_iter()
.map(|(k, v)| (k, serde_json::Value::String(v)))
.collect();
serde_json::from_value(serde_json::Value::Object(map))
.map_err(|e| format!("query parse error: {e}"))
}
pub fn parse_json_body<T: serde::de::DeserializeOwned>(
parts: &http::request::Parts,
) -> Result<T, String> {
let body = parts
.extensions
.get::<RequestBody>()
.ok_or_else(|| "No request body".to_owned())?;
serde_json::from_slice(&body.0).map_err(|e| format!("Invalid JSON: {e}"))
}
pub fn parse_form_body(parts: &http::request::Parts) -> Result<Vec<(String, String)>, String> {
let body = parts
.extensions
.get::<RequestBody>()
.ok_or_else(|| "No request body".to_owned())?;
let s = std::str::from_utf8(&body.0).map_err(|e| format!("invalid UTF-8: {e}"))?;
Ok(s.split('&')
.filter(|p| !p.is_empty())
.map(|pair| {
let (k, v) = pair.split_once('=').unwrap_or((pair, ""));
(urlencoding_decode(k), urlencoding_decode(v))
})
.collect())
}
#[must_use]
pub fn form_field<'a>(form: &'a [(String, String)], name: &str) -> Option<&'a str> {
form.iter()
.rev()
.find(|(k, _)| k == name)
.map(|(_, v)| v.as_str())
}
#[must_use]
pub fn form_checkbox(form: &[(String, String)], name: &str) -> bool {
form.iter().any(|(k, _)| k == name)
}
#[must_use]
pub fn decode_percent(input: &str) -> String {
urlencoding_decode(input)
}
fn urlencoding_decode(input: &str) -> String {
let bytes = input.as_bytes();
let mut out: Vec<u8> = Vec::with_capacity(bytes.len());
let mut i = 0;
while i < bytes.len() {
let b = bytes[i];
if b == b'%' && i + 2 < bytes.len() {
let hi = hex_val(bytes[i + 1]);
let lo = hex_val(bytes[i + 2]);
out.push((hi << 4) | lo);
i += 3;
} else if b == b'+' {
out.push(b' ');
i += 1;
} else {
out.push(b);
i += 1;
}
}
String::from_utf8_lossy(&out).into_owned()
}
fn hex_val(b: u8) -> u8 {
match b {
b'0'..=b'9' => b - b'0',
b'a'..=b'f' => b - b'a' + 10,
b'A'..=b'F' => b - b'A' + 10,
_ => 0,
}
}
#[must_use]
pub fn html_response(status: StatusCode, body: impl Into<Bytes>) -> Response<Body> {
Response::builder()
.status(status)
.header("content-type", "text/html; charset=utf-8")
.body(Body::from(body.into()))
.unwrap_or_default()
}
#[must_use]
pub fn json_response(status: StatusCode, value: &serde_json::Value) -> Response<Body> {
let body = serde_json::to_string(value).unwrap_or_else(|_| "{}".to_owned());
Response::builder()
.status(status)
.header("content-type", "application/json; charset=utf-8")
.body(Body::from(Bytes::from(body)))
.unwrap_or_default()
}
#[must_use]
pub fn problem_response(
status: StatusCode,
error_type: &str,
title: &str,
detail: &str,
) -> Response<Body> {
let body = serde_json::json!({
"type": error_type,
"title": title,
"status": status.as_u16(),
"detail": detail,
});
let json = serde_json::to_string(&body).unwrap_or_else(|_| "{}".to_owned());
Response::builder()
.status(status)
.header("content-type", "application/problem+json; charset=utf-8")
.body(Body::from(Bytes::from(json)))
.unwrap_or_default()
}
#[must_use]
pub fn redirect(location: &str) -> Response<Body> {
Response::builder()
.status(StatusCode::SEE_OTHER)
.header("location", location)
.body(Body::default())
.unwrap_or_default()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_path_param_found() {
let params = vec![("id".to_owned(), "ndaal-sa-2026-001".to_owned())];
assert_eq!(
path_param(¶ms, "id"),
Some("ndaal-sa-2026-001".to_owned())
);
}
#[test]
fn test_path_param_not_found() {
let params: Vec<(String, String)> = vec![];
assert_eq!(path_param(¶ms, "id"), None);
}
#[test]
fn test_json_response_content_type() {
let resp = json_response(StatusCode::OK, &serde_json::json!({"key": "value"}));
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(
resp.headers().get("content-type").unwrap(),
"application/json; charset=utf-8"
);
}
#[test]
fn test_problem_response_content_type() {
let resp = problem_response(
StatusCode::NOT_FOUND,
"https://ndaal.eu/csaf/errors/not-found",
"Not Found",
"Document not found",
);
assert_eq!(resp.status(), StatusCode::NOT_FOUND);
assert_eq!(
resp.headers().get("content-type").unwrap(),
"application/problem+json; charset=utf-8"
);
}
#[test]
fn test_html_response() {
let resp = html_response(StatusCode::OK, "<h1>Test</h1>");
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(
resp.headers().get("content-type").unwrap(),
"text/html; charset=utf-8"
);
}
#[test]
fn test_redirect() {
let resp = redirect("/csaf");
assert_eq!(resp.status(), StatusCode::SEE_OTHER);
assert_eq!(resp.headers().get("location").unwrap(), "/csaf");
}
#[test]
fn test_urlencoding_decode() {
assert_eq!(urlencoding_decode("hello%20world"), "hello world");
assert_eq!(urlencoding_decode("a+b"), "a b");
assert_eq!(urlencoding_decode("%2F"), "/");
assert_eq!(urlencoding_decode("plain"), "plain");
}
#[test]
fn test_urlencoding_decode_literal_utf8_passthrough() {
assert_eq!(urlencoding_decode("Ú"), "Ú");
assert_eq!(urlencoding_decode("Á&Ú"), "Á&Ú");
assert_eq!(urlencoding_decode("ɉ"), "ɉ");
assert_eq!(urlencoding_decode("カフェ"), "カフェ");
}
#[test]
fn test_urlencoding_decode_percent_escaped_utf8() {
assert_eq!(urlencoding_decode("Gr%C3%BC%C3%9F"), "Grüß");
assert_eq!(urlencoding_decode("caf%C3%A9"), "café");
}
#[test]
fn test_urlencoding_decode_mixed_literal_and_escaped() {
assert_eq!(urlencoding_decode("Ü%2FÖ"), "Ü/Ö");
assert_eq!(urlencoding_decode("a+b+%C3%B6"), "a b ö");
}
#[test]
fn test_urlencoding_decode_total_on_truncated_escape() {
assert_eq!(urlencoding_decode("abc%"), "abc%");
assert_eq!(urlencoding_decode("abc%2"), "abc%2");
assert_eq!(urlencoding_decode("%"), "%");
}
}