use crate::{Error, Result, StringBuf};
use base64ct::{Base64Unpadded as B64, Encoding};
use core::{
fmt,
ops::Deref,
str::{self, FromStr},
};
#[cfg(feature = "rand_core")]
use rand_core::{CryptoRng, TryCryptoRng};
const INVARIANT_VIOLATED_MSG: &str = "salt string invariant violated";
#[derive(Copy, Clone, Eq, PartialEq)]
pub struct Salt {
pub(super) length: u8,
pub(super) bytes: [u8; Self::MAX_LENGTH],
}
#[allow(clippy::len_without_is_empty)]
impl Salt {
pub const MIN_LENGTH: usize = 8;
pub const MAX_LENGTH: usize = 48;
pub const RECOMMENDED_LENGTH: usize = 16;
#[cfg(feature = "getrandom")]
pub fn generate() -> Self {
let mut bytes = [0u8; Self::RECOMMENDED_LENGTH];
getrandom::fill(&mut bytes).expect("RNG failure");
Self::new(&bytes).expect(INVARIANT_VIOLATED_MSG)
}
#[cfg(feature = "rand_core")]
pub fn from_rng<R: CryptoRng + ?Sized>(rng: &mut R) -> Self {
let Ok(out) = Self::try_from_rng(rng);
out
}
#[cfg(feature = "rand_core")]
pub fn try_from_rng<R: TryCryptoRng + ?Sized>(
rng: &mut R,
) -> core::result::Result<Self, R::Error> {
let mut bytes = [0u8; Self::RECOMMENDED_LENGTH];
rng.try_fill_bytes(&mut bytes)?;
Ok(Self::new(&bytes).expect(INVARIANT_VIOLATED_MSG))
}
pub fn new(slice: &[u8]) -> Result<Self> {
if slice.len() < Self::MIN_LENGTH {
return Err(Error::SaltTooShort);
}
let mut bytes = [0; Self::MAX_LENGTH];
bytes
.get_mut(..slice.len())
.ok_or(Error::SaltTooLong)?
.copy_from_slice(slice);
debug_assert!(slice.len() >= Self::MIN_LENGTH);
debug_assert!(slice.len() <= Self::MAX_LENGTH);
Ok(Self {
bytes,
length: slice.len() as u8,
})
}
pub fn from_b64(b64: &str) -> Result<Self> {
if b64.len() < SaltString::MIN_LENGTH {
return Err(Error::SaltTooShort);
}
if b64.len() > SaltString::MAX_LENGTH {
return Err(Error::SaltTooLong);
}
let mut bytes = [0; Self::MAX_LENGTH];
let length = B64::decode(b64, &mut bytes)?.len();
debug_assert!(length <= Self::MAX_LENGTH);
Ok(Self {
bytes,
length: length as u8,
})
}
pub fn to_salt_string(&self) -> SaltString {
self.into()
}
}
impl AsRef<[u8]> for Salt {
fn as_ref(&self) -> &[u8] {
&self.bytes[..(self.length as usize)]
}
}
impl Deref for Salt {
type Target = [u8];
fn deref(&self) -> &[u8] {
self.as_ref()
}
}
impl FromStr for Salt {
type Err = Error;
fn from_str(b64: &str) -> Result<Self> {
Self::from_b64(b64)
}
}
impl TryFrom<&[u8]> for Salt {
type Error = Error;
fn try_from(slice: &[u8]) -> Result<Self> {
Self::new(slice)
}
}
impl TryFrom<&str> for Salt {
type Error = Error;
fn try_from(b64: &str) -> Result<Self> {
Self::from_b64(b64)
}
}
impl fmt::Display for Salt {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.to_salt_string().fmt(f)
}
}
impl fmt::Debug for Salt {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("Salt").field(&self.as_ref()).finish()
}
}
#[derive(Clone, Eq)]
pub struct SaltString(StringBuf<{ SaltString::MAX_LENGTH }>);
#[allow(clippy::len_without_is_empty)]
impl SaltString {
pub const MIN_LENGTH: usize = 11;
pub const MAX_LENGTH: usize = 64;
#[cfg(feature = "getrandom")]
pub fn generate() -> Self {
Salt::generate().into()
}
#[cfg(feature = "rand_core")]
pub fn from_rng<R: CryptoRng + ?Sized>(rng: &mut R) -> Self {
let Ok(out) = Self::try_from_rng(rng);
out
}
#[cfg(feature = "rand_core")]
pub fn try_from_rng<R: TryCryptoRng + ?Sized>(
rng: &mut R,
) -> core::result::Result<Self, R::Error> {
Ok(Salt::try_from_rng(rng)?.to_salt_string())
}
pub fn from_b64(s: &str) -> Result<Self> {
Salt::from_b64(s)?;
Ok(Self(s.parse()?))
}
pub fn to_salt(&self) -> Salt {
self.into()
}
}
impl AsRef<str> for SaltString {
fn as_ref(&self) -> &str {
&self.0
}
}
impl Deref for SaltString {
type Target = str;
fn deref(&self) -> &str {
&self.0
}
}
impl From<Salt> for SaltString {
fn from(salt: Salt) -> Self {
SaltString::from(&salt)
}
}
impl From<&Salt> for SaltString {
fn from(salt: &Salt) -> Self {
let mut buf = [0; SaltString::MAX_LENGTH];
let b64 = B64::encode(salt, &mut buf).expect(INVARIANT_VIOLATED_MSG);
SaltString(b64.parse().expect(INVARIANT_VIOLATED_MSG))
}
}
impl From<SaltString> for Salt {
fn from(salt: SaltString) -> Self {
Salt::from(&salt)
}
}
impl From<&SaltString> for Salt {
fn from(salt: &SaltString) -> Self {
Salt::from_b64(salt.as_ref()).expect(INVARIANT_VIOLATED_MSG)
}
}
impl FromStr for SaltString {
type Err = Error;
fn from_str(s: &str) -> Result<Self> {
Self::from_b64(s)
}
}
impl PartialEq for SaltString {
fn eq(&self, other: &Self) -> bool {
self.as_ref().eq(other.as_ref())
}
}
impl fmt::Display for SaltString {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(self.as_ref())
}
}
impl fmt::Debug for SaltString {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "SaltString({:?})", self.as_ref())
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::{Error, Salt};
#[test]
fn new_with_valid_min_length_input() {
let s = "abcdabcdabc";
let salt = Salt::from_b64(s).unwrap();
assert_eq!(
salt.as_ref(),
&[0x69, 0xb7, 0x1d, 0x69, 0xb7, 0x1d, 0x69, 0xb7]
);
}
#[test]
fn new_with_valid_max_length_input() {
let s = "012345678911234567892123456789312345678941234567";
let salt = Salt::from_b64(s).unwrap();
assert_eq!(
salt.as_ref(),
&[
0xd3, 0x5d, 0xb7, 0xe3, 0x9e, 0xbb, 0xf3, 0xdd, 0x75, 0xdb, 0x7e, 0x39, 0xeb, 0xbf,
0x3d, 0xdb, 0x5d, 0xb7, 0xe3, 0x9e, 0xbb, 0xf3, 0xdd, 0xf5, 0xdb, 0x7e, 0x39, 0xeb,
0xbf, 0x3d, 0xe3, 0x5d, 0xb7, 0xe3, 0x9e, 0xbb
]
);
}
#[test]
fn reject_new_too_short() {
for &too_short in &["", "a", "ab", "abc"] {
let err = Salt::from_b64(too_short).err().unwrap();
assert_eq!(err, Error::SaltTooShort);
}
}
#[test]
fn reject_new_too_long() {
let s = "01234567891123456789212345678931234567894123456785234567896234567";
let err = Salt::from_b64(s).err().unwrap();
assert_eq!(err, Error::SaltTooLong);
}
#[test]
fn reject_new_invalid_char() {
let s = "01234_abcde";
let err = Salt::from_b64(s).err().unwrap();
assert_eq!(err, Error::Base64(base64ct::Error::InvalidEncoding));
}
}