use noxtls_core::{Error, Result};
const CONSTANTS: [u32; 4] = [0x6170_7865, 0x3320_646e, 0x7962_2d32, 0x6b20_6574];
#[derive(Debug, Clone)]
pub struct ChaCha20 {
state: [u32; 16],
}
impl ChaCha20 {
#[inline(always)]
pub(crate) fn noxtls_prepare_key_words(key: &[u8; 32]) -> [u32; 12] {
let mut words = [0_u32; 12];
words[..4].copy_from_slice(&CONSTANTS);
for (idx, chunk) in key.chunks_exact(4).enumerate() {
words[4 + idx] = Self::read_u32_le(chunk);
}
words
}
#[inline(always)]
pub(crate) fn noxtls_from_prepared(
prepared: &[u32; 12],
nonce: &[u8; 12],
counter: u32,
) -> Self {
let mut state = [0_u32; 16];
state[..12].copy_from_slice(prepared);
state[12] = counter;
state[13] = Self::read_u32_le(&nonce[0..4]);
state[14] = Self::read_u32_le(&nonce[4..8]);
state[15] = Self::read_u32_le(&nonce[8..12]);
Self { state }
}
pub fn noxtls_new(key: &[u8; 32], nonce: &[u8; 12], counter: u32) -> Self {
let prepared = Self::noxtls_prepare_key_words(key);
Self::noxtls_from_prepared(&prepared, nonce, counter)
}
#[inline(always)]
fn quarter_round(state: &mut [u32; 16], a: usize, b: usize, c: usize, d: usize) {
state[a] = state[a].wrapping_add(state[b]);
state[d] ^= state[a];
state[d] = state[d].rotate_left(16);
state[c] = state[c].wrapping_add(state[d]);
state[b] ^= state[c];
state[b] = state[b].rotate_left(12);
state[a] = state[a].wrapping_add(state[b]);
state[d] ^= state[a];
state[d] = state[d].rotate_left(8);
state[c] = state[c].wrapping_add(state[d]);
state[b] ^= state[c];
state[b] = state[b].rotate_left(7);
}
pub fn block_output(&self) -> [u8; 64] {
let mut out = [0_u8; 64];
Self::write_block_bytes(&self.block_words(), &mut out);
out
}
#[inline(always)]
pub(crate) fn block_words(&self) -> [u32; 16] {
let mut working = self.state;
for _ in 0..10 {
Self::quarter_round(&mut working, 0, 4, 8, 12);
Self::quarter_round(&mut working, 1, 5, 9, 13);
Self::quarter_round(&mut working, 2, 6, 10, 14);
Self::quarter_round(&mut working, 3, 7, 11, 15);
Self::quarter_round(&mut working, 0, 5, 10, 15);
Self::quarter_round(&mut working, 1, 6, 11, 12);
Self::quarter_round(&mut working, 2, 7, 8, 13);
Self::quarter_round(&mut working, 3, 4, 9, 14);
}
for (w, s) in working.iter_mut().zip(self.state) {
*w = w.wrapping_add(s);
}
working
}
pub fn apply_keystream(&mut self, input: &[u8], output: &mut [u8]) -> Result<()> {
if output.len() != input.len() {
return Err(Error::InvalidLength("input and output length mismatch"));
}
let mut offset = 0;
while offset < input.len() {
let block = self.block_words();
self.state[12] = self.state[12].wrapping_add(1);
let chunk_len = (input.len() - offset).min(64);
let input_chunk = &input[offset..offset + chunk_len];
let output_chunk = &mut output[offset..offset + chunk_len];
let full_words = chunk_len / 4;
for (word_idx, block_word) in block.iter().enumerate().take(full_words) {
let byte_idx = word_idx * 4;
let in_word = Self::read_u32_le(&input_chunk[byte_idx..byte_idx + 4]);
let out_word = in_word ^ *block_word;
output_chunk[byte_idx..byte_idx + 4].copy_from_slice(&out_word.to_le_bytes());
}
let byte_offset = full_words * 4;
if byte_offset < chunk_len {
let keystream = block[full_words].to_le_bytes();
for tail_idx in 0..(chunk_len - byte_offset) {
output_chunk[byte_offset + tail_idx] =
input_chunk[byte_offset + tail_idx] ^ keystream[tail_idx];
}
}
offset += chunk_len;
}
Ok(())
}
#[inline(always)]
fn read_u32_le(bytes: &[u8]) -> u32 {
u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]])
}
#[inline(always)]
fn write_block_bytes(block: &[u32; 16], out: &mut [u8; 64]) {
for (chunk, word) in out.chunks_exact_mut(4).zip(block.iter()) {
chunk.copy_from_slice(&word.to_le_bytes());
}
}
}