Skip to main content

oxidite_auth/
session_middleware.rs

1use oxidite_core::{OxiditeRequest, OxiditeResponse, Error as CoreError};
2use tower::{Service, Layer};
3use std::task::{Context, Poll};
4use std::future::Future;
5use std::pin::Pin;
6use std::sync::Arc;
7use cookie::{Cookie, SameSite};
8use crate::session::SessionStore;
9
10const SESSION_COOKIE_NAME: &str = "oxidite_session";
11
12/// Session middleware
13#[derive(Clone)]
14pub struct SessionMiddleware<S> {
15    inner: S,
16    store: Arc<dyn SessionStore>,
17    cookie_secure: bool,
18    cookie_http_only: bool,
19    session_ttl_secs: u64,
20}
21
22impl<S> SessionMiddleware<S> {
23    pub fn new(
24        inner: S,
25        store: Arc<dyn SessionStore>,
26        cookie_secure: bool,
27        cookie_http_only: bool,
28        session_ttl_secs: u64,
29    ) -> Self {
30        Self {
31            inner,
32            store,
33            cookie_secure,
34            cookie_http_only,
35            session_ttl_secs,
36        }
37    }
38}
39
40impl<S> Service<OxiditeRequest> for SessionMiddleware<S>
41where
42    S: Service<OxiditeRequest, Response = OxiditeResponse, Error = CoreError> + Clone + Send + 'static,
43    S::Future: Send + 'static,
44{
45    type Response = S::Response;
46    type Error = S::Error;
47    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
48
49    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
50        self.inner.poll_ready(cx)
51    }
52
53    fn call(&mut self, req: OxiditeRequest) -> Self::Future {
54        // Extract session cookie
55        let session_id = req
56            .headers()
57            .get("cookie")
58            .and_then(|h| h.to_str().ok())
59            .and_then(|cookies| {
60                for cookie_str in cookies.split(';') {
61                    if let Ok(cookie) = Cookie::parse(cookie_str.trim()) {
62                        if cookie.name() == SESSION_COOKIE_NAME {
63                            return Some(cookie.value().to_string());
64                        }
65                    }
66                }
67                None
68            });
69
70        let store = self.store.clone();
71        let cookie_secure = self.cookie_secure;
72        let cookie_http_only = self.cookie_http_only;
73        let session_ttl_secs = self.session_ttl_secs;
74        let mut inner = self.inner.clone();
75
76        Box::pin(async move {
77            // Try to load existing session
78            let session = if let Some(sid) = session_id {
79                store.get(&sid).await.ok().flatten()
80            } else {
81                None
82            };
83
84            let mut req = req;
85            if let Some(sess) = session.clone() {
86                req.extensions_mut().insert(sess.clone());
87                req.extensions_mut().insert(sess.user_id.clone());
88                if let Ok(user_id) = sess.user_id.parse::<i64>() {
89                    req.extensions_mut().insert(user_id);
90                }
91            }
92
93            let mut response = inner.call(req).await?;
94
95            // If session was renewed or created, set cookie
96            if let Some(sess) = session {
97                if !sess.is_expired() {
98                    let cookie = Cookie::build((SESSION_COOKIE_NAME, sess.id.clone()))
99                        .secure(cookie_secure)
100                        .http_only(cookie_http_only)
101                        .same_site(SameSite::Lax)
102                        .max_age(cookie::time::Duration::seconds(session_ttl_secs as i64))
103                        .path("/")
104                        .build();
105
106                    if let Ok(cookie_val) = cookie.to_string().parse() {
107                        response.headers_mut().insert("set-cookie", cookie_val);
108                    }
109                }
110            }
111
112            Ok(response)
113        })
114    }
115}
116
117/// Layer for session middleware
118pub struct SessionLayer {
119    store: Arc<dyn SessionStore>,
120    cookie_secure: bool,
121    cookie_http_only: bool,
122    session_ttl_secs: u64,
123}
124
125impl SessionLayer {
126    pub fn new(
127        store: Arc<dyn SessionStore>,
128        cookie_secure: bool,
129        cookie_http_only: bool,
130        session_ttl_secs: u64,
131    ) -> Self {
132        Self {
133            store,
134            cookie_secure,
135            cookie_http_only,
136            session_ttl_secs,
137        }
138    }
139
140    pub fn with_defaults(store: Arc<dyn SessionStore>) -> Self {
141        Self::new(store, true, true, 3600)
142    }
143}
144
145impl<S> Layer<S> for SessionLayer {
146    type Service = SessionMiddleware<S>;
147
148    fn layer(&self, inner: S) -> Self::Service {
149        SessionMiddleware::new(
150            inner,
151            self.store.clone(),
152            self.cookie_secure,
153            self.cookie_http_only,
154            self.session_ttl_secs,
155        )
156    }
157}