use crate::crypto::pqc::types::{MlDsaPublicKey, MlDsaSecretKey, MlDsaSignature};
use crate::crypto::raw_public_keys::key_utils::generate_ml_dsa_keypair;
use crate::crypto::raw_public_keys::pqc::{
ML_DSA_65_SIGNATURE_SIZE, sign_with_ml_dsa, verify_with_ml_dsa,
};
use crate::relay::{RelayError, RelayResult};
use std::collections::{HashSet, VecDeque};
use std::sync::{Arc, Mutex};
use std::time::{SystemTime, UNIX_EPOCH};
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct AuthToken {
pub nonce: u64,
pub timestamp: u64,
pub bandwidth_limit: u32,
pub timeout_seconds: u32,
pub signature: Vec<u8>,
}
#[derive(Debug, Default)]
struct NonceWindow {
order: VecDeque<u64>,
set: HashSet<u64>,
}
impl NonceWindow {
fn contains(&self, nonce: u64) -> bool {
self.set.contains(&nonce)
}
fn insert_with_limit(&mut self, nonce: u64, max_size: usize) {
if self.set.insert(nonce) {
self.order.push_back(nonce);
}
while self.set.len() > max_size {
if let Some(oldest) = self.order.pop_front() {
self.set.remove(&oldest);
} else {
break;
}
}
}
fn clear(&mut self) {
self.order.clear();
self.set.clear();
}
fn len(&self) -> usize {
self.set.len()
}
}
#[derive(Debug)]
pub struct RelayAuthenticator {
public_key: MlDsaPublicKey,
secret_key: MlDsaSecretKey,
used_nonces: Arc<Mutex<NonceWindow>>,
max_token_age: u64,
replay_window_size: u64,
}
impl AuthToken {
pub fn new(
bandwidth_limit: u32,
timeout_seconds: u32,
secret_key: &MlDsaSecretKey,
) -> RelayResult<Self> {
let nonce = Self::generate_nonce();
let timestamp = Self::current_timestamp()?;
let mut token = Self {
nonce,
timestamp,
bandwidth_limit,
timeout_seconds,
signature: vec![0; ML_DSA_65_SIGNATURE_SIZE],
};
let sig = sign_with_ml_dsa(secret_key, &token.signable_data()).map_err(|_| {
RelayError::AuthenticationFailed {
reason: "ML-DSA-65 signing failed".to_string(),
}
})?;
token.signature = sig.as_bytes().to_vec();
Ok(token)
}
fn generate_nonce() -> u64 {
use rand::Rng;
use rand::rngs::OsRng;
OsRng.r#gen()
}
fn current_timestamp() -> RelayResult<u64> {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_secs())
.map_err(|_| RelayError::AuthenticationFailed {
reason: "System time before Unix epoch".to_string(),
})
}
fn signable_data(&self) -> Vec<u8> {
let mut data = Vec::new();
data.extend_from_slice(&self.nonce.to_le_bytes());
data.extend_from_slice(&self.timestamp.to_le_bytes());
data.extend_from_slice(&self.bandwidth_limit.to_le_bytes());
data.extend_from_slice(&self.timeout_seconds.to_le_bytes());
data
}
pub fn verify(&self, public_key: &MlDsaPublicKey) -> RelayResult<()> {
let signature = MlDsaSignature::from_bytes(&self.signature).map_err(|_| {
RelayError::AuthenticationFailed {
reason: "Invalid signature format".to_string(),
}
})?;
verify_with_ml_dsa(public_key, &self.signable_data(), &signature).map_err(|_| {
RelayError::AuthenticationFailed {
reason: "Signature verification failed".to_string(),
}
})
}
pub fn is_expired(&self, max_age_seconds: u64) -> RelayResult<bool> {
let current_time = Self::current_timestamp()?;
Ok(current_time > self.timestamp + max_age_seconds)
}
}
impl RelayAuthenticator {
pub fn new() -> RelayResult<Self> {
let (public_key, secret_key) =
generate_ml_dsa_keypair().map_err(|e| RelayError::AuthenticationFailed {
reason: format!("ML-DSA-65 keypair generation failed: {}", e),
})?;
Ok(Self {
public_key,
secret_key,
used_nonces: Arc::new(Mutex::new(NonceWindow::default())),
max_token_age: 300, replay_window_size: 1000,
})
}
pub fn with_keypair(public_key: MlDsaPublicKey, secret_key: MlDsaSecretKey) -> Self {
Self {
public_key,
secret_key,
used_nonces: Arc::new(Mutex::new(NonceWindow::default())),
max_token_age: 300,
replay_window_size: 1000,
}
}
pub fn public_key(&self) -> &MlDsaPublicKey {
&self.public_key
}
pub fn create_token(
&self,
bandwidth_limit: u32,
timeout_seconds: u32,
) -> RelayResult<AuthToken> {
AuthToken::new(bandwidth_limit, timeout_seconds, &self.secret_key)
}
#[allow(clippy::expect_used)]
pub fn verify_token(
&self,
token: &AuthToken,
peer_public_key: &MlDsaPublicKey,
) -> RelayResult<()> {
token.verify(peer_public_key)?;
if token.is_expired(self.max_token_age)? {
return Err(RelayError::AuthenticationFailed {
reason: "Token expired".to_string(),
});
}
let mut used_nonces = match self.used_nonces.lock() {
Ok(guard) => guard,
Err(_poisoned) => {
tracing::error!(
"Mutex poisoned in relay authenticator - potential security compromise, \
failing authentication to prevent replay attacks"
);
return Err(RelayError::AuthenticationFailed {
reason: "Internal security state compromised".to_string(),
});
}
};
if used_nonces.contains(token.nonce) {
return Err(RelayError::AuthenticationFailed {
reason: "Token replay detected".to_string(),
});
}
used_nonces.insert_with_limit(token.nonce, self.replay_window_size as usize);
Ok(())
}
pub fn set_max_token_age(&mut self, max_age_seconds: u64) {
self.max_token_age = max_age_seconds;
}
pub fn max_token_age(&self) -> u64 {
self.max_token_age
}
#[allow(clippy::unwrap_used, clippy::expect_used)]
pub fn clear_nonces(&self) {
let mut used_nonces = self
.used_nonces
.lock()
.expect("Mutex poisoning is unexpected in normal operation");
used_nonces.clear();
}
#[allow(clippy::unwrap_used, clippy::expect_used)]
pub fn nonce_count(&self) -> usize {
let used_nonces = self
.used_nonces
.lock()
.expect("Mutex poisoning is unexpected in normal operation");
used_nonces.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashSet;
use std::thread;
use std::time::Duration;
#[test]
fn test_auth_token_creation_and_verification() {
let authenticator = RelayAuthenticator::new().unwrap();
let token = authenticator.create_token(1024, 300).unwrap();
assert!(token.bandwidth_limit == 1024);
assert!(token.timeout_seconds == 300);
assert!(token.nonce != 0);
assert!(token.timestamp > 0);
assert!(token.verify(authenticator.public_key()).is_ok());
}
#[test]
fn test_token_verification_with_wrong_key() {
let authenticator1 = RelayAuthenticator::new().unwrap();
let authenticator2 = RelayAuthenticator::new().unwrap();
let token = authenticator1.create_token(1024, 300).unwrap();
assert!(token.verify(authenticator2.public_key()).is_err());
}
#[test]
fn test_token_expiration() {
let mut authenticator = RelayAuthenticator::new().unwrap();
authenticator.set_max_token_age(1);
let token = authenticator.create_token(1024, 300).unwrap();
let max_age = authenticator.max_token_age();
assert!(!token.is_expired(max_age).unwrap());
thread::sleep(Duration::from_secs(2));
assert!(token.is_expired(max_age).unwrap());
}
#[test]
fn test_anti_replay_protection() {
let authenticator = RelayAuthenticator::new().unwrap();
let token = authenticator.create_token(1024, 300).unwrap();
assert!(
authenticator
.verify_token(&token, authenticator.public_key())
.is_ok()
);
assert!(
authenticator
.verify_token(&token, authenticator.public_key())
.is_err()
);
}
#[test]
fn test_nonce_uniqueness() {
let authenticator = RelayAuthenticator::new().unwrap();
let mut nonces = HashSet::new();
for _ in 0..1000 {
let token = authenticator.create_token(1024, 300).unwrap();
assert!(!nonces.contains(&token.nonce), "Duplicate nonce detected");
nonces.insert(token.nonce);
}
}
#[test]
fn test_token_signable_data() {
let authenticator = RelayAuthenticator::new().unwrap();
let token1 = authenticator.create_token(1024, 300).unwrap();
let token2 = authenticator.create_token(1024, 300).unwrap();
assert_ne!(token1.signable_data(), token2.signable_data());
}
#[test]
fn test_nonce_window_management() {
let authenticator = RelayAuthenticator::new().unwrap();
for _ in 0..1000 {
let token = authenticator.create_token(1024, 300).unwrap();
let _ = authenticator.verify_token(&token, authenticator.public_key());
}
assert_eq!(authenticator.nonce_count(), 1000);
let token = authenticator.create_token(1024, 300).unwrap();
let _ = authenticator.verify_token(&token, authenticator.public_key());
assert!(authenticator.nonce_count() <= 1000);
}
#[test]
fn test_clear_nonces() {
let authenticator = RelayAuthenticator::new().unwrap();
let token = authenticator.create_token(1024, 300).unwrap();
let _ = authenticator.verify_token(&token, authenticator.public_key());
assert!(authenticator.nonce_count() > 0);
authenticator.clear_nonces();
assert_eq!(authenticator.nonce_count(), 0);
assert!(
authenticator
.verify_token(&token, authenticator.public_key())
.is_ok()
);
}
#[test]
fn test_with_specific_keypair() {
let (public_key, secret_key) = generate_ml_dsa_keypair().unwrap();
let authenticator = RelayAuthenticator::with_keypair(public_key, secret_key);
let token = authenticator.create_token(1024, 300).unwrap();
assert!(token.verify(authenticator.public_key()).is_ok());
}
}