use crate::error::{Result, TidewayError};
use async_trait::async_trait;
#[async_trait]
pub trait RefreshTokenStore: Send + Sync {
async fn is_family_revoked(&self, family: &str) -> Result<bool>;
async fn get_family_generation(&self, family: &str) -> Result<Option<u32>>;
async fn set_family_generation(&self, family: &str, generation: u32) -> Result<()>;
async fn compare_and_swap_family_generation(
&self,
_family: &str,
_expected_generation: u32,
_new_generation: u32,
) -> Result<bool> {
Err(TidewayError::internal(
"RefreshTokenStore::compare_and_swap_family_generation must be implemented atomically",
))
}
async fn revoke_family(&self, family: &str) -> Result<()>;
async fn revoke_all_for_user(&self, user_id: &str) -> Result<()>;
async fn associate_family_with_user(&self, family: &str, user_id: &str) -> Result<()>;
}
#[async_trait]
pub trait MfaTokenStore: Send + Sync {
async fn store(&self, token: &str, user_id: &str, ttl: std::time::Duration) -> Result<()>;
async fn consume(&self, token: &str) -> Result<Option<String>>;
}
#[cfg(any(test, feature = "test-auth-bypass"))]
pub mod test {
use super::*;
use std::collections::HashMap;
use std::sync::RwLock;
use std::time::{Duration, Instant};
#[derive(Default)]
pub struct InMemoryRefreshTokenStore {
families: RwLock<HashMap<String, FamilyState>>,
user_families: RwLock<HashMap<String, Vec<String>>>,
}
struct FamilyState {
generation: u32,
revoked: bool,
}
impl InMemoryRefreshTokenStore {
pub fn new() -> Self {
Self::default()
}
}
#[async_trait]
impl RefreshTokenStore for InMemoryRefreshTokenStore {
async fn is_family_revoked(&self, family: &str) -> Result<bool> {
let families = self.families.read().unwrap();
Ok(families.get(family).map(|s| s.revoked).unwrap_or(false))
}
async fn get_family_generation(&self, family: &str) -> Result<Option<u32>> {
let families = self.families.read().unwrap();
Ok(families.get(family).map(|s| s.generation))
}
async fn set_family_generation(&self, family: &str, generation: u32) -> Result<()> {
let mut families = self.families.write().unwrap();
families.insert(
family.to_string(),
FamilyState {
generation,
revoked: false,
},
);
Ok(())
}
async fn compare_and_swap_family_generation(
&self,
family: &str,
expected_generation: u32,
new_generation: u32,
) -> Result<bool> {
let mut families = self.families.write().unwrap();
match families.get_mut(family) {
Some(state) => {
if state.revoked || state.generation != expected_generation {
return Ok(false);
}
state.generation = new_generation;
Ok(true)
}
None if expected_generation == 0 => {
families.insert(
family.to_string(),
FamilyState {
generation: new_generation,
revoked: false,
},
);
Ok(true)
}
None => Ok(false),
}
}
async fn revoke_family(&self, family: &str) -> Result<()> {
let mut families = self.families.write().unwrap();
if let Some(state) = families.get_mut(family) {
state.revoked = true;
}
Ok(())
}
async fn revoke_all_for_user(&self, user_id: &str) -> Result<()> {
let user_families = self.user_families.read().unwrap();
if let Some(families_list) = user_families.get(user_id) {
let mut families = self.families.write().unwrap();
for family in families_list {
if let Some(state) = families.get_mut(family) {
state.revoked = true;
}
}
}
Ok(())
}
async fn associate_family_with_user(&self, family: &str, user_id: &str) -> Result<()> {
let mut user_families = self.user_families.write().unwrap();
user_families
.entry(user_id.to_string())
.or_default()
.push(family.to_string());
Ok(())
}
}
#[derive(Default)]
pub struct InMemoryMfaTokenStore {
tokens: RwLock<HashMap<String, (String, Instant)>>,
}
impl InMemoryMfaTokenStore {
pub fn new() -> Self {
Self::default()
}
}
#[async_trait]
impl MfaTokenStore for InMemoryMfaTokenStore {
async fn store(&self, token: &str, user_id: &str, ttl: Duration) -> Result<()> {
let mut tokens = self.tokens.write().unwrap();
tokens.insert(
token.to_string(),
(user_id.to_string(), Instant::now() + ttl),
);
Ok(())
}
async fn consume(&self, token: &str) -> Result<Option<String>> {
let mut tokens = self.tokens.write().unwrap();
if let Some((user_id, expires)) = tokens.remove(token) {
if Instant::now() < expires {
return Ok(Some(user_id));
}
}
Ok(None)
}
}
#[cfg(test)]
struct NonAtomicDefaultStore;
#[cfg(test)]
#[async_trait]
impl RefreshTokenStore for NonAtomicDefaultStore {
async fn is_family_revoked(&self, _family: &str) -> Result<bool> {
Ok(false)
}
async fn get_family_generation(&self, _family: &str) -> Result<Option<u32>> {
Ok(None)
}
async fn set_family_generation(&self, _family: &str, _generation: u32) -> Result<()> {
Ok(())
}
async fn revoke_family(&self, _family: &str) -> Result<()> {
Ok(())
}
async fn revoke_all_for_user(&self, _user_id: &str) -> Result<()> {
Ok(())
}
async fn associate_family_with_user(&self, _family: &str, _user_id: &str) -> Result<()> {
Ok(())
}
}
#[cfg(test)]
#[tokio::test]
async fn test_default_compare_and_swap_fails_closed() {
let store = NonAtomicDefaultStore;
let result = store
.compare_and_swap_family_generation("family-1", 0, 1)
.await;
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("must be implemented atomically")
);
}
}