const CONSTANTS: [u32; 4] = [0x6170_7865, 0x3320_646e, 0x7962_2d32, 0x6b20_6574];
#[inline]
fn quarter_round(s: &mut [u32; 16], a: usize, b: usize, c: usize, d: usize) {
s[a] = s[a].wrapping_add(s[b]);
s[d] = (s[d] ^ s[a]).rotate_left(16);
s[c] = s[c].wrapping_add(s[d]);
s[b] = (s[b] ^ s[c]).rotate_left(12);
s[a] = s[a].wrapping_add(s[b]);
s[d] = (s[d] ^ s[a]).rotate_left(8);
s[c] = s[c].wrapping_add(s[d]);
s[b] = (s[b] ^ s[c]).rotate_left(7);
}
pub(crate) fn hchacha20(key: &[u8; 32], nonce16: &[u8; 16]) -> [u8; 32] {
let mut s = [0u32; 16];
s[0..4].copy_from_slice(&CONSTANTS);
for (word, chunk) in s[4..12].iter_mut().zip(key.chunks_exact(4)) {
*word = u32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]);
}
for (word, chunk) in s[12..16].iter_mut().zip(nonce16.chunks_exact(4)) {
*word = u32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]);
}
for _ in 0..10 {
quarter_round(&mut s, 0, 4, 8, 12);
quarter_round(&mut s, 1, 5, 9, 13);
quarter_round(&mut s, 2, 6, 10, 14);
quarter_round(&mut s, 3, 7, 11, 15);
quarter_round(&mut s, 0, 5, 10, 15);
quarter_round(&mut s, 1, 6, 11, 12);
quarter_round(&mut s, 2, 7, 8, 13);
quarter_round(&mut s, 3, 4, 9, 14);
}
let mut out = [0u8; 32];
for (i, &idx) in [0usize, 1, 2, 3, 12, 13, 14, 15].iter().enumerate() {
out[i * 4..i * 4 + 4].copy_from_slice(&s[idx].to_le_bytes());
}
out
}
#[derive(Clone)]
pub struct ChaCha20 {
key: [u32; 8],
}
impl ChaCha20 {
pub fn new(key: &[u8; 32]) -> Self {
let mut k = [0u32; 8];
for (word, chunk) in k.iter_mut().zip(key.chunks_exact(4)) {
*word = u32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]);
}
ChaCha20 { key: k }
}
fn state(&self, nonce: &[u8; 12], counter: u32) -> [u32; 16] {
let mut s = [0u32; 16];
s[0..4].copy_from_slice(&CONSTANTS);
s[4..12].copy_from_slice(&self.key);
s[12] = counter;
for (word, chunk) in s[13..16].iter_mut().zip(nonce.chunks_exact(4)) {
*word = u32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]);
}
s
}
pub fn block(&self, nonce: &[u8; 12], counter: u32) -> [u8; 64] {
let initial = self.state(nonce, counter);
let mut s = initial;
for _ in 0..10 {
quarter_round(&mut s, 0, 4, 8, 12);
quarter_round(&mut s, 1, 5, 9, 13);
quarter_round(&mut s, 2, 6, 10, 14);
quarter_round(&mut s, 3, 7, 11, 15);
quarter_round(&mut s, 0, 5, 10, 15);
quarter_round(&mut s, 1, 6, 11, 12);
quarter_round(&mut s, 2, 7, 8, 13);
quarter_round(&mut s, 3, 4, 9, 14);
}
let mut out = [0u8; 64];
for (i, chunk) in out.chunks_exact_mut(4).enumerate() {
let word = s[i].wrapping_add(initial[i]);
chunk.copy_from_slice(&word.to_le_bytes());
}
out
}
pub fn apply_keystream(&self, nonce: &[u8; 12], counter: u32, buf: &mut [u8]) {
let blocks_needed = buf.len().div_ceil(64) as u64;
let counter_end = (counter as u64) + blocks_needed;
assert!(
counter_end <= u64::from(u32::MAX) + 1,
"ChaCha20 counter would overflow 2^32 (buf too large for a single invocation)"
);
#[cfg(all(feature = "std", target_arch = "x86_64"))]
if simd::supported() {
simd::apply_keystream(&self.key, nonce, counter, buf);
return;
}
let mut block_counter = counter;
for block in buf.chunks_mut(64) {
let ks = self.block(nonce, block_counter);
for (b, k) in block.iter_mut().zip(ks.iter()) {
*b ^= *k;
}
block_counter = block_counter.wrapping_add(1);
}
}
}
#[cfg(all(feature = "std", target_arch = "x86_64"))]
#[allow(unsafe_code)]
mod simd {
use super::CONSTANTS;
use core::arch::x86_64::*;
pub(super) fn supported() -> bool {
std::is_x86_feature_detected!("avx2")
}
#[inline(always)]
unsafe fn rol<const N: i32, const M: i32>(x: __m256i) -> __m256i {
unsafe { _mm256_or_si256(_mm256_slli_epi32::<N>(x), _mm256_srli_epi32::<M>(x)) }
}
#[inline(always)]
#[allow(clippy::too_many_arguments)]
unsafe fn qr(v: &mut [__m256i; 16], a: usize, b: usize, c: usize, d: usize) {
unsafe {
v[a] = _mm256_add_epi32(v[a], v[b]);
v[d] = rol::<16, 16>(_mm256_xor_si256(v[d], v[a]));
v[c] = _mm256_add_epi32(v[c], v[d]);
v[b] = rol::<12, 20>(_mm256_xor_si256(v[b], v[c]));
v[a] = _mm256_add_epi32(v[a], v[b]);
v[d] = rol::<8, 24>(_mm256_xor_si256(v[d], v[a]));
v[c] = _mm256_add_epi32(v[c], v[d]);
v[b] = rol::<7, 25>(_mm256_xor_si256(v[b], v[c]));
}
}
pub(super) fn apply_keystream(key: &[u32; 8], nonce: &[u8; 12], counter: u32, buf: &mut [u8]) {
unsafe { apply_keystream_avx2(key, nonce, counter, buf) }
}
#[target_feature(enable = "avx2")]
unsafe fn apply_keystream_avx2(key: &[u32; 8], nonce: &[u8; 12], counter: u32, buf: &mut [u8]) {
unsafe {
let n =
|o: usize| u32::from_le_bytes([nonce[o], nonce[o + 1], nonce[o + 2], nonce[o + 3]]);
let (n0, n1, n2) = (n(0), n(4), n(8));
let lanes = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
let mut ctr = counter;
let mut off = 0usize;
while off < buf.len() {
let mut v = [
_mm256_set1_epi32(CONSTANTS[0] as i32),
_mm256_set1_epi32(CONSTANTS[1] as i32),
_mm256_set1_epi32(CONSTANTS[2] as i32),
_mm256_set1_epi32(CONSTANTS[3] as i32),
_mm256_set1_epi32(key[0] as i32),
_mm256_set1_epi32(key[1] as i32),
_mm256_set1_epi32(key[2] as i32),
_mm256_set1_epi32(key[3] as i32),
_mm256_set1_epi32(key[4] as i32),
_mm256_set1_epi32(key[5] as i32),
_mm256_set1_epi32(key[6] as i32),
_mm256_set1_epi32(key[7] as i32),
_mm256_add_epi32(_mm256_set1_epi32(ctr as i32), lanes),
_mm256_set1_epi32(n0 as i32),
_mm256_set1_epi32(n1 as i32),
_mm256_set1_epi32(n2 as i32),
];
let init = v;
for _ in 0..10 {
qr(&mut v, 0, 4, 8, 12);
qr(&mut v, 1, 5, 9, 13);
qr(&mut v, 2, 6, 10, 14);
qr(&mut v, 3, 7, 11, 15);
qr(&mut v, 0, 5, 10, 15);
qr(&mut v, 1, 6, 11, 12);
qr(&mut v, 2, 7, 8, 13);
qr(&mut v, 3, 4, 9, 14);
}
let mut words = [[0u32; 8]; 16];
for i in 0..16 {
let added = _mm256_add_epi32(v[i], init[i]);
_mm256_storeu_si256(words[i].as_mut_ptr() as *mut __m256i, added);
}
let avail = (buf.len() - off).min(512);
let mut ks = [0u8; 512];
for (b, blk) in ks.chunks_exact_mut(64).enumerate() {
for (i, word) in blk.chunks_exact_mut(4).enumerate() {
word.copy_from_slice(&words[i][b].to_le_bytes());
}
}
for (dst, k) in buf[off..off + avail].iter_mut().zip(ks.iter()) {
*dst ^= *k;
}
off += 512;
ctr = ctr.wrapping_add(8);
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_util::from_hex;
#[test]
fn rfc8439_block_function() {
let key =
from_hex::<32>("000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f");
let nonce = from_hex::<12>("000000090000004a00000000");
let c = ChaCha20::new(&key);
let block = c.block(&nonce, 1);
assert_eq!(
block,
from_hex::<64>(
"10f1e7e4d13b5915500fdd1fa32071c4c7d1f4c733c068030422aa9ac3d46c4e\
d2826446079faa0914c2d705d98b02a2b5129cd1de164eb9cbd083e8a2503c4e"
)
);
}
#[test]
fn xchacha_draft_hchacha20() {
let key =
from_hex::<32>("000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f");
let nonce = from_hex::<16>("000000090000004a0000000031415927");
let out = hchacha20(&key, &nonce);
assert_eq!(
out,
from_hex::<32>("82413b4227b27bfed30e42508a877d73a0f9e4d58a74a853c12ec41326d3ecdc")
);
}
#[test]
fn rfc8439_encryption() {
let key =
from_hex::<32>("000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f");
let nonce = from_hex::<12>("000000000000004a00000000");
let plaintext = b"Ladies and Gentlemen of the class of '99: If I could offer you \
only one tip for the future, sunscreen would be it.";
let mut buf = plaintext.to_vec();
ChaCha20::new(&key).apply_keystream(&nonce, 1, &mut buf);
let expected = from_hex::<114>(
"6e2e359a2568f98041ba0728dd0d6981e97e7aec1d4360c20a27afccfd9fae0b\
f91b65c5524733ab8f593dabcd62b3571639d624e65152ab8f530c359f0861d8\
07ca0dbf500d6a6156a38e088a22b65e52bc514d16ccf806818ce91ab7793736\
5af90bbf74a35be6b40b8eedf2785e42874d",
);
assert_eq!(buf, expected);
ChaCha20::new(&key).apply_keystream(&nonce, 1, &mut buf);
assert_eq!(buf, plaintext);
}
#[cfg(all(feature = "std", target_arch = "x86_64"))]
#[test]
fn simd_matches_scalar() {
use alloc::vec::Vec;
if !super::simd::supported() {
return;
}
let key: [u8; 32] = core::array::from_fn(|i| (i as u8).wrapping_mul(7).wrapping_add(1));
let nonce: [u8; 12] = core::array::from_fn(|i| (i as u8).wrapping_mul(5).wrapping_add(3));
let c = ChaCha20::new(&key);
for &len in &[
0usize,
1,
63,
64,
65,
200,
511,
512,
513,
1024,
1500,
4096 + 7,
] {
for &ctr in &[0u32, 1, 100, u32::MAX - 20] {
if (ctr as u64) + (len.div_ceil(64) as u64) > u64::from(u32::MAX) + 1 {
continue;
}
let base: Vec<u8> = (0..len).map(|i| (i * 31 + 9) as u8).collect();
let mut simd = base.clone();
c.apply_keystream(&nonce, ctr, &mut simd);
let mut scalar = base.clone();
let mut bc = ctr;
for block in scalar.chunks_mut(64) {
let ks = c.block(&nonce, bc);
for (b, k) in block.iter_mut().zip(ks.iter()) {
*b ^= *k;
}
bc = bc.wrapping_add(1);
}
assert_eq!(simd, scalar, "len={len} ctr={ctr}");
}
}
}
}