Skip to main content

sp1_gpu_basefold/
encoder.rs

1use std::sync::Arc;
2
3use slop_challenger::IopCtx;
4use slop_dft::DftOrdering;
5use sp1_gpu_cudart::TaskScope;
6use sp1_gpu_merkle_tree::MerkleTree;
7use sp1_gpu_utils::Felt;
8
9use slop_algebra::{AbstractField, Field};
10use slop_tensor::{Tensor, TensorView};
11use sp1_gpu_cudart::{
12    sys::dft::{batch_coset_dft, sppark_init_default_stream},
13    CudaError, DeviceCopy,
14};
15use sp1_primitives::SP1Field;
16
17pub fn encode_batch<'a>(
18    dft: SpparkDftKoalaBear,
19    log_blowup: u32,
20    data: TensorView<'a, Felt, TaskScope>,
21    dst: &mut Tensor<Felt, TaskScope>,
22) -> Result<(), CudaError> {
23    dft.coset_dft_into(
24        data,
25        dst,
26        <Felt as AbstractField>::one(),
27        log_blowup as usize,
28        DftOrdering::BitReversed,
29        1,
30    )
31    .unwrap();
32    Ok(())
33}
34
35pub trait SpparkCudaDftSys<T: DeviceCopy>: 'static + Send + Sync {
36    /// # Safety
37    ///
38    /// The caller must ensure the validity of pointers, allocation size, and lifetimes.
39    #[allow(clippy::too_many_arguments)]
40    unsafe fn dft_unchecked(
41        &self,
42        d_out: *mut T,
43        d_in: *mut T,
44        lg_domain_size: u32,
45        lg_blowup: u32,
46        shift: T,
47        batch_size: u32,
48        bit_rev_output: bool,
49        backend: &TaskScope,
50    ) -> Result<(), CudaError>;
51}
52
53#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord, Hash)]
54pub struct SpparkDft<F, T>(pub F, std::marker::PhantomData<T>);
55
56#[derive(Clone)]
57pub struct CudaStackedPcsProverData<GC: IopCtx> {
58    /// The usizes are the height of the Merkle tree and the number of elements in a leaf.
59    pub merkle_tree_tcs_data: (MerkleTree<GC::Digest, TaskScope>, GC::Digest, usize, usize),
60    /// The codeword (encoded polynomial). This is `None` when `drop_traces` is true.
61    pub codeword_mle: Option<Arc<Tensor<GC::F, TaskScope>>>,
62}
63
64impl<T: Field, F: SpparkCudaDftSys<T>> SpparkDft<F, T> {
65    /// Performs a discrete Fourier transform along the last dimension of the input tensor.
66    fn coset_dft_into<'a>(
67        &self,
68        src: TensorView<'a, T, TaskScope>,
69        dst: &mut Tensor<T, TaskScope>,
70        shift: T,
71        log_blowup: usize,
72        ordering: DftOrdering,
73        dim: usize,
74    ) -> Result<(), CudaError> {
75        let backend = src.backend();
76        let d_in = src.as_ptr() as *mut T;
77        let d_out = dst.as_mut_ptr();
78        let src_dimensions = src.sizes();
79        let dst_dimensions = dst.sizes();
80
81        let shift = shift / T::generator();
82
83        assert_eq!(
84            src_dimensions[0], dst_dimensions[0],
85            "dimension mismatch along the first dimension"
86        );
87        assert_eq!(src.sizes().len(), 2);
88        assert_eq!(dst.sizes().len(), 2);
89        assert_eq!(dim, 1);
90
91        let lg_domain_size = src_dimensions[1].ilog2();
92        let lg_blowup = dst_dimensions[1].ilog2() - lg_domain_size;
93        assert_eq!(log_blowup, lg_blowup as usize);
94        let batch_size = src_dimensions[0] as u32;
95        let bit_rev_output = ordering == DftOrdering::BitReversed;
96
97        unsafe {
98            // Set the correct length for the output tensor
99            dst.assume_init();
100            // Call the function.
101            self.0.dft_unchecked(
102                d_out,
103                d_in,
104                lg_domain_size,
105                lg_blowup,
106                shift,
107                batch_size,
108                bit_rev_output,
109                backend,
110            )
111        }
112    }
113}
114
115#[derive(Copy, Clone, Debug)]
116pub struct SpparkB31Kernels;
117
118pub type SpparkDftKoalaBear = SpparkDft<SpparkB31Kernels, Felt>;
119
120impl Default for SpparkB31Kernels {
121    fn default() -> Self {
122        unsafe { sppark_init_default_stream() };
123        Self
124    }
125}
126
127impl SpparkCudaDftSys<SP1Field> for SpparkB31Kernels {
128    unsafe fn dft_unchecked(
129        &self,
130        d_out: *mut SP1Field,
131        d_in: *mut SP1Field,
132        lg_domain_size: u32,
133        lg_blowup: u32,
134        shift: SP1Field,
135        batch_size: u32,
136        bit_rev_output: bool,
137        scope: &TaskScope,
138    ) -> Result<(), CudaError> {
139        CudaError::result_from_ffi(batch_coset_dft(
140            d_out,
141            d_in,
142            lg_domain_size,
143            lg_blowup,
144            shift,
145            batch_size,
146            bit_rev_output,
147            scope.handle(),
148        ))
149    }
150}
151
152#[cfg(test)]
153mod tests {
154    use itertools::Itertools;
155    use rand::thread_rng;
156    use slop_algebra::AbstractField;
157    use slop_dft::{p3::Radix2DitParallel, Dft};
158
159    use sp1_gpu_cudart::{run_sync_in_place, DeviceTensor};
160
161    use super::*;
162
163    #[test]
164    fn test_batch_coset_dft() {
165        let mut rng = thread_rng();
166
167        let log_degrees = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15];
168        let log_blowup = 1;
169        let shift = SP1Field::generator();
170        let batch_size = 16;
171
172        let p3_dft = Radix2DitParallel;
173
174        for log_d in log_degrees.iter() {
175            let d = 1 << log_d;
176
177            let tensor_h = Tensor::<SP1Field>::rand(&mut rng, [d, batch_size]);
178
179            let tensor_h_sent = tensor_h.clone();
180            let result = run_sync_in_place(|t| {
181                let tensor_raw = DeviceTensor::from_host(&tensor_h_sent, &t).unwrap().into_inner();
182                let tensor = DeviceTensor::from_raw(tensor_raw).transpose().into_inner();
183                let dft = SpparkDftKoalaBear::default();
184                let mut dst =
185                    Tensor::<Felt, _>::with_sizes_in([batch_size, d << log_blowup], t.clone());
186                dft.coset_dft_into(
187                    tensor.as_view(),
188                    &mut dst,
189                    shift,
190                    log_blowup,
191                    DftOrdering::BitReversed,
192                    1,
193                )
194                .unwrap();
195
196                let result = DeviceTensor::from_raw(dst).transpose();
197                result.to_host().unwrap()
198            })
199            .unwrap();
200
201            let expected_result = p3_dft
202                .coset_dft(&tensor_h, shift, log_blowup, DftOrdering::BitReversed, 0)
203                .unwrap();
204
205            for (i, (r, e)) in
206                result.as_slice().iter().zip_eq(expected_result.as_slice()).enumerate()
207            {
208                assert_eq!(r, e, "Mismatch at index {i}");
209            }
210        }
211    }
212}