iota-crypto 0.15.3

The canonical source of cryptographic ground-truth for all IOTA Rust projects
Documentation
// Copyright 2020-2021 IOTA Stiftung
// SPDX-License-Identifier: Apache-2.0

//! A batched version of the `CurlP` hash.

mod bct;
mod bct_curlp;

use alloc::vec::Vec;

use bct::{BcTrit, BcTritArr, BcTritBuf};
use bct_curlp::BctCurlP;

use crate::{
    encoding::ternary::{
        raw::{RawEncoding, RawEncodingBuf},
        Btrit, TritBuf,
    },
    hashes::ternary::HASH_LENGTH,
};

/// The number of inputs that can be processed in a single batch.
pub const BATCH_SIZE: usize = 8 * core::mem::size_of::<usize>();
const HIGH_BITS: usize = usize::max_value();
const NUM_ROUNDS: usize = 81;

/// A hasher that can process several inputs at the same time in batches.
///
/// This hasher works by interleaving the trits of the inputs in each batch and hashing this
/// interleaved representation. It is also able to fall back to the regular CurlP algorithm if
/// required.
pub struct CurlPBatchHasher<B: RawEncodingBuf> {
    /// The trits of the inputs before being interleaved.
    trit_inputs: Vec<TritBuf<B>>,
    /// An interleaved representation of the input trits.
    bct_inputs: BcTritBuf,
    /// An interleaved representation of the output trits.
    bct_hashes: BcTritArr<HASH_LENGTH>,
    /// A buffer for demultiplexing.
    buf_demux: TritBuf,
    /// The CurlP hasher for binary coded trits.
    bct_curlp: BctCurlP,
}

impl<B> CurlPBatchHasher<B>
where
    B: RawEncodingBuf,
    B::Slice: RawEncoding<Trit = Btrit>,
{
    /// Create a new hasher.
    ///
    /// It requires the length of the input, and the length of the output hash.
    pub fn new(input_length: usize) -> Self {
        Self {
            trit_inputs: Vec::with_capacity(BATCH_SIZE),
            bct_inputs: BcTritBuf::zeros(input_length),
            bct_hashes: BcTritArr::<HASH_LENGTH>::zeros(),
            buf_demux: TritBuf::zeros(HASH_LENGTH),
            bct_curlp: BctCurlP::new(),
        }
    }

    /// Add a new input to the batch.
    ///
    /// It panics if the size of the batch exceeds `BATCH_SIZE` or if `input.len()` is not equal to
    /// the `input_length` parameter of the constructor.
    pub fn add(&mut self, input: TritBuf<B>) {
        assert!(self.trit_inputs.len() < BATCH_SIZE, "Batch is full.");
        assert_eq!(input.len(), self.bct_inputs.len(), "Input has an incorrect size.");
        self.trit_inputs.push(input);
    }

    /// Return the length of the current batch.
    pub fn len(&self) -> usize {
        self.trit_inputs.len()
    }

    /// Check if the current batch is empty.
    pub fn is_empty(&self) -> bool {
        self.len() == 0
    }

    /// Multiplex or interleave the input trits in the bash.
    ///
    /// Before doing the actual interleaving, each trit is encoded as two bits which are usually
    /// refered as the low and high bit.
    ///
    /// | Trit | Low bit | High bit |
    /// |------|---------|----------|
    /// |  -1  |    1    |     0    |
    /// |   0  |    0    |     1    |
    /// |   1  |    1    |     1    |
    ///
    /// Then the low and high bits are interleaved into two vectors of integers. Each integer has
    /// size `BATCH_SIZE` and there are `input_length` integers in each vector.  This means that
    /// the low and high bits of the transaction number `N` in the batch are stored in the position
    /// `N` of each integer.
    ///
    /// This step works correctly even if there are less than `BATCH_SIZE` inputs.
    fn mux(&mut self) {
        let count = self.trit_inputs.len();
        for i in 0..self.bct_inputs.len() {
            // This is safe because `i < self.bct_inputs.len()`.
            let BcTrit(lo, hi) = unsafe { self.bct_inputs.get_unchecked_mut(i) };

            for j in 0..count {
                // this is safe because `j < self.trit_inputs.len()` and
                // `i < self.trit_inputs[j].len()` (the `add` method guarantees that all the inputs
                // have the same length as `self.trit_inputs`).
                match unsafe { self.trit_inputs.get_unchecked(j).get_unchecked(i) } {
                    Btrit::NegOne => *lo |= 1 << j,
                    Btrit::PlusOne => *hi |= 1 << j,
                    Btrit::Zero => {
                        *lo |= 1 << j;
                        *hi |= 1 << j;
                    }
                }
            }
        }
    }

    /// Demultiplex the bits of the output to obtain the hash of the input with a specific index.
    ///
    /// This is the inverse of the `mux` function, but it is applied over the vector with the
    /// binary encoding of the output hashes. Each pair of low and high bits in the `bct_hashes`
    /// field is decoded into a trit using the same convention as the `mux` step with an additional
    /// rule for the `(0, 0)` pair of bits which is mapped to the `0` trit.
    fn demux(&mut self, index: usize) -> TritBuf {
        for (bc_trit, btrit) in self.bct_hashes.iter().zip(self.buf_demux.iter_mut()) {
            let lo = (bc_trit.lo() >> index) & 1;
            let hi = (bc_trit.hi() >> index) & 1;

            *btrit = match (lo, hi) {
                (1, 0) => Btrit::NegOne,
                (0, 1) => Btrit::PlusOne,
                // This can only be `(0, 0)` or `(1, 1)`.
                _ => Btrit::Zero,
            };
        }

        self.buf_demux.clone()
    }

    /// Hash the received inputs.
    ///
    /// This function also takes care of cleaning the buffers of the struct and resetting the
    /// batched CurlP hasher so it can be called at any time.
    pub fn hash(&mut self) -> impl Iterator<Item = TritBuf> + '_ {
        let total = self.trit_inputs.len();
        // Reset batched CurlP hasher.
        self.bct_curlp.reset();
        // Multiplex the trits in `trits` and dump them into `inputs`
        self.mux();
        // Do the regular sponge steps.
        self.bct_curlp.absorb(&self.bct_inputs);
        self.bct_curlp.squeeze_into(&mut self.bct_hashes);
        // Clear the `trits` buffer to allow receiving a new batch.
        self.trit_inputs.clear();
        // Fill the `inputs` buffer with zeros.
        self.bct_inputs.fill(0);
        // Return an iterator for the output hashes.
        BatchedHashes {
            hasher: self,
            range: 0..total,
        }
    }
}

/// A helper iterator type for the output of the `hash_batched` method.
struct BatchedHashes<'a, B: RawEncodingBuf> {
    hasher: &'a mut CurlPBatchHasher<B>,
    range: core::ops::Range<usize>,
}

impl<'a, B> Iterator for BatchedHashes<'a, B>
where
    B: RawEncodingBuf,
    B::Slice: RawEncoding<Trit = Btrit>,
{
    type Item = TritBuf;

    fn next(&mut self) -> Option<Self::Item> {
        let index = self.range.next()?;
        Some(self.hasher.demux(index))
    }
}