use alloy_primitives::Keccak256;
use super::error::EncryptionError;
use super::key::EncryptionKey;
#[inline]
fn key_state(key: &EncryptionKey) -> Keccak256 {
let mut h = Keccak256::new();
h.update(key.as_bytes());
h
}
#[inline]
fn derive_segment_key(key_state: &Keccak256, counter: u32) -> [u8; 32] {
let mut h1 = key_state.clone();
h1.update(counter.to_le_bytes());
let round1 = h1.finalize();
let mut h2 = Keccak256::new();
h2.update(round1.as_slice());
h2.finalize().into()
}
#[inline]
fn apply_keystream(key: &EncryptionKey, init_ctr: u32, data: &mut [u8]) {
let ks = key_state(key);
for (i, chunk) in data.chunks_mut(EncryptionKey::SIZE).enumerate() {
let seg = derive_segment_key(&ks, init_ctr.wrapping_add(i as u32));
for (j, byte) in chunk.iter_mut().enumerate() {
*byte ^= seg[j];
}
}
}
#[inline]
pub fn transcrypt(
key: &EncryptionKey,
init_ctr: u32,
input: &[u8],
output: &mut [u8],
) -> Result<(), EncryptionError> {
if output.len() < input.len() {
return Err(EncryptionError::OutputBufferTooSmall {
len: output.len(),
required: input.len(),
});
}
output[..input.len()].copy_from_slice(input);
apply_keystream(key, init_ctr, &mut output[..input.len()]);
Ok(())
}
#[inline]
pub fn transcrypt_in_place(key: &EncryptionKey, init_ctr: u32, data: &mut [u8]) {
apply_keystream(key, init_ctr, data);
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn go_test_vector() {
let key_hex = "8abf1502f557f15026716030fb6384792583daf39608a3cd02ff2f47e9bc6e49";
let key_bytes: [u8; 32] = hex::decode(key_hex).unwrap().try_into().unwrap();
let key = EncryptionKey::from(key_bytes);
let input = [0u8; 4096];
let mut output = [0u8; 4096];
transcrypt(&key, 0, &input, &mut output).unwrap();
let expected_hex = include_str!("testdata/go_vector_4096.hex");
let expected = hex::decode(expected_hex.trim()).unwrap();
assert_eq!(output.as_slice(), expected.as_slice());
}
#[test]
fn go_test_vector_in_place() {
let key_hex = "8abf1502f557f15026716030fb6384792583daf39608a3cd02ff2f47e9bc6e49";
let key_bytes: [u8; 32] = hex::decode(key_hex).unwrap().try_into().unwrap();
let key = EncryptionKey::from(key_bytes);
let mut data = [0u8; 4096];
transcrypt_in_place(&key, 0, &mut data);
let expected_hex = include_str!("testdata/go_vector_4096.hex");
let expected = hex::decode(expected_hex.trim()).unwrap();
assert_eq!(data.as_slice(), expected.as_slice());
}
#[test]
fn symmetry() {
let key = EncryptionKey::from([0x42; 32]);
let plaintext = b"hello world, this is a test!!!!!"; let mut ciphertext = [0u8; 32];
let mut recovered = [0u8; 32];
transcrypt(&key, 0, plaintext, &mut ciphertext).unwrap();
assert_ne!(&ciphertext[..], plaintext);
transcrypt(&key, 0, &ciphertext, &mut recovered).unwrap();
assert_eq!(&recovered[..], plaintext);
}
#[test]
fn in_place_symmetry() {
let key = EncryptionKey::from([0x42; 32]);
let original = *b"hello world, this is a test!!!!!";
let mut data = original;
transcrypt_in_place(&key, 0, &mut data);
assert_ne!(data, original);
transcrypt_in_place(&key, 0, &mut data);
assert_eq!(data, original);
}
#[test]
fn in_place_matches_transcrypt() {
let key = EncryptionKey::from([0xbb; 32]);
let input = [0x77u8; 256];
let mut via_transcrypt = [0u8; 256];
transcrypt(&key, 3, &input, &mut via_transcrypt).unwrap();
let mut via_in_place = input;
transcrypt_in_place(&key, 3, &mut via_in_place);
assert_eq!(via_transcrypt, via_in_place);
}
#[test]
fn segmented_equals_whole() {
let key = EncryptionKey::from([0xaa; 32]);
let input = [0x55u8; 128]; let mut whole = [0u8; 128];
transcrypt(&key, 0, &input, &mut whole).unwrap();
let mut segmented = [0u8; 128];
for (i, chunk) in input.chunks(EncryptionKey::SIZE).enumerate() {
transcrypt(
&key,
i as u32,
chunk,
&mut segmented[i * EncryptionKey::SIZE..],
)
.unwrap();
}
assert_eq!(whole, segmented);
}
#[test]
fn partial_block() {
let key = EncryptionKey::from([0x11; 32]);
let input = [0xffu8; 17]; let mut output = [0u8; 17];
transcrypt(&key, 0, &input, &mut output).unwrap();
let mut recovered = [0u8; 17];
transcrypt(&key, 0, &output, &mut recovered).unwrap();
assert_eq!(recovered, input);
}
#[test]
fn nonzero_init_ctr() {
let key = EncryptionKey::from([0x33; 32]);
let input = [0u8; 32];
let mut out_ctr0 = [0u8; 32];
let mut out_ctr5 = [0u8; 32];
transcrypt(&key, 0, &input, &mut out_ctr0).unwrap();
transcrypt(&key, 5, &input, &mut out_ctr5).unwrap();
assert_ne!(out_ctr0, out_ctr5);
}
mod hex {
pub(super) fn decode(s: &str) -> Result<Vec<u8>, String> {
if !s.len().is_multiple_of(2) {
return Err("odd length".into());
}
(0..s.len())
.step_by(2)
.map(|i| u8::from_str_radix(&s[i..i + 2], 16).map_err(|e| e.to_string()))
.collect()
}
}
}