solana-shake256 0.1.0

A no_std, SVM-optimized SHAKE256 library.
Documentation
//! SHAKE256 (FIPS 202) — hand-rolled, `no_std`, zero dependencies, tuned
//! for the Solana SBF target.
//!
//! This is the single source of the Keccak-f[1600] core shared by
//! [`solana-hawk512`] and [`solana-falcon512`] (the permutation,
//! `absorb`/`finalize` and SHAKE256 padding were byte-identical in both;
//! consolidating here keeps two consensus-critical verifiers provably
//! running the same primitive). The two output styles each verifier needs
//! are both exposed: the bulk **rate-draining** path
//! ([`Shake256::rate_lanes`] + [`Shake256::permute`], used by Falcon's
//! `hash_to_point` rejection sampling) and a fixed-length
//! [`Shake256::squeeze`] (used by HAWK's `hpub`/`M`/`h`).
//!
//! The `keccak_f1600` core uses Bertoni **lane-complementing** (the 6-lane
//! Keccak-Team set `{1,2,8,12,17,20}`, pre-/post-complemented once per
//! permute so ~456 NOTs are eliminated across 24 rounds) fused with an
//! **in-place chi-row + 10 cell-saves** layout (no `B[25]` scratch).
//!
//! [`solana-hawk512`]: https://github.com/blueshift-gg/solana-hawk512
//! [`solana-falcon512`]: https://github.com/blueshift-gg/solana-falcon512

#![no_std]

const RC: [u64; 24] = [
    0x0000000000000001,
    0x0000000000008082,
    0x800000000000808a,
    0x8000000080008000,
    0x000000000000808b,
    0x0000000080000001,
    0x8000000080008081,
    0x8000000000008009,
    0x000000000000008a,
    0x0000000000000088,
    0x0000000080008009,
    0x000000008000000a,
    0x000000008000808b,
    0x800000000000008b,
    0x8000000000008089,
    0x8000000000008003,
    0x8000000000008002,
    0x8000000000000080,
    0x000000000000800a,
    0x800000008000000a,
    0x8000000080008081,
    0x8000000000008080,
    0x0000000080000001,
    0x8000000080008008,
];

