Skip to main content

network_protocol/protocol/
handshake.rs

1//! Secure handshake protocol implementation using Elliptic Curve Diffie-Hellman (ECDH)
2//!
3//! This module implements a secure cryptographic handshake based on x25519-dalek
4//! with protection against replay attacks using timestamped nonces.
5//!
6//! **Key Change: Per-Session State**
7//! Instead of global singletons, handshake state is now managed through session-scoped
8//! structures (`ClientHandshakeState`, `ServerHandshakeState`) that are passed through
9//! the handshake flow. This prevents concurrent handshake state trampling and ensures
10//! clean state per connection.
11
12use crate::error::{constants, ProtocolError, Result};
13use crate::protocol::message::Message;
14use crate::utils::replay_cache::ReplayCache;
15use rand_core::{OsRng, RngCore};
16use sha2::{Digest, Sha256};
17use std::time::{SystemTime, UNIX_EPOCH};
18use x25519_dalek::{EphemeralSecret, PublicKey, SharedSecret};
19use zeroize::Zeroize;
20
21#[allow(unused_imports)]
22use tracing::{debug, instrument, warn};
23
24/// Client-side handshake state - passed through the handshake flow
25#[derive(Zeroize)]
26#[zeroize(drop)]
27pub struct ClientHandshakeState {
28    secret: Option<EphemeralSecret>,
29    public: Option<[u8; 32]>,
30    server_public: Option<[u8; 32]>,
31    client_nonce: Option<[u8; 16]>,
32    server_nonce: Option<[u8; 16]>,
33}
34
35impl ClientHandshakeState {
36    /// Create a new empty client handshake state
37    pub fn new() -> Self {
38        Self {
39            secret: None,
40            public: None,
41            server_public: None,
42            client_nonce: None,
43            server_nonce: None,
44        }
45    }
46
47    /// Get reference to client nonce (for testing)
48    #[cfg(test)]
49    pub fn client_nonce(&self) -> Option<&[u8; 16]> {
50        self.client_nonce.as_ref()
51    }
52
53    /// Get reference to server nonce (for testing)
54    #[cfg(test)]
55    pub fn server_nonce(&self) -> Option<&[u8; 16]> {
56        self.server_nonce.as_ref()
57    }
58}
59
60impl Default for ClientHandshakeState {
61    fn default() -> Self {
62        Self::new()
63    }
64}
65
66/// Server-side handshake state - passed through the handshake flow
67#[derive(Zeroize)]
68#[zeroize(drop)]
69pub struct ServerHandshakeState {
70    secret: Option<EphemeralSecret>,
71    public: Option<[u8; 32]>,
72    client_public: Option<[u8; 32]>,
73    client_nonce: Option<[u8; 16]>,
74    server_nonce: Option<[u8; 16]>,
75}
76
77impl ServerHandshakeState {
78    /// Create a new empty server handshake state
79    pub fn new() -> Self {
80        Self {
81            secret: None,
82            public: None,
83            client_public: None,
84            client_nonce: None,
85            server_nonce: None,
86        }
87    }
88
89    /// Get reference to server nonce (for testing)
90    #[cfg(test)]
91    pub fn server_nonce(&self) -> Option<&[u8; 16]> {
92        self.server_nonce.as_ref()
93    }
94
95    /// Get reference to client public key (for testing)
96    #[cfg(test)]
97    pub fn client_public(&self) -> Option<&[u8; 32]> {
98        self.client_public.as_ref()
99    }
100}
101
102impl Default for ServerHandshakeState {
103    fn default() -> Self {
104        Self::new()
105    }
106}
107
108/// Get the current timestamp in milliseconds
109///
110/// # Errors
111/// Returns a `ProtocolError::Custom` if the system time is earlier than UNIX_EPOCH
112fn current_timestamp() -> Result<u64> {
113    SystemTime::now()
114        .duration_since(UNIX_EPOCH)
115        .map(|duration| duration.as_millis() as u64)
116        .map_err(|_| ProtocolError::Custom(constants::ERR_SYSTEM_TIME.into()))
117}
118
119/// Generate a cryptographically secure random nonce
120fn generate_nonce() -> [u8; 16] {
121    let mut nonce = [0u8; 16];
122    OsRng.fill_bytes(&mut nonce);
123    nonce
124}
125
126/// Verify that a timestamp is recent enough
127/// Default threshold is 30 seconds with a strict 2-second future tolerance for clock skew
128pub fn verify_timestamp(timestamp: u64, max_age_seconds: u64) -> bool {
129    let current = match current_timestamp() {
130        Ok(time) => time,
131        Err(_) => return false,
132    };
133
134    let max_age_ms = max_age_seconds * 1000;
135    const FUTURE_TOLERANCE_MS: u64 = 2000; // 2 seconds max clock skew
136
137    // Check if timestamp is from the future (strict tolerance for clock skew)
138    if timestamp > current + FUTURE_TOLERANCE_MS {
139        return false;
140    }
141
142    // Check if timestamp is too old
143    if current > timestamp && current - timestamp > max_age_ms {
144        return false;
145    }
146
147    true
148}
149
150/// Compute hash of a nonce for verification
151fn hash_nonce(nonce: &[u8]) -> [u8; 32] {
152    let mut hasher = Sha256::new();
153    hasher.update(nonce);
154    hasher.finalize().into()
155}
156
157/// Derive a session key from a shared secret and nonces
158fn derive_key_from_shared_secret(
159    shared_secret: &SharedSecret,
160    client_nonce: &[u8],
161    server_nonce: &[u8],
162) -> [u8; 32] {
163    let mut hasher = Sha256::new();
164
165    // Include shared secret
166    hasher.update(shared_secret.as_bytes());
167
168    // Include both nonces for additional security (order matters for domain separation)
169    hasher.update(b"client_nonce");
170    hasher.update(client_nonce);
171    hasher.update(b"server_nonce");
172    hasher.update(server_nonce);
173
174    hasher.finalize().into()
175}
176
177/// Initiates secure handshake from the client side.
178/// Generates a new key pair and nonce for the client.
179///
180/// # Returns
181/// A tuple of (new `ClientHandshakeState`, `Message::SecureHandshakeInit`)
182///
183/// # Errors
184/// Returns timestamp errors if system time is invalid
185#[instrument]
186pub fn client_secure_handshake_init() -> Result<(ClientHandshakeState, Message)> {
187    // Generate a new client key pair using OsRng
188    let client_secret = EphemeralSecret::random_from_rng(OsRng);
189    let client_public = PublicKey::from(&client_secret);
190
191    // Generate nonce and timestamp
192    let nonce = generate_nonce();
193    let timestamp = current_timestamp()?;
194
195    let mut state = ClientHandshakeState::new();
196    state.secret = Some(client_secret);
197    state.public = Some(client_public.to_bytes());
198    state.client_nonce = Some(nonce);
199
200    debug!("Client initiating secure handshake");
201
202    Ok((
203        state,
204        Message::SecureHandshakeInit {
205            pub_key: client_public.to_bytes(),
206            timestamp,
207            nonce,
208        },
209    ))
210}
211
212/// Generates server response to client handshake initialization.
213/// Validates client timestamp, generates server key pair and nonce.
214///
215/// # Returns
216/// A tuple of (new `ServerHandshakeState`, `Message::SecureHandshakeResponse`)
217///
218/// # Errors
219/// Returns `ProtocolError::HandshakeError` if client timestamp is invalid or too old
220#[instrument(skip(client_pub_key, client_nonce, replay_cache))]
221pub fn server_secure_handshake_response(
222    client_pub_key: [u8; 32],
223    client_nonce: [u8; 16],
224    client_timestamp: u64,
225    peer_id: &str,
226    replay_cache: &mut ReplayCache,
227) -> Result<(ServerHandshakeState, Message)> {
228    // Validate the client timestamp (must be within last 30 seconds)
229    if !verify_timestamp(client_timestamp, 30) {
230        return Err(ProtocolError::HandshakeError(
231            constants::ERR_INVALID_TIMESTAMP.into(),
232        ));
233    }
234
235    // Check for replay attacks using the cache
236    if replay_cache.is_replay(peer_id, &client_nonce, client_timestamp) {
237        return Err(ProtocolError::HandshakeError(
238            constants::ERR_REPLAY_ATTACK.into(),
239        ));
240    }
241
242    // Generate server key pair and nonce
243    let server_secret = EphemeralSecret::random_from_rng(OsRng);
244    let server_public = PublicKey::from(&server_secret);
245    let server_nonce = generate_nonce();
246
247    // Compute verification hash of client nonce
248    let nonce_verification = hash_nonce(&client_nonce);
249
250    let mut state = ServerHandshakeState::new();
251    state.secret = Some(server_secret);
252    state.public = Some(server_public.to_bytes());
253    state.client_public = Some(client_pub_key);
254    state.client_nonce = Some(client_nonce);
255    state.server_nonce = Some(server_nonce);
256
257    debug!("Server responding to handshake initiation");
258
259    Ok((
260        state,
261        Message::SecureHandshakeResponse {
262            pub_key: server_public.to_bytes(),
263            nonce: server_nonce,
264            nonce_verification,
265        },
266    ))
267}
268
269/// Client verifies server response and sends verification message.
270/// Updates client state and returns confirmation message.
271///
272/// # Returns
273/// Updated `ClientHandshakeState` and `Message::SecureHandshakeConfirm`
274///
275/// # Errors
276/// Returns `ProtocolError::HandshakeError` if verification fails
277#[instrument(skip(state, server_pub_key, server_nonce, nonce_verification, replay_cache))]
278pub fn client_secure_handshake_verify(
279    mut state: ClientHandshakeState,
280    server_pub_key: [u8; 32],
281    server_nonce: [u8; 16],
282    nonce_verification: [u8; 32],
283    peer_id: &str,
284    replay_cache: &mut ReplayCache,
285) -> Result<(ClientHandshakeState, Message)> {
286    // Check for replay attacks using the cache
287    if replay_cache.is_replay(peer_id, &server_nonce, 0) {
288        // Use 0 for server nonce timestamp check
289        return Err(ProtocolError::HandshakeError(
290            constants::ERR_REPLAY_ATTACK.into(),
291        ));
292    }
293
294    // Verify that server correctly verified our nonce
295    let client_nonce = state.client_nonce.ok_or_else(|| {
296        ProtocolError::HandshakeError(constants::ERR_CLIENT_NONCE_NOT_FOUND.into())
297    })?;
298
299    let expected_verification = hash_nonce(&client_nonce);
300
301    if expected_verification != nonce_verification {
302        return Err(ProtocolError::HandshakeError(
303            constants::ERR_NONCE_VERIFICATION_FAILED.into(),
304        ));
305    }
306
307    // Store server info
308    state.server_public = Some(server_pub_key);
309    state.server_nonce = Some(server_nonce);
310    // Verify that server correctly verified our nonce
311    let client_nonce = state.client_nonce.ok_or_else(|| {
312        ProtocolError::HandshakeError(constants::ERR_CLIENT_NONCE_NOT_FOUND.into())
313    })?;
314
315    let expected_verification = hash_nonce(&client_nonce);
316
317    if expected_verification != nonce_verification {
318        return Err(ProtocolError::HandshakeError(
319            constants::ERR_NONCE_VERIFICATION_FAILED.into(),
320        ));
321    }
322
323    // Store server info
324    state.server_public = Some(server_pub_key);
325    state.server_nonce = Some(server_nonce);
326
327    // Hash the server nonce for verification
328    let hash = hash_nonce(&server_nonce);
329
330    debug!("Client verified server response");
331
332    Ok((
333        state,
334        Message::SecureHandshakeConfirm {
335            nonce_verification: hash,
336        },
337    ))
338}
339
340/// Server verifies client's confirmation and derives session key.
341/// Returns the session key if verification succeeds.
342///
343/// # Returns
344/// The derived session key (32 bytes)
345///
346/// # Errors
347/// Returns `ProtocolError::HandshakeError` if verification fails or state is incomplete
348#[instrument(skip(state, nonce_verification))]
349pub fn server_secure_handshake_finalize(
350    mut state: ServerHandshakeState,
351    nonce_verification: [u8; 32],
352) -> Result<[u8; 32]> {
353    // Verify that client correctly verified our nonce
354    let server_nonce = state.server_nonce.ok_or_else(|| {
355        ProtocolError::HandshakeError(constants::ERR_SERVER_NONCE_NOT_FOUND.into())
356    })?;
357
358    let expected_verification = hash_nonce(&server_nonce);
359
360    if expected_verification != nonce_verification {
361        return Err(ProtocolError::HandshakeError(
362            constants::ERR_SERVER_VERIFICATION_FAILED.into(),
363        ));
364    }
365
366    // Extract and take ownership of secret data for key derivation
367    let server_secret = state.secret.take().ok_or_else(|| {
368        ProtocolError::HandshakeError(constants::ERR_SERVER_SECRET_NOT_FOUND.into())
369    })?;
370    let client_public_bytes = state.client_public.ok_or_else(|| {
371        ProtocolError::HandshakeError(constants::ERR_CLIENT_PUBLIC_NOT_FOUND.into())
372    })?;
373    let client_nonce = state.client_nonce.ok_or_else(|| {
374        ProtocolError::HandshakeError(constants::ERR_CLIENT_NONCE_NOT_FOUND.into())
375    })?;
376
377    // Perform ECDH to derive shared secret
378    let client_public = PublicKey::from(client_public_bytes);
379    let shared_secret = server_secret.diffie_hellman(&client_public);
380
381    // Derive final key using shared secret and both nonces
382    let key = derive_key_from_shared_secret(&shared_secret, &client_nonce, &server_nonce);
383
384    // State will be zeroized on drop due to Zeroize derive
385    debug!("Server finalized handshake and derived session key");
386
387    Ok(key)
388}
389
390/// Client derives the session key.
391/// Must be called after `client_secure_handshake_verify`.
392///
393/// # Returns
394/// The derived session key (32 bytes)
395///
396/// # Errors
397/// Returns `ProtocolError::HandshakeError` if state is incomplete
398#[instrument(skip(state))]
399pub fn client_derive_session_key(mut state: ClientHandshakeState) -> Result<[u8; 32]> {
400    // Extract required data
401    let client_secret = state.secret.take().ok_or_else(|| {
402        ProtocolError::HandshakeError(constants::ERR_CLIENT_SECRET_NOT_FOUND.into())
403    })?;
404    let server_public_bytes = state.server_public.ok_or_else(|| {
405        ProtocolError::HandshakeError(constants::ERR_SERVER_PUBLIC_NOT_FOUND.into())
406    })?;
407    let client_nonce = state.client_nonce.ok_or_else(|| {
408        ProtocolError::HandshakeError(constants::ERR_CLIENT_NONCE_NOT_FOUND.into())
409    })?;
410    let server_nonce = state.server_nonce.ok_or_else(|| {
411        ProtocolError::HandshakeError(constants::ERR_SERVER_NONCE_NOT_FOUND.into())
412    })?;
413
414    // Perform ECDH to derive shared secret
415    let server_public = PublicKey::from(server_public_bytes);
416    let shared_secret = client_secret.diffie_hellman(&server_public);
417
418    // Derive session key
419    let key = derive_key_from_shared_secret(&shared_secret, &client_nonce, &server_nonce);
420
421    // State will be zeroized on drop
422    debug!("Client derived session key");
423
424    Ok(key)
425}
426
427#[cfg(test)]
428#[allow(clippy::unwrap_used, clippy::panic)]
429mod tests {
430    use super::*;
431
432    #[test]
433    fn test_per_session_state_isolation() {
434        let mut replay_cache = crate::utils::replay_cache::ReplayCache::new();
435        let peer_id = "test-peer";
436
437        // Simulate two concurrent handshakes - they should not interfere
438        let (client1, msg1) = client_secure_handshake_init().unwrap();
439        let (client2, msg2) = client_secure_handshake_init().unwrap();
440
441        // Extract from messages
442        let (pub_key1, ts1, nonce1) = match msg1 {
443            Message::SecureHandshakeInit {
444                pub_key,
445                timestamp,
446                nonce,
447            } => (pub_key, timestamp, nonce),
448            _ => panic!("Wrong message type"),
449        };
450
451        let (pub_key2, ts2, nonce2) = match msg2 {
452            Message::SecureHandshakeInit {
453                pub_key,
454                timestamp,
455                nonce,
456            } => (pub_key, timestamp, nonce),
457            _ => panic!("Wrong message type"),
458        };
459
460        // Verify they are different
461        assert_ne!(pub_key1, pub_key2);
462        assert_ne!(nonce1, nonce2);
463
464        // Server responses should be independent
465        let (server1, resp1) =
466            server_secure_handshake_response(pub_key1, nonce1, ts1, peer_id, &mut replay_cache)
467                .unwrap();
468        let (server2, resp2) =
469            server_secure_handshake_response(pub_key2, nonce2, ts2, peer_id, &mut replay_cache)
470                .unwrap();
471
472        let (server_pub1, server_nonce1, verify1) = match resp1 {
473            Message::SecureHandshakeResponse {
474                pub_key,
475                nonce,
476                nonce_verification,
477            } => (pub_key, nonce, nonce_verification),
478            _ => panic!("Wrong message type"),
479        };
480
481        let (server_pub2, server_nonce2, verify2) = match resp2 {
482            Message::SecureHandshakeResponse {
483                pub_key,
484                nonce,
485                nonce_verification,
486            } => (pub_key, nonce, nonce_verification),
487            _ => panic!("Wrong message type"),
488        };
489
490        assert_ne!(server_pub1, server_pub2);
491        assert_ne!(server_nonce1, server_nonce2);
492
493        // Client verifications
494        let (client1_verified, confirm1) = client_secure_handshake_verify(
495            client1,
496            server_pub1,
497            server_nonce1,
498            verify1,
499            peer_id,
500            &mut replay_cache,
501        )
502        .unwrap();
503        let (client2_verified, confirm2) = client_secure_handshake_verify(
504            client2,
505            server_pub2,
506            server_nonce2,
507            verify2,
508            peer_id,
509            &mut replay_cache,
510        )
511        .unwrap();
512
513        let confirm_hash1 = match confirm1 {
514            Message::SecureHandshakeConfirm { nonce_verification } => nonce_verification,
515            _ => panic!("Wrong message type"),
516        };
517
518        let confirm_hash2 = match confirm2 {
519            Message::SecureHandshakeConfirm { nonce_verification } => nonce_verification,
520            _ => panic!("Wrong message type"),
521        };
522
523        assert_ne!(confirm_hash1, confirm_hash2);
524
525        // Finalize both sides
526        let key1_server = server_secure_handshake_finalize(server1, confirm_hash1).unwrap();
527        let key1_client = client_derive_session_key(client1_verified).unwrap();
528
529        let key2_server = server_secure_handshake_finalize(server2, confirm_hash2).unwrap();
530        let key2_client = client_derive_session_key(client2_verified).unwrap();
531
532        // Keys should match on both sides
533        assert_eq!(key1_server, key1_client);
534        assert_eq!(key2_server, key2_client);
535
536        // But different pairs should have different keys
537        assert_ne!(key1_server, key2_server);
538    }
539
540    #[test]
541    fn test_timestamp_validation() {
542        let now = current_timestamp().unwrap();
543        assert!(verify_timestamp(now, 30));
544        assert!(verify_timestamp(now - 10000, 30)); // 10 seconds ago
545        assert!(!verify_timestamp(now - 31000, 30)); // 31 seconds ago
546        assert!(verify_timestamp(now + 1000, 30)); // 1 second in future (within tolerance)
547        assert!(!verify_timestamp(now + 3000, 30)); // 3 seconds in future (beyond tolerance)
548    }
549
550    #[test]
551    fn test_nonce_verification() {
552        let nonce = generate_nonce();
553        let hash = hash_nonce(&nonce);
554        assert_eq!(hash.len(), 32);
555        // Same nonce should produce same hash
556        assert_eq!(hash, hash_nonce(&nonce));
557        // Different nonce should produce different hash
558        let mut different_nonce = nonce;
559        different_nonce[0] ^= 0xFF;
560        assert_ne!(hash, hash_nonce(&different_nonce));
561    }
562}