#[cfg(feature = "db")]
pub mod db;
use std::any::Any;
use std::borrow::Cow;
use std::fmt::Debug;
use std::sync::Arc;
use async_trait::async_trait;
use chrono::{DateTime, FixedOffset};
#[cfg(test)]
use mockall::automock;
use password_auth::VerifyError;
use serde::{Deserialize, Serialize};
use subtle::ConstantTimeEq;
use thiserror::Error;
use crate::config::SecretKey;
#[cfg(feature = "db")]
use crate::db::{ColumnType, DatabaseField, DbValue, FromDbValue, SqlxValueRef, ToDbValue};
use crate::request::{Request, RequestExt};
#[derive(Debug, Error)]
pub enum AuthError {
#[error("Password hash is invalid")]
PasswordHashInvalid,
#[error("Error while accessing the session object")]
SessionAccess(#[from] tower_sessions::session::Error),
#[error("Error while accessing the user object")]
UserBackend(#[source] Box<dyn std::error::Error + Send + Sync + 'static>),
#[error("Tried to authenticate with an unsupported credentials type")]
CredentialsTypeNotSupported,
#[error("Tried to get a user by an unsupported user ID type")]
UserIdTypeNotSupported,
}
impl AuthError {
pub fn backend_error(error: impl std::error::Error + Send + Sync + 'static) -> Self {
Self::UserBackend(Box::new(error))
}
}
pub type Result<T> = std::result::Result<T, AuthError>;
#[cfg_attr(test, automock)]
pub trait User {
fn id(&self) -> Option<UserId> {
None
}
#[allow(clippy::needless_lifetimes)]
fn username<'a>(&'a self) -> Option<Cow<'a, str>> {
None
}
fn is_active(&self) -> bool {
false
}
fn is_authenticated(&self) -> bool {
false
}
fn last_login(&self) -> Option<DateTime<FixedOffset>> {
None
}
fn joined(&self) -> Option<DateTime<FixedOffset>> {
None
}
#[allow(unused_variables)]
fn session_auth_hash(&self, secret_key: &SecretKey) -> Option<SessionAuthHash> {
None
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(untagged)]
pub enum UserId {
Int(i64),
String(String),
}
impl UserId {
#[must_use]
pub fn as_int(&self) -> Option<i64> {
match self {
Self::Int(id) => Some(*id),
Self::String(_) => None,
}
}
#[must_use]
pub fn as_string(&self) -> Option<&str> {
match self {
Self::Int(_) => None,
Self::String(id) => Some(id),
}
}
}
#[derive(Debug, Copy, Clone, Default)]
pub struct AnonymousUser();
impl PartialEq for AnonymousUser {
fn eq(&self, _other: &Self) -> bool {
true
}
}
impl User for AnonymousUser {}
#[repr(transparent)]
#[derive(Clone)]
pub struct SessionAuthHash(Box<[u8]>);
impl SessionAuthHash {
#[must_use]
pub fn new(hash: &[u8]) -> Self {
Self(Box::from(hash))
}
#[must_use]
pub fn as_bytes(&self) -> &[u8] {
&self.0
}
#[must_use]
pub fn into_bytes(self) -> Box<[u8]> {
self.0
}
}
impl From<&[u8]> for SessionAuthHash {
fn from(hash: &[u8]) -> Self {
Self::new(hash)
}
}
impl PartialEq for SessionAuthHash {
fn eq(&self, other: &Self) -> bool {
self.0.ct_eq(&other.0).into()
}
}
impl Eq for SessionAuthHash {}
impl Debug for SessionAuthHash {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_tuple("SessionAuthHash")
.field(&"**********")
.finish()
}
}
#[repr(transparent)]
#[derive(Clone)]
pub struct PasswordHash(String);
impl PasswordHash {
pub fn new<T: Into<String>>(hash: T) -> Result<Self> {
let hash = hash.into();
if hash.len() > MAX_PASSWORD_HASH_LENGTH as usize {
return Err(AuthError::PasswordHashInvalid);
}
password_auth::is_hash_obsolete(&hash).map_err(|_| AuthError::PasswordHashInvalid)?;
Ok(Self(hash))
}
#[must_use]
pub fn from_password(password: &Password) -> Self {
let hash = password_auth::generate_hash(password.as_str());
if hash.len() > MAX_PASSWORD_HASH_LENGTH as usize {
unreachable!("password hash should never exceed {MAX_PASSWORD_HASH_LENGTH} bytes");
}
Self(hash)
}
pub fn verify(&self, password: &Password) -> PasswordVerificationResult {
const VALID_ERROR_STR: &str = "password hash should always be valid if created with `PasswordHash::new` or `PasswordHash::from_password`";
match password_auth::verify_password(password.as_str(), &self.0) {
Ok(()) => {
let Ok(is_obsolete) = password_auth::is_hash_obsolete(&self.0) else {
unreachable!("{VALID_ERROR_STR}");
};
if is_obsolete {
PasswordVerificationResult::OkObsolete(PasswordHash::from_password(password))
} else {
PasswordVerificationResult::Ok
}
}
Err(error) => match error {
VerifyError::PasswordInvalid => PasswordVerificationResult::Invalid,
VerifyError::Parse(_) => unreachable!("{VALID_ERROR_STR}"),
},
}
}
#[must_use]
pub fn as_str(&self) -> &str {
&self.0
}
#[must_use]
pub fn into_string(self) -> String {
self.0
}
}
impl TryFrom<String> for PasswordHash {
type Error = AuthError;
fn try_from(value: String) -> std::result::Result<Self, Self::Error> {
Self::new(value)
}
}
#[derive(Debug, Clone)]
#[must_use]
pub enum PasswordVerificationResult {
Ok,
OkObsolete(PasswordHash),
Invalid,
}
impl Debug for PasswordHash {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_tuple("PasswordHash")
.field(&format!("{}**********", &self.0[..10]))
.finish()
}
}
const MAX_PASSWORD_HASH_LENGTH: u32 = 128;
#[cfg(feature = "db")]
impl DatabaseField for PasswordHash {
const TYPE: ColumnType = ColumnType::String(MAX_PASSWORD_HASH_LENGTH);
}
#[cfg(feature = "db")]
impl FromDbValue for PasswordHash {
#[cfg(feature = "sqlite")]
fn from_sqlite(value: crate::db::impl_sqlite::SqliteValueRef<'_>) -> cot::db::Result<Self> {
PasswordHash::new(value.get::<String>()?).map_err(cot::db::DatabaseError::value_decode)
}
#[cfg(feature = "postgres")]
fn from_postgres(
value: crate::db::impl_postgres::PostgresValueRef<'_>,
) -> cot::db::Result<Self> {
PasswordHash::new(value.get::<String>()?).map_err(cot::db::DatabaseError::value_decode)
}
#[cfg(feature = "mysql")]
fn from_mysql(value: crate::db::impl_mysql::MySqlValueRef<'_>) -> crate::db::Result<Self>
where
Self: Sized,
{
PasswordHash::new(value.get::<String>()?).map_err(cot::db::DatabaseError::value_decode)
}
}
#[cfg(feature = "db")]
impl ToDbValue for PasswordHash {
fn to_db_value(&self) -> DbValue {
self.0.clone().into()
}
}
#[derive(Clone)]
pub struct Password(String);
impl Debug for Password {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_tuple("Password").field(&"**********").finish()
}
}
impl Password {
#[must_use]
pub fn new<T: Into<String>>(password: T) -> Self {
Self(password.into())
}
#[must_use]
pub fn as_str(&self) -> &str {
&self.0
}
#[must_use]
pub fn into_string(self) -> String {
self.0
}
}
impl From<&Password> for Password {
fn from(password: &Password) -> Self {
password.clone()
}
}
impl From<&str> for Password {
fn from(password: &str) -> Self {
Self::new(password)
}
}
impl From<String> for Password {
fn from(password: String) -> Self {
Self::new(password)
}
}
mod private {
pub trait Sealed {}
}
#[async_trait]
pub trait AuthRequestExt: private::Sealed {
async fn user(&mut self) -> Result<&dyn User>;
async fn authenticate(
&mut self,
credentials: &(dyn Any + Send + Sync),
) -> Result<Option<Box<dyn User + Send + Sync>>>;
async fn login(&mut self, user: Box<dyn User + Send + Sync + 'static>) -> Result<()>;
async fn logout(&mut self) -> Result<()>;
}
const USER_ID_SESSION_KEY: &str = "__cot_auth_user_id";
const SESSION_HASH_SESSION_KEY: &str = "__cot_auth_session_hash";
type UserExtension = Arc<dyn User + Send + Sync + 'static>;
impl private::Sealed for Request {}
#[async_trait]
impl AuthRequestExt for Request {
async fn user(&mut self) -> Result<&dyn User> {
if self.extensions().get::<UserExtension>().is_none() {
if let Some(user) = get_user_with_saved_id(self).await? {
self.extensions_mut().insert(UserExtension::from(user));
} else {
self.logout().await?;
}
}
Ok(&**self
.extensions()
.get::<UserExtension>()
.expect("User extension should have just been added"))
}
async fn authenticate(
&mut self,
credentials: &(dyn Any + Send + Sync),
) -> Result<Option<Box<dyn User + Send + Sync>>> {
self.context()
.auth_backend()
.authenticate(self, credentials)
.await
}
async fn login(&mut self, user: Box<dyn User + Send + Sync + 'static>) -> Result<()> {
let user = UserExtension::from(user);
if let Some(user_id) = user.id() {
self.session_mut()
.insert(USER_ID_SESSION_KEY, user_id)
.await?;
}
let secret_key = &self.project_config().secret_key;
if let Some(session_auth_hash) = user.session_auth_hash(secret_key) {
self.session_mut()
.insert(SESSION_HASH_SESSION_KEY, session_auth_hash.as_bytes())
.await?;
}
self.extensions_mut().insert(user);
Ok(())
}
async fn logout(&mut self) -> Result<()> {
self.session_mut().remove_value(USER_ID_SESSION_KEY).await?;
self.session_mut()
.remove_value(SESSION_HASH_SESSION_KEY)
.await?;
self.extensions_mut()
.insert::<UserExtension>(Arc::new(AnonymousUser()));
Ok(())
}
}
async fn get_user_with_saved_id(
request: &mut Request,
) -> Result<Option<Box<dyn User + Send + Sync>>> {
let Some(user_id) = request.session().get::<UserId>(USER_ID_SESSION_KEY).await? else {
return Ok(None);
};
let Some(user) = request
.context()
.auth_backend()
.get_by_id(request, user_id)
.await?
else {
return Ok(None);
};
if session_auth_hash_valid(&*user, request).await? {
Ok(Some(user))
} else {
Ok(None)
}
}
async fn session_auth_hash_valid(
user: &(dyn User + Send + Sync),
request: &mut Request,
) -> Result<bool> {
let config = request.project_config();
let Some(user_hash) = user.session_auth_hash(&config.secret_key) else {
return Ok(true);
};
let stored_hash = request
.session()
.get::<Vec<u8>>(SESSION_HASH_SESSION_KEY)
.await?
.expect("Session hash should be present in the session object");
let stored_hash = SessionAuthHash::new(&stored_hash);
if user_hash == stored_hash {
return Ok(true);
}
for fallback_key in &config.fallback_secret_keys {
let user_hash_fallback = user
.session_auth_hash(fallback_key)
.expect("User should have a session hash for each secret key");
if user_hash_fallback == stored_hash {
request
.session_mut()
.insert(SESSION_HASH_SESSION_KEY, user_hash.as_bytes())
.await?;
return Ok(true);
}
}
Ok(false)
}
#[async_trait]
pub trait AuthBackend: Send + Sync {
async fn authenticate(
&self,
request: &Request,
credentials: &(dyn Any + Send + Sync),
) -> Result<Option<Box<dyn User + Send + Sync>>>;
async fn get_by_id(
&self,
request: &Request,
id: UserId,
) -> Result<Option<Box<dyn User + Send + Sync>>>;
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
pub struct NoAuthBackend;
#[async_trait]
impl AuthBackend for NoAuthBackend {
async fn authenticate(
&self,
_request: &Request,
_credentials: &(dyn Any + Send + Sync),
) -> Result<Option<Box<dyn User + Send + Sync>>> {
Ok(None)
}
async fn get_by_id(
&self,
_request: &Request,
_id: UserId,
) -> Result<Option<Box<dyn User + Send + Sync>>> {
Ok(None)
}
}
#[cfg(test)]
mod tests {
use std::sync::Mutex;
use mockall::predicate::eq;
use super::*;
use crate::config::ProjectConfig;
use crate::test::TestRequestBuilder;
struct MockAuthBackend<F> {
return_user: F,
}
#[async_trait]
impl<F: Fn() -> MockUser + Send + Sync + 'static> AuthBackend for MockAuthBackend<F> {
async fn authenticate(
&self,
_request: &Request,
_credentials: &(dyn Any + Send + Sync),
) -> Result<Option<Box<dyn User + Send + Sync>>> {
Ok(Some(Box::new((self.return_user)())))
}
async fn get_by_id(
&self,
_request: &Request,
_id: UserId,
) -> Result<Option<Box<dyn User + Send + Sync>>> {
Ok(Some(Box::new((self.return_user)())))
}
}
const TEST_KEY_1: &[u8] = b"key1";
const TEST_KEY_2: &[u8] = b"key2";
const TEST_KEY_3: &[u8] = b"key3";
fn test_request<T: Fn() -> MockUser + Send + Sync + 'static>(return_user: T) -> Request {
test_request_with_auth_backend(MockAuthBackend { return_user })
}
fn test_request_with_auth_backend<T: AuthBackend + 'static>(auth_backend: T) -> Request {
TestRequestBuilder::get("/")
.with_session()
.config(test_project_config(SecretKey::new(TEST_KEY_1), vec![]))
.auth_backend(auth_backend)
.build()
}
fn test_request_with_auth_config_and_session<T: AuthBackend + 'static>(
auth_backend: T,
config: ProjectConfig,
session_source: &Request,
) -> Request {
TestRequestBuilder::get("/")
.with_session_from(session_source)
.config(config)
.auth_backend(auth_backend)
.build()
}
fn test_project_config(secret_key: SecretKey, fallback_keys: Vec<SecretKey>) -> ProjectConfig {
ProjectConfig::builder()
.secret_key(secret_key)
.fallback_secret_keys(fallback_keys)
.clone()
.build()
}
#[test]
fn anonymous_user() {
let anonymous_user = AnonymousUser();
assert_eq!(anonymous_user.id(), None);
assert_eq!(anonymous_user.username(), None);
assert!(!anonymous_user.is_active());
assert!(!anonymous_user.is_authenticated());
assert_eq!(anonymous_user.last_login(), None);
assert_eq!(anonymous_user.joined(), None);
assert_eq!(
anonymous_user.session_auth_hash(&SecretKey::new(b"key")),
None
);
let anonymous_user2 = AnonymousUser();
assert_eq!(anonymous_user, anonymous_user2);
}
#[test]
#[cfg_attr(miri, ignore)]
fn password_hash() {
let password = Password::new("password".to_string());
let hash = PasswordHash::from_password(&password);
match hash.verify(&password) {
PasswordVerificationResult::Ok => {}
_ => panic!("Password hash verification failed"),
}
}
#[test]
fn session_auth_hash_debug() {
let hash = SessionAuthHash::from([1, 2, 3].as_ref());
assert_eq!(format!("{hash:?}"), "SessionAuthHash(\"**********\")");
}
#[test]
fn password_debug() {
let password = Password::new("password");
assert_eq!(format!("{password:?}"), "Password(\"**********\")");
}
#[test]
fn password_str() {
let password = Password::new("password");
assert_eq!(password.as_str(), "password");
assert_eq!(password.into_string(), "password");
}
const TEST_PASSWORD_HASH: &str = "$argon2id$v=19$m=19456,t=2,p=1$QAAI3EMU1eTLT9NzzBhQjg$khq4zuHsEyk9trGjuqMBFYnTbpqkmn0wXGxFn1nkPBc";
#[test]
#[cfg_attr(miri, ignore)]
fn password_hash_debug() {
let hash = PasswordHash::new(TEST_PASSWORD_HASH).unwrap();
assert_eq!(
format!("{hash:?}"),
"PasswordHash(\"$argon2id$**********\")"
);
}
#[test]
#[cfg_attr(miri, ignore)]
fn password_hash_verify() {
let password = Password::new("password");
let hash = PasswordHash::from_password(&password);
match hash.verify(&password) {
PasswordVerificationResult::Ok => {}
_ => panic!("Password hash verification failed"),
}
let wrong_password = Password::new("wrongpassword");
match hash.verify(&wrong_password) {
PasswordVerificationResult::Invalid => {}
_ => panic!("Password hash verification failed"),
}
}
#[test]
#[cfg_attr(miri, ignore)]
fn password_hash_str() {
let hash = PasswordHash::new(TEST_PASSWORD_HASH).unwrap();
assert_eq!(hash.as_str(), TEST_PASSWORD_HASH);
assert_eq!(hash.into_string(), TEST_PASSWORD_HASH);
let hash = PasswordHash::try_from(TEST_PASSWORD_HASH.to_string()).unwrap();
assert_eq!(hash.as_str(), TEST_PASSWORD_HASH);
assert_eq!(hash.into_string(), TEST_PASSWORD_HASH);
}
#[cot::test]
async fn user_anonymous() {
let mut request = test_request_with_auth_backend(NoAuthBackend {});
let user = request.user().await.unwrap();
assert!(!user.is_authenticated());
assert!(!user.is_active());
}
#[cot::test]
async fn user() {
let mut request = test_request(|| {
let mut mock_user = MockUser::new();
mock_user.expect_id().return_const(UserId::Int(1));
mock_user.expect_session_auth_hash().return_const(None);
mock_user
.expect_username()
.return_const(Some(Cow::from("mockuser")));
mock_user
});
request
.session_mut()
.insert(USER_ID_SESSION_KEY, UserId::Int(1))
.await
.unwrap();
let user = request.user().await.unwrap();
assert_eq!(user.username(), Some(Cow::from("mockuser")));
}
#[cot::test]
async fn authenticate() {
let mut request = test_request(|| {
let mut mock_user = MockUser::new();
mock_user
.expect_username()
.return_const(Some(Cow::from("mockuser")));
mock_user
});
let credentials: &(dyn Any + Send + Sync) = &();
let user = request.authenticate(credentials).await.unwrap().unwrap();
assert_eq!(user.username(), Some(Cow::from("mockuser")));
}
#[cot::test]
async fn login_logout() {
let mut request = test_request(MockUser::new);
let mut mock_user = MockUser::new();
mock_user.expect_id().return_const(UserId::Int(1));
mock_user.expect_session_auth_hash().return_const(None);
mock_user
.expect_username()
.return_const(Some(Cow::from("mockuser")));
request.login(Box::new(mock_user)).await.unwrap();
let user = request.user().await.unwrap();
assert_eq!(user.username(), Some(Cow::from("mockuser")));
request.logout().await.unwrap();
let user = request.user().await.unwrap();
assert!(user.username().is_none());
}
#[cot::test]
async fn logout_on_invalid_user_id_in_session() {
let mut request = test_request_with_auth_backend(NoAuthBackend {});
request
.session_mut()
.insert(USER_ID_SESSION_KEY, UserId::Int(1))
.await
.unwrap();
let user = request.user().await.unwrap();
assert_eq!(user.username(), None);
assert!(!user.is_authenticated());
}
#[cot::test]
async fn logout_on_session_hash_change() {
let session_auth_hash = Arc::new(Mutex::new(SessionAuthHash::new(&[1, 2, 3])));
let session_auth_hash_clone = Arc::clone(&session_auth_hash);
let create_user = move || {
let session_auth_hash_clone = Arc::clone(&session_auth_hash_clone);
let mut mock_user = MockUser::new();
mock_user.expect_id().return_const(UserId::Int(1));
mock_user
.expect_session_auth_hash()
.returning(move |_| Some(session_auth_hash_clone.lock().unwrap().clone()));
mock_user
.expect_username()
.return_const(Some(Cow::from("mockuser")));
mock_user
};
let mut request = test_request(create_user.clone());
request.login(Box::new(create_user())).await.unwrap();
let user = request.user().await.unwrap();
assert_eq!(user.username(), Some(Cow::from("mockuser")));
request.extensions_mut().remove::<UserExtension>();
let user = request.user().await.unwrap();
assert_eq!(user.username(), Some(Cow::from("mockuser")));
request.extensions_mut().remove::<UserExtension>();
*session_auth_hash.lock().unwrap() = SessionAuthHash::new(&[4, 5, 6]);
let user = request.user().await.unwrap();
assert!(!user.is_authenticated());
assert_eq!(user.username(), None);
}
#[cot::test]
async fn user_secret_key_change() {
let create_user = move || {
let mut mock_user = MockUser::new();
mock_user.expect_id().return_const(UserId::Int(1));
mock_user
.expect_session_auth_hash()
.with(eq(SecretKey::new(TEST_KEY_1)))
.returning(move |_| Some(SessionAuthHash::new(&[1, 2, 3])));
mock_user
.expect_session_auth_hash()
.with(eq(SecretKey::new(TEST_KEY_2)))
.returning(move |_| Some(SessionAuthHash::new(&[4, 5, 6])));
mock_user
.expect_session_auth_hash()
.with(eq(SecretKey::new(TEST_KEY_3)))
.returning(move |_| Some(SessionAuthHash::new(&[7, 8, 9])));
mock_user
.expect_username()
.return_const(Some(Cow::from("mockuser")));
mock_user
};
let mut request = test_request(create_user);
request.login(Box::new(create_user())).await.unwrap();
let user = request.user().await.unwrap();
assert_eq!(user.username(), Some(Cow::from("mockuser")));
let replace_keys = move |request: &mut Request, secret_key, fallback_keys| {
let auth_backend = MockAuthBackend {
return_user: create_user,
};
let new_config = test_project_config(secret_key, fallback_keys);
*request = test_request_with_auth_config_and_session(auth_backend, new_config, request);
};
replace_keys(
&mut request,
SecretKey::new(TEST_KEY_2),
vec![SecretKey::new(TEST_KEY_1)],
);
let user = request.user().await.unwrap();
assert_eq!(user.username(), Some(Cow::from("mockuser")));
replace_keys(&mut request, SecretKey::new(TEST_KEY_2), vec![]);
let user = request.user().await.unwrap();
assert_eq!(user.username(), Some(Cow::from("mockuser")));
replace_keys(&mut request, SecretKey::new(TEST_KEY_3), vec![]);
let user = request.user().await.unwrap();
assert_eq!(user.username(), None);
assert!(!user.is_authenticated());
}
}