use async_trait::async_trait;
use chrono::{DateTime, Duration, Utc};
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use crate::error::{Error, Result};
use crate::random::generate_random_base64_url;
#[derive(Debug, Clone)]
pub struct MagicLinkConfig {
pub token_length: usize,
pub ttl: std::time::Duration,
pub max_active_per_user: usize,
pub consume_on_verify: bool,
}
impl Default for MagicLinkConfig {
fn default() -> Self {
Self {
token_length: 32, ttl: std::time::Duration::from_secs(15 * 60), max_active_per_user: 3,
consume_on_verify: true,
}
}
}
impl MagicLinkConfig {
pub fn new() -> Self {
Self::default()
}
pub fn with_token_length(mut self, length: usize) -> Self {
self.token_length = length;
self
}
pub fn with_ttl(mut self, ttl: std::time::Duration) -> Self {
self.ttl = ttl;
self
}
pub fn with_max_active_per_user(mut self, max: usize) -> Self {
self.max_active_per_user = max;
self
}
pub fn with_consume_on_verify(mut self, consume: bool) -> Self {
self.consume_on_verify = consume;
self
}
pub fn high_security() -> Self {
Self {
token_length: 48,
ttl: std::time::Duration::from_secs(5 * 60),
max_active_per_user: 1,
consume_on_verify: true,
}
}
pub fn relaxed() -> Self {
Self {
token_length: 24,
ttl: std::time::Duration::from_secs(60 * 60),
max_active_per_user: 10,
consume_on_verify: true,
}
}
}
#[derive(Debug, Clone)]
pub struct MagicLinkData {
pub token: String,
pub identifier: String,
pub created_at: DateTime<Utc>,
pub expires_at: DateTime<Utc>,
}
impl MagicLinkData {
pub fn is_expired(&self) -> bool {
Utc::now() > self.expires_at
}
pub fn remaining_seconds(&self) -> i64 {
let remaining = self.expires_at - Utc::now();
remaining.num_seconds().max(0)
}
}
#[derive(Debug, Clone)]
struct StoredToken {
identifier: String,
expires_at: DateTime<Utc>,
created_at: DateTime<Utc>,
}
#[async_trait]
pub trait MagicLinkStore: Send + Sync {
async fn save(&self, token: &str, identifier: &str, expires_at: DateTime<Utc>) -> Result<()>;
async fn get(&self, token: &str) -> Result<Option<(String, DateTime<Utc>)>>;
async fn delete(&self, token: &str) -> Result<()>;
async fn get_user_tokens(&self, identifier: &str) -> Result<Vec<String>>;
async fn delete_oldest_user_token(&self, identifier: &str) -> Result<()>;
async fn cleanup_expired(&self) -> Result<usize>;
}
#[derive(Debug, Clone, Default)]
pub struct InMemoryMagicLinkStore {
tokens: Arc<RwLock<HashMap<String, StoredToken>>>,
}
impl InMemoryMagicLinkStore {
pub fn new() -> Self {
Self::default()
}
pub fn len(&self) -> usize {
self.tokens.read().unwrap().len()
}
pub fn is_empty(&self) -> bool {
self.tokens.read().unwrap().is_empty()
}
}
#[async_trait]
impl MagicLinkStore for InMemoryMagicLinkStore {
async fn save(&self, token: &str, identifier: &str, expires_at: DateTime<Utc>) -> Result<()> {
let mut tokens = self.tokens.write().unwrap();
tokens.insert(
token.to_string(),
StoredToken {
identifier: identifier.to_string(),
expires_at,
created_at: Utc::now(),
},
);
Ok(())
}
async fn get(&self, token: &str) -> Result<Option<(String, DateTime<Utc>)>> {
let tokens = self.tokens.read().unwrap();
Ok(tokens
.get(token)
.map(|record| (record.identifier.clone(), record.expires_at)))
}
async fn delete(&self, token: &str) -> Result<()> {
let mut tokens = self.tokens.write().unwrap();
tokens.remove(token);
Ok(())
}
async fn get_user_tokens(&self, identifier: &str) -> Result<Vec<String>> {
let tokens = self.tokens.read().unwrap();
let user_tokens: Vec<String> = tokens
.iter()
.filter(|(_, record)| record.identifier == identifier)
.map(|(token, _)| token.clone())
.collect();
Ok(user_tokens)
}
async fn delete_oldest_user_token(&self, identifier: &str) -> Result<()> {
let mut tokens = self.tokens.write().unwrap();
let oldest = tokens
.iter()
.filter(|(_, record)| record.identifier == identifier)
.min_by_key(|(_, record)| record.created_at)
.map(|(token, _)| token.clone());
if let Some(token) = oldest {
tokens.remove(&token);
}
Ok(())
}
async fn cleanup_expired(&self) -> Result<usize> {
let mut tokens = self.tokens.write().unwrap();
let now = Utc::now();
let before = tokens.len();
tokens.retain(|_, record| record.expires_at > now);
Ok(before - tokens.len())
}
}
pub struct MagicLinkManager<S: MagicLinkStore = InMemoryMagicLinkStore> {
store: S,
config: MagicLinkConfig,
}
impl MagicLinkManager<InMemoryMagicLinkStore> {
pub fn new(config: MagicLinkConfig) -> Self {
Self {
store: InMemoryMagicLinkStore::new(),
config,
}
}
pub fn with_default_config() -> Self {
Self::new(MagicLinkConfig::default())
}
}
impl<S: MagicLinkStore> MagicLinkManager<S> {
pub fn with_store(store: S, config: MagicLinkConfig) -> Self {
Self { store, config }
}
pub async fn generate(&self, identifier: impl Into<String>) -> Result<MagicLinkData> {
let identifier = identifier.into();
let user_tokens = self.store.get_user_tokens(&identifier).await?;
if user_tokens.len() >= self.config.max_active_per_user {
self.store.delete_oldest_user_token(&identifier).await?;
}
let token = generate_random_base64_url(self.config.token_length)?;
let created_at = Utc::now();
let expires_at = created_at + Duration::seconds(self.config.ttl.as_secs() as i64);
self.store.save(&token, &identifier, expires_at).await?;
Ok(MagicLinkData {
token,
identifier,
created_at,
expires_at,
})
}
pub async fn verify(&self, token: &str) -> Result<String> {
let (identifier, expires_at) = self
.store
.get(token)
.await?
.ok_or_else(|| Error::validation("invalid or expired magic link token"))?;
if Utc::now() > expires_at {
self.store.delete(token).await?;
return Err(Error::validation("magic link token has expired"));
}
if self.config.consume_on_verify {
self.store.delete(token).await?;
}
Ok(identifier)
}
pub async fn revoke(&self, token: &str) -> Result<()> {
self.store.delete(token).await
}
pub async fn revoke_all_for_user(&self, identifier: &str) -> Result<usize> {
let tokens = self.store.get_user_tokens(identifier).await?;
let count = tokens.len();
for token in tokens {
self.store.delete(&token).await?;
}
Ok(count)
}
pub async fn cleanup(&self) -> Result<usize> {
self.store.cleanup_expired().await
}
pub fn config(&self) -> &MagicLinkConfig {
&self.config
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread::sleep;
use std::time::Duration as StdDuration;
#[tokio::test]
async fn test_generate_and_verify() {
let manager = MagicLinkManager::new(MagicLinkConfig::default());
let data = manager.generate("test@example.com").await.unwrap();
assert!(!data.token.is_empty());
assert_eq!(data.identifier, "test@example.com");
assert!(!data.is_expired());
let email = manager.verify(&data.token).await.unwrap();
assert_eq!(email, "test@example.com");
}
#[tokio::test]
async fn test_token_consumed_after_verify() {
let manager = MagicLinkManager::new(MagicLinkConfig::default());
let data = manager.generate("test@example.com").await.unwrap();
assert!(manager.verify(&data.token).await.is_ok());
assert!(manager.verify(&data.token).await.is_err());
}
#[tokio::test]
async fn test_token_not_consumed_when_disabled() {
let config = MagicLinkConfig::default().with_consume_on_verify(false);
let manager = MagicLinkManager::new(config);
let data = manager.generate("test@example.com").await.unwrap();
assert!(manager.verify(&data.token).await.is_ok());
assert!(manager.verify(&data.token).await.is_ok());
assert!(manager.verify(&data.token).await.is_ok());
}
#[tokio::test]
async fn test_token_expiration() {
let config = MagicLinkConfig::default().with_ttl(StdDuration::from_secs(1));
let manager = MagicLinkManager::new(config);
let data = manager.generate("test@example.com").await.unwrap();
assert!(!data.is_expired());
sleep(StdDuration::from_millis(1100));
assert!(manager.verify(&data.token).await.is_err());
}
#[tokio::test]
async fn test_max_active_tokens_per_user() {
let config = MagicLinkConfig::default().with_max_active_per_user(2);
let manager = MagicLinkManager::new(config);
let token1 = manager.generate("user@example.com").await.unwrap();
let token2 = manager.generate("user@example.com").await.unwrap();
let token3 = manager.generate("user@example.com").await.unwrap();
assert!(manager.verify(&token1.token).await.is_err());
assert!(manager.verify(&token2.token).await.is_ok());
assert!(manager.verify(&token3.token).await.is_ok());
}
#[tokio::test]
async fn test_revoke_token() {
let manager = MagicLinkManager::new(MagicLinkConfig::default());
let data = manager.generate("test@example.com").await.unwrap();
manager.revoke(&data.token).await.unwrap();
assert!(manager.verify(&data.token).await.is_err());
}
#[tokio::test]
async fn test_revoke_all_for_user() {
let config = MagicLinkConfig::default()
.with_max_active_per_user(10)
.with_consume_on_verify(false);
let manager = MagicLinkManager::new(config);
let t1 = manager.generate("user@example.com").await.unwrap();
let t2 = manager.generate("user@example.com").await.unwrap();
let t3 = manager.generate("other@example.com").await.unwrap();
let count = manager
.revoke_all_for_user("user@example.com")
.await
.unwrap();
assert_eq!(count, 2);
assert!(manager.verify(&t1.token).await.is_err());
assert!(manager.verify(&t2.token).await.is_err());
assert!(manager.verify(&t3.token).await.is_ok());
}
#[tokio::test]
async fn test_cleanup_expired() {
let config = MagicLinkConfig::default()
.with_ttl(StdDuration::from_secs(1))
.with_max_active_per_user(10);
let manager = MagicLinkManager::new(config);
manager.generate("user1@example.com").await.unwrap();
manager.generate("user2@example.com").await.unwrap();
manager.generate("user3@example.com").await.unwrap();
sleep(StdDuration::from_millis(1100));
let cleaned = manager.cleanup().await.unwrap();
assert_eq!(cleaned, 3);
}
#[tokio::test]
async fn test_different_users_independent() {
let manager = MagicLinkManager::new(MagicLinkConfig::default());
let data1 = manager.generate("user1@example.com").await.unwrap();
let data2 = manager.generate("user2@example.com").await.unwrap();
let email1 = manager.verify(&data1.token).await.unwrap();
assert_eq!(email1, "user1@example.com");
let email2 = manager.verify(&data2.token).await.unwrap();
assert_eq!(email2, "user2@example.com");
}
#[tokio::test]
async fn test_remaining_seconds() {
let config = MagicLinkConfig::default().with_ttl(StdDuration::from_secs(300));
let manager = MagicLinkManager::new(config);
let data = manager.generate("test@example.com").await.unwrap();
let remaining = data.remaining_seconds();
assert!(remaining > 295 && remaining <= 300);
}
#[test]
fn test_high_security_config() {
let config = MagicLinkConfig::high_security();
assert_eq!(config.token_length, 48);
assert_eq!(config.ttl, StdDuration::from_secs(5 * 60));
assert_eq!(config.max_active_per_user, 1);
}
#[test]
fn test_relaxed_config() {
let config = MagicLinkConfig::relaxed();
assert_eq!(config.token_length, 24);
assert_eq!(config.ttl, StdDuration::from_secs(60 * 60));
assert_eq!(config.max_active_per_user, 10);
}
#[tokio::test]
async fn test_invalid_token() {
let manager = MagicLinkManager::new(MagicLinkConfig::default());
assert!(manager.verify("invalid-token").await.is_err());
}
#[tokio::test]
async fn test_store_len_and_is_empty() {
let store = InMemoryMagicLinkStore::new();
assert!(store.is_empty());
assert_eq!(store.len(), 0);
store
.save(
"token1",
"user@example.com",
Utc::now() + Duration::hours(1),
)
.await
.unwrap();
assert!(!store.is_empty());
assert_eq!(store.len(), 1);
}
}