1use super::Error;
7use crate::comms::WebsocketSender;
8use crate::config::get_config;
9use crate::http::{Authorization, Request, Response};
10use crate::view::{ToTemplateValue, Value};
11
12use async_trait::async_trait;
13use serde::{Deserialize, Serialize};
14use time::{Duration, OffsetDateTime};
15
16use std::collections::HashMap;
17use std::fmt::Debug;
18use std::sync::Arc;
19
20#[derive(Clone)]
22pub struct AuthHandler {
23 auth: Arc<Box<dyn Authentication>>,
24}
25
26impl Default for AuthHandler {
27 fn default() -> Self {
28 Self::new(AllowAll {})
29 }
30}
31
32impl AuthHandler {
33 pub fn new(auth: impl Authentication + 'static) -> Self {
35 AuthHandler {
36 auth: Arc::new(Box::new(auth)),
37 }
38 }
39
40 pub fn auth(&self) -> &Box<dyn Authentication> {
42 &self.auth
43 }
44}
45
46#[async_trait]
48#[allow(unused_variables)]
49pub trait Authentication: Sync + Send {
50 async fn authorize(&self, request: &Request) -> Result<bool, Error>;
53
54 async fn denied(&self, request: &Request) -> Result<Response, Error> {
57 Ok(Response::forbidden())
58 }
59
60 fn handler(self) -> AuthHandler
63 where
64 Self: Sized + 'static,
65 {
66 AuthHandler::new(self)
67 }
68}
69
70pub struct AllowAll;
72
73#[async_trait]
74impl Authentication for AllowAll {
75 async fn authorize(&self, _request: &Request) -> Result<bool, Error> {
76 Ok(true)
77 }
78}
79
80pub struct DenyAll;
85
86#[async_trait]
87impl Authentication for DenyAll {
88 async fn authorize(&self, _request: &Request) -> Result<bool, Error> {
89 Ok(false)
90 }
91}
92
93pub struct BasicAuth {
95 pub user: String,
97 pub password: String,
99}
100
101#[async_trait]
102impl Authentication for BasicAuth {
103 async fn authorize(&self, request: &Request) -> Result<bool, Error> {
104 Ok(
105 if let Some(Authorization::Basic { user, password }) = request.authorization() {
106 self.user == user && self.password == password
107 } else {
108 false
109 },
110 )
111 }
112
113 async fn denied(&self, _request: &Request) -> Result<Response, Error> {
114 Ok(Response::unauthorized("Basic"))
115 }
116}
117
118pub struct Token {
123 pub token: String,
125}
126
127#[async_trait]
128impl Authentication for Token {
129 async fn authorize(&self, request: &Request) -> Result<bool, Error> {
130 Ok(
131 if let Some(Authorization::Token { token }) = request.authorization() {
132 self.token == token
133 } else {
134 false
135 },
136 )
137 }
138}
139
140#[derive(Clone, Debug, Serialize, Deserialize, Hash, PartialEq, Eq)]
142pub enum SessionId {
143 Guest(String),
145 Authenticated(i64),
147}
148
149impl SessionId {
150 pub fn authenticated(&self) -> bool {
152 use SessionId::*;
153
154 match self {
155 Guest(_) => false,
156 Authenticated(_) => true,
157 }
158 }
159
160 pub fn guest(&self) -> bool {
162 !self.authenticated()
163 }
164
165 pub fn user_id(&self) -> Option<i64> {
168 match self {
169 SessionId::Authenticated(id) => Some(*id),
170 _ => None,
171 }
172 }
173}
174
175impl std::fmt::Display for SessionId {
176 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
177 match self {
178 SessionId::Authenticated(id) => write!(f, "{}", id),
179 SessionId::Guest(id) => write!(f, "{}", id),
180 }
181 }
182}
183
184impl Default for SessionId {
185 fn default() -> Self {
186 use rand::{distributions::Alphanumeric, thread_rng, Rng};
187
188 SessionId::Guest(
189 thread_rng()
190 .sample_iter(&Alphanumeric)
191 .take(16)
192 .map(char::from)
193 .collect::<String>(),
194 )
195 }
196}
197
198#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
203pub struct Session {
204 #[serde(rename = "p")]
206 pub payload: serde_json::Value,
207 #[serde(rename = "e")]
209 pub expiration: i64,
210 #[serde(rename = "s")]
212 pub session_id: SessionId,
213}
214
215impl Default for Session {
216 fn default() -> Self {
217 Self::new(serde_json::json!({})).expect("json")
218 }
219}
220
221impl ToTemplateValue for Session {
222 fn to_template_value(&self) -> Result<Value, crate::view::Error> {
223 let mut hash = HashMap::new();
224 hash.insert("expiration".into(), Value::Integer(self.expiration));
225 hash.insert(
226 "session_id".into(),
227 Value::String(self.session_id.to_string()),
228 );
229 hash.insert(
230 "payload".into(),
231 Value::String(serde_json::to_string(&self.payload).unwrap()),
232 );
233
234 Ok(Value::Hash(hash))
235 }
236}
237
238impl Session {
239 pub fn anonymous() -> Self {
241 Self::default()
242 }
243
244 pub fn empty() -> Self {
246 Self::default()
247 }
248
249 pub fn new(payload: impl Serialize) -> Result<Self, Error> {
251 Ok(Self {
252 payload: serde_json::to_value(payload)?,
253 expiration: (OffsetDateTime::now_utc() + get_config().general.session_duration())
254 .unix_timestamp(),
255 session_id: SessionId::default(),
256 })
257 }
258
259 pub fn new_authenticated(payload: impl Serialize, user_id: i64) -> Result<Self, Error> {
261 let mut session = Self::new(payload)?;
262 session.session_id = SessionId::Authenticated(user_id);
263
264 Ok(session)
265 }
266
267 pub fn renew(mut self, renew_for: Duration) -> Self {
269 self.expiration = (OffsetDateTime::now_utc() + renew_for).unix_timestamp();
270 self
271 }
272
273 pub fn should_renew(&self) -> bool {
275 if let Ok(expiration) = OffsetDateTime::from_unix_timestamp(self.expiration) {
276 let now = OffsetDateTime::now_utc();
277 let remains = expiration - now;
278 let session_duration = get_config().general.session_duration();
279 remains < session_duration / 2 && remains.is_positive() } else {
281 true
282 }
283 }
284
285 pub fn expired(&self) -> bool {
287 if let Ok(expiration) = OffsetDateTime::from_unix_timestamp(self.expiration) {
288 let now = OffsetDateTime::now_utc();
289 expiration < now
290 } else {
291 false
292 }
293 }
294
295 pub fn websocket(&self) -> WebsocketSender {
298 use crate::comms::Comms;
299 Comms::websocket(&self.session_id)
300 }
301
302 pub fn authenticated(&self) -> bool {
304 !self.expired() && self.session_id.authenticated()
305 }
306
307 pub fn guest(&self) -> bool {
309 !self.expired() && self.session_id.guest()
310 }
311}
312
313#[derive(Default)]
315pub struct SessionAuth {
316 redirect: Option<String>,
317}
318
319impl SessionAuth {
320 pub fn redirect(url: impl ToString) -> Self {
323 Self {
324 redirect: Some(url.to_string()),
325 }
326 }
327}
328
329#[async_trait]
330impl Authentication for SessionAuth {
331 async fn authorize(&self, request: &Request) -> Result<bool, Error> {
332 Ok(request.session().authenticated())
333 }
334
335 async fn denied(&self, _request: &Request) -> Result<Response, Error> {
336 if let Some(ref redirect) = self.redirect {
337 Ok(Response::new().redirect(redirect))
338 } else {
339 Ok(Response::forbidden())
340 }
341 }
342}
343
344#[cfg(test)]
345mod test {
346 use super::*;
347
348 #[test]
349 fn test_should_renew() {
350 let mut session = Session::default();
351 assert!(!session.should_renew());
352
353 assert_eq!(get_config().general.session_duration(), Duration::weeks(4));
354
355 session.expiration = (OffsetDateTime::now_utc() + Duration::weeks(2)
356 - Duration::seconds(5))
357 .unix_timestamp();
358 assert!(session.should_renew());
359
360 session.expiration =
361 (OffsetDateTime::now_utc() + Duration::weeks(2) + Duration::seconds(5))
362 .unix_timestamp();
363 assert!(!session.should_renew());
364 }
365}