use logging_timer::time;
use rand::{
distributions::{Alphanumeric, DistString},
thread_rng,
};
use serde_with::{DeserializeFromStr, SerializeDisplay};
use std::convert::From;
use std::fmt;
pub const MAX_LENGTH_BYTES: usize = 32;
const STRING_CONVERSION_ERR_MSG: &str = "A failure should not be possible here because the length of the random string exactly matches the max allowed length";
#[derive(Debug, Clone, PartialEq, SerializeDisplay, DeserializeFromStr)]
pub struct Salt([u8; 32]);
impl Salt {
pub fn as_bytes(&self) -> &[u8; 32] {
&self.0
}
#[time("debug", "NdmSmt::NdmSmtSalts::{}")]
pub fn generate_random() -> Self {
let mut rng = thread_rng();
let random_str = Alphanumeric.sample_string(&mut rng, MAX_LENGTH_BYTES);
Salt::from_str(&random_str).expect(STRING_CONVERSION_ERR_MSG)
}
}
impl fmt::Display for Salt {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let s = String::from_utf8_lossy(&self.0);
write!(f, "{}", s)
}
}
use crate::kdf;
impl From<kdf::Key> for Salt {
fn from(key: kdf::Key) -> Self {
let bytes: [u8; 32] = key.into();
Salt(bytes)
}
}
use std::str::FromStr;
impl FromStr for Salt {
type Err = SaltParserError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
if s.len() > MAX_LENGTH_BYTES {
Err(SaltParserError::StringTooLongError)
} else {
let mut arr = [0u8; 32];
arr[..s.len()].copy_from_slice(s.as_bytes());
Ok(Salt(arr))
}
}
}
impl From<Salt> for [u8; 32] {
fn from(item: Salt) -> Self {
item.0
}
}
impl From<u64> for Salt {
fn from(num: u64) -> Self {
let bytes = num.to_le_bytes();
let mut arr = [0u8; 32];
arr[..8].copy_from_slice(&bytes[..8]);
Salt(arr)
}
}
use clap::builder::OsStr;
impl From<Salt> for OsStr {
fn from(salt: Salt) -> OsStr {
OsStr::from(String::from_utf8_lossy(&salt.0).into_owned())
}
}
impl Default for Salt {
fn default() -> Self {
Salt::generate_random()
}
}
#[derive(Debug, thiserror::Error)]
pub enum SaltParserError {
#[error("The given string has more than the max allowed bytes of {MAX_LENGTH_BYTES}")]
StringTooLongError,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn randomly_generated_salts_differ_enough() {
let salt_1 = Salt::generate_random();
let salt_2 = Salt::generate_random();
let threshold = 10;
let iter_1 = salt_1.0.iter();
let iter_2 = salt_2.0.iter();
assert!(
iter_1
.zip(iter_2)
.filter(|(byte_1, byte_2)| byte_1 == byte_2)
.count()
< threshold
);
}
}