mod accept;
mod addr;
#[cfg(feature = "compression")]
mod compress;
#[cfg(feature = "cookie")]
#[cfg_attr(docsrs, doc(cfg(feature = "cookie")))]
pub mod cookie;
mod data;
mod form;
mod json;
#[cfg(feature = "multipart")]
mod multipart;
mod path;
mod query;
mod real_ip;
mod redirect;
#[cfg(feature = "sse")]
#[cfg_attr(docsrs, doc(cfg(feature = "sse")))]
pub mod sse;
#[cfg(feature = "static-files")]
mod static_file;
#[cfg(feature = "tempfile")]
mod tempfile;
#[cfg(feature = "xml")]
mod xml;
#[cfg(feature = "yaml")]
mod yaml;
#[doc(inline)]
pub use headers;
#[cfg(feature = "csrf")]
mod csrf;
mod typed_header;
#[cfg(feature = "websocket")]
#[cfg_attr(docsrs, doc(cfg(feature = "websocket")))]
pub mod websocket;
use std::{convert::Infallible, fmt::Debug, future::Future};
#[cfg(feature = "compression")]
pub use async_compression::Level as CompressionLevel;
use bytes::Bytes;
use futures_util::FutureExt;
use http::header;
#[cfg(feature = "compression")]
pub use self::compress::{Compress, CompressionAlgo};
#[cfg(feature = "csrf")]
pub use self::csrf::{CsrfToken, CsrfVerifier};
#[cfg(feature = "multipart")]
pub use self::multipart::{Field, Multipart};
pub(crate) use self::path::PathDeserializer;
#[cfg(feature = "static-files")]
pub use self::static_file::{StaticFileRequest, StaticFileResponse};
#[cfg(feature = "tempfile")]
pub use self::tempfile::TempFile;
#[cfg(feature = "xml")]
pub use self::xml::Xml;
#[cfg(feature = "yaml")]
pub use self::yaml::Yaml;
pub use self::{
accept::Accept,
addr::{LocalAddr, RemoteAddr},
data::Data,
form::Form,
json::Json,
path::Path,
query::Query,
real_ip::RealIp,
redirect::Redirect,
typed_header::TypedHeader,
};
use crate::{
body::Body,
error::{ReadBodyError, Result},
http::{
HeaderValue, Method, StatusCode, Uri, Version,
header::{HeaderMap, HeaderName},
},
request::Request,
response::Response,
};
#[derive(Default)]
pub struct RequestBody(Option<Body>);
impl RequestBody {
pub fn new(body: Body) -> Self {
Self(Some(body))
}
pub fn take(&mut self) -> Result<Body, ReadBodyError> {
self.0.take().ok_or(ReadBodyError::BodyHasBeenTaken)
}
#[inline]
pub fn is_some(&self) -> bool {
self.0.is_some()
}
#[inline]
pub fn is_none(&self) -> bool {
self.0.is_none()
}
}
pub trait FromRequest<'a>: Sized {
fn from_request(
req: &'a Request,
body: &mut RequestBody,
) -> impl Future<Output = Result<Self>> + Send;
fn from_request_without_body(req: &'a Request) -> impl Future<Output = Result<Self>> + Send {
async move {
Self::from_request(req, &mut Default::default())
.boxed()
.await
}
}
}
pub trait IntoResponse: Send {
fn into_response(self) -> Response;
fn with_header<K, V>(self, key: K, value: V) -> WithHeader<Self>
where
K: TryInto<HeaderName>,
V: TryInto<HeaderValue>,
Self: Sized,
{
let key = key.try_into().ok();
let value = value.try_into().ok();
WithHeader {
inner: self,
header: key.zip(value),
}
}
fn with_content_type<V>(self, content_type: V) -> WithContentType<Self>
where
V: TryInto<HeaderValue>,
Self: Sized,
{
WithContentType {
inner: self,
content_type: content_type.try_into().ok(),
}
}
fn with_status(self, status: StatusCode) -> WithStatus<Self>
where
Self: Sized,
{
WithStatus {
inner: self,
status,
}
}
fn with_body(self, body: impl Into<Body>) -> WithBody<Self>
where
Self: Sized,
{
WithBody {
inner: self,
body: body.into(),
}
}
}
impl IntoResponse for Infallible {
fn into_response(self) -> Response {
unreachable!()
}
}
pub struct WithHeader<T> {
inner: T,
header: Option<(HeaderName, HeaderValue)>,
}
impl<T: IntoResponse> IntoResponse for WithHeader<T> {
fn into_response(self) -> Response {
let mut resp = self.inner.into_response();
if let Some((key, value)) = &self.header {
resp.headers_mut().append(key, value.clone());
}
resp
}
}
pub struct WithContentType<T> {
inner: T,
content_type: Option<HeaderValue>,
}
impl<T: IntoResponse> IntoResponse for WithContentType<T> {
fn into_response(self) -> Response {
let mut resp = self.inner.into_response();
if let Some(content_type) = self.content_type {
resp.headers_mut()
.insert(header::CONTENT_TYPE, content_type);
}
resp
}
}
pub struct WithStatus<T> {
inner: T,
status: StatusCode,
}
impl<T: IntoResponse> IntoResponse for WithStatus<T> {
fn into_response(self) -> Response {
let mut resp = self.inner.into_response();
resp.set_status(self.status);
resp
}
}
pub struct WithBody<T> {
inner: T,
body: Body,
}
impl<T: IntoResponse> IntoResponse for WithBody<T> {
fn into_response(self) -> Response {
let mut resp = self.inner.into_response();
resp.set_body(self.body);
resp
}
}
impl IntoResponse for Response {
fn into_response(self) -> Response {
self
}
}
impl IntoResponse for String {
fn into_response(self) -> Response {
Response::builder()
.content_type("text/plain; charset=utf-8")
.body(self)
}
}
impl IntoResponse for &'static str {
fn into_response(self) -> Response {
Response::builder()
.content_type("text/plain; charset=utf-8")
.body(self)
}
}
impl IntoResponse for &'static [u8] {
fn into_response(self) -> Response {
Response::builder()
.content_type("application/octet-stream")
.body(self)
}
}
impl IntoResponse for Bytes {
fn into_response(self) -> Response {
Response::builder()
.content_type("application/octet-stream")
.body(self)
}
}
impl IntoResponse for Vec<u8> {
fn into_response(self) -> Response {
Response::builder()
.content_type("application/octet-stream")
.body(self)
}
}
impl IntoResponse for () {
fn into_response(self) -> Response {
Response::builder().body(Body::empty())
}
}
impl IntoResponse for Body {
fn into_response(self) -> Response {
Response::builder().body(self)
}
}
impl IntoResponse for StatusCode {
fn into_response(self) -> Response {
Response::builder().status(self).finish()
}
}
impl<T: IntoResponse> IntoResponse for (StatusCode, T) {
fn into_response(self) -> Response {
let mut resp = self.1.into_response();
resp.set_status(self.0);
resp
}
}
impl<T: IntoResponse> IntoResponse for (StatusCode, HeaderMap, T) {
fn into_response(self) -> Response {
let mut resp = self.2.into_response();
resp.set_status(self.0);
resp.headers_mut().extend(self.1);
resp
}
}
impl<T: IntoResponse> IntoResponse for (HeaderMap, T) {
fn into_response(self) -> Response {
let mut resp = self.1.into_response();
resp.headers_mut().extend(self.0);
resp
}
}
#[derive(Debug, Clone, Eq, PartialEq, Default)]
pub struct Html<T>(pub T);
impl<T: Into<String> + Send> IntoResponse for Html<T> {
fn into_response(self) -> Response {
Response::builder()
.content_type("text/html; charset=utf-8")
.body(self.0.into())
}
}
impl<'a> FromRequest<'a> for &'a Request {
async fn from_request(req: &'a Request, _body: &mut RequestBody) -> Result<Self> {
Ok(req)
}
}
impl<'a> FromRequest<'a> for &'a Uri {
async fn from_request(req: &'a Request, _body: &mut RequestBody) -> Result<Self> {
Ok(req.uri())
}
}
impl<'a> FromRequest<'a> for Method {
async fn from_request(req: &'a Request, _body: &mut RequestBody) -> Result<Self> {
Ok(req.method().clone())
}
}
impl<'a> FromRequest<'a> for Version {
async fn from_request(req: &'a Request, _body: &mut RequestBody) -> Result<Self> {
Ok(req.version())
}
}
impl<'a> FromRequest<'a> for &'a HeaderMap {
async fn from_request(req: &'a Request, _body: &mut RequestBody) -> Result<Self> {
Ok(req.headers())
}
}
impl<'a> FromRequest<'a> for Body {
async fn from_request(_req: &'a Request, body: &mut RequestBody) -> Result<Self> {
Ok(body.take()?)
}
}
impl<'a> FromRequest<'a> for String {
async fn from_request(_req: &'a Request, body: &mut RequestBody) -> Result<Self> {
let data = body.take()?.into_bytes().await?;
Ok(String::from_utf8(data.to_vec()).map_err(ReadBodyError::Utf8)?)
}
}
impl<'a> FromRequest<'a> for Bytes {
async fn from_request(_req: &'a Request, body: &mut RequestBody) -> Result<Self> {
Ok(body.take()?.into_bytes().await?)
}
}
impl<'a> FromRequest<'a> for Vec<u8> {
async fn from_request(_req: &'a Request, body: &mut RequestBody) -> Result<Self> {
Ok(body.take()?.into_vec().await?)
}
}
impl<'a> FromRequest<'a> for &'a RemoteAddr {
async fn from_request(req: &'a Request, _body: &mut RequestBody) -> Result<Self> {
Ok(&req.state().remote_addr)
}
}
impl<'a> FromRequest<'a> for &'a LocalAddr {
async fn from_request(req: &'a Request, _body: &mut RequestBody) -> Result<Self> {
Ok(&req.state().local_addr)
}
}
impl<'a, T: FromRequest<'a>> FromRequest<'a> for Option<T> {
async fn from_request(req: &'a Request, body: &mut RequestBody) -> Result<Self> {
Ok(T::from_request(req, body).boxed().await.ok())
}
}
impl<'a, T: FromRequest<'a>> FromRequest<'a> for Result<T> {
async fn from_request(req: &'a Request, body: &mut RequestBody) -> Result<Self> {
Ok(T::from_request(req, body).boxed().await)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::Addr;
#[tokio::test]
async fn into_response() {
let resp = "abc".to_string().into_response();
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.into_body().into_string().await.unwrap(), "abc");
let resp = "abc".into_response();
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.into_body().into_string().await.unwrap(), "abc");
let resp = [1, 2, 3].into_response();
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.into_body().into_vec().await.unwrap(), &[1, 2, 3]);
let resp = Bytes::from_static(&[1, 2, 3]).into_response();
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.into_body().into_vec().await.unwrap(), &[1, 2, 3]);
let resp = vec![1, 2, 3].into_response();
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.into_body().into_vec().await.unwrap(), &[1, 2, 3]);
let resp = ().into_response();
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.into_body().into_vec().await.unwrap(), &[] as &[u8]);
let resp = (StatusCode::BAD_GATEWAY, "abc").into_response();
assert_eq!(resp.status(), StatusCode::BAD_GATEWAY);
assert_eq!(resp.into_body().into_string().await.unwrap(), "abc");
let resp = Response::builder()
.status(StatusCode::BAD_GATEWAY)
.header("Value1", "567")
.body("abc");
let mut headers = HeaderMap::new();
headers.append("Value2", HeaderValue::from_static("123"));
let resp = (headers, resp).into_response();
assert_eq!(resp.status(), StatusCode::BAD_GATEWAY);
assert_eq!(
resp.headers().get("Value1"),
Some(&HeaderValue::from_static("567"))
);
assert_eq!(
resp.headers().get("Value2"),
Some(&HeaderValue::from_static("123"))
);
assert_eq!(resp.into_body().into_string().await.unwrap(), "abc");
let resp = Response::builder()
.status(StatusCode::OK)
.header("Value1", "567")
.body("abc");
let mut headers = HeaderMap::new();
headers.append("Value2", HeaderValue::from_static("123"));
let resp = (StatusCode::BAD_GATEWAY, headers, resp).into_response();
assert_eq!(resp.status(), StatusCode::BAD_GATEWAY);
assert_eq!(
resp.headers().get("Value1"),
Some(&HeaderValue::from_static("567"))
);
assert_eq!(
resp.headers().get("Value2"),
Some(&HeaderValue::from_static("123"))
);
assert_eq!(resp.into_body().into_string().await.unwrap(), "abc");
let resp = StatusCode::CREATED.into_response();
assert_eq!(resp.status(), StatusCode::CREATED);
assert!(resp.into_body().into_string().await.unwrap().is_empty());
let resp = Html("abc").into_response();
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.content_type(), Some("text/html; charset=utf-8"));
assert_eq!(resp.into_body().into_string().await.unwrap(), "abc");
let resp = Json(serde_json::json!({ "a": 1, "b": 2})).into_response();
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.content_type(), Some("application/json; charset=utf-8"));
assert_eq!(
resp.into_body().into_string().await.unwrap(),
r#"{"a":1,"b":2}"#
);
#[cfg(feature = "xml")]
{
let resp = Xml(serde_json::json!({"a": 1, "b": 2})).into_response();
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.content_type(), Some("application/xml; charset=utf-8"));
assert_eq!(
resp.into_body().into_string().await.unwrap(),
r#"<root><a>1</a><b>2</b></root>"#
);
}
let resp = StatusCode::CONFLICT.with_body("abc").into_response();
assert_eq!(resp.status(), StatusCode::CONFLICT);
assert_eq!(resp.into_body().into_string().await.unwrap(), "abc");
let resp = Response::builder()
.header("Value1", "123")
.finish()
.with_header("Value2", "456")
.with_header("Value3", "789")
.into_response();
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(
resp.headers().get("Value1"),
Some(&HeaderValue::from_static("123"))
);
assert_eq!(
resp.headers().get("Value2"),
Some(&HeaderValue::from_static("456"))
);
assert_eq!(
resp.headers().get("Value3"),
Some(&HeaderValue::from_static("789"))
);
let resp = StatusCode::CONFLICT
.with_status(StatusCode::BAD_GATEWAY)
.into_response();
assert_eq!(resp.status(), StatusCode::BAD_GATEWAY);
assert!(resp.into_body().into_string().await.unwrap().is_empty());
}
#[tokio::test]
async fn from_request() {
fn request() -> Request {
let mut req = Request::builder()
.version(Version::HTTP_11)
.method(Method::DELETE)
.header("Value1", "123")
.header("Value2", "456")
.uri(Uri::from_static("http://example.com/a/b"))
.body("abc");
req.state_mut().remote_addr = RemoteAddr(Addr::custom("test", "example"));
req.state_mut().local_addr = LocalAddr(Addr::custom("test", "example-local"));
req
}
let req = request();
let (req, mut body) = req.split();
assert_eq!(
Version::from_request(&req, &mut body).await.unwrap(),
Version::HTTP_11
);
assert_eq!(
<&HeaderMap>::from_request(&req, &mut body).await.unwrap(),
&{
let mut headers = HeaderMap::new();
headers.append("Value1", HeaderValue::from_static("123"));
headers.append("Value2", HeaderValue::from_static("456"));
headers
}
);
assert_eq!(
<&Uri>::from_request(&req, &mut body).await.unwrap(),
&Uri::from_static("http://example.com/a/b")
);
assert_eq!(
<&RemoteAddr>::from_request(&req, &mut body).await.unwrap(),
&RemoteAddr(Addr::custom("test", "example"))
);
assert_eq!(
<&LocalAddr>::from_request(&req, &mut body).await.unwrap(),
&LocalAddr(Addr::custom("test", "example-local"))
);
assert_eq!(
<Method>::from_request(&req, &mut body).await.unwrap(),
Method::DELETE
);
let req = request();
let (req, mut body) = req.split();
assert_eq!(
String::from_request(&req, &mut body).await.unwrap(),
"abc".to_string()
);
let req = request();
let (req, mut body) = req.split();
assert_eq!(
<Vec<u8>>::from_request(&req, &mut body).await.unwrap(),
b"abc"
);
let req = request();
let (req, mut body) = req.split();
assert_eq!(
Bytes::from_request(&req, &mut body).await.unwrap(),
Bytes::from_static(b"abc")
);
}
}