use std::borrow::Cow;
use std::collections::{BTreeMap, BTreeSet};
use std::sync::Arc;
use itertools::Itertools;
use slop_air::{Air, BaseAir};
use slop_algebra::{
interpolate_univariate_polynomial, AbstractField, ExtensionField, Field, UnivariatePolynomial,
};
use slop_alloc::{Buffer, HasBackend};
use slop_challenger::{FieldChallenger, VariableLengthChallenger};
use slop_multilinear::Point;
use slop_sumcheck::PartialSumcheckProof;
use slop_tensor::Tensor;
use sp1_gpu_air::ir::{
analyze_constraints, build_dag, chunk_dag, enumerate_lowerings, lower_column_tile,
lower_sequential, ChunkBudget, ChunkBytecode, ColumnTileBytecode, DagBuilder, Lowering,
};
use sp1_gpu_cudart::sys::kernels::{
zerocheck_aggregate_partials_kernel, zerocheck_column_tile_ext_kernel,
zerocheck_column_tile_kb_kernel, zerocheck_fix_geq_state_kernel,
zerocheck_fused_sequential_ext_1024_kernel, zerocheck_fused_sequential_ext_128_kernel,
zerocheck_fused_sequential_ext_256_kernel, zerocheck_fused_sequential_ext_32_kernel,
zerocheck_fused_sequential_ext_512_kernel, zerocheck_fused_sequential_ext_64_kernel,
zerocheck_fused_sequential_kb_1024_kernel, zerocheck_fused_sequential_kb_128_kernel,
zerocheck_fused_sequential_kb_256_kernel, zerocheck_fused_sequential_kb_32_kernel,
zerocheck_fused_sequential_kb_512_kernel, zerocheck_fused_sequential_kb_64_kernel,
zerocheck_geq_corrections_kernel, zerocheck_gkr_sweep_ext_kernel,
zerocheck_gkr_sweep_kb_kernel, zerocheck_pad_adj_1024_kernel, zerocheck_pad_adj_128_kernel,
zerocheck_pad_adj_256_kernel, zerocheck_pad_adj_32_kernel, zerocheck_pad_adj_512_kernel,
zerocheck_pad_adj_64_kernel,
};
use sp1_gpu_cudart::sys::runtime::KernelPtr;
use sp1_gpu_cudart::{args, DeviceBuffer, DeviceCopy, DevicePoint, TaskScope};
use sp1_gpu_utils::{Ext, Felt, JaggedTraceMle};
use sp1_hypercube::air::MachineAir;
use sp1_hypercube::prover::ZerocheckAir;
use sp1_hypercube::{
AirOpenedValues, Chip, ChipEvaluation, ChipOpenedValues, LogUpEvaluations, ShardOpenedValues,
};
use crate::challenger_update;
use crate::primitives::{evaluate_jagged_fix_last_variable, JaggedFixLastVariableKernel};
#[derive(Debug, Clone)]
pub(crate) enum CompiledChunk {
Sequential(ChunkBytecode),
ColumnTile(ColumnTileBytecode),
}
#[derive(Debug, Clone)]
pub(crate) struct CompiledChip {
pub chip_idx: u32,
pub name: String,
pub main_width: u32,
pub prep_width: u32,
pub chunks: Vec<CompiledChunk>,
}
pub(crate) fn compile_chips<A>(
chips: &BTreeSet<Chip<Felt, A>>,
budget: ChunkBudget,
) -> Vec<CompiledChip>
where
A: MachineAir<Felt> + for<'a> Air<DagBuilder<'a>>,
{
let t_compile = std::time::Instant::now();
let mut out = Vec::with_capacity(chips.len());
for (i, chip) in chips.iter().enumerate() {
let air: &A = chip.air.as_ref();
let dag = build_dag(air);
let infos = analyze_constraints(&dag);
let chunks_meta = chunk_dag(&infos, &budget);
let mut compiled_chunks: Vec<CompiledChunk> = Vec::new();
for chunk in &chunks_meta {
let lowerings = enumerate_lowerings(chunk, &infos, &dag);
let mut placed = false;
if let Some(plan) = lowerings.iter().find_map(|l| match l {
Lowering::ColumnTile(p) => Some(p),
_ => None,
}) {
if let Some(bc) = lower_column_tile(chunk, &infos, &dag, plan) {
compiled_chunks.push(CompiledChunk::ColumnTile(bc));
placed = true;
}
}
if !placed {
let plan = lowerings
.iter()
.find_map(|l| match l {
Lowering::Sequential(p) => Some(p),
_ => None,
})
.expect("every chunk must have a Sequential lowering");
let bc = lower_sequential(chunk, &infos, &dag, plan);
const MAX_FUSED_REGS: u16 = 1024;
assert!(
bc.max_reg <= MAX_FUSED_REGS,
"chip {}: chunk max_reg={} exceeds fused-kernel cap ({}); \
reduce CHUNKER_MAX_LEAFSET or implement the oversize-singleton \
escape valve",
air.name(),
bc.max_reg,
MAX_FUSED_REGS,
);
if std::env::var("SP1_GPU_DEBUG_MAXREG").is_ok() {
eprintln!(
"compile chip={} max_reg={} n_instrs={} n_asserts={}",
air.name(),
bc.max_reg,
bc.instrs.len(),
bc.asserts.len()
);
}
compiled_chunks.push(CompiledChunk::Sequential(bc));
}
}
let main_width = air.width() as u32;
let prep_width = air.preprocessed_width() as u32;
if let Some(carrier_bc) = compiled_chunks.iter_mut().find_map(|c| match c {
CompiledChunk::Sequential(bc) => Some(bc),
_ => None,
}) {
carrier_bc.gkr_main_width = main_width;
carrier_bc.gkr_prep_width = prep_width;
}
out.push(CompiledChip {
chip_idx: i as u32,
name: air.name().to_string(),
main_width,
prep_width,
chunks: compiled_chunks,
});
}
if std::env::var("SP1_GPU_ZEROCHECK_TIMING").is_ok() {
tracing::info!("compile_chips: {} chips in {:?}", out.len(), t_compile.elapsed());
}
out
}
#[repr(C)]
#[derive(Debug, Clone, Copy)]
pub struct ChunkStaticC {
pub instrs: *const sp1_gpu_air::ir::DagInstr,
pub leaves: *const sp1_gpu_air::ir::LeafRef,
pub consts: *const Felt,
pub publics: *const u32,
pub assert_regs: *const u16,
pub assert_alphas: *const u32,
pub n_instrs: u32,
pub n_asserts: u32,
pub chip_idx: u32,
pub gkr_main_width: u32,
pub gkr_prep_width: u32,
pub chip_alpha_offset: u32,
}
#[repr(C)]
#[derive(Debug, Clone, Copy)]
pub(crate) struct ChipGkrInfoC {
pub main_width: u32,
pub prep_width: u32,
}
unsafe impl Send for ChunkStaticC {}
unsafe impl Sync for ChunkStaticC {}
#[repr(C)]
#[derive(Debug, Clone, Copy)]
pub struct ChipLayoutC {
pub main_ptr: u64,
pub preprocessed_ptr: u64,
pub height: u32,
pub _pad: u32,
}
#[repr(C)]
#[derive(Debug, Clone, Copy)]
pub struct BlockDispatchC {
pub chunk_id: u32,
pub row_offset: u32,
pub n_rows: u32,
}
#[repr(C)]
#[derive(Debug, Clone, Copy)]
pub struct VirtualGeqStateC {
pub threshold: u32,
pub num_vars: u32,
pub geq_coefficient: Ext,
pub eq_coefficient: Ext,
}
#[repr(C)]
#[derive(Debug, Clone, Copy)]
pub struct ChipColumnLayoutEntry {
pub prep_col_idx: u32,
pub main_col_idx: u32,
pub prep_width: u32,
pub main_width: u32,
}
pub struct ShardLayoutTracker {
pub chip_prep_h_pair: Vec<u32>,
pub chip_main_h_pair: Vec<u32>,
pub prep_padding_h_pair: Vec<u32>,
pub main_padding_h_pair: Vec<u32>,
pub chip_prep_w: Vec<u32>,
pub chip_main_w: Vec<u32>,
}
impl ShardLayoutTracker {
#[inline]
pub fn fold(&mut self) {
for h in self
.chip_prep_h_pair
.iter_mut()
.chain(self.chip_main_h_pair.iter_mut())
.chain(self.prep_padding_h_pair.iter_mut())
.chain(self.main_padding_h_pair.iter_mut())
{
*h = h.div_ceil(4) * 2;
}
}
#[inline]
pub fn total_length_pair(&self) -> u32 {
let chip_sum: u32 = self
.chip_prep_w
.iter()
.zip(self.chip_prep_h_pair.iter())
.map(|(w, h)| w * h)
.sum::<u32>()
+ self
.chip_main_w
.iter()
.zip(self.chip_main_h_pair.iter())
.map(|(w, h)| w * h)
.sum::<u32>();
let padding_sum: u32 = self.prep_padding_h_pair.iter().sum::<u32>()
+ self.main_padding_h_pair.iter().sum::<u32>();
chip_sum + padding_sum
}
#[inline]
pub fn chip_height_elements(&self, chip_idx: usize) -> u32 {
if self.chip_main_w[chip_idx] > 0 {
self.chip_main_h_pair[chip_idx] * 2
} else if self.chip_prep_w[chip_idx] > 0 {
self.chip_prep_h_pair[chip_idx] * 2
} else {
0
}
}
}
#[derive(Clone, Copy)]
pub(crate) struct ChunkDeviceBufs {
pub kind: ChunkKind,
pub leaves: *const sp1_gpu_air::ir::LeafRef,
pub consts: *const Felt,
pub publics: *const u32,
pub instrs: *const sp1_gpu_air::ir::DagInstr,
pub assert_regs: *const u16,
pub assert_alphas: *const u32,
pub max_reg: u16,
pub n_instrs: u32,
pub n_asserts: u32,
pub gkr_main_width: u32,
pub gkr_prep_width: u32,
pub terms: *const sp1_gpu_air::ir::ColumnTermEntry,
pub n_terms: u32,
}
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub(crate) enum ChunkKind {
Sequential,
ColumnTile,
}
#[derive(Clone)]
pub(crate) struct CompiledChipDevice {
pub chip_idx: u32,
pub main_width: u32,
pub prep_width: u32,
pub chunks: Vec<ChunkDeviceBufs>,
}
pub struct MachineBytecode {
_flat_instrs: Buffer<sp1_gpu_air::ir::DagInstr, TaskScope>,
_flat_leaves: Buffer<sp1_gpu_air::ir::LeafRef, TaskScope>,
_flat_consts: Buffer<Felt, TaskScope>,
_flat_publics: Buffer<u32, TaskScope>,
_flat_assert_regs: Buffer<u16, TaskScope>,
_flat_assert_alphas: Buffer<u32, TaskScope>,
_flat_terms: Buffer<sp1_gpu_air::ir::ColumnTermEntry, TaskScope>,
pub(crate) chips: Vec<CompiledChipDevice>,
pub(crate) chip_index: BTreeMap<String, usize>,
}
unsafe impl Send for MachineBytecode {}
unsafe impl Sync for MachineBytecode {}
pub fn upload_machine_bytecode<A>(
chips: &BTreeSet<Chip<Felt, A>>,
budget: ChunkBudget,
scope: &TaskScope,
) -> MachineBytecode
where
A: MachineAir<Felt> + for<'a> Air<DagBuilder<'a>>,
{
upload_compiled_bytecode(compile_chips(chips, budget), scope)
}
pub(crate) fn upload_compiled_bytecode(
compiled: Vec<CompiledChip>,
scope: &TaskScope,
) -> MachineBytecode {
let mut flat_instrs: Vec<sp1_gpu_air::ir::DagInstr> = Vec::new();
let mut flat_leaves: Vec<sp1_gpu_air::ir::LeafRef> = Vec::new();
let mut flat_consts: Vec<Felt> = Vec::new();
let mut flat_publics: Vec<u32> = Vec::new();
let mut flat_assert_regs: Vec<u16> = Vec::new();
let mut flat_assert_alphas: Vec<u32> = Vec::new();
let mut flat_terms: Vec<sp1_gpu_air::ir::ColumnTermEntry> = Vec::new();
struct ChunkOffsets {
kind: ChunkKind,
leaves: (usize, usize),
consts: (usize, usize),
publics: (usize, usize),
instrs: (usize, usize),
assert_regs: (usize, usize),
assert_alphas: (usize, usize),
terms: (usize, usize),
max_reg: u16,
gkr_main_width: u32,
gkr_prep_width: u32,
}
let mut chip_offsets: Vec<Vec<ChunkOffsets>> = Vec::with_capacity(compiled.len());
fn extend_flat<T: Copy>(dst: &mut Vec<T>, src: &[T]) -> (usize, usize) {
let off = dst.len();
dst.extend_from_slice(src);
(off, src.len())
}
for chip in &compiled {
let mut chunks = Vec::with_capacity(chip.chunks.len());
for c in chip.chunks.iter() {
chunks.push(match c {
CompiledChunk::Sequential(bc) => {
let regs: Vec<u16> = bc.asserts.iter().map(|&(r, _)| r).collect();
let alphas: Vec<u32> = bc.asserts.iter().map(|&(_, a)| a).collect();
ChunkOffsets {
kind: ChunkKind::Sequential,
leaves: extend_flat(&mut flat_leaves, &bc.leaves),
consts: extend_flat(&mut flat_consts, &bc.consts),
publics: extend_flat(&mut flat_publics, &bc.publics),
instrs: extend_flat(&mut flat_instrs, &bc.instrs),
assert_regs: extend_flat(&mut flat_assert_regs, ®s),
assert_alphas: extend_flat(&mut flat_assert_alphas, &alphas),
terms: (flat_terms.len(), 0),
max_reg: bc.max_reg,
gkr_main_width: bc.gkr_main_width,
gkr_prep_width: bc.gkr_prep_width,
}
}
CompiledChunk::ColumnTile(bc) => ChunkOffsets {
kind: ChunkKind::ColumnTile,
leaves: extend_flat(&mut flat_leaves, &bc.leaves),
consts: extend_flat(&mut flat_consts, &bc.consts),
publics: extend_flat(&mut flat_publics, &bc.publics),
instrs: (flat_instrs.len(), 0),
assert_regs: (flat_assert_regs.len(), 0),
assert_alphas: (flat_assert_alphas.len(), 0),
terms: extend_flat(&mut flat_terms, &bc.terms),
max_reg: 0,
gkr_main_width: 0,
gkr_prep_width: 0,
},
});
}
chip_offsets.push(chunks);
}
fn upload_flat<T: Copy + 'static>(v: &mut Vec<T>, scope: &TaskScope) -> Buffer<T, TaskScope> {
if v.is_empty() {
v.push(unsafe { std::mem::zeroed() });
}
DeviceBuffer::from_host_slice(v, scope).unwrap().into_inner()
}
if std::env::var("SP1_GPU_ZEROCHECK_TIMING").is_ok() {
let mb = |n: usize, sz: usize| (n * sz) as f64 / (1024.0 * 1024.0);
tracing::info!(
"upload_compiled_bytecode: {} chips, flat bytes — instrs={:.1}M leaves={:.1}M \
consts={:.1}M publics={:.1}M assert_regs={:.1}M assert_alphas={:.1}M terms={:.1}M",
compiled.len(),
mb(flat_instrs.len(), std::mem::size_of::<sp1_gpu_air::ir::DagInstr>()),
mb(flat_leaves.len(), std::mem::size_of::<sp1_gpu_air::ir::LeafRef>()),
mb(flat_consts.len(), std::mem::size_of::<Felt>()),
mb(flat_publics.len(), 4),
mb(flat_assert_regs.len(), 2),
mb(flat_assert_alphas.len(), 4),
mb(flat_terms.len(), std::mem::size_of::<sp1_gpu_air::ir::ColumnTermEntry>()),
);
}
let flat_instrs_buf = upload_flat(&mut flat_instrs, scope);
let flat_leaves_buf = upload_flat(&mut flat_leaves, scope);
let flat_consts_buf = upload_flat(&mut flat_consts, scope);
let flat_publics_buf = upload_flat(&mut flat_publics, scope);
let flat_assert_regs_buf = upload_flat(&mut flat_assert_regs, scope);
let flat_assert_alphas_buf = upload_flat(&mut flat_assert_alphas, scope);
let flat_terms_buf = upload_flat(&mut flat_terms, scope);
let instrs_base = flat_instrs_buf.as_ptr();
let leaves_base = flat_leaves_buf.as_ptr();
let consts_base = flat_consts_buf.as_ptr();
let publics_base = flat_publics_buf.as_ptr();
let assert_regs_base = flat_assert_regs_buf.as_ptr();
let assert_alphas_base = flat_assert_alphas_buf.as_ptr();
let terms_base = flat_terms_buf.as_ptr();
let mut device_chips = Vec::with_capacity(compiled.len());
let mut chip_index = BTreeMap::new();
for (chip, offsets) in compiled.iter().zip(chip_offsets.iter()) {
let chunks = offsets
.iter()
.map(|o| ChunkDeviceBufs {
kind: o.kind,
leaves: unsafe { leaves_base.add(o.leaves.0) },
consts: unsafe { consts_base.add(o.consts.0) },
publics: unsafe { publics_base.add(o.publics.0) },
instrs: unsafe { instrs_base.add(o.instrs.0) },
assert_regs: unsafe { assert_regs_base.add(o.assert_regs.0) },
assert_alphas: unsafe { assert_alphas_base.add(o.assert_alphas.0) },
terms: unsafe { terms_base.add(o.terms.0) },
max_reg: o.max_reg,
n_instrs: o.instrs.1 as u32,
n_asserts: o.assert_regs.1 as u32,
gkr_main_width: o.gkr_main_width,
gkr_prep_width: o.gkr_prep_width,
n_terms: o.terms.1 as u32,
})
.collect();
chip_index.insert(chip.name.clone(), device_chips.len());
device_chips.push(CompiledChipDevice {
chip_idx: chip.chip_idx,
main_width: chip.main_width,
prep_width: chip.prep_width,
chunks,
});
}
MachineBytecode {
_flat_instrs: flat_instrs_buf,
_flat_leaves: flat_leaves_buf,
_flat_consts: flat_consts_buf,
_flat_publics: flat_publics_buf,
_flat_assert_regs: flat_assert_regs_buf,
_flat_assert_alphas: flat_assert_alphas_buf,
_flat_terms: flat_terms_buf,
chips: device_chips,
chip_index,
}
}
pub(crate) struct ZeroCheckJaggedPoly<'b, K: Field> {
pub data: Cow<'b, JaggedTraceMle<K, TaskScope>>,
pub compiled: Vec<CompiledChipDevice>,
pub machine_bytecode: Arc<MachineBytecode>,
pub eq_adjustment: Ext,
pub zeta: Point<Ext>,
pub claim: Ext,
pub padded_row_adjustment_host: Vec<Ext>,
pub public_values: Buffer<Felt, TaskScope>,
pub powers_of_alpha: Buffer<Ext, TaskScope>,
pub gkr_powers: Buffer<Ext, TaskScope>,
pub powers_of_lambda: Buffer<Ext, TaskScope>,
pub layout_tracker: ShardLayoutTracker,
pub chip_column_layouts_dev: Buffer<ChipColumnLayoutEntry, TaskScope>,
pub chip_layouts_dev: Buffer<ChipLayoutC, TaskScope>,
pub chip_alpha_offset: Vec<u32>,
pub seq_tiers: [SeqTierStatic; 2],
pub chip_geq_state_dev: Buffer<VirtualGeqStateC, TaskScope>,
pub chip_pad_adj_dev: Buffer<Ext, TaskScope>,
pub geq_chip_indices_dev: Option<Buffer<u32, TaskScope>>,
pub n_geq_chips: usize,
pub chip_gkr_info_dev: Buffer<ChipGkrInfoC, TaskScope>,
pub gkr_active_chips: Vec<u32>,
pub cached_seq_dispatch: [Option<Buffer<BlockDispatchC, TaskScope>>; 2],
pub cached_gkr_dispatch: Option<Buffer<BlockDispatchC, TaskScope>>,
pub scan_block_counter: Buffer<u32, TaskScope>,
pub scan_flags: Buffer<u32, TaskScope>,
pub scan_values: Buffer<u32, TaskScope>,
}
pub(crate) struct SeqTierStatic {
pub static_host: Vec<ChunkStaticC>,
pub chip_indices: Vec<u32>,
pub max_reg: u16,
pub static_buf: Option<Buffer<ChunkStaticC, TaskScope>>,
}
pub(crate) trait EvalKernels<K: Field> {
fn column_tile_kernel() -> KernelPtr;
fn fused_sequential_kernel_for(max_reg_in_tier: u16) -> KernelPtr;
fn gkr_sweep_kernel() -> KernelPtr;
}
impl EvalKernels<Felt> for TaskScope {
fn column_tile_kernel() -> KernelPtr {
unsafe { zerocheck_column_tile_kb_kernel() }
}
fn fused_sequential_kernel_for(max_reg_in_tier: u16) -> KernelPtr {
unsafe {
if max_reg_in_tier <= 32 {
zerocheck_fused_sequential_kb_32_kernel()
} else if max_reg_in_tier <= 64 {
zerocheck_fused_sequential_kb_64_kernel()
} else if max_reg_in_tier <= 128 {
zerocheck_fused_sequential_kb_128_kernel()
} else if max_reg_in_tier <= 256 {
zerocheck_fused_sequential_kb_256_kernel()
} else if max_reg_in_tier <= 512 {
zerocheck_fused_sequential_kb_512_kernel()
} else {
zerocheck_fused_sequential_kb_1024_kernel()
}
}
}
fn gkr_sweep_kernel() -> KernelPtr {
unsafe { zerocheck_gkr_sweep_kb_kernel() }
}
}
impl EvalKernels<Ext> for TaskScope {
fn column_tile_kernel() -> KernelPtr {
unsafe { zerocheck_column_tile_ext_kernel() }
}
fn fused_sequential_kernel_for(max_reg_in_tier: u16) -> KernelPtr {
unsafe {
if max_reg_in_tier <= 32 {
zerocheck_fused_sequential_ext_32_kernel()
} else if max_reg_in_tier <= 64 {
zerocheck_fused_sequential_ext_64_kernel()
} else if max_reg_in_tier <= 128 {
zerocheck_fused_sequential_ext_128_kernel()
} else if max_reg_in_tier <= 256 {
zerocheck_fused_sequential_ext_256_kernel()
} else if max_reg_in_tier <= 512 {
zerocheck_fused_sequential_ext_512_kernel()
} else {
zerocheck_fused_sequential_ext_1024_kernel()
}
}
}
fn gkr_sweep_kernel() -> KernelPtr {
unsafe { zerocheck_gkr_sweep_ext_kernel() }
}
}
#[allow(clippy::too_many_arguments)]
pub(crate) fn initialize_zerocheck_poly<'b, A>(
data: &'b JaggedTraceMle<Felt, TaskScope>,
chips: &BTreeSet<Chip<Felt, A>>,
compiled_chips_dev: Vec<CompiledChipDevice>,
machine_bytecode: Arc<MachineBytecode>,
initial_heights: Vec<u32>,
public_values: Vec<Felt>,
powers_of_alpha: Vec<Ext>,
gkr_powers: Vec<Ext>,
powers_of_lambda: Vec<Ext>,
zeta: Point<Ext>,
claim: Ext,
) -> ZeroCheckJaggedPoly<'b, Felt>
where
A: MachineAir<Felt>,
{
let scope = data.dense().backend();
let layout_tracker = build_layout_tracker(chips, data);
let chip_column_layouts_host = build_chip_column_layouts(chips);
let chip_column_layouts_dev =
DeviceBuffer::from_host_slice(&chip_column_layouts_host, scope).unwrap().into_inner();
let mut chip_layouts_dev =
Buffer::<ChipLayoutC, _>::with_capacity_in(chips.len(), scope.clone());
unsafe { chip_layouts_dev.assume_init() };
launch_chip_layouts_kernel(data, &chip_column_layouts_dev, &mut chip_layouts_dev);
let max_num_constraints =
chips.iter().map(|c| c.num_constraints).max().unwrap_or(1).max(1) as u32;
let chip_alpha_offset: Vec<u32> =
chips.iter().map(|c| max_num_constraints - c.num_constraints as u32).collect();
let public_values_device =
DeviceBuffer::from_host(&Buffer::from(public_values), scope).unwrap().into_inner();
let powers_of_alpha_device =
DeviceBuffer::from_host(&Buffer::from(powers_of_alpha), scope).unwrap().into_inner();
let gkr_powers_device =
DeviceBuffer::from_host(&Buffer::from(gkr_powers), scope).unwrap().into_inner();
let powers_of_lambda_device =
DeviceBuffer::from_host(&Buffer::from(powers_of_lambda), scope).unwrap().into_inner();
let seq_tiers = build_seq_tiers(&compiled_chips_dev, &chip_alpha_offset, scope);
let chip_gkr_info_host: Vec<ChipGkrInfoC> = compiled_chips_dev
.iter()
.map(|chip| ChipGkrInfoC { main_width: chip.main_width, prep_width: chip.prep_width })
.collect();
let chip_gkr_info_dev =
DeviceBuffer::from_host_slice(&chip_gkr_info_host, scope).unwrap().into_inner();
let gkr_active_chips: Vec<u32> = compiled_chips_dev
.iter()
.filter(|chip| chip.main_width + chip.prep_width > 0)
.filter(|chip| {
let has_seq_carrier =
chip.chunks.iter().any(|c| matches!(c.kind, ChunkKind::Sequential));
chip_uses_decoupled_gkr(chip.main_width, chip.prep_width) || !has_seq_carrier
})
.map(|chip| chip.chip_idx)
.collect();
let num_vars = zeta.dimension() as u32;
let geq_state_host: Vec<VirtualGeqStateC> = initial_heights
.iter()
.map(|&h| VirtualGeqStateC {
threshold: h,
num_vars,
geq_coefficient: Ext::one(),
eq_coefficient: Ext::zero(),
})
.collect();
let chip_geq_state_dev =
DeviceBuffer::from_host_slice(&geq_state_host, scope).unwrap().into_inner();
let padded_row_adjustment = compute_padded_row_adjustment(
compiled_chips_dev.len(),
&seq_tiers,
&public_values_device,
&powers_of_alpha_device,
scope,
);
let chip_pad_adj_dev =
DeviceBuffer::from_host_slice(&padded_row_adjustment, scope).unwrap().into_inner();
let geq_chip_indices_host: Vec<u32> = compiled_chips_dev
.iter()
.filter(|chip| {
chip.chunks.iter().any(|c| matches!(c.kind, ChunkKind::Sequential))
&& padded_row_adjustment[chip.chip_idx as usize] != Ext::zero()
})
.map(|chip| chip.chip_idx)
.collect();
let n_geq_chips = geq_chip_indices_host.len();
let geq_chip_indices_dev = if n_geq_chips > 0 {
Some(DeviceBuffer::from_host_slice(&geq_chip_indices_host, scope).unwrap().into_inner())
} else {
None
};
let section_size =
unsafe { sp1_gpu_cudart::sys::kernels::jagged_fold_metadata_section_size() } as usize;
let initial_n_blocks = data.column_heights.len().div_ceil(section_size).max(1);
let scan_block_counter = {
let mut b = Buffer::<u32, _>::with_capacity_in(1, scope.clone());
b.write_bytes(0, std::mem::size_of::<u32>()).unwrap();
b
};
let scan_flags = Buffer::<u32, _>::with_capacity_in(initial_n_blocks + 1, scope.clone());
let scan_values = Buffer::<u32, _>::with_capacity_in(initial_n_blocks + 1, scope.clone());
ZeroCheckJaggedPoly {
data: Cow::Borrowed(data),
compiled: compiled_chips_dev,
machine_bytecode,
eq_adjustment: Ext::one(),
zeta,
claim,
padded_row_adjustment_host: padded_row_adjustment,
public_values: public_values_device,
powers_of_alpha: powers_of_alpha_device,
gkr_powers: gkr_powers_device,
powers_of_lambda: powers_of_lambda_device,
layout_tracker,
chip_column_layouts_dev,
chip_layouts_dev,
chip_alpha_offset,
seq_tiers,
chip_geq_state_dev,
chip_pad_adj_dev,
geq_chip_indices_dev,
n_geq_chips,
chip_gkr_info_dev,
gkr_active_chips,
cached_seq_dispatch: [None, None],
cached_gkr_dispatch: None,
scan_block_counter,
scan_flags,
scan_values,
}
}
const TIER_SPLIT_MAX_REG: u16 = 256;
pub(crate) const WIDE_GKR_THRESHOLD: u32 = 256;
fn chip_uses_decoupled_gkr(main_width: u32, prep_width: u32) -> bool {
main_width + prep_width > WIDE_GKR_THRESHOLD
}
fn build_seq_tiers(
compiled: &[CompiledChipDevice],
chip_alpha_offset: &[u32],
scope: &TaskScope,
) -> [SeqTierStatic; 2] {
let mut tier1_candidate_count = 0usize;
let mut total_seq_count = 0usize;
for chip in compiled.iter() {
for chunk in chip.chunks.iter() {
if matches!(chunk.kind, ChunkKind::Sequential) {
total_seq_count += 1;
if chunk.max_reg > TIER_SPLIT_MAX_REG {
tier1_candidate_count += 1;
}
}
}
}
let do_tier_split = total_seq_count > 0
&& tier1_candidate_count > 0
&& tier1_candidate_count * 10 <= total_seq_count;
let mut tiers: [SeqTierStatic; 2] = std::array::from_fn(|_| SeqTierStatic {
static_host: Vec::new(),
chip_indices: Vec::new(),
max_reg: 0,
static_buf: None,
});
for chip in compiled.iter() {
let chip_idx = chip.chip_idx;
let decoupled = chip_uses_decoupled_gkr(chip.main_width, chip.prep_width);
for chunk in chip.chunks.iter() {
if !matches!(chunk.kind, ChunkKind::Sequential) {
continue;
}
let tier: usize =
if do_tier_split && chunk.max_reg > TIER_SPLIT_MAX_REG { 1 } else { 0 };
tiers[tier].max_reg = tiers[tier].max_reg.max(chunk.max_reg);
tiers[tier].chip_indices.push(chip_idx);
tiers[tier].static_host.push(ChunkStaticC {
instrs: chunk.instrs,
leaves: chunk.leaves,
consts: chunk.consts,
publics: chunk.publics,
assert_regs: chunk.assert_regs,
assert_alphas: chunk.assert_alphas,
n_instrs: chunk.n_instrs,
n_asserts: chunk.n_asserts,
chip_idx,
gkr_main_width: if decoupled { 0 } else { chunk.gkr_main_width },
gkr_prep_width: if decoupled { 0 } else { chunk.gkr_prep_width },
chip_alpha_offset: chip_alpha_offset[chip_idx as usize],
});
}
}
for tier in tiers.iter_mut() {
if !tier.static_host.is_empty() {
tier.static_buf =
Some(DeviceBuffer::from_host_slice(&tier.static_host, scope).unwrap().into_inner());
}
}
tiers
}
fn refill_buffer<'a, T: Copy + DeviceCopy>(
cache: &'a mut Option<Buffer<T, TaskScope>>,
host_data: &[T],
scope: &TaskScope,
) -> &'a Buffer<T, TaskScope> {
let needed = host_data.len().max(1);
if cache.as_ref().is_none_or(|b| b.capacity() < needed) {
*cache = Some(Buffer::with_capacity_in(needed, scope.clone()));
}
let buf = cache.as_mut().unwrap();
unsafe {
buf.set_len(0);
}
buf.extend_from_host_slice(host_data).unwrap();
cache.as_ref().unwrap()
}
fn pad_adj_kernel_for(max_reg_in_tier: u16) -> KernelPtr {
unsafe {
if max_reg_in_tier <= 32 {
zerocheck_pad_adj_32_kernel()
} else if max_reg_in_tier <= 64 {
zerocheck_pad_adj_64_kernel()
} else if max_reg_in_tier <= 128 {
zerocheck_pad_adj_128_kernel()
} else if max_reg_in_tier <= 256 {
zerocheck_pad_adj_256_kernel()
} else if max_reg_in_tier <= 512 {
zerocheck_pad_adj_512_kernel()
} else {
zerocheck_pad_adj_1024_kernel()
}
}
}
fn compute_padded_row_adjustment(
n_chips: usize,
seq_tiers: &[SeqTierStatic; 2],
public_values: &Buffer<Felt, TaskScope>,
powers_of_alpha: &Buffer<Ext, TaskScope>,
scope: &TaskScope,
) -> Vec<Ext> {
let mut padded_row_adjustment = vec![Ext::zero(); n_chips];
const PAD_ADJ_BLOCK_SIZE: u32 = 64;
for tier in seq_tiers.iter() {
let n_chunks = tier.static_host.len();
if n_chunks == 0 {
continue;
}
let static_buf = tier.static_buf.as_ref().unwrap();
let mut output: Tensor<Ext, TaskScope> = Tensor::with_sizes_in([n_chunks], scope.clone());
unsafe {
output.assume_init();
}
let n_blocks = (n_chunks as u32).div_ceil(PAD_ADJ_BLOCK_SIZE);
unsafe {
let args = args!(
static_buf.as_ptr(),
(n_chunks as u32),
public_values.as_ptr(),
powers_of_alpha.as_ptr(),
output.as_mut_ptr()
);
scope
.launch_kernel(
pad_adj_kernel_for(tier.max_reg),
(n_blocks, 1u32, 1u32),
(PAD_ADJ_BLOCK_SIZE, 1u32, 1u32),
&args,
0,
)
.unwrap();
}
let per_chunk: Vec<Ext> = unsafe { output.into_buffer().copy_into_host_vec() };
for (i, &chip_idx) in tier.chip_indices.iter().enumerate() {
padded_row_adjustment[chip_idx as usize] += per_chunk[i];
}
}
padded_row_adjustment
}
fn build_chip_column_layouts<A>(chips: &BTreeSet<Chip<Felt, A>>) -> Vec<ChipColumnLayoutEntry>
where
A: MachineAir<Felt>,
{
let total_prep_widths: usize = chips.iter().map(|c| c.preprocessed_width()).sum();
let main_section_start_col: usize = total_prep_widths + 1;
let mut out = Vec::with_capacity(chips.len());
let mut cum_prep: usize = 0;
let mut cum_main: usize = 0;
for chip in chips.iter() {
let prep_w = chip.preprocessed_width() as u32;
let main_w = chip.width() as u32;
out.push(ChipColumnLayoutEntry {
prep_col_idx: cum_prep as u32,
main_col_idx: (main_section_start_col + cum_main) as u32,
prep_width: prep_w,
main_width: main_w,
});
cum_prep += prep_w as usize;
cum_main += main_w as usize;
}
out
}
fn launch_chip_layouts_kernel<K: Field>(
data: &JaggedTraceMle<K, TaskScope>,
chip_column_layouts_dev: &Buffer<ChipColumnLayoutEntry, TaskScope>,
chip_layouts_dev: &mut Buffer<ChipLayoutC, TaskScope>,
) {
let n_chips = chip_column_layouts_dev.len() as u32;
if n_chips == 0 {
return;
}
let scope = data.dense().backend();
const BLOCK: u32 = 128;
let n_blocks = n_chips.div_ceil(BLOCK);
unsafe {
let args = args!(
data.0.start_indices.as_ptr(),
data.0.column_heights.as_ptr(),
chip_column_layouts_dev.as_ptr(),
n_chips,
chip_layouts_dev.as_mut_ptr()
);
scope
.launch_kernel(
sp1_gpu_cudart::sys::kernels::jagged_chip_layouts_kernel(),
(n_blocks, 1u32, 1u32),
(BLOCK, 1u32, 1u32),
&args,
0,
)
.unwrap();
}
}
fn build_layout_tracker<A>(
chips: &BTreeSet<Chip<Felt, A>>,
data: &JaggedTraceMle<Felt, TaskScope>,
) -> ShardLayoutTracker
where
A: MachineAir<Felt>,
{
let chip_prep_w: Vec<u32> = chips.iter().map(|c| c.preprocessed_width() as u32).collect();
let chip_main_w: Vec<u32> = chips.iter().map(|c| c.width() as u32).collect();
let chip_prep_h_pair: Vec<u32> = chips
.iter()
.map(|chip| {
if chip.preprocessed_width() > 0 {
let off = data.dense_data.preprocessed_table_index.get(chip.name()).unwrap();
(off.poly_size as u32) / 2
} else {
0
}
})
.collect();
let chip_main_h_pair: Vec<u32> = chips
.iter()
.map(|chip| {
if chip.width() > 0 {
let off = data.dense_data.main_table_index.get(chip.name()).unwrap();
(off.poly_size as u32) / 2
} else {
0
}
})
.collect();
let n_prep_padding = data.dense_data.prep_padding_col_count;
let n_main_padding = data.dense_data.main_padding_col_count;
let total_prep_w: usize = chip_prep_w.iter().sum::<u32>() as usize;
let total_main_w: usize = chip_main_w.iter().sum::<u32>() as usize;
let column_heights: Vec<u32> = unsafe { data.0.column_heights.copy_into_host_vec() };
debug_assert_eq!(
column_heights.len(),
total_prep_w + n_prep_padding + total_main_w + n_main_padding,
"TraceDenseData padding col counts disagree with column_heights structure",
);
let prep_padding_start = total_prep_w;
let prep_padding_end = prep_padding_start + n_prep_padding;
let main_padding_start = prep_padding_end + total_main_w;
let prep_padding_h_pair: Vec<u32> =
column_heights[prep_padding_start..prep_padding_end].to_vec();
let main_padding_h_pair: Vec<u32> = column_heights[main_padding_start..].to_vec();
ShardLayoutTracker {
chip_prep_h_pair,
chip_main_h_pair,
prep_padding_h_pair,
main_padding_h_pair,
chip_prep_w,
chip_main_w,
}
}
pub(crate) fn evaluate_zerocheck<'b, K: Field>(
poly: &mut ZeroCheckJaggedPoly<'b, K>,
) -> UnivariatePolynomial<Ext>
where
TaskScope: EvalKernels<K>,
{
let backend = poly.data.backend();
const NUM_EVAL_POINT: usize = 3;
const MAX_GRID: u32 = 4096;
const BLOCK_SIZE_LOW_REG: u32 = 256;
const BLOCK_SIZE_HIGH_REG: u32 = 64;
const ROWS_PER_THREAD: u32 = 4;
let (rest, last) = poly.zeta.split_at(poly.zeta.dimension() - 1);
let last = *last[0];
let rest_point = DevicePoint::from_host(&rest, backend).unwrap();
let partial_lagrange = rest_point.partial_lagrange();
let rest_point_dim = rest.dimension() as u32;
let trace_ptr = poly.data.as_ref().dense_data.dense.as_ptr();
let block_size_for = |tier: usize| -> u32 {
if poly.seq_tiers[tier].max_reg > 128 {
BLOCK_SIZE_HIGH_REG
} else {
BLOCK_SIZE_LOW_REG
}
};
let chip_layouts_ptr = poly.chip_layouts_dev.as_ptr();
let mut ct_launches: Vec<(u32, &ChunkDeviceBufs, u32)> = Vec::new();
for chip in poly.compiled.iter() {
let chip_idx = chip.chip_idx;
let row_count = poly.layout_tracker.chip_height_elements(chip_idx as usize) / 2;
if row_count == 0 {
continue;
}
for chunk in chip.chunks.iter() {
if let ChunkKind::ColumnTile = chunk.kind {
ct_launches.push((chip_idx, chunk, row_count));
}
}
}
let mut dispatch_tiers: [Vec<BlockDispatchC>; 2] = [Vec::new(), Vec::new()];
for (t, tier) in poly.seq_tiers.iter().enumerate() {
let tile = block_size_for(t) * ROWS_PER_THREAD;
for (chunk_idx_in_tier, &chip_idx) in tier.chip_indices.iter().enumerate() {
let row_count = poly.layout_tracker.chip_height_elements(chip_idx as usize) / 2;
if row_count == 0 {
continue;
}
let mut row_offset = 0u32;
while row_offset < row_count {
let n_rows = (row_count - row_offset).min(tile);
dispatch_tiers[t].push(BlockDispatchC {
chunk_id: chunk_idx_in_tier as u32,
row_offset,
n_rows,
});
row_offset += tile;
}
}
}
const GKR_BLOCK_SIZE: u32 = 256;
let gkr_tile: u32 = GKR_BLOCK_SIZE * ROWS_PER_THREAD;
let mut gkr_dispatch: Vec<BlockDispatchC> = Vec::new();
for &chip_idx in poly.gkr_active_chips.iter() {
let row_count = poly.layout_tracker.chip_height_elements(chip_idx as usize) / 2;
if row_count == 0 {
continue;
}
let mut row_offset = 0u32;
while row_offset < row_count {
let n_rows = (row_count - row_offset).min(gkr_tile);
gkr_dispatch.push(BlockDispatchC { chunk_id: chip_idx, row_offset, n_rows });
row_offset += gkr_tile;
}
}
let mut tier_slot: [usize; 2] = [0, 0];
let mut total_slots: usize = 0;
for t in 0..2 {
tier_slot[t] = total_slots;
total_slots += dispatch_tiers[t].len() * NUM_EVAL_POINT;
}
let mut ct_slots: Vec<(usize, u32)> = Vec::with_capacity(ct_launches.len());
let ct_block_size: u32 = 128; for &(_, chunk, row_count) in &ct_launches {
let total = chunk.n_terms as u64 * row_count as u64;
let n_blocks = if total == 0 {
0
} else {
total.div_ceil(ct_block_size as u64).min(MAX_GRID as u64).max(1) as u32
};
ct_slots.push((total_slots, n_blocks));
total_slots += (n_blocks as usize) * NUM_EVAL_POINT;
}
let geq_slot = total_slots;
total_slots += poly.n_geq_chips * NUM_EVAL_POINT;
let gkr_slot = total_slots;
total_slots += gkr_dispatch.len() * NUM_EVAL_POINT;
let mut shared_output: Tensor<Ext, TaskScope> =
Tensor::with_sizes_in([total_slots.max(1)], backend.clone());
unsafe {
shared_output.assume_init();
}
let shared_output_ptr = shared_output.as_mut_ptr();
for t in 0..2 {
if dispatch_tiers[t].is_empty() {
continue;
}
let bs = block_size_for(t);
let dispatch_ptr =
refill_buffer(&mut poly.cached_seq_dispatch[t], &dispatch_tiers[t], backend).as_ptr();
let static_ptr = poly.seq_tiers[t].static_buf.as_ref().unwrap().as_ptr();
let max_reg = poly.seq_tiers[t].max_reg;
let out_ptr = unsafe { shared_output_ptr.add(tier_slot[t]) };
let shmem_bytes = (bs as usize / 32).max(1) * std::mem::size_of::<Ext>();
unsafe {
let args = args!(
dispatch_ptr,
static_ptr,
chip_layouts_ptr,
trace_ptr,
poly.public_values.as_ptr(),
poly.powers_of_alpha.as_ptr(),
partial_lagrange.as_ptr(),
poly.powers_of_lambda.as_ptr(),
poly.gkr_powers.as_ptr(),
rest_point_dim,
out_ptr
);
backend
.launch_kernel(
<TaskScope as EvalKernels<K>>::fused_sequential_kernel_for(max_reg),
(dispatch_tiers[t].len() as u32, 1, 3),
(bs, 1, 1),
&args,
shmem_bytes,
)
.unwrap();
}
}
for (i, &(chip_idx, chunk, row_count)) in ct_launches.iter().enumerate() {
let (slot, n_blocks) = ct_slots[i];
if n_blocks == 0 {
continue;
}
let out_slot = unsafe { shared_output_ptr.add(slot) };
launch_chunk_into::<K>(
backend,
chunk,
trace_ptr,
chip_layouts_ptr,
&poly.public_values,
&poly.powers_of_alpha,
poly.chip_alpha_offset[chip_idx as usize],
partial_lagrange.as_ptr(),
&poly.powers_of_lambda,
chip_idx,
rest_point_dim,
0,
row_count,
n_blocks,
ct_block_size,
out_slot,
);
}
if poly.n_geq_chips > 0 {
const GEQ_BLOCK_SIZE: u32 = 256;
let geq_indices = poly.geq_chip_indices_dev.as_ref().unwrap();
let out_ptr = unsafe { shared_output_ptr.add(geq_slot) };
let shmem_bytes = (GEQ_BLOCK_SIZE as usize / 32) * std::mem::size_of::<Ext>();
unsafe {
let args = args!(
geq_indices.as_ptr(),
(poly.n_geq_chips as u32),
poly.chip_geq_state_dev.as_ptr(),
poly.chip_pad_adj_dev.as_ptr(),
poly.powers_of_lambda.as_ptr(),
chip_layouts_ptr,
partial_lagrange.as_ptr(),
rest_point_dim,
out_ptr
);
backend
.launch_kernel(
zerocheck_geq_corrections_kernel(),
(poly.n_geq_chips as u32, 1, 1),
(GEQ_BLOCK_SIZE, 1, 1),
&args,
shmem_bytes,
)
.unwrap();
}
}
if !gkr_dispatch.is_empty() {
let gkr_ptr = refill_buffer(&mut poly.cached_gkr_dispatch, &gkr_dispatch, backend).as_ptr();
let out_ptr = unsafe { shared_output_ptr.add(gkr_slot) };
let shmem_bytes = (GKR_BLOCK_SIZE as usize / 32) * std::mem::size_of::<Ext>();
unsafe {
let args = args!(
gkr_ptr,
chip_layouts_ptr,
poly.chip_gkr_info_dev.as_ptr(),
trace_ptr,
poly.gkr_powers.as_ptr(),
partial_lagrange.as_ptr(),
poly.powers_of_lambda.as_ptr(),
rest_point_dim,
out_ptr
);
backend
.launch_kernel(
<TaskScope as EvalKernels<K>>::gkr_sweep_kernel(),
(gkr_dispatch.len() as u32, 1, 3),
(GKR_BLOCK_SIZE, 1, 1),
&args,
shmem_bytes,
)
.unwrap();
}
}
let mut totals_buf: Tensor<Ext, TaskScope> =
Tensor::with_sizes_in([NUM_EVAL_POINT], backend.clone());
unsafe {
totals_buf.assume_init();
}
{
const AGG_BLOCK_SIZE: u32 = 256;
let shmem_bytes = (AGG_BLOCK_SIZE as usize / 32) * std::mem::size_of::<Ext>();
unsafe {
let args = args!(
shared_output_ptr as *const Ext,
(total_slots as u32),
totals_buf.as_mut_ptr()
);
backend
.launch_kernel(
zerocheck_aggregate_partials_kernel(),
(1u32, 1u32, 1u32),
(AGG_BLOCK_SIZE, 1u32, 1u32),
&args,
shmem_bytes,
)
.unwrap();
}
}
let totals_vec: Vec<Ext> = unsafe { totals_buf.into_buffer().copy_into_host_vec() };
let totals: [Ext; NUM_EVAL_POINT] = [totals_vec[0], totals_vec[1], totals_vec[2]];
drop(shared_output);
let mut xs =
vec![Ext::from_canonical_u32(0), Ext::from_canonical_u32(2), Ext::from_canonical_u32(4)];
let mut ys: Vec<Ext> = xs
.iter()
.zip(totals.iter())
.map(|(&x, &t)| {
let last_var_eq = (Ext::one() - x) * (Ext::one() - last) + x * last;
t * last_var_eq * poly.eq_adjustment
})
.collect();
xs.push(Ext::from_canonical_u32(1));
ys.push(poly.claim - ys[0]);
xs.push((last - Ext::one()) / (last + last - Ext::one()));
ys.push(Ext::zero());
interpolate_univariate_polynomial(&xs, &ys)
}
#[allow(clippy::too_many_arguments)]
fn launch_chunk_into<K: Field>(
scope: &TaskScope,
chunk: &ChunkDeviceBufs,
trace_ptr: *const K,
chip_layouts_ptr: *const ChipLayoutC,
public_values: &Buffer<Felt, TaskScope>,
powers_of_alpha: &Buffer<Ext, TaskScope>,
chip_alpha_offset: u32,
partial_lagrange_ptr: *const Ext,
powers_of_lambda: &Buffer<Ext, TaskScope>,
chip_idx: u32,
rest_point_dim: u32,
row_start: u32,
row_count: u32,
n_blocks: u32,
block_size: u32,
output_ptr: *mut Ext,
) where
TaskScope: EvalKernels<K>,
{
let shmem_bytes = (block_size as usize / 32) * std::mem::size_of::<Ext>();
match chunk.kind {
ChunkKind::Sequential => {
unreachable!("Sequential chunks go through the fused kernel, not launch_chunk_into")
}
ChunkKind::ColumnTile => unsafe {
let powers_of_alpha_shifted = powers_of_alpha.as_ptr().add(chip_alpha_offset as usize);
let args = args!(
chunk.terms,
chunk.n_terms,
chunk.leaves,
chunk.consts,
chunk.publics,
trace_ptr,
chip_layouts_ptr,
public_values.as_ptr(),
powers_of_alpha_shifted,
partial_lagrange_ptr,
powers_of_lambda.as_ptr(),
chip_idx,
rest_point_dim,
row_start,
row_count,
output_ptr
);
scope
.launch_kernel(
<TaskScope as EvalKernels<K>>::column_tile_kernel(),
(n_blocks, 1, 1),
(block_size, 1, 1),
&args,
shmem_bytes,
)
.unwrap();
},
}
}
pub(crate) fn zerocheck_fix_last_variable<'b, K: Field>(
input: ZeroCheckJaggedPoly<'b, K>,
point: Ext,
claim: Ext,
) -> ZeroCheckJaggedPoly<'b, Ext>
where
TaskScope: JaggedFixLastVariableKernel<K>,
Ext: ExtensionField<K>,
{
let (rest, last) = input.zeta.split_at(input.zeta.dimension() - 1);
let last = *last[0];
let input_length = input.layout_tracker.total_length_pair();
let mut layout_tracker = input.layout_tracker;
layout_tracker.fold();
let new_total_length = layout_tracker.total_length_pair() * 2;
let mut scan_block_counter = input.scan_block_counter;
let mut scan_flags = input.scan_flags;
let mut scan_values = input.scan_values;
let new_data = evaluate_jagged_fix_last_variable(
&input.data,
point,
input_length,
new_total_length,
crate::primitives::FoldMetadataScratch {
block_counter: &mut scan_block_counter,
flags: &mut scan_flags,
scan_values: &mut scan_values,
},
);
let eq = (Ext::one() - last) * (Ext::one() - point) + last * point;
let eq_adjustment = input.eq_adjustment * eq;
let n_chips = input.compiled.len() as u32;
if n_chips > 0 {
const BS: u32 = 128;
let n_blocks = n_chips.div_ceil(BS);
let scope = new_data.dense().backend();
unsafe {
let args = args!(input.chip_geq_state_dev.as_ptr(), n_chips, point);
scope
.launch_kernel(
zerocheck_fix_geq_state_kernel(),
(n_blocks, 1, 1),
(BS, 1, 1),
&args,
0,
)
.unwrap();
}
}
let mut chip_layouts_dev = input.chip_layouts_dev;
launch_chip_layouts_kernel(&new_data, &input.chip_column_layouts_dev, &mut chip_layouts_dev);
ZeroCheckJaggedPoly {
data: Cow::Owned(new_data),
compiled: input.compiled,
machine_bytecode: input.machine_bytecode,
eq_adjustment,
zeta: rest,
claim,
padded_row_adjustment_host: input.padded_row_adjustment_host,
public_values: input.public_values,
powers_of_alpha: input.powers_of_alpha,
gkr_powers: input.gkr_powers,
powers_of_lambda: input.powers_of_lambda,
layout_tracker,
chip_column_layouts_dev: input.chip_column_layouts_dev,
chip_layouts_dev,
chip_alpha_offset: input.chip_alpha_offset,
seq_tiers: input.seq_tiers,
chip_geq_state_dev: input.chip_geq_state_dev,
chip_pad_adj_dev: input.chip_pad_adj_dev,
geq_chip_indices_dev: input.geq_chip_indices_dev,
n_geq_chips: input.n_geq_chips,
chip_gkr_info_dev: input.chip_gkr_info_dev,
gkr_active_chips: input.gkr_active_chips,
cached_seq_dispatch: input.cached_seq_dispatch,
cached_gkr_dispatch: input.cached_gkr_dispatch,
scan_block_counter,
scan_flags,
scan_values,
}
}
#[allow(clippy::too_many_arguments)]
pub fn zerocheck<A, C>(
chips: &BTreeSet<Chip<Felt, A>>,
machine_bytecode: &Arc<MachineBytecode>,
trace_mle: &JaggedTraceMle<Felt, TaskScope>,
batching_challenge: Ext,
gkr_opening_batch_randomness: Ext,
logup_evaluations: &LogUpEvaluations<Ext>,
public_values: Vec<Felt>,
challenger: &mut C,
max_log_row_count: u32,
) -> (ShardOpenedValues<Felt, Ext>, PartialSumcheckProof<Ext>)
where
A: ZerocheckAir<Felt, Ext>,
C: FieldChallenger<Felt>,
{
let data_input_heights: Vec<u32> = unsafe { trace_mle.column_heights.copy_into_host_vec() };
let initial_heights = trace_mle
.dense_data
.main_table_index
.values()
.map(|trace_offset| trace_offset.poly_size as u32)
.collect::<Vec<u32>>();
let max_num_constraints =
itertools::max(chips.iter().map(|chip| chip.num_constraints)).unwrap();
let max_columns =
itertools::max(chips.iter().map(|chip| chip.preprocessed_width() + chip.width())).unwrap();
let total_preprocessed_columns = trace_mle.dense().preprocessed_cols;
let mut powers_of_challenge =
batching_challenge.powers().take(max_num_constraints).collect::<Vec<_>>();
powers_of_challenge.reverse();
let num_chips = chips.len();
let debug_timing = std::env::var("SP1_GPU_ZEROCHECK_TIMING").is_ok();
let t_setup = std::time::Instant::now();
let gkr_powers =
gkr_opening_batch_randomness.powers().skip(1).take(max_columns).collect::<Vec<_>>();
let lambda: Ext = challenger.sample_ext_element();
let powers_of_lambda =
lambda.powers().take(num_chips).collect_vec().into_iter().rev().collect::<Vec<_>>();
let mut claim = Ext::zero();
let LogUpEvaluations { point: gkr_point, chip_openings } = logup_evaluations;
for chip in chips.iter() {
let ChipEvaluation {
main_trace_evaluations: main_opening,
preprocessed_trace_evaluations: prep_opening,
} = chip_openings.get(chip.name()).unwrap();
claim *= lambda;
let addend = main_opening
.evaluations()
.as_slice()
.iter()
.chain(
prep_opening
.as_ref()
.map_or_else(Vec::new, |mle| mle.evaluations().as_slice().to_vec())
.iter(),
)
.zip(gkr_powers.iter())
.map(|(opening, power)| *opening * *power)
.sum::<Ext>();
claim += addend;
}
let t_pra_and_claim = t_setup.elapsed();
let t_select = std::time::Instant::now();
let compiled_dev: Vec<CompiledChipDevice> = chips
.iter()
.enumerate()
.map(|(shard_idx, chip)| {
let m = *machine_bytecode
.chip_index
.get(chip.name())
.expect("shard chip not present in machine bytecode");
let mut view = machine_bytecode.chips[m].clone();
view.chip_idx = shard_idx as u32;
view
})
.collect();
let t_select = t_select.elapsed();
if debug_timing {
tracing::info!(
"zerocheck setup: num_chips={} pra+claim={:?} select={:?}",
num_chips,
t_pra_and_claim,
t_select,
);
}
let mut main_poly = initialize_zerocheck_poly(
trace_mle,
chips,
compiled_dev,
machine_bytecode.clone(),
initial_heights.clone(),
public_values,
powers_of_challenge,
gkr_powers,
powers_of_lambda,
gkr_point.clone(),
claim,
);
let mut univariate_polys = vec![];
let mut jagged_point: Point<Ext> = Point::from(vec![]);
let t_eval_total = std::time::Instant::now();
let mut total_fold = std::time::Duration::ZERO;
let mut total_eval = std::time::Duration::ZERO;
let mut total_chal = std::time::Duration::ZERO;
let t = std::time::Instant::now();
let mut result = evaluate_zerocheck(&mut main_poly);
if debug_timing {
total_eval += t.elapsed();
}
let t = std::time::Instant::now();
let (mut point, mut next_claim) = challenger_update(&result, challenger);
if debug_timing {
total_chal += t.elapsed();
}
univariate_polys.push(result);
jagged_point.add_dimension(point);
let t = std::time::Instant::now();
let mut next_poly = zerocheck_fix_last_variable(main_poly, point, next_claim);
if debug_timing {
total_fold += t.elapsed();
}
for _ in 0..max_log_row_count - 1 {
let t = std::time::Instant::now();
result = evaluate_zerocheck(&mut next_poly);
if debug_timing {
total_eval += t.elapsed();
}
let t = std::time::Instant::now();
(point, next_claim) = challenger_update(&result, challenger);
if debug_timing {
total_chal += t.elapsed();
}
univariate_polys.push(result);
jagged_point.add_dimension(point);
let t = std::time::Instant::now();
next_poly = zerocheck_fix_last_variable(next_poly, point, next_claim);
if debug_timing {
total_fold += t.elapsed();
}
}
if debug_timing {
tracing::info!(
"zerocheck: total={:?} eval={:?} fold={:?} chal={:?}",
t_eval_total.elapsed(),
total_eval,
total_fold,
total_chal
);
}
let final_jagged_data =
unsafe { next_poly.data.as_ref().dense_data.dense.copy_into_host_vec() };
let mut idx = 0;
let mut individual_column_evals = vec![Ext::zero(); data_input_heights.len()];
for i in 0..data_input_heights.len() {
if data_input_heights[i] != 0 {
individual_column_evals[i] = final_jagged_data[idx];
idx += 4;
}
}
let mut preprocessed_ptr = 0;
let mut main_ptr = total_preprocessed_columns;
let mut opened_values: BTreeMap<String, ChipOpenedValues<Felt, Ext>> = BTreeMap::new();
challenger.observe(Felt::from_canonical_usize(chips.len()));
for (i, chip) in chips.iter().enumerate() {
let preprocessed_width = chip.preprocessed_width();
let preprocessed = AirOpenedValues {
local: individual_column_evals[preprocessed_ptr..preprocessed_ptr + preprocessed_width]
.to_vec(),
};
challenger.observe_variable_length_extension_slice(&preprocessed.local);
preprocessed_ptr += preprocessed_width;
let width = chip.width();
let main =
AirOpenedValues { local: individual_column_evals[main_ptr..main_ptr + width].to_vec() };
challenger.observe_variable_length_extension_slice(&main.local);
main_ptr += width;
opened_values.insert(
chip.air.name().to_string(),
ChipOpenedValues {
preprocessed,
main,
degree: Point::from_usize(
initial_heights[i] as usize,
(max_log_row_count + 1) as usize,
),
},
);
}
let partial_sumcheck_proof = PartialSumcheckProof {
univariate_polys,
claimed_sum: claim,
point_and_eval: (jagged_point, next_claim),
};
let shard_open_values = ShardOpenedValues { chips: opened_values };
(shard_open_values, partial_sumcheck_proof)
}
#[cfg(test)]
mod layout_tracker_tests {
use super::ShardLayoutTracker;
fn synthetic(
chips: &[(u32, u32, u32, u32)],
prep_padding: &[u32],
main_padding: &[u32],
) -> (ShardLayoutTracker, Vec<u32>) {
let chip_prep_w: Vec<u32> = chips.iter().map(|c| c.0).collect();
let chip_main_w: Vec<u32> = chips.iter().map(|c| c.1).collect();
let chip_prep_h_pair: Vec<u32> = chips.iter().map(|c| c.2).collect();
let chip_main_h_pair: Vec<u32> = chips.iter().map(|c| c.3).collect();
let prep_padding_h_pair: Vec<u32> = prep_padding.to_vec();
let main_padding_h_pair: Vec<u32> = main_padding.to_vec();
let mut column_heights = Vec::new();
for (w, h) in chip_prep_w.iter().zip(chip_prep_h_pair.iter()) {
for _ in 0..*w {
column_heights.push(*h);
}
}
column_heights.extend(prep_padding_h_pair.iter().copied());
for (w, h) in chip_main_w.iter().zip(chip_main_h_pair.iter()) {
for _ in 0..*w {
column_heights.push(*h);
}
}
column_heights.extend(main_padding_h_pair.iter().copied());
let tracker = ShardLayoutTracker {
chip_prep_h_pair,
chip_main_h_pair,
prep_padding_h_pair,
main_padding_h_pair,
chip_prep_w,
chip_main_w,
};
(tracker, column_heights)
}
#[test]
fn total_length_matches_column_heights_sum() {
let (tracker, column_heights) =
synthetic(&[(3, 7, 12, 12), (5, 4, 8, 8), (0, 2, 0, 16)], &[5], &[]);
assert_eq!(tracker.total_length_pair(), column_heights.iter().sum::<u32>());
}
#[test]
fn total_length_matches_with_multi_col_prep_padding() {
let (tracker, column_heights) =
synthetic(&[(3, 7, 12, 12), (5, 4, 8, 8)], &[1024, 1024, 1024, 512], &[]);
assert_eq!(tracker.total_length_pair(), column_heights.iter().sum::<u32>());
}
#[test]
fn total_length_matches_with_main_padding() {
let (tracker, column_heights) = synthetic(&[(2, 3, 10, 10), (1, 1, 4, 4)], &[6], &[3, 1]);
assert_eq!(tracker.total_length_pair(), column_heights.iter().sum::<u32>());
}
#[test]
fn fold_stays_in_lockstep_with_element_wise_transform() {
let (mut tracker, mut column_heights) = synthetic(
&[(3, 7, 13, 14), (5, 4, 9, 6), (0, 2, 0, 17)],
&[1024, 1024, 1024, 511],
&[3, 1],
);
for _round in 0..25 {
tracker.fold();
for h in column_heights.iter_mut() {
*h = h.div_ceil(4) * 2;
}
assert_eq!(
tracker.total_length_pair(),
column_heights.iter().sum::<u32>(),
"tracker drifted from device-equivalent column_heights",
);
}
}
#[test]
fn chip_height_elements_uses_main_when_present_else_prep() {
let (tracker, _) = synthetic(
&[
(3, 7, 10, 20), (5, 0, 8, 0), (0, 4, 0, 12), ],
&[5],
&[],
);
assert_eq!(tracker.chip_height_elements(0), 40); assert_eq!(tracker.chip_height_elements(1), 16); assert_eq!(tracker.chip_height_elements(2), 24); }
#[test]
fn empty_padding_sections_are_handled() {
let (tracker, column_heights) = synthetic(&[(2, 3, 10, 10)], &[], &[]);
assert_eq!(tracker.total_length_pair(), column_heights.iter().sum::<u32>());
assert_eq!(tracker.total_length_pair(), 2 * 10 + 3 * 10);
}
}