1use crate::cipherstate::CipherState;
2use crate::handshakepattern::{HandshakePattern, Token};
3use crate::symmetricstate::SymmetricState;
4use crate::traits::{Cipher, Hash, U8Array, Unspecified, DH};
5use arrayvec::{ArrayString, ArrayVec};
6use core::fmt::{Display, Error as FmtError, Formatter, Write};
7
8#[cfg(feature = "alloc")]
9use alloc::vec::Vec;
10
11pub struct HandshakeState<D: DH, C: Cipher, H: Hash> {
13 symmetric: SymmetricState<C, H>,
14 s: Option<D::Key>,
15 e: Option<D::Key>,
16 rs: Option<D::Pubkey>,
17 re: Option<D::Pubkey>,
18 is_initiator: bool,
19 pattern: HandshakePattern,
20 message_index: usize,
21 pattern_has_psk: bool,
22 psks: ArrayVec<[u8; 32], 4>,
23}
24
25impl<D, C, H> Clone for HandshakeState<D, C, H>
26where
27 D: DH,
28 <D as DH>::Key: Clone,
29 C: Cipher,
30 <C as Cipher>::Key: Clone,
31 H: Hash,
32{
33 fn clone(&self) -> Self {
34 Self {
35 symmetric: self.symmetric.clone(),
36 s: self.s.clone(),
37 e: self.e.clone(),
38 rs: self.rs.as_ref().map(U8Array::clone),
39 re: self.re.as_ref().map(U8Array::clone),
40 is_initiator: self.is_initiator,
41 pattern: self.pattern.clone(),
42 message_index: self.message_index,
43 pattern_has_psk: self.pattern_has_psk,
44 psks: self.psks.clone(),
45 }
46 }
47}
48
49impl<D, C, H> HandshakeState<D, C, H>
50where
51 D: DH,
52 C: Cipher,
53 H: Hash,
54{
55 fn get_name(pattern_name: &str) -> ArrayString<256> {
57 let mut ret = ArrayString::new();
58 write!(
59 &mut ret,
60 "Noise_{}_{}_{}_{}",
61 pattern_name,
62 D::name(),
63 C::name(),
64 H::name()
65 )
66 .unwrap();
67 ret
68 }
69
70 pub fn new<P>(
82 pattern: HandshakePattern,
83 is_initiator: bool,
84 prologue: P,
85 s: Option<D::Key>,
86 e: Option<D::Key>,
87 rs: Option<D::Pubkey>,
88 re: Option<D::Pubkey>,
89 ) -> Self
90 where
91 P: AsRef<[u8]>,
92 {
93 let mut symmetric = SymmetricState::new(Self::get_name(pattern.get_name()).as_bytes());
94 let pattern_has_psk = pattern.has_psk();
95
96 symmetric.mix_hash(prologue.as_ref());
98
99 for t in pattern.get_pre_i() {
101 match *t {
102 Token::S => {
103 if is_initiator {
104 symmetric.mix_hash(D::pubkey(s.as_ref().unwrap()).as_slice());
105 } else {
106 symmetric.mix_hash(rs.as_ref().unwrap().as_slice());
107 }
108 }
109 _ => panic!("Unexpected token in pre message"),
110 }
111 }
112 for t in pattern.get_pre_r() {
113 match *t {
114 Token::S => {
115 if is_initiator {
116 symmetric.mix_hash(rs.as_ref().unwrap().as_slice());
117 } else {
118 symmetric.mix_hash(D::pubkey(s.as_ref().unwrap()).as_slice());
119 }
120 }
121 Token::E => {
122 if is_initiator {
123 let re = re.as_ref().unwrap().as_slice();
124 symmetric.mix_hash(re);
125 if pattern_has_psk {
126 symmetric.mix_key(re);
127 }
128 } else {
129 let e = D::pubkey(e.as_ref().unwrap());
130 symmetric.mix_hash(e.as_slice());
131 if pattern_has_psk {
132 symmetric.mix_key(e.as_slice());
133 }
134 }
135 }
136 _ => panic!("Unexpected token in pre message"),
137 }
138 }
139
140 HandshakeState {
141 symmetric,
142 s,
143 e,
144 rs,
145 re,
146 is_initiator,
147 pattern,
148 message_index: 0,
149 pattern_has_psk,
150 psks: ArrayVec::new(),
151 }
152 }
153
154 pub fn get_next_message_overhead(&self) -> usize {
161 let m = self.pattern.get_message_pattern(self.message_index);
162
163 let mut overhead = 0;
164
165 let mut has_key = self.symmetric.has_key();
166
167 for &t in m {
168 match t {
169 Token::E => {
170 overhead += D::Pubkey::len();
171 if self.pattern_has_psk {
172 has_key = true;
173 }
174 }
175 Token::S => {
176 overhead += D::Pubkey::len();
177 if has_key {
178 overhead += 16;
179 }
180 }
181 _ => {
182 has_key = true;
183 }
184 }
185 }
186
187 if has_key {
188 overhead += 16
189 }
190
191 overhead
192 }
193
194 #[cfg(feature = "alloc")]
196 pub fn write_message_vec(&mut self, payload: &[u8]) -> Result<Vec<u8>, Error> {
197 let mut out = vec![0u8; payload.len() + self.get_next_message_overhead()];
198 self.write_message(payload, &mut out)?;
199 Ok(out)
200 }
201
202 pub fn write_message(&mut self, payload: &[u8], out: &mut [u8]) -> Result<(), Error> {
220 debug_assert_eq!(out.len(), payload.len() + self.get_next_message_overhead());
221
222 assert!(self.is_write_turn());
224
225 let m = self.pattern.get_message_pattern(self.message_index);
227 self.message_index += 1;
228
229 let mut cur: usize = 0;
230 for t in m {
232 match *t {
233 Token::E => {
234 if self.e.is_none() {
235 self.e = Some(D::genkey());
236 }
237 let e_pk = D::pubkey(self.e.as_ref().unwrap());
238 self.symmetric.mix_hash(e_pk.as_slice());
239 if self.pattern_has_psk {
240 self.symmetric.mix_key(e_pk.as_slice());
241 }
242 out[cur..cur + D::Pubkey::len()].copy_from_slice(e_pk.as_slice());
243 cur += D::Pubkey::len();
244 }
245 Token::S => {
246 let len = if self.symmetric.has_key() {
247 D::Pubkey::len() + 16
248 } else {
249 D::Pubkey::len()
250 };
251
252 let encrypted_s_out = &mut out[cur..cur + len];
253 self.symmetric.encrypt_and_hash(
254 D::pubkey(self.s.as_ref().unwrap()).as_slice(),
255 encrypted_s_out,
256 );
257 cur += len;
258 }
259 Token::PSK => {
260 if let Some(psk) = self.psks.pop_at(0) {
261 self.symmetric.mix_key_and_hash(&psk);
262 } else {
263 return Err(Error::need_psk());
264 }
265 }
266 t => {
267 let dh_result = self.perform_dh(t).map_err(|_| Error::dh())?;
268 self.symmetric.mix_key(dh_result.as_slice());
269 }
270 }
271 }
272
273 self.symmetric.encrypt_and_hash(payload, &mut out[cur..]);
274 Ok(())
275 }
276
277 pub fn read_message(&mut self, data: &[u8], out: &mut [u8]) -> Result<(), Error> {
309 debug_assert_eq!(out.len() + self.get_next_message_overhead(), data.len());
310
311 assert!(!self.is_write_turn());
312
313 let m = self.pattern.get_message_pattern(self.message_index);
315 self.message_index += 1;
316
317 let mut data = data;
318 let mut get = |n| {
320 let ret = &data[..n];
321 data = &data[n..];
322 ret
323 };
324
325 for t in m {
327 match *t {
328 Token::E => {
329 let re = D::Pubkey::from_slice(get(D::Pubkey::len()));
330 self.symmetric.mix_hash(re.as_slice());
331 if self.pattern_has_psk {
332 self.symmetric.mix_key(re.as_slice());
333 }
334 self.re = Some(re);
335 }
336 Token::S => {
337 let temp = get(if self.symmetric.has_key() {
338 D::Pubkey::len() + 16
339 } else {
340 D::Pubkey::len()
341 });
342 let mut rs = D::Pubkey::new();
343 self.symmetric
344 .decrypt_and_hash(temp, rs.as_mut())
345 .map_err(|_| Error::decryption())?;
346 self.rs = Some(rs);
347 }
348 Token::PSK => {
349 if let Some(psk) = self.psks.pop_at(0) {
350 self.symmetric.mix_key_and_hash(&psk);
351 } else {
352 return Err(Error::need_psk());
353 }
354 }
355 t => {
356 let dh_result = self.perform_dh(t).map_err(|_| Error::dh())?;
357 self.symmetric.mix_key(dh_result.as_slice());
358 }
359 }
360 }
361
362 self.symmetric
363 .decrypt_and_hash(data, out)
364 .map_err(|_| Error::decryption())
365 }
366
367 #[cfg(feature = "alloc")]
374 pub fn read_message_vec(&mut self, data: &[u8]) -> Result<Vec<u8>, Error> {
375 let overhead = self.get_next_message_overhead();
376 if data.len() < overhead {
377 Err(Error::too_short())
378 } else {
379 let mut out = vec![0u8; data.len() - overhead];
380 self.read_message(data, &mut out)?;
381 Ok(out)
382 }
383 }
384
385 pub fn push_psk(&mut self, psk: &[u8]) {
391 self.psks.push(U8Array::from_slice(psk));
392 }
393
394 pub fn completed(&self) -> bool {
396 self.message_index == self.pattern.get_message_patterns_len()
397 }
398
399 pub fn get_hash(&self) -> &[u8] {
401 self.symmetric.get_hash()
402 }
403
404 pub fn get_ciphers(&self) -> (CipherState<C>, CipherState<C>) {
411 self.symmetric.split()
412 }
413
414 pub fn get_rs(&self) -> Option<D::Pubkey> {
416 self.rs.as_ref().map(U8Array::clone)
417 }
418
419 pub fn get_re(&self) -> Option<D::Pubkey> {
425 self.re.as_ref().map(U8Array::clone)
426 }
427
428 pub fn get_is_initiator(&self) -> bool {
430 self.is_initiator
431 }
432
433 pub fn get_pattern(&self) -> &HandshakePattern {
435 &self.pattern
436 }
437
438 pub fn is_write_turn(&self) -> bool {
440 self.message_index % 2 == if self.is_initiator { 0 } else { 1 }
441 }
442
443 fn perform_dh(&self, t: Token) -> Result<D::Output, Unspecified> {
444 let dh = |a: Option<&D::Key>, b: Option<&D::Pubkey>| D::dh(a.unwrap(), b.unwrap());
445
446 match t {
447 Token::EE => dh(self.e.as_ref(), self.re.as_ref()),
448 Token::ES => {
449 if self.is_initiator {
450 dh(self.e.as_ref(), self.rs.as_ref())
451 } else {
452 dh(self.s.as_ref(), self.re.as_ref())
453 }
454 }
455 Token::SE => {
456 if self.is_initiator {
457 dh(self.s.as_ref(), self.re.as_ref())
458 } else {
459 dh(self.e.as_ref(), self.rs.as_ref())
460 }
461 }
462 Token::SS => dh(self.s.as_ref(), self.rs.as_ref()),
463 _ => unreachable!(),
464 }
465 }
466}
467
468#[derive(Debug)]
470pub struct Error {
471 kind: ErrorKind,
472}
473
474#[derive(Debug, PartialEq, Eq, Copy, Clone)]
476pub enum ErrorKind {
477 DH,
479 NeedPSK,
481 Decryption,
483 TooShort,
485}
486
487impl Error {
488 fn dh() -> Error {
489 Error {
490 kind: ErrorKind::DH,
491 }
492 }
493
494 fn need_psk() -> Error {
495 Error {
496 kind: ErrorKind::NeedPSK,
497 }
498 }
499
500 fn decryption() -> Error {
501 Error {
502 kind: ErrorKind::Decryption,
503 }
504 }
505
506 fn too_short() -> Error {
507 Error {
508 kind: ErrorKind::TooShort,
509 }
510 }
511
512 pub fn kind(&self) -> ErrorKind {
514 self.kind
515 }
516}
517
518impl Display for Error {
519 fn fmt(&self, fmt: &mut Formatter<'_>) -> Result<(), FmtError> {
520 write!(fmt, "{:?}", self)
521 }
522}
523
524#[cfg(feature = "std")]
525impl ::std::error::Error for Error {
526 fn description(&self) -> &'static str {
527 match self.kind {
528 ErrorKind::DH => "DH error",
529 ErrorKind::NeedPSK => "Need PSK",
530 ErrorKind::Decryption => "Decryption failed",
531 ErrorKind::TooShort => "Message is too short",
532 }
533 }
534}
535
536pub struct HandshakeStateBuilder<'a, D: DH> {
538 pattern: Option<HandshakePattern>,
539 is_initiator: Option<bool>,
540 prologue: Option<&'a [u8]>,
541 s: Option<D::Key>,
542 e: Option<D::Key>,
543 rs: Option<D::Pubkey>,
544 re: Option<D::Pubkey>,
545}
546
547impl<'a, D: DH> Default for HandshakeStateBuilder<'a, D> {
548 fn default() -> Self {
549 HandshakeStateBuilder::new()
550 }
551}
552
553impl<'a, D> HandshakeStateBuilder<'a, D>
554where
555 D: DH,
556{
557 pub fn new() -> Self {
559 HandshakeStateBuilder {
560 pattern: None,
561 is_initiator: None,
562 prologue: None,
563 s: None,
564 e: None,
565 rs: None,
566 re: None,
567 }
568 }
569
570 pub fn set_pattern(&mut self, p: HandshakePattern) -> &mut Self {
572 self.pattern = Some(p);
573 self
574 }
575
576 pub fn set_is_initiator(&mut self, is: bool) -> &mut Self {
578 self.is_initiator = Some(is);
579 self
580 }
581
582 pub fn set_prologue(&mut self, prologue: &'a [u8]) -> &mut Self {
584 self.prologue = Some(prologue);
585 self
586 }
587
588 pub fn set_e(&mut self, e: D::Key) -> &mut Self {
593 self.e = Some(e);
594 self
595 }
596
597 pub fn set_s(&mut self, s: D::Key) -> &mut Self {
599 self.s = Some(s);
600 self
601 }
602
603 pub fn set_re(&mut self, re: D::Pubkey) -> &mut Self {
607 self.re = Some(re);
608 self
609 }
610
611 pub fn set_rs(&mut self, rs: D::Pubkey) -> &mut Self {
613 self.rs = Some(rs);
614 self
615 }
616
617 pub fn build_handshake_state<C, H>(self) -> HandshakeState<D, C, H>
626 where
627 C: Cipher,
628 H: Hash,
629 {
630 HandshakeState::new(
631 self.pattern.unwrap(),
632 self.is_initiator.unwrap(),
633 self.prologue.unwrap(),
634 self.s,
635 self.e,
636 self.rs,
637 self.re,
638 )
639 }
640}