use std::borrow::Cow;
use std::sync::Arc;
use async_trait::async_trait;
use bytes::Bytes;
use indexmap::IndexMap;
pub use path_params_deserializer::PathParamsDeserializerError;
use tower_sessions::Session;
#[cfg(feature = "db")]
use crate::db::Database;
use crate::error::ErrorRepr;
use crate::headers::FORM_CONTENT_TYPE;
#[cfg(feature = "json")]
use crate::headers::JSON_CONTENT_TYPE;
use crate::router::Router;
use crate::{Body, Result};
mod path_params_deserializer;
pub type Request = http::Request<Body>;
mod private {
pub trait Sealed {}
}
#[async_trait]
pub trait RequestExt: private::Sealed {
#[must_use]
fn context(&self) -> &crate::ProjectContext;
#[must_use]
fn project_config(&self) -> &crate::config::ProjectConfig;
#[must_use]
fn router(&self) -> &Router;
fn app_name(&self) -> Option<&str>;
#[must_use]
fn route_name(&self) -> Option<&str>;
#[must_use]
fn path_params(&self) -> &PathParams;
#[must_use]
fn path_params_mut(&mut self) -> &mut PathParams;
#[cfg(feature = "db")]
#[must_use]
fn db(&self) -> &Database;
#[must_use]
fn session(&self) -> &Session;
#[must_use]
fn session_mut(&mut self) -> &mut Session;
async fn form_data(&mut self) -> Result<Bytes>;
#[cfg(feature = "json")]
async fn json<T: serde::de::DeserializeOwned>(&mut self) -> Result<T>;
#[must_use]
fn content_type(&self) -> Option<&http::HeaderValue>;
fn expect_content_type(&mut self, expected: &'static str) -> Result<()>;
}
impl private::Sealed for Request {}
#[async_trait]
impl RequestExt for Request {
fn context(&self) -> &crate::ProjectContext {
self.extensions()
.get::<Arc<crate::ProjectContext>>()
.expect("AppContext extension missing")
}
fn project_config(&self) -> &crate::config::ProjectConfig {
self.context().config()
}
fn router(&self) -> &Router {
self.context().router()
}
fn app_name(&self) -> Option<&str> {
self.extensions()
.get::<AppName>()
.map(|AppName(name)| name.as_str())
}
fn route_name(&self) -> Option<&str> {
self.extensions()
.get::<RouteName>()
.map(|RouteName(name)| name.as_str())
}
fn path_params(&self) -> &PathParams {
self.extensions()
.get::<PathParams>()
.expect("PathParams extension missing")
}
fn path_params_mut(&mut self) -> &mut PathParams {
self.extensions_mut().get_or_insert_default::<PathParams>()
}
#[cfg(feature = "db")]
fn db(&self) -> &Database {
self.context().database()
}
fn session(&self) -> &Session {
self.extensions()
.get::<Session>()
.expect("Session extension missing. Did you forget to add the SessionMiddleware?")
}
fn session_mut(&mut self) -> &mut Session {
self.extensions_mut()
.get_mut::<Session>()
.expect("Session extension missing. Did you forget to add the SessionMiddleware?")
}
async fn form_data(&mut self) -> Result<Bytes> {
if self.method() == http::Method::GET || self.method() == http::Method::HEAD {
if let Some(query) = self.uri().query() {
return Ok(Bytes::copy_from_slice(query.as_bytes()));
}
Ok(Bytes::new())
} else {
self.expect_content_type(FORM_CONTENT_TYPE)?;
let body = std::mem::take(self.body_mut());
let bytes = body.into_bytes().await?;
Ok(bytes)
}
}
#[cfg(feature = "json")]
async fn json<T: serde::de::DeserializeOwned>(&mut self) -> Result<T> {
self.expect_content_type(JSON_CONTENT_TYPE)?;
let body = std::mem::take(self.body_mut());
let bytes = body.into_bytes().await?;
Ok(serde_json::from_slice(&bytes)?)
}
fn content_type(&self) -> Option<&http::HeaderValue> {
self.headers().get(http::header::CONTENT_TYPE)
}
fn expect_content_type(&mut self, expected: &'static str) -> Result<()> {
let content_type = self
.content_type()
.map_or("".into(), |value| String::from_utf8_lossy(value.as_bytes()));
if content_type == expected {
Ok(())
} else {
Err(ErrorRepr::InvalidContentType {
expected,
actual: content_type.into_owned(),
}
.into())
}
}
}
#[repr(transparent)]
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub(crate) struct AppName(pub(crate) String);
#[repr(transparent)]
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub(crate) struct RouteName(pub(crate) String);
#[derive(Debug, Clone)]
pub struct PathParams {
params: IndexMap<String, String>,
}
impl Default for PathParams {
fn default() -> Self {
Self::new()
}
}
impl PathParams {
#[must_use]
pub fn new() -> Self {
Self {
params: IndexMap::new(),
}
}
pub fn insert(&mut self, name: String, value: String) {
self.params.insert(name, value);
}
pub fn iter(&self) -> impl Iterator<Item = (&str, &str)> {
self.params
.iter()
.map(|(name, value)| (name.as_str(), value.as_str()))
}
#[must_use]
pub fn len(&self) -> usize {
self.params.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.params.is_empty()
}
#[must_use]
pub fn get(&self, name: &str) -> Option<&str> {
self.params.get(name).map(String::as_str)
}
#[must_use]
pub fn get_index(&self, index: usize) -> Option<&str> {
self.params
.get_index(index)
.map(|(_, value)| value.as_str())
}
#[must_use]
pub fn key_at_index(&self, index: usize) -> Option<&str> {
self.params.get_index(index).map(|(key, _)| key.as_str())
}
pub fn parse<'de, T: serde::Deserialize<'de>>(
&'de self,
) -> std::result::Result<T, PathParamsDeserializerError> {
T::deserialize(path_params_deserializer::PathParamsDeserializer::new(self))
}
}
pub(crate) fn query_pairs(bytes: &Bytes) -> impl Iterator<Item = (Cow<'_, str>, Cow<'_, str>)> {
form_urlencoded::parse(bytes.as_ref())
}
#[cfg(test)]
mod tests {
use super::*;
#[cot::test]
async fn form_data() {
let mut request = http::Request::builder()
.method(http::Method::POST)
.header(http::header::CONTENT_TYPE, FORM_CONTENT_TYPE)
.body(Body::fixed("hello=world"))
.unwrap();
let bytes = request.form_data().await.unwrap();
assert_eq!(bytes, Bytes::from_static(b"hello=world"));
}
#[cfg(feature = "json")]
#[cot::test]
async fn json() {
let mut request = http::Request::builder()
.method(http::Method::POST)
.header(http::header::CONTENT_TYPE, JSON_CONTENT_TYPE)
.body(Body::fixed(r#"{"hello":"world"}"#))
.unwrap();
let data: serde_json::Value = request.json().await.unwrap();
assert_eq!(data, serde_json::json!({"hello": "world"}));
}
#[test]
fn path_params() {
let mut path_params = PathParams::new();
path_params.insert("name".into(), "world".into());
assert_eq!(path_params.get("name"), Some("world"));
assert_eq!(path_params.get("missing"), None);
}
#[test]
fn path_params_parse() {
let mut path_params = PathParams::new();
path_params.insert("hello".into(), "world".into());
path_params.insert("foo".into(), "bar".into());
#[derive(Debug, PartialEq, Eq, serde::Deserialize)]
struct Params {
hello: String,
foo: String,
}
let params: Params = path_params.parse().unwrap();
assert_eq!(
params,
Params {
hello: "world".to_string(),
foo: "bar".to_string(),
}
);
}
#[test]
fn create_query_pairs() {
let bytes = Bytes::from_static(b"hello=world&foo=bar");
let pairs: Vec<_> = query_pairs(&bytes).collect();
assert_eq!(
pairs,
vec![
(Cow::from("hello"), Cow::from("world")),
(Cow::from("foo"), Cow::from("bar"))
]
);
}
}