modo_session/
middleware.rs1use crate::meta::{SessionMeta, extract_client_ip, header_str};
16use crate::store::SessionStore;
17use crate::types::{SessionData, SessionToken};
18use chrono::Utc;
19use futures_util::future::BoxFuture;
20use http::{Request, Response};
21use modo::axum::extract::connect_info::ConnectInfo;
22use modo::cookies::{CookieOptions, build_cookie};
23use std::net::SocketAddr;
24use std::sync::Arc;
25use std::task::{Context, Poll};
26use tokio::sync::Mutex;
27use tower::{Layer, Service};
28
29#[derive(Clone)]
32pub(crate) enum SessionAction {
33 None,
34 Set(SessionToken),
35 Remove,
36}
37
38pub(crate) struct SessionManagerState {
39 pub store: SessionStore,
40 pub current_session: Mutex<Option<SessionData>>,
41 pub meta: SessionMeta,
42 pub action: Mutex<SessionAction>,
43}
44
45#[derive(Clone)]
51pub struct SessionContextLayer {
52 store: Arc<SessionStore>,
53}
54
55impl SessionContextLayer {
56 fn new(store: SessionStore) -> Self {
57 Self {
58 store: Arc::new(store),
59 }
60 }
61}
62
63impl<S> Layer<S> for SessionContextLayer {
64 type Service = SessionMiddleware<S>;
65
66 fn layer(&self, inner: S) -> Self::Service {
67 SessionMiddleware {
68 inner,
69 store: self.store.clone(),
70 }
71 }
72}
73
74pub fn layer(store: SessionStore) -> SessionContextLayer {
80 SessionContextLayer::new(store)
81}
82
83#[derive(Clone)]
90pub struct SessionMiddleware<S> {
91 inner: S,
92 store: Arc<SessionStore>,
93}
94
95impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for SessionMiddleware<S>
96where
97 S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone + Send + 'static,
98 S::Future: Send + 'static,
99 ReqBody: Send + 'static,
100 ResBody: Default + Send + 'static,
101{
102 type Response = Response<ResBody>;
103 type Error = S::Error;
104 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
105
106 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
107 self.inner.poll_ready(cx)
108 }
109
110 fn call(&mut self, mut request: Request<ReqBody>) -> Self::Future {
111 let store = self.store.clone();
112 let mut inner = self.inner.clone();
113
114 Box::pin(async move {
115 let config = store.config();
116 let cookie_name = &config.cookie_name;
117
118 let connect_ip = request
120 .extensions()
121 .get::<ConnectInfo<SocketAddr>>()
122 .map(|ci| ci.0.ip());
123 let headers = request.headers();
124 let ip = extract_client_ip(headers, &config.trusted_proxies, connect_ip);
125 let ua = header_str(headers, "user-agent");
126 let accept_lang = header_str(headers, "accept-language");
127 let accept_enc = header_str(headers, "accept-encoding");
128 let meta = SessionMeta::from_headers(ip, ua, accept_lang, accept_enc);
129
130 let session_token = read_session_cookie(headers, cookie_name);
132 let had_cookie = session_token.is_some();
133
134 let (current_session, read_failed) = if let Some(ref token) = session_token {
136 match store.read_by_token(token).await {
137 Ok(session) => (session, false),
138 Err(e) => {
139 tracing::error!("Failed to read session: {e}");
140 (None, true)
141 }
142 }
143 } else {
144 (None, false)
145 };
146
147 let current_session = if let Some(session) = current_session {
149 if config.validate_fingerprint && meta.fingerprint != session.fingerprint {
150 tracing::warn!(
151 session_id = session.id.as_str(),
152 user_id = session.user_id,
153 "Session fingerprint mismatch — possible hijack, destroying session"
154 );
155 let _ = store.destroy(&session.id).await;
156 None
157 } else {
158 Some(session)
159 }
160 } else {
161 None
162 };
163
164 let should_touch = current_session.as_ref().is_some_and(|s| {
166 let elapsed = Utc::now() - s.last_active_at;
167 elapsed >= chrono::Duration::seconds(config.touch_interval_secs as i64)
168 });
169
170 let manager_state = Arc::new(SessionManagerState {
172 store: (*store).clone(),
173 current_session: Mutex::new(current_session.clone()),
174 meta,
175 action: Mutex::new(SessionAction::None),
176 });
177
178 request.extensions_mut().insert(manager_state.clone());
179
180 let mut response = inner.call(request).await?;
182
183 let action = {
185 let guard = manager_state.action.lock().await;
186 guard.clone()
187 };
188
189 let ttl_secs = config.session_ttl_secs;
190
191 match action {
192 SessionAction::Set(token) => {
193 let opts = CookieOptions::from_config(store.cookie_config()).max_age(ttl_secs);
194 append_cookie_header(&mut response, cookie_name, &token.as_hex(), &opts);
195 }
196 SessionAction::Remove => {
197 let opts = CookieOptions::from_config(store.cookie_config()).max_age(0);
199 append_cookie_header(&mut response, cookie_name, "", &opts);
200 }
201 SessionAction::None => {
202 if should_touch && let Some(ref session) = current_session {
203 let new_expires = Utc::now() + chrono::Duration::seconds(ttl_secs as i64);
204 if let Err(e) = store.touch(&session.id, new_expires).await {
205 tracing::error!(
206 session_id = session.id.as_str(),
207 "Failed to touch session: {e}"
208 );
209 } else if let Some(ref token) = session_token {
210 let opts =
211 CookieOptions::from_config(store.cookie_config()).max_age(ttl_secs);
212 append_cookie_header(
213 &mut response,
214 cookie_name,
215 &token.as_hex(),
216 &opts,
217 );
218 }
219 }
220
221 if had_cookie && current_session.is_none() && !read_failed {
223 let opts = CookieOptions::from_config(store.cookie_config()).max_age(0);
225 append_cookie_header(&mut response, cookie_name, "", &opts);
226 }
227 }
228 }
229
230 Ok(response)
231 })
232 }
233}
234
235pub fn user_id_from_extensions(extensions: &http::Extensions) -> Option<String> {
244 extensions
245 .get::<Arc<SessionManagerState>>()
246 .and_then(|state| match state.current_session.try_lock() {
247 Ok(guard) => guard.as_ref().map(|s| s.user_id.clone()),
248 Err(_) => {
249 tracing::trace!("user_id_from_extensions: session lock contended, returning None");
250 None
251 }
252 })
253}
254
255fn read_session_cookie(headers: &http::HeaderMap, cookie_name: &str) -> Option<SessionToken> {
258 headers
259 .get_all(http::header::COOKIE)
260 .iter()
261 .find_map(|val| {
262 let val = val.to_str().ok()?;
263 for pair in val.split(';') {
264 let pair = pair.trim();
265 if let Some(value) = pair.strip_prefix(cookie_name) {
266 let value = value.strip_prefix('=')?;
267 return SessionToken::from_hex(value).ok();
268 }
269 }
270 None
271 })
272}
273
274fn append_cookie_header<B>(
275 response: &mut Response<B>,
276 name: &str,
277 value: &str,
278 opts: &CookieOptions,
279) {
280 let cookie = build_cookie(name, value, opts);
281 match http::HeaderValue::try_from(cookie.to_string()) {
282 Ok(val) => {
283 response.headers_mut().append(http::header::SET_COOKIE, val);
284 }
285 Err(e) => {
286 tracing::warn!(
287 cookie_name = name,
288 "Failed to serialize session cookie: {e}"
289 );
290 }
291 }
292}