use crate::{Algorithm, Argon2, Error, Result, Version, SYNC_POINTS};
use base64ct::{Base64Unpadded as B64, Encoding};
use core::str::FromStr;
#[cfg(feature = "password-hash")]
use password_hash::{ParamsString, PasswordHash};
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct Params {
m_cost: u32,
t_cost: u32,
p_cost: u32,
keyid: KeyId,
data: AssociatedData,
output_len: Option<usize>,
}
impl Params {
pub const DEFAULT_M_COST: u32 = 19 * 1024;
#[allow(clippy::cast_possible_truncation)]
pub const MIN_M_COST: u32 = 2 * SYNC_POINTS as u32;
pub const MAX_M_COST: u32 = u32::MAX;
pub const DEFAULT_T_COST: u32 = 2;
pub const MIN_T_COST: u32 = 1;
pub const MAX_T_COST: u32 = u32::MAX;
pub const DEFAULT_P_COST: u32 = 1;
pub const MIN_P_COST: u32 = 1;
pub const MAX_P_COST: u32 = 0xFFFFFF;
pub const MAX_KEYID_LEN: usize = 8;
pub const MAX_DATA_LEN: usize = 32;
pub const DEFAULT_OUTPUT_LEN: usize = 32;
pub const MIN_OUTPUT_LEN: usize = 4;
pub const MAX_OUTPUT_LEN: usize = 0xFFFFFFFF;
pub const DEFAULT: Self = Params {
m_cost: Self::DEFAULT_M_COST,
t_cost: Self::DEFAULT_T_COST,
p_cost: Self::DEFAULT_P_COST,
keyid: KeyId {
bytes: [0u8; Self::MAX_KEYID_LEN],
len: 0,
},
data: AssociatedData {
bytes: [0u8; Self::MAX_DATA_LEN],
len: 0,
},
output_len: None,
};
pub const fn new(
m_cost: u32,
t_cost: u32,
p_cost: u32,
output_len: Option<usize>,
) -> Result<Self> {
if m_cost < Params::MIN_M_COST {
return Err(Error::MemoryTooLittle);
}
if m_cost < p_cost * 8 {
return Err(Error::MemoryTooLittle);
}
if t_cost < Params::MIN_T_COST {
return Err(Error::TimeTooSmall);
}
if p_cost < Params::MIN_P_COST {
return Err(Error::ThreadsTooFew);
}
if p_cost > Params::MAX_P_COST {
return Err(Error::ThreadsTooMany);
}
if let Some(len) = output_len {
if len < Params::MIN_OUTPUT_LEN {
return Err(Error::OutputTooShort);
}
if len > Params::MAX_OUTPUT_LEN {
return Err(Error::OutputTooLong);
}
}
Ok(Params {
m_cost,
t_cost,
p_cost,
keyid: KeyId::EMPTY,
data: AssociatedData::EMPTY,
output_len,
})
}
pub const fn m_cost(&self) -> u32 {
self.m_cost
}
pub const fn t_cost(&self) -> u32 {
self.t_cost
}
pub const fn p_cost(&self) -> u32 {
self.p_cost
}
pub fn keyid(&self) -> &[u8] {
self.keyid.as_bytes()
}
pub fn data(&self) -> &[u8] {
self.data.as_bytes()
}
pub const fn output_len(&self) -> Option<usize> {
self.output_len
}
#[allow(clippy::cast_possible_truncation)]
pub(crate) const fn lanes(&self) -> usize {
self.p_cost as usize
}
pub(crate) const fn lane_length(&self) -> usize {
self.segment_length() * SYNC_POINTS
}
pub(crate) const fn segment_length(&self) -> usize {
let m_cost = self.m_cost as usize;
let memory_blocks = if m_cost < 2 * SYNC_POINTS * self.lanes() {
2 * SYNC_POINTS * self.lanes()
} else {
m_cost
};
memory_blocks / (self.lanes() * SYNC_POINTS)
}
pub const fn block_count(&self) -> usize {
self.segment_length() * self.lanes() * SYNC_POINTS
}
}
impl Default for Params {
fn default() -> Params {
Params::DEFAULT
}
}
macro_rules! param_buf {
($ty:ident, $name:expr, $max_len:expr, $error:expr, $doc:expr) => {
#[doc = $doc]
#[derive(Copy, Clone, Debug, Default, Eq, Hash, PartialEq, PartialOrd, Ord)]
pub struct $ty {
/// Byte array
bytes: [u8; Self::MAX_LEN],
len: usize,
}
impl $ty {
pub const MAX_LEN: usize = $max_len;
#[doc = "Create a new"]
#[doc = $name]
#[doc = "from a slice."]
pub fn new(slice: &[u8]) -> Result<Self> {
let mut bytes = [0u8; Self::MAX_LEN];
let len = slice.len();
bytes.get_mut(..len).ok_or($error)?.copy_from_slice(slice);
Ok(Self { bytes, len })
}
pub const EMPTY: Self = Self {
bytes: [0u8; Self::MAX_LEN],
len: 0,
};
#[doc = "Decode"]
#[doc = $name]
#[doc = " from a B64 string"]
pub fn from_b64(s: &str) -> Result<Self> {
let mut bytes = [0u8; Self::MAX_LEN];
let len = B64::decode(s, &mut bytes)?.len();
Ok(Self { bytes, len })
}
pub fn as_bytes(&self) -> &[u8] {
&self.bytes[..self.len]
}
pub const fn len(&self) -> usize {
self.len
}
pub const fn is_empty(&self) -> bool {
self.len() == 0
}
}
impl AsRef<[u8]> for $ty {
fn as_ref(&self) -> &[u8] {
self.as_bytes()
}
}
impl FromStr for $ty {
type Err = Error;
fn from_str(s: &str) -> Result<Self> {
Self::from_b64(s)
}
}
impl TryFrom<&[u8]> for $ty {
type Error = Error;
fn try_from(bytes: &[u8]) -> Result<Self> {
Self::new(bytes)
}
}
};
}
param_buf!(
KeyId,
"KeyId",
Params::MAX_KEYID_LEN,
Error::KeyIdTooLong,
"Key identifier"
);
param_buf!(
AssociatedData,
"AssociatedData",
Params::MAX_DATA_LEN,
Error::AdTooLong,
"Associated data"
);
#[cfg(feature = "password-hash")]
#[cfg_attr(docsrs, doc(cfg(feature = "password-hash")))]
impl<'a> TryFrom<&'a PasswordHash<'a>> for Params {
type Error = password_hash::Error;
fn try_from(hash: &'a PasswordHash<'a>) -> password_hash::Result<Self> {
let mut builder = ParamsBuilder::new();
for (ident, value) in hash.params.iter() {
match ident.as_str() {
"m" => {
builder.m_cost(value.decimal()?);
}
"t" => {
builder.t_cost(value.decimal()?);
}
"p" => {
builder.p_cost(value.decimal()?);
}
"keyid" => {
builder.keyid(value.as_str().parse()?);
}
"data" => {
builder.data(value.as_str().parse()?);
}
_ => return Err(password_hash::Error::ParamNameInvalid),
}
}
if let Some(output) = &hash.hash {
builder.output_len(output.len());
}
Ok(builder.build()?)
}
}
#[cfg(feature = "password-hash")]
#[cfg_attr(docsrs, doc(cfg(feature = "password-hash")))]
impl TryFrom<Params> for ParamsString {
type Error = password_hash::Error;
fn try_from(params: Params) -> password_hash::Result<ParamsString> {
ParamsString::try_from(¶ms)
}
}
#[cfg(feature = "password-hash")]
#[cfg_attr(docsrs, doc(cfg(feature = "password-hash")))]
impl TryFrom<&Params> for ParamsString {
type Error = password_hash::Error;
fn try_from(params: &Params) -> password_hash::Result<ParamsString> {
let mut output = ParamsString::new();
output.add_decimal("m", params.m_cost)?;
output.add_decimal("t", params.t_cost)?;
output.add_decimal("p", params.p_cost)?;
if !params.keyid.is_empty() {
output.add_b64_bytes("keyid", params.keyid.as_bytes())?;
}
if !params.data.is_empty() {
output.add_b64_bytes("data", params.data.as_bytes())?;
}
Ok(output)
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct ParamsBuilder {
m_cost: u32,
t_cost: u32,
p_cost: u32,
keyid: Option<KeyId>,
data: Option<AssociatedData>,
output_len: Option<usize>,
}
impl ParamsBuilder {
pub const fn new() -> Self {
Self::DEFAULT
}
pub fn m_cost(&mut self, m_cost: u32) -> &mut Self {
self.m_cost = m_cost;
self
}
pub fn t_cost(&mut self, t_cost: u32) -> &mut Self {
self.t_cost = t_cost;
self
}
pub fn p_cost(&mut self, p_cost: u32) -> &mut Self {
self.p_cost = p_cost;
self
}
pub fn keyid(&mut self, keyid: KeyId) -> &mut Self {
self.keyid = Some(keyid);
self
}
pub fn data(&mut self, data: AssociatedData) -> &mut Self {
self.data = Some(data);
self
}
pub fn output_len(&mut self, len: usize) -> &mut Self {
self.output_len = Some(len);
self
}
pub const fn build(&self) -> Result<Params> {
let mut params = match Params::new(self.m_cost, self.t_cost, self.p_cost, self.output_len) {
Ok(params) => params,
Err(err) => return Err(err),
};
if let Some(keyid) = self.keyid {
params.keyid = keyid;
}
if let Some(data) = self.data {
params.data = data;
};
Ok(params)
}
pub fn context(&self, algorithm: Algorithm, version: Version) -> Result<Argon2<'_>> {
Ok(Argon2::new(algorithm, version, self.build()?))
}
pub const DEFAULT: ParamsBuilder = {
let params = Params::DEFAULT;
Self {
m_cost: params.m_cost,
t_cost: params.t_cost,
p_cost: params.p_cost,
keyid: None,
data: None,
output_len: params.output_len,
}
};
}
impl Default for ParamsBuilder {
fn default() -> Self {
Self::DEFAULT
}
}
impl TryFrom<ParamsBuilder> for Params {
type Error = Error;
fn try_from(builder: ParamsBuilder) -> Result<Params> {
builder.build()
}
}
#[cfg(all(test, feature = "alloc", feature = "password-hash"))]
mod tests {
use super::*;
#[test]
fn params_builder_bad_values() {
assert_eq!(
ParamsBuilder::new().m_cost(Params::MIN_M_COST - 1).build(),
Err(Error::MemoryTooLittle)
);
assert_eq!(
ParamsBuilder::new().t_cost(Params::MIN_T_COST - 1).build(),
Err(Error::TimeTooSmall)
);
assert_eq!(
ParamsBuilder::new().p_cost(Params::MIN_P_COST - 1).build(),
Err(Error::ThreadsTooFew)
);
assert_eq!(
ParamsBuilder::new()
.m_cost(Params::DEFAULT_P_COST * 8 - 1)
.build(),
Err(Error::MemoryTooLittle)
);
assert_eq!(
ParamsBuilder::new()
.m_cost((Params::MAX_P_COST + 1) * 8)
.p_cost(Params::MAX_P_COST + 1)
.build(),
Err(Error::ThreadsTooMany)
);
}
#[test]
fn associated_data_too_long() {
let ret = AssociatedData::new(&[0u8; Params::MAX_DATA_LEN + 1]);
assert_eq!(ret, Err(Error::AdTooLong));
}
#[test]
fn keyid_too_long() {
let ret = KeyId::new(&[0u8; Params::MAX_KEYID_LEN + 1]);
assert_eq!(ret, Err(Error::KeyIdTooLong));
}
}