kit_rs/session/
middleware.rs

1//! Session middleware for Kit framework
2
3use crate::http::cookie::{Cookie, SameSite};
4use crate::http::Response;
5use crate::middleware::{Middleware, Next};
6use crate::Request;
7use async_trait::async_trait;
8use rand::Rng;
9use std::cell::RefCell;
10use std::sync::Arc;
11
12use super::config::SessionConfig;
13use super::driver::DatabaseSessionDriver;
14use super::store::{SessionData, SessionStore};
15
16// Thread-local session context for storing the current request's session data
17thread_local! {
18    static SESSION_CONTEXT: RefCell<Option<SessionData>> = const { RefCell::new(None) };
19}
20
21/// Get the current session (read-only)
22///
23/// Returns a clone of the current session data if available.
24///
25/// # Example
26///
27/// ```rust,ignore
28/// use kit::session::session;
29///
30/// if let Some(session) = session() {
31///     let name: Option<String> = session.get("name");
32/// }
33/// ```
34pub fn session() -> Option<SessionData> {
35    SESSION_CONTEXT.with(|ctx| ctx.borrow().clone())
36}
37
38/// Get the current session and modify it
39///
40/// # Example
41///
42/// ```rust,ignore
43/// use kit::session::session_mut;
44///
45/// session_mut(|session| {
46///     session.put("name", "John");
47/// });
48/// ```
49pub fn session_mut<F, R>(f: F) -> Option<R>
50where
51    F: FnOnce(&mut SessionData) -> R,
52{
53    SESSION_CONTEXT.with(|ctx| {
54        let mut session_opt = ctx.borrow_mut();
55        session_opt.as_mut().map(f)
56    })
57}
58
59/// Set the session context for the current request
60pub fn set_session(session: SessionData) {
61    SESSION_CONTEXT.with(|ctx| {
62        *ctx.borrow_mut() = Some(session);
63    });
64}
65
66/// Clear the session context
67pub fn clear_session() {
68    SESSION_CONTEXT.with(|ctx| {
69        *ctx.borrow_mut() = None;
70    });
71}
72
73/// Take the session out of the context (for saving)
74pub fn take_session() -> Option<SessionData> {
75    SESSION_CONTEXT.with(|ctx| ctx.borrow_mut().take())
76}
77
78/// Generate a cryptographically secure session ID
79///
80/// Generates a 40-character alphanumeric string.
81pub fn generate_session_id() -> String {
82    let mut rng = rand::thread_rng();
83    const CHARSET: &[u8] = b"abcdefghijklmnopqrstuvwxyz0123456789";
84
85    (0..40)
86        .map(|_| {
87            let idx = rng.gen_range(0..CHARSET.len());
88            CHARSET[idx] as char
89        })
90        .collect()
91}
92
93/// Generate a CSRF token
94///
95/// Same format as session ID for consistency.
96pub fn generate_csrf_token() -> String {
97    generate_session_id()
98}
99
100/// Session middleware
101///
102/// Handles session lifecycle:
103/// 1. Reads session ID from cookie
104/// 2. Loads session data from storage
105/// 3. Makes session available during request
106/// 4. Saves session after request
107/// 5. Sets session cookie on response
108pub struct SessionMiddleware {
109    config: SessionConfig,
110    store: Arc<dyn SessionStore>,
111}
112
113impl SessionMiddleware {
114    /// Create a new session middleware with the given configuration
115    pub fn new(config: SessionConfig) -> Self {
116        let store = Arc::new(DatabaseSessionDriver::new(config.lifetime));
117        Self { config, store }
118    }
119
120    /// Create session middleware with a custom store
121    pub fn with_store(config: SessionConfig, store: Arc<dyn SessionStore>) -> Self {
122        Self { config, store }
123    }
124
125    fn create_session_cookie(&self, session_id: &str) -> Cookie {
126        let mut cookie = Cookie::new(&self.config.cookie_name, session_id)
127            .http_only(self.config.cookie_http_only)
128            .secure(self.config.cookie_secure)
129            .path(&self.config.cookie_path)
130            .max_age(self.config.lifetime);
131
132        cookie = match self.config.cookie_same_site.to_lowercase().as_str() {
133            "strict" => cookie.same_site(SameSite::Strict),
134            "none" => cookie.same_site(SameSite::None),
135            _ => cookie.same_site(SameSite::Lax),
136        };
137
138        cookie
139    }
140}
141
142#[async_trait]
143impl Middleware for SessionMiddleware {
144    async fn handle(&self, request: Request, next: Next) -> Response {
145        // Get session ID from cookie or generate new one
146        let session_id = request
147            .cookie(&self.config.cookie_name)
148            .unwrap_or_else(generate_session_id);
149
150        // Load session from store
151        let mut session = match self.store.read(&session_id).await {
152            Ok(Some(s)) => s,
153            Ok(None) => {
154                // Create new session
155                SessionData::new(session_id.clone(), generate_csrf_token())
156            }
157            Err(e) => {
158                eprintln!("Session read error: {}", e);
159                SessionData::new(session_id.clone(), generate_csrf_token())
160            }
161        };
162
163        // Age flash data from previous request
164        session.age_flash_data();
165
166        // Store session in thread-local context
167        set_session(session);
168
169        // Process the request
170        let response = next(request).await;
171
172        // Get the potentially modified session
173        let session = take_session();
174
175        // Save session and add cookie to response
176        if let Some(session) = session {
177            // Always save to update last_activity
178            if let Err(e) = self.store.write(&session).await {
179                eprintln!("Session write error: {}", e);
180            }
181
182            // Add session cookie to response
183            let cookie = self.create_session_cookie(&session.id);
184
185            match response {
186                Ok(res) => Ok(res.cookie(cookie)),
187                Err(res) => Err(res.cookie(cookie)),
188            }
189        } else {
190            response
191        }
192    }
193}
194
195/// Regenerate the session ID (for security after login)
196///
197/// This creates a new session ID while preserving session data,
198/// which helps prevent session fixation attacks.
199pub fn regenerate_session_id() {
200    session_mut(|session| {
201        session.id = generate_session_id();
202        session.dirty = true;
203    });
204}
205
206/// Invalidate the current session (clear all data)
207pub fn invalidate_session() {
208    session_mut(|session| {
209        session.flush();
210        session.csrf_token = generate_csrf_token();
211    });
212}
213
214/// Helper to get the CSRF token from current session
215pub fn get_csrf_token() -> Option<String> {
216    session().map(|s| s.csrf_token)
217}
218
219/// Helper to check if user is authenticated
220pub fn is_authenticated() -> bool {
221    session().map(|s| s.user_id.is_some()).unwrap_or(false)
222}
223
224/// Helper to get the authenticated user ID
225pub fn auth_user_id() -> Option<i64> {
226    session().and_then(|s| s.user_id)
227}
228
229/// Helper to set the authenticated user
230pub fn set_auth_user(user_id: i64) {
231    session_mut(|session| {
232        session.user_id = Some(user_id);
233        session.dirty = true;
234    });
235}
236
237/// Helper to clear the authenticated user (logout)
238pub fn clear_auth_user() {
239    session_mut(|session| {
240        session.user_id = None;
241        session.dirty = true;
242    });
243}