modo_session/
middleware.rs1use crate::meta::{SessionMeta, extract_client_ip, header_str};
16use crate::store::SessionStore;
17use crate::types::{SessionData, SessionToken};
18use chrono::Utc;
19use futures_util::future::BoxFuture;
20use http::{Request, Response};
21use modo::axum::extract::connect_info::ConnectInfo;
22use modo::cookies::{CookieOptions, build_cookie};
23use std::net::SocketAddr;
24use std::sync::Arc;
25use std::task::{Context, Poll};
26use tokio::sync::Mutex;
27use tower::{Layer, Service};
28
29#[derive(Clone)]
32pub(crate) enum SessionAction {
33 None,
34 Set(SessionToken),
35 Remove,
36}
37
38pub(crate) struct SessionManagerState {
39 pub store: SessionStore,
40 pub current_session: Mutex<Option<SessionData>>,
41 pub meta: SessionMeta,
42 pub action: Mutex<SessionAction>,
43}
44
45#[derive(Clone)]
48pub struct SessionLayer {
49 store: Arc<SessionStore>,
50}
51
52impl SessionLayer {
53 fn new(store: SessionStore) -> Self {
54 Self {
55 store: Arc::new(store),
56 }
57 }
58}
59
60impl<S> Layer<S> for SessionLayer {
61 type Service = SessionMiddleware<S>;
62
63 fn layer(&self, inner: S) -> Self::Service {
64 SessionMiddleware {
65 inner,
66 store: self.store.clone(),
67 }
68 }
69}
70
71pub fn layer(store: SessionStore) -> SessionLayer {
73 SessionLayer::new(store)
74}
75
76#[derive(Clone)]
79pub struct SessionMiddleware<S> {
80 inner: S,
81 store: Arc<SessionStore>,
82}
83
84impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for SessionMiddleware<S>
85where
86 S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone + Send + 'static,
87 S::Future: Send + 'static,
88 ReqBody: Send + 'static,
89 ResBody: Default + Send + 'static,
90{
91 type Response = Response<ResBody>;
92 type Error = S::Error;
93 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
94
95 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
96 self.inner.poll_ready(cx)
97 }
98
99 fn call(&mut self, mut request: Request<ReqBody>) -> Self::Future {
100 let store = self.store.clone();
101 let mut inner = self.inner.clone();
102
103 Box::pin(async move {
104 let config = store.config();
105 let cookie_name = &config.cookie_name;
106
107 let connect_ip = request
109 .extensions()
110 .get::<ConnectInfo<SocketAddr>>()
111 .map(|ci| ci.0.ip());
112 let headers = request.headers();
113 let ip = extract_client_ip(headers, &config.trusted_proxies, connect_ip);
114 let ua = header_str(headers, "user-agent");
115 let accept_lang = header_str(headers, "accept-language");
116 let accept_enc = header_str(headers, "accept-encoding");
117 let meta = SessionMeta::from_headers(ip, ua, accept_lang, accept_enc);
118
119 let session_token = read_session_cookie(headers, cookie_name);
121 let had_cookie = session_token.is_some();
122
123 let (current_session, read_failed) = if let Some(ref token) = session_token {
125 match store.read_by_token(token).await {
126 Ok(session) => (session, false),
127 Err(e) => {
128 tracing::error!("Failed to read session: {e}");
129 (None, true)
130 }
131 }
132 } else {
133 (None, false)
134 };
135
136 let current_session = if let Some(session) = current_session {
138 if config.validate_fingerprint && meta.fingerprint != session.fingerprint {
139 tracing::warn!(
140 session_id = session.id.as_str(),
141 user_id = session.user_id,
142 "Session fingerprint mismatch — possible hijack, destroying session"
143 );
144 let _ = store.destroy(&session.id).await;
145 None
146 } else {
147 Some(session)
148 }
149 } else {
150 None
151 };
152
153 let should_touch = current_session.as_ref().is_some_and(|s| {
155 let elapsed = Utc::now() - s.last_active_at;
156 elapsed >= chrono::Duration::seconds(config.touch_interval_secs as i64)
157 });
158
159 let manager_state = Arc::new(SessionManagerState {
161 store: (*store).clone(),
162 current_session: Mutex::new(current_session.clone()),
163 meta,
164 action: Mutex::new(SessionAction::None),
165 });
166
167 request.extensions_mut().insert(manager_state.clone());
168
169 let mut response = inner.call(request).await?;
171
172 let action = {
174 let guard = manager_state.action.lock().await;
175 guard.clone()
176 };
177
178 let ttl_secs = config.session_ttl_secs;
179
180 match action {
181 SessionAction::Set(token) => {
182 let opts = CookieOptions::from_config(store.cookie_config()).max_age(ttl_secs);
183 append_cookie_header(&mut response, cookie_name, &token.as_hex(), &opts);
184 }
185 SessionAction::Remove => {
186 let opts = CookieOptions::from_config(store.cookie_config()).max_age(0);
188 append_cookie_header(&mut response, cookie_name, "", &opts);
189 }
190 SessionAction::None => {
191 if should_touch && let Some(ref session) = current_session {
192 let new_expires = Utc::now() + chrono::Duration::seconds(ttl_secs as i64);
193 if let Err(e) = store.touch(&session.id, new_expires).await {
194 tracing::error!(
195 session_id = session.id.as_str(),
196 "Failed to touch session: {e}"
197 );
198 } else if let Some(ref token) = session_token {
199 let opts =
200 CookieOptions::from_config(store.cookie_config()).max_age(ttl_secs);
201 append_cookie_header(
202 &mut response,
203 cookie_name,
204 &token.as_hex(),
205 &opts,
206 );
207 }
208 }
209
210 if had_cookie && current_session.is_none() && !read_failed {
212 let opts = CookieOptions::from_config(store.cookie_config()).max_age(0);
214 append_cookie_header(&mut response, cookie_name, "", &opts);
215 }
216 }
217 }
218
219 Ok(response)
220 })
221 }
222}
223
224pub fn user_id_from_extensions(extensions: &http::Extensions) -> Option<String> {
231 extensions
232 .get::<Arc<SessionManagerState>>()
233 .and_then(|state| match state.current_session.try_lock() {
234 Ok(guard) => guard.as_ref().map(|s| s.user_id.clone()),
235 Err(_) => {
236 tracing::trace!("user_id_from_extensions: session lock contended, returning None");
237 None
238 }
239 })
240}
241
242fn read_session_cookie(headers: &http::HeaderMap, cookie_name: &str) -> Option<SessionToken> {
245 headers
246 .get_all(http::header::COOKIE)
247 .iter()
248 .find_map(|val| {
249 let val = val.to_str().ok()?;
250 for pair in val.split(';') {
251 let pair = pair.trim();
252 if let Some(value) = pair.strip_prefix(cookie_name) {
253 let value = value.strip_prefix('=')?;
254 return SessionToken::from_hex(value).ok();
255 }
256 }
257 None
258 })
259}
260
261fn append_cookie_header<B>(
262 response: &mut Response<B>,
263 name: &str,
264 value: &str,
265 opts: &CookieOptions,
266) {
267 let cookie = build_cookie(name, value, opts);
268 match http::HeaderValue::try_from(cookie.to_string()) {
269 Ok(val) => {
270 response.headers_mut().append(http::header::SET_COOKIE, val);
271 }
272 Err(e) => {
273 tracing::warn!(
274 cookie_name = name,
275 "Failed to serialize session cookie: {e}"
276 );
277 }
278 }
279}