spa_rs/
session.rs

1//! A tower middleware who can reading and writing session data from Cookie.
2//!
3use crate::filter::Predicate;
4use axum::{extract::Request, http::StatusCode, response::Response};
5use headers::{Cookie, HeaderMapExt};
6use parking_lot::RwLock;
7use std::{cmp::PartialEq, collections::HashMap, sync::Arc};
8
9/// Session object, can access by Extension in RequireSession layer.
10///
11/// See [RequireSession] example for usage
12#[derive(Clone)]
13pub struct Session<T> {
14    /// current session data
15    pub current: T,
16    /// session storage
17    pub all: Arc<SessionStore<T>>,
18}
19
20/// Session storage, can access by Extersion in AddSession layer.
21///
22/// See [AddSession] example for usage
23#[derive(Debug)]
24pub struct SessionStore<T> {
25    key: String,
26    inner: RwLock<HashMap<String, T>>,
27}
28
29impl<T: PartialEq> SessionStore<T> {
30    /// return new SessionStore with specific key
31    pub fn new(key: impl Into<String>) -> Self {
32        SessionStore {
33            key: key.into(),
34            inner: RwLock::new(HashMap::new()),
35        }
36    }
37
38    /// get the key reference
39    pub fn key(&self) -> &str {
40        &self.key
41    }
42
43    /// insert a new session item
44    pub fn insert(&self, k: impl Into<String>, v: T) {
45        self.inner.write().insert(k.into(), v);
46    }
47
48    /// remove the session item
49    pub fn remove(&self, v: T) {
50        self.inner.write().retain(|_, x| *x != v);
51    }
52}
53
54/// Middleware that can access and modify all sessions data. Usually used for **Login** handler
55///
56/// # Example
57///```
58/// # use spa_rs::routing::{post, Router};
59/// # use spa_rs::Extension;
60/// # use spa_rs::session::AddSession;
61/// # use spa_rs::session::SessionStore;
62/// # use axum_help::filter::FilterExLayer;
63/// # use std::sync::Arc;
64/// #
65/// #[derive(PartialEq, Clone)]
66/// struct User;
67///
68/// async fn login(Extension(session): Extension<Arc<SessionStore<User>>>) {
69///     let new_user = User;
70///     session.insert("session_id", new_user);
71/// }
72///
73/// #[tokio::main]
74/// async fn main() {
75///     let session = Arc::new(SessionStore::<User>::new("my_session"));
76///     let app = Router::new()
77///         .route("/login", post(login))
78///         .layer(FilterExLayer::new(AddSession::new(session.clone())));
79/// #   axum::Server::bind(&"0.0.0.0:3000".parse().unwrap()).serve(app.into_make_service());
80/// }
81///```
82#[derive(Clone, Debug)]
83pub struct AddSession<T>(Arc<SessionStore<T>>);
84
85impl<T> AddSession<T> {
86    pub fn new(store: Arc<SessionStore<T>>) -> Self {
87        Self(store)
88    }
89}
90
91impl<T> Predicate<Request> for AddSession<T>
92where
93    T: Send + Sync + 'static,
94{
95    type Request = Request;
96    type Response = Response;
97
98    fn check(&self, mut request: Request) -> Result<Self::Request, Self::Response> {
99        request.extensions_mut().insert(self.0.clone());
100        Ok(request)
101    }
102}
103
104/// Middleware that can access and modify all sessions data.
105///
106/// # Example
107///```
108/// # use spa_rs::routing::{post, Router};
109/// # use spa_rs::Extension;
110/// # use spa_rs::session::RequireSession;
111/// # use spa_rs::session::SessionStore;
112/// # use spa_rs::session::Session;
113/// # use axum_help::filter::FilterExLayer;
114/// # use std::sync::Arc;
115/// #
116/// #[derive(PartialEq, Clone, Debug)]
117/// struct User;
118///
119/// async fn action(Extension(session): Extension<Arc<Session<User>>>) {
120///     println!("current user: {:?}", session.current);
121/// }
122///
123/// #[tokio::main]
124/// async fn main() {
125///     let session = Arc::new(SessionStore::<User>::new("my_session"));
126///     let app = Router::new()
127///         .route("/action", post(action))
128///         .layer(FilterExLayer::new(RequireSession::new(session.clone())));
129/// #   axum::Server::bind(&"0.0.0.0:3000".parse().unwrap()).serve(app.into_make_service());
130/// }
131///```
132#[derive(Clone, Debug)]
133pub struct RequireSession<T>(Arc<SessionStore<T>>);
134
135impl<T> RequireSession<T> {
136    pub fn new(store: Arc<SessionStore<T>>) -> Self {
137        Self(store)
138    }
139}
140
141impl<T> Predicate<Request> for RequireSession<T>
142where
143    T: Clone + Send + Sync + 'static,
144{
145    type Request = Request;
146    type Response = Response;
147
148    fn check(&self, mut request: Request) -> Result<Self::Request, Self::Response> {
149        if let Some(cookie) = request.headers().typed_get::<Cookie>() {
150            let sessions = self.0.inner.read();
151            for (k, v) in cookie.iter() {
152                if k == self.0.key {
153                    if let Some(u) = sessions.get(v) {
154                        request.extensions_mut().insert(Session {
155                            current: u.clone(),
156                            all: self.0.clone(),
157                        });
158                        return Ok(request);
159                    }
160                }
161            }
162        }
163
164        Err({
165            let mut response = Response::default();
166            *response.status_mut() = StatusCode::UNAUTHORIZED;
167            response
168        })
169    }
170}