Skip to main content

axiom_eth/utils/component/
mod.rs

1use std::{any::Any, collections::HashMap, fmt::Debug, hash::Hash, marker::PhantomData};
2
3use crate::{Field, RawField};
4use getset::Getters;
5use halo2_base::{
6    gates::{circuit::builder::BaseCircuitBuilder, GateInstructions, RangeChip},
7    halo2_proofs::halo2curves::bn256::Fr,
8    AssignedValue, Context,
9};
10use itertools::Itertools;
11use serde::{de::DeserializeOwned, Deserialize, Serialize};
12use static_assertions::assert_impl_all;
13
14use crate::rlc::chip::RlcChip;
15
16use self::{
17    promise_loader::{
18        comp_loader::{BasicComponentCommiter, ComponentCommiter},
19        flatten_witness_to_rlc,
20    },
21    types::{ComponentPublicInstances, FixLenLogical, Flatten},
22    utils::{into_key, try_from_key},
23};
24
25pub mod circuit;
26pub mod param;
27pub mod promise_collector;
28pub mod promise_loader;
29#[cfg(test)]
30mod tests;
31pub mod types;
32pub mod utils;
33
34pub type ComponentId = u64;
35pub type ComponentTypeId = String;
36pub const USER_COMPONENT_ID: ComponentId = 0;
37
38/// Unified representation of a logical input of a component type.
39/// TODO: Can this be extended to variable length output?
40/// In the caller end, there could be multiple formats of promise calls for a component type. e.g.
41/// fix/var length array to keccak. But in the receiver end, we only need to know the logical input.
42/// In the receiver end, a logical input could take 1(fix len input) or multiple virtual rows(var len).
43/// The number of virtual rows a logical input take is "capacity".
44pub trait LogicalInputValue<F: Field>:
45    Debug + Send + Sync + Clone + Eq + Serialize + DeserializeOwned + 'static
46{
47    /// Get the capacity of this logical input.
48    /// The default implementaion is for the fixed length case.
49    fn get_capacity(&self) -> usize;
50}
51/// A format of a promise call to component type T.
52pub trait PromiseCallWitness<F: Field>: Debug + Send + Sync + 'static {
53    /// The component type this promise call is for.
54    fn get_component_type_id(&self) -> ComponentTypeId;
55    /// Get the capacity of this promise call.
56    fn get_capacity(&self) -> usize;
57    /// Encode the promise call into RLC.
58    /// TODO: maybe pass builder here for better flexiability? but constructing chips are slow.
59    fn to_rlc(
60        &self,
61        ctx_pair: (&mut Context<F>, &mut Context<F>),
62        range_chip: &RangeChip<F>,
63        rlc_chip: &RlcChip<F>,
64    ) -> AssignedValue<F>;
65    /// Get the logical input of this promise call.
66    fn to_typeless_logical_input(&self) -> TypelessLogicalInput;
67    /// Get dummy output of this promise call.
68    fn get_mock_output(&self) -> Flatten<F>;
69    /// Enable downcasting
70    fn as_any(&self) -> &dyn Any;
71}
72
73/// The flatten version of output of a component.
74pub type FlattenVirtualTable<F> = Vec<FlattenVirtualRow<F>>;
75/// A flatten virtual row in a virtual table.
76pub type FlattenVirtualRow<F> = (Flatten<F>, Flatten<F>);
77
78/// Logical result of a component type.
79#[derive(Clone)]
80pub struct LogicalResult<F: Field, T: ComponentType<F>> {
81    pub input: T::LogicalInput,
82    pub output: T::OutputValue,
83    pub _marker: PhantomData<F>,
84}
85impl<F: Field, T: ComponentType<F>> LogicalResult<F, T> {
86    /// Create LogicalResult
87    pub fn new(input: T::LogicalInput, output: T::OutputValue) -> Self {
88        Self { input, output, _marker: PhantomData }
89    }
90}
91impl<F: Field, T: ComponentType<F>> TryFrom<ComponentPromiseResult<F>> for LogicalResult<F, T> {
92    type Error = anyhow::Error;
93    fn try_from(value: ComponentPromiseResult<F>) -> Result<Self, Self::Error> {
94        let (input, output) = value;
95        let input = try_from_key::<T::LogicalInput>(&input)?;
96        Ok(Self::new(input, T::OutputValue::try_from_raw(output)?))
97    }
98}
99impl<F: Field, T: ComponentType<F>> From<LogicalResult<F, T>> for ComponentPromiseResult<F> {
100    fn from(value: LogicalResult<F, T>) -> Self {
101        let LogicalResult { input, output, .. } = value;
102        (into_key(input), output.into_raw())
103    }
104}
105impl<F: Field, T: ComponentType<F>> From<LogicalResult<F, T>> for Vec<FlattenVirtualRow<F>> {
106    fn from(value: LogicalResult<F, T>) -> Self {
107        let logical_virtual_rows = T::logical_result_to_virtual_rows(&value);
108        logical_virtual_rows
109            .into_iter()
110            .map(|(input, output)| (input.into(), output.into()))
111            .collect_vec()
112    }
113}
114/// Specify the logical types of a component type.
115pub trait ComponentType<F: Field>: 'static + Sized {
116    type InputValue: FixLenLogical<F>;
117    type InputWitness: FixLenLogical<AssignedValue<F>>;
118    type OutputValue: FixLenLogical<F>;
119    type OutputWitness: FixLenLogical<AssignedValue<F>>;
120    type LogicalInput: LogicalInputValue<F>;
121    type Commiter: ComponentCommiter<F> = BasicComponentCommiter<F>;
122
123    /// Get ComponentTypeId of this component type.
124    fn get_type_id() -> ComponentTypeId;
125    /// Get ComponentTypeName for logging/debugging.
126    fn get_type_name() -> &'static str {
127        std::any::type_name::<Self>()
128    }
129
130    /// Wrap logical_result_to_virtual_rows_impl with sanity check.
131    fn logical_result_to_virtual_rows(
132        ins: &LogicalResult<F, Self>,
133    ) -> Vec<(Self::InputValue, Self::OutputValue)> {
134        let v_rows = Self::logical_result_to_virtual_rows_impl(ins);
135        assert_eq!(v_rows.len(), ins.input.get_capacity());
136        v_rows
137    }
138    /// Convert a logical result to 1 or multiple virtual rows.
139    fn logical_result_to_virtual_rows_impl(
140        ins: &LogicalResult<F, Self>,
141    ) -> Vec<(Self::InputValue, Self::OutputValue)>;
142
143    /// Wrap logical_input_to_virtual_rows_impl with sanity check.
144    /// TODO: we are not using this.
145    fn logical_input_to_virtual_rows(li: &Self::LogicalInput) -> Vec<Self::InputValue> {
146        let v_rows = Self::logical_input_to_virtual_rows_impl(li);
147        assert_eq!(v_rows.len(), li.get_capacity());
148        v_rows
149    }
150    /// Real implementation to convert a logical input to virtual rows.
151    fn logical_input_to_virtual_rows_impl(li: &Self::LogicalInput) -> Vec<Self::InputValue>;
152
153    /// RLC virtual rows. A logical input might take multiple virtual rows.
154    /// The default implementation is for the fixed length case.
155    fn rlc_virtual_rows(
156        (gate_ctx, rlc_ctx): (&mut Context<F>, &mut Context<F>),
157        range_chip: &RangeChip<F>,
158        rlc_chip: &RlcChip<F>,
159        inputs: &[(Self::InputWitness, Self::OutputWitness)],
160    ) -> Vec<AssignedValue<F>> {
161        let input_multiplier = rlc_chip.rlc_pow_fixed(
162            gate_ctx,
163            &range_chip.gate,
164            Self::OutputWitness::get_num_fields(),
165        );
166
167        inputs
168            .iter()
169            .map(|(input, output)| {
170                let i_rlc = flatten_witness_to_rlc(rlc_ctx, rlc_chip, &input.clone().into());
171                let o_rlc = flatten_witness_to_rlc(rlc_ctx, rlc_chip, &output.clone().into());
172                range_chip.gate.mul_add(gate_ctx, i_rlc, input_multiplier, o_rlc)
173            })
174            .collect_vec()
175    }
176}
177
178// ============= Data types passed between components =============
179pub type TypelessLogicalInput = Vec<u8>;
180#[derive(Debug, Clone, Serialize, Deserialize, Hash, PartialEq, Eq, PartialOrd, Ord)]
181pub struct TypelessPromiseCall {
182    pub capacity: usize,
183    pub logical_input: TypelessLogicalInput,
184}
185
186/// (Receiver ComponentType, serialized logical input)
187pub type GroupedPromiseCalls = HashMap<ComponentTypeId, Vec<TypelessPromiseCall>>;
188/// (typeless logical input, output)
189pub type ComponentPromiseResult<F> = (TypelessLogicalInput, Vec<F>);
190
191/// Metadata for a promise shard
192#[derive(Debug, Clone, Serialize, Deserialize, Hash, PartialEq, Eq)]
193pub struct PromiseShardMetadata<F: RawField> {
194    pub commit: F,
195    pub capacity: usize,
196}
197/// (shard index, shard data)
198pub type SelectedDataShard<S> = (usize, S);
199/// (shard index, vec of ComponentPromiseResult)
200pub type SelectedPromiseResultShard<F> = SelectedDataShard<Vec<ComponentPromiseResult<F>>>;
201
202#[derive(Debug, Clone, Getters, Serialize, Deserialize, Hash, PartialEq, Eq)]
203pub struct SelectedDataShardsInMerkle<F: RawField, S: Clone> {
204    // metadata of leaves of this Merkle tree
205    #[getset(get = "pub")]
206    leaves: Vec<PromiseShardMetadata<F>>,
207    /// Selected shards.
208    #[getset(get = "pub")]
209    shards: Vec<SelectedDataShard<S>>,
210}
211
212impl<F: Field, S: Clone> SelectedDataShardsInMerkle<F, S> {
213    // create SelectedDataShardsInMerkle
214    pub fn new(leaves: Vec<PromiseShardMetadata<F>>, shards: Vec<SelectedDataShard<S>>) -> Self {
215        assert!(leaves.len().is_power_of_two());
216        // TODO: check capacity of each shard.
217        Self { leaves, shards }
218    }
219    /// Map data into another type.
220    pub fn map_data<NS: Clone>(self, f: impl Fn(S) -> NS) -> SelectedDataShardsInMerkle<F, NS> {
221        SelectedDataShardsInMerkle::new(
222            self.leaves,
223            self.shards.into_iter().map(|(i, s)| (i, f(s))).collect(),
224        )
225    }
226}
227
228/// Each shard is a virtual table, so shards is a vector of virtual tables.
229pub type ComponentPromiseResultsInMerkle<F> =
230    SelectedDataShardsInMerkle<F, Vec<ComponentPromiseResult<F>>>;
231
232impl<F: Field> ComponentPromiseResultsInMerkle<F> {
233    /// Helper function to create ComponentPromiseResults from a single shard.
234    pub fn from_single_shard<T: ComponentType<F>>(lr: Vec<LogicalResult<F, T>>) -> Self {
235        let vt = lr.iter().flat_map(T::logical_result_to_virtual_rows).collect_vec();
236        let mut mock_builder = BaseCircuitBuilder::<F>::new(true).use_k(18).use_lookup_bits(8);
237        let ctx = mock_builder.main(0);
238        let witness_vt = vt
239            .into_iter()
240            .map(|(v_i, v_o)| (v_i.into().assign(ctx), v_o.into().assign(ctx)))
241            .collect_vec();
242        let witness_commit = T::Commiter::compute_commitment(&mut mock_builder, &witness_vt);
243        let commit = *witness_commit.value();
244        mock_builder.clear(); // prevent drop warning
245        Self {
246            leaves: vec![PromiseShardMetadata::<F> { commit, capacity: witness_vt.len() }],
247            shards: vec![(0, lr.into_iter().map(|lr| lr.into()).collect())],
248        }
249    }
250}
251pub type GroupedPromiseResults<F> = HashMap<ComponentTypeId, ComponentPromiseResultsInMerkle<F>>;
252
253assert_impl_all!(ComponentPromiseResultsInMerkle<Fr>: Serialize, DeserializeOwned);
254
255pub const NUM_COMPONENT_OWNED_INSTANCES: usize = 2;
256
257pub trait ComponentCircuit<F: Field> {
258    fn clear_witnesses(&self);
259    /// Compute promise calls.
260    fn compute_promise_calls(&self) -> anyhow::Result<GroupedPromiseCalls>;
261    /// Feed inputs into the core builder. The `input` type should be the `CoreInput` type specified by the `CoreBuilder`.
262    /// It is the caller's responsibility to ensure that the capacity of the input
263    /// is equal to the configured capacity of the component circuit. This function
264    /// does **not** check this.
265    fn feed_input(&self, input: Box<dyn Any>) -> anyhow::Result<()>;
266    /// Fulfill promise results.
267    fn fulfill_promise_results(
268        &self,
269        promise_results: &GroupedPromiseResults<F>,
270    ) -> anyhow::Result<()>;
271    /// When inputs and promise results are ready, we can generate outputs of this component.
272    /// * When you call `compute_outputs`, `feed_inputs` must have already be called.
273    /// * Input capacity checking should happen when calling `feed_inputs`, not in this function. This function assumes that the input capacity is equal to the configured capacity of the component circuit.
274    // We don't have padding in the framework level because we don't have a formal interface to get a dummy input with capacity = 1. But even if we want to pad, it should happen when `feed_input`.
275    /// * The only goal of `compute_outputs` is to return the virtual table and its commit.
276    fn compute_outputs(&self) -> anyhow::Result<ComponentPromiseResultsInMerkle<F>>;
277    // Get public instances of this component.
278    fn get_public_instances(&self) -> ComponentPublicInstances<F>;
279}