k12 0.4.0

Implementation of the KangarooTwelve family of extendable-output functions
Documentation
use crate::{
    consts::{CHUNK_SIZE, INTERMEDIATE_NODE_DS, PAD},
    utils::{copy_cv, xor_block},
};
use digest::{
    array::{Array, ArraySize},
    block_buffer::BlockSizes,
};
use keccak::{Fn1600, State1600};

/// Parallel version of TurboSHAKE specialized for computation of chaining values.
///
/// Ideally, this function should have the following signature:
/// ```rust,ignore
/// fn par_turbo_shake<B: Backend, const RATE: usize>(
///     p1600: ParFn1600<B>,
///     data: &[u8; CHUNK_SIZE * B::PAR_SIZE_1600],
/// ) -> [[u8; {200 - RATE}]; B::PAR_SIZE_1600] { ... }
/// ```
/// But it requires advanced const generics or to deal with annoying `typenum`-based trait bounds,
/// so instead we use "runtime" asserts which should be optimized out by the compiler, see:
/// https://rust.godbolt.org/z/4Y7ervTd7
// TODO(MSRV-1.88): use `as_chunks::<CHUNK_SIZE>()`
pub(crate) fn parallel<ParSize: ArraySize, Rate: BlockSizes>(
    par_p1600: fn(&mut Array<State1600, ParSize>),
    data: &[u8],
    par_cv_dst: &mut [u8],
) {
    let par_size = ParSize::USIZE;
    assert_eq!(data.len(), CHUNK_SIZE * par_size);
    let cv_size = 200 - Rate::USIZE;
    assert_eq!(par_cv_dst.len(), cv_size * par_size);

    let mut par_state: Array<State1600, ParSize> = Default::default();

    let block_size = Rate::USIZE;
    let full_blocks = CHUNK_SIZE / block_size;

    // Process full blocks
    for block_idx in 0..full_blocks {
        for (state_idx, state) in par_state.iter_mut().enumerate() {
            let chunk_offset = state_idx * CHUNK_SIZE;
            let block_offset = chunk_offset + block_idx * block_size;

            let block = &data[block_offset..][..block_size];
            xor_block(state, block);
        }
        par_p1600(&mut par_state);
    }

    // Process tail blocks
    let tail_block_size = CHUNK_SIZE - full_blocks * block_size;
    assert_ne!(tail_block_size, 0);
    for (state_idx, state) in par_state.iter_mut().enumerate() {
        let chunk_offset = state_idx * CHUNK_SIZE;
        let block_offset = chunk_offset + full_blocks * block_size;

        let tail_data = &data[block_offset..][..tail_block_size];
        process_tail_data::<Rate>(state, tail_data);
    }
    par_p1600(&mut par_state);

    // Copy the resulting chain values
    let mut cvs = par_cv_dst.chunks_exact_mut(cv_size);
    for (state, cv_dst) in par_state.iter_mut().zip(&mut cvs) {
        copy_cv(state, cv_dst);
    }
    assert!(cvs.into_remainder().is_empty());
}

/// Scalar version of TurboSHAKE specialized for computation of chaining values.
///
/// Ideally, this function should have the following signature:
/// ```rust,ignore
/// fn turbo_shake-cv<const RATE: usize>(
///     p1600: Fn1600,
///     data: &[u8; CHUNK_SIZE],
/// ) -> [u8; {200 - RATE}] { ... }
/// ```
pub(crate) fn scalar<Rate: BlockSizes>(p1600: Fn1600, data: &[u8], cv_dst: &mut [u8]) {
    assert_eq!(data.len(), CHUNK_SIZE);
    let cv_size = 200 - Rate::USIZE;
    assert_eq!(cv_dst.len(), cv_size);

    let mut state = State1600::default();

    let block_size = Rate::USIZE;
    let mut blocks = data.chunks_exact(block_size);

    // Process full blocks
    for block in &mut blocks {
        xor_block(&mut state, block);
        p1600(&mut state);
    }

    // Process the incomplete tail block
    let tail_data = blocks.remainder();
    finalize::<Rate>(p1600, &mut state, tail_data, cv_dst);
}

pub(crate) fn finalize<Rate: BlockSizes>(
    p1600: Fn1600,
    state: &mut State1600,
    tail_data: &[u8],
    cv_dst: &mut [u8],
) {
    process_tail_data::<Rate>(state, tail_data);
    p1600(state);
    copy_cv(state, cv_dst);
}

fn process_tail_data<Rate: BlockSizes>(state: &mut State1600, tail_data: &[u8]) {
    let block_size = Rate::USIZE;

    debug_assert_eq!(
        tail_data.len(),
        CHUNK_SIZE % block_size,
        "tail_data has unexpected length",
    );
    debug_assert_eq!(tail_data.len() % size_of::<u64>(), 0);

    xor_block(state, tail_data);

    // Apply padding by XORing the state.
    // Note that we use little endian byte order.
    let pos = tail_data.len() / size_of::<u64>();
    let pad_pos = block_size / size_of::<u64>() - 1;
    state[pos] ^= u64::from(INTERMEDIATE_NODE_DS);
    state[pad_pos] ^= u64::from(PAD) << 56;
}

/// Tests vectors are generated by the `turbo-shake` crate
#[cfg(test)]
mod tests {
    use super::{parallel, scalar};
    use crate::consts::{CHUNK_SIZE, ROUNDS};
    use digest::array::typenum::{U136, U168, Unsigned};
    use keccak::{Backend, BackendClosure};

