Skip to main content

sp1_gpu_commit/
commit.rs

1use std::{iter::once, sync::Arc};
2
3use slop_algebra::AbstractField;
4use slop_alloc::HasBackend;
5use slop_challenger::IopCtx;
6use slop_jagged::JaggedProverData;
7use slop_symmetric::{CryptographicHasher, PseudoCompressionFunction as _};
8use slop_tensor::Tensor;
9use sp1_gpu_basefold::{CudaStackedPcsProverData, FriCudaProver};
10use sp1_gpu_cudart::TaskScope;
11use sp1_gpu_merkle_tree::{CudaTcsProver, SingleLayerMerkleTreeProverError};
12use sp1_gpu_utils::{traces::JaggedTraceMle, Ext, Felt};
13
14/// TODO: document
15#[allow(clippy::type_complexity)]
16pub fn commit_multilinears<GC: IopCtx<F = Felt, EF = Ext>, P: CudaTcsProver<GC>>(
17    jagged_trace_mle: &JaggedTraceMle<Felt, TaskScope>,
18    max_log_row_count: u32,
19    use_preprocessed: bool,
20    drop_main_traces: bool,
21    basefold_prover: &FriCudaProver<GC, P, Felt>,
22) -> Result<
23    (GC::Digest, JaggedProverData<GC, CudaStackedPcsProverData<GC>>),
24    SingleLayerMerkleTreeProverError,
25> {
26    let (index, padding, dst) = if use_preprocessed {
27        (
28            &jagged_trace_mle.dense().preprocessed_table_index,
29            jagged_trace_mle.dense().preprocessed_padding,
30            Tensor::<Felt, TaskScope>::with_sizes_in(
31                [
32                    jagged_trace_mle.dense().preprocessed_offset >> basefold_prover.log_height,
33                    1 << (basefold_prover.log_height as usize
34                        + basefold_prover.config.log_blowup()),
35                ],
36                jagged_trace_mle.dense().dense.backend().clone(),
37            ),
38        )
39    } else {
40        (
41            &jagged_trace_mle.dense().main_table_index,
42            jagged_trace_mle.dense().main_padding,
43            Tensor::<Felt, TaskScope>::with_sizes_in(
44                [
45                    jagged_trace_mle.dense().main_size() >> basefold_prover.log_height,
46                    1 << (basefold_prover.log_height as usize
47                        + basefold_prover.config.log_blowup()),
48                ],
49                jagged_trace_mle.dense().dense.backend().clone(),
50            ),
51        )
52    };
53    let (mut row_counts, mut column_counts) = (
54        index.values().map(|x| x.poly_size).collect::<Vec<_>>(),
55        index.values().map(|x| x.num_polys).collect::<Vec<_>>(),
56    );
57
58    let drop_traces = drop_main_traces && !use_preprocessed;
59
60    let (commitment, data) =
61        basefold_prover.encode_and_commit(use_preprocessed, drop_traces, jagged_trace_mle, dst)?;
62
63    let num_added_cols = padding.div_ceil(1 << max_log_row_count).max(1);
64
65    row_counts.push(1 << max_log_row_count);
66    row_counts.push(padding - (num_added_cols - 1) * (1 << max_log_row_count));
67    column_counts.push(num_added_cols - 1);
68    column_counts.push(1);
69
70    let (hasher, compressor) = GC::default_hasher_and_compressor();
71
72    let hash = hasher.hash_iter(
73        once(Felt::from_canonical_u32(row_counts.len() as u32))
74            .chain(row_counts.clone().into_iter().map(|x| Felt::from_canonical_u32(x as u32)))
75            .chain(column_counts.clone().into_iter().map(|x| Felt::from_canonical_u32(x as u32))),
76    );
77
78    let final_commitment = compressor.compress([commitment, hash]);
79
80    let jagged_prover_data = JaggedProverData {
81        pcs_prover_data: data,
82        row_counts: Arc::new(row_counts),
83        column_counts: Arc::new(column_counts),
84        padding_column_count: num_added_cols,
85        original_commitment: commitment,
86    };
87
88    Ok((final_commitment, jagged_prover_data))
89}
90
91#[cfg(test)]
92mod tests {
93    use std::sync::Arc;
94
95    use serial_test::serial;
96    use slop_alloc::{CpuBackend, ToHost};
97    use slop_challenger::IopCtx;
98    use slop_futures::queue::WorkerQueue;
99    use slop_jagged::{JaggedPcsVerifier, JaggedProver};
100    use slop_merkle_tree::Poseidon2KoalaBear16Prover;
101    use slop_stacked::StackedPcsProver;
102    use sp1_core_machine::io::SP1Stdin;
103    use sp1_gpu_basefold::FriCudaProver;
104    use sp1_gpu_cudart::{run_in_place, PinnedBuffer};
105    use sp1_gpu_jagged_tracegen::test_utils::tracegen_setup::{
106        self, CORE_MAX_LOG_ROW_COUNT, LOG_STACKING_HEIGHT,
107    };
108    use sp1_gpu_jagged_tracegen::{full_tracegen, CORE_MAX_TRACE_SIZE};
109    use sp1_gpu_merkle_tree::{CudaTcsProver, Poseidon2SP1Field16CudaProver};
110    use sp1_gpu_utils::{Felt, TestGC};
111    use sp1_hypercube::prover::{DefaultTraceGenerator, ProverSemaphore, TraceGenerator};
112    use sp1_hypercube::{SP1InnerPcs, SP1PcsProofInner};
113    use sp1_primitives::fri_params::core_fri_config;
114
115    use crate::commit::commit_multilinears;
116    #[serial]
117    #[tokio::test]
118    async fn test_commit_matches() {
119        let (machine, record, program) =
120            tracegen_setup::setup(&test_artifacts::FIBONACCI_ELF, SP1Stdin::new()).await;
121
122        type JC = SP1InnerPcs;
123        type Prover = JaggedProver<
124            TestGC,
125            SP1PcsProofInner,
126            StackedPcsProver<Poseidon2KoalaBear16Prover, TestGC>,
127        >;
128
129        run_in_place(|scope| async move {
130            let semaphore = ProverSemaphore::new(1);
131            // Generate traces using the host tracegen.
132            let trace_generator = DefaultTraceGenerator::new_in(machine.clone(), CpuBackend);
133            let old_traces = trace_generator
134                .generate_traces(
135                    program.clone(),
136                    record.clone(),
137                    CORE_MAX_LOG_ROW_COUNT as usize,
138                    semaphore.clone(),
139                )
140                .await;
141
142            tracing::info!(
143                "warmup traces generated: {:?}",
144                old_traces.main_trace_data.shard_chips.len()
145            );
146
147            let num_rounds = 2;
148
149            let jagged_verifier = JaggedPcsVerifier::<_, JC>::new_from_basefold_params(
150                core_fri_config(),
151                LOG_STACKING_HEIGHT,
152                CORE_MAX_LOG_ROW_COUNT as usize,
153                num_rounds,
154            );
155
156            // Commit to preprocessed and main using the old prover.
157            let jagged_prover = Prover::from_verifier(&jagged_verifier);
158
159            let mut preprocessed_host_values = Vec::new();
160            for mle in old_traces.preprocessed_traces.values() {
161                let mle_host = mle.to_host().unwrap();
162                preprocessed_host_values.push(mle_host);
163            }
164
165            let mut main_host_values = Vec::new();
166            for mle in old_traces.main_trace_data.traces.values() {
167                let mle_host = mle.to_host().unwrap();
168                main_host_values.push(mle_host);
169            }
170
171            let preprocessed_message = preprocessed_host_values.into_iter().collect();
172            let main_message = main_host_values.into_iter().collect();
173
174            let (old_preprocessed_commitment, old_preprocessed_data) =
175                jagged_prover.commit_multilinears(preprocessed_message).ok().unwrap();
176            let (old_main_commitment, old_main_data) =
177                jagged_prover.commit_multilinears(main_message).ok().unwrap();
178
179            // Commit to preprocessed and main using the new prover.
180            // Do tracegen with the new setup.
181            let record = Arc::new(record);
182            let capacity = CORE_MAX_TRACE_SIZE as usize;
183            let buffer = PinnedBuffer::<Felt>::with_capacity(capacity);
184            let queue = Arc::new(WorkerQueue::new(vec![buffer]));
185            let buffer = queue.pop().await.unwrap();
186            let (_public_values, jagged_trace_data, _chip_set, _permit) = full_tracegen(
187                &machine,
188                program.clone(),
189                record.clone(),
190                &buffer,
191                CORE_MAX_TRACE_SIZE as usize,
192                LOG_STACKING_HEIGHT,
193                CORE_MAX_LOG_ROW_COUNT,
194                &scope,
195                ProverSemaphore::new(1),
196                false,
197            )
198            .await;
199
200            let tcs_prover = Poseidon2SP1Field16CudaProver::new(&scope);
201
202            let basefold_prover = FriCudaProver::<TestGC, _, <TestGC as IopCtx>::F>::new(
203                tcs_prover,
204                jagged_verifier.pcs_verifier.basefold_verifier.fri_config,
205                LOG_STACKING_HEIGHT,
206            );
207
208            let (new_preprocessed_commitment, new_preprocessed_data) =
209                commit_multilinears::<TestGC, _>(
210                    &jagged_trace_data,
211                    CORE_MAX_LOG_ROW_COUNT,
212                    true,
213                    false,
214                    &basefold_prover,
215                )
216                .unwrap();
217
218            let (new_main_commitment, new_main_data) = commit_multilinears::<TestGC, _>(
219                &jagged_trace_data,
220                CORE_MAX_LOG_ROW_COUNT,
221                false,
222                false,
223                &basefold_prover,
224            )
225            .unwrap();
226
227            assert_eq!(old_preprocessed_data.row_counts, new_preprocessed_data.row_counts);
228            assert_eq!(old_preprocessed_data.column_counts, new_preprocessed_data.column_counts);
229            assert_eq!(
230                old_preprocessed_data.padding_column_count,
231                new_preprocessed_data.padding_column_count
232            );
233            assert_eq!(old_main_data.row_counts, new_main_data.row_counts);
234            assert_eq!(old_main_data.column_counts, new_main_data.column_counts);
235            assert_eq!(old_main_data.padding_column_count, new_main_data.padding_column_count);
236            assert_eq!(
237                old_preprocessed_data.original_commitment,
238                new_preprocessed_data.original_commitment
239            );
240            assert_eq!(old_main_data.original_commitment, new_main_data.original_commitment);
241            assert_eq!(old_preprocessed_commitment, new_preprocessed_commitment);
242            assert_eq!(old_main_commitment, new_main_commitment);
243        })
244        .await;
245    }
246}