Skip to main content

noise_protocol/
handshakestate.rs

1use crate::cipherstate::CipherState;
2use crate::handshakepattern::{HandshakePattern, Token};
3use crate::symmetricstate::SymmetricState;
4use crate::traits::{Cipher, Hash, U8Array, DH};
5use arrayvec::{ArrayString, ArrayVec};
6use core::fmt::{Display, Error as FmtError, Formatter, Write};
7
8#[cfg(feature = "use_alloc")]
9use alloc::vec::Vec;
10
11/// Noise handshake state.
12pub struct HandshakeState<D: DH, C: Cipher, H: Hash> {
13    symmetric: SymmetricState<C, H>,
14    s: Option<D::Key>,
15    e: Option<D::Key>,
16    rs: Option<D::Pubkey>,
17    re: Option<D::Pubkey>,
18    is_initiator: bool,
19    pattern: HandshakePattern,
20    message_index: usize,
21    pattern_has_psk: bool,
22    psks: ArrayVec<[u8; 32], 4>,
23}
24
25impl<D, C, H> Clone for HandshakeState<D, C, H>
26where
27    D: DH,
28    C: Cipher,
29    H: Hash,
30{
31    fn clone(&self) -> Self {
32        Self {
33            symmetric: self.symmetric.clone(),
34            s: self.s.as_ref().map(U8Array::clone),
35            e: self.e.as_ref().map(U8Array::clone),
36            rs: self.rs.as_ref().map(U8Array::clone),
37            re: self.re.as_ref().map(U8Array::clone),
38            is_initiator: self.is_initiator,
39            pattern: self.pattern.clone(),
40            message_index: self.message_index,
41            pattern_has_psk: self.pattern_has_psk,
42            psks: self.psks.clone(),
43        }
44    }
45}
46
47impl<D, C, H> HandshakeState<D, C, H>
48where
49    D: DH,
50    C: Cipher,
51    H: Hash,
52{
53    /// Get protocol name, e.g. Noise_IK_25519_ChaChaPoly_BLAKE2s.
54    fn get_name(pattern_name: &str) -> ArrayString<256> {
55        let mut ret = ArrayString::new();
56        write!(
57            &mut ret,
58            "Noise_{}_{}_{}_{}",
59            pattern_name,
60            D::name(),
61            C::name(),
62            H::name()
63        )
64        .unwrap();
65        ret
66    }
67
68    /// Initialize a handshake state.
69    ///
70    /// If `e` is [`None`], a new ephemeral key will be generated if necessary
71    /// when [`write_message`](HandshakeState::write_message).
72    ///
73    /// # Setting Explicit Ephemeral Key
74    ///
75    /// An explicit `e` should only be specified for testing purposes, or in
76    /// fallback patterns. If you do pass in an explicit `e`, [`HandshakeState`]
77    /// will use it as is and will not generate new ephemeral keys in
78    /// [`write_message`](HandshakeState::write_message).
79    pub fn new<P>(
80        pattern: HandshakePattern,
81        is_initiator: bool,
82        prologue: P,
83        s: Option<D::Key>,
84        e: Option<D::Key>,
85        rs: Option<D::Pubkey>,
86        re: Option<D::Pubkey>,
87    ) -> Self
88    where
89        P: AsRef<[u8]>,
90    {
91        let mut symmetric = SymmetricState::new(Self::get_name(pattern.get_name()).as_bytes());
92        let pattern_has_psk = pattern.has_psk();
93
94        // Mix in prologue.
95        symmetric.mix_hash(prologue.as_ref());
96
97        // Mix in static keys known ahead of time.
98        for t in pattern.get_pre_i() {
99            match *t {
100                Token::S => {
101                    if is_initiator {
102                        symmetric.mix_hash(D::pubkey(s.as_ref().unwrap()).as_slice());
103                    } else {
104                        symmetric.mix_hash(rs.as_ref().unwrap().as_slice());
105                    }
106                }
107                _ => panic!("Unexpected token in pre message"),
108            }
109        }
110        for t in pattern.get_pre_r() {
111            match *t {
112                Token::S => {
113                    if is_initiator {
114                        symmetric.mix_hash(rs.as_ref().unwrap().as_slice());
115                    } else {
116                        symmetric.mix_hash(D::pubkey(s.as_ref().unwrap()).as_slice());
117                    }
118                }
119                Token::E => {
120                    if is_initiator {
121                        let re = re.as_ref().unwrap().as_slice();
122                        symmetric.mix_hash(re);
123                        if pattern_has_psk {
124                            symmetric.mix_key(re);
125                        }
126                    } else {
127                        let e = D::pubkey(e.as_ref().unwrap());
128                        symmetric.mix_hash(e.as_slice());
129                        if pattern_has_psk {
130                            symmetric.mix_key(e.as_slice());
131                        }
132                    }
133                }
134                _ => panic!("Unexpected token in pre message"),
135            }
136        }
137
138        HandshakeState {
139            symmetric,
140            s,
141            e,
142            rs,
143            re,
144            is_initiator,
145            pattern,
146            message_index: 0,
147            pattern_has_psk,
148            psks: ArrayVec::new(),
149        }
150    }
151
152    /// Calculate the size overhead of the next message.
153    ///
154    /// # Panics
155    ///
156    /// If these is no more message to read/write, i.e., if the handshake is
157    /// already completed.
158    pub fn get_next_message_overhead(&self) -> usize {
159        let m = self.pattern.get_message_pattern(self.message_index);
160
161        let mut overhead = 0;
162
163        let mut has_key = self.symmetric.has_key();
164
165        for &t in m {
166            match t {
167                Token::E => {
168                    overhead += D::Pubkey::len();
169                    if self.pattern_has_psk {
170                        has_key = true;
171                    }
172                }
173                Token::S => {
174                    overhead += D::Pubkey::len();
175                    if has_key {
176                        overhead += 16;
177                    }
178                }
179                _ => {
180                    has_key = true;
181                }
182            }
183        }
184
185        if has_key {
186            overhead += 16
187        }
188
189        overhead
190    }
191
192    /// Like [`write_message`](HandshakeState::write_message), but returns a [`Vec`].
193    #[cfg(any(feature = "use_std", feature = "use_alloc"))]
194    pub fn write_message_vec(&mut self, payload: &[u8]) -> Result<Vec<u8>, Error> {
195        let mut out = vec![0u8; payload.len() + self.get_next_message_overhead()];
196        self.write_message(payload, &mut out)?;
197        Ok(out)
198    }
199
200    /// Takes a payload and write the generated handshake message to
201    /// `out`.
202    ///
203    /// # Error Kinds
204    ///
205    /// - [DH](ErrorKind::DH): DH operation failed.
206    /// - [NeedPSK](ErrorKind::NeedPSK): A PSK token is encountered but none is available.
207    ///
208    /// # Panics
209    ///
210    /// * If a required static key is not set.
211    ///
212    /// * If `out.len() != payload.len() + self.get_next_message_overhead()`.
213    ///
214    /// * If it is not our turn to write.
215    ///
216    /// * If the handshake has already completed.
217    pub fn write_message(&mut self, payload: &[u8], out: &mut [u8]) -> Result<(), Error> {
218        debug_assert_eq!(out.len(), payload.len() + self.get_next_message_overhead());
219
220        // Check that it is our turn to send.
221        assert!(self.is_write_turn());
222
223        // Get the message pattern.
224        let m = self.pattern.get_message_pattern(self.message_index);
225        self.message_index += 1;
226
227        let mut cur: usize = 0;
228        // Process tokens.
229        for t in m {
230            match *t {
231                Token::E => {
232                    if self.e.is_none() {
233                        self.e = Some(D::genkey());
234                    }
235                    let e_pk = D::pubkey(self.e.as_ref().unwrap());
236                    self.symmetric.mix_hash(e_pk.as_slice());
237                    if self.pattern_has_psk {
238                        self.symmetric.mix_key(e_pk.as_slice());
239                    }
240                    out[cur..cur + D::Pubkey::len()].copy_from_slice(e_pk.as_slice());
241                    cur += D::Pubkey::len();
242                }
243                Token::S => {
244                    let len = if self.symmetric.has_key() {
245                        D::Pubkey::len() + 16
246                    } else {
247                        D::Pubkey::len()
248                    };
249
250                    let encrypted_s_out = &mut out[cur..cur + len];
251                    self.symmetric.encrypt_and_hash(
252                        D::pubkey(self.s.as_ref().unwrap()).as_slice(),
253                        encrypted_s_out,
254                    );
255                    cur += len;
256                }
257                Token::PSK => {
258                    if let Some(psk) = self.psks.pop_at(0) {
259                        self.symmetric.mix_key_and_hash(&psk);
260                    } else {
261                        return Err(Error::need_psk());
262                    }
263                }
264                t => {
265                    let dh_result = self.perform_dh(t).map_err(|_| Error::dh())?;
266                    self.symmetric.mix_key(dh_result.as_slice());
267                }
268            }
269        }
270
271        self.symmetric.encrypt_and_hash(payload, &mut out[cur..]);
272        Ok(())
273    }
274
275    /// Takes a handshake message, process it and update our internal
276    /// state, and write the encapsulated payload to `out`.
277    ///
278    /// # Error Kinds
279    ///
280    /// - [DH](ErrorKind::DH): DH operation failed.
281    /// - [NeedPSK](ErrorKind::NeedPSK): A PSK token is encountered but none is
282    ///   available.
283    /// - [Decryption](ErrorKind::Decryption): Decryption failed.
284    ///
285    /// # Error Recovery
286    ///
287    /// If [`read_message`](HandshakeState::read_message) fails, the whole
288    /// [`HandshakeState`] may be in invalid state and should not be used to
289    /// read or write any further messages. (But
290    /// [`get_re()`](HandshakeState::get_re) and
291    /// [`get_rs()`](HandshakeState::get_rs) is allowed.) In case error recovery
292    /// is desirable, [`clone`](Clone::clone) the [`HandshakeState`] before
293    /// calling [`read_message`](HandshakeState::read_message).
294    ///
295    /// # Panics
296    ///
297    /// * If `out.len() + self.get_next_message_overhead() != data.len()`.
298    ///
299    ///   (Notes that this implies `data.len() >= overhead`.)
300    ///
301    /// * If a required static key is not set.
302    ///
303    /// * If it is not our turn to read.
304    ///
305    /// * If the handshake has already completed.
306    pub fn read_message(&mut self, data: &[u8], out: &mut [u8]) -> Result<(), Error> {
307        debug_assert_eq!(out.len() + self.get_next_message_overhead(), data.len());
308
309        assert!(!self.is_write_turn());
310
311        // Get the message pattern.
312        let m = self.pattern.get_message_pattern(self.message_index);
313        self.message_index += 1;
314
315        let mut data = data;
316        // Consume the next `n` bytes of data.
317        let mut get = |n| {
318            let ret = &data[..n];
319            data = &data[n..];
320            ret
321        };
322
323        // Process tokens.
324        for t in m {
325            match *t {
326                Token::E => {
327                    let re = D::Pubkey::from_slice(get(D::Pubkey::len()));
328                    self.symmetric.mix_hash(re.as_slice());
329                    if self.pattern_has_psk {
330                        self.symmetric.mix_key(re.as_slice());
331                    }
332                    self.re = Some(re);
333                }
334                Token::S => {
335                    let temp = get(if self.symmetric.has_key() {
336                        D::Pubkey::len() + 16
337                    } else {
338                        D::Pubkey::len()
339                    });
340                    let mut rs = D::Pubkey::new();
341                    self.symmetric
342                        .decrypt_and_hash(temp, rs.as_mut())
343                        .map_err(|_| Error::decryption())?;
344                    self.rs = Some(rs);
345                }
346                Token::PSK => {
347                    if let Some(psk) = self.psks.pop_at(0) {
348                        self.symmetric.mix_key_and_hash(&psk);
349                    } else {
350                        return Err(Error::need_psk());
351                    }
352                }
353                t => {
354                    let dh_result = self.perform_dh(t).map_err(|_| Error::dh())?;
355                    self.symmetric.mix_key(dh_result.as_slice());
356                }
357            }
358        }
359
360        self.symmetric
361            .decrypt_and_hash(data, out)
362            .map_err(|_| Error::decryption())
363    }
364
365    /// Similar to [`read_message`](HandshakeState::read_message), but returns
366    /// result as a [`Vec`].
367    ///
368    /// In addition to possible errors from
369    /// [`read_message`](HandshakeState::read_message),
370    /// [TooShort](ErrorKind::TooShort) may be returned.
371    #[cfg(any(feature = "use_std", feature = "use_alloc"))]
372    pub fn read_message_vec(&mut self, data: &[u8]) -> Result<Vec<u8>, Error> {
373        let overhead = self.get_next_message_overhead();
374        if data.len() < overhead {
375            Err(Error::too_short())
376        } else {
377            let mut out = vec![0u8; data.len() - overhead];
378            self.read_message(data, &mut out)?;
379            Ok(out)
380        }
381    }
382
383    /// Push a PSK to the PSK-queue.
384    ///
385    /// # Panics
386    ///
387    /// If the PSK-queue becomes longer than 4.
388    pub fn push_psk(&mut self, psk: &[u8]) {
389        self.psks.push(U8Array::from_slice(psk));
390    }
391
392    /// Whether handshake has completed.
393    pub fn completed(&self) -> bool {
394        self.message_index == self.pattern.get_message_patterns_len()
395    }
396
397    /// Get handshake hash. Useful for e.g., channel binding.
398    pub fn get_hash(&self) -> &[u8] {
399        self.symmetric.get_hash()
400    }
401
402    /// Get ciphers that can be used to encrypt/decrypt further messages. The
403    /// first [`CipherState`] is for initiator to responder, and the second for
404    /// responder to initiator.
405    ///
406    /// Should be called after handshake is
407    /// [`completed`](HandshakeState::completed).
408    pub fn get_ciphers(&self) -> (CipherState<C>, CipherState<C>) {
409        self.symmetric.split()
410    }
411
412    /// Get remote static pubkey, if available.
413    pub fn get_rs(&self) -> Option<D::Pubkey> {
414        self.rs.as_ref().map(U8Array::clone)
415    }
416
417    /// Get remote semi-ephemeral pubkey.
418    ///
419    /// Returns [`None`](None) if we do not know.
420    ///
421    /// Useful for noise-pipes.
422    pub fn get_re(&self) -> Option<D::Pubkey> {
423        self.re.as_ref().map(U8Array::clone)
424    }
425
426    /// Set local static key.
427    ///
428    /// Useful if you want to choose the static key based on information
429    /// from previous messages, e.g. remote static pubkey.
430    ///
431    /// Handshake will panic if the static key is not set at a time it
432    /// is expected to be set.
433    pub fn set_s(&mut self, s: D::Key) {
434        self.s = Some(s);
435    }
436
437    /// Get whether this [`HandshakeState`] is created as initiator.
438    pub fn get_is_initiator(&self) -> bool {
439        self.is_initiator
440    }
441
442    /// Get handshake pattern this [`HandshakeState`] uses.
443    pub fn get_pattern(&self) -> &HandshakePattern {
444        &self.pattern
445    }
446
447    /// Check whether it is our turn to send in the handshake state.
448    pub fn is_write_turn(&self) -> bool {
449        self.message_index % 2 == if self.is_initiator { 0 } else { 1 }
450    }
451
452    fn perform_dh(&self, t: Token) -> Result<D::Output, ()> {
453        let dh = |a: Option<&D::Key>, b: Option<&D::Pubkey>| D::dh(a.unwrap(), b.unwrap());
454
455        match t {
456            Token::EE => dh(self.e.as_ref(), self.re.as_ref()),
457            Token::ES => {
458                if self.is_initiator {
459                    dh(self.e.as_ref(), self.rs.as_ref())
460                } else {
461                    dh(self.s.as_ref(), self.re.as_ref())
462                }
463            }
464            Token::SE => {
465                if self.is_initiator {
466                    dh(self.s.as_ref(), self.re.as_ref())
467                } else {
468                    dh(self.e.as_ref(), self.rs.as_ref())
469                }
470            }
471            Token::SS => dh(self.s.as_ref(), self.rs.as_ref()),
472            _ => unreachable!(),
473        }
474    }
475}
476
477/// Handshake error.
478#[derive(Debug)]
479pub struct Error {
480    kind: ErrorKind,
481}
482
483/// Handshake error kind.
484#[derive(Debug, PartialEq, Eq, Copy, Clone)]
485pub enum ErrorKind {
486    /// A DH operation has failed.
487    DH,
488    /// A PSK is needed, but none is available.
489    NeedPSK,
490    /// Decryption failed.
491    Decryption,
492    /// The message is too short, and impossible to read.
493    TooShort,
494}
495
496impl Error {
497    fn dh() -> Error {
498        Error {
499            kind: ErrorKind::DH,
500        }
501    }
502
503    fn need_psk() -> Error {
504        Error {
505            kind: ErrorKind::NeedPSK,
506        }
507    }
508
509    fn decryption() -> Error {
510        Error {
511            kind: ErrorKind::Decryption,
512        }
513    }
514
515    fn too_short() -> Error {
516        Error {
517            kind: ErrorKind::TooShort,
518        }
519    }
520
521    /// Error kind.
522    pub fn kind(&self) -> ErrorKind {
523        self.kind
524    }
525}
526
527impl Display for Error {
528    fn fmt(&self, fmt: &mut Formatter<'_>) -> Result<(), FmtError> {
529        write!(fmt, "{:?}", self)
530    }
531}
532
533#[cfg(feature = "use_std")]
534impl ::std::error::Error for Error {
535    fn description(&self) -> &'static str {
536        match self.kind {
537            ErrorKind::DH => "DH error",
538            ErrorKind::NeedPSK => "Need PSK",
539            ErrorKind::Decryption => "Decryption failed",
540            ErrorKind::TooShort => "Message is too short",
541        }
542    }
543}
544
545/// Builder for `HandshakeState`.
546pub struct HandshakeStateBuilder<'a, D: DH> {
547    pattern: Option<HandshakePattern>,
548    is_initiator: Option<bool>,
549    prologue: Option<&'a [u8]>,
550    s: Option<D::Key>,
551    e: Option<D::Key>,
552    rs: Option<D::Pubkey>,
553    re: Option<D::Pubkey>,
554}
555
556impl<'a, D: DH> Default for HandshakeStateBuilder<'a, D> {
557    fn default() -> Self {
558        HandshakeStateBuilder::new()
559    }
560}
561
562impl<'a, D> HandshakeStateBuilder<'a, D>
563where
564    D: DH,
565{
566    /// Create a new [`HandshakeStateBuilder`].
567    pub fn new() -> Self {
568        HandshakeStateBuilder {
569            pattern: None,
570            is_initiator: None,
571            prologue: None,
572            s: None,
573            e: None,
574            rs: None,
575            re: None,
576        }
577    }
578
579    /// Set handshake pattern.
580    pub fn set_pattern(&mut self, p: HandshakePattern) -> &mut Self {
581        self.pattern = Some(p);
582        self
583    }
584
585    /// Set whether the [`HandshakeState`] is initiator.
586    pub fn set_is_initiator(&mut self, is: bool) -> &mut Self {
587        self.is_initiator = Some(is);
588        self
589    }
590
591    /// Set prologue.
592    pub fn set_prologue(&mut self, prologue: &'a [u8]) -> &mut Self {
593        self.prologue = Some(prologue);
594        self
595    }
596
597    /// Set ephemeral key.
598    ///
599    /// This is not encouraged and usually not necessary. Cf.
600    /// [`HandshakeState::new()`].
601    pub fn set_e(&mut self, e: D::Key) -> &mut Self {
602        self.e = Some(e);
603        self
604    }
605
606    /// Set static key.
607    pub fn set_s(&mut self, s: D::Key) -> &mut Self {
608        self.s = Some(s);
609        self
610    }
611
612    /// Set peer semi-ephemeral public key.
613    ///
614    /// Usually used in fallback patterns.
615    pub fn set_re(&mut self, re: D::Pubkey) -> &mut Self {
616        self.re = Some(re);
617        self
618    }
619
620    /// Set peer static public key.
621    pub fn set_rs(&mut self, rs: D::Pubkey) -> &mut Self {
622        self.rs = Some(rs);
623        self
624    }
625
626    /// Build [`HandshakeState`].
627    ///
628    /// # Panics
629    ///
630    /// If any of [`set_pattern`](HandshakeStateBuilder::set_pattern),
631    /// [`set_prologue`](HandshakeStateBuilder::set_prologue) or
632    /// [`set_is_initiator`](HandshakeStateBuilder::set_is_initiator) has not
633    /// been called yet.
634    pub fn build_handshake_state<C, H>(self) -> HandshakeState<D, C, H>
635    where
636        C: Cipher,
637        H: Hash,
638    {
639        HandshakeState::new(
640            self.pattern.unwrap(),
641            self.is_initiator.unwrap(),
642            self.prologue.unwrap(),
643            self.s,
644            self.e,
645            self.rs,
646            self.re,
647        )
648    }
649}