rsasl/
mechname.rs

1//! Utilities for handling and validating names of Mechanisms
2//!
3use core::convert::TryFrom;
4
5use core::fmt;
6use core::ops::Deref;
7use thiserror::Error;
8
9use crate::mechname::MechanismNameError::InvalidChar;
10
11#[repr(transparent)]
12#[derive(Eq, PartialEq)]
13/// A validated Mechanism name (akin to [`str`])
14///
15/// This struct, like `str`, is only ever passed by reference since it's `!Sized`. The main
16/// reason to have this struct is to ensure at type level and with no run-time overhead that a
17/// passed mechanism name was verified.
18///
19/// The main way to construct a `Mechname` is by calling [`Mechname::parse`].
20///
21/// This type implements `Deref<Target=str>` so it can be used anywhere where `&str` is expected.
22/// Alternatively the methods [`Mechname::as_str`] and [`Mechname::as_bytes`] can be used to
23/// manually extract a `&str` and `&[u8]` respectively.
24///
25/// Note: While RFC 4422 Section 3.1 explicitly limits Mechanism name to 20 characters or less you
26/// **SHOULD NOT** rely on this behaviour as there are mechanisms in use that break this
27/// rule, e.g. `ECDSA-NIST256P-CHALLENGE` (25 chars) used by some IRCv3 implementations.
28pub struct Mechname {
29    inner: str,
30}
31
32impl Mechname {
33    /// Convert a byte slice into a `&Mechname` after checking it for validity.
34    ///
35    ///
36    pub fn parse(input: &[u8]) -> Result<&Self, MechanismNameError> {
37        if input.is_empty() {
38            Err(MechanismNameError::TooShort)
39        } else {
40            input.iter().enumerate().try_for_each(|(index, value)| {
41                if is_invalid(*value) {
42                    Err(InvalidChar {
43                        index,
44                        value: *value,
45                    })
46                } else {
47                    Ok(())
48                }
49            })?;
50            Ok(Self::const_new(input))
51        }
52    }
53
54    #[must_use]
55    #[inline(always)]
56    pub const fn as_str(&self) -> &str {
57        &self.inner
58    }
59
60    #[must_use]
61    #[inline(always)]
62    pub const fn as_bytes(&self) -> &[u8] {
63        self.inner.as_bytes()
64    }
65
66    pub(crate) const fn const_new(s: &[u8]) -> &Self {
67        unsafe { core::mem::transmute(s) }
68    }
69}
70
71#[cfg(feature = "unstable_custom_mechanism")]
72/// These associated functions are only available with feature `unstable_custom_mechanism`. They
73/// are *not guaranteed to be stable under semver*
74impl Mechname {
75    #[inline(always)]
76    /// `const` capable conversion from `&'a [u8]` to `&'a Mechname` with no validity checking.
77    ///
78    /// While this is safe from a memory protection standpoint since `&Mechname` and `&[u8]` have
79    /// the exact same representation it can be used to break the contract of `Mechname` only
80    /// containing a subset of ASCII, which may result in undefined behaviour.
81    ///
82    /// Uses transmute due to [rustc issue #51911](https://github.com/rust-lang/rust/issues/51911)
83    #[must_use]
84    pub const fn const_new_unchecked(s: &[u8]) -> &Self {
85        Self::const_new(s)
86    }
87}
88
89impl fmt::Display for Mechname {
90    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
91        f.write_str(self.as_str())
92    }
93}
94
95impl fmt::Debug for Mechname {
96    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
97        write!(f, "MECHANISM({})", self.as_str())
98    }
99}
100
101impl PartialEq<[u8]> for Mechname {
102    fn eq(&self, other: &[u8]) -> bool {
103        self.as_bytes() == other
104    }
105}
106impl PartialEq<Mechname> for [u8] {
107    fn eq(&self, other: &Mechname) -> bool {
108        self == other.as_bytes()
109    }
110}
111
112impl PartialEq<str> for Mechname {
113    fn eq(&self, other: &str) -> bool {
114        self.as_str() == other
115    }
116}
117impl PartialEq<Mechname> for str {
118    fn eq(&self, other: &Mechname) -> bool {
119        self == other.as_str()
120    }
121}
122
123impl<'a> TryFrom<&'a [u8]> for &'a Mechname {
124    type Error = MechanismNameError;
125
126    fn try_from(value: &'a [u8]) -> Result<Self, Self::Error> {
127        Mechname::parse(value)
128    }
129}
130
131impl<'a> TryFrom<&'a str> for &'a Mechname {
132    type Error = MechanismNameError;
133
134    fn try_from(value: &'a str) -> Result<Self, Self::Error> {
135        Mechname::parse(value.as_bytes())
136    }
137}
138
139impl Deref for Mechname {
140    type Target = str;
141
142    fn deref(&self) -> &Self::Target {
143        self.as_str()
144    }
145}
146
147#[inline(always)]
148const fn is_invalid(byte: u8) -> bool {
149    !(is_valid(byte))
150}
151
152#[inline(always)]
153const fn is_valid(byte: u8) -> bool {
154    // RFC 4422 section 3.1 limits mechanism names to:
155    //     sasl-mech    = 1*20mech-char
156    //     mech-char    = UPPER-ALPHA / DIGIT / HYPHEN / UNDERSCORE
157    //     ; mech-char is restricted to A-Z (uppercase only), 0-9, -, and _
158    //     ; from ASCII character set.
159    core::matches!(byte, b'A'..=b'Z' | b'0'..=b'9' | b'-' | b'_')
160}
161
162#[derive(Debug, Ord, PartialOrd, Eq, PartialEq, Copy, Clone, Error)]
163#[non_exhaustive]
164pub enum MechanismNameError {
165    /// Mechanism name is shorter than 1 character
166    #[error("a mechanism name can not be empty")]
167    TooShort,
168
169    /// Mechanism name contained a character outside of [A-Z0-9-_] at `index`
170    ///
171    ///
172    #[error("contains invalid character at offset {index}: {value:#x}")]
173    InvalidChar {
174        /// Index of the invalid character byte
175        index: usize,
176        /// Value of the invalid character byte
177        value: u8,
178    },
179}
180
181#[cfg(test)]
182mod tests {
183    use super::*;
184
185    #[test]
186    fn test_mechname() {
187        let valids = [
188            "PLAIN",
189            "SCRAM-SHA256-PLUS",
190            "GS2-KRB5-PLUS",
191            "XOAUTHBEARER",
192            "EXACTLY_20_CHAR_LONG",
193            "X-THIS-MECHNAME-IS-TOO-LONG",
194            "EXACTLY_21_CHARS_LONG",
195        ];
196        let invalidchars = [
197            ("PLAIN GSSAPI LOGIN", 5, b' '),
198            ("SCRAM-SHA256-PLUS GSSAPI X-OAUTH2", 17, b' '),
199            ("X-CONTAINS-NULL\0", 15, b'\0'),
200            ("PLAIN\0", 5, b'\0'),
201            ("X-lowercase", 2, b'l'),
202            ("X-LÄTIN1", 3, b'\xC3'),
203        ];
204
205        for m in valids {
206            println!("Checking {m}");
207            let res = Mechname::parse(m.as_bytes()).map(Mechname::as_bytes);
208            assert_eq!(res, Ok(m.as_bytes()));
209        }
210        for (m, index, value) in invalidchars {
211            let e = Mechname::parse(m.as_bytes())
212                .map(Mechname::as_bytes)
213                .unwrap_err();
214            println!("Checking {m}: {e}");
215            assert_eq!(e, MechanismNameError::InvalidChar { index, value });
216        }
217    }
218}