use parking_lot::RwLock;
use std::sync::OnceLock;
use argon2::{Algorithm, Argon2, Params, Version};
use chacha20poly1305::aead::{Aead, KeyInit, Payload};
use chacha20poly1305::{XChaCha20Poly1305, XNonce};
use rand::random;
use crate::error::{Error, Result};
pub type TokenEncoder = fn(record_id: &str, model_name: &str) -> Result<String>;
pub type TokenDecoder = fn(token: &str, model_name: &str) -> Result<Option<String>>;
static GLOBAL_ENCRYPTION_KEY: OnceLock<RwLock<Option<ConfiguredEncryptionKey>>> = OnceLock::new();
static GLOBAL_TOKEN_ENCODER: OnceLock<RwLock<Option<TokenEncoder>>> = OnceLock::new();
static GLOBAL_TOKEN_DECODER: OnceLock<RwLock<Option<TokenDecoder>>> = OnceLock::new();
#[derive(Clone)]
struct ConfiguredEncryptionKey {
raw: String,
derived: [u8; 32],
}
impl ConfiguredEncryptionKey {
fn new(raw: &str) -> Self {
Self {
raw: raw.to_string(),
derived: derive_encryption_key(raw),
}
}
}
fn global_encryption_key_state() -> &'static RwLock<Option<ConfiguredEncryptionKey>> {
GLOBAL_ENCRYPTION_KEY.get_or_init(|| RwLock::new(None))
}
fn global_token_encoder_state() -> &'static RwLock<Option<TokenEncoder>> {
GLOBAL_TOKEN_ENCODER.get_or_init(|| RwLock::new(None))
}
fn global_token_decoder_state() -> &'static RwLock<Option<TokenDecoder>> {
GLOBAL_TOKEN_DECODER.get_or_init(|| RwLock::new(None))
}
pub struct TokenConfig;
impl TokenConfig {
pub fn set_encryption_key(key: &str) {
let configured_key = ConfiguredEncryptionKey::new(key);
*global_encryption_key_state().write() = Some(configured_key);
}
pub fn get_encryption_key() -> Result<String> {
Self::current_encryption_key()
.map(|configured| configured.raw)
.ok_or_else(|| Error::tokenization("No encryption key configured"))
}
pub(crate) fn get_derived_encryption_key() -> Result<[u8; 32]> {
Self::current_encryption_key()
.map(|configured| configured.derived)
.ok_or_else(|| Error::tokenization("No encryption key configured"))
}
pub fn has_encryption_key() -> bool {
global_encryption_key_state().read().is_some()
}
fn current_encryption_key() -> Option<ConfiguredEncryptionKey> {
global_encryption_key_state().read().clone()
}
pub fn set_encoder(encoder: TokenEncoder) {
*global_token_encoder_state().write() = Some(encoder);
}
pub fn set_decoder(decoder: TokenDecoder) {
*global_token_decoder_state().write() = Some(decoder);
}
pub fn reset() {
*global_encryption_key_state().write() = None;
*global_token_encoder_state().write() = None;
*global_token_decoder_state().write() = None;
}
pub fn get_encoder() -> TokenEncoder {
(*global_token_encoder_state().read()).unwrap_or(default_encode)
}
pub fn get_decoder() -> TokenDecoder {
(*global_token_decoder_state().read()).unwrap_or(default_decode)
}
pub fn encode(record_id: &str, model_name: &str) -> Result<String> {
Self::get_encoder()(record_id, model_name)
}
pub fn decode(token: &str, model_name: &str) -> Result<Option<String>> {
Self::get_decoder()(token, model_name)
}
}
pub(crate) fn derive_encryption_key(key: &str) -> [u8; 32] {
const DERIVED_KEY_LEN: usize = 32;
const TOKENIZATION_KDF_SALT: &[u8] = b"tideorm::xchacha20poly1305-key::v2";
let params = Params::new(64 * 1024, 3, 1, Some(DERIVED_KEY_LEN))
.expect("argon2 params for tokenization key derivation should be valid");
let argon2 = Argon2::new(Algorithm::Argon2id, Version::V0x13, params);
let mut derived = [0u8; DERIVED_KEY_LEN];
argon2
.hash_password_into(key.as_bytes(), TOKENIZATION_KDF_SALT, &mut derived)
.expect("argon2 key derivation should succeed with static parameters");
derived
}
pub(crate) fn base64_url_encode(data: &[u8]) -> String {
const ALPHABET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_";
let mut result = String::new();
let mut bits = 0u32;
let mut bit_count = 0;
for &byte in data {
bits = (bits << 8) | (byte as u32);
bit_count += 8;
while bit_count >= 6 {
bit_count -= 6;
result.push(ALPHABET[((bits >> bit_count) & 0x3F) as usize] as char);
}
}
if bit_count > 0 {
bits <<= 6 - bit_count;
result.push(ALPHABET[(bits & 0x3F) as usize] as char);
}
result
}
pub(crate) fn base64_url_decode(encoded: &str) -> Option<Vec<u8>> {
fn char_to_value(c: char) -> Option<u8> {
match c {
'A'..='Z' => Some(c as u8 - b'A'),
'a'..='z' => Some(c as u8 - b'a' + 26),
'0'..='9' => Some(c as u8 - b'0' + 52),
'-' => Some(62),
'_' => Some(63),
_ => None,
}
}
let mut result = Vec::new();
let mut bits = 0u32;
let mut bit_count = 0;
for c in encoded.chars() {
let value = char_to_value(c)?;
bits = (bits << 6) | (value as u32);
bit_count += 6;
if bit_count >= 8 {
bit_count -= 8;
result.push((bits >> bit_count) as u8);
}
}
Some(result)
}
pub fn default_encode(record_id: &str, model_name: &str) -> Result<String> {
let key = TokenConfig::get_derived_encryption_key()?;
let cipher = XChaCha20Poly1305::new((&key).into());
let nonce_bytes: [u8; 24] = random();
let nonce = XNonce::from_slice(&nonce_bytes);
let ciphertext = cipher
.encrypt(
nonce,
Payload {
msg: record_id.as_bytes(),
aad: model_name.as_bytes(),
},
)
.map_err(|_| Error::tokenization("Failed to encrypt token payload"))?;
let mut token_data = Vec::with_capacity(24 + ciphertext.len());
token_data.extend_from_slice(&nonce_bytes);
token_data.extend_from_slice(&ciphertext);
Ok(base64_url_encode(&token_data))
}
pub fn default_decode(token: &str, model_name: &str) -> Result<Option<String>> {
let key = TokenConfig::get_derived_encryption_key()?;
let cipher = XChaCha20Poly1305::new((&key).into());
let Some(token_data) = base64_url_decode(token) else {
return Ok(None);
};
if token_data.len() <= 24 {
return Ok(None);
}
let nonce = XNonce::from_slice(&token_data[..24]);
let plaintext = match cipher.decrypt(
nonce,
Payload {
msg: &token_data[24..],
aad: model_name.as_bytes(),
},
) {
Ok(plaintext) => plaintext,
Err(_) => return Ok(None),
};
Ok(String::from_utf8(plaintext).ok())
}
#[async_trait::async_trait]
pub trait Tokenizable: Sized + Send + Sync {
type TokenPrimaryKey: Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static;
fn token_model_name() -> &'static str;
fn token_primary_key(&self) -> Self::TokenPrimaryKey;
fn tokenization_enabled() -> bool {
true
}
fn token_encoder() -> Option<TokenEncoder> {
None
}
fn token_decoder() -> Option<TokenDecoder> {
None
}
fn to_token(&self) -> Result<String> {
if !Self::tokenization_enabled() {
return Err(Error::tokenization(
"Tokenization is not enabled for this model",
));
}
let encoder = Self::token_encoder().unwrap_or_else(TokenConfig::get_encoder);
let primary_key = self.token_primary_key();
let payload = serde_json::to_string(&primary_key).map_err(|error| {
Error::tokenization(format!("Failed to serialize token primary key: {error}"))
})?;
encoder(&payload, Self::token_model_name())
}
fn tokenize(&self) -> Result<String> {
self.to_token()
}
fn tokenize_id(id: Self::TokenPrimaryKey) -> Result<String> {
if !Self::tokenization_enabled() {
return Err(Error::tokenization(
"Tokenization is not enabled for this model",
));
}
let encoder = Self::token_encoder().unwrap_or_else(TokenConfig::get_encoder);
let payload = serde_json::to_string(&id).map_err(|error| {
Error::tokenization(format!("Failed to serialize token primary key: {error}"))
})?;
encoder(&payload, Self::token_model_name())
}
async fn from_token(token: &str) -> Result<Self>;
fn detokenize(token: &str) -> Result<Self::TokenPrimaryKey> {
Self::decode_token(token)
}
fn decode_token(token: &str) -> Result<Self::TokenPrimaryKey> {
if !Self::tokenization_enabled() {
return Err(Error::tokenization(
"Tokenization is not enabled for this model",
));
}
let decoder = Self::token_decoder().unwrap_or_else(TokenConfig::get_decoder);
let payload = decoder(token, Self::token_model_name())?
.ok_or_else(|| Error::invalid_token("Failed to decode token"))?;
serde_json::from_str::<Self::TokenPrimaryKey>(&payload).map_err(|error| {
Error::invalid_token(format!(
"Failed to deserialize decoded token payload '{}' for model {}: {}",
payload,
Self::token_model_name(),
error
))
})
}
fn regenerate_token(&self) -> Result<String> {
self.to_token()
}
}
#[cfg(test)]
#[path = "testing/tokenization_tests.rs"]
mod tests;