1use crate::middleware::{SessionAction, SessionManagerState};
2use crate::types::{SessionData, SessionId};
3use modo::axum::extract::FromRequestParts;
4use modo::axum::http::request::Parts;
5use modo::{Error, HttpError};
6use serde::Serialize;
7use serde::de::DeserializeOwned;
8use std::sync::Arc;
9
10pub struct SessionManager {
21 state: Arc<SessionManagerState>,
22}
23
24impl<S: Send + Sync> FromRequestParts<S> for SessionManager {
25 type Rejection = Error;
26
27 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
28 let state = parts
29 .extensions
30 .get::<Arc<SessionManagerState>>()
31 .cloned()
32 .ok_or_else(|| Error::internal("SessionManager requires session middleware"))?;
33
34 Ok(Self { state })
35 }
36}
37
38impl SessionManager {
39 pub async fn authenticate(&self, user_id: &str) -> Result<(), Error> {
42 self.authenticate_with(user_id, serde_json::json!({})).await
43 }
44
45 pub async fn authenticate_with(
50 &self,
51 user_id: &str,
52 data: serde_json::Value,
53 ) -> Result<(), Error> {
54 {
56 let current = self.state.current_session.lock().await;
57 if let Some(ref session) = *current {
58 self.state.store.destroy(&session.id).await.map_err(|e| {
59 tracing::error!(
60 session_id = session.id.as_str(),
61 "Failed to destroy previous session during authentication: {e}"
62 );
63 Error::internal(format!("failed to invalidate previous session: {e}"))
64 })?;
65 }
66 }
67
68 let (session_data, token) = self
69 .state
70 .store
71 .create(&self.state.meta, user_id, Some(data))
72 .await?;
73
74 *self.state.current_session.lock().await = Some(session_data);
75 *self.state.action.lock().await = SessionAction::Set(token);
76 Ok(())
77 }
78
79 pub async fn logout(&self) -> Result<(), Error> {
81 {
82 let current = self.state.current_session.lock().await;
83 if let Some(ref session) = *current {
84 self.state.store.destroy(&session.id).await?;
85 }
86 }
87 *self.state.action.lock().await = SessionAction::Remove;
88 *self.state.current_session.lock().await = None;
89 Ok(())
90 }
91
92 pub async fn logout_all(&self) -> Result<(), Error> {
94 {
95 let current = self.state.current_session.lock().await;
96 if let Some(ref session) = *current {
97 self.state
98 .store
99 .destroy_all_for_user(&session.user_id)
100 .await?;
101 }
102 }
103 *self.state.action.lock().await = SessionAction::Remove;
104 *self.state.current_session.lock().await = None;
105 Ok(())
106 }
107
108 pub async fn logout_other(&self) -> Result<(), Error> {
110 let current = self.state.current_session.lock().await;
111 let session = current
112 .as_ref()
113 .ok_or_else(|| Error::from(HttpError::Unauthorized))?;
114 self.state
115 .store
116 .destroy_all_except(&session.user_id, &session.id)
117 .await
118 }
119
120 pub async fn revoke(&self, id: &SessionId) -> Result<(), Error> {
123 let current = self.state.current_session.lock().await;
124 let session = current
125 .as_ref()
126 .ok_or_else(|| Error::from(HttpError::Unauthorized))?;
127
128 let target = self
129 .state
130 .store
131 .read(id)
132 .await?
133 .ok_or_else(|| Error::from(HttpError::NotFound))?;
134
135 if target.user_id != session.user_id {
136 return Err(Error::from(HttpError::NotFound));
137 }
138
139 self.state.store.destroy(id).await
140 }
141
142 pub async fn rotate(&self) -> Result<(), Error> {
144 let session_id = {
145 let current = self.state.current_session.lock().await;
146 let session = current
147 .as_ref()
148 .ok_or_else(|| Error::from(HttpError::Unauthorized))?;
149 session.id.clone()
150 };
151
152 let new_token = self.state.store.rotate_token(&session_id).await?;
153
154 {
155 let mut current = self.state.current_session.lock().await;
156 if let Some(ref mut s) = *current {
157 s.token_hash = new_token.hash();
158 }
159 }
160
161 *self.state.action.lock().await = SessionAction::Set(new_token);
162 Ok(())
163 }
164
165 pub async fn current(&self) -> Option<SessionData> {
167 self.state.current_session.lock().await.clone()
168 }
169
170 pub async fn user_id(&self) -> Option<String> {
172 self.state
173 .current_session
174 .lock()
175 .await
176 .as_ref()
177 .map(|s| s.user_id.clone())
178 }
179
180 pub async fn is_authenticated(&self) -> bool {
182 self.state.current_session.lock().await.is_some()
183 }
184
185 pub async fn list_my_sessions(&self) -> Result<Vec<SessionData>, Error> {
187 let user_id = {
188 let current = self.state.current_session.lock().await;
189 let session = current
190 .as_ref()
191 .ok_or_else(|| Error::from(HttpError::Unauthorized))?;
192 session.user_id.clone()
193 };
194 self.state.store.list_for_user(&user_id).await
195 }
196
197 pub async fn get<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>, Error> {
199 let current = self.state.current_session.lock().await;
200 let session = match current.as_ref() {
201 Some(s) => s,
202 None => return Ok(None),
203 };
204 match session.data.get(key) {
205 Some(v) => match serde_json::from_value(v.clone()) {
206 Ok(val) => Ok(Some(val)),
207 Err(e) => {
208 tracing::warn!(key, error = %e, "Failed to deserialize session data key");
209 Ok(None)
210 }
211 },
212 None => Ok(None),
213 }
214 }
215
216 pub async fn set<T: Serialize>(&self, key: &str, value: &T) -> Result<(), Error> {
218 let mut current = self.state.current_session.lock().await;
219 let session = current
220 .as_mut()
221 .ok_or_else(|| Error::from(HttpError::Unauthorized))?;
222
223 if !session.data.is_object() {
224 session.data = serde_json::Value::Object(Default::default());
225 }
226 if let serde_json::Value::Object(ref mut map) = session.data {
227 map.insert(
228 key.to_string(),
229 serde_json::to_value(value)
230 .map_err(|e| Error::internal(format!("serialize session value: {e}")))?,
231 );
232 }
233 self.state
234 .store
235 .update_data(&session.id, session.data.clone())
236 .await
237 }
238
239 pub async fn remove_key(&self, key: &str) -> Result<(), Error> {
241 let mut current = self.state.current_session.lock().await;
242 let session = current
243 .as_mut()
244 .ok_or_else(|| Error::from(HttpError::Unauthorized))?;
245
246 if let serde_json::Value::Object(ref mut map) = session.data {
247 map.remove(key);
248 }
249 self.state
250 .store
251 .update_data(&session.id, session.data.clone())
252 .await
253 }
254}