1use crate::middleware::{SessionAction, SessionManagerState};
2use crate::types::{SessionData, SessionId};
3use modo::axum::extract::FromRequestParts;
4use modo::axum::http::request::Parts;
5use modo::{Error, HttpError};
6use serde::Serialize;
7use serde::de::DeserializeOwned;
8use std::sync::Arc;
9
10pub struct SessionManager {
25 state: Arc<SessionManagerState>,
26}
27
28impl<S: Send + Sync> FromRequestParts<S> for SessionManager {
29 type Rejection = Error;
30
31 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
32 let state = parts
33 .extensions
34 .get::<Arc<SessionManagerState>>()
35 .cloned()
36 .ok_or_else(|| Error::internal("session manager requires session middleware"))?;
37
38 Ok(Self { state })
39 }
40}
41
42impl SessionManager {
43 pub async fn authenticate(&self, user_id: &str) -> Result<(), Error> {
49 self.authenticate_with(user_id, serde_json::json!({})).await
50 }
51
52 pub async fn authenticate_with(
58 &self,
59 user_id: &str,
60 data: serde_json::Value,
61 ) -> Result<(), Error> {
62 {
64 let current = self.state.current_session.lock().await;
65 if let Some(ref session) = *current {
66 self.state.store.destroy(&session.id).await.map_err(|e| {
67 tracing::error!(
68 session_id = session.id.as_str(),
69 "Failed to destroy previous session during authentication: {e}"
70 );
71 Error::internal(format!("failed to invalidate previous session: {e}"))
72 })?;
73 }
74 }
75
76 let (session_data, token) = self
77 .state
78 .store
79 .create(&self.state.meta, user_id, Some(data))
80 .await?;
81
82 *self.state.current_session.lock().await = Some(session_data);
83 *self.state.action.lock().await = SessionAction::Set(token);
84 Ok(())
85 }
86
87 pub async fn logout(&self) -> Result<(), Error> {
92 {
93 let current = self.state.current_session.lock().await;
94 if let Some(ref session) = *current {
95 self.state.store.destroy(&session.id).await?;
96 }
97 }
98 *self.state.action.lock().await = SessionAction::Remove;
99 *self.state.current_session.lock().await = None;
100 Ok(())
101 }
102
103 pub async fn logout_all(&self) -> Result<(), Error> {
108 {
109 let current = self.state.current_session.lock().await;
110 if let Some(ref session) = *current {
111 self.state
112 .store
113 .destroy_all_for_user(&session.user_id)
114 .await?;
115 }
116 }
117 *self.state.action.lock().await = SessionAction::Remove;
118 *self.state.current_session.lock().await = None;
119 Ok(())
120 }
121
122 pub async fn logout_other(&self) -> Result<(), Error> {
126 let current = self.state.current_session.lock().await;
127 let session = current
128 .as_ref()
129 .ok_or_else(|| Error::from(HttpError::Unauthorized))?;
130 self.state
131 .store
132 .destroy_all_except(&session.user_id, &session.id)
133 .await
134 }
135
136 pub async fn revoke(&self, id: &SessionId) -> Result<(), Error> {
143 let current = self.state.current_session.lock().await;
144 let session = current
145 .as_ref()
146 .ok_or_else(|| Error::from(HttpError::Unauthorized))?;
147
148 let target = self
149 .state
150 .store
151 .read(id)
152 .await?
153 .ok_or_else(|| Error::from(HttpError::NotFound))?;
154
155 if target.user_id != session.user_id {
156 return Err(Error::from(HttpError::NotFound));
157 }
158
159 self.state.store.destroy(id).await
160 }
161
162 pub async fn rotate(&self) -> Result<(), Error> {
167 let session_id = {
168 let current = self.state.current_session.lock().await;
169 let session = current
170 .as_ref()
171 .ok_or_else(|| Error::from(HttpError::Unauthorized))?;
172 session.id.clone()
173 };
174
175 let new_token = self.state.store.rotate_token(&session_id).await?;
176
177 {
178 let mut current = self.state.current_session.lock().await;
179 if let Some(ref mut s) = *current {
180 s.token_hash = new_token.hash();
181 }
182 }
183
184 *self.state.action.lock().await = SessionAction::Set(new_token);
185 Ok(())
186 }
187
188 pub async fn current(&self) -> Option<SessionData> {
191 self.state.current_session.lock().await.clone()
192 }
193
194 pub async fn user_id(&self) -> Option<String> {
197 self.state
198 .current_session
199 .lock()
200 .await
201 .as_ref()
202 .map(|s| s.user_id.clone())
203 }
204
205 pub async fn is_authenticated(&self) -> bool {
208 self.state.current_session.lock().await.is_some()
209 }
210
211 pub async fn list_my_sessions(&self) -> Result<Vec<SessionData>, Error> {
216 let user_id = {
217 let current = self.state.current_session.lock().await;
218 let session = current
219 .as_ref()
220 .ok_or_else(|| Error::from(HttpError::Unauthorized))?;
221 session.user_id.clone()
222 };
223 self.state.store.list_for_user(&user_id).await
224 }
225
226 pub async fn get<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>, Error> {
232 let current = self.state.current_session.lock().await;
233 let session = match current.as_ref() {
234 Some(s) => s,
235 None => return Ok(None),
236 };
237 match session.data.get(key) {
238 Some(v) => match serde_json::from_value(v.clone()) {
239 Ok(val) => Ok(Some(val)),
240 Err(e) => {
241 tracing::warn!(key, error = %e, "Failed to deserialize session data key");
242 Ok(None)
243 }
244 },
245 None => Ok(None),
246 }
247 }
248
249 pub async fn set<T: Serialize>(&self, key: &str, value: &T) -> Result<(), Error> {
253 let mut current = self.state.current_session.lock().await;
254 let session = current
255 .as_mut()
256 .ok_or_else(|| Error::from(HttpError::Unauthorized))?;
257
258 if !session.data.is_object() {
259 session.data = serde_json::Value::Object(Default::default());
260 }
261 if let serde_json::Value::Object(ref mut map) = session.data {
262 map.insert(
263 key.to_string(),
264 serde_json::to_value(value)
265 .map_err(|e| Error::internal(format!("serialize session value: {e}")))?,
266 );
267 }
268 self.state
269 .store
270 .update_data(&session.id, session.data.clone())
271 .await
272 }
273
274 pub async fn remove_key(&self, key: &str) -> Result<(), Error> {
279 let mut current = self.state.current_session.lock().await;
280 let session = current
281 .as_mut()
282 .ok_or_else(|| Error::from(HttpError::Unauthorized))?;
283
284 if let serde_json::Value::Object(ref mut map) = session.data {
285 map.remove(key);
286 }
287 self.state
288 .store
289 .update_data(&session.id, session.data.clone())
290 .await
291 }
292}