use super::*;
use hex;
use std::path::{Path, PathBuf};
fn vectors_dir() -> PathBuf {
Path::new(env!("CARGO_MANIFEST_DIR"))
.join("..") .join("..") .join("tests")
.join("src")
.join("vectors")
.join("legacy_rsp")
.join("shake")
}
#[test]
fn test_shake128_empty() {
let expected = "7f9c2ba4e88f827d616045507605853ed73b8093f6efbc88eb1a6eacfa66ef26";
let hash = Shake128::digest(&[]).unwrap();
assert_eq!(hex::encode(&hash), expected);
}
#[test]
fn test_shake128_abc() {
let expected = "5881092dd818bf5cf8a3ddb793fbcba74097d5c526a6d35f97b83351940f2cc8";
let hash = Shake128::digest(b"abc").unwrap();
assert_eq!(hex::encode(&hash), expected);
}
#[test]
fn test_shake256_empty() {
let expected = "46b9dd2b0ba88d13233b3feb743eeb243fcd52ea62b81b82b50c27646ed5762fd75dc4ddd8c0f200cb05019d67b592f6fc821c49479ab48640292eacb3b7c4be";
let hash = Shake256::digest(&[]).unwrap();
assert_eq!(hex::encode(&hash), expected);
}
fn run_shake_tests<H: HashFunction>(filepath: &str, name: &str)
where
H::Output: AsRef<[u8]> + std::fmt::Debug,
{
use std::collections::HashMap;
use std::fs::File;
use std::io::{BufRead, BufReader};
use std::path::Path;
struct ShakeTestVector {
len: usize, msg: String, output_len: usize, output: String, }
let file = match File::open(Path::new(filepath)) {
Ok(f) => f,
Err(_) => {
println!("Test vector file not found: {}", filepath);
println!("Please ensure the test vectors are in the correct directory.");
return;
}
};
let reader = BufReader::new(file);
let mut lines = reader.lines();
let mut test_vectors = Vec::new();
let mut current_vector: Option<ShakeTestVector> = None;
while let Some(Ok(line)) = lines.next() {
let line = line.trim();
if line.is_empty() || line.starts_with('#') {
continue;
}
if let Some(len_str) = line.strip_prefix("Len = ") {
if let Some(vector) = current_vector.take() {
test_vectors.push(vector);
}
let len = len_str.parse::<usize>().unwrap();
current_vector = Some(ShakeTestVector {
len,
msg: String::new(),
output_len: 0,
output: String::new(),
});
} else if let Some(outlen_str) = line.strip_prefix("OutLen = ") {
if let Some(ref mut vector) = current_vector {
vector.output_len = outlen_str.parse::<usize>().unwrap();
}
} else if let Some(ref mut vector) = current_vector {
if let Some(msg) = line.strip_prefix("Msg = ") {
vector.msg = msg.to_string();
} else if let Some(output) = line.strip_prefix("Output = ") {
vector.output = output.to_string();
if vector.output_len == 0 && !vector.output.is_empty() {
vector.output_len = vector.output.len() * 4;
}
}
}
}
if let Some(vector) = current_vector {
test_vectors.push(vector);
}
println!("Found {} test vectors in {}", test_vectors.len(), filepath);
let mut tested = 0;
let mut skipped = 0;
let mut skipped_output_sizes = HashMap::new();
for (i, test) in test_vectors.iter().enumerate() {
let expected_output_bytes = if test.output_len == 0 {
test.output.len() / 2
} else {
test.output_len / 8
};
if expected_output_bytes != H::output_size() {
skipped += 1;
*skipped_output_sizes.entry(test.output_len).or_insert(0) += 1;
continue;
}
if test.len == 0 {
let hash = H::digest(&[]).unwrap();
let expected = hex::decode(&test.output).unwrap();
assert_eq!(
hash.as_ref(),
expected.as_slice(),
"{} test case {} failed.",
name,
i
);
tested += 1;
continue;
}
let msg = if test.msg.is_empty() {
Vec::new()
} else {
hex::decode(&test.msg).unwrap()
};
if test.len % 8 != 0 {
let bytes = test.len / 8;
let bits = test.len % 8;
if bytes < msg.len() {
let mut truncated_msg = msg[..bytes].to_vec();
if bits > 0 {
let mask = (1u8 << bits) - 1;
truncated_msg.push(msg[bytes] & mask);
}
let hash = H::digest(&truncated_msg).unwrap();
let expected = hex::decode(&test.output).unwrap();
assert_eq!(
hash.as_ref(),
expected.as_slice(),
"{} test case {} failed.",
name,
i
);
tested += 1;
continue;
}
}
let hash = H::digest(&msg).unwrap();
let expected = hex::decode(&test.output).unwrap();
assert_eq!(
hash.as_ref(),
expected.as_slice(),
"{} test case {} failed.",
name,
i
);
tested += 1;
}
println!("{} tests: {} passed, {} skipped", name, tested, skipped);
if skipped > 0 {
println!("Skipped test vectors by output size:");
let mut sorted_sizes: Vec<_> = skipped_output_sizes.iter().collect();
sorted_sizes.sort_by_key(|&(size, _)| *size);
for (output_len, count) in sorted_sizes {
println!(
" - {} test vectors with {} bits output (expected {} bytes)",
count,
output_len,
H::output_size()
);
}
}
}
#[test]
fn test_shake_nist_short_vectors() {
let vectors_dir = vectors_dir();
let shake_128_path = vectors_dir.join("SHAKE128ShortMsg.rsp");
let shake_256_path = vectors_dir.join("SHAKE256ShortMsg.rsp");
for path in [&shake_128_path, &shake_256_path] {
assert!(
path.exists(),
"Test vector file not found: {}",
path.display()
);
}
run_shake_tests::<Shake128>(shake_128_path.to_str().unwrap(), "SHAKE-128");
run_shake_tests::<Shake256>(shake_256_path.to_str().unwrap(), "SHAKE-256");
}
#[test]
fn test_shake_nist_long_vectors() {
let vectors_dir = vectors_dir();
let shake_128_path = vectors_dir.join("SHAKE128LongMsg.rsp");
let shake_256_path = vectors_dir.join("SHAKE256LongMsg.rsp");
for path in [&shake_128_path, &shake_256_path] {
assert!(
path.exists(),
"Test vector file not found: {}",
path.display()
);
}
run_shake_tests::<Shake128>(shake_128_path.to_str().unwrap(), "SHAKE-128");
run_shake_tests::<Shake256>(shake_256_path.to_str().unwrap(), "SHAKE-256");
}