use crate::error::InitError;
use crate::validator::KeyStatus;
use crate::{
config::{Environment, KeyConfig, KeyPrefix},
error::Result,
generator::KeyGenerator,
hasher::KeyHasher,
secure::SecureString,
validator::KeyValidator,
ExposeSecret, HashConfig,
};
use chrono::{DateTime, Utc};
use derive_getters::Getters;
use std::fmt::Debug;
#[derive(Clone, Getters)]
pub struct ApiKeyManagerV0 {
#[getter(skip)]
generator: KeyGenerator,
hasher: KeyHasher,
#[getter(skip)]
validator: KeyValidator,
#[getter(skip)]
include_checksum: bool,
#[getter(skip)]
expiry_grace_period: std::time::Duration,
}
#[derive(Debug, Getters, PartialEq)]
pub struct Hash {
key_id: String,
hash: String,
}
#[derive(Debug)]
pub struct NoHash;
#[derive(Debug)]
pub struct ApiKey<Hash> {
key: SecureString,
hash: Hash,
}
impl ApiKeyManagerV0 {
pub fn init(
prefix: impl Into<String>,
config: KeyConfig,
hash_config: HashConfig,
expiry_grace_period: std::time::Duration,
) -> std::result::Result<Self, InitError> {
let include_checksum = *config.checksum_length() != 0;
let prefix = KeyPrefix::new(prefix)?;
let generator = KeyGenerator::new(prefix, config)?;
let hasher = KeyHasher::new(hash_config);
let dummy_key = generator.dummy_key().clone();
let (_dummy_key_id, dummy_hash) = hasher.hash(&dummy_key)?;
let validator = KeyValidator::new(include_checksum, dummy_key, dummy_hash)?;
Ok(Self {
generator,
hasher,
validator,
include_checksum,
expiry_grace_period,
})
}
pub fn init_default_config(prefix: impl Into<String>) -> std::result::Result<Self, InitError> {
Self::init(
prefix,
KeyConfig::default(),
HashConfig::default(),
std::time::Duration::from_secs(10),
)
}
pub fn init_high_security_config(
prefix: impl Into<String>,
) -> std::result::Result<Self, InitError> {
Self::init(
prefix,
KeyConfig::high_security(),
HashConfig::high_security(),
std::time::Duration::from_secs(10),
)
}
pub fn generate(&self, environment: impl Into<Environment>) -> Result<ApiKey<Hash>> {
let key = self.generator.generate(environment.into(), None)?;
let api_key = ApiKey::new(key).into_hashed(&self.hasher)?;
Ok(api_key)
}
pub fn generate_with_expiry(
&self,
environment: impl Into<Environment>,
expiry: DateTime<Utc>,
) -> Result<ApiKey<Hash>> {
let key = self.generator.generate(environment.into(), Some(expiry))?;
let api_key = ApiKey::new(key).into_hashed(&self.hasher)?;
Ok(api_key)
}
pub fn verify(&self, key: &SecureString, stored_hash: impl AsRef<str>) -> Result<KeyStatus> {
if self.include_checksum && !self.verify_checksum(key)? {
return Ok(KeyStatus::Invalid);
}
self.validator.verify(
key.expose_secret(),
stored_hash.as_ref(),
self.expiry_grace_period,
)
}
pub fn verify_checksum(&self, key: &SecureString) -> Result<bool> {
self.generator.verify_checksum(key)
}
pub fn extract_key_id(&self, key: &SecureString) -> String {
self.hasher.generate_key_id(key)
}
}
impl<T> ApiKey<T> {
pub fn key(&self) -> &SecureString {
&self.key
}
}
impl ApiKey<NoHash> {
pub fn new(key: SecureString) -> ApiKey<NoHash> {
ApiKey { key, hash: NoHash }
}
pub fn into_hashed(self, hasher: &KeyHasher) -> Result<ApiKey<Hash>> {
let (key_id, hash) = hasher.hash(&self.key)?;
Ok(ApiKey {
key: self.key,
hash: Hash { key_id, hash },
})
}
pub fn into_hashed_with_phc(self, hasher: &KeyHasher, phc_hash: &str) -> Result<ApiKey<Hash>> {
let (key_id, hash) = hasher.hash_with_phc(&self.key, phc_hash)?;
Ok(ApiKey {
key: self.key,
hash: Hash { key_id, hash },
})
}
pub fn into_key(self) -> SecureString {
self.key
}
}
impl ApiKey<Hash> {
pub fn expose_hash(&self) -> &Hash {
&self.hash
}
pub fn into_key(self) -> SecureString {
self.key
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{ExposeSecret, SecureStringExt};
#[test]
fn test_full_lifecycle() {
let generator = ApiKeyManagerV0::init_default_config("sk").unwrap();
let api_key = generator.generate(Environment::production()).unwrap();
let key_str = api_key.key();
let hash_str = api_key.expose_hash().hash();
assert!(key_str.expose_secret().starts_with("sk-live-"));
assert!(hash_str.starts_with("$argon2id$"));
assert_eq!(
generator.verify(key_str, hash_str).unwrap(),
KeyStatus::Valid
);
let wrong_key = SecureString::from("wrong_key".to_string());
assert_eq!(
generator.verify(&wrong_key, hash_str).unwrap(),
KeyStatus::Invalid
);
}
#[test]
fn test_different_presets() {
let balanced_gen = ApiKeyManagerV0::init_default_config("pk").unwrap();
let balanced = balanced_gen.generate(Environment::test()).unwrap();
let high_sec_gen = ApiKeyManagerV0::init_high_security_config("sk").unwrap();
let high_sec = high_sec_gen.generate(Environment::Production).unwrap();
assert!(!balanced.key().is_empty());
assert!(high_sec.key().len() > balanced.key().len());
}
#[test]
fn test_custom_config() {
let config = KeyConfig::new().with_entropy(32).unwrap();
let generator = ApiKeyManagerV0::init(
"custom",
config,
HashConfig::default(),
std::time::Duration::ZERO,
)
.unwrap();
let key = generator.generate(Environment::production()).unwrap();
assert!(generator.verify_checksum(key.key()).unwrap());
}
#[test]
fn compare_hash() {
let manager = ApiKeyManagerV0::init_default_config("sk").unwrap();
let key = manager.generate(Environment::production()).unwrap();
let new_secret = ApiKey::new(SecureString::from(key.key().expose_secret()))
.into_hashed_with_phc(manager.hasher(), key.expose_hash().hash())
.unwrap();
assert_eq!(new_secret.expose_hash(), key.expose_hash());
}
#[test]
fn test_extract_key_id() {
let manager = ApiKeyManagerV0::init_default_config("sk").unwrap();
let key1 = manager.generate(Environment::production()).unwrap();
let key2 = manager.generate(Environment::production()).unwrap();
let id1 = manager.extract_key_id(key1.key());
assert_eq!(id1, *key1.expose_hash().key_id());
assert_eq!(id1, manager.extract_key_id(key1.key()));
assert_ne!(id1, manager.extract_key_id(key2.key()));
assert_eq!(id1.len(), 32);
assert!(id1.chars().all(|c| c.is_ascii_hexdigit()));
}
#[test]
fn test_key_id_stability_across_rehashing() {
let manager = ApiKeyManagerV0::init_default_config("sk").unwrap();
let key1 = manager.generate(Environment::production()).unwrap();
let key2 = ApiKey::new(SecureString::from(key1.key().expose_secret()))
.into_hashed(manager.hasher())
.unwrap();
assert_eq!(key1.expose_hash().key_id(), key2.expose_hash().key_id());
assert_ne!(key1.expose_hash().hash(), key2.expose_hash().hash());
}
#[test]
fn test_key_id_database_lookup() {
let manager = ApiKeyManagerV0::init_default_config("sk").unwrap();
let api_key = manager.generate(Environment::production()).unwrap();
let stored_key_id = api_key.expose_hash().key_id().to_string();
let stored_hash = api_key.expose_hash().hash().to_string();
let incoming_key = SecureString::from(api_key.key().expose_secret());
let lookup_key_id = manager.extract_key_id(&incoming_key);
assert_eq!(lookup_key_id, stored_key_id);
assert_eq!(
manager.verify(&incoming_key, &stored_hash).unwrap(),
KeyStatus::Valid
);
}
}