Skip to main content

ferro_rs/session/
middleware.rs

1//! Session middleware for Ferro framework
2
3use crate::http::cookie::{Cookie, SameSite};
4use crate::http::Response;
5use crate::middleware::{Middleware, Next};
6use crate::Request;
7use async_trait::async_trait;
8use rand::Rng;
9use std::sync::Arc;
10use tokio::sync::RwLock;
11
12use super::config::SessionConfig;
13use super::driver::DatabaseSessionDriver;
14use super::store::{SessionData, SessionStore};
15
16// Task-local session context using tokio's task_local macro
17// This is async-safe unlike thread_local which can lose data across await points
18tokio::task_local! {
19    static SESSION_CONTEXT: Arc<RwLock<Option<SessionData>>>;
20}
21
22/// Get the current session (read-only)
23///
24/// Returns a clone of the current session data if available.
25///
26/// # Example
27///
28/// ```rust,ignore
29/// use ferro_rs::session::session;
30///
31/// if let Some(session) = session() {
32///     let name: Option<String> = session.get("name");
33/// }
34/// ```
35pub fn session() -> Option<SessionData> {
36    SESSION_CONTEXT
37        .try_with(|ctx| {
38            // Use try_read to avoid blocking - if locked, return None
39            ctx.try_read().ok().and_then(|guard| guard.clone())
40        })
41        .ok()
42        .flatten()
43}
44
45/// Get the current session and modify it
46///
47/// # Example
48///
49/// ```rust,ignore
50/// use ferro_rs::session::session_mut;
51///
52/// session_mut(|session| {
53///     session.put("name", "John");
54/// });
55/// ```
56pub fn session_mut<F, R>(f: F) -> Option<R>
57where
58    F: FnOnce(&mut SessionData) -> R,
59{
60    SESSION_CONTEXT
61        .try_with(|ctx| {
62            // Use try_write to avoid blocking
63            ctx.try_write()
64                .ok()
65                .and_then(|mut guard| guard.as_mut().map(f))
66        })
67        .ok()
68        .flatten()
69}
70
71/// Take the session out of the context (for saving)
72fn take_session_internal(ctx: &Arc<RwLock<Option<SessionData>>>) -> Option<SessionData> {
73    ctx.try_write().ok().and_then(|mut guard| guard.take())
74}
75
76/// Generate a cryptographically secure session ID
77///
78/// Generates a 40-character alphanumeric string.
79pub fn generate_session_id() -> String {
80    let mut rng = rand::thread_rng();
81    const CHARSET: &[u8] = b"abcdefghijklmnopqrstuvwxyz0123456789";
82
83    (0..40)
84        .map(|_| {
85            let idx = rng.gen_range(0..CHARSET.len());
86            CHARSET[idx] as char
87        })
88        .collect()
89}
90
91/// Generate a CSRF token
92///
93/// Same format as session ID for consistency.
94pub fn generate_csrf_token() -> String {
95    generate_session_id()
96}
97
98/// Session middleware
99///
100/// Handles session lifecycle:
101/// 1. Reads session ID from cookie
102/// 2. Loads session data from storage
103/// 3. Makes session available during request
104/// 4. Saves session after request
105/// 5. Sets session cookie on response
106pub struct SessionMiddleware {
107    config: SessionConfig,
108    store: Arc<dyn SessionStore>,
109}
110
111impl SessionMiddleware {
112    /// Create a new session middleware with the given configuration
113    pub fn new(config: SessionConfig) -> Self {
114        let store = Arc::new(DatabaseSessionDriver::new(config.lifetime));
115        Self { config, store }
116    }
117
118    /// Create session middleware with a custom store
119    pub fn with_store(config: SessionConfig, store: Arc<dyn SessionStore>) -> Self {
120        Self { config, store }
121    }
122
123    fn create_session_cookie(&self, session_id: &str) -> Cookie {
124        let mut cookie = Cookie::new(&self.config.cookie_name, session_id)
125            .http_only(self.config.cookie_http_only)
126            .secure(self.config.cookie_secure)
127            .path(&self.config.cookie_path)
128            .max_age(self.config.lifetime);
129
130        cookie = match self.config.cookie_same_site.to_lowercase().as_str() {
131            "strict" => cookie.same_site(SameSite::Strict),
132            "none" => cookie.same_site(SameSite::None),
133            _ => cookie.same_site(SameSite::Lax),
134        };
135
136        cookie
137    }
138}
139
140#[async_trait]
141impl Middleware for SessionMiddleware {
142    async fn handle(&self, request: Request, next: Next) -> Response {
143        // Get session ID from cookie or generate new one
144        let session_id = request
145            .cookie(&self.config.cookie_name)
146            .unwrap_or_else(generate_session_id);
147
148        // Load session from store
149        let mut session = match self.store.read(&session_id).await {
150            Ok(Some(s)) => s,
151            Ok(None) => {
152                // Create new session
153                SessionData::new(session_id.clone(), generate_csrf_token())
154            }
155            Err(e) => {
156                eprintln!("Session read error: {e}");
157                SessionData::new(session_id.clone(), generate_csrf_token())
158            }
159        };
160
161        // Age flash data from previous request
162        session.age_flash_data();
163
164        // Create task-local context and store session in it
165        let ctx = Arc::new(RwLock::new(Some(session)));
166
167        // Process the request within the task-local scope
168        // This makes session() and session_mut() work correctly across await points
169        let response = SESSION_CONTEXT
170            .scope(ctx.clone(), async { next(request).await })
171            .await;
172
173        // Get the potentially modified session from the context
174        let session = take_session_internal(&ctx);
175
176        // Save session and add cookie to response
177        if let Some(session) = session {
178            // Always save to update last_activity
179            if let Err(e) = self.store.write(&session).await {
180                eprintln!("Session write error: {e}");
181            }
182
183            // Add session cookie to response
184            let cookie = self.create_session_cookie(&session.id);
185
186            match response {
187                Ok(res) => Ok(res.cookie(cookie)),
188                Err(res) => Err(res.cookie(cookie)),
189            }
190        } else {
191            response
192        }
193    }
194}
195
196/// Regenerate the session ID (for security after login)
197///
198/// This creates a new session ID while preserving session data,
199/// which helps prevent session fixation attacks.
200pub fn regenerate_session_id() {
201    session_mut(|session| {
202        session.id = generate_session_id();
203        session.dirty = true;
204    });
205}
206
207/// Invalidate the current session (clear all data)
208pub fn invalidate_session() {
209    session_mut(|session| {
210        session.flush();
211        session.csrf_token = generate_csrf_token();
212    });
213}
214
215/// Helper to get the CSRF token from current session
216pub fn get_csrf_token() -> Option<String> {
217    session().map(|s| s.csrf_token)
218}
219
220/// Helper to check if user is authenticated
221pub fn is_authenticated() -> bool {
222    session().map(|s| s.user_id.is_some()).unwrap_or(false)
223}
224
225/// Helper to get the authenticated user ID
226pub fn auth_user_id() -> Option<i64> {
227    session().and_then(|s| s.user_id)
228}
229
230/// Helper to set the authenticated user
231pub fn set_auth_user(user_id: i64) {
232    session_mut(|session| {
233        session.user_id = Some(user_id);
234        session.dirty = true;
235    });
236}
237
238/// Helper to clear the authenticated user (logout)
239pub fn clear_auth_user() {
240    session_mut(|session| {
241        session.user_id = None;
242        session.dirty = true;
243    });
244}