1#![doc = include_str!("../README.md")]
2
3use axum_core::extract::{FromRequestParts, Request};
4use axum_core::response::Response;
5use cookie_rs::{Cookie, CookieJar};
6use http::header::{COOKIE, SET_COOKIE};
7use http::request::Parts;
8use http::{HeaderValue, StatusCode};
9use std::collections::BTreeSet;
10use std::future::Future;
11use std::pin::Pin;
12use std::sync::{Arc, Mutex};
13use std::task::{Context, Poll};
14use tower_layer::Layer;
15use tower_service::Service;
16
17pub mod cookie {
18 pub use cookie_rs::*;
19}
20
21pub mod prelude {
22 pub use crate::CookieLayer;
23 pub use crate::CookieManager;
24 pub use cookie_rs::prelude::*;
25}
26
27#[derive(Clone)]
31pub struct CookieManager {
32 jar: Arc<Mutex<CookieJar<'static>>>,
33}
34
35impl CookieManager {
36 pub fn new(jar: CookieJar<'static>) -> Self {
41 Self {
42 jar: Arc::new(Mutex::new(jar)),
43 }
44 }
45
46 pub fn add<C: Into<Cookie<'static>>>(&self, cookie: C) {
51 let mut jar = self.jar.lock().unwrap();
52
53 jar.add(cookie);
54 }
55
56 pub fn set<C: Into<Cookie<'static>>>(&self, cookie: C) {
63 self.add(cookie);
64 }
65
66 pub fn remove(&self, name: &str) {
71 let mut jar = self.jar.lock().unwrap();
72
73 jar.remove(name.to_owned());
74 }
75
76 pub fn get(&self, name: &str) -> Option<Cookie<'static>> {
84 let jar = self.jar.lock().unwrap();
85
86 jar.get(name).cloned()
87 }
88
89 pub fn cookie(&self) -> BTreeSet<Cookie<'static>> {
94 let jar = self.jar.lock().unwrap();
95
96 jar.cookie().into_iter().cloned().collect()
97 }
98
99 pub fn as_header_value(&self) -> Vec<String> {
104 let jar = self.jar.lock().unwrap();
105
106 jar.as_header_values()
107 }
108}
109
110impl<S> FromRequestParts<S> for CookieManager {
111 type Rejection = (StatusCode, String);
112
113 fn from_request_parts(
114 parts: &mut Parts,
115 _: &S,
116 ) -> impl Future<Output = Result<Self, Self::Rejection>> + Send {
117 Box::pin(async move {
118 parts
119 .extensions
120 .get::<Result<Self, Self::Rejection>>()
121 .cloned()
122 .ok_or((
123 StatusCode::INTERNAL_SERVER_ERROR,
124 "CookieLayer is not initialized".to_string(),
125 ))?
126 })
127 }
128}
129
130#[derive(Clone, Default)]
133pub struct CookieLayer {
134 strict: bool,
135}
136
137impl CookieLayer {
138 pub fn strict() -> Self {
140 Self { strict: true }
141 }
142}
143
144impl<S> Layer<S> for CookieLayer {
145 type Service = CookieMiddleware<S>;
146
147 fn layer(&self, inner: S) -> Self::Service {
148 CookieMiddleware {
149 strict: self.strict,
150 inner,
151 }
152 }
153}
154
155#[derive(Clone)]
158pub struct CookieMiddleware<S> {
159 strict: bool,
160 inner: S,
161}
162
163impl<S, ReqBody> Service<Request<ReqBody>> for CookieMiddleware<S>
164where
165 S: Service<Request<ReqBody>, Response = Response<ReqBody>> + Send + 'static,
166 S::Future: Send + 'static,
167 ReqBody: Send + 'static,
168{
169 type Response = S::Response;
170 type Error = S::Error;
171 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
172
173 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
174 self.inner.poll_ready(cx)
175 }
176
177 fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
178 let cookie = req
179 .headers()
180 .get_all(COOKIE)
181 .iter()
182 .map(|h| h.to_str())
183 .collect::<Result<Box<[_]>, _>>()
184 .map(|c| c.join(";"));
185
186 let manager = cookie
187 .map(|cookie| {
188 match self.strict {
189 false => CookieJar::parse(cookie),
190 true => CookieJar::parse_strict(cookie),
191 }
192 .map(CookieManager::new)
193 .map_err(|e| (StatusCode::BAD_REQUEST, e.to_string()))
194 })
195 .map_err(|e| (StatusCode::BAD_REQUEST, e.to_string()))
196 .and_then(|inner| inner);
197
198 req.extensions_mut().insert(manager.clone());
199
200 let fut = self.inner.call(req);
201
202 Box::pin(async move {
203 let mut response = fut.await?;
204
205 if let Ok(manager) = manager {
206 for cookie in manager.as_header_value() {
207 response
208 .headers_mut()
209 .append(SET_COOKIE, HeaderValue::from_str(&cookie).unwrap());
210 }
211 }
212
213 Ok(response)
214 })
215 }
216}