1#![allow(clippy::enum_glob_use)]
2
3#[cfg(not(feature = "std"))]
4use alloc::{vec, vec::Vec};
5
6use crate::error::{Error, PatternProblem};
7use core::{convert::TryFrom, str::FromStr};
8
9macro_rules! message_vec {
12 ($($item:expr),*) => ({
13 let token_groups: &[&[Token]] = &[$($item),*];
14 let mut vec: MessagePatterns = Vec::with_capacity(10);
15 for group in token_groups {
16 let mut inner = Vec::with_capacity(10);
17 inner.extend_from_slice(group);
18 vec.push(inner);
19 }
20 vec
21 });
22}
23
24macro_rules! pattern_enum {
30 ($name:ident {
34 $($variant:ident),* $(,)*
35 }) => {
36 #[allow(missing_docs)]
40 #[derive(Copy, Clone, PartialEq, Debug)]
41 pub enum $name {
42 $($variant),*,
43 }
44
45 impl FromStr for $name {
46 type Err = Error;
47 fn from_str(s: &str) -> Result<Self, Self::Err> {
48 use self::$name::*;
49 match s {
50 $(
51 stringify!($variant) => Ok($variant)
52 ),
53 *,
54 _ => return Err(PatternProblem::UnsupportedHandshakeType.into())
55 }
56 }
57 }
58
59 impl $name {
60 #[must_use] pub fn as_str(self) -> &'static str {
62 use self::$name::*;
63 match self {
64 $(
65 $variant => stringify!($variant)
66 ),
67 *
68 }
69 }
70 }
71
72 #[doc(hidden)]
73 pub const SUPPORTED_HANDSHAKE_PATTERNS: &'static [$name] = &[$($name::$variant),*];
74 }
75}
76
77#[derive(Copy, Clone, PartialEq, Debug)]
81pub(crate) enum DhToken {
82 Ee,
83 Es,
84 Se,
85 Ss,
86}
87
88#[derive(Copy, Clone, PartialEq, Debug)]
92pub(crate) enum Token {
93 E,
94 S,
95 Dh(DhToken),
96 Psk(u8),
97 #[cfg(feature = "hfs")]
98 E1,
99 #[cfg(feature = "hfs")]
100 Ekem1,
101}
102
103#[cfg(feature = "hfs")]
104impl Token {
105 fn is_dh(self) -> bool {
106 matches!(self, Dh(_))
107 }
108}
109
110pattern_enum! {
112 HandshakePattern {
113 N, X, K,
115
116 NN, NK, NX, XN, XK, XX, KN, KK, KX, IN, IK, IX,
118
119 NK1, NX1, X1N, X1K, XK1, X1K1, X1X, XX1, X1X1, K1N, K1K, KK1, K1K1, K1X,
121 KX1, K1X1, I1N, I1K, IK1, I1K1, I1X, IX1, I1X1
122 }
123}
124
125impl HandshakePattern {
126 #[must_use]
130 pub fn is_oneway(self) -> bool {
131 matches!(self, N | X | K)
132 }
133
134 #[must_use]
136 pub fn needs_local_static_key(self, initiator: bool) -> bool {
137 if initiator {
138 !matches!(self, N | NN | NK | NX | NK1 | NX1)
139 } else {
140 !matches!(self, NN | XN | KN | IN | X1N | K1N | I1N)
141 }
142 }
143
144 #[rustfmt::skip]
146 #[must_use] pub fn need_known_remote_pubkey(self, initiator: bool) -> bool {
147 if initiator {
148 matches!(
149 self,
150 N | K | X | NK | XK | KK | IK | NK1 | X1K | XK1 | X1K1 | K1K | KK1 | K1K1 | I1K | IK1 | I1K1
151 )
152 } else {
153 matches!(
154 self,
155 K | KN | KK | KX | K1N | K1K | KK1 | K1K1 | K1X | KX1 | K1X1
156 )
157 }
158 }
159}
160
161#[derive(Copy, Clone, PartialEq, Debug)]
163pub enum HandshakeModifier {
164 Psk(u8),
166
167 Fallback,
169
170 #[cfg(feature = "hfs")]
171 Hfs,
173}
174
175impl FromStr for HandshakeModifier {
176 type Err = Error;
177
178 fn from_str(s: &str) -> Result<Self, Self::Err> {
179 match s {
180 s if s.starts_with("psk") => {
181 Ok(HandshakeModifier::Psk(s[3..].parse().map_err(|_| PatternProblem::InvalidPsk)?))
182 },
183 "fallback" => Ok(HandshakeModifier::Fallback),
184 #[cfg(feature = "hfs")]
185 "hfs" => Ok(HandshakeModifier::Hfs),
186 _ => Err(PatternProblem::UnsupportedModifier.into()),
187 }
188 }
189}
190
191#[derive(Clone, PartialEq, Debug)]
193pub struct HandshakeModifierList {
194 pub list: Vec<HandshakeModifier>,
196}
197
198impl FromStr for HandshakeModifierList {
199 type Err = Error;
200
201 fn from_str(s: &str) -> Result<Self, Self::Err> {
202 if s.is_empty() {
203 Ok(HandshakeModifierList { list: vec![] })
204 } else {
205 let modifier_names = s.split('+');
206 let mut modifiers = vec![];
207 for modifier_name in modifier_names {
208 let modifier: HandshakeModifier = modifier_name.parse()?;
209 if modifiers.contains(&modifier) {
210 return Err(Error::Pattern(PatternProblem::DuplicateModifier));
211 }
212 modifiers.push(modifier);
213 }
214 Ok(HandshakeModifierList { list: modifiers })
215 }
216 }
217}
218
219#[derive(Clone, PartialEq, Debug)]
222pub struct HandshakeChoice {
223 pub pattern: HandshakePattern,
225
226 pub modifiers: HandshakeModifierList,
228}
229
230impl HandshakeChoice {
231 #[must_use]
233 pub fn is_psk(&self) -> bool {
234 for modifier in &self.modifiers.list {
235 if let HandshakeModifier::Psk(_) = *modifier {
236 return true;
237 }
238 }
239 false
240 }
241
242 #[must_use]
244 pub fn is_fallback(&self) -> bool {
245 self.modifiers.list.contains(&HandshakeModifier::Fallback)
246 }
247
248 #[cfg(feature = "hfs")]
250 #[must_use]
251 pub fn is_hfs(&self) -> bool {
252 self.modifiers.list.contains(&HandshakeModifier::Hfs)
253 }
254
255 fn parse_pattern_and_modifier(s: &str) -> Result<(HandshakePattern, &str), Error> {
257 for i in (1..=4).rev() {
258 if s.len() > i - 1 && s.is_char_boundary(i) {
259 if let Ok(p) = s[..i].parse() {
260 return Ok((p, &s[i..]));
261 }
262 }
263 }
264
265 Err(PatternProblem::UnsupportedHandshakeType.into())
266 }
267}
268
269impl FromStr for HandshakeChoice {
270 type Err = Error;
271
272 fn from_str(s: &str) -> Result<Self, Self::Err> {
273 let (pattern, remainder) = Self::parse_pattern_and_modifier(s)?;
274 let modifiers = remainder.parse()?;
275
276 Ok(HandshakeChoice { pattern, modifiers })
277 }
278}
279
280type PremessagePatterns = &'static [Token];
281pub(crate) type MessagePatterns = Vec<Vec<Token>>;
282
283#[derive(Debug)]
287pub(crate) struct HandshakeTokens {
288 pub premsg_pattern_i: PremessagePatterns,
289 pub premsg_pattern_r: PremessagePatterns,
290 pub msg_patterns: MessagePatterns,
291}
292
293use self::{DhToken::*, HandshakePattern::*, Token::*};
294
295type Patterns = (PremessagePatterns, PremessagePatterns, MessagePatterns);
296
297impl<'a> TryFrom<&'a HandshakeChoice> for HandshakeTokens {
298 type Error = Error;
299
300 #[allow(clippy::cognitive_complexity)]
303 #[allow(clippy::too_many_lines)]
304 fn try_from(handshake: &'a HandshakeChoice) -> Result<Self, Self::Error> {
305 #[cfg(feature = "hfs")]
307 check_hfs_and_oneway_conflict(handshake)?;
308
309 #[rustfmt::skip]
310 let mut patterns: Patterns = match handshake.pattern {
311 N => (
312 static_slice![Token: ],
313 static_slice![Token: S],
314 message_vec![&[E, Dh(Es)]]
315 ),
316 K => (
317 static_slice![Token: S],
318 static_slice![Token: S],
319 message_vec![&[E, Dh(Es), Dh(Ss)]]
320 ),
321 X => (
322 static_slice![Token: ],
323 static_slice![Token: S],
324 message_vec![&[E, Dh(Es), S, Dh(Ss)]]
325 ),
326 NN => (
327 static_slice![Token: ],
328 static_slice![Token: ],
329 message_vec![&[E], &[E, Dh(Ee)]]
330 ),
331 NK => (
332 static_slice![Token: ],
333 static_slice![Token: S],
334 message_vec![&[E, Dh(Es)], &[E, Dh(Ee)]]
335 ),
336 NX => (
337 static_slice![Token: ],
338 static_slice![Token: ],
339 message_vec![&[E], &[E, Dh(Ee), S, Dh(Es)]]
340 ),
341 XN => (
342 static_slice![Token: ],
343 static_slice![Token: ],
344 message_vec![&[E], &[E, Dh(Ee)], &[S, Dh(Se)]]
345 ),
346 XK => (
347 static_slice![Token: ],
348 static_slice![Token: S],
349 message_vec![&[E, Dh(Es)], &[E, Dh(Ee)], &[S, Dh(Se)]]
350 ),
351 XX => (
352 static_slice![Token: ],
353 static_slice![Token: ],
354 message_vec![&[E], &[E, Dh(Ee), S, Dh(Es)], &[S, Dh(Se)]],
355 ),
356 KN => (
357 static_slice![Token: S],
358 static_slice![Token: ],
359 message_vec![&[E], &[E, Dh(Ee), Dh(Se)]],
360 ),
361 KK => (
362 static_slice![Token: S],
363 static_slice![Token: S],
364 message_vec![&[E, Dh(Es), Dh(Ss)], &[E, Dh(Ee), Dh(Se)]],
365 ),
366 KX => (
367 static_slice![Token: S],
368 static_slice![Token: ],
369 message_vec![&[E], &[E, Dh(Ee), Dh(Se), S, Dh(Es)]],
370 ),
371 IN => (
372 static_slice![Token: ],
373 static_slice![Token: ],
374 message_vec![&[E, S], &[E, Dh(Ee), Dh(Se)]],
375 ),
376 IK => (
377 static_slice![Token: ],
378 static_slice![Token: S],
379 message_vec![&[E, Dh(Es), S, Dh(Ss)], &[E, Dh(Ee), Dh(Se)]],
380 ),
381 IX => (
382 static_slice![Token: ],
383 static_slice![Token: ],
384 message_vec![&[E, S], &[E, Dh(Ee), Dh(Se), S, Dh(Es)]],
385 ),
386 NK1 => (
387 static_slice![Token: ],
388 static_slice![Token: S],
389 message_vec![&[E], &[E, Dh(Ee), Dh(Es)]],
390 ),
391 NX1 => (
392 static_slice![Token: ],
393 static_slice![Token: ],
394 message_vec![&[E], &[E, Dh(Ee), S], &[Dh(Es)]]
395 ),
396 X1N => (
397 static_slice![Token: ],
398 static_slice![Token: ],
399 message_vec![&[E], &[E, Dh(Ee)], &[S], &[Dh(Se)]]
400 ),
401 X1K => (
402 static_slice![Token: ],
403 static_slice![Token: S],
404 message_vec![&[E, Dh(Es)], &[E, Dh(Ee)], &[S], &[Dh(Se)]]
405 ),
406 XK1 => (
407 static_slice![Token: ],
408 static_slice![Token: S],
409 message_vec![&[E], &[E, Dh(Ee), Dh(Es)], &[S, Dh(Se)]]
410 ),
411 X1K1 => (
412 static_slice![Token: ],
413 static_slice![Token: S],
414 message_vec![&[E], &[E, Dh(Ee), Dh(Es)], &[S], &[Dh(Se)]]
415 ),
416 X1X => (
417 static_slice![Token: ],
418 static_slice![Token: ],
419 message_vec![&[E], &[E, Dh(Ee), S, Dh(Es)], &[S], &[Dh(Se)]],
420 ),
421 XX1 => (
422 static_slice![Token: ],
423 static_slice![Token: ],
424 message_vec![&[E], &[E, Dh(Ee), S], &[Dh(Es), S, Dh(Se)]],
425 ),
426 X1X1 => (
427 static_slice![Token: ],
428 static_slice![Token: ],
429 message_vec![&[E], &[E, Dh(Ee), S], &[Dh(Es), S], &[Dh(Se)]],
430 ),
431 K1N => (
432 static_slice![Token: S],
433 static_slice![Token: ],
434 message_vec![&[E], &[E, Dh(Ee)], &[Dh(Se)]],
435 ),
436 K1K => (
437 static_slice![Token: S],
438 static_slice![Token: S],
439 message_vec![&[E, Dh(Es)], &[E, Dh(Ee)], &[Dh(Se)]],
440 ),
441 KK1 => (
442 static_slice![Token: S],
443 static_slice![Token: S],
444 message_vec![&[E], &[E, Dh(Ee), Dh(Se), Dh(Es)]],
445 ),
446 K1K1 => (
447 static_slice![Token: S],
448 static_slice![Token: S],
449 message_vec![&[E], &[E, Dh(Ee), Dh(Es)], &[Dh(Se)]],
450 ),
451 K1X => (
452 static_slice![Token: S],
453 static_slice![Token: ],
454 message_vec![&[E], &[E, Dh(Ee), S, Dh(Es)], &[Dh(Se)]],
455 ),
456 KX1 => (
457 static_slice![Token: S],
458 static_slice![Token: ],
459 message_vec![&[E], &[E, Dh(Ee), Dh(Se), S], &[Dh(Es)]],
460 ),
461 K1X1 => (
462 static_slice![Token: S],
463 static_slice![Token: ],
464 message_vec![&[E], &[E, Dh(Ee), S], &[Dh(Se), Dh(Es)]],
465 ),
466 I1N => (
467 static_slice![Token: ],
468 static_slice![Token: ],
469 message_vec![&[E, S], &[E, Dh(Ee)], &[Dh(Se)]],
470 ),
471 I1K => (
472 static_slice![Token: ],
473 static_slice![Token: S],
474 message_vec![&[E, Dh(Es), S], &[E, Dh(Ee)], &[Dh(Se)]],
475 ),
476 IK1 => (
477 static_slice![Token: ],
478 static_slice![Token: S],
479 message_vec![&[E, S], &[E, Dh(Ee), Dh(Se), Dh(Es)]],
480 ),
481 I1K1 => (
482 static_slice![Token: ],
483 static_slice![Token: S],
484 message_vec![&[E, S], &[E, Dh(Ee), Dh(Es)], &[Dh(Se)]],
485 ),
486 I1X => (
487 static_slice![Token: ],
488 static_slice![Token: ],
489 message_vec![&[E, S], &[E, Dh(Ee), S, Dh(Es)], &[Dh(Se)]],
490 ),
491 IX1 => (
492 static_slice![Token: ],
493 static_slice![Token: ],
494 message_vec![&[E, S], &[E, Dh(Ee), Dh(Se), S], &[Dh(Es)]],
495 ),
496 I1X1 => (
497 static_slice![Token: ],
498 static_slice![Token: ],
499 message_vec![&[E, S], &[E, Dh(Ee), S], &[Dh(Se), Dh(Es)]],
500 ),
501 };
502
503 for modifier in &handshake.modifiers.list {
504 match modifier {
505 HandshakeModifier::Psk(n) => apply_psk_modifier(&mut patterns, *n)?,
506 #[cfg(feature = "hfs")]
507 HandshakeModifier::Hfs => apply_hfs_modifier(&mut patterns),
508 _ => return Err(PatternProblem::UnsupportedModifier.into()),
509 }
510 }
511
512 Ok(HandshakeTokens {
513 premsg_pattern_i: patterns.0,
514 premsg_pattern_r: patterns.1,
515 msg_patterns: patterns.2,
516 })
517 }
518}
519
520#[cfg(feature = "hfs")]
521fn check_hfs_and_oneway_conflict(handshake: &HandshakeChoice) -> Result<(), Error> {
526 if handshake.is_hfs() && handshake.pattern.is_oneway() {
527 Err(PatternProblem::UnsupportedModifier.into())
528 } else {
529 Ok(())
530 }
531}
532
533fn apply_psk_modifier(patterns: &mut Patterns, n: u8) -> Result<(), Error> {
535 let tokens = patterns
536 .2
537 .get_mut(usize::from(n).saturating_sub(1))
538 .ok_or(Error::Pattern(PatternProblem::InvalidPsk))?;
539 if n == 0 {
540 tokens.insert(0, Token::Psk(n));
541 } else {
542 tokens.push(Token::Psk(n));
543 }
544 Ok(())
545}
546
547#[cfg(feature = "hfs")]
548fn apply_hfs_modifier(patterns: &mut Patterns) {
549 let mut e1_insert_idx = None;
561 for msg in &mut patterns.2 {
562 if let Some(e_idx) = msg.iter().position(|x| *x == Token::E) {
563 if let Some(dh_idx) = msg.iter().copied().position(Token::is_dh) {
564 e1_insert_idx = Some(dh_idx + 1);
565 } else {
566 e1_insert_idx = Some(e_idx + 1);
567 }
568 }
569 if let Some(idx) = e1_insert_idx {
570 msg.insert(idx, Token::E1);
571 break;
572 }
573 }
574
575 let mut ekem1_insert_idx = None;
577 for msg in &mut patterns.2 {
578 if let Some(ee_idx) = msg.iter().position(|x| *x == Token::Dh(Ee)) {
579 ekem1_insert_idx = Some(ee_idx + 1);
580 }
581 if let Some(idx) = ekem1_insert_idx {
582 msg.insert(idx, Token::Ekem1);
583 break;
584 }
585 }
586
587 assert!(
590 !(e1_insert_idx.is_some() ^ ekem1_insert_idx.is_some()),
591 "handshake messages contain one of the ['e1', 'ekem1'] tokens, but not the other",
592 );
593}