use std::sync::OnceLock;
use crate::error::{Error, Result};
pub type TokenEncoder = fn(record_id: i64, model_name: &str) -> Result<String>;
pub type TokenDecoder = fn(token: &str, model_name: &str) -> Option<i64>;
static GLOBAL_ENCRYPTION_KEY: OnceLock<String> = OnceLock::new();
static GLOBAL_TOKEN_ENCODER: OnceLock<TokenEncoder> = OnceLock::new();
static GLOBAL_TOKEN_DECODER: OnceLock<TokenDecoder> = OnceLock::new();
pub struct TokenConfig;
impl TokenConfig {
pub fn set_encryption_key(key: &str) {
let _ = GLOBAL_ENCRYPTION_KEY.set(key.to_string());
}
pub fn get_encryption_key() -> String {
GLOBAL_ENCRYPTION_KEY
.get()
.cloned()
.unwrap_or_else(|| "tideorm-default-dev-key-32bytes!".to_string())
}
pub fn has_encryption_key() -> bool {
GLOBAL_ENCRYPTION_KEY.get().is_some()
}
pub fn set_encoder(encoder: TokenEncoder) {
let _ = GLOBAL_TOKEN_ENCODER.set(encoder);
}
pub fn set_decoder(decoder: TokenDecoder) {
let _ = GLOBAL_TOKEN_DECODER.set(decoder);
}
pub fn get_encoder() -> TokenEncoder {
GLOBAL_TOKEN_ENCODER
.get()
.copied()
.unwrap_or(default_encode)
}
pub fn get_decoder() -> TokenDecoder {
GLOBAL_TOKEN_DECODER
.get()
.copied()
.unwrap_or(default_decode)
}
pub fn encode(record_id: i64, model_name: &str) -> Result<String> {
Self::get_encoder()(record_id, model_name)
}
pub fn decode(token: &str, model_name: &str) -> Option<i64> {
Self::get_decoder()(token, model_name)
}
}
fn xor_encrypt(data: &[u8], key: &[u8]) -> Vec<u8> {
let mut result = Vec::with_capacity(data.len());
for (i, &byte) in data.iter().enumerate() {
result.push(byte ^ key[i % key.len()]);
}
result
}
fn compute_hmac(data: &[u8], key: &[u8]) -> [u8; 8] {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
data.hash(&mut hasher);
key.hash(&mut hasher);
let hash = hasher.finish();
hash.to_be_bytes()
}
fn generate_iv(key: &[u8], model_name: &str) -> [u8; 16] {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
key.hash(&mut hasher);
model_name.hash(&mut hasher);
let hash1 = hasher.finish();
let mut hasher2 = DefaultHasher::new();
hash1.hash(&mut hasher2);
key.iter().rev().collect::<Vec<_>>().hash(&mut hasher2);
let hash2 = hasher2.finish();
let mut iv = [0u8; 16];
iv[..8].copy_from_slice(&hash1.to_be_bytes());
iv[8..].copy_from_slice(&hash2.to_be_bytes());
iv
}
fn base64_url_encode(data: &[u8]) -> String {
const ALPHABET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_";
let mut result = String::new();
let mut bits = 0u32;
let mut bit_count = 0;
for &byte in data {
bits = (bits << 8) | (byte as u32);
bit_count += 8;
while bit_count >= 6 {
bit_count -= 6;
result.push(ALPHABET[((bits >> bit_count) & 0x3F) as usize] as char);
}
}
if bit_count > 0 {
bits <<= 6 - bit_count;
result.push(ALPHABET[(bits & 0x3F) as usize] as char);
}
result
}
fn base64_url_decode(encoded: &str) -> Option<Vec<u8>> {
fn char_to_value(c: char) -> Option<u8> {
match c {
'A'..='Z' => Some(c as u8 - b'A'),
'a'..='z' => Some(c as u8 - b'a' + 26),
'0'..='9' => Some(c as u8 - b'0' + 52),
'-' => Some(62),
'_' => Some(63),
_ => None,
}
}
let mut result = Vec::new();
let mut bits = 0u32;
let mut bit_count = 0;
for c in encoded.chars() {
let value = char_to_value(c)?;
bits = (bits << 6) | (value as u32);
bit_count += 6;
if bit_count >= 8 {
bit_count -= 8;
result.push((bits >> bit_count) as u8);
}
}
Some(result)
}
pub fn default_encode(record_id: i64, model_name: &str) -> Result<String> {
let key = TokenConfig::get_encryption_key();
let key_bytes = key.as_bytes();
let iv = generate_iv(key_bytes, model_name);
let id_bytes = record_id.to_be_bytes();
let mut combined_key = Vec::with_capacity(key_bytes.len() + iv.len());
combined_key.extend_from_slice(key_bytes);
combined_key.extend_from_slice(&iv);
let encrypted = xor_encrypt(&id_bytes, &combined_key);
let mut hmac_data = Vec::new();
hmac_data.extend_from_slice(&iv);
hmac_data.extend_from_slice(&encrypted);
hmac_data.extend_from_slice(model_name.as_bytes());
let hmac = compute_hmac(&hmac_data, key_bytes);
let mut token_data = Vec::with_capacity(32);
token_data.extend_from_slice(&iv);
token_data.extend_from_slice(&encrypted);
token_data.extend_from_slice(&hmac);
Ok(base64_url_encode(&token_data))
}
pub fn default_decode(token: &str, model_name: &str) -> Option<i64> {
let key = TokenConfig::get_encryption_key();
let key_bytes = key.as_bytes();
let token_data = base64_url_decode(token)?;
if token_data.len() < 32 {
return None;
}
let iv = &token_data[0..16];
let encrypted = &token_data[16..24];
let provided_hmac = &token_data[24..32];
let mut hmac_data = Vec::new();
hmac_data.extend_from_slice(iv);
hmac_data.extend_from_slice(encrypted);
hmac_data.extend_from_slice(model_name.as_bytes());
let computed_hmac = compute_hmac(&hmac_data, key_bytes);
if provided_hmac != computed_hmac {
return None; }
let mut combined_key = Vec::with_capacity(key_bytes.len() + iv.len());
combined_key.extend_from_slice(key_bytes);
combined_key.extend_from_slice(iv);
let decrypted = xor_encrypt(encrypted, &combined_key);
if decrypted.len() != 8 {
return None;
}
let id_bytes: [u8; 8] = decrypted.try_into().ok()?;
Some(i64::from_be_bytes(id_bytes))
}
#[async_trait::async_trait]
pub trait Tokenizable: Sized + Send + Sync {
fn token_model_name() -> &'static str;
fn token_primary_key(&self) -> i64;
fn tokenization_enabled() -> bool {
true
}
fn token_encoder() -> Option<TokenEncoder> {
None
}
fn token_decoder() -> Option<TokenDecoder> {
None
}
fn to_token(&self) -> Result<String> {
if !Self::tokenization_enabled() {
return Err(Error::tokenization("Tokenization is not enabled for this model"));
}
let encoder = Self::token_encoder()
.unwrap_or_else(TokenConfig::get_encoder);
encoder(self.token_primary_key(), Self::token_model_name())
}
fn tokenize(&self) -> Result<String> {
self.to_token()
}
fn tokenize_id(id: i64) -> Result<String> {
if !Self::tokenization_enabled() {
return Err(Error::tokenization("Tokenization is not enabled for this model"));
}
let encoder = Self::token_encoder()
.unwrap_or_else(TokenConfig::get_encoder);
encoder(id, Self::token_model_name())
}
async fn from_token(token: &str) -> Result<Self>;
fn detokenize(token: &str) -> Result<i64> {
Self::decode_token(token)
}
fn decode_token(token: &str) -> Result<i64> {
if !Self::tokenization_enabled() {
return Err(Error::tokenization("Tokenization is not enabled for this model"));
}
let decoder = Self::token_decoder()
.unwrap_or_else(TokenConfig::get_decoder);
decoder(token, Self::token_model_name())
.ok_or_else(|| Error::invalid_token("Failed to decode token"))
}
fn regenerate_token(&self) -> Result<String> {
self.to_token()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_base64_url_encode_decode() {
let data = b"Hello, World!";
let encoded = base64_url_encode(data);
let decoded = base64_url_decode(&encoded).unwrap();
assert_eq!(data.to_vec(), decoded);
}
#[test]
fn test_base64_url_various_lengths() {
for len in 1..=32 {
let data: Vec<u8> = (0..len).map(|i| i as u8).collect();
let encoded = base64_url_encode(&data);
let decoded = base64_url_decode(&encoded).unwrap();
assert_eq!(data, decoded, "Failed for length {}", len);
}
}
#[test]
fn test_default_encode_decode() {
let record_id = 12345i64;
let model_name = "User";
let token = default_encode(record_id, model_name).unwrap();
let decoded = default_decode(&token, model_name);
assert_eq!(decoded, Some(record_id));
}
#[test]
fn test_encode_decode_negative_id() {
let record_id = -99999i64;
let model_name = "NegativeModel";
let token = default_encode(record_id, model_name).unwrap();
let decoded = default_decode(&token, model_name);
assert_eq!(decoded, Some(record_id));
}
#[test]
fn test_encode_decode_zero() {
let record_id = 0i64;
let model_name = "ZeroModel";
let token = default_encode(record_id, model_name).unwrap();
let decoded = default_decode(&token, model_name);
assert_eq!(decoded, Some(record_id));
}
#[test]
fn test_encode_decode_max_i64() {
let record_id = i64::MAX;
let model_name = "MaxModel";
let token = default_encode(record_id, model_name).unwrap();
let decoded = default_decode(&token, model_name);
assert_eq!(decoded, Some(record_id));
}
#[test]
fn test_wrong_model_fails() {
let record_id = 42i64;
let token = default_encode(record_id, "User").unwrap();
let decoded = default_decode(&token, "Product");
assert_eq!(decoded, None);
}
#[test]
fn test_tampered_token_fails() {
let record_id = 42i64;
let token = default_encode(record_id, "User").unwrap();
let mut chars: Vec<char> = token.chars().collect();
if let Some(c) = chars.get_mut(10) {
*c = if *c == 'A' { 'B' } else { 'A' };
}
let tampered: String = chars.into_iter().collect();
let decoded = default_decode(&tampered, "User");
assert_eq!(decoded, None);
}
#[test]
fn test_invalid_base64_fails() {
let decoded = default_decode("not-valid-base64!!!", "User");
assert_eq!(decoded, None);
}
#[test]
fn test_too_short_token_fails() {
let decoded = default_decode("abc", "User");
assert_eq!(decoded, None);
}
#[test]
fn test_token_is_url_safe() {
let record_id = 999999999i64;
let token = default_encode(record_id, "User").unwrap();
assert!(token.chars().all(|c| {
c.is_ascii_alphanumeric() || c == '-' || c == '_'
}));
}
#[test]
fn test_different_ids_different_tokens() {
let token1 = default_encode(1, "User").unwrap();
let token2 = default_encode(2, "User").unwrap();
assert_ne!(token1, token2);
}
#[test]
fn test_same_id_same_token() {
let token1 = default_encode(42, "User").unwrap();
let token2 = default_encode(42, "User").unwrap();
assert_eq!(token1, token2);
}
#[test]
fn test_xor_encrypt_decrypt() {
let data = b"test data";
let key = b"secret key";
let encrypted = xor_encrypt(data, key);
let decrypted = xor_encrypt(&encrypted, key);
assert_eq!(data.to_vec(), decrypted);
}
#[test]
fn test_token_config_encode_decode() {
let token = TokenConfig::encode(123, "TestModel").unwrap();
let decoded = TokenConfig::decode(&token, "TestModel");
assert_eq!(decoded, Some(123));
}
}