axum_extra/extract/cookie/
mod.rs

1//! Cookie parsing and cookie jar management.
2//!
3//! See [`CookieJar`], [`SignedCookieJar`], and [`PrivateCookieJar`] for more details.
4
5use axum::{
6    extract::FromRequestParts,
7    response::{IntoResponse, IntoResponseParts, Response, ResponseParts},
8};
9use http::{
10    header::{COOKIE, SET_COOKIE},
11    request::Parts,
12    HeaderMap,
13};
14use std::convert::Infallible;
15
16#[cfg(feature = "cookie-private")]
17mod private;
18#[cfg(feature = "cookie-signed")]
19mod signed;
20
21#[cfg(feature = "cookie-private")]
22pub use self::private::PrivateCookieJar;
23#[cfg(feature = "cookie-signed")]
24pub use self::signed::SignedCookieJar;
25
26pub use cookie::{Cookie, Expiration, SameSite};
27
28#[cfg(any(feature = "cookie-signed", feature = "cookie-private"))]
29pub use cookie::Key;
30
31/// Extractor that grabs cookies from the request and manages the jar.
32///
33/// Note that methods like [`CookieJar::add`], [`CookieJar::remove`], etc updates the [`CookieJar`]
34/// and returns it. This value _must_ be returned from the handler as part of the response for the
35/// changes to be propagated.
36///
37/// # Example
38///
39/// ```rust
40/// use axum::{
41///     Router,
42///     routing::{post, get},
43///     response::{IntoResponse, Redirect},
44///     http::StatusCode,
45/// };
46/// use axum_extra::{
47///     TypedHeader,
48///     headers::authorization::{Authorization, Bearer},
49///     extract::cookie::{CookieJar, Cookie},
50/// };
51///
52/// async fn create_session(
53///     TypedHeader(auth): TypedHeader<Authorization<Bearer>>,
54///     jar: CookieJar,
55/// ) -> Result<(CookieJar, Redirect), StatusCode> {
56///     if let Some(session_id) = authorize_and_create_session(auth.token()).await {
57///         Ok((
58///             // the updated jar must be returned for the changes
59///             // to be included in the response
60///             jar.add(Cookie::new("session_id", session_id)),
61///             Redirect::to("/me"),
62///         ))
63///     } else {
64///         Err(StatusCode::UNAUTHORIZED)
65///     }
66/// }
67///
68/// async fn me(jar: CookieJar) -> Result<(), StatusCode> {
69///     if let Some(session_id) = jar.get("session_id") {
70///         // fetch and render user...
71///         # Ok(())
72///     } else {
73///         Err(StatusCode::UNAUTHORIZED)
74///     }
75/// }
76///
77/// async fn authorize_and_create_session(token: &str) -> Option<String> {
78///     // authorize the user and create a session...
79///     # todo!()
80/// }
81///
82/// let app = Router::new()
83///     .route("/sessions", post(create_session))
84///     .route("/me", get(me));
85/// # let app: Router = app;
86/// ```
87#[must_use = "`CookieJar` should be returned as part of a `Response`, otherwise it does nothing."]
88#[derive(Debug, Default, Clone)]
89pub struct CookieJar {
90    jar: cookie::CookieJar,
91}
92
93impl<S> FromRequestParts<S> for CookieJar
94where
95    S: Send + Sync,
96{
97    type Rejection = Infallible;
98
99    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
100        Ok(Self::from_headers(&parts.headers))
101    }
102}
103
104fn cookies_from_request(headers: &HeaderMap) -> impl Iterator<Item = Cookie<'static>> + '_ {
105    headers
106        .get_all(COOKIE)
107        .into_iter()
108        .filter_map(|value| value.to_str().ok())
109        .flat_map(|value| value.split(';'))
110        .filter_map(|cookie| Cookie::parse_encoded(cookie.to_owned()).ok())
111}
112
113impl CookieJar {
114    /// Create a new `CookieJar` from a map of request headers.
115    ///
116    /// The cookies in `headers` will be added to the jar.
117    ///
118    /// This is intended to be used in middleware and other places where it might be difficult to
119    /// run extractors. Normally you should create `CookieJar`s through [`FromRequestParts`].
120    ///
121    /// [`FromRequestParts`]: axum::extract::FromRequestParts
122    pub fn from_headers(headers: &HeaderMap) -> Self {
123        let mut jar = cookie::CookieJar::new();
124        for cookie in cookies_from_request(headers) {
125            jar.add_original(cookie);
126        }
127        Self { jar }
128    }
129
130    /// Create a new empty `CookieJar`.
131    ///
132    /// This is intended to be used in middleware and other places where it might be difficult to
133    /// run extractors. Normally you should create `CookieJar`s through [`FromRequestParts`].
134    ///
135    /// If you need a jar that contains the headers from a request use `impl From<&HeaderMap> for
136    /// CookieJar`.
137    ///
138    /// [`FromRequestParts`]: axum::extract::FromRequestParts
139    pub fn new() -> Self {
140        Self::default()
141    }
142
143    /// Get a cookie from the jar.
144    ///
145    /// # Example
146    ///
147    /// ```rust
148    /// use axum_extra::extract::cookie::CookieJar;
149    /// use axum::response::IntoResponse;
150    ///
151    /// async fn handle(jar: CookieJar) {
152    ///     let value: Option<String> = jar
153    ///         .get("foo")
154    ///         .map(|cookie| cookie.value().to_owned());
155    /// }
156    /// ```
157    #[must_use]
158    pub fn get(&self, name: &str) -> Option<&Cookie<'static>> {
159        self.jar.get(name)
160    }
161
162    /// Remove a cookie from the jar.
163    ///
164    /// # Example
165    ///
166    /// ```rust
167    /// use axum_extra::extract::cookie::{CookieJar, Cookie};
168    /// use axum::response::IntoResponse;
169    ///
170    /// async fn handle(jar: CookieJar) -> CookieJar {
171    ///     jar.remove(Cookie::from("foo"))
172    /// }
173    /// ```
174    pub fn remove<C: Into<Cookie<'static>>>(mut self, cookie: C) -> Self {
175        self.jar.remove(cookie);
176        self
177    }
178
179    /// Add a cookie to the jar.
180    ///
181    /// The value will automatically be percent-encoded.
182    ///
183    /// # Example
184    ///
185    /// ```rust
186    /// use axum_extra::extract::cookie::{CookieJar, Cookie};
187    /// use axum::response::IntoResponse;
188    ///
189    /// async fn handle(jar: CookieJar) -> CookieJar {
190    ///     jar.add(Cookie::new("foo", "bar"))
191    /// }
192    /// ```
193    #[allow(clippy::should_implement_trait)]
194    pub fn add<C: Into<Cookie<'static>>>(mut self, cookie: C) -> Self {
195        self.jar.add(cookie);
196        self
197    }
198
199    /// Get an iterator over all cookies in the jar.
200    pub fn iter(&self) -> impl Iterator<Item = &'_ Cookie<'static>> {
201        self.jar.iter()
202    }
203}
204
205impl IntoResponseParts for CookieJar {
206    type Error = Infallible;
207
208    fn into_response_parts(self, mut res: ResponseParts) -> Result<ResponseParts, Self::Error> {
209        set_cookies(self.jar, res.headers_mut());
210        Ok(res)
211    }
212}
213
214impl IntoResponse for CookieJar {
215    fn into_response(self) -> Response {
216        (self, ()).into_response()
217    }
218}
219
220fn set_cookies(jar: cookie::CookieJar, headers: &mut HeaderMap) {
221    for cookie in jar.delta() {
222        if let Ok(header_value) = cookie.encoded().to_string().parse() {
223            headers.append(SET_COOKIE, header_value);
224        }
225    }
226
227    // we don't need to call `jar.reset_delta()` because `into_response_parts` consumes the cookie
228    // jar so it cannot be called multiple times.
229}
230
231#[cfg(test)]
232mod tests {
233    use super::*;
234    use axum::{body::Body, extract::FromRef, http::Request, routing::get, Router};
235    use http_body_util::BodyExt;
236    use tower::ServiceExt;
237
238    macro_rules! cookie_test {
239        ($name:ident, $jar:ty) => {
240            #[tokio::test]
241            async fn $name() {
242                async fn set_cookie(jar: $jar) -> impl IntoResponse {
243                    jar.add(Cookie::new("key", "value"))
244                }
245
246                async fn get_cookie(jar: $jar) -> impl IntoResponse {
247                    jar.get("key").unwrap().value().to_owned()
248                }
249
250                async fn remove_cookie(jar: $jar) -> impl IntoResponse {
251                    jar.remove(Cookie::from("key"))
252                }
253
254                let state = AppState {
255                    key: Key::generate(),
256                    custom_key: CustomKey(Key::generate()),
257                };
258
259                let app = Router::new()
260                    .route("/set", get(set_cookie))
261                    .route("/get", get(get_cookie))
262                    .route("/remove", get(remove_cookie))
263                    .with_state(state);
264
265                let res = app
266                    .clone()
267                    .oneshot(Request::builder().uri("/set").body(Body::empty()).unwrap())
268                    .await
269                    .unwrap();
270                let cookie_value = res.headers()["set-cookie"].to_str().unwrap();
271
272                let res = app
273                    .clone()
274                    .oneshot(
275                        Request::builder()
276                            .uri("/get")
277                            .header("cookie", cookie_value)
278                            .body(Body::empty())
279                            .unwrap(),
280                    )
281                    .await
282                    .unwrap();
283                let body = body_text(res).await;
284                assert_eq!(body, "value");
285
286                let res = app
287                    .clone()
288                    .oneshot(
289                        Request::builder()
290                            .uri("/remove")
291                            .header("cookie", cookie_value)
292                            .body(Body::empty())
293                            .unwrap(),
294                    )
295                    .await
296                    .unwrap();
297                assert!(res.headers()["set-cookie"]
298                    .to_str()
299                    .unwrap()
300                    .contains("key=;"));
301            }
302        };
303    }
304
305    cookie_test!(plaintext_cookies, CookieJar);
306
307    #[cfg(feature = "cookie-signed")]
308    cookie_test!(signed_cookies, SignedCookieJar);
309    #[cfg(feature = "cookie-signed")]
310    cookie_test!(signed_cookies_with_custom_key, SignedCookieJar<CustomKey>);
311
312    #[cfg(feature = "cookie-private")]
313    cookie_test!(private_cookies, PrivateCookieJar);
314    #[cfg(feature = "cookie-private")]
315    cookie_test!(private_cookies_with_custom_key, PrivateCookieJar<CustomKey>);
316
317    #[derive(Clone)]
318    struct AppState {
319        key: Key,
320        custom_key: CustomKey,
321    }
322
323    impl FromRef<AppState> for Key {
324        fn from_ref(state: &AppState) -> Key {
325            state.key.clone()
326        }
327    }
328
329    impl FromRef<AppState> for CustomKey {
330        fn from_ref(state: &AppState) -> CustomKey {
331            state.custom_key.clone()
332        }
333    }
334
335    #[derive(Clone)]
336    struct CustomKey(Key);
337
338    impl From<CustomKey> for Key {
339        fn from(custom: CustomKey) -> Self {
340            custom.0
341        }
342    }
343
344    #[cfg(feature = "cookie-signed")]
345    #[tokio::test]
346    async fn signed_cannot_access_invalid_cookies() {
347        async fn get_cookie(jar: SignedCookieJar) -> impl IntoResponse {
348            format!("{:?}", jar.get("key"))
349        }
350
351        let state = AppState {
352            key: Key::generate(),
353            custom_key: CustomKey(Key::generate()),
354        };
355
356        let app = Router::new()
357            .route("/get", get(get_cookie))
358            .with_state(state);
359
360        let res = app
361            .clone()
362            .oneshot(
363                Request::builder()
364                    .uri("/get")
365                    .header("cookie", "key=value")
366                    .body(Body::empty())
367                    .unwrap(),
368            )
369            .await
370            .unwrap();
371        let body = body_text(res).await;
372        assert_eq!(body, "None");
373    }
374
375    async fn body_text<B>(body: B) -> String
376    where
377        B: axum::body::HttpBody,
378        B::Error: std::fmt::Debug,
379    {
380        let bytes = body.collect().await.unwrap().to_bytes();
381        String::from_utf8(bytes.to_vec()).unwrap()
382    }
383}