network_protocol/protocol/
handshake.rs

1//! Secure handshake protocol implementation using Elliptic Curve Diffie-Hellman (ECDH)
2//!
3//! This module implements a secure cryptographic handshake based on x25519-dalek
4//! with protection against replay attacks using timestamped nonces.
5//!
6//! **Key Change: Per-Session State**
7//! Instead of global singletons, handshake state is now managed through session-scoped
8//! structures (`ClientHandshakeState`, `ServerHandshakeState`) that are passed through
9//! the handshake flow. This prevents concurrent handshake state trampling and ensures
10//! clean state per connection.
11
12use crate::error::{ProtocolError, Result};
13use crate::protocol::message::Message;
14use rand_core::{OsRng, RngCore};
15use sha2::{Digest, Sha256};
16use std::time::{SystemTime, UNIX_EPOCH};
17use x25519_dalek::{EphemeralSecret, PublicKey, SharedSecret};
18use zeroize::Zeroize;
19
20#[allow(unused_imports)]
21use tracing::{debug, instrument, warn};
22
23/// Client-side handshake state - passed through the handshake flow
24#[derive(Zeroize)]
25#[zeroize(drop)]
26pub struct ClientHandshakeState {
27    secret: Option<EphemeralSecret>,
28    public: Option<[u8; 32]>,
29    server_public: Option<[u8; 32]>,
30    client_nonce: Option<[u8; 16]>,
31    server_nonce: Option<[u8; 16]>,
32}
33
34impl ClientHandshakeState {
35    /// Create a new empty client handshake state
36    pub fn new() -> Self {
37        Self {
38            secret: None,
39            public: None,
40            server_public: None,
41            client_nonce: None,
42            server_nonce: None,
43        }
44    }
45
46    /// Get reference to client nonce (for testing)
47    #[cfg(test)]
48    pub fn client_nonce(&self) -> Option<&[u8; 16]> {
49        self.client_nonce.as_ref()
50    }
51
52    /// Get reference to server nonce (for testing)
53    #[cfg(test)]
54    pub fn server_nonce(&self) -> Option<&[u8; 16]> {
55        self.server_nonce.as_ref()
56    }
57}
58
59impl Default for ClientHandshakeState {
60    fn default() -> Self {
61        Self::new()
62    }
63}
64
65/// Server-side handshake state - passed through the handshake flow
66#[derive(Zeroize)]
67#[zeroize(drop)]
68pub struct ServerHandshakeState {
69    secret: Option<EphemeralSecret>,
70    public: Option<[u8; 32]>,
71    client_public: Option<[u8; 32]>,
72    client_nonce: Option<[u8; 16]>,
73    server_nonce: Option<[u8; 16]>,
74}
75
76impl ServerHandshakeState {
77    /// Create a new empty server handshake state
78    pub fn new() -> Self {
79        Self {
80            secret: None,
81            public: None,
82            client_public: None,
83            client_nonce: None,
84            server_nonce: None,
85        }
86    }
87
88    /// Get reference to server nonce (for testing)
89    #[cfg(test)]
90    pub fn server_nonce(&self) -> Option<&[u8; 16]> {
91        self.server_nonce.as_ref()
92    }
93
94    /// Get reference to client public key (for testing)
95    #[cfg(test)]
96    pub fn client_public(&self) -> Option<&[u8; 32]> {
97        self.client_public.as_ref()
98    }
99}
100
101impl Default for ServerHandshakeState {
102    fn default() -> Self {
103        Self::new()
104    }
105}
106
107/// Get the current timestamp in milliseconds
108///
109/// # Errors
110/// Returns a `ProtocolError::Custom` if the system time is earlier than UNIX_EPOCH
111fn current_timestamp() -> Result<u64> {
112    SystemTime::now()
113        .duration_since(UNIX_EPOCH)
114        .map(|duration| duration.as_millis() as u64)
115        .map_err(|_| ProtocolError::Custom("System time error: time went backwards".to_string()))
116}
117
118/// Generate a cryptographically secure random nonce
119fn generate_nonce() -> [u8; 16] {
120    let mut nonce = [0u8; 16];
121    OsRng.fill_bytes(&mut nonce);
122    nonce
123}
124
125/// Verify that a timestamp is recent enough
126/// Default threshold is 30 seconds with a strict 2-second future tolerance for clock skew
127pub fn verify_timestamp(timestamp: u64, max_age_seconds: u64) -> bool {
128    let current = match current_timestamp() {
129        Ok(time) => time,
130        Err(_) => return false,
131    };
132
133    let max_age_ms = max_age_seconds * 1000;
134    const FUTURE_TOLERANCE_MS: u64 = 2000; // 2 seconds max clock skew
135
136    // Check if timestamp is from the future (strict tolerance for clock skew)
137    if timestamp > current + FUTURE_TOLERANCE_MS {
138        return false;
139    }
140
141    // Check if timestamp is too old
142    if current > timestamp && current - timestamp > max_age_ms {
143        return false;
144    }
145
146    true
147}
148
149/// Compute hash of a nonce for verification
150fn hash_nonce(nonce: &[u8]) -> [u8; 32] {
151    let mut hasher = Sha256::new();
152    hasher.update(nonce);
153    hasher.finalize().into()
154}
155
156/// Derive a session key from a shared secret and nonces
157fn derive_key_from_shared_secret(
158    shared_secret: &SharedSecret,
159    client_nonce: &[u8],
160    server_nonce: &[u8],
161) -> [u8; 32] {
162    let mut hasher = Sha256::new();
163
164    // Include shared secret
165    hasher.update(shared_secret.as_bytes());
166
167    // Include both nonces for additional security (order matters for domain separation)
168    hasher.update(b"client_nonce");
169    hasher.update(client_nonce);
170    hasher.update(b"server_nonce");
171    hasher.update(server_nonce);
172
173    hasher.finalize().into()
174}
175
176/// Initiates secure handshake from the client side.
177/// Generates a new key pair and nonce for the client.
178///
179/// # Returns
180/// A tuple of (new `ClientHandshakeState`, `Message::SecureHandshakeInit`)
181///
182/// # Errors
183/// Returns timestamp errors if system time is invalid
184#[instrument]
185pub fn client_secure_handshake_init() -> Result<(ClientHandshakeState, Message)> {
186    // Generate a new client key pair using OsRng
187    let client_secret = EphemeralSecret::random_from_rng(OsRng);
188    let client_public = PublicKey::from(&client_secret);
189
190    // Generate nonce and timestamp
191    let nonce = generate_nonce();
192    let timestamp = current_timestamp()?;
193
194    let mut state = ClientHandshakeState::new();
195    state.secret = Some(client_secret);
196    state.public = Some(client_public.to_bytes());
197    state.client_nonce = Some(nonce);
198
199    debug!("Client initiating secure handshake");
200
201    Ok((
202        state,
203        Message::SecureHandshakeInit {
204            pub_key: client_public.to_bytes(),
205            timestamp,
206            nonce,
207        },
208    ))
209}
210
211/// Generates server response to client handshake initialization.
212/// Validates client timestamp, generates server key pair and nonce.
213///
214/// # Returns
215/// A tuple of (new `ServerHandshakeState`, `Message::SecureHandshakeResponse`)
216///
217/// # Errors
218/// Returns `ProtocolError::HandshakeError` if client timestamp is invalid or too old
219#[instrument(skip(client_pub_key, client_nonce))]
220pub fn server_secure_handshake_response(
221    client_pub_key: [u8; 32],
222    client_nonce: [u8; 16],
223    client_timestamp: u64,
224) -> Result<(ServerHandshakeState, Message)> {
225    // Validate the client timestamp (must be within last 30 seconds)
226    if !verify_timestamp(client_timestamp, 30) {
227        return Err(ProtocolError::HandshakeError(
228            "Invalid or stale timestamp".to_string(),
229        ));
230    }
231
232    // Generate server key pair and nonce
233    let server_secret = EphemeralSecret::random_from_rng(OsRng);
234    let server_public = PublicKey::from(&server_secret);
235    let server_nonce = generate_nonce();
236
237    // Compute verification hash of client nonce
238    let nonce_verification = hash_nonce(&client_nonce);
239
240    let mut state = ServerHandshakeState::new();
241    state.secret = Some(server_secret);
242    state.public = Some(server_public.to_bytes());
243    state.client_public = Some(client_pub_key);
244    state.client_nonce = Some(client_nonce);
245    state.server_nonce = Some(server_nonce);
246
247    debug!("Server responding to handshake initiation");
248
249    Ok((
250        state,
251        Message::SecureHandshakeResponse {
252            pub_key: server_public.to_bytes(),
253            nonce: server_nonce,
254            nonce_verification,
255        },
256    ))
257}
258
259/// Client verifies server response and sends verification message.
260/// Updates client state and returns confirmation message.
261///
262/// # Returns
263/// Updated `ClientHandshakeState` and `Message::SecureHandshakeConfirm`
264///
265/// # Errors
266/// Returns `ProtocolError::HandshakeError` if verification fails
267#[instrument(skip(state, server_pub_key, server_nonce, nonce_verification))]
268pub fn client_secure_handshake_verify(
269    mut state: ClientHandshakeState,
270    server_pub_key: [u8; 32],
271    server_nonce: [u8; 16],
272    nonce_verification: [u8; 32],
273) -> Result<(ClientHandshakeState, Message)> {
274    // Verify that server correctly verified our nonce
275    let client_nonce = state
276        .client_nonce
277        .ok_or_else(|| ProtocolError::HandshakeError("Client nonce not found".to_string()))?;
278
279    let expected_verification = hash_nonce(&client_nonce);
280
281    if expected_verification != nonce_verification {
282        return Err(ProtocolError::HandshakeError(
283            "Server failed to verify client nonce".to_string(),
284        ));
285    }
286
287    // Store server info
288    state.server_public = Some(server_pub_key);
289    state.server_nonce = Some(server_nonce);
290
291    // Hash the server nonce for verification
292    let hash = hash_nonce(&server_nonce);
293
294    debug!("Client verified server response");
295
296    Ok((
297        state,
298        Message::SecureHandshakeConfirm {
299            nonce_verification: hash,
300        },
301    ))
302}
303
304/// Server verifies client's confirmation and derives session key.
305/// Returns the session key if verification succeeds.
306///
307/// # Returns
308/// The derived session key (32 bytes)
309///
310/// # Errors
311/// Returns `ProtocolError::HandshakeError` if verification fails or state is incomplete
312#[instrument(skip(state, nonce_verification))]
313pub fn server_secure_handshake_finalize(
314    mut state: ServerHandshakeState,
315    nonce_verification: [u8; 32],
316) -> Result<[u8; 32]> {
317    // Verify that client correctly verified our nonce
318    let server_nonce = state
319        .server_nonce
320        .ok_or_else(|| ProtocolError::HandshakeError("Server nonce not found".to_string()))?;
321
322    let expected_verification = hash_nonce(&server_nonce);
323
324    if expected_verification != nonce_verification {
325        return Err(ProtocolError::HandshakeError(
326            "Client failed to verify server nonce".to_string(),
327        ));
328    }
329
330    // Extract and take ownership of secret data for key derivation
331    let server_secret = state
332        .secret
333        .take()
334        .ok_or_else(|| ProtocolError::HandshakeError("Server secret not found".to_string()))?;
335    let client_public_bytes = state
336        .client_public
337        .ok_or_else(|| ProtocolError::HandshakeError("Client public key not found".to_string()))?;
338    let client_nonce = state
339        .client_nonce
340        .ok_or_else(|| ProtocolError::HandshakeError("Client nonce not found".to_string()))?;
341
342    // Perform ECDH to derive shared secret
343    let client_public = PublicKey::from(client_public_bytes);
344    let shared_secret = server_secret.diffie_hellman(&client_public);
345
346    // Derive final key using shared secret and both nonces
347    let key = derive_key_from_shared_secret(&shared_secret, &client_nonce, &server_nonce);
348
349    // State will be zeroized on drop due to Zeroize derive
350    debug!("Server finalized handshake and derived session key");
351
352    Ok(key)
353}
354
355/// Client derives the session key.
356/// Must be called after `client_secure_handshake_verify`.
357///
358/// # Returns
359/// The derived session key (32 bytes)
360///
361/// # Errors
362/// Returns `ProtocolError::HandshakeError` if state is incomplete
363#[instrument(skip(state))]
364pub fn client_derive_session_key(mut state: ClientHandshakeState) -> Result<[u8; 32]> {
365    // Extract required data
366    let client_secret = state
367        .secret
368        .take()
369        .ok_or_else(|| ProtocolError::HandshakeError("Client secret not found".to_string()))?;
370    let server_public_bytes = state
371        .server_public
372        .ok_or_else(|| ProtocolError::HandshakeError("Server public key not found".to_string()))?;
373    let client_nonce = state
374        .client_nonce
375        .ok_or_else(|| ProtocolError::HandshakeError("Client nonce not found".to_string()))?;
376    let server_nonce = state
377        .server_nonce
378        .ok_or_else(|| ProtocolError::HandshakeError("Server nonce not found".to_string()))?;
379
380    // Perform ECDH to derive shared secret
381    let server_public = PublicKey::from(server_public_bytes);
382    let shared_secret = client_secret.diffie_hellman(&server_public);
383
384    // Derive session key
385    let key = derive_key_from_shared_secret(&shared_secret, &client_nonce, &server_nonce);
386
387    // State will be zeroized on drop
388    debug!("Client derived session key");
389
390    Ok(key)
391}
392
393#[cfg(test)]
394#[allow(clippy::unwrap_used, clippy::panic)]
395mod tests {
396    use super::*;
397
398    #[test]
399    fn test_per_session_state_isolation() {
400        // Simulate two concurrent handshakes - they should not interfere
401        let (client1, msg1) = client_secure_handshake_init().unwrap();
402        let (client2, msg2) = client_secure_handshake_init().unwrap();
403
404        // Extract from messages
405        let (pub_key1, ts1, nonce1) = match msg1 {
406            Message::SecureHandshakeInit {
407                pub_key,
408                timestamp,
409                nonce,
410            } => (pub_key, timestamp, nonce),
411            _ => panic!("Wrong message type"),
412        };
413
414        let (pub_key2, ts2, nonce2) = match msg2 {
415            Message::SecureHandshakeInit {
416                pub_key,
417                timestamp,
418                nonce,
419            } => (pub_key, timestamp, nonce),
420            _ => panic!("Wrong message type"),
421        };
422
423        // Verify they are different
424        assert_ne!(pub_key1, pub_key2);
425        assert_ne!(nonce1, nonce2);
426
427        // Server responses should be independent
428        let (server1, resp1) = server_secure_handshake_response(pub_key1, nonce1, ts1).unwrap();
429        let (server2, resp2) = server_secure_handshake_response(pub_key2, nonce2, ts2).unwrap();
430
431        let (server_pub1, server_nonce1, verify1) = match resp1 {
432            Message::SecureHandshakeResponse {
433                pub_key,
434                nonce,
435                nonce_verification,
436            } => (pub_key, nonce, nonce_verification),
437            _ => panic!("Wrong message type"),
438        };
439
440        let (server_pub2, server_nonce2, verify2) = match resp2 {
441            Message::SecureHandshakeResponse {
442                pub_key,
443                nonce,
444                nonce_verification,
445            } => (pub_key, nonce, nonce_verification),
446            _ => panic!("Wrong message type"),
447        };
448
449        assert_ne!(server_pub1, server_pub2);
450        assert_ne!(server_nonce1, server_nonce2);
451
452        // Client verifications
453        let (client1_verified, confirm1) =
454            client_secure_handshake_verify(client1, server_pub1, server_nonce1, verify1).unwrap();
455        let (client2_verified, confirm2) =
456            client_secure_handshake_verify(client2, server_pub2, server_nonce2, verify2).unwrap();
457
458        let confirm_hash1 = match confirm1 {
459            Message::SecureHandshakeConfirm { nonce_verification } => nonce_verification,
460            _ => panic!("Wrong message type"),
461        };
462
463        let confirm_hash2 = match confirm2 {
464            Message::SecureHandshakeConfirm { nonce_verification } => nonce_verification,
465            _ => panic!("Wrong message type"),
466        };
467
468        assert_ne!(confirm_hash1, confirm_hash2);
469
470        // Finalize both sides
471        let key1_server = server_secure_handshake_finalize(server1, confirm_hash1).unwrap();
472        let key1_client = client_derive_session_key(client1_verified).unwrap();
473
474        let key2_server = server_secure_handshake_finalize(server2, confirm_hash2).unwrap();
475        let key2_client = client_derive_session_key(client2_verified).unwrap();
476
477        // Keys should match on both sides
478        assert_eq!(key1_server, key1_client);
479        assert_eq!(key2_server, key2_client);
480
481        // But different pairs should have different keys
482        assert_ne!(key1_server, key2_server);
483    }
484
485    #[test]
486    fn test_timestamp_validation() {
487        let now = current_timestamp().unwrap();
488        assert!(verify_timestamp(now, 30));
489        assert!(verify_timestamp(now - 10000, 30)); // 10 seconds ago
490        assert!(!verify_timestamp(now - 31000, 30)); // 31 seconds ago
491        assert!(verify_timestamp(now + 1000, 30)); // 1 second in future (within tolerance)
492        assert!(!verify_timestamp(now + 3000, 30)); // 3 seconds in future (beyond tolerance)
493    }
494
495    #[test]
496    fn test_nonce_verification() {
497        let nonce = generate_nonce();
498        let hash = hash_nonce(&nonce);
499        assert_eq!(hash.len(), 32);
500        // Same nonce should produce same hash
501        assert_eq!(hash, hash_nonce(&nonce));
502        // Different nonce should produce different hash
503        let mut different_nonce = nonce;
504        different_nonce[0] ^= 0xFF;
505        assert_ne!(hash, hash_nonce(&different_nonce));
506    }
507}