use core::convert::TryFrom;
use core::fmt;
use core::ops::Deref;
use thiserror::Error;
use crate::mechname::MechanismNameError::InvalidChar;
#[repr(transparent)]
#[derive(Eq, PartialEq)]
pub struct Mechname {
inner: str,
}
impl Mechname {
pub fn parse(input: &[u8]) -> Result<&Self, MechanismNameError> {
if input.is_empty() {
Err(MechanismNameError::TooShort)
} else {
input.iter().enumerate().try_for_each(|(index, value)| {
if is_invalid(*value) {
Err(InvalidChar {
index,
value: *value,
})
} else {
Ok(())
}
})?;
Ok(Self::const_new(input))
}
}
#[must_use]
#[inline(always)]
pub const fn as_str(&self) -> &str {
&self.inner
}
#[must_use]
#[inline(always)]
pub const fn as_bytes(&self) -> &[u8] {
self.inner.as_bytes()
}
pub(crate) const fn const_new(s: &[u8]) -> &Self {
unsafe { core::mem::transmute(s) }
}
}
#[cfg(feature = "unstable_custom_mechanism")]
impl Mechname {
#[inline(always)]
#[must_use]
pub const fn const_new_unchecked(s: &[u8]) -> &Self {
Self::const_new(s)
}
}
impl fmt::Display for Mechname {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(self.as_str())
}
}
impl fmt::Debug for Mechname {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "MECHANISM({})", self.as_str())
}
}
impl PartialEq<[u8]> for Mechname {
fn eq(&self, other: &[u8]) -> bool {
self.as_bytes() == other
}
}
impl PartialEq<Mechname> for [u8] {
fn eq(&self, other: &Mechname) -> bool {
self == other.as_bytes()
}
}
impl PartialEq<str> for Mechname {
fn eq(&self, other: &str) -> bool {
self.as_str() == other
}
}
impl PartialEq<Mechname> for str {
fn eq(&self, other: &Mechname) -> bool {
self == other.as_str()
}
}
impl<'a> TryFrom<&'a [u8]> for &'a Mechname {
type Error = MechanismNameError;
fn try_from(value: &'a [u8]) -> Result<Self, Self::Error> {
Mechname::parse(value)
}
}
impl<'a> TryFrom<&'a str> for &'a Mechname {
type Error = MechanismNameError;
fn try_from(value: &'a str) -> Result<Self, Self::Error> {
Mechname::parse(value.as_bytes())
}
}
impl Deref for Mechname {
type Target = str;
fn deref(&self) -> &Self::Target {
self.as_str()
}
}
#[inline(always)]
const fn is_invalid(byte: u8) -> bool {
!(is_valid(byte))
}
#[inline(always)]
const fn is_valid(byte: u8) -> bool {
core::matches!(byte, b'A'..=b'Z' | b'0'..=b'9' | b'-' | b'_')
}
#[derive(Debug, Ord, PartialOrd, Eq, PartialEq, Copy, Clone, Error)]
#[non_exhaustive]
pub enum MechanismNameError {
#[error("a mechanism name can not be empty")]
TooShort,
#[error("contains invalid character at offset {index}: {value:#x}")]
InvalidChar {
index: usize,
value: u8,
},
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_mechname() {
let valids = [
"PLAIN",
"SCRAM-SHA256-PLUS",
"GS2-KRB5-PLUS",
"XOAUTHBEARER",
"EXACTLY_20_CHAR_LONG",
"X-THIS-MECHNAME-IS-TOO-LONG",
"EXACTLY_21_CHARS_LONG",
];
let invalidchars = [
("PLAIN GSSAPI LOGIN", 5, b' '),
("SCRAM-SHA256-PLUS GSSAPI X-OAUTH2", 17, b' '),
("X-CONTAINS-NULL\0", 15, b'\0'),
("PLAIN\0", 5, b'\0'),
("X-lowercase", 2, b'l'),
("X-LÄTIN1", 3, b'\xC3'),
];
for m in valids {
println!("Checking {m}");
let res = Mechname::parse(m.as_bytes()).map(Mechname::as_bytes);
assert_eq!(res, Ok(m.as_bytes()));
}
for (m, index, value) in invalidchars {
let e = Mechname::parse(m.as_bytes())
.map(Mechname::as_bytes)
.unwrap_err();
println!("Checking {m}: {e}");
assert_eq!(e, MechanismNameError::InvalidChar { index, value });
}
}
}