fn keccak_f1600(s: &mut [u64; 25]) {
    // **Bertoni lane-complementing + chi-row** layout.
    //
    // Pre-complement the canonical 6-lane Keccak Team set
    //     CS = {1, 2, 8, 12, 17, 20}
    // chosen so that across one full round (theta+rho+pi+chi+iota), the
    // complementation pattern is invariant. Per-row IN-complemented b's at
    // post-pi positions (derived from theta+rho+pi propagation):
    //     row 0: b0, b2, b3   row 1: b0, b2     row 2: b0, b2
    //     row 3: b1, b3, b4   row 4: b0, b3
    // Per-row OUT-complemented (must store ~A_logical_new):
    //     row 0: x=1, x=2     row 1: x=3        row 2: x=2
    //     row 3: x=2          row 4: x=0
    // Net ~456 NOTs eliminated per 24-round permute, ~12 added at boundaries.

    // Entry: complement the 6 CS lanes once.
    s[1] = !s[1];
    s[2] = !s[2];
    s[8] = !s[8];
    s[12] = !s[12];
    s[17] = !s[17];
    s[20] = !s[20];

    macro_rules! round {
        ($rc:expr) => {{
            // theta — column parities
            let c0 = s[0] ^ s[5] ^ s[10] ^ s[15] ^ s[20];
            let c1 = s[1] ^ s[6] ^ s[11] ^ s[16] ^ s[21];
            let c2 = s[2] ^ s[7] ^ s[12] ^ s[17] ^ s[22];
            let c3 = s[3] ^ s[8] ^ s[13] ^ s[18] ^ s[23];
            let c4 = s[4] ^ s[9] ^ s[14] ^ s[19] ^ s[24];

            let d0 = c4 ^ c1.rotate_left(1);
            let d1 = c0 ^ c2.rotate_left(1);
            let d2 = c1 ^ c3.rotate_left(1);
            let d3 = c2 ^ c4.rotate_left(1);
            let d4 = c3 ^ c0.rotate_left(1);

            // **In-place chi-row + 10 cell-saves**.
            // Row 0 outputs to s[0..5]; rows 1..4 read s[3], s[1], s[4], s[2]
            // from this range — save before overwriting.
            let s3 = s[3];
            let s1 = s[1];
            let s4 = s[4];
            let s2 = s[2];

            // Row 0 — IN: b0,b2,b3 complemented; OUT-complement: x=1,2.
            // Iota fused into lane 0.
            {
                let b0 = s[0] ^ d0;
                let b1 = (s[6] ^ d1).rotate_left(44);
                let b2 = (s[12] ^ d2).rotate_left(43);
                let b3 = (s[18] ^ d3).rotate_left(21);
                let b4 = (s[24] ^ d4).rotate_left(14);
                s[0] = b0 ^ (b1 | b2) ^ $rc;
                s[1] = b1 ^ ((!b2) | b3);
                s[2] = b2 ^ (b3 & b4);
                s[3] = b3 ^ (b4 | b0);
                s[4] = b4 ^ (b0 & b1);
            }

            // Row 1 outputs to s[5..10]; rows 2..4 read s[7], s[5], s[8].
            let s7 = s[7];
            let s5 = s[5];
            let s8 = s[8];

            // Row 1 — IN: b0,b2 complemented; OUT-complement: x=3.
            {
                let b0 = (s3 ^ d3).rotate_left(28);
                let b1 = (s[9] ^ d4).rotate_left(20);
                let b2 = (s[10] ^ d0).rotate_left(3);
                let b3 = (s[16] ^ d1).rotate_left(45);
                let b4 = (s[22] ^ d2).rotate_left(61);
                s[5] = b0 ^ (b1 | b2);
                s[6] = b1 ^ (b2 & b3);
                s[7] = (!b2) ^ b4 ^ (b3 & b4);
                s[8] = b3 ^ (b4 | b0);
                s[9] = b4 ^ (b0 & b1);
            }

            // Row 2 outputs to s[10..15]; rows 3..4 read s[11], s[14].
            let s11 = s[11];
            let s14 = s[14];

            // Row 2 — IN: b0,b2 complemented; OUT-complement: x=2.
            {
                let b0 = (s1 ^ d1).rotate_left(1);
                let b1 = (s7 ^ d2).rotate_left(6);
                let b2 = (s[13] ^ d3).rotate_left(25);
                let b3 = (s[19] ^ d4).rotate_left(8);
                let b4 = (s[20] ^ d0).rotate_left(18);
                s[10] = b0 ^ (b1 | b2);
                s[11] = b1 ^ (b2 & b3);
                s[12] = b2 ^ b4 ^ (b3 & b4);
                s[13] = b3 ^ !(b4 | b0);
                s[14] = b4 ^ (b0 & b1);
            }

            // Row 3 outputs to s[15..20]; row 4 reads s[15].
            let s15 = s[15];

            // Row 3 — IN: b1,b3,b4 complemented; OUT-complement: x=2.
            {
                let b0 = (s4 ^ d4).rotate_left(27);
                let b1 = (s5 ^ d0).rotate_left(36);
                let b2 = (s11 ^ d1).rotate_left(10);
                let b3 = (s[17] ^ d2).rotate_left(15);
                let b4 = (s[23] ^ d3).rotate_left(56);
                s[15] = b0 ^ (b1 & b2);
                s[16] = b1 ^ (b2 | b3);
                s[17] = b2 ^ ((!b3) | b4);
                s[18] = (!b3) ^ (b4 & b0);
                s[19] = b4 ^ (b0 | b1);
            }

            // Row 4 — IN: b0,b3 complemented; OUT-complement: x=0.
            {
                let b0 = (s2 ^ d2).rotate_left(62);
                let b1 = (s8 ^ d3).rotate_left(55);
                let b2 = (s14 ^ d4).rotate_left(39);
                let b3 = (s15 ^ d0).rotate_left(41);
                let b4 = (s[21] ^ d1).rotate_left(2);
                s[20] = b0 ^ b2 ^ (b1 & b2);
                s[21] = b1 ^ !(b2 | b3);
                s[22] = b2 ^ (b3 & b4);
                s[23] = b3 ^ (b4 | b0);
                s[24] = b4 ^ (b0 & b1);
            }
        }};
    }

    round!(RC[0]);
    round!(RC[1]);
    round!(RC[2]);
    round!(RC[3]);
    round!(RC[4]);
    round!(RC[5]);
    round!(RC[6]);
    round!(RC[7]);
    round!(RC[8]);
    round!(RC[9]);
    round!(RC[10]);
    round!(RC[11]);
    round!(RC[12]);
    round!(RC[13]);
    round!(RC[14]);
    round!(RC[15]);
    round!(RC[16]);
    round!(RC[17]);
    round!(RC[18]);
    round!(RC[19]);
    round!(RC[20]);
    round!(RC[21]);
    round!(RC[22]);
    round!(RC[23]);

    // Exit: un-complement the 6 CS lanes so the caller sees the normal
    // (uncomplemented) state. Cost paid once per permute.
    s[1] = !s[1];
    s[2] = !s[2];
    s[8] = !s[8];
    s[12] = !s[12];
    s[17] = !s[17];
    s[20] = !s[20];
}

/// SHAKE256 rate in bytes (1600-bit state − 2·256-bit capacity = 1088 bits).
pub const RATE: usize = 136;

/// Incremental SHAKE256 (FIPS 202). `new` → `absorb`* → `finalize` → then
/// either drain the rate (`rate_lanes`/`permute`) or `squeeze` a fixed
/// number of bytes.
///
/// Every method is `#[inline]` so the consumer (built `lto`/`opt-level=3`
/// for SBF) folds the whole thing in exactly as if it were a local module —
/// the crate boundary has no codegen cost.
pub struct Shake256 {
    state: [u64; 25],
    pos: usize,
}

impl Default for Shake256 {
    #[inline]
    fn default() -> Self {
        Self::new()
    }
}

impl Shake256 {
    #[inline]
    pub fn new() -> Self {
        Self {
            state: [0; 25],
            pos: 0,
        }
    }

