conduit_cookie/
session.rs

1use base64::{decode, encode};
2use std::collections::HashMap;
3use std::str;
4
5use conduit::RequestExt;
6use conduit_middleware::{AfterResult, BeforeResult};
7use cookie::{time::Duration, Cookie, Key, SameSite};
8
9use super::RequestCookies;
10
11const MAX_AGE_DAYS: i64 = 90;
12
13pub struct SessionMiddleware {
14    cookie_name: String,
15    key: Key,
16    secure: bool,
17}
18
19pub struct Session {
20    data: HashMap<String, String>,
21    dirty: bool,
22}
23
24impl SessionMiddleware {
25    pub fn new(cookie: &str, key: Key, secure: bool) -> SessionMiddleware {
26        SessionMiddleware {
27            cookie_name: cookie.to_string(),
28            key,
29            secure,
30        }
31    }
32
33    pub fn decode(cookie: Cookie<'_>) -> HashMap<String, String> {
34        let mut ret = HashMap::new();
35        let bytes = decode(cookie.value().as_bytes()).unwrap_or_default();
36        let mut parts = bytes.split(|&a| a == 0xff);
37        while let (Some(key), Some(value)) = (parts.next(), parts.next()) {
38            if key.is_empty() {
39                break;
40            }
41            if let (Ok(key), Ok(value)) = (str::from_utf8(key), str::from_utf8(value)) {
42                ret.insert(key.to_string(), value.to_string());
43            }
44        }
45        ret
46    }
47
48    pub fn encode(h: &HashMap<String, String>) -> String {
49        let mut ret = Vec::new();
50        for (i, (k, v)) in h.iter().enumerate() {
51            if i != 0 {
52                ret.push(0xff)
53            }
54            ret.extend(k.bytes());
55            ret.push(0xff);
56            ret.extend(v.bytes());
57        }
58        while ret.len() * 8 % 6 != 0 {
59            ret.push(0xff);
60        }
61        encode(&ret[..])
62    }
63}
64
65impl conduit_middleware::Middleware for SessionMiddleware {
66    fn before(&self, req: &mut dyn RequestExt) -> BeforeResult {
67        let session = {
68            let jar = req.cookies_mut().signed(&self.key);
69            jar.get(&self.cookie_name)
70                .map(Self::decode)
71                .unwrap_or_else(HashMap::new)
72        };
73        req.mut_extensions().insert(Session {
74            data: session,
75            dirty: false,
76        });
77        Ok(())
78    }
79
80    fn after(&self, req: &mut dyn RequestExt, res: AfterResult) -> AfterResult {
81        let session = req.extensions().get::<Session>();
82        let session = session.expect("session must be present after request");
83        if session.dirty {
84            let encoded = Self::encode(&session.data);
85            let cookie = Cookie::build(self.cookie_name.to_string(), encoded)
86                .http_only(true)
87                .secure(self.secure)
88                .same_site(SameSite::Strict)
89                .max_age(Duration::days(MAX_AGE_DAYS))
90                .path("/")
91                .finish();
92            req.cookies_mut().signed_mut(&self.key).add(cookie);
93        }
94        res
95    }
96}
97
98pub trait RequestSession {
99    fn session(&self) -> &HashMap<String, String>;
100    fn session_mut(&mut self) -> &mut HashMap<String, String>;
101}
102
103impl<T: RequestExt + ?Sized> RequestSession for T {
104    fn session(&self) -> &HashMap<String, String> {
105        &self
106            .extensions()
107            .get::<Session>()
108            .expect("missing cookie session")
109            .data
110    }
111
112    fn session_mut(&mut self) -> &mut HashMap<String, String> {
113        let session = self
114            .mut_extensions()
115            .get_mut::<Session>()
116            .expect("missing cookie session");
117        session.dirty = true;
118        &mut session.data
119    }
120}
121
122#[cfg(test)]
123mod test {
124    use std::collections::HashMap;
125
126    use conduit::{header, Body, Handler, HttpResult, Method, RequestExt, Response};
127    use conduit_middleware::MiddlewareBuilder;
128    use conduit_test::MockRequest;
129    use cookie::{Cookie, Key};
130
131    use crate::{Middleware, RequestSession, SessionMiddleware};
132
133    fn test_key() -> Key {
134        let master_key: Vec<u8> = (0..32).collect();
135        Key::derive_from(&master_key)
136    }
137
138    #[test]
139    fn simple() {
140        let mut req = MockRequest::new(Method::POST, "/articles");
141        let key = test_key();
142
143        // Set the session cookie
144        let mut app = MiddlewareBuilder::new(set_session);
145        app.add(Middleware::new());
146        app.add(SessionMiddleware::new("lol", key, false));
147        let response = app.call(&mut req).unwrap();
148
149        let v = response
150            .headers()
151            .get(header::SET_COOKIE)
152            .unwrap()
153            .to_str()
154            .unwrap();
155        assert!(v.starts_with("lol"));
156
157        // Use the session cookie
158        req.header(header::COOKIE, v);
159        let key = test_key();
160        let mut app = MiddlewareBuilder::new(use_session);
161        app.add(Middleware::new());
162        app.add(SessionMiddleware::new("lol", key, false));
163        assert!(app.call(&mut req).is_ok());
164
165        fn set_session(req: &mut dyn RequestExt) -> HttpResult {
166            assert!(req
167                .session_mut()
168                .insert("foo".to_string(), "bar".to_string())
169                .is_none());
170            Response::builder().body(Body::empty())
171        }
172        fn use_session(req: &mut dyn RequestExt) -> HttpResult {
173            assert_eq!(*req.session().get("foo").unwrap(), "bar");
174            Response::builder().body(Body::empty())
175        }
176    }
177
178    #[test]
179    fn no_equals() {
180        let e = {
181            let mut map = HashMap::new();
182            map.insert("a".to_string(), "bc".to_string());
183            SessionMiddleware::encode(&map)
184        };
185        assert!(!e.ends_with('='));
186
187        let m = SessionMiddleware::decode(Cookie::new("foo", e));
188        assert_eq!(*m.get("a").unwrap(), "bc");
189    }
190
191    #[test]
192    fn dirty_tracking() {
193        let mut req = MockRequest::new(Method::GET, "/");
194
195        let mut app = MiddlewareBuilder::new(read_session);
196        app.add(Middleware::new());
197        app.add(SessionMiddleware::new("dirty", test_key(), false));
198        let response = app.call(&mut req).unwrap();
199
200        assert!(response.headers().get(header::SET_COOKIE).is_none());
201
202        let mut app = MiddlewareBuilder::new(modify_session);
203        app.add(Middleware::new());
204        app.add(SessionMiddleware::new("dirty", test_key(), false));
205        let response = app.call(&mut req).unwrap();
206
207        assert!(response.headers().get(header::SET_COOKIE).is_some());
208
209        fn read_session(req: &mut dyn RequestExt) -> HttpResult {
210            req.session();
211            Response::builder().body(Body::empty())
212        }
213        fn modify_session(req: &mut dyn RequestExt) -> HttpResult {
214            req.session_mut();
215            Response::builder().body(Body::empty())
216        }
217    }
218}