rsasl/mechanisms/scram/
parser.rs

1use crate::alloc::{borrow::Cow, string::String, vec::Vec};
2use crate::error::{MechanismError, MechanismErrorKind};
3use core::fmt::{Display, Formatter};
4use core::str::Utf8Error;
5use thiserror::Error;
6
7#[derive(Debug, Error, Copy, Clone, Eq, PartialEq)]
8pub enum SaslNameError {
9    #[error("empty string is invalid for name")]
10    Empty,
11    #[error("name contains invalid utf-8: {0}")]
12    InvalidUtf8(
13        #[from]
14        #[source]
15        Utf8Error,
16    ),
17    #[error("name contains invalid char {0}")]
18    InvalidChar(u8),
19    #[error("name contains invalid escape sequence")]
20    InvalidEscape,
21}
22
23impl MechanismError for SaslNameError {
24    fn kind(&self) -> MechanismErrorKind {
25        MechanismErrorKind::Parse
26    }
27}
28
29#[derive(Clone)]
30enum SaslEscapeState {
31    Done,
32    Char(char),
33    Comma,
34    Comma1,
35    Equals,
36    Equals1,
37}
38
39impl SaslEscapeState {
40    pub const fn escape(c: char) -> Self {
41        match c {
42            ',' => Self::Comma,
43            '=' => Self::Equals,
44            _ => Self::Char(c),
45        }
46    }
47}
48
49impl Iterator for SaslEscapeState {
50    type Item = char;
51
52    fn next(&mut self) -> Option<Self::Item> {
53        match *self {
54            Self::Done => None,
55            Self::Char(c) => {
56                *self = Self::Done;
57                Some(c)
58            }
59            Self::Comma => {
60                *self = Self::Comma1;
61                Some('=')
62            }
63            Self::Comma1 => {
64                *self = Self::Char('C');
65                Some('2')
66            }
67            Self::Equals => {
68                *self = Self::Equals1;
69                Some('=')
70            }
71            Self::Equals1 => {
72                *self = Self::Char('D');
73                Some('3')
74            }
75        }
76    }
77
78    #[inline]
79    fn size_hint(&self) -> (usize, Option<usize>) {
80        let n = self.len();
81        (n, Some(n))
82    }
83}
84
85impl ExactSizeIterator for SaslEscapeState {
86    fn len(&self) -> usize {
87        match self {
88            Self::Done => 0,
89            Self::Char(_) => 1,
90            Self::Comma | Self::Equals => 3,
91            Self::Comma1 | Self::Equals1 => 2,
92        }
93    }
94}
95
96#[repr(transparent)]
97/// Escaped saslname type
98pub struct SaslName<'a>(Cow<'a, str>);
99impl<'a> SaslName<'a> {
100    /// Convert a Rust-side string into the representation required by SCRAM
101    ///
102    /// This will clone the given string if characters need escaping
103    pub fn escape(input: &str) -> Result<Cow<'_, str>, SaslNameError> {
104        if input.is_empty() {
105            return Err(SaslNameError::Empty);
106        }
107        if input.contains('\0') {
108            return Err(SaslNameError::InvalidChar(0));
109        }
110
111        if input.contains([',', '=']) {
112            let escaped: String = input.chars().flat_map(SaslEscapeState::escape).collect();
113            Ok(Cow::Owned(escaped))
114        } else {
115            Ok(Cow::Borrowed(input))
116        }
117    }
118
119    #[allow(unused)]
120    /// Convert a SCRAM-side string into the representation expected by Rust
121    ///
122    /// This will clone the given string if characters need unescaping
123    pub fn unescape(input: &[u8]) -> Result<Cow<'_, str>, SaslNameError> {
124        if input.is_empty() {
125            return Err(SaslNameError::Empty);
126        }
127
128        if let Some(c) = input.iter().find(|byte| matches!(**byte, b'\0' | b',')) {
129            return Err(SaslNameError::InvalidChar(*c));
130        }
131
132        if let Some(bad) = input.iter().position(|b| matches!(b, b'=')) {
133            let mut out = String::with_capacity(input.len());
134            let good = core::str::from_utf8(&input[..bad]).map_err(SaslNameError::InvalidUtf8)?;
135            out.push_str(good);
136            let mut input = &input[bad..];
137
138            while let Some(bad) = input.iter().position(|b| matches!(b, b'=')) {
139                let good =
140                    core::str::from_utf8(&input[..bad]).map_err(SaslNameError::InvalidUtf8)?;
141                out.push_str(good);
142                let c = match &input[bad + 1..bad + 3] {
143                    b"2C" => ',',
144                    b"3D" => '=',
145                    _ => return Err(SaslNameError::InvalidEscape),
146                };
147                out.push(c);
148                input = &input[bad..];
149            }
150
151            Ok(out.into())
152        } else {
153            Ok(Cow::Borrowed(core::str::from_utf8(input)?))
154        }
155    }
156}
157
158#[derive(Copy, Clone, Eq, PartialEq, Debug, Error)]
159pub enum ParseError {
160    #[error("bad channel flag")]
161    BadCBFlag,
162    #[error("channel binding name contains invalid byte {0:#x}")]
163    BadCBName(u8),
164    #[error("invalid gs2header")]
165    BadGS2Header,
166    #[error("attribute contains invalid byte {0:#x}")]
167    InvalidAttribute(u8),
168    #[error("required attribute is missing")]
169    MissingAttributes,
170    #[error("an extension is unknown but marked mandatory")]
171    UnknownMandatoryExtensions,
172    #[error("invalid UTF-8: {0}")]
173    BadUtf8(
174        #[from]
175        #[source]
176        Utf8Error,
177    ),
178    #[error("nonce contains invalid character")]
179    BadNonce,
180}
181
182#[derive(Copy, Clone, Ord, PartialOrd, Eq, PartialEq, Debug)]
183pub enum GS2CBindFlag<'scram> {
184    SupportedNotUsed,
185    NotSupported,
186    /// Channel bindings of the given name are used
187    ///
188    /// RFC 5056 Section 7 limits the channel binding name to "any string composed of US-ASCII
189    /// alphanumeric characters, period ('.'), and dash ('-')", which is always valid UTF-8
190    /// making the use of `str` here correct.
191    Used(&'scram str),
192}
193impl<'scram> GS2CBindFlag<'scram> {
194    pub fn parse(input: &'scram [u8]) -> Result<Self, ParseError> {
195        match input {
196            b"n" => Ok(Self::NotSupported),
197            b"y" => Ok(Self::SupportedNotUsed),
198            _x if input.len() > 2 && input[0] == b'p' && input[1] == b'=' => {
199                let cbname = &input[2..];
200                cbname
201                    .iter()
202                    // According to [RFC5056 Section 7](https://www.rfc-editor.org/rfc/rfc5056#section-7)
203                    // valid cb names are only composed of ASCII alphanumeric, '.' and '-'
204                    .find(|b| !(matches!(b, b'.' | b'-' | b'0'..=b'9' | b'A'..=b'Z' | b'a'..=b'z')))
205                    .map_or_else(
206                        || {
207                            // SAFE because we just checked for a subset of ASCII which is always UTF-8
208                            let name = unsafe { core::str::from_utf8_unchecked(cbname) };
209                            Ok(Self::Used(name))
210                        },
211                        |bad| Err(ParseError::BadCBName(*bad)),
212                    )
213            }
214            _ => Err(ParseError::BadCBFlag),
215        }
216    }
217
218    pub const fn as_ioslices(&self) -> [&'scram [u8]; 2] {
219        match self {
220            Self::NotSupported => [b"n", &[]],
221            Self::SupportedNotUsed => [b"y", &[]],
222            Self::Used(name) => [b"p=", name.as_bytes()],
223        }
224    }
225}
226
227#[derive(Copy, Clone, Ord, PartialOrd, Eq, PartialEq, Debug)]
228pub struct ClientFirstMessage<'scram> {
229    pub cbflag: GS2CBindFlag<'scram>,
230    pub authzid: Option<&'scram str>,
231    pub username: &'scram str,
232    pub nonce: &'scram [u8],
233}
234impl<'scram> ClientFirstMessage<'scram> {
235    #[allow(unused)]
236    pub const fn new(
237        cbflag: GS2CBindFlag<'scram>,
238        authzid: Option<&'scram str>,
239        username: &'scram str,
240        nonce: &'scram [u8],
241    ) -> Self {
242        Self {
243            cbflag,
244            authzid,
245            username,
246            nonce,
247        }
248    }
249
250    pub fn parse(input: &'scram [u8]) -> Result<Self, ParseError> {
251        let mut partiter = input.split(|b| matches!(b, b','));
252
253        let first = partiter.next().ok_or(ParseError::BadCBFlag)?;
254        let cbflag = GS2CBindFlag::parse(first)?;
255
256        let authzid = partiter.next().ok_or(ParseError::BadGS2Header)?;
257        let authzid = if authzid.is_empty() {
258            None
259        } else {
260            Some(core::str::from_utf8(&authzid[2..]).map_err(ParseError::BadUtf8)?)
261        };
262
263        let next = partiter.next().ok_or(ParseError::MissingAttributes)?;
264        if &next[0..2] == b"m=" {
265            return Err(ParseError::UnknownMandatoryExtensions);
266        }
267
268        let username = if &next[0..2] == b"n=" {
269            core::str::from_utf8(&next[2..]).map_err(ParseError::BadUtf8)?
270        } else {
271            return Err(ParseError::InvalidAttribute(next[0]));
272        };
273
274        let next = partiter.next().ok_or(ParseError::MissingAttributes)?;
275        let nonce = if &next[0..2] == b"r=" {
276            &next[2..]
277        } else {
278            return Err(ParseError::InvalidAttribute(next[0]));
279        };
280        if !nonce.iter().all(|b| matches!(b, 0x21..=0x2B | 0x2D..=0x7E)) {
281            return Err(ParseError::BadNonce);
282        }
283
284        Ok(Self {
285            cbflag,
286            authzid,
287            username,
288            nonce,
289        })
290    }
291
292    #[allow(clippy::similar_names)]
293    fn gs2_header_parts(&self) -> [&'scram [u8]; 4] {
294        let [cba, cbb] = self.cbflag.as_ioslices();
295
296        let (prefix, authzid): (&[u8], &[u8]) = self
297            .authzid
298            .map_or((b",", &[]), |authzid| (b",a=", authzid.as_bytes()));
299
300        [cba, cbb, prefix, authzid]
301    }
302
303    #[allow(clippy::similar_names)]
304    #[allow(unused)]
305    pub fn as_ioslices(&self) -> [&'scram [u8]; 8] {
306        let [cba, cbb, prefix, authzid] = self.gs2_header_parts();
307
308        [
309            cba,
310            cbb,
311            prefix,
312            authzid,
313            b",n=",
314            self.username.as_bytes(),
315            b",r=",
316            self.nonce,
317        ]
318    }
319
320    #[allow(clippy::similar_names)]
321    pub(super) fn build_gs2_header_vec(&self) -> Vec<u8> {
322        let [cba, cbb, prefix, authzid] = self.gs2_header_parts();
323
324        let gs2_header_len = cba.len() + cbb.len() + prefix.len() + authzid.len() + 1;
325        let mut gs2_header = Vec::with_capacity(gs2_header_len);
326
327        // y | n | p=
328        gs2_header.extend_from_slice(cba);
329        // &[] | cbname
330        gs2_header.extend_from_slice(cbb);
331        // b","
332        gs2_header.extend_from_slice(prefix);
333        // authzid
334        gs2_header.extend_from_slice(authzid);
335        // b","
336        gs2_header.extend_from_slice(b",");
337
338        gs2_header
339    }
340}
341
342#[derive(Copy, Clone, Ord, PartialOrd, Eq, PartialEq, Debug)]
343pub struct ServerFirst<'scram> {
344    /// Client or Client+Server Nonce
345    ///
346    /// If the field `server_nonce` is None this contains both client and server nonce
347    /// concatenated, otherwise it contains only the client nonce.
348    pub nonce: &'scram [u8],
349    pub server_nonce: Option<&'scram [u8]>,
350    pub salt: &'scram [u8],
351    pub iteration_count: &'scram [u8],
352}
353
354impl<'scram> ServerFirst<'scram> {
355    pub const fn new(
356        client_nonce: &'scram [u8],
357        server_nonce: &'scram [u8],
358        salt: &'scram [u8],
359        iteration_count: &'scram [u8],
360    ) -> Self {
361        Self {
362            nonce: client_nonce,
363            server_nonce: Some(server_nonce),
364            salt,
365            iteration_count,
366        }
367    }
368
369    pub fn parse(input: &'scram [u8]) -> Result<Self, ParseError> {
370        let mut partiter = input.split(|b| matches!(b, b','));
371
372        let next = partiter.next().ok_or(ParseError::MissingAttributes)?;
373        if next.len() < 2 {
374            return Err(ParseError::MissingAttributes);
375        }
376        if &next[0..2] == b"m=" {
377            return Err(ParseError::UnknownMandatoryExtensions);
378        }
379
380        let nonce = if &next[0..2] == b"r=" {
381            &next[2..]
382        } else {
383            return Err(ParseError::InvalidAttribute(next[0]));
384        };
385
386        let next = partiter.next().ok_or(ParseError::MissingAttributes)?;
387        let salt = if &next[0..2] == b"s=" {
388            &next[2..]
389        } else {
390            return Err(ParseError::InvalidAttribute(next[0]));
391        };
392
393        let next = partiter.next().ok_or(ParseError::MissingAttributes)?;
394        let iteration_count = if &next[0..2] == b"i=" {
395            &next[2..]
396        } else {
397            return Err(ParseError::InvalidAttribute(next[0]));
398        };
399
400        if let Some(next) = partiter.next() {
401            return Err(ParseError::InvalidAttribute(next[0]));
402        }
403
404        Ok(Self {
405            nonce,
406            server_nonce: None,
407            salt,
408            iteration_count,
409        })
410    }
411
412    pub fn as_ioslices(&self) -> [&'scram [u8]; 7] {
413        [
414            b"r=",
415            self.nonce,
416            self.server_nonce.unwrap_or(&[]),
417            b",s=",
418            self.salt,
419            b",i=",
420            self.iteration_count,
421        ]
422    }
423}
424
425pub struct ClientFinal<'scram> {
426    pub channel_binding: &'scram [u8],
427    pub nonce: &'scram [u8],
428    pub proof: &'scram [u8],
429}
430
431impl<'scram> ClientFinal<'scram> {
432    pub const fn new(
433        channel_binding: &'scram [u8],
434        nonce: &'scram [u8],
435        proof: &'scram [u8],
436    ) -> Self {
437        Self {
438            channel_binding,
439            nonce,
440            proof,
441        }
442    }
443
444    pub fn parse(input: &'scram [u8]) -> Result<Self, ParseError> {
445        let mut partiter = input.split(|b| matches!(b, b','));
446
447        let next = partiter.next().ok_or(ParseError::MissingAttributes)?;
448        let channel_binding = if &next[0..2] == b"c=" {
449            &next[2..]
450        } else {
451            return Err(ParseError::InvalidAttribute(next[0]));
452        };
453
454        let next = partiter.next().ok_or(ParseError::MissingAttributes)?;
455        let nonce = if &next[0..2] == b"r=" {
456            &next[2..]
457        } else {
458            return Err(ParseError::InvalidAttribute(next[0]));
459        };
460
461        let proof = loop {
462            // Skip all extensions in between nonce and proof since we can't handle them.
463            // If they are mandatory-to-implement extensions we error.
464            let next = partiter.next().ok_or(ParseError::MissingAttributes)?;
465            if &next[0..2] == b"p=" {
466                break &next[2..];
467            } else if &next[0..2] == b"m=" {
468                return Err(ParseError::UnknownMandatoryExtensions);
469            };
470        };
471
472        if let Some(next) = partiter.next() {
473            return Err(ParseError::InvalidAttribute(next[0]));
474        }
475
476        Ok(Self {
477            channel_binding,
478            nonce,
479            proof,
480        })
481    }
482
483    pub const fn to_ioslices(&self) -> [&'scram [u8]; 6] {
484        [
485            b"c=",
486            self.channel_binding,
487            b",r=",
488            self.nonce,
489            b",p=",
490            self.proof,
491        ]
492    }
493}
494
495#[derive(Copy, Clone, Ord, PartialOrd, Eq, PartialEq, Debug)]
496pub enum ServerErrorValue {
497    InvalidEncoding,
498    ExtensionsNotSupported,
499    InvalidProof,
500    ChannelBindingsDontMatch,
501    ServerDoesSupportChannelBinding,
502    ChannelBindingNotSupported,
503    UnsupportedChannelBindingType,
504    UnknownUser,
505    InvalidUsernameEncoding,
506    NoResources,
507    OtherError,
508}
509impl ServerErrorValue {
510    pub const fn as_bytes(self) -> &'static [u8] {
511        match self {
512            Self::InvalidEncoding => b"invalid-encoding",
513            Self::ExtensionsNotSupported => b"extensions-not-supported",
514            Self::InvalidProof => b"invalid-proof",
515            Self::ChannelBindingsDontMatch => b"channel-bindings-dont-match",
516            Self::ServerDoesSupportChannelBinding => b"server-does-support-channel-binding",
517            Self::ChannelBindingNotSupported => b"channel-binding-not-supported",
518            Self::UnsupportedChannelBindingType => b"unsupported-channel-binding-type",
519            Self::UnknownUser => b"unknown-user",
520            Self::InvalidUsernameEncoding => b"invalid-username-encoding",
521            Self::NoResources => b"no-resources",
522            Self::OtherError => b"other-error",
523        }
524    }
525
526    pub const fn as_str(self) -> &'static str {
527        match self {
528            Self::InvalidEncoding => "invalid encoding",
529            Self::ExtensionsNotSupported => "extensions not supported",
530            Self::InvalidProof => "invalid proof",
531            Self::ChannelBindingsDontMatch => "channel bindings dont match",
532            Self::ServerDoesSupportChannelBinding => "server does support channel binding",
533            Self::ChannelBindingNotSupported => "channel binding not supported",
534            Self::UnsupportedChannelBindingType => "unsupported channel binding type",
535            Self::UnknownUser => "unknown user",
536            Self::InvalidUsernameEncoding => "invalid username encoding",
537            Self::NoResources => "no resources",
538            Self::OtherError => "other error",
539        }
540    }
541}
542impl Display for ServerErrorValue {
543    fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
544        f.write_str(self.as_str())
545    }
546}
547
548pub enum ServerFinal<'scram> {
549    Verifier(&'scram [u8]),
550    Error(ServerErrorValue),
551}
552
553impl<'scram> ServerFinal<'scram> {
554    pub fn parse(input: &'scram [u8]) -> Result<Self, ParseError> {
555        if &input[0..2] == b"v=" {
556            Ok(Self::Verifier(&input[2..]))
557        } else if &input[0..2] == b"e=" {
558            use ServerErrorValue::{
559                ChannelBindingNotSupported, ChannelBindingsDontMatch, ExtensionsNotSupported,
560                InvalidEncoding, InvalidProof, InvalidUsernameEncoding, NoResources, OtherError,
561                ServerDoesSupportChannelBinding, UnknownUser, UnsupportedChannelBindingType,
562            };
563            let e = match &input[2..] {
564                b"invalid-encoding" => InvalidEncoding,
565                b"extensions-not-supported" => ExtensionsNotSupported,
566                b"invalid-proof" => InvalidProof,
567                b"channel-bindings-dont-match" => ChannelBindingsDontMatch,
568                b"server-does-support-channel-binding" => ServerDoesSupportChannelBinding,
569                b"channel-binding-not-supported" => ChannelBindingNotSupported,
570                b"unsupported-channel-binding-type" => UnsupportedChannelBindingType,
571                b"unknown-user" => UnknownUser,
572                b"invalid-username-encoding" => InvalidUsernameEncoding,
573                b"no-resources" => NoResources,
574                _ => OtherError,
575            };
576            Ok(Self::Error(e))
577        } else {
578            Err(ParseError::InvalidAttribute(input[0]))
579        }
580    }
581
582    pub const fn to_ioslices(&self) -> [&'scram [u8]; 2] {
583        match self {
584            Self::Verifier(v) => [b"v=", v],
585            Self::Error(e) => [b"e=", e.as_bytes()],
586        }
587    }
588}
589
590#[cfg(test)]
591mod tests {
592    use super::*;
593
594    #[test]
595    fn test_parse_gs2_cbind_flag() {
596        let valid: [(&[u8], GS2CBindFlag); 7] = [
597            (b"n", GS2CBindFlag::NotSupported),
598            (b"y", GS2CBindFlag::SupportedNotUsed),
599            (b"p=tls-unique", GS2CBindFlag::Used("tls-unique")),
600            (b"p=.", GS2CBindFlag::Used(".")),
601            (b"p=-", GS2CBindFlag::Used("-")),
602            (b"p=a", GS2CBindFlag::Used("a")),
603            (
604                b"p=a-very-long-cb-name.indeed",
605                GS2CBindFlag::Used("a-very-long-cb-name.indeed"),
606            ),
607        ];
608
609        for (input, output) in &valid {
610            assert_eq!(GS2CBindFlag::parse(input), Ok(*output));
611        }
612    }
613}