use crate::crypto::{derive_key_from_password, encrypt, decrypt, EncryptionKey};
use crate::{Error, Result};
use sha2::{Sha256, Digest};
use std::collections::HashSet;
use std::sync::{Arc, RwLock};
use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
use zeroize::Zeroizing;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ZkeMode {
Full,
Hybrid,
PerRequest,
}
impl Default for ZkeMode {
fn default() -> Self {
Self::PerRequest
}
}
#[derive(Debug, Clone)]
pub struct ZkeConfig {
pub mode: ZkeMode,
pub require_key_hash: bool,
pub replay_protection: bool,
pub nonce_window_secs: u64,
pub max_cached_nonces: usize,
}
impl Default for ZkeConfig {
fn default() -> Self {
Self {
mode: ZkeMode::PerRequest,
require_key_hash: true,
replay_protection: true,
nonce_window_secs: 300, max_cached_nonces: 10000,
}
}
}
#[derive(Clone)]
pub struct ZkeDerivedKeys {
pub auth_key: Zeroizing<EncryptionKey>,
pub encryption_key: Zeroizing<EncryptionKey>,
pub encryption_key_hash: [u8; 32],
}
impl ZkeDerivedKeys {
pub fn key_hash_hex(&self) -> String {
hex::encode(self.encryption_key_hash)
}
}
pub struct ZkeKeyDerivation;
impl ZkeKeyDerivation {
pub fn derive_keys(password: &str, identifier: &str) -> Result<ZkeDerivedKeys> {
let base_salt = Self::create_salt(identifier);
let mut auth_salt = base_salt.clone();
auth_salt.extend_from_slice(b"auth");
let auth_key = derive_key_from_password(password, &auth_salt)?;
let mut encrypt_salt = base_salt;
encrypt_salt.extend_from_slice(b"encrypt");
let encryption_key = derive_key_from_password(password, &encrypt_salt)?;
let encryption_key_hash = Self::compute_key_hash(&encryption_key);
Ok(ZkeDerivedKeys {
auth_key: Zeroizing::new(auth_key),
encryption_key: Zeroizing::new(encryption_key),
encryption_key_hash,
})
}
fn create_salt(identifier: &str) -> Vec<u8> {
let mut hasher = Sha256::new();
hasher.update(identifier.as_bytes());
hasher.update(b"heliosdb.zke.salt");
hasher.finalize().to_vec()
}
pub fn compute_key_hash(key: &EncryptionKey) -> [u8; 32] {
let mut hasher = Sha256::new();
hasher.update(key);
let result = hasher.finalize();
let mut hash = [0u8; 32];
hash.copy_from_slice(&result);
hash
}
}
pub struct ZeroKnowledgeSession {
key: Zeroizing<EncryptionKey>,
key_hash: [u8; 32],
created_at: Instant,
nonce: Option<[u8; 16]>,
}
impl ZeroKnowledgeSession {
pub fn new(key: EncryptionKey) -> Result<Self> {
let key_hash = ZkeKeyDerivation::compute_key_hash(&key);
Ok(Self {
key: Zeroizing::new(key),
key_hash,
created_at: Instant::now(),
nonce: None,
})
}
pub fn from_hex_key(hex_key: &str) -> Result<Self> {
let key = Self::parse_hex_key(hex_key)?;
Self::new(key)
}
pub fn from_derived_keys(keys: &ZkeDerivedKeys) -> Result<Self> {
Ok(Self {
key: keys.encryption_key.clone(),
key_hash: keys.encryption_key_hash,
created_at: Instant::now(),
nonce: None,
})
}
pub fn with_nonce(mut self, nonce: [u8; 16]) -> Self {
self.nonce = Some(nonce);
self
}
pub fn with_random_nonce(mut self) -> Self {
self.nonce = Some(rand::random());
self
}
pub fn nonce(&self) -> Option<&[u8; 16]> {
self.nonce.as_ref()
}
pub fn nonce_hex(&self) -> Option<String> {
self.nonce.map(|n| hex::encode(n))
}
pub fn encrypt(&self, plaintext: &[u8]) -> Result<Vec<u8>> {
encrypt(&self.key, plaintext)
}
pub fn decrypt(&self, ciphertext: &[u8]) -> Result<Vec<u8>> {
decrypt(&self.key, ciphertext)
}
pub fn key(&self) -> &EncryptionKey {
&self.key
}
pub fn key_hash(&self) -> &[u8; 32] {
&self.key_hash
}
pub fn key_hash_hex(&self) -> String {
hex::encode(self.key_hash)
}
pub fn validate_key_hash(&self, expected_hash: &[u8; 32]) -> bool {
constant_time_compare(&self.key_hash, expected_hash)
}
pub fn validate_key_hash_hex(&self, expected_hash_hex: &str) -> Result<bool> {
let expected = hex::decode(expected_hash_hex)
.map_err(|e| Error::encryption(format!("Invalid hash hex: {}", e)))?;
if expected.len() != 32 {
return Err(Error::encryption("Hash must be 32 bytes"));
}
let mut hash = [0u8; 32];
hash.copy_from_slice(&expected);
Ok(self.validate_key_hash(&hash))
}
pub fn age_secs(&self) -> u64 {
self.created_at.elapsed().as_secs()
}
fn parse_hex_key(hex_str: &str) -> Result<EncryptionKey> {
let hex_str = hex_str.trim();
if hex_str.len() != 64 {
return Err(Error::encryption(format!(
"Hex key must be 64 characters (32 bytes), got {}",
hex_str.len()
)));
}
let mut key = [0u8; 32];
for (i, chunk) in hex_str.as_bytes().chunks(2).enumerate() {
let hex_byte = std::str::from_utf8(chunk)
.map_err(|_| Error::encryption("Invalid hex string"))?;
let dest = key.get_mut(i).ok_or_else(|| Error::encryption("Key index out of bounds"))?;
*dest = u8::from_str_radix(hex_byte, 16)
.map_err(|_| Error::encryption(format!("Invalid hex byte: {}", hex_byte)))?;
}
Ok(key)
}
}
fn constant_time_compare(a: &[u8], b: &[u8]) -> bool {
if a.len() != b.len() {
return false;
}
let mut result = 0u8;
for (x, y) in a.iter().zip(b.iter()) {
result |= x ^ y;
}
result == 0
}
pub struct NonceTracker {
nonces: RwLock<HashSet<[u8; 16]>>,
expiry: RwLock<Vec<(Instant, [u8; 16])>>,
window: Duration,
max_nonces: usize,
}
impl NonceTracker {
pub fn new(window_secs: u64, max_nonces: usize) -> Self {
Self {
nonces: RwLock::new(HashSet::new()),
expiry: RwLock::new(Vec::new()),
window: Duration::from_secs(window_secs),
max_nonces,
}
}
pub fn check_and_record(&self, nonce: &[u8; 16]) -> bool {
self.cleanup_expired();
let mut nonces = self.nonces.write().unwrap_or_else(|e| e.into_inner());
let mut expiry = self.expiry.write().unwrap_or_else(|e| e.into_inner());
if nonces.len() >= self.max_nonces {
self.force_cleanup(&mut nonces, &mut expiry);
}
if nonces.contains(nonce) {
return false; }
nonces.insert(*nonce);
expiry.push((Instant::now(), *nonce));
true
}
fn cleanup_expired(&self) {
let now = Instant::now();
let mut nonces = self.nonces.write().unwrap_or_else(|e| e.into_inner());
let mut expiry = self.expiry.write().unwrap_or_else(|e| e.into_inner());
expiry.retain(|(created, nonce)| {
if now.duration_since(*created) > self.window {
nonces.remove(nonce);
false
} else {
true
}
});
}
fn force_cleanup(
&self,
nonces: &mut HashSet<[u8; 16]>,
expiry: &mut Vec<(Instant, [u8; 16])>,
) {
let remove_count = self.max_nonces / 10;
for _ in 0..remove_count {
if let Some((_, nonce)) = expiry.first() {
nonces.remove(nonce);
}
if !expiry.is_empty() {
expiry.remove(0);
}
}
}
pub fn len(&self) -> usize {
self.nonces.read().unwrap_or_else(|e| e.into_inner()).len()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
impl Default for NonceTracker {
fn default() -> Self {
Self::new(300, 10000) }
}
pub struct TimestampValidator {
max_skew_secs: u64,
}
impl TimestampValidator {
pub fn new(max_skew_secs: u64) -> Self {
Self { max_skew_secs }
}
pub fn validate(&self, request_timestamp_secs: u64) -> bool {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0);
let diff = request_timestamp_secs.abs_diff(now);
diff <= self.max_skew_secs
}
pub fn current_timestamp() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0)
}
}
impl Default for TimestampValidator {
fn default() -> Self {
Self::new(300) }
}
pub struct ZkeRequestContext {
session: ZeroKnowledgeSession,
nonce_tracker: Arc<NonceTracker>,
timestamp_validator: TimestampValidator,
config: ZkeConfig,
}
impl ZkeRequestContext {
pub fn new(
session: ZeroKnowledgeSession,
nonce_tracker: Arc<NonceTracker>,
config: ZkeConfig,
) -> Self {
Self {
session,
nonce_tracker,
timestamp_validator: TimestampValidator::new(config.nonce_window_secs),
config,
}
}
pub fn validate(
&self,
expected_key_hash: Option<&str>,
nonce: Option<&[u8; 16]>,
timestamp: Option<u64>,
) -> Result<()> {
if self.config.require_key_hash {
if let Some(hash) = expected_key_hash {
if !self.session.validate_key_hash_hex(hash)? {
return Err(Error::encryption("Key hash validation failed"));
}
} else {
return Err(Error::encryption("Key hash required but not provided"));
}
}
if self.config.replay_protection {
if let Some(n) = nonce {
if !self.nonce_tracker.check_and_record(n) {
return Err(Error::encryption("Replay attack detected: nonce already used"));
}
} else {
return Err(Error::encryption("Nonce required for replay protection"));
}
if let Some(ts) = timestamp {
if !self.timestamp_validator.validate(ts) {
return Err(Error::encryption("Request timestamp out of valid range"));
}
}
}
Ok(())
}
pub fn session(&self) -> &ZeroKnowledgeSession {
&self.session
}
pub fn encrypt(&self, plaintext: &[u8]) -> Result<Vec<u8>> {
self.session.encrypt(plaintext)
}
pub fn decrypt(&self, ciphertext: &[u8]) -> Result<Vec<u8>> {
self.session.decrypt(ciphertext)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_derive_keys() {
let keys = ZkeKeyDerivation::derive_keys("password123", "user@example.com")
.expect("Key derivation failed");
assert_eq!(keys.auth_key.len(), 32);
assert_eq!(keys.encryption_key.len(), 32);
assert_eq!(keys.encryption_key_hash.len(), 32);
assert_ne!(*keys.auth_key, *keys.encryption_key);
}
#[test]
fn test_derive_keys_deterministic() {
let keys1 = ZkeKeyDerivation::derive_keys("password", "user@test.com").unwrap();
let keys2 = ZkeKeyDerivation::derive_keys("password", "user@test.com").unwrap();
assert_eq!(*keys1.auth_key, *keys2.auth_key);
assert_eq!(*keys1.encryption_key, *keys2.encryption_key);
assert_eq!(keys1.encryption_key_hash, keys2.encryption_key_hash);
}
#[test]
fn test_derive_keys_different_passwords() {
let keys1 = ZkeKeyDerivation::derive_keys("password1", "user@test.com").unwrap();
let keys2 = ZkeKeyDerivation::derive_keys("password2", "user@test.com").unwrap();
assert_ne!(*keys1.encryption_key, *keys2.encryption_key);
}
#[test]
fn test_zke_session_encrypt_decrypt() {
let keys = ZkeKeyDerivation::derive_keys("test", "test@test.com").unwrap();
let session = ZeroKnowledgeSession::from_derived_keys(&keys).unwrap();
let plaintext = b"SELECT * FROM secret_data";
let ciphertext = session.encrypt(plaintext).unwrap();
let decrypted = session.decrypt(&ciphertext).unwrap();
assert_eq!(plaintext, &decrypted[..]);
}
#[test]
fn test_zke_session_from_hex() {
let hex_key = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef";
let session = ZeroKnowledgeSession::from_hex_key(hex_key).unwrap();
let plaintext = b"test data";
let ciphertext = session.encrypt(plaintext).unwrap();
let decrypted = session.decrypt(&ciphertext).unwrap();
assert_eq!(plaintext, &decrypted[..]);
}
#[test]
fn test_key_hash_validation() {
let keys = ZkeKeyDerivation::derive_keys("test", "user").unwrap();
let session = ZeroKnowledgeSession::from_derived_keys(&keys).unwrap();
assert!(session.validate_key_hash(&keys.encryption_key_hash));
let mut wrong_hash = keys.encryption_key_hash;
wrong_hash[0] ^= 0xFF;
assert!(!session.validate_key_hash(&wrong_hash));
}
#[test]
fn test_key_hash_hex_validation() {
let keys = ZkeKeyDerivation::derive_keys("test", "user").unwrap();
let session = ZeroKnowledgeSession::from_derived_keys(&keys).unwrap();
let hash_hex = keys.key_hash_hex();
assert!(session.validate_key_hash_hex(&hash_hex).unwrap());
}
#[test]
fn test_nonce_tracker() {
let tracker = NonceTracker::new(300, 100);
let nonce1: [u8; 16] = rand::random();
let nonce2: [u8; 16] = rand::random();
assert!(tracker.check_and_record(&nonce1));
assert!(tracker.check_and_record(&nonce2));
assert!(!tracker.check_and_record(&nonce1));
assert!(!tracker.check_and_record(&nonce2));
}
#[test]
fn test_timestamp_validator() {
let validator = TimestampValidator::new(60);
let now = TimestampValidator::current_timestamp();
assert!(validator.validate(now));
assert!(validator.validate(now - 30));
assert!(validator.validate(now + 30));
assert!(!validator.validate(now - 120));
assert!(!validator.validate(now + 120));
}
#[test]
fn test_session_with_nonce() {
let key: EncryptionKey = rand::random();
let session = ZeroKnowledgeSession::new(key)
.unwrap()
.with_random_nonce();
assert!(session.nonce().is_some());
assert!(session.nonce_hex().is_some());
}
#[test]
fn test_request_context_validation() {
let keys = ZkeKeyDerivation::derive_keys("test", "user").unwrap();
let session = ZeroKnowledgeSession::from_derived_keys(&keys)
.unwrap()
.with_random_nonce();
let nonce = *session.nonce().unwrap();
let nonce_tracker = Arc::new(NonceTracker::default());
let config = ZkeConfig::default();
let context = ZkeRequestContext::new(session, nonce_tracker, config);
let timestamp = TimestampValidator::current_timestamp();
let hash_hex = keys.key_hash_hex();
assert!(context.validate(Some(&hash_hex), Some(&nonce), Some(timestamp)).is_ok());
}
#[test]
fn test_constant_time_compare() {
let a = [1u8, 2, 3, 4];
let b = [1u8, 2, 3, 4];
let c = [1u8, 2, 3, 5];
assert!(constant_time_compare(&a, &b));
assert!(!constant_time_compare(&a, &c));
}
}