hyphae_handshake/
handshake.rs

1use std::{iter::once, mem, ops::Deref};
2
3use rand_core::{CryptoRng, RngCore};
4
5use crate::{buffer::{AppendOnlyBuffer, Buffer, BufferFullError, MaxLenBuffer, VarIntSize, VarLengthPrefixBuffer}, crypto::{CryptoBackend, CryptoError, NoiseHandshake, SecretKeySetup, SymmetricKey, TransportCrypto}, customization::{HandshakeConfig, HandshakeDriver, HandshakeInfo}, Error};
6
7impl From<BufferFullError> for Error {
8    fn from(_: BufferFullError) -> Self {
9        Self::BufferSize
10    }
11}
12
13impl From<CryptoError> for Error {
14    fn from(value: CryptoError) -> Self {
15        match value {
16            CryptoError::DecryptionFailed => Error::HandshakeFailed,
17            _ => Error::Internal,
18        }
19    }
20}
21
22#[derive(Clone, Copy, Debug)]
23#[non_exhaustive]
24pub enum HandshakeVersion {
25    Version1,
26}
27
28impl HandshakeVersion {
29    pub fn label(self) -> &'static [u8] {
30        match self {
31            HandshakeVersion::Version1 => b"hyphae-h-v1",
32        }
33    }
34
35    pub fn id(self) -> u8 {
36        match self {
37            HandshakeVersion::Version1 => 1,
38        }
39    }
40}
41
42pub const HYPHAE_INITIAL_SECRET_HKDF_LABEL: &'static [u8] = b"hyphae initial";
43pub const HYPHAE_RETRY_SECRET_HKDF_LABEL: &'static [u8] = b"hyphae retry";
44pub const HYPHAE_KEY_ASK_LABEL: &'static [u8] = b"hyphae key";
45pub const HYPHAE_INIT_DATA_HKDF_LABEL: &'static [u8] = b"init data";
46pub const HYPHAE_RESP_DATA_HKDF_LABEL: &'static [u8] = b"resp data";
47pub const HYPHAE_INIT_HP_HKDF_LABEL: &'static [u8] = b"init hp";
48pub const HYPHAE_RESP_HP_HKDF_LABEL: &'static [u8] = b"resp hp";
49
50#[cfg(feature = "alloc")]
51pub struct AllocHyphaeHandshake<T: HandshakeDriver, B: CryptoBackend, R: Deref<Target = B>> {
52    crypto: R,
53    phase: AllocHyphaeHandshakePhase,
54    handshake_driver: Box<T>,
55    noise_handshake: Box<B::NoiseHandshake>,
56    peer_transport_params: Option<Vec<u8>>,
57    peer_zero_rtt_accepted: Option<bool>,
58    next_level_secret_ready: bool,
59}
60
61#[cfg(feature = "alloc")]
62impl <T: HandshakeDriver, B: CryptoBackend, R: Deref<Target = B>> AllocHyphaeHandshake<T, B, R> {
63    pub fn new_initiator<C> (handshake_config: &C, crypto: R, version: HandshakeVersion, transport_label: &[u8], transport_params: Vec<u8>, server_name: &str) -> Result<Self, Error>
64    where
65        C: HandshakeConfig<Driver = T>,
66    {
67        let mut preamble = Vec::new();
68        handshake_config.initiator_preamble(&mut preamble)?;
69
70        let mut noise_handshake = Box::new(crypto.new_handshake()?);
71        let mut noise_wrapper = NoiseHandshakeWrapper::wrap_init(noise_handshake.as_mut(), version, transport_label, &preamble, true);
72        let handshake_driver = Box::new(handshake_config.new_initiator(server_name, &mut noise_wrapper)?);
73
74        if noise_handshake.is_reset() {
75            return Err(Error::Internal);
76        }
77
78        let phase = if preamble.is_empty() {
79            AllocHyphaeHandshakePhase::Initiator(
80                AllocHyphaeInitiatorPhase::WriteInitiatorConfigNoise { transport_params }
81            )
82        } else {
83            AllocHyphaeHandshakePhase::Initiator(
84                AllocHyphaeInitiatorPhase::WritePreamble { preamble, transport_params }
85            )
86        };
87
88        Ok(Self {
89            crypto,
90            phase,
91            handshake_driver,
92            noise_handshake,
93            peer_transport_params: None,
94            peer_zero_rtt_accepted: None,
95            next_level_secret_ready: false,
96        })
97    }
98
99    pub fn new_responder<C> (handshake_config: &C, crypto: R, version: HandshakeVersion, transport_label: &[u8], transport_params: Vec<u8>, mut first_message: Vec<u8>) -> Result<Self, Error>
100    where
101        C: HandshakeConfig<Driver = T>,
102    {
103        let mut noise_handshake = Box::new(crypto.new_handshake()?);
104
105        let preamble = if MessageReader::decode_message_type(&first_message)? == HandshakeMessage::Preamble {
106            let reader = MessageReader::decode_in_place(&mut first_message, HandshakeMessage::Preamble, noise_handshake.as_mut())?;
107            reader.payload()?
108        } else {
109            &[]
110        };
111
112        let mut noise_wrapper = NoiseHandshakeWrapper::wrap_init(noise_handshake.as_mut(), version, transport_label, preamble, false);
113        let handshake_driver = Box::new(handshake_config.new_responder(preamble, &mut noise_wrapper)?);
114
115        if noise_handshake.is_reset() {
116            return Err(Error::Internal);
117        }
118
119        let phase = AllocHyphaeHandshakePhase::Responder(
120            AllocHyphaeResponderPhase::ReadInitiatorConfigNoise { transport_params }
121        );
122
123        let mut this = Self {
124            crypto,
125            phase,
126            handshake_driver,
127            noise_handshake,
128            peer_transport_params: None,
129            peer_zero_rtt_accepted: None,
130            next_level_secret_ready: false,
131        };
132
133        if preamble.is_empty() {
134            this.read_message(first_message)?
135        }
136        
137        Ok(this)
138    }
139
140    pub fn peer_params(&self) -> Option<&[u8]> {
141        self.peer_transport_params.as_ref().map(Vec::as_slice)
142    }
143
144    /// Returns true once the Noise portion of the handshake is finished
145    /// and 1-RTT keys are available.
146    /// 
147    /// Peers will still exchange final messages but this happens in the
148    /// 1-RTT packet space and cannot fail the handshake.
149    /// 
150    /// The Noise handshake state and its key material will be discarded
151    /// after a peer sends its final message.
152    pub fn is_handshake_finished(&self) -> bool {
153        match self.phase {
154            AllocHyphaeHandshakePhase::Initiator(AllocHyphaeInitiatorPhase::SendFinal { .. }) => true,
155            AllocHyphaeHandshakePhase::Responder(AllocHyphaeResponderPhase::SendFinal { .. }) => true,
156            AllocHyphaeHandshakePhase::Initiator(AllocHyphaeInitiatorPhase::RecvFinal) => true,
157            AllocHyphaeHandshakePhase::Responder(AllocHyphaeResponderPhase::RecvFinal) => true,
158            AllocHyphaeHandshakePhase::Initiator(AllocHyphaeInitiatorPhase::Finalized) => true,
159            AllocHyphaeHandshakePhase::Responder(AllocHyphaeResponderPhase::Finalized) => true,
160            _ => false,
161        }
162    }
163
164    /// Returns true once this peer has sent and received its final
165    /// message.
166    /// 
167    /// At this point, all handshake state can be discarded.
168    pub fn is_handshake_finalized(&self) -> bool {
169        //todo, broken, fix this - also need to dispose of noise handshake before reading each other's finals to clear keys in case the other side never responds
170        match self.phase {
171            AllocHyphaeHandshakePhase::Initiator(AllocHyphaeInitiatorPhase::Finalized) => true,
172            AllocHyphaeHandshakePhase::Responder(AllocHyphaeResponderPhase::Finalized) => true,
173            _ => false,
174        }
175    }
176
177    pub fn is_initiator(&self) -> bool {
178        self.noise_handshake.is_initiator()
179    }
180
181    pub fn remote_public(&self) -> Option<&[u8]> {
182        self.noise_handshake.remote_public()
183    }
184
185    pub fn final_handshake_hash(&self) -> Option<&[u8]> {
186        match self.noise_handshake.is_finished() {
187            true => Some(self.noise_handshake.handshake_hash()),
188            false => None,
189        }
190    }
191
192    pub fn handshake_driver(&self) -> &T {
193        &self.handshake_driver
194    }
195
196    pub fn zero_rtt_accepted(&self) -> Option<bool> {
197        self.peer_zero_rtt_accepted
198    }
199
200    pub fn next_level_secret_ready(&self) -> bool {
201        self.next_level_secret_ready
202    }
203
204    pub fn next_level_secret(&mut self, level_secret: &mut SymmetricKey) -> Result<(), Error> {
205        if self.next_level_secret_ready {
206            self.noise_handshake.get_ask(HYPHAE_KEY_ASK_LABEL, level_secret)?;
207            self.next_level_secret_ready = false;
208            Ok(())
209        } else {
210            Err(Error::Internal)
211        }
212    }
213
214    pub fn transport_crypto(&self) -> Result<B::TransportCrypto, Error> {
215        Ok(self.crypto.transport_crypto(&self.noise_handshake)?)
216    }
217
218    pub fn export_1rtt_rekey(&mut self, rekey: &mut B::TransportRekey) -> Result<(), Error> {
219        Ok(self.crypto.export_1rtt_rekey(&mut self.noise_handshake, rekey)?)
220    }
221
222    pub fn read_message(&mut self, message: Vec<u8>) -> Result<(), Error> {
223        match self.phase {
224            AllocHyphaeHandshakePhase::Initiator(_) => self.initiator_read_message(message),
225            AllocHyphaeHandshakePhase::Responder(_) => self.responder_read_message(message),
226        }
227    }
228
229    pub fn write_message(&mut self, buffer: &mut impl Buffer) -> Result<(), Error> {
230        match self.phase {
231            AllocHyphaeHandshakePhase::Initiator(_) => self.initiator_write_message(buffer),
232            AllocHyphaeHandshakePhase::Responder(_) => self.responder_write_message(buffer),
233        }
234    }
235
236    fn initiator_write_message(&mut self, buffer: &mut impl Buffer) -> Result<(), Error> {
237        let AllocHyphaeHandshakePhase::Initiator(ref mut phase) = self.phase else {
238            unreachable!();
239        };
240
241        match phase {
242            AllocHyphaeInitiatorPhase::WritePreamble { preamble, transport_params } => {
243                write_preamble(buffer, preamble.as_slice())?;
244                *phase = AllocHyphaeInitiatorPhase::WriteInitiatorConfigNoise {
245                    transport_params: mem::take(transport_params),
246                };
247                Ok(())
248            },
249
250            AllocHyphaeInitiatorPhase::WriteInitiatorConfigNoise { transport_params } => {
251                let transport_params = mem::take(transport_params);
252                write_initiator_initial(buffer, self.noise_handshake.as_mut(), transport_params.as_slice(), self.handshake_driver.as_mut(), phase.message_position()?)?;
253                *phase = AllocHyphaeInitiatorPhase::ReadResponderConfigNoise;
254                Ok(())
255            },
256
257            AllocHyphaeInitiatorPhase::Noise { .. } if self.noise_handshake.is_my_turn() => self.write_noise_message(buffer),
258
259            AllocHyphaeInitiatorPhase::SendFinal { received_final } => {
260                if self.next_level_secret_ready {
261                    return Err(Error::Internal); // Todo, maybe just always check this for safety
262                }
263
264                write_final(buffer, self.noise_handshake.as_mut(), self.handshake_driver.as_mut())?;
265                // Todo, destroy noise handshake state here.
266                match received_final {
267                    true => *phase = AllocHyphaeInitiatorPhase::Finalized,
268                    false => *phase = AllocHyphaeInitiatorPhase::RecvFinal,
269                }
270                Ok(())
271            },
272
273            _ => Ok(())
274        }
275    }
276
277    fn responder_read_message(&mut self, mut message: Vec<u8>) -> Result<(), Error> {
278        let AllocHyphaeHandshakePhase::Responder(ref mut phase) = self.phase else {
279            unreachable!();
280        };
281
282        match phase {
283            AllocHyphaeResponderPhase::ReadInitiatorConfigNoise { transport_params } => {
284                let transport_params = mem::take(transport_params);
285
286                let prev_hash = self.noise_handshake.handshake_hash().to_vec();
287                let reader = MessageReader::decode_in_place(&mut message, HandshakeMessage::Initial, self.noise_handshake.as_mut())?;
288                let (peer_transport_params, app_payload) = reader.initial_init_payloads()?;
289                self.peer_transport_params = Some(peer_transport_params.to_vec());
290
291                let mut noise_wrapper = NoiseHandshakeWrapper::wrap_payload(self.noise_handshake.as_mut(),Some(phase.message_position()?), Some(&prev_hash));
292                self.handshake_driver.read_noise_payload(app_payload, &mut noise_wrapper)?;
293
294                *phase = AllocHyphaeResponderPhase::WriteResponderConfigNoise {
295                    transport_params,
296                };
297
298                Ok(())
299            },
300
301            AllocHyphaeResponderPhase::Noise { .. } if !self.noise_handshake.is_my_turn() => self.read_noise_message(message),
302
303            AllocHyphaeResponderPhase::SendFinal { received_final: false } |
304            AllocHyphaeResponderPhase::RecvFinal => {
305                if self.next_level_secret_ready {
306                    return Err(Error::Internal);
307                }
308
309                let prev_hash = self.noise_handshake.handshake_hash().to_vec();
310                let reader = MessageReader::decode_in_place(&mut message, HandshakeMessage::Final, self.noise_handshake.as_mut())?;
311                let final_payload = reader.final_payload()?;
312                let mut noise_wrapper = NoiseHandshakeWrapper::wrap_payload(self.noise_handshake.as_mut(), None, Some(&prev_hash));
313                self.handshake_driver.read_final_payload(final_payload, &mut noise_wrapper)?;
314
315                match phase {
316                    AllocHyphaeResponderPhase::SendFinal { received_final: false } => {
317                        *phase = AllocHyphaeResponderPhase::SendFinal { received_final: true }
318                    },
319                    _ => *phase = AllocHyphaeResponderPhase::Finalized,
320                }
321
322                Ok(())
323            },
324
325            _ => Err(Error::HandshakeFailed)
326        }
327    }
328
329
330    fn responder_write_message(&mut self, buffer: &mut impl Buffer) -> Result<(), Error> {
331        let AllocHyphaeHandshakePhase::Responder(ref mut phase) = self.phase else {
332            unreachable!();
333        };
334
335        match phase {
336            AllocHyphaeResponderPhase::WriteResponderConfigNoise { transport_params } => {
337                // Build deferred payload.
338                let transport_params = mem::take(transport_params);
339                let mut deferred_payload = Vec::new();
340                write_responder_deferred_payload(&mut deferred_payload, self.noise_handshake.as_mut(), transport_params.as_slice(), false, self.handshake_driver.as_mut(), phase.message_position()?)?;
341
342                let crypto = self.crypto.transport_crypto(&self.noise_handshake)?;
343                let mut deferred_payload_hash = crypto.zeros_hash();
344                crypto.hash_into(&deferred_payload[1..], &mut deferred_payload_hash);
345
346                // Build initial message with deferred payload hash.
347                write_responder_initial(buffer, self.noise_handshake.as_mut(), &crypto.hash_as_slice(&deferred_payload_hash))?;
348
349                *phase = AllocHyphaeResponderPhase::WriteResponderDeferredPayload {
350                    deferred_payload,
351                };
352
353                self.next_level_secret_ready = true;
354
355                Ok(())
356            },
357
358            AllocHyphaeResponderPhase::WriteResponderDeferredPayload { deferred_payload } => {
359                if self.next_level_secret_ready {
360                    return Err(Error::Internal);
361                }
362
363                buffer.extend_from_slice(&deferred_payload)?;
364                *phase = AllocHyphaeResponderPhase::Noise { position: 3 };
365                self.check_noise_finished()
366            },
367
368            AllocHyphaeResponderPhase::Noise { .. } if self.noise_handshake.is_my_turn() => self.write_noise_message(buffer),
369
370            AllocHyphaeResponderPhase::SendFinal { received_final } => {
371                if self.next_level_secret_ready {
372                    return Err(Error::Internal); // Todo, maybe just always check this for safety
373                }
374
375                write_final(buffer, self.noise_handshake.as_mut(), self.handshake_driver.as_mut())?;
376                // Todo, destroy noise handshake state here.
377                match received_final {
378                    true => *phase = AllocHyphaeResponderPhase::Finalized,
379                    false => *phase = AllocHyphaeResponderPhase::RecvFinal,
380                }
381                Ok(())
382            },
383
384            _ => Ok(())
385        }
386    }
387
388    fn initiator_read_message(&mut self, mut message: Vec<u8>) -> Result<(), Error> {
389        let AllocHyphaeHandshakePhase::Initiator(ref mut phase) = self.phase else {
390            unreachable!();
391        };
392
393        match phase {
394            AllocHyphaeInitiatorPhase::ReadResponderConfigNoise => {
395                let prev_noise_hash = self.noise_handshake.handshake_hash().to_vec();
396                let reader = MessageReader::decode_in_place(&mut message, HandshakeMessage::Initial, self.noise_handshake.as_mut())?;
397                let deferred_payload_hash = reader.initial_resp_payloads()?;
398
399                *phase = AllocHyphaeInitiatorPhase::ReadResponderDeferredPayload {
400                    deferred_payload_hash: deferred_payload_hash.to_vec(),
401                    prev_noise_hash,
402                };
403
404                self.next_level_secret_ready = true;
405                Ok(())
406            },
407
408            AllocHyphaeInitiatorPhase::ReadResponderDeferredPayload { deferred_payload_hash, prev_noise_hash } => {
409                if self.next_level_secret_ready {
410                    return Err(Error::Internal);
411                }
412
413                let prev_noise_hash = mem::take(prev_noise_hash);
414
415                let reader = MessageReader::decode_in_place(&mut message, HandshakeMessage::DeferredPayload, self.noise_handshake.as_mut())?;
416
417                let crypto = self.crypto.transport_crypto(&self.noise_handshake)?;
418                let mut actual_payload_hash = crypto.zeros_hash();
419                crypto.hash_into(reader.payload, &mut actual_payload_hash);
420                if deferred_payload_hash.as_slice() != crypto.hash_as_slice(&actual_payload_hash) {
421                    return Err(Error::HandshakeFailed);
422                }
423
424                let (peer_params, zero_rtt_acc, app_payload) = reader.deferred_resp_payloads()?;
425                self.peer_zero_rtt_accepted = Some(zero_rtt_acc);
426
427                let mut noise_wrapper = NoiseHandshakeWrapper::wrap_payload(self.noise_handshake.as_mut(), Some(phase.message_position()?), Some(&prev_noise_hash));
428                self.handshake_driver.read_noise_payload(app_payload, &mut noise_wrapper)?;
429
430                self.peer_transport_params = Some(peer_params.to_vec());
431                *phase = AllocHyphaeInitiatorPhase::Noise { position: 3 };
432                self.check_noise_finished()
433            },
434
435            AllocHyphaeInitiatorPhase::Noise { .. } if !self.noise_handshake.is_my_turn() => self.read_noise_message(message),
436
437            AllocHyphaeInitiatorPhase::SendFinal { received_final: false } |
438            AllocHyphaeInitiatorPhase::RecvFinal => {
439                if self.next_level_secret_ready {
440                    return Err(Error::Internal);
441                }
442
443                let reader = MessageReader::decode_in_place(&mut message, HandshakeMessage::Final, self.noise_handshake.as_mut())?;
444                let final_payload = reader.final_payload()?;
445                let mut noise_wrapper = NoiseHandshakeWrapper::wrap_payload(self.noise_handshake.as_mut(), None, None);
446                self.handshake_driver.read_final_payload(final_payload, &mut noise_wrapper)?;
447
448                match phase {
449                    AllocHyphaeInitiatorPhase::SendFinal { received_final: false } => {
450                        *phase = AllocHyphaeInitiatorPhase::SendFinal { received_final: true }
451                    },
452                    _ => *phase = AllocHyphaeInitiatorPhase::Finalized,
453                }
454
455                Ok(())
456            },
457
458            _ => Err(Error::HandshakeFailed),
459        }
460    }
461
462    fn write_noise_message(&mut self, buffer: &mut impl Buffer) -> Result<(), Error> {
463        write_noise(buffer, self.noise_handshake.as_mut(), self.handshake_driver.as_mut(), self.phase.message_position()?)?;
464        self.phase.advance_message_position()?;
465        self.check_noise_finished()
466    }
467
468    fn read_noise_message(&mut self, mut message: Vec<u8>) -> Result<(), Error> {
469        let prev_hash = self.noise_handshake.handshake_hash().to_vec();
470        let reader = MessageReader::decode_in_place(&mut message, HandshakeMessage::Noise, self.noise_handshake.as_mut())?;
471        let mut noise_wrapper = NoiseHandshakeWrapper::wrap_payload(self.noise_handshake.as_mut(), Some(self.phase.message_position()?), Some(&prev_hash));
472        self.handshake_driver.read_noise_payload(reader.payload()?, &mut noise_wrapper)?;
473        self.phase.advance_message_position()?;
474        self.check_noise_finished()
475    }
476
477    fn check_noise_finished(&mut self) -> Result<(), Error> {
478        match &mut self.phase {
479            AllocHyphaeHandshakePhase::Initiator(AllocHyphaeInitiatorPhase::Noise { .. }) => {
480                if self.noise_handshake.is_finished() {
481                    self.phase = AllocHyphaeHandshakePhase::Initiator(AllocHyphaeInitiatorPhase::SendFinal { received_final: false });
482                }
483            },
484            AllocHyphaeHandshakePhase::Responder(AllocHyphaeResponderPhase::Noise { .. }) => {
485                if self.noise_handshake.is_finished() {
486                    self.phase = AllocHyphaeHandshakePhase::Responder(AllocHyphaeResponderPhase::SendFinal { received_final: false });
487                }
488            },
489            _ => return Err(Error::Internal)
490        }
491
492        if self.noise_handshake.is_finished() {
493            self.next_level_secret_ready = true;
494        }
495        Ok(())
496    }
497}
498
499enum AllocHyphaeHandshakePhase {
500    Initiator (AllocHyphaeInitiatorPhase),
501    Responder (AllocHyphaeResponderPhase),
502}
503
504impl AllocHyphaeHandshakePhase {
505    pub fn message_position(&self) -> Result<u8, Error> {
506        match self {
507            AllocHyphaeHandshakePhase::Initiator(phase) => phase.message_position(),
508            AllocHyphaeHandshakePhase::Responder(phase) => phase.message_position(),
509        }
510    }
511
512    pub fn advance_message_position(&mut self) -> Result<(), Error> {
513        let position = match self {
514            AllocHyphaeHandshakePhase::Initiator(AllocHyphaeInitiatorPhase::Noise { position }) => position,
515            AllocHyphaeHandshakePhase::Responder(AllocHyphaeResponderPhase::Noise { position }) => position,
516            _ => return Err(Error::Internal)
517        };
518        *position = position.checked_add(1).ok_or(Error::Internal)?;
519        Ok(())
520    }
521}
522
523enum AllocHyphaeInitiatorPhase {
524    WritePreamble {
525        preamble: Vec<u8>,
526        transport_params: Vec<u8>,
527    },
528    WriteInitiatorConfigNoise {
529        transport_params: Vec<u8>,
530    },
531    ReadResponderConfigNoise,
532    ReadResponderDeferredPayload {
533        deferred_payload_hash: Vec<u8>,
534        prev_noise_hash: Vec<u8>,
535    },
536    Noise {
537        position: u8,
538    },
539    SendFinal {
540        received_final: bool,
541    },
542    RecvFinal,
543    Finalized,
544}
545
546impl AllocHyphaeInitiatorPhase {
547    pub fn message_position(&self) -> Result<u8, Error> {
548        match self {
549            AllocHyphaeInitiatorPhase::WriteInitiatorConfigNoise { .. } => Ok(1),
550            AllocHyphaeInitiatorPhase::ReadResponderDeferredPayload { .. } => Ok(2),
551            AllocHyphaeInitiatorPhase::Noise { position } => Ok(*position),
552            _ => Err(Error::Internal),
553        }
554    }
555}
556
557enum AllocHyphaeResponderPhase {
558    ReadInitiatorConfigNoise {
559        transport_params: Vec<u8>,
560    },
561    WriteResponderConfigNoise {
562        transport_params: Vec<u8>,
563    },
564    WriteResponderDeferredPayload {
565        deferred_payload: Vec<u8>,
566    },
567    Noise {
568        position: u8,
569    },
570    SendFinal {
571        received_final: bool,
572    },
573    RecvFinal,
574    Finalized,
575}
576
577impl AllocHyphaeResponderPhase {
578    pub fn message_position(&self) -> Result<u8, Error> {
579        match self {
580            AllocHyphaeResponderPhase::ReadInitiatorConfigNoise { .. } => Ok(1),
581            AllocHyphaeResponderPhase::WriteResponderConfigNoise { .. } => Ok(2),
582            AllocHyphaeResponderPhase::Noise { position } => Ok(*position),
583            _ => Err(Error::Internal),
584        }
585    }
586}
587
588#[repr(u8)]
589#[derive(Clone, Copy, PartialEq, Eq, Debug)]
590enum HandshakeMessage {
591    Preamble = 1,
592    Initial = 2,
593    DeferredPayload = 3,
594    Noise = 4,
595    FinalPayload = 126,
596    Final = 127,
597    Failed = 255,
598}
599
600impl HandshakeMessage {
601    pub fn from_id(id: u8) -> Result<Self, Error> {
602        match id {
603            x if x == Self::Preamble as u8 => Ok(Self::Preamble),
604            x if x == Self::Initial as u8 => Ok(Self::Initial),
605            x if x == Self::DeferredPayload as u8 => Ok(Self::DeferredPayload),
606            x if x == Self::Noise as u8 => Ok(Self::Noise),
607            x if x == Self::Final as u8 => Ok(Self::Final),
608            x if x == Self::FinalPayload as u8 => Ok(Self::FinalPayload),
609            x if x == Self::Failed as u8 => Ok(Self::Failed),
610            _ => Err(Error::HandshakeFailed)
611        }
612    }
613
614    pub fn is_encrypted(self) -> bool {
615        match self {
616            HandshakeMessage::Initial |
617            HandshakeMessage::Noise => true,
618            _ => false,
619        }
620    }
621
622    pub fn has_compound_payload(self) -> bool {
623        match self {
624            HandshakeMessage::Initial |
625            HandshakeMessage::DeferredPayload |
626            HandshakeMessage::FinalPayload => true,
627            _ => false,
628        }
629    }
630
631    pub fn has_payload(self) -> Option<bool> {
632        match self {
633            HandshakeMessage::Preamble => Some(true),
634            HandshakeMessage::Initial => Some(true),
635            HandshakeMessage::DeferredPayload => Some(true),
636            HandshakeMessage::Noise => None,
637            HandshakeMessage::FinalPayload => Some(true),
638            HandshakeMessage::Final => Some(false),
639            HandshakeMessage::Failed => Some(false),
640        }
641    }
642}
643
644#[repr(u8)]
645#[derive(Clone, Copy, PartialEq, Eq, Debug)]
646enum PayloadFrame {
647    Padding = 0,
648    ApplicationPayload = 1,
649    TransportParameters = 2,
650    DeferredPayloadHash = 3,
651    ZeroRttAccepted = 64,
652}
653
654impl PayloadFrame {
655    /// Optional frame lower bound (inclusive).
656    /// 
657    /// Optional frames can be ignored if they are not supported. These
658    /// frames must begin with a `VarInt` length prefix and not be
659    /// essential to the handshake. Frame IDs less than the optional
660    /// base must be recognized or the handshake will fail.
661    /// 
662    /// This feature isn't used yet, but is here to allow extensibility
663    /// without revving the handshake version.
664    const OPTIONAL_BASE: u8 = 128;
665
666    fn ok_in(self, message: HandshakeMessage, from_initiator: bool) -> Result<(), Error> {
667        match (message, from_initiator, self) {
668            (HandshakeMessage::Initial, true, Self::Padding) => Ok(()),
669            (HandshakeMessage::Initial, true, Self::TransportParameters) => Ok(()),
670            (HandshakeMessage::Initial, true, Self::ApplicationPayload) => Ok(()),
671            (HandshakeMessage::Initial, false, Self::Padding) => Ok(()),
672            (HandshakeMessage::Initial, false, Self::DeferredPayloadHash) => Ok(()),
673            (HandshakeMessage::DeferredPayload, false, Self::Padding) => Ok(()),
674            (HandshakeMessage::DeferredPayload, false, Self::TransportParameters) => Ok(()),
675            (HandshakeMessage::DeferredPayload, false, Self::ZeroRttAccepted) => Ok(()),
676            (HandshakeMessage::DeferredPayload, false, Self::ApplicationPayload) => Ok(()),
677            (HandshakeMessage::FinalPayload, false, Self::ApplicationPayload) => Ok(()),
678            _ => Err(Error::HandshakeFailed)
679        }
680    }
681
682    fn from_id(frame_id: u8) -> Result<Option<Self>, Error> {
683        match frame_id {
684            id if id == Self::Padding as u8 => Ok(Some(Self::Padding)),
685            id if id == Self::ApplicationPayload as u8 => Ok(Some(Self::ApplicationPayload)),
686            id if id == Self::TransportParameters as u8 => Ok(Some(Self::TransportParameters)),
687            id if id == Self::DeferredPayloadHash as u8 => Ok(Some(Self::DeferredPayloadHash)),
688            id if id == Self::ZeroRttAccepted as u8 => Ok(Some(Self::ZeroRttAccepted)),
689            id if id >= Self::OPTIONAL_BASE => Ok(None),
690            _ => Err(Error::HandshakeFailed),
691        }
692    }
693
694    fn get_frame_payload(this: Option<Self>, mut remaining: &[u8]) -> Result<(&[u8], &[u8]), Error> {
695        let payload_len = match this {
696            Some(Self::ApplicationPayload) => remaining.len(),
697            Some(Self::DeferredPayloadHash) => remaining.len(),
698            Some(Self::Padding) => 0,
699            Some(Self::ZeroRttAccepted) => 0,
700            None | Some(Self::TransportParameters) => {
701                // todo, better varint decoding
702                let prefix_len = VarIntSize::from_msb(remaining.get(0).copied().ok_or(Error::HandshakeFailed)?);
703                if remaining.len() < prefix_len.len() {
704                    return Err(Error::HandshakeFailed);
705                }
706                let (prefix, r) = remaining.split_at(prefix_len.len());
707                remaining = r;
708                let mut prefix64 = [0u8; 8];
709                prefix64[8 - prefix.len()..].copy_from_slice(prefix);
710                prefix64[8 - prefix.len()] &= !0xC0;
711                u64::from_be_bytes(prefix64).try_into().map_err(|_| Error::HandshakeFailed)?
712            }
713        };
714        if payload_len > remaining.len() {
715            return Err(Error::HandshakeFailed)
716        }
717        Ok(remaining.split_at(payload_len))
718    }
719
720    pub fn next_frame(remaining: &[u8], message: HandshakeMessage, from_initiator: bool) -> Result<Option<(Self, &[u8], &[u8])>, Error> {
721        let Some(frame_id) = remaining.get(0).cloned() else {
722            return Ok(None);
723        };
724
725        let frame_type = Self::from_id(frame_id)?;
726        if let Some(frame_type) = frame_type {
727            frame_type.ok_in(message, from_initiator)?;
728        }
729        let (frame_payload, remaining) = Self::get_frame_payload(frame_type, &remaining[1..])?;
730
731        match frame_type {
732            Some(frame_type) if frame_type != Self::Padding =>
733                Ok(Some((frame_type, frame_payload, remaining))),
734
735            _ => Self::next_frame(remaining, message, from_initiator), // todo, this could be an issue if it isn't a tail call
736        }
737    }
738    
739}
740
741struct NoiseHandshakeWrapper<'a, X: NoiseHandshake> {
742    inner: &'a mut X,
743    init_info: Option<(HandshakeVersion, &'a [u8], &'a [u8])>,
744    initiator: Option<bool>,
745    position: Option<u8>,
746    prev_hash: Option<&'a [u8]>,
747}
748
749impl <'a, X: NoiseHandshake> NoiseHandshakeWrapper<'a, X> {
750    pub fn wrap_init(inner: &'a mut X, version: HandshakeVersion, transport_label: &'a [u8], preamble: &'a [u8], initiator: bool) -> Self {
751        Self {
752            inner,
753            init_info: Some((version, transport_label, preamble)),
754            initiator: Some(initiator),
755            position: None,
756            prev_hash: None,
757        }
758    }
759
760    pub fn wrap_payload(inner: &'a mut X, position: Option<u8>, prev_hash: Option<&'a [u8]>) -> Self {
761        Self {
762            inner,
763            init_info: None,
764            initiator: None,
765            position,
766            prev_hash,
767        }
768    }
769}
770
771impl <X: NoiseHandshake> HandshakeInfo for NoiseHandshakeWrapper<'_, X> {
772    fn initialize(&mut self, rng: &mut (impl CryptoRng + RngCore), protocol: &str, prologue: &[u8], s: Option<SecretKeySetup>, rs: Option<&[u8]>) -> Result<(), CryptoError> {
773        let Some(initiator) = self.initiator else {
774            return Err(CryptoError::Internal);
775        };
776        let Some((version, transport_label, preamble)) = self.init_info else {
777            return Err(CryptoError::Internal);
778        };
779        let Ok(preamble_len) = u16::try_from(preamble.len()) else {
780            return Err(CryptoError::Internal);
781        };
782        let preamble_len_le = preamble_len.to_le_bytes();
783
784        if !self.inner.is_reset() {
785            return Err(CryptoError::Internal);
786        }
787
788        let handshake_prologue =
789            once(version.label())
790            .chain(once(b".".as_slice()))
791            .chain(once(transport_label))
792            .chain(once(b".".as_slice()))
793            .chain(once(preamble_len_le.as_slice()))
794            .chain(once(preamble))
795            .chain(once(prologue));
796
797        self.inner.initialize(rng, protocol, initiator, handshake_prologue, s, rs)
798    }
799
800    fn set_token(&mut self, _token: &str, _value: &[u8]) -> Result<(), CryptoError> {
801        Err(CryptoError::Internal)
802    }
803
804    fn is_initiator(&self) -> bool {
805        if self.inner.is_reset() {
806            self.initiator.unwrap_or_default()
807        } else {
808            self.inner.is_initiator()
809        }
810    }
811
812    fn is_finished(&self) -> bool {
813        self.inner.is_finished()
814    }
815    
816    fn handshake_position(&self) -> Option<u8> {
817        self.position
818    }
819
820    fn prev_handshake_hash(&self) -> Option<&[u8]> {
821        self.prev_hash.or_else(|| Some(self.inner.handshake_hash()))
822    }
823
824    fn final_handshake_hash(&self) -> Option<&[u8]> {
825        match self.inner.is_finished() {
826            true => Some(self.inner.handshake_hash()),
827            false => None,
828        }
829    }
830}
831
832struct MessageReader<'a> {
833    payload: &'a [u8],
834    message_type: HandshakeMessage,
835}
836
837impl <'a> MessageReader<'a> {
838    pub fn decode_message_type(buffer: &[u8]) -> Result<HandshakeMessage, Error> {
839        if buffer.is_empty() {
840            return Err(Error::HandshakeFailed);
841        }
842
843        HandshakeMessage::from_id(buffer[0])
844    }
845
846    pub fn decode_in_place(buffer: &'a mut [u8], expect: HandshakeMessage, noise: &mut impl NoiseHandshake) -> Result<Self, Error> {
847        let message_type = Self::decode_message_type(buffer)?;
848        let buffer = &mut buffer[1..];
849
850        let expected = match expect {
851            HandshakeMessage::Final => 
852                message_type == HandshakeMessage::Final ||
853                message_type == HandshakeMessage::FinalPayload,
854            expect => expect == message_type,
855        };
856        if !expected {
857            return Err(Error::HandshakeFailed);
858        }
859        
860        // Decrypt Noise messages in place, extract payload.
861        let payload = if message_type.is_encrypted() {
862            noise.read_message_in_place(buffer)?
863        } else {
864            buffer
865        };
866
867        // Check compound payload version.
868        if message_type.has_compound_payload() &&
869           (payload.is_empty() || payload[0] != HandshakeVersion::Version1.id())
870        {
871            return Err(Error::HandshakeFailed);
872        }
873
874        // Verify payload expectation.
875        if let Some(has_payload) = message_type.has_payload() {
876            if has_payload == payload.is_empty() {
877                return Err(Error::HandshakeFailed);
878            }
879        }
880
881        Ok(Self {
882            payload,
883            message_type,
884        })
885    }
886
887    /// Return the payload from messages without a compound payload.
888    pub fn payload(&self) -> Result<&'a [u8], Error> {
889        if self.message_type.has_compound_payload() {
890            return Err(Error::Internal);
891        }
892        Ok(self.payload)
893    }
894
895    /// Return `(transport_params, application_payload)` from the
896    /// initiator's initial message.
897    pub fn initial_init_payloads(&self) -> Result<(&'a [u8], &'a [u8]), Error> {
898        if self.message_type != HandshakeMessage::Initial {
899            return Err(Error::Internal);
900        }
901
902        let mut frame_cursor = &self.payload[1..];
903        let mut transport_params = None;
904        let mut application_payload = None;
905
906        loop {
907            let Some((frame, payload, remaining)) = PayloadFrame::next_frame(frame_cursor, self.message_type, true)? else {
908                break;
909            };
910            match frame {
911                PayloadFrame::ApplicationPayload if application_payload.is_some() => return Err(Error::HandshakeFailed),
912                PayloadFrame::ApplicationPayload => application_payload = Some(payload),
913                PayloadFrame::TransportParameters if transport_params.is_some() => return Err(Error::HandshakeFailed),
914                PayloadFrame::TransportParameters => transport_params = Some(payload),
915                _ => {}
916            }
917            frame_cursor = remaining;
918        }
919
920        if let Some(true) = application_payload.map(|s| s.is_empty()) {
921            return Err(Error::HandshakeFailed);
922        }
923
924        application_payload.get_or_insert(&[]);
925
926        match (transport_params, application_payload) {
927            (Some(tp), Some(ap)) => Ok((tp, ap)),
928            _ => Err(Error::HandshakeFailed)
929        }
930    }
931
932    /// Return the deferred payload hash from the responder's initial
933    /// message.
934    pub fn initial_resp_payloads(&self) -> Result<&'a [u8], Error> {
935        if self.message_type != HandshakeMessage::Initial {
936            return Err(Error::Internal);
937        }
938
939        let mut frame_cursor = &self.payload[1..];
940        let mut deferred_hash = None;
941
942        loop {
943            let Some((frame, payload, remaining)) = PayloadFrame::next_frame(frame_cursor, self.message_type, false)? else {
944                break;
945            };
946            match frame {
947                PayloadFrame::DeferredPayloadHash if deferred_hash.is_some() => return Err(Error::HandshakeFailed),
948                PayloadFrame::DeferredPayloadHash => deferred_hash = Some(payload),
949                _ => {}
950            }
951            frame_cursor = remaining;
952        }
953        
954        if let Some(true) = deferred_hash.map(|s| s.is_empty()) {
955            return Err(Error::HandshakeFailed);
956        }
957
958        match deferred_hash {
959            Some(dh) => Ok(dh),
960            _ => Err(Error::HandshakeFailed)
961        }
962    }
963
964    /// Return `(transport_params, zero_rtt_accepted, application_payload)`
965    /// from the responder's deferred payload message.
966    pub fn deferred_resp_payloads(&self) -> Result<(&'a [u8], bool, &'a [u8]), Error> {
967        if self.message_type != HandshakeMessage::DeferredPayload {
968            return Err(Error::Internal);
969        }
970
971        let mut frame_cursor = &self.payload[1..];
972        let mut transport_params = None;
973        let mut application_payload = None;
974        let mut zero_rtt_accepted = None;
975
976        loop {
977            let Some((frame, payload, remaining)) = PayloadFrame::next_frame(frame_cursor, self.message_type, false)? else {
978                break;
979            };
980            match frame {
981                PayloadFrame::ApplicationPayload if application_payload.is_some() => return Err(Error::HandshakeFailed),
982                PayloadFrame::ApplicationPayload => application_payload = Some(payload),
983                PayloadFrame::TransportParameters if transport_params.is_some() => return Err(Error::HandshakeFailed),
984                PayloadFrame::TransportParameters => transport_params = Some(payload),
985                PayloadFrame::ZeroRttAccepted if zero_rtt_accepted.is_some() => return Err(Error::HandshakeFailed),
986                PayloadFrame::ZeroRttAccepted => zero_rtt_accepted = Some(true),
987                _ => {}
988            }
989            frame_cursor = remaining;
990        }
991        
992        if let Some(true) = application_payload.map(|s| s.is_empty()) {
993            return Err(Error::HandshakeFailed);
994        }
995
996        application_payload.get_or_insert(&[]);
997        zero_rtt_accepted.get_or_insert(false);
998
999        match (transport_params, zero_rtt_accepted, application_payload) {
1000            (Some(tp), Some(zrtt), Some(ap)) => Ok((tp, zrtt, ap)),
1001            _ => Err(Error::HandshakeFailed)
1002        }
1003    }
1004
1005    /// Return the final message's payload or an empty slice if one
1006    /// wasn't sent.
1007    pub fn final_payload(&self) -> Result<&'a [u8], Error> {
1008        match self.message_type {
1009            HandshakeMessage::Final => return Ok(&[]),
1010            HandshakeMessage::FinalPayload => {},
1011            _ => return Err(Error::Internal),
1012        }
1013
1014        let mut frame_cursor = &self.payload[1..];
1015        let mut final_payload = None;
1016
1017        loop {
1018            let Some((frame, payload, remaining)) = PayloadFrame::next_frame(frame_cursor, self.message_type, false)? else {
1019                break;
1020            };
1021            match frame {
1022                PayloadFrame::ApplicationPayload if final_payload.is_some() => return Err(Error::HandshakeFailed),
1023                PayloadFrame::ApplicationPayload => final_payload = Some(payload),
1024                _ => {}
1025            }
1026            frame_cursor = remaining;
1027        }
1028
1029        match final_payload {
1030            Some(fp) => Ok(fp),
1031            _ => Err(Error::HandshakeFailed)
1032        }
1033    }
1034}
1035
1036fn write_preamble(buffer: &mut impl Buffer, preamble: &[u8]) -> Result<(), Error> {
1037    let mut buffer = MaxLenBuffer::new(buffer, u16::MAX as usize)?;
1038    buffer.push(HandshakeMessage::Preamble as u8)?;
1039    buffer.extend_from_slice(preamble)?;
1040    Ok(())
1041}
1042
1043fn write_initiator_initial(buffer: &mut impl Buffer, noise: &mut impl NoiseHandshake, transport_params: &[u8], driver: &mut impl HandshakeDriver, position: u8) -> Result<(), Error> {
1044    let mut buffer = MaxLenBuffer::new(buffer, u16::MAX as usize)?;
1045    let (token_padding, tag_padding) = noise.next_message_layout()?;
1046    buffer.push(HandshakeMessage::Initial as u8)?;
1047    insert_padding(&mut buffer, token_padding)?;
1048    buffer.push(HandshakeVersion::Version1.id())?;
1049    insert_varlen_frame(&mut buffer, PayloadFrame::TransportParameters, transport_params)?;
1050    insert_application_payload(&mut buffer, noise, driver, Some(position))?;
1051    insert_padding(&mut buffer, tag_padding)?;
1052    noise.write_message_in_place(&mut buffer.as_mut()[1..])?;
1053    Ok(())
1054}
1055
1056fn write_responder_initial(buffer: &mut impl Buffer, noise: &mut impl NoiseHandshake, deferred_payload_hash: &[u8]) -> Result<(), Error> {
1057    let mut buffer = MaxLenBuffer::new(buffer, u16::MAX as usize)?;
1058    let (token_padding, tag_padding) = noise.next_message_layout()?;
1059    buffer.push(HandshakeMessage::Initial as u8)?;
1060    insert_padding(&mut buffer, token_padding)?;
1061    buffer.push(HandshakeVersion::Version1.id())?;
1062    buffer.push(PayloadFrame::DeferredPayloadHash as u8)?;
1063    buffer.extend_from_slice(deferred_payload_hash)?;
1064    insert_padding(&mut buffer, tag_padding)?;
1065    noise.write_message_in_place(&mut buffer.as_mut()[1..])?;
1066    Ok(())
1067}
1068
1069fn write_responder_deferred_payload(buffer: &mut impl Buffer, noise: &mut impl NoiseHandshake, transport_params: &[u8], zero_rtt_accepted: bool, driver: &mut impl HandshakeDriver, position: u8) -> Result<(), Error> {
1070    let mut buffer = MaxLenBuffer::new(buffer, u16::MAX as usize)?;
1071    buffer.push(HandshakeMessage::DeferredPayload as u8)?;
1072    buffer.push(HandshakeVersion::Version1.id())?;
1073    insert_varlen_frame(&mut buffer, PayloadFrame::TransportParameters, transport_params)?;
1074    if zero_rtt_accepted {
1075        buffer.push(PayloadFrame::ZeroRttAccepted as u8)?;
1076    }
1077    insert_application_payload(&mut buffer, noise, driver, Some(position))?;
1078    Ok(())
1079}
1080
1081fn write_noise(buffer: &mut impl Buffer, noise: &mut impl NoiseHandshake, driver: &mut impl HandshakeDriver, position: u8) -> Result<(), Error> {
1082    let mut buffer = MaxLenBuffer::new(buffer, u16::MAX as usize)?;
1083    let (token_padding, tag_padding) = noise.next_message_layout()?;
1084    buffer.push(HandshakeMessage::Noise as u8)?;
1085    insert_padding(&mut buffer, token_padding)?;
1086
1087    let mut noise_wrapper = NoiseHandshakeWrapper::wrap_payload(noise, Some(position), None);
1088    driver.write_noise_payload(&mut AppendOnlyBuffer::new(&mut buffer), &mut noise_wrapper)?;
1089
1090    insert_padding(&mut buffer, tag_padding)?;
1091    noise.write_message_in_place(&mut buffer.as_mut()[1..])?;
1092    Ok(())
1093}
1094
1095fn write_final(buffer: &mut impl Buffer, noise: &mut impl NoiseHandshake, driver: &mut impl HandshakeDriver) -> Result<(), Error> {
1096    let mut buffer = MaxLenBuffer::new(buffer, u16::MAX as usize)?;
1097    let mut buffer = AppendOnlyBuffer::new(&mut buffer);
1098    buffer.push(HandshakeMessage::FinalPayload as u8)?;
1099    buffer.push(HandshakeVersion::Version1.id())?;
1100    buffer.push(PayloadFrame::ApplicationPayload as u8)?;
1101    let mut noise_wrapper = NoiseHandshakeWrapper::wrap_payload(noise, None, None);
1102    driver.write_final_payload(&mut AppendOnlyBuffer::new(&mut buffer), &mut noise_wrapper)?;
1103    if buffer.len() == 3 {
1104        buffer.clear();
1105        buffer.push(HandshakeMessage::Final as u8)?;
1106    }
1107    Ok(())
1108}
1109
1110fn insert_varlen_frame(buffer: &mut impl Buffer, frame: PayloadFrame, payload: &[u8]) -> Result<(), Error> {
1111    buffer.push(frame as u8)?;
1112    let mut len_buffer = VarLengthPrefixBuffer::new(buffer, payload.len())?;
1113    len_buffer.extend_from_slice(payload)?;
1114    Ok(())
1115}
1116
1117fn insert_application_payload(buffer: &mut impl Buffer, noise: &mut impl NoiseHandshake, driver: &mut impl HandshakeDriver, position: Option<u8>) -> Result<(), Error> {
1118    let mut buffer = AppendOnlyBuffer::new(buffer);
1119    buffer.push(PayloadFrame::ApplicationPayload as u8)?;
1120    let mut noise_wrapper = NoiseHandshakeWrapper::wrap_payload(noise, position, None);
1121    driver.write_noise_payload(&mut AppendOnlyBuffer::new(&mut buffer), &mut noise_wrapper)?;
1122    if buffer.len() == 1 {
1123        buffer.clear();
1124    }
1125    Ok(())
1126}
1127
1128fn insert_padding(buffer: &mut impl Buffer, len: usize) -> Result<(), Error> {
1129    for _ in 0..len {
1130        buffer.push(0)?;
1131    }
1132    Ok(())
1133}