use noxtls_core::{Error, Result};
#[derive(Debug, Clone)]
pub struct ChaCha20 {
state: [u32; 16],
}
impl ChaCha20 {
pub fn new(key: &[u8; 32], nonce: &[u8; 12], counter: u32) -> Self {
let constants: [u8; 16] = *b"expand 32-byte k";
let mut state = [0_u32; 16];
state[0] = u32::from_le_bytes(constants[0..4].try_into().expect("len"));
state[1] = u32::from_le_bytes(constants[4..8].try_into().expect("len"));
state[2] = u32::from_le_bytes(constants[8..12].try_into().expect("len"));
state[3] = u32::from_le_bytes(constants[12..16].try_into().expect("len"));
for i in 0..8 {
state[4 + i] = u32::from_le_bytes(key[i * 4..i * 4 + 4].try_into().expect("len"));
}
state[12] = counter;
state[13] = u32::from_le_bytes(nonce[0..4].try_into().expect("len"));
state[14] = u32::from_le_bytes(nonce[4..8].try_into().expect("len"));
state[15] = u32::from_le_bytes(nonce[8..12].try_into().expect("len"));
Self { state }
}
fn quarter_round(state: &mut [u32; 16], a: usize, b: usize, c: usize, d: usize) {
state[a] = state[a].wrapping_add(state[b]);
state[d] ^= state[a];
state[d] = state[d].rotate_left(16);
state[c] = state[c].wrapping_add(state[d]);
state[b] ^= state[c];
state[b] = state[b].rotate_left(12);
state[a] = state[a].wrapping_add(state[b]);
state[d] ^= state[a];
state[d] = state[d].rotate_left(8);
state[c] = state[c].wrapping_add(state[d]);
state[b] ^= state[c];
state[b] = state[b].rotate_left(7);
}
pub fn block_output(&self) -> [u8; 64] {
self.block()
}
fn block(&self) -> [u8; 64] {
let mut working = self.state;
for _ in 0..10 {
Self::quarter_round(&mut working, 0, 4, 8, 12);
Self::quarter_round(&mut working, 1, 5, 9, 13);
Self::quarter_round(&mut working, 2, 6, 10, 14);
Self::quarter_round(&mut working, 3, 7, 11, 15);
Self::quarter_round(&mut working, 0, 5, 10, 15);
Self::quarter_round(&mut working, 1, 6, 11, 12);
Self::quarter_round(&mut working, 2, 7, 8, 13);
Self::quarter_round(&mut working, 3, 4, 9, 14);
}
for (w, s) in working.iter_mut().zip(self.state) {
*w = w.wrapping_add(s);
}
let mut out = [0_u8; 64];
for (i, word) in working.iter().enumerate() {
out[i * 4..(i + 1) * 4].copy_from_slice(&word.to_le_bytes());
}
out
}
pub fn apply_keystream(&mut self, input: &[u8], output: &mut [u8]) -> Result<()> {
if output.len() != input.len() {
return Err(Error::InvalidLength("input and output length mismatch"));
}
let mut offset = 0;
while offset < input.len() {
let block = self.block();
self.state[12] = self.state[12].wrapping_add(1);
let chunk_len = (input.len() - offset).min(64);
for idx in 0..chunk_len {
output[offset + idx] = input[offset + idx] ^ block[idx];
}
offset += chunk_len;
}
Ok(())
}
}