snow/params/
patterns.rs

1#![allow(clippy::enum_glob_use)]
2
3#[cfg(not(feature = "std"))]
4use alloc::{vec, vec::Vec};
5
6use crate::error::{Error, PatternProblem};
7use core::{convert::TryFrom, str::FromStr};
8
9/// A small helper macro that behaves similar to the `vec![]` standard macro,
10/// except it allocates a bit extra to avoid resizing.
11macro_rules! message_vec {
12    ($($item:expr),*) => ({
13        let token_groups: &[&[Token]] = &[$($item),*];
14        let mut vec: MessagePatterns = Vec::with_capacity(10);
15        for group in token_groups {
16            let mut inner = Vec::with_capacity(10);
17            inner.extend_from_slice(group);
18            vec.push(inner);
19        }
20        vec
21    });
22}
23
24/// This macro is specifically a helper to generate the enum of all handshake
25/// patterns in a less error-prone way.
26///
27/// While rust macros can be really difficult to read, it felt too sketchy to hand-
28/// write a growing list of str -> enum variant match statements.
29macro_rules! pattern_enum {
30    // NOTE: see https://danielkeep.github.io/tlborm/book/mbe-macro-rules.html and
31    // https://doc.rust-lang.org/rust-by-example/macros.html for a great overview
32    // of `macro_rules!`.
33    ($name:ident {
34        $($variant:ident),* $(,)*
35    }) => {
36        /// One of the patterns as defined in the
37        /// [Handshake Pattern](https://noiseprotocol.org/noise.html#handshake-patterns)
38        /// section.
39        #[allow(missing_docs)]
40        #[derive(Copy, Clone, PartialEq, Debug)]
41        pub enum $name {
42            $($variant),*,
43        }
44
45        impl FromStr for $name {
46            type Err = Error;
47            fn from_str(s: &str) -> Result<Self, Self::Err> {
48                use self::$name::*;
49                match s {
50                    $(
51                        stringify!($variant) => Ok($variant)
52                    ),
53                    *,
54                    _    => return Err(PatternProblem::UnsupportedHandshakeType.into())
55                }
56            }
57        }
58
59        impl $name {
60            /// The equivalent of the `ToString` trait, but for `&'static str`.
61            #[must_use] pub fn as_str(self) -> &'static str {
62                use self::$name::*;
63                match self {
64                    $(
65                        $variant => stringify!($variant)
66                    ),
67                    *
68                }
69            }
70        }
71
72        #[doc(hidden)]
73        pub const SUPPORTED_HANDSHAKE_PATTERNS: &'static [$name] = &[$($name::$variant),*];
74    }
75}
76
77/// The tokens which describe patterns involving DH calculations.
78///
79/// See: <https://noiseprotocol.org/noise.html#handshake-patterns>
80#[derive(Copy, Clone, PartialEq, Debug)]
81pub(crate) enum DhToken {
82    Ee,
83    Es,
84    Se,
85    Ss,
86}
87
88/// The tokens which describe message patterns.
89///
90/// See: <https://noiseprotocol.org/noise.html#handshake-patterns>
91#[derive(Copy, Clone, PartialEq, Debug)]
92pub(crate) enum Token {
93    E,
94    S,
95    Dh(DhToken),
96    Psk(u8),
97    #[cfg(feature = "hfs")]
98    E1,
99    #[cfg(feature = "hfs")]
100    Ekem1,
101}
102
103#[cfg(feature = "hfs")]
104impl Token {
105    fn is_dh(self) -> bool {
106        matches!(self, Dh(_))
107    }
108}
109
110// See the documentation in the macro above.
111pattern_enum! {
112    HandshakePattern {
113        // 7.4. One-way handshake patterns
114        N, X, K,
115
116        // 7.5. Interactive handshake patterns (fundamental)
117        NN, NK, NX, XN, XK, XX, KN, KK, KX, IN, IK, IX,
118
119        // 7.6. Interactive handshake patterns (deferred)
120        NK1, NX1, X1N, X1K, XK1, X1K1, X1X, XX1, X1X1, K1N, K1K, KK1, K1K1, K1X,
121        KX1, K1X1, I1N, I1K, IK1, I1K1, I1X, IX1, I1X1
122    }
123}
124
125impl HandshakePattern {
126    /// If the protocol is one-way only
127    ///
128    /// See: <https://noiseprotocol.org/noise.html#one-way-handshake-patterns>
129    #[must_use]
130    pub fn is_oneway(self) -> bool {
131        matches!(self, N | X | K)
132    }
133
134    /// Whether this pattern requires a long-term static key.
135    #[must_use]
136    pub fn needs_local_static_key(self, initiator: bool) -> bool {
137        if initiator {
138            !matches!(self, N | NN | NK | NX | NK1 | NX1)
139        } else {
140            !matches!(self, NN | XN | KN | IN | X1N | K1N | I1N)
141        }
142    }
143
144    /// Whether this pattern demands a remote public key pre-message.
145    #[rustfmt::skip]
146    #[must_use]    pub fn need_known_remote_pubkey(self, initiator: bool) -> bool {
147        if initiator {
148            matches!(
149                self,
150                N | K | X | NK | XK | KK | IK | NK1 | X1K | XK1 | X1K1 | K1K | KK1 | K1K1 | I1K | IK1 | I1K1
151            )
152        } else {
153            matches!(
154                self,
155                K | KN | KK | KX | K1N | K1K | KK1 | K1K1 | K1X | KX1 | K1X1
156            )
157        }
158    }
159}
160
161/// A modifier applied to the base pattern as defined in the Noise spec.
162#[derive(Copy, Clone, PartialEq, Debug)]
163pub enum HandshakeModifier {
164    /// Insert a PSK to mix at the associated position
165    Psk(u8),
166
167    /// Modify the base pattern to its "fallback" form
168    Fallback,
169
170    #[cfg(feature = "hfs")]
171    /// Modify the base pattern to use Hybrid-Forward-Secrecy
172    Hfs,
173}
174
175impl FromStr for HandshakeModifier {
176    type Err = Error;
177
178    fn from_str(s: &str) -> Result<Self, Self::Err> {
179        match s {
180            s if s.starts_with("psk") => {
181                Ok(HandshakeModifier::Psk(s[3..].parse().map_err(|_| PatternProblem::InvalidPsk)?))
182            },
183            "fallback" => Ok(HandshakeModifier::Fallback),
184            #[cfg(feature = "hfs")]
185            "hfs" => Ok(HandshakeModifier::Hfs),
186            _ => Err(PatternProblem::UnsupportedModifier.into()),
187        }
188    }
189}
190
191/// Handshake modifiers that will be used during key exchange handshake.
192#[derive(Clone, PartialEq, Debug)]
193pub struct HandshakeModifierList {
194    /// List of parsed modifiers.
195    pub list: Vec<HandshakeModifier>,
196}
197
198impl FromStr for HandshakeModifierList {
199    type Err = Error;
200
201    fn from_str(s: &str) -> Result<Self, Self::Err> {
202        if s.is_empty() {
203            Ok(HandshakeModifierList { list: vec![] })
204        } else {
205            let modifier_names = s.split('+');
206            let mut modifiers = vec![];
207            for modifier_name in modifier_names {
208                let modifier: HandshakeModifier = modifier_name.parse()?;
209                if modifiers.contains(&modifier) {
210                    return Err(Error::Pattern(PatternProblem::DuplicateModifier));
211                }
212                modifiers.push(modifier);
213            }
214            Ok(HandshakeModifierList { list: modifiers })
215        }
216    }
217}
218
219/// The pattern/modifier combination choice (no primitives specified)
220/// for a full noise protocol definition.
221#[derive(Clone, PartialEq, Debug)]
222pub struct HandshakeChoice {
223    /// The base pattern itself
224    pub pattern: HandshakePattern,
225
226    /// The modifier(s) requested for the base pattern
227    pub modifiers: HandshakeModifierList,
228}
229
230impl HandshakeChoice {
231    /// Whether the handshake choice includes one or more PSK modifiers.
232    #[must_use]
233    pub fn is_psk(&self) -> bool {
234        for modifier in &self.modifiers.list {
235            if let HandshakeModifier::Psk(_) = *modifier {
236                return true;
237            }
238        }
239        false
240    }
241
242    /// Whether the handshake choice includes the fallback modifier.
243    #[must_use]
244    pub fn is_fallback(&self) -> bool {
245        self.modifiers.list.contains(&HandshakeModifier::Fallback)
246    }
247
248    /// Whether the handshake choice includes the hfs modifier.
249    #[cfg(feature = "hfs")]
250    #[must_use]
251    pub fn is_hfs(&self) -> bool {
252        self.modifiers.list.contains(&HandshakeModifier::Hfs)
253    }
254
255    /// Parse and split a base `HandshakePattern` from its optional modifiers
256    fn parse_pattern_and_modifier(s: &str) -> Result<(HandshakePattern, &str), Error> {
257        for i in (1..=4).rev() {
258            if s.len() > i - 1 && s.is_char_boundary(i) {
259                if let Ok(p) = s[..i].parse() {
260                    return Ok((p, &s[i..]));
261                }
262            }
263        }
264
265        Err(PatternProblem::UnsupportedHandshakeType.into())
266    }
267}
268
269impl FromStr for HandshakeChoice {
270    type Err = Error;
271
272    fn from_str(s: &str) -> Result<Self, Self::Err> {
273        let (pattern, remainder) = Self::parse_pattern_and_modifier(s)?;
274        let modifiers = remainder.parse()?;
275
276        Ok(HandshakeChoice { pattern, modifiers })
277    }
278}
279
280type PremessagePatterns = &'static [Token];
281pub(crate) type MessagePatterns = Vec<Vec<Token>>;
282
283/// The defined token patterns for a given handshake.
284///
285/// See: <https://noiseprotocol.org/noise.html#handshake-patterns>
286#[derive(Debug)]
287pub(crate) struct HandshakeTokens {
288    pub premsg_pattern_i: PremessagePatterns,
289    pub premsg_pattern_r: PremessagePatterns,
290    pub msg_patterns: MessagePatterns,
291}
292
293use self::{DhToken::*, HandshakePattern::*, Token::*};
294
295type Patterns = (PremessagePatterns, PremessagePatterns, MessagePatterns);
296
297impl<'a> TryFrom<&'a HandshakeChoice> for HandshakeTokens {
298    type Error = Error;
299
300    // We're going to ignore the clippy warnings here about this function being too long because
301    // it's essentially a lookup table and not problematic complex logic.
302    #[allow(clippy::cognitive_complexity)]
303    #[allow(clippy::too_many_lines)]
304    fn try_from(handshake: &'a HandshakeChoice) -> Result<Self, Self::Error> {
305        // Hfs cannot be combined with one-way handshake patterns
306        #[cfg(feature = "hfs")]
307        check_hfs_and_oneway_conflict(handshake)?;
308
309        #[rustfmt::skip]
310        let mut patterns: Patterns = match handshake.pattern {
311            N  => (
312                static_slice![Token: ],
313                static_slice![Token: S],
314                message_vec![&[E, Dh(Es)]]
315            ),
316            K  => (
317                static_slice![Token: S],
318                static_slice![Token: S],
319                message_vec![&[E, Dh(Es), Dh(Ss)]]
320            ),
321            X  => (
322                static_slice![Token: ],
323                static_slice![Token: S],
324                message_vec![&[E, Dh(Es), S, Dh(Ss)]]
325            ),
326            NN => (
327                static_slice![Token: ],
328                static_slice![Token: ],
329                message_vec![&[E], &[E, Dh(Ee)]]
330            ),
331            NK => (
332                static_slice![Token: ],
333                static_slice![Token: S],
334                message_vec![&[E, Dh(Es)], &[E, Dh(Ee)]]
335            ),
336            NX => (
337                static_slice![Token: ],
338                static_slice![Token: ],
339                message_vec![&[E], &[E, Dh(Ee), S, Dh(Es)]]
340            ),
341            XN => (
342                static_slice![Token: ],
343                static_slice![Token: ],
344                message_vec![&[E], &[E, Dh(Ee)], &[S, Dh(Se)]]
345            ),
346            XK => (
347                static_slice![Token: ],
348                static_slice![Token: S],
349                message_vec![&[E, Dh(Es)], &[E, Dh(Ee)], &[S, Dh(Se)]]
350            ),
351            XX => (
352                static_slice![Token: ],
353                static_slice![Token: ],
354                message_vec![&[E], &[E, Dh(Ee), S, Dh(Es)], &[S, Dh(Se)]],
355            ),
356            KN => (
357                static_slice![Token: S],
358                static_slice![Token: ],
359                message_vec![&[E], &[E, Dh(Ee), Dh(Se)]],
360            ),
361            KK => (
362                static_slice![Token: S],
363                static_slice![Token: S],
364                message_vec![&[E, Dh(Es), Dh(Ss)], &[E, Dh(Ee), Dh(Se)]],
365            ),
366            KX => (
367                static_slice![Token: S],
368                static_slice![Token: ],
369                message_vec![&[E], &[E, Dh(Ee), Dh(Se), S, Dh(Es)]],
370            ),
371            IN => (
372                static_slice![Token: ],
373                static_slice![Token: ],
374                message_vec![&[E, S], &[E, Dh(Ee), Dh(Se)]],
375            ),
376            IK => (
377                static_slice![Token: ],
378                static_slice![Token: S],
379                message_vec![&[E, Dh(Es), S, Dh(Ss)], &[E, Dh(Ee), Dh(Se)]],
380            ),
381            IX => (
382                static_slice![Token: ],
383                static_slice![Token: ],
384                message_vec![&[E, S], &[E, Dh(Ee), Dh(Se), S, Dh(Es)]],
385            ),
386            NK1 => (
387                static_slice![Token: ],
388                static_slice![Token: S],
389                message_vec![&[E], &[E, Dh(Ee), Dh(Es)]],
390            ),
391            NX1 => (
392                static_slice![Token: ],
393                static_slice![Token: ],
394                message_vec![&[E], &[E, Dh(Ee), S], &[Dh(Es)]]
395            ),
396            X1N => (
397                static_slice![Token: ],
398                static_slice![Token: ],
399                message_vec![&[E], &[E, Dh(Ee)], &[S], &[Dh(Se)]]
400            ),
401            X1K => (
402                static_slice![Token: ],
403                static_slice![Token: S],
404                message_vec![&[E, Dh(Es)], &[E, Dh(Ee)], &[S], &[Dh(Se)]]
405            ),
406            XK1 => (
407                static_slice![Token: ],
408                static_slice![Token: S],
409                message_vec![&[E], &[E, Dh(Ee), Dh(Es)], &[S, Dh(Se)]]
410            ),
411            X1K1 => (
412                static_slice![Token: ],
413                static_slice![Token: S],
414                message_vec![&[E], &[E, Dh(Ee), Dh(Es)], &[S], &[Dh(Se)]]
415            ),
416            X1X => (
417                static_slice![Token: ],
418                static_slice![Token: ],
419                message_vec![&[E], &[E, Dh(Ee), S, Dh(Es)], &[S], &[Dh(Se)]],
420            ),
421            XX1 => (
422                static_slice![Token: ],
423                static_slice![Token: ],
424                message_vec![&[E], &[E, Dh(Ee), S], &[Dh(Es), S, Dh(Se)]],
425            ),
426            X1X1 => (
427                static_slice![Token: ],
428                static_slice![Token: ],
429                message_vec![&[E], &[E, Dh(Ee), S], &[Dh(Es), S], &[Dh(Se)]],
430            ),
431            K1N => (
432                static_slice![Token: S],
433                static_slice![Token: ],
434                message_vec![&[E], &[E, Dh(Ee)], &[Dh(Se)]],
435            ),
436            K1K => (
437                static_slice![Token: S],
438                static_slice![Token: S],
439                message_vec![&[E, Dh(Es)], &[E, Dh(Ee)], &[Dh(Se)]],
440            ),
441            KK1 => (
442                static_slice![Token: S],
443                static_slice![Token: S],
444                message_vec![&[E], &[E, Dh(Ee), Dh(Se), Dh(Es)]],
445            ),
446            K1K1 => (
447                static_slice![Token: S],
448                static_slice![Token: S],
449                message_vec![&[E], &[E, Dh(Ee), Dh(Es)], &[Dh(Se)]],
450            ),
451            K1X => (
452                static_slice![Token: S],
453                static_slice![Token: ],
454                message_vec![&[E], &[E, Dh(Ee), S, Dh(Es)], &[Dh(Se)]],
455            ),
456            KX1 => (
457                static_slice![Token: S],
458                static_slice![Token: ],
459                message_vec![&[E], &[E, Dh(Ee), Dh(Se), S], &[Dh(Es)]],
460            ),
461            K1X1 => (
462                static_slice![Token: S],
463                static_slice![Token: ],
464                message_vec![&[E], &[E, Dh(Ee), S], &[Dh(Se), Dh(Es)]],
465            ),
466            I1N => (
467                static_slice![Token: ],
468                static_slice![Token: ],
469                message_vec![&[E, S], &[E, Dh(Ee)], &[Dh(Se)]],
470            ),
471            I1K => (
472                static_slice![Token: ],
473                static_slice![Token: S],
474                message_vec![&[E, Dh(Es), S], &[E, Dh(Ee)], &[Dh(Se)]],
475            ),
476            IK1 => (
477                static_slice![Token: ],
478                static_slice![Token: S],
479                message_vec![&[E, S], &[E, Dh(Ee), Dh(Se), Dh(Es)]],
480            ),
481            I1K1 => (
482                static_slice![Token: ],
483                static_slice![Token: S],
484                message_vec![&[E, S], &[E, Dh(Ee), Dh(Es)], &[Dh(Se)]],
485            ),
486            I1X => (
487                static_slice![Token: ],
488                static_slice![Token: ],
489                message_vec![&[E, S], &[E, Dh(Ee), S, Dh(Es)], &[Dh(Se)]],
490            ),
491            IX1 => (
492                static_slice![Token: ],
493                static_slice![Token: ],
494                message_vec![&[E, S], &[E, Dh(Ee), Dh(Se), S], &[Dh(Es)]],
495            ),
496            I1X1 => (
497                static_slice![Token: ],
498                static_slice![Token: ],
499                message_vec![&[E, S], &[E, Dh(Ee), S], &[Dh(Se), Dh(Es)]],
500            ),
501        };
502
503        for modifier in &handshake.modifiers.list {
504            match modifier {
505                HandshakeModifier::Psk(n) => apply_psk_modifier(&mut patterns, *n)?,
506                #[cfg(feature = "hfs")]
507                HandshakeModifier::Hfs => apply_hfs_modifier(&mut patterns),
508                _ => return Err(PatternProblem::UnsupportedModifier.into()),
509            }
510        }
511
512        Ok(HandshakeTokens {
513            premsg_pattern_i: patterns.0,
514            premsg_pattern_r: patterns.1,
515            msg_patterns: patterns.2,
516        })
517    }
518}
519
520#[cfg(feature = "hfs")]
521/// Check that this handshake is not HFS *and* one-way.
522///
523/// Usage of HFS in conjuction with a oneway pattern is invalid. This function returns an error
524/// if `handshake` is invalid because of this. Otherwise it will return `()`.
525fn check_hfs_and_oneway_conflict(handshake: &HandshakeChoice) -> Result<(), Error> {
526    if handshake.is_hfs() && handshake.pattern.is_oneway() {
527        Err(PatternProblem::UnsupportedModifier.into())
528    } else {
529        Ok(())
530    }
531}
532
533/// Given our PSK modifier, we inject the token at the appropriate place.
534fn apply_psk_modifier(patterns: &mut Patterns, n: u8) -> Result<(), Error> {
535    let tokens = patterns
536        .2
537        .get_mut(usize::from(n).saturating_sub(1))
538        .ok_or(Error::Pattern(PatternProblem::InvalidPsk))?;
539    if n == 0 {
540        tokens.insert(0, Token::Psk(n));
541    } else {
542        tokens.push(Token::Psk(n));
543    }
544    Ok(())
545}
546
547#[cfg(feature = "hfs")]
548fn apply_hfs_modifier(patterns: &mut Patterns) {
549    // From the HFS spec, Section 5:
550    //
551    //     Add an "e1" token directly following the first occurence of "e",
552    //     unless there is a DH operation in this same message, in which case
553    //     the "hfs" [should be "e1"?] token is placed directly after this DH
554    //     (so that the public key will be encrypted).
555    //
556    //     The "hfs" modifier also adds an "ekem1" token directly following the
557    //     first occurrence of "ee".
558
559    // Add the e1 token
560    let mut e1_insert_idx = None;
561    for msg in &mut patterns.2 {
562        if let Some(e_idx) = msg.iter().position(|x| *x == Token::E) {
563            if let Some(dh_idx) = msg.iter().copied().position(Token::is_dh) {
564                e1_insert_idx = Some(dh_idx + 1);
565            } else {
566                e1_insert_idx = Some(e_idx + 1);
567            }
568        }
569        if let Some(idx) = e1_insert_idx {
570            msg.insert(idx, Token::E1);
571            break;
572        }
573    }
574
575    // Add the ekem1 token
576    let mut ekem1_insert_idx = None;
577    for msg in &mut patterns.2 {
578        if let Some(ee_idx) = msg.iter().position(|x| *x == Token::Dh(Ee)) {
579            ekem1_insert_idx = Some(ee_idx + 1);
580        }
581        if let Some(idx) = ekem1_insert_idx {
582            msg.insert(idx, Token::Ekem1);
583            break;
584        }
585    }
586
587    // This should not be possible, because the caller verified that the
588    // HandshakePattern is not one-way.
589    assert!(
590        !(e1_insert_idx.is_some() ^ ekem1_insert_idx.is_some()),
591        "handshake messages contain one of the ['e1', 'ekem1'] tokens, but not the other",
592    );
593}