rwf/controller/
auth.rs

1//! Authentication system.
2//!
3//! Made to be easily extendable. Users need only to implement the [`Authentication`] trait
4//! and set it on their controller. Rwf also comes with several built-in authentication mechanisms that
5//! can be used out of the box.
6use super::Error;
7use crate::comms::WebsocketSender;
8use crate::config::get_config;
9use crate::http::{Authorization, Request, Response};
10use crate::view::{ToTemplateValue, Value};
11
12use async_trait::async_trait;
13use serde::{Deserialize, Serialize};
14use time::{Duration, OffsetDateTime};
15
16use std::collections::HashMap;
17use std::fmt::Debug;
18use std::sync::Arc;
19
20/// An authentication mechanism wrapper that can be attached to a controller.
21#[derive(Clone)]
22pub struct AuthHandler {
23    auth: Arc<Box<dyn Authentication>>,
24}
25
26impl Default for AuthHandler {
27    fn default() -> Self {
28        Self::new(AllowAll {})
29    }
30}
31
32impl AuthHandler {
33    /// Create new authentication mechanism using the provided authentication method.
34    pub fn new(auth: impl Authentication + 'static) -> Self {
35        AuthHandler {
36            auth: Arc::new(Box::new(auth)),
37        }
38    }
39
40    /// Get the authentication method.
41    pub fn auth(&self) -> &Box<dyn Authentication> {
42        &self.auth
43    }
44}
45
46/// Authenticators need to implement this trait.
47#[async_trait]
48#[allow(unused_variables)]
49pub trait Authentication: Sync + Send {
50    /// Perform the authentication and allow or deny the request from
51    /// going forward.
52    async fn authorize(&self, request: &Request) -> Result<bool, Error>;
53
54    /// If the request is denied, return a specific response.
55    /// Default is `403 - Forbidden`.
56    async fn denied(&self, request: &Request) -> Result<Response, Error> {
57        Ok(Response::forbidden())
58    }
59
60    /// Returns an authentication handler used when configuring
61    /// authentication on a controller.
62    fn handler(self) -> AuthHandler
63    where
64        Self: Sized + 'static,
65    {
66        AuthHandler::new(self)
67    }
68}
69
70/// Allow all requests. This is the default authentication method for all controllers.
71pub struct AllowAll;
72
73#[async_trait]
74impl Authentication for AllowAll {
75    async fn authorize(&self, _request: &Request) -> Result<bool, Error> {
76        Ok(true)
77    }
78}
79
80/// Deny all requests.
81///
82/// Not particularly useful, since there is no way to override it,
83/// but it is included to demonstrate how authentication works.
84pub struct DenyAll;
85
86#[async_trait]
87impl Authentication for DenyAll {
88    async fn authorize(&self, _request: &Request) -> Result<bool, Error> {
89        Ok(false)
90    }
91}
92
93/// HTTP Basic authentication.
94pub struct BasicAuth {
95    /// Username.
96    pub user: String,
97    /// Password.
98    pub password: String,
99}
100
101#[async_trait]
102impl Authentication for BasicAuth {
103    async fn authorize(&self, request: &Request) -> Result<bool, Error> {
104        Ok(
105            if let Some(Authorization::Basic { user, password }) = request.authorization() {
106                self.user == user && self.password == password
107            } else {
108                false
109            },
110        )
111    }
112
113    async fn denied(&self, _request: &Request) -> Result<Response, Error> {
114        Ok(Response::unauthorized("Basic"))
115    }
116}
117
118/// Static token authentication (basically a passphrase).
119///
120/// Not very secure since the token can leak, but helpful if you need
121/// to quickly protect an endpoint.
122pub struct Token {
123    /// A token string.
124    pub token: String,
125}
126
127#[async_trait]
128impl Authentication for Token {
129    async fn authorize(&self, request: &Request) -> Result<bool, Error> {
130        Ok(
131            if let Some(Authorization::Token { token }) = request.authorization() {
132                self.token == token
133            } else {
134                false
135            },
136        )
137    }
138}
139
140/// Type of session provided by the client in the request.
141#[derive(Clone, Debug, Serialize, Deserialize, Hash, PartialEq, Eq)]
142pub enum SessionId {
143    /// Guest user. All visitors are given a guest session.
144    Guest(String),
145    /// Authenticated user. This user has passed an authentication challenge, e.g. username and password.
146    Authenticated(i64),
147}
148
149impl SessionId {
150    /// The session is authenticated, i.e. it's a user.
151    pub fn authenticated(&self) -> bool {
152        use SessionId::*;
153
154        match self {
155            Guest(_) => false,
156            Authenticated(_) => true,
157        }
158    }
159
160    /// The session is a guest session, i.e. anonymous, not logged in.
161    pub fn guest(&self) -> bool {
162        !self.authenticated()
163    }
164
165    /// Get the user's ID. This is an arbitrary integer, but
166    /// should ideally be the primary key of a `"users"` table, if such exists.
167    pub fn user_id(&self) -> Option<i64> {
168        match self {
169            SessionId::Authenticated(id) => Some(*id),
170            _ => None,
171        }
172    }
173}
174
175impl std::fmt::Display for SessionId {
176    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
177        match self {
178            SessionId::Authenticated(id) => write!(f, "{}", id),
179            SessionId::Guest(id) => write!(f, "{}", id),
180        }
181    }
182}
183
184impl Default for SessionId {
185    fn default() -> Self {
186        use rand::{distributions::Alphanumeric, thread_rng, Rng};
187
188        SessionId::Guest(
189            thread_rng()
190                .sample_iter(&Alphanumeric)
191                .take(16)
192                .map(char::from)
193                .collect::<String>(),
194        )
195    }
196}
197
198/// A client's session.
199///
200/// This is a JSON-encoded object
201/// that's stored securely in a cookie (using encryption).
202#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
203pub struct Session {
204    /// Customizable session payload.
205    #[serde(rename = "p")]
206    pub payload: serde_json::Value,
207    /// Session expiration (UNIX timestamp in UTC).
208    #[serde(rename = "e")]
209    pub expiration: i64,
210    /// Type of session, e.g. guest or user.
211    #[serde(rename = "s")]
212    pub session_id: SessionId,
213}
214
215impl Default for Session {
216    fn default() -> Self {
217        Self::new(serde_json::json!({})).expect("json")
218    }
219}
220
221impl ToTemplateValue for Session {
222    fn to_template_value(&self) -> Result<Value, crate::view::Error> {
223        let mut hash = HashMap::new();
224        hash.insert("expiration".into(), Value::Integer(self.expiration));
225        hash.insert(
226            "session_id".into(),
227            Value::String(self.session_id.to_string()),
228        );
229        hash.insert(
230            "payload".into(),
231            Value::String(serde_json::to_string(&self.payload).unwrap()),
232        );
233
234        Ok(Value::Hash(hash))
235    }
236}
237
238impl Session {
239    /// Create a guest session.
240    pub fn anonymous() -> Self {
241        Self::default()
242    }
243
244    /// Alias for creating a guest session.
245    pub fn empty() -> Self {
246        Self::default()
247    }
248
249    /// Create new session with this payload. This creates a guest session.
250    pub fn new(payload: impl Serialize) -> Result<Self, Error> {
251        Ok(Self {
252            payload: serde_json::to_value(payload)?,
253            expiration: (OffsetDateTime::now_utc() + get_config().general.session_duration())
254                .unix_timestamp(),
255            session_id: SessionId::default(),
256        })
257    }
258
259    /// Create new session with this payload, authenticated to a particular user.
260    pub fn new_authenticated(payload: impl Serialize, user_id: i64) -> Result<Self, Error> {
261        let mut session = Self::new(payload)?;
262        session.session_id = SessionId::Authenticated(user_id);
263
264        Ok(session)
265    }
266
267    /// Renew the session for the specified duration.
268    pub fn renew(mut self, renew_for: Duration) -> Self {
269        self.expiration = (OffsetDateTime::now_utc() + renew_for).unix_timestamp();
270        self
271    }
272
273    /// The session is close to being expired and should be renewed automatically.
274    pub fn should_renew(&self) -> bool {
275        if let Ok(expiration) = OffsetDateTime::from_unix_timestamp(self.expiration) {
276            let now = OffsetDateTime::now_utc();
277            let remains = expiration - now;
278            let session_duration = get_config().general.session_duration();
279            remains < session_duration / 2 && remains.is_positive() // not expired
280        } else {
281            true
282        }
283    }
284
285    /// Check if the session has expired.
286    pub fn expired(&self) -> bool {
287        if let Ok(expiration) = OffsetDateTime::from_unix_timestamp(self.expiration) {
288            let now = OffsetDateTime::now_utc();
289            expiration < now
290        } else {
291            false
292        }
293    }
294
295    /// Get a Websocket sender for this session. This allows to send arbitray messages
296    /// to all browsers connected with this session.
297    pub fn websocket(&self) -> WebsocketSender {
298        use crate::comms::Comms;
299        Comms::websocket(&self.session_id)
300    }
301
302    /// This session is authenticated to a user and hasn't expired.
303    pub fn authenticated(&self) -> bool {
304        !self.expired() && self.session_id.authenticated()
305    }
306
307    /// This is a guest session.
308    pub fn guest(&self) -> bool {
309        !self.expired() && self.session_id.guest()
310    }
311}
312
313/// Session authentication.
314#[derive(Default)]
315pub struct SessionAuth {
316    redirect: Option<String>,
317}
318
319impl SessionAuth {
320    /// Create session authentication which redirects to this URL instead
321    /// of just returning `403 - Unauthorized`.
322    pub fn redirect(url: impl ToString) -> Self {
323        Self {
324            redirect: Some(url.to_string()),
325        }
326    }
327}
328
329#[async_trait]
330impl Authentication for SessionAuth {
331    async fn authorize(&self, request: &Request) -> Result<bool, Error> {
332        Ok(request.session().authenticated())
333    }
334
335    async fn denied(&self, _request: &Request) -> Result<Response, Error> {
336        if let Some(ref redirect) = self.redirect {
337            Ok(Response::new().redirect(redirect))
338        } else {
339            Ok(Response::forbidden())
340        }
341    }
342}
343
344#[cfg(test)]
345mod test {
346    use super::*;
347
348    #[test]
349    fn test_should_renew() {
350        let mut session = Session::default();
351        assert!(!session.should_renew());
352
353        assert_eq!(get_config().general.session_duration(), Duration::weeks(4));
354
355        session.expiration = (OffsetDateTime::now_utc() + Duration::weeks(2)
356            - Duration::seconds(5))
357        .unix_timestamp();
358        assert!(session.should_renew());
359
360        session.expiration =
361            (OffsetDateTime::now_utc() + Duration::weeks(2) + Duration::seconds(5))
362                .unix_timestamp();
363        assert!(!session.should_renew());
364    }
365}