Skip to main content

sp1_gpu_utils/
traces.rs

1use std::collections::BTreeMap;
2use std::ops::{Deref, DerefMut, Range};
3
4use slop_algebra::Field;
5use slop_alloc::{Backend, Buffer, CpuBackend, HasBackend};
6use slop_tensor::{Dimensions, Tensor, TensorView};
7use sp1_gpu_cudart::{DeviceBuffer, TaskScope};
8
9use crate::jagged::JaggedMle;
10use crate::{DenseData, DenseDataMut};
11
12#[derive(Clone, Debug)]
13pub struct TraceOffset {
14    /// Dense data offset.
15    pub dense_offset: Range<usize>,
16    /// The size of each polynomial in this trace.
17    pub poly_size: usize,
18    /// Number of polynomials in this trace.
19    pub num_polys: usize,
20}
21
22#[derive(Clone)]
23pub struct JaggedTraceMle<F: Field, B: Backend>(pub JaggedMle<TraceDenseData<F, B>, B>);
24
25impl<F: Field, B: Backend> HasBackend for JaggedTraceMle<F, B> {
26    type Backend = B;
27    fn backend(&self) -> &B {
28        self.0.backend()
29    }
30}
31
32impl<F: Field, B: Backend> Deref for JaggedTraceMle<F, B> {
33    type Target = JaggedMle<TraceDenseData<F, B>, B>;
34
35    fn deref(&self) -> &Self::Target {
36        &self.0
37    }
38}
39
40impl<F: Field, B: Backend> DerefMut for JaggedTraceMle<F, B> {
41    fn deref_mut(&mut self) -> &mut Self::Target {
42        &mut self.0
43    }
44}
45
46/// Jagged representation of the traces.
47#[derive(Clone, Debug)]
48pub struct TraceDenseData<F: Field, B: Backend> {
49    /// The dense representation of the traces.
50    pub dense: Buffer<F, B>,
51    /// The dense offset of the preprocessed traces.
52    pub preprocessed_offset: usize,
53    /// The total number of columns in the preprocessed traces.
54    pub preprocessed_cols: usize,
55    /// The amount of preprocessed padding, to the next multiple of 2^log_stacking_height.
56    pub preprocessed_padding: usize,
57    /// The amount of main padding, to the next multiple of 2^log_stacking_height.
58    pub main_padding: usize,
59    /// Number of *columns* of preprocessed padding between the chip prep
60    /// section and the chip main section in the jagged structure. Equal to
61    /// `cols_so_far - Σ chip_prep_widths` after the prep section is
62    /// generated; both construction paths (real `jagged_tracegen` and
63    /// `from_chip_layout`) record this explicitly so consumers don't have
64    /// to guess from `preprocessed_padding` (which is in *element* units).
65    /// The real tracegen path can emit more than one such column when the
66    /// "fill to next stacking-multiple" loop allocates several.
67    pub prep_padding_col_count: usize,
68    /// Number of *columns* of main padding at the tail of the jagged
69    /// structure. Set after the main section is generated.
70    pub main_padding_col_count: usize,
71    /// A mapping from chip name to the range of dense data it occupies for preprocessed traces.
72    pub preprocessed_table_index: BTreeMap<String, TraceOffset>,
73    /// A mapping from chip name to the range of dense data it occupies for main traces.
74    pub main_table_index: BTreeMap<String, TraceOffset>,
75}
76
77impl<F: Field, B: Backend> TraceDenseData<F, B> {
78    pub fn main_virtual_tensor(&'_ self, log_stacking_height: u32) -> TensorView<'_, F, B> {
79        let ptr = unsafe { self.dense.as_ptr().add(self.preprocessed_offset) };
80        let sizes = Dimensions::try_from([
81            self.main_size() / (1 << log_stacking_height),
82            1 << log_stacking_height,
83        ])
84        .unwrap();
85        // This is safe because we inherit the lifetime of self and the offset should be valid.
86        unsafe { TensorView::from_raw_parts(ptr, sizes, self.backend().clone()) }
87    }
88
89    /// Copies the correct data from dense to a new tensor for main traces.
90    pub fn main_tensor(&self, log_stacking_height: u32) -> Tensor<F, B> {
91        let mut tensor = Tensor::with_sizes_in(
92            [self.main_size() / (1 << log_stacking_height), 1 << log_stacking_height],
93            self.backend().clone(),
94        );
95        let backend = self.dense.backend();
96        unsafe {
97            tensor.assume_init();
98            tensor
99                .as_mut_buffer()
100                .copy_from_slice(&self.dense[self.preprocessed_offset..], backend)
101                .unwrap();
102        }
103        tensor
104    }
105
106    pub fn preprocessed_virtual_tensor(&'_ self, log_stacking_height: u32) -> TensorView<'_, F, B> {
107        let ptr = self.dense.as_ptr();
108        let sizes = Dimensions::try_from([
109            self.preprocessed_offset / (1 << log_stacking_height),
110            1 << log_stacking_height,
111        ])
112        .unwrap();
113        unsafe { TensorView::from_raw_parts(ptr, sizes, self.backend().clone()) }
114    }
115
116    /// Copies the correct data from dense to a new tensor for preprocessed traces.
117    pub fn preprocessed_tensor(&self, log_stacking_height: u32) -> Tensor<F, B> {
118        let mut tensor = Tensor::with_sizes_in(
119            [self.preprocessed_offset / (1 << log_stacking_height), 1 << log_stacking_height],
120            self.backend().clone(),
121        );
122        let backend = self.dense.backend();
123        unsafe {
124            tensor.assume_init();
125            tensor
126                .as_mut_buffer()
127                .copy_from_slice(&self.dense[..self.preprocessed_offset], backend)
128                .unwrap();
129        }
130        tensor
131    }
132
133    /// The size of the main polynomial.
134    #[inline]
135    pub fn main_poly_height(&self, name: &str) -> Option<usize> {
136        self.main_table_index.get(name).map(|offset| offset.poly_size)
137    }
138
139    /// The size of the preprocessed polynomial.
140    #[inline]
141    pub fn preprocessed_poly_height(&self, name: &str) -> Option<usize> {
142        self.preprocessed_table_index.get(name).map(|offset| offset.poly_size)
143    }
144
145    /// The number of polynomials in the main trace.
146    #[inline]
147    pub fn main_num_polys(&self, name: &str) -> Option<usize> {
148        self.main_table_index.get(name).map(|offset| offset.num_polys)
149    }
150
151    /// The size of the main trace dense data, including padding.
152    #[inline]
153    pub fn main_size(&self) -> usize {
154        self.dense.len() - self.preprocessed_offset
155    }
156
157    /// The number of polynomials in the preprocessed trace.
158    #[inline]
159    pub fn preprocessed_num_polys(&self, name: &str) -> Option<usize> {
160        self.preprocessed_table_index.get(name).map(|offset| offset.num_polys)
161    }
162}
163
164/// Abstract description of a chip layout used to build [`TraceDenseData`] / [`JaggedTraceMle`].
165/// Each tuple is`(chip_name, preprocessed_width, main_width)` for one chip;
166pub struct AbstractChipLayout(Vec<(String, usize, usize)>);
167
168impl AbstractChipLayout {
169    pub fn new(entries: Vec<(String, usize, usize)>) -> Self {
170        Self(entries)
171    }
172
173    pub fn entries(&self) -> &[(String, usize, usize)] {
174        &self.0
175    }
176}
177
178/// Like [`AbstractChipLayout`], but with a per-chip row count attached to each entry.
179/// Each tuple is `(chip_name, preprocessed_width, main_width, height)` for one chip.
180pub struct AbstractChipLayoutWithHeights(Vec<(String, usize, usize, usize)>);
181
182impl AbstractChipLayoutWithHeights {
183    pub fn new(entries: Vec<(String, usize, usize, usize)>) -> Self {
184        Self(entries)
185    }
186
187    pub fn entries(&self) -> &[(String, usize, usize, usize)] {
188        &self.0
189    }
190
191    /// Chip names in layout order.
192    pub fn chip_names(&self) -> impl Iterator<Item = &str> {
193        self.0.iter().map(|(name, _, _, _)| name.as_str())
194    }
195}
196
197impl<F: Field> TraceDenseData<F, CpuBackend> {
198    /// Build a `TraceDenseData` over a pre-allocated `dense` buffer using an
199    /// [`AbstractChipLayoutWithHeights`].
200    ///
201    /// The `dense` buffer must be sized as `padded_preprocessed + padded_main`, where
202    /// each section is the unpadded total rounded up to the next multiple of
203    /// `2^log_stacking_height`.
204    pub fn from_chip_layout(
205        dense: Buffer<F, CpuBackend>,
206        layout: &AbstractChipLayoutWithHeights,
207        log_stacking_height: u32,
208    ) -> Self {
209        let stacking = 1usize << log_stacking_height;
210
211        let total_preprocessed: usize = layout.0.iter().map(|(_, p, _, h)| p * h).sum();
212        let total_main: usize = layout.0.iter().map(|(_, _, m, h)| m * h).sum();
213
214        // note that this makes sure there is always at least one main and one preprocessed column
215        let padded_preprocessed = total_preprocessed.next_multiple_of(stacking).max(stacking);
216        let padded_main = total_main.next_multiple_of(stacking).max(stacking);
217
218        assert_eq!(
219            dense.len(),
220            padded_preprocessed + padded_main,
221            "dense buffer length must equal padded_preprocessed + padded_main",
222        );
223
224        let preprocessed_cols: usize = layout.0.iter().map(|(_, p, _, _)| p).sum();
225
226        let mut preprocessed_table_index = BTreeMap::new();
227        let mut main_table_index = BTreeMap::new();
228        let mut preprocessed_ptr = 0usize;
229        let mut main_ptr = padded_preprocessed;
230        for (name, prep_w, main_w, h) in layout.0.iter() {
231            let prep_lo = preprocessed_ptr;
232            let prep_hi = prep_lo + h * prep_w;
233            preprocessed_table_index.insert(
234                name.clone(),
235                TraceOffset { dense_offset: prep_lo..prep_hi, poly_size: *h, num_polys: *prep_w },
236            );
237            preprocessed_ptr = prep_hi;
238
239            let main_lo = main_ptr;
240            let main_hi = main_lo + h * main_w;
241            main_table_index.insert(
242                name.clone(),
243                TraceOffset { dense_offset: main_lo..main_hi, poly_size: *h, num_polys: *main_w },
244            );
245            main_ptr = main_hi;
246        }
247
248        let preprocessed_padding = padded_preprocessed - total_preprocessed;
249        let main_padding = padded_main - total_main;
250        TraceDenseData {
251            dense,
252            preprocessed_offset: padded_preprocessed,
253            preprocessed_cols,
254            preprocessed_padding,
255            main_padding,
256            // `from_chip_layout` emits exactly one prep/main padding column
257            // when the corresponding padding is non-zero (see
258            // `JaggedTraceMle::from_chip_layout`).
259            prep_padding_col_count: (preprocessed_padding > 0) as usize,
260            main_padding_col_count: (main_padding > 0) as usize,
261            preprocessed_table_index,
262            main_table_index,
263        }
264    }
265}
266
267impl<F: Field, B: Backend> HasBackend for TraceDenseData<F, B> {
268    type Backend = B;
269    fn backend(&self) -> &B {
270        self.dense.backend()
271    }
272}
273
274impl<F: Field, B: Backend> JaggedTraceMle<F, B> {
275    pub fn new(
276        dense_data: TraceDenseData<F, B>,
277        col_index: Buffer<u32, B>,
278        start_indices: Buffer<u32, B>,
279        column_heights: Buffer<u32, B>,
280    ) -> Self {
281        JaggedTraceMle(JaggedMle::new(dense_data, col_index, start_indices, column_heights))
282    }
283}
284
285impl<F: Field> JaggedTraceMle<F, CpuBackend> {
286    /// Build a `JaggedTraceMle` over a pre-allocated `dense` buffer using a chip-layout
287    /// description as parallel slices. Constructs the inner [`TraceDenseData`] with the
288    /// same layout as [`TraceDenseData::from_chip_layout`], plus the jagged column
289    /// metadata: one logical column per chip column for both preprocessed and main,
290    /// plus one padding column per section that has nonzero padding.
291    ///
292    /// All heights must be even, since column heights and column-index entries
293    /// are stored at half-element granularity.
294    pub fn from_chip_layout(
295        dense: Buffer<F, CpuBackend>,
296        layout: &AbstractChipLayoutWithHeights,
297        log_stacking_height: u32,
298    ) -> Self {
299        assert!(layout.0.iter().all(|(_, _, _, h)| h % 2 == 0), "heights must be even");
300
301        let dense_data = TraceDenseData::from_chip_layout(dense, layout, log_stacking_height);
302
303        let total_dense = dense_data.dense.len();
304        let preprocessed_padding = dense_data.preprocessed_padding;
305        let main_padding = dense_data.main_padding;
306
307        let num_data_cols: usize = layout.0.iter().map(|(_, p, m, _)| p + m).sum();
308        let num_cols =
309            num_data_cols + (preprocessed_padding > 0) as usize + (main_padding > 0) as usize;
310
311        let mut col_index = vec![0u32; total_dense / 2];
312        let mut start_idx = vec![0u32; num_cols + 1];
313        let mut column_heights: Vec<u32> = Vec::with_capacity(num_cols);
314
315        let mut col: u32 = 0;
316        let mut cnt: usize = 0;
317
318        let mut emit = |w: usize, h: usize, col: &mut u32, cnt: &mut usize| {
319            let half = h / 2;
320            for _ in 0..w {
321                col_index[*cnt..*cnt + half].fill(*col);
322                *cnt += half;
323                start_idx[*col as usize + 1] = start_idx[*col as usize] + half as u32;
324                column_heights.push(half as u32);
325                *col += 1;
326            }
327        };
328
329        for (_, prep_w, _, h) in layout.0.iter() {
330            emit(*prep_w, *h, &mut col, &mut cnt);
331        }
332        if preprocessed_padding > 0 {
333            emit(1, preprocessed_padding, &mut col, &mut cnt);
334        }
335        for (_, _, main_w, h) in layout.0.iter() {
336            emit(*main_w, *h, &mut col, &mut cnt);
337        }
338        if main_padding > 0 {
339            emit(1, main_padding, &mut col, &mut cnt);
340        }
341
342        debug_assert_eq!(cnt, total_dense / 2);
343        debug_assert_eq!(col as usize, num_cols);
344
345        Self::new(
346            dense_data,
347            Buffer::from(col_index),
348            Buffer::from(start_idx),
349            Buffer::from(column_heights),
350        )
351    }
352}
353
354impl<F: Field> JaggedTraceMle<F, TaskScope> {
355    pub fn preprocessed_virtual_tensor(
356        &'_ self,
357        log_stacking_height: u32,
358    ) -> TensorView<'_, F, TaskScope> {
359        self.dense_data.preprocessed_virtual_tensor(log_stacking_height)
360    }
361
362    pub fn main_virtual_tensor(&'_ self, log_stacking_height: u32) -> TensorView<'_, F, TaskScope> {
363        self.dense_data.main_virtual_tensor(log_stacking_height)
364    }
365
366    pub fn main_poly_height(&self, name: &str) -> Option<usize> {
367        self.dense_data.main_poly_height(name)
368    }
369
370    pub fn preprocessed_poly_height(&self, name: &str) -> Option<usize> {
371        self.dense_data.preprocessed_poly_height(name)
372    }
373
374    pub fn main_num_polys(&self, name: &str) -> Option<usize> {
375        self.dense_data.main_num_polys(name)
376    }
377
378    pub fn main_size(&self) -> usize {
379        self.dense_data.main_size()
380    }
381
382    pub fn preprocessed_num_polys(&self, name: &str) -> Option<usize> {
383        self.dense_data.preprocessed_num_polys(name)
384    }
385}
386
387/// The raw pointer to the dense data, for use in CUDA FFI calls.
388#[repr(C)]
389pub struct TraceDenseDataRaw<F> {
390    dense: *const F,
391}
392
393/// The raw pointer to the dense data, for use in CUDA FFI calls.
394#[repr(C)]
395pub struct TraceDenseDataMutRaw<F> {
396    dense: *mut F,
397}
398
399impl<F: Field, B: Backend> DenseData<B> for TraceDenseData<F, B> {
400    type DenseDataRaw = TraceDenseDataRaw<F>;
401
402    fn as_ptr(&self) -> TraceDenseDataRaw<F> {
403        TraceDenseDataRaw { dense: self.dense.as_ptr() }
404    }
405}
406
407impl<F: Field, B: Backend> DenseDataMut<B> for TraceDenseData<F, B> {
408    type DenseDataMutRaw = TraceDenseDataMutRaw<F>;
409
410    fn as_mut_ptr(&mut self) -> TraceDenseDataMutRaw<F> {
411        TraceDenseDataMutRaw { dense: self.dense.as_mut_ptr() }
412    }
413}
414
415impl<F: Field> JaggedTraceMle<F, CpuBackend> {
416    pub fn into_device(self, t: &TaskScope) -> JaggedTraceMle<F, TaskScope> {
417        let JaggedMle { col_index, start_indices, column_heights, dense_data } = self.0;
418        JaggedTraceMle::new(
419            dense_data.into_device_in(t),
420            DeviceBuffer::from_host(&col_index, t).unwrap().into_inner(),
421            DeviceBuffer::from_host(&start_indices, t).unwrap().into_inner(),
422            DeviceBuffer::from_host(&column_heights, t).unwrap().into_inner(),
423        )
424    }
425}
426
427impl<F: Field> TraceDenseData<F, CpuBackend> {
428    pub fn into_device_in(self, t: &TaskScope) -> TraceDenseData<F, TaskScope> {
429        TraceDenseData {
430            dense: DeviceBuffer::from_host(&self.dense, t).unwrap().into_inner(),
431            preprocessed_offset: self.preprocessed_offset,
432            preprocessed_cols: self.preprocessed_cols,
433            preprocessed_table_index: self.preprocessed_table_index,
434            main_table_index: self.main_table_index,
435            preprocessed_padding: self.preprocessed_padding,
436            main_padding: self.main_padding,
437            prep_padding_col_count: self.prep_padding_col_count,
438            main_padding_col_count: self.main_padding_col_count,
439        }
440    }
441}
442
443impl<F: Field> JaggedTraceMle<F, TaskScope> {
444    pub fn into_host(self) -> JaggedTraceMle<F, CpuBackend> {
445        let JaggedMle { col_index, start_indices, column_heights, dense_data } = self.0;
446        let host_dense = dense_data.into_host();
447        // Convert device buffers to host using DeviceBuffer wrapper
448        let col_index_host = DeviceBuffer::from_raw(col_index).to_host().unwrap().into();
449        let start_indices_host = DeviceBuffer::from_raw(start_indices).to_host().unwrap().into();
450        let column_heights_host = DeviceBuffer::from_raw(column_heights).to_host().unwrap().into();
451        JaggedTraceMle::new(host_dense, col_index_host, start_indices_host, column_heights_host)
452    }
453}
454
455impl<F: Field> TraceDenseData<F, TaskScope> {
456    pub fn into_host(self) -> TraceDenseData<F, CpuBackend> {
457        let host_dense = DeviceBuffer::from_raw(self.dense).to_host().unwrap().into();
458        TraceDenseData {
459            dense: host_dense,
460            preprocessed_offset: self.preprocessed_offset,
461            preprocessed_cols: self.preprocessed_cols,
462            preprocessed_table_index: self.preprocessed_table_index,
463            main_table_index: self.main_table_index,
464            preprocessed_padding: self.preprocessed_padding,
465            main_padding: self.main_padding,
466            prep_padding_col_count: self.prep_padding_col_count,
467            main_padding_col_count: self.main_padding_col_count,
468        }
469    }
470}