use std::future::Future;
use std::sync::Arc;
use serde::de::DeserializeOwned;
use crate::auth::Auth;
use crate::form::{Form, FormResult};
#[cfg(feature = "json")]
use crate::json::Json;
use crate::request::{InvalidContentType, PathParams, Request, RequestExt, RequestHead};
use crate::router::Urls;
use crate::session::Session;
use crate::{Body, Method};
pub trait FromRequest: Sized {
fn from_request(
head: &RequestHead,
body: Body,
) -> impl Future<Output = cot::Result<Self>> + Send;
}
impl FromRequest for Request {
async fn from_request(head: &RequestHead, body: Body) -> cot::Result<Self> {
Ok(Request::from_parts(head.clone(), body))
}
}
pub trait FromRequestHead: Sized {
fn from_request_head(head: &RequestHead) -> impl Future<Output = cot::Result<Self>> + Send;
}
impl FromRequestHead for Urls {
async fn from_request_head(head: &RequestHead) -> cot::Result<Self> {
Ok(Self::from_parts(head))
}
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct Path<D>(pub D);
impl<D: DeserializeOwned> FromRequestHead for Path<D> {
async fn from_request_head(head: &RequestHead) -> cot::Result<Self> {
let params = head
.extensions
.get::<PathParams>()
.expect("PathParams extension missing")
.parse()?;
Ok(Self(params))
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct UrlQuery<T>(pub T);
impl<D: DeserializeOwned> FromRequestHead for UrlQuery<D>
where
D: DeserializeOwned,
{
async fn from_request_head(head: &RequestHead) -> cot::Result<Self> {
let query = head.uri.query().unwrap_or_default();
let deserializer =
serde_html_form::Deserializer::new(form_urlencoded::parse(query.as_bytes()));
let value =
serde_path_to_error::deserialize(deserializer).map_err(QueryParametersParseError)?;
Ok(UrlQuery(value))
}
}
#[derive(Debug, thiserror::Error)]
#[error("could not parse query parameters: {0}")]
struct QueryParametersParseError(serde_path_to_error::Error<serde::de::value::Error>);
impl_into_cot_error!(QueryParametersParseError, BAD_REQUEST);
#[cfg(feature = "json")]
impl<D: DeserializeOwned> FromRequest for Json<D> {
async fn from_request(head: &RequestHead, body: Body) -> cot::Result<Self> {
let content_type = head
.headers
.get(http::header::CONTENT_TYPE)
.map_or("".into(), |value| String::from_utf8_lossy(value.as_bytes()));
if content_type != cot::headers::JSON_CONTENT_TYPE {
return Err(InvalidContentType {
expected: cot::headers::JSON_CONTENT_TYPE,
actual: content_type.into_owned(),
}
.into());
}
let bytes = body.into_bytes().await?;
let deserializer = &mut serde_json::Deserializer::from_slice(&bytes);
let result =
serde_path_to_error::deserialize(deserializer).map_err(JsonDeserializeError)?;
Ok(Self(result))
}
}
#[cfg(feature = "json")]
#[derive(Debug, thiserror::Error)]
#[error("JSON deserialization error: {0}")]
struct JsonDeserializeError(serde_path_to_error::Error<serde_json::Error>);
#[cfg(feature = "json")]
impl_into_cot_error!(JsonDeserializeError, BAD_REQUEST);
#[derive(Debug)]
pub struct RequestForm<F: Form>(pub FormResult<F>);
impl<F: Form> FromRequest for RequestForm<F> {
async fn from_request(head: &RequestHead, body: Body) -> cot::Result<Self> {
let mut request = Request::from_parts(head.clone(), body);
Ok(Self(F::from_request(&mut request).await?))
}
}
#[cfg(feature = "db")]
impl FromRequestHead for crate::db::Database {
async fn from_request_head(head: &RequestHead) -> cot::Result<Self> {
Ok(head.context().database().clone())
}
}
#[cfg(feature = "cache")]
impl FromRequestHead for crate::cache::Cache {
async fn from_request_head(head: &RequestHead) -> cot::Result<Self> {
Ok(head.context().cache().clone())
}
}
#[cfg(feature = "email")]
impl FromRequestHead for crate::email::Email {
async fn from_request_head(head: &RequestHead) -> cot::Result<Self> {
Ok(head.context().email().clone())
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct StaticFiles {
inner: Arc<crate::static_files::StaticFiles>,
}
impl StaticFiles {
pub fn url_for(&self, path: &str) -> Result<&str, StaticFilesGetError> {
self.inner
.path_for(path)
.ok_or_else(|| StaticFilesGetError::NotFound {
path: path.to_owned(),
})
}
}
const ERROR_PREFIX: &str = "could not get URL for a static file:";
#[derive(Debug, Clone, PartialEq, Eq, Hash, thiserror::Error)]
#[non_exhaustive]
pub enum StaticFilesGetError {
#[error("{ERROR_PREFIX} static file `{path}` not found")]
#[non_exhaustive]
NotFound {
path: String,
},
}
impl_into_cot_error!(StaticFilesGetError);
impl FromRequestHead for StaticFiles {
async fn from_request_head(head: &RequestHead) -> cot::Result<Self> {
Ok(StaticFiles {
inner: head
.extensions
.get::<Arc<crate::static_files::StaticFiles>>()
.cloned()
.expect("StaticFilesMiddleware not enabled for the route/project"),
})
}
}
impl FromRequestHead for RequestHead {
async fn from_request_head(head: &RequestHead) -> cot::Result<Self> {
Ok(head.clone())
}
}
impl FromRequestHead for Method {
async fn from_request_head(head: &RequestHead) -> cot::Result<Self> {
Ok(head.method.clone())
}
}
impl FromRequestHead for Session {
async fn from_request_head(head: &RequestHead) -> cot::Result<Self> {
Ok(Session::from_extensions(&head.extensions).clone())
}
}
impl FromRequestHead for Auth {
async fn from_request_head(head: &RequestHead) -> cot::Result<Self> {
let auth = head
.extensions
.get::<Auth>()
.expect("AuthMiddleware not enabled for the route/project")
.clone();
Ok(auth)
}
}
pub use cot_macros::FromRequestHead;
use crate::error::error_impl::impl_into_cot_error;
#[cfg(test)]
mod tests {
use serde::Deserialize;
use super::*;
use crate::html::Html;
use crate::request::extractors::{FromRequest, Json, Path, UrlQuery};
use crate::router::{Route, Router, Urls};
use crate::test::TestRequestBuilder;
use crate::{Body, reverse};
#[cfg(feature = "json")]
#[cot::test]
async fn json() {
let request = http::Request::builder()
.method(http::Method::POST)
.header(http::header::CONTENT_TYPE, cot::headers::JSON_CONTENT_TYPE)
.body(Body::fixed(r#"{"hello":"world"}"#))
.unwrap();
let (head, body) = request.into_parts();
let Json(data): Json<serde_json::Value> = Json::from_request(&head, body).await.unwrap();
assert_eq!(data, serde_json::json!({"hello": "world"}));
}
#[cfg(feature = "json")]
#[cot::test]
async fn json_empty() {
#[derive(Debug, Deserialize, PartialEq, Eq)]
struct TestData {}
let request = http::Request::builder()
.method(http::Method::POST)
.header(http::header::CONTENT_TYPE, cot::headers::JSON_CONTENT_TYPE)
.body(Body::fixed("{}"))
.unwrap();
let (head, body) = request.into_parts();
let Json(data): Json<TestData> = Json::from_request(&head, body).await.unwrap();
assert_eq!(data, TestData {});
}
#[cfg(feature = "json")]
#[cot::test]
async fn json_struct() {
#[derive(Debug, Deserialize, PartialEq, Eq)]
struct TestDataInner {
hello: String,
}
#[derive(Debug, Deserialize, PartialEq, Eq)]
struct TestData {
inner: TestDataInner,
}
let request = http::Request::builder()
.method(http::Method::POST)
.header(http::header::CONTENT_TYPE, cot::headers::JSON_CONTENT_TYPE)
.body(Body::fixed(r#"{"inner":{"hello":"world"}}"#))
.unwrap();
let (head, body) = request.into_parts();
let Json(data): Json<TestData> = Json::from_request(&head, body).await.unwrap();
assert_eq!(
data,
TestData {
inner: TestDataInner {
hello: "world".to_string(),
}
}
);
}
#[cot::test]
async fn path_extraction() {
#[derive(Deserialize, Debug, PartialEq)]
struct TestParams {
id: i32,
name: String,
}
let (mut head, _body) = Request::new(Body::empty()).into_parts();
let mut params = PathParams::new();
params.insert("id".to_string(), "42".to_string());
params.insert("name".to_string(), "test".to_string());
head.extensions.insert(params);
let Path(extracted): Path<TestParams> = Path::from_request_head(&head).await.unwrap();
let expected = TestParams {
id: 42,
name: "test".to_string(),
};
assert_eq!(extracted, expected);
}
#[cot::test]
async fn url_query_extraction() {
#[derive(Deserialize, Debug, PartialEq)]
struct QueryParams {
page: i32,
filter: String,
}
let (mut head, _body) = Request::new(Body::empty()).into_parts();
head.uri = "https://example.com/?page=2&filter=active".parse().unwrap();
let UrlQuery(query): UrlQuery<QueryParams> =
UrlQuery::from_request_head(&head).await.unwrap();
assert_eq!(query.page, 2);
assert_eq!(query.filter, "active");
}
#[cot::test]
async fn url_query_empty() {
#[derive(Deserialize, Debug, PartialEq)]
struct EmptyParams {}
let (mut head, _body) = Request::new(Body::empty()).into_parts();
head.uri = "https://example.com/".parse().unwrap();
let result: UrlQuery<EmptyParams> = UrlQuery::from_request_head(&head).await.unwrap();
assert!(matches!(result, UrlQuery(_)));
}
#[cfg(feature = "json")]
#[cot::test]
async fn json_invalid_content_type() {
let request = http::Request::builder()
.method(http::Method::POST)
.header(http::header::CONTENT_TYPE, "text/plain")
.body(Body::fixed(r#"{"hello":"world"}"#))
.unwrap();
let (head, body) = request.into_parts();
let result = Json::<serde_json::Value>::from_request(&head, body).await;
assert!(result.is_err());
}
#[cot::test]
async fn request_form() {
#[derive(Debug, PartialEq, Eq, Form)]
struct MyForm {
hello: String,
foo: String,
}
let request = TestRequestBuilder::post("/")
.form_data(&[("hello", "world"), ("foo", "bar")])
.build();
let (head, body) = request.into_parts();
let RequestForm(form_result): RequestForm<MyForm> =
RequestForm::from_request(&head, body).await.unwrap();
assert_eq!(
form_result.unwrap(),
MyForm {
hello: "world".to_string(),
foo: "bar".to_string(),
}
);
}
#[cot::test]
async fn urls_extraction() {
async fn handler() -> Html {
Html::new("")
}
let router = Router::with_urls([Route::with_handler_and_name(
"/test/",
handler,
"test_route",
)]);
let mut request = TestRequestBuilder::get("/test/").router(router).build();
let urls: Urls = request.extract_from_head().await.unwrap();
assert!(reverse!(urls, "test_route").is_ok());
}
#[cot::test]
async fn method_extraction() {
let mut request = TestRequestBuilder::get("/test/").build();
let method: Method = request.extract_from_head().await.unwrap();
assert_eq!(method, Method::GET);
}
#[cfg(feature = "db")]
#[cot::test]
#[cfg_attr(
miri,
ignore = "unsupported operation: can't call foreign function `sqlite3_open_v2` on OS `linux`"
)]
async fn request_db() {
let db = crate::test::TestDatabase::new_sqlite().await.unwrap();
let mut test_request = TestRequestBuilder::get("/").database(db.database()).build();
let extracted_db: crate::db::Database = test_request.extract_from_head().await.unwrap();
extracted_db.close().await.unwrap();
}
#[cfg(feature = "cache")]
#[cot::test]
async fn request_cache() {
let mut request_builder = TestRequestBuilder::get("/");
let mut request = request_builder.build();
let extracted_cache = request.extract_from_head::<crate::cache::Cache>().await;
assert!(extracted_cache.is_ok());
}
#[cfg(feature = "email")]
#[cot::test]
async fn request_email() {
let mut request_builder = TestRequestBuilder::get("/");
let mut request = request_builder.build();
let email_service = request.extract_from_head::<crate::email::Email>().await;
assert!(email_service.is_ok());
}
}