    #[inline(always)]
    pub fn absorb(&mut self, data: &[u8]) {
        let mut i = 0;
        let len = data.len();

        // Phase 1: byte-by-byte until lane-aligned.
        while i < len && !self.pos.is_multiple_of(8) {
            let lane = self.pos / 8;
            let shift = 8 * (self.pos % 8);
            self.state[lane] ^= (data[i] as u64) << shift;
            self.pos += 1;
            if self.pos == RATE {
                keccak_f1600(&mut self.state);
                self.pos = 0;
            }
            i += 1;
        }

        // Phase 2: bulk 8-byte chunks XORed straight into a lane. Bytes within
        // a lane are little-endian per FIPS 202, so `from_le_bytes` is the
        // correct assembly.
        while i + 8 <= len {
            // SAFETY: phase 1 made `self.pos` lane-aligned (multiple of 8),
            // and `pos < RATE = 136 = 17 * 8`, so `pos / 8 < 17 < 25`.
            unsafe { core::hint::assert_unchecked(self.pos / 8 < 17) };
            let chunk_bytes: [u8; 8] = data[i..i + 8].try_into().unwrap();
            let chunk = u64::from_le_bytes(chunk_bytes);
            self.state[self.pos / 8] ^= chunk;
            self.pos += 8;
            i += 8;
            if self.pos == RATE {
                keccak_f1600(&mut self.state);
                self.pos = 0;
            }
        }

        // Phase 3: tail bytes (< 8 left).
        while i < len {
            let lane = self.pos / 8;
            let shift = 8 * (self.pos % 8);
            self.state[lane] ^= (data[i] as u64) << shift;
            self.pos += 1;
            if self.pos == RATE {
                keccak_f1600(&mut self.state);
                self.pos = 0;
            }
            i += 1;
        }
    }

    #[inline(always)]
    pub fn finalize(&mut self) {
        let lane = self.pos / 8;
        let shift = 8 * (self.pos % 8);
        self.state[lane] ^= 0x1Fu64 << shift;
        let last = RATE - 1;
        self.state[last / 8] ^= 0x80u64 << (8 * (last % 8));
        keccak_f1600(&mut self.state);
        self.pos = 0;
    }

    /// First 17 u64 lanes (= the 136-byte rate). Bytes within each lane are
    /// little-endian per FIPS 202: byte at offset `b` of lane `l` is
    /// `(rate_lanes()[l] >> (8*b)) & 0xff`. Drain this, then call
    /// [`permute`](Self::permute) for the next block — the bulk-rate squeeze
    /// path (e.g. Falcon's `hash_to_point` rejection sampling).
    #[inline]
    pub fn rate_lanes(&self) -> &[u64] {
        &self.state[..17]
    }

    /// Apply Keccak-f[1600] to refill the rate (used with
    /// [`rate_lanes`](Self::rate_lanes)).
    #[inline]
    pub fn permute(&mut self) {
        keccak_f1600(&mut self.state);
    }

    /// Squeeze exactly `LEN` bytes, handling rate-boundary permutes. `LEN`
    /// is const so the loop bounds and the rate-boundary check fold and the
    /// bulk lane copy fully unrolls (the fixed-length path, e.g. HAWK's
    /// `hpub`/`M`/`h`). Bytes within a lane are little-endian per FIPS 202,
    /// so a lane-aligned run of ≥ 8 bytes is one `to_le_bytes` copy (the
    /// same bytes as eight `state[lane] >> 8k` reads).
    #[inline]
    pub fn squeeze<const LEN: usize>(&mut self, out: &mut [u8; LEN]) {
        let len = LEN;
        let mut i = 0;

        // Byte-by-byte until lane-aligned.
        while i < len && !self.pos.is_multiple_of(8) {
            let lane = self.pos / 8;
            let shift = 8 * (self.pos % 8);
            out[i] = (self.state[lane] >> shift) as u8;
            self.pos += 1;
            i += 1;
            if self.pos == RATE {
                keccak_f1600(&mut self.state);
                self.pos = 0;
            }
        }

        // Bulk 8-byte lanes (no rate boundary lands mid-lane: RATE = 17·8).
        while i + 8 <= len && self.pos + 8 <= RATE {
            // SAFETY: `pos` lane-aligned, `pos + 8 ≤ RATE = 136` ⇒
            // `pos/8 ≤ 16 < 17 < 25`.
            unsafe { core::hint::assert_unchecked(self.pos / 8 < 17) };
            out[i..i + 8].copy_from_slice(&self.state[self.pos / 8].to_le_bytes());
            self.pos += 8;
            i += 8;
            if self.pos == RATE {
                keccak_f1600(&mut self.state);
                self.pos = 0;
            }
        }

        // Tail bytes (< 8 left, or a partial lane before a rate boundary).
        while i < len {
            let lane = self.pos / 8;
            let shift = 8 * (self.pos % 8);
            out[i] = (self.state[lane] >> shift) as u8;
            self.pos += 1;
            i += 1;
            if self.pos == RATE {
                keccak_f1600(&mut self.state);
                self.pos = 0;
            }
        }
    }
}