1mod recursion;
2mod riscv;
3
4use core::future::{ready, Future};
5use core::pin::pin;
6use std::collections::BTreeSet;
7use std::{collections::BTreeMap, sync::Arc};
8
9use futures::stream::FuturesUnordered;
10use futures::{join, StreamExt};
11use rayon::prelude::*;
12use slop_air::BaseAir;
13use slop_algebra::Field;
14use slop_alloc::mem::CopyError;
15use slop_multilinear::{Mle, PaddedMle};
16use sp1_gpu_cudart::{DeviceMle, DeviceTransposeKernel, TaskScope};
17use sp1_hypercube::prover::{MainTraceData, PreprocessedTraceData, ProverSemaphore, TraceData};
18use sp1_hypercube::{
19 air::MachineAir,
20 prover::{TraceGenerator, Traces},
21 Machine,
22};
23
24use sp1_hypercube::{Chip, MachineRecord};
25use sp1_primitives::SP1Field;
26use tracing::{debug_span, instrument, Instrument};
27
28pub(crate) type F = SP1Field;
30
31pub struct CudaTraceGenerator<F: Field, A> {
33 machine: Machine<F, A>,
34 trace_allocator: TaskScope,
35}
36
37impl<A: MachineAir<F>> CudaTraceGenerator<F, A> {
38 #[must_use]
40 pub fn new_in(machine: Machine<F, A>, trace_allocator: TaskScope) -> Self {
41 Self { machine, trace_allocator }
42 }
43}
44
45struct HostPhaseTracegen<F, A> {
47 pub device_airs: Vec<Arc<A>>,
48 pub host_traces: futures::channel::mpsc::UnboundedReceiver<(String, Mle<F>)>,
49}
50
51struct HostPhaseShapePadding<F: Field, A> {
53 pub shard_chips: BTreeSet<Chip<F, A>>,
54 pub padded_traces: BTreeMap<String, PaddedMle<F, TaskScope>>,
55}
56
57impl<F, A> CudaTraceGenerator<F, A>
58where
59 F: Field,
60 A: CudaTracegenAir<F>,
61 TaskScope: DeviceTransposeKernel<F>,
62{
63 #[instrument(skip_all, level = "debug")]
65 fn host_preprocessed_tracegen(
66 &self,
67 program: Arc<<A as MachineAir<F>>::Program>,
68 ) -> HostPhaseTracegen<F, A> {
69 let (device_airs, host_airs): (Vec<_>, Vec<_>) = self
71 .machine
72 .chips()
73 .iter()
74 .map(|chip| chip.air.clone())
75 .partition(|air| air.supports_device_preprocessed_tracegen());
76
77 let (host_traces_tx, host_traces) = futures::channel::mpsc::unbounded();
80 slop_futures::rayon::spawn(move || {
81 host_airs.into_par_iter().for_each_with(host_traces_tx, |tx, air| {
82 if let Some(trace) = air.generate_preprocessed_trace(&program) {
83 tx.unbounded_send((air.name().to_string(), Mle::from(trace))).unwrap();
84 }
85 });
86 drop(program);
89 });
90 HostPhaseTracegen { device_airs, host_traces }
91 }
92
93 #[instrument(skip_all, level = "debug")]
94 async fn device_preprocessed_tracegen(
95 &self,
96 program: Arc<<A as MachineAir<F>>::Program>,
97 max_log_row_count: usize,
98 host_phase_tracegen: HostPhaseTracegen<F, A>,
99 ) -> Traces<F, TaskScope> {
100 let HostPhaseTracegen { device_airs, host_traces } = host_phase_tracegen;
101
102 let copied_host_traces = pin!(host_traces.then(|(name, trace)| async move {
104 (name, DeviceMle::from_host(&trace, &self.trace_allocator).unwrap().into())
105 }));
106 let device_traces = device_airs
108 .into_iter()
109 .map(|air| {
110 let program = program.as_ref();
112 async move {
113 let maybe_trace = air
114 .generate_preprocessed_trace_device(program, &self.trace_allocator)
115 .await
116 .unwrap();
117 (air, maybe_trace)
118 }
119 })
120 .collect::<FuturesUnordered<_>>()
121 .filter_map(|(air, maybe_trace)| {
122 ready(maybe_trace.map(|trace| (air.name().to_string(), trace.into())))
123 });
124
125 let named_traces = futures::stream_select!(copied_host_traces, device_traces)
126 .map(|(name, trace)| {
127 (name, PaddedMle::padded_with_zeros(Arc::new(trace), max_log_row_count as u32))
128 })
129 .collect::<BTreeMap<_, _>>()
130 .await;
131
132 rayon::spawn(move || drop(program));
135
136 Traces { named_traces }
137 }
138
139 #[instrument(skip_all, level = "debug")]
141 fn host_main_tracegen(
142 &self,
143 record: Arc<<A as MachineAir<F>>::Record>,
144 max_log_row_count: usize,
145 ) -> (HostPhaseTracegen<F, A>, HostPhaseShapePadding<F, A>)
146 where
147 F: Field,
148 A: CudaTracegenAir<F>,
149 {
150 let chip_set = self
152 .machine
153 .chips()
154 .iter()
155 .filter(|chip| chip.included(&record))
156 .cloned()
157 .collect::<BTreeSet<_>>();
158
159 let (device_airs, host_airs): (Vec<_>, Vec<_>) = chip_set
161 .iter()
162 .map(|chip| chip.air.clone())
163 .partition(|c| c.supports_device_main_tracegen());
164
165 let (host_traces_tx, host_traces) = futures::channel::mpsc::unbounded();
168 slop_futures::rayon::spawn(move || {
169 host_airs.into_par_iter().for_each_with(host_traces_tx, |tx, air| {
170 let trace = Mle::from(air.generate_trace(&record, &mut A::Record::default()));
171 tx.unbounded_send((air.name().to_string(), trace)).unwrap();
173 });
174 drop(record);
177 });
178
179 let shard_chips = self.machine.smallest_cluster(&chip_set).unwrap().clone();
181 let padded_traces = shard_chips
183 .iter()
184 .filter(|chip| !chip_set.contains(chip))
185 .map(|chip| {
186 let num_polynomials = chip.width();
187 (
188 chip.name().to_string(),
189 PaddedMle::zeros_in(
190 num_polynomials,
191 max_log_row_count as u32,
192 self.trace_allocator.clone(),
193 ),
194 )
195 })
196 .collect::<BTreeMap<_, _>>();
197
198 (
199 HostPhaseTracegen { device_airs, host_traces },
200 HostPhaseShapePadding { shard_chips, padded_traces },
201 )
202 }
203
204 #[instrument(skip_all, level = "debug")]
205 async fn device_main_tracegen(
206 &self,
207 max_log_row_count: usize,
208 record: Arc<<A as MachineAir<F>>::Record>,
209 host_phase_tracegen: HostPhaseTracegen<F, A>,
210 padded_traces: BTreeMap<String, PaddedMle<F, TaskScope>>,
211 ) -> (Traces<F, TaskScope>, Vec<F>)
212 where
213 F: Field,
214 A: CudaTracegenAir<F>,
215 {
216 let HostPhaseTracegen { device_airs, host_traces } = host_phase_tracegen;
217
218 let copied_host_traces = pin!(host_traces.then(|(name, trace)| async move {
220 (name, DeviceMle::from_host(&trace, &self.trace_allocator).unwrap().into())
221 }));
222 let device_traces = device_airs
224 .into_iter()
225 .map(|air| {
226 let record = record.as_ref();
228 async move {
229 let trace = air
230 .generate_trace_device(
231 record,
232 &mut A::Record::default(),
233 &self.trace_allocator,
234 )
235 .await
236 .unwrap();
237 (air.name().to_string(), trace.into())
238 }
239 })
240 .collect::<FuturesUnordered<_>>();
241
242 let mut all_traces = padded_traces;
243
244 futures::stream_select!(copied_host_traces, device_traces)
246 .for_each(|(name, trace)| {
247 all_traces.insert(
248 name,
249 PaddedMle::padded_with_zeros(Arc::new(trace), max_log_row_count as u32),
250 );
251 ready(())
252 })
253 .await;
254
255 let public_values = record.public_values::<F>();
258
259 rayon::spawn(move || drop(record));
262
263 let traces = Traces { named_traces: all_traces };
264 (traces, public_values)
265 }
266}
267
268impl<F, A> TraceGenerator<F, A, TaskScope> for CudaTraceGenerator<F, A>
269where
270 F: Field,
271 A: CudaTracegenAir<F>,
272 TaskScope: DeviceTransposeKernel<F>,
273{
274 fn machine(&self) -> &Machine<F, A> {
275 &self.machine
276 }
277
278 fn allocator(&self) -> &TaskScope {
279 &self.trace_allocator
280 }
281
282 async fn generate_preprocessed_traces(
283 &self,
284 program: Arc<<A as MachineAir<F>>::Program>,
285 max_log_row_count: usize,
286 prover_permits: ProverSemaphore,
287 ) -> PreprocessedTraceData<F, TaskScope> {
288 let host_phase_tracegen = self.host_preprocessed_tracegen(Arc::clone(&program));
289
290 let permit = prover_permits.acquire().instrument(debug_span!("acquire")).await.unwrap();
292
293 let preprocessed_traces = self
298 .device_preprocessed_tracegen(program, max_log_row_count, host_phase_tracegen)
299 .await;
300 PreprocessedTraceData { preprocessed_traces, permit }
301 }
302
303 async fn generate_main_traces(
304 &self,
305 record: <A as MachineAir<F>>::Record,
306 max_log_row_count: usize,
307 prover_permits: ProverSemaphore,
308 ) -> MainTraceData<F, A, TaskScope> {
309 let record = Arc::new(record);
310
311 let (host_phase_tracegen, HostPhaseShapePadding { shard_chips, padded_traces }) =
312 self.host_main_tracegen(Arc::clone(&record), max_log_row_count);
313
314 let permit = prover_permits.acquire().instrument(debug_span!("acquire")).await.unwrap();
316
317 let (traces, public_values) = self
322 .device_main_tracegen(max_log_row_count, record, host_phase_tracegen, padded_traces)
323 .await;
324
325 MainTraceData { traces, public_values, permit, shard_chips }
326 }
327
328 async fn generate_traces(
329 &self,
330 program: Arc<<A as MachineAir<F>>::Program>,
331 record: <A as MachineAir<F>>::Record,
332 max_log_row_count: usize,
333 prover_permits: sp1_hypercube::prover::ProverSemaphore,
334 ) -> TraceData<F, A, TaskScope> {
335 let record = Arc::new(record);
336
337 let prep_host_phase_tracegen = self.host_preprocessed_tracegen(Arc::clone(&program));
338
339 let (main_host_phase_tracegen, HostPhaseShapePadding { shard_chips, padded_traces }) =
340 self.host_main_tracegen(Arc::clone(&record), max_log_row_count);
341
342 let permit = prover_permits.acquire().instrument(debug_span!("acquire")).await.unwrap();
344
345 let (preprocessed_traces, (traces, public_values)) = join!(
350 self.device_preprocessed_tracegen(program, max_log_row_count, prep_host_phase_tracegen),
351 self.device_main_tracegen(
352 max_log_row_count,
353 record,
354 main_host_phase_tracegen,
355 padded_traces,
356 )
357 );
358
359 TraceData {
360 preprocessed_traces,
361 main_trace_data: MainTraceData { traces, public_values, permit, shard_chips },
362 }
363 }
364}
365
366pub trait CudaTracegenAir<F: Field>: MachineAir<F> {
368 fn supports_device_preprocessed_tracegen(&self) -> bool {
370 false
371 }
372
373 #[allow(unused_variables)]
378 fn generate_preprocessed_trace_device(
379 &self,
380 program: &Self::Program,
381 scope: &TaskScope,
382 ) -> impl Future<Output = Result<Option<DeviceMle<F>>, CopyError>> + Send {
383 #[allow(unreachable_code)]
384 ready(unimplemented!())
385 }
386
387 fn supports_device_main_tracegen(&self) -> bool {
389 false
390 }
391
392 #[allow(unused_variables)]
397 fn generate_trace_device(
398 &self,
399 input: &Self::Record,
400 output: &mut Self::Record,
401 scope: &TaskScope,
402 ) -> impl Future<Output = Result<DeviceMle<F>, CopyError>> + Send {
403 #[allow(unreachable_code)]
404 ready(unimplemented!())
405 }
406}
407
408#[cfg(test)]
409pub(crate) mod tests {
410 use super::{CudaTracegenAir, F};
411 use rand::{rngs::StdRng, SeedableRng};
412 use slop_tensor::Tensor;
413 use sp1_gpu_cudart::TaskScope;
414 use sp1_hypercube::air::MachineAir;
415 use std::collections::BTreeSet;
416
417 pub(crate) fn test_traces_eq(
418 trace: &Tensor<F>,
419 gpu_trace: &Tensor<F>,
420 events: &[impl core::fmt::Debug],
421 ) {
422 assert_eq!(gpu_trace.dimensions, trace.dimensions);
423
424 tracing::info!("{:?}", trace.dimensions);
425
426 let mut eventful_mismatched_columns = BTreeSet::new();
427 let mut padding_mismatched_columns = BTreeSet::new();
428 for row_idx in 0..trace.sizes()[0] {
429 let mut col_mismatches = BTreeSet::new();
430 for col_idx in 0..trace.sizes()[1] {
431 let actual = gpu_trace[[row_idx, col_idx]];
432 let expected = trace[[row_idx, col_idx]];
433 if actual != expected {
434 tracing::error!(
435 "mismatch on row {} col {}. actual: {:?} expected: {:?}",
436 row_idx,
437 col_idx,
438 *actual,
439 *expected
440 );
441 col_mismatches.insert(col_idx);
442 }
443 }
444 let event = events.get(row_idx);
445 if col_mismatches.is_empty() {
446 tracing::info!(
447 "row {row_idx} matches . event (assuming events/row = 1): {event:?}"
448 );
449 } else {
450 tracing::error!(
451 "row {row_idx} MISMATCHES. event (assuming events/row = 1): {event:?}"
452 );
453 tracing::error!("mismatched columns: {col_mismatches:?}");
454 }
455 if event.is_some() {
456 eventful_mismatched_columns.extend(col_mismatches);
457 } else {
458 padding_mismatched_columns.extend(col_mismatches);
459 }
460 }
461 tracing::info!("eventful mismatched columns: {eventful_mismatched_columns:?}");
462 tracing::info!("padding mismatched columns: {padding_mismatched_columns:?}");
463
464 assert_eq!(gpu_trace, trace);
465 }
466
467 pub async fn test_main_tracegen<A, Event, Record>(
468 chip: A,
469 mut make_event: impl FnMut(&mut StdRng) -> Event,
470 mut insert_events: impl FnMut(Vec<Event>) -> Record,
471 scope: TaskScope,
472 ) where
473 A: CudaTracegenAir<F> + MachineAir<F, Record = Record>,
474 Record: Default,
475 Event: Clone + core::fmt::Debug,
476 {
477 let mut rng = StdRng::seed_from_u64(0xDEADBEEF);
478
479 let events =
480 core::iter::repeat_with(|| make_event(&mut rng)).take(1000).collect::<Vec<_>>();
481
482 let [shard, gpu_shard] = core::array::from_fn(|_| insert_events(events.clone()));
483
484 let trace = Tensor::<F>::from(chip.generate_trace(&shard, &mut Record::default()));
485
486 let gpu_trace = chip
487 .generate_trace_device(&gpu_shard, &mut Record::default(), &scope)
488 .await
489 .expect("should copy events to device successfully")
490 .to_host()
491 .expect("should copy trace to host successfully")
492 .into_guts();
493
494 crate::tests::test_traces_eq(&trace, &gpu_trace, &events);
495 }
496}