axum_cookie/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use axum_core::extract::{FromRequestParts, Request};
4use axum_core::response::Response;
5use cookie_rs::{Cookie, CookieJar};
6use http::header::{COOKIE, SET_COOKIE};
7use http::request::Parts;
8use http::{HeaderValue, StatusCode};
9use std::collections::BTreeSet;
10use std::future::Future;
11use std::pin::Pin;
12use std::sync::{Arc, Mutex};
13use std::task::{Context, Poll};
14use tower_layer::Layer;
15use tower_service::Service;
16
17pub mod cookie {
18    pub use cookie_rs::*;
19}
20
21pub mod prelude {
22    pub use crate::CookieLayer;
23    pub use crate::CookieManager;
24    pub use cookie_rs::prelude::*;
25}
26
27/// Manages cookies using a thread-safe `CookieJar`.
28/// This struct provides methods to add, remove, and retrieve cookies,
29/// as well as generate `Set-Cookie` headers for HTTP responses.
30#[derive(Clone)]
31pub struct CookieManager {
32    jar: Arc<Mutex<CookieJar<'static>>>,
33}
34
35impl CookieManager {
36    /// Creates a new instance of `CookieManager` with the specified cookie jar.
37    ///
38    /// # Arguments
39    /// * `jar` - The initial cookie jar to manage cookies.
40    pub fn new(jar: CookieJar<'static>) -> Self {
41        Self {
42            jar: Arc::new(Mutex::new(jar)),
43        }
44    }
45
46    /// Adds a cookie to the jar.
47    ///
48    /// # Arguments
49    /// * `cookie` - The cookie to add to the jar.
50    pub fn add<C: Into<Cookie<'static>>>(&self, cookie: C) {
51        let mut jar = self.jar.lock().unwrap();
52
53        jar.add(cookie);
54    }
55
56    /// Adds a cookie to the jar.
57    ///
58    /// # Arguments
59    /// * `cookie` - The cookie to add to the jar.
60    ///
61    /// > alias for `CookieManager::add`
62    pub fn set<C: Into<Cookie<'static>>>(&self, cookie: C) {
63        self.add(cookie);
64    }
65
66    /// Removes a cookie from the jar by its name.
67    ///
68    /// # Arguments
69    /// * `name` - The name of the cookie to remove.
70    pub fn remove(&self, name: &str) {
71        let mut jar = self.jar.lock().unwrap();
72
73        jar.remove(name.to_owned());
74    }
75
76    /// Retrieves a cookie from the jar by its name.
77    ///
78    /// # Arguments
79    /// * `name` - The name of the cookie to retrieve.
80    ///
81    /// # Returns
82    /// * `Option<Cookie<'static>>` - The cookie if found, otherwise `None`.
83    pub fn get(&self, name: &str) -> Option<Cookie<'static>> {
84        let jar = self.jar.lock().unwrap();
85
86        jar.get(name).cloned()
87    }
88
89    /// Returns all cookies in the jar as a set.
90    ///
91    /// # Returns
92    /// * `BTreeSet<Cookie<'static>>` - A set of all cookies currently in the jar.
93    pub fn cookie(&self) -> BTreeSet<Cookie<'static>> {
94        let jar = self.jar.lock().unwrap();
95
96        jar.cookie().into_iter().cloned().collect()
97    }
98
99    /// Generates `Set-Cookie` header value for all cookies in the jar.
100    ///
101    /// # Returns
102    /// * `Vec<String>` - A vector of `Set-Cookie` header string value.
103    pub fn as_header_value(&self) -> Vec<String> {
104        let jar = self.jar.lock().unwrap();
105
106        jar.as_header_values()
107    }
108}
109
110impl<S> FromRequestParts<S> for CookieManager {
111    type Rejection = (StatusCode, String);
112
113    fn from_request_parts(
114        parts: &mut Parts,
115        _: &S,
116    ) -> impl Future<Output = Result<Self, Self::Rejection>> + Send {
117        Box::pin(async move {
118            parts
119                .extensions
120                .get::<Result<Self, Self::Rejection>>()
121                .cloned()
122                .ok_or((
123                    StatusCode::INTERNAL_SERVER_ERROR,
124                    "CookieLayer is not initialized".to_string(),
125                ))?
126        })
127    }
128}
129
130/// A middleware layer for processing cookies.
131/// This layer integrates cookie management into the middleware stack.
132#[derive(Clone, Default)]
133pub struct CookieLayer {
134    strict: bool,
135}
136
137impl CookieLayer {
138    /// Creates a layer with strict cookie parsing enabled.
139    pub fn strict() -> Self {
140        Self { strict: true }
141    }
142}
143
144impl<S> Layer<S> for CookieLayer {
145    type Service = CookieMiddleware<S>;
146
147    fn layer(&self, inner: S) -> Self::Service {
148        CookieMiddleware {
149            strict: self.strict,
150            inner,
151        }
152    }
153}
154
155/// Middleware for handling HTTP requests and responses with cookies.
156/// This middleware parses cookies from requests and adds `Set-Cookie` headers to responses.
157#[derive(Clone)]
158pub struct CookieMiddleware<S> {
159    strict: bool,
160    inner: S,
161}
162
163impl<S, ReqBody> Service<Request<ReqBody>> for CookieMiddleware<S>
164where
165    S: Service<Request<ReqBody>, Response = Response<ReqBody>> + Send + 'static,
166    S::Future: Send + 'static,
167    ReqBody: Send + 'static,
168{
169    type Response = S::Response;
170    type Error = S::Error;
171    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
172
173    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
174        self.inner.poll_ready(cx)
175    }
176
177    fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
178        let cookie = req
179            .headers()
180            .get_all(COOKIE)
181            .iter()
182            .map(|h| h.to_str())
183            .collect::<Result<Box<[_]>, _>>()
184            .map(|c| c.join(";"));
185
186        let manager = cookie
187            .map(|cookie| {
188                match self.strict {
189                    false => CookieJar::parse(cookie),
190                    true => CookieJar::parse_strict(cookie),
191                }
192                .map(CookieManager::new)
193                .map_err(|e| (StatusCode::BAD_REQUEST, e.to_string()))
194            })
195            .map_err(|e| (StatusCode::BAD_REQUEST, e.to_string()))
196            .and_then(|inner| inner);
197
198        req.extensions_mut().insert(manager.clone());
199
200        let fut = self.inner.call(req);
201
202        Box::pin(async move {
203            let mut response = fut.await?;
204
205            if let Ok(manager) = manager {
206                for cookie in manager.as_header_value() {
207                    response
208                        .headers_mut()
209                        .append(SET_COOKIE, HeaderValue::from_str(&cookie).unwrap());
210                }
211            }
212
213            Ok(response)
214        })
215    }
216}