#![doc = include_str!("../README.md")]
use axum_core::extract::{FromRequestParts, Request};
use axum_core::response::Response;
use cookie_rs::{Cookie, CookieJar};
use http::header::{COOKIE, SET_COOKIE};
use http::request::Parts;
use http::{HeaderValue, StatusCode};
use std::collections::BTreeSet;
use std::future::Future;
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll};
use tower_layer::Layer;
use tower_service::Service;
pub mod cookie {
pub use cookie_rs::*;
}
pub mod prelude {
pub use crate::CookieLayer;
pub use crate::CookieManager;
pub use cookie_rs::prelude::*;
}
#[derive(Clone)]
pub struct CookieManager {
jar: Arc<Mutex<CookieJar<'static>>>,
}
impl CookieManager {
pub fn new(jar: CookieJar<'static>) -> Self {
Self {
jar: Arc::new(Mutex::new(jar)),
}
}
pub fn add<C: Into<Cookie<'static>>>(&self, cookie: C) {
let mut jar = self.jar.lock().unwrap();
jar.add(cookie);
}
pub fn set<C: Into<Cookie<'static>>>(&self, cookie: C) {
self.add(cookie);
}
pub fn remove(&self, name: &str) {
let mut jar = self.jar.lock().unwrap();
jar.remove(name.to_owned());
}
pub fn get(&self, name: &str) -> Option<Cookie<'static>> {
let jar = self.jar.lock().unwrap();
jar.get(name).cloned()
}
pub fn cookie(&self) -> BTreeSet<Cookie<'static>> {
let jar = self.jar.lock().unwrap();
jar.cookie().into_iter().cloned().collect()
}
pub fn as_header_value(&self) -> Vec<String> {
let jar = self.jar.lock().unwrap();
jar.as_header_values()
}
}
impl<S> FromRequestParts<S> for CookieManager {
type Rejection = (StatusCode, String);
fn from_request_parts(
parts: &mut Parts,
_: &S,
) -> impl Future<Output = Result<Self, Self::Rejection>> + Send {
Box::pin(async move {
parts
.extensions
.get::<Result<Self, Self::Rejection>>()
.cloned()
.ok_or((
StatusCode::INTERNAL_SERVER_ERROR,
"CookieLayer is not initialized".to_string(),
))?
})
}
}
#[derive(Clone, Default)]
pub struct CookieLayer {
strict: bool,
}
impl CookieLayer {
pub fn strict() -> Self {
Self { strict: true }
}
}
impl<S> Layer<S> for CookieLayer {
type Service = CookieMiddleware<S>;
fn layer(&self, inner: S) -> Self::Service {
CookieMiddleware {
strict: self.strict,
inner,
}
}
}
#[derive(Clone)]
pub struct CookieMiddleware<S> {
strict: bool,
inner: S,
}
impl<S, ReqBody> Service<Request<ReqBody>> for CookieMiddleware<S>
where
S: Service<Request<ReqBody>, Response = Response<ReqBody>> + Send + 'static,
S::Future: Send + 'static,
ReqBody: Send + 'static,
{
type Response = S::Response;
type Error = S::Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
let cookie = req
.headers()
.get(COOKIE)
.map(|h| h.to_str())
.unwrap_or(Ok(""))
.map(|c| c.to_owned());
let manager = cookie
.map(|cookie| {
match self.strict {
false => CookieJar::parse(cookie),
true => CookieJar::parse_strict(cookie),
}
.map(CookieManager::new)
.map_err(|e| (StatusCode::BAD_REQUEST, e.to_string()))
})
.map_err(|e| (StatusCode::BAD_REQUEST, e.to_string()))
.and_then(|inner| inner);
req.extensions_mut().insert(manager.clone());
let fut = self.inner.call(req);
Box::pin(async move {
let mut response = fut.await?;
if let Ok(manager) = manager {
for cookie in manager.as_header_value() {
response
.headers_mut()
.append(SET_COOKIE, HeaderValue::from_str(&cookie).unwrap());
}
}
Ok(response)
})
}
}