use std::collections::HashSet;
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SecureTokenConfig {
pub field: String,
pub length: usize,
}
impl SecureTokenConfig {
#[must_use]
pub fn new(field: &str) -> Self {
Self {
field: field.to_owned(),
length: 24,
}
}
#[must_use]
pub fn length(mut self, length: usize) -> Self {
self.length = length.max(1);
self
}
}
#[derive(Debug, thiserror::Error, Clone, PartialEq, Eq)]
pub enum SecureTokenError {
#[error("unknown secure token field: {0}")]
UnknownField(String),
#[error("could not generate a unique token for {0}")]
ExhaustedAttempts(String),
#[error("token for {0} must be unique")]
DuplicateToken(String),
}
#[must_use]
pub fn has_secure_token(field: &str) -> SecureTokenConfig {
SecureTokenConfig::new(field)
}
pub trait SecureToken {
fn secure_token_configurations() -> &'static [SecureTokenConfig];
fn get_secure_token(&self, field: &str) -> Option<&str>;
fn set_secure_token(&mut self, field: &str, token: String);
fn generate_token() -> String {
generate_token_with_length(32)
}
fn ensure_secure_tokens(
&mut self,
existing_tokens: &HashSet<String>,
) -> Result<(), SecureTokenError> {
let mut reserved = existing_tokens.clone();
for config in Self::secure_token_configurations() {
match self.get_secure_token(&config.field) {
Some(token) if reserved.contains(token) => {
return Err(SecureTokenError::DuplicateToken(config.field.clone()));
}
Some(token) => {
reserved.insert(token.to_owned());
}
None => {
let token = generate_unique_token(config.length, &reserved, &config.field)?;
reserved.insert(token.clone());
self.set_secure_token(&config.field, token);
}
}
}
Ok(())
}
fn regenerate_token(
&mut self,
field: &str,
existing_tokens: &HashSet<String>,
) -> Result<String, SecureTokenError> {
let config = Self::secure_token_configurations()
.iter()
.find(|config| config.field == field)
.ok_or_else(|| SecureTokenError::UnknownField(field.to_owned()))?;
let mut reserved = existing_tokens.clone();
if let Some(current) = self.get_secure_token(field) {
reserved.remove(current);
}
let token = generate_unique_token(config.length, &reserved, field)?;
self.set_secure_token(field, token.clone());
Ok(token)
}
}
fn generate_unique_token(
length: usize,
existing_tokens: &HashSet<String>,
field: &str,
) -> Result<String, SecureTokenError> {
for _ in 0..32 {
let token = generate_token_with_length(length);
if !existing_tokens.contains(&token) {
return Ok(token);
}
}
Err(SecureTokenError::ExhaustedAttempts(field.to_owned()))
}
fn generate_token_with_length(length: usize) -> String {
let mut token = String::new();
while token.len() < length {
let bytes: [u8; 24] = rand::random();
token.push_str(&URL_SAFE_NO_PAD.encode(bytes));
}
token.truncate(length);
token
}
#[cfg(test)]
mod tests {
use std::collections::HashSet;
use std::sync::LazyLock;
use super::{SecureToken, SecureTokenConfig, SecureTokenError, has_secure_token};
#[derive(Debug, Default)]
struct ApiKeyRecord {
token: Option<String>,
recovery_token: Option<String>,
}
static TOKEN_CONFIGS: LazyLock<Vec<SecureTokenConfig>> = LazyLock::new(|| {
vec![
has_secure_token("token"),
has_secure_token("recovery_token").length(12),
]
});
impl SecureToken for ApiKeyRecord {
fn secure_token_configurations() -> &'static [SecureTokenConfig] {
TOKEN_CONFIGS.as_slice()
}
fn get_secure_token(&self, field: &str) -> Option<&str> {
match field {
"token" => self.token.as_deref(),
"recovery_token" => self.recovery_token.as_deref(),
_ => None,
}
}
fn set_secure_token(&mut self, field: &str, token: String) {
match field {
"token" => self.token = Some(token),
"recovery_token" => self.recovery_token = Some(token),
_ => {}
}
}
}
#[test]
fn ensure_secure_tokens_generates_missing_tokens() {
let mut record = ApiKeyRecord::default();
record
.ensure_secure_tokens(&HashSet::new())
.expect("tokens should be generated");
assert!(record.token.is_some());
assert!(record.recovery_token.is_some());
assert_eq!(record.recovery_token.as_deref().map(str::len), Some(12));
}
#[test]
fn ensure_secure_tokens_preserves_unique_existing_tokens() {
let mut record = ApiKeyRecord {
token: Some("existing-token".to_owned()),
recovery_token: None,
};
record
.ensure_secure_tokens(&HashSet::new())
.expect("existing token should be preserved");
assert_eq!(record.token.as_deref(), Some("existing-token"));
assert!(record.recovery_token.is_some());
}
#[test]
fn ensure_secure_tokens_rejects_duplicate_existing_tokens() {
let mut record = ApiKeyRecord {
token: Some("taken".to_owned()),
recovery_token: None,
};
let existing = HashSet::from(["taken".to_owned()]);
assert_eq!(
record.ensure_secure_tokens(&existing),
Err(SecureTokenError::DuplicateToken("token".to_owned()))
);
}
#[test]
fn regenerate_token_replaces_existing_value() {
let mut record = ApiKeyRecord {
token: Some("current".to_owned()),
recovery_token: None,
};
let token = record
.regenerate_token("token", &HashSet::new())
.expect("token should regenerate");
assert_eq!(record.token.as_deref(), Some(token.as_str()));
assert_ne!(token, "current");
}
#[test]
fn regenerate_token_rejects_unknown_fields() {
let mut record = ApiKeyRecord::default();
assert_eq!(
record.regenerate_token("missing", &HashSet::new()),
Err(SecureTokenError::UnknownField("missing".to_owned()))
);
}
#[test]
fn metadata_builder_preserves_length_overrides() {
let config = has_secure_token("auth_token").length(10);
assert_eq!(config.field, "auth_token");
assert_eq!(config.length, 10);
}
#[test]
fn generate_token_returns_minimum_length() {
let token = <ApiKeyRecord as SecureToken>::generate_token();
assert!(token.len() >= 32);
}
#[test]
fn ensure_secure_tokens_still_honors_length_overrides() {
let mut record = ApiKeyRecord::default();
record
.ensure_secure_tokens(&HashSet::new())
.expect("tokens should be generated");
assert_eq!(record.recovery_token.as_deref().map(str::len), Some(12));
}
}