vyre 0.4.0

GPU compute intermediate representation with a standard operation library
Documentation
//! CPU reference kernel for `crypto.chacha20_block`.

use crate::ir::DataType;
use crate::ops::{Backend, IntrinsicDescriptor, OpSpec};

pub const INPUTS: &[DataType] = &[DataType::Bytes];
pub const OUTPUTS: &[DataType] = &[DataType::Bytes];
pub const LAWS: &[crate::ops::AlgebraicLaw] = &[];

pub fn wgsl_only(_backend: &Backend) -> bool {
    true
}

/// Declarative operation specification for `crypto.chacha20_block`.
pub const SPEC: OpSpec = OpSpec::intrinsic(
    "crypto.chacha20_block",
    INPUTS,
    OUTPUTS,
    LAWS,
    wgsl_only,
    IntrinsicDescriptor::new("crypto.chacha20_block.wgsl", "wgsl-chacha20-block", cpu),
);

/// Flat-byte CPU adapter for the ChaCha20 block intrinsic.
pub fn cpu(input: &[u8], output: &mut Vec<u8>) {
    output.clear();
    if input.len() != 64 {
        tracing::error!(
            "ChaCha20 block input was {} bytes. Fix: pass exactly one 64-byte ChaCha20 state.",
            input.len()
        );
        return;
    }
    output.extend_from_slice(&chacha20_block(input));
}

/// Build a 64-byte ChaCha20 initial state from key, nonce, and counter.
///
/// # Panics
///
/// Panics if `key.len() != 32` or `nonce.len() != 12`.
#[must_use]
pub fn chacha20_state_from_key_nonce_counter(
    key: &[u8],
    nonce: &[u8],
    counter: u32,
) -> [u8; 64] {
    assert_eq!(key.len(), 32, "Fix: ChaCha20 key must be exactly 32 bytes");
    assert_eq!(
        nonce.len(),
        12,
        "Fix: ChaCha20 nonce must be exactly 12 bytes"
    );

    let mut state = [0u8; 64];
    // Constants: "expand 32-byte k" in little-endian
    state[0..4].copy_from_slice(&0x6170_7865u32.to_le_bytes());
    state[4..8].copy_from_slice(&0x3320_646eu32.to_le_bytes());
    state[8..12].copy_from_slice(&0x7962_2d32u32.to_le_bytes());
    state[12..16].copy_from_slice(&0x6b20_6574u32.to_le_bytes());

    // Key (32 bytes)
    state[16..48].copy_from_slice(key);

    // Counter (4 bytes)
    state[48..52].copy_from_slice(&counter.to_le_bytes());

    // Nonce (12 bytes)
    state[52..64].copy_from_slice(nonce);

    state
}

/// Compute one ChaCha20 block from a 64-byte initial state.
///
/// `state` must be exactly 64 bytes (16 little-endian u32 words).
/// Returns a 64-byte keystream block (16 little-endian u32 words).
///
/// # Panics
///
/// Panics if `state.len() != 64`.
#[must_use]
pub fn chacha20_block(state: &[u8]) -> [u8; 64] {
    assert_eq!(
        state.len(),
        64,
        "Fix: ChaCha20 state must be exactly 64 bytes"
    );

    let mut words = [0u32; 16];
    for i in 0..16 {
        words[i] = u32::from_le_bytes([
            state[i * 4],
            state[i * 4 + 1],
            state[i * 4 + 2],
            state[i * 4 + 3],
        ]);
    }

    let original = words;

    for _ in 0..10 {
        // Column rounds
        quarter_round(&mut words, 0, 4, 8, 12);
        quarter_round(&mut words, 1, 5, 9, 13);
        quarter_round(&mut words, 2, 6, 10, 14);
        quarter_round(&mut words, 3, 7, 11, 15);
        // Diagonal rounds
        quarter_round(&mut words, 0, 5, 10, 15);
        quarter_round(&mut words, 1, 6, 11, 12);
        quarter_round(&mut words, 2, 7, 8, 13);
        quarter_round(&mut words, 3, 4, 9, 14);
    }

    for i in 0..16 {
        words[i] = words[i].wrapping_add(original[i]);
    }

    let mut out = [0u8; 64];
    for i in 0..16 {
        out[i * 4..i * 4 + 4].copy_from_slice(&words[i].to_le_bytes());
    }
    out
}

