cl_noise_protocol/
handshakestate.rs

1use crate::cipherstate::CipherState;
2use crate::handshakepattern::{HandshakePattern, Token};
3use crate::symmetricstate::SymmetricState;
4use crate::traits::{Cipher, Hash, U8Array, Unspecified, DH};
5use arrayvec::{ArrayString, ArrayVec};
6use core::fmt::{Display, Error as FmtError, Formatter, Write};
7
8#[cfg(feature = "alloc")]
9use alloc::vec::Vec;
10
11/// Noise handshake state.
12pub struct HandshakeState<D: DH, C: Cipher, H: Hash> {
13    symmetric: SymmetricState<C, H>,
14    s: Option<D::Key>,
15    e: Option<D::Key>,
16    rs: Option<D::Pubkey>,
17    re: Option<D::Pubkey>,
18    is_initiator: bool,
19    pattern: HandshakePattern,
20    message_index: usize,
21    pattern_has_psk: bool,
22    psks: ArrayVec<[u8; 32], 4>,
23}
24
25impl<D, C, H> Clone for HandshakeState<D, C, H>
26where
27    D: DH,
28    <D as DH>::Key: Clone,
29    C: Cipher,
30    <C as Cipher>::Key: Clone,
31    H: Hash,
32{
33    fn clone(&self) -> Self {
34        Self {
35            symmetric: self.symmetric.clone(),
36            s: self.s.clone(),
37            e: self.e.clone(),
38            rs: self.rs.as_ref().map(U8Array::clone),
39            re: self.re.as_ref().map(U8Array::clone),
40            is_initiator: self.is_initiator,
41            pattern: self.pattern.clone(),
42            message_index: self.message_index,
43            pattern_has_psk: self.pattern_has_psk,
44            psks: self.psks.clone(),
45        }
46    }
47}
48
49impl<D, C, H> HandshakeState<D, C, H>
50where
51    D: DH,
52    C: Cipher,
53    H: Hash,
54{
55    /// Get protocol name, e.g. Noise_IK_25519_ChaChaPoly_BLAKE2s.
56    fn get_name(pattern_name: &str) -> ArrayString<256> {
57        let mut ret = ArrayString::new();
58        write!(
59            &mut ret,
60            "Noise_{}_{}_{}_{}",
61            pattern_name,
62            D::name(),
63            C::name(),
64            H::name()
65        )
66        .unwrap();
67        ret
68    }
69
70    /// Initialize a handshake state.
71    ///
72    /// If `e` is [`None`], a new ephemeral key will be generated if necessary
73    /// when [`write_message`](HandshakeState::write_message).
74    ///
75    /// # Setting Explicit Ephemeral Key
76    ///
77    /// An explicit `e` should only be specified for testing purposes, or in
78    /// fallback patterns. If you do pass in an explicit `e`, [`HandshakeState`]
79    /// will use it as is and will not generate new ephemeral keys in
80    /// [`write_message`](HandshakeState::write_message).
81    pub fn new<P>(
82        pattern: HandshakePattern,
83        is_initiator: bool,
84        prologue: P,
85        s: Option<D::Key>,
86        e: Option<D::Key>,
87        rs: Option<D::Pubkey>,
88        re: Option<D::Pubkey>,
89    ) -> Self
90    where
91        P: AsRef<[u8]>,
92    {
93        let mut symmetric = SymmetricState::new(Self::get_name(pattern.get_name()).as_bytes());
94        let pattern_has_psk = pattern.has_psk();
95
96        // Mix in prologue.
97        symmetric.mix_hash(prologue.as_ref());
98
99        // Mix in static keys known ahead of time.
100        for t in pattern.get_pre_i() {
101            match *t {
102                Token::S => {
103                    if is_initiator {
104                        symmetric.mix_hash(D::pubkey(s.as_ref().unwrap()).as_slice());
105                    } else {
106                        symmetric.mix_hash(rs.as_ref().unwrap().as_slice());
107                    }
108                }
109                _ => panic!("Unexpected token in pre message"),
110            }
111        }
112        for t in pattern.get_pre_r() {
113            match *t {
114                Token::S => {
115                    if is_initiator {
116                        symmetric.mix_hash(rs.as_ref().unwrap().as_slice());
117                    } else {
118                        symmetric.mix_hash(D::pubkey(s.as_ref().unwrap()).as_slice());
119                    }
120                }
121                Token::E => {
122                    if is_initiator {
123                        let re = re.as_ref().unwrap().as_slice();
124                        symmetric.mix_hash(re);
125                        if pattern_has_psk {
126                            symmetric.mix_key(re);
127                        }
128                    } else {
129                        let e = D::pubkey(e.as_ref().unwrap());
130                        symmetric.mix_hash(e.as_slice());
131                        if pattern_has_psk {
132                            symmetric.mix_key(e.as_slice());
133                        }
134                    }
135                }
136                _ => panic!("Unexpected token in pre message"),
137            }
138        }
139
140        HandshakeState {
141            symmetric,
142            s,
143            e,
144            rs,
145            re,
146            is_initiator,
147            pattern,
148            message_index: 0,
149            pattern_has_psk,
150            psks: ArrayVec::new(),
151        }
152    }
153
154    /// Calculate the size overhead of the next message.
155    ///
156    /// # Panics
157    ///
158    /// If these is no more message to read/write, i.e., if the handshake is
159    /// already completed.
160    pub fn get_next_message_overhead(&self) -> usize {
161        let m = self.pattern.get_message_pattern(self.message_index);
162
163        let mut overhead = 0;
164
165        let mut has_key = self.symmetric.has_key();
166
167        for &t in m {
168            match t {
169                Token::E => {
170                    overhead += D::Pubkey::len();
171                    if self.pattern_has_psk {
172                        has_key = true;
173                    }
174                }
175                Token::S => {
176                    overhead += D::Pubkey::len();
177                    if has_key {
178                        overhead += 16;
179                    }
180                }
181                _ => {
182                    has_key = true;
183                }
184            }
185        }
186
187        if has_key {
188            overhead += 16
189        }
190
191        overhead
192    }
193
194    /// Like [`write_message`](HandshakeState::write_message), but returns a [`Vec`].
195    #[cfg(feature = "alloc")]
196    pub fn write_message_vec(&mut self, payload: &[u8]) -> Result<Vec<u8>, Error> {
197        let mut out = vec![0u8; payload.len() + self.get_next_message_overhead()];
198        self.write_message(payload, &mut out)?;
199        Ok(out)
200    }
201
202    /// Takes a payload and write the generated handshake message to
203    /// `out`.
204    ///
205    /// # Error Kinds
206    ///
207    /// - [DH](ErrorKind::DH): DH operation failed.
208    /// - [NeedPSK](ErrorKind::NeedPSK): A PSK token is encountered but none is available.
209    ///
210    /// # Panics
211    ///
212    /// * If a required static key is not set.
213    ///
214    /// * If `out.len() != payload.len() + self.get_next_message_overhead()`.
215    ///
216    /// * If it is not our turn to write.
217    ///
218    /// * If the handshake has already completed.
219    pub fn write_message(&mut self, payload: &[u8], out: &mut [u8]) -> Result<(), Error> {
220        debug_assert_eq!(out.len(), payload.len() + self.get_next_message_overhead());
221
222        // Check that it is our turn to send.
223        assert!(self.is_write_turn());
224
225        // Get the message pattern.
226        let m = self.pattern.get_message_pattern(self.message_index);
227        self.message_index += 1;
228
229        let mut cur: usize = 0;
230        // Process tokens.
231        for t in m {
232            match *t {
233                Token::E => {
234                    if self.e.is_none() {
235                        self.e = Some(D::genkey());
236                    }
237                    let e_pk = D::pubkey(self.e.as_ref().unwrap());
238                    self.symmetric.mix_hash(e_pk.as_slice());
239                    if self.pattern_has_psk {
240                        self.symmetric.mix_key(e_pk.as_slice());
241                    }
242                    out[cur..cur + D::Pubkey::len()].copy_from_slice(e_pk.as_slice());
243                    cur += D::Pubkey::len();
244                }
245                Token::S => {
246                    let len = if self.symmetric.has_key() {
247                        D::Pubkey::len() + 16
248                    } else {
249                        D::Pubkey::len()
250                    };
251
252                    let encrypted_s_out = &mut out[cur..cur + len];
253                    self.symmetric.encrypt_and_hash(
254                        D::pubkey(self.s.as_ref().unwrap()).as_slice(),
255                        encrypted_s_out,
256                    );
257                    cur += len;
258                }
259                Token::PSK => {
260                    if let Some(psk) = self.psks.pop_at(0) {
261                        self.symmetric.mix_key_and_hash(&psk);
262                    } else {
263                        return Err(Error::need_psk());
264                    }
265                }
266                t => {
267                    let dh_result = self.perform_dh(t).map_err(|_| Error::dh())?;
268                    self.symmetric.mix_key(dh_result.as_slice());
269                }
270            }
271        }
272
273        self.symmetric.encrypt_and_hash(payload, &mut out[cur..]);
274        Ok(())
275    }
276
277    /// Takes a handshake message, process it and update our internal
278    /// state, and write the encapsulated payload to `out`.
279    ///
280    /// # Error Kinds
281    ///
282    /// - [DH](ErrorKind::DH): DH operation failed.
283    /// - [NeedPSK](ErrorKind::NeedPSK): A PSK token is encountered but none is
284    ///   available.
285    /// - [Decryption](ErrorKind::Decryption): Decryption failed.
286    ///
287    /// # Error Recovery
288    ///
289    /// If [`read_message`](HandshakeState::read_message) fails, the whole
290    /// [`HandshakeState`] may be in invalid state and should not be used to
291    /// read or write any further messages. (But
292    /// [`get_re()`](HandshakeState::get_re) and
293    /// [`get_rs()`](HandshakeState::get_rs) is allowed.) In case error recovery
294    /// is desirable, [`clone`](Clone::clone) the [`HandshakeState`] before
295    /// calling [`read_message`](HandshakeState::read_message).
296    ///
297    /// # Panics
298    ///
299    /// * If `out.len() + self.get_next_message_overhead() != data.len()`.
300    ///
301    ///   (Notes that this implies `data.len() >= overhead`.)
302    ///
303    /// * If a required static key is not set.
304    ///
305    /// * If it is not our turn to read.
306    ///
307    /// * If the handshake has already completed.
308    pub fn read_message(&mut self, data: &[u8], out: &mut [u8]) -> Result<(), Error> {
309        debug_assert_eq!(out.len() + self.get_next_message_overhead(), data.len());
310
311        assert!(!self.is_write_turn());
312
313        // Get the message pattern.
314        let m = self.pattern.get_message_pattern(self.message_index);
315        self.message_index += 1;
316
317        let mut data = data;
318        // Consume the next `n` bytes of data.
319        let mut get = |n| {
320            let ret = &data[..n];
321            data = &data[n..];
322            ret
323        };
324
325        // Process tokens.
326        for t in m {
327            match *t {
328                Token::E => {
329                    let re = D::Pubkey::from_slice(get(D::Pubkey::len()));
330                    self.symmetric.mix_hash(re.as_slice());
331                    if self.pattern_has_psk {
332                        self.symmetric.mix_key(re.as_slice());
333                    }
334                    self.re = Some(re);
335                }
336                Token::S => {
337                    let temp = get(if self.symmetric.has_key() {
338                        D::Pubkey::len() + 16
339                    } else {
340                        D::Pubkey::len()
341                    });
342                    let mut rs = D::Pubkey::new();
343                    self.symmetric
344                        .decrypt_and_hash(temp, rs.as_mut())
345                        .map_err(|_| Error::decryption())?;
346                    self.rs = Some(rs);
347                }
348                Token::PSK => {
349                    if let Some(psk) = self.psks.pop_at(0) {
350                        self.symmetric.mix_key_and_hash(&psk);
351                    } else {
352                        return Err(Error::need_psk());
353                    }
354                }
355                t => {
356                    let dh_result = self.perform_dh(t).map_err(|_| Error::dh())?;
357                    self.symmetric.mix_key(dh_result.as_slice());
358                }
359            }
360        }
361
362        self.symmetric
363            .decrypt_and_hash(data, out)
364            .map_err(|_| Error::decryption())
365    }
366
367    /// Similar to [`read_message`](HandshakeState::read_message), but returns
368    /// result as a [`Vec`].
369    ///
370    /// In addition to possible errors from
371    /// [`read_message`](HandshakeState::read_message),
372    /// [TooShort](ErrorKind::TooShort) may be returned.
373    #[cfg(feature = "alloc")]
374    pub fn read_message_vec(&mut self, data: &[u8]) -> Result<Vec<u8>, Error> {
375        let overhead = self.get_next_message_overhead();
376        if data.len() < overhead {
377            Err(Error::too_short())
378        } else {
379            let mut out = vec![0u8; data.len() - overhead];
380            self.read_message(data, &mut out)?;
381            Ok(out)
382        }
383    }
384
385    /// Push a PSK to the PSK-queue.
386    ///
387    /// # Panics
388    ///
389    /// If the PSK-queue becomes longer than 4.
390    pub fn push_psk(&mut self, psk: &[u8]) {
391        self.psks.push(U8Array::from_slice(psk));
392    }
393
394    /// Whether handshake has completed.
395    pub fn completed(&self) -> bool {
396        self.message_index == self.pattern.get_message_patterns_len()
397    }
398
399    /// Get handshake hash. Useful for e.g., channel binding.
400    pub fn get_hash(&self) -> &[u8] {
401        self.symmetric.get_hash()
402    }
403
404    /// Get ciphers that can be used to encrypt/decrypt further messages. The
405    /// first [`CipherState`] is for initiator to responder, and the second for
406    /// responder to initiator.
407    ///
408    /// Should be called after handshake is
409    /// [`completed`](HandshakeState::completed).
410    pub fn get_ciphers(&self) -> (CipherState<C>, CipherState<C>) {
411        self.symmetric.split()
412    }
413
414    /// Get remote static pubkey, if available.
415    pub fn get_rs(&self) -> Option<D::Pubkey> {
416        self.rs.as_ref().map(U8Array::clone)
417    }
418
419    /// Get remote semi-ephemeral pubkey.
420    ///
421    /// Returns [`None`](None) if we do not know.
422    ///
423    /// Useful for noise-pipes.
424    pub fn get_re(&self) -> Option<D::Pubkey> {
425        self.re.as_ref().map(U8Array::clone)
426    }
427
428    /// Get whether this [`HandshakeState`] is created as initiator.
429    pub fn get_is_initiator(&self) -> bool {
430        self.is_initiator
431    }
432
433    /// Get handshake pattern this [`HandshakeState`] uses.
434    pub fn get_pattern(&self) -> &HandshakePattern {
435        &self.pattern
436    }
437
438    /// Check whether it is our turn to send in the handshake state.
439    pub fn is_write_turn(&self) -> bool {
440        self.message_index % 2 == if self.is_initiator { 0 } else { 1 }
441    }
442
443    fn perform_dh(&self, t: Token) -> Result<D::Output, Unspecified> {
444        let dh = |a: Option<&D::Key>, b: Option<&D::Pubkey>| D::dh(a.unwrap(), b.unwrap());
445
446        match t {
447            Token::EE => dh(self.e.as_ref(), self.re.as_ref()),
448            Token::ES => {
449                if self.is_initiator {
450                    dh(self.e.as_ref(), self.rs.as_ref())
451                } else {
452                    dh(self.s.as_ref(), self.re.as_ref())
453                }
454            }
455            Token::SE => {
456                if self.is_initiator {
457                    dh(self.s.as_ref(), self.re.as_ref())
458                } else {
459                    dh(self.e.as_ref(), self.rs.as_ref())
460                }
461            }
462            Token::SS => dh(self.s.as_ref(), self.rs.as_ref()),
463            _ => unreachable!(),
464        }
465    }
466}
467
468/// Handshake error.
469#[derive(Debug)]
470pub struct Error {
471    kind: ErrorKind,
472}
473
474/// Handshake error kind.
475#[derive(Debug, PartialEq, Eq, Copy, Clone)]
476pub enum ErrorKind {
477    /// A DH operation has failed.
478    DH,
479    /// A PSK is needed, but none is available.
480    NeedPSK,
481    /// Decryption failed.
482    Decryption,
483    /// The message is too short, and impossible to read.
484    TooShort,
485}
486
487impl Error {
488    fn dh() -> Error {
489        Error {
490            kind: ErrorKind::DH,
491        }
492    }
493
494    fn need_psk() -> Error {
495        Error {
496            kind: ErrorKind::NeedPSK,
497        }
498    }
499
500    fn decryption() -> Error {
501        Error {
502            kind: ErrorKind::Decryption,
503        }
504    }
505
506    fn too_short() -> Error {
507        Error {
508            kind: ErrorKind::TooShort,
509        }
510    }
511
512    /// Error kind.
513    pub fn kind(&self) -> ErrorKind {
514        self.kind
515    }
516}
517
518impl Display for Error {
519    fn fmt(&self, fmt: &mut Formatter<'_>) -> Result<(), FmtError> {
520        write!(fmt, "{:?}", self)
521    }
522}
523
524#[cfg(feature = "std")]
525impl ::std::error::Error for Error {
526    fn description(&self) -> &'static str {
527        match self.kind {
528            ErrorKind::DH => "DH error",
529            ErrorKind::NeedPSK => "Need PSK",
530            ErrorKind::Decryption => "Decryption failed",
531            ErrorKind::TooShort => "Message is too short",
532        }
533    }
534}
535
536/// Builder for `HandshakeState`.
537pub struct HandshakeStateBuilder<'a, D: DH> {
538    pattern: Option<HandshakePattern>,
539    is_initiator: Option<bool>,
540    prologue: Option<&'a [u8]>,
541    s: Option<D::Key>,
542    e: Option<D::Key>,
543    rs: Option<D::Pubkey>,
544    re: Option<D::Pubkey>,
545}
546
547impl<'a, D: DH> Default for HandshakeStateBuilder<'a, D> {
548    fn default() -> Self {
549        HandshakeStateBuilder::new()
550    }
551}
552
553impl<'a, D> HandshakeStateBuilder<'a, D>
554where
555    D: DH,
556{
557    /// Create a new [`HandshakeStateBuilder`].
558    pub fn new() -> Self {
559        HandshakeStateBuilder {
560            pattern: None,
561            is_initiator: None,
562            prologue: None,
563            s: None,
564            e: None,
565            rs: None,
566            re: None,
567        }
568    }
569
570    /// Set handshake pattern.
571    pub fn set_pattern(&mut self, p: HandshakePattern) -> &mut Self {
572        self.pattern = Some(p);
573        self
574    }
575
576    /// Set whether the [`HandshakeState`] is initiator.
577    pub fn set_is_initiator(&mut self, is: bool) -> &mut Self {
578        self.is_initiator = Some(is);
579        self
580    }
581
582    /// Set prologue.
583    pub fn set_prologue(&mut self, prologue: &'a [u8]) -> &mut Self {
584        self.prologue = Some(prologue);
585        self
586    }
587
588    /// Set ephemeral key.
589    ///
590    /// This is not encouraged and usually not necessary. Cf.
591    /// [`HandshakeState::new()`].
592    pub fn set_e(&mut self, e: D::Key) -> &mut Self {
593        self.e = Some(e);
594        self
595    }
596
597    /// Set static key.
598    pub fn set_s(&mut self, s: D::Key) -> &mut Self {
599        self.s = Some(s);
600        self
601    }
602
603    /// Set peer semi-ephemeral public key.
604    ///
605    /// Usually used in fallback patterns.
606    pub fn set_re(&mut self, re: D::Pubkey) -> &mut Self {
607        self.re = Some(re);
608        self
609    }
610
611    /// Set peer static public key.
612    pub fn set_rs(&mut self, rs: D::Pubkey) -> &mut Self {
613        self.rs = Some(rs);
614        self
615    }
616
617    /// Build [`HandshakeState`].
618    ///
619    /// # Panics
620    ///
621    /// If any of [`set_pattern`](HandshakeStateBuilder::set_pattern),
622    /// [`set_prologue`](HandshakeStateBuilder::set_prologue) or
623    /// [`set_is_initiator`](HandshakeStateBuilder::set_is_initiator) has not
624    /// been called yet.
625    pub fn build_handshake_state<C, H>(self) -> HandshakeState<D, C, H>
626    where
627        C: Cipher,
628        H: Hash,
629    {
630        HandshakeState::new(
631            self.pattern.unwrap(),
632            self.is_initiator.unwrap(),
633            self.prologue.unwrap(),
634            self.s,
635            self.e,
636            self.rs,
637            self.re,
638        )
639    }
640}