nomad_protocol/crypto/
noise.rs

1//! Noise_IK handshake implementation
2//!
3//! Per 1-SECURITY.md, NOMAD uses the Noise_IK pattern for 1-RTT mutual authentication.
4//!
5//! ```text
6//! Noise_IK(s, rs):
7//!   <- s                    # Responder's static key known to Initiator
8//!   ...
9//!   -> e, es, s, ss         # Initiator sends ephemeral + encrypted static
10//!   <- e, ee, se            # Responder sends ephemeral, completes DH
11//! ```
12//!
13//! After handshake, both parties derive session keys using HKDF.
14
15use crate::core::{CryptoError, HASH_SIZE, PUBLIC_KEY_SIZE};
16use snow::{Builder, HandshakeState};
17use zeroize::Zeroize;
18
19use super::{SessionKey, StaticKeypair, SESSION_KEY_SIZE};
20
21/// Noise protocol pattern for NOMAD
22const NOISE_PATTERN: &str = "Noise_IK_25519_ChaChaPoly_BLAKE2s";
23
24/// Result of a completed handshake
25pub struct HandshakeResult {
26    /// The handshake hash (used for key derivation)
27    pub handshake_hash: [u8; HASH_SIZE],
28}
29
30/// Handshake state machine for the initiator (client).
31pub struct InitiatorHandshake {
32    state: HandshakeState,
33}
34
35impl InitiatorHandshake {
36    /// Create a new initiator handshake.
37    ///
38    /// # Arguments
39    /// * `local_keypair` - The initiator's static keypair
40    /// * `remote_public` - The responder's known static public key
41    pub fn new(
42        local_keypair: &StaticKeypair,
43        remote_public: &[u8; PUBLIC_KEY_SIZE],
44    ) -> Result<Self, CryptoError> {
45        let builder = Builder::new(NOISE_PATTERN.parse().unwrap());
46        let state = builder
47            .local_private_key(local_keypair.private_key())
48            .remote_public_key(remote_public)
49            .build_initiator()
50            .map_err(|e| CryptoError::HandshakeFailed(e.to_string()))?;
51
52        Ok(Self { state })
53    }
54
55    /// Generate the first handshake message (-> e, es, s, ss).
56    ///
57    /// # Arguments
58    /// * `payload` - Optional payload to include (state type ID, extensions)
59    ///
60    /// # Returns
61    /// The handshake message bytes to send
62    pub fn write_message(&mut self, payload: &[u8]) -> Result<Vec<u8>, CryptoError> {
63        let mut buf = vec![0u8; 65535];
64        let len = self
65            .state
66            .write_message(payload, &mut buf)
67            .map_err(|e| CryptoError::HandshakeFailed(e.to_string()))?;
68        buf.truncate(len);
69        Ok(buf)
70    }
71
72    /// Process the handshake response (<- e, ee, se).
73    ///
74    /// # Arguments
75    /// * `message` - The handshake response from the responder
76    ///
77    /// # Returns
78    /// The payload from the responder and the handshake result
79    pub fn read_message(mut self, message: &[u8]) -> Result<(Vec<u8>, HandshakeResult), CryptoError> {
80        let mut payload = vec![0u8; 65535];
81        let len = self
82            .state
83            .read_message(message, &mut payload)
84            .map_err(|e| CryptoError::HandshakeFailed(e.to_string()))?;
85        payload.truncate(len);
86
87        // Get the handshake hash BEFORE transitioning to transport mode
88        let hash_slice = self.state.get_handshake_hash();
89        let mut handshake_hash = [0u8; HASH_SIZE];
90        handshake_hash.copy_from_slice(hash_slice);
91
92        // Verify handshake is complete
93        let _transport = self
94            .state
95            .into_transport_mode()
96            .map_err(|e| CryptoError::HandshakeFailed(e.to_string()))?;
97
98        Ok((payload, HandshakeResult { handshake_hash }))
99    }
100}
101
102/// Handshake state machine for the responder (server).
103pub struct ResponderHandshake {
104    state: HandshakeState,
105}
106
107impl ResponderHandshake {
108    /// Create a new responder handshake.
109    ///
110    /// # Arguments
111    /// * `local_keypair` - The responder's static keypair
112    pub fn new(local_keypair: &StaticKeypair) -> Result<Self, CryptoError> {
113        let builder = Builder::new(NOISE_PATTERN.parse().unwrap());
114        let state = builder
115            .local_private_key(local_keypair.private_key())
116            .build_responder()
117            .map_err(|e| CryptoError::HandshakeFailed(e.to_string()))?;
118
119        Ok(Self { state })
120    }
121
122    /// Process the initiator's handshake message (-> e, es, s, ss).
123    ///
124    /// # Arguments
125    /// * `message` - The handshake initiation from the initiator
126    ///
127    /// # Returns
128    /// The payload from the initiator and the remote static public key
129    pub fn read_message(&mut self, message: &[u8]) -> Result<(Vec<u8>, [u8; PUBLIC_KEY_SIZE]), CryptoError> {
130        let mut payload = vec![0u8; 65535];
131        let len = self
132            .state
133            .read_message(message, &mut payload)
134            .map_err(|e| CryptoError::HandshakeFailed(e.to_string()))?;
135        payload.truncate(len);
136
137        // Get the remote static public key
138        let remote_static = self
139            .state
140            .get_remote_static()
141            .ok_or_else(|| CryptoError::HandshakeFailed("no remote static key".into()))?;
142
143        let mut remote_public = [0u8; PUBLIC_KEY_SIZE];
144        remote_public.copy_from_slice(remote_static);
145
146        Ok((payload, remote_public))
147    }
148
149    /// Generate the handshake response (<- e, ee, se).
150    ///
151    /// # Arguments
152    /// * `payload` - Optional payload to include (ack, negotiated extensions)
153    ///
154    /// # Returns
155    /// The handshake response bytes and the handshake result
156    pub fn write_message(mut self, payload: &[u8]) -> Result<(Vec<u8>, HandshakeResult), CryptoError> {
157        let mut buf = vec![0u8; 65535];
158        let len = self
159            .state
160            .write_message(payload, &mut buf)
161            .map_err(|e| CryptoError::HandshakeFailed(e.to_string()))?;
162        buf.truncate(len);
163
164        // Get the handshake hash BEFORE transitioning to transport mode
165        let hash_slice = self.state.get_handshake_hash();
166        let mut handshake_hash = [0u8; HASH_SIZE];
167        handshake_hash.copy_from_slice(hash_slice);
168
169        // Verify handshake is complete
170        let _transport = self
171            .state
172            .into_transport_mode()
173            .map_err(|e| CryptoError::HandshakeFailed(e.to_string()))?;
174
175        Ok((buf, HandshakeResult { handshake_hash }))
176    }
177}
178
179/// Session keys derived from the Noise handshake.
180///
181/// Per 1-SECURITY.md:
182/// ```text
183/// (initiator_key, responder_key) = HKDF-Expand(
184///     handshake_hash,
185///     "nomad v1 session keys",
186///     64
187/// )
188/// ```
189///
190/// Additionally, for PCS (Post-Compromise Security), we derive:
191/// ```text
192/// rekey_auth_key = HKDF-Expand(
193///     static_dh_secret,   // DH(s_initiator, S_responder)
194///     "nomad v1 rekey auth",
195///     32
196/// )
197/// ```
198pub struct SessionKeys {
199    /// Key for initiator → responder messages
200    pub initiator_key: SessionKey,
201    /// Key for responder → initiator messages
202    pub responder_key: SessionKey,
203    /// The handshake hash (stored for rekeying)
204    pub handshake_hash: [u8; HASH_SIZE],
205    /// Rekey authentication key for PCS (derived from static DH)
206    pub rekey_auth_key: [u8; HASH_SIZE],
207}
208
209impl SessionKeys {
210    /// Derive session keys from the handshake result and static DH secret.
211    ///
212    /// Uses SHA-256 HKDF-Expand with the handshake hash as PRK for session keys,
213    /// and the static DH secret for the rekey authentication key (PCS).
214    ///
215    /// # Arguments
216    /// * `result` - The handshake result containing the handshake hash
217    /// * `static_dh_secret` - The DH(s_initiator, S_responder) shared secret
218    pub fn derive(result: &HandshakeResult, static_dh_secret: &[u8; 32]) -> Result<Self, CryptoError> {
219        use hkdf::Hkdf;
220        use sha2::Sha256;
221        use super::rekey::derive_rekey_auth_key;
222
223        let handshake_hash = &result.handshake_hash;
224
225        // HKDF-Expand using SHA-256
226        // PRK = handshake_hash (treated as already-extracted key)
227        // info = "nomad v1 session keys"
228        let label = b"nomad v1 session keys";
229
230        let hk = Hkdf::<Sha256>::from_prk(handshake_hash)
231            .map_err(|_| CryptoError::KeyDerivationFailed)?;
232        let mut key_material = [0u8; 64];
233        hk.expand(label, &mut key_material)
234            .map_err(|_| CryptoError::KeyDerivationFailed)?;
235
236        let mut initiator_key = [0u8; SESSION_KEY_SIZE];
237        let mut responder_key = [0u8; SESSION_KEY_SIZE];
238        initiator_key.copy_from_slice(&key_material[..32]);
239        responder_key.copy_from_slice(&key_material[32..]);
240
241        // Derive rekey authentication key from static DH for PCS
242        let rekey_auth_key = derive_rekey_auth_key(static_dh_secret);
243
244        // Zeroize the intermediate material
245        key_material.zeroize();
246
247        Ok(Self {
248            initiator_key: SessionKey::from_bytes(initiator_key),
249            responder_key: SessionKey::from_bytes(responder_key),
250            handshake_hash: *handshake_hash,
251            rekey_auth_key,
252        })
253    }
254}
255
256/// Role in the handshake (affects which key is used for send/receive)
257#[derive(Clone, Copy, Debug, PartialEq, Eq)]
258pub enum Role {
259    /// Initiator (client)
260    Initiator,
261    /// Responder (server)
262    Responder,
263}
264
265impl SessionKeys {
266    /// Get the send key for the given role.
267    pub fn send_key(&self, role: Role) -> &SessionKey {
268        match role {
269            Role::Initiator => &self.initiator_key,
270            Role::Responder => &self.responder_key,
271        }
272    }
273
274    /// Get the receive key for the given role.
275    pub fn recv_key(&self, role: Role) -> &SessionKey {
276        match role {
277            Role::Initiator => &self.responder_key,
278            Role::Responder => &self.initiator_key,
279        }
280    }
281}
282
283#[cfg(test)]
284mod tests {
285    use super::*;
286
287    #[test]
288    fn test_handshake_roundtrip() {
289        // Generate keypairs
290        let initiator_keypair = StaticKeypair::generate();
291        let responder_keypair = StaticKeypair::generate();
292
293        // Initiator creates handshake with responder's public key
294        let mut initiator = InitiatorHandshake::new(
295            &initiator_keypair,
296            responder_keypair.public_key(),
297        ).unwrap();
298
299        // Responder creates handshake
300        let mut responder = ResponderHandshake::new(&responder_keypair).unwrap();
301
302        // Initiator sends first message
303        let init_payload = b"nomad.echo.v1";
304        let init_message = initiator.write_message(init_payload).unwrap();
305
306        // Responder processes first message
307        let (recv_payload, remote_public) = responder.read_message(&init_message).unwrap();
308        assert_eq!(recv_payload, init_payload);
309        assert_eq!(&remote_public, initiator_keypair.public_key());
310
311        // Responder sends response
312        let resp_payload = b"OK";
313        let (resp_message, responder_result) = responder.write_message(resp_payload).unwrap();
314
315        // Initiator processes response
316        let (recv_resp_payload, initiator_result) = initiator.read_message(&resp_message).unwrap();
317        assert_eq!(recv_resp_payload, resp_payload);
318
319        // Both should have the same handshake hash
320        assert_eq!(initiator_result.handshake_hash, responder_result.handshake_hash);
321
322        // Compute static DH for both parties (should be the same)
323        let initiator_static_dh = initiator_keypair.compute_static_dh(responder_keypair.public_key());
324        let responder_static_dh = responder_keypair.compute_static_dh(initiator_keypair.public_key());
325        assert_eq!(initiator_static_dh, responder_static_dh);
326
327        // Both can derive session keys with static DH
328        let initiator_keys = SessionKeys::derive(&initiator_result, &initiator_static_dh).unwrap();
329        let responder_keys = SessionKeys::derive(&responder_result, &responder_static_dh).unwrap();
330
331        // Keys should match (initiator's send = responder's receive)
332        assert_eq!(
333            initiator_keys.send_key(Role::Initiator).as_bytes(),
334            responder_keys.recv_key(Role::Responder).as_bytes()
335        );
336        assert_eq!(
337            initiator_keys.recv_key(Role::Initiator).as_bytes(),
338            responder_keys.send_key(Role::Responder).as_bytes()
339        );
340
341        // Rekey auth keys should also match
342        assert_eq!(initiator_keys.rekey_auth_key, responder_keys.rekey_auth_key);
343    }
344
345    #[test]
346    fn test_handshake_wrong_key_fails() {
347        let initiator_keypair = StaticKeypair::generate();
348        let responder_keypair = StaticKeypair::generate();
349        let wrong_keypair = StaticKeypair::generate();
350
351        // Initiator uses wrong public key
352        let mut initiator = InitiatorHandshake::new(
353            &initiator_keypair,
354            wrong_keypair.public_key(), // Wrong key!
355        ).unwrap();
356
357        let mut responder = ResponderHandshake::new(&responder_keypair).unwrap();
358
359        let init_message = initiator.write_message(b"test").unwrap();
360
361        // Responder should fail to decrypt the initiator's static key
362        let result = responder.read_message(&init_message);
363        assert!(result.is_err());
364    }
365
366    #[test]
367    fn test_role_keys() {
368        let initiator_keypair = StaticKeypair::generate();
369        let responder_keypair = StaticKeypair::generate();
370
371        let mut initiator = InitiatorHandshake::new(
372            &initiator_keypair,
373            responder_keypair.public_key(),
374        ).unwrap();
375        let mut responder = ResponderHandshake::new(&responder_keypair).unwrap();
376
377        let init_message = initiator.write_message(b"").unwrap();
378        responder.read_message(&init_message).unwrap();
379        let (resp_message, responder_result) = responder.write_message(b"").unwrap();
380        let (_, initiator_result) = initiator.read_message(&resp_message).unwrap();
381
382        // Compute static DH
383        let static_dh = initiator_keypair.compute_static_dh(responder_keypair.public_key());
384
385        let initiator_keys = SessionKeys::derive(&initiator_result, &static_dh).unwrap();
386        let responder_keys = SessionKeys::derive(&responder_result, &static_dh).unwrap();
387
388        // Verify role-based key access
389        assert_eq!(
390            initiator_keys.send_key(Role::Initiator).as_bytes(),
391            initiator_keys.initiator_key.as_bytes()
392        );
393        assert_eq!(
394            initiator_keys.recv_key(Role::Initiator).as_bytes(),
395            initiator_keys.responder_key.as_bytes()
396        );
397        assert_eq!(
398            responder_keys.send_key(Role::Responder).as_bytes(),
399            responder_keys.responder_key.as_bytes()
400        );
401        assert_eq!(
402            responder_keys.recv_key(Role::Responder).as_bytes(),
403            responder_keys.initiator_key.as_bytes()
404        );
405    }
406}