use crate::{Error, Result, SYNC_POINTS};
use base64ct::{Base64Unpadded as B64, Encoding};
use core::str::FromStr;
#[cfg(feature = "password-hash")]
use password_hash::{ParamsString, PasswordHash};
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct Params {
m_cost: u32,
t_cost: u32,
p_cost: u32,
keyid: KeyId,
data: AssociatedData,
output_len: Option<usize>,
}
impl Params {
pub const DEFAULT_M_COST: u32 = 4096;
#[allow(clippy::cast_possible_truncation)]
pub const MIN_M_COST: u32 = 2 * SYNC_POINTS as u32;
pub const MAX_M_COST: u32 = 0x0FFFFFFF;
pub const DEFAULT_T_COST: u32 = 3;
pub const MIN_T_COST: u32 = 1;
pub const MAX_T_COST: u32 = u32::MAX;
pub const DEFAULT_P_COST: u32 = 1;
pub const MIN_P_COST: u32 = 1;
pub const MAX_P_COST: u32 = 0xFFFFFF;
pub const MAX_KEYID_LEN: usize = 8;
pub const MAX_DATA_LEN: usize = 32;
pub const DEFAULT_OUTPUT_LEN: usize = 32;
pub const MIN_OUTPUT_LEN: usize = 4;
pub const MAX_OUTPUT_LEN: usize = 0xFFFFFFFF;
pub fn new(m_cost: u32, t_cost: u32, p_cost: u32, output_len: Option<usize>) -> Result<Self> {
let mut builder = ParamsBuilder::new();
builder.m_cost(m_cost).t_cost(t_cost).p_cost(p_cost);
if let Some(len) = output_len {
builder.output_len(len);
}
builder.build()
}
pub fn m_cost(&self) -> u32 {
self.m_cost
}
pub fn t_cost(&self) -> u32 {
self.t_cost
}
pub fn p_cost(&self) -> u32 {
self.p_cost
}
pub fn keyid(&self) -> &[u8] {
self.keyid.as_bytes()
}
pub fn data(&self) -> &[u8] {
self.data.as_bytes()
}
pub fn output_len(&self) -> Option<usize> {
self.output_len
}
#[allow(clippy::cast_possible_truncation)]
pub(crate) fn lanes(&self) -> usize {
self.p_cost as usize
}
pub(crate) fn lane_length(&self) -> usize {
self.segment_length() * SYNC_POINTS
}
pub(crate) fn segment_length(&self) -> usize {
let m_cost = self.m_cost as usize;
let memory_blocks = if m_cost < 2 * SYNC_POINTS * self.lanes() {
2 * SYNC_POINTS * self.lanes()
} else {
m_cost
};
memory_blocks / (self.lanes() * SYNC_POINTS)
}
pub fn block_count(&self) -> usize {
(self.segment_length() * self.lanes() * SYNC_POINTS) as usize
}
}
impl Default for Params {
fn default() -> Params {
Params {
m_cost: Self::DEFAULT_M_COST,
t_cost: Self::DEFAULT_T_COST,
p_cost: Self::DEFAULT_P_COST,
keyid: KeyId::default(),
data: AssociatedData::default(),
output_len: None,
}
}
}
macro_rules! param_buf {
($ty:ident, $name:expr, $max_len:expr, $error:expr, $doc:expr) => {
#[doc = $doc]
#[derive(Copy, Clone, Debug, Default, Eq, Hash, PartialEq, PartialOrd, Ord)]
pub struct $ty {
/// Byte array
bytes: [u8; Self::MAX_LEN],
len: usize,
}
impl $ty {
pub const MAX_LEN: usize = $max_len;
#[doc = "Create a new"]
#[doc = $name]
#[doc = "from a slice."]
pub fn new(slice: &[u8]) -> Result<Self> {
let mut bytes = [0u8; Self::MAX_LEN];
let len = slice.len();
bytes.get_mut(..len).ok_or($error)?.copy_from_slice(slice);
Ok(Self { bytes, len })
}
#[doc = "Decode"]
#[doc = $name]
#[doc = " from a B64 string"]
pub fn from_b64(s: &str) -> Result<Self> {
let mut bytes = [0u8; Self::MAX_LEN];
let len = B64::decode(s, &mut bytes)?.len();
Ok(Self { bytes, len })
}
pub fn as_bytes(&self) -> &[u8] {
&self.bytes[..self.len]
}
pub fn len(&self) -> usize {
self.len
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
impl AsRef<[u8]> for $ty {
fn as_ref(&self) -> &[u8] {
self.as_bytes()
}
}
impl FromStr for $ty {
type Err = Error;
fn from_str(s: &str) -> Result<Self> {
Self::from_b64(s)
}
}
impl TryFrom<&[u8]> for $ty {
type Error = Error;
fn try_from(bytes: &[u8]) -> Result<Self> {
Self::new(bytes)
}
}
};
}
param_buf!(
KeyId,
"KeyId",
Params::MAX_KEYID_LEN,
Error::KeyIdTooLong,
"Key identifier"
);
param_buf!(
AssociatedData,
"AssociatedData",
Params::MAX_DATA_LEN,
Error::AdTooLong,
"Associated data"
);
#[cfg(feature = "password-hash")]
#[cfg_attr(docsrs, doc(cfg(feature = "password-hash")))]
impl<'a> TryFrom<&'a PasswordHash<'a>> for Params {
type Error = password_hash::Error;
fn try_from(hash: &'a PasswordHash<'a>) -> password_hash::Result<Self> {
let mut builder = ParamsBuilder::new();
for (ident, value) in hash.params.iter() {
match ident.as_str() {
"m" => {
builder.m_cost(value.decimal()?);
}
"t" => {
builder.t_cost(value.decimal()?);
}
"p" => {
builder.p_cost(value.decimal()?);
}
"keyid" => {
builder.keyid(value.as_str().parse()?);
}
"data" => {
builder.data(value.as_str().parse()?);
}
_ => return Err(password_hash::Error::ParamNameInvalid),
}
}
if let Some(output) = &hash.hash {
builder.output_len(output.len());
}
Ok(builder.build()?)
}
}
#[cfg(feature = "password-hash")]
#[cfg_attr(docsrs, doc(cfg(feature = "password-hash")))]
impl TryFrom<Params> for ParamsString {
type Error = password_hash::Error;
fn try_from(params: Params) -> password_hash::Result<ParamsString> {
ParamsString::try_from(¶ms)
}
}
#[cfg(feature = "password-hash")]
#[cfg_attr(docsrs, doc(cfg(feature = "password-hash")))]
impl TryFrom<&Params> for ParamsString {
type Error = password_hash::Error;
fn try_from(params: &Params) -> password_hash::Result<ParamsString> {
let mut output = ParamsString::new();
output.add_decimal("m", params.m_cost)?;
output.add_decimal("t", params.t_cost)?;
output.add_decimal("p", params.p_cost)?;
if !params.keyid.is_empty() {
output.add_b64_bytes("keyid", params.keyid.as_bytes())?;
}
if !params.data.is_empty() {
output.add_b64_bytes("data", params.data.as_bytes())?;
}
Ok(output)
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct ParamsBuilder {
m_cost: u32,
t_cost: u32,
p_cost: u32,
keyid: Option<KeyId>,
data: Option<AssociatedData>,
output_len: Option<usize>,
}
impl ParamsBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn m_cost(&mut self, m_cost: u32) -> &mut Self {
self.m_cost = m_cost;
self
}
pub fn t_cost(&mut self, t_cost: u32) -> &mut Self {
self.t_cost = t_cost;
self
}
pub fn p_cost(&mut self, p_cost: u32) -> &mut Self {
self.p_cost = p_cost;
self
}
pub fn keyid(&mut self, keyid: KeyId) -> &mut Self {
self.keyid = Some(keyid);
self
}
pub fn data(&mut self, data: AssociatedData) -> &mut Self {
self.data = Some(data);
self
}
pub fn output_len(&mut self, len: usize) -> &mut Self {
self.output_len = Some(len);
self
}
pub fn build(&self) -> Result<Params> {
if self.m_cost < Params::MIN_M_COST {
return Err(Error::MemoryTooLittle);
}
if self.m_cost > Params::MAX_M_COST {
return Err(Error::MemoryTooMuch);
}
if self.m_cost < self.p_cost * 8 {
return Err(Error::MemoryTooLittle);
}
if self.t_cost < Params::MIN_T_COST {
return Err(Error::TimeTooSmall);
}
if self.p_cost < Params::MIN_P_COST {
return Err(Error::ThreadsTooFew);
}
if self.p_cost > Params::MAX_P_COST {
return Err(Error::ThreadsTooMany);
}
if let Some(len) = self.output_len {
if len < Params::MIN_OUTPUT_LEN {
return Err(Error::OutputTooShort);
}
if len > Params::MAX_OUTPUT_LEN {
return Err(Error::OutputTooLong);
}
}
let keyid = self.keyid.unwrap_or_default();
let data = self.data.unwrap_or_default();
let params = Params {
m_cost: self.m_cost,
t_cost: self.t_cost,
p_cost: self.p_cost,
keyid,
data,
output_len: self.output_len,
};
Ok(params)
}
}
impl Default for ParamsBuilder {
fn default() -> Self {
let params = Params::default();
Self {
m_cost: params.m_cost,
t_cost: params.t_cost,
p_cost: params.p_cost,
keyid: None,
data: None,
output_len: params.output_len,
}
}
}
impl TryFrom<ParamsBuilder> for Params {
type Error = Error;
fn try_from(builder: ParamsBuilder) -> Result<Params> {
builder.build()
}
}
#[cfg(all(test, feature = "alloc", feature = "password-hash"))]
mod tests {
use super::*;
#[test]
fn params_builder_bad_values() {
assert_eq!(
ParamsBuilder::new().m_cost(Params::MIN_M_COST - 1).build(),
Err(Error::MemoryTooLittle)
);
assert_eq!(
ParamsBuilder::new().m_cost(Params::MAX_M_COST + 1).build(),
Err(Error::MemoryTooMuch)
);
assert_eq!(
ParamsBuilder::new().t_cost(Params::MIN_T_COST - 1).build(),
Err(Error::TimeTooSmall)
);
assert_eq!(
ParamsBuilder::new().p_cost(Params::MIN_P_COST - 1).build(),
Err(Error::ThreadsTooFew)
);
assert_eq!(
ParamsBuilder::new()
.m_cost(Params::DEFAULT_P_COST * 8 - 1)
.build(),
Err(Error::MemoryTooLittle)
);
assert_eq!(
ParamsBuilder::new()
.m_cost((Params::MAX_P_COST + 1) * 8)
.p_cost(Params::MAX_P_COST + 1)
.build(),
Err(Error::ThreadsTooMany)
);
}
#[test]
fn associated_data_too_long() {
let ret = AssociatedData::new(&[0u8; Params::MAX_DATA_LEN + 1]);
assert_eq!(ret, Err(Error::AdTooLong));
}
#[test]
fn keyid_too_long() {
let ret = KeyId::new(&[0u8; Params::MAX_KEYID_LEN + 1]);
assert_eq!(ret, Err(Error::KeyIdTooLong));
}
}