#![forbid(unsafe_code)]
use oxicrypto_core::{CryptoError, SecretVec};
use crate::argon2_kdf::{argon2id_derive, Argon2Params};
use crate::balloon::balloon_sha256;
use crate::pbkdf2_kdf::pbkdf2_sha256;
use crate::scrypt_kdf::scrypt_derive;
#[derive(Debug, Clone, Copy)]
pub struct Argon2idStretchParams {
pub params: Argon2Params,
pub out_len: usize,
}
#[derive(Debug, Clone, Copy)]
pub struct ScryptStretchParams {
pub log_n: u8,
pub r: u32,
pub p: u32,
pub out_len: usize,
}
#[derive(Debug, Clone, Copy)]
pub struct Pbkdf2StretchParams {
pub iterations: u32,
pub out_len: usize,
}
#[derive(Debug, Clone, Copy)]
pub struct BalloonStretchParams {
pub space_cost: u64,
pub time_cost: u64,
}
#[derive(Debug, Clone, Copy)]
pub enum StretchParams {
Argon2id(Argon2idStretchParams),
Scrypt(ScryptStretchParams),
Pbkdf2Sha256(Pbkdf2StretchParams),
BalloonSha256(BalloonStretchParams),
}
impl StretchParams {
#[must_use]
pub fn output_len(&self) -> usize {
match self {
StretchParams::Argon2id(p) => p.out_len,
StretchParams::Scrypt(p) => p.out_len,
StretchParams::Pbkdf2Sha256(p) => p.out_len,
StretchParams::BalloonSha256(_) => 32,
}
}
#[must_use]
pub fn name(&self) -> &'static str {
match self {
StretchParams::Argon2id(_) => "argon2id",
StretchParams::Scrypt(_) => "scrypt",
StretchParams::Pbkdf2Sha256(_) => "pbkdf2-sha256",
StretchParams::BalloonSha256(_) => "balloon-sha256",
}
}
}
pub trait KeyStretcher {
fn stretch(&self, password: &[u8], salt: &[u8]) -> Result<SecretVec, CryptoError>;
fn output_len(&self) -> usize;
fn name(&self) -> &'static str;
}
#[derive(Debug, Clone, Copy)]
pub struct Stretcher {
params: StretchParams,
}
impl Stretcher {
#[must_use]
pub fn new(params: StretchParams) -> Self {
Self { params }
}
#[must_use]
pub fn params(&self) -> &StretchParams {
&self.params
}
}
impl KeyStretcher for Stretcher {
fn stretch(&self, password: &[u8], salt: &[u8]) -> Result<SecretVec, CryptoError> {
match self.params {
StretchParams::Argon2id(p) => {
if p.out_len == 0 {
return Err(CryptoError::BadInput);
}
let mut out = vec![0u8; p.out_len];
argon2id_derive(password, salt, p.params, &mut out)?;
Ok(SecretVec::new(out))
}
StretchParams::Scrypt(p) => {
if p.out_len == 0 {
return Err(CryptoError::BadInput);
}
let mut out = vec![0u8; p.out_len];
scrypt_derive(password, salt, p.log_n, p.r, p.p, &mut out)?;
Ok(SecretVec::new(out))
}
StretchParams::Pbkdf2Sha256(p) => {
if p.out_len == 0 {
return Err(CryptoError::BadInput);
}
let mut out = vec![0u8; p.out_len];
pbkdf2_sha256(password, salt, p.iterations, &mut out)?;
Ok(SecretVec::new(out))
}
StretchParams::BalloonSha256(p) => {
let mut out = vec![0u8; 32];
balloon_sha256(password, salt, p.space_cost, p.time_cost, &mut out)?;
Ok(SecretVec::new(out))
}
}
}
fn output_len(&self) -> usize {
self.params.output_len()
}
fn name(&self) -> &'static str {
self.params.name()
}
}
#[cfg(test)]
mod tests {
use super::*;
const SALT16: &[u8] = b"0123456789abcdef";
fn check_backend(params: StretchParams, expect_len: usize) {
let stretcher = Stretcher::new(params);
assert_eq!(stretcher.output_len(), expect_len);
let k1 = stretcher.stretch(b"password", SALT16).expect("stretch 1");
let k2 = stretcher.stretch(b"password", SALT16).expect("stretch 2");
assert_eq!(k1.as_bytes(), k2.as_bytes(), "{}", stretcher.name());
assert_eq!(k1.len(), expect_len);
assert_ne!(k1.as_bytes(), vec![0u8; expect_len].as_slice());
let k3 = stretcher
.stretch(b"password", b"fedcba9876543210")
.expect("stretch 3");
assert_ne!(
k1.as_bytes(),
k3.as_bytes(),
"{} salt sensitivity",
stretcher.name()
);
}
#[test]
fn argon2id_backend() {
check_backend(
StretchParams::Argon2id(Argon2idStretchParams {
params: Argon2Params::TEST_PARAMS,
out_len: 32,
}),
32,
);
}
#[test]
fn scrypt_backend() {
check_backend(
StretchParams::Scrypt(ScryptStretchParams {
log_n: 4,
r: 8,
p: 1,
out_len: 32,
}),
32,
);
}
#[test]
fn pbkdf2_backend() {
check_backend(
StretchParams::Pbkdf2Sha256(Pbkdf2StretchParams {
iterations: 1000,
out_len: 48,
}),
48,
);
}
#[test]
fn balloon_backend() {
check_backend(
StretchParams::BalloonSha256(BalloonStretchParams {
space_cost: 8,
time_cost: 3,
}),
32,
);
}
#[test]
fn trait_object_dispatch() {
let backends: Vec<Box<dyn KeyStretcher>> = vec![
Box::new(Stretcher::new(StretchParams::Argon2id(
Argon2idStretchParams {
params: Argon2Params::TEST_PARAMS,
out_len: 32,
},
))),
Box::new(Stretcher::new(StretchParams::Scrypt(ScryptStretchParams {
log_n: 4,
r: 8,
p: 1,
out_len: 32,
}))),
Box::new(Stretcher::new(StretchParams::Pbkdf2Sha256(
Pbkdf2StretchParams {
iterations: 1000,
out_len: 32,
},
))),
Box::new(Stretcher::new(StretchParams::BalloonSha256(
BalloonStretchParams {
space_cost: 8,
time_cost: 3,
},
))),
];
for b in &backends {
let key = b.stretch(b"password", SALT16).expect("dispatch stretch");
assert_eq!(key.len(), b.output_len(), "{}", b.name());
}
let outs: Vec<Vec<u8>> = backends
.iter()
.map(|b| {
b.stretch(b"password", SALT16)
.expect("k")
.as_bytes()
.to_vec()
})
.collect();
for i in 0..outs.len() {
for j in (i + 1)..outs.len() {
assert_ne!(outs[i], outs[j], "backends {i} and {j} collided");
}
}
}
#[test]
fn matches_standalone_pbkdf2() {
let stretcher = Stretcher::new(StretchParams::Pbkdf2Sha256(Pbkdf2StretchParams {
iterations: 1000,
out_len: 32,
}));
let via_trait = stretcher.stretch(b"password", b"salt").expect("trait");
let mut direct = [0u8; 32];
pbkdf2_sha256(b"password", b"salt", 1000, &mut direct).expect("direct");
assert_eq!(via_trait.as_bytes(), &direct[..]);
}
#[test]
fn zero_output_len_rejected() {
let stretcher = Stretcher::new(StretchParams::Pbkdf2Sha256(Pbkdf2StretchParams {
iterations: 1000,
out_len: 0,
}));
assert_eq!(
stretcher.stretch(b"pw", SALT16).err(),
Some(CryptoError::BadInput)
);
}
}