use std::future::Future;
use std::sync::Arc;
use http::Extensions;
use indexmap::IndexMap;
use crate::error::error_impl::impl_into_cot_error;
use crate::request::extractors::FromRequestHead;
use crate::router::Router;
use crate::{Body, Result};
pub mod extractors;
mod path_params_deserializer;
pub type Request = http::Request<Body>;
pub type RequestHead = http::request::Parts;
mod private {
pub trait Sealed {}
}
pub trait RequestExt: private::Sealed {
fn extract_from_head<E>(&mut self) -> impl Future<Output = Result<E>> + Send
where
E: FromRequestHead + 'static;
#[must_use]
fn context(&self) -> &crate::ProjectContext;
#[must_use]
fn project_config(&self) -> &crate::config::ProjectConfig;
#[must_use]
fn router(&self) -> &Arc<Router>;
fn app_name(&self) -> Option<&str>;
#[must_use]
fn route_name(&self) -> Option<&str>;
#[must_use]
fn path_params(&self) -> &PathParams;
#[must_use]
fn path_params_mut(&mut self) -> &mut PathParams;
#[cfg(feature = "db")]
#[must_use]
#[deprecated(
since = "0.5.0",
note = "use request extractors (`FromRequestHead`) instead"
)]
fn db(&self) -> &crate::db::Database;
#[must_use]
fn content_type(&self) -> Option<&http::HeaderValue>;
fn expect_content_type(&mut self, expected: &'static str) -> Result<()> {
let content_type = self
.content_type()
.map_or("".into(), |value| String::from_utf8_lossy(value.as_bytes()));
if content_type == expected {
Ok(())
} else {
Err(InvalidContentType {
expected,
actual: content_type.into_owned(),
}
.into())
}
}
#[doc(hidden)]
fn extensions(&self) -> &Extensions;
}
impl private::Sealed for Request {}
impl RequestExt for Request {
async fn extract_from_head<E>(&mut self) -> Result<E>
where
E: FromRequestHead + 'static,
{
let request = std::mem::take(self);
let (head, body) = request.into_parts();
let result = E::from_request_head(&head).await;
*self = Request::from_parts(head, body);
result
}
#[track_caller]
fn context(&self) -> &crate::ProjectContext {
self.extensions()
.get::<Arc<crate::ProjectContext>>()
.expect("AppContext extension missing")
}
fn project_config(&self) -> &crate::config::ProjectConfig {
self.context().config()
}
fn router(&self) -> &Arc<Router> {
self.context().router()
}
fn app_name(&self) -> Option<&str> {
self.extensions()
.get::<AppName>()
.map(|AppName(name)| name.as_str())
}
fn route_name(&self) -> Option<&str> {
self.extensions()
.get::<RouteName>()
.map(|RouteName(name)| name.as_str())
}
#[track_caller]
fn path_params(&self) -> &PathParams {
self.extensions()
.get::<PathParams>()
.expect("PathParams extension missing")
}
fn path_params_mut(&mut self) -> &mut PathParams {
self.extensions_mut().get_or_insert_default::<PathParams>()
}
#[cfg(feature = "db")]
fn db(&self) -> &crate::db::Database {
self.context().database()
}
fn content_type(&self) -> Option<&http::HeaderValue> {
self.headers().get(http::header::CONTENT_TYPE)
}
fn extensions(&self) -> &Extensions {
self.extensions()
}
}
impl private::Sealed for RequestHead {}
impl RequestExt for RequestHead {
async fn extract_from_head<E>(&mut self) -> Result<E>
where
E: FromRequestHead + 'static,
{
E::from_request_head(self).await
}
fn context(&self) -> &crate::ProjectContext {
self.extensions
.get::<Arc<crate::ProjectContext>>()
.expect("AppContext extension missing")
}
fn project_config(&self) -> &crate::config::ProjectConfig {
self.context().config()
}
fn router(&self) -> &Arc<Router> {
self.context().router()
}
fn app_name(&self) -> Option<&str> {
self.extensions
.get::<AppName>()
.map(|AppName(name)| name.as_str())
}
fn route_name(&self) -> Option<&str> {
self.extensions
.get::<RouteName>()
.map(|RouteName(name)| name.as_str())
}
fn path_params(&self) -> &PathParams {
self.extensions
.get::<PathParams>()
.expect("PathParams extension missing")
}
fn path_params_mut(&mut self) -> &mut PathParams {
self.extensions.get_or_insert_default::<PathParams>()
}
#[cfg(feature = "db")]
fn db(&self) -> &crate::db::Database {
self.context().database()
}
fn content_type(&self) -> Option<&http::HeaderValue> {
self.headers.get(http::header::CONTENT_TYPE)
}
fn extensions(&self) -> &Extensions {
&self.extensions
}
}
#[derive(Debug, thiserror::Error)]
#[error("invalid content type; expected `{expected}`, found `{actual}`")]
pub(crate) struct InvalidContentType {
expected: &'static str,
actual: String,
}
impl_into_cot_error!(InvalidContentType, BAD_REQUEST);
#[repr(transparent)]
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub(crate) struct AppName(pub(crate) String);
#[repr(transparent)]
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub(crate) struct RouteName(pub(crate) String);
#[derive(Debug, Clone)]
pub struct PathParams {
params: IndexMap<String, String>,
}
impl Default for PathParams {
fn default() -> Self {
Self::new()
}
}
impl PathParams {
#[must_use]
pub fn new() -> Self {
Self {
params: IndexMap::new(),
}
}
pub fn insert(&mut self, name: String, value: String) {
self.params.insert(name, value);
}
pub fn iter(&self) -> impl Iterator<Item = (&str, &str)> {
self.params
.iter()
.map(|(name, value)| (name.as_str(), value.as_str()))
}
#[must_use]
pub fn len(&self) -> usize {
self.params.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.params.is_empty()
}
#[must_use]
pub fn get(&self, name: &str) -> Option<&str> {
self.params.get(name).map(String::as_str)
}
#[must_use]
pub fn get_index(&self, index: usize) -> Option<&str> {
self.params
.get_index(index)
.map(|(_, value)| value.as_str())
}
#[must_use]
pub fn key_at_index(&self, index: usize) -> Option<&str> {
self.params.get_index(index).map(|(key, _)| key.as_str())
}
pub fn parse<'de, T: serde::Deserialize<'de>>(
&'de self,
) -> std::result::Result<T, PathParamsDeserializerError> {
let deserializer = path_params_deserializer::PathParamsDeserializer::new(self);
serde_path_to_error::deserialize(deserializer).map_err(PathParamsDeserializerError)
}
}
#[derive(Debug, Clone, thiserror::Error)]
#[error("could not parse path parameters: {0}")]
pub struct PathParamsDeserializerError(
#[source] serde_path_to_error::Error<path_params_deserializer::PathParamsDeserializerError>,
);
impl_into_cot_error!(PathParamsDeserializerError, BAD_REQUEST);
#[cfg(test)]
mod tests {
use super::*;
use crate::request::extractors::Path;
use crate::response::Response;
use crate::router::{Route, Router};
use crate::test::TestRequestBuilder;
#[test]
fn path_params() {
let mut path_params = PathParams::new();
path_params.insert("name".into(), "world".into());
assert_eq!(path_params.get("name"), Some("world"));
assert_eq!(path_params.get("missing"), None);
}
#[test]
fn path_params_parse() {
#[derive(Debug, PartialEq, Eq, serde::Deserialize)]
struct Params {
hello: String,
foo: String,
}
let mut path_params = PathParams::new();
path_params.insert("hello".into(), "world".into());
path_params.insert("foo".into(), "bar".into());
let params: Params = path_params.parse().unwrap();
assert_eq!(
params,
Params {
hello: "world".to_string(),
foo: "bar".to_string(),
}
);
}
#[test]
fn request_ext_app_name() {
let mut request = TestRequestBuilder::get("/").build();
assert_eq!(request.app_name(), None);
request
.extensions_mut()
.insert(AppName("test_app".to_string()));
assert_eq!(request.app_name(), Some("test_app"));
}
#[test]
fn request_ext_route_name() {
let mut request = TestRequestBuilder::get("/").build();
assert_eq!(request.route_name(), None);
request
.extensions_mut()
.insert(RouteName("test_route".to_string()));
assert_eq!(request.route_name(), Some("test_route"));
}
#[test]
fn request_ext_parts_route_name() {
let request = TestRequestBuilder::get("/").build();
let (mut head, _body) = request.into_parts();
assert_eq!(head.route_name(), None);
head.extensions.insert(RouteName("test_route".to_string()));
assert_eq!(head.route_name(), Some("test_route"));
}
#[test]
fn request_ext_path_params() {
let mut request = TestRequestBuilder::get("/").build();
let mut params = PathParams::new();
params.insert("id".to_string(), "42".to_string());
request.extensions_mut().insert(params);
assert_eq!(request.path_params().get("id"), Some("42"));
}
#[test]
fn request_ext_path_params_mut() {
let mut request = TestRequestBuilder::get("/").build();
request
.path_params_mut()
.insert("id".to_string(), "42".to_string());
assert_eq!(request.path_params().get("id"), Some("42"));
}
#[test]
fn request_ext_content_type() {
let mut request = TestRequestBuilder::get("/").build();
assert_eq!(request.content_type(), None);
request.headers_mut().insert(
http::header::CONTENT_TYPE,
http::HeaderValue::from_static("text/plain"),
);
assert_eq!(
request.content_type(),
Some(&http::HeaderValue::from_static("text/plain"))
);
}
#[test]
fn request_ext_expect_content_type() {
let mut request = TestRequestBuilder::get("/").build();
assert!(request.expect_content_type("text/plain").is_err());
request.headers_mut().insert(
http::header::CONTENT_TYPE,
http::HeaderValue::from_static("text/plain"),
);
assert!(request.expect_content_type("text/plain").is_ok());
assert!(request.expect_content_type("application/json").is_err());
}
#[cot::test]
async fn request_ext_extract_from_head() {
async fn handler(mut request: Request) -> Result<Response> {
let Path(id): Path<String> = request.extract_from_head().await?;
assert_eq!(id, "42");
Ok(Response::new(Body::empty()))
}
let router = Router::with_urls([Route::with_handler("/{id}/", handler)]);
let request = TestRequestBuilder::get("/42/")
.router(router.clone())
.build();
router.handle(request).await.unwrap();
}
#[test]
fn parts_ext_path_params() {
let (mut head, _) = Request::new(Body::empty()).into_parts();
let mut params = PathParams::new();
params.insert("id".to_string(), "42".to_string());
head.extensions.insert(params);
assert_eq!(head.path_params().get("id"), Some("42"));
}
#[test]
fn parts_ext_mutating_path_params() {
let (mut head, _) = Request::new(Body::empty()).into_parts();
head.path_params_mut()
.insert("page".to_string(), "1".to_string());
assert_eq!(head.path_params().get("page"), Some("1"));
}
#[test]
fn parts_ext_app_name() {
let (mut head, _) = Request::new(Body::empty()).into_parts();
head.extensions.insert(AppName("test_app".to_string()));
assert_eq!(head.app_name(), Some("test_app"));
}
#[test]
fn parts_ext_route_name() {
let (mut head, _) = Request::new(Body::empty()).into_parts();
head.extensions.insert(RouteName("test_route".to_string()));
assert_eq!(head.route_name(), Some("test_route"));
}
#[test]
fn parts_ext_content_type() {
let (mut head, _) = Request::new(Body::empty()).into_parts();
head.headers.insert(
http::header::CONTENT_TYPE,
http::HeaderValue::from_static("text/plain"),
);
assert_eq!(
head.content_type(),
Some(&http::HeaderValue::from_static("text/plain"))
);
}
#[cot::test]
async fn path_extract_from_head() {
let (mut head, _) = Request::new(Body::empty()).into_parts();
let mut params = PathParams::new();
params.insert("id".to_string(), "42".to_string());
head.extensions.insert(params);
let Path(id): Path<String> = head.extract_from_head().await.unwrap();
assert_eq!(id, "42");
}
}