network_protocol/protocol/
handshake.rs1use crate::error::{ProtocolError, Result};
13use crate::protocol::message::Message;
14use rand_core::{OsRng, RngCore};
15use sha2::{Digest, Sha256};
16use std::time::{SystemTime, UNIX_EPOCH};
17use x25519_dalek::{EphemeralSecret, PublicKey, SharedSecret};
18use zeroize::Zeroize;
19
20#[allow(unused_imports)]
21use tracing::{debug, instrument, warn};
22
23#[derive(Zeroize)]
25#[zeroize(drop)]
26pub struct ClientHandshakeState {
27 secret: Option<EphemeralSecret>,
28 public: Option<[u8; 32]>,
29 server_public: Option<[u8; 32]>,
30 client_nonce: Option<[u8; 16]>,
31 server_nonce: Option<[u8; 16]>,
32}
33
34impl ClientHandshakeState {
35 pub fn new() -> Self {
37 Self {
38 secret: None,
39 public: None,
40 server_public: None,
41 client_nonce: None,
42 server_nonce: None,
43 }
44 }
45
46 #[cfg(test)]
48 pub fn client_nonce(&self) -> Option<&[u8; 16]> {
49 self.client_nonce.as_ref()
50 }
51
52 #[cfg(test)]
54 pub fn server_nonce(&self) -> Option<&[u8; 16]> {
55 self.server_nonce.as_ref()
56 }
57}
58
59impl Default for ClientHandshakeState {
60 fn default() -> Self {
61 Self::new()
62 }
63}
64
65#[derive(Zeroize)]
67#[zeroize(drop)]
68pub struct ServerHandshakeState {
69 secret: Option<EphemeralSecret>,
70 public: Option<[u8; 32]>,
71 client_public: Option<[u8; 32]>,
72 client_nonce: Option<[u8; 16]>,
73 server_nonce: Option<[u8; 16]>,
74}
75
76impl ServerHandshakeState {
77 pub fn new() -> Self {
79 Self {
80 secret: None,
81 public: None,
82 client_public: None,
83 client_nonce: None,
84 server_nonce: None,
85 }
86 }
87
88 #[cfg(test)]
90 pub fn server_nonce(&self) -> Option<&[u8; 16]> {
91 self.server_nonce.as_ref()
92 }
93
94 #[cfg(test)]
96 pub fn client_public(&self) -> Option<&[u8; 32]> {
97 self.client_public.as_ref()
98 }
99}
100
101impl Default for ServerHandshakeState {
102 fn default() -> Self {
103 Self::new()
104 }
105}
106
107fn current_timestamp() -> Result<u64> {
112 SystemTime::now()
113 .duration_since(UNIX_EPOCH)
114 .map(|duration| duration.as_millis() as u64)
115 .map_err(|_| ProtocolError::Custom("System time error: time went backwards".to_string()))
116}
117
118fn generate_nonce() -> [u8; 16] {
120 let mut nonce = [0u8; 16];
121 OsRng.fill_bytes(&mut nonce);
122 nonce
123}
124
125pub fn verify_timestamp(timestamp: u64, max_age_seconds: u64) -> bool {
128 let current = match current_timestamp() {
129 Ok(time) => time,
130 Err(_) => return false,
131 };
132
133 let max_age_ms = max_age_seconds * 1000;
134 const FUTURE_TOLERANCE_MS: u64 = 2000; if timestamp > current + FUTURE_TOLERANCE_MS {
138 return false;
139 }
140
141 if current > timestamp && current - timestamp > max_age_ms {
143 return false;
144 }
145
146 true
147}
148
149fn hash_nonce(nonce: &[u8]) -> [u8; 32] {
151 let mut hasher = Sha256::new();
152 hasher.update(nonce);
153 hasher.finalize().into()
154}
155
156fn derive_key_from_shared_secret(
158 shared_secret: &SharedSecret,
159 client_nonce: &[u8],
160 server_nonce: &[u8],
161) -> [u8; 32] {
162 let mut hasher = Sha256::new();
163
164 hasher.update(shared_secret.as_bytes());
166
167 hasher.update(b"client_nonce");
169 hasher.update(client_nonce);
170 hasher.update(b"server_nonce");
171 hasher.update(server_nonce);
172
173 hasher.finalize().into()
174}
175
176#[instrument]
185pub fn client_secure_handshake_init() -> Result<(ClientHandshakeState, Message)> {
186 let client_secret = EphemeralSecret::random_from_rng(OsRng);
188 let client_public = PublicKey::from(&client_secret);
189
190 let nonce = generate_nonce();
192 let timestamp = current_timestamp()?;
193
194 let mut state = ClientHandshakeState::new();
195 state.secret = Some(client_secret);
196 state.public = Some(client_public.to_bytes());
197 state.client_nonce = Some(nonce);
198
199 debug!("Client initiating secure handshake");
200
201 Ok((
202 state,
203 Message::SecureHandshakeInit {
204 pub_key: client_public.to_bytes(),
205 timestamp,
206 nonce,
207 },
208 ))
209}
210
211#[instrument(skip(client_pub_key, client_nonce))]
220pub fn server_secure_handshake_response(
221 client_pub_key: [u8; 32],
222 client_nonce: [u8; 16],
223 client_timestamp: u64,
224) -> Result<(ServerHandshakeState, Message)> {
225 if !verify_timestamp(client_timestamp, 30) {
227 return Err(ProtocolError::HandshakeError(
228 "Invalid or stale timestamp".to_string(),
229 ));
230 }
231
232 let server_secret = EphemeralSecret::random_from_rng(OsRng);
234 let server_public = PublicKey::from(&server_secret);
235 let server_nonce = generate_nonce();
236
237 let nonce_verification = hash_nonce(&client_nonce);
239
240 let mut state = ServerHandshakeState::new();
241 state.secret = Some(server_secret);
242 state.public = Some(server_public.to_bytes());
243 state.client_public = Some(client_pub_key);
244 state.client_nonce = Some(client_nonce);
245 state.server_nonce = Some(server_nonce);
246
247 debug!("Server responding to handshake initiation");
248
249 Ok((
250 state,
251 Message::SecureHandshakeResponse {
252 pub_key: server_public.to_bytes(),
253 nonce: server_nonce,
254 nonce_verification,
255 },
256 ))
257}
258
259#[instrument(skip(state, server_pub_key, server_nonce, nonce_verification))]
268pub fn client_secure_handshake_verify(
269 mut state: ClientHandshakeState,
270 server_pub_key: [u8; 32],
271 server_nonce: [u8; 16],
272 nonce_verification: [u8; 32],
273) -> Result<(ClientHandshakeState, Message)> {
274 let client_nonce = state
276 .client_nonce
277 .ok_or_else(|| ProtocolError::HandshakeError("Client nonce not found".to_string()))?;
278
279 let expected_verification = hash_nonce(&client_nonce);
280
281 if expected_verification != nonce_verification {
282 return Err(ProtocolError::HandshakeError(
283 "Server failed to verify client nonce".to_string(),
284 ));
285 }
286
287 state.server_public = Some(server_pub_key);
289 state.server_nonce = Some(server_nonce);
290
291 let hash = hash_nonce(&server_nonce);
293
294 debug!("Client verified server response");
295
296 Ok((
297 state,
298 Message::SecureHandshakeConfirm {
299 nonce_verification: hash,
300 },
301 ))
302}
303
304#[instrument(skip(state, nonce_verification))]
313pub fn server_secure_handshake_finalize(
314 mut state: ServerHandshakeState,
315 nonce_verification: [u8; 32],
316) -> Result<[u8; 32]> {
317 let server_nonce = state
319 .server_nonce
320 .ok_or_else(|| ProtocolError::HandshakeError("Server nonce not found".to_string()))?;
321
322 let expected_verification = hash_nonce(&server_nonce);
323
324 if expected_verification != nonce_verification {
325 return Err(ProtocolError::HandshakeError(
326 "Client failed to verify server nonce".to_string(),
327 ));
328 }
329
330 let server_secret = state
332 .secret
333 .take()
334 .ok_or_else(|| ProtocolError::HandshakeError("Server secret not found".to_string()))?;
335 let client_public_bytes = state
336 .client_public
337 .ok_or_else(|| ProtocolError::HandshakeError("Client public key not found".to_string()))?;
338 let client_nonce = state
339 .client_nonce
340 .ok_or_else(|| ProtocolError::HandshakeError("Client nonce not found".to_string()))?;
341
342 let client_public = PublicKey::from(client_public_bytes);
344 let shared_secret = server_secret.diffie_hellman(&client_public);
345
346 let key = derive_key_from_shared_secret(&shared_secret, &client_nonce, &server_nonce);
348
349 debug!("Server finalized handshake and derived session key");
351
352 Ok(key)
353}
354
355#[instrument(skip(state))]
364pub fn client_derive_session_key(mut state: ClientHandshakeState) -> Result<[u8; 32]> {
365 let client_secret = state
367 .secret
368 .take()
369 .ok_or_else(|| ProtocolError::HandshakeError("Client secret not found".to_string()))?;
370 let server_public_bytes = state
371 .server_public
372 .ok_or_else(|| ProtocolError::HandshakeError("Server public key not found".to_string()))?;
373 let client_nonce = state
374 .client_nonce
375 .ok_or_else(|| ProtocolError::HandshakeError("Client nonce not found".to_string()))?;
376 let server_nonce = state
377 .server_nonce
378 .ok_or_else(|| ProtocolError::HandshakeError("Server nonce not found".to_string()))?;
379
380 let server_public = PublicKey::from(server_public_bytes);
382 let shared_secret = client_secret.diffie_hellman(&server_public);
383
384 let key = derive_key_from_shared_secret(&shared_secret, &client_nonce, &server_nonce);
386
387 debug!("Client derived session key");
389
390 Ok(key)
391}
392
393#[cfg(test)]
394#[allow(clippy::unwrap_used, clippy::panic)]
395mod tests {
396 use super::*;
397
398 #[test]
399 fn test_per_session_state_isolation() {
400 let (client1, msg1) = client_secure_handshake_init().unwrap();
402 let (client2, msg2) = client_secure_handshake_init().unwrap();
403
404 let (pub_key1, ts1, nonce1) = match msg1 {
406 Message::SecureHandshakeInit {
407 pub_key,
408 timestamp,
409 nonce,
410 } => (pub_key, timestamp, nonce),
411 _ => panic!("Wrong message type"),
412 };
413
414 let (pub_key2, ts2, nonce2) = match msg2 {
415 Message::SecureHandshakeInit {
416 pub_key,
417 timestamp,
418 nonce,
419 } => (pub_key, timestamp, nonce),
420 _ => panic!("Wrong message type"),
421 };
422
423 assert_ne!(pub_key1, pub_key2);
425 assert_ne!(nonce1, nonce2);
426
427 let (server1, resp1) = server_secure_handshake_response(pub_key1, nonce1, ts1).unwrap();
429 let (server2, resp2) = server_secure_handshake_response(pub_key2, nonce2, ts2).unwrap();
430
431 let (server_pub1, server_nonce1, verify1) = match resp1 {
432 Message::SecureHandshakeResponse {
433 pub_key,
434 nonce,
435 nonce_verification,
436 } => (pub_key, nonce, nonce_verification),
437 _ => panic!("Wrong message type"),
438 };
439
440 let (server_pub2, server_nonce2, verify2) = match resp2 {
441 Message::SecureHandshakeResponse {
442 pub_key,
443 nonce,
444 nonce_verification,
445 } => (pub_key, nonce, nonce_verification),
446 _ => panic!("Wrong message type"),
447 };
448
449 assert_ne!(server_pub1, server_pub2);
450 assert_ne!(server_nonce1, server_nonce2);
451
452 let (client1_verified, confirm1) =
454 client_secure_handshake_verify(client1, server_pub1, server_nonce1, verify1).unwrap();
455 let (client2_verified, confirm2) =
456 client_secure_handshake_verify(client2, server_pub2, server_nonce2, verify2).unwrap();
457
458 let confirm_hash1 = match confirm1 {
459 Message::SecureHandshakeConfirm { nonce_verification } => nonce_verification,
460 _ => panic!("Wrong message type"),
461 };
462
463 let confirm_hash2 = match confirm2 {
464 Message::SecureHandshakeConfirm { nonce_verification } => nonce_verification,
465 _ => panic!("Wrong message type"),
466 };
467
468 assert_ne!(confirm_hash1, confirm_hash2);
469
470 let key1_server = server_secure_handshake_finalize(server1, confirm_hash1).unwrap();
472 let key1_client = client_derive_session_key(client1_verified).unwrap();
473
474 let key2_server = server_secure_handshake_finalize(server2, confirm_hash2).unwrap();
475 let key2_client = client_derive_session_key(client2_verified).unwrap();
476
477 assert_eq!(key1_server, key1_client);
479 assert_eq!(key2_server, key2_client);
480
481 assert_ne!(key1_server, key2_server);
483 }
484
485 #[test]
486 fn test_timestamp_validation() {
487 let now = current_timestamp().unwrap();
488 assert!(verify_timestamp(now, 30));
489 assert!(verify_timestamp(now - 10000, 30)); assert!(!verify_timestamp(now - 31000, 30)); assert!(verify_timestamp(now + 1000, 30)); assert!(!verify_timestamp(now + 3000, 30)); }
494
495 #[test]
496 fn test_nonce_verification() {
497 let nonce = generate_nonce();
498 let hash = hash_nonce(&nonce);
499 assert_eq!(hash.len(), 32);
500 assert_eq!(hash, hash_nonce(&nonce));
502 let mut different_nonce = nonce;
504 different_nonce[0] ^= 0xFF;
505 assert_ne!(hash, hash_nonce(&different_nonce));
506 }
507}