use super::{Middleware, Next};
use crate::{HttpEntity, Request, Response};
use cookie::{Cookie, CookieJar};
use std::pin::Pin;
#[derive(Default, Clone)]
pub struct CookieMiddleware {
_v: (),
}
pub trait CookieExt: self::sealed::Sealed + Sized {
#[doc(hidden)]
fn extensions(&self) -> &http::Extensions;
#[doc(hidden)]
fn extensions_mut(&mut self) -> &mut http::Extensions;
fn cookies(&self) -> Option<&CookieJar> {
self.extensions().get::<CookieJar>()
}
fn cookies_mut(&mut self) -> &mut CookieJar {
if self.extensions().get::<CookieJar>().is_none() {
self.extensions_mut().insert(CookieJar::new());
}
self.extensions_mut().get_mut::<CookieJar>().unwrap()
}
#[must_use]
fn with_cookies(mut self, jar: CookieJar) -> Self {
self.extensions_mut().insert(jar);
self
}
fn cookie(&self, name: &str) -> Option<&str> {
self.cookies()
.and_then(|c| c.get(name))
.map(cookie::Cookie::value)
}
fn add_cookie(&mut self, cookie: Cookie<'static>) {
self.cookies_mut().add(cookie);
}
#[must_use]
fn with_cookie(mut self, cookie: Cookie<'static>) -> Self {
self.add_cookie(cookie);
self
}
}
impl self::sealed::Sealed for Request {}
impl self::sealed::Sealed for Response {}
impl CookieExt for Request {
fn extensions(&self) -> &http::Extensions {
self.extensions()
}
fn extensions_mut(&mut self) -> &mut http::Extensions {
self.extensions_mut()
}
}
impl CookieExt for Response {
fn extensions(&self) -> &http::Extensions {
self.extensions()
}
fn extensions_mut(&mut self) -> &mut http::Extensions {
self.extensions_mut()
}
}
mod sealed {
pub trait Sealed {}
}
impl std::fmt::Debug for CookieMiddleware {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CookieMiddleware").finish()
}
}
impl CookieMiddleware {
#[must_use]
pub fn new() -> Self {
Self { _v: () }
}
}
#[async_trait]
impl Middleware for CookieMiddleware {
async fn apply(
self: Pin<&Self>,
mut request: Request,
next: Next<'_>,
) -> Result<Response, anyhow::Error> {
let jar = request
.headers()
.get_all("Cookie")
.into_iter()
.filter_map(|h| h.to_str().ok())
.flat_map(|h| h.split(';'))
.filter_map(|h| Cookie::parse_encoded(h).ok())
.map(cookie::Cookie::into_owned)
.fold(CookieJar::new(), |mut jar, cookie| {
jar.add_original(cookie);
jar
});
request.extensions_mut().insert(jar);
let mut response = next.apply(request).await?;
let result_jar = response.extensions_mut().remove::<CookieJar>();
if let Some(jar) = result_jar {
let headers = response.headers_mut();
for cookie in jar.delta() {
if let Ok(cookie) = cookie.encoded().to_string().try_into() {
headers.append("Set-Cookie", cookie);
}
}
}
Ok(response)
}
}