1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
use std::iter::once;
use slop_alloc::{Backend, Buffer, CpuBackend, HasBackend};
use slop_tensor::{Dimensions, Tensor};
use sp1_gpu_cudart::{args, TaskScope};
#[derive(Clone, Debug)]
#[repr(C)]
pub struct JaggedMle<D: DenseData<A>, A: Backend> {
/// col_index[i / 2] is the column that the i'th element of the dense data belongs to.
pub col_index: Buffer<u32, A>,
/// start_indices[i] is the half the dense index of the first element of the i'th column.
pub start_indices: Buffer<u32, A>,
/// column_heights[i] is half of the height of the i'th column. Device-
/// resident — every fold runs on the GPU, and zerocheck consumes it
/// directly on device to derive per-chip layouts without a host round-trip.
pub column_heights: Buffer<u32, A>,
pub dense_data: D,
}
pub struct VirtualTensor<T, B: Backend> {
pub data: *const T,
pub sizes: Dimensions,
pub backend: B,
}
impl<T, B: Backend> VirtualTensor<T, B> {
pub fn new(data: *const T, sizes: Dimensions, backend: B) -> Self {
Self { data, sizes, backend }
}
pub fn sizes(&self) -> &[usize] {
self.sizes.sizes()
}
pub fn backend(&self) -> &B {
&self.backend
}
pub fn as_ptr(&self) -> *const T {
self.data
}
pub fn from_tensor(tensor: &Tensor<T, B>) -> Self {
Self {
data: tensor.as_ptr(),
sizes: tensor.shape().clone(),
backend: tensor.backend().clone(),
}
}
}
pub trait DenseData<A: Backend> {
type DenseDataRaw;
fn as_ptr(&self) -> Self::DenseDataRaw;
}
pub trait DenseDataMut<A: Backend>: DenseData<A> {
type DenseDataMutRaw;
fn as_mut_ptr(&mut self) -> Self::DenseDataMutRaw;
}
/// The raw pointer equivalent of [`JaggedMle`] for use in cuda kernels.
#[repr(C)]
pub struct JaggedMleRaw<D: DenseData<A>, A: Backend> {
col_index: *const u32,
start_indices: *const u32,
dense_data: D::DenseDataRaw,
}
/// The mutable raw pointer equivalent of [`JaggedMle`] for use in cuda kernels.
#[repr(C)]
pub struct JaggedMleMutRaw<D: DenseDataMut<A>, A: Backend> {
col_index: *mut u32,
start_indices: *mut u32,
dense_data: D::DenseDataMutRaw,
}
impl<D: DenseData<A>, A: Backend> JaggedMle<D, A> {
pub fn as_raw(&self) -> JaggedMleRaw<D, A> {
JaggedMleRaw {
col_index: self.col_index.as_ptr(),
start_indices: self.start_indices.as_ptr(),
dense_data: self.dense_data.as_ptr(),
}
}
pub fn as_mut_raw(&mut self) -> JaggedMleMutRaw<D, A>
where
D: DenseDataMut<A>,
{
JaggedMleMutRaw {
col_index: self.col_index.as_mut_ptr(),
start_indices: self.start_indices.as_mut_ptr(),
dense_data: self.dense_data.as_mut_ptr(),
}
}
pub fn new(
dense_data: D,
col_index: Buffer<u32, A>,
start_indices: Buffer<u32, A>,
column_heights: Buffer<u32, A>,
) -> Self {
Self { dense_data, col_index, start_indices, column_heights }
}
pub fn column_heights(&self) -> &Buffer<u32, A> {
&self.column_heights
}
pub fn dense(&self) -> &D {
&self.dense_data
}
pub fn dense_mut(&mut self) -> &mut D {
&mut self.dense_data
}
pub fn col_index(&self) -> &Buffer<u32, A> {
&self.col_index
}
pub fn col_index_mut(&mut self) -> &mut Buffer<u32, A> {
&mut self.col_index
}
pub fn start_indices(&self) -> &Buffer<u32, A> {
&self.start_indices
}
pub fn start_indices_mut(&mut self) -> &mut Buffer<u32, A> {
&mut self.start_indices
}
pub fn into_parts(self) -> (D, Buffer<u32, A>, Buffer<u32, A>) {
(self.dense_data, self.col_index, self.start_indices)
}
}
impl<D: DenseData<TaskScope>> JaggedMle<D, TaskScope> {
/// Computes the next start indices, column heights and *input* total
/// length for use in jagged fix last variable.
///
/// Returns host buffers; the caller uploads device copies as needed. We
/// download `column_heights` once and compute on host because the per-
/// round fold's hot work happens on device — this metadata derivation is
/// O(n_columns) and the round trip is cheap. The input length is returned
/// alongside so callers don't re-download `column_heights` to sum it.
///
/// TODO: ignore all of the padding stuff.
pub fn next_start_indices_and_column_heights(
&self,
) -> (Buffer<u32, CpuBackend>, Vec<u32>, u32) {
// SAFETY: `column_heights` was populated via `extend_from_host_slice`
// (or the equivalent during fold), so the device range is fully
// initialised up to `len()`.
let host_column_heights: Vec<u32> = unsafe { self.column_heights.copy_into_host_vec() };
let input_length = host_column_heights.iter().sum::<u32>();
let output_heights =
host_column_heights.iter().map(|height| height.div_ceil(4) * 2).collect::<Vec<u32>>();
let new_start_idx = once(0)
.chain(output_heights.iter().scan(0u32, |acc, x| {
*acc += x;
Some(*acc)
}))
.collect::<Vec<_>>();
let buffer_start_idx = Buffer::from(new_start_idx);
(buffer_start_idx, output_heights, input_length)
}
/// Device-resident counterpart of [`Self::next_start_indices_and_column_heights`].
///
/// Runs the `jagged_fold_metadata` kernel to compute the new `column_heights`
/// and `start_indices` on device (no host download of the input
/// `column_heights`, no host upload of the derived metadata). Reads back
/// only the final `output_height` scalar (last element of new
/// `start_indices`, ~4 bytes) since downstream callers need it to size
/// host-allocated output tensors.
///
/// Replaces the bulk `D2H column_heights + 2× H2D start_idx/heights`
/// pattern with a single kernel launch + tiny D2H — on `v6/rsp` this
/// drops ~50 k of `cudaMemcpyAsync` calls per prove across the 4
/// logup_gkr callers (`execution::layer_transition`,
/// `execution::first_layer_transition`, `sumcheck::fix_and_sum_first_layer`,
/// `sumcheck::fix_and_sum_layer_transition`).
///
/// Allocates the scan bookkeeping (`block_counter`, `flags`,
/// `scan_values`) inline. The bookkeeping is small (≤ `n_blocks + 1`
/// `u32` each, where `n_blocks = ceil(n_columns / SECTION_SIZE)` —
/// typically 1 for shards with a few hundred columns), so the extra
/// allocs are cheap compared to the bulk transfers eliminated.
///
/// Returns `(new_start_indices_dev, new_column_heights_dev, output_height)`.
pub fn next_start_indices_and_column_heights_dev(
&self,
) -> (Buffer<u32, TaskScope>, Buffer<u32, TaskScope>, u32) {
let backend = self.column_heights.backend();
let n_columns = self.column_heights.len();
let section_size =
unsafe { sp1_gpu_cudart::sys::kernels::jagged_fold_metadata_section_size() } as usize;
let block_dim = unsafe { sp1_gpu_cudart::sys::kernels::jagged_fold_metadata_block_dim() };
let n_blocks: usize = n_columns.div_ceil(section_size).max(1);
let mut new_column_heights =
Buffer::<u32, TaskScope>::with_capacity_in(n_columns, backend.clone());
let mut new_start_indices =
Buffer::<u32, TaskScope>::with_capacity_in(n_columns + 1, backend.clone());
// SAFETY: the fold-metadata kernel writes all `n_columns` +
// `n_columns + 1` slots before any downstream read.
unsafe {
new_column_heights.assume_init();
new_start_indices.assume_init();
}
// Decoupled-lookback scan bookkeeping. Per the contract in
// `fold_metadata.cuh`: `block_counter[0] = 0`, `flags[0] = 1` so
// the first block doesn't wait, `flags[1..]` and `scan_values[..]`
// start at zero.
let u32_bytes = std::mem::size_of::<u32>();
let mut block_counter = Buffer::<u32, TaskScope>::with_capacity_in(1, backend.clone());
let mut flags = Buffer::<u32, TaskScope>::with_capacity_in(n_blocks + 1, backend.clone());
let mut scan_values =
Buffer::<u32, TaskScope>::with_capacity_in(n_blocks + 1, backend.clone());
block_counter.write_bytes(0, u32_bytes).unwrap();
flags.write_bytes(1, u32_bytes).unwrap();
flags.write_bytes(0, n_blocks * u32_bytes).unwrap();
scan_values.write_bytes(0, (n_blocks + 1) * u32_bytes).unwrap();
// SAFETY: `args!` tuple matches `jagged_fold_metadata`'s C signature
// in `sys/include/jagged_assist/fold_metadata.cuh`; every pointer
// borrows from a Buffer owned for the launch's lifetime.
unsafe {
let a = args!(
self.column_heights.as_ptr(),
n_columns as u32,
new_column_heights.as_mut_ptr(),
new_start_indices.as_mut_ptr(),
block_counter.as_mut_ptr(),
flags.as_mut_ptr(),
scan_values.as_mut_ptr()
);
backend
.launch_kernel(
sp1_gpu_cudart::sys::kernels::jagged_fold_metadata_kernel(),
(n_blocks as u32, 1u32, 1u32),
(block_dim, 1u32, 1u32),
&a,
0,
)
.unwrap();
}
// Read back `output_height = new_start_indices[n_columns]`. The
// downstream caller needs this scalar to size the next layer's
// host-allocated output tensors. We download the whole
// `new_start_indices` buffer (n_columns + 1 u32 ≈ a few KB,
// *much* smaller than the bulk transfers this path replaces) and
// grab the last element. A future optimization could maintain
// `output_height` as a host-tracked scalar via the same
// recurrence the kernel runs, eliminating this final D2H.
// SAFETY: kernel above fully wrote `new_start_indices`; the
// download synchronizes on `backend`'s stream.
let host_start_idx: Vec<u32> = unsafe { new_start_indices.copy_into_host_vec() };
let output_height = *host_start_idx.last().unwrap();
(new_start_indices, new_column_heights, output_height)
}
}
impl<D: DenseData<A>, A: Backend> HasBackend for JaggedMle<D, A> {
type Backend = A;
fn backend(&self) -> &A {
self.col_index.backend()
}
}