extern crate alloc;
use alloc::format;
use alloc::string::String;
use alloc::vec;
use alloc::vec::Vec;
use redoubt_util::hex_to_bytes;
use redoubt_hkdf_core::HkdfApi;
#[allow(dead_code)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Flag {
Normal,
EmptySalt,
MaximalOutputSize,
SizeTooLarge,
OutputCollision,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TestResult {
Valid,
Invalid,
#[allow(dead_code)]
Acceptable,
}
pub struct TestCase {
pub tc_id: usize,
pub comment: String,
pub flags: Vec<Flag>,
pub ikm: String,
pub salt: String,
pub info: String,
pub size: usize,
pub okm: String,
pub result: TestResult,
}
fn run_test_case(backend: &mut impl HkdfApi, tc: &TestCase) -> Result<(), String> {
let ikm = hex_to_bytes(&tc.ikm);
let salt = hex_to_bytes(&tc.salt);
let info = hex_to_bytes(&tc.info);
let expected_okm = hex_to_bytes(&tc.okm);
let mut out = vec![0u8; tc.size];
let result = backend.api_hkdf(&salt, &ikm, &info, &mut out);
match (&tc.result, &result) {
(TestResult::Valid, Ok(())) | (TestResult::Acceptable, Ok(())) => {
if out == expected_okm {
Ok(())
} else {
Err(format!(
"tc_id {} ({}): output mismatch\n expected: {}\n got: {}",
tc.tc_id,
tc.comment,
tc.okm,
hex::encode(&out)
))
}
}
(TestResult::Valid, Err(e)) | (TestResult::Acceptable, Err(e)) => Err(format!(
"tc_id {} ({}): expected valid but got error: {:?}",
tc.tc_id, tc.comment, e
)),
(TestResult::Invalid, Ok(())) => Err(format!(
"tc_id {} ({}): expected invalid but derivation succeeded",
tc.tc_id, tc.comment
)),
(TestResult::Invalid, Err(_)) => Ok(()),
}
}
mod hex {
use super::String;
use super::format;
pub fn encode(data: &[u8]) -> String {
data.iter().map(|b| format!("{:02x}", b)).collect()
}
}
pub fn run_hkdf_wycheproof_tests(backend: &mut impl HkdfApi) {
use super::hkdf_sha256_wycheproof_vectors::test_vectors;
let vectors = test_vectors();
let mut failures = Vec::new();
for tc in vectors.iter() {
if let Err(msg) = run_test_case(backend, tc) {
failures.push(msg);
}
}
if !failures.is_empty() {
panic!(
"HKDF-SHA256 Wycheproof test failures ({}/{}):\n{}",
failures.len(),
vectors.len(),
failures.join("\n")
);
}
}