use std::{any::Any, marker::PhantomData};
use getset::Getters;
use halo2_base::{
gates::{circuit::builder::BaseCircuitBuilder, GateInstructions, RangeChip, RangeInstructions},
safe_types::{FixLenBytesVec, VarLenBytesVec},
utils::bit_length,
AssignedValue, Context,
QuantumCell::Constant,
};
use itertools::Itertools;
use num_bigint::BigUint;
use snark_verifier::util::hash::Poseidon;
use snark_verifier_sdk::NativeLoader;
use zkevm_hashes::keccak::{
component::{
encode::{format_input, num_word_per_witness, pack_native_input},
param::{POSEIDON_RATE, POSEIDON_T},
},
vanilla::{
keccak_packed_multi::get_num_keccak_f,
param::{NUM_BITS_PER_WORD, NUM_BYTES_TO_ABSORB},
},
};
use crate::{
rlc::chip::RlcChip,
utils::component::{
promise_loader::comp_loader::ComponentCommiter,
types::Flatten,
utils::{create_hasher, into_key, native_poseidon_hasher, try_from_key},
ComponentCircuit, ComponentType, ComponentTypeId, LogicalInputValue, PromiseCallWitness,
TypelessLogicalInput,
},
Field,
};
use super::types::{
ComponentTypeKeccak, KeccakLogicalInput, KeccakVirtualInput, KeccakVirtualOutput,
OutputKeccakShard, NUM_WITNESS_PER_KECCAK_F,
};
#[derive(Clone, Debug, Getters)]
pub struct KeccakFixLenCall<F: Field> {
#[getset(get = "pub")]
bytes: FixLenBytesVec<F>,
}
impl<F: Field> KeccakFixLenCall<F> {
pub fn new(bytes: FixLenBytesVec<F>) -> Self {
Self { bytes }
}
pub fn to_logical_input(&self) -> KeccakLogicalInput {
let bytes_vec = self
.bytes
.bytes()
.iter()
.map(|b| b.as_ref().value().get_lower_64() as u8)
.collect_vec();
KeccakLogicalInput::new(bytes_vec)
}
}
impl<F: Field> PromiseCallWitness<F> for KeccakFixLenCall<F> {
fn get_component_type_id(&self) -> ComponentTypeId {
ComponentTypeKeccak::<F>::get_type_id()
}
fn get_capacity(&self) -> usize {
get_num_keccak_f(self.bytes.len())
}
fn to_rlc(
&self,
(gate_ctx, rlc_ctx): (&mut Context<F>, &mut Context<F>),
range_chip: &RangeChip<F>,
rlc_chip: &RlcChip<F>,
) -> AssignedValue<F> {
let len = self.bytes.len();
let len_p1 = gate_ctx.load_constant(F::from((len + 1) as u64));
let packed_input =
format_input(gate_ctx, &range_chip.gate, self.bytes.bytes(), len_p1).concat().concat();
let rlc_fixed_trace = rlc_chip.compute_rlc_fixed_len(rlc_ctx, packed_input);
rlc_fixed_trace.rlc_val
}
fn to_typeless_logical_input(&self) -> TypelessLogicalInput {
into_key(self.to_logical_input())
}
fn get_mock_output(&self) -> Flatten<F> {
let bytes_vec = self
.bytes
.bytes()
.iter()
.map(|b| b.as_ref().value().get_lower_64() as u8)
.collect_vec();
let logical_input = KeccakLogicalInput::new(bytes_vec);
let output_val = logical_input.compute_output();
output_val.into()
}
fn as_any(&self) -> &dyn Any {
self
}
}
#[derive(Clone, Debug, Getters)]
pub struct KeccakVarLenCall<F: Field> {
#[getset(get = "pub")]
bytes: VarLenBytesVec<F>,
min_len: usize,
}
impl<F: Field> KeccakVarLenCall<F> {
pub fn new(bytes: VarLenBytesVec<F>, min_len: usize) -> Self {
Self { bytes, min_len }
}
pub fn to_logical_input(&self) -> KeccakLogicalInput {
let len = self.bytes.len().value().get_lower_64() as usize;
let bytes_vec = self.bytes.bytes()[..len]
.iter()
.map(|b| b.as_ref().value().get_lower_64() as u8)
.collect_vec();
KeccakLogicalInput::new(bytes_vec)
}
pub fn num_keccak_f_m1(
&self,
gate_ctx: &mut Context<F>,
range_chip: &RangeChip<F>,
) -> AssignedValue<F> {
let max_len = self.bytes.max_len();
let num_bits = bit_length(max_len as u64);
let len = *self.bytes.len();
let (num_keccak_f_m1, _) =
range_chip.div_mod(gate_ctx, len, BigUint::from(NUM_BYTES_TO_ABSORB), num_bits);
num_keccak_f_m1
}
}
impl<F: Field> PromiseCallWitness<F> for KeccakVarLenCall<F> {
fn get_component_type_id(&self) -> ComponentTypeId {
ComponentTypeKeccak::<F>::get_type_id()
}
fn get_capacity(&self) -> usize {
get_num_keccak_f(self.bytes.len().value().get_lower_64() as usize)
}
fn to_rlc(
&self,
(gate_ctx, rlc_ctx): (&mut Context<F>, &mut Context<F>),
range_chip: &RangeChip<F>,
rlc_chip: &RlcChip<F>,
) -> AssignedValue<F> {
let bytes = self.bytes.ensure_0_padding(gate_ctx, &range_chip.gate);
let num_keccak_f_m1 = self.num_keccak_f_m1(gate_ctx, range_chip);
let len = bytes.len();
let len_p1 = range_chip.gate.inc(gate_ctx, *len);
let num_keccak_f = range_chip.gate.inc(gate_ctx, num_keccak_f_m1);
let packed_input = format_input(gate_ctx, &range_chip.gate, bytes.bytes(), len_p1);
let packed_input = packed_input.into_iter().flatten().flatten();
let rlc_len = range_chip.gate.mul(
gate_ctx,
Constant(F::from(NUM_WITNESS_PER_KECCAK_F as u64)),
num_keccak_f,
);
let rlc_trace = rlc_chip.compute_rlc_with_min_len(
(gate_ctx, rlc_ctx),
&range_chip.gate,
packed_input,
rlc_len,
get_num_keccak_f(self.min_len) * NUM_WITNESS_PER_KECCAK_F,
);
rlc_trace.rlc_val
}
fn to_typeless_logical_input(&self) -> TypelessLogicalInput {
into_key(self.to_logical_input())
}
fn get_mock_output(&self) -> Flatten<F> {
let len = self.bytes.len().value().get_lower_64() as usize;
let bytes_vec = self.bytes.bytes()[..len]
.iter()
.map(|b| b.as_ref().value().get_lower_64() as u8)
.collect_vec();
let logical_input: KeccakLogicalInput = KeccakLogicalInput::new(bytes_vec);
let output_val: <ComponentTypeKeccak<F> as ComponentType<F>>::OutputValue =
logical_input.compute_output();
output_val.into()
}
fn as_any(&self) -> &dyn Any {
self
}
}
fn get_dummy_key<F: Field>(native_poseidon: &mut Poseidon<F, F, POSEIDON_T, POSEIDON_RATE>) -> F {
native_poseidon.clear();
let witnesses_per_keccak_f = pack_native_input(&[]);
for witnesses in witnesses_per_keccak_f {
for absorbing in witnesses.chunks(POSEIDON_RATE) {
let mut padded_absorb = [F::ZERO; POSEIDON_RATE];
padded_absorb[..absorbing.len()].copy_from_slice(absorbing);
native_poseidon.update(&padded_absorb);
}
}
native_poseidon.squeeze()
}
pub struct KeccakComponentCommiter<F: Field>(PhantomData<F>);
impl<F: Field> ComponentCommiter<F> for KeccakComponentCommiter<F> {
fn compute_commitment(
builder: &mut BaseCircuitBuilder<F>,
witness_virtual_rows: &[(Flatten<AssignedValue<F>>, Flatten<AssignedValue<F>>)],
) -> AssignedValue<F> {
let range_chip = &builder.range_chip();
let ctx = builder.main(0);
let mut hasher = create_hasher::<F>();
hasher.initialize_consts(ctx, &range_chip.gate);
let dummy_key = {
let mut native_poseidon = Poseidon::from_spec(&NativeLoader, hasher.spec().clone());
get_dummy_key(&mut native_poseidon)
};
let dummy_input = ctx.load_constant(dummy_key);
let parsed_virtual_rows: Vec<(KeccakVirtualInput<_>, KeccakVirtualOutput<_>)> =
witness_virtual_rows
.iter()
.map(|(v_i, v_o)| {
(v_i.clone().try_into().unwrap(), v_o.clone().try_into().unwrap())
})
.collect_vec();
let mut remaining_keccak_f = ctx.load_zero();
for (v_i, _) in &parsed_virtual_rows {
let (_, length_placeholder) = range_chip.div_mod(
ctx,
v_i.packed_input[0],
BigUint::from(1u128 << NUM_BITS_PER_WORD),
NUM_BITS_PER_WORD * num_word_per_witness::<F>(),
);
let (num_keccak_f_dec, _) = range_chip.div_mod(
ctx,
length_placeholder,
BigUint::from(NUM_BYTES_TO_ABSORB),
NUM_BITS_PER_WORD,
);
let remaining_keccak_f_is_zero = range_chip.gate.is_zero(ctx, remaining_keccak_f);
let remaining_keccak_f_dec = range_chip.gate.dec(ctx, remaining_keccak_f);
remaining_keccak_f = range_chip.gate.select(
ctx,
num_keccak_f_dec,
remaining_keccak_f_dec,
remaining_keccak_f_is_zero,
);
let is_final = range_chip.gate.is_zero(ctx, remaining_keccak_f);
ctx.constrain_equal(&is_final, &v_i.is_final);
}
let mut inputs_to_poseidon = Vec::with_capacity(parsed_virtual_rows.len());
let mut virtual_outputs = Vec::with_capacity(parsed_virtual_rows.len());
for (v_i, v_o) in parsed_virtual_rows {
inputs_to_poseidon.push(v_i.into());
virtual_outputs.push(v_o);
}
let poseidon_results =
hasher.hash_compact_chunk_inputs(ctx, &range_chip.gate, &inputs_to_poseidon);
let keccak_outputs = poseidon_results
.into_iter()
.zip_eq(virtual_outputs)
.map(|(po, vo)| {
let key = range_chip.gate.select(ctx, po.hash(), dummy_input, po.is_final());
vec![key, vo.hash.lo(), vo.hash.hi()]
})
.concat();
hasher.hash_fix_len_array(ctx, &range_chip.gate, &keccak_outputs)
}
fn compute_native_commitment(witness_virtual_rows: &[(Flatten<F>, Flatten<F>)]) -> F {
let mut hasher = native_poseidon_hasher();
let dummy_key = get_dummy_key(&mut hasher);
hasher.clear();
let keccak_outputs: Vec<_> = witness_virtual_rows
.iter()
.flat_map(|(v_i, v_o)| {
let (v_i, v_o): (KeccakVirtualInput<_>, KeccakVirtualOutput<_>) =
(v_i.clone().try_into().unwrap(), v_o.clone().try_into().unwrap());
hasher.update(&v_i.packed_input);
let key = if v_i.is_final == F::ONE {
let key = hasher.squeeze();
hasher.clear();
key
} else {
dummy_key
};
let [hi, lo] = v_o.hash.hi_lo();
[key, lo, hi]
})
.collect();
hasher.clear();
hasher.update(&keccak_outputs);
hasher.squeeze()
}
}
pub fn generate_keccak_shards_from_calls<F: Field>(
comp_circuit: &dyn ComponentCircuit<F>,
capacity: usize,
) -> anyhow::Result<OutputKeccakShard> {
let calls = comp_circuit.compute_promise_calls()?;
let keccak_type_id = ComponentTypeKeccak::<F>::get_type_id();
let keccak_calls = calls.get(&keccak_type_id).ok_or(anyhow::anyhow!("no keccak calls"))?;
let mut used_capacity = 0;
let responses = keccak_calls
.iter()
.map(|call| {
let li = try_from_key::<KeccakLogicalInput>(&call.logical_input).unwrap();
used_capacity += <KeccakLogicalInput as LogicalInputValue<F>>::get_capacity(&li);
(li.bytes.clone().into(), None)
})
.collect_vec();
log::info!("Keccak used capacity: {}", used_capacity);
if used_capacity > capacity {
return Err(anyhow::anyhow!(
"used capacity {} exceeds capacity {}",
used_capacity,
capacity
));
}
Ok(OutputKeccakShard { responses, capacity })
}