/// ChaCha20 quarter round on a 16-word state.
///
/// Updates `s[a]`, `s[b]`, `s[c]`, `s[d]` according to the ChaCha20 specification.
pub fn quarter_round(s: &mut [u32; 16], a: usize, b: usize, c: usize, d: usize) {
    s[a] = s[a].wrapping_add(s[b]);
    s[d] ^= s[a];
    s[d] = s[d].rotate_left(16);

    s[c] = s[c].wrapping_add(s[d]);
    s[b] ^= s[c];
    s[b] = s[b].rotate_left(12);

    s[a] = s[a].wrapping_add(s[b]);
    s[d] ^= s[a];
    s[d] = s[d].rotate_left(8);

    s[c] = s[c].wrapping_add(s[d]);
    s[b] ^= s[c];
    s[b] = s[b].rotate_left(7);
}

#[cfg(test)]
mod tests {
    use super::*;
    
    #[test]
    pub(crate) fn rfc8439_test_vector_1_all_zeros() {
        let key = [0u8; 32];
        let nonce = [0u8; 12];
        let state = chacha20_state_from_key_nonce_counter(&key, &nonce, 0);
        let block = chacha20_block(&state);

        let expected = [
            0x76, 0xb8, 0xe0, 0xad, 0xa0, 0xf1, 0x3d, 0x90, 0x40, 0x5d, 0x6a, 0xe5, 0x53, 0x86,
            0xbd, 0x28, 0xbd, 0xd2, 0x19, 0xb8, 0xa0, 0x8d, 0xed, 0x1a, 0xa8, 0x36, 0xef, 0xcc,
            0x8b, 0x77, 0x0d, 0xc7, 0xda, 0x41, 0x59, 0x7c, 0x51, 0x57, 0x48, 0x8d, 0x77, 0x24,
            0xe0, 0x3f, 0xb8, 0xd8, 0x4a, 0x37, 0x6a, 0x43, 0xb8, 0xf4, 0x15, 0x18, 0xa1, 0x1c,
            0xc3, 0x87, 0xb6, 0x69, 0xb2, 0xee, 0x65, 0x86,
        ];
        assert_eq!(block, expected, "RFC 8439 Test Vector #1 mismatch");
    }

    #[test]
    pub(crate) fn rfc8439_test_vector_sunscreen_block_1() {
        let key = [
            0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d,
            0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b,
            0x1c, 0x1d, 0x1e, 0x1f,
        ];
        let nonce = [
            0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x4a, 0x00, 0x00, 0x00, 0x00,
        ];
        let state = chacha20_state_from_key_nonce_counter(&key, &nonce, 1);
        let block = chacha20_block(&state);

        let expected_first_16_bytes = [
            0x22, 0x4f, 0x51, 0xf3, 0x40, 0x1b, 0xd9, 0xe1, 0x2f, 0xde, 0x27, 0x6f, 0xb8, 0x63,
            0x1d, 0xed,
        ];
        assert_eq!(
            block[0..16],
            expected_first_16_bytes,
            "RFC 8439 Sunscreen block 1 mismatch"
        );
    }

    #[test]
    pub(crate) fn spec_is_intrinsic() {
        assert_eq!(SPEC.id(), "crypto.chacha20_block");
        assert!(matches!(SPEC.category(), crate::ops::Category::C { .. }));
    }

    #[test]
    #[should_panic(expected = "Fix: ChaCha20 state must be exactly 64 bytes")]
    pub(crate) fn panics_on_short_state() {
        let _ = chacha20_block(&[0u8; 63]);
    }

    #[test]
    #[should_panic(expected = "Fix: ChaCha20 state must be exactly 64 bytes")]
    pub(crate) fn panics_on_long_state() {
        let _ = chacha20_block(&[0u8; 65]);
    }

    #[test]
    pub(crate) fn quarter_round_is_deterministic() {
        let mut s1 = [0x1111_1111, 0x0102_0304, 0x0000_0000, 0x0000_0000, 0x9b8d_6f43, 0x0000_0000,
                      0x0000_0000, 0x0000_0000, 0x0123_4567, 0x0000_0000, 0x0000_0000, 0x0000_0000,
                      0x0000_0000, 0x0000_0000, 0x0000_0000, 0x0000_0000];
        let mut s2 = s1;
        quarter_round(&mut s1, 0, 4, 8, 12);
        quarter_round(&mut s2, 0, 4, 8, 12);
        assert_eq!(s1, s2);
        // State should have been modified
        assert_ne!(s1[0], 0x1111_1111);
    }

    #[test]
    pub(crate) fn state_builder_panics_on_short_key() {
        let result = std::panic::catch_unwind(|| {
            chacha20_state_from_key_nonce_counter(&[0u8; 31], &[0u8; 12], 0)
        });
        assert!(result.is_err());
    }

    #[test]
    pub(crate) fn state_builder_panics_on_short_nonce() {
        let result = std::panic::catch_unwind(|| {
            chacha20_state_from_key_nonce_counter(&[0u8; 32], &[0u8; 11], 0)
        });
        assert!(result.is_err());
    }
}