use std::future::Future;
use std::sync::Arc;
use cot::Error;
use cot::error::ErrorRepr;
use cot::request::{PathParams, Request};
use http::request::Parts;
use serde::de::DeserializeOwned;
use crate::Method;
use crate::auth::Auth;
use crate::form::{Form, FormResult};
#[cfg(feature = "json")]
use crate::json::Json;
use crate::request::RequestExt;
use crate::router::Urls;
use crate::session::Session;
pub trait FromRequest: Sized {
fn from_request(request: Request) -> impl Future<Output = cot::Result<Self>> + Send;
}
impl FromRequest for Request {
async fn from_request(request: Request) -> cot::Result<Self> {
Ok(request)
}
}
pub trait FromRequestParts: Sized {
fn from_request_parts(parts: &mut Parts) -> impl Future<Output = cot::Result<Self>> + Send;
}
impl FromRequestParts for Urls {
async fn from_request_parts(parts: &mut Parts) -> cot::Result<Self> {
Ok(Self::from_parts(parts))
}
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct Path<D>(pub D);
impl<D: DeserializeOwned> FromRequestParts for Path<D> {
async fn from_request_parts(parts: &mut Parts) -> cot::Result<Self> {
let params = parts
.extensions
.get::<PathParams>()
.expect("PathParams extension missing")
.parse()
.map_err(|error| Error::new(ErrorRepr::PathParametersParse(error)))?;
Ok(Self(params))
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct UrlQuery<T>(pub T);
impl<D: DeserializeOwned> FromRequestParts for UrlQuery<D>
where
D: DeserializeOwned,
{
async fn from_request_parts(parts: &mut Parts) -> cot::Result<Self> {
let query = parts.uri.query().unwrap_or_default();
let deserializer =
serde_html_form::Deserializer::new(form_urlencoded::parse(query.as_bytes()));
let value = serde_path_to_error::deserialize(deserializer)
.map_err(|error| Error::new(ErrorRepr::QueryParametersParse(error)))?;
Ok(UrlQuery(value))
}
}
#[cfg(feature = "json")]
impl<D: DeserializeOwned> FromRequest for Json<D> {
async fn from_request(mut request: Request) -> cot::Result<Self> {
request.expect_content_type(cot::headers::JSON_CONTENT_TYPE)?;
let body = std::mem::take(request.body_mut());
let bytes = body.into_bytes().await?;
let deserializer = &mut serde_json::Deserializer::from_slice(&bytes);
let result = serde_path_to_error::deserialize(deserializer)
.map_err(|error| Error::new(ErrorRepr::Json(error)))?;
Ok(Self(result))
}
}
#[derive(Debug)]
pub struct RequestForm<F: Form>(pub FormResult<F>);
impl<F: Form> FromRequest for RequestForm<F> {
async fn from_request(mut request: Request) -> cot::Result<Self> {
Ok(Self(F::from_request(&mut request).await?))
}
}
#[cfg(feature = "db")]
#[derive(Debug)]
pub struct RequestDb(pub Arc<crate::db::Database>);
#[cfg(feature = "db")]
impl FromRequestParts for RequestDb {
async fn from_request_parts(parts: &mut Parts) -> cot::Result<Self> {
Ok(Self(parts.db().clone()))
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct StaticFiles {
inner: Arc<crate::static_files::StaticFiles>,
}
impl StaticFiles {
pub fn url_for(&self, path: &str) -> Result<&str, StaticFilesGetError> {
self.inner
.path_for(path)
.ok_or_else(|| StaticFilesGetError::NotFound {
path: path.to_owned(),
})
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, thiserror::Error)]
#[non_exhaustive]
pub enum StaticFilesGetError {
#[error("static file `{path}` not found")]
NotFound {
path: String,
},
}
impl FromRequestParts for StaticFiles {
async fn from_request_parts(parts: &mut Parts) -> cot::Result<Self> {
Ok(StaticFiles {
inner: parts
.extensions
.get::<Arc<crate::static_files::StaticFiles>>()
.cloned()
.expect("StaticFilesMiddleware not enabled for the route/project"),
})
}
}
impl FromRequestParts for Method {
async fn from_request_parts(parts: &mut Parts) -> cot::Result<Self> {
Ok(parts.method.clone())
}
}
impl FromRequestParts for Session {
async fn from_request_parts(parts: &mut Parts) -> cot::Result<Self> {
Ok(Session::from_extensions(&parts.extensions).clone())
}
}
impl FromRequestParts for Auth {
async fn from_request_parts(parts: &mut Parts) -> cot::Result<Self> {
let auth = parts
.extensions
.get::<Auth>()
.expect("AuthMiddleware not enabled for the route/project")
.clone();
Ok(auth)
}
}
#[cfg(test)]
mod tests {
use serde::Deserialize;
use super::*;
use crate::html::Html;
use crate::request::extractors::{FromRequest, Json, Path, UrlQuery};
use crate::router::{Route, Router, Urls};
use crate::test::TestRequestBuilder;
use crate::{Body, reverse};
#[cfg(feature = "json")]
#[cot::test]
async fn json() {
let request = http::Request::builder()
.method(http::Method::POST)
.header(http::header::CONTENT_TYPE, cot::headers::JSON_CONTENT_TYPE)
.body(Body::fixed(r#"{"hello":"world"}"#))
.unwrap();
let Json(data): Json<serde_json::Value> = Json::from_request(request).await.unwrap();
assert_eq!(data, serde_json::json!({"hello": "world"}));
}
#[cfg(feature = "json")]
#[cot::test]
async fn json_empty() {
#[derive(Debug, Deserialize, PartialEq, Eq)]
struct TestData {}
let request = http::Request::builder()
.method(http::Method::POST)
.header(http::header::CONTENT_TYPE, cot::headers::JSON_CONTENT_TYPE)
.body(Body::fixed("{}"))
.unwrap();
let Json(data): Json<TestData> = Json::from_request(request).await.unwrap();
assert_eq!(data, TestData {});
}
#[cfg(feature = "json")]
#[cot::test]
async fn json_struct() {
#[derive(Debug, Deserialize, PartialEq, Eq)]
struct TestDataInner {
hello: String,
}
#[derive(Debug, Deserialize, PartialEq, Eq)]
struct TestData {
inner: TestDataInner,
}
let request = http::Request::builder()
.method(http::Method::POST)
.header(http::header::CONTENT_TYPE, cot::headers::JSON_CONTENT_TYPE)
.body(Body::fixed(r#"{"inner":{"hello":"world"}}"#))
.unwrap();
let Json(data): Json<TestData> = Json::from_request(request).await.unwrap();
assert_eq!(
data,
TestData {
inner: TestDataInner {
hello: "world".to_string(),
}
}
);
}
#[cot::test]
async fn path_extraction() {
#[derive(Deserialize, Debug, PartialEq)]
struct TestParams {
id: i32,
name: String,
}
let (mut parts, _body) = Request::new(Body::empty()).into_parts();
let mut params = PathParams::new();
params.insert("id".to_string(), "42".to_string());
params.insert("name".to_string(), "test".to_string());
parts.extensions.insert(params);
let Path(extracted): Path<TestParams> = Path::from_request_parts(&mut parts).await.unwrap();
let expected = TestParams {
id: 42,
name: "test".to_string(),
};
assert_eq!(extracted, expected);
}
#[cot::test]
async fn url_query_extraction() {
#[derive(Deserialize, Debug, PartialEq)]
struct QueryParams {
page: i32,
filter: String,
}
let (mut parts, _body) = Request::new(Body::empty()).into_parts();
parts.uri = "https://example.com/?page=2&filter=active".parse().unwrap();
let UrlQuery(query): UrlQuery<QueryParams> =
UrlQuery::from_request_parts(&mut parts).await.unwrap();
assert_eq!(query.page, 2);
assert_eq!(query.filter, "active");
}
#[cot::test]
async fn url_query_empty() {
#[derive(Deserialize, Debug, PartialEq)]
struct EmptyParams {}
let (mut parts, _body) = Request::new(Body::empty()).into_parts();
parts.uri = "https://example.com/".parse().unwrap();
let result: UrlQuery<EmptyParams> = UrlQuery::from_request_parts(&mut parts).await.unwrap();
assert!(matches!(result, UrlQuery(_)));
}
#[cfg(feature = "json")]
#[cot::test]
async fn json_invalid_content_type() {
let request = http::Request::builder()
.method(http::Method::POST)
.header(http::header::CONTENT_TYPE, "text/plain")
.body(Body::fixed(r#"{"hello":"world"}"#))
.unwrap();
let result = Json::<serde_json::Value>::from_request(request).await;
assert!(result.is_err());
}
#[cot::test]
async fn request_form() {
#[derive(Debug, PartialEq, Eq, Form)]
struct MyForm {
hello: String,
foo: String,
}
let request = TestRequestBuilder::post("/")
.form_data(&[("hello", "world"), ("foo", "bar")])
.build();
let RequestForm(form_result): RequestForm<MyForm> =
RequestForm::from_request(request).await.unwrap();
assert_eq!(
form_result.unwrap(),
MyForm {
hello: "world".to_string(),
foo: "bar".to_string(),
}
);
}
#[cot::test]
async fn urls_extraction() {
async fn handler() -> Html {
Html::new("")
}
let router = Router::with_urls([Route::with_handler_and_name(
"/test/",
handler,
"test_route",
)]);
let mut request = TestRequestBuilder::get("/test/").router(router).build();
let urls: Urls = request.extract_parts().await.unwrap();
assert!(reverse!(urls, "test_route").is_ok());
}
#[cot::test]
async fn method_extraction() {
let mut request = TestRequestBuilder::get("/test/").build();
let method: Method = request.extract_parts().await.unwrap();
assert_eq!(method, Method::GET);
}
#[cfg(feature = "db")]
#[cot::test]
#[cfg_attr(miri, ignore)]
async fn request_db() {
let db = crate::test::TestDatabase::new_sqlite().await.unwrap();
let mut test_request = TestRequestBuilder::get("/").database(db.database()).build();
let RequestDb(extracted_db) = test_request.extract_parts().await.unwrap();
extracted_db.close().await.unwrap();
}
}