    const CHUNKS: usize = 32;
    const KT128_CV_LEN: usize = 32;
    const KT256_CV_LEN: usize = 64;
    const KT128_CVS_LEN: usize = KT128_CV_LEN * CHUNKS;
    const KT256_CVS_LEN: usize = KT256_CV_LEN * CHUNKS;

    const DATA: &[u8] = &{
        let mut buf = [0u8; CHUNKS * CHUNK_SIZE];
        let mut i = 0;
        while i < CHUNKS {
            let mut j = 0;
            while j < CHUNK_SIZE {
                buf[i * CHUNK_SIZE + j] = (i + j) as u8;
                j += 1;
            }
            i += 1;
        }
        buf
    };

    const KT128_CVS: &[u8; KT128_CVS_LEN] = include_bytes!("../tests/data/kt128_cvs.bin");
    const KT256_CVS: &[u8; KT256_CVS_LEN] = include_bytes!("../tests/data/kt256_cvs.bin");

    #[test]
    fn kt128_cvs() {
        keccak::Keccak::new().with_p1600::<ROUNDS>(|p1600| {
            let mut cvs = [0u8; KT128_CVS_LEN];
            let mut data_chunks = DATA.chunks_exact(CHUNK_SIZE);
            let mut cvs_chunks = cvs.chunks_exact_mut(KT128_CV_LEN);

            for (data_chunk, par_cv) in (&mut data_chunks).zip(&mut cvs_chunks) {
                scalar::<U168>(p1600, data_chunk, par_cv);
            }

            assert!(data_chunks.remainder().is_empty());
            assert!(cvs_chunks.into_remainder().is_empty());
            assert_eq!(&cvs, KT128_CVS);
        });
    }

    #[test]
    fn kt256_cvs() {
        keccak::Keccak::new().with_p1600::<ROUNDS>(|p1600| {
            let mut cvs = [0u8; KT256_CVS_LEN];
            let mut data_chunks = DATA.chunks_exact(CHUNK_SIZE);
            let mut cvs_chunks = cvs.chunks_exact_mut(KT256_CV_LEN);

            for (data_chunk, par_cv) in (&mut data_chunks).zip(&mut cvs_chunks) {
                scalar::<U136>(p1600, data_chunk, par_cv);
            }

            assert!(data_chunks.remainder().is_empty());
            assert!(cvs_chunks.into_remainder().is_empty());
            assert_eq!(&cvs, KT256_CVS);
        });
    }

    #[test]
    fn kt128_par_cvs() {
        struct Closure;

        impl BackendClosure for Closure {
            fn call_once<B: Backend>(self) {
                let par_p1600 = B::get_par_p1600::<ROUNDS>();
                let p1600 = B::get_p1600::<ROUNDS>();

                let mut cvs = [0u8; KT128_CVS_LEN];

                let par_size = B::ParSize1600::USIZE;
                let par_data_size = CHUNK_SIZE * par_size;
                let par_cv_size = KT128_CV_LEN * par_size;
                let mut data_chunks = DATA.chunks_exact(par_data_size);
                let mut par_cvs = cvs.chunks_exact_mut(par_cv_size);

                for (data_chunk, par_cv) in (&mut data_chunks).zip(&mut par_cvs) {
                    parallel::<_, U168>(par_p1600, data_chunk, par_cv);
                }

                let mut data_chunks = data_chunks.remainder().chunks_exact(CHUNK_SIZE);
                let mut cvs_chunks = par_cvs.into_remainder().chunks_exact_mut(KT128_CV_LEN);

                for (data_chunk, par_cv) in (&mut data_chunks).zip(&mut cvs_chunks) {
                    scalar::<U168>(p1600, data_chunk, par_cv);
                }

                assert!(data_chunks.remainder().is_empty());
                assert!(cvs_chunks.into_remainder().is_empty());
                assert_eq!(&cvs, KT128_CVS);
            }
        }

        keccak::Keccak::new().with_backend(Closure);
    }

    #[test]
    fn kt256_par_cvs() {
        struct Closure;

        impl BackendClosure for Closure {
            fn call_once<B: Backend>(self) {
                let par_p1600 = B::get_par_p1600::<ROUNDS>();
                let p1600 = B::get_p1600::<ROUNDS>();

                let mut cvs = [0u8; KT256_CVS_LEN];

                let par_size = B::ParSize1600::USIZE;
                let par_data_size = CHUNK_SIZE * par_size;
                let par_cv_size = KT256_CV_LEN * par_size;
                let mut data_chunks = DATA.chunks_exact(par_data_size);
                let mut par_cvs = cvs.chunks_exact_mut(par_cv_size);

                for (data_chunk, par_cv) in (&mut data_chunks).zip(&mut par_cvs) {
                    parallel::<_, U136>(par_p1600, data_chunk, par_cv);
                }

                let mut data_chunks = data_chunks.remainder().chunks_exact(CHUNK_SIZE);
                let mut cvs_chunks = par_cvs.into_remainder().chunks_exact_mut(KT256_CV_LEN);

                for (data_chunk, par_cv) in (&mut data_chunks).zip(&mut cvs_chunks) {
                    scalar::<U136>(p1600, data_chunk, par_cv);
                }

                assert!(data_chunks.remainder().is_empty());
                assert!(cvs_chunks.into_remainder().is_empty());
                assert_eq!(&cvs, KT256_CVS);
            }
        }

        keccak::Keccak::new().with_backend(Closure);
    }
}