1use std::hash::Hash;
2
3use p3_air::{Air, BaseAir, PairBuilder};
4use p3_field::{ExtensionField, Field, PrimeField, PrimeField32};
5use p3_matrix::dense::RowMajorMatrix;
6use p3_uni_stark::{get_max_constraint_degree, SymbolicAirBuilder};
7use p3_util::log2_ceil_usize;
8
9use crate::{
10 air::{InteractionScope, MachineAir, MultiTableAirBuilder, SP1AirBuilder},
11 local_permutation_trace_width,
12 lookup::{Interaction, InteractionBuilder, InteractionKind},
13};
14
15use super::{
16 eval_permutation_constraints, generate_permutation_trace, scoped_interactions,
17 PROOF_MAX_NUM_PVS,
18};
19
20pub struct Chip<F: Field, A> {
22 pub air: A,
24 pub sends: Vec<Interaction<F>>,
26 pub receives: Vec<Interaction<F>>,
28 pub log_quotient_degree: usize,
30}
31
32impl<F: Field, A> Chip<F, A> {
33 pub fn sends(&self) -> &[Interaction<F>] {
35 &self.sends
36 }
37
38 pub fn receives(&self) -> &[Interaction<F>] {
40 &self.receives
41 }
42
43 pub const fn log_quotient_degree(&self) -> usize {
45 self.log_quotient_degree
46 }
47
48 pub fn into_inner(self) -> A {
50 self.air
51 }
52}
53
54impl<F: PrimeField32, A: MachineAir<F>> Chip<F, A> {
55 pub fn included(&self, shard: &A::Record) -> bool {
57 self.air.included(shard)
58 }
59}
60
61impl<F, A> Chip<F, A>
62where
63 F: Field,
64 A: BaseAir<F>,
65{
66 pub fn new(air: A) -> Self
68 where
69 A: MachineAir<F> + Air<InteractionBuilder<F>> + Air<SymbolicAirBuilder<F>>,
70 {
71 let mut builder = InteractionBuilder::new(air.preprocessed_width(), air.width());
72 air.eval(&mut builder);
73 let (sends, receives) = builder.interactions();
74
75 let nb_byte_sends = sends.iter().filter(|s| s.kind == InteractionKind::Byte).count();
76 let nb_byte_receives = receives.iter().filter(|r| r.kind == InteractionKind::Byte).count();
77 tracing::debug!(
78 "chip {} has {} byte interactions",
79 air.name(),
80 nb_byte_sends + nb_byte_receives
81 );
82
83 let mut max_constraint_degree =
84 get_max_constraint_degree(&air, air.preprocessed_width(), PROOF_MAX_NUM_PVS);
85
86 if !sends.is_empty() || !receives.is_empty() {
87 max_constraint_degree = max_constraint_degree.max(3);
88 }
89 assert!(max_constraint_degree > 0);
90 let log_quotient_degree = log2_ceil_usize(max_constraint_degree - 1);
91
92 Self { air, sends, receives, log_quotient_degree }
93 }
94
95 #[inline]
97 pub fn num_interactions(&self) -> usize {
98 self.sends.len() + self.receives.len()
99 }
100
101 #[inline]
103 pub fn num_sent_byte_lookups(&self) -> usize {
104 self.sends.iter().filter(|i| i.kind == InteractionKind::Byte).count()
105 }
106
107 #[inline]
109 pub fn num_sends_by_kind(&self, kind: InteractionKind) -> usize {
110 self.sends.iter().filter(|i| i.kind == kind).count()
111 }
112
113 #[inline]
115 pub fn num_receives_by_kind(&self, kind: InteractionKind) -> usize {
116 self.receives.iter().filter(|i| i.kind == kind).count()
117 }
118
119 pub fn generate_permutation_trace<EF: ExtensionField<F>>(
121 &self,
122 preprocessed: Option<&RowMajorMatrix<F>>,
123 main: &RowMajorMatrix<F>,
124 random_elements: &[EF],
125 ) -> (RowMajorMatrix<EF>, EF)
126 where
127 F: PrimeField,
128 A: MachineAir<F>,
129 {
130 let batch_size = self.logup_batch_size();
131 generate_permutation_trace::<F, EF>(
132 &self.sends,
133 &self.receives,
134 preprocessed,
135 main,
136 random_elements,
137 batch_size,
138 )
139 }
140
141 #[inline]
143 pub fn permutation_width(&self) -> usize {
144 let (scoped_sends, scoped_receives) = scoped_interactions(self.sends(), self.receives());
145 let empty = Vec::new();
146 let local_sends = scoped_sends.get(&InteractionScope::Local).unwrap_or(&empty);
147 let local_receives = scoped_receives.get(&InteractionScope::Local).unwrap_or(&empty);
148
149 local_permutation_trace_width(
150 local_sends.len() + local_receives.len(),
151 self.logup_batch_size(),
152 )
153 }
154
155 #[inline]
157 pub fn cost(&self) -> u64
158 where
159 A: MachineAir<F>,
160 {
161 let preprocessed_cols = self.preprocessed_width();
162 let main_cols = self.width();
163 let permutation_cols = self.permutation_width() * 4;
164 let quotient_cols = self.quotient_width() * 4;
165 (preprocessed_cols + main_cols + permutation_cols + quotient_cols) as u64
166 }
167
168 #[inline]
170 pub const fn quotient_width(&self) -> usize {
171 1 << self.log_quotient_degree
172 }
173
174 #[inline]
176 pub const fn logup_batch_size(&self) -> usize {
177 1 << self.log_quotient_degree
178 }
179}
180
181impl<F, A> BaseAir<F> for Chip<F, A>
182where
183 F: Field,
184 A: BaseAir<F>,
185{
186 fn width(&self) -> usize {
187 self.air.width()
188 }
189
190 fn preprocessed_trace(&self) -> Option<RowMajorMatrix<F>> {
191 panic!("Chip should not use the `BaseAir` method, but the `MachineAir` method.")
192 }
193}
194
195impl<F, A> MachineAir<F> for Chip<F, A>
196where
197 F: Field,
198 A: MachineAir<F>,
199{
200 type Record = A::Record;
201
202 type Program = A::Program;
203
204 fn name(&self) -> String {
205 self.air.name()
206 }
207
208 fn preprocessed_width(&self) -> usize {
209 <A as MachineAir<F>>::preprocessed_width(&self.air)
210 }
211
212 fn preprocessed_num_rows(&self, program: &Self::Program, instrs_len: usize) -> Option<usize> {
213 <A as MachineAir<F>>::preprocessed_num_rows(&self.air, program, instrs_len)
214 }
215
216 fn generate_preprocessed_trace(&self, program: &A::Program) -> Option<RowMajorMatrix<F>> {
217 <A as MachineAir<F>>::generate_preprocessed_trace(&self.air, program)
218 }
219
220 fn num_rows(&self, input: &A::Record) -> Option<usize> {
221 <A as MachineAir<F>>::num_rows(&self.air, input)
222 }
223
224 fn generate_trace(&self, input: &A::Record, output: &mut A::Record) -> RowMajorMatrix<F> {
225 self.air.generate_trace(input, output)
226 }
227
228 fn generate_dependencies(&self, input: &A::Record, output: &mut A::Record) {
229 self.air.generate_dependencies(input, output);
230 }
231
232 fn included(&self, shard: &Self::Record) -> bool {
233 self.air.included(shard)
234 }
235
236 fn commit_scope(&self) -> crate::air::InteractionScope {
237 self.air.commit_scope()
238 }
239
240 fn local_only(&self) -> bool {
241 self.air.local_only()
242 }
243}
244
245impl<'a, F, A, AB> Air<AB> for Chip<F, A>
247where
248 F: Field,
249 A: Air<AB> + MachineAir<F>,
250 AB: SP1AirBuilder<F = F> + MultiTableAirBuilder<'a> + PairBuilder + 'a,
251{
252 fn eval(&self, builder: &mut AB) {
253 self.air.eval(builder);
255 let batch_size = self.logup_batch_size();
257 eval_permutation_constraints(
258 &self.sends,
259 &self.receives,
260 batch_size,
261 self.air.commit_scope(),
262 builder,
263 );
264 }
265}
266
267impl<F, A> PartialEq for Chip<F, A>
268where
269 F: Field,
270 A: PartialEq,
271{
272 fn eq(&self, other: &Self) -> bool {
273 self.air == other.air
274 }
275}
276
277impl<F: Field, A: Eq> Eq for Chip<F, A> where F: Field + Eq {}
278
279impl<F, A> Hash for Chip<F, A>
280where
281 F: Field,
282 A: Hash,
283{
284 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
285 self.air.hash(state);
286 }
287}