axiom_eth/utils/component/
mod.rs1use std::{any::Any, collections::HashMap, fmt::Debug, hash::Hash, marker::PhantomData};
2
3use crate::{Field, RawField};
4use getset::Getters;
5use halo2_base::{
6 gates::{circuit::builder::BaseCircuitBuilder, GateInstructions, RangeChip},
7 halo2_proofs::halo2curves::bn256::Fr,
8 AssignedValue, Context,
9};
10use itertools::Itertools;
11use serde::{de::DeserializeOwned, Deserialize, Serialize};
12use static_assertions::assert_impl_all;
13
14use crate::rlc::chip::RlcChip;
15
16use self::{
17 promise_loader::{
18 comp_loader::{BasicComponentCommiter, ComponentCommiter},
19 flatten_witness_to_rlc,
20 },
21 types::{ComponentPublicInstances, FixLenLogical, Flatten},
22 utils::{into_key, try_from_key},
23};
24
25pub mod circuit;
26pub mod param;
27pub mod promise_collector;
28pub mod promise_loader;
29#[cfg(test)]
30mod tests;
31pub mod types;
32pub mod utils;
33
34pub type ComponentId = u64;
35pub type ComponentTypeId = String;
36pub const USER_COMPONENT_ID: ComponentId = 0;
37
38pub trait LogicalInputValue<F: Field>:
45 Debug + Send + Sync + Clone + Eq + Serialize + DeserializeOwned + 'static
46{
47 fn get_capacity(&self) -> usize;
50}
51pub trait PromiseCallWitness<F: Field>: Debug + Send + Sync + 'static {
53 fn get_component_type_id(&self) -> ComponentTypeId;
55 fn get_capacity(&self) -> usize;
57 fn to_rlc(
60 &self,
61 ctx_pair: (&mut Context<F>, &mut Context<F>),
62 range_chip: &RangeChip<F>,
63 rlc_chip: &RlcChip<F>,
64 ) -> AssignedValue<F>;
65 fn to_typeless_logical_input(&self) -> TypelessLogicalInput;
67 fn get_mock_output(&self) -> Flatten<F>;
69 fn as_any(&self) -> &dyn Any;
71}
72
73pub type FlattenVirtualTable<F> = Vec<FlattenVirtualRow<F>>;
75pub type FlattenVirtualRow<F> = (Flatten<F>, Flatten<F>);
77
78#[derive(Clone)]
80pub struct LogicalResult<F: Field, T: ComponentType<F>> {
81 pub input: T::LogicalInput,
82 pub output: T::OutputValue,
83 pub _marker: PhantomData<F>,
84}
85impl<F: Field, T: ComponentType<F>> LogicalResult<F, T> {
86 pub fn new(input: T::LogicalInput, output: T::OutputValue) -> Self {
88 Self { input, output, _marker: PhantomData }
89 }
90}
91impl<F: Field, T: ComponentType<F>> TryFrom<ComponentPromiseResult<F>> for LogicalResult<F, T> {
92 type Error = anyhow::Error;
93 fn try_from(value: ComponentPromiseResult<F>) -> Result<Self, Self::Error> {
94 let (input, output) = value;
95 let input = try_from_key::<T::LogicalInput>(&input)?;
96 Ok(Self::new(input, T::OutputValue::try_from_raw(output)?))
97 }
98}
99impl<F: Field, T: ComponentType<F>> From<LogicalResult<F, T>> for ComponentPromiseResult<F> {
100 fn from(value: LogicalResult<F, T>) -> Self {
101 let LogicalResult { input, output, .. } = value;
102 (into_key(input), output.into_raw())
103 }
104}
105impl<F: Field, T: ComponentType<F>> From<LogicalResult<F, T>> for Vec<FlattenVirtualRow<F>> {
106 fn from(value: LogicalResult<F, T>) -> Self {
107 let logical_virtual_rows = T::logical_result_to_virtual_rows(&value);
108 logical_virtual_rows
109 .into_iter()
110 .map(|(input, output)| (input.into(), output.into()))
111 .collect_vec()
112 }
113}
114pub trait ComponentType<F: Field>: 'static + Sized {
116 type InputValue: FixLenLogical<F>;
117 type InputWitness: FixLenLogical<AssignedValue<F>>;
118 type OutputValue: FixLenLogical<F>;
119 type OutputWitness: FixLenLogical<AssignedValue<F>>;
120 type LogicalInput: LogicalInputValue<F>;
121 type Commiter: ComponentCommiter<F> = BasicComponentCommiter<F>;
122
123 fn get_type_id() -> ComponentTypeId;
125 fn get_type_name() -> &'static str {
127 std::any::type_name::<Self>()
128 }
129
130 fn logical_result_to_virtual_rows(
132 ins: &LogicalResult<F, Self>,
133 ) -> Vec<(Self::InputValue, Self::OutputValue)> {
134 let v_rows = Self::logical_result_to_virtual_rows_impl(ins);
135 assert_eq!(v_rows.len(), ins.input.get_capacity());
136 v_rows
137 }
138 fn logical_result_to_virtual_rows_impl(
140 ins: &LogicalResult<F, Self>,
141 ) -> Vec<(Self::InputValue, Self::OutputValue)>;
142
143 fn logical_input_to_virtual_rows(li: &Self::LogicalInput) -> Vec<Self::InputValue> {
146 let v_rows = Self::logical_input_to_virtual_rows_impl(li);
147 assert_eq!(v_rows.len(), li.get_capacity());
148 v_rows
149 }
150 fn logical_input_to_virtual_rows_impl(li: &Self::LogicalInput) -> Vec<Self::InputValue>;
152
153 fn rlc_virtual_rows(
156 (gate_ctx, rlc_ctx): (&mut Context<F>, &mut Context<F>),
157 range_chip: &RangeChip<F>,
158 rlc_chip: &RlcChip<F>,
159 inputs: &[(Self::InputWitness, Self::OutputWitness)],
160 ) -> Vec<AssignedValue<F>> {
161 let input_multiplier = rlc_chip.rlc_pow_fixed(
162 gate_ctx,
163 &range_chip.gate,
164 Self::OutputWitness::get_num_fields(),
165 );
166
167 inputs
168 .iter()
169 .map(|(input, output)| {
170 let i_rlc = flatten_witness_to_rlc(rlc_ctx, rlc_chip, &input.clone().into());
171 let o_rlc = flatten_witness_to_rlc(rlc_ctx, rlc_chip, &output.clone().into());
172 range_chip.gate.mul_add(gate_ctx, i_rlc, input_multiplier, o_rlc)
173 })
174 .collect_vec()
175 }
176}
177
178pub type TypelessLogicalInput = Vec<u8>;
180#[derive(Debug, Clone, Serialize, Deserialize, Hash, PartialEq, Eq, PartialOrd, Ord)]
181pub struct TypelessPromiseCall {
182 pub capacity: usize,
183 pub logical_input: TypelessLogicalInput,
184}
185
186pub type GroupedPromiseCalls = HashMap<ComponentTypeId, Vec<TypelessPromiseCall>>;
188pub type ComponentPromiseResult<F> = (TypelessLogicalInput, Vec<F>);
190
191#[derive(Debug, Clone, Serialize, Deserialize, Hash, PartialEq, Eq)]
193pub struct PromiseShardMetadata<F: RawField> {
194 pub commit: F,
195 pub capacity: usize,
196}
197pub type SelectedDataShard<S> = (usize, S);
199pub type SelectedPromiseResultShard<F> = SelectedDataShard<Vec<ComponentPromiseResult<F>>>;
201
202#[derive(Debug, Clone, Getters, Serialize, Deserialize, Hash, PartialEq, Eq)]
203pub struct SelectedDataShardsInMerkle<F: RawField, S: Clone> {
204 #[getset(get = "pub")]
206 leaves: Vec<PromiseShardMetadata<F>>,
207 #[getset(get = "pub")]
209 shards: Vec<SelectedDataShard<S>>,
210}
211
212impl<F: Field, S: Clone> SelectedDataShardsInMerkle<F, S> {
213 pub fn new(leaves: Vec<PromiseShardMetadata<F>>, shards: Vec<SelectedDataShard<S>>) -> Self {
215 assert!(leaves.len().is_power_of_two());
216 Self { leaves, shards }
218 }
219 pub fn map_data<NS: Clone>(self, f: impl Fn(S) -> NS) -> SelectedDataShardsInMerkle<F, NS> {
221 SelectedDataShardsInMerkle::new(
222 self.leaves,
223 self.shards.into_iter().map(|(i, s)| (i, f(s))).collect(),
224 )
225 }
226}
227
228pub type ComponentPromiseResultsInMerkle<F> =
230 SelectedDataShardsInMerkle<F, Vec<ComponentPromiseResult<F>>>;
231
232impl<F: Field> ComponentPromiseResultsInMerkle<F> {
233 pub fn from_single_shard<T: ComponentType<F>>(lr: Vec<LogicalResult<F, T>>) -> Self {
235 let vt = lr.iter().flat_map(T::logical_result_to_virtual_rows).collect_vec();
236 let mut mock_builder = BaseCircuitBuilder::<F>::new(true).use_k(18).use_lookup_bits(8);
237 let ctx = mock_builder.main(0);
238 let witness_vt = vt
239 .into_iter()
240 .map(|(v_i, v_o)| (v_i.into().assign(ctx), v_o.into().assign(ctx)))
241 .collect_vec();
242 let witness_commit = T::Commiter::compute_commitment(&mut mock_builder, &witness_vt);
243 let commit = *witness_commit.value();
244 mock_builder.clear(); Self {
246 leaves: vec![PromiseShardMetadata::<F> { commit, capacity: witness_vt.len() }],
247 shards: vec![(0, lr.into_iter().map(|lr| lr.into()).collect())],
248 }
249 }
250}
251pub type GroupedPromiseResults<F> = HashMap<ComponentTypeId, ComponentPromiseResultsInMerkle<F>>;
252
253assert_impl_all!(ComponentPromiseResultsInMerkle<Fr>: Serialize, DeserializeOwned);
254
255pub const NUM_COMPONENT_OWNED_INSTANCES: usize = 2;
256
257pub trait ComponentCircuit<F: Field> {
258 fn clear_witnesses(&self);
259 fn compute_promise_calls(&self) -> anyhow::Result<GroupedPromiseCalls>;
261 fn feed_input(&self, input: Box<dyn Any>) -> anyhow::Result<()>;
266 fn fulfill_promise_results(
268 &self,
269 promise_results: &GroupedPromiseResults<F>,
270 ) -> anyhow::Result<()>;
271 fn compute_outputs(&self) -> anyhow::Result<ComponentPromiseResultsInMerkle<F>>;
277 fn get_public_instances(&self) -> ComponentPublicInstances<F>;
279}