1use crate::{
2 database::{
3 CreateKeyError, CreateSessionError, CreateUserError, DatabaseAdapter, KeyError, KeySchema,
4 SessionData, SessionError, SessionSchema, UserError,
5 },
6 errors::AuthError,
7 request::Request,
8 utils, Key, KeyTimestamp, KeyType, NaiveKeyType, Session, SessionId, SessionState, User,
9 UserData, UserId, ValidationSuccess,
10};
11use cookie::{time::OffsetDateTime, Cookie, CookieJar};
12use futures::{stream, StreamExt, TryStreamExt};
13use http::{HeaderMap, Method};
14use std::{marker::PhantomData, time::Duration};
15use tokio::join;
16use tracing::{debug, error, info, warn};
17use url::Url;
18use uuid::Uuid;
19
20const SESSION_COOKIE_NAME: &str = "auth_session";
21
22pub struct AuthenticatorBuilder<D> {
23 adapter: D,
24 generate_custom_user_id: fn() -> UserId,
25 auto_database_cleanup: bool,
26}
27
28impl<D> AuthenticatorBuilder<D> {
29 pub fn new(adapter: D) -> Self {
30 Self {
31 adapter,
32 generate_custom_user_id: || UserId::new(Uuid::new_v4().to_string()),
33 auto_database_cleanup: true,
34 }
35 }
36
37 pub fn set_user_id_generator<F>(self, function: fn() -> UserId) -> Self {
38 Self {
39 generate_custom_user_id: function,
40 ..self
41 }
42 }
43
44 pub fn enable_auto_database_cleanup(self, enable: bool) -> Self {
45 Self {
46 auto_database_cleanup: enable,
47 ..self
48 }
49 }
50
51 pub fn build<U>(self) -> Authenticator<D, U> {
52 Authenticator {
53 adapter: self.adapter,
54 generate_custom_user_id: self.generate_custom_user_id,
55 auto_database_cleanup: self.auto_database_cleanup,
56 _user_attributes: PhantomData::default(),
57 }
58 }
59}
60
61pub struct Authenticator<D, U> {
62 adapter: D,
63 generate_custom_user_id: fn() -> UserId,
64 auto_database_cleanup: bool,
65 _user_attributes: PhantomData<U>,
66}
67
68impl<D, U> Authenticator<D, U>
69where
70 D: DatabaseAdapter<U>,
71{
72 pub async fn create_user(&self, data: UserData, attributes: U) -> Result<User<U>, AuthError> {
73 let user_id = (self.generate_custom_user_id)();
74 let key_id = format!("{}:{}", data.provider_id, data.provider_user_id);
75 let hashed_password = if let Some(password) = data.password {
76 Some(utils::hash_password(password).await)
77 } else {
78 None
79 };
80 let res = self
81 .adapter
82 .create_user_and_key(
83 &attributes,
84 &KeySchema {
85 id: key_id,
86 user_id: user_id.clone(),
87 hashed_password,
88 primary_key: true,
89 expires: None,
90 },
91 )
92 .await;
93 match res {
94 Err(CreateUserError::DatabaseError(err)) => Err(AuthError::DatabaseError(err)),
95 Err(CreateUserError::UserAlreadyExists) => {
96 let attributes = self.get_user(&user_id).await?;
97 Ok(User {
98 user_id,
99 user_attributes: attributes,
100 })
101 }
102 Ok(()) => Ok(User {
103 user_id,
104 user_attributes: attributes,
105 }),
106 }
107 }
108
109 pub async fn use_key(
110 &self,
111 provider_id: &str,
112 provider_user_id: &str,
113 password: Option<String>,
114 ) -> Result<Key, AuthError> {
115 let key_id = format!("{provider_id}:{provider_user_id}");
116 let database_key_data = self
117 .adapter
118 .read_key(&key_id)
119 .await
120 .map_err(|err| match err {
121 KeyError::DatabaseError(err) => AuthError::DatabaseError(err),
122 KeyError::KeyDoesNotExist => {
123 error!(key_id, "key not found");
124 AuthError::InvalidKeyId
125 }
126 })?;
127 let single_use = database_key_data.expires.filter(|&expires| expires != 0);
128 let hashed_password = database_key_data.hashed_password.clone();
129 if let Some(hashed_password) = hashed_password {
130 info!("key includes password");
131 if let Some(password) = password {
132 if password.is_empty() || hashed_password.is_empty() {
133 return Err(AuthError::InvalidPassword);
134 }
135 if hashed_password.starts_with("$2a") {
136 return Err(AuthError::OutdatedPassword);
137 }
138 let valid_password =
139 utils::validate_password(password.clone(), hashed_password).await;
140 if !valid_password {
141 error!(password, "incorrect key password");
142 return Err(AuthError::InvalidPassword);
143 }
144 } else {
145 error!(key_id, "key password not provided");
146 return Err(AuthError::InvalidPassword);
147 }
148 warn!("validated key password");
149 } else {
150 info!("no password included in key");
151 }
152 if let Some(expires) = single_use {
153 info!("key type: single-use");
154 let within_expiration = utils::is_within_expiration(expires);
155 if !within_expiration {
156 error!(key_id, "key expired at {}", expires);
157 return Err(AuthError::ExpiredKey);
158 }
159 self.adapter
160 .delete_non_primary_key(&database_key_data.id)
161 .await?;
162 } else {
163 info!("key type: persistent");
164 }
165 info!(key_id, "validated key");
166 Ok(database_key_data.into())
167 }
168
169 pub async fn create_session(&self, user_id: UserId) -> Result<Session, AuthError> {
170 let session_info = generate_session_id();
171 let session_schema = SessionSchema {
172 session_data: session_info,
173 user_id: user_id.clone(),
174 };
175 if self.auto_database_cleanup {
176 let (res, _) = join!(
177 self.adapter.create_session(&session_schema),
178 self.delete_dead_user_sessions(&user_id)
179 );
180 res
181 } else {
182 self.adapter.create_session(&session_schema).await
183 }
184 .map_err(|err| match err {
185 CreateSessionError::DatabaseError(err) => AuthError::DatabaseError(err),
186 CreateSessionError::DuplicateSessionId => AuthError::DuplicateSessionId,
187 CreateSessionError::InvalidUserId => AuthError::InvalidUserId,
188 })?;
189 Ok(Session {
190 active_period_expires_at: session_schema.session_data.active_period_expires_at,
191 session_id: session_schema.session_data.session_id,
192 idle_period_expires_at: session_schema.session_data.idle_period_expires_at,
193 state: SessionState::Active,
194 fresh: true,
195 })
196 }
197
198 pub async fn invalidate_session(&self, session_id: &str) -> Result<(), AuthError> {
199 self.adapter.delete_session(session_id).await?;
200 info!(session_id, "invalidated session");
201 Ok(())
202 }
203
204 pub(crate) async fn validate_session_user(
205 &self,
206 session_id: &str,
207 ) -> Result<ValidationSuccess<U>, AuthError> {
208 let info = self.get_session_user(session_id).await?;
209 if info.session.state == SessionState::Active {
210 info!(info.session.session_id, "validated session");
211 Ok(info)
212 } else {
213 let renewed_session = self.get_session(session_id, true).await?;
214 Ok(ValidationSuccess {
215 session: renewed_session,
216 ..info
217 })
218 }
219 }
220
221 async fn get_session_user(&self, session_id: &str) -> Result<ValidationSuccess<U>, AuthError> {
222 if Uuid::try_parse(session_id).is_err() {
223 error!(session_id, "session id is invalid");
224 return Err(AuthError::InvalidSessionId);
225 }
226 let database_user_session = self
227 .adapter
228 .read_session_and_user_by_session_id(session_id)
229 .await
230 .map_err(|err| match err {
231 SessionError::SessionNotFound => {
232 error!(session_id, "session not found");
233 AuthError::InvalidSessionId
234 }
235 SessionError::DatabaseError(err) => AuthError::DatabaseError(err),
236 })?;
237 let database_user = database_user_session.user;
238 let session_data = database_user_session.session;
239 let session = utils::validate_database_session(session_data.session_data.clone());
240 if let Some(session) = session {
241 Ok(ValidationSuccess {
242 session,
243 user: database_user,
244 })
245 } else {
246 error!(
247 session_id,
248 "session expired at {}", session_data.session_data.idle_period_expires_at
249 );
250 if self.auto_database_cleanup {
251 self.adapter
252 .delete_session(&session_data.session_data.session_id)
253 .await?;
254 }
255 Err(AuthError::InvalidSessionId)
256 }
257 }
258
259 async fn get_session(&self, session_id: &str, renew: bool) -> Result<Session, AuthError> {
260 if Uuid::try_parse(session_id).is_err() {
261 error!(session_id, "session id is invalid");
262 return Err(AuthError::InvalidSessionId);
263 }
264 let database_session =
265 self.adapter
266 .read_session(session_id)
267 .await
268 .map_err(|err| match err {
269 SessionError::DatabaseError(err) => AuthError::DatabaseError(err),
270 SessionError::SessionNotFound => {
271 error!(session_id, "session not found");
272 AuthError::InvalidSessionId
273 }
274 })?;
275 let idle_expires = database_session.session_data.idle_period_expires_at;
276 let session = utils::validate_database_session(database_session.session_data);
277 if let Some(session) = session {
278 if renew {
279 let user_id = database_session.user_id;
280 let renewed_session = if self.auto_database_cleanup {
281 let (renewed_session, _) = join!(
282 self.create_session(user_id.clone()),
283 self.delete_dead_user_sessions(&user_id)
284 );
285 renewed_session
286 } else {
287 self.create_session(user_id).await
288 }?;
289 info!(renewed_session.session_id, "session renewed");
290 Ok(renewed_session)
291 } else {
292 Ok(session)
293 }
294 } else {
295 error!(session_id, "session expired at {}", idle_expires);
296 if self.auto_database_cleanup {
297 self.adapter.delete_session(session_id).await?;
298 }
299 Err(AuthError::InvalidSessionId)
300 }
301 }
302
303 pub async fn get_user(&self, user_id: &UserId) -> Result<U, AuthError> {
304 Ok(self
305 .adapter
306 .read_user(user_id)
307 .await
308 .map_err(|err| match err {
309 UserError::DatabaseError(err) => AuthError::DatabaseError(err),
310 UserError::UserDoesNotExist => AuthError::InvalidUserId,
311 })?
312 .user_attributes)
313 }
314
315 pub fn handle_request<'a>(
316 &'a self,
317 cookies: &CookieJar,
318 method: &Method,
319 headers: &HeaderMap,
320 origin_url: &Url,
321 ) -> Request<'a, D, U> {
322 Request::new(self, cookies, method, headers, origin_url)
323 }
324
325 async fn delete_dead_user_sessions(&self, user_id: &UserId) -> Result<(), AuthError> {
326 let database_sessions = self.adapter.read_sessions(user_id).await?;
327 let dead_session_ids = database_sessions.into_iter().filter_map(|s| {
328 if utils::is_within_expiration(s.session_data.idle_period_expires_at) {
329 None
330 } else {
331 Some(s.session_data.session_id)
332 }
333 });
334 stream::iter(dead_session_ids)
335 .map(|id| async move { self.adapter.delete_session(&id).await })
336 .buffer_unordered(10)
337 .try_collect()
338 .await?;
339 Ok(())
340 }
341
342 pub async fn update_user_attributes(
343 &self,
344 user_id: &UserId,
345 attributes: U,
346 ) -> Result<(), AuthError> {
347 if self.auto_database_cleanup {
348 let (res, _) = join!(
349 self.adapter.update_user(user_id, &attributes),
350 self.delete_dead_user_sessions(user_id)
351 );
352 res
353 } else {
354 self.adapter.update_user(user_id, &attributes).await
355 }
356 .map_err(|err| match err {
357 UserError::DatabaseError(err) => AuthError::DatabaseError(err),
358 UserError::UserDoesNotExist => AuthError::InvalidUserId,
359 })
360 }
361
362 pub async fn invalidate_all_user_sessions(&self, user_id: &UserId) -> Result<(), AuthError> {
363 Ok(self.adapter.delete_sessions_by_user_id(user_id).await?)
364 }
365
366 pub async fn delete_user(&self, user_id: UserId) -> Result<(), AuthError> {
367 self.adapter.delete_sessions_by_user_id(&user_id).await?;
368 self.adapter.delete_keys(&user_id).await?;
369 Ok(self.adapter.delete_user(&user_id).await?)
370 }
371
372 pub async fn create_key(
373 &self,
374 user_id: UserId,
375 user_data: UserData,
376 key_type: &NaiveKeyType,
377 ) -> Result<Key, AuthError> {
378 let key_id = format!("{}:{}", user_data.provider_id, user_data.provider_user_id);
379 let hashed_password = if let Some(password) = user_data.password.clone() {
380 Some(utils::hash_password(password).await)
381 } else {
382 None
383 };
384 let key_type = if let NaiveKeyType::SingleUse { expires_in } = key_type {
385 let expires_at = get_one_time_key_expiration(expires_in.get_timestamp());
386 self.adapter
387 .create_key(&KeySchema {
388 id: key_id,
389 hashed_password,
390 user_id: user_id.clone(),
391 primary_key: false,
392 expires: Some(expires_at),
393 })
394 .await
395 .map_err(|err| match err {
396 CreateKeyError::DatabaseError(err) => AuthError::DatabaseError(err),
397 CreateKeyError::KeyAlreadyExists => AuthError::DuplicateKeyId,
398 CreateKeyError::UserDoesNotExist => AuthError::InvalidUserId,
399 })?;
400 KeyType::SingleUse {
401 expires_in: expires_at.into(),
402 }
403 } else {
404 self.adapter
405 .create_key(&KeySchema {
406 id: key_id,
407 hashed_password,
408 user_id: user_id.clone(),
409 primary_key: false,
410 expires: None,
411 })
412 .await
413 .map_err(|err| match err {
414 CreateKeyError::DatabaseError(err) => AuthError::DatabaseError(err),
415 CreateKeyError::KeyAlreadyExists => AuthError::DuplicateKeyId,
416 CreateKeyError::UserDoesNotExist => AuthError::InvalidUserId,
417 })?;
418 KeyType::Persistent { primary: false }
419 };
420 Ok(Key {
421 key_type,
422 password_defined: if let Some(password) = user_data.password {
423 !password.is_empty()
424 } else {
425 false
426 },
427 user_id,
428 provider_id: user_data.provider_id,
429 provider_user_id: user_data.provider_user_id,
430 })
431 }
432
433 pub async fn get_key(
434 &self,
435 provider_id: &str,
436 provider_user_id: &str,
437 ) -> Result<Key, AuthError> {
438 let key_id = format!("{provider_id}:{provider_user_id}");
439 let database_key = self
440 .adapter
441 .read_key(&key_id)
442 .await
443 .map_err(|err| match err {
444 KeyError::DatabaseError(err) => AuthError::DatabaseError(err),
445 KeyError::KeyDoesNotExist => AuthError::InvalidKeyId,
446 })?;
447 Ok(database_key.into())
448 }
449
450 pub async fn get_all_user_keys(&self, user_id: &UserId) -> Result<Vec<Key>, AuthError> {
451 let database_data = self.adapter.read_keys_by_user_id(user_id).await?;
452 Ok(database_data
453 .into_iter()
454 .map(std::convert::Into::into)
455 .collect())
456 }
457
458 pub async fn update_key_password(&self, data: UserData) -> Result<(), AuthError> {
459 let key_id = format!("{}:{}", data.provider_id, data.provider_user_id);
460 if let Some(password) = data.password {
461 let hashed_password = utils::hash_password(password).await;
462 self.adapter
463 .update_key_password(&key_id, Some(&hashed_password))
464 .await
465 } else {
466 self.adapter.update_key_password(&key_id, None).await
467 }
468 .map_err(|err| match err {
469 KeyError::DatabaseError(err) => AuthError::DatabaseError(err),
470 KeyError::KeyDoesNotExist => AuthError::InvalidKeyId,
471 })
472 }
473
474 pub async fn delete_key(
475 &self,
476 provider_id: &str,
477 provider_user_id: &str,
478 ) -> Result<(), AuthError> {
479 let key_id = format!("{provider_id}:{provider_user_id}");
480 Ok(self.adapter.delete_non_primary_key(&key_id).await?)
481 }
482}
483
484#[must_use]
485pub fn parse_request_headers<'c>(
486 cookies: &'c CookieJar,
487 method: &Method,
488 headers: &HeaderMap,
489 origin_url: &Url,
490) -> Option<SessionId<'c>> {
491 debug!("{}, {}", method, origin_url);
492 let session_id: Option<SessionId> = cookies.get(SESSION_COOKIE_NAME).map(|c| c.value().into());
493 if let Some(session_id) = &session_id {
494 info!("found session cookie: {}", session_id.as_str());
495 } else {
496 info!("no session cookie found");
497 }
498 let csrf_check = method != Method::GET && method != Method::HEAD;
499 if csrf_check {
500 let request_origin = headers.get("origin");
501 match request_origin {
502 Some(request_origin) => {
503 if let Ok(request_origin) = request_origin.to_str() {
504 if origin_url.as_str() != request_origin {
505 error!(request_origin, "invalid request origin");
506 return None;
507 }
508 } else {
509 error!("invalid origin string: {:?}", request_origin);
510 return None;
511 }
512 info!("valid request origin: {:?}", request_origin);
513 }
514 None => {
515 error!("no request origin available");
516 return None;
517 }
518 }
519 }
520 session_id
521}
522
523fn generate_session_id() -> SessionData {
524 const ACTIVE_PERIOD: u64 = 1000 * 60 * 60 * 24;
525 const IDLE_PERIOD: u64 = 1000 * 60 * 60 * 24 * 14;
526 let session_id = Uuid::new_v4().to_string();
527 let active_period_expires_at = OffsetDateTime::now_utc() + Duration::from_millis(ACTIVE_PERIOD);
528 let idle_period_expires_at = active_period_expires_at + Duration::from_millis(IDLE_PERIOD);
529 SessionData {
530 active_period_expires_at: active_period_expires_at.unix_timestamp(),
531 idle_period_expires_at: idle_period_expires_at.unix_timestamp(),
532 session_id,
533 }
534}
535
536fn get_one_time_key_expiration(duration: i64) -> i64 {
537 assert!(duration >= 0, "duration cannot be negative");
538 (OffsetDateTime::now_utc() + Duration::from_millis(duration as u64 * 1000 * 1000))
539 .unix_timestamp()
540}
541
542#[must_use]
543pub fn create_session_cookie<'c>(session: Option<Session>) -> Cookie<'c> {
544 if let Some(session) = session {
545 Cookie::build(SESSION_COOKIE_NAME, session.session_id)
546 .same_site(cookie::SameSite::Lax)
547 .path("/")
548 .http_only(true)
549 .expires(OffsetDateTime::from_unix_timestamp(session.idle_period_expires_at).unwrap())
550 .secure(true)
551 .finish()
552 } else {
553 Cookie::build(SESSION_COOKIE_NAME, "")
554 .same_site(cookie::SameSite::Lax)
555 .path("/")
556 .http_only(true)
557 .expires(OffsetDateTime::UNIX_EPOCH)
558 .secure(true)
559 .finish()
560 }
561}
562
563impl KeyTimestamp {
564 #[must_use]
565 pub fn get_timestamp(&self) -> i64 {
566 self.0
567 }
568
569 #[must_use]
570 pub fn is_expired(&self) -> bool {
571 !utils::is_within_expiration(self.get_timestamp())
572 }
573}
574
575impl From<i64> for KeyTimestamp {
576 fn from(value: i64) -> Self {
577 Self(value)
578 }
579}
580
581impl From<KeySchema> for Key {
582 fn from(database_key: KeySchema) -> Self {
583 let user_id = database_key.user_id;
584 let is_password_defined = if let Some(hashed_password) = database_key.hashed_password {
585 !hashed_password.is_empty()
586 } else {
587 false
588 };
589 let (provider_id, provider_user_id) = database_key.id.split_once(':').unwrap();
590 let key_type = if let Some(expires) = database_key.expires {
591 KeyType::SingleUse {
592 expires_in: expires.into(),
593 }
594 } else {
595 KeyType::Persistent {
596 primary: database_key.primary_key,
597 }
598 };
599 Self {
600 key_type,
601 password_defined: is_password_defined,
602 user_id,
603 provider_id: provider_id.to_string(),
604 provider_user_id: provider_user_id.to_string(),
605 }
606 }
607}
608
609#[cfg(test)]
610mod tests {
611 use crate::{database::SessionData, utils::validate_database_session, Session, SessionState};
612 use cookie::time::OffsetDateTime;
613 use std::time::Duration;
614
615 #[test]
616 fn validate_database_session_returns_none_if_dead_state() {
617 let output = validate_database_session(SessionData {
618 active_period_expires_at: OffsetDateTime::now_utc().unix_timestamp(),
619 idle_period_expires_at: (OffsetDateTime::now_utc() - Duration::from_millis(10 * 1000))
620 .unix_timestamp(),
621 session_id: String::new(),
622 });
623 assert!(output.is_none());
624 }
625
626 #[test]
627 fn validate_database_session_returns_idle_session_if_idle_state() {
628 let output = validate_database_session(SessionData {
629 active_period_expires_at: (OffsetDateTime::now_utc()
630 - Duration::from_millis(10 * 1000))
631 .unix_timestamp(),
632 idle_period_expires_at: (OffsetDateTime::now_utc() + Duration::from_millis(10 * 1000))
633 .unix_timestamp(),
634 session_id: String::new(),
635 });
636 assert!(matches!(
637 output,
638 Some(Session {
639 state: SessionState::Idle,
640 ..
641 })
642 ))
643 }
644
645 #[test]
646 fn validate_database_session_returns_active_session_if_active_state() {
647 let output = validate_database_session(SessionData {
648 active_period_expires_at: (OffsetDateTime::now_utc()
649 + Duration::from_millis(10 * 1000))
650 .unix_timestamp(),
651 idle_period_expires_at: (OffsetDateTime::now_utc() + Duration::from_millis(10 * 1000))
652 .unix_timestamp(),
653 session_id: String::new(),
654 });
655 assert!(matches!(
656 output,
657 Some(Session {
658 state: SessionState::Active,
659 ..
660 })
661 ))
662 }
663}