1use std::{iter::once, mem, ops::Deref};
2
3use rand_core::{CryptoRng, RngCore};
4
5use crate::{buffer::{AppendOnlyBuffer, Buffer, BufferFullError, MaxLenBuffer, VarIntSize, VarLengthPrefixBuffer}, crypto::{CryptoBackend, CryptoError, NoiseHandshake, SecretKeySetup, SymmetricKey, TransportCrypto}, customization::{HandshakeConfig, HandshakeDriver, HandshakeInfo}, Error};
6
7impl From<BufferFullError> for Error {
8 fn from(_: BufferFullError) -> Self {
9 Self::BufferSize
10 }
11}
12
13impl From<CryptoError> for Error {
14 fn from(value: CryptoError) -> Self {
15 match value {
16 CryptoError::DecryptionFailed => Error::HandshakeFailed,
17 _ => Error::Internal,
18 }
19 }
20}
21
22#[derive(Clone, Copy, Debug)]
23#[non_exhaustive]
24pub enum HandshakeVersion {
25 Version1,
26}
27
28impl HandshakeVersion {
29 pub fn label(self) -> &'static [u8] {
30 match self {
31 HandshakeVersion::Version1 => b"hyphae-h-v1",
32 }
33 }
34
35 pub fn id(self) -> u8 {
36 match self {
37 HandshakeVersion::Version1 => 1,
38 }
39 }
40}
41
42pub const HYPHAE_INITIAL_SECRET_HKDF_LABEL: &'static [u8] = b"hyphae initial";
43pub const HYPHAE_RETRY_SECRET_HKDF_LABEL: &'static [u8] = b"hyphae retry";
44pub const HYPHAE_KEY_ASK_LABEL: &'static [u8] = b"hyphae key";
45pub const HYPHAE_INIT_DATA_HKDF_LABEL: &'static [u8] = b"init data";
46pub const HYPHAE_RESP_DATA_HKDF_LABEL: &'static [u8] = b"resp data";
47pub const HYPHAE_INIT_HP_HKDF_LABEL: &'static [u8] = b"init hp";
48pub const HYPHAE_RESP_HP_HKDF_LABEL: &'static [u8] = b"resp hp";
49
50#[cfg(feature = "alloc")]
51pub struct AllocHyphaeHandshake<T: HandshakeDriver, B: CryptoBackend, R: Deref<Target = B>> {
52 crypto: R,
53 phase: AllocHyphaeHandshakePhase,
54 handshake_driver: Box<T>,
55 noise_handshake: Box<B::NoiseHandshake>,
56 peer_transport_params: Option<Vec<u8>>,
57 peer_zero_rtt_accepted: Option<bool>,
58 next_level_secret_ready: bool,
59}
60
61#[cfg(feature = "alloc")]
62impl <T: HandshakeDriver, B: CryptoBackend, R: Deref<Target = B>> AllocHyphaeHandshake<T, B, R> {
63 pub fn new_initiator<C> (handshake_config: &C, crypto: R, version: HandshakeVersion, transport_label: &[u8], transport_params: Vec<u8>, server_name: &str) -> Result<Self, Error>
64 where
65 C: HandshakeConfig<Driver = T>,
66 {
67 let mut preamble = Vec::new();
68 handshake_config.initiator_preamble(&mut preamble)?;
69
70 let mut noise_handshake = Box::new(crypto.new_handshake()?);
71 let mut noise_wrapper = NoiseHandshakeWrapper::wrap_init(noise_handshake.as_mut(), version, transport_label, &preamble, true);
72 let handshake_driver = Box::new(handshake_config.new_initiator(server_name, &mut noise_wrapper)?);
73
74 if noise_handshake.is_reset() {
75 return Err(Error::Internal);
76 }
77
78 let phase = if preamble.is_empty() {
79 AllocHyphaeHandshakePhase::Initiator(
80 AllocHyphaeInitiatorPhase::WriteInitiatorConfigNoise { transport_params }
81 )
82 } else {
83 AllocHyphaeHandshakePhase::Initiator(
84 AllocHyphaeInitiatorPhase::WritePreamble { preamble, transport_params }
85 )
86 };
87
88 Ok(Self {
89 crypto,
90 phase,
91 handshake_driver,
92 noise_handshake,
93 peer_transport_params: None,
94 peer_zero_rtt_accepted: None,
95 next_level_secret_ready: false,
96 })
97 }
98
99 pub fn new_responder<C> (handshake_config: &C, crypto: R, version: HandshakeVersion, transport_label: &[u8], transport_params: Vec<u8>, mut first_message: Vec<u8>) -> Result<Self, Error>
100 where
101 C: HandshakeConfig<Driver = T>,
102 {
103 let mut noise_handshake = Box::new(crypto.new_handshake()?);
104
105 let preamble = if MessageReader::decode_message_type(&first_message)? == HandshakeMessage::Preamble {
106 let reader = MessageReader::decode_in_place(&mut first_message, HandshakeMessage::Preamble, noise_handshake.as_mut())?;
107 reader.payload()?
108 } else {
109 &[]
110 };
111
112 let mut noise_wrapper = NoiseHandshakeWrapper::wrap_init(noise_handshake.as_mut(), version, transport_label, preamble, false);
113 let handshake_driver = Box::new(handshake_config.new_responder(preamble, &mut noise_wrapper)?);
114
115 if noise_handshake.is_reset() {
116 return Err(Error::Internal);
117 }
118
119 let phase = AllocHyphaeHandshakePhase::Responder(
120 AllocHyphaeResponderPhase::ReadInitiatorConfigNoise { transport_params }
121 );
122
123 let mut this = Self {
124 crypto,
125 phase,
126 handshake_driver,
127 noise_handshake,
128 peer_transport_params: None,
129 peer_zero_rtt_accepted: None,
130 next_level_secret_ready: false,
131 };
132
133 if preamble.is_empty() {
134 this.read_message(first_message)?
135 }
136
137 Ok(this)
138 }
139
140 pub fn peer_params(&self) -> Option<&[u8]> {
141 self.peer_transport_params.as_ref().map(Vec::as_slice)
142 }
143
144 pub fn is_handshake_finished(&self) -> bool {
153 match self.phase {
154 AllocHyphaeHandshakePhase::Initiator(AllocHyphaeInitiatorPhase::SendFinal { .. }) => true,
155 AllocHyphaeHandshakePhase::Responder(AllocHyphaeResponderPhase::SendFinal { .. }) => true,
156 AllocHyphaeHandshakePhase::Initiator(AllocHyphaeInitiatorPhase::RecvFinal) => true,
157 AllocHyphaeHandshakePhase::Responder(AllocHyphaeResponderPhase::RecvFinal) => true,
158 AllocHyphaeHandshakePhase::Initiator(AllocHyphaeInitiatorPhase::Finalized) => true,
159 AllocHyphaeHandshakePhase::Responder(AllocHyphaeResponderPhase::Finalized) => true,
160 _ => false,
161 }
162 }
163
164 pub fn is_handshake_finalized(&self) -> bool {
169 match self.phase {
171 AllocHyphaeHandshakePhase::Initiator(AllocHyphaeInitiatorPhase::Finalized) => true,
172 AllocHyphaeHandshakePhase::Responder(AllocHyphaeResponderPhase::Finalized) => true,
173 _ => false,
174 }
175 }
176
177 pub fn is_initiator(&self) -> bool {
178 self.noise_handshake.is_initiator()
179 }
180
181 pub fn remote_public(&self) -> Option<&[u8]> {
182 self.noise_handshake.remote_public()
183 }
184
185 pub fn final_handshake_hash(&self) -> Option<&[u8]> {
186 match self.noise_handshake.is_finished() {
187 true => Some(self.noise_handshake.handshake_hash()),
188 false => None,
189 }
190 }
191
192 pub fn handshake_driver(&self) -> &T {
193 &self.handshake_driver
194 }
195
196 pub fn zero_rtt_accepted(&self) -> Option<bool> {
197 self.peer_zero_rtt_accepted
198 }
199
200 pub fn next_level_secret_ready(&self) -> bool {
201 self.next_level_secret_ready
202 }
203
204 pub fn next_level_secret(&mut self, level_secret: &mut SymmetricKey) -> Result<(), Error> {
205 if self.next_level_secret_ready {
206 self.noise_handshake.get_ask(HYPHAE_KEY_ASK_LABEL, level_secret)?;
207 self.next_level_secret_ready = false;
208 Ok(())
209 } else {
210 Err(Error::Internal)
211 }
212 }
213
214 pub fn transport_crypto(&self) -> Result<B::TransportCrypto, Error> {
215 Ok(self.crypto.transport_crypto(&self.noise_handshake)?)
216 }
217
218 pub fn export_1rtt_rekey(&mut self, rekey: &mut B::TransportRekey) -> Result<(), Error> {
219 Ok(self.crypto.export_1rtt_rekey(&mut self.noise_handshake, rekey)?)
220 }
221
222 pub fn read_message(&mut self, message: Vec<u8>) -> Result<(), Error> {
223 match self.phase {
224 AllocHyphaeHandshakePhase::Initiator(_) => self.initiator_read_message(message),
225 AllocHyphaeHandshakePhase::Responder(_) => self.responder_read_message(message),
226 }
227 }
228
229 pub fn write_message(&mut self, buffer: &mut impl Buffer) -> Result<(), Error> {
230 match self.phase {
231 AllocHyphaeHandshakePhase::Initiator(_) => self.initiator_write_message(buffer),
232 AllocHyphaeHandshakePhase::Responder(_) => self.responder_write_message(buffer),
233 }
234 }
235
236 fn initiator_write_message(&mut self, buffer: &mut impl Buffer) -> Result<(), Error> {
237 let AllocHyphaeHandshakePhase::Initiator(ref mut phase) = self.phase else {
238 unreachable!();
239 };
240
241 match phase {
242 AllocHyphaeInitiatorPhase::WritePreamble { preamble, transport_params } => {
243 write_preamble(buffer, preamble.as_slice())?;
244 *phase = AllocHyphaeInitiatorPhase::WriteInitiatorConfigNoise {
245 transport_params: mem::take(transport_params),
246 };
247 Ok(())
248 },
249
250 AllocHyphaeInitiatorPhase::WriteInitiatorConfigNoise { transport_params } => {
251 let transport_params = mem::take(transport_params);
252 write_initiator_initial(buffer, self.noise_handshake.as_mut(), transport_params.as_slice(), self.handshake_driver.as_mut(), phase.message_position()?)?;
253 *phase = AllocHyphaeInitiatorPhase::ReadResponderConfigNoise;
254 Ok(())
255 },
256
257 AllocHyphaeInitiatorPhase::Noise { .. } if self.noise_handshake.is_my_turn() => self.write_noise_message(buffer),
258
259 AllocHyphaeInitiatorPhase::SendFinal { received_final } => {
260 if self.next_level_secret_ready {
261 return Err(Error::Internal); }
263
264 write_final(buffer, self.noise_handshake.as_mut(), self.handshake_driver.as_mut())?;
265 match received_final {
267 true => *phase = AllocHyphaeInitiatorPhase::Finalized,
268 false => *phase = AllocHyphaeInitiatorPhase::RecvFinal,
269 }
270 Ok(())
271 },
272
273 _ => Ok(())
274 }
275 }
276
277 fn responder_read_message(&mut self, mut message: Vec<u8>) -> Result<(), Error> {
278 let AllocHyphaeHandshakePhase::Responder(ref mut phase) = self.phase else {
279 unreachable!();
280 };
281
282 match phase {
283 AllocHyphaeResponderPhase::ReadInitiatorConfigNoise { transport_params } => {
284 let transport_params = mem::take(transport_params);
285
286 let prev_hash = self.noise_handshake.handshake_hash().to_vec();
287 let reader = MessageReader::decode_in_place(&mut message, HandshakeMessage::Initial, self.noise_handshake.as_mut())?;
288 let (peer_transport_params, app_payload) = reader.initial_init_payloads()?;
289 self.peer_transport_params = Some(peer_transport_params.to_vec());
290
291 let mut noise_wrapper = NoiseHandshakeWrapper::wrap_payload(self.noise_handshake.as_mut(),Some(phase.message_position()?), Some(&prev_hash));
292 self.handshake_driver.read_noise_payload(app_payload, &mut noise_wrapper)?;
293
294 *phase = AllocHyphaeResponderPhase::WriteResponderConfigNoise {
295 transport_params,
296 };
297
298 Ok(())
299 },
300
301 AllocHyphaeResponderPhase::Noise { .. } if !self.noise_handshake.is_my_turn() => self.read_noise_message(message),
302
303 AllocHyphaeResponderPhase::SendFinal { received_final: false } |
304 AllocHyphaeResponderPhase::RecvFinal => {
305 if self.next_level_secret_ready {
306 return Err(Error::Internal);
307 }
308
309 let prev_hash = self.noise_handshake.handshake_hash().to_vec();
310 let reader = MessageReader::decode_in_place(&mut message, HandshakeMessage::Final, self.noise_handshake.as_mut())?;
311 let final_payload = reader.final_payload()?;
312 let mut noise_wrapper = NoiseHandshakeWrapper::wrap_payload(self.noise_handshake.as_mut(), None, Some(&prev_hash));
313 self.handshake_driver.read_final_payload(final_payload, &mut noise_wrapper)?;
314
315 match phase {
316 AllocHyphaeResponderPhase::SendFinal { received_final: false } => {
317 *phase = AllocHyphaeResponderPhase::SendFinal { received_final: true }
318 },
319 _ => *phase = AllocHyphaeResponderPhase::Finalized,
320 }
321
322 Ok(())
323 },
324
325 _ => Err(Error::HandshakeFailed)
326 }
327 }
328
329
330 fn responder_write_message(&mut self, buffer: &mut impl Buffer) -> Result<(), Error> {
331 let AllocHyphaeHandshakePhase::Responder(ref mut phase) = self.phase else {
332 unreachable!();
333 };
334
335 match phase {
336 AllocHyphaeResponderPhase::WriteResponderConfigNoise { transport_params } => {
337 let transport_params = mem::take(transport_params);
339 let mut deferred_payload = Vec::new();
340 write_responder_deferred_payload(&mut deferred_payload, self.noise_handshake.as_mut(), transport_params.as_slice(), false, self.handshake_driver.as_mut(), phase.message_position()?)?;
341
342 let crypto = self.crypto.transport_crypto(&self.noise_handshake)?;
343 let mut deferred_payload_hash = crypto.zeros_hash();
344 crypto.hash_into(&deferred_payload[1..], &mut deferred_payload_hash);
345
346 write_responder_initial(buffer, self.noise_handshake.as_mut(), &crypto.hash_as_slice(&deferred_payload_hash))?;
348
349 *phase = AllocHyphaeResponderPhase::WriteResponderDeferredPayload {
350 deferred_payload,
351 };
352
353 self.next_level_secret_ready = true;
354
355 Ok(())
356 },
357
358 AllocHyphaeResponderPhase::WriteResponderDeferredPayload { deferred_payload } => {
359 if self.next_level_secret_ready {
360 return Err(Error::Internal);
361 }
362
363 buffer.extend_from_slice(&deferred_payload)?;
364 *phase = AllocHyphaeResponderPhase::Noise { position: 3 };
365 self.check_noise_finished()
366 },
367
368 AllocHyphaeResponderPhase::Noise { .. } if self.noise_handshake.is_my_turn() => self.write_noise_message(buffer),
369
370 AllocHyphaeResponderPhase::SendFinal { received_final } => {
371 if self.next_level_secret_ready {
372 return Err(Error::Internal); }
374
375 write_final(buffer, self.noise_handshake.as_mut(), self.handshake_driver.as_mut())?;
376 match received_final {
378 true => *phase = AllocHyphaeResponderPhase::Finalized,
379 false => *phase = AllocHyphaeResponderPhase::RecvFinal,
380 }
381 Ok(())
382 },
383
384 _ => Ok(())
385 }
386 }
387
388 fn initiator_read_message(&mut self, mut message: Vec<u8>) -> Result<(), Error> {
389 let AllocHyphaeHandshakePhase::Initiator(ref mut phase) = self.phase else {
390 unreachable!();
391 };
392
393 match phase {
394 AllocHyphaeInitiatorPhase::ReadResponderConfigNoise => {
395 let prev_noise_hash = self.noise_handshake.handshake_hash().to_vec();
396 let reader = MessageReader::decode_in_place(&mut message, HandshakeMessage::Initial, self.noise_handshake.as_mut())?;
397 let deferred_payload_hash = reader.initial_resp_payloads()?;
398
399 *phase = AllocHyphaeInitiatorPhase::ReadResponderDeferredPayload {
400 deferred_payload_hash: deferred_payload_hash.to_vec(),
401 prev_noise_hash,
402 };
403
404 self.next_level_secret_ready = true;
405 Ok(())
406 },
407
408 AllocHyphaeInitiatorPhase::ReadResponderDeferredPayload { deferred_payload_hash, prev_noise_hash } => {
409 if self.next_level_secret_ready {
410 return Err(Error::Internal);
411 }
412
413 let prev_noise_hash = mem::take(prev_noise_hash);
414
415 let reader = MessageReader::decode_in_place(&mut message, HandshakeMessage::DeferredPayload, self.noise_handshake.as_mut())?;
416
417 let crypto = self.crypto.transport_crypto(&self.noise_handshake)?;
418 let mut actual_payload_hash = crypto.zeros_hash();
419 crypto.hash_into(reader.payload, &mut actual_payload_hash);
420 if deferred_payload_hash.as_slice() != crypto.hash_as_slice(&actual_payload_hash) {
421 return Err(Error::HandshakeFailed);
422 }
423
424 let (peer_params, zero_rtt_acc, app_payload) = reader.deferred_resp_payloads()?;
425 self.peer_zero_rtt_accepted = Some(zero_rtt_acc);
426
427 let mut noise_wrapper = NoiseHandshakeWrapper::wrap_payload(self.noise_handshake.as_mut(), Some(phase.message_position()?), Some(&prev_noise_hash));
428 self.handshake_driver.read_noise_payload(app_payload, &mut noise_wrapper)?;
429
430 self.peer_transport_params = Some(peer_params.to_vec());
431 *phase = AllocHyphaeInitiatorPhase::Noise { position: 3 };
432 self.check_noise_finished()
433 },
434
435 AllocHyphaeInitiatorPhase::Noise { .. } if !self.noise_handshake.is_my_turn() => self.read_noise_message(message),
436
437 AllocHyphaeInitiatorPhase::SendFinal { received_final: false } |
438 AllocHyphaeInitiatorPhase::RecvFinal => {
439 if self.next_level_secret_ready {
440 return Err(Error::Internal);
441 }
442
443 let reader = MessageReader::decode_in_place(&mut message, HandshakeMessage::Final, self.noise_handshake.as_mut())?;
444 let final_payload = reader.final_payload()?;
445 let mut noise_wrapper = NoiseHandshakeWrapper::wrap_payload(self.noise_handshake.as_mut(), None, None);
446 self.handshake_driver.read_final_payload(final_payload, &mut noise_wrapper)?;
447
448 match phase {
449 AllocHyphaeInitiatorPhase::SendFinal { received_final: false } => {
450 *phase = AllocHyphaeInitiatorPhase::SendFinal { received_final: true }
451 },
452 _ => *phase = AllocHyphaeInitiatorPhase::Finalized,
453 }
454
455 Ok(())
456 },
457
458 _ => Err(Error::HandshakeFailed),
459 }
460 }
461
462 fn write_noise_message(&mut self, buffer: &mut impl Buffer) -> Result<(), Error> {
463 write_noise(buffer, self.noise_handshake.as_mut(), self.handshake_driver.as_mut(), self.phase.message_position()?)?;
464 self.phase.advance_message_position()?;
465 self.check_noise_finished()
466 }
467
468 fn read_noise_message(&mut self, mut message: Vec<u8>) -> Result<(), Error> {
469 let prev_hash = self.noise_handshake.handshake_hash().to_vec();
470 let reader = MessageReader::decode_in_place(&mut message, HandshakeMessage::Noise, self.noise_handshake.as_mut())?;
471 let mut noise_wrapper = NoiseHandshakeWrapper::wrap_payload(self.noise_handshake.as_mut(), Some(self.phase.message_position()?), Some(&prev_hash));
472 self.handshake_driver.read_noise_payload(reader.payload()?, &mut noise_wrapper)?;
473 self.phase.advance_message_position()?;
474 self.check_noise_finished()
475 }
476
477 fn check_noise_finished(&mut self) -> Result<(), Error> {
478 match &mut self.phase {
479 AllocHyphaeHandshakePhase::Initiator(AllocHyphaeInitiatorPhase::Noise { .. }) => {
480 if self.noise_handshake.is_finished() {
481 self.phase = AllocHyphaeHandshakePhase::Initiator(AllocHyphaeInitiatorPhase::SendFinal { received_final: false });
482 }
483 },
484 AllocHyphaeHandshakePhase::Responder(AllocHyphaeResponderPhase::Noise { .. }) => {
485 if self.noise_handshake.is_finished() {
486 self.phase = AllocHyphaeHandshakePhase::Responder(AllocHyphaeResponderPhase::SendFinal { received_final: false });
487 }
488 },
489 _ => return Err(Error::Internal)
490 }
491
492 if self.noise_handshake.is_finished() {
493 self.next_level_secret_ready = true;
494 }
495 Ok(())
496 }
497}
498
499enum AllocHyphaeHandshakePhase {
500 Initiator (AllocHyphaeInitiatorPhase),
501 Responder (AllocHyphaeResponderPhase),
502}
503
504impl AllocHyphaeHandshakePhase {
505 pub fn message_position(&self) -> Result<u8, Error> {
506 match self {
507 AllocHyphaeHandshakePhase::Initiator(phase) => phase.message_position(),
508 AllocHyphaeHandshakePhase::Responder(phase) => phase.message_position(),
509 }
510 }
511
512 pub fn advance_message_position(&mut self) -> Result<(), Error> {
513 let position = match self {
514 AllocHyphaeHandshakePhase::Initiator(AllocHyphaeInitiatorPhase::Noise { position }) => position,
515 AllocHyphaeHandshakePhase::Responder(AllocHyphaeResponderPhase::Noise { position }) => position,
516 _ => return Err(Error::Internal)
517 };
518 *position = position.checked_add(1).ok_or(Error::Internal)?;
519 Ok(())
520 }
521}
522
523enum AllocHyphaeInitiatorPhase {
524 WritePreamble {
525 preamble: Vec<u8>,
526 transport_params: Vec<u8>,
527 },
528 WriteInitiatorConfigNoise {
529 transport_params: Vec<u8>,
530 },
531 ReadResponderConfigNoise,
532 ReadResponderDeferredPayload {
533 deferred_payload_hash: Vec<u8>,
534 prev_noise_hash: Vec<u8>,
535 },
536 Noise {
537 position: u8,
538 },
539 SendFinal {
540 received_final: bool,
541 },
542 RecvFinal,
543 Finalized,
544}
545
546impl AllocHyphaeInitiatorPhase {
547 pub fn message_position(&self) -> Result<u8, Error> {
548 match self {
549 AllocHyphaeInitiatorPhase::WriteInitiatorConfigNoise { .. } => Ok(1),
550 AllocHyphaeInitiatorPhase::ReadResponderDeferredPayload { .. } => Ok(2),
551 AllocHyphaeInitiatorPhase::Noise { position } => Ok(*position),
552 _ => Err(Error::Internal),
553 }
554 }
555}
556
557enum AllocHyphaeResponderPhase {
558 ReadInitiatorConfigNoise {
559 transport_params: Vec<u8>,
560 },
561 WriteResponderConfigNoise {
562 transport_params: Vec<u8>,
563 },
564 WriteResponderDeferredPayload {
565 deferred_payload: Vec<u8>,
566 },
567 Noise {
568 position: u8,
569 },
570 SendFinal {
571 received_final: bool,
572 },
573 RecvFinal,
574 Finalized,
575}
576
577impl AllocHyphaeResponderPhase {
578 pub fn message_position(&self) -> Result<u8, Error> {
579 match self {
580 AllocHyphaeResponderPhase::ReadInitiatorConfigNoise { .. } => Ok(1),
581 AllocHyphaeResponderPhase::WriteResponderConfigNoise { .. } => Ok(2),
582 AllocHyphaeResponderPhase::Noise { position } => Ok(*position),
583 _ => Err(Error::Internal),
584 }
585 }
586}
587
588#[repr(u8)]
589#[derive(Clone, Copy, PartialEq, Eq, Debug)]
590enum HandshakeMessage {
591 Preamble = 1,
592 Initial = 2,
593 DeferredPayload = 3,
594 Noise = 4,
595 FinalPayload = 126,
596 Final = 127,
597 Failed = 255,
598}
599
600impl HandshakeMessage {
601 pub fn from_id(id: u8) -> Result<Self, Error> {
602 match id {
603 x if x == Self::Preamble as u8 => Ok(Self::Preamble),
604 x if x == Self::Initial as u8 => Ok(Self::Initial),
605 x if x == Self::DeferredPayload as u8 => Ok(Self::DeferredPayload),
606 x if x == Self::Noise as u8 => Ok(Self::Noise),
607 x if x == Self::Final as u8 => Ok(Self::Final),
608 x if x == Self::FinalPayload as u8 => Ok(Self::FinalPayload),
609 x if x == Self::Failed as u8 => Ok(Self::Failed),
610 _ => Err(Error::HandshakeFailed)
611 }
612 }
613
614 pub fn is_encrypted(self) -> bool {
615 match self {
616 HandshakeMessage::Initial |
617 HandshakeMessage::Noise => true,
618 _ => false,
619 }
620 }
621
622 pub fn has_compound_payload(self) -> bool {
623 match self {
624 HandshakeMessage::Initial |
625 HandshakeMessage::DeferredPayload |
626 HandshakeMessage::FinalPayload => true,
627 _ => false,
628 }
629 }
630
631 pub fn has_payload(self) -> Option<bool> {
632 match self {
633 HandshakeMessage::Preamble => Some(true),
634 HandshakeMessage::Initial => Some(true),
635 HandshakeMessage::DeferredPayload => Some(true),
636 HandshakeMessage::Noise => None,
637 HandshakeMessage::FinalPayload => Some(true),
638 HandshakeMessage::Final => Some(false),
639 HandshakeMessage::Failed => Some(false),
640 }
641 }
642}
643
644#[repr(u8)]
645#[derive(Clone, Copy, PartialEq, Eq, Debug)]
646enum PayloadFrame {
647 Padding = 0,
648 ApplicationPayload = 1,
649 TransportParameters = 2,
650 DeferredPayloadHash = 3,
651 ZeroRttAccepted = 64,
652}
653
654impl PayloadFrame {
655 const OPTIONAL_BASE: u8 = 128;
665
666 fn ok_in(self, message: HandshakeMessage, from_initiator: bool) -> Result<(), Error> {
667 match (message, from_initiator, self) {
668 (HandshakeMessage::Initial, true, Self::Padding) => Ok(()),
669 (HandshakeMessage::Initial, true, Self::TransportParameters) => Ok(()),
670 (HandshakeMessage::Initial, true, Self::ApplicationPayload) => Ok(()),
671 (HandshakeMessage::Initial, false, Self::Padding) => Ok(()),
672 (HandshakeMessage::Initial, false, Self::DeferredPayloadHash) => Ok(()),
673 (HandshakeMessage::DeferredPayload, false, Self::Padding) => Ok(()),
674 (HandshakeMessage::DeferredPayload, false, Self::TransportParameters) => Ok(()),
675 (HandshakeMessage::DeferredPayload, false, Self::ZeroRttAccepted) => Ok(()),
676 (HandshakeMessage::DeferredPayload, false, Self::ApplicationPayload) => Ok(()),
677 (HandshakeMessage::FinalPayload, false, Self::ApplicationPayload) => Ok(()),
678 _ => Err(Error::HandshakeFailed)
679 }
680 }
681
682 fn from_id(frame_id: u8) -> Result<Option<Self>, Error> {
683 match frame_id {
684 id if id == Self::Padding as u8 => Ok(Some(Self::Padding)),
685 id if id == Self::ApplicationPayload as u8 => Ok(Some(Self::ApplicationPayload)),
686 id if id == Self::TransportParameters as u8 => Ok(Some(Self::TransportParameters)),
687 id if id == Self::DeferredPayloadHash as u8 => Ok(Some(Self::DeferredPayloadHash)),
688 id if id == Self::ZeroRttAccepted as u8 => Ok(Some(Self::ZeroRttAccepted)),
689 id if id >= Self::OPTIONAL_BASE => Ok(None),
690 _ => Err(Error::HandshakeFailed),
691 }
692 }
693
694 fn get_frame_payload(this: Option<Self>, mut remaining: &[u8]) -> Result<(&[u8], &[u8]), Error> {
695 let payload_len = match this {
696 Some(Self::ApplicationPayload) => remaining.len(),
697 Some(Self::DeferredPayloadHash) => remaining.len(),
698 Some(Self::Padding) => 0,
699 Some(Self::ZeroRttAccepted) => 0,
700 None | Some(Self::TransportParameters) => {
701 let prefix_len = VarIntSize::from_msb(remaining.get(0).copied().ok_or(Error::HandshakeFailed)?);
703 if remaining.len() < prefix_len.len() {
704 return Err(Error::HandshakeFailed);
705 }
706 let (prefix, r) = remaining.split_at(prefix_len.len());
707 remaining = r;
708 let mut prefix64 = [0u8; 8];
709 prefix64[8 - prefix.len()..].copy_from_slice(prefix);
710 prefix64[8 - prefix.len()] &= !0xC0;
711 u64::from_be_bytes(prefix64).try_into().map_err(|_| Error::HandshakeFailed)?
712 }
713 };
714 if payload_len > remaining.len() {
715 return Err(Error::HandshakeFailed)
716 }
717 Ok(remaining.split_at(payload_len))
718 }
719
720 pub fn next_frame(remaining: &[u8], message: HandshakeMessage, from_initiator: bool) -> Result<Option<(Self, &[u8], &[u8])>, Error> {
721 let Some(frame_id) = remaining.get(0).cloned() else {
722 return Ok(None);
723 };
724
725 let frame_type = Self::from_id(frame_id)?;
726 if let Some(frame_type) = frame_type {
727 frame_type.ok_in(message, from_initiator)?;
728 }
729 let (frame_payload, remaining) = Self::get_frame_payload(frame_type, &remaining[1..])?;
730
731 match frame_type {
732 Some(frame_type) if frame_type != Self::Padding =>
733 Ok(Some((frame_type, frame_payload, remaining))),
734
735 _ => Self::next_frame(remaining, message, from_initiator), }
737 }
738
739}
740
741struct NoiseHandshakeWrapper<'a, X: NoiseHandshake> {
742 inner: &'a mut X,
743 init_info: Option<(HandshakeVersion, &'a [u8], &'a [u8])>,
744 initiator: Option<bool>,
745 position: Option<u8>,
746 prev_hash: Option<&'a [u8]>,
747}
748
749impl <'a, X: NoiseHandshake> NoiseHandshakeWrapper<'a, X> {
750 pub fn wrap_init(inner: &'a mut X, version: HandshakeVersion, transport_label: &'a [u8], preamble: &'a [u8], initiator: bool) -> Self {
751 Self {
752 inner,
753 init_info: Some((version, transport_label, preamble)),
754 initiator: Some(initiator),
755 position: None,
756 prev_hash: None,
757 }
758 }
759
760 pub fn wrap_payload(inner: &'a mut X, position: Option<u8>, prev_hash: Option<&'a [u8]>) -> Self {
761 Self {
762 inner,
763 init_info: None,
764 initiator: None,
765 position,
766 prev_hash,
767 }
768 }
769}
770
771impl <X: NoiseHandshake> HandshakeInfo for NoiseHandshakeWrapper<'_, X> {
772 fn initialize(&mut self, rng: &mut (impl CryptoRng + RngCore), protocol: &str, prologue: &[u8], s: Option<SecretKeySetup>, rs: Option<&[u8]>) -> Result<(), CryptoError> {
773 let Some(initiator) = self.initiator else {
774 return Err(CryptoError::Internal);
775 };
776 let Some((version, transport_label, preamble)) = self.init_info else {
777 return Err(CryptoError::Internal);
778 };
779 let Ok(preamble_len) = u16::try_from(preamble.len()) else {
780 return Err(CryptoError::Internal);
781 };
782 let preamble_len_le = preamble_len.to_le_bytes();
783
784 if !self.inner.is_reset() {
785 return Err(CryptoError::Internal);
786 }
787
788 let handshake_prologue =
789 once(version.label())
790 .chain(once(b".".as_slice()))
791 .chain(once(transport_label))
792 .chain(once(b".".as_slice()))
793 .chain(once(preamble_len_le.as_slice()))
794 .chain(once(preamble))
795 .chain(once(prologue));
796
797 self.inner.initialize(rng, protocol, initiator, handshake_prologue, s, rs)
798 }
799
800 fn set_token(&mut self, _token: &str, _value: &[u8]) -> Result<(), CryptoError> {
801 Err(CryptoError::Internal)
802 }
803
804 fn is_initiator(&self) -> bool {
805 if self.inner.is_reset() {
806 self.initiator.unwrap_or_default()
807 } else {
808 self.inner.is_initiator()
809 }
810 }
811
812 fn is_finished(&self) -> bool {
813 self.inner.is_finished()
814 }
815
816 fn handshake_position(&self) -> Option<u8> {
817 self.position
818 }
819
820 fn prev_handshake_hash(&self) -> Option<&[u8]> {
821 self.prev_hash.or_else(|| Some(self.inner.handshake_hash()))
822 }
823
824 fn final_handshake_hash(&self) -> Option<&[u8]> {
825 match self.inner.is_finished() {
826 true => Some(self.inner.handshake_hash()),
827 false => None,
828 }
829 }
830}
831
832struct MessageReader<'a> {
833 payload: &'a [u8],
834 message_type: HandshakeMessage,
835}
836
837impl <'a> MessageReader<'a> {
838 pub fn decode_message_type(buffer: &[u8]) -> Result<HandshakeMessage, Error> {
839 if buffer.is_empty() {
840 return Err(Error::HandshakeFailed);
841 }
842
843 HandshakeMessage::from_id(buffer[0])
844 }
845
846 pub fn decode_in_place(buffer: &'a mut [u8], expect: HandshakeMessage, noise: &mut impl NoiseHandshake) -> Result<Self, Error> {
847 let message_type = Self::decode_message_type(buffer)?;
848 let buffer = &mut buffer[1..];
849
850 let expected = match expect {
851 HandshakeMessage::Final =>
852 message_type == HandshakeMessage::Final ||
853 message_type == HandshakeMessage::FinalPayload,
854 expect => expect == message_type,
855 };
856 if !expected {
857 return Err(Error::HandshakeFailed);
858 }
859
860 let payload = if message_type.is_encrypted() {
862 noise.read_message_in_place(buffer)?
863 } else {
864 buffer
865 };
866
867 if message_type.has_compound_payload() &&
869 (payload.is_empty() || payload[0] != HandshakeVersion::Version1.id())
870 {
871 return Err(Error::HandshakeFailed);
872 }
873
874 if let Some(has_payload) = message_type.has_payload() {
876 if has_payload == payload.is_empty() {
877 return Err(Error::HandshakeFailed);
878 }
879 }
880
881 Ok(Self {
882 payload,
883 message_type,
884 })
885 }
886
887 pub fn payload(&self) -> Result<&'a [u8], Error> {
889 if self.message_type.has_compound_payload() {
890 return Err(Error::Internal);
891 }
892 Ok(self.payload)
893 }
894
895 pub fn initial_init_payloads(&self) -> Result<(&'a [u8], &'a [u8]), Error> {
898 if self.message_type != HandshakeMessage::Initial {
899 return Err(Error::Internal);
900 }
901
902 let mut frame_cursor = &self.payload[1..];
903 let mut transport_params = None;
904 let mut application_payload = None;
905
906 loop {
907 let Some((frame, payload, remaining)) = PayloadFrame::next_frame(frame_cursor, self.message_type, true)? else {
908 break;
909 };
910 match frame {
911 PayloadFrame::ApplicationPayload if application_payload.is_some() => return Err(Error::HandshakeFailed),
912 PayloadFrame::ApplicationPayload => application_payload = Some(payload),
913 PayloadFrame::TransportParameters if transport_params.is_some() => return Err(Error::HandshakeFailed),
914 PayloadFrame::TransportParameters => transport_params = Some(payload),
915 _ => {}
916 }
917 frame_cursor = remaining;
918 }
919
920 if let Some(true) = application_payload.map(|s| s.is_empty()) {
921 return Err(Error::HandshakeFailed);
922 }
923
924 application_payload.get_or_insert(&[]);
925
926 match (transport_params, application_payload) {
927 (Some(tp), Some(ap)) => Ok((tp, ap)),
928 _ => Err(Error::HandshakeFailed)
929 }
930 }
931
932 pub fn initial_resp_payloads(&self) -> Result<&'a [u8], Error> {
935 if self.message_type != HandshakeMessage::Initial {
936 return Err(Error::Internal);
937 }
938
939 let mut frame_cursor = &self.payload[1..];
940 let mut deferred_hash = None;
941
942 loop {
943 let Some((frame, payload, remaining)) = PayloadFrame::next_frame(frame_cursor, self.message_type, false)? else {
944 break;
945 };
946 match frame {
947 PayloadFrame::DeferredPayloadHash if deferred_hash.is_some() => return Err(Error::HandshakeFailed),
948 PayloadFrame::DeferredPayloadHash => deferred_hash = Some(payload),
949 _ => {}
950 }
951 frame_cursor = remaining;
952 }
953
954 if let Some(true) = deferred_hash.map(|s| s.is_empty()) {
955 return Err(Error::HandshakeFailed);
956 }
957
958 match deferred_hash {
959 Some(dh) => Ok(dh),
960 _ => Err(Error::HandshakeFailed)
961 }
962 }
963
964 pub fn deferred_resp_payloads(&self) -> Result<(&'a [u8], bool, &'a [u8]), Error> {
967 if self.message_type != HandshakeMessage::DeferredPayload {
968 return Err(Error::Internal);
969 }
970
971 let mut frame_cursor = &self.payload[1..];
972 let mut transport_params = None;
973 let mut application_payload = None;
974 let mut zero_rtt_accepted = None;
975
976 loop {
977 let Some((frame, payload, remaining)) = PayloadFrame::next_frame(frame_cursor, self.message_type, false)? else {
978 break;
979 };
980 match frame {
981 PayloadFrame::ApplicationPayload if application_payload.is_some() => return Err(Error::HandshakeFailed),
982 PayloadFrame::ApplicationPayload => application_payload = Some(payload),
983 PayloadFrame::TransportParameters if transport_params.is_some() => return Err(Error::HandshakeFailed),
984 PayloadFrame::TransportParameters => transport_params = Some(payload),
985 PayloadFrame::ZeroRttAccepted if zero_rtt_accepted.is_some() => return Err(Error::HandshakeFailed),
986 PayloadFrame::ZeroRttAccepted => zero_rtt_accepted = Some(true),
987 _ => {}
988 }
989 frame_cursor = remaining;
990 }
991
992 if let Some(true) = application_payload.map(|s| s.is_empty()) {
993 return Err(Error::HandshakeFailed);
994 }
995
996 application_payload.get_or_insert(&[]);
997 zero_rtt_accepted.get_or_insert(false);
998
999 match (transport_params, zero_rtt_accepted, application_payload) {
1000 (Some(tp), Some(zrtt), Some(ap)) => Ok((tp, zrtt, ap)),
1001 _ => Err(Error::HandshakeFailed)
1002 }
1003 }
1004
1005 pub fn final_payload(&self) -> Result<&'a [u8], Error> {
1008 match self.message_type {
1009 HandshakeMessage::Final => return Ok(&[]),
1010 HandshakeMessage::FinalPayload => {},
1011 _ => return Err(Error::Internal),
1012 }
1013
1014 let mut frame_cursor = &self.payload[1..];
1015 let mut final_payload = None;
1016
1017 loop {
1018 let Some((frame, payload, remaining)) = PayloadFrame::next_frame(frame_cursor, self.message_type, false)? else {
1019 break;
1020 };
1021 match frame {
1022 PayloadFrame::ApplicationPayload if final_payload.is_some() => return Err(Error::HandshakeFailed),
1023 PayloadFrame::ApplicationPayload => final_payload = Some(payload),
1024 _ => {}
1025 }
1026 frame_cursor = remaining;
1027 }
1028
1029 match final_payload {
1030 Some(fp) => Ok(fp),
1031 _ => Err(Error::HandshakeFailed)
1032 }
1033 }
1034}
1035
1036fn write_preamble(buffer: &mut impl Buffer, preamble: &[u8]) -> Result<(), Error> {
1037 let mut buffer = MaxLenBuffer::new(buffer, u16::MAX as usize)?;
1038 buffer.push(HandshakeMessage::Preamble as u8)?;
1039 buffer.extend_from_slice(preamble)?;
1040 Ok(())
1041}
1042
1043fn write_initiator_initial(buffer: &mut impl Buffer, noise: &mut impl NoiseHandshake, transport_params: &[u8], driver: &mut impl HandshakeDriver, position: u8) -> Result<(), Error> {
1044 let mut buffer = MaxLenBuffer::new(buffer, u16::MAX as usize)?;
1045 let (token_padding, tag_padding) = noise.next_message_layout()?;
1046 buffer.push(HandshakeMessage::Initial as u8)?;
1047 insert_padding(&mut buffer, token_padding)?;
1048 buffer.push(HandshakeVersion::Version1.id())?;
1049 insert_varlen_frame(&mut buffer, PayloadFrame::TransportParameters, transport_params)?;
1050 insert_application_payload(&mut buffer, noise, driver, Some(position))?;
1051 insert_padding(&mut buffer, tag_padding)?;
1052 noise.write_message_in_place(&mut buffer.as_mut()[1..])?;
1053 Ok(())
1054}
1055
1056fn write_responder_initial(buffer: &mut impl Buffer, noise: &mut impl NoiseHandshake, deferred_payload_hash: &[u8]) -> Result<(), Error> {
1057 let mut buffer = MaxLenBuffer::new(buffer, u16::MAX as usize)?;
1058 let (token_padding, tag_padding) = noise.next_message_layout()?;
1059 buffer.push(HandshakeMessage::Initial as u8)?;
1060 insert_padding(&mut buffer, token_padding)?;
1061 buffer.push(HandshakeVersion::Version1.id())?;
1062 buffer.push(PayloadFrame::DeferredPayloadHash as u8)?;
1063 buffer.extend_from_slice(deferred_payload_hash)?;
1064 insert_padding(&mut buffer, tag_padding)?;
1065 noise.write_message_in_place(&mut buffer.as_mut()[1..])?;
1066 Ok(())
1067}
1068
1069fn write_responder_deferred_payload(buffer: &mut impl Buffer, noise: &mut impl NoiseHandshake, transport_params: &[u8], zero_rtt_accepted: bool, driver: &mut impl HandshakeDriver, position: u8) -> Result<(), Error> {
1070 let mut buffer = MaxLenBuffer::new(buffer, u16::MAX as usize)?;
1071 buffer.push(HandshakeMessage::DeferredPayload as u8)?;
1072 buffer.push(HandshakeVersion::Version1.id())?;
1073 insert_varlen_frame(&mut buffer, PayloadFrame::TransportParameters, transport_params)?;
1074 if zero_rtt_accepted {
1075 buffer.push(PayloadFrame::ZeroRttAccepted as u8)?;
1076 }
1077 insert_application_payload(&mut buffer, noise, driver, Some(position))?;
1078 Ok(())
1079}
1080
1081fn write_noise(buffer: &mut impl Buffer, noise: &mut impl NoiseHandshake, driver: &mut impl HandshakeDriver, position: u8) -> Result<(), Error> {
1082 let mut buffer = MaxLenBuffer::new(buffer, u16::MAX as usize)?;
1083 let (token_padding, tag_padding) = noise.next_message_layout()?;
1084 buffer.push(HandshakeMessage::Noise as u8)?;
1085 insert_padding(&mut buffer, token_padding)?;
1086
1087 let mut noise_wrapper = NoiseHandshakeWrapper::wrap_payload(noise, Some(position), None);
1088 driver.write_noise_payload(&mut AppendOnlyBuffer::new(&mut buffer), &mut noise_wrapper)?;
1089
1090 insert_padding(&mut buffer, tag_padding)?;
1091 noise.write_message_in_place(&mut buffer.as_mut()[1..])?;
1092 Ok(())
1093}
1094
1095fn write_final(buffer: &mut impl Buffer, noise: &mut impl NoiseHandshake, driver: &mut impl HandshakeDriver) -> Result<(), Error> {
1096 let mut buffer = MaxLenBuffer::new(buffer, u16::MAX as usize)?;
1097 let mut buffer = AppendOnlyBuffer::new(&mut buffer);
1098 buffer.push(HandshakeMessage::FinalPayload as u8)?;
1099 buffer.push(HandshakeVersion::Version1.id())?;
1100 buffer.push(PayloadFrame::ApplicationPayload as u8)?;
1101 let mut noise_wrapper = NoiseHandshakeWrapper::wrap_payload(noise, None, None);
1102 driver.write_final_payload(&mut AppendOnlyBuffer::new(&mut buffer), &mut noise_wrapper)?;
1103 if buffer.len() == 3 {
1104 buffer.clear();
1105 buffer.push(HandshakeMessage::Final as u8)?;
1106 }
1107 Ok(())
1108}
1109
1110fn insert_varlen_frame(buffer: &mut impl Buffer, frame: PayloadFrame, payload: &[u8]) -> Result<(), Error> {
1111 buffer.push(frame as u8)?;
1112 let mut len_buffer = VarLengthPrefixBuffer::new(buffer, payload.len())?;
1113 len_buffer.extend_from_slice(payload)?;
1114 Ok(())
1115}
1116
1117fn insert_application_payload(buffer: &mut impl Buffer, noise: &mut impl NoiseHandshake, driver: &mut impl HandshakeDriver, position: Option<u8>) -> Result<(), Error> {
1118 let mut buffer = AppendOnlyBuffer::new(buffer);
1119 buffer.push(PayloadFrame::ApplicationPayload as u8)?;
1120 let mut noise_wrapper = NoiseHandshakeWrapper::wrap_payload(noise, position, None);
1121 driver.write_noise_payload(&mut AppendOnlyBuffer::new(&mut buffer), &mut noise_wrapper)?;
1122 if buffer.len() == 1 {
1123 buffer.clear();
1124 }
1125 Ok(())
1126}
1127
1128fn insert_padding(buffer: &mut impl Buffer, len: usize) -> Result<(), Error> {
1129 for _ in 0..len {
1130 buffer.push(0)?;
1131 }
1132 Ok(())
1133}