ferro_rs/session/
middleware.rs1use crate::http::cookie::{Cookie, SameSite};
4use crate::http::Response;
5use crate::middleware::{Middleware, Next};
6use crate::Request;
7use async_trait::async_trait;
8use rand::Rng;
9use std::sync::Arc;
10use tokio::sync::RwLock;
11
12use super::config::SessionConfig;
13use super::driver::DatabaseSessionDriver;
14use super::store::{SessionData, SessionStore};
15
16tokio::task_local! {
19 static SESSION_CONTEXT: Arc<RwLock<Option<SessionData>>>;
20}
21
22pub fn session() -> Option<SessionData> {
36 SESSION_CONTEXT
37 .try_with(|ctx| {
38 ctx.try_read().ok().and_then(|guard| guard.clone())
40 })
41 .ok()
42 .flatten()
43}
44
45pub fn session_mut<F, R>(f: F) -> Option<R>
57where
58 F: FnOnce(&mut SessionData) -> R,
59{
60 SESSION_CONTEXT
61 .try_with(|ctx| {
62 ctx.try_write()
64 .ok()
65 .and_then(|mut guard| guard.as_mut().map(f))
66 })
67 .ok()
68 .flatten()
69}
70
71fn take_session_internal(ctx: &Arc<RwLock<Option<SessionData>>>) -> Option<SessionData> {
73 ctx.try_write().ok().and_then(|mut guard| guard.take())
74}
75
76pub fn generate_session_id() -> String {
80 let mut rng = rand::thread_rng();
81 const CHARSET: &[u8] = b"abcdefghijklmnopqrstuvwxyz0123456789";
82
83 (0..40)
84 .map(|_| {
85 let idx = rng.gen_range(0..CHARSET.len());
86 CHARSET[idx] as char
87 })
88 .collect()
89}
90
91pub fn generate_csrf_token() -> String {
95 generate_session_id()
96}
97
98pub struct SessionMiddleware {
107 config: SessionConfig,
108 store: Arc<dyn SessionStore>,
109}
110
111impl SessionMiddleware {
112 pub fn new(config: SessionConfig) -> Self {
114 let store = Arc::new(DatabaseSessionDriver::new(config.lifetime));
115 Self { config, store }
116 }
117
118 pub fn with_store(config: SessionConfig, store: Arc<dyn SessionStore>) -> Self {
120 Self { config, store }
121 }
122
123 fn create_session_cookie(&self, session_id: &str) -> Cookie {
124 let mut cookie = Cookie::new(&self.config.cookie_name, session_id)
125 .http_only(self.config.cookie_http_only)
126 .secure(self.config.cookie_secure)
127 .path(&self.config.cookie_path)
128 .max_age(self.config.lifetime);
129
130 cookie = match self.config.cookie_same_site.to_lowercase().as_str() {
131 "strict" => cookie.same_site(SameSite::Strict),
132 "none" => cookie.same_site(SameSite::None),
133 _ => cookie.same_site(SameSite::Lax),
134 };
135
136 cookie
137 }
138}
139
140#[async_trait]
141impl Middleware for SessionMiddleware {
142 async fn handle(&self, request: Request, next: Next) -> Response {
143 let session_id = request
145 .cookie(&self.config.cookie_name)
146 .unwrap_or_else(generate_session_id);
147
148 let mut session = match self.store.read(&session_id).await {
150 Ok(Some(s)) => s,
151 Ok(None) => {
152 SessionData::new(session_id.clone(), generate_csrf_token())
154 }
155 Err(e) => {
156 eprintln!("Session read error: {e}");
157 SessionData::new(session_id.clone(), generate_csrf_token())
158 }
159 };
160
161 session.age_flash_data();
163
164 let ctx = Arc::new(RwLock::new(Some(session)));
166
167 let response = SESSION_CONTEXT
170 .scope(ctx.clone(), async { next(request).await })
171 .await;
172
173 let session = take_session_internal(&ctx);
175
176 if let Some(session) = session {
178 if let Err(e) = self.store.write(&session).await {
180 eprintln!("Session write error: {e}");
181 }
182
183 let cookie = self.create_session_cookie(&session.id);
185
186 match response {
187 Ok(res) => Ok(res.cookie(cookie)),
188 Err(res) => Err(res.cookie(cookie)),
189 }
190 } else {
191 response
192 }
193 }
194}
195
196pub fn regenerate_session_id() {
201 session_mut(|session| {
202 session.id = generate_session_id();
203 session.dirty = true;
204 });
205}
206
207pub fn invalidate_session() {
209 session_mut(|session| {
210 session.flush();
211 session.csrf_token = generate_csrf_token();
212 });
213}
214
215pub fn get_csrf_token() -> Option<String> {
217 session().map(|s| s.csrf_token)
218}
219
220pub fn is_authenticated() -> bool {
222 session().map(|s| s.user_id.is_some()).unwrap_or(false)
223}
224
225pub fn auth_user_id() -> Option<i64> {
227 session().and_then(|s| s.user_id)
228}
229
230pub fn set_auth_user(user_id: i64) {
232 session_mut(|session| {
233 session.user_id = Some(user_id);
234 session.dirty = true;
235 });
236}
237
238pub fn clear_auth_user() {
240 session_mut(|session| {
241 session.user_id = None;
242 session.dirty = true;
243 });
244}