use crate::error::{Error, Result};
use crate::types::{OxiditeRequest, OxiditeResponse as Response};
use serde::de::DeserializeOwned;
use http_body_util::BodyExt;
pub struct Path<T>(pub T);
pub struct Query<T>(pub T);
pub struct Json<T>(pub T);
pub trait FromRequest: Sized {
fn from_request(req: &mut OxiditeRequest) -> impl std::future::Future<Output = Result<Self>> + Send;
}
impl<T: DeserializeOwned + Send> FromRequest for Path<T> {
async fn from_request(req: &mut OxiditeRequest) -> Result<Self> {
req.extensions()
.get::<PathParams>()
.ok_or_else(|| Error::BadRequest("No path parameters found".to_string()))
.and_then(|params| {
serde_json::from_value(params.0.clone())
.map(Path)
.map_err(|e| Error::BadRequest(format!("Invalid path parameters: {}", e)))
})
}
}
impl<T: DeserializeOwned + Send> FromRequest for Query<T> {
async fn from_request(req: &mut OxiditeRequest) -> Result<Self> {
let query = req.uri().query().unwrap_or("");
serde_urlencoded::from_str(query)
.map(Query)
.map_err(|e| Error::BadRequest(format!("Invalid query parameters: {}", e)))
}
}
impl<T: DeserializeOwned + Send> FromRequest for Json<T> {
async fn from_request(req: &mut OxiditeRequest) -> Result<Self> {
use http_body_util::BodyExt;
use bytes::Buf;
let body = req.body_mut();
let bytes = body.collect().await
.map_err(|e| Error::InternalServerError(format!("Failed to read body: {}", e)))?
.aggregate();
serde_json::from_reader(bytes.reader())
.map(Json)
.map_err(|e| Error::BadRequest(format!("Invalid JSON: {}", e)))
}
}
#[derive(Clone)]
pub struct PathParams(pub serde_json::Value);
impl<T: serde::Serialize> Json<T> {
pub fn into_response(self) -> Result<http_body_util::Full<bytes::Bytes>> {
let body = serde_json::to_vec(&self.0)
.map_err(|e| Error::InternalServerError(format!("Failed to serialize JSON: {}", e)))?;
Ok(http_body_util::Full::new(bytes::Bytes::from(body)))
}
}
pub struct State<T>(pub T);
impl<T: Clone + Send + Sync + 'static> FromRequest for State<T> {
async fn from_request(req: &mut OxiditeRequest) -> Result<Self> {
if let Some(state) = req.extensions().get::<T>() {
return Ok(State(state.clone()));
}
if let Some(router_exts) = req.extensions().get::<std::sync::Arc<std::sync::RwLock<http::Extensions>>>() {
if let Ok(exts) = router_exts.read() {
if let Some(state) = exts.get::<T>() {
return Ok(State(state.clone()));
}
}
}
Err(Error::InternalServerError(format!(
"Application state of type {} not found in request or router extensions",
std::any::type_name::<T>()
)))
}
}
pub struct Form<T>(pub T);
impl<T: DeserializeOwned + Send> FromRequest for Form<T> {
async fn from_request(req: &mut OxiditeRequest) -> Result<Self> {
use http_body_util::BodyExt;
use bytes::Buf;
let content_type = req.headers()
.get("content-type")
.and_then(|ct| ct.to_str().ok())
.unwrap_or("");
if !content_type.starts_with("application/x-www-form-urlencoded") {
return Err(Error::BadRequest(
"Expected application/x-www-form-urlencoded content type".to_string()
));
}
let body = req.body_mut();
let bytes = body.collect().await
.map_err(|e| Error::InternalServerError(format!("Failed to read body: {}", e)))?
.aggregate();
let body_str = std::str::from_utf8(bytes.chunk())
.map_err(|e| Error::BadRequest(format!("Invalid UTF-8 in form data: {}", e)))?;
serde_urlencoded::from_str(body_str)
.map(Form)
.map_err(|e| Error::BadRequest(format!("Invalid form data: {}", e)))
}
}
pub struct Cookies {
cookies: std::collections::HashMap<String, String>,
}
impl Cookies {
pub fn get(&self, name: &str) -> Option<&String> {
self.cookies.get(name)
}
pub fn contains_key(&self, name: &str) -> bool {
self.cookies.contains_key(name)
}
pub fn iter(&self) -> impl Iterator<Item = (&String, &String)> {
self.cookies.iter()
}
}
impl FromRequest for Cookies {
async fn from_request(req: &mut OxiditeRequest) -> Result<Self> {
let mut cookies_map = std::collections::HashMap::new();
if let Some(cookie_header) = req.headers().get(http::header::COOKIE) {
if let Ok(cookie_str) = cookie_header.to_str() {
for cookie_pair in cookie_str.split(';') {
let trimmed = cookie_pair.trim();
if let Some((name, value)) = trimmed.split_once('=') {
cookies_map.insert(name.trim().to_string(), value.trim().to_string());
}
}
}
}
Ok(Cookies { cookies: cookies_map })
}
}
pub struct Body<T>(pub T);
impl FromRequest for Body<String> {
async fn from_request(req: &mut OxiditeRequest) -> Result<Self> {
use http_body_util::BodyExt;
use bytes::Buf;
let body = req.body_mut();
let bytes = body.collect().await
.map_err(|e| Error::InternalServerError(format!("Failed to read body: {}", e)))?
.aggregate();
let body_str = std::str::from_utf8(bytes.chunk())
.map_err(|e| Error::InternalServerError(format!("Invalid UTF-8 in body: {}", e)))?
.to_string();
Ok(Body(body_str))
}
}
impl FromRequest for Body<Vec<u8>> {
async fn from_request(req: &mut OxiditeRequest) -> Result<Self> {
use http_body_util::BodyExt;
let body = req.body_mut();
let bytes = body.collect().await
.map_err(|e| Error::InternalServerError(format!("Failed to read body: {}", e)))?
.to_bytes();
Ok(Body(bytes.to_vec()))
}
}
pub struct WebSocketUpgrade {
pub key: String,
}
impl WebSocketUpgrade {
pub fn response(&self) -> Response {
use sha1::{Sha1, Digest};
use base64::{Engine as _, engine::general_purpose};
let mut hasher = Sha1::new();
hasher.update(self.key.as_bytes());
hasher.update(b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11");
let accept = general_purpose::STANDARD.encode(hasher.finalize());
let res = http::Response::builder()
.status(http::StatusCode::SWITCHING_PROTOCOLS)
.header(http::header::UPGRADE, "websocket")
.header(http::header::CONNECTION, "upgrade")
.header(http::header::SEC_WEBSOCKET_ACCEPT, accept)
.body(crate::types::BoxBody::new(http_body_util::Empty::new().map_err(|e| match e {}).boxed()))
.unwrap();
Response::new(res)
}
}
impl FromRequest for WebSocketUpgrade {
async fn from_request(req: &mut OxiditeRequest) -> Result<Self> {
let headers = req.headers();
let upgrade = headers.get(http::header::UPGRADE).and_then(|h| h.to_str().ok());
let _connection = headers.get(http::header::CONNECTION).and_then(|h| h.to_str().ok());
let key = headers.get(http::header::SEC_WEBSOCKET_KEY).and_then(|h| h.to_str().ok());
if upgrade == Some("websocket") && key.is_some() {
Ok(WebSocketUpgrade {
key: key.unwrap().to_string(),
})
} else {
Err(Error::BadRequest("Expected WebSocket upgrade".to_string()))
}
}
}