nomad_protocol/crypto/
noise.rs1use crate::core::{CryptoError, HASH_SIZE, PUBLIC_KEY_SIZE};
16use snow::{Builder, HandshakeState};
17use zeroize::Zeroize;
18
19use super::{SessionKey, StaticKeypair, SESSION_KEY_SIZE};
20
21const NOISE_PATTERN: &str = "Noise_IK_25519_ChaChaPoly_BLAKE2s";
23
24pub struct HandshakeResult {
26 pub handshake_hash: [u8; HASH_SIZE],
28}
29
30pub struct InitiatorHandshake {
32 state: HandshakeState,
33}
34
35impl InitiatorHandshake {
36 pub fn new(
42 local_keypair: &StaticKeypair,
43 remote_public: &[u8; PUBLIC_KEY_SIZE],
44 ) -> Result<Self, CryptoError> {
45 let builder = Builder::new(NOISE_PATTERN.parse().unwrap());
46 let state = builder
47 .local_private_key(local_keypair.private_key())
48 .remote_public_key(remote_public)
49 .build_initiator()
50 .map_err(|e| CryptoError::HandshakeFailed(e.to_string()))?;
51
52 Ok(Self { state })
53 }
54
55 pub fn write_message(&mut self, payload: &[u8]) -> Result<Vec<u8>, CryptoError> {
63 let mut buf = vec![0u8; 65535];
64 let len = self
65 .state
66 .write_message(payload, &mut buf)
67 .map_err(|e| CryptoError::HandshakeFailed(e.to_string()))?;
68 buf.truncate(len);
69 Ok(buf)
70 }
71
72 pub fn read_message(mut self, message: &[u8]) -> Result<(Vec<u8>, HandshakeResult), CryptoError> {
80 let mut payload = vec![0u8; 65535];
81 let len = self
82 .state
83 .read_message(message, &mut payload)
84 .map_err(|e| CryptoError::HandshakeFailed(e.to_string()))?;
85 payload.truncate(len);
86
87 let hash_slice = self.state.get_handshake_hash();
89 let mut handshake_hash = [0u8; HASH_SIZE];
90 handshake_hash.copy_from_slice(hash_slice);
91
92 let _transport = self
94 .state
95 .into_transport_mode()
96 .map_err(|e| CryptoError::HandshakeFailed(e.to_string()))?;
97
98 Ok((payload, HandshakeResult { handshake_hash }))
99 }
100}
101
102pub struct ResponderHandshake {
104 state: HandshakeState,
105}
106
107impl ResponderHandshake {
108 pub fn new(local_keypair: &StaticKeypair) -> Result<Self, CryptoError> {
113 let builder = Builder::new(NOISE_PATTERN.parse().unwrap());
114 let state = builder
115 .local_private_key(local_keypair.private_key())
116 .build_responder()
117 .map_err(|e| CryptoError::HandshakeFailed(e.to_string()))?;
118
119 Ok(Self { state })
120 }
121
122 pub fn read_message(&mut self, message: &[u8]) -> Result<(Vec<u8>, [u8; PUBLIC_KEY_SIZE]), CryptoError> {
130 let mut payload = vec![0u8; 65535];
131 let len = self
132 .state
133 .read_message(message, &mut payload)
134 .map_err(|e| CryptoError::HandshakeFailed(e.to_string()))?;
135 payload.truncate(len);
136
137 let remote_static = self
139 .state
140 .get_remote_static()
141 .ok_or_else(|| CryptoError::HandshakeFailed("no remote static key".into()))?;
142
143 let mut remote_public = [0u8; PUBLIC_KEY_SIZE];
144 remote_public.copy_from_slice(remote_static);
145
146 Ok((payload, remote_public))
147 }
148
149 pub fn write_message(mut self, payload: &[u8]) -> Result<(Vec<u8>, HandshakeResult), CryptoError> {
157 let mut buf = vec![0u8; 65535];
158 let len = self
159 .state
160 .write_message(payload, &mut buf)
161 .map_err(|e| CryptoError::HandshakeFailed(e.to_string()))?;
162 buf.truncate(len);
163
164 let hash_slice = self.state.get_handshake_hash();
166 let mut handshake_hash = [0u8; HASH_SIZE];
167 handshake_hash.copy_from_slice(hash_slice);
168
169 let _transport = self
171 .state
172 .into_transport_mode()
173 .map_err(|e| CryptoError::HandshakeFailed(e.to_string()))?;
174
175 Ok((buf, HandshakeResult { handshake_hash }))
176 }
177}
178
179pub struct SessionKeys {
199 pub initiator_key: SessionKey,
201 pub responder_key: SessionKey,
203 pub handshake_hash: [u8; HASH_SIZE],
205 pub rekey_auth_key: [u8; HASH_SIZE],
207}
208
209impl SessionKeys {
210 pub fn derive(result: &HandshakeResult, static_dh_secret: &[u8; 32]) -> Result<Self, CryptoError> {
219 use hkdf::Hkdf;
220 use sha2::Sha256;
221 use super::rekey::derive_rekey_auth_key;
222
223 let handshake_hash = &result.handshake_hash;
224
225 let label = b"nomad v1 session keys";
229
230 let hk = Hkdf::<Sha256>::from_prk(handshake_hash)
231 .map_err(|_| CryptoError::KeyDerivationFailed)?;
232 let mut key_material = [0u8; 64];
233 hk.expand(label, &mut key_material)
234 .map_err(|_| CryptoError::KeyDerivationFailed)?;
235
236 let mut initiator_key = [0u8; SESSION_KEY_SIZE];
237 let mut responder_key = [0u8; SESSION_KEY_SIZE];
238 initiator_key.copy_from_slice(&key_material[..32]);
239 responder_key.copy_from_slice(&key_material[32..]);
240
241 let rekey_auth_key = derive_rekey_auth_key(static_dh_secret);
243
244 key_material.zeroize();
246
247 Ok(Self {
248 initiator_key: SessionKey::from_bytes(initiator_key),
249 responder_key: SessionKey::from_bytes(responder_key),
250 handshake_hash: *handshake_hash,
251 rekey_auth_key,
252 })
253 }
254}
255
256#[derive(Clone, Copy, Debug, PartialEq, Eq)]
258pub enum Role {
259 Initiator,
261 Responder,
263}
264
265impl SessionKeys {
266 pub fn send_key(&self, role: Role) -> &SessionKey {
268 match role {
269 Role::Initiator => &self.initiator_key,
270 Role::Responder => &self.responder_key,
271 }
272 }
273
274 pub fn recv_key(&self, role: Role) -> &SessionKey {
276 match role {
277 Role::Initiator => &self.responder_key,
278 Role::Responder => &self.initiator_key,
279 }
280 }
281}
282
283#[cfg(test)]
284mod tests {
285 use super::*;
286
287 #[test]
288 fn test_handshake_roundtrip() {
289 let initiator_keypair = StaticKeypair::generate();
291 let responder_keypair = StaticKeypair::generate();
292
293 let mut initiator = InitiatorHandshake::new(
295 &initiator_keypair,
296 responder_keypair.public_key(),
297 ).unwrap();
298
299 let mut responder = ResponderHandshake::new(&responder_keypair).unwrap();
301
302 let init_payload = b"nomad.echo.v1";
304 let init_message = initiator.write_message(init_payload).unwrap();
305
306 let (recv_payload, remote_public) = responder.read_message(&init_message).unwrap();
308 assert_eq!(recv_payload, init_payload);
309 assert_eq!(&remote_public, initiator_keypair.public_key());
310
311 let resp_payload = b"OK";
313 let (resp_message, responder_result) = responder.write_message(resp_payload).unwrap();
314
315 let (recv_resp_payload, initiator_result) = initiator.read_message(&resp_message).unwrap();
317 assert_eq!(recv_resp_payload, resp_payload);
318
319 assert_eq!(initiator_result.handshake_hash, responder_result.handshake_hash);
321
322 let initiator_static_dh = initiator_keypair.compute_static_dh(responder_keypair.public_key());
324 let responder_static_dh = responder_keypair.compute_static_dh(initiator_keypair.public_key());
325 assert_eq!(initiator_static_dh, responder_static_dh);
326
327 let initiator_keys = SessionKeys::derive(&initiator_result, &initiator_static_dh).unwrap();
329 let responder_keys = SessionKeys::derive(&responder_result, &responder_static_dh).unwrap();
330
331 assert_eq!(
333 initiator_keys.send_key(Role::Initiator).as_bytes(),
334 responder_keys.recv_key(Role::Responder).as_bytes()
335 );
336 assert_eq!(
337 initiator_keys.recv_key(Role::Initiator).as_bytes(),
338 responder_keys.send_key(Role::Responder).as_bytes()
339 );
340
341 assert_eq!(initiator_keys.rekey_auth_key, responder_keys.rekey_auth_key);
343 }
344
345 #[test]
346 fn test_handshake_wrong_key_fails() {
347 let initiator_keypair = StaticKeypair::generate();
348 let responder_keypair = StaticKeypair::generate();
349 let wrong_keypair = StaticKeypair::generate();
350
351 let mut initiator = InitiatorHandshake::new(
353 &initiator_keypair,
354 wrong_keypair.public_key(), ).unwrap();
356
357 let mut responder = ResponderHandshake::new(&responder_keypair).unwrap();
358
359 let init_message = initiator.write_message(b"test").unwrap();
360
361 let result = responder.read_message(&init_message);
363 assert!(result.is_err());
364 }
365
366 #[test]
367 fn test_role_keys() {
368 let initiator_keypair = StaticKeypair::generate();
369 let responder_keypair = StaticKeypair::generate();
370
371 let mut initiator = InitiatorHandshake::new(
372 &initiator_keypair,
373 responder_keypair.public_key(),
374 ).unwrap();
375 let mut responder = ResponderHandshake::new(&responder_keypair).unwrap();
376
377 let init_message = initiator.write_message(b"").unwrap();
378 responder.read_message(&init_message).unwrap();
379 let (resp_message, responder_result) = responder.write_message(b"").unwrap();
380 let (_, initiator_result) = initiator.read_message(&resp_message).unwrap();
381
382 let static_dh = initiator_keypair.compute_static_dh(responder_keypair.public_key());
384
385 let initiator_keys = SessionKeys::derive(&initiator_result, &static_dh).unwrap();
386 let responder_keys = SessionKeys::derive(&responder_result, &static_dh).unwrap();
387
388 assert_eq!(
390 initiator_keys.send_key(Role::Initiator).as_bytes(),
391 initiator_keys.initiator_key.as_bytes()
392 );
393 assert_eq!(
394 initiator_keys.recv_key(Role::Initiator).as_bytes(),
395 initiator_keys.responder_key.as_bytes()
396 );
397 assert_eq!(
398 responder_keys.send_key(Role::Responder).as_bytes(),
399 responder_keys.responder_key.as_bytes()
400 );
401 assert_eq!(
402 responder_keys.recv_key(Role::Responder).as_bytes(),
403 responder_keys.initiator_key.as_bytes()
404 );
405 }
406}