1use crate::alloc::{borrow::Cow, string::String, vec::Vec};
2use crate::error::{MechanismError, MechanismErrorKind};
3use core::fmt::{Display, Formatter};
4use core::str::Utf8Error;
5use thiserror::Error;
6
7#[derive(Debug, Error, Copy, Clone, Eq, PartialEq)]
8pub enum SaslNameError {
9 #[error("empty string is invalid for name")]
10 Empty,
11 #[error("name contains invalid utf-8: {0}")]
12 InvalidUtf8(
13 #[from]
14 #[source]
15 Utf8Error,
16 ),
17 #[error("name contains invalid char {0}")]
18 InvalidChar(u8),
19 #[error("name contains invalid escape sequence")]
20 InvalidEscape,
21}
22
23impl MechanismError for SaslNameError {
24 fn kind(&self) -> MechanismErrorKind {
25 MechanismErrorKind::Parse
26 }
27}
28
29#[derive(Clone)]
30enum SaslEscapeState {
31 Done,
32 Char(char),
33 Comma,
34 Comma1,
35 Equals,
36 Equals1,
37}
38
39impl SaslEscapeState {
40 pub const fn escape(c: char) -> Self {
41 match c {
42 ',' => Self::Comma,
43 '=' => Self::Equals,
44 _ => Self::Char(c),
45 }
46 }
47}
48
49impl Iterator for SaslEscapeState {
50 type Item = char;
51
52 fn next(&mut self) -> Option<Self::Item> {
53 match *self {
54 Self::Done => None,
55 Self::Char(c) => {
56 *self = Self::Done;
57 Some(c)
58 }
59 Self::Comma => {
60 *self = Self::Comma1;
61 Some('=')
62 }
63 Self::Comma1 => {
64 *self = Self::Char('C');
65 Some('2')
66 }
67 Self::Equals => {
68 *self = Self::Equals1;
69 Some('=')
70 }
71 Self::Equals1 => {
72 *self = Self::Char('D');
73 Some('3')
74 }
75 }
76 }
77
78 #[inline]
79 fn size_hint(&self) -> (usize, Option<usize>) {
80 let n = self.len();
81 (n, Some(n))
82 }
83}
84
85impl ExactSizeIterator for SaslEscapeState {
86 fn len(&self) -> usize {
87 match self {
88 Self::Done => 0,
89 Self::Char(_) => 1,
90 Self::Comma | Self::Equals => 3,
91 Self::Comma1 | Self::Equals1 => 2,
92 }
93 }
94}
95
96#[repr(transparent)]
97pub struct SaslName<'a>(Cow<'a, str>);
99impl<'a> SaslName<'a> {
100 pub fn escape(input: &str) -> Result<Cow<'_, str>, SaslNameError> {
104 if input.is_empty() {
105 return Err(SaslNameError::Empty);
106 }
107 if input.contains('\0') {
108 return Err(SaslNameError::InvalidChar(0));
109 }
110
111 if input.contains([',', '=']) {
112 let escaped: String = input.chars().flat_map(SaslEscapeState::escape).collect();
113 Ok(Cow::Owned(escaped))
114 } else {
115 Ok(Cow::Borrowed(input))
116 }
117 }
118
119 #[allow(unused)]
120 pub fn unescape(input: &[u8]) -> Result<Cow<'_, str>, SaslNameError> {
124 if input.is_empty() {
125 return Err(SaslNameError::Empty);
126 }
127
128 if let Some(c) = input.iter().find(|byte| matches!(**byte, b'\0' | b',')) {
129 return Err(SaslNameError::InvalidChar(*c));
130 }
131
132 if let Some(bad) = input.iter().position(|b| matches!(b, b'=')) {
133 let mut out = String::with_capacity(input.len());
134 let good = core::str::from_utf8(&input[..bad]).map_err(SaslNameError::InvalidUtf8)?;
135 out.push_str(good);
136 let mut input = &input[bad..];
137
138 while let Some(bad) = input.iter().position(|b| matches!(b, b'=')) {
139 let good =
140 core::str::from_utf8(&input[..bad]).map_err(SaslNameError::InvalidUtf8)?;
141 out.push_str(good);
142 let c = match &input[bad + 1..bad + 3] {
143 b"2C" => ',',
144 b"3D" => '=',
145 _ => return Err(SaslNameError::InvalidEscape),
146 };
147 out.push(c);
148 input = &input[bad..];
149 }
150
151 Ok(out.into())
152 } else {
153 Ok(Cow::Borrowed(core::str::from_utf8(input)?))
154 }
155 }
156}
157
158#[derive(Copy, Clone, Eq, PartialEq, Debug, Error)]
159pub enum ParseError {
160 #[error("bad channel flag")]
161 BadCBFlag,
162 #[error("channel binding name contains invalid byte {0:#x}")]
163 BadCBName(u8),
164 #[error("invalid gs2header")]
165 BadGS2Header,
166 #[error("attribute contains invalid byte {0:#x}")]
167 InvalidAttribute(u8),
168 #[error("required attribute is missing")]
169 MissingAttributes,
170 #[error("an extension is unknown but marked mandatory")]
171 UnknownMandatoryExtensions,
172 #[error("invalid UTF-8: {0}")]
173 BadUtf8(
174 #[from]
175 #[source]
176 Utf8Error,
177 ),
178 #[error("nonce contains invalid character")]
179 BadNonce,
180}
181
182#[derive(Copy, Clone, Ord, PartialOrd, Eq, PartialEq, Debug)]
183pub enum GS2CBindFlag<'scram> {
184 SupportedNotUsed,
185 NotSupported,
186 Used(&'scram str),
192}
193impl<'scram> GS2CBindFlag<'scram> {
194 pub fn parse(input: &'scram [u8]) -> Result<Self, ParseError> {
195 match input {
196 b"n" => Ok(Self::NotSupported),
197 b"y" => Ok(Self::SupportedNotUsed),
198 _x if input.len() > 2 && input[0] == b'p' && input[1] == b'=' => {
199 let cbname = &input[2..];
200 cbname
201 .iter()
202 .find(|b| !(matches!(b, b'.' | b'-' | b'0'..=b'9' | b'A'..=b'Z' | b'a'..=b'z')))
205 .map_or_else(
206 || {
207 let name = unsafe { core::str::from_utf8_unchecked(cbname) };
209 Ok(Self::Used(name))
210 },
211 |bad| Err(ParseError::BadCBName(*bad)),
212 )
213 }
214 _ => Err(ParseError::BadCBFlag),
215 }
216 }
217
218 pub const fn as_ioslices(&self) -> [&'scram [u8]; 2] {
219 match self {
220 Self::NotSupported => [b"n", &[]],
221 Self::SupportedNotUsed => [b"y", &[]],
222 Self::Used(name) => [b"p=", name.as_bytes()],
223 }
224 }
225}
226
227#[derive(Copy, Clone, Ord, PartialOrd, Eq, PartialEq, Debug)]
228pub struct ClientFirstMessage<'scram> {
229 pub cbflag: GS2CBindFlag<'scram>,
230 pub authzid: Option<&'scram str>,
231 pub username: &'scram str,
232 pub nonce: &'scram [u8],
233}
234impl<'scram> ClientFirstMessage<'scram> {
235 #[allow(unused)]
236 pub const fn new(
237 cbflag: GS2CBindFlag<'scram>,
238 authzid: Option<&'scram str>,
239 username: &'scram str,
240 nonce: &'scram [u8],
241 ) -> Self {
242 Self {
243 cbflag,
244 authzid,
245 username,
246 nonce,
247 }
248 }
249
250 pub fn parse(input: &'scram [u8]) -> Result<Self, ParseError> {
251 let mut partiter = input.split(|b| matches!(b, b','));
252
253 let first = partiter.next().ok_or(ParseError::BadCBFlag)?;
254 let cbflag = GS2CBindFlag::parse(first)?;
255
256 let authzid = partiter.next().ok_or(ParseError::BadGS2Header)?;
257 let authzid = if authzid.is_empty() {
258 None
259 } else {
260 Some(core::str::from_utf8(&authzid[2..]).map_err(ParseError::BadUtf8)?)
261 };
262
263 let next = partiter.next().ok_or(ParseError::MissingAttributes)?;
264 if &next[0..2] == b"m=" {
265 return Err(ParseError::UnknownMandatoryExtensions);
266 }
267
268 let username = if &next[0..2] == b"n=" {
269 core::str::from_utf8(&next[2..]).map_err(ParseError::BadUtf8)?
270 } else {
271 return Err(ParseError::InvalidAttribute(next[0]));
272 };
273
274 let next = partiter.next().ok_or(ParseError::MissingAttributes)?;
275 let nonce = if &next[0..2] == b"r=" {
276 &next[2..]
277 } else {
278 return Err(ParseError::InvalidAttribute(next[0]));
279 };
280 if !nonce.iter().all(|b| matches!(b, 0x21..=0x2B | 0x2D..=0x7E)) {
281 return Err(ParseError::BadNonce);
282 }
283
284 Ok(Self {
285 cbflag,
286 authzid,
287 username,
288 nonce,
289 })
290 }
291
292 #[allow(clippy::similar_names)]
293 fn gs2_header_parts(&self) -> [&'scram [u8]; 4] {
294 let [cba, cbb] = self.cbflag.as_ioslices();
295
296 let (prefix, authzid): (&[u8], &[u8]) = self
297 .authzid
298 .map_or((b",", &[]), |authzid| (b",a=", authzid.as_bytes()));
299
300 [cba, cbb, prefix, authzid]
301 }
302
303 #[allow(clippy::similar_names)]
304 #[allow(unused)]
305 pub fn as_ioslices(&self) -> [&'scram [u8]; 8] {
306 let [cba, cbb, prefix, authzid] = self.gs2_header_parts();
307
308 [
309 cba,
310 cbb,
311 prefix,
312 authzid,
313 b",n=",
314 self.username.as_bytes(),
315 b",r=",
316 self.nonce,
317 ]
318 }
319
320 #[allow(clippy::similar_names)]
321 pub(super) fn build_gs2_header_vec(&self) -> Vec<u8> {
322 let [cba, cbb, prefix, authzid] = self.gs2_header_parts();
323
324 let gs2_header_len = cba.len() + cbb.len() + prefix.len() + authzid.len() + 1;
325 let mut gs2_header = Vec::with_capacity(gs2_header_len);
326
327 gs2_header.extend_from_slice(cba);
329 gs2_header.extend_from_slice(cbb);
331 gs2_header.extend_from_slice(prefix);
333 gs2_header.extend_from_slice(authzid);
335 gs2_header.extend_from_slice(b",");
337
338 gs2_header
339 }
340}
341
342#[derive(Copy, Clone, Ord, PartialOrd, Eq, PartialEq, Debug)]
343pub struct ServerFirst<'scram> {
344 pub nonce: &'scram [u8],
349 pub server_nonce: Option<&'scram [u8]>,
350 pub salt: &'scram [u8],
351 pub iteration_count: &'scram [u8],
352}
353
354impl<'scram> ServerFirst<'scram> {
355 pub const fn new(
356 client_nonce: &'scram [u8],
357 server_nonce: &'scram [u8],
358 salt: &'scram [u8],
359 iteration_count: &'scram [u8],
360 ) -> Self {
361 Self {
362 nonce: client_nonce,
363 server_nonce: Some(server_nonce),
364 salt,
365 iteration_count,
366 }
367 }
368
369 pub fn parse(input: &'scram [u8]) -> Result<Self, ParseError> {
370 let mut partiter = input.split(|b| matches!(b, b','));
371
372 let next = partiter.next().ok_or(ParseError::MissingAttributes)?;
373 if next.len() < 2 {
374 return Err(ParseError::MissingAttributes);
375 }
376 if &next[0..2] == b"m=" {
377 return Err(ParseError::UnknownMandatoryExtensions);
378 }
379
380 let nonce = if &next[0..2] == b"r=" {
381 &next[2..]
382 } else {
383 return Err(ParseError::InvalidAttribute(next[0]));
384 };
385
386 let next = partiter.next().ok_or(ParseError::MissingAttributes)?;
387 let salt = if &next[0..2] == b"s=" {
388 &next[2..]
389 } else {
390 return Err(ParseError::InvalidAttribute(next[0]));
391 };
392
393 let next = partiter.next().ok_or(ParseError::MissingAttributes)?;
394 let iteration_count = if &next[0..2] == b"i=" {
395 &next[2..]
396 } else {
397 return Err(ParseError::InvalidAttribute(next[0]));
398 };
399
400 if let Some(next) = partiter.next() {
401 return Err(ParseError::InvalidAttribute(next[0]));
402 }
403
404 Ok(Self {
405 nonce,
406 server_nonce: None,
407 salt,
408 iteration_count,
409 })
410 }
411
412 pub fn as_ioslices(&self) -> [&'scram [u8]; 7] {
413 [
414 b"r=",
415 self.nonce,
416 self.server_nonce.unwrap_or(&[]),
417 b",s=",
418 self.salt,
419 b",i=",
420 self.iteration_count,
421 ]
422 }
423}
424
425pub struct ClientFinal<'scram> {
426 pub channel_binding: &'scram [u8],
427 pub nonce: &'scram [u8],
428 pub proof: &'scram [u8],
429}
430
431impl<'scram> ClientFinal<'scram> {
432 pub const fn new(
433 channel_binding: &'scram [u8],
434 nonce: &'scram [u8],
435 proof: &'scram [u8],
436 ) -> Self {
437 Self {
438 channel_binding,
439 nonce,
440 proof,
441 }
442 }
443
444 pub fn parse(input: &'scram [u8]) -> Result<Self, ParseError> {
445 let mut partiter = input.split(|b| matches!(b, b','));
446
447 let next = partiter.next().ok_or(ParseError::MissingAttributes)?;
448 let channel_binding = if &next[0..2] == b"c=" {
449 &next[2..]
450 } else {
451 return Err(ParseError::InvalidAttribute(next[0]));
452 };
453
454 let next = partiter.next().ok_or(ParseError::MissingAttributes)?;
455 let nonce = if &next[0..2] == b"r=" {
456 &next[2..]
457 } else {
458 return Err(ParseError::InvalidAttribute(next[0]));
459 };
460
461 let proof = loop {
462 let next = partiter.next().ok_or(ParseError::MissingAttributes)?;
465 if &next[0..2] == b"p=" {
466 break &next[2..];
467 } else if &next[0..2] == b"m=" {
468 return Err(ParseError::UnknownMandatoryExtensions);
469 };
470 };
471
472 if let Some(next) = partiter.next() {
473 return Err(ParseError::InvalidAttribute(next[0]));
474 }
475
476 Ok(Self {
477 channel_binding,
478 nonce,
479 proof,
480 })
481 }
482
483 pub const fn to_ioslices(&self) -> [&'scram [u8]; 6] {
484 [
485 b"c=",
486 self.channel_binding,
487 b",r=",
488 self.nonce,
489 b",p=",
490 self.proof,
491 ]
492 }
493}
494
495#[derive(Copy, Clone, Ord, PartialOrd, Eq, PartialEq, Debug)]
496pub enum ServerErrorValue {
497 InvalidEncoding,
498 ExtensionsNotSupported,
499 InvalidProof,
500 ChannelBindingsDontMatch,
501 ServerDoesSupportChannelBinding,
502 ChannelBindingNotSupported,
503 UnsupportedChannelBindingType,
504 UnknownUser,
505 InvalidUsernameEncoding,
506 NoResources,
507 OtherError,
508}
509impl ServerErrorValue {
510 pub const fn as_bytes(self) -> &'static [u8] {
511 match self {
512 Self::InvalidEncoding => b"invalid-encoding",
513 Self::ExtensionsNotSupported => b"extensions-not-supported",
514 Self::InvalidProof => b"invalid-proof",
515 Self::ChannelBindingsDontMatch => b"channel-bindings-dont-match",
516 Self::ServerDoesSupportChannelBinding => b"server-does-support-channel-binding",
517 Self::ChannelBindingNotSupported => b"channel-binding-not-supported",
518 Self::UnsupportedChannelBindingType => b"unsupported-channel-binding-type",
519 Self::UnknownUser => b"unknown-user",
520 Self::InvalidUsernameEncoding => b"invalid-username-encoding",
521 Self::NoResources => b"no-resources",
522 Self::OtherError => b"other-error",
523 }
524 }
525
526 pub const fn as_str(self) -> &'static str {
527 match self {
528 Self::InvalidEncoding => "invalid encoding",
529 Self::ExtensionsNotSupported => "extensions not supported",
530 Self::InvalidProof => "invalid proof",
531 Self::ChannelBindingsDontMatch => "channel bindings dont match",
532 Self::ServerDoesSupportChannelBinding => "server does support channel binding",
533 Self::ChannelBindingNotSupported => "channel binding not supported",
534 Self::UnsupportedChannelBindingType => "unsupported channel binding type",
535 Self::UnknownUser => "unknown user",
536 Self::InvalidUsernameEncoding => "invalid username encoding",
537 Self::NoResources => "no resources",
538 Self::OtherError => "other error",
539 }
540 }
541}
542impl Display for ServerErrorValue {
543 fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
544 f.write_str(self.as_str())
545 }
546}
547
548pub enum ServerFinal<'scram> {
549 Verifier(&'scram [u8]),
550 Error(ServerErrorValue),
551}
552
553impl<'scram> ServerFinal<'scram> {
554 pub fn parse(input: &'scram [u8]) -> Result<Self, ParseError> {
555 if &input[0..2] == b"v=" {
556 Ok(Self::Verifier(&input[2..]))
557 } else if &input[0..2] == b"e=" {
558 use ServerErrorValue::{
559 ChannelBindingNotSupported, ChannelBindingsDontMatch, ExtensionsNotSupported,
560 InvalidEncoding, InvalidProof, InvalidUsernameEncoding, NoResources, OtherError,
561 ServerDoesSupportChannelBinding, UnknownUser, UnsupportedChannelBindingType,
562 };
563 let e = match &input[2..] {
564 b"invalid-encoding" => InvalidEncoding,
565 b"extensions-not-supported" => ExtensionsNotSupported,
566 b"invalid-proof" => InvalidProof,
567 b"channel-bindings-dont-match" => ChannelBindingsDontMatch,
568 b"server-does-support-channel-binding" => ServerDoesSupportChannelBinding,
569 b"channel-binding-not-supported" => ChannelBindingNotSupported,
570 b"unsupported-channel-binding-type" => UnsupportedChannelBindingType,
571 b"unknown-user" => UnknownUser,
572 b"invalid-username-encoding" => InvalidUsernameEncoding,
573 b"no-resources" => NoResources,
574 _ => OtherError,
575 };
576 Ok(Self::Error(e))
577 } else {
578 Err(ParseError::InvalidAttribute(input[0]))
579 }
580 }
581
582 pub const fn to_ioslices(&self) -> [&'scram [u8]; 2] {
583 match self {
584 Self::Verifier(v) => [b"v=", v],
585 Self::Error(e) => [b"e=", e.as_bytes()],
586 }
587 }
588}
589
590#[cfg(test)]
591mod tests {
592 use super::*;
593
594 #[test]
595 fn test_parse_gs2_cbind_flag() {
596 let valid: [(&[u8], GS2CBindFlag); 7] = [
597 (b"n", GS2CBindFlag::NotSupported),
598 (b"y", GS2CBindFlag::SupportedNotUsed),
599 (b"p=tls-unique", GS2CBindFlag::Used("tls-unique")),
600 (b"p=.", GS2CBindFlag::Used(".")),
601 (b"p=-", GS2CBindFlag::Used("-")),
602 (b"p=a", GS2CBindFlag::Used("a")),
603 (
604 b"p=a-very-long-cb-name.indeed",
605 GS2CBindFlag::Used("a-very-long-cb-name.indeed"),
606 ),
607 ];
608
609 for (input, output) in &valid {
610 assert_eq!(GS2CBindFlag::parse(input), Ok(*output));
611 }
612 }
613}