use core::iter::once;
use getset::Getters;
use halo2_base::{
gates::{GateChip, GateInstructions, RangeChip, RangeInstructions},
safe_types::{SafeBytes32, SafeTypeChip},
utils::{bit_length, ScalarField},
AssignedValue, Context,
QuantumCell::Constant,
};
use itertools::Itertools;
use zkevm_hashes::keccak::vanilla::param::NUM_BYTES_TO_SQUEEZE;
use crate::{
keccak::promise::KeccakVarLenCall,
utils::{
bytes_be_to_u128, component::promise_collector::PromiseCaller, u128s_to_bytes_be,
AssignedH256,
},
Field,
};
use self::{
promise::KeccakFixLenCall,
types::{ComponentTypeKeccak, KeccakFixedLenQuery, KeccakVarLenQuery},
};
mod component_shim;
pub mod promise;
#[cfg(test)]
mod tests;
pub mod types;
#[derive(Clone, Debug, Getters)]
pub struct KeccakChip<F: Field> {
#[getset(get = "pub")]
promise_caller: PromiseCaller<F>,
#[getset(get = "pub")]
range: RangeChip<F>,
}
impl<F: Field> KeccakChip<F> {
pub fn new(range: RangeChip<F>, promise_collector: PromiseCaller<F>) -> Self {
Self::new_with_promise_collector(range, promise_collector)
}
pub fn new_with_promise_collector(
range: RangeChip<F>,
promise_collector: PromiseCaller<F>,
) -> Self {
Self { range, promise_caller: promise_collector }
}
pub fn gate(&self) -> &GateChip<F> {
&self.range.gate
}
pub fn keccak_fixed_len(
&self,
ctx: &mut Context<F>,
input: Vec<AssignedValue<F>>,
) -> KeccakFixedLenQuery<F> {
let [output_hi, output_lo] = {
let len = input.len();
let output = self
.promise_caller
.call::<KeccakFixLenCall<F>, ComponentTypeKeccak<F>>(
ctx,
KeccakFixLenCall::new(SafeTypeChip::unsafe_to_fix_len_bytes_vec(
input.clone(),
len,
)),
)
.unwrap();
output.hash.hi_lo()
};
let output_bytes = u128s_to_bytes_be(ctx, self.range(), &[output_hi, output_lo]);
let output_raw: Vec<AssignedValue<_>> =
output_bytes.into_iter().map(|b| b.into()).collect();
let output_bytes = SafeTypeChip::unsafe_to_safe_type(output_raw);
KeccakFixedLenQuery { input_assigned: input, output_bytes, output_hi, output_lo }
}
pub fn keccak_var_len(
&self,
ctx: &mut Context<F>,
input: Vec<AssignedValue<F>>,
len: AssignedValue<F>,
min_len: usize,
) -> KeccakVarLenQuery<F> {
let bytes = get_bytes(&input);
let max_len = input.len();
let range = self.range();
range.check_less_than_safe(ctx, len, (max_len + 1) as u64);
if min_len != 0 {
range.check_less_than(
ctx,
Constant(F::from((min_len - 1) as u64)),
len,
bit_length((max_len + 1) as u64),
);
}
let num_bytes = len.value().get_lower_64() as usize;
debug_assert!(bytes.len() >= num_bytes);
let [output_hi, output_lo] = {
let output = self
.promise_caller
.call::<KeccakVarLenCall<F>, ComponentTypeKeccak<F>>(
ctx,
KeccakVarLenCall::new(
SafeTypeChip::unsafe_to_var_len_bytes_vec(input.clone(), len, max_len),
min_len,
),
)
.unwrap();
output.hash.hi_lo()
};
let output_bytes = u128s_to_bytes_be(ctx, self.range(), &[output_hi, output_lo]);
KeccakVarLenQuery {
min_bytes: min_len,
length: len,
input_assigned: input,
output_bytes: output_bytes.try_into().unwrap(),
output_hi,
output_lo,
}
}
pub fn merkle_tree_root(
&self,
ctx: &mut Context<F>,
leaves: &[impl AsRef<[AssignedValue<F>]>],
) -> (SafeBytes32<F>, AssignedH256<F>) {
let depth = leaves.len().ilog2() as usize;
debug_assert_eq!(1 << depth, leaves.len());
assert_ne!(depth, 0, "Merkle root of a single leaf is ill-defined");
let mut hashes = leaves
.chunks(2)
.map(|pair| {
let leaves_concat = [pair[0].as_ref(), pair[1].as_ref()].concat();
self.keccak_fixed_len(ctx, leaves_concat)
})
.collect_vec();
debug_assert_eq!(hashes.len(), 1 << (depth - 1));
for d in (0..depth - 1).rev() {
for i in 0..(1 << d) {
let leaves_concat =
[2 * i, 2 * i + 1].map(|idx| hashes[idx].output_bytes.as_ref()).concat();
hashes[i] = self.keccak_fixed_len(ctx, leaves_concat);
}
}
(hashes[0].output_bytes.clone(), [hashes[0].output_hi, hashes[0].output_lo])
}
pub fn merkle_mountain_range(
&self,
ctx: &mut Context<F>,
leaves: &[Vec<AssignedValue<F>>],
num_leaves_bits: &[AssignedValue<F>],
) -> Vec<(SafeBytes32<F>, AssignedH256<F>)> {
let max_depth = leaves.len().ilog2() as usize;
assert_eq!(leaves.len(), 1 << max_depth);
assert_eq!(num_leaves_bits.len(), max_depth + 1);
let mut shift_leaves = leaves.to_vec();
once(self.merkle_tree_root(ctx, leaves))
.chain(num_leaves_bits.iter().enumerate().rev().skip(1).map(|(depth, &sel)| {
if depth != 0 {
let peak = self.merkle_tree_root(ctx, &shift_leaves[..(1usize << depth)]);
for i in 0..1 << depth {
debug_assert_eq!(shift_leaves[i].len(), NUM_BYTES_TO_SQUEEZE);
for j in 0..shift_leaves[i].len() {
shift_leaves[i][j] = self.gate().select(
ctx,
shift_leaves[i + (1 << depth)][j],
shift_leaves[i][j],
sel,
);
}
}
peak
} else {
let leaf_bytes =
SafeTypeChip::unsafe_to_fix_len_bytes_vec(shift_leaves[0].clone(), 32)
.into_bytes();
let hi_lo: [_; 2] =
bytes_be_to_u128(ctx, self.gate(), &leaf_bytes).try_into().unwrap();
let bytes = SafeBytes32::try_from(leaf_bytes).unwrap();
(bytes, hi_lo)
}
}))
.collect()
}
}
pub fn get_bytes<F: ScalarField>(bytes: &[impl AsRef<AssignedValue<F>>]) -> Vec<u8> {
bytes.iter().map(|b| b.as_ref().value().get_lower_64() as u8).collect_vec()
}