use std::collections::HashMap;
use std::sync::Mutex;
use std::time::{Duration, Instant};
pub trait RecoveryTokenStore: Send + Sync {
fn store(&self, token: &str, credential_identifier: &str) -> Result<(), String>;
fn peek(&self, token: &str) -> Option<String>;
fn remove(&self, token: &str) -> Option<String>;
}
struct TokenEntry {
credential_identifier: String,
created_at: Instant,
}
pub struct InMemoryRecoveryTokenStore {
ttl: Duration,
max_tokens: usize,
tokens: Mutex<HashMap<String, TokenEntry>>,
}
impl InMemoryRecoveryTokenStore {
pub const DEFAULT_TTL_SECS: u64 = 600;
pub const DEFAULT_MAX_TOKENS: usize = 10_000;
pub fn new() -> Self {
Self::with_config(
Duration::from_secs(Self::DEFAULT_TTL_SECS),
Self::DEFAULT_MAX_TOKENS,
)
}
pub fn with_config(ttl: Duration, max_tokens: usize) -> Self {
Self {
ttl,
max_tokens,
tokens: Mutex::new(HashMap::new()),
}
}
fn cleanup(tokens: &mut HashMap<String, TokenEntry>, ttl: Duration) {
let cutoff = Instant::now() - ttl;
tokens.retain(|_, entry| entry.created_at > cutoff);
}
fn is_valid(entry: &TokenEntry, ttl: Duration) -> bool {
entry.created_at.elapsed() < ttl
}
}
impl Default for InMemoryRecoveryTokenStore {
fn default() -> Self {
Self::new()
}
}
impl RecoveryTokenStore for InMemoryRecoveryTokenStore {
fn store(&self, token: &str, credential_identifier: &str) -> Result<(), String> {
let mut tokens = self.tokens.lock().unwrap();
Self::cleanup(&mut tokens, self.ttl);
if tokens.len() >= self.max_tokens {
return Err("Too many pending recovery tokens".to_string());
}
tokens.insert(
token.to_string(),
TokenEntry {
credential_identifier: credential_identifier.to_string(),
created_at: Instant::now(),
},
);
Ok(())
}
fn peek(&self, token: &str) -> Option<String> {
let mut tokens = self.tokens.lock().unwrap();
Self::cleanup(&mut tokens, self.ttl);
tokens.get(token)
.filter(|entry| Self::is_valid(entry, self.ttl))
.map(|entry| entry.credential_identifier.clone())
}
fn remove(&self, token: &str) -> Option<String> {
let mut tokens = self.tokens.lock().unwrap();
tokens.remove(token)
.filter(|entry| Self::is_valid(entry, self.ttl))
.map(|entry| entry.credential_identifier)
}
}