use super::*;
use hex;
#[test]
fn test_shake128_empty_output() {
let mut xof = ShakeXof128::new();
let result = xof.squeeze_into_vec(0);
assert!(result.is_err());
let mut empty_buffer = [];
let result = xof.squeeze(&mut empty_buffer);
assert!(result.is_err());
}
#[test]
fn test_shake256_state_errors() {
let mut xof = ShakeXof256::new();
xof.finalize().unwrap();
let result = xof.update(b"test");
assert!(matches!(result, Err(Error::Processing { .. })));
let mut output = [0u8; 32];
assert!(xof.squeeze(&mut output).is_ok());
let result = xof.update(b"test");
assert!(matches!(result, Err(Error::Processing { .. })));
}
#[test]
fn test_shake_reset() {
let mut xof = ShakeXof128::new();
xof.update(b"test").unwrap();
xof.finalize().unwrap();
let mut first_output = [0u8; 32];
xof.squeeze(&mut first_output).unwrap();
xof.reset().unwrap();
assert!(xof.update(b"test").is_ok());
assert!(!xof.is_finalized);
assert!(!xof.squeezing);
}
#[test]
fn test_shake128_xof_variable_length() {
let empty_32_expected = "7f9c2ba4e88f827d616045507605853ed73b8093f6efbc88eb1a6eacfa66ef26";
let mut xof = ShakeXof128::new();
xof.update(&[]).unwrap();
let output = xof.squeeze_into_vec(32).unwrap();
assert_eq!(hex::encode(&output), empty_32_expected);
let output2 = xof.squeeze_into_vec(32).unwrap();
assert_ne!(output, output2);
let abc_32_expected = "5881092dd818bf5cf8a3ddb793fbcba74097d5c526a6d35f97b83351940f2cc8";
let mut xof = ShakeXof128::new();
xof.update(b"abc").unwrap();
let output = xof.squeeze_into_vec(32).unwrap();
assert_eq!(hex::encode(&output), abc_32_expected);
}
#[test]
fn test_shake256_xof_variable_length() {
let empty_64_expected = "46b9dd2b0ba88d13233b3feb743eeb243fcd52ea62b81b82b50c27646ed5762fd75dc4ddd8c0f200cb05019d67b592f6fc821c49479ab48640292eacb3b7c4be";
let mut xof = ShakeXof256::new();
xof.update(&[]).unwrap();
let output = xof.squeeze_into_vec(64).unwrap();
assert_eq!(hex::encode(&output), empty_64_expected);
let mut xof = ShakeXof256::new();
xof.update(&[]).unwrap();
let output1 = xof.squeeze_into_vec(32).unwrap();
let output2 = xof.squeeze_into_vec(32).unwrap();
let mut combined = Vec::new();
combined.extend_from_slice(&output1);
combined.extend_from_slice(&output2);
assert_eq!(hex::encode(&combined), empty_64_expected);
}
#[test]
fn test_xof_reuse_error() {
let mut xof = ShakeXof256::new();
xof.update(b"test").unwrap();
xof.finalize().unwrap();
let result = xof.update(b"more data");
assert!(result.is_err());
let mut xof = ShakeXof256::new();
xof.update(b"test").unwrap();
let _ = xof.squeeze_into_vec(32).unwrap();
let result = xof.update(b"more data");
assert!(result.is_err());
}
#[test]
fn test_xof_reset() {
let mut xof = ShakeXof256::new();
xof.update(b"test").unwrap();
let output1 = xof.squeeze_into_vec(32).unwrap();
xof.reset().unwrap();
xof.update(b"test").unwrap();
let output2 = xof.squeeze_into_vec(32).unwrap();
assert_eq!(output1, output2);
}
#[test]
fn test_shake_xof_incremental_output() {
let test_data = b"test data for incremental output";
let mut xof128 = ShakeXof128::new();
xof128.update(test_data).unwrap();
xof128.finalize().unwrap();
let part1_128 = xof128.squeeze_into_vec(50).unwrap();
let part2_128 = xof128.squeeze_into_vec(50).unwrap();
let mut xof128_all = ShakeXof128::new();
xof128_all.update(test_data).unwrap();
let all_128 = xof128_all.squeeze_into_vec(100).unwrap();
let mut combined_128 = part1_128.clone();
combined_128.extend_from_slice(&part2_128);
assert_eq!(
combined_128, all_128,
"SHAKE-128 incremental output doesn't match combined output"
);
let mut xof256 = ShakeXof256::new();
xof256.update(test_data).unwrap();
xof256.finalize().unwrap();
let part1_256 = xof256.squeeze_into_vec(50).unwrap();
let part2_256 = xof256.squeeze_into_vec(50).unwrap();
let mut xof256_all = ShakeXof256::new();
xof256_all.update(test_data).unwrap();
let all_256 = xof256_all.squeeze_into_vec(100).unwrap();
let mut combined_256 = part1_256.clone();
combined_256.extend_from_slice(&part2_256);
assert_eq!(
combined_256, all_256,
"SHAKE-256 incremental output doesn't match combined output"
);
}
#[test]
fn debug_shake_implementation() {
println!("\nDebugging SHAKE implementation:");
let empty_input: [u8; 0] = [];
let mut shake128 = ShakeXof128::new();
shake128.update(&empty_input).unwrap();
let shake_empty_result = shake128.squeeze_into_vec(32).unwrap();
println!(
"SHAKE-128 empty input (actual): {}",
hex::encode(&shake_empty_result)
);
println!("SHAKE-128 empty input (expected): 7f9c2ba4e88f827d616045507605853ed73b8093f6efbc88eb1a6eacfa66ef26");
let abc_input = b"abc";
let mut shake128 = ShakeXof128::new();
shake128.update(abc_input).unwrap();
let shake_abc_result = shake128.squeeze_into_vec(32).unwrap();
println!(
"SHAKE-128 'abc' input (actual): {}",
hex::encode(&shake_abc_result)
);
println!("SHAKE-128 'abc' input (expected): 5881092dd818bf5cf8a3ddb793fbcba74097d5c526a6d35f97b83351940f2cc8");
let mut shake256 = ShakeXof256::new();
shake256.update(&empty_input).unwrap();
let shake256_empty_result = shake256.squeeze_into_vec(64).unwrap();
println!(
"\nSHAKE-256 empty input (actual): {}",
hex::encode(&shake256_empty_result)
);
println!("SHAKE-256 empty input (expected): 46b9dd2b0ba88d13233b3feb743eeb243fcd52ea62b81b82b50c27646ed5762fd75dc4ddd8c0f200cb05019d67b592f6fc821c49479ab48640292eacb3b7c4be");
println!("\nAnalyzing SHAKE state transitions:");
let mut debug_shake = ShakeXof128::new();
debug_shake.update(&empty_input).unwrap();
println!("Before finalize, buffer_idx: {}", debug_shake.buffer_idx);
debug_shake.finalize().unwrap();
if debug_shake.buffer_idx == 0 {
println!("After finalize:");
println!(
" First byte (should be 0x1F): {:02x}",
debug_shake.buffer.as_ref()[0]
);
println!(
" Last byte (should be 0x80): {:02x}",
debug_shake.buffer.as_ref()[SHAKE128_RATE - 1]
);
}
}
fn run_shake_xof_tests<X: ExtendableOutputFunction>(filepath: &str, name: &str) {
let test_vectors = parse_shake_test_file(filepath);
println!("Found {} test vectors in {}", test_vectors.len(), filepath);
let mut tested = 0;
let mut skipped = 0;
for (i, test) in test_vectors.iter().enumerate() {
let bit_len = test.output_len;
let output_bytes = bit_len.div_ceil(8);
if bit_len == 0 || test.output.is_empty() {
println!(
"Skipping test case {}: zero-length output ({} bits)",
i, bit_len
);
skipped += 1;
continue;
}
let expected_bytes = test.output.len() / 2;
if output_bytes != expected_bytes {
println!(
"Skipping test case {}: declared {} bits → {} bytes, but hex is {} bytes",
i, bit_len, output_bytes, expected_bytes
);
skipped += 1;
continue;
}
if test.len == 0 {
let mut xof = X::new();
xof.update(&[]).unwrap();
let result = xof.squeeze_into_vec(output_bytes).unwrap();
let expected = hex::decode(&test.output).unwrap();
let n = result.len().min(expected.len());
assert_eq!(
&result[..n],
&expected[..n],
"{} test case {} failed (empty input, {} bits)",
name,
i,
bit_len
);
tested += 1;
continue;
}
if test.len % 8 != 0 {
let bytes = test.len / 8;
let bits = test.len % 8;
let msg_bytes = hex::decode(&test.msg).unwrap();
if bytes < msg_bytes.len() {
let mut truncated = msg_bytes[..bytes].to_vec();
if bits > 0 {
let mask = (1u8 << bits) - 1;
truncated.push(msg_bytes[bytes] & mask);
}
let mut xof = X::new();
xof.update(&truncated).unwrap();
let result = xof.squeeze_into_vec(output_bytes).unwrap();
let expected = hex::decode(&test.output).unwrap();
let m = result.len().min(expected.len());
assert_eq!(
&result[..m],
&expected[..m],
"{} test case {} failed ({}-bit input, {} bits out)",
name,
i,
test.len,
bit_len
);
tested += 1;
continue;
}
}
let msg = if test.msg.is_empty() {
Vec::new()
} else {
hex::decode(&test.msg).unwrap()
};
let mut xof = X::new();
xof.update(&msg).unwrap();
let result = xof.squeeze_into_vec(output_bytes).unwrap();
let expected = hex::decode(&test.output).unwrap();
let cmp_len = result.len().min(expected.len());
assert_eq!(
&result[..cmp_len],
&expected[..cmp_len],
"{} test case {} failed ({} bytes in → {} bits out)",
name,
i,
msg.len(),
bit_len
);
tested += 1;
}
println!("{} tests: {} passed, {} skipped", name, tested, skipped);
}
#[derive(Debug)]
struct ShakeTestVector {
len: usize, msg: String, output_len: usize, output: String, }
fn parse_shake_test_file(filepath: &str) -> Vec<ShakeTestVector> {
use std::fs::File;
use std::io::{BufRead, BufReader};
use std::path::Path;
let file = match File::open(Path::new(filepath)) {
Ok(f) => f,
Err(_) => {
println!("Test vector file not found: {}", filepath);
println!("Please ensure the test vectors are in the correct directory.");
return Vec::new();
}
};
let reader = BufReader::new(file);
let mut lines = reader.lines();
let mut test_vectors = Vec::new();
let mut current_vector: Option<ShakeTestVector> = None;
let mut in_test_group = false;
while let Some(Ok(line)) = lines.next() {
let line = line.trim();
if line.is_empty() || line.starts_with('#') {
continue;
}
if line.starts_with('[') && line.ends_with(']') {
in_test_group = true;
continue;
}
if let Some(stripped) = line.strip_prefix("Len = ") {
if let Some(vector) = current_vector.take() {
test_vectors.push(vector);
}
let len = match stripped.trim().parse::<usize>() {
Ok(val) => val,
Err(_) => {
println!("Warning: Invalid length format in line: {}", line);
continue;
}
};
current_vector = Some(ShakeTestVector {
len,
msg: String::new(),
output_len: 0, output: String::new(),
});
} else if let Some(stripped) = line.strip_prefix("OutLen = ") {
if let Some(ref mut vector) = current_vector {
vector.output_len = match stripped.trim().parse::<usize>() {
Ok(val) => val,
Err(_) => {
println!("Warning: Invalid output length format in line: {}", line);
0
}
};
}
} else if let Some(ref mut vector) = current_vector {
if let Some(stripped) = line.strip_prefix("Msg = ") {
vector.msg = stripped.trim().to_string();
} else if let Some(stripped) = line.strip_prefix("Output = ") {
vector.output = stripped.trim().to_string();
if vector.output_len == 0 && !vector.output.is_empty() {
vector.output_len = vector.output.len() * 4;
}
} else if line.starts_with("Count = ") && in_test_group {
if let Some(old_vector) = current_vector.take() {
test_vectors.push(old_vector);
}
current_vector = Some(ShakeTestVector {
len: 0, msg: String::new(),
output_len: 0,
output: String::new(),
});
}
}
}
if let Some(vector) = current_vector {
if !vector.output.is_empty() {
test_vectors.push(vector);
}
}
test_vectors
}
#[test]
fn test_shake_nist_variable_output() {
let base_path = env!("CARGO_MANIFEST_DIR");
let vectors_dir = format!("{}/../dcrypt-test/src/vectors", base_path);
let shake128_path = format!("{}/shake/SHAKE128VariableOut.rsp", vectors_dir);
let shake256_path = format!("{}/shake/SHAKE256VariableOut.rsp", vectors_dir);
run_shake_xof_tests::<ShakeXof128>(&shake128_path, "SHAKE-128");
run_shake_xof_tests::<ShakeXof256>(&shake256_path, "SHAKE-256");
let shake128_short_path = format!("{}/shake/SHAKE128ShortMsg.rsp", vectors_dir);
let shake256_short_path = format!("{}/shake/SHAKE256ShortMsg.rsp", vectors_dir);
run_shake_xof_tests::<ShakeXof128>(&shake128_short_path, "SHAKE-128 (Short)");
run_shake_xof_tests::<ShakeXof256>(&shake256_short_path, "SHAKE-256 (Short)");
}