1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
//! Utilities for handling and validating names of Mechanisms
//!
use core::convert::TryFrom;

use core::fmt;
use core::ops::Deref;
use thiserror::Error;

use crate::mechname::MechanismNameError::InvalidChar;

#[repr(transparent)]
#[derive(Eq, PartialEq)]
/// A validated Mechanism name (akin to [`str`])
///
/// This struct, like `str`, is only ever passed by reference since it's `!Sized`. The main
/// reason to have this struct is to ensure at type level and with no run-time overhead that a
/// passed mechanism name was verified.
///
/// The main way to construct a `Mechname` is by calling [`Mechname::parse`].
///
/// This type implements `Deref<Target=str>` so it can be used anywhere where `&str` is expected.
/// Alternatively the methods [`Mechname::as_str`] and [`Mechname::as_bytes`] can be used to
/// manually extract a `&str` and `&[u8]` respectively.
///
/// Note: While RFC 4422 Section 3.1 explicitly limits Mechanism name to 20 characters or less you
/// **SHOULD NOT** rely on this behaviour as there are mechanisms in use that break this
/// rule, e.g. `ECDSA-NIST256P-CHALLENGE` (25 chars) used by some IRCv3 implementations.
pub struct Mechname {
    inner: str,
}

impl Mechname {
    /// Convert a byte slice into a `&Mechname` after checking it for validity.
    ///
    ///
    pub fn parse(input: &[u8]) -> Result<&Self, MechanismNameError> {
        if input.is_empty() {
            Err(MechanismNameError::TooShort)
        } else {
            input.iter().enumerate().try_for_each(|(index, value)| {
                if is_invalid(*value) {
                    Err(InvalidChar {
                        index,
                        value: *value,
                    })
                } else {
                    Ok(())
                }
            })?;
            Ok(Self::const_new(input))
        }
    }

    #[must_use]
    #[inline(always)]
    pub const fn as_str(&self) -> &str {
        &self.inner
    }

    #[must_use]
    #[inline(always)]
    pub const fn as_bytes(&self) -> &[u8] {
        self.inner.as_bytes()
    }

    pub(crate) const fn const_new(s: &[u8]) -> &Self {
        unsafe { core::mem::transmute(s) }
    }
}

#[cfg(feature = "unstable_custom_mechanism")]
/// These associated functions are only available with feature `unstable_custom_mechanism`. They
/// are *not guaranteed to be stable under semver*
impl Mechname {
    #[inline(always)]
    /// `const` capable conversion from `&'a [u8]` to `&'a Mechname` with no validity checking.
    ///
    /// While this is safe from a memory protection standpoint since `&Mechname` and `&[u8]` have
    /// the exact same representation it can be used to break the contract of `Mechname` only
    /// containing a subset of ASCII, which may result in undefined behaviour.
    ///
    /// Uses transmute due to [rustc issue #51911](https://github.com/rust-lang/rust/issues/51911)
    #[must_use]
    pub const fn const_new_unchecked(s: &[u8]) -> &Self {
        Self::const_new(s)
    }
}

impl fmt::Display for Mechname {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.write_str(self.as_str())
    }
}

impl fmt::Debug for Mechname {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        write!(f, "MECHANISM({})", self.as_str())
    }
}

impl PartialEq<[u8]> for Mechname {
    fn eq(&self, other: &[u8]) -> bool {
        self.as_bytes() == other
    }
}
impl PartialEq<Mechname> for [u8] {
    fn eq(&self, other: &Mechname) -> bool {
        self == other.as_bytes()
    }
}

impl PartialEq<str> for Mechname {
    fn eq(&self, other: &str) -> bool {
        self.as_str() == other
    }
}
impl PartialEq<Mechname> for str {
    fn eq(&self, other: &Mechname) -> bool {
        self == other.as_str()
    }
}

impl<'a> TryFrom<&'a [u8]> for &'a Mechname {
    type Error = MechanismNameError;

    fn try_from(value: &'a [u8]) -> Result<Self, Self::Error> {
        Mechname::parse(value)
    }
}

impl<'a> TryFrom<&'a str> for &'a Mechname {
    type Error = MechanismNameError;

    fn try_from(value: &'a str) -> Result<Self, Self::Error> {
        Mechname::parse(value.as_bytes())
    }
}

impl Deref for Mechname {
    type Target = str;

    fn deref(&self) -> &Self::Target {
        self.as_str()
    }
}

#[inline(always)]
const fn is_invalid(byte: u8) -> bool {
    !(is_valid(byte))
}

#[inline(always)]
const fn is_valid(byte: u8) -> bool {
    // RFC 4422 section 3.1 limits mechanism names to:
    //     sasl-mech    = 1*20mech-char
    //     mech-char    = UPPER-ALPHA / DIGIT / HYPHEN / UNDERSCORE
    //     ; mech-char is restricted to A-Z (uppercase only), 0-9, -, and _
    //     ; from ASCII character set.
    core::matches!(byte, b'A'..=b'Z' | b'0'..=b'9' | b'-' | b'_')
}

#[derive(Debug, Ord, PartialOrd, Eq, PartialEq, Copy, Clone, Error)]
#[non_exhaustive]
pub enum MechanismNameError {
    /// Mechanism name is shorter than 1 character
    #[error("a mechanism name can not be empty")]
    TooShort,

    /// Mechanism name contained a character outside of [A-Z0-9-_] at `index`
    ///
    ///
    #[error("contains invalid character at offset {index}: {value:#x}")]
    InvalidChar {
        /// Index of the invalid character byte
        index: usize,
        /// Value of the invalid character byte
        value: u8,
    },
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_mechname() {
        let valids = [
            "PLAIN",
            "SCRAM-SHA256-PLUS",
            "GS2-KRB5-PLUS",
            "XOAUTHBEARER",
            "EXACTLY_20_CHAR_LONG",
            "X-THIS-MECHNAME-IS-TOO-LONG",
            "EXACTLY_21_CHARS_LONG",
        ];
        let invalidchars = [
            ("PLAIN GSSAPI LOGIN", 5, b' '),
            ("SCRAM-SHA256-PLUS GSSAPI X-OAUTH2", 17, b' '),
            ("X-CONTAINS-NULL\0", 15, b'\0'),
            ("PLAIN\0", 5, b'\0'),
            ("X-lowercase", 2, b'l'),
            ("X-LÄTIN1", 3, b'\xC3'),
        ];

        for m in valids {
            println!("Checking {m}");
            let res = Mechname::parse(m.as_bytes()).map(Mechname::as_bytes);
            assert_eq!(res, Ok(m.as_bytes()));
        }
        for (m, index, value) in invalidchars {
            let e = Mechname::parse(m.as_bytes())
                .map(Mechname::as_bytes)
                .unwrap_err();
            println!("Checking {m}: {e}");
            assert_eq!(e, MechanismNameError::InvalidChar { index, value });
        }
    }
}