1use crate::crypto::pqc::types::{MlDsaPublicKey, MlDsaSecretKey, MlDsaSignature};
11use crate::crypto::raw_public_keys::key_utils::generate_ml_dsa_keypair;
12use crate::crypto::raw_public_keys::pqc::{
13 ML_DSA_65_SIGNATURE_SIZE, sign_with_ml_dsa, verify_with_ml_dsa,
14};
15use crate::relay::{RelayError, RelayResult};
16use std::collections::HashSet;
17use std::sync::{Arc, Mutex};
18use std::time::{SystemTime, UNIX_EPOCH};
19
20#[derive(Debug, Clone, PartialEq, Eq)]
22pub struct AuthToken {
23 pub nonce: u64,
25 pub timestamp: u64,
27 pub bandwidth_limit: u32,
29 pub timeout_seconds: u32,
31 pub signature: Vec<u8>,
33}
34
35#[derive(Debug)]
37pub struct RelayAuthenticator {
38 public_key: MlDsaPublicKey,
40 secret_key: MlDsaSecretKey,
42 used_nonces: Arc<Mutex<HashSet<u64>>>,
44 max_token_age: u64,
46 replay_window_size: u64,
48}
49
50impl AuthToken {
51 pub fn new(
53 bandwidth_limit: u32,
54 timeout_seconds: u32,
55 secret_key: &MlDsaSecretKey,
56 ) -> RelayResult<Self> {
57 let nonce = Self::generate_nonce();
58 let timestamp = Self::current_timestamp()?;
59
60 let mut token = Self {
61 nonce,
62 timestamp,
63 bandwidth_limit,
64 timeout_seconds,
65 signature: vec![0; ML_DSA_65_SIGNATURE_SIZE],
66 };
67
68 let sig = sign_with_ml_dsa(secret_key, &token.signable_data()).map_err(|_| {
70 RelayError::AuthenticationFailed {
71 reason: "ML-DSA-65 signing failed".to_string(),
72 }
73 })?;
74 token.signature = sig.as_bytes().to_vec();
75
76 Ok(token)
77 }
78
79 fn generate_nonce() -> u64 {
81 use rand::Rng;
82 use rand::rngs::OsRng;
83 OsRng.r#gen()
84 }
85
86 fn current_timestamp() -> RelayResult<u64> {
88 SystemTime::now()
89 .duration_since(UNIX_EPOCH)
90 .map(|d| d.as_secs())
91 .map_err(|_| RelayError::AuthenticationFailed {
92 reason: "System time before Unix epoch".to_string(),
93 })
94 }
95
96 fn signable_data(&self) -> Vec<u8> {
98 let mut data = Vec::new();
99 data.extend_from_slice(&self.nonce.to_le_bytes());
100 data.extend_from_slice(&self.timestamp.to_le_bytes());
101 data.extend_from_slice(&self.bandwidth_limit.to_le_bytes());
102 data.extend_from_slice(&self.timeout_seconds.to_le_bytes());
103 data
104 }
105
106 pub fn verify(&self, public_key: &MlDsaPublicKey) -> RelayResult<()> {
108 let signature = MlDsaSignature::from_bytes(&self.signature).map_err(|_| {
109 RelayError::AuthenticationFailed {
110 reason: "Invalid signature format".to_string(),
111 }
112 })?;
113
114 verify_with_ml_dsa(public_key, &self.signable_data(), &signature).map_err(|_| {
115 RelayError::AuthenticationFailed {
116 reason: "Signature verification failed".to_string(),
117 }
118 })
119 }
120
121 pub fn is_expired(&self, max_age_seconds: u64) -> RelayResult<bool> {
123 let current_time = Self::current_timestamp()?;
124 Ok(current_time > self.timestamp + max_age_seconds)
125 }
126}
127
128impl RelayAuthenticator {
129 pub fn new() -> RelayResult<Self> {
134 let (public_key, secret_key) =
135 generate_ml_dsa_keypair().map_err(|e| RelayError::AuthenticationFailed {
136 reason: format!("ML-DSA-65 keypair generation failed: {}", e),
137 })?;
138
139 Ok(Self {
140 public_key,
141 secret_key,
142 used_nonces: Arc::new(Mutex::new(HashSet::new())),
143 max_token_age: 300, replay_window_size: 1000,
145 })
146 }
147
148 pub fn with_keypair(public_key: MlDsaPublicKey, secret_key: MlDsaSecretKey) -> Self {
150 Self {
151 public_key,
152 secret_key,
153 used_nonces: Arc::new(Mutex::new(HashSet::new())),
154 max_token_age: 300,
155 replay_window_size: 1000,
156 }
157 }
158
159 pub fn public_key(&self) -> &MlDsaPublicKey {
161 &self.public_key
162 }
163
164 pub fn create_token(
166 &self,
167 bandwidth_limit: u32,
168 timeout_seconds: u32,
169 ) -> RelayResult<AuthToken> {
170 AuthToken::new(bandwidth_limit, timeout_seconds, &self.secret_key)
171 }
172
173 #[allow(clippy::expect_used)]
175 pub fn verify_token(
176 &self,
177 token: &AuthToken,
178 peer_public_key: &MlDsaPublicKey,
179 ) -> RelayResult<()> {
180 token.verify(peer_public_key)?;
182
183 if token.is_expired(self.max_token_age)? {
185 return Err(RelayError::AuthenticationFailed {
186 reason: "Token expired".to_string(),
187 });
188 }
189
190 let mut used_nonces = self
192 .used_nonces
193 .lock()
194 .expect("Mutex poisoning is unexpected in normal operation");
195
196 if used_nonces.contains(&token.nonce) {
197 return Err(RelayError::AuthenticationFailed {
198 reason: "Token replay detected".to_string(),
199 });
200 }
201
202 if used_nonces.len() >= self.replay_window_size as usize {
204 let to_remove: Vec<_> = used_nonces.iter().take(100).cloned().collect();
206 for nonce in to_remove {
207 used_nonces.remove(&nonce);
208 }
209 }
210
211 used_nonces.insert(token.nonce);
212
213 Ok(())
214 }
215
216 pub fn set_max_token_age(&mut self, max_age_seconds: u64) {
218 self.max_token_age = max_age_seconds;
219 }
220
221 pub fn max_token_age(&self) -> u64 {
223 self.max_token_age
224 }
225
226 #[allow(clippy::unwrap_used, clippy::expect_used)]
228 pub fn clear_nonces(&self) {
229 let mut used_nonces = self
230 .used_nonces
231 .lock()
232 .expect("Mutex poisoning is unexpected in normal operation");
233 used_nonces.clear();
234 }
235
236 #[allow(clippy::unwrap_used, clippy::expect_used)]
238 pub fn nonce_count(&self) -> usize {
239 let used_nonces = self
240 .used_nonces
241 .lock()
242 .expect("Mutex poisoning is unexpected in normal operation");
243 used_nonces.len()
244 }
245}
246
247#[cfg(test)]
250mod tests {
251 use super::*;
252 use std::thread;
253 use std::time::Duration;
254
255 #[test]
256 fn test_auth_token_creation_and_verification() {
257 let authenticator = RelayAuthenticator::new().unwrap();
258 let token = authenticator.create_token(1024, 300).unwrap();
259
260 assert!(token.bandwidth_limit == 1024);
261 assert!(token.timeout_seconds == 300);
262 assert!(token.nonce != 0);
263 assert!(token.timestamp > 0);
264
265 assert!(token.verify(authenticator.public_key()).is_ok());
267 }
268
269 #[test]
270 fn test_token_verification_with_wrong_key() {
271 let authenticator1 = RelayAuthenticator::new().unwrap();
272 let authenticator2 = RelayAuthenticator::new().unwrap();
273
274 let token = authenticator1.create_token(1024, 300).unwrap();
275
276 assert!(token.verify(authenticator2.public_key()).is_err());
278 }
279
280 #[test]
281 fn test_token_expiration() {
282 let mut authenticator = RelayAuthenticator::new().unwrap();
283 authenticator.set_max_token_age(1); let token = authenticator.create_token(1024, 300).unwrap();
286
287 let max_age = authenticator.max_token_age();
289 assert!(!token.is_expired(max_age).unwrap());
290
291 thread::sleep(Duration::from_secs(2)); assert!(token.is_expired(max_age).unwrap());
296 }
297
298 #[test]
299 fn test_anti_replay_protection() {
300 let authenticator = RelayAuthenticator::new().unwrap();
301 let token = authenticator.create_token(1024, 300).unwrap();
302
303 assert!(
305 authenticator
306 .verify_token(&token, authenticator.public_key())
307 .is_ok()
308 );
309
310 assert!(
312 authenticator
313 .verify_token(&token, authenticator.public_key())
314 .is_err()
315 );
316 }
317
318 #[test]
319 fn test_nonce_uniqueness() {
320 let authenticator = RelayAuthenticator::new().unwrap();
321 let mut nonces = HashSet::new();
322
323 for _ in 0..1000 {
325 let token = authenticator.create_token(1024, 300).unwrap();
326 assert!(!nonces.contains(&token.nonce), "Duplicate nonce detected");
327 nonces.insert(token.nonce);
328 }
329 }
330
331 #[test]
332 fn test_token_signable_data() {
333 let authenticator = RelayAuthenticator::new().unwrap();
334 let token1 = authenticator.create_token(1024, 300).unwrap();
335 let token2 = authenticator.create_token(1024, 300).unwrap();
336
337 assert_ne!(token1.signable_data(), token2.signable_data());
339 }
340
341 #[test]
342 fn test_nonce_window_management() {
343 let authenticator = RelayAuthenticator::new().unwrap();
344
345 for _ in 0..1000 {
347 let token = authenticator.create_token(1024, 300).unwrap();
348 let _ = authenticator.verify_token(&token, authenticator.public_key());
349 }
350
351 assert_eq!(authenticator.nonce_count(), 1000);
352
353 let token = authenticator.create_token(1024, 300).unwrap();
355 let _ = authenticator.verify_token(&token, authenticator.public_key());
356
357 assert!(authenticator.nonce_count() <= 1000);
359 }
360
361 #[test]
362 fn test_clear_nonces() {
363 let authenticator = RelayAuthenticator::new().unwrap();
364 let token = authenticator.create_token(1024, 300).unwrap();
365
366 let _ = authenticator.verify_token(&token, authenticator.public_key());
368 assert!(authenticator.nonce_count() > 0);
369
370 authenticator.clear_nonces();
372 assert_eq!(authenticator.nonce_count(), 0);
373
374 assert!(
376 authenticator
377 .verify_token(&token, authenticator.public_key())
378 .is_ok()
379 );
380 }
381
382 #[test]
383 fn test_with_specific_keypair() {
384 let (public_key, secret_key) = generate_ml_dsa_keypair().unwrap();
385 let authenticator = RelayAuthenticator::with_keypair(public_key, secret_key);
386
387 let token = authenticator.create_token(1024, 300).unwrap();
388 assert!(token.verify(authenticator.public_key()).is_ok());
389 }
390}