noise_protocol/
handshakestate.rs

1use crate::cipherstate::CipherState;
2use crate::handshakepattern::{HandshakePattern, Token};
3use crate::symmetricstate::SymmetricState;
4use crate::traits::{Cipher, Hash, U8Array, DH};
5use arrayvec::{ArrayString, ArrayVec};
6use core::fmt::{Display, Error as FmtError, Formatter, Write};
7
8#[cfg(feature = "use_alloc")]
9use alloc::vec::Vec;
10
11/// Noise handshake state.
12pub struct HandshakeState<D: DH, C: Cipher, H: Hash> {
13    symmetric: SymmetricState<C, H>,
14    s: Option<D::Key>,
15    e: Option<D::Key>,
16    rs: Option<D::Pubkey>,
17    re: Option<D::Pubkey>,
18    is_initiator: bool,
19    pattern: HandshakePattern,
20    message_index: usize,
21    pattern_has_psk: bool,
22    psks: ArrayVec<[u8; 32], 4>,
23}
24
25impl<D, C, H> Clone for HandshakeState<D, C, H>
26where
27    D: DH,
28    C: Cipher,
29    H: Hash,
30{
31    fn clone(&self) -> Self {
32        Self {
33            symmetric: self.symmetric.clone(),
34            s: self.s.as_ref().map(U8Array::clone),
35            e: self.e.as_ref().map(U8Array::clone),
36            rs: self.rs.as_ref().map(U8Array::clone),
37            re: self.re.as_ref().map(U8Array::clone),
38            is_initiator: self.is_initiator,
39            pattern: self.pattern.clone(),
40            message_index: self.message_index,
41            pattern_has_psk: self.pattern_has_psk,
42            psks: self.psks.clone(),
43        }
44    }
45}
46
47impl<D, C, H> HandshakeState<D, C, H>
48where
49    D: DH,
50    C: Cipher,
51    H: Hash,
52{
53    /// Get protocol name, e.g. Noise_IK_25519_ChaChaPoly_BLAKE2s.
54    fn get_name(pattern_name: &str) -> ArrayString<256> {
55        let mut ret = ArrayString::new();
56        write!(
57            &mut ret,
58            "Noise_{}_{}_{}_{}",
59            pattern_name,
60            D::name(),
61            C::name(),
62            H::name()
63        )
64        .unwrap();
65        ret
66    }
67
68    /// Initialize a handshake state.
69    ///
70    /// If `e` is [`None`], a new ephemeral key will be generated if necessary
71    /// when [`write_message`](HandshakeState::write_message).
72    ///
73    /// # Setting Explicit Ephemeral Key
74    ///
75    /// An explicit `e` should only be specified for testing purposes, or in
76    /// fallback patterns. If you do pass in an explicit `e`, [`HandshakeState`]
77    /// will use it as is and will not generate new ephemeral keys in
78    /// [`write_message`](HandshakeState::write_message).
79    pub fn new<P>(
80        pattern: HandshakePattern,
81        is_initiator: bool,
82        prologue: P,
83        s: Option<D::Key>,
84        e: Option<D::Key>,
85        rs: Option<D::Pubkey>,
86        re: Option<D::Pubkey>,
87    ) -> Self
88    where
89        P: AsRef<[u8]>,
90    {
91        let mut symmetric = SymmetricState::new(Self::get_name(pattern.get_name()).as_bytes());
92        let pattern_has_psk = pattern.has_psk();
93
94        // Mix in prologue.
95        symmetric.mix_hash(prologue.as_ref());
96
97        // Mix in static keys known ahead of time.
98        for t in pattern.get_pre_i() {
99            match *t {
100                Token::S => {
101                    if is_initiator {
102                        symmetric.mix_hash(D::pubkey(s.as_ref().unwrap()).as_slice());
103                    } else {
104                        symmetric.mix_hash(rs.as_ref().unwrap().as_slice());
105                    }
106                }
107                _ => panic!("Unexpected token in pre message"),
108            }
109        }
110        for t in pattern.get_pre_r() {
111            match *t {
112                Token::S => {
113                    if is_initiator {
114                        symmetric.mix_hash(rs.as_ref().unwrap().as_slice());
115                    } else {
116                        symmetric.mix_hash(D::pubkey(s.as_ref().unwrap()).as_slice());
117                    }
118                }
119                Token::E => {
120                    if is_initiator {
121                        let re = re.as_ref().unwrap().as_slice();
122                        symmetric.mix_hash(re);
123                        if pattern_has_psk {
124                            symmetric.mix_key(re);
125                        }
126                    } else {
127                        let e = D::pubkey(e.as_ref().unwrap());
128                        symmetric.mix_hash(e.as_slice());
129                        if pattern_has_psk {
130                            symmetric.mix_key(e.as_slice());
131                        }
132                    }
133                }
134                _ => panic!("Unexpected token in pre message"),
135            }
136        }
137
138        HandshakeState {
139            symmetric,
140            s,
141            e,
142            rs,
143            re,
144            is_initiator,
145            pattern,
146            message_index: 0,
147            pattern_has_psk,
148            psks: ArrayVec::new(),
149        }
150    }
151
152    /// Calculate the size overhead of the next message.
153    ///
154    /// # Panics
155    ///
156    /// If these is no more message to read/write, i.e., if the handshake is
157    /// already completed.
158    pub fn get_next_message_overhead(&self) -> usize {
159        let m = self.pattern.get_message_pattern(self.message_index);
160
161        let mut overhead = 0;
162
163        let mut has_key = self.symmetric.has_key();
164
165        for &t in m {
166            match t {
167                Token::E => {
168                    overhead += D::Pubkey::len();
169                    if self.pattern_has_psk {
170                        has_key = true;
171                    }
172                }
173                Token::S => {
174                    overhead += D::Pubkey::len();
175                    if has_key {
176                        overhead += 16;
177                    }
178                }
179                _ => {
180                    has_key = true;
181                }
182            }
183        }
184
185        if has_key {
186            overhead += 16
187        }
188
189        overhead
190    }
191
192    /// Like [`write_message`](HandshakeState::write_message), but returns a [`Vec`].
193    #[cfg(any(feature = "use_std", feature = "use_alloc"))]
194    pub fn write_message_vec(&mut self, payload: &[u8]) -> Result<Vec<u8>, Error> {
195        let mut out = vec![0u8; payload.len() + self.get_next_message_overhead()];
196        self.write_message(payload, &mut out)?;
197        Ok(out)
198    }
199
200    /// Takes a payload and write the generated handshake message to
201    /// `out`.
202    ///
203    /// # Error Kinds
204    ///
205    /// - [DH](ErrorKind::DH): DH operation failed.
206    /// - [NeedPSK](ErrorKind::NeedPSK): A PSK token is encountered but none is available.
207    ///
208    /// # Panics
209    ///
210    /// * If a required static key is not set.
211    ///
212    /// * If `out.len() != payload.len() + self.get_next_message_overhead()`.
213    ///
214    /// * If it is not our turn to write.
215    ///
216    /// * If the handshake has already completed.
217    pub fn write_message(&mut self, payload: &[u8], out: &mut [u8]) -> Result<(), Error> {
218        debug_assert_eq!(out.len(), payload.len() + self.get_next_message_overhead());
219
220        // Check that it is our turn to send.
221        assert!(self.is_write_turn());
222
223        // Get the message pattern.
224        let m = self.pattern.get_message_pattern(self.message_index);
225        self.message_index += 1;
226
227        let mut cur: usize = 0;
228        // Process tokens.
229        for t in m {
230            match *t {
231                Token::E => {
232                    if self.e.is_none() {
233                        self.e = Some(D::genkey());
234                    }
235                    let e_pk = D::pubkey(self.e.as_ref().unwrap());
236                    self.symmetric.mix_hash(e_pk.as_slice());
237                    if self.pattern_has_psk {
238                        self.symmetric.mix_key(e_pk.as_slice());
239                    }
240                    out[cur..cur + D::Pubkey::len()].copy_from_slice(e_pk.as_slice());
241                    cur += D::Pubkey::len();
242                }
243                Token::S => {
244                    let len = if self.symmetric.has_key() {
245                        D::Pubkey::len() + 16
246                    } else {
247                        D::Pubkey::len()
248                    };
249
250                    let encrypted_s_out = &mut out[cur..cur + len];
251                    self.symmetric.encrypt_and_hash(
252                        D::pubkey(self.s.as_ref().unwrap()).as_slice(),
253                        encrypted_s_out,
254                    );
255                    cur += len;
256                }
257                Token::PSK => {
258                    if let Some(psk) = self.psks.pop_at(0) {
259                        self.symmetric.mix_key_and_hash(&psk);
260                    } else {
261                        return Err(Error::need_psk());
262                    }
263                }
264                t => {
265                    let dh_result = self.perform_dh(t).map_err(|_| Error::dh())?;
266                    self.symmetric.mix_key(dh_result.as_slice());
267                }
268            }
269        }
270
271        self.symmetric.encrypt_and_hash(payload, &mut out[cur..]);
272        Ok(())
273    }
274
275    /// Takes a handshake message, process it and update our internal
276    /// state, and write the encapsulated payload to `out`.
277    ///
278    /// # Error Kinds
279    ///
280    /// - [DH](ErrorKind::DH): DH operation failed.
281    /// - [NeedPSK](ErrorKind::NeedPSK): A PSK token is encountered but none is
282    ///   available.
283    /// - [Decryption](ErrorKind::Decryption): Decryption failed.
284    ///
285    /// # Error Recovery
286    ///
287    /// If [`read_message`](HandshakeState::read_message) fails, the whole
288    /// [`HandshakeState`] may be in invalid state and should not be used to
289    /// read or write any further messages. (But
290    /// [`get_re()`](HandshakeState::get_re) and
291    /// [`get_rs()`](HandshakeState::get_rs) is allowed.) In case error recovery
292    /// is desirable, [`clone`](Clone::clone) the [`HandshakeState`] before
293    /// calling [`read_message`](HandshakeState::read_message).
294    ///
295    /// # Panics
296    ///
297    /// * If `out.len() + self.get_next_message_overhead() != data.len()`.
298    ///
299    ///   (Notes that this implies `data.len() >= overhead`.)
300    ///
301    /// * If a required static key is not set.
302    ///
303    /// * If it is not our turn to read.
304    ///
305    /// * If the handshake has already completed.
306    pub fn read_message(&mut self, data: &[u8], out: &mut [u8]) -> Result<(), Error> {
307        debug_assert_eq!(out.len() + self.get_next_message_overhead(), data.len());
308
309        assert!(!self.is_write_turn());
310
311        // Get the message pattern.
312        let m = self.pattern.get_message_pattern(self.message_index);
313        self.message_index += 1;
314
315        let mut data = data;
316        // Consume the next `n` bytes of data.
317        let mut get = |n| {
318            let ret = &data[..n];
319            data = &data[n..];
320            ret
321        };
322
323        // Process tokens.
324        for t in m {
325            match *t {
326                Token::E => {
327                    let re = D::Pubkey::from_slice(get(D::Pubkey::len()));
328                    self.symmetric.mix_hash(re.as_slice());
329                    if self.pattern_has_psk {
330                        self.symmetric.mix_key(re.as_slice());
331                    }
332                    self.re = Some(re);
333                }
334                Token::S => {
335                    let temp = get(if self.symmetric.has_key() {
336                        D::Pubkey::len() + 16
337                    } else {
338                        D::Pubkey::len()
339                    });
340                    let mut rs = D::Pubkey::new();
341                    self.symmetric
342                        .decrypt_and_hash(temp, rs.as_mut())
343                        .map_err(|_| Error::decryption())?;
344                    self.rs = Some(rs);
345                }
346                Token::PSK => {
347                    if let Some(psk) = self.psks.pop_at(0) {
348                        self.symmetric.mix_key_and_hash(&psk);
349                    } else {
350                        return Err(Error::need_psk());
351                    }
352                }
353                t => {
354                    let dh_result = self.perform_dh(t).map_err(|_| Error::dh())?;
355                    self.symmetric.mix_key(dh_result.as_slice());
356                }
357            }
358        }
359
360        self.symmetric
361            .decrypt_and_hash(data, out)
362            .map_err(|_| Error::decryption())
363    }
364
365    /// Similar to [`read_message`](HandshakeState::read_message), but returns
366    /// result as a [`Vec`].
367    ///
368    /// In addition to possible errors from
369    /// [`read_message`](HandshakeState::read_message),
370    /// [TooShort](ErrorKind::TooShort) may be returned.
371    #[cfg(any(feature = "use_std", feature = "use_alloc"))]
372    pub fn read_message_vec(&mut self, data: &[u8]) -> Result<Vec<u8>, Error> {
373        let overhead = self.get_next_message_overhead();
374        if data.len() < overhead {
375            Err(Error::too_short())
376        } else {
377            let mut out = vec![0u8; data.len() - overhead];
378            self.read_message(data, &mut out)?;
379            Ok(out)
380        }
381    }
382
383    /// Push a PSK to the PSK-queue.
384    ///
385    /// # Panics
386    ///
387    /// If the PSK-queue becomes longer than 4.
388    pub fn push_psk(&mut self, psk: &[u8]) {
389        self.psks.push(U8Array::from_slice(psk));
390    }
391
392    /// Whether handshake has completed.
393    pub fn completed(&self) -> bool {
394        self.message_index == self.pattern.get_message_patterns_len()
395    }
396
397    /// Get handshake hash. Useful for e.g., channel binding.
398    pub fn get_hash(&self) -> &[u8] {
399        self.symmetric.get_hash()
400    }
401
402    /// Get ciphers that can be used to encrypt/decrypt further messages. The
403    /// first [`CipherState`] is for initiator to responder, and the second for
404    /// responder to initiator.
405    ///
406    /// Should be called after handshake is
407    /// [`completed`](HandshakeState::completed).
408    pub fn get_ciphers(&self) -> (CipherState<C>, CipherState<C>) {
409        self.symmetric.split()
410    }
411
412    /// Get remote static pubkey, if available.
413    pub fn get_rs(&self) -> Option<D::Pubkey> {
414        self.rs.as_ref().map(U8Array::clone)
415    }
416
417    /// Get remote semi-ephemeral pubkey.
418    ///
419    /// Returns [`None`](None) if we do not know.
420    ///
421    /// Useful for noise-pipes.
422    pub fn get_re(&self) -> Option<D::Pubkey> {
423        self.re.as_ref().map(U8Array::clone)
424    }
425
426    /// Get whether this [`HandshakeState`] is created as initiator.
427    pub fn get_is_initiator(&self) -> bool {
428        self.is_initiator
429    }
430
431    /// Get handshake pattern this [`HandshakeState`] uses.
432    pub fn get_pattern(&self) -> &HandshakePattern {
433        &self.pattern
434    }
435
436    /// Check whether it is our turn to send in the handshake state.
437    pub fn is_write_turn(&self) -> bool {
438        self.message_index % 2 == if self.is_initiator { 0 } else { 1 }
439    }
440
441    fn perform_dh(&self, t: Token) -> Result<D::Output, ()> {
442        let dh = |a: Option<&D::Key>, b: Option<&D::Pubkey>| D::dh(a.unwrap(), b.unwrap());
443
444        match t {
445            Token::EE => dh(self.e.as_ref(), self.re.as_ref()),
446            Token::ES => {
447                if self.is_initiator {
448                    dh(self.e.as_ref(), self.rs.as_ref())
449                } else {
450                    dh(self.s.as_ref(), self.re.as_ref())
451                }
452            }
453            Token::SE => {
454                if self.is_initiator {
455                    dh(self.s.as_ref(), self.re.as_ref())
456                } else {
457                    dh(self.e.as_ref(), self.rs.as_ref())
458                }
459            }
460            Token::SS => dh(self.s.as_ref(), self.rs.as_ref()),
461            _ => unreachable!(),
462        }
463    }
464}
465
466/// Handshake error.
467#[derive(Debug)]
468pub struct Error {
469    kind: ErrorKind,
470}
471
472/// Handshake error kind.
473#[derive(Debug, PartialEq, Eq, Copy, Clone)]
474pub enum ErrorKind {
475    /// A DH operation has failed.
476    DH,
477    /// A PSK is needed, but none is available.
478    NeedPSK,
479    /// Decryption failed.
480    Decryption,
481    /// The message is too short, and impossible to read.
482    TooShort,
483}
484
485impl Error {
486    fn dh() -> Error {
487        Error {
488            kind: ErrorKind::DH,
489        }
490    }
491
492    fn need_psk() -> Error {
493        Error {
494            kind: ErrorKind::NeedPSK,
495        }
496    }
497
498    fn decryption() -> Error {
499        Error {
500            kind: ErrorKind::Decryption,
501        }
502    }
503
504    fn too_short() -> Error {
505        Error {
506            kind: ErrorKind::TooShort,
507        }
508    }
509
510    /// Error kind.
511    pub fn kind(&self) -> ErrorKind {
512        self.kind
513    }
514}
515
516impl Display for Error {
517    fn fmt(&self, fmt: &mut Formatter<'_>) -> Result<(), FmtError> {
518        write!(fmt, "{:?}", self)
519    }
520}
521
522#[cfg(feature = "use_std")]
523impl ::std::error::Error for Error {
524    fn description(&self) -> &'static str {
525        match self.kind {
526            ErrorKind::DH => "DH error",
527            ErrorKind::NeedPSK => "Need PSK",
528            ErrorKind::Decryption => "Decryption failed",
529            ErrorKind::TooShort => "Message is too short",
530        }
531    }
532}
533
534/// Builder for `HandshakeState`.
535pub struct HandshakeStateBuilder<'a, D: DH> {
536    pattern: Option<HandshakePattern>,
537    is_initiator: Option<bool>,
538    prologue: Option<&'a [u8]>,
539    s: Option<D::Key>,
540    e: Option<D::Key>,
541    rs: Option<D::Pubkey>,
542    re: Option<D::Pubkey>,
543}
544
545impl<'a, D: DH> Default for HandshakeStateBuilder<'a, D> {
546    fn default() -> Self {
547        HandshakeStateBuilder::new()
548    }
549}
550
551impl<'a, D> HandshakeStateBuilder<'a, D>
552where
553    D: DH,
554{
555    /// Create a new [`HandshakeStateBuilder`].
556    pub fn new() -> Self {
557        HandshakeStateBuilder {
558            pattern: None,
559            is_initiator: None,
560            prologue: None,
561            s: None,
562            e: None,
563            rs: None,
564            re: None,
565        }
566    }
567
568    /// Set handshake pattern.
569    pub fn set_pattern(&mut self, p: HandshakePattern) -> &mut Self {
570        self.pattern = Some(p);
571        self
572    }
573
574    /// Set whether the [`HandshakeState`] is initiator.
575    pub fn set_is_initiator(&mut self, is: bool) -> &mut Self {
576        self.is_initiator = Some(is);
577        self
578    }
579
580    /// Set prologue.
581    pub fn set_prologue(&mut self, prologue: &'a [u8]) -> &mut Self {
582        self.prologue = Some(prologue);
583        self
584    }
585
586    /// Set ephemeral key.
587    ///
588    /// This is not encouraged and usually not necessary. Cf.
589    /// [`HandshakeState::new()`].
590    pub fn set_e(&mut self, e: D::Key) -> &mut Self {
591        self.e = Some(e);
592        self
593    }
594
595    /// Set static key.
596    pub fn set_s(&mut self, s: D::Key) -> &mut Self {
597        self.s = Some(s);
598        self
599    }
600
601    /// Set peer semi-ephemeral public key.
602    ///
603    /// Usually used in fallback patterns.
604    pub fn set_re(&mut self, re: D::Pubkey) -> &mut Self {
605        self.re = Some(re);
606        self
607    }
608
609    /// Set peer static public key.
610    pub fn set_rs(&mut self, rs: D::Pubkey) -> &mut Self {
611        self.rs = Some(rs);
612        self
613    }
614
615    /// Build [`HandshakeState`].
616    ///
617    /// # Panics
618    ///
619    /// If any of [`set_pattern`](HandshakeStateBuilder::set_pattern),
620    /// [`set_prologue`](HandshakeStateBuilder::set_prologue) or
621    /// [`set_is_initiator`](HandshakeStateBuilder::set_is_initiator) has not
622    /// been called yet.
623    pub fn build_handshake_state<C, H>(self) -> HandshakeState<D, C, H>
624    where
625        C: Cipher,
626        H: Hash,
627    {
628        HandshakeState::new(
629            self.pattern.unwrap(),
630            self.is_initiator.unwrap(),
631            self.prologue.unwrap(),
632            self.s,
633            self.e,
634            self.rs,
635            self.re,
636        )
637    }
638}