use async_trait::async_trait;
use serde::de::DeserializeOwned;
use std::collections::HashMap;
use crate::{Error, Request, Result};
#[async_trait]
pub trait FromRequest: Sized {
async fn from_request(req: &mut Request) -> Result<Self>;
}
pub struct Path<T>(pub T);
#[async_trait]
impl<T> FromRequest for Path<T>
where
T: DeserializeOwned + Send,
{
async fn from_request(req: &mut Request) -> Result<Self> {
let params_map: HashMap<String, serde_json::Value> = req
.params()
.iter()
.map(|(k, v)| {
let value = if let Ok(num) = v.parse::<i64>() {
serde_json::Value::Number(num.into())
} else if let Ok(num) = v.parse::<u64>() {
serde_json::Value::Number(num.into())
} else if let Ok(num) = v.parse::<f64>() {
serde_json::Number::from_f64(num)
.map(serde_json::Value::Number)
.unwrap_or_else(|| serde_json::Value::String(v.clone()))
} else if let Ok(b) = v.parse::<bool>() {
serde_json::Value::Bool(b)
} else {
serde_json::Value::String(v.clone())
};
(k.clone(), value)
})
.collect();
let value = serde_json::Value::Object(
params_map
.into_iter()
.map(|(k, v)| (k, v))
.collect(),
);
let params: T = serde_json::from_value(value)
.map_err(|e| Error::BadRequest(format!("Failed to parse path parameters: {}", e)))?;
Ok(Path(params))
}
}
pub struct Json<T>(pub T);
#[async_trait]
impl<T> FromRequest for Json<T>
where
T: DeserializeOwned + Send,
{
async fn from_request(req: &mut Request) -> Result<Self> {
let value = req.json().await?;
Ok(Json(value))
}
}
pub struct Query<T>(pub T);
#[async_trait]
impl<T> FromRequest for Query<T>
where
T: DeserializeOwned + Send,
{
async fn from_request(req: &mut Request) -> Result<Self> {
let query_map: HashMap<String, serde_json::Value> = req
.query_params()
.iter()
.map(|(k, v)| {
let value = if let Ok(num) = v.parse::<i64>() {
serde_json::Value::Number(num.into())
} else if let Ok(num) = v.parse::<f64>() {
serde_json::Number::from_f64(num)
.map(serde_json::Value::Number)
.unwrap_or_else(|| serde_json::Value::String(v.clone()))
} else if let Ok(b) = v.parse::<bool>() {
serde_json::Value::Bool(b)
} else {
serde_json::Value::String(v.clone())
};
(k.clone(), value)
})
.collect();
let value = serde_json::Value::Object(
query_map
.into_iter()
.map(|(k, v)| (k, v))
.collect(),
);
let params: T = serde_json::from_value(value)
.map_err(|e| Error::BadRequest(format!("Failed to parse query parameters: {}", e)))?;
Ok(Query(params))
}
}
#[cfg(test)]
mod tests {
use super::*;
use bytes::Bytes;
use http::{HeaderMap, Method, Uri};
use serde::Deserialize;
#[derive(Deserialize, Debug, PartialEq)]
struct TestParams {
id: u32,
name: String,
}
#[tokio::test]
async fn test_path_extractor() {
let mut req = Request::new(
Method::GET,
Uri::from_static("/users/123"),
HeaderMap::new(),
Bytes::new(),
);
req.set_param("id".to_string(), "123".to_string());
req.set_param("name".to_string(), "john".to_string());
let Path(params): Path<TestParams> = Path::from_request(&mut req).await.unwrap();
assert_eq!(params.id, 123);
assert_eq!(params.name, "john");
}
#[tokio::test]
async fn test_json_extractor() {
#[derive(Deserialize, Debug, PartialEq)]
struct User {
name: String,
age: u32,
}
let json_body = r#"{"name":"Alice","age":30}"#;
let mut req = Request::new(
Method::POST,
Uri::from_static("/users"),
HeaderMap::new(),
Bytes::from(json_body),
);
let Json(user): Json<User> = Json::from_request(&mut req).await.unwrap();
assert_eq!(user.name, "Alice");
assert_eq!(user.age, 30);
}
#[tokio::test]
async fn test_query_extractor() {
#[derive(Deserialize, Debug, PartialEq)]
struct Pagination {
page: u32,
limit: u32,
}
let mut req = Request::new(
Method::GET,
Uri::from_static("/users?page=1&limit=10"),
HeaderMap::new(),
Bytes::new(),
);
let Query(params): Query<Pagination> = Query::from_request(&mut req).await.unwrap();
assert_eq!(params.page, 1);
assert_eq!(params.limit, 10);
}
#[tokio::test]
async fn test_path_extractor_with_dash() {
#[derive(Deserialize, Debug, PartialEq)]
struct PathParams {
name: String,
id: u32,
}
let mut req = Request::new(
Method::GET,
Uri::from_static("/users/-/0"),
HeaderMap::new(),
Bytes::new(),
);
req.set_param("name".to_string(), "-".to_string());
req.set_param("id".to_string(), "0".to_string());
let Path(params) = Path::<PathParams>::from_request(&mut req).await.unwrap();
assert_eq!(params.name, "-");
assert_eq!(params.id, 0);
}
}