Skip to main content

modo_session/
manager.rs

1use crate::middleware::{SessionAction, SessionManagerState};
2use crate::types::{SessionData, SessionId};
3use modo::axum::extract::FromRequestParts;
4use modo::axum::http::request::Parts;
5use modo::{Error, HttpError};
6use serde::Serialize;
7use serde::de::DeserializeOwned;
8use std::sync::Arc;
9
10/// Request-scoped session manager, available as an axum extractor.
11///
12/// Inject `SessionManager` as a handler parameter to read or modify the session
13/// for the current request.  The session middleware must be installed via
14/// [`crate::layer`] for the extractor to work; if the middleware is missing the
15/// extractor returns an internal error.
16///
17/// Changes made through `SessionManager` (authentication, logout, token
18/// rotation, data writes) are applied to the HTTP response cookie automatically
19/// by the middleware after the handler returns.
20pub struct SessionManager {
21    state: Arc<SessionManagerState>,
22}
23
24impl<S: Send + Sync> FromRequestParts<S> for SessionManager {
25    type Rejection = Error;
26
27    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
28        let state = parts
29            .extensions
30            .get::<Arc<SessionManagerState>>()
31            .cloned()
32            .ok_or_else(|| Error::internal("SessionManager requires session middleware"))?;
33
34        Ok(Self { state })
35    }
36}
37
38impl SessionManager {
39    /// Create a session for the given user.
40    /// Destroys any existing session first (fixation prevention).
41    pub async fn authenticate(&self, user_id: &str) -> Result<(), Error> {
42        self.authenticate_with(user_id, serde_json::json!({})).await
43    }
44
45    /// Create a session with custom data attached.
46    ///
47    /// Any existing session is destroyed before the new one is created to
48    /// prevent session-fixation attacks.
49    pub async fn authenticate_with(
50        &self,
51        user_id: &str,
52        data: serde_json::Value,
53    ) -> Result<(), Error> {
54        // Destroy current session (fixation prevention)
55        {
56            let current = self.state.current_session.lock().await;
57            if let Some(ref session) = *current {
58                self.state.store.destroy(&session.id).await.map_err(|e| {
59                    tracing::error!(
60                        session_id = session.id.as_str(),
61                        "Failed to destroy previous session during authentication: {e}"
62                    );
63                    Error::internal(format!("failed to invalidate previous session: {e}"))
64                })?;
65            }
66        }
67
68        let (session_data, token) = self
69            .state
70            .store
71            .create(&self.state.meta, user_id, Some(data))
72            .await?;
73
74        *self.state.current_session.lock().await = Some(session_data);
75        *self.state.action.lock().await = SessionAction::Set(token);
76        Ok(())
77    }
78
79    /// Destroy the current session. Cookie is removed automatically.
80    pub async fn logout(&self) -> Result<(), Error> {
81        {
82            let current = self.state.current_session.lock().await;
83            if let Some(ref session) = *current {
84                self.state.store.destroy(&session.id).await?;
85            }
86        }
87        *self.state.action.lock().await = SessionAction::Remove;
88        *self.state.current_session.lock().await = None;
89        Ok(())
90    }
91
92    /// Destroy ALL sessions for the current user.
93    pub async fn logout_all(&self) -> Result<(), Error> {
94        {
95            let current = self.state.current_session.lock().await;
96            if let Some(ref session) = *current {
97                self.state
98                    .store
99                    .destroy_all_for_user(&session.user_id)
100                    .await?;
101            }
102        }
103        *self.state.action.lock().await = SessionAction::Remove;
104        *self.state.current_session.lock().await = None;
105        Ok(())
106    }
107
108    /// Destroy all sessions for the current user except the current one.
109    pub async fn logout_other(&self) -> Result<(), Error> {
110        let current = self.state.current_session.lock().await;
111        let session = current
112            .as_ref()
113            .ok_or_else(|| Error::from(HttpError::Unauthorized))?;
114        self.state
115            .store
116            .destroy_all_except(&session.user_id, &session.id)
117            .await
118    }
119
120    /// Destroy a specific session by ID (for "manage my devices" UI).
121    /// Only works on sessions owned by the current user.
122    pub async fn revoke(&self, id: &SessionId) -> Result<(), Error> {
123        let current = self.state.current_session.lock().await;
124        let session = current
125            .as_ref()
126            .ok_or_else(|| Error::from(HttpError::Unauthorized))?;
127
128        let target = self
129            .state
130            .store
131            .read(id)
132            .await?
133            .ok_or_else(|| Error::from(HttpError::NotFound))?;
134
135        if target.user_id != session.user_id {
136            return Err(Error::from(HttpError::NotFound));
137        }
138
139        self.state.store.destroy(id).await
140    }
141
142    /// Regenerate the session token without changing the session ID.
143    pub async fn rotate(&self) -> Result<(), Error> {
144        let session_id = {
145            let current = self.state.current_session.lock().await;
146            let session = current
147                .as_ref()
148                .ok_or_else(|| Error::from(HttpError::Unauthorized))?;
149            session.id.clone()
150        };
151
152        let new_token = self.state.store.rotate_token(&session_id).await?;
153
154        {
155            let mut current = self.state.current_session.lock().await;
156            if let Some(ref mut s) = *current {
157                s.token_hash = new_token.hash();
158            }
159        }
160
161        *self.state.action.lock().await = SessionAction::Set(new_token);
162        Ok(())
163    }
164
165    /// Access the current session data (if authenticated).
166    pub async fn current(&self) -> Option<SessionData> {
167        self.state.current_session.lock().await.clone()
168    }
169
170    /// Get the current user ID.
171    pub async fn user_id(&self) -> Option<String> {
172        self.state
173            .current_session
174            .lock()
175            .await
176            .as_ref()
177            .map(|s| s.user_id.clone())
178    }
179
180    /// Check if a session is active.
181    pub async fn is_authenticated(&self) -> bool {
182        self.state.current_session.lock().await.is_some()
183    }
184
185    /// List all active sessions for the authenticated user.
186    pub async fn list_my_sessions(&self) -> Result<Vec<SessionData>, Error> {
187        let user_id = {
188            let current = self.state.current_session.lock().await;
189            let session = current
190                .as_ref()
191                .ok_or_else(|| Error::from(HttpError::Unauthorized))?;
192            session.user_id.clone()
193        };
194        self.state.store.list_for_user(&user_id).await
195    }
196
197    /// Get a typed value from the session data by key.
198    pub async fn get<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>, Error> {
199        let current = self.state.current_session.lock().await;
200        let session = match current.as_ref() {
201            Some(s) => s,
202            None => return Ok(None),
203        };
204        match session.data.get(key) {
205            Some(v) => match serde_json::from_value(v.clone()) {
206                Ok(val) => Ok(Some(val)),
207                Err(e) => {
208                    tracing::warn!(key, error = %e, "Failed to deserialize session data key");
209                    Ok(None)
210                }
211            },
212            None => Ok(None),
213        }
214    }
215
216    /// Set a single key in the session data (immediate DB write).
217    pub async fn set<T: Serialize>(&self, key: &str, value: &T) -> Result<(), Error> {
218        let mut current = self.state.current_session.lock().await;
219        let session = current
220            .as_mut()
221            .ok_or_else(|| Error::from(HttpError::Unauthorized))?;
222
223        if !session.data.is_object() {
224            session.data = serde_json::Value::Object(Default::default());
225        }
226        if let serde_json::Value::Object(ref mut map) = session.data {
227            map.insert(
228                key.to_string(),
229                serde_json::to_value(value)
230                    .map_err(|e| Error::internal(format!("serialize session value: {e}")))?,
231            );
232        }
233        self.state
234            .store
235            .update_data(&session.id, session.data.clone())
236            .await
237    }
238
239    /// Remove a key from the session data (immediate DB write).
240    pub async fn remove_key(&self, key: &str) -> Result<(), Error> {
241        let mut current = self.state.current_session.lock().await;
242        let session = current
243            .as_mut()
244            .ok_or_else(|| Error::from(HttpError::Unauthorized))?;
245
246        if let serde_json::Value::Object(ref mut map) = session.data {
247            map.remove(key);
248        }
249        self.state
250            .store
251            .update_data(&session.id, session.data.clone())
252            .await
253    }
254}