const DOUBLE_ROUNDS: usize = 10;
#[inline(always)]
fn quarter_round(state: &mut [u32], a: usize, b: usize, c: usize, d: usize) {
state[a] = state[a].wrapping_add(state[b]);
state[d] ^= state[a];
state[d] = state[d].rotate_left(16);
state[c] = state[c].wrapping_add(state[d]);
state[b] ^= state[c];
state[b] = state[b].rotate_left(12);
state[a] = state[a].wrapping_add(state[b]);
state[d] ^= state[a];
state[d] = state[d].rotate_left(8);
state[c] = state[c].wrapping_add(state[d]);
state[b] ^= state[c];
state[b] = state[b].rotate_left(7);
}
pub fn xor(data: &mut [u8], key: &[u32], counter: &[u32]) {
assert!(data.len() <= 64, "Data length must not exceed 64 bytes");
let mut state = [0u32; 16];
state[0] = 0x61707865;
state[1] = 0x3320646e;
state[2] = 0x79622d32;
state[3] = 0x6b206574;
state[4..(8 + 4)].copy_from_slice(key);
state[12..].copy_from_slice(counter);
let mut working_state = state;
for _ in 0..DOUBLE_ROUNDS {
quarter_round(&mut working_state, 0, 4, 8, 12);
quarter_round(&mut working_state, 1, 5, 9, 13);
quarter_round(&mut working_state, 2, 6, 10, 14);
quarter_round(&mut working_state, 3, 7, 11, 15);
quarter_round(&mut working_state, 0, 5, 10, 15);
quarter_round(&mut working_state, 1, 6, 11, 12);
quarter_round(&mut working_state, 2, 7, 8, 13);
quarter_round(&mut working_state, 3, 4, 9, 14);
}
for (w, s) in working_state.iter_mut().zip(state.iter()) {
*w = w.wrapping_add(*s);
}
let mut block = [0u8; 64];
for (i, word) in working_state.iter().enumerate() {
block[i * 4..(i + 1) * 4].copy_from_slice(&word.to_le_bytes());
}
for (b, k) in data.iter_mut().zip(block.iter()) {
*b ^= *k;
}
}
#[cfg(test)]
mod tests {
use crate::fallback_chacha20::xor;
#[test]
fn test_chacha20_xor_encrypt_decrypt() {
let key = [0u8; 32];
let nonce = [0u8; 12];
let mut key_words = [0u32; 8];
for i in 0..8 {
key_words[i] =
u32::from_le_bytes([key[i * 4], key[i * 4 + 1], key[i * 4 + 2], key[i * 4 + 3]]);
}
let mut counter = [0u32; 4];
for i in 0..3 {
counter[i + 1] = u32::from_le_bytes([
nonce[i * 4],
nonce[i * 4 + 1],
nonce[i * 4 + 2],
nonce[i * 4 + 3],
]);
}
let plaintext = b"Hello, ChaCha20 fallback test!";
let mut buffer = [0u8; 32];
buffer[..plaintext.len()].copy_from_slice(plaintext);
xor(&mut buffer[..plaintext.len()], &key_words, &counter);
let mut ciphertext = [0u8; 64];
ciphertext[..plaintext.len()].copy_from_slice(&buffer[..plaintext.len()]);
xor(&mut buffer[..plaintext.len()], &key_words, &counter);
assert_eq!(&buffer[..plaintext.len()], plaintext);
assert_ne!(&ciphertext[..plaintext.len()], plaintext);
}
}