use std::{marker::PhantomData, sync::RwLock};
use anyhow::anyhow;
use ethers_core::{
types::{Bytes, H256},
utils::keccak256,
};
use halo2_base::{
gates::{GateInstructions, RangeChip},
poseidon::hasher::PoseidonCompactChunkInput,
safe_types::{SafeBytes32, SafeTypeChip},
utils::ScalarField,
AssignedValue, Context,
};
use itertools::Itertools;
use lazy_static::lazy_static;
use serde::{Deserialize, Serialize};
use type_map::concurrent::TypeMap;
use zkevm_hashes::keccak::{
component::{encode::pack_native_input, output::KeccakCircuitOutput, param::POSEIDON_RATE},
vanilla::keccak_packed_multi::get_num_keccak_f,
};
use crate::{
rlc::chip::RlcChip,
utils::{
component::{
types::{FixLenLogical, Flatten},
ComponentType, ComponentTypeId, LogicalInputValue, LogicalResult,
},
hilo::HiLo,
AssignedH256,
},
Field,
};
use super::promise::KeccakComponentCommiter;
#[derive(Clone, Debug)]
pub struct KeccakFixedLenQuery<F: ScalarField> {
pub input_assigned: Vec<AssignedValue<F>>,
pub output_bytes: SafeBytes32<F>,
pub output_hi: AssignedValue<F>,
pub output_lo: AssignedValue<F>,
}
impl<F: ScalarField> KeccakFixedLenQuery<F> {
pub fn hi_lo(&self) -> AssignedH256<F> {
[self.output_hi, self.output_lo]
}
}
#[derive(Clone, Debug)]
pub struct KeccakVarLenQuery<F: ScalarField> {
pub min_bytes: usize,
pub length: AssignedValue<F>,
pub input_assigned: Vec<AssignedValue<F>>,
pub output_bytes: SafeBytes32<F>,
pub output_hi: AssignedValue<F>,
pub output_lo: AssignedValue<F>,
}
impl<F: ScalarField> KeccakVarLenQuery<F> {
pub fn hi_lo(&self) -> AssignedH256<F> {
[self.output_hi, self.output_lo]
}
}
pub type CoreInputKeccak = Vec<Vec<u8>>;
#[derive(Clone, Debug, Default, Hash, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord)]
#[serde(rename_all = "camelCase")]
pub struct OutputKeccakShard {
pub responses: Vec<(Bytes, Option<H256>)>,
pub capacity: usize,
}
impl OutputKeccakShard {
pub fn create_dummy(capacity: usize) -> Self {
Self { responses: vec![], capacity }
}
pub fn into_logical_results<F: Field>(self) -> Vec<LogicalResult<F, ComponentTypeKeccak<F>>> {
let mut total_capacity = 0;
let mut promise_results = self
.responses
.into_iter()
.map(|(input, output)| {
let input = KeccakLogicalInput::new(input.to_vec());
total_capacity += get_num_keccak_f(input.bytes.len());
let v_output =
if let Some(hash) = output { hash.into() } else { input.compute_output::<F>() };
LogicalResult::<F, ComponentTypeKeccak<F>>::new(input, v_output)
})
.collect_vec();
assert!(total_capacity <= self.capacity);
if total_capacity < self.capacity {
let target_len = self.capacity - total_capacity + promise_results.len();
let dummy = dummy_circuit_output::<F>();
promise_results.resize(
target_len,
LogicalResult::new(
KeccakLogicalInput::new(vec![]),
KeccakVirtualOutput::<F> {
hash: HiLo::from_hi_lo([dummy.hash_hi, dummy.hash_lo]),
},
),
);
}
promise_results
}
}
#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize)]
pub struct KeccakLogicalInput {
pub bytes: Vec<u8>,
}
impl KeccakLogicalInput {
pub fn new(bytes: Vec<u8>) -> Self {
Self { bytes }
}
pub fn compute_output<F: Field>(&self) -> KeccakVirtualOutput<F> {
let hash = H256(keccak256(&self.bytes));
hash.into()
}
}
impl<F: Field> LogicalInputValue<F> for KeccakLogicalInput {
fn get_capacity(&self) -> usize {
get_num_keccak_f(self.bytes.len())
}
}
pub(crate) const NUM_WITNESS_PER_KECCAK_F: usize = 6;
const KECCAK_VIRTUAL_INPUT_FIELD_SIZE: [usize; NUM_WITNESS_PER_KECCAK_F + 1] = [
192, 192, 192, 192, 192, 192, 1, ];
const KECCAK_VIRTUAL_OUTPUT_FIELD_SIZE: [usize; 2] = [128, 128];
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
pub struct KeccakVirtualInput<T: Clone> {
pub packed_input: [T; NUM_WITNESS_PER_KECCAK_F],
pub is_final: T,
}
impl<T: Clone> KeccakVirtualInput<T> {
pub fn new(packed_input: [T; NUM_WITNESS_PER_KECCAK_F], is_final: T) -> Self {
Self { packed_input, is_final }
}
}
impl<T: Copy> TryFrom<Flatten<T>> for KeccakVirtualInput<T> {
type Error = anyhow::Error;
fn try_from(value: Flatten<T>) -> std::result::Result<Self, Self::Error> {
if value.field_size != KECCAK_VIRTUAL_INPUT_FIELD_SIZE {
return Err(anyhow::anyhow!("invalid field size"));
}
if value.field_size.len() != value.fields.len() {
return Err(anyhow::anyhow!("field length doesn't match"));
}
Ok(Self {
packed_input: value.fields[0..NUM_WITNESS_PER_KECCAK_F]
.try_into()
.map_err(|_| anyhow!("failed to convert flatten to KeccakVirtualInput"))?,
is_final: value.fields[NUM_WITNESS_PER_KECCAK_F],
})
}
}
impl<T: Copy> From<KeccakVirtualInput<T>> for Flatten<T> {
fn from(val: KeccakVirtualInput<T>) -> Self {
Self {
fields: [val.packed_input.as_slice(), [val.is_final].as_slice()].concat(),
field_size: &KECCAK_VIRTUAL_INPUT_FIELD_SIZE,
}
}
}
impl<T: Copy> FixLenLogical<T> for KeccakVirtualInput<T> {
fn get_field_size() -> &'static [usize] {
&KECCAK_VIRTUAL_INPUT_FIELD_SIZE
}
}
impl<F: Field> From<KeccakVirtualInput<AssignedValue<F>>>
for PoseidonCompactChunkInput<F, POSEIDON_RATE>
{
fn from(val: KeccakVirtualInput<AssignedValue<F>>) -> Self {
let KeccakVirtualInput::<AssignedValue<F>> { packed_input, is_final } = val;
assert!(packed_input.len() % POSEIDON_RATE == 0);
let inputs: Vec<[AssignedValue<F>; POSEIDON_RATE]> = packed_input
.into_iter()
.chunks(POSEIDON_RATE)
.into_iter()
.map(|c| c.collect_vec().try_into().unwrap())
.collect_vec();
let is_final = SafeTypeChip::unsafe_to_bool(is_final);
Self::new(inputs, is_final)
}
}
#[derive(Default, Debug, Clone, Hash, PartialEq, Eq)]
pub struct KeccakVirtualOutput<T: Clone> {
pub hash: HiLo<T>,
}
impl<T: Clone> KeccakVirtualOutput<T> {
pub fn new(hash: HiLo<T>) -> Self {
Self { hash }
}
}
impl<T: Copy> TryFrom<Flatten<T>> for KeccakVirtualOutput<T> {
type Error = anyhow::Error;
fn try_from(value: Flatten<T>) -> std::result::Result<Self, Self::Error> {
if value.field_size != KECCAK_VIRTUAL_OUTPUT_FIELD_SIZE {
return Err(anyhow::anyhow!("invalid field size"));
}
if value.field_size.len() != value.fields.len() {
return Err(anyhow::anyhow!("field length doesn't match"));
}
Ok(Self {
hash: HiLo::from_hi_lo(
value
.fields
.try_into()
.map_err(|_| anyhow!("failed to convert flatten to KeccakVirtualOutput"))?,
),
})
}
}
impl<T: Copy> From<KeccakVirtualOutput<T>> for Flatten<T> {
fn from(val: KeccakVirtualOutput<T>) -> Self {
Self { fields: val.hash.hi_lo().to_vec(), field_size: &KECCAK_VIRTUAL_OUTPUT_FIELD_SIZE }
}
}
impl<T: Copy> FixLenLogical<T> for KeccakVirtualOutput<T> {
fn get_field_size() -> &'static [usize] {
&KECCAK_VIRTUAL_OUTPUT_FIELD_SIZE
}
}
impl<F: Field> From<H256> for KeccakVirtualOutput<F> {
fn from(hash: H256) -> Self {
let hash_hi = u128::from_be_bytes(hash[..16].try_into().unwrap());
let hash_lo = u128::from_be_bytes(hash[16..].try_into().unwrap());
Self { hash: HiLo::from_hi_lo([F::from_u128(hash_hi), F::from_u128(hash_lo)]) }
}
}
#[derive(Debug, Clone)]
pub struct ComponentTypeKeccak<F: Field>(PhantomData<F>);
impl<F: Field> ComponentType<F> for ComponentTypeKeccak<F> {
type InputValue = KeccakVirtualInput<F>;
type InputWitness = KeccakVirtualInput<AssignedValue<F>>;
type OutputValue = KeccakVirtualOutput<F>;
type OutputWitness = KeccakVirtualOutput<AssignedValue<F>>;
type LogicalInput = KeccakLogicalInput;
type Commiter = KeccakComponentCommiter<F>;
fn get_type_id() -> ComponentTypeId {
"axiom-eth:ComponentTypeKeccak".to_string()
}
fn logical_result_to_virtual_rows_impl(
ins: &LogicalResult<F, Self>,
) -> Vec<(Self::InputValue, Self::OutputValue)> {
let virtual_inputs = Self::logical_input_to_virtual_rows_impl(&ins.input);
let len = virtual_inputs.len();
let mut virtual_outputs = Vec::with_capacity(len);
let dummy = dummy_circuit_output();
virtual_outputs.resize(
len - 1,
Self::OutputValue { hash: HiLo::from_hi_lo([dummy.hash_hi, dummy.hash_lo]) },
);
virtual_outputs.push(ins.output.clone());
virtual_inputs.into_iter().zip_eq(virtual_outputs).collect_vec()
}
fn logical_input_to_virtual_rows_impl(li: &Self::LogicalInput) -> Vec<Self::InputValue> {
let mut packed_inputs = pack_native_input::<F>(&li.bytes);
let len = packed_inputs.len();
for (i, packed_input) in packed_inputs.iter_mut().enumerate() {
let is_final = if i + 1 == len { F::ONE } else { F::ZERO };
packed_input.push(is_final);
}
packed_inputs
.into_iter()
.map(|p| KeccakVirtualInput::try_from_raw(p).unwrap())
.collect_vec()
}
fn rlc_virtual_rows(
(gate_ctx, rlc_ctx): (&mut Context<F>, &mut Context<F>),
range_chip: &RangeChip<F>,
rlc_chip: &RlcChip<F>,
virtual_rows: &[(Self::InputWitness, Self::OutputWitness)],
) -> Vec<AssignedValue<F>> {
let gate = &range_chip.gate;
let one = gate_ctx.load_constant(F::ONE);
let zero = gate_ctx.load_zero();
let empty_input_rlc = rlc_chip.rlc_pow_fixed(gate_ctx, gate, NUM_WITNESS_PER_KECCAK_F - 1);
let chunk_multiplier =
rlc_chip.rlc_pow_fixed(gate_ctx, &range_chip.gate, NUM_WITNESS_PER_KECCAK_F);
let output_multiplier = rlc_chip.rlc_pow_fixed(
gate_ctx,
&range_chip.gate,
Self::OutputWitness::get_num_fields(),
);
let mut last_is_final = one;
let mut curr_rlc = zero;
let mut virtual_row_rlcs = Vec::with_capacity(virtual_rows.len());
for (input, output) in virtual_rows {
let mut input_to_rlc = input.packed_input;
input_to_rlc[0] = range_chip.gate.add(gate_ctx, input_to_rlc[0], last_is_final);
let chunk_rlc = rlc_chip.compute_rlc_fixed_len(rlc_ctx, input_to_rlc).rlc_val;
curr_rlc = range_chip.gate.mul_add(gate_ctx, curr_rlc, chunk_multiplier, chunk_rlc);
let input_rlc =
range_chip.gate.select(gate_ctx, curr_rlc, empty_input_rlc, input.is_final);
let output_rlc = rlc_chip.compute_rlc_fixed_len(rlc_ctx, output.hash.hi_lo()).rlc_val;
let virtual_row_rlc =
range_chip.gate.mul_add(gate_ctx, input_rlc, output_multiplier, output_rlc);
virtual_row_rlcs.push(virtual_row_rlc);
curr_rlc = range_chip.gate.select(gate_ctx, zero, curr_rlc, input.is_final);
last_is_final = input.is_final;
}
virtual_row_rlcs
}
}
lazy_static! {
static ref CACHED_DUMMY_CIRCUIT_OUTPUT: RwLock<TypeMap> = RwLock::new(TypeMap::new());
}
fn dummy_circuit_output<F: crate::RawField>() -> KeccakCircuitOutput<F> {
use zkevm_hashes::keccak::component::output::dummy_circuit_output;
let cached_output =
CACHED_DUMMY_CIRCUIT_OUTPUT.read().unwrap().get::<KeccakCircuitOutput<F>>().cloned();
if let Some(cached_output) = cached_output {
return cached_output;
}
let output = dummy_circuit_output::<F>();
CACHED_DUMMY_CIRCUIT_OUTPUT.write().unwrap().insert(output);
output
}