1use std::{iter::once, sync::Arc};
2
3use slop_algebra::AbstractField;
4use slop_alloc::HasBackend;
5use slop_challenger::IopCtx;
6use slop_jagged::JaggedProverData;
7use slop_symmetric::{CryptographicHasher, PseudoCompressionFunction as _};
8use slop_tensor::Tensor;
9use sp1_gpu_basefold::{CudaStackedPcsProverData, FriCudaProver};
10use sp1_gpu_cudart::TaskScope;
11use sp1_gpu_merkle_tree::{CudaTcsProver, SingleLayerMerkleTreeProverError};
12use sp1_gpu_utils::{traces::JaggedTraceMle, Ext, Felt};
13
14#[allow(clippy::type_complexity)]
16pub fn commit_multilinears<GC: IopCtx<F = Felt, EF = Ext>, P: CudaTcsProver<GC>>(
17 jagged_trace_mle: &JaggedTraceMle<Felt, TaskScope>,
18 max_log_row_count: u32,
19 use_preprocessed: bool,
20 drop_main_traces: bool,
21 basefold_prover: &FriCudaProver<GC, P, Felt>,
22) -> Result<
23 (GC::Digest, JaggedProverData<GC, CudaStackedPcsProverData<GC>>),
24 SingleLayerMerkleTreeProverError,
25> {
26 let (index, padding, dst) = if use_preprocessed {
27 (
28 &jagged_trace_mle.dense().preprocessed_table_index,
29 jagged_trace_mle.dense().preprocessed_padding,
30 Tensor::<Felt, TaskScope>::with_sizes_in(
31 [
32 jagged_trace_mle.dense().preprocessed_offset >> basefold_prover.log_height,
33 1 << (basefold_prover.log_height as usize
34 + basefold_prover.config.log_blowup()),
35 ],
36 jagged_trace_mle.dense().dense.backend().clone(),
37 ),
38 )
39 } else {
40 (
41 &jagged_trace_mle.dense().main_table_index,
42 jagged_trace_mle.dense().main_padding,
43 Tensor::<Felt, TaskScope>::with_sizes_in(
44 [
45 jagged_trace_mle.dense().main_size() >> basefold_prover.log_height,
46 1 << (basefold_prover.log_height as usize
47 + basefold_prover.config.log_blowup()),
48 ],
49 jagged_trace_mle.dense().dense.backend().clone(),
50 ),
51 )
52 };
53 let (mut row_counts, mut column_counts) = (
54 index.values().map(|x| x.poly_size).collect::<Vec<_>>(),
55 index.values().map(|x| x.num_polys).collect::<Vec<_>>(),
56 );
57
58 let drop_traces = drop_main_traces && !use_preprocessed;
59
60 let (commitment, data) =
61 basefold_prover.encode_and_commit(use_preprocessed, drop_traces, jagged_trace_mle, dst)?;
62
63 let num_added_cols = padding.div_ceil(1 << max_log_row_count).max(1);
64
65 row_counts.push(1 << max_log_row_count);
66 row_counts.push(padding - (num_added_cols - 1) * (1 << max_log_row_count));
67 column_counts.push(num_added_cols - 1);
68 column_counts.push(1);
69
70 let (hasher, compressor) = GC::default_hasher_and_compressor();
71
72 let hash = hasher.hash_iter(
73 once(Felt::from_canonical_u32(row_counts.len() as u32))
74 .chain(row_counts.clone().into_iter().map(|x| Felt::from_canonical_u32(x as u32)))
75 .chain(column_counts.clone().into_iter().map(|x| Felt::from_canonical_u32(x as u32))),
76 );
77
78 let final_commitment = compressor.compress([commitment, hash]);
79
80 let jagged_prover_data = JaggedProverData {
81 pcs_prover_data: data,
82 row_counts: Arc::new(row_counts),
83 column_counts: Arc::new(column_counts),
84 padding_column_count: num_added_cols,
85 original_commitment: commitment,
86 };
87
88 Ok((final_commitment, jagged_prover_data))
89}
90
91#[cfg(test)]
92mod tests {
93 use std::sync::Arc;
94
95 use serial_test::serial;
96 use slop_alloc::{CpuBackend, ToHost};
97 use slop_challenger::IopCtx;
98 use slop_futures::queue::WorkerQueue;
99 use slop_jagged::{JaggedPcsVerifier, JaggedProver};
100 use slop_merkle_tree::Poseidon2KoalaBear16Prover;
101 use slop_stacked::StackedPcsProver;
102 use sp1_core_machine::io::SP1Stdin;
103 use sp1_gpu_basefold::FriCudaProver;
104 use sp1_gpu_cudart::{run_in_place, PinnedBuffer};
105 use sp1_gpu_jagged_tracegen::test_utils::tracegen_setup::{
106 self, CORE_MAX_LOG_ROW_COUNT, LOG_STACKING_HEIGHT,
107 };
108 use sp1_gpu_jagged_tracegen::{full_tracegen, CORE_MAX_TRACE_SIZE};
109 use sp1_gpu_merkle_tree::{CudaTcsProver, Poseidon2SP1Field16CudaProver};
110 use sp1_gpu_utils::{Felt, TestGC};
111 use sp1_hypercube::prover::{DefaultTraceGenerator, ProverSemaphore, TraceGenerator};
112 use sp1_hypercube::{SP1InnerPcs, SP1PcsProofInner};
113 use sp1_primitives::fri_params::core_fri_config;
114
115 use crate::commit::commit_multilinears;
116 #[serial]
117 #[tokio::test]
118 async fn test_commit_matches() {
119 let (machine, record, program) =
120 tracegen_setup::setup(&test_artifacts::FIBONACCI_ELF, SP1Stdin::new()).await;
121
122 type JC = SP1InnerPcs;
123 type Prover = JaggedProver<
124 TestGC,
125 SP1PcsProofInner,
126 StackedPcsProver<Poseidon2KoalaBear16Prover, TestGC>,
127 >;
128
129 run_in_place(|scope| async move {
130 let semaphore = ProverSemaphore::new(1);
131 let trace_generator = DefaultTraceGenerator::new_in(machine.clone(), CpuBackend);
133 let old_traces = trace_generator
134 .generate_traces(
135 program.clone(),
136 record.clone(),
137 CORE_MAX_LOG_ROW_COUNT as usize,
138 semaphore.clone(),
139 )
140 .await;
141
142 tracing::info!(
143 "warmup traces generated: {:?}",
144 old_traces.main_trace_data.shard_chips.len()
145 );
146
147 let num_rounds = 2;
148
149 let jagged_verifier = JaggedPcsVerifier::<_, JC>::new_from_basefold_params(
150 core_fri_config(),
151 LOG_STACKING_HEIGHT,
152 CORE_MAX_LOG_ROW_COUNT as usize,
153 num_rounds,
154 );
155
156 let jagged_prover = Prover::from_verifier(&jagged_verifier);
158
159 let mut preprocessed_host_values = Vec::new();
160 for mle in old_traces.preprocessed_traces.values() {
161 let mle_host = mle.to_host().unwrap();
162 preprocessed_host_values.push(mle_host);
163 }
164
165 let mut main_host_values = Vec::new();
166 for mle in old_traces.main_trace_data.traces.values() {
167 let mle_host = mle.to_host().unwrap();
168 main_host_values.push(mle_host);
169 }
170
171 let preprocessed_message = preprocessed_host_values.into_iter().collect();
172 let main_message = main_host_values.into_iter().collect();
173
174 let (old_preprocessed_commitment, old_preprocessed_data) =
175 jagged_prover.commit_multilinears(preprocessed_message).ok().unwrap();
176 let (old_main_commitment, old_main_data) =
177 jagged_prover.commit_multilinears(main_message).ok().unwrap();
178
179 let record = Arc::new(record);
182 let capacity = CORE_MAX_TRACE_SIZE as usize;
183 let buffer = PinnedBuffer::<Felt>::with_capacity(capacity);
184 let queue = Arc::new(WorkerQueue::new(vec![buffer]));
185 let buffer = queue.pop().await.unwrap();
186 let (_public_values, jagged_trace_data, _chip_set, _permit) = full_tracegen(
187 &machine,
188 program.clone(),
189 record.clone(),
190 &buffer,
191 CORE_MAX_TRACE_SIZE as usize,
192 LOG_STACKING_HEIGHT,
193 CORE_MAX_LOG_ROW_COUNT,
194 &scope,
195 ProverSemaphore::new(1),
196 false,
197 )
198 .await;
199
200 let tcs_prover = Poseidon2SP1Field16CudaProver::new(&scope);
201
202 let basefold_prover = FriCudaProver::<TestGC, _, <TestGC as IopCtx>::F>::new(
203 tcs_prover,
204 jagged_verifier.pcs_verifier.basefold_verifier.fri_config,
205 LOG_STACKING_HEIGHT,
206 );
207
208 let (new_preprocessed_commitment, new_preprocessed_data) =
209 commit_multilinears::<TestGC, _>(
210 &jagged_trace_data,
211 CORE_MAX_LOG_ROW_COUNT,
212 true,
213 false,
214 &basefold_prover,
215 )
216 .unwrap();
217
218 let (new_main_commitment, new_main_data) = commit_multilinears::<TestGC, _>(
219 &jagged_trace_data,
220 CORE_MAX_LOG_ROW_COUNT,
221 false,
222 false,
223 &basefold_prover,
224 )
225 .unwrap();
226
227 assert_eq!(old_preprocessed_data.row_counts, new_preprocessed_data.row_counts);
228 assert_eq!(old_preprocessed_data.column_counts, new_preprocessed_data.column_counts);
229 assert_eq!(
230 old_preprocessed_data.padding_column_count,
231 new_preprocessed_data.padding_column_count
232 );
233 assert_eq!(old_main_data.row_counts, new_main_data.row_counts);
234 assert_eq!(old_main_data.column_counts, new_main_data.column_counts);
235 assert_eq!(old_main_data.padding_column_count, new_main_data.padding_column_count);
236 assert_eq!(
237 old_preprocessed_data.original_commitment,
238 new_preprocessed_data.original_commitment
239 );
240 assert_eq!(old_main_data.original_commitment, new_main_data.original_commitment);
241 assert_eq!(old_preprocessed_commitment, new_preprocessed_commitment);
242 assert_eq!(old_main_commitment, new_main_commitment);
243 })
244 .await;
245 }
246}