1use std::{marker::PhantomData, sync::RwLock};
2
3use anyhow::anyhow;
4
5use ethers_core::{
6 types::{Bytes, H256},
7 utils::keccak256,
8};
9use halo2_base::{
10 gates::{GateInstructions, RangeChip},
11 poseidon::hasher::PoseidonCompactChunkInput,
12 safe_types::{SafeBytes32, SafeTypeChip},
13 utils::ScalarField,
14 AssignedValue, Context,
15};
16use itertools::Itertools;
17use lazy_static::lazy_static;
18use serde::{Deserialize, Serialize};
19
20use type_map::concurrent::TypeMap;
21use zkevm_hashes::keccak::{
22 component::{encode::pack_native_input, output::KeccakCircuitOutput, param::POSEIDON_RATE},
23 vanilla::keccak_packed_multi::get_num_keccak_f,
24};
25
26use crate::{
27 rlc::chip::RlcChip,
28 utils::{
29 component::{
30 types::{FixLenLogical, Flatten},
31 ComponentType, ComponentTypeId, LogicalInputValue, LogicalResult,
32 },
33 hilo::HiLo,
34 AssignedH256,
35 },
36 Field,
37};
38
39use super::promise::KeccakComponentCommiter;
40
41#[derive(Clone, Debug)]
42pub struct KeccakFixedLenQuery<F: ScalarField> {
43 pub input_assigned: Vec<AssignedValue<F>>,
45 pub output_bytes: SafeBytes32<F>,
48 pub output_hi: AssignedValue<F>,
50 pub output_lo: AssignedValue<F>,
52}
53
54impl<F: ScalarField> KeccakFixedLenQuery<F> {
55 pub fn hi_lo(&self) -> AssignedH256<F> {
56 [self.output_hi, self.output_lo]
57 }
58}
59
60#[derive(Clone, Debug)]
61pub struct KeccakVarLenQuery<F: ScalarField> {
62 pub min_bytes: usize,
63 pub length: AssignedValue<F>,
67 pub input_assigned: Vec<AssignedValue<F>>,
68 pub output_bytes: SafeBytes32<F>,
71 pub output_hi: AssignedValue<F>,
73 pub output_lo: AssignedValue<F>,
75}
76
77impl<F: ScalarField> KeccakVarLenQuery<F> {
78 pub fn hi_lo(&self) -> AssignedH256<F> {
79 [self.output_hi, self.output_lo]
80 }
81}
82
83pub type CoreInputKeccak = Vec<Vec<u8>>;
85
86#[derive(Clone, Debug, Default, Hash, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord)]
87#[serde(rename_all = "camelCase")]
88pub struct OutputKeccakShard {
89 pub responses: Vec<(Bytes, Option<H256>)>,
92 pub capacity: usize,
94}
95
96impl OutputKeccakShard {
97 pub fn create_dummy(capacity: usize) -> Self {
99 Self { responses: vec![], capacity }
100 }
101 pub fn into_logical_results<F: Field>(self) -> Vec<LogicalResult<F, ComponentTypeKeccak<F>>> {
102 let mut total_capacity = 0;
103 let mut promise_results = self
104 .responses
105 .into_iter()
106 .map(|(input, output)| {
107 let input = KeccakLogicalInput::new(input.to_vec());
108 total_capacity += get_num_keccak_f(input.bytes.len());
109 let v_output =
110 if let Some(hash) = output { hash.into() } else { input.compute_output::<F>() };
111 LogicalResult::<F, ComponentTypeKeccak<F>>::new(input, v_output)
112 })
113 .collect_vec();
114 assert!(total_capacity <= self.capacity);
115 if total_capacity < self.capacity {
116 let target_len = self.capacity - total_capacity + promise_results.len();
117 let dummy = dummy_circuit_output::<F>();
118 promise_results.resize(
119 target_len,
120 LogicalResult::new(
121 KeccakLogicalInput::new(vec![]),
122 KeccakVirtualOutput::<F> {
123 hash: HiLo::from_hi_lo([dummy.hash_hi, dummy.hash_lo]),
124 },
125 ),
126 );
127 }
128 promise_results
129 }
130}
131
132#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize)]
134pub struct KeccakLogicalInput {
135 pub bytes: Vec<u8>,
136}
137impl KeccakLogicalInput {
138 pub fn new(bytes: Vec<u8>) -> Self {
140 Self { bytes }
141 }
142 pub fn compute_output<F: Field>(&self) -> KeccakVirtualOutput<F> {
143 let hash = H256(keccak256(&self.bytes));
144 hash.into()
145 }
146}
147
148impl<F: Field> LogicalInputValue<F> for KeccakLogicalInput {
149 fn get_capacity(&self) -> usize {
150 get_num_keccak_f(self.bytes.len())
151 }
152}
153
154pub(crate) const NUM_WITNESS_PER_KECCAK_F: usize = 6;
155const KECCAK_VIRTUAL_INPUT_FIELD_SIZE: [usize; NUM_WITNESS_PER_KECCAK_F + 1] = [
156 192, 192, 192, 192, 192, 192, 1, ];
159const KECCAK_VIRTUAL_OUTPUT_FIELD_SIZE: [usize; 2] = [128, 128];
160
161#[derive(Debug, Clone, Hash, PartialEq, Eq)]
164pub struct KeccakVirtualInput<T: Clone> {
165 pub packed_input: [T; NUM_WITNESS_PER_KECCAK_F],
168 pub is_final: T,
171}
172
173impl<T: Clone> KeccakVirtualInput<T> {
174 pub fn new(packed_input: [T; NUM_WITNESS_PER_KECCAK_F], is_final: T) -> Self {
175 Self { packed_input, is_final }
176 }
177}
178
179impl<T: Copy> TryFrom<Flatten<T>> for KeccakVirtualInput<T> {
180 type Error = anyhow::Error;
181
182 fn try_from(value: Flatten<T>) -> std::result::Result<Self, Self::Error> {
183 if value.field_size != KECCAK_VIRTUAL_INPUT_FIELD_SIZE {
184 return Err(anyhow::anyhow!("invalid field size"));
185 }
186 if value.field_size.len() != value.fields.len() {
187 return Err(anyhow::anyhow!("field length doesn't match"));
188 }
189
190 Ok(Self {
191 packed_input: value.fields[0..NUM_WITNESS_PER_KECCAK_F]
192 .try_into()
193 .map_err(|_| anyhow!("failed to convert flatten to KeccakVirtualInput"))?,
194 is_final: value.fields[NUM_WITNESS_PER_KECCAK_F],
195 })
196 }
197}
198impl<T: Copy> From<KeccakVirtualInput<T>> for Flatten<T> {
199 fn from(val: KeccakVirtualInput<T>) -> Self {
200 Self {
201 fields: [val.packed_input.as_slice(), [val.is_final].as_slice()].concat(),
202 field_size: &KECCAK_VIRTUAL_INPUT_FIELD_SIZE,
203 }
204 }
205}
206impl<T: Copy> FixLenLogical<T> for KeccakVirtualInput<T> {
207 fn get_field_size() -> &'static [usize] {
208 &KECCAK_VIRTUAL_INPUT_FIELD_SIZE
209 }
210}
211
212impl<F: Field> From<KeccakVirtualInput<AssignedValue<F>>>
213 for PoseidonCompactChunkInput<F, POSEIDON_RATE>
214{
215 fn from(val: KeccakVirtualInput<AssignedValue<F>>) -> Self {
216 let KeccakVirtualInput::<AssignedValue<F>> { packed_input, is_final } = val;
217 assert!(packed_input.len() % POSEIDON_RATE == 0);
218 let inputs: Vec<[AssignedValue<F>; POSEIDON_RATE]> = packed_input
219 .into_iter()
220 .chunks(POSEIDON_RATE)
221 .into_iter()
222 .map(|c| c.collect_vec().try_into().unwrap())
223 .collect_vec();
224 let is_final = SafeTypeChip::unsafe_to_bool(is_final);
225 Self::new(inputs, is_final)
226 }
227}
228
229#[derive(Default, Debug, Clone, Hash, PartialEq, Eq)]
231pub struct KeccakVirtualOutput<T: Clone> {
232 pub hash: HiLo<T>,
234}
235
236impl<T: Clone> KeccakVirtualOutput<T> {
237 pub fn new(hash: HiLo<T>) -> Self {
238 Self { hash }
239 }
240}
241
242impl<T: Copy> TryFrom<Flatten<T>> for KeccakVirtualOutput<T> {
243 type Error = anyhow::Error;
244
245 fn try_from(value: Flatten<T>) -> std::result::Result<Self, Self::Error> {
246 if value.field_size != KECCAK_VIRTUAL_OUTPUT_FIELD_SIZE {
247 return Err(anyhow::anyhow!("invalid field size"));
248 }
249 if value.field_size.len() != value.fields.len() {
250 return Err(anyhow::anyhow!("field length doesn't match"));
251 }
252
253 Ok(Self {
254 hash: HiLo::from_hi_lo(
255 value
256 .fields
257 .try_into()
258 .map_err(|_| anyhow!("failed to convert flatten to KeccakVirtualOutput"))?,
259 ),
260 })
261 }
262}
263impl<T: Copy> From<KeccakVirtualOutput<T>> for Flatten<T> {
264 fn from(val: KeccakVirtualOutput<T>) -> Self {
265 Self { fields: val.hash.hi_lo().to_vec(), field_size: &KECCAK_VIRTUAL_OUTPUT_FIELD_SIZE }
266 }
267}
268impl<T: Copy> FixLenLogical<T> for KeccakVirtualOutput<T> {
269 fn get_field_size() -> &'static [usize] {
270 &KECCAK_VIRTUAL_OUTPUT_FIELD_SIZE
271 }
272}
273impl<F: Field> From<H256> for KeccakVirtualOutput<F> {
274 fn from(hash: H256) -> Self {
275 let hash_hi = u128::from_be_bytes(hash[..16].try_into().unwrap());
276 let hash_lo = u128::from_be_bytes(hash[16..].try_into().unwrap());
277 Self { hash: HiLo::from_hi_lo([F::from_u128(hash_hi), F::from_u128(hash_lo)]) }
278 }
279}
280
281#[derive(Debug, Clone)]
282pub struct ComponentTypeKeccak<F: Field>(PhantomData<F>);
283
284impl<F: Field> ComponentType<F> for ComponentTypeKeccak<F> {
285 type InputValue = KeccakVirtualInput<F>;
286 type InputWitness = KeccakVirtualInput<AssignedValue<F>>;
287 type OutputValue = KeccakVirtualOutput<F>;
288 type OutputWitness = KeccakVirtualOutput<AssignedValue<F>>;
289 type LogicalInput = KeccakLogicalInput;
290 type Commiter = KeccakComponentCommiter<F>;
291
292 fn get_type_id() -> ComponentTypeId {
293 "axiom-eth:ComponentTypeKeccak".to_string()
294 }
295
296 fn logical_result_to_virtual_rows_impl(
297 ins: &LogicalResult<F, Self>,
298 ) -> Vec<(Self::InputValue, Self::OutputValue)> {
299 let virtual_inputs = Self::logical_input_to_virtual_rows_impl(&ins.input);
300 let len = virtual_inputs.len();
301 let mut virtual_outputs = Vec::with_capacity(len);
302 let dummy = dummy_circuit_output();
303 virtual_outputs.resize(
304 len - 1,
305 Self::OutputValue { hash: HiLo::from_hi_lo([dummy.hash_hi, dummy.hash_lo]) },
306 );
307 virtual_outputs.push(ins.output.clone());
308 virtual_inputs.into_iter().zip_eq(virtual_outputs).collect_vec()
309 }
310 fn logical_input_to_virtual_rows_impl(li: &Self::LogicalInput) -> Vec<Self::InputValue> {
311 let mut packed_inputs = pack_native_input::<F>(&li.bytes);
312 let len = packed_inputs.len();
313 for (i, packed_input) in packed_inputs.iter_mut().enumerate() {
314 let is_final = if i + 1 == len { F::ONE } else { F::ZERO };
315 packed_input.push(is_final);
316 }
317 packed_inputs
318 .into_iter()
319 .map(|p| KeccakVirtualInput::try_from_raw(p).unwrap())
320 .collect_vec()
321 }
322 fn rlc_virtual_rows(
323 (gate_ctx, rlc_ctx): (&mut Context<F>, &mut Context<F>),
324 range_chip: &RangeChip<F>,
325 rlc_chip: &RlcChip<F>,
326 virtual_rows: &[(Self::InputWitness, Self::OutputWitness)],
327 ) -> Vec<AssignedValue<F>> {
328 let gate = &range_chip.gate;
329 let one = gate_ctx.load_constant(F::ONE);
330 let zero = gate_ctx.load_zero();
331 let empty_input_rlc = rlc_chip.rlc_pow_fixed(gate_ctx, gate, NUM_WITNESS_PER_KECCAK_F - 1);
332 let chunk_multiplier =
336 rlc_chip.rlc_pow_fixed(gate_ctx, &range_chip.gate, NUM_WITNESS_PER_KECCAK_F);
337 let output_multiplier = rlc_chip.rlc_pow_fixed(
338 gate_ctx,
339 &range_chip.gate,
340 Self::OutputWitness::get_num_fields(),
341 );
342
343 let mut last_is_final = one;
345 let mut curr_rlc = zero;
347 let mut virtual_row_rlcs = Vec::with_capacity(virtual_rows.len());
348 for (input, output) in virtual_rows {
349 let mut input_to_rlc = input.packed_input;
350 input_to_rlc[0] = range_chip.gate.add(gate_ctx, input_to_rlc[0], last_is_final);
353
354 let chunk_rlc = rlc_chip.compute_rlc_fixed_len(rlc_ctx, input_to_rlc).rlc_val;
355 curr_rlc = range_chip.gate.mul_add(gate_ctx, curr_rlc, chunk_multiplier, chunk_rlc);
356
357 let input_rlc =
358 range_chip.gate.select(gate_ctx, curr_rlc, empty_input_rlc, input.is_final);
359 let output_rlc = rlc_chip.compute_rlc_fixed_len(rlc_ctx, output.hash.hi_lo()).rlc_val;
360 let virtual_row_rlc =
361 range_chip.gate.mul_add(gate_ctx, input_rlc, output_multiplier, output_rlc);
362 virtual_row_rlcs.push(virtual_row_rlc);
363
364 curr_rlc = range_chip.gate.select(gate_ctx, zero, curr_rlc, input.is_final);
365
366 last_is_final = input.is_final;
367 }
368 virtual_row_rlcs
369 }
370}
371
372lazy_static! {
373 static ref CACHED_DUMMY_CIRCUIT_OUTPUT: RwLock<TypeMap> = RwLock::new(TypeMap::new());
377}
378
379fn dummy_circuit_output<F: crate::RawField>() -> KeccakCircuitOutput<F> {
384 use zkevm_hashes::keccak::component::output::dummy_circuit_output;
385
386 let cached_output =
387 CACHED_DUMMY_CIRCUIT_OUTPUT.read().unwrap().get::<KeccakCircuitOutput<F>>().cloned();
388 if let Some(cached_output) = cached_output {
389 return cached_output;
390 }
391 let output = dummy_circuit_output::<F>();
392 CACHED_DUMMY_CIRCUIT_OUTPUT.write().unwrap().insert(output);
393 output
394}