#![forbid(unsafe_code)]
use oxicrypto_core::{CryptoError, PasswordHash as PasswordHashTrait, PasswordHashParams};
use scrypt::{scrypt, Params as RcScryptParams};
#[derive(Debug, Clone, Copy)]
pub struct ScryptParams {
pub log_n: u8,
pub r: u32,
pub p: u32,
}
impl ScryptParams {
#[must_use = "ScryptParams creation result must be checked"]
pub fn new(log_n: u8, r: u32, p: u32) -> Result<Self, CryptoError> {
RcScryptParams::new(log_n, r, p).map_err(|_| CryptoError::BadInput)?;
Ok(Self { log_n, r, p })
}
#[must_use]
pub fn interactive() -> Self {
Self {
log_n: 15,
r: 8,
p: 1,
}
}
#[must_use]
pub fn moderate() -> Self {
Self {
log_n: 17,
r: 8,
p: 1,
}
}
#[must_use]
pub fn sensitive() -> Self {
Self {
log_n: 20,
r: 8,
p: 1,
}
}
}
impl PasswordHashParams for ScryptParams {
fn memory_cost(&self) -> Option<u32> {
let n: u64 = 1u64 << self.log_n;
let kib = 128u64.saturating_mul(n).saturating_mul(self.r as u64) / 1024;
u32::try_from(kib).ok()
}
fn time_cost(&self) -> Option<u32> {
None
}
fn parallelism(&self) -> Option<u32> {
Some(self.p)
}
}
#[must_use = "scrypt derive result must be checked"]
pub fn scrypt_derive(
password: &[u8],
salt: &[u8],
log_n: u8,
r: u32,
p: u32,
out: &mut [u8],
) -> Result<(), CryptoError> {
if out.is_empty() {
return Err(CryptoError::BadInput);
}
let params = RcScryptParams::new(log_n, r, p).map_err(|_| CryptoError::BadInput)?;
scrypt(password, salt, ¶ms, out).map_err(|_| CryptoError::Internal("scrypt failed"))
}
#[derive(Debug, Clone, Copy)]
pub struct ScryptHasher {
pub params: ScryptParams,
}
impl ScryptHasher {
pub fn new(params: ScryptParams) -> Result<Self, CryptoError> {
RcScryptParams::new(params.log_n, params.r, params.p).map_err(|_| CryptoError::BadInput)?;
Ok(Self { params })
}
pub fn new_checked(params: ScryptParams) -> Self {
Self::new(params).expect("invalid ScryptParams")
}
#[must_use]
pub fn interactive() -> Self {
Self {
params: ScryptParams::interactive(),
}
}
#[must_use]
pub fn moderate() -> Self {
Self {
params: ScryptParams::moderate(),
}
}
#[must_use]
pub fn sensitive() -> Self {
Self {
params: ScryptParams::sensitive(),
}
}
}
impl PasswordHashTrait for ScryptHasher {
fn name(&self) -> &'static str {
"scrypt"
}
fn hash_password(
&self,
password: &[u8],
salt: &[u8],
_params: &dyn PasswordHashParams,
out: &mut [u8],
) -> Result<(), CryptoError> {
scrypt_derive(
password,
salt,
self.params.log_n,
self.params.r,
self.params.p,
out,
)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn test_params() -> ScryptParams {
ScryptParams {
log_n: 1,
r: 1,
p: 1,
}
}
fn test_hasher() -> ScryptHasher {
ScryptHasher::new(test_params()).expect("test params valid")
}
const SALT: &[u8] = b"test-salt-16byte";
#[test]
fn scrypt_derive_deterministic() {
let p = test_params();
let mut out1 = [0u8; 32];
let mut out2 = [0u8; 32];
scrypt_derive(b"password", SALT, p.log_n, p.r, p.p, &mut out1).expect("derive 1");
scrypt_derive(b"password", SALT, p.log_n, p.r, p.p, &mut out2).expect("derive 2");
assert_eq!(out1, out2, "scrypt must be deterministic");
assert_ne!(out1, [0u8; 32]);
}
#[test]
fn scrypt_derive_empty_output_errors() {
let p = test_params();
let result = scrypt_derive(b"password", SALT, p.log_n, p.r, p.p, &mut []);
assert_eq!(result, Err(CryptoError::BadInput));
}
#[test]
fn password_hash_trait_deterministic() {
let hasher = test_hasher();
let mut out1 = [0u8; 32];
let mut out2 = [0u8; 32];
hasher
.hash_password(b"password", SALT, &hasher.params, &mut out1)
.expect("hash 1");
hasher
.hash_password(b"password", SALT, &hasher.params, &mut out2)
.expect("hash 2");
assert_eq!(out1, out2);
assert_ne!(out1, [0u8; 32]);
}
#[test]
fn preset_cost_ordering() {
let interactive = ScryptParams::interactive();
let moderate = ScryptParams::moderate();
let sensitive = ScryptParams::sensitive();
assert!(sensitive.log_n > moderate.log_n);
assert!(moderate.log_n > interactive.log_n);
assert!(sensitive.memory_cost() > moderate.memory_cost());
assert!(moderate.memory_cost() > interactive.memory_cost());
}
#[test]
fn scrypt_params_password_hash_params_impl() {
let p = ScryptParams::interactive();
assert!(p.memory_cost().is_some());
assert!(p.time_cost().is_none());
assert_eq!(p.parallelism(), Some(1));
}
#[test]
fn hasher_name() {
assert_eq!(test_hasher().name(), "scrypt");
}
}