1use crate::cipherstate::CipherState;
2use crate::handshakepattern::{HandshakePattern, Token};
3use crate::symmetricstate::SymmetricState;
4use crate::traits::{Cipher, Hash, U8Array, DH};
5use arrayvec::{ArrayString, ArrayVec};
6use core::fmt::{Display, Error as FmtError, Formatter, Write};
7
8#[cfg(feature = "use_alloc")]
9use alloc::vec::Vec;
10
11pub struct HandshakeState<D: DH, C: Cipher, H: Hash> {
13 symmetric: SymmetricState<C, H>,
14 s: Option<D::Key>,
15 e: Option<D::Key>,
16 rs: Option<D::Pubkey>,
17 re: Option<D::Pubkey>,
18 is_initiator: bool,
19 pattern: HandshakePattern,
20 message_index: usize,
21 pattern_has_psk: bool,
22 psks: ArrayVec<[u8; 32], 4>,
23}
24
25impl<D, C, H> Clone for HandshakeState<D, C, H>
26where
27 D: DH,
28 C: Cipher,
29 H: Hash,
30{
31 fn clone(&self) -> Self {
32 Self {
33 symmetric: self.symmetric.clone(),
34 s: self.s.as_ref().map(U8Array::clone),
35 e: self.e.as_ref().map(U8Array::clone),
36 rs: self.rs.as_ref().map(U8Array::clone),
37 re: self.re.as_ref().map(U8Array::clone),
38 is_initiator: self.is_initiator,
39 pattern: self.pattern.clone(),
40 message_index: self.message_index,
41 pattern_has_psk: self.pattern_has_psk,
42 psks: self.psks.clone(),
43 }
44 }
45}
46
47impl<D, C, H> HandshakeState<D, C, H>
48where
49 D: DH,
50 C: Cipher,
51 H: Hash,
52{
53 fn get_name(pattern_name: &str) -> ArrayString<256> {
55 let mut ret = ArrayString::new();
56 write!(
57 &mut ret,
58 "Noise_{}_{}_{}_{}",
59 pattern_name,
60 D::name(),
61 C::name(),
62 H::name()
63 )
64 .unwrap();
65 ret
66 }
67
68 pub fn new<P>(
80 pattern: HandshakePattern,
81 is_initiator: bool,
82 prologue: P,
83 s: Option<D::Key>,
84 e: Option<D::Key>,
85 rs: Option<D::Pubkey>,
86 re: Option<D::Pubkey>,
87 ) -> Self
88 where
89 P: AsRef<[u8]>,
90 {
91 let mut symmetric = SymmetricState::new(Self::get_name(pattern.get_name()).as_bytes());
92 let pattern_has_psk = pattern.has_psk();
93
94 symmetric.mix_hash(prologue.as_ref());
96
97 for t in pattern.get_pre_i() {
99 match *t {
100 Token::S => {
101 if is_initiator {
102 symmetric.mix_hash(D::pubkey(s.as_ref().unwrap()).as_slice());
103 } else {
104 symmetric.mix_hash(rs.as_ref().unwrap().as_slice());
105 }
106 }
107 _ => panic!("Unexpected token in pre message"),
108 }
109 }
110 for t in pattern.get_pre_r() {
111 match *t {
112 Token::S => {
113 if is_initiator {
114 symmetric.mix_hash(rs.as_ref().unwrap().as_slice());
115 } else {
116 symmetric.mix_hash(D::pubkey(s.as_ref().unwrap()).as_slice());
117 }
118 }
119 Token::E => {
120 if is_initiator {
121 let re = re.as_ref().unwrap().as_slice();
122 symmetric.mix_hash(re);
123 if pattern_has_psk {
124 symmetric.mix_key(re);
125 }
126 } else {
127 let e = D::pubkey(e.as_ref().unwrap());
128 symmetric.mix_hash(e.as_slice());
129 if pattern_has_psk {
130 symmetric.mix_key(e.as_slice());
131 }
132 }
133 }
134 _ => panic!("Unexpected token in pre message"),
135 }
136 }
137
138 HandshakeState {
139 symmetric,
140 s,
141 e,
142 rs,
143 re,
144 is_initiator,
145 pattern,
146 message_index: 0,
147 pattern_has_psk,
148 psks: ArrayVec::new(),
149 }
150 }
151
152 pub fn get_next_message_overhead(&self) -> usize {
159 let m = self.pattern.get_message_pattern(self.message_index);
160
161 let mut overhead = 0;
162
163 let mut has_key = self.symmetric.has_key();
164
165 for &t in m {
166 match t {
167 Token::E => {
168 overhead += D::Pubkey::len();
169 if self.pattern_has_psk {
170 has_key = true;
171 }
172 }
173 Token::S => {
174 overhead += D::Pubkey::len();
175 if has_key {
176 overhead += 16;
177 }
178 }
179 _ => {
180 has_key = true;
181 }
182 }
183 }
184
185 if has_key {
186 overhead += 16
187 }
188
189 overhead
190 }
191
192 #[cfg(any(feature = "use_std", feature = "use_alloc"))]
194 pub fn write_message_vec(&mut self, payload: &[u8]) -> Result<Vec<u8>, Error> {
195 let mut out = vec![0u8; payload.len() + self.get_next_message_overhead()];
196 self.write_message(payload, &mut out)?;
197 Ok(out)
198 }
199
200 pub fn write_message(&mut self, payload: &[u8], out: &mut [u8]) -> Result<(), Error> {
218 debug_assert_eq!(out.len(), payload.len() + self.get_next_message_overhead());
219
220 assert!(self.is_write_turn());
222
223 let m = self.pattern.get_message_pattern(self.message_index);
225 self.message_index += 1;
226
227 let mut cur: usize = 0;
228 for t in m {
230 match *t {
231 Token::E => {
232 if self.e.is_none() {
233 self.e = Some(D::genkey());
234 }
235 let e_pk = D::pubkey(self.e.as_ref().unwrap());
236 self.symmetric.mix_hash(e_pk.as_slice());
237 if self.pattern_has_psk {
238 self.symmetric.mix_key(e_pk.as_slice());
239 }
240 out[cur..cur + D::Pubkey::len()].copy_from_slice(e_pk.as_slice());
241 cur += D::Pubkey::len();
242 }
243 Token::S => {
244 let len = if self.symmetric.has_key() {
245 D::Pubkey::len() + 16
246 } else {
247 D::Pubkey::len()
248 };
249
250 let encrypted_s_out = &mut out[cur..cur + len];
251 self.symmetric.encrypt_and_hash(
252 D::pubkey(self.s.as_ref().unwrap()).as_slice(),
253 encrypted_s_out,
254 );
255 cur += len;
256 }
257 Token::PSK => {
258 if let Some(psk) = self.psks.pop_at(0) {
259 self.symmetric.mix_key_and_hash(&psk);
260 } else {
261 return Err(Error::need_psk());
262 }
263 }
264 t => {
265 let dh_result = self.perform_dh(t).map_err(|_| Error::dh())?;
266 self.symmetric.mix_key(dh_result.as_slice());
267 }
268 }
269 }
270
271 self.symmetric.encrypt_and_hash(payload, &mut out[cur..]);
272 Ok(())
273 }
274
275 pub fn read_message(&mut self, data: &[u8], out: &mut [u8]) -> Result<(), Error> {
307 debug_assert_eq!(out.len() + self.get_next_message_overhead(), data.len());
308
309 assert!(!self.is_write_turn());
310
311 let m = self.pattern.get_message_pattern(self.message_index);
313 self.message_index += 1;
314
315 let mut data = data;
316 let mut get = |n| {
318 let ret = &data[..n];
319 data = &data[n..];
320 ret
321 };
322
323 for t in m {
325 match *t {
326 Token::E => {
327 let re = D::Pubkey::from_slice(get(D::Pubkey::len()));
328 self.symmetric.mix_hash(re.as_slice());
329 if self.pattern_has_psk {
330 self.symmetric.mix_key(re.as_slice());
331 }
332 self.re = Some(re);
333 }
334 Token::S => {
335 let temp = get(if self.symmetric.has_key() {
336 D::Pubkey::len() + 16
337 } else {
338 D::Pubkey::len()
339 });
340 let mut rs = D::Pubkey::new();
341 self.symmetric
342 .decrypt_and_hash(temp, rs.as_mut())
343 .map_err(|_| Error::decryption())?;
344 self.rs = Some(rs);
345 }
346 Token::PSK => {
347 if let Some(psk) = self.psks.pop_at(0) {
348 self.symmetric.mix_key_and_hash(&psk);
349 } else {
350 return Err(Error::need_psk());
351 }
352 }
353 t => {
354 let dh_result = self.perform_dh(t).map_err(|_| Error::dh())?;
355 self.symmetric.mix_key(dh_result.as_slice());
356 }
357 }
358 }
359
360 self.symmetric
361 .decrypt_and_hash(data, out)
362 .map_err(|_| Error::decryption())
363 }
364
365 #[cfg(any(feature = "use_std", feature = "use_alloc"))]
372 pub fn read_message_vec(&mut self, data: &[u8]) -> Result<Vec<u8>, Error> {
373 let overhead = self.get_next_message_overhead();
374 if data.len() < overhead {
375 Err(Error::too_short())
376 } else {
377 let mut out = vec![0u8; data.len() - overhead];
378 self.read_message(data, &mut out)?;
379 Ok(out)
380 }
381 }
382
383 pub fn push_psk(&mut self, psk: &[u8]) {
389 self.psks.push(U8Array::from_slice(psk));
390 }
391
392 pub fn completed(&self) -> bool {
394 self.message_index == self.pattern.get_message_patterns_len()
395 }
396
397 pub fn get_hash(&self) -> &[u8] {
399 self.symmetric.get_hash()
400 }
401
402 pub fn get_ciphers(&self) -> (CipherState<C>, CipherState<C>) {
409 self.symmetric.split()
410 }
411
412 pub fn get_rs(&self) -> Option<D::Pubkey> {
414 self.rs.as_ref().map(U8Array::clone)
415 }
416
417 pub fn get_re(&self) -> Option<D::Pubkey> {
423 self.re.as_ref().map(U8Array::clone)
424 }
425
426 pub fn get_is_initiator(&self) -> bool {
428 self.is_initiator
429 }
430
431 pub fn get_pattern(&self) -> &HandshakePattern {
433 &self.pattern
434 }
435
436 pub fn is_write_turn(&self) -> bool {
438 self.message_index % 2 == if self.is_initiator { 0 } else { 1 }
439 }
440
441 fn perform_dh(&self, t: Token) -> Result<D::Output, ()> {
442 let dh = |a: Option<&D::Key>, b: Option<&D::Pubkey>| D::dh(a.unwrap(), b.unwrap());
443
444 match t {
445 Token::EE => dh(self.e.as_ref(), self.re.as_ref()),
446 Token::ES => {
447 if self.is_initiator {
448 dh(self.e.as_ref(), self.rs.as_ref())
449 } else {
450 dh(self.s.as_ref(), self.re.as_ref())
451 }
452 }
453 Token::SE => {
454 if self.is_initiator {
455 dh(self.s.as_ref(), self.re.as_ref())
456 } else {
457 dh(self.e.as_ref(), self.rs.as_ref())
458 }
459 }
460 Token::SS => dh(self.s.as_ref(), self.rs.as_ref()),
461 _ => unreachable!(),
462 }
463 }
464}
465
466#[derive(Debug)]
468pub struct Error {
469 kind: ErrorKind,
470}
471
472#[derive(Debug, PartialEq, Eq, Copy, Clone)]
474pub enum ErrorKind {
475 DH,
477 NeedPSK,
479 Decryption,
481 TooShort,
483}
484
485impl Error {
486 fn dh() -> Error {
487 Error {
488 kind: ErrorKind::DH,
489 }
490 }
491
492 fn need_psk() -> Error {
493 Error {
494 kind: ErrorKind::NeedPSK,
495 }
496 }
497
498 fn decryption() -> Error {
499 Error {
500 kind: ErrorKind::Decryption,
501 }
502 }
503
504 fn too_short() -> Error {
505 Error {
506 kind: ErrorKind::TooShort,
507 }
508 }
509
510 pub fn kind(&self) -> ErrorKind {
512 self.kind
513 }
514}
515
516impl Display for Error {
517 fn fmt(&self, fmt: &mut Formatter<'_>) -> Result<(), FmtError> {
518 write!(fmt, "{:?}", self)
519 }
520}
521
522#[cfg(feature = "use_std")]
523impl ::std::error::Error for Error {
524 fn description(&self) -> &'static str {
525 match self.kind {
526 ErrorKind::DH => "DH error",
527 ErrorKind::NeedPSK => "Need PSK",
528 ErrorKind::Decryption => "Decryption failed",
529 ErrorKind::TooShort => "Message is too short",
530 }
531 }
532}
533
534pub struct HandshakeStateBuilder<'a, D: DH> {
536 pattern: Option<HandshakePattern>,
537 is_initiator: Option<bool>,
538 prologue: Option<&'a [u8]>,
539 s: Option<D::Key>,
540 e: Option<D::Key>,
541 rs: Option<D::Pubkey>,
542 re: Option<D::Pubkey>,
543}
544
545impl<'a, D: DH> Default for HandshakeStateBuilder<'a, D> {
546 fn default() -> Self {
547 HandshakeStateBuilder::new()
548 }
549}
550
551impl<'a, D> HandshakeStateBuilder<'a, D>
552where
553 D: DH,
554{
555 pub fn new() -> Self {
557 HandshakeStateBuilder {
558 pattern: None,
559 is_initiator: None,
560 prologue: None,
561 s: None,
562 e: None,
563 rs: None,
564 re: None,
565 }
566 }
567
568 pub fn set_pattern(&mut self, p: HandshakePattern) -> &mut Self {
570 self.pattern = Some(p);
571 self
572 }
573
574 pub fn set_is_initiator(&mut self, is: bool) -> &mut Self {
576 self.is_initiator = Some(is);
577 self
578 }
579
580 pub fn set_prologue(&mut self, prologue: &'a [u8]) -> &mut Self {
582 self.prologue = Some(prologue);
583 self
584 }
585
586 pub fn set_e(&mut self, e: D::Key) -> &mut Self {
591 self.e = Some(e);
592 self
593 }
594
595 pub fn set_s(&mut self, s: D::Key) -> &mut Self {
597 self.s = Some(s);
598 self
599 }
600
601 pub fn set_re(&mut self, re: D::Pubkey) -> &mut Self {
605 self.re = Some(re);
606 self
607 }
608
609 pub fn set_rs(&mut self, rs: D::Pubkey) -> &mut Self {
611 self.rs = Some(rs);
612 self
613 }
614
615 pub fn build_handshake_state<C, H>(self) -> HandshakeState<D, C, H>
624 where
625 C: Cipher,
626 H: Hash,
627 {
628 HandshakeState::new(
629 self.pattern.unwrap(),
630 self.is_initiator.unwrap(),
631 self.prologue.unwrap(),
632 self.s,
633 self.e,
634 self.rs,
635 self.re,
636 )
637 }
638}