use std::sync::Arc;
use cot_core::error::impl_into_cot_error;
pub use cot_core::request::extractors::FromRequest;
pub use cot_core::request::extractors::FromRequestHead;
#[doc(inline)]
pub use cot_core::request::extractors::{Path, UrlQuery};
use crate::Body;
use crate::auth::Auth;
use crate::form::{Form, FormResult};
use crate::request::{Request, RequestExt, RequestHead};
use crate::router::Urls;
use crate::session::Session;
impl FromRequestHead for Urls {
async fn from_request_head(head: &RequestHead) -> cot::Result<Self> {
Ok(Self::from_parts(head))
}
}
#[derive(Debug)]
pub struct RequestForm<F: Form>(pub FormResult<F>);
impl<F: Form> FromRequest for RequestForm<F> {
async fn from_request(head: &RequestHead, body: Body) -> cot::Result<Self> {
let mut request = Request::from_parts(head.clone(), body);
Ok(Self(F::from_request(&mut request).await?))
}
}
#[cfg(feature = "db")]
impl FromRequestHead for crate::db::Database {
async fn from_request_head(head: &RequestHead) -> cot::Result<Self> {
Ok(head.context().database().clone())
}
}
#[cfg(feature = "cache")]
impl FromRequestHead for crate::cache::Cache {
async fn from_request_head(head: &RequestHead) -> cot::Result<Self> {
Ok(head.context().cache().clone())
}
}
#[cfg(feature = "email")]
impl FromRequestHead for crate::email::Email {
async fn from_request_head(head: &RequestHead) -> cot::Result<Self> {
Ok(head.context().email().clone())
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct StaticFiles {
inner: Arc<crate::static_files::StaticFiles>,
}
impl StaticFiles {
pub fn url_for(&self, path: &str) -> Result<&str, StaticFilesGetError> {
self.inner
.path_for(path)
.ok_or_else(|| StaticFilesGetError::NotFound {
path: path.to_owned(),
})
}
}
const ERROR_PREFIX: &str = "could not get URL for a static file:";
#[derive(Debug, Clone, PartialEq, Eq, Hash, thiserror::Error)]
#[non_exhaustive]
pub enum StaticFilesGetError {
#[error("{ERROR_PREFIX} static file `{path}` not found")]
#[non_exhaustive]
NotFound {
path: String,
},
}
impl_into_cot_error!(StaticFilesGetError);
impl FromRequestHead for StaticFiles {
async fn from_request_head(head: &RequestHead) -> cot::Result<Self> {
Ok(StaticFiles {
inner: head
.extensions
.get::<Arc<crate::static_files::StaticFiles>>()
.cloned()
.expect("StaticFilesMiddleware not enabled for the route/project"),
})
}
}
impl FromRequestHead for Session {
async fn from_request_head(head: &RequestHead) -> cot::Result<Self> {
Ok(Session::from_extensions(&head.extensions).clone())
}
}
impl FromRequestHead for Auth {
async fn from_request_head(head: &RequestHead) -> cot::Result<Self> {
let auth = head
.extensions
.get::<Auth>()
.expect("AuthMiddleware not enabled for the route/project")
.clone();
Ok(auth)
}
}
#[cfg(test)]
mod tests {
use cot_core::Method;
use cot_core::html::Html;
use super::*;
use crate::request::extractors::FromRequest;
use crate::reverse;
use crate::router::{Route, Router};
use crate::test::TestRequestBuilder;
#[cot::test]
async fn urls_extraction() {
async fn handler() -> Html {
Html::new("")
}
let router = Router::with_urls([Route::with_handler_and_name(
"/test/",
handler,
"test_route",
)]);
let mut request = TestRequestBuilder::get("/test/").router(router).build();
let urls: Urls = request.extract_from_head().await.unwrap();
assert!(reverse!(urls, "test_route").is_ok());
}
#[cot::test]
async fn method_extraction() {
let mut request = TestRequestBuilder::get("/test/").build();
let method: Method = request.extract_from_head().await.unwrap();
assert_eq!(method, Method::GET);
}
#[cot::test]
async fn request_form() {
#[derive(Debug, PartialEq, Eq, Form)]
struct MyForm {
hello: String,
foo: String,
}
let request = TestRequestBuilder::post("/")
.form_data(&[("hello", "world"), ("foo", "bar")])
.build();
let (head, body) = request.into_parts();
let RequestForm(form_result): RequestForm<MyForm> =
RequestForm::from_request(&head, body).await.unwrap();
assert_eq!(
form_result.unwrap(),
MyForm {
hello: "world".to_string(),
foo: "bar".to_string(),
}
);
}
#[cfg(feature = "db")]
#[cot::test]
#[cfg_attr(
miri,
ignore = "unsupported operation: can't call foreign function `sqlite3_open_v2` on OS `linux`"
)]
async fn request_db() {
let db = crate::test::TestDatabase::new_sqlite().await.unwrap();
let mut test_request = TestRequestBuilder::get("/").database(db.database()).build();
let extracted_db: crate::db::Database = test_request.extract_from_head().await.unwrap();
extracted_db.close().await.unwrap();
}
#[cfg(feature = "cache")]
#[cot::test]
async fn request_cache() {
let mut request_builder = TestRequestBuilder::get("/");
let mut request = request_builder.build();
let extracted_cache = request.extract_from_head::<crate::cache::Cache>().await;
assert!(extracted_cache.is_ok());
}
#[cfg(feature = "email")]
#[cot::test]
async fn request_email() {
let mut request_builder = TestRequestBuilder::get("/");
let mut request = request_builder.build();
let email_service = request.extract_from_head::<crate::email::Email>().await;
assert!(email_service.is_ok());
}
}