Skip to main content

fips_core/noise/
handshake.rs

1use super::{
2    CipherState, EPOCH_ENCRYPTED_SIZE, EPOCH_SIZE, HANDSHAKE_MSG1_SIZE, HANDSHAKE_MSG2_SIZE,
3    HandshakeProgress, HandshakeRole, NoiseError, NoisePattern, NoiseSession, PROTOCOL_NAME_IK,
4    PROTOCOL_NAME_XK, PUBKEY_SIZE, XK_HANDSHAKE_MSG1_SIZE, XK_HANDSHAKE_MSG2_SIZE,
5    XK_HANDSHAKE_MSG3_SIZE,
6};
7use hkdf::Hkdf;
8use rand::Rng;
9use secp256k1::{Keypair, PublicKey, Secp256k1, SecretKey, ecdh::shared_secret_point};
10use sha2::{Digest, Sha256};
11use std::fmt;
12
13/// Symmetric state during handshake.
14///
15/// Maintains the chaining key (ck), handshake hash (h), and current cipher.
16struct SymmetricState {
17    /// Chaining key for key derivation.
18    ck: [u8; 32],
19    /// Handshake hash for transcript binding.
20    h: [u8; 32],
21    /// Current cipher state for encrypting handshake payloads.
22    cipher: CipherState,
23}
24
25impl SymmetricState {
26    /// Initialize with protocol name.
27    fn initialize(protocol_name: &[u8]) -> Self {
28        // If protocol name <= 32 bytes, pad with zeros
29        // If > 32 bytes, hash it
30        let h = if protocol_name.len() <= 32 {
31            let mut h = [0u8; 32];
32            h[..protocol_name.len()].copy_from_slice(protocol_name);
33            h
34        } else {
35            let mut hasher = Sha256::new();
36            hasher.update(protocol_name);
37            hasher.finalize().into()
38        };
39
40        Self {
41            ck: h,
42            h,
43            cipher: CipherState::empty(),
44        }
45    }
46
47    /// Mix data into the handshake hash.
48    fn mix_hash(&mut self, data: &[u8]) {
49        let mut hasher = Sha256::new();
50        hasher.update(self.h);
51        hasher.update(data);
52        self.h = hasher.finalize().into();
53    }
54
55    /// Mix key material into the chaining key.
56    fn mix_key(&mut self, input_key_material: &[u8]) {
57        let hk = Hkdf::<Sha256>::new(Some(&self.ck), input_key_material);
58        let mut output = [0u8; 64];
59        hk.expand(&[], &mut output)
60            .expect("64 bytes is valid output length");
61
62        self.ck.copy_from_slice(&output[..32]);
63
64        // Initialize cipher with derived key for handshake encryption
65        let mut key = [0u8; 32];
66        key.copy_from_slice(&output[32..64]);
67        self.cipher.initialize_key(key);
68    }
69
70    /// Encrypt and mix into hash.
71    fn encrypt_and_hash(&mut self, plaintext: &[u8]) -> Result<Vec<u8>, NoiseError> {
72        let ciphertext = self.cipher.encrypt(plaintext)?;
73        self.mix_hash(&ciphertext);
74        Ok(ciphertext)
75    }
76
77    /// Decrypt and mix ciphertext into hash.
78    fn decrypt_and_hash(&mut self, ciphertext: &[u8]) -> Result<Vec<u8>, NoiseError> {
79        let plaintext = self.cipher.decrypt(ciphertext)?;
80        self.mix_hash(ciphertext);
81        Ok(plaintext)
82    }
83
84    /// Split into two cipher states for transport.
85    fn split(&self) -> (CipherState, CipherState) {
86        let hk = Hkdf::<Sha256>::new(Some(&self.ck), &[]);
87        let mut output = [0u8; 64];
88        hk.expand(&[], &mut output)
89            .expect("64 bytes is valid output length");
90
91        let mut k1 = [0u8; 32];
92        let mut k2 = [0u8; 32];
93        k1.copy_from_slice(&output[..32]);
94        k2.copy_from_slice(&output[32..64]);
95
96        (CipherState::new(k1), CipherState::new(k2))
97    }
98
99    /// Get the handshake hash (for channel binding).
100    fn handshake_hash(&self) -> [u8; 32] {
101        self.h
102    }
103}
104
105/// Handshake state for Noise IK and XK patterns.
106pub struct HandshakeState {
107    /// Which Noise pattern is being used.
108    pattern: NoisePattern,
109    /// Our role in the handshake.
110    role: HandshakeRole,
111    /// Current progress.
112    progress: HandshakeProgress,
113    /// Symmetric state.
114    symmetric: SymmetricState,
115    /// Our static keypair.
116    static_keypair: Keypair,
117    /// Our ephemeral keypair (generated at handshake start).
118    ephemeral_keypair: Option<Keypair>,
119    /// Remote static public key.
120    /// For IK initiator: known before handshake (from config).
121    /// For IK responder: learned from message 1.
122    /// For XK initiator: known before handshake (from config).
123    /// For XK responder: learned from message 3.
124    remote_static: Option<PublicKey>,
125    /// Remote ephemeral public key (learned during handshake).
126    remote_ephemeral: Option<PublicKey>,
127    /// Secp256k1 context.
128    secp: Secp256k1<secp256k1::All>,
129    /// Our startup epoch for restart detection.
130    local_epoch: Option<[u8; 8]>,
131    /// Remote peer's startup epoch (learned during handshake).
132    remote_epoch: Option<[u8; 8]>,
133}
134
135impl HandshakeState {
136    /// Normalize a compressed public key to even parity for pre-message hashing.
137    ///
138    /// Nostr npubs encode x-only keys (no parity). The Noise IK pre-message
139    /// mixes the responder's static key into the hash before any messages.
140    /// Both sides must mix identical bytes. Since the initiator may only have
141    /// the x-only key (from an npub), we normalize to even parity (0x02 prefix)
142    /// so the hash chain matches regardless of the key's actual parity.
143    ///
144    /// This does NOT affect ECDH operations (which use x-coordinate-only output)
145    /// or the keys sent in handshake messages (which use actual parity).
146    fn normalize_for_premessage(pubkey: &PublicKey) -> [u8; PUBKEY_SIZE] {
147        let mut bytes = pubkey.serialize();
148        bytes[0] = 0x02; // Force even parity
149        bytes
150    }
151
152    /// Create a new IK handshake as initiator.
153    ///
154    /// The initiator knows the responder's static key and will send first.
155    /// Used by FMP (link layer).
156    pub fn new_initiator(static_keypair: Keypair, remote_static: PublicKey) -> Self {
157        let secp = Secp256k1::new();
158        let mut state = Self {
159            pattern: NoisePattern::Ik,
160            role: HandshakeRole::Initiator,
161            progress: HandshakeProgress::Initial,
162            symmetric: SymmetricState::initialize(PROTOCOL_NAME_IK),
163            static_keypair,
164            ephemeral_keypair: None,
165            remote_static: Some(remote_static),
166            remote_ephemeral: None,
167            secp,
168            local_epoch: None,
169            remote_epoch: None,
170        };
171
172        // Mix in pre-message: <- s (responder's static is known)
173        // Normalize to even parity so initiator and responder hash chains match
174        // even when the initiator only has the x-only key (from npub).
175        let normalized = Self::normalize_for_premessage(&remote_static);
176        state.symmetric.mix_hash(&normalized);
177
178        state
179    }
180
181    /// Create a new IK handshake as responder.
182    ///
183    /// The responder does NOT know the initiator's static key - it will be
184    /// learned from message 1. Used by FMP (link layer).
185    pub fn new_responder(static_keypair: Keypair) -> Self {
186        let secp = Secp256k1::new();
187        let mut state = Self {
188            pattern: NoisePattern::Ik,
189            role: HandshakeRole::Responder,
190            progress: HandshakeProgress::Initial,
191            symmetric: SymmetricState::initialize(PROTOCOL_NAME_IK),
192            static_keypair,
193            ephemeral_keypair: None,
194            remote_static: None, // Will learn from message 1
195            remote_ephemeral: None,
196            secp,
197            local_epoch: None,
198            remote_epoch: None,
199        };
200
201        // Mix in pre-message: <- s (our static, since we're responder)
202        // Normalize to even parity to match initiator's hash chain.
203        let normalized = Self::normalize_for_premessage(&state.static_keypair.public_key());
204        state.symmetric.mix_hash(&normalized);
205
206        state
207    }
208
209    /// Create a new XK handshake as initiator.
210    ///
211    /// The initiator knows the responder's static key. XK defers the
212    /// initiator's static key reveal to msg3. Used by FSP (session layer).
213    pub fn new_xk_initiator(static_keypair: Keypair, remote_static: PublicKey) -> Self {
214        let secp = Secp256k1::new();
215        let mut state = Self {
216            pattern: NoisePattern::Xk,
217            role: HandshakeRole::Initiator,
218            progress: HandshakeProgress::Initial,
219            symmetric: SymmetricState::initialize(PROTOCOL_NAME_XK),
220            static_keypair,
221            ephemeral_keypair: None,
222            remote_static: Some(remote_static),
223            remote_ephemeral: None,
224            secp,
225            local_epoch: None,
226            remote_epoch: None,
227        };
228
229        // Mix in pre-message: <- s (responder's static is known)
230        let normalized = Self::normalize_for_premessage(&remote_static);
231        state.symmetric.mix_hash(&normalized);
232
233        state
234    }
235
236    /// Create a new XK handshake as responder.
237    ///
238    /// The responder does NOT know the initiator's static key - it will be
239    /// learned from message 3. Used by FSP (session layer).
240    pub fn new_xk_responder(static_keypair: Keypair) -> Self {
241        let secp = Secp256k1::new();
242        let mut state = Self {
243            pattern: NoisePattern::Xk,
244            role: HandshakeRole::Responder,
245            progress: HandshakeProgress::Initial,
246            symmetric: SymmetricState::initialize(PROTOCOL_NAME_XK),
247            static_keypair,
248            ephemeral_keypair: None,
249            remote_static: None, // Will learn from message 3
250            remote_ephemeral: None,
251            secp,
252            local_epoch: None,
253            remote_epoch: None,
254        };
255
256        // Mix in pre-message: <- s (our static, since we're responder)
257        let normalized = Self::normalize_for_premessage(&state.static_keypair.public_key());
258        state.symmetric.mix_hash(&normalized);
259
260        state
261    }
262
263    /// Get our role.
264    pub fn role(&self) -> HandshakeRole {
265        self.role
266    }
267
268    /// Get current progress.
269    pub fn progress(&self) -> HandshakeProgress {
270        self.progress
271    }
272
273    /// Check if handshake is complete.
274    pub fn is_complete(&self) -> bool {
275        self.progress == HandshakeProgress::Complete
276    }
277
278    /// Get the remote static key (available after message 1 for responder).
279    pub fn remote_static(&self) -> Option<&PublicKey> {
280        self.remote_static.as_ref()
281    }
282
283    /// Set the local startup epoch for restart detection.
284    pub fn set_local_epoch(&mut self, epoch: [u8; 8]) {
285        self.local_epoch = Some(epoch);
286    }
287
288    /// Get the remote peer's startup epoch (available after processing their message).
289    pub fn remote_epoch(&self) -> Option<[u8; 8]> {
290        self.remote_epoch
291    }
292
293    /// Generate ephemeral keypair.
294    fn generate_ephemeral(&mut self) {
295        let mut rng = rand::rng();
296        let mut secret_bytes = [0u8; 32];
297        rng.fill_bytes(&mut secret_bytes);
298
299        let secret_key =
300            SecretKey::from_slice(&secret_bytes).expect("32 random bytes is valid secret key");
301        self.ephemeral_keypair = Some(Keypair::from_secret_key(&self.secp, &secret_key));
302    }
303
304    /// Perform ECDH between our secret and their public key.
305    ///
306    /// Uses x-only hashing (SHA-256 of just the x-coordinate) to produce
307    /// a parity-independent shared secret. This is necessary because Nostr
308    /// npubs encode x-only keys without parity information, so the initiator
309    /// may have the wrong parity for the responder's static key. Since P and
310    /// -P produce ECDH result points with the same x-coordinate, hashing
311    /// only x ensures both sides derive the same shared secret.
312    fn ecdh(&self, our_secret: &SecretKey, their_public: &PublicKey) -> [u8; 32] {
313        // Get raw (x, y) coordinates (64 bytes) without any hashing
314        let point = shared_secret_point(their_public, our_secret);
315        // Hash only the x-coordinate (first 32 bytes), ignoring y/parity
316        let mut hasher = Sha256::new();
317        hasher.update(&point[..32]);
318        let hash = hasher.finalize();
319        let mut result = [0u8; 32];
320        result.copy_from_slice(&hash);
321        result
322    }
323
324    /// Write message 1 (initiator only).
325    ///
326    /// Message 1 contains:
327    /// - e: ephemeral public key (33 bytes)
328    /// - encrypted s: our static public key encrypted (33 + 16 = 49 bytes)
329    /// - encrypted epoch: startup epoch for restart detection (8 + 16 = 24 bytes)
330    ///
331    /// Total: 106 bytes
332    pub fn write_message_1(&mut self) -> Result<Vec<u8>, NoiseError> {
333        if self.role != HandshakeRole::Initiator {
334            return Err(NoiseError::WrongState {
335                expected: "initiator".to_string(),
336                got: "responder".to_string(),
337            });
338        }
339        if self.progress != HandshakeProgress::Initial {
340            return Err(NoiseError::WrongState {
341                expected: HandshakeProgress::Initial.to_string(),
342                got: self.progress.to_string(),
343            });
344        }
345
346        let remote_static = self
347            .remote_static
348            .expect("initiator must have remote static");
349        let epoch = self
350            .local_epoch
351            .expect("local epoch must be set before write_message_1");
352
353        // Generate ephemeral keypair
354        self.generate_ephemeral();
355        let ephemeral = self.ephemeral_keypair.as_ref().unwrap();
356        let e_pub = ephemeral.public_key().serialize();
357
358        let mut message = Vec::with_capacity(HANDSHAKE_MSG1_SIZE);
359
360        // -> e: send ephemeral, mix into hash
361        message.extend_from_slice(&e_pub);
362        self.symmetric.mix_hash(&e_pub);
363
364        // -> es: DH(e, rs), mix into key
365        let es = self.ecdh(&ephemeral.secret_key(), &remote_static);
366        self.symmetric.mix_key(&es);
367
368        // -> s: encrypt our static and send
369        let our_static = self.static_keypair.public_key().serialize();
370        let encrypted_static = self.symmetric.encrypt_and_hash(&our_static)?;
371        message.extend_from_slice(&encrypted_static);
372
373        // -> ss: DH(s, rs), mix into key
374        let ss = self.ecdh(&self.static_keypair.secret_key(), &remote_static);
375        self.symmetric.mix_key(&ss);
376
377        // -> epoch: encrypt startup epoch for restart detection
378        let encrypted_epoch = self.symmetric.encrypt_and_hash(&epoch)?;
379        debug_assert_eq!(encrypted_epoch.len(), EPOCH_ENCRYPTED_SIZE);
380        message.extend_from_slice(&encrypted_epoch);
381
382        self.progress = HandshakeProgress::Message1Done;
383
384        Ok(message)
385    }
386
387    /// Read message 1 (responder only).
388    ///
389    /// Processes the initiator's first message and learns their identity and epoch.
390    pub fn read_message_1(&mut self, message: &[u8]) -> Result<(), NoiseError> {
391        if self.role != HandshakeRole::Responder {
392            return Err(NoiseError::WrongState {
393                expected: "responder".to_string(),
394                got: "initiator".to_string(),
395            });
396        }
397        if self.progress != HandshakeProgress::Initial {
398            return Err(NoiseError::WrongState {
399                expected: HandshakeProgress::Initial.to_string(),
400                got: self.progress.to_string(),
401            });
402        }
403        if message.len() != HANDSHAKE_MSG1_SIZE {
404            return Err(NoiseError::MessageTooShort {
405                expected: HANDSHAKE_MSG1_SIZE,
406                got: message.len(),
407            });
408        }
409
410        // -> e: parse remote ephemeral, mix into hash
411        let re = PublicKey::from_slice(&message[..PUBKEY_SIZE])
412            .map_err(|_| NoiseError::InvalidPublicKey)?;
413        self.remote_ephemeral = Some(re);
414        self.symmetric.mix_hash(&message[..PUBKEY_SIZE]);
415
416        // -> es: DH(s, re), mix into key
417        // (responder uses their static with initiator's ephemeral)
418        let es = self.ecdh(&self.static_keypair.secret_key(), &re);
419        self.symmetric.mix_key(&es);
420
421        // -> s: decrypt initiator's static
422        let encrypted_static_end = PUBKEY_SIZE + PUBKEY_SIZE + super::TAG_SIZE;
423        let encrypted_static = &message[PUBKEY_SIZE..encrypted_static_end];
424        let decrypted_static = self.symmetric.decrypt_and_hash(encrypted_static)?;
425        let rs =
426            PublicKey::from_slice(&decrypted_static).map_err(|_| NoiseError::InvalidPublicKey)?;
427        self.remote_static = Some(rs);
428
429        // -> ss: DH(s, rs), mix into key
430        let ss = self.ecdh(&self.static_keypair.secret_key(), &rs);
431        self.symmetric.mix_key(&ss);
432
433        // -> epoch: decrypt initiator's startup epoch
434        let encrypted_epoch = &message[encrypted_static_end..];
435        debug_assert_eq!(encrypted_epoch.len(), EPOCH_ENCRYPTED_SIZE);
436        let decrypted_epoch = self.symmetric.decrypt_and_hash(encrypted_epoch)?;
437        debug_assert_eq!(decrypted_epoch.len(), EPOCH_SIZE);
438        let mut epoch = [0u8; EPOCH_SIZE];
439        epoch.copy_from_slice(&decrypted_epoch);
440        self.remote_epoch = Some(epoch);
441
442        self.progress = HandshakeProgress::Message1Done;
443
444        Ok(())
445    }
446
447    /// Write message 2 (responder only).
448    ///
449    /// Message 2 contains:
450    /// - e: ephemeral public key (33 bytes)
451    /// - encrypted epoch: startup epoch for restart detection (8 + 16 = 24 bytes)
452    ///
453    /// Total: 57 bytes
454    pub fn write_message_2(&mut self) -> Result<Vec<u8>, NoiseError> {
455        if self.role != HandshakeRole::Responder {
456            return Err(NoiseError::WrongState {
457                expected: "responder".to_string(),
458                got: "initiator".to_string(),
459            });
460        }
461        if self.progress != HandshakeProgress::Message1Done {
462            return Err(NoiseError::WrongState {
463                expected: HandshakeProgress::Message1Done.to_string(),
464                got: self.progress.to_string(),
465            });
466        }
467
468        let re = self.remote_ephemeral.expect("should have remote ephemeral");
469        let epoch = self
470            .local_epoch
471            .expect("local epoch must be set before write_message_2");
472
473        // Generate ephemeral keypair
474        self.generate_ephemeral();
475        let ephemeral = self.ephemeral_keypair.as_ref().unwrap();
476        let e_pub = ephemeral.public_key().serialize();
477
478        let mut message = Vec::with_capacity(HANDSHAKE_MSG2_SIZE);
479
480        // <- e: send ephemeral, mix into hash
481        message.extend_from_slice(&e_pub);
482        self.symmetric.mix_hash(&e_pub);
483
484        // <- ee: DH(e, re), mix into key
485        let ee = self.ecdh(&ephemeral.secret_key(), &re);
486        self.symmetric.mix_key(&ee);
487
488        // <- se: DH(s, re), mix into key
489        let se = self.ecdh(&self.static_keypair.secret_key(), &re);
490        self.symmetric.mix_key(&se);
491
492        // <- epoch: encrypt startup epoch for restart detection
493        let encrypted_epoch = self.symmetric.encrypt_and_hash(&epoch)?;
494        debug_assert_eq!(encrypted_epoch.len(), EPOCH_ENCRYPTED_SIZE);
495        message.extend_from_slice(&encrypted_epoch);
496
497        self.progress = HandshakeProgress::Complete;
498
499        Ok(message)
500    }
501
502    /// Read message 2 (initiator only).
503    ///
504    /// Processes the responder's message and completes the handshake.
505    pub fn read_message_2(&mut self, message: &[u8]) -> Result<(), NoiseError> {
506        if self.role != HandshakeRole::Initiator {
507            return Err(NoiseError::WrongState {
508                expected: "initiator".to_string(),
509                got: "responder".to_string(),
510            });
511        }
512        if self.progress != HandshakeProgress::Message1Done {
513            return Err(NoiseError::WrongState {
514                expected: HandshakeProgress::Message1Done.to_string(),
515                got: self.progress.to_string(),
516            });
517        }
518        if message.len() != HANDSHAKE_MSG2_SIZE {
519            return Err(NoiseError::MessageTooShort {
520                expected: HANDSHAKE_MSG2_SIZE,
521                got: message.len(),
522            });
523        }
524
525        // <- e: parse remote ephemeral, mix into hash
526        let e_pub = &message[..PUBKEY_SIZE];
527        let re = PublicKey::from_slice(e_pub).map_err(|_| NoiseError::InvalidPublicKey)?;
528        self.remote_ephemeral = Some(re);
529        self.symmetric.mix_hash(e_pub);
530
531        // <- ee: DH(e, re), mix into key
532        let ephemeral = self.ephemeral_keypair.as_ref().unwrap();
533        let ee = self.ecdh(&ephemeral.secret_key(), &re);
534        self.symmetric.mix_key(&ee);
535
536        // <- se: DH(e, rs), mix into key
537        // (initiator uses their ephemeral with responder's static)
538        let rs = self.remote_static.expect("initiator has remote static");
539        let se = self.ecdh(&ephemeral.secret_key(), &rs);
540        self.symmetric.mix_key(&se);
541
542        // <- epoch: decrypt responder's startup epoch
543        let encrypted_epoch = &message[PUBKEY_SIZE..];
544        debug_assert_eq!(encrypted_epoch.len(), EPOCH_ENCRYPTED_SIZE);
545        let decrypted_epoch = self.symmetric.decrypt_and_hash(encrypted_epoch)?;
546        debug_assert_eq!(decrypted_epoch.len(), EPOCH_SIZE);
547        let mut epoch = [0u8; EPOCH_SIZE];
548        epoch.copy_from_slice(&decrypted_epoch);
549        self.remote_epoch = Some(epoch);
550
551        self.progress = HandshakeProgress::Complete;
552
553        Ok(())
554    }
555
556    // ========================================================================
557    // XK Pattern Methods (Session Layer)
558    // ========================================================================
559
560    /// Write XK message 1 (initiator only).
561    ///
562    /// XK msg1: `-> e, es`
563    /// - e: ephemeral public key (33 bytes)
564    /// - es: DH(e_priv, rs_pub), mix_key
565    ///
566    /// Total: 33 bytes (ephemeral only — no static, no epoch)
567    pub fn write_xk_message_1(&mut self) -> Result<Vec<u8>, NoiseError> {
568        if self.role != HandshakeRole::Initiator {
569            return Err(NoiseError::WrongState {
570                expected: "initiator".to_string(),
571                got: "responder".to_string(),
572            });
573        }
574        if self.progress != HandshakeProgress::Initial {
575            return Err(NoiseError::WrongState {
576                expected: HandshakeProgress::Initial.to_string(),
577                got: self.progress.to_string(),
578            });
579        }
580
581        let remote_static = self
582            .remote_static
583            .expect("initiator must have remote static");
584
585        // Generate ephemeral keypair
586        self.generate_ephemeral();
587        let ephemeral = self.ephemeral_keypair.as_ref().unwrap();
588        let e_pub = ephemeral.public_key().serialize();
589
590        let mut message = Vec::with_capacity(XK_HANDSHAKE_MSG1_SIZE);
591
592        // -> e: send ephemeral, mix into hash
593        message.extend_from_slice(&e_pub);
594        self.symmetric.mix_hash(&e_pub);
595
596        // -> es: DH(e, rs), mix into key
597        let es = self.ecdh(&ephemeral.secret_key(), &remote_static);
598        self.symmetric.mix_key(&es);
599
600        self.progress = HandshakeProgress::Message1Done;
601
602        Ok(message)
603    }
604
605    /// Read XK message 1 (responder only).
606    ///
607    /// Processes the initiator's first message. Does NOT learn initiator's
608    /// identity (that comes in msg3).
609    pub fn read_xk_message_1(&mut self, message: &[u8]) -> Result<(), NoiseError> {
610        if self.role != HandshakeRole::Responder {
611            return Err(NoiseError::WrongState {
612                expected: "responder".to_string(),
613                got: "initiator".to_string(),
614            });
615        }
616        if self.progress != HandshakeProgress::Initial {
617            return Err(NoiseError::WrongState {
618                expected: HandshakeProgress::Initial.to_string(),
619                got: self.progress.to_string(),
620            });
621        }
622        if message.len() != XK_HANDSHAKE_MSG1_SIZE {
623            return Err(NoiseError::MessageTooShort {
624                expected: XK_HANDSHAKE_MSG1_SIZE,
625                got: message.len(),
626            });
627        }
628
629        // -> e: parse remote ephemeral, mix into hash
630        let re = PublicKey::from_slice(&message[..PUBKEY_SIZE])
631            .map_err(|_| NoiseError::InvalidPublicKey)?;
632        self.remote_ephemeral = Some(re);
633        self.symmetric.mix_hash(&message[..PUBKEY_SIZE]);
634
635        // -> es: DH(s, re), mix into key
636        // (responder uses their static with initiator's ephemeral)
637        let es = self.ecdh(&self.static_keypair.secret_key(), &re);
638        self.symmetric.mix_key(&es);
639
640        self.progress = HandshakeProgress::Message1Done;
641
642        Ok(())
643    }
644
645    /// Write XK message 2 (responder only).
646    ///
647    /// XK msg2: `<- e, ee` + encrypted epoch
648    /// - e: ephemeral public key (33 bytes)
649    /// - ee: DH(e_priv, re_pub), mix_key
650    /// - encrypted epoch (24 bytes)
651    ///
652    /// Total: 57 bytes
653    pub fn write_xk_message_2(&mut self) -> Result<Vec<u8>, NoiseError> {
654        if self.role != HandshakeRole::Responder {
655            return Err(NoiseError::WrongState {
656                expected: "responder".to_string(),
657                got: "initiator".to_string(),
658            });
659        }
660        if self.progress != HandshakeProgress::Message1Done {
661            return Err(NoiseError::WrongState {
662                expected: HandshakeProgress::Message1Done.to_string(),
663                got: self.progress.to_string(),
664            });
665        }
666
667        let re = self.remote_ephemeral.expect("should have remote ephemeral");
668        let epoch = self
669            .local_epoch
670            .expect("local epoch must be set before write_xk_message_2");
671
672        // Generate ephemeral keypair
673        self.generate_ephemeral();
674        let ephemeral = self.ephemeral_keypair.as_ref().unwrap();
675        let e_pub = ephemeral.public_key().serialize();
676
677        let mut message = Vec::with_capacity(XK_HANDSHAKE_MSG2_SIZE);
678
679        // <- e: send ephemeral, mix into hash
680        message.extend_from_slice(&e_pub);
681        self.symmetric.mix_hash(&e_pub);
682
683        // <- ee: DH(e, re), mix into key
684        let ee = self.ecdh(&ephemeral.secret_key(), &re);
685        self.symmetric.mix_key(&ee);
686
687        // <- epoch: encrypt startup epoch for restart detection
688        let encrypted_epoch = self.symmetric.encrypt_and_hash(&epoch)?;
689        debug_assert_eq!(encrypted_epoch.len(), EPOCH_ENCRYPTED_SIZE);
690        message.extend_from_slice(&encrypted_epoch);
691
692        self.progress = HandshakeProgress::Message2Done;
693
694        Ok(message)
695    }
696
697    /// Read XK message 2 (initiator only).
698    ///
699    /// Processes the responder's message and extracts the responder's epoch.
700    /// Does NOT complete the handshake — msg3 still needed.
701    pub fn read_xk_message_2(&mut self, message: &[u8]) -> Result<(), NoiseError> {
702        if self.role != HandshakeRole::Initiator {
703            return Err(NoiseError::WrongState {
704                expected: "initiator".to_string(),
705                got: "responder".to_string(),
706            });
707        }
708        if self.progress != HandshakeProgress::Message1Done {
709            return Err(NoiseError::WrongState {
710                expected: HandshakeProgress::Message1Done.to_string(),
711                got: self.progress.to_string(),
712            });
713        }
714        if message.len() != XK_HANDSHAKE_MSG2_SIZE {
715            return Err(NoiseError::MessageTooShort {
716                expected: XK_HANDSHAKE_MSG2_SIZE,
717                got: message.len(),
718            });
719        }
720
721        // <- e: parse remote ephemeral, mix into hash
722        let e_pub = &message[..PUBKEY_SIZE];
723        let re = PublicKey::from_slice(e_pub).map_err(|_| NoiseError::InvalidPublicKey)?;
724        self.remote_ephemeral = Some(re);
725        self.symmetric.mix_hash(e_pub);
726
727        // <- ee: DH(e, re), mix into key
728        let ephemeral = self.ephemeral_keypair.as_ref().unwrap();
729        let ee = self.ecdh(&ephemeral.secret_key(), &re);
730        self.symmetric.mix_key(&ee);
731
732        // <- epoch: decrypt responder's startup epoch
733        let encrypted_epoch = &message[PUBKEY_SIZE..];
734        debug_assert_eq!(encrypted_epoch.len(), EPOCH_ENCRYPTED_SIZE);
735        let decrypted_epoch = self.symmetric.decrypt_and_hash(encrypted_epoch)?;
736        debug_assert_eq!(decrypted_epoch.len(), EPOCH_SIZE);
737        let mut epoch = [0u8; EPOCH_SIZE];
738        epoch.copy_from_slice(&decrypted_epoch);
739        self.remote_epoch = Some(epoch);
740
741        self.progress = HandshakeProgress::Message2Done;
742
743        Ok(())
744    }
745
746    /// Write XK message 3 (initiator only).
747    ///
748    /// XK msg3: `-> s, se` + encrypted epoch
749    /// - s: encrypt_and_hash(s_pub) — encrypted static (49 bytes)
750    /// - se: DH(s_priv, re_pub), mix_key
751    /// - encrypted epoch (24 bytes)
752    ///
753    /// Total: 73 bytes
754    pub fn write_xk_message_3(&mut self) -> Result<Vec<u8>, NoiseError> {
755        if self.role != HandshakeRole::Initiator {
756            return Err(NoiseError::WrongState {
757                expected: "initiator".to_string(),
758                got: "responder".to_string(),
759            });
760        }
761        if self.progress != HandshakeProgress::Message2Done {
762            return Err(NoiseError::WrongState {
763                expected: HandshakeProgress::Message2Done.to_string(),
764                got: self.progress.to_string(),
765            });
766        }
767
768        let re = self
769            .remote_ephemeral
770            .expect("should have remote ephemeral after msg2");
771        let epoch = self
772            .local_epoch
773            .expect("local epoch must be set before write_xk_message_3");
774
775        let mut message = Vec::with_capacity(XK_HANDSHAKE_MSG3_SIZE);
776
777        // -> s: encrypt our static and send
778        let our_static = self.static_keypair.public_key().serialize();
779        let encrypted_static = self.symmetric.encrypt_and_hash(&our_static)?;
780        message.extend_from_slice(&encrypted_static);
781
782        // -> se: DH(s, re), mix into key
783        let se = self.ecdh(&self.static_keypair.secret_key(), &re);
784        self.symmetric.mix_key(&se);
785
786        // -> epoch: encrypt startup epoch for restart detection
787        let encrypted_epoch = self.symmetric.encrypt_and_hash(&epoch)?;
788        debug_assert_eq!(encrypted_epoch.len(), EPOCH_ENCRYPTED_SIZE);
789        message.extend_from_slice(&encrypted_epoch);
790
791        self.progress = HandshakeProgress::Complete;
792
793        Ok(message)
794    }
795
796    /// Read XK message 3 (responder only).
797    ///
798    /// Processes the initiator's encrypted static key and epoch.
799    /// After this, the responder learns the initiator's identity.
800    pub fn read_xk_message_3(&mut self, message: &[u8]) -> Result<(), NoiseError> {
801        if self.role != HandshakeRole::Responder {
802            return Err(NoiseError::WrongState {
803                expected: "responder".to_string(),
804                got: "initiator".to_string(),
805            });
806        }
807        if self.progress != HandshakeProgress::Message2Done {
808            return Err(NoiseError::WrongState {
809                expected: HandshakeProgress::Message2Done.to_string(),
810                got: self.progress.to_string(),
811            });
812        }
813        if message.len() != XK_HANDSHAKE_MSG3_SIZE {
814            return Err(NoiseError::MessageTooShort {
815                expected: XK_HANDSHAKE_MSG3_SIZE,
816                got: message.len(),
817            });
818        }
819
820        // -> s: decrypt initiator's static
821        let encrypted_static_end = PUBKEY_SIZE + super::TAG_SIZE;
822        let encrypted_static = &message[..encrypted_static_end];
823        let decrypted_static = self.symmetric.decrypt_and_hash(encrypted_static)?;
824        let rs =
825            PublicKey::from_slice(&decrypted_static).map_err(|_| NoiseError::InvalidPublicKey)?;
826        self.remote_static = Some(rs);
827
828        // -> se: DH(e, rs), mix into key
829        // (responder uses their ephemeral with initiator's now-known static)
830        let ephemeral = self
831            .ephemeral_keypair
832            .as_ref()
833            .expect("should have ephemeral after msg2");
834        let se = self.ecdh(&ephemeral.secret_key(), &rs);
835        self.symmetric.mix_key(&se);
836
837        // -> epoch: decrypt initiator's startup epoch
838        let encrypted_epoch = &message[encrypted_static_end..];
839        debug_assert_eq!(encrypted_epoch.len(), EPOCH_ENCRYPTED_SIZE);
840        let decrypted_epoch = self.symmetric.decrypt_and_hash(encrypted_epoch)?;
841        debug_assert_eq!(decrypted_epoch.len(), EPOCH_SIZE);
842        let mut epoch = [0u8; EPOCH_SIZE];
843        epoch.copy_from_slice(&decrypted_epoch);
844        self.remote_epoch = Some(epoch);
845
846        self.progress = HandshakeProgress::Complete;
847
848        Ok(())
849    }
850
851    /// Complete the handshake and return a NoiseSession.
852    ///
853    /// Must be called after the handshake is complete.
854    pub fn into_session(self) -> Result<NoiseSession, NoiseError> {
855        if !self.is_complete() {
856            return Err(NoiseError::HandshakeNotComplete);
857        }
858
859        let (c1, c2) = self.symmetric.split();
860        let handshake_hash = self.symmetric.handshake_hash();
861        let remote_static = self
862            .remote_static
863            .expect("remote static must be known after handshake");
864
865        // Initiator sends with c1, receives with c2
866        // Responder sends with c2, receives with c1
867        let (send_cipher, recv_cipher) = match self.role {
868            HandshakeRole::Initiator => (c1, c2),
869            HandshakeRole::Responder => (c2, c1),
870        };
871
872        Ok(NoiseSession::from_handshake(
873            self.role,
874            send_cipher,
875            recv_cipher,
876            handshake_hash,
877            remote_static,
878        ))
879    }
880
881    /// Get the handshake hash (for channel binding, available after complete).
882    pub fn handshake_hash(&self) -> [u8; 32] {
883        self.symmetric.handshake_hash()
884    }
885}
886
887impl fmt::Debug for HandshakeState {
888    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
889        f.debug_struct("HandshakeState")
890            .field("pattern", &self.pattern)
891            .field("role", &self.role)
892            .field("progress", &self.progress)
893            .field("has_ephemeral", &self.ephemeral_keypair.is_some())
894            .field("has_remote_static", &self.remote_static.is_some())
895            .field("has_remote_ephemeral", &self.remote_ephemeral.is_some())
896            .field("has_local_epoch", &self.local_epoch.is_some())
897            .field("has_remote_epoch", &self.remote_epoch.is_some())
898            .finish()
899    }
900}