use async_trait::async_trait;
use chrono::{DateTime, Duration, Utc};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use crate::error::{Error, Result};
use crate::random::{constant_time_compare_str, generate_random_in_range};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum OtpPurpose {
Login,
Registration,
PasswordReset,
EmailVerification,
PhoneVerification,
TransactionConfirmation,
TwoFactor,
Custom(u8),
}
impl std::fmt::Display for OtpPurpose {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
OtpPurpose::Login => write!(f, "login"),
OtpPurpose::Registration => write!(f, "registration"),
OtpPurpose::PasswordReset => write!(f, "password_reset"),
OtpPurpose::EmailVerification => write!(f, "email_verification"),
OtpPurpose::PhoneVerification => write!(f, "phone_verification"),
OtpPurpose::TransactionConfirmation => write!(f, "transaction_confirmation"),
OtpPurpose::TwoFactor => write!(f, "two_factor"),
OtpPurpose::Custom(id) => write!(f, "custom_{}", id),
}
}
}
#[derive(Debug, Clone)]
pub struct OtpConfig {
pub code_length: usize,
pub ttl: std::time::Duration,
pub max_attempts: u32,
pub consume_on_verify: bool,
pub allow_multiple: bool,
pub min_interval: Option<std::time::Duration>,
}
impl Default for OtpConfig {
fn default() -> Self {
Self {
code_length: 6,
ttl: std::time::Duration::from_secs(5 * 60), max_attempts: 3,
consume_on_verify: true,
allow_multiple: false,
min_interval: Some(std::time::Duration::from_secs(60)), }
}
}
impl OtpConfig {
pub fn new() -> Self {
Self::default()
}
pub fn with_code_length(mut self, length: usize) -> Self {
assert!(
(4..=10).contains(&length),
"code length must be between 4 and 10"
);
self.code_length = length;
self
}
pub fn with_ttl(mut self, ttl: std::time::Duration) -> Self {
self.ttl = ttl;
self
}
pub fn with_max_attempts(mut self, max: u32) -> Self {
self.max_attempts = max;
self
}
pub fn with_consume_on_verify(mut self, consume: bool) -> Self {
self.consume_on_verify = consume;
self
}
pub fn with_allow_multiple(mut self, allow: bool) -> Self {
self.allow_multiple = allow;
self
}
pub fn with_min_interval(mut self, interval: Option<std::time::Duration>) -> Self {
self.min_interval = interval;
self
}
pub fn high_security() -> Self {
Self {
code_length: 8,
ttl: std::time::Duration::from_secs(3 * 60),
max_attempts: 3,
consume_on_verify: true,
allow_multiple: false,
min_interval: Some(std::time::Duration::from_secs(120)),
}
}
pub fn relaxed() -> Self {
Self {
code_length: 4,
ttl: std::time::Duration::from_secs(30 * 60),
max_attempts: 10,
consume_on_verify: true,
allow_multiple: true,
min_interval: None,
}
}
}
#[derive(Debug, Clone)]
pub struct OtpData {
pub code: String,
pub identifier: String,
pub purpose: OtpPurpose,
pub created_at: DateTime<Utc>,
pub expires_at: DateTime<Utc>,
pub remaining_attempts: u32,
}
impl OtpData {
pub fn is_expired(&self) -> bool {
Utc::now() > self.expires_at
}
pub fn remaining_seconds(&self) -> i64 {
let remaining = self.expires_at - Utc::now();
remaining.num_seconds().max(0)
}
}
#[derive(Debug, Clone)]
pub struct StoredOtp {
pub code: String,
pub expires_at: DateTime<Utc>,
pub created_at: DateTime<Utc>,
pub remaining_attempts: u32,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
struct OtpKey {
identifier: String,
purpose: OtpPurpose,
}
#[async_trait]
pub trait OtpStore: Send + Sync {
async fn save(
&self,
identifier: &str,
purpose: OtpPurpose,
code: &str,
expires_at: DateTime<Utc>,
max_attempts: u32,
) -> Result<()>;
async fn get(&self, identifier: &str, purpose: OtpPurpose) -> Result<Option<StoredOtp>>;
async fn decrement_attempts(&self, identifier: &str, purpose: OtpPurpose) -> Result<()>;
async fn delete(&self, identifier: &str, purpose: OtpPurpose) -> Result<()>;
async fn get_last_generated(
&self,
identifier: &str,
purpose: OtpPurpose,
) -> Result<Option<DateTime<Utc>>>;
async fn cleanup_expired(&self) -> Result<usize>;
}
#[derive(Debug, Clone, Default)]
pub struct InMemoryOtpStore {
records: Arc<RwLock<HashMap<OtpKey, StoredOtp>>>,
}
impl InMemoryOtpStore {
pub fn new() -> Self {
Self::default()
}
pub fn len(&self) -> usize {
self.records.read().unwrap().len()
}
pub fn is_empty(&self) -> bool {
self.records.read().unwrap().is_empty()
}
}
#[async_trait]
impl OtpStore for InMemoryOtpStore {
async fn save(
&self,
identifier: &str,
purpose: OtpPurpose,
code: &str,
expires_at: DateTime<Utc>,
max_attempts: u32,
) -> Result<()> {
let mut records = self.records.write().unwrap();
let key = OtpKey {
identifier: identifier.to_string(),
purpose,
};
records.insert(
key,
StoredOtp {
code: code.to_string(),
expires_at,
created_at: Utc::now(),
remaining_attempts: max_attempts,
},
);
Ok(())
}
async fn get(&self, identifier: &str, purpose: OtpPurpose) -> Result<Option<StoredOtp>> {
let records = self.records.read().unwrap();
let key = OtpKey {
identifier: identifier.to_string(),
purpose,
};
Ok(records.get(&key).cloned())
}
async fn decrement_attempts(&self, identifier: &str, purpose: OtpPurpose) -> Result<()> {
let mut records = self.records.write().unwrap();
let key = OtpKey {
identifier: identifier.to_string(),
purpose,
};
if let Some(record) = records.get_mut(&key) {
record.remaining_attempts = record.remaining_attempts.saturating_sub(1);
}
Ok(())
}
async fn delete(&self, identifier: &str, purpose: OtpPurpose) -> Result<()> {
let mut records = self.records.write().unwrap();
let key = OtpKey {
identifier: identifier.to_string(),
purpose,
};
records.remove(&key);
Ok(())
}
async fn get_last_generated(
&self,
identifier: &str,
purpose: OtpPurpose,
) -> Result<Option<DateTime<Utc>>> {
let records = self.records.read().unwrap();
let key = OtpKey {
identifier: identifier.to_string(),
purpose,
};
Ok(records.get(&key).map(|r| r.created_at))
}
async fn cleanup_expired(&self) -> Result<usize> {
let mut records = self.records.write().unwrap();
let now = Utc::now();
let before = records.len();
records.retain(|_, record| record.expires_at > now);
Ok(before - records.len())
}
}
pub struct OtpManager<S: OtpStore = InMemoryOtpStore> {
store: S,
config: OtpConfig,
}
impl OtpManager<InMemoryOtpStore> {
pub fn new(config: OtpConfig) -> Self {
Self {
store: InMemoryOtpStore::new(),
config,
}
}
pub fn with_default_config() -> Self {
Self::new(OtpConfig::default())
}
}
impl<S: OtpStore> OtpManager<S> {
pub fn with_store(store: S, config: OtpConfig) -> Self {
Self { store, config }
}
fn generate_code(&self) -> String {
let min = 10u64.pow((self.config.code_length - 1) as u32);
let max = 10u64.pow(self.config.code_length as u32);
let code = generate_random_in_range(min, max);
format!("{:0>width$}", code, width = self.config.code_length)
}
pub async fn generate(
&self,
identifier: impl Into<String>,
purpose: OtpPurpose,
) -> Result<OtpData> {
let identifier = identifier.into();
if let Some(min_interval) = self.config.min_interval
&& let Some(last_generated) =
self.store.get_last_generated(&identifier, purpose).await?
{
let elapsed = Utc::now() - last_generated;
let min_seconds = min_interval.as_secs() as i64;
if elapsed.num_seconds() < min_seconds {
let wait_seconds = min_seconds - elapsed.num_seconds();
return Err(Error::validation(format!(
"please wait {} seconds before requesting a new code",
wait_seconds
)));
}
}
if !self.config.allow_multiple {
self.store.delete(&identifier, purpose).await?;
}
let code = self.generate_code();
let created_at = Utc::now();
let expires_at = created_at + Duration::seconds(self.config.ttl.as_secs() as i64);
self.store
.save(
&identifier,
purpose,
&code,
expires_at,
self.config.max_attempts,
)
.await?;
Ok(OtpData {
code,
identifier,
purpose,
created_at,
expires_at,
remaining_attempts: self.config.max_attempts,
})
}
pub async fn verify(&self, identifier: &str, code: &str, purpose: OtpPurpose) -> Result<()> {
let stored = self
.store
.get(identifier, purpose)
.await?
.ok_or_else(|| Error::validation("no OTP found for this identifier and purpose"))?;
if Utc::now() > stored.expires_at {
self.store.delete(identifier, purpose).await?;
return Err(Error::validation("OTP has expired"));
}
if stored.remaining_attempts == 0 {
self.store.delete(identifier, purpose).await?;
return Err(Error::validation("maximum attempts exceeded"));
}
if !constant_time_compare_str(code, &stored.code) {
self.store.decrement_attempts(identifier, purpose).await?;
let remaining = stored.remaining_attempts.saturating_sub(1);
if remaining == 0 {
self.store.delete(identifier, purpose).await?;
return Err(Error::validation("invalid OTP, maximum attempts exceeded"));
}
return Err(Error::validation(format!(
"invalid OTP, {} attempts remaining",
remaining
)));
}
if self.config.consume_on_verify {
self.store.delete(identifier, purpose).await?;
}
Ok(())
}
pub async fn can_generate(&self, identifier: &str, purpose: OtpPurpose) -> Result<bool> {
if let Some(min_interval) = self.config.min_interval
&& let Some(last_generated) = self.store.get_last_generated(identifier, purpose).await?
{
let elapsed = Utc::now() - last_generated;
let min_seconds = min_interval.as_secs() as i64;
return Ok(elapsed.num_seconds() >= min_seconds);
}
Ok(true)
}
pub async fn seconds_until_can_generate(
&self,
identifier: &str,
purpose: OtpPurpose,
) -> Result<i64> {
if let Some(min_interval) = self.config.min_interval
&& let Some(last_generated) = self.store.get_last_generated(identifier, purpose).await?
{
let elapsed = Utc::now() - last_generated;
let min_seconds = min_interval.as_secs() as i64;
let remaining = min_seconds - elapsed.num_seconds();
return Ok(remaining.max(0));
}
Ok(0)
}
pub async fn revoke(&self, identifier: &str, purpose: OtpPurpose) -> Result<()> {
self.store.delete(identifier, purpose).await
}
pub async fn cleanup(&self) -> Result<usize> {
self.store.cleanup_expired().await
}
pub fn config(&self) -> &OtpConfig {
&self.config
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread::sleep;
use std::time::Duration as StdDuration;
#[tokio::test]
async fn test_generate_and_verify() {
let manager = OtpManager::new(OtpConfig::default().with_min_interval(None));
let otp = manager
.generate("user@example.com", OtpPurpose::Login)
.await
.unwrap();
assert_eq!(otp.code.len(), 6);
assert_eq!(otp.identifier, "user@example.com");
assert_eq!(otp.purpose, OtpPurpose::Login);
assert!(!otp.is_expired());
assert!(
manager
.verify("user@example.com", &otp.code, OtpPurpose::Login)
.await
.is_ok()
);
}
#[tokio::test]
async fn test_otp_consumed_after_verify() {
let manager = OtpManager::new(OtpConfig::default().with_min_interval(None));
let otp = manager
.generate("user@example.com", OtpPurpose::Login)
.await
.unwrap();
assert!(
manager
.verify("user@example.com", &otp.code, OtpPurpose::Login)
.await
.is_ok()
);
assert!(
manager
.verify("user@example.com", &otp.code, OtpPurpose::Login)
.await
.is_err()
);
}
#[tokio::test]
async fn test_otp_not_consumed_when_disabled() {
let config = OtpConfig::default()
.with_consume_on_verify(false)
.with_min_interval(None);
let manager = OtpManager::new(config);
let otp = manager
.generate("user@example.com", OtpPurpose::Login)
.await
.unwrap();
assert!(
manager
.verify("user@example.com", &otp.code, OtpPurpose::Login)
.await
.is_ok()
);
assert!(
manager
.verify("user@example.com", &otp.code, OtpPurpose::Login)
.await
.is_ok()
);
}
#[tokio::test]
async fn test_wrong_code() {
let manager = OtpManager::new(OtpConfig::default().with_min_interval(None));
let otp = manager
.generate("user@example.com", OtpPurpose::Login)
.await
.unwrap();
assert!(
manager
.verify("user@example.com", "000000", OtpPurpose::Login)
.await
.is_err()
);
assert!(
manager
.verify("user@example.com", &otp.code, OtpPurpose::Login)
.await
.is_ok()
);
}
#[tokio::test]
async fn test_max_attempts() {
let config = OtpConfig::default()
.with_max_attempts(2)
.with_min_interval(None);
let manager = OtpManager::new(config);
let otp = manager
.generate("user@example.com", OtpPurpose::Login)
.await
.unwrap();
assert!(
manager
.verify("user@example.com", "000000", OtpPurpose::Login)
.await
.is_err()
);
assert!(
manager
.verify("user@example.com", "000000", OtpPurpose::Login)
.await
.is_err()
);
assert!(
manager
.verify("user@example.com", &otp.code, OtpPurpose::Login)
.await
.is_err()
);
}
#[tokio::test]
async fn test_otp_expiration() {
let config = OtpConfig::default()
.with_ttl(StdDuration::from_millis(100))
.with_min_interval(None);
let manager = OtpManager::new(config);
let otp = manager
.generate("user@example.com", OtpPurpose::Login)
.await
.unwrap();
sleep(StdDuration::from_millis(150));
assert!(
manager
.verify("user@example.com", &otp.code, OtpPurpose::Login)
.await
.is_err()
);
}
#[tokio::test]
async fn test_different_purposes_independent() {
let manager = OtpManager::new(OtpConfig::default().with_min_interval(None));
let login_otp = manager
.generate("user@example.com", OtpPurpose::Login)
.await
.unwrap();
let reset_otp = manager
.generate("user@example.com", OtpPurpose::PasswordReset)
.await
.unwrap();
assert!(
manager
.verify(
"user@example.com",
&login_otp.code,
OtpPurpose::PasswordReset
)
.await
.is_err()
);
assert!(
manager
.verify("user@example.com", &reset_otp.code, OtpPurpose::Login)
.await
.is_err()
);
assert!(
manager
.verify("user@example.com", &login_otp.code, OtpPurpose::Login)
.await
.is_ok()
);
assert!(
manager
.verify(
"user@example.com",
&reset_otp.code,
OtpPurpose::PasswordReset
)
.await
.is_ok()
);
}
#[tokio::test]
async fn test_min_interval() {
let config = OtpConfig::default().with_min_interval(Some(StdDuration::from_secs(1)));
let manager = OtpManager::new(config);
manager
.generate("user@example.com", OtpPurpose::Login)
.await
.unwrap();
assert!(
manager
.generate("user@example.com", OtpPurpose::Login)
.await
.is_err()
);
sleep(StdDuration::from_millis(1100));
assert!(
manager
.generate("user@example.com", OtpPurpose::Login)
.await
.is_ok()
);
}
#[tokio::test]
async fn test_code_length() {
let config = OtpConfig::default()
.with_code_length(8)
.with_min_interval(None);
let manager = OtpManager::new(config);
let otp = manager
.generate("user@example.com", OtpPurpose::Login)
.await
.unwrap();
assert_eq!(otp.code.len(), 8);
}
#[tokio::test]
async fn test_revoke() {
let manager = OtpManager::new(OtpConfig::default().with_min_interval(None));
let otp = manager
.generate("user@example.com", OtpPurpose::Login)
.await
.unwrap();
manager
.revoke("user@example.com", OtpPurpose::Login)
.await
.unwrap();
assert!(
manager
.verify("user@example.com", &otp.code, OtpPurpose::Login)
.await
.is_err()
);
}
#[tokio::test]
async fn test_cleanup_expired() {
let config = OtpConfig::default()
.with_ttl(StdDuration::from_millis(100))
.with_min_interval(None);
let manager = OtpManager::new(config);
manager
.generate("user1@example.com", OtpPurpose::Login)
.await
.unwrap();
manager
.generate("user2@example.com", OtpPurpose::Login)
.await
.unwrap();
sleep(StdDuration::from_millis(150));
let cleaned = manager.cleanup().await.unwrap();
assert_eq!(cleaned, 2);
}
#[tokio::test]
async fn test_can_generate() {
let config = OtpConfig::default().with_min_interval(Some(StdDuration::from_secs(1)));
let manager = OtpManager::new(config);
assert!(
manager
.can_generate("user@example.com", OtpPurpose::Login)
.await
.unwrap()
);
manager
.generate("user@example.com", OtpPurpose::Login)
.await
.unwrap();
assert!(
!manager
.can_generate("user@example.com", OtpPurpose::Login)
.await
.unwrap()
);
sleep(StdDuration::from_millis(1100));
assert!(
manager
.can_generate("user@example.com", OtpPurpose::Login)
.await
.unwrap()
);
}
#[tokio::test]
async fn test_seconds_until_can_generate() {
let config = OtpConfig::default().with_min_interval(Some(StdDuration::from_secs(60)));
let manager = OtpManager::new(config);
assert_eq!(
manager
.seconds_until_can_generate("user@example.com", OtpPurpose::Login)
.await
.unwrap(),
0
);
manager
.generate("user@example.com", OtpPurpose::Login)
.await
.unwrap();
let seconds = manager
.seconds_until_can_generate("user@example.com", OtpPurpose::Login)
.await
.unwrap();
assert!(seconds > 55 && seconds <= 60);
}
#[test]
fn test_high_security_config() {
let config = OtpConfig::high_security();
assert_eq!(config.code_length, 8);
assert_eq!(config.ttl, StdDuration::from_secs(3 * 60));
assert_eq!(config.max_attempts, 3);
}
#[test]
fn test_relaxed_config() {
let config = OtpConfig::relaxed();
assert_eq!(config.code_length, 4);
assert_eq!(config.ttl, StdDuration::from_secs(30 * 60));
assert_eq!(config.max_attempts, 10);
}
#[test]
fn test_otp_purpose_display() {
assert_eq!(OtpPurpose::Login.to_string(), "login");
assert_eq!(OtpPurpose::PasswordReset.to_string(), "password_reset");
assert_eq!(OtpPurpose::Custom(42).to_string(), "custom_42");
}
#[tokio::test]
async fn test_remaining_seconds() {
let config = OtpConfig::default()
.with_ttl(StdDuration::from_secs(300))
.with_min_interval(None);
let manager = OtpManager::new(config);
let otp = manager
.generate("user@example.com", OtpPurpose::Login)
.await
.unwrap();
let remaining = otp.remaining_seconds();
assert!(remaining > 295 && remaining <= 300);
}
#[tokio::test]
async fn test_store_len_and_is_empty() {
let store = InMemoryOtpStore::new();
assert!(store.is_empty());
assert_eq!(store.len(), 0);
store
.save(
"user@example.com",
OtpPurpose::Login,
"123456",
Utc::now() + Duration::hours(1),
3,
)
.await
.unwrap();
assert!(!store.is_empty());
assert_eq!(store.len(), 1);
}
}