Skip to main content

ferro_rs/session/
middleware.rs

1//! Session middleware for Ferro framework
2
3use crate::http::cookie::{Cookie, SameSite};
4use crate::http::Response;
5use crate::middleware::{Middleware, Next};
6use crate::Request;
7use async_trait::async_trait;
8use rand::Rng;
9use std::sync::Arc;
10use tokio::sync::RwLock;
11
12use super::config::SessionConfig;
13use super::driver::DatabaseSessionDriver;
14use super::store::{SessionData, SessionStore};
15
16// Task-local session context using tokio's task_local macro
17// This is async-safe unlike thread_local which can lose data across await points
18tokio::task_local! {
19    static SESSION_CONTEXT: Arc<RwLock<Option<SessionData>>>;
20}
21
22/// Get the current session (read-only)
23///
24/// Returns a clone of the current session data if available.
25///
26/// # Example
27///
28/// ```rust,ignore
29/// use ferro_rs::session::session;
30///
31/// if let Some(session) = session() {
32///     let name: Option<String> = session.get("name");
33/// }
34/// ```
35pub fn session() -> Option<SessionData> {
36    SESSION_CONTEXT
37        .try_with(|ctx| {
38            // Use try_read to avoid blocking - if locked, return None
39            ctx.try_read().ok().and_then(|guard| guard.clone())
40        })
41        .ok()
42        .flatten()
43}
44
45/// Get the current session and modify it
46///
47/// # Example
48///
49/// ```rust,ignore
50/// use ferro_rs::session::session_mut;
51///
52/// session_mut(|session| {
53///     session.put("name", "John");
54/// });
55/// ```
56pub fn session_mut<F, R>(f: F) -> Option<R>
57where
58    F: FnOnce(&mut SessionData) -> R,
59{
60    SESSION_CONTEXT
61        .try_with(|ctx| {
62            // Use try_write to avoid blocking
63            ctx.try_write()
64                .ok()
65                .and_then(|mut guard| guard.as_mut().map(f))
66        })
67        .ok()
68        .flatten()
69}
70
71/// Take the session out of the context (for saving)
72fn take_session_internal(ctx: &Arc<RwLock<Option<SessionData>>>) -> Option<SessionData> {
73    ctx.try_write().ok().and_then(|mut guard| guard.take())
74}
75
76/// Generate a cryptographically secure session ID
77///
78/// Generates a 40-character alphanumeric string.
79pub fn generate_session_id() -> String {
80    let mut rng = rand::thread_rng();
81    const CHARSET: &[u8] = b"abcdefghijklmnopqrstuvwxyz0123456789";
82
83    (0..40)
84        .map(|_| {
85            let idx = rng.gen_range(0..CHARSET.len());
86            CHARSET[idx] as char
87        })
88        .collect()
89}
90
91/// Generate a CSRF token
92///
93/// Same format as session ID for consistency.
94pub fn generate_csrf_token() -> String {
95    generate_session_id()
96}
97
98/// Session middleware
99///
100/// Handles session lifecycle:
101/// 1. Reads session ID from cookie
102/// 2. Loads session data from storage
103/// 3. Makes session available during request
104/// 4. Saves session after request
105/// 5. Sets session cookie on response
106pub struct SessionMiddleware {
107    config: SessionConfig,
108    store: Arc<dyn SessionStore>,
109}
110
111impl SessionMiddleware {
112    /// Create a new session middleware with the given configuration
113    pub fn new(config: SessionConfig) -> Self {
114        let store = Arc::new(DatabaseSessionDriver::new(
115            config.idle_lifetime,
116            config.absolute_lifetime,
117        ));
118        Self { config, store }
119    }
120
121    /// Create session middleware with a custom store
122    pub fn with_store(config: SessionConfig, store: Arc<dyn SessionStore>) -> Self {
123        Self { config, store }
124    }
125
126    fn create_session_cookie(&self, session_id: &str) -> Cookie {
127        let mut cookie = Cookie::new(&self.config.cookie_name, session_id)
128            .http_only(self.config.cookie_http_only)
129            .secure(self.config.cookie_secure)
130            .path(&self.config.cookie_path)
131            .max_age(std::cmp::max(
132                self.config.idle_lifetime,
133                self.config.absolute_lifetime,
134            ));
135
136        cookie = match self.config.cookie_same_site.to_lowercase().as_str() {
137            "strict" => cookie.same_site(SameSite::Strict),
138            "none" => cookie.same_site(SameSite::None),
139            _ => cookie.same_site(SameSite::Lax),
140        };
141
142        cookie
143    }
144}
145
146#[async_trait]
147impl Middleware for SessionMiddleware {
148    async fn handle(&self, request: Request, next: Next) -> Response {
149        // Get session ID from cookie or generate new one
150        let session_id = request
151            .cookie(&self.config.cookie_name)
152            .unwrap_or_else(generate_session_id);
153
154        // Load session from store
155        let mut session = match self.store.read(&session_id).await {
156            Ok(Some(s)) => s,
157            Ok(None) => {
158                // Create new session
159                SessionData::new(session_id.clone(), generate_csrf_token())
160            }
161            Err(e) => {
162                eprintln!("Session read error: {e}");
163                SessionData::new(session_id.clone(), generate_csrf_token())
164            }
165        };
166
167        // Age flash data from previous request
168        session.age_flash_data();
169
170        // Create task-local context and store session in it
171        let ctx = Arc::new(RwLock::new(Some(session)));
172
173        // Process the request within the task-local scope
174        // This makes session() and session_mut() work correctly across await points
175        let response = SESSION_CONTEXT
176            .scope(ctx.clone(), async { next(request).await })
177            .await;
178
179        // Get the potentially modified session from the context
180        let session = take_session_internal(&ctx);
181
182        // Save session and add cookie to response
183        if let Some(session) = session {
184            // Always save to update last_activity
185            if let Err(e) = self.store.write(&session).await {
186                eprintln!("Session write error: {e}");
187            }
188
189            // Add session cookie to response
190            let cookie = self.create_session_cookie(&session.id);
191
192            match response {
193                Ok(res) => Ok(res.cookie(cookie)),
194                Err(res) => Err(res.cookie(cookie)),
195            }
196        } else {
197            response
198        }
199    }
200}
201
202/// Regenerate the session ID (for security after login)
203///
204/// This creates a new session ID while preserving session data,
205/// which helps prevent session fixation attacks.
206pub fn regenerate_session_id() {
207    session_mut(|session| {
208        session.id = generate_session_id();
209        session.dirty = true;
210    });
211}
212
213/// Invalidate the current session (clear all data)
214pub fn invalidate_session() {
215    session_mut(|session| {
216        session.flush();
217        session.csrf_token = generate_csrf_token();
218    });
219}
220
221/// Helper to get the CSRF token from current session
222pub fn get_csrf_token() -> Option<String> {
223    session().map(|s| s.csrf_token)
224}
225
226/// Helper to check if user is authenticated
227pub fn is_authenticated() -> bool {
228    session().map(|s| s.user_id.is_some()).unwrap_or(false)
229}
230
231/// Helper to get the authenticated user ID
232pub fn auth_user_id() -> Option<i64> {
233    session().and_then(|s| s.user_id)
234}
235
236/// Helper to set the authenticated user
237pub fn set_auth_user(user_id: i64) {
238    session_mut(|session| {
239        session.user_id = Some(user_id);
240        session.dirty = true;
241    });
242}
243
244/// Helper to clear the authenticated user (logout)
245pub fn clear_auth_user() {
246    session_mut(|session| {
247        session.user_id = None;
248        session.dirty = true;
249    });
250}
251
252/// Destroy all sessions for a user, with optional exception for the current session.
253///
254/// Uses the session store's `destroy_for_user` method for direct DB deletion.
255/// This is auth-method-agnostic — works for password-based, OAuth, or any auth.
256///
257/// # Arguments
258/// * `store` - The session store to use
259/// * `user_id` - The user whose sessions to destroy
260/// * `except_session_id` - Optional session ID to preserve (current session)
261pub async fn invalidate_all_for_user(
262    store: &dyn super::store::SessionStore,
263    user_id: i64,
264    except_session_id: Option<&str>,
265) -> Result<u64, crate::error::FrameworkError> {
266    store.destroy_for_user(user_id, except_session_id).await
267}