use crate::{Encoding, Error, Result, Value};
use core::{fmt, str};
use crate::errors::InvalidValue;
#[cfg(feature = "rand_core")]
use rand_core::CryptoRngCore;
const INVARIANT_VIOLATED_MSG: &str = "salt string invariant violated";
#[derive(Copy, Clone, Eq, PartialEq)]
pub struct Salt<'a>(Value<'a>);
#[allow(clippy::len_without_is_empty)]
impl<'a> Salt<'a> {
pub const MIN_LENGTH: usize = 4;
pub const MAX_LENGTH: usize = 64;
pub const RECOMMENDED_LENGTH: usize = 16;
pub fn from_b64(input: &'a str) -> Result<Self> {
let length = input.as_bytes().len();
if length < Self::MIN_LENGTH {
return Err(Error::SaltInvalid(InvalidValue::TooShort));
}
if length > Self::MAX_LENGTH {
return Err(Error::SaltInvalid(InvalidValue::TooLong));
}
for char in input.chars() {
if !matches!(char, 'a'..='z' | 'A'..='Z' | '0'..='9' | '/' | '+' | '.' | '-') {
return Err(Error::SaltInvalid(InvalidValue::InvalidChar(char)));
}
}
input.try_into().map(Self).map_err(|e| match e {
Error::ParamValueInvalid(value_err) => Error::SaltInvalid(value_err),
err => err,
})
}
pub fn decode_b64<'b>(&self, buf: &'b mut [u8]) -> Result<&'b [u8]> {
self.0.b64_decode(buf)
}
pub fn as_str(&self) -> &'a str {
self.0.as_str()
}
pub fn len(&self) -> usize {
self.as_str().len()
}
#[deprecated(since = "0.5.0", note = "use `from_b64` instead")]
pub fn new(input: &'a str) -> Result<Self> {
Self::from_b64(input)
}
#[deprecated(since = "0.5.0", note = "use `decode_b64` instead")]
pub fn b64_decode<'b>(&self, buf: &'b mut [u8]) -> Result<&'b [u8]> {
self.decode_b64(buf)
}
}
impl<'a> AsRef<str> for Salt<'a> {
fn as_ref(&self) -> &str {
self.as_str()
}
}
impl<'a> TryFrom<&'a str> for Salt<'a> {
type Error = Error;
fn try_from(input: &'a str) -> Result<Self> {
Self::from_b64(input)
}
}
impl<'a> fmt::Display for Salt<'a> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(self.as_str())
}
}
impl<'a> fmt::Debug for Salt<'a> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Salt({:?})", self.as_str())
}
}
#[derive(Clone, Eq)]
pub struct SaltString {
chars: [u8; Salt::MAX_LENGTH],
length: u8,
}
#[allow(clippy::len_without_is_empty)]
impl SaltString {
#[cfg(feature = "rand_core")]
pub fn generate(mut rng: impl CryptoRngCore) -> Self {
let mut bytes = [0u8; Salt::RECOMMENDED_LENGTH];
rng.fill_bytes(&mut bytes);
Self::encode_b64(&bytes).expect(INVARIANT_VIOLATED_MSG)
}
pub fn from_b64(s: &str) -> Result<Self> {
Salt::from_b64(s)?;
let len = s.as_bytes().len();
let mut bytes = [0u8; Salt::MAX_LENGTH];
bytes[..len].copy_from_slice(s.as_bytes());
Ok(SaltString {
chars: bytes,
length: len as u8, })
}
pub fn decode_b64<'a>(&self, buf: &'a mut [u8]) -> Result<&'a [u8]> {
self.as_salt().decode_b64(buf)
}
pub fn encode_b64(input: &[u8]) -> Result<Self> {
let mut bytes = [0u8; Salt::MAX_LENGTH];
let length = Encoding::B64.encode(input, &mut bytes)?.len() as u8;
Ok(Self {
chars: bytes,
length,
})
}
pub fn as_salt(&self) -> Salt<'_> {
Salt::from_b64(self.as_str()).expect(INVARIANT_VIOLATED_MSG)
}
pub fn as_str(&self) -> &str {
str::from_utf8(&self.chars[..(self.length as usize)]).expect(INVARIANT_VIOLATED_MSG)
}
pub fn len(&self) -> usize {
self.as_str().len()
}
#[deprecated(since = "0.5.0", note = "use `from_b64` instead")]
pub fn new(s: &str) -> Result<Self> {
Self::from_b64(s)
}
#[deprecated(since = "0.5.0", note = "use `decode_b64` instead")]
pub fn b64_decode<'a>(&self, buf: &'a mut [u8]) -> Result<&'a [u8]> {
self.decode_b64(buf)
}
#[deprecated(since = "0.5.0", note = "use `encode_b64` instead")]
pub fn b64_encode(input: &[u8]) -> Result<Self> {
Self::encode_b64(input)
}
}
impl AsRef<str> for SaltString {
fn as_ref(&self) -> &str {
self.as_str()
}
}
impl PartialEq for SaltString {
fn eq(&self, other: &Self) -> bool {
self.as_ref().eq(other.as_ref())
}
}
impl<'a> From<&'a SaltString> for Salt<'a> {
fn from(salt_string: &'a SaltString) -> Salt<'a> {
salt_string.as_salt()
}
}
impl fmt::Display for SaltString {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(self.as_str())
}
}
impl fmt::Debug for SaltString {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "SaltString({:?})", self.as_str())
}
}
#[cfg(test)]
mod tests {
use super::{Error, Salt};
use crate::errors::InvalidValue;
#[test]
fn new_with_valid_min_length_input() {
let s = "abcd";
let salt = Salt::from_b64(s).unwrap();
assert_eq!(salt.as_ref(), s);
}
#[test]
fn new_with_valid_max_length_input() {
let s = "012345678911234567892123456789312345678941234567";
let salt = Salt::from_b64(s).unwrap();
assert_eq!(salt.as_ref(), s);
}
#[test]
fn reject_new_too_short() {
for &too_short in &["", "a", "ab", "abc"] {
let err = Salt::from_b64(too_short).err().unwrap();
assert_eq!(err, Error::SaltInvalid(InvalidValue::TooShort));
}
}
#[test]
fn reject_new_too_long() {
let s = "01234567891123456789212345678931234567894123456785234567896234567";
let err = Salt::from_b64(s).err().unwrap();
assert_eq!(err, Error::SaltInvalid(InvalidValue::TooLong));
}
#[test]
fn reject_new_invalid_char() {
let s = "01234_abcd";
let err = Salt::from_b64(s).err().unwrap();
assert_eq!(err, Error::SaltInvalid(InvalidValue::InvalidChar('_')));
}
}