1use crate::error::{constants, ProtocolError, Result};
13use crate::protocol::message::Message;
14use crate::utils::replay_cache::ReplayCache;
15use rand_core::{OsRng, RngCore};
16use sha2::{Digest, Sha256};
17use std::time::{SystemTime, UNIX_EPOCH};
18use x25519_dalek::{EphemeralSecret, PublicKey, SharedSecret};
19use zeroize::Zeroize;
20
21#[allow(unused_imports)]
22use tracing::{debug, instrument, warn};
23
24#[derive(Zeroize)]
26#[zeroize(drop)]
27pub struct ClientHandshakeState {
28 secret: Option<EphemeralSecret>,
29 public: Option<[u8; 32]>,
30 server_public: Option<[u8; 32]>,
31 client_nonce: Option<[u8; 16]>,
32 server_nonce: Option<[u8; 16]>,
33}
34
35impl ClientHandshakeState {
36 pub fn new() -> Self {
38 Self {
39 secret: None,
40 public: None,
41 server_public: None,
42 client_nonce: None,
43 server_nonce: None,
44 }
45 }
46
47 #[cfg(test)]
49 pub fn client_nonce(&self) -> Option<&[u8; 16]> {
50 self.client_nonce.as_ref()
51 }
52
53 #[cfg(test)]
55 pub fn server_nonce(&self) -> Option<&[u8; 16]> {
56 self.server_nonce.as_ref()
57 }
58}
59
60impl Default for ClientHandshakeState {
61 fn default() -> Self {
62 Self::new()
63 }
64}
65
66#[derive(Zeroize)]
68#[zeroize(drop)]
69pub struct ServerHandshakeState {
70 secret: Option<EphemeralSecret>,
71 public: Option<[u8; 32]>,
72 client_public: Option<[u8; 32]>,
73 client_nonce: Option<[u8; 16]>,
74 server_nonce: Option<[u8; 16]>,
75}
76
77impl ServerHandshakeState {
78 pub fn new() -> Self {
80 Self {
81 secret: None,
82 public: None,
83 client_public: None,
84 client_nonce: None,
85 server_nonce: None,
86 }
87 }
88
89 #[cfg(test)]
91 pub fn server_nonce(&self) -> Option<&[u8; 16]> {
92 self.server_nonce.as_ref()
93 }
94
95 #[cfg(test)]
97 pub fn client_public(&self) -> Option<&[u8; 32]> {
98 self.client_public.as_ref()
99 }
100}
101
102impl Default for ServerHandshakeState {
103 fn default() -> Self {
104 Self::new()
105 }
106}
107
108fn current_timestamp() -> Result<u64> {
113 SystemTime::now()
114 .duration_since(UNIX_EPOCH)
115 .map(|duration| duration.as_millis() as u64)
116 .map_err(|_| ProtocolError::Custom(constants::ERR_SYSTEM_TIME.into()))
117}
118
119fn generate_nonce() -> [u8; 16] {
121 let mut nonce = [0u8; 16];
122 OsRng.fill_bytes(&mut nonce);
123 nonce
124}
125
126pub fn verify_timestamp(timestamp: u64, max_age_seconds: u64) -> bool {
129 let current = match current_timestamp() {
130 Ok(time) => time,
131 Err(_) => return false,
132 };
133
134 let max_age_ms = max_age_seconds * 1000;
135 const FUTURE_TOLERANCE_MS: u64 = 2000; if timestamp > current + FUTURE_TOLERANCE_MS {
139 return false;
140 }
141
142 if current > timestamp && current - timestamp > max_age_ms {
144 return false;
145 }
146
147 true
148}
149
150fn hash_nonce(nonce: &[u8]) -> [u8; 32] {
152 let mut hasher = Sha256::new();
153 hasher.update(nonce);
154 hasher.finalize().into()
155}
156
157fn derive_key_from_shared_secret(
159 shared_secret: &SharedSecret,
160 client_nonce: &[u8],
161 server_nonce: &[u8],
162) -> [u8; 32] {
163 let mut hasher = Sha256::new();
164
165 hasher.update(shared_secret.as_bytes());
167
168 hasher.update(b"client_nonce");
170 hasher.update(client_nonce);
171 hasher.update(b"server_nonce");
172 hasher.update(server_nonce);
173
174 hasher.finalize().into()
175}
176
177#[instrument]
186pub fn client_secure_handshake_init() -> Result<(ClientHandshakeState, Message)> {
187 let client_secret = EphemeralSecret::random_from_rng(OsRng);
189 let client_public = PublicKey::from(&client_secret);
190
191 let nonce = generate_nonce();
193 let timestamp = current_timestamp()?;
194
195 let mut state = ClientHandshakeState::new();
196 state.secret = Some(client_secret);
197 state.public = Some(client_public.to_bytes());
198 state.client_nonce = Some(nonce);
199
200 debug!("Client initiating secure handshake");
201
202 Ok((
203 state,
204 Message::SecureHandshakeInit {
205 pub_key: client_public.to_bytes(),
206 timestamp,
207 nonce,
208 },
209 ))
210}
211
212#[instrument(skip(client_pub_key, client_nonce, replay_cache))]
221pub fn server_secure_handshake_response(
222 client_pub_key: [u8; 32],
223 client_nonce: [u8; 16],
224 client_timestamp: u64,
225 peer_id: &str,
226 replay_cache: &mut ReplayCache,
227) -> Result<(ServerHandshakeState, Message)> {
228 if !verify_timestamp(client_timestamp, 30) {
230 return Err(ProtocolError::HandshakeError(
231 constants::ERR_INVALID_TIMESTAMP.into(),
232 ));
233 }
234
235 if replay_cache.is_replay(peer_id, &client_nonce, client_timestamp) {
237 return Err(ProtocolError::HandshakeError(
238 constants::ERR_REPLAY_ATTACK.into(),
239 ));
240 }
241
242 let server_secret = EphemeralSecret::random_from_rng(OsRng);
244 let server_public = PublicKey::from(&server_secret);
245 let server_nonce = generate_nonce();
246
247 let nonce_verification = hash_nonce(&client_nonce);
249
250 let mut state = ServerHandshakeState::new();
251 state.secret = Some(server_secret);
252 state.public = Some(server_public.to_bytes());
253 state.client_public = Some(client_pub_key);
254 state.client_nonce = Some(client_nonce);
255 state.server_nonce = Some(server_nonce);
256
257 debug!("Server responding to handshake initiation");
258
259 Ok((
260 state,
261 Message::SecureHandshakeResponse {
262 pub_key: server_public.to_bytes(),
263 nonce: server_nonce,
264 nonce_verification,
265 },
266 ))
267}
268
269#[instrument(skip(state, server_pub_key, server_nonce, nonce_verification, replay_cache))]
278pub fn client_secure_handshake_verify(
279 mut state: ClientHandshakeState,
280 server_pub_key: [u8; 32],
281 server_nonce: [u8; 16],
282 nonce_verification: [u8; 32],
283 peer_id: &str,
284 replay_cache: &mut ReplayCache,
285) -> Result<(ClientHandshakeState, Message)> {
286 if replay_cache.is_replay(peer_id, &server_nonce, 0) {
288 return Err(ProtocolError::HandshakeError(
290 constants::ERR_REPLAY_ATTACK.into(),
291 ));
292 }
293
294 let client_nonce = state.client_nonce.ok_or_else(|| {
296 ProtocolError::HandshakeError(constants::ERR_CLIENT_NONCE_NOT_FOUND.into())
297 })?;
298
299 let expected_verification = hash_nonce(&client_nonce);
300
301 if expected_verification != nonce_verification {
302 return Err(ProtocolError::HandshakeError(
303 constants::ERR_NONCE_VERIFICATION_FAILED.into(),
304 ));
305 }
306
307 state.server_public = Some(server_pub_key);
309 state.server_nonce = Some(server_nonce);
310 let client_nonce = state.client_nonce.ok_or_else(|| {
312 ProtocolError::HandshakeError(constants::ERR_CLIENT_NONCE_NOT_FOUND.into())
313 })?;
314
315 let expected_verification = hash_nonce(&client_nonce);
316
317 if expected_verification != nonce_verification {
318 return Err(ProtocolError::HandshakeError(
319 constants::ERR_NONCE_VERIFICATION_FAILED.into(),
320 ));
321 }
322
323 state.server_public = Some(server_pub_key);
325 state.server_nonce = Some(server_nonce);
326
327 let hash = hash_nonce(&server_nonce);
329
330 debug!("Client verified server response");
331
332 Ok((
333 state,
334 Message::SecureHandshakeConfirm {
335 nonce_verification: hash,
336 },
337 ))
338}
339
340#[instrument(skip(state, nonce_verification))]
349pub fn server_secure_handshake_finalize(
350 mut state: ServerHandshakeState,
351 nonce_verification: [u8; 32],
352) -> Result<[u8; 32]> {
353 let server_nonce = state.server_nonce.ok_or_else(|| {
355 ProtocolError::HandshakeError(constants::ERR_SERVER_NONCE_NOT_FOUND.into())
356 })?;
357
358 let expected_verification = hash_nonce(&server_nonce);
359
360 if expected_verification != nonce_verification {
361 return Err(ProtocolError::HandshakeError(
362 constants::ERR_SERVER_VERIFICATION_FAILED.into(),
363 ));
364 }
365
366 let server_secret = state.secret.take().ok_or_else(|| {
368 ProtocolError::HandshakeError(constants::ERR_SERVER_SECRET_NOT_FOUND.into())
369 })?;
370 let client_public_bytes = state.client_public.ok_or_else(|| {
371 ProtocolError::HandshakeError(constants::ERR_CLIENT_PUBLIC_NOT_FOUND.into())
372 })?;
373 let client_nonce = state.client_nonce.ok_or_else(|| {
374 ProtocolError::HandshakeError(constants::ERR_CLIENT_NONCE_NOT_FOUND.into())
375 })?;
376
377 let client_public = PublicKey::from(client_public_bytes);
379 let shared_secret = server_secret.diffie_hellman(&client_public);
380
381 let key = derive_key_from_shared_secret(&shared_secret, &client_nonce, &server_nonce);
383
384 debug!("Server finalized handshake and derived session key");
386
387 Ok(key)
388}
389
390#[instrument(skip(state))]
399pub fn client_derive_session_key(mut state: ClientHandshakeState) -> Result<[u8; 32]> {
400 let client_secret = state.secret.take().ok_or_else(|| {
402 ProtocolError::HandshakeError(constants::ERR_CLIENT_SECRET_NOT_FOUND.into())
403 })?;
404 let server_public_bytes = state.server_public.ok_or_else(|| {
405 ProtocolError::HandshakeError(constants::ERR_SERVER_PUBLIC_NOT_FOUND.into())
406 })?;
407 let client_nonce = state.client_nonce.ok_or_else(|| {
408 ProtocolError::HandshakeError(constants::ERR_CLIENT_NONCE_NOT_FOUND.into())
409 })?;
410 let server_nonce = state.server_nonce.ok_or_else(|| {
411 ProtocolError::HandshakeError(constants::ERR_SERVER_NONCE_NOT_FOUND.into())
412 })?;
413
414 let server_public = PublicKey::from(server_public_bytes);
416 let shared_secret = client_secret.diffie_hellman(&server_public);
417
418 let key = derive_key_from_shared_secret(&shared_secret, &client_nonce, &server_nonce);
420
421 debug!("Client derived session key");
423
424 Ok(key)
425}
426
427#[cfg(test)]
428#[allow(clippy::unwrap_used, clippy::panic)]
429mod tests {
430 use super::*;
431
432 #[test]
433 fn test_per_session_state_isolation() {
434 let mut replay_cache = crate::utils::replay_cache::ReplayCache::new();
435 let peer_id = "test-peer";
436
437 let (client1, msg1) = client_secure_handshake_init().unwrap();
439 let (client2, msg2) = client_secure_handshake_init().unwrap();
440
441 let (pub_key1, ts1, nonce1) = match msg1 {
443 Message::SecureHandshakeInit {
444 pub_key,
445 timestamp,
446 nonce,
447 } => (pub_key, timestamp, nonce),
448 _ => panic!("Wrong message type"),
449 };
450
451 let (pub_key2, ts2, nonce2) = match msg2 {
452 Message::SecureHandshakeInit {
453 pub_key,
454 timestamp,
455 nonce,
456 } => (pub_key, timestamp, nonce),
457 _ => panic!("Wrong message type"),
458 };
459
460 assert_ne!(pub_key1, pub_key2);
462 assert_ne!(nonce1, nonce2);
463
464 let (server1, resp1) =
466 server_secure_handshake_response(pub_key1, nonce1, ts1, peer_id, &mut replay_cache)
467 .unwrap();
468 let (server2, resp2) =
469 server_secure_handshake_response(pub_key2, nonce2, ts2, peer_id, &mut replay_cache)
470 .unwrap();
471
472 let (server_pub1, server_nonce1, verify1) = match resp1 {
473 Message::SecureHandshakeResponse {
474 pub_key,
475 nonce,
476 nonce_verification,
477 } => (pub_key, nonce, nonce_verification),
478 _ => panic!("Wrong message type"),
479 };
480
481 let (server_pub2, server_nonce2, verify2) = match resp2 {
482 Message::SecureHandshakeResponse {
483 pub_key,
484 nonce,
485 nonce_verification,
486 } => (pub_key, nonce, nonce_verification),
487 _ => panic!("Wrong message type"),
488 };
489
490 assert_ne!(server_pub1, server_pub2);
491 assert_ne!(server_nonce1, server_nonce2);
492
493 let (client1_verified, confirm1) = client_secure_handshake_verify(
495 client1,
496 server_pub1,
497 server_nonce1,
498 verify1,
499 peer_id,
500 &mut replay_cache,
501 )
502 .unwrap();
503 let (client2_verified, confirm2) = client_secure_handshake_verify(
504 client2,
505 server_pub2,
506 server_nonce2,
507 verify2,
508 peer_id,
509 &mut replay_cache,
510 )
511 .unwrap();
512
513 let confirm_hash1 = match confirm1 {
514 Message::SecureHandshakeConfirm { nonce_verification } => nonce_verification,
515 _ => panic!("Wrong message type"),
516 };
517
518 let confirm_hash2 = match confirm2 {
519 Message::SecureHandshakeConfirm { nonce_verification } => nonce_verification,
520 _ => panic!("Wrong message type"),
521 };
522
523 assert_ne!(confirm_hash1, confirm_hash2);
524
525 let key1_server = server_secure_handshake_finalize(server1, confirm_hash1).unwrap();
527 let key1_client = client_derive_session_key(client1_verified).unwrap();
528
529 let key2_server = server_secure_handshake_finalize(server2, confirm_hash2).unwrap();
530 let key2_client = client_derive_session_key(client2_verified).unwrap();
531
532 assert_eq!(key1_server, key1_client);
534 assert_eq!(key2_server, key2_client);
535
536 assert_ne!(key1_server, key2_server);
538 }
539
540 #[test]
541 fn test_timestamp_validation() {
542 let now = current_timestamp().unwrap();
543 assert!(verify_timestamp(now, 30));
544 assert!(verify_timestamp(now - 10000, 30)); assert!(!verify_timestamp(now - 31000, 30)); assert!(verify_timestamp(now + 1000, 30)); assert!(!verify_timestamp(now + 3000, 30)); }
549
550 #[test]
551 fn test_nonce_verification() {
552 let nonce = generate_nonce();
553 let hash = hash_nonce(&nonce);
554 assert_eq!(hash.len(), 32);
555 assert_eq!(hash, hash_nonce(&nonce));
557 let mut different_nonce = nonce;
559 different_nonce[0] ^= 0xFF;
560 assert_ne!(hash, hash_nonce(&different_nonce));
561 }
562}