1use std::iter::once;
2
3use slop_alloc::{Backend, Buffer, CpuBackend, HasBackend};
4use slop_tensor::{Dimensions, Tensor};
5use sp1_gpu_cudart::{args, TaskScope};
6
7#[derive(Clone, Debug)]
8#[repr(C)]
9pub struct JaggedMle<D: DenseData<A>, A: Backend> {
10 pub col_index: Buffer<u32, A>,
12 pub start_indices: Buffer<u32, A>,
14 pub column_heights: Buffer<u32, A>,
18 pub dense_data: D,
19}
20
21pub struct VirtualTensor<T, B: Backend> {
22 pub data: *const T,
23 pub sizes: Dimensions,
24 pub backend: B,
25}
26
27impl<T, B: Backend> VirtualTensor<T, B> {
28 pub fn new(data: *const T, sizes: Dimensions, backend: B) -> Self {
29 Self { data, sizes, backend }
30 }
31
32 pub fn sizes(&self) -> &[usize] {
33 self.sizes.sizes()
34 }
35
36 pub fn backend(&self) -> &B {
37 &self.backend
38 }
39
40 pub fn as_ptr(&self) -> *const T {
41 self.data
42 }
43
44 pub fn from_tensor(tensor: &Tensor<T, B>) -> Self {
45 Self {
46 data: tensor.as_ptr(),
47 sizes: tensor.shape().clone(),
48 backend: tensor.backend().clone(),
49 }
50 }
51}
52
53pub trait DenseData<A: Backend> {
54 type DenseDataRaw;
55 fn as_ptr(&self) -> Self::DenseDataRaw;
56}
57
58pub trait DenseDataMut<A: Backend>: DenseData<A> {
59 type DenseDataMutRaw;
60 fn as_mut_ptr(&mut self) -> Self::DenseDataMutRaw;
61}
62
63#[repr(C)]
65pub struct JaggedMleRaw<D: DenseData<A>, A: Backend> {
66 col_index: *const u32,
67 start_indices: *const u32,
68 dense_data: D::DenseDataRaw,
69}
70
71#[repr(C)]
73pub struct JaggedMleMutRaw<D: DenseDataMut<A>, A: Backend> {
74 col_index: *mut u32,
75 start_indices: *mut u32,
76 dense_data: D::DenseDataMutRaw,
77}
78
79impl<D: DenseData<A>, A: Backend> JaggedMle<D, A> {
80 pub fn as_raw(&self) -> JaggedMleRaw<D, A> {
81 JaggedMleRaw {
82 col_index: self.col_index.as_ptr(),
83 start_indices: self.start_indices.as_ptr(),
84 dense_data: self.dense_data.as_ptr(),
85 }
86 }
87
88 pub fn as_mut_raw(&mut self) -> JaggedMleMutRaw<D, A>
89 where
90 D: DenseDataMut<A>,
91 {
92 JaggedMleMutRaw {
93 col_index: self.col_index.as_mut_ptr(),
94 start_indices: self.start_indices.as_mut_ptr(),
95 dense_data: self.dense_data.as_mut_ptr(),
96 }
97 }
98
99 pub fn new(
100 dense_data: D,
101 col_index: Buffer<u32, A>,
102 start_indices: Buffer<u32, A>,
103 column_heights: Buffer<u32, A>,
104 ) -> Self {
105 Self { dense_data, col_index, start_indices, column_heights }
106 }
107
108 pub fn column_heights(&self) -> &Buffer<u32, A> {
109 &self.column_heights
110 }
111
112 pub fn dense(&self) -> &D {
113 &self.dense_data
114 }
115
116 pub fn dense_mut(&mut self) -> &mut D {
117 &mut self.dense_data
118 }
119
120 pub fn col_index(&self) -> &Buffer<u32, A> {
121 &self.col_index
122 }
123
124 pub fn col_index_mut(&mut self) -> &mut Buffer<u32, A> {
125 &mut self.col_index
126 }
127
128 pub fn start_indices(&self) -> &Buffer<u32, A> {
129 &self.start_indices
130 }
131
132 pub fn start_indices_mut(&mut self) -> &mut Buffer<u32, A> {
133 &mut self.start_indices
134 }
135
136 pub fn into_parts(self) -> (D, Buffer<u32, A>, Buffer<u32, A>) {
137 (self.dense_data, self.col_index, self.start_indices)
138 }
139}
140
141impl<D: DenseData<TaskScope>> JaggedMle<D, TaskScope> {
142 pub fn next_start_indices_and_column_heights(
153 &self,
154 ) -> (Buffer<u32, CpuBackend>, Vec<u32>, u32) {
155 let host_column_heights: Vec<u32> = unsafe { self.column_heights.copy_into_host_vec() };
159 let input_length = host_column_heights.iter().sum::<u32>();
160 let output_heights =
161 host_column_heights.iter().map(|height| height.div_ceil(4) * 2).collect::<Vec<u32>>();
162
163 let new_start_idx = once(0)
164 .chain(output_heights.iter().scan(0u32, |acc, x| {
165 *acc += x;
166 Some(*acc)
167 }))
168 .collect::<Vec<_>>();
169 let buffer_start_idx = Buffer::from(new_start_idx);
170 (buffer_start_idx, output_heights, input_length)
171 }
172
173 pub fn next_start_indices_and_column_heights_dev(
197 &self,
198 ) -> (Buffer<u32, TaskScope>, Buffer<u32, TaskScope>, u32) {
199 let backend = self.column_heights.backend();
200 let n_columns = self.column_heights.len();
201 let section_size =
202 unsafe { sp1_gpu_cudart::sys::kernels::jagged_fold_metadata_section_size() } as usize;
203 let block_dim = unsafe { sp1_gpu_cudart::sys::kernels::jagged_fold_metadata_block_dim() };
204 let n_blocks: usize = n_columns.div_ceil(section_size).max(1);
205
206 let mut new_column_heights =
207 Buffer::<u32, TaskScope>::with_capacity_in(n_columns, backend.clone());
208 let mut new_start_indices =
209 Buffer::<u32, TaskScope>::with_capacity_in(n_columns + 1, backend.clone());
210 unsafe {
213 new_column_heights.assume_init();
214 new_start_indices.assume_init();
215 }
216
217 let u32_bytes = std::mem::size_of::<u32>();
222 let mut block_counter = Buffer::<u32, TaskScope>::with_capacity_in(1, backend.clone());
223 let mut flags = Buffer::<u32, TaskScope>::with_capacity_in(n_blocks + 1, backend.clone());
224 let mut scan_values =
225 Buffer::<u32, TaskScope>::with_capacity_in(n_blocks + 1, backend.clone());
226 block_counter.write_bytes(0, u32_bytes).unwrap();
227 flags.write_bytes(1, u32_bytes).unwrap();
228 flags.write_bytes(0, n_blocks * u32_bytes).unwrap();
229 scan_values.write_bytes(0, (n_blocks + 1) * u32_bytes).unwrap();
230
231 unsafe {
235 let a = args!(
236 self.column_heights.as_ptr(),
237 n_columns as u32,
238 new_column_heights.as_mut_ptr(),
239 new_start_indices.as_mut_ptr(),
240 block_counter.as_mut_ptr(),
241 flags.as_mut_ptr(),
242 scan_values.as_mut_ptr()
243 );
244 backend
245 .launch_kernel(
246 sp1_gpu_cudart::sys::kernels::jagged_fold_metadata_kernel(),
247 (n_blocks as u32, 1u32, 1u32),
248 (block_dim, 1u32, 1u32),
249 &a,
250 0,
251 )
252 .unwrap();
253 }
254
255 let host_start_idx: Vec<u32> = unsafe { new_start_indices.copy_into_host_vec() };
266 let output_height = *host_start_idx.last().unwrap();
267
268 (new_start_indices, new_column_heights, output_height)
269 }
270}
271
272impl<D: DenseData<A>, A: Backend> HasBackend for JaggedMle<D, A> {
273 type Backend = A;
274 fn backend(&self) -> &A {
275 self.col_index.backend()
276 }
277}