use std::collections::HashMap;
use std::sync::Mutex;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum TokenKind {
PasswordReset,
EmailChange,
MagicLink,
}
impl TokenKind {
pub fn as_str(&self) -> &'static str {
match self {
Self::PasswordReset => "password_reset",
Self::EmailChange => "email_change",
Self::MagicLink => "magic_link",
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct VerificationToken {
pub id: String,
pub kind: TokenKind,
pub email: String,
pub user_id: Option<String>,
pub payload: Option<String>,
pub token_hash: String,
pub token_prefix: String,
pub created_at: u64,
pub expires_at: u64,
pub consumed_at: Option<u64>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum VerificationError {
NotFound,
Expired,
AlreadyConsumed,
KindMismatch,
}
impl std::fmt::Display for VerificationError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(match self {
Self::NotFound => "verification token not found",
Self::Expired => "verification token expired",
Self::AlreadyConsumed => "verification token already consumed",
Self::KindMismatch => "verification token is for a different flow",
})
}
}
pub trait VerificationBackend: Send + Sync {
fn put(&self, token: &VerificationToken);
fn get(&self, id: &str) -> Option<VerificationToken>;
fn by_prefix(&self, prefix: &str) -> Vec<VerificationToken>;
fn mark_consumed(&self, id: &str, now: u64) -> bool;
fn purge_expired(&self, now: u64);
}
pub struct InMemoryVerificationBackend {
tokens: Mutex<HashMap<String, VerificationToken>>,
}
impl Default for InMemoryVerificationBackend {
fn default() -> Self {
Self {
tokens: Mutex::new(HashMap::new()),
}
}
}
impl VerificationBackend for InMemoryVerificationBackend {
fn put(&self, token: &VerificationToken) {
self.tokens
.lock()
.unwrap()
.insert(token.id.clone(), token.clone());
}
fn get(&self, id: &str) -> Option<VerificationToken> {
self.tokens.lock().unwrap().get(id).cloned()
}
fn by_prefix(&self, prefix: &str) -> Vec<VerificationToken> {
self.tokens
.lock()
.unwrap()
.values()
.filter(|t| t.token_prefix == prefix)
.cloned()
.collect()
}
fn mark_consumed(&self, id: &str, now: u64) -> bool {
let mut map = self.tokens.lock().unwrap();
let Some(t) = map.get_mut(id) else {
return false;
};
if t.consumed_at.is_some() {
return false;
}
t.consumed_at = Some(now);
true
}
fn purge_expired(&self, now: u64) {
let mut map = self.tokens.lock().unwrap();
map.retain(|_, t| t.expires_at > now || t.consumed_at.is_none());
}
}
pub struct VerificationStore {
backend: Box<dyn VerificationBackend>,
}
impl Default for VerificationStore {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct MintedToken {
pub token: VerificationToken,
pub plaintext: String,
}
impl VerificationStore {
const PASSWORD_RESET_TTL_SECS: u64 = 30 * 60; const MAGIC_LINK_TTL_SECS: u64 = 15 * 60; const EMAIL_CHANGE_TTL_SECS: u64 = 24 * 60 * 60;
pub fn new() -> Self {
Self::with_backend(Box::new(InMemoryVerificationBackend::default()))
}
pub fn with_backend(backend: Box<dyn VerificationBackend>) -> Self {
Self { backend }
}
pub fn mint(
&self,
kind: TokenKind,
email: &str,
user_id: Option<String>,
payload: Option<String>,
) -> MintedToken {
let id = format!("vt_{}", random_token(20));
let plaintext = random_token(32);
let prefix: String = plaintext.chars().take(8).collect();
let token_hash = hash_token(&plaintext);
let now = now_secs();
let ttl = match kind {
TokenKind::PasswordReset => Self::PASSWORD_RESET_TTL_SECS,
TokenKind::MagicLink => Self::MAGIC_LINK_TTL_SECS,
TokenKind::EmailChange => Self::EMAIL_CHANGE_TTL_SECS,
};
let token = VerificationToken {
id,
kind,
email: email.to_lowercase(),
user_id,
payload,
token_hash,
token_prefix: prefix,
created_at: now,
expires_at: now + ttl,
consumed_at: None,
};
self.backend.put(&token);
MintedToken {
token,
plaintext,
}
}
pub fn consume(
&self,
plaintext: &str,
expected_kind: TokenKind,
) -> Result<VerificationToken, VerificationError> {
let prefix: String = plaintext.chars().take(8).collect();
let expected_hash = hash_token(plaintext);
let candidates = self.backend.by_prefix(&prefix);
let now = now_secs();
for t in candidates {
if !crate::constant_time_eq(t.token_hash.as_bytes(), expected_hash.as_bytes()) {
continue;
}
if t.kind != expected_kind {
return Err(VerificationError::KindMismatch);
}
if t.consumed_at.is_some() {
return Err(VerificationError::AlreadyConsumed);
}
if t.expires_at <= now {
return Err(VerificationError::Expired);
}
if !self.backend.mark_consumed(&t.id, now) {
return Err(VerificationError::AlreadyConsumed);
}
return Ok(t);
}
Err(VerificationError::NotFound)
}
pub fn purge_expired(&self) {
self.backend.purge_expired(now_secs());
}
}
fn random_token(n_bytes: usize) -> String {
use rand::RngCore;
let mut bytes = vec![0u8; n_bytes];
rand::thread_rng().fill_bytes(&mut bytes);
use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
URL_SAFE_NO_PAD.encode(bytes)
}
fn hash_token(plaintext: &str) -> String {
use hmac::{Hmac, Mac};
use sha2::Sha256;
type HmacSha256 = Hmac<Sha256>;
let pepper = std::env::var("PYLON_API_KEY_PEPPER")
.unwrap_or_else(|_| "pylon-dev-api-key-pepper-not-for-production".into());
let mut mac = HmacSha256::new_from_slice(pepper.as_bytes())
.expect("HMAC accepts any key length");
mac.update(plaintext.as_bytes());
let out = mac.finalize().into_bytes();
use std::fmt::Write;
let mut s = String::with_capacity(64);
for b in out {
let _ = write!(s, "{b:02x}");
}
s
}
fn now_secs() -> u64 {
use std::time::{SystemTime, UNIX_EPOCH};
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn mint_and_consume_round_trip() {
let store = VerificationStore::new();
let minted = store.mint(
TokenKind::PasswordReset,
"alice@example.com",
None,
None,
);
let consumed = store
.consume(&minted.plaintext, TokenKind::PasswordReset)
.expect("consume");
assert_eq!(consumed.id, minted.token.id);
assert_eq!(consumed.email, "alice@example.com");
}
#[test]
fn consume_is_single_use() {
let store = VerificationStore::new();
let minted = store.mint(TokenKind::MagicLink, "a@b.com", None, None);
store.consume(&minted.plaintext, TokenKind::MagicLink).unwrap();
let err = store
.consume(&minted.plaintext, TokenKind::MagicLink)
.unwrap_err();
assert_eq!(err, VerificationError::AlreadyConsumed);
}
#[test]
fn cross_kind_replay_rejected() {
let store = VerificationStore::new();
let minted = store.mint(TokenKind::MagicLink, "a@b.com", None, None);
let err = store
.consume(&minted.plaintext, TokenKind::PasswordReset)
.unwrap_err();
assert_eq!(err, VerificationError::KindMismatch);
}
#[test]
fn unknown_token_returns_not_found() {
let store = VerificationStore::new();
let err = store
.consume("nonexistent_plaintext_xxxxxxxxxxxxxxxxxxxx", TokenKind::PasswordReset)
.unwrap_err();
assert_eq!(err, VerificationError::NotFound);
}
#[test]
fn email_lowercased_at_mint() {
let store = VerificationStore::new();
let minted = store.mint(TokenKind::MagicLink, "MIXED@CASE.com", None, None);
assert_eq!(minted.token.email, "mixed@case.com");
}
#[test]
fn payload_round_trips() {
let store = VerificationStore::new();
let minted = store.mint(
TokenKind::EmailChange,
"new@example.com",
Some("user-1".into()),
Some("new@example.com".into()),
);
let consumed = store
.consume(&minted.plaintext, TokenKind::EmailChange)
.unwrap();
assert_eq!(consumed.payload.as_deref(), Some("new@example.com"));
assert_eq!(consumed.user_id.as_deref(), Some("user-1"));
}
#[test]
fn expired_token_rejected() {
let store = VerificationStore::new();
let minted = store.mint(TokenKind::MagicLink, "a@b.com", None, None);
let backend = InMemoryVerificationBackend::default();
let mut expired = minted.token.clone();
expired.expires_at = 1;
backend.put(&expired);
let store2 = VerificationStore::with_backend(Box::new(backend));
let err = store2
.consume(&minted.plaintext, TokenKind::MagicLink)
.unwrap_err();
assert_eq!(err, VerificationError::Expired);
}
}