conrad_core/
auth.rs

1use crate::{
2    database::{
3        CreateKeyError, CreateSessionError, CreateUserError, DatabaseAdapter, KeyError, KeySchema,
4        SessionData, SessionError, SessionSchema, UserError,
5    },
6    errors::AuthError,
7    request::Request,
8    utils, Key, KeyTimestamp, KeyType, NaiveKeyType, Session, SessionId, SessionState, User,
9    UserData, UserId, ValidationSuccess,
10};
11use cookie::{time::OffsetDateTime, Cookie, CookieJar};
12use futures::{stream, StreamExt, TryStreamExt};
13use http::{HeaderMap, Method};
14use std::{marker::PhantomData, time::Duration};
15use tokio::join;
16use tracing::{debug, error, info, warn};
17use url::Url;
18use uuid::Uuid;
19
20const SESSION_COOKIE_NAME: &str = "auth_session";
21
22pub struct AuthenticatorBuilder<D> {
23    adapter: D,
24    generate_custom_user_id: fn() -> UserId,
25    auto_database_cleanup: bool,
26}
27
28impl<D> AuthenticatorBuilder<D> {
29    pub fn new(adapter: D) -> Self {
30        Self {
31            adapter,
32            generate_custom_user_id: || UserId::new(Uuid::new_v4().to_string()),
33            auto_database_cleanup: true,
34        }
35    }
36
37    pub fn set_user_id_generator<F>(self, function: fn() -> UserId) -> Self {
38        Self {
39            generate_custom_user_id: function,
40            ..self
41        }
42    }
43
44    pub fn enable_auto_database_cleanup(self, enable: bool) -> Self {
45        Self {
46            auto_database_cleanup: enable,
47            ..self
48        }
49    }
50
51    pub fn build<U>(self) -> Authenticator<D, U> {
52        Authenticator {
53            adapter: self.adapter,
54            generate_custom_user_id: self.generate_custom_user_id,
55            auto_database_cleanup: self.auto_database_cleanup,
56            _user_attributes: PhantomData::default(),
57        }
58    }
59}
60
61pub struct Authenticator<D, U> {
62    adapter: D,
63    generate_custom_user_id: fn() -> UserId,
64    auto_database_cleanup: bool,
65    _user_attributes: PhantomData<U>,
66}
67
68impl<D, U> Authenticator<D, U>
69where
70    D: DatabaseAdapter<U>,
71{
72    pub async fn create_user(&self, data: UserData, attributes: U) -> Result<User<U>, AuthError> {
73        let user_id = (self.generate_custom_user_id)();
74        let key_id = format!("{}:{}", data.provider_id, data.provider_user_id);
75        let hashed_password = if let Some(password) = data.password {
76            Some(utils::hash_password(password).await)
77        } else {
78            None
79        };
80        let res = self
81            .adapter
82            .create_user_and_key(
83                &attributes,
84                &KeySchema {
85                    id: key_id,
86                    user_id: user_id.clone(),
87                    hashed_password,
88                    primary_key: true,
89                    expires: None,
90                },
91            )
92            .await;
93        match res {
94            Err(CreateUserError::DatabaseError(err)) => Err(AuthError::DatabaseError(err)),
95            Err(CreateUserError::UserAlreadyExists) => {
96                let attributes = self.get_user(&user_id).await?;
97                Ok(User {
98                    user_id,
99                    user_attributes: attributes,
100                })
101            }
102            Ok(()) => Ok(User {
103                user_id,
104                user_attributes: attributes,
105            }),
106        }
107    }
108
109    pub async fn use_key(
110        &self,
111        provider_id: &str,
112        provider_user_id: &str,
113        password: Option<String>,
114    ) -> Result<Key, AuthError> {
115        let key_id = format!("{provider_id}:{provider_user_id}");
116        let database_key_data = self
117            .adapter
118            .read_key(&key_id)
119            .await
120            .map_err(|err| match err {
121                KeyError::DatabaseError(err) => AuthError::DatabaseError(err),
122                KeyError::KeyDoesNotExist => {
123                    error!(key_id, "key not found");
124                    AuthError::InvalidKeyId
125                }
126            })?;
127        let single_use = database_key_data.expires.filter(|&expires| expires != 0);
128        let hashed_password = database_key_data.hashed_password.clone();
129        if let Some(hashed_password) = hashed_password {
130            info!("key includes password");
131            if let Some(password) = password {
132                if password.is_empty() || hashed_password.is_empty() {
133                    return Err(AuthError::InvalidPassword);
134                }
135                if hashed_password.starts_with("$2a") {
136                    return Err(AuthError::OutdatedPassword);
137                }
138                let valid_password =
139                    utils::validate_password(password.clone(), hashed_password).await;
140                if !valid_password {
141                    error!(password, "incorrect key password");
142                    return Err(AuthError::InvalidPassword);
143                }
144            } else {
145                error!(key_id, "key password not provided");
146                return Err(AuthError::InvalidPassword);
147            }
148            warn!("validated key password");
149        } else {
150            info!("no password included in key");
151        }
152        if let Some(expires) = single_use {
153            info!("key type: single-use");
154            let within_expiration = utils::is_within_expiration(expires);
155            if !within_expiration {
156                error!(key_id, "key expired at {}", expires);
157                return Err(AuthError::ExpiredKey);
158            }
159            self.adapter
160                .delete_non_primary_key(&database_key_data.id)
161                .await?;
162        } else {
163            info!("key type: persistent");
164        }
165        info!(key_id, "validated key");
166        Ok(database_key_data.into())
167    }
168
169    pub async fn create_session(&self, user_id: UserId) -> Result<Session, AuthError> {
170        let session_info = generate_session_id();
171        let session_schema = SessionSchema {
172            session_data: session_info,
173            user_id: user_id.clone(),
174        };
175        if self.auto_database_cleanup {
176            let (res, _) = join!(
177                self.adapter.create_session(&session_schema),
178                self.delete_dead_user_sessions(&user_id)
179            );
180            res
181        } else {
182            self.adapter.create_session(&session_schema).await
183        }
184        .map_err(|err| match err {
185            CreateSessionError::DatabaseError(err) => AuthError::DatabaseError(err),
186            CreateSessionError::DuplicateSessionId => AuthError::DuplicateSessionId,
187            CreateSessionError::InvalidUserId => AuthError::InvalidUserId,
188        })?;
189        Ok(Session {
190            active_period_expires_at: session_schema.session_data.active_period_expires_at,
191            session_id: session_schema.session_data.session_id,
192            idle_period_expires_at: session_schema.session_data.idle_period_expires_at,
193            state: SessionState::Active,
194            fresh: true,
195        })
196    }
197
198    pub async fn invalidate_session(&self, session_id: &str) -> Result<(), AuthError> {
199        self.adapter.delete_session(session_id).await?;
200        info!(session_id, "invalidated session");
201        Ok(())
202    }
203
204    pub(crate) async fn validate_session_user(
205        &self,
206        session_id: &str,
207    ) -> Result<ValidationSuccess<U>, AuthError> {
208        let info = self.get_session_user(session_id).await?;
209        if info.session.state == SessionState::Active {
210            info!(info.session.session_id, "validated session");
211            Ok(info)
212        } else {
213            let renewed_session = self.get_session(session_id, true).await?;
214            Ok(ValidationSuccess {
215                session: renewed_session,
216                ..info
217            })
218        }
219    }
220
221    async fn get_session_user(&self, session_id: &str) -> Result<ValidationSuccess<U>, AuthError> {
222        if Uuid::try_parse(session_id).is_err() {
223            error!(session_id, "session id is invalid");
224            return Err(AuthError::InvalidSessionId);
225        }
226        let database_user_session = self
227            .adapter
228            .read_session_and_user_by_session_id(session_id)
229            .await
230            .map_err(|err| match err {
231                SessionError::SessionNotFound => {
232                    error!(session_id, "session not found");
233                    AuthError::InvalidSessionId
234                }
235                SessionError::DatabaseError(err) => AuthError::DatabaseError(err),
236            })?;
237        let database_user = database_user_session.user;
238        let session_data = database_user_session.session;
239        let session = utils::validate_database_session(session_data.session_data.clone());
240        if let Some(session) = session {
241            Ok(ValidationSuccess {
242                session,
243                user: database_user,
244            })
245        } else {
246            error!(
247                session_id,
248                "session expired at {}", session_data.session_data.idle_period_expires_at
249            );
250            if self.auto_database_cleanup {
251                self.adapter
252                    .delete_session(&session_data.session_data.session_id)
253                    .await?;
254            }
255            Err(AuthError::InvalidSessionId)
256        }
257    }
258
259    async fn get_session(&self, session_id: &str, renew: bool) -> Result<Session, AuthError> {
260        if Uuid::try_parse(session_id).is_err() {
261            error!(session_id, "session id is invalid");
262            return Err(AuthError::InvalidSessionId);
263        }
264        let database_session =
265            self.adapter
266                .read_session(session_id)
267                .await
268                .map_err(|err| match err {
269                    SessionError::DatabaseError(err) => AuthError::DatabaseError(err),
270                    SessionError::SessionNotFound => {
271                        error!(session_id, "session not found");
272                        AuthError::InvalidSessionId
273                    }
274                })?;
275        let idle_expires = database_session.session_data.idle_period_expires_at;
276        let session = utils::validate_database_session(database_session.session_data);
277        if let Some(session) = session {
278            if renew {
279                let user_id = database_session.user_id;
280                let renewed_session = if self.auto_database_cleanup {
281                    let (renewed_session, _) = join!(
282                        self.create_session(user_id.clone()),
283                        self.delete_dead_user_sessions(&user_id)
284                    );
285                    renewed_session
286                } else {
287                    self.create_session(user_id).await
288                }?;
289                info!(renewed_session.session_id, "session renewed");
290                Ok(renewed_session)
291            } else {
292                Ok(session)
293            }
294        } else {
295            error!(session_id, "session expired at {}", idle_expires);
296            if self.auto_database_cleanup {
297                self.adapter.delete_session(session_id).await?;
298            }
299            Err(AuthError::InvalidSessionId)
300        }
301    }
302
303    pub async fn get_user(&self, user_id: &UserId) -> Result<U, AuthError> {
304        Ok(self
305            .adapter
306            .read_user(user_id)
307            .await
308            .map_err(|err| match err {
309                UserError::DatabaseError(err) => AuthError::DatabaseError(err),
310                UserError::UserDoesNotExist => AuthError::InvalidUserId,
311            })?
312            .user_attributes)
313    }
314
315    pub fn handle_request<'a>(
316        &'a self,
317        cookies: &CookieJar,
318        method: &Method,
319        headers: &HeaderMap,
320        origin_url: &Url,
321    ) -> Request<'a, D, U> {
322        Request::new(self, cookies, method, headers, origin_url)
323    }
324
325    async fn delete_dead_user_sessions(&self, user_id: &UserId) -> Result<(), AuthError> {
326        let database_sessions = self.adapter.read_sessions(user_id).await?;
327        let dead_session_ids = database_sessions.into_iter().filter_map(|s| {
328            if utils::is_within_expiration(s.session_data.idle_period_expires_at) {
329                None
330            } else {
331                Some(s.session_data.session_id)
332            }
333        });
334        stream::iter(dead_session_ids)
335            .map(|id| async move { self.adapter.delete_session(&id).await })
336            .buffer_unordered(10)
337            .try_collect()
338            .await?;
339        Ok(())
340    }
341
342    pub async fn update_user_attributes(
343        &self,
344        user_id: &UserId,
345        attributes: U,
346    ) -> Result<(), AuthError> {
347        if self.auto_database_cleanup {
348            let (res, _) = join!(
349                self.adapter.update_user(user_id, &attributes),
350                self.delete_dead_user_sessions(user_id)
351            );
352            res
353        } else {
354            self.adapter.update_user(user_id, &attributes).await
355        }
356        .map_err(|err| match err {
357            UserError::DatabaseError(err) => AuthError::DatabaseError(err),
358            UserError::UserDoesNotExist => AuthError::InvalidUserId,
359        })
360    }
361
362    pub async fn invalidate_all_user_sessions(&self, user_id: &UserId) -> Result<(), AuthError> {
363        Ok(self.adapter.delete_sessions_by_user_id(user_id).await?)
364    }
365
366    pub async fn delete_user(&self, user_id: UserId) -> Result<(), AuthError> {
367        self.adapter.delete_sessions_by_user_id(&user_id).await?;
368        self.adapter.delete_keys(&user_id).await?;
369        Ok(self.adapter.delete_user(&user_id).await?)
370    }
371
372    pub async fn create_key(
373        &self,
374        user_id: UserId,
375        user_data: UserData,
376        key_type: &NaiveKeyType,
377    ) -> Result<Key, AuthError> {
378        let key_id = format!("{}:{}", user_data.provider_id, user_data.provider_user_id);
379        let hashed_password = if let Some(password) = user_data.password.clone() {
380            Some(utils::hash_password(password).await)
381        } else {
382            None
383        };
384        let key_type = if let NaiveKeyType::SingleUse { expires_in } = key_type {
385            let expires_at = get_one_time_key_expiration(expires_in.get_timestamp());
386            self.adapter
387                .create_key(&KeySchema {
388                    id: key_id,
389                    hashed_password,
390                    user_id: user_id.clone(),
391                    primary_key: false,
392                    expires: Some(expires_at),
393                })
394                .await
395                .map_err(|err| match err {
396                    CreateKeyError::DatabaseError(err) => AuthError::DatabaseError(err),
397                    CreateKeyError::KeyAlreadyExists => AuthError::DuplicateKeyId,
398                    CreateKeyError::UserDoesNotExist => AuthError::InvalidUserId,
399                })?;
400            KeyType::SingleUse {
401                expires_in: expires_at.into(),
402            }
403        } else {
404            self.adapter
405                .create_key(&KeySchema {
406                    id: key_id,
407                    hashed_password,
408                    user_id: user_id.clone(),
409                    primary_key: false,
410                    expires: None,
411                })
412                .await
413                .map_err(|err| match err {
414                    CreateKeyError::DatabaseError(err) => AuthError::DatabaseError(err),
415                    CreateKeyError::KeyAlreadyExists => AuthError::DuplicateKeyId,
416                    CreateKeyError::UserDoesNotExist => AuthError::InvalidUserId,
417                })?;
418            KeyType::Persistent { primary: false }
419        };
420        Ok(Key {
421            key_type,
422            password_defined: if let Some(password) = user_data.password {
423                !password.is_empty()
424            } else {
425                false
426            },
427            user_id,
428            provider_id: user_data.provider_id,
429            provider_user_id: user_data.provider_user_id,
430        })
431    }
432
433    pub async fn get_key(
434        &self,
435        provider_id: &str,
436        provider_user_id: &str,
437    ) -> Result<Key, AuthError> {
438        let key_id = format!("{provider_id}:{provider_user_id}");
439        let database_key = self
440            .adapter
441            .read_key(&key_id)
442            .await
443            .map_err(|err| match err {
444                KeyError::DatabaseError(err) => AuthError::DatabaseError(err),
445                KeyError::KeyDoesNotExist => AuthError::InvalidKeyId,
446            })?;
447        Ok(database_key.into())
448    }
449
450    pub async fn get_all_user_keys(&self, user_id: &UserId) -> Result<Vec<Key>, AuthError> {
451        let database_data = self.adapter.read_keys_by_user_id(user_id).await?;
452        Ok(database_data
453            .into_iter()
454            .map(std::convert::Into::into)
455            .collect())
456    }
457
458    pub async fn update_key_password(&self, data: UserData) -> Result<(), AuthError> {
459        let key_id = format!("{}:{}", data.provider_id, data.provider_user_id);
460        if let Some(password) = data.password {
461            let hashed_password = utils::hash_password(password).await;
462            self.adapter
463                .update_key_password(&key_id, Some(&hashed_password))
464                .await
465        } else {
466            self.adapter.update_key_password(&key_id, None).await
467        }
468        .map_err(|err| match err {
469            KeyError::DatabaseError(err) => AuthError::DatabaseError(err),
470            KeyError::KeyDoesNotExist => AuthError::InvalidKeyId,
471        })
472    }
473
474    pub async fn delete_key(
475        &self,
476        provider_id: &str,
477        provider_user_id: &str,
478    ) -> Result<(), AuthError> {
479        let key_id = format!("{provider_id}:{provider_user_id}");
480        Ok(self.adapter.delete_non_primary_key(&key_id).await?)
481    }
482}
483
484#[must_use]
485pub fn parse_request_headers<'c>(
486    cookies: &'c CookieJar,
487    method: &Method,
488    headers: &HeaderMap,
489    origin_url: &Url,
490) -> Option<SessionId<'c>> {
491    debug!("{}, {}", method, origin_url);
492    let session_id: Option<SessionId> = cookies.get(SESSION_COOKIE_NAME).map(|c| c.value().into());
493    if let Some(session_id) = &session_id {
494        info!("found session cookie: {}", session_id.as_str());
495    } else {
496        info!("no session cookie found");
497    }
498    let csrf_check = method != Method::GET && method != Method::HEAD;
499    if csrf_check {
500        let request_origin = headers.get("origin");
501        match request_origin {
502            Some(request_origin) => {
503                if let Ok(request_origin) = request_origin.to_str() {
504                    if origin_url.as_str() != request_origin {
505                        error!(request_origin, "invalid request origin");
506                        return None;
507                    }
508                } else {
509                    error!("invalid origin string: {:?}", request_origin);
510                    return None;
511                }
512                info!("valid request origin: {:?}", request_origin);
513            }
514            None => {
515                error!("no request origin available");
516                return None;
517            }
518        }
519    }
520    session_id
521}
522
523fn generate_session_id() -> SessionData {
524    const ACTIVE_PERIOD: u64 = 1000 * 60 * 60 * 24;
525    const IDLE_PERIOD: u64 = 1000 * 60 * 60 * 24 * 14;
526    let session_id = Uuid::new_v4().to_string();
527    let active_period_expires_at = OffsetDateTime::now_utc() + Duration::from_millis(ACTIVE_PERIOD);
528    let idle_period_expires_at = active_period_expires_at + Duration::from_millis(IDLE_PERIOD);
529    SessionData {
530        active_period_expires_at: active_period_expires_at.unix_timestamp(),
531        idle_period_expires_at: idle_period_expires_at.unix_timestamp(),
532        session_id,
533    }
534}
535
536fn get_one_time_key_expiration(duration: i64) -> i64 {
537    assert!(duration >= 0, "duration cannot be negative");
538    (OffsetDateTime::now_utc() + Duration::from_millis(duration as u64 * 1000 * 1000))
539        .unix_timestamp()
540}
541
542#[must_use]
543pub fn create_session_cookie<'c>(session: Option<Session>) -> Cookie<'c> {
544    if let Some(session) = session {
545        Cookie::build(SESSION_COOKIE_NAME, session.session_id)
546            .same_site(cookie::SameSite::Lax)
547            .path("/")
548            .http_only(true)
549            .expires(OffsetDateTime::from_unix_timestamp(session.idle_period_expires_at).unwrap())
550            .secure(true)
551            .finish()
552    } else {
553        Cookie::build(SESSION_COOKIE_NAME, "")
554            .same_site(cookie::SameSite::Lax)
555            .path("/")
556            .http_only(true)
557            .expires(OffsetDateTime::UNIX_EPOCH)
558            .secure(true)
559            .finish()
560    }
561}
562
563impl KeyTimestamp {
564    #[must_use]
565    pub fn get_timestamp(&self) -> i64 {
566        self.0
567    }
568
569    #[must_use]
570    pub fn is_expired(&self) -> bool {
571        !utils::is_within_expiration(self.get_timestamp())
572    }
573}
574
575impl From<i64> for KeyTimestamp {
576    fn from(value: i64) -> Self {
577        Self(value)
578    }
579}
580
581impl From<KeySchema> for Key {
582    fn from(database_key: KeySchema) -> Self {
583        let user_id = database_key.user_id;
584        let is_password_defined = if let Some(hashed_password) = database_key.hashed_password {
585            !hashed_password.is_empty()
586        } else {
587            false
588        };
589        let (provider_id, provider_user_id) = database_key.id.split_once(':').unwrap();
590        let key_type = if let Some(expires) = database_key.expires {
591            KeyType::SingleUse {
592                expires_in: expires.into(),
593            }
594        } else {
595            KeyType::Persistent {
596                primary: database_key.primary_key,
597            }
598        };
599        Self {
600            key_type,
601            password_defined: is_password_defined,
602            user_id,
603            provider_id: provider_id.to_string(),
604            provider_user_id: provider_user_id.to_string(),
605        }
606    }
607}
608
609#[cfg(test)]
610mod tests {
611    use crate::{database::SessionData, utils::validate_database_session, Session, SessionState};
612    use cookie::time::OffsetDateTime;
613    use std::time::Duration;
614
615    #[test]
616    fn validate_database_session_returns_none_if_dead_state() {
617        let output = validate_database_session(SessionData {
618            active_period_expires_at: OffsetDateTime::now_utc().unix_timestamp(),
619            idle_period_expires_at: (OffsetDateTime::now_utc() - Duration::from_millis(10 * 1000))
620                .unix_timestamp(),
621            session_id: String::new(),
622        });
623        assert!(output.is_none());
624    }
625
626    #[test]
627    fn validate_database_session_returns_idle_session_if_idle_state() {
628        let output = validate_database_session(SessionData {
629            active_period_expires_at: (OffsetDateTime::now_utc()
630                - Duration::from_millis(10 * 1000))
631            .unix_timestamp(),
632            idle_period_expires_at: (OffsetDateTime::now_utc() + Duration::from_millis(10 * 1000))
633                .unix_timestamp(),
634            session_id: String::new(),
635        });
636        assert!(matches!(
637            output,
638            Some(Session {
639                state: SessionState::Idle,
640                ..
641            })
642        ))
643    }
644
645    #[test]
646    fn validate_database_session_returns_active_session_if_active_state() {
647        let output = validate_database_session(SessionData {
648            active_period_expires_at: (OffsetDateTime::now_utc()
649                + Duration::from_millis(10 * 1000))
650            .unix_timestamp(),
651            idle_period_expires_at: (OffsetDateTime::now_utc() + Duration::from_millis(10 * 1000))
652                .unix_timestamp(),
653            session_id: String::new(),
654        });
655        assert!(matches!(
656            output,
657            Some(Session {
658                state: SessionState::Active,
659                ..
660            })
661        ))
662    }
663}