1use std::sync::Arc;
2
3use slop_challenger::IopCtx;
4use slop_dft::DftOrdering;
5use sp1_gpu_cudart::TaskScope;
6use sp1_gpu_merkle_tree::MerkleTree;
7use sp1_gpu_utils::Felt;
8
9use slop_algebra::{AbstractField, Field};
10use slop_tensor::{Tensor, TensorView};
11use sp1_gpu_cudart::{
12 sys::dft::{batch_coset_dft, sppark_init_default_stream},
13 CudaError, DeviceCopy,
14};
15use sp1_primitives::SP1Field;
16
17pub fn encode_batch<'a>(
18 dft: SpparkDftKoalaBear,
19 log_blowup: u32,
20 data: TensorView<'a, Felt, TaskScope>,
21 dst: &mut Tensor<Felt, TaskScope>,
22) -> Result<(), CudaError> {
23 dft.coset_dft_into(
24 data,
25 dst,
26 <Felt as AbstractField>::one(),
27 log_blowup as usize,
28 DftOrdering::BitReversed,
29 1,
30 )
31 .unwrap();
32 Ok(())
33}
34
35pub trait SpparkCudaDftSys<T: DeviceCopy>: 'static + Send + Sync {
36 #[allow(clippy::too_many_arguments)]
40 unsafe fn dft_unchecked(
41 &self,
42 d_out: *mut T,
43 d_in: *mut T,
44 lg_domain_size: u32,
45 lg_blowup: u32,
46 shift: T,
47 batch_size: u32,
48 bit_rev_output: bool,
49 backend: &TaskScope,
50 ) -> Result<(), CudaError>;
51}
52
53#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord, Hash)]
54pub struct SpparkDft<F, T>(pub F, std::marker::PhantomData<T>);
55
56#[derive(Clone)]
57pub struct CudaStackedPcsProverData<GC: IopCtx> {
58 pub merkle_tree_tcs_data: (MerkleTree<GC::Digest, TaskScope>, GC::Digest, usize, usize),
60 pub codeword_mle: Option<Arc<Tensor<GC::F, TaskScope>>>,
62}
63
64impl<T: Field, F: SpparkCudaDftSys<T>> SpparkDft<F, T> {
65 fn coset_dft_into<'a>(
67 &self,
68 src: TensorView<'a, T, TaskScope>,
69 dst: &mut Tensor<T, TaskScope>,
70 shift: T,
71 log_blowup: usize,
72 ordering: DftOrdering,
73 dim: usize,
74 ) -> Result<(), CudaError> {
75 let backend = src.backend();
76 let d_in = src.as_ptr() as *mut T;
77 let d_out = dst.as_mut_ptr();
78 let src_dimensions = src.sizes();
79 let dst_dimensions = dst.sizes();
80
81 let shift = shift / T::generator();
82
83 assert_eq!(
84 src_dimensions[0], dst_dimensions[0],
85 "dimension mismatch along the first dimension"
86 );
87 assert_eq!(src.sizes().len(), 2);
88 assert_eq!(dst.sizes().len(), 2);
89 assert_eq!(dim, 1);
90
91 let lg_domain_size = src_dimensions[1].ilog2();
92 let lg_blowup = dst_dimensions[1].ilog2() - lg_domain_size;
93 assert_eq!(log_blowup, lg_blowup as usize);
94 let batch_size = src_dimensions[0] as u32;
95 let bit_rev_output = ordering == DftOrdering::BitReversed;
96
97 unsafe {
98 dst.assume_init();
100 self.0.dft_unchecked(
102 d_out,
103 d_in,
104 lg_domain_size,
105 lg_blowup,
106 shift,
107 batch_size,
108 bit_rev_output,
109 backend,
110 )
111 }
112 }
113}
114
115#[derive(Copy, Clone, Debug)]
116pub struct SpparkB31Kernels;
117
118pub type SpparkDftKoalaBear = SpparkDft<SpparkB31Kernels, Felt>;
119
120impl Default for SpparkB31Kernels {
121 fn default() -> Self {
122 unsafe { sppark_init_default_stream() };
123 Self
124 }
125}
126
127impl SpparkCudaDftSys<SP1Field> for SpparkB31Kernels {
128 unsafe fn dft_unchecked(
129 &self,
130 d_out: *mut SP1Field,
131 d_in: *mut SP1Field,
132 lg_domain_size: u32,
133 lg_blowup: u32,
134 shift: SP1Field,
135 batch_size: u32,
136 bit_rev_output: bool,
137 scope: &TaskScope,
138 ) -> Result<(), CudaError> {
139 CudaError::result_from_ffi(batch_coset_dft(
140 d_out,
141 d_in,
142 lg_domain_size,
143 lg_blowup,
144 shift,
145 batch_size,
146 bit_rev_output,
147 scope.handle(),
148 ))
149 }
150}
151
152#[cfg(test)]
153mod tests {
154 use itertools::Itertools;
155 use rand::thread_rng;
156 use slop_algebra::AbstractField;
157 use slop_dft::{p3::Radix2DitParallel, Dft};
158
159 use sp1_gpu_cudart::{run_sync_in_place, DeviceTensor};
160
161 use super::*;
162
163 #[test]
164 fn test_batch_coset_dft() {
165 let mut rng = thread_rng();
166
167 let log_degrees = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15];
168 let log_blowup = 1;
169 let shift = SP1Field::generator();
170 let batch_size = 16;
171
172 let p3_dft = Radix2DitParallel;
173
174 for log_d in log_degrees.iter() {
175 let d = 1 << log_d;
176
177 let tensor_h = Tensor::<SP1Field>::rand(&mut rng, [d, batch_size]);
178
179 let tensor_h_sent = tensor_h.clone();
180 let result = run_sync_in_place(|t| {
181 let tensor_raw = DeviceTensor::from_host(&tensor_h_sent, &t).unwrap().into_inner();
182 let tensor = DeviceTensor::from_raw(tensor_raw).transpose().into_inner();
183 let dft = SpparkDftKoalaBear::default();
184 let mut dst =
185 Tensor::<Felt, _>::with_sizes_in([batch_size, d << log_blowup], t.clone());
186 dft.coset_dft_into(
187 tensor.as_view(),
188 &mut dst,
189 shift,
190 log_blowup,
191 DftOrdering::BitReversed,
192 1,
193 )
194 .unwrap();
195
196 let result = DeviceTensor::from_raw(dst).transpose();
197 result.to_host().unwrap()
198 })
199 .unwrap();
200
201 let expected_result = p3_dft
202 .coset_dft(&tensor_h, shift, log_blowup, DftOrdering::BitReversed, 0)
203 .unwrap();
204
205 for (i, (r, e)) in
206 result.as_slice().iter().zip_eq(expected_result.as_slice()).enumerate()
207 {
208 assert_eq!(r, e, "Mismatch at index {i}");
209 }
210 }
211 }
212}