Skip to main content

axum_security/cookie/
mod.rs

1mod builder;
2mod expiry;
3mod id;
4mod service;
5mod session;
6mod store;
7
8use std::{borrow::Cow, convert::Infallible, error::Error, sync::Arc};
9
10use axum::{
11    extract::{FromRef, FromRequestParts},
12    http::{HeaderMap, request::Parts},
13};
14pub use builder::CookieSessionBuilder;
15pub use id::SessionId;
16pub use session::CookieSession;
17pub use store::{CookieStore, MemStore};
18
19pub use cookie_monster::{Cookie, CookieBuilder, CookieJar, Expires, SameSite};
20use tokio::task::JoinHandle;
21
22use crate::{
23    cookie::store::{BoxDynError, ErasedStore},
24    utils::utc_now,
25};
26
27pub struct CookieContext<S>(Arc<CookieContextInner<S>>);
28
29struct CookieContextInner<S> {
30    store: ErasedStore<S>,
31    cookie_opts: CookieBuilder,
32    handle: Option<JoinHandle<()>>,
33}
34
35impl CookieContext<()> {
36    pub fn builder() -> CookieSessionBuilder<()> {
37        CookieSessionBuilder::new()
38    }
39}
40
41impl<S: 'static> CookieContext<S> {
42    pub fn get_cookie(&self, session_id: SessionId) -> Cookie {
43        self.0.cookie_opts.clone().value(session_id).build()
44    }
45
46    pub async fn create_session(
47        &self,
48        state: S,
49    ) -> Result<Cookie, Box<dyn Error + Send + 'static>> {
50        let session_id = SessionId::new();
51        tracing::debug!("Storing {session_id:?} in cookie store");
52        let now = utc_now().as_secs();
53        let session = CookieSession::new(session_id.clone(), now, state);
54        self.0.store.store_session(session).await?;
55
56        Ok(self.get_cookie(session_id))
57    }
58
59    pub async fn remove_session_jar(
60        &self,
61        jar: &CookieJar,
62    ) -> Result<Option<CookieSession<S>>, BoxDynError> {
63        let Some(session_id) = self.session_id_from_jar(jar) else {
64            return Ok(None);
65        };
66
67        self.0.store.remove_session(&session_id).await
68    }
69
70    pub async fn remove_session_cookie(
71        &self,
72        cookie: &Cookie,
73    ) -> Result<Option<CookieSession<S>>, BoxDynError> {
74        let session_id = SessionId::from_cookie(cookie);
75        self.remove_session(&session_id).await
76    }
77
78    pub async fn remove_session(
79        &self,
80        session_id: &SessionId,
81    ) -> Result<Option<CookieSession<S>>, BoxDynError> {
82        self.0.store.remove_session(session_id).await
83    }
84
85    pub fn build_cookie(&self, name: impl Into<Cow<'static, str>>) -> CookieBuilder {
86        self.0.cookie_opts.clone().name(name)
87    }
88
89    pub fn cookie_builder(&self) -> &CookieBuilder {
90        &self.0.cookie_opts
91    }
92
93    pub async fn remove_before(&self, deadline: u64) -> Result<(), BoxDynError> {
94        self.0.store.remove_before(deadline).await
95    }
96
97    pub(crate) async fn load_from_headers(
98        &self,
99        headers: &HeaderMap,
100    ) -> Result<Option<CookieSession<S>>, BoxDynError> {
101        let cookies = CookieJar::from_headers(headers);
102
103        self.load_from_jar(&cookies).await
104    }
105
106    pub(crate) async fn load_from_jar(
107        &self,
108        cookies: &CookieJar,
109    ) -> Result<Option<CookieSession<S>>, BoxDynError> {
110        let Some(session_id) = self.session_id_from_jar(cookies) else {
111            return Ok(None);
112        };
113
114        self.0.store.load_session(&session_id).await
115    }
116
117    pub(crate) fn session_id_from_jar(&self, jar: &CookieJar) -> Option<SessionId> {
118        let cookie = jar.get(self.0.cookie_opts.get_name())?;
119
120        Some(SessionId::from_cookie(cookie))
121    }
122
123    pub async fn load_from_cookie(
124        &self,
125        cookie: &Cookie,
126    ) -> Result<Option<CookieSession<S>>, BoxDynError> {
127        let session_id = SessionId::from_cookie(cookie);
128
129        self.0.store.load_session(&session_id).await
130    }
131}
132
133impl<S, U> FromRequestParts<S> for CookieContext<U>
134where
135    CookieContext<U>: FromRef<S>,
136    S: Send + Sync,
137{
138    type Rejection = Infallible;
139
140    async fn from_request_parts(_parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
141        Ok(Self::from_ref(state))
142    }
143}
144
145impl<S> Drop for CookieContextInner<S> {
146    fn drop(&mut self) {
147        // Make sure to cancel the bg task if the cookie context is dropped. This is only
148        // implemented for the Inner type because we don't to cancel the task if a weak reference
149        // is dropped.
150        if let Some(handle) = &self.handle {
151            handle.abort();
152        }
153    }
154}
155impl<S> Clone for CookieContext<S> {
156    fn clone(&self) -> Self {
157        CookieContext(self.0.clone())
158    }
159}