ferro_rs/session/
middleware.rs1use crate::http::cookie::{Cookie, SameSite};
4use crate::http::Response;
5use crate::middleware::{Middleware, Next};
6use crate::Request;
7use async_trait::async_trait;
8use rand::Rng;
9use std::sync::Arc;
10use tokio::sync::RwLock;
11
12use super::config::SessionConfig;
13use super::driver::DatabaseSessionDriver;
14use super::store::{SessionData, SessionStore};
15
16tokio::task_local! {
19 static SESSION_CONTEXT: Arc<RwLock<Option<SessionData>>>;
20}
21
22pub fn session() -> Option<SessionData> {
36 SESSION_CONTEXT
37 .try_with(|ctx| {
38 ctx.try_read().ok().and_then(|guard| guard.clone())
40 })
41 .ok()
42 .flatten()
43}
44
45pub fn session_mut<F, R>(f: F) -> Option<R>
57where
58 F: FnOnce(&mut SessionData) -> R,
59{
60 SESSION_CONTEXT
61 .try_with(|ctx| {
62 ctx.try_write()
64 .ok()
65 .and_then(|mut guard| guard.as_mut().map(f))
66 })
67 .ok()
68 .flatten()
69}
70
71fn take_session_internal(ctx: &Arc<RwLock<Option<SessionData>>>) -> Option<SessionData> {
73 ctx.try_write().ok().and_then(|mut guard| guard.take())
74}
75
76pub fn generate_session_id() -> String {
80 let mut rng = rand::thread_rng();
81 const CHARSET: &[u8] = b"abcdefghijklmnopqrstuvwxyz0123456789";
82
83 (0..40)
84 .map(|_| {
85 let idx = rng.gen_range(0..CHARSET.len());
86 CHARSET[idx] as char
87 })
88 .collect()
89}
90
91pub fn generate_csrf_token() -> String {
95 generate_session_id()
96}
97
98pub struct SessionMiddleware {
107 config: SessionConfig,
108 store: Arc<dyn SessionStore>,
109}
110
111impl SessionMiddleware {
112 pub fn new(config: SessionConfig) -> Self {
114 let store = Arc::new(DatabaseSessionDriver::new(
115 config.idle_lifetime,
116 config.absolute_lifetime,
117 ));
118 Self { config, store }
119 }
120
121 pub fn with_store(config: SessionConfig, store: Arc<dyn SessionStore>) -> Self {
123 Self { config, store }
124 }
125
126 fn create_session_cookie(&self, session_id: &str) -> Cookie {
127 let mut cookie = Cookie::new(&self.config.cookie_name, session_id)
128 .http_only(self.config.cookie_http_only)
129 .secure(self.config.cookie_secure)
130 .path(&self.config.cookie_path)
131 .max_age(std::cmp::max(
132 self.config.idle_lifetime,
133 self.config.absolute_lifetime,
134 ));
135
136 cookie = match self.config.cookie_same_site.to_lowercase().as_str() {
137 "strict" => cookie.same_site(SameSite::Strict),
138 "none" => cookie.same_site(SameSite::None),
139 _ => cookie.same_site(SameSite::Lax),
140 };
141
142 cookie
143 }
144}
145
146#[async_trait]
147impl Middleware for SessionMiddleware {
148 async fn handle(&self, request: Request, next: Next) -> Response {
149 let session_id = request
151 .cookie(&self.config.cookie_name)
152 .unwrap_or_else(generate_session_id);
153
154 let mut session = match self.store.read(&session_id).await {
156 Ok(Some(s)) => s,
157 Ok(None) => {
158 SessionData::new(session_id.clone(), generate_csrf_token())
160 }
161 Err(e) => {
162 eprintln!("Session read error: {e}");
163 SessionData::new(session_id.clone(), generate_csrf_token())
164 }
165 };
166
167 session.age_flash_data();
169
170 let ctx = Arc::new(RwLock::new(Some(session)));
172
173 let response = SESSION_CONTEXT
176 .scope(ctx.clone(), async { next(request).await })
177 .await;
178
179 let session = take_session_internal(&ctx);
181
182 if let Some(session) = session {
184 if let Err(e) = self.store.write(&session).await {
186 eprintln!("Session write error: {e}");
187 }
188
189 let cookie = self.create_session_cookie(&session.id);
191
192 match response {
193 Ok(res) => Ok(res.cookie(cookie)),
194 Err(res) => Err(res.cookie(cookie)),
195 }
196 } else {
197 response
198 }
199 }
200}
201
202pub fn regenerate_session_id() {
207 session_mut(|session| {
208 session.id = generate_session_id();
209 session.dirty = true;
210 });
211}
212
213pub fn invalidate_session() {
215 session_mut(|session| {
216 session.flush();
217 session.csrf_token = generate_csrf_token();
218 });
219}
220
221pub fn get_csrf_token() -> Option<String> {
223 session().map(|s| s.csrf_token)
224}
225
226pub fn is_authenticated() -> bool {
228 session().map(|s| s.user_id.is_some()).unwrap_or(false)
229}
230
231pub fn auth_user_id() -> Option<i64> {
233 session().and_then(|s| s.user_id)
234}
235
236pub fn set_auth_user(user_id: i64) {
238 session_mut(|session| {
239 session.user_id = Some(user_id);
240 session.dirty = true;
241 });
242}
243
244pub fn clear_auth_user() {
246 session_mut(|session| {
247 session.user_id = None;
248 session.dirty = true;
249 });
250}
251
252pub async fn invalidate_all_for_user(
262 store: &dyn super::store::SessionStore,
263 user_id: i64,
264 except_session_id: Option<&str>,
265) -> Result<u64, crate::error::FrameworkError> {
266 store.destroy_for_user(user_id, except_session_id).await
267}