Skip to main content

axiom_eth/keccak/
types.rs

1use std::{marker::PhantomData, sync::RwLock};
2
3use anyhow::anyhow;
4
5use ethers_core::{
6    types::{Bytes, H256},
7    utils::keccak256,
8};
9use halo2_base::{
10    gates::{GateInstructions, RangeChip},
11    poseidon::hasher::PoseidonCompactChunkInput,
12    safe_types::{SafeBytes32, SafeTypeChip},
13    utils::ScalarField,
14    AssignedValue, Context,
15};
16use itertools::Itertools;
17use lazy_static::lazy_static;
18use serde::{Deserialize, Serialize};
19
20use type_map::concurrent::TypeMap;
21use zkevm_hashes::keccak::{
22    component::{encode::pack_native_input, output::KeccakCircuitOutput, param::POSEIDON_RATE},
23    vanilla::keccak_packed_multi::get_num_keccak_f,
24};
25
26use crate::{
27    rlc::chip::RlcChip,
28    utils::{
29        component::{
30            types::{FixLenLogical, Flatten},
31            ComponentType, ComponentTypeId, LogicalInputValue, LogicalResult,
32        },
33        hilo::HiLo,
34        AssignedH256,
35    },
36    Field,
37};
38
39use super::promise::KeccakComponentCommiter;
40
41#[derive(Clone, Debug)]
42pub struct KeccakFixedLenQuery<F: ScalarField> {
43    /// Input in bytes
44    pub input_assigned: Vec<AssignedValue<F>>,
45    /// The hash digest, in bytes
46    // For backwards compatbility we always compute this; we can consider computing it on-demand in the future
47    pub output_bytes: SafeBytes32<F>,
48    /// The hash digest, hi 128 bits (range checked by lookup table)
49    pub output_hi: AssignedValue<F>,
50    /// The hash digest, lo 128 bits (range checked by lookup table)
51    pub output_lo: AssignedValue<F>,
52}
53
54impl<F: ScalarField> KeccakFixedLenQuery<F> {
55    pub fn hi_lo(&self) -> AssignedH256<F> {
56        [self.output_hi, self.output_lo]
57    }
58}
59
60#[derive(Clone, Debug)]
61pub struct KeccakVarLenQuery<F: ScalarField> {
62    pub min_bytes: usize,
63    // pub max_bytes: usize, // equal to input_assigned.len()
64    // pub num_bytes: usize,
65    /// Actual length of input
66    pub length: AssignedValue<F>,
67    pub input_assigned: Vec<AssignedValue<F>>,
68    /// The hash digest, in bytes
69    // For backwards compatbility we always compute this; we can consider computing it on-demand in the future
70    pub output_bytes: SafeBytes32<F>,
71    /// The hash digest, hi 128 bits (range checked by lookup table)
72    pub output_hi: AssignedValue<F>,
73    /// The hash digest, lo 128 bits (range checked by lookup table)
74    pub output_lo: AssignedValue<F>,
75}
76
77impl<F: ScalarField> KeccakVarLenQuery<F> {
78    pub fn hi_lo(&self) -> AssignedH256<F> {
79        [self.output_hi, self.output_lo]
80    }
81}
82
83/// The core logical input to the keccak component circuit.
84pub type CoreInputKeccak = Vec<Vec<u8>>;
85
86#[derive(Clone, Debug, Default, Hash, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord)]
87#[serde(rename_all = "camelCase")]
88pub struct OutputKeccakShard {
89    /// The (assumed to be deduplicated) list of requests, in the form of variable
90    /// length byte arrays to be hashed. Optionally include the calculated hash.
91    pub responses: Vec<(Bytes, Option<H256>)>,
92    /// To prevent inconsistencies, also specify the capacity of the keccak circuit
93    pub capacity: usize,
94}
95
96impl OutputKeccakShard {
97    /// Createa a dummy OutputKeccakShard with the given capacity.
98    pub fn create_dummy(capacity: usize) -> Self {
99        Self { responses: vec![], capacity }
100    }
101    pub fn into_logical_results<F: Field>(self) -> Vec<LogicalResult<F, ComponentTypeKeccak<F>>> {
102        let mut total_capacity = 0;
103        let mut promise_results = self
104            .responses
105            .into_iter()
106            .map(|(input, output)| {
107                let input = KeccakLogicalInput::new(input.to_vec());
108                total_capacity += get_num_keccak_f(input.bytes.len());
109                let v_output =
110                    if let Some(hash) = output { hash.into() } else { input.compute_output::<F>() };
111                LogicalResult::<F, ComponentTypeKeccak<F>>::new(input, v_output)
112            })
113            .collect_vec();
114        assert!(total_capacity <= self.capacity);
115        if total_capacity < self.capacity {
116            let target_len = self.capacity - total_capacity + promise_results.len();
117            let dummy = dummy_circuit_output::<F>();
118            promise_results.resize(
119                target_len,
120                LogicalResult::new(
121                    KeccakLogicalInput::new(vec![]),
122                    KeccakVirtualOutput::<F> {
123                        hash: HiLo::from_hi_lo([dummy.hash_hi, dummy.hash_lo]),
124                    },
125                ),
126            );
127        }
128        promise_results
129    }
130}
131
132/// KeccakLogicalInput is the logical input of Keccak Component.
133#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize)]
134pub struct KeccakLogicalInput {
135    pub bytes: Vec<u8>,
136}
137impl KeccakLogicalInput {
138    // Create KeccakLogicalInput
139    pub fn new(bytes: Vec<u8>) -> Self {
140        Self { bytes }
141    }
142    pub fn compute_output<F: Field>(&self) -> KeccakVirtualOutput<F> {
143        let hash = H256(keccak256(&self.bytes));
144        hash.into()
145    }
146}
147
148impl<F: Field> LogicalInputValue<F> for KeccakLogicalInput {
149    fn get_capacity(&self) -> usize {
150        get_num_keccak_f(self.bytes.len())
151    }
152}
153
154pub(crate) const NUM_WITNESS_PER_KECCAK_F: usize = 6;
155const KECCAK_VIRTUAL_INPUT_FIELD_SIZE: [usize; NUM_WITNESS_PER_KECCAK_F + 1] = [
156    192, 192, 192, 192, 192, 192, // packed_input
157    1,   // is_final
158];
159const KECCAK_VIRTUAL_OUTPUT_FIELD_SIZE: [usize; 2] = [128, 128];
160
161/// Virtual input of Keccak Component.
162/// TODO: this cannot work if F::capacity < 192.
163#[derive(Debug, Clone, Hash, PartialEq, Eq)]
164pub struct KeccakVirtualInput<T: Clone> {
165    // 1 length + 17 64-byte words, every 3 are compressed into 1 witness.
166    // spec: https://github.com/axiom-crypto/halo2-lib/blob/9e6c9a16196e7e2ce58ccb6ffc31984fc0ba69d9/hashes/zkevm/src/keccak/component/encode.rs#L25
167    pub packed_input: [T; NUM_WITNESS_PER_KECCAK_F],
168    // Whether this is the last chunk of the input.
169    // TODO: this is hacky because it can be derived from packed_input but it's not really committed.
170    pub is_final: T,
171}
172
173impl<T: Clone> KeccakVirtualInput<T> {
174    pub fn new(packed_input: [T; NUM_WITNESS_PER_KECCAK_F], is_final: T) -> Self {
175        Self { packed_input, is_final }
176    }
177}
178
179impl<T: Copy> TryFrom<Flatten<T>> for KeccakVirtualInput<T> {
180    type Error = anyhow::Error;
181
182    fn try_from(value: Flatten<T>) -> std::result::Result<Self, Self::Error> {
183        if value.field_size != KECCAK_VIRTUAL_INPUT_FIELD_SIZE {
184            return Err(anyhow::anyhow!("invalid field size"));
185        }
186        if value.field_size.len() != value.fields.len() {
187            return Err(anyhow::anyhow!("field length doesn't match"));
188        }
189
190        Ok(Self {
191            packed_input: value.fields[0..NUM_WITNESS_PER_KECCAK_F]
192                .try_into()
193                .map_err(|_| anyhow!("failed to convert flatten to KeccakVirtualInput"))?,
194            is_final: value.fields[NUM_WITNESS_PER_KECCAK_F],
195        })
196    }
197}
198impl<T: Copy> From<KeccakVirtualInput<T>> for Flatten<T> {
199    fn from(val: KeccakVirtualInput<T>) -> Self {
200        Self {
201            fields: [val.packed_input.as_slice(), [val.is_final].as_slice()].concat(),
202            field_size: &KECCAK_VIRTUAL_INPUT_FIELD_SIZE,
203        }
204    }
205}
206impl<T: Copy> FixLenLogical<T> for KeccakVirtualInput<T> {
207    fn get_field_size() -> &'static [usize] {
208        &KECCAK_VIRTUAL_INPUT_FIELD_SIZE
209    }
210}
211
212impl<F: Field> From<KeccakVirtualInput<AssignedValue<F>>>
213    for PoseidonCompactChunkInput<F, POSEIDON_RATE>
214{
215    fn from(val: KeccakVirtualInput<AssignedValue<F>>) -> Self {
216        let KeccakVirtualInput::<AssignedValue<F>> { packed_input, is_final } = val;
217        assert!(packed_input.len() % POSEIDON_RATE == 0);
218        let inputs: Vec<[AssignedValue<F>; POSEIDON_RATE]> = packed_input
219            .into_iter()
220            .chunks(POSEIDON_RATE)
221            .into_iter()
222            .map(|c| c.collect_vec().try_into().unwrap())
223            .collect_vec();
224        let is_final = SafeTypeChip::unsafe_to_bool(is_final);
225        Self::new(inputs, is_final)
226    }
227}
228
229/// Virtual input of Keccak Component.
230#[derive(Default, Debug, Clone, Hash, PartialEq, Eq)]
231pub struct KeccakVirtualOutput<T: Clone> {
232    /// Keccak hash result
233    pub hash: HiLo<T>,
234}
235
236impl<T: Clone> KeccakVirtualOutput<T> {
237    pub fn new(hash: HiLo<T>) -> Self {
238        Self { hash }
239    }
240}
241
242impl<T: Copy> TryFrom<Flatten<T>> for KeccakVirtualOutput<T> {
243    type Error = anyhow::Error;
244
245    fn try_from(value: Flatten<T>) -> std::result::Result<Self, Self::Error> {
246        if value.field_size != KECCAK_VIRTUAL_OUTPUT_FIELD_SIZE {
247            return Err(anyhow::anyhow!("invalid field size"));
248        }
249        if value.field_size.len() != value.fields.len() {
250            return Err(anyhow::anyhow!("field length doesn't match"));
251        }
252
253        Ok(Self {
254            hash: HiLo::from_hi_lo(
255                value
256                    .fields
257                    .try_into()
258                    .map_err(|_| anyhow!("failed to convert flatten to KeccakVirtualOutput"))?,
259            ),
260        })
261    }
262}
263impl<T: Copy> From<KeccakVirtualOutput<T>> for Flatten<T> {
264    fn from(val: KeccakVirtualOutput<T>) -> Self {
265        Self { fields: val.hash.hi_lo().to_vec(), field_size: &KECCAK_VIRTUAL_OUTPUT_FIELD_SIZE }
266    }
267}
268impl<T: Copy> FixLenLogical<T> for KeccakVirtualOutput<T> {
269    fn get_field_size() -> &'static [usize] {
270        &KECCAK_VIRTUAL_OUTPUT_FIELD_SIZE
271    }
272}
273impl<F: Field> From<H256> for KeccakVirtualOutput<F> {
274    fn from(hash: H256) -> Self {
275        let hash_hi = u128::from_be_bytes(hash[..16].try_into().unwrap());
276        let hash_lo = u128::from_be_bytes(hash[16..].try_into().unwrap());
277        Self { hash: HiLo::from_hi_lo([F::from_u128(hash_hi), F::from_u128(hash_lo)]) }
278    }
279}
280
281#[derive(Debug, Clone)]
282pub struct ComponentTypeKeccak<F: Field>(PhantomData<F>);
283
284impl<F: Field> ComponentType<F> for ComponentTypeKeccak<F> {
285    type InputValue = KeccakVirtualInput<F>;
286    type InputWitness = KeccakVirtualInput<AssignedValue<F>>;
287    type OutputValue = KeccakVirtualOutput<F>;
288    type OutputWitness = KeccakVirtualOutput<AssignedValue<F>>;
289    type LogicalInput = KeccakLogicalInput;
290    type Commiter = KeccakComponentCommiter<F>;
291
292    fn get_type_id() -> ComponentTypeId {
293        "axiom-eth:ComponentTypeKeccak".to_string()
294    }
295
296    fn logical_result_to_virtual_rows_impl(
297        ins: &LogicalResult<F, Self>,
298    ) -> Vec<(Self::InputValue, Self::OutputValue)> {
299        let virtual_inputs = Self::logical_input_to_virtual_rows_impl(&ins.input);
300        let len = virtual_inputs.len();
301        let mut virtual_outputs = Vec::with_capacity(len);
302        let dummy = dummy_circuit_output();
303        virtual_outputs.resize(
304            len - 1,
305            Self::OutputValue { hash: HiLo::from_hi_lo([dummy.hash_hi, dummy.hash_lo]) },
306        );
307        virtual_outputs.push(ins.output.clone());
308        virtual_inputs.into_iter().zip_eq(virtual_outputs).collect_vec()
309    }
310    fn logical_input_to_virtual_rows_impl(li: &Self::LogicalInput) -> Vec<Self::InputValue> {
311        let mut packed_inputs = pack_native_input::<F>(&li.bytes);
312        let len = packed_inputs.len();
313        for (i, packed_input) in packed_inputs.iter_mut().enumerate() {
314            let is_final = if i + 1 == len { F::ONE } else { F::ZERO };
315            packed_input.push(is_final);
316        }
317        packed_inputs
318            .into_iter()
319            .map(|p| KeccakVirtualInput::try_from_raw(p).unwrap())
320            .collect_vec()
321    }
322    fn rlc_virtual_rows(
323        (gate_ctx, rlc_ctx): (&mut Context<F>, &mut Context<F>),
324        range_chip: &RangeChip<F>,
325        rlc_chip: &RlcChip<F>,
326        virtual_rows: &[(Self::InputWitness, Self::OutputWitness)],
327    ) -> Vec<AssignedValue<F>> {
328        let gate = &range_chip.gate;
329        let one = gate_ctx.load_constant(F::ONE);
330        let zero = gate_ctx.load_zero();
331        let empty_input_rlc = rlc_chip.rlc_pow_fixed(gate_ctx, gate, NUM_WITNESS_PER_KECCAK_F - 1);
332        // = rlc_chip.compute_rlc_fixed_len(rlc_ctx, [one, zero, zero, zero, zero, zero]).rlc_val;
333        // empty_input_rlc[0] = empty_input_len + 1 = 1. empty_input corresponds to input = []
334
335        let chunk_multiplier =
336            rlc_chip.rlc_pow_fixed(gate_ctx, &range_chip.gate, NUM_WITNESS_PER_KECCAK_F);
337        let output_multiplier = rlc_chip.rlc_pow_fixed(
338            gate_ctx,
339            &range_chip.gate,
340            Self::OutputWitness::get_num_fields(),
341        );
342
343        // If last chunk is a final chunk.
344        let mut last_is_final = one;
345        // RLC of the current logical input.
346        let mut curr_rlc = zero;
347        let mut virtual_row_rlcs = Vec::with_capacity(virtual_rows.len());
348        for (input, output) in virtual_rows {
349            let mut input_to_rlc = input.packed_input;
350            // +1 to length when calculating RLC in order to make sure 0 is not a valid RLC for any input. Therefore the lookup
351            // table column doesn't need a selector.
352            input_to_rlc[0] = range_chip.gate.add(gate_ctx, input_to_rlc[0], last_is_final);
353
354            let chunk_rlc = rlc_chip.compute_rlc_fixed_len(rlc_ctx, input_to_rlc).rlc_val;
355            curr_rlc = range_chip.gate.mul_add(gate_ctx, curr_rlc, chunk_multiplier, chunk_rlc);
356
357            let input_rlc =
358                range_chip.gate.select(gate_ctx, curr_rlc, empty_input_rlc, input.is_final);
359            let output_rlc = rlc_chip.compute_rlc_fixed_len(rlc_ctx, output.hash.hi_lo()).rlc_val;
360            let virtual_row_rlc =
361                range_chip.gate.mul_add(gate_ctx, input_rlc, output_multiplier, output_rlc);
362            virtual_row_rlcs.push(virtual_row_rlc);
363
364            curr_rlc = range_chip.gate.select(gate_ctx, zero, curr_rlc, input.is_final);
365
366            last_is_final = input.is_final;
367        }
368        virtual_row_rlcs
369    }
370}
371
372lazy_static! {
373    /// We cache the dummy circuit output to avoid re-computing it.
374    /// The recomputation involves creating an optimized Poseidon spec, which is
375    /// time intensive.
376    static ref CACHED_DUMMY_CIRCUIT_OUTPUT: RwLock<TypeMap> = RwLock::new(TypeMap::new());
377}
378
379/// The default dummy_circuit_output needs to do Poseidon. Poseidon generic over F
380/// requires re-computing the optimized Poseidon spec, which is computationally
381/// intensive. Since we call dummy_circuit_output very often, we cache the result
382/// as a performance optimization.
383fn dummy_circuit_output<F: crate::RawField>() -> KeccakCircuitOutput<F> {
384    use zkevm_hashes::keccak::component::output::dummy_circuit_output;
385
386    let cached_output =
387        CACHED_DUMMY_CIRCUIT_OUTPUT.read().unwrap().get::<KeccakCircuitOutput<F>>().cloned();
388    if let Some(cached_output) = cached_output {
389        return cached_output;
390    }
391    let output = dummy_circuit_output::<F>();
392    CACHED_DUMMY_CIRCUIT_OUTPUT.write().unwrap().insert(output);
393    output
394}