conduit_cookie/
session.rs1use base64::{decode, encode};
2use std::collections::HashMap;
3use std::str;
4
5use conduit::RequestExt;
6use conduit_middleware::{AfterResult, BeforeResult};
7use cookie::{time::Duration, Cookie, Key, SameSite};
8
9use super::RequestCookies;
10
11const MAX_AGE_DAYS: i64 = 90;
12
13pub struct SessionMiddleware {
14 cookie_name: String,
15 key: Key,
16 secure: bool,
17}
18
19pub struct Session {
20 data: HashMap<String, String>,
21 dirty: bool,
22}
23
24impl SessionMiddleware {
25 pub fn new(cookie: &str, key: Key, secure: bool) -> SessionMiddleware {
26 SessionMiddleware {
27 cookie_name: cookie.to_string(),
28 key,
29 secure,
30 }
31 }
32
33 pub fn decode(cookie: Cookie<'_>) -> HashMap<String, String> {
34 let mut ret = HashMap::new();
35 let bytes = decode(cookie.value().as_bytes()).unwrap_or_default();
36 let mut parts = bytes.split(|&a| a == 0xff);
37 while let (Some(key), Some(value)) = (parts.next(), parts.next()) {
38 if key.is_empty() {
39 break;
40 }
41 if let (Ok(key), Ok(value)) = (str::from_utf8(key), str::from_utf8(value)) {
42 ret.insert(key.to_string(), value.to_string());
43 }
44 }
45 ret
46 }
47
48 pub fn encode(h: &HashMap<String, String>) -> String {
49 let mut ret = Vec::new();
50 for (i, (k, v)) in h.iter().enumerate() {
51 if i != 0 {
52 ret.push(0xff)
53 }
54 ret.extend(k.bytes());
55 ret.push(0xff);
56 ret.extend(v.bytes());
57 }
58 while ret.len() * 8 % 6 != 0 {
59 ret.push(0xff);
60 }
61 encode(&ret[..])
62 }
63}
64
65impl conduit_middleware::Middleware for SessionMiddleware {
66 fn before(&self, req: &mut dyn RequestExt) -> BeforeResult {
67 let session = {
68 let jar = req.cookies_mut().signed(&self.key);
69 jar.get(&self.cookie_name)
70 .map(Self::decode)
71 .unwrap_or_else(HashMap::new)
72 };
73 req.mut_extensions().insert(Session {
74 data: session,
75 dirty: false,
76 });
77 Ok(())
78 }
79
80 fn after(&self, req: &mut dyn RequestExt, res: AfterResult) -> AfterResult {
81 let session = req.extensions().get::<Session>();
82 let session = session.expect("session must be present after request");
83 if session.dirty {
84 let encoded = Self::encode(&session.data);
85 let cookie = Cookie::build(self.cookie_name.to_string(), encoded)
86 .http_only(true)
87 .secure(self.secure)
88 .same_site(SameSite::Strict)
89 .max_age(Duration::days(MAX_AGE_DAYS))
90 .path("/")
91 .finish();
92 req.cookies_mut().signed_mut(&self.key).add(cookie);
93 }
94 res
95 }
96}
97
98pub trait RequestSession {
99 fn session(&self) -> &HashMap<String, String>;
100 fn session_mut(&mut self) -> &mut HashMap<String, String>;
101}
102
103impl<T: RequestExt + ?Sized> RequestSession for T {
104 fn session(&self) -> &HashMap<String, String> {
105 &self
106 .extensions()
107 .get::<Session>()
108 .expect("missing cookie session")
109 .data
110 }
111
112 fn session_mut(&mut self) -> &mut HashMap<String, String> {
113 let session = self
114 .mut_extensions()
115 .get_mut::<Session>()
116 .expect("missing cookie session");
117 session.dirty = true;
118 &mut session.data
119 }
120}
121
122#[cfg(test)]
123mod test {
124 use std::collections::HashMap;
125
126 use conduit::{header, Body, Handler, HttpResult, Method, RequestExt, Response};
127 use conduit_middleware::MiddlewareBuilder;
128 use conduit_test::MockRequest;
129 use cookie::{Cookie, Key};
130
131 use crate::{Middleware, RequestSession, SessionMiddleware};
132
133 fn test_key() -> Key {
134 let master_key: Vec<u8> = (0..32).collect();
135 Key::derive_from(&master_key)
136 }
137
138 #[test]
139 fn simple() {
140 let mut req = MockRequest::new(Method::POST, "/articles");
141 let key = test_key();
142
143 let mut app = MiddlewareBuilder::new(set_session);
145 app.add(Middleware::new());
146 app.add(SessionMiddleware::new("lol", key, false));
147 let response = app.call(&mut req).unwrap();
148
149 let v = response
150 .headers()
151 .get(header::SET_COOKIE)
152 .unwrap()
153 .to_str()
154 .unwrap();
155 assert!(v.starts_with("lol"));
156
157 req.header(header::COOKIE, v);
159 let key = test_key();
160 let mut app = MiddlewareBuilder::new(use_session);
161 app.add(Middleware::new());
162 app.add(SessionMiddleware::new("lol", key, false));
163 assert!(app.call(&mut req).is_ok());
164
165 fn set_session(req: &mut dyn RequestExt) -> HttpResult {
166 assert!(req
167 .session_mut()
168 .insert("foo".to_string(), "bar".to_string())
169 .is_none());
170 Response::builder().body(Body::empty())
171 }
172 fn use_session(req: &mut dyn RequestExt) -> HttpResult {
173 assert_eq!(*req.session().get("foo").unwrap(), "bar");
174 Response::builder().body(Body::empty())
175 }
176 }
177
178 #[test]
179 fn no_equals() {
180 let e = {
181 let mut map = HashMap::new();
182 map.insert("a".to_string(), "bc".to_string());
183 SessionMiddleware::encode(&map)
184 };
185 assert!(!e.ends_with('='));
186
187 let m = SessionMiddleware::decode(Cookie::new("foo", e));
188 assert_eq!(*m.get("a").unwrap(), "bc");
189 }
190
191 #[test]
192 fn dirty_tracking() {
193 let mut req = MockRequest::new(Method::GET, "/");
194
195 let mut app = MiddlewareBuilder::new(read_session);
196 app.add(Middleware::new());
197 app.add(SessionMiddleware::new("dirty", test_key(), false));
198 let response = app.call(&mut req).unwrap();
199
200 assert!(response.headers().get(header::SET_COOKIE).is_none());
201
202 let mut app = MiddlewareBuilder::new(modify_session);
203 app.add(Middleware::new());
204 app.add(SessionMiddleware::new("dirty", test_key(), false));
205 let response = app.call(&mut req).unwrap();
206
207 assert!(response.headers().get(header::SET_COOKIE).is_some());
208
209 fn read_session(req: &mut dyn RequestExt) -> HttpResult {
210 req.session();
211 Response::builder().body(Body::empty())
212 }
213 fn modify_session(req: &mut dyn RequestExt) -> HttpResult {
214 req.session_mut();
215 Response::builder().body(Body::empty())
216 }
217 }
218}