use crate::{
core::{
bounds::FieldBounds,
circuits::{
boolean::{
boolean_array::BooleanArray,
boolean_value::{Boolean, BooleanValue},
byte::Byte,
},
traits::arithmetic_circuit::ArithmeticCircuit,
},
expressions::{bit_expr::BitExpr, expr::EvalFailure},
global_value::{global_expr_store::with_global_expr_store_as_local, value::FieldValue},
ir_builder::{ExprStore, IRBuilder},
},
traits::Keccak,
utils::field::BaseField,
};
use core::panic;
use sha3::{Digest, Sha3_256, Sha3_512};
#[derive(Clone, Copy)]
struct Lane64<B: Boolean>([Byte<B>; 8]);
impl<B: Boolean> Lane64<B> {
pub fn load64(bytes: [Byte<B>; 8]) -> Self {
Self(bytes)
}
pub fn store64(&self) -> [Byte<B>; 8] {
self.0
}
pub fn rol64(&self, n: usize) -> Self {
let n = n % 64;
let mut bits_lo = self
.0
.iter()
.flat_map(|byte| byte.get_bits().into_iter().collect::<Vec<B>>())
.collect::<Vec<B>>();
let mut bits_hi = bits_lo.split_off(64 - n);
bits_hi.append(&mut bits_lo);
Self(
bits_hi
.into_iter()
.collect::<Vec<B>>()
.chunks(8)
.map(|chunk| {
Byte::new(chunk.to_vec().try_into().unwrap_or_else(|v: Vec<B>| {
panic!("Expected a Vec of length 8 (found {})", v.len())
}))
})
.collect::<Vec<Byte<B>>>()
.try_into()
.unwrap_or_else(|v: Vec<Byte<B>>| {
panic!("Expected a Vec of length 8 (found {})", v.len())
}),
)
}
pub fn xor(&self, other: Lane64<B>) -> Lane64<B> {
Self(
self.0
.into_iter()
.zip(other.0)
.map(|(b_lhs, b_rhs)| b_lhs ^ b_rhs)
.collect::<Vec<Byte<B>>>()
.try_into()
.unwrap_or_else(|v: Vec<Byte<B>>| {
panic!("Expected a Vec of length 8 (found {})", v.len())
}),
)
}
pub fn and(&self, other: Lane64<B>) -> Lane64<B> {
Self(
self.0
.into_iter()
.zip(other.0)
.map(|(b_lhs, b_rhs)| b_lhs & b_rhs)
.collect::<Vec<Byte<B>>>()
.try_into()
.unwrap_or_else(|v: Vec<Byte<B>>| {
panic!("Expected a Vec of length 8 (found {})", v.len())
}),
)
}
pub fn not(&self) -> Lane64<B> {
Self(
self.0
.into_iter()
.map(|b| !b)
.collect::<Vec<Byte<B>>>()
.try_into()
.unwrap_or_else(|v: Vec<Byte<B>>| {
panic!("Expected a Vec of length 8 (found {})", v.len())
}),
)
}
}
struct Lanes<B: Boolean>([[Lane64<B>; 5]; 5]);
impl<B: Boolean> Lanes<B> {
pub fn new(lanes: [[Lane64<B>; 5]; 5]) -> Self {
Self(lanes)
}
pub fn inner(&self) -> [[Lane64<B>; 5]; 5] {
self.0
}
pub fn set(&mut self, x: usize, y: usize, lane: Lane64<B>) {
self.0[x][y] = lane;
}
}
#[derive(Debug, Clone)]
pub struct KeccakPermutation;
impl KeccakPermutation {
fn f1600_on_lanes<B: Boolean>(mut lanes: Lanes<B>) -> Lanes<B> {
let mut r = 1u16;
for _ in 0..24 {
let c = lanes
.inner()
.into_iter()
.map(|ls| ls[0].xor(ls[1]).xor(ls[2]).xor(ls[3]).xor(ls[4]))
.collect::<Vec<Lane64<B>>>();
let ds = (0..5)
.map(|x| c[(x + 4) % 5].xor(c[(x + 1) % 5].rol64(1)))
.collect::<Vec<Lane64<B>>>();
lanes = Lanes::new(
lanes
.inner()
.into_iter()
.zip(ds)
.map(|(ls, d)| {
ls.into_iter()
.map(|l| l.xor(d))
.collect::<Vec<Lane64<B>>>()
.try_into()
.unwrap_or_else(|v: Vec<Lane64<B>>| {
panic!("Expected a Vec of length 5 (found {})", v.len())
})
})
.collect::<Vec<[Lane64<B>; 5]>>()
.try_into()
.unwrap_or_else(|v: Vec<[Lane64<B>; 5]>| {
panic!("Expected a Vec of length 5 (found {})", v.len())
}),
);
let (mut x, mut y) = (1, 0);
let mut current = lanes.inner()[x][y];
for t in 0..24 {
(x, y) = (y, (2 * x + 3 * y) % 5);
let tmp = current;
current = lanes.inner()[x][y];
lanes.set(x, y, tmp.rol64(((t + 1) * (t + 2)) / 2));
}
for y in 0..5 {
let t = lanes
.inner()
.into_iter()
.map(|ls| ls[y])
.collect::<Vec<Lane64<B>>>();
for x in 0..5 {
lanes.set(x, y, t[x].xor(t[(x + 1) % 5].not().and(t[(x + 2) % 5])));
}
}
for j in 0..7 {
r = ((r << 1) ^ ((r >> 7) * 0x71)) % 256;
if (r >> 1) & 1u16 == 1u16 {
let rhs = Lane64::load64(
(1u64 << ((1 << j) - 1))
.to_le_bytes()
.into_iter()
.map(Byte::from)
.collect::<Vec<Byte<B>>>()
.try_into()
.unwrap_or_else(|v: Vec<Byte<B>>| {
panic!("Expected a Vec of length 8 (found {})", v.len())
}),
);
lanes.set(0, 0, lanes.inner()[0][0].xor(rhs));
}
}
}
lanes
}
fn f1600<B: Boolean>(state: [Byte<B>; 200]) -> [Byte<B>; 200] {
let mut lanes = Lanes::new(
(0..5)
.map(|x| {
(0..5)
.map(|y| {
Lane64::load64(
state[8 * (x + 5 * y)..8 * (x + 5 * y) + 8]
.to_vec()
.try_into()
.unwrap_or_else(|v: Vec<Byte<B>>| {
panic!("Expected a Vec of length 8 (found {})", v.len())
}),
)
})
.collect::<Vec<Lane64<B>>>()
.try_into()
.unwrap_or_else(|v: Vec<Lane64<B>>| {
panic!("Expected a Vec of length 5 (found {})", v.len())
})
})
.collect::<Vec<[Lane64<B>; 5]>>()
.try_into()
.unwrap_or_else(|v: Vec<[Lane64<B>; 5]>| {
panic!("Expected a Vec of length 5 (found {})", v.len())
}),
);
lanes = Self::f1600_on_lanes(lanes);
let mut state = vec![Byte::from(0u8); 200];
(0..5).for_each(|x| {
(0..5).for_each(|y| {
state.splice(
8 * (x + 5 * y)..8 * (x + 5 * y) + 8,
lanes.inner()[x][y].store64(),
);
});
});
state.try_into().unwrap_or_else(|v: Vec<Byte<B>>| {
panic!("Expected a Vec of length 200 (found {})", v.len())
})
}
}
impl Keccak for BooleanValue {
fn f1600(state: [Byte<BooleanValue>; 200]) -> [Byte<BooleanValue>; 200] {
if state
.iter()
.all(|byte| byte.get_bits().iter().all(|bit| bit.is_plaintext()))
{
let res_ids = with_global_expr_store_as_local(|expr_store| {
let bit_ids = state
.into_iter()
.flat_map(|byte| {
byte.get_bits()
.into_iter()
.map(|bit| bit.get_id())
.collect::<Vec<usize>>()
})
.collect::<Vec<usize>>();
(0..1600)
.map(|i| {
<IRBuilder as ExprStore<BaseField>>::push_bit(
expr_store,
BitExpr::KeccakF1600(bit_ids.clone(), i),
)
})
.collect::<Vec<usize>>()
});
res_ids
.into_iter()
.map(BooleanValue::new)
.collect::<Vec<BooleanValue>>()
.chunks(8)
.map(|chunk| {
Byte::new(
chunk
.to_vec()
.try_into()
.unwrap_or_else(|v: Vec<BooleanValue>| {
panic!("Expected a Vec of length 8 (found {})", v.len())
}),
)
})
.collect::<Vec<Byte<BooleanValue>>>()
.try_into()
.unwrap_or_else(|v: Vec<Byte<BooleanValue>>| {
panic!("Expected a Vec of length 200 (found {})", v.len())
})
} else {
KeccakPermutation::f1600(state)
}
}
}
impl<const N: usize> Keccak for BooleanArray<N> {
fn f1600(state: [Byte<BooleanArray<N>>; 200]) -> [Byte<BooleanArray<N>>; 200] {
KeccakPermutation::f1600(state)
}
}
impl Keccak for bool {
fn f1600(state: [Byte<bool>; 200]) -> [Byte<bool>; 200] {
let mut internal_state = [0u64; 25];
state.chunks(8).enumerate().for_each(|(i, chunk)| {
let mut val = 0u64;
chunk.iter().enumerate().take(8).for_each(|(j, byte)| {
val |= u64::from(u8::from(*byte)) << (8 * j);
});
internal_state[i] = val;
});
keccak::f1600(&mut internal_state);
internal_state
.into_iter()
.flat_map(|val| {
val.to_le_bytes()
.into_iter()
.map(Byte::<bool>::from)
.collect::<Vec<Byte<bool>>>()
})
.collect::<Vec<Byte<bool>>>()
.try_into()
.unwrap_or_else(|v: Vec<Byte<bool>>| {
panic!("Expected a Vec of length 200 (found {})", v.len())
})
}
}
#[derive(Clone, Debug, Default)]
#[allow(clippy::upper_case_acronyms)]
pub struct SHA3_256;
impl SHA3_256 {
pub fn new() -> Self {
Self
}
pub fn digest_in_bytes(&self) -> usize {
32
}
pub fn rate_in_bytes(&self) -> usize {
1088 / 8
}
pub fn digest<B: Boolean>(&self, message: Vec<Byte<B>>) -> [Byte<B>; 32] {
Keccak::sponge::<32>(1088, 512, message)
}
}
#[derive(Clone, Debug, Default)]
#[allow(clippy::upper_case_acronyms)]
pub struct SHA3_512;
impl SHA3_512 {
pub fn new() -> Self {
Self
}
pub fn digest_in_bytes(&self) -> usize {
64
}
pub fn rate_in_bytes(&self) -> usize {
576 / 8
}
pub fn digest<B: Boolean>(&self, message: Vec<Byte<B>>) -> [Byte<B>; 64] {
Keccak::sponge::<64>(576, 1024, message)
}
}
impl ArithmeticCircuit<BaseField> for SHA3_256 {
fn eval(&self, x: Vec<BaseField>) -> Result<Vec<BaseField>, EvalFailure> {
x.iter()
.for_each(|byte| assert!(*byte <= BaseField::from(255)));
let message = x
.into_iter()
.map(|val| val.to_le_bytes()[0])
.collect::<Vec<u8>>();
let mut hasher = Sha3_256::new();
hasher.update(message);
let digest = hasher.finalize();
Ok(digest
.iter()
.map(|byte| BaseField::from(*byte as u64))
.collect::<Vec<BaseField>>())
}
fn bounds(&self, _bounds: Vec<FieldBounds<BaseField>>) -> Vec<FieldBounds<BaseField>> {
vec![FieldBounds::new(BaseField::from(0), BaseField::from(255)); 32]
}
fn run(&self, vals: Vec<FieldValue<BaseField>>) -> Vec<FieldValue<BaseField>> {
let message = vals
.into_iter()
.map(Byte::from)
.collect::<Vec<Byte<BooleanValue>>>();
let hasher = SHA3_256::new();
hasher
.digest(message)
.into_iter()
.map(FieldValue::<BaseField>::from)
.collect::<Vec<FieldValue<BaseField>>>()
}
}
impl ArithmeticCircuit<BaseField> for SHA3_512 {
fn eval(&self, x: Vec<BaseField>) -> Result<Vec<BaseField>, EvalFailure> {
x.iter()
.for_each(|byte| assert!(*byte <= BaseField::from(255)));
let message = x
.into_iter()
.map(|val| val.to_le_bytes()[0])
.collect::<Vec<u8>>();
let mut hasher = Sha3_512::new();
hasher.update(message);
let digest = hasher.finalize();
Ok(digest
.iter()
.map(|byte| BaseField::from(*byte as u64))
.collect::<Vec<BaseField>>())
}
fn bounds(&self, _bounds: Vec<FieldBounds<BaseField>>) -> Vec<FieldBounds<BaseField>> {
vec![FieldBounds::new(BaseField::from(0), BaseField::from(255)); 64]
}
fn run(&self, vals: Vec<FieldValue<BaseField>>) -> Vec<FieldValue<BaseField>> {
let message = vals
.into_iter()
.map(Byte::from)
.collect::<Vec<Byte<BooleanValue>>>();
let hasher = SHA3_512::new();
hasher
.digest(message)
.into_iter()
.map(FieldValue::<BaseField>::from)
.collect::<Vec<FieldValue<BaseField>>>()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::circuits::traits::arithmetic_circuit::tests::TestedArithmeticCircuit;
use rand::Rng;
impl TestedArithmeticCircuit<BaseField> for SHA3_256 {
fn gen_desc<R: Rng + ?Sized>(_rng: &mut R) -> Self {
Self
}
fn gen_n_inputs<R: Rng + ?Sized>(&self, rng: &mut R) -> usize {
let mut byte_len = 130;
while rng.gen_bool(0.75) {
byte_len += 3;
}
byte_len
}
fn gen_input_bounds<R: Rng + ?Sized>(_rng: &mut R) -> FieldBounds<BaseField> {
FieldBounds::new(BaseField::from(0), BaseField::from(255))
}
}
impl TestedArithmeticCircuit<BaseField> for SHA3_512 {
fn gen_desc<R: Rng + ?Sized>(_rng: &mut R) -> Self {
Self
}
fn gen_n_inputs<R: Rng + ?Sized>(&self, rng: &mut R) -> usize {
let mut byte_len = 66;
while rng.gen_bool(0.75) {
byte_len += 3;
}
byte_len
}
fn gen_input_bounds<R: Rng + ?Sized>(_rng: &mut R) -> FieldBounds<BaseField> {
FieldBounds::new(BaseField::from(0), BaseField::from(255))
}
}
#[test]
fn tested_sha3_256() {
SHA3_256::test(1, 1)
}
#[test]
fn tested_sha3_512() {
SHA3_512::test(1, 1)
}
}