1use crate::{
2 septic_curve::SepticCurve, septic_digest::SepticDigest, septic_extension::SepticExtension,
3 PROOF_MAX_NUM_PVS,
4};
5use hashbrown::HashMap;
6use itertools::Itertools;
7use p3_air::Air;
8use p3_challenger::{CanObserve, FieldChallenger};
9use p3_commit::Pcs;
10use p3_field::{AbstractExtensionField, AbstractField, Field, PrimeField32};
11use p3_matrix::{dense::RowMajorMatrix, Dimensions, Matrix};
12use p3_maybe_rayon::prelude::*;
13use p3_uni_stark::{get_symbolic_constraints, SymbolicAirBuilder};
14use serde::{de::DeserializeOwned, Deserialize, Serialize};
15use std::{cmp::Reverse, env, fmt::Debug, iter::once, time::Instant};
16use tracing::instrument;
17
18use super::{debug_constraints, Dom};
19use crate::{
20 air::{InteractionScope, MachineAir, MachineProgram},
21 count_permutation_constraints,
22 lookup::{debug_interactions_with_all_chips, InteractionKind},
23 record::MachineRecord,
24 DebugConstraintBuilder, ShardProof, VerifierConstraintFolder,
25};
26
27use super::{
28 Chip, Com, MachineProof, PcsProverData, StarkGenericConfig, Val, VerificationError, Verifier,
29};
30
31pub type MachineChip<SC, A> = Chip<Val<SC>, A>;
33
34pub struct StarkMachine<SC: StarkGenericConfig, A> {
36 config: SC,
38 chips: Vec<Chip<Val<SC>, A>>,
40
41 num_pv_elts: usize,
43
44 contains_global_bus: bool,
46}
47
48impl<SC: StarkGenericConfig, A> StarkMachine<SC, A> {
49 pub const fn new(
51 config: SC,
52 chips: Vec<Chip<Val<SC>, A>>,
53 num_pv_elts: usize,
54 contains_global_bus: bool,
55 ) -> Self {
56 Self { config, chips, num_pv_elts, contains_global_bus }
57 }
58}
59
60#[derive(Clone, Serialize, Deserialize)]
62#[serde(bound(serialize = "PcsProverData<SC>: Serialize"))]
63#[serde(bound(deserialize = "PcsProverData<SC>: DeserializeOwned"))]
64pub struct StarkProvingKey<SC: StarkGenericConfig> {
65 pub commit: Com<SC>,
67 pub pc_start: Val<SC>,
69 pub initial_global_cumulative_sum: SepticDigest<Val<SC>>,
71 pub traces: Vec<RowMajorMatrix<Val<SC>>>,
73 pub data: PcsProverData<SC>,
75 pub chip_ordering: HashMap<String, usize>,
77 pub local_only: Vec<bool>,
79 pub constraints_map: HashMap<String, usize>,
81}
82
83impl<SC: StarkGenericConfig> StarkProvingKey<SC> {
84 pub fn observe_into(&self, challenger: &mut SC::Challenger) {
86 challenger.observe(self.commit.clone());
87 challenger.observe(self.pc_start);
88 challenger.observe_slice(&self.initial_global_cumulative_sum.0.x.0);
89 challenger.observe_slice(&self.initial_global_cumulative_sum.0.y.0);
90 challenger.observe(Val::<SC>::zero());
92 }
93}
94
95#[derive(Clone, Serialize, Deserialize)]
97#[serde(bound(serialize = "Dom<SC>: Serialize"))]
98#[serde(bound(deserialize = "Dom<SC>: DeserializeOwned"))]
99pub struct StarkVerifyingKey<SC: StarkGenericConfig> {
100 pub commit: Com<SC>,
102 pub pc_start: Val<SC>,
104 pub initial_global_cumulative_sum: SepticDigest<Val<SC>>,
106 pub chip_information: Vec<(String, Dom<SC>, Dimensions)>,
108 pub chip_ordering: HashMap<String, usize>,
110}
111
112impl<SC: StarkGenericConfig> StarkVerifyingKey<SC> {
113 pub fn observe_into(&self, challenger: &mut SC::Challenger) {
115 challenger.observe(self.commit.clone());
116 challenger.observe(self.pc_start);
117 challenger.observe_slice(&self.initial_global_cumulative_sum.0.x.0);
118 challenger.observe_slice(&self.initial_global_cumulative_sum.0.y.0);
119 challenger.observe(Val::<SC>::zero());
121 }
122}
123
124impl<SC: StarkGenericConfig> Debug for StarkVerifyingKey<SC> {
125 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
126 f.debug_struct("VerifyingKey").finish()
127 }
128}
129
130impl<SC: StarkGenericConfig, A: MachineAir<Val<SC>>> StarkMachine<SC, A> {
131 pub fn shard_chips_ordered<'a, 'b>(
133 &'a self,
134 chip_ordering: &'b HashMap<String, usize>,
135 ) -> impl Iterator<Item = &'b MachineChip<SC, A>>
136 where
137 'a: 'b,
138 {
139 self.chips
140 .iter()
141 .filter(|chip| chip_ordering.contains_key(&chip.name()))
142 .sorted_by_key(|chip| chip_ordering.get(&chip.name()))
143 }
144
145 pub const fn config(&self) -> &SC {
147 &self.config
148 }
149
150 pub fn chips(&self) -> &[MachineChip<SC, A>] {
152 &self.chips
153 }
154
155 pub const fn num_pv_elts(&self) -> usize {
157 self.num_pv_elts
158 }
159
160 pub fn shard_chips<'a, 'b>(
162 &'a self,
163 shard: &'b A::Record,
164 ) -> impl Iterator<Item = &'b MachineChip<SC, A>>
165 where
166 'a: 'b,
167 {
168 self.chips.iter().filter(|chip| chip.included(shard))
169 }
170
171 #[instrument("debug constraints", level = "debug", skip_all)]
173 pub fn debug_constraints(
174 &self,
175 pk: &StarkProvingKey<SC>,
176 records: Vec<A::Record>,
177 challenger: &mut SC::Challenger,
178 ) where
179 SC::Val: PrimeField32,
180 A: for<'a> Air<DebugConstraintBuilder<'a, Val<SC>, SC::Challenge>>,
181 {
182 tracing::debug!("checking constraints for each shard");
183
184 let mut permutation_challenges: Vec<SC::Challenge> = Vec::new();
186 for _ in 0..2 {
187 permutation_challenges.push(challenger.sample_ext_element());
188 }
189
190 let mut global_cumulative_sums = Vec::new();
191 global_cumulative_sums.push(pk.initial_global_cumulative_sum);
192
193 for shard in records.iter() {
194 let chips = self.shard_chips(shard).collect::<Vec<_>>();
196
197 let pre_traces = chips
199 .iter()
200 .map(|chip| pk.chip_ordering.get(&chip.name()).map(|index| &pk.traces[*index]))
201 .collect::<Vec<_>>();
202 let mut traces = chips
203 .par_iter()
204 .map(|chip| chip.generate_trace(shard, &mut A::Record::default()))
205 .zip(pre_traces)
206 .collect::<Vec<_>>();
207
208 let mut permutation_traces = Vec::with_capacity(chips.len());
210 let mut chip_cumulative_sums = Vec::with_capacity(chips.len());
211 tracing::debug_span!("generate permutation traces").in_scope(|| {
212 chips
213 .par_iter()
214 .zip(traces.par_iter_mut())
215 .map(|(chip, (main_trace, pre_trace))| {
216 let (trace, local_sum) = chip.generate_permutation_trace(
217 *pre_trace,
218 main_trace,
219 &permutation_challenges,
220 );
221 let global_sum = if chip.commit_scope() == InteractionScope::Local {
222 SepticDigest::<Val<SC>>::zero()
223 } else {
224 let main_trace_size = main_trace.height() * main_trace.width();
225 let last_row =
226 &main_trace.values[main_trace_size - 14..main_trace_size];
227 SepticDigest(SepticCurve {
228 x: SepticExtension::<Val<SC>>::from_base_fn(|i| last_row[i]),
229 y: SepticExtension::<Val<SC>>::from_base_fn(|i| last_row[i + 7]),
230 })
231 };
232 (trace, (global_sum, local_sum))
233 })
234 .unzip_into_vecs(&mut permutation_traces, &mut chip_cumulative_sums);
235 });
236
237 let global_cumulative_sum =
238 chip_cumulative_sums.iter().map(|sums| sums.0).sum::<SepticDigest<Val<SC>>>();
239 global_cumulative_sums.push(global_cumulative_sum);
240
241 let local_cumulative_sum =
242 chip_cumulative_sums.iter().map(|sums| sums.1).sum::<SC::Challenge>();
243
244 if !local_cumulative_sum.is_zero() {
245 tracing::warn!("Local cumulative sum is not zero");
246 tracing::debug_span!("debug local interactions").in_scope(|| {
247 debug_interactions_with_all_chips::<SC, A>(
248 self,
249 pk,
250 std::slice::from_ref(shard),
251 InteractionKind::all_kinds(),
252 InteractionScope::Local,
253 )
254 });
255 panic!("Local cumulative sum is not zero");
256 }
257
258 for i in 0..chips.len() {
260 let trace_width = traces[i].0.width();
261 let pre_width = traces[i].1.map_or(0, p3_matrix::Matrix::width);
262 let permutation_width = permutation_traces[i].width() *
263 <SC::Challenge as AbstractExtensionField<SC::Val>>::D;
264 let total_width = trace_width + pre_width + permutation_width;
265 tracing::debug!(
266 "{:<11} | Main Cols = {:<5} | Pre Cols = {:<5} | Perm Cols = {:<5} | Rows = {:<10} | Cells = {:<10}",
267 chips[i].name(),
268 trace_width,
269 pre_width,
270 permutation_width,
271 traces[i].0.height(),
272 total_width * traces[i].0.height(),
273 );
274 }
275
276 if env::var("SKIP_CONSTRAINTS").is_err() {
277 tracing::info_span!("debug constraints").in_scope(|| {
278 for i in 0..chips.len() {
279 let preprocessed_trace =
280 pk.chip_ordering.get(&chips[i].name()).map(|index| &pk.traces[*index]);
281 debug_constraints::<SC, A>(
282 chips[i],
283 preprocessed_trace,
284 &traces[i].0,
285 &permutation_traces[i],
286 &permutation_challenges,
287 &shard.public_values(),
288 &chip_cumulative_sums[i].1,
289 &chip_cumulative_sums[i].0,
290 );
291 }
292 });
293 }
294 }
295
296 tracing::info!("Constraints verified successfully");
297
298 let global_cumulative_sum: SepticDigest<Val<SC>> =
299 global_cumulative_sums.iter().copied().sum();
300
301 if !global_cumulative_sum.is_zero() {
303 tracing::warn!("Global cumulative sum is not zero");
304 tracing::debug_span!("debug global interactions").in_scope(|| {
305 debug_interactions_with_all_chips::<SC, A>(
306 self,
307 pk,
308 &records,
309 InteractionKind::all_kinds(),
310 InteractionScope::Global,
311 )
312 });
313 panic!("Global cumulative sum is not zero");
314 }
315 }
316}
317
318impl<SC: StarkGenericConfig, A: MachineAir<Val<SC>> + Air<SymbolicAirBuilder<Val<SC>>>>
319 StarkMachine<SC, A>
320{
321 pub const fn contains_global_bus(&self) -> bool {
323 self.contains_global_bus
324 }
325
326 pub fn preprocessed_chip_ids(&self) -> Vec<usize> {
328 self.chips
329 .iter()
330 .enumerate()
331 .filter(|(_, chip)| chip.preprocessed_width() > 0)
332 .map(|(i, _)| i)
333 .collect()
334 }
335
336 pub fn chips_sorted_indices(&self, proof: &ShardProof<SC>) -> Vec<Option<usize>> {
338 self.chips().iter().map(|chip| proof.chip_ordering.get(&chip.name()).copied()).collect()
339 }
340
341 pub fn setup_core(
344 &self,
345 program: &A::Program,
346 initial_global_cumulative_sum: SepticDigest<Val<SC>>,
347 ) -> (StarkProvingKey<SC>, StarkVerifyingKey<SC>) {
348 let parent_span = tracing::debug_span!("generate preprocessed traces");
349 let (named_preprocessed_traces, num_constraints): (Vec<_>, Vec<_>) =
350 parent_span.in_scope(|| {
351 self.chips()
352 .par_iter()
353 .map(|chip| {
354 let chip_name = chip.name();
355 let begin = Instant::now();
356 let prep_trace = chip.generate_preprocessed_trace(program);
357 tracing::debug!(
358 parent: &parent_span,
359 "generated preprocessed trace for chip {} in {:?}",
360 chip_name,
361 begin.elapsed()
362 );
363 let expected_width =
365 prep_trace.as_ref().map_or(0, p3_matrix::Matrix::width);
366 assert_eq!(
367 expected_width,
368 chip.preprocessed_width(),
369 "Incorrect number of preprocessed columns for chip {chip_name}"
370 );
371
372 let num_main_constraints = get_symbolic_constraints(
374 &chip.air,
375 chip.preprocessed_width(),
376 PROOF_MAX_NUM_PVS,
377 )
378 .len();
379
380 let num_permutation_constraints = count_permutation_constraints(
381 &chip.sends,
382 &chip.receives,
383 chip.logup_batch_size(),
384 chip.air.commit_scope(),
385 );
386
387 (
388 prep_trace.map(move |t| (chip.name(), chip.local_only(), t)),
389 (chip_name, num_main_constraints + num_permutation_constraints),
390 )
391 })
392 .unzip()
393 });
394
395 let mut named_preprocessed_traces =
396 named_preprocessed_traces.into_iter().flatten().collect::<Vec<_>>();
397
398 named_preprocessed_traces
400 .sort_by_key(|(name, _, trace)| (Reverse(trace.height()), name.clone()));
401
402 let pcs = self.config.pcs();
403 let (chip_information, domains_and_traces): (Vec<_>, Vec<_>) = named_preprocessed_traces
404 .iter()
405 .map(|(name, _, trace)| {
406 let domain = pcs.natural_domain_for_degree(trace.height());
407 ((name.to_owned(), domain, trace.dimensions()), (domain, trace.to_owned()))
408 })
409 .unzip();
410
411 let (commit, data) = tracing::debug_span!("commit to preprocessed traces")
413 .in_scope(|| pcs.commit(domains_and_traces));
414
415 let chip_ordering = named_preprocessed_traces
417 .iter()
418 .enumerate()
419 .map(|(i, (name, _, _))| (name.to_owned(), i))
420 .collect::<HashMap<_, _>>();
421
422 let local_only = named_preprocessed_traces
423 .iter()
424 .map(|(_, local_only, _)| local_only.to_owned())
425 .collect::<Vec<_>>();
426
427 let constraints_map: HashMap<_, _> = num_constraints.into_iter().collect();
428
429 let traces =
431 named_preprocessed_traces.into_iter().map(|(_, _, trace)| trace).collect::<Vec<_>>();
432
433 let pc_start = program.pc_start();
434
435 (
436 StarkProvingKey {
437 commit: commit.clone(),
438 pc_start,
439 initial_global_cumulative_sum,
440 traces,
441 data,
442 chip_ordering: chip_ordering.clone(),
443 local_only,
444 constraints_map,
445 },
446 StarkVerifyingKey {
447 commit,
448 pc_start,
449 initial_global_cumulative_sum,
450 chip_information,
451 chip_ordering,
452 },
453 )
454 }
455
456 #[instrument("setup machine", level = "debug", skip_all)]
461 #[allow(clippy::map_unwrap_or)]
462 #[allow(clippy::redundant_closure_for_method_calls)]
463 pub fn setup(&self, program: &A::Program) -> (StarkProvingKey<SC>, StarkVerifyingKey<SC>) {
464 let initial_global_cumulative_sum = program.initial_global_cumulative_sum();
465 self.setup_core(program, initial_global_cumulative_sum)
466 }
467
468 #[allow(clippy::needless_for_each)]
470 pub fn generate_dependencies(
471 &self,
472 records: &mut [A::Record],
473 opts: &<A::Record as MachineRecord>::Config,
474 chips_filter: Option<&[String]>,
475 ) {
476 let chips = self
477 .chips
478 .iter()
479 .filter(|chip| {
480 if let Some(chips_filter) = chips_filter {
481 chips_filter.contains(&chip.name())
482 } else {
483 true
484 }
485 })
486 .collect::<Vec<_>>();
487
488 records.iter_mut().for_each(|record| {
489 chips.iter().for_each(|chip| {
490 let mut output = A::Record::default();
491 chip.generate_dependencies(record, &mut output);
492 record.append(&mut output);
493 });
494 tracing::debug_span!("register nonces").in_scope(|| record.register_nonces(opts));
495 });
496 }
497
498 #[instrument("verify", level = "info", skip_all)]
500 #[allow(clippy::match_bool)]
501 pub fn verify(
502 &self,
503 vk: &StarkVerifyingKey<SC>,
504 proof: &MachineProof<SC>,
505 challenger: &mut SC::Challenger,
506 ) -> Result<(), MachineVerificationError<SC>>
507 where
508 SC::Challenger: Clone,
509 A: for<'a> Air<VerifierConstraintFolder<'a, SC>>,
510 {
511 vk.observe_into(challenger);
513
514 if proof.shard_proofs.is_empty() {
516 return Err(MachineVerificationError::EmptyProof);
517 }
518
519 tracing::debug_span!("verify shard proofs").in_scope(|| {
520 for (i, shard_proof) in proof.shard_proofs.iter().enumerate() {
521 tracing::debug_span!("verifying shard", shard = i).in_scope(|| {
522 let chips =
523 self.shard_chips_ordered(&shard_proof.chip_ordering).collect::<Vec<_>>();
524 let mut shard_challenger = challenger.clone();
525 shard_challenger
526 .observe_slice(&shard_proof.public_values[0..self.num_pv_elts()]);
527 Verifier::verify_shard(
528 &self.config,
529 vk,
530 &chips,
531 &mut shard_challenger,
532 shard_proof,
533 )
534 .map_err(MachineVerificationError::InvalidShardProof)
535 })?;
536 }
537
538 Ok(())
539 })?;
540
541 tracing::debug_span!("verify global cumulative sum is 0").in_scope(|| {
543 let sum = proof
544 .shard_proofs
545 .iter()
546 .map(ShardProof::global_cumulative_sum)
547 .chain(once(vk.initial_global_cumulative_sum))
548 .sum::<SepticDigest<Val<SC>>>();
549
550 if !sum.is_zero() {
551 return Err(MachineVerificationError::NonZeroCumulativeSum(
552 InteractionScope::Global,
553 0,
554 ));
555 }
556
557 Ok(())
558 })
559 }
560}
561
562pub enum MachineVerificationError<SC: StarkGenericConfig> {
564 InvalidShardProof(VerificationError<SC>),
566 InvalidGlobalProof(VerificationError<SC>),
568 NonZeroCumulativeSum(InteractionScope, usize),
570 InvalidPublicValuesDigest,
572 DebugInteractionsFailed,
574 EmptyProof,
576 InvalidPublicValues(&'static str),
578 TooManyShards,
580 InvalidChipOccurrence(String),
582 MissingCpuInFirstShard,
584 CpuLogDegreeTooLarge(usize),
586 InvalidVerificationKey,
588}
589
590impl<SC: StarkGenericConfig> Debug for MachineVerificationError<SC> {
591 #[allow(clippy::uninlined_format_args)]
592 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
593 match self {
594 MachineVerificationError::InvalidShardProof(e) => {
595 write!(f, "Invalid shard proof: {:?}", e)
596 }
597 MachineVerificationError::InvalidGlobalProof(e) => {
598 write!(f, "Invalid global proof: {:?}", e)
599 }
600 MachineVerificationError::NonZeroCumulativeSum(scope, shard) => {
601 write!(f, "Non-zero cumulative sum. Scope: {}, Shard: {}", scope, shard)
602 }
603 MachineVerificationError::InvalidPublicValuesDigest => {
604 write!(f, "Invalid public values digest")
605 }
606 MachineVerificationError::EmptyProof => {
607 write!(f, "Empty proof")
608 }
609 MachineVerificationError::DebugInteractionsFailed => {
610 write!(f, "Debug interactions failed")
611 }
612 MachineVerificationError::InvalidPublicValues(s) => {
613 write!(f, "Invalid public values: {}", s)
614 }
615 MachineVerificationError::TooManyShards => {
616 write!(f, "Too many shards")
617 }
618 MachineVerificationError::InvalidChipOccurrence(s) => {
619 write!(f, "Invalid chip occurrence: {}", s)
620 }
621 MachineVerificationError::MissingCpuInFirstShard => {
622 write!(f, "Missing CPU in first shard")
623 }
624 MachineVerificationError::CpuLogDegreeTooLarge(log_degree) => {
625 write!(f, "CPU log degree too large: {}", log_degree)
626 }
627 MachineVerificationError::InvalidVerificationKey => {
628 write!(f, "Invalid verification key")
629 }
630 }
631 }
632}
633
634impl<SC: StarkGenericConfig> std::fmt::Display for MachineVerificationError<SC> {
635 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
636 Debug::fmt(self, f)
637 }
638}
639
640impl<SC: StarkGenericConfig> std::error::Error for MachineVerificationError<SC> {}
641
642impl<SC: StarkGenericConfig> MachineVerificationError<SC> {
643 pub fn is_constraints_failing(&self, expected_chip_name: &str) -> bool {
645 if let MachineVerificationError::InvalidShardProof(
646 VerificationError::OodEvaluationMismatch(chip_name),
647 ) = self
648 {
649 return chip_name == expected_chip_name;
650 }
651
652 false
653 }
654
655 pub fn is_local_cumulative_sum_failing(&self) -> bool {
657 matches!(
658 self,
659 MachineVerificationError::InvalidShardProof(VerificationError::CumulativeSumsError(
660 "local cumulative sum is not zero"
661 ))
662 )
663 }
664}