Skip to main content

poulpy_hal/layouts/
vec_znx_dft.rs

1use std::{
2    fmt,
3    hash::{DefaultHasher, Hasher},
4    marker::PhantomData,
5};
6
7use rand_distr::num_traits::Zero;
8
9use crate::layouts::{
10    Backend, Data, DataView, DataViewMut, DigestU64, HostDataMut, HostDataRef, VecZnxBig, VecZnxShape, ZnxInfos, ZnxView,
11    ZnxViewMut, ZnxZero,
12};
13
14/// Polynomial vector in DFT (evaluation) domain.
15///
16/// `VecZnxDft` has the same structural shape as [`VecZnx`](crate::layouts::VecZnx)
17/// but stores coefficients as [`Backend::ScalarPrep`] values in the
18/// frequency domain rather than `i64` values in the coefficient domain.
19///
20/// Multiplication and scalar-vector/vector-matrix products are performed
21/// in this representation to exploit FFT-based convolution. Use
22/// [`VecZnxDftApply`](crate::api::VecZnxDftApply) /
23/// [`VecZnxIdftApply`](crate::api::VecZnxIdftApply) to convert
24/// between coefficient and DFT domains.
25#[repr(C)]
26#[derive(PartialEq, Eq)]
27pub struct VecZnxDft<D: Data, B: Backend> {
28    pub data: D,
29    shape: VecZnxShape,
30    pub _phantom: PhantomData<B>,
31}
32
33impl<D: HostDataRef, B: Backend> DigestU64 for VecZnxDft<D, B> {
34    fn digest_u64(&self) -> u64 {
35        let mut h: DefaultHasher = DefaultHasher::new();
36        h.write(self.data.as_ref());
37        h.write_usize(self.n());
38        h.write_usize(self.cols());
39        h.write_usize(self.size());
40        h.write_usize(self.max_size());
41        h.finish()
42    }
43}
44
45impl<D: HostDataRef, B: Backend> ZnxView for VecZnxDft<D, B> {
46    type Scalar = B::ScalarPrep;
47}
48
49impl<D: Data, B: Backend> VecZnxDft<D, B> {
50    pub fn n(&self) -> usize {
51        self.shape.n()
52    }
53
54    pub fn cols(&self) -> usize {
55        self.shape.cols()
56    }
57
58    pub fn size(&self) -> usize {
59        self.shape.size()
60    }
61
62    /// Reinterprets this DFT vector as a [`VecZnxBig`], consuming `self`.
63    ///
64    /// This is a zero-copy conversion that changes only the type tag;
65    /// the underlying data buffer is moved as-is.
66    pub fn into_big(self) -> VecZnxBig<D, B> {
67        let shape = self.shape;
68        VecZnxBig::<D, B>::from_data(self.data, shape.n(), shape.cols(), shape.size())
69    }
70}
71
72impl<D: Data, B: Backend> ZnxInfos for VecZnxDft<D, B> {
73    fn cols(&self) -> usize {
74        self.shape.cols()
75    }
76
77    fn rows(&self) -> usize {
78        1
79    }
80
81    fn n(&self) -> usize {
82        self.shape.n()
83    }
84
85    fn size(&self) -> usize {
86        self.shape.size()
87    }
88}
89
90impl<D: Data, B: Backend> DataView for VecZnxDft<D, B> {
91    type D = D;
92    fn data(&self) -> &Self::D {
93        &self.data
94    }
95}
96
97impl<D: Data, B: Backend> DataViewMut for VecZnxDft<D, B> {
98    fn data_mut(&mut self) -> &mut Self::D {
99        &mut self.data
100    }
101}
102
103impl<D: Data, B: Backend> VecZnxDft<D, B> {
104    pub fn shape(&self) -> VecZnxShape {
105        self.shape
106    }
107
108    pub fn max_size(&self) -> usize {
109        self.shape.max_size()
110    }
111
112    pub fn with_size(mut self, size: usize) -> Self {
113        assert!(size <= self.max_size());
114        self.shape = self.shape.with_size(size);
115        self
116    }
117}
118
119impl<D: Data, B: Backend> VecZnxDft<D, B> {
120    /// Sets the active limb count.
121    ///
122    /// # Panics
123    ///
124    /// Panics if `size > max_size`.
125    pub fn set_size(&mut self, size: usize) {
126        self.shape = self.shape.with_size(size)
127    }
128}
129
130impl<D: HostDataMut, B: Backend> ZnxZero for VecZnxDft<D, B>
131where
132    Self: ZnxViewMut,
133    <Self as ZnxView>::Scalar: Zero + Copy,
134{
135    fn zero(&mut self) {
136        self.raw_mut().fill(<Self as ZnxView>::Scalar::zero())
137    }
138    fn zero_at(&mut self, i: usize, j: usize) {
139        self.at_mut(i, j).fill(<Self as ZnxView>::Scalar::zero());
140    }
141}
142
143impl<B: Backend> VecZnxDft<<B as Backend>::OwnedBuf, B> {
144    pub fn alloc(n: usize, cols: usize, size: usize) -> Self {
145        let data: <B as Backend>::OwnedBuf = B::alloc_zeroed_bytes(B::bytes_of_vec_znx_dft(n, cols, size));
146        Self {
147            data,
148            shape: VecZnxShape::new(n, cols, size, size),
149            _phantom: PhantomData,
150        }
151    }
152
153    pub fn from_bytes(n: usize, cols: usize, size: usize, bytes: impl Into<Vec<u8>>) -> Self {
154        let data: Vec<u8> = bytes.into();
155        assert!(data.len() == B::bytes_of_vec_znx_dft(n, cols, size));
156        let data: <B as Backend>::OwnedBuf = B::from_host_bytes(&data);
157        Self {
158            data,
159            shape: VecZnxShape::new(n, cols, size, size),
160            _phantom: PhantomData,
161        }
162    }
163}
164
165/// Owned `VecZnxDft` backed by a backend-owned buffer.
166pub type VecZnxDftOwned<B> = VecZnxDft<<B as Backend>::OwnedBuf, B>;
167/// Shared backend-native borrow of a `VecZnxDft`.
168pub type VecZnxDftBackendRef<'a, B> = VecZnxDft<<B as Backend>::BufRef<'a>, B>;
169/// Mutable backend-native borrow of a `VecZnxDft`.
170pub type VecZnxDftBackendMut<'a, B> = VecZnxDft<<B as Backend>::BufMut<'a>, B>;
171
172/// Reborrow a mutable backend-native `VecZnxDft` view as a shared backend-native view.
173pub fn vec_znx_dft_backend_ref_from_mut<'a, 'b, B: Backend + 'b>(
174    vec: &'a VecZnxDftBackendMut<'b, B>,
175) -> VecZnxDftBackendRef<'a, B> {
176    VecZnxDft {
177        data: B::view_ref_mut(&vec.data),
178        shape: vec.shape,
179        _phantom: PhantomData,
180    }
181}
182
183pub fn vec_znx_dft_backend_mut_from_mut<'a, 'b, B: Backend + 'b>(
184    vec: &'a mut VecZnxDftBackendMut<'b, B>,
185) -> VecZnxDftBackendMut<'a, B> {
186    VecZnxDft {
187        data: B::view_mut_ref(&mut vec.data),
188        shape: vec.shape,
189        _phantom: PhantomData,
190    }
191}
192
193impl<D: Data, B: Backend> VecZnxDft<D, B> {
194    /// Constructs a `VecZnxDft` from raw parts without validation.
195    pub fn from_data(data: D, n: usize, cols: usize, size: usize) -> Self {
196        Self {
197            data,
198            shape: VecZnxShape::new(n, cols, size, size),
199            _phantom: PhantomData,
200        }
201    }
202
203    pub fn from_data_with_max_size(data: D, n: usize, cols: usize, size: usize, max_size: usize) -> Self {
204        Self {
205            data,
206            shape: VecZnxShape::new(n, cols, size, max_size),
207            _phantom: PhantomData,
208        }
209    }
210}
211
212/// Borrow a backend-owned `VecZnxDft` using the backend's native view type.
213pub trait VecZnxDftToBackendRef<B: Backend> {
214    fn to_backend_ref(&self) -> VecZnxDftBackendRef<'_, B>;
215}
216
217impl<B: Backend> VecZnxDftToBackendRef<B> for VecZnxDft<B::OwnedBuf, B> {
218    fn to_backend_ref(&self) -> VecZnxDftBackendRef<'_, B> {
219        VecZnxDft {
220            data: B::view(&self.data),
221            shape: self.shape,
222            _phantom: std::marker::PhantomData,
223        }
224    }
225}
226
227impl<'b, B: Backend + 'b> VecZnxDftToBackendRef<B> for &VecZnxDft<B::BufRef<'b>, B> {
228    fn to_backend_ref(&self) -> VecZnxDftBackendRef<'_, B> {
229        VecZnxDft {
230            data: B::view_ref(&self.data),
231            shape: self.shape,
232            _phantom: std::marker::PhantomData,
233        }
234    }
235}
236
237/// Reborrow an already backend-borrowed `VecZnxDft` as a shared backend-native view.
238pub trait VecZnxDftReborrowBackendRef<B: Backend> {
239    fn reborrow_backend_ref(&self) -> VecZnxDftBackendRef<'_, B>;
240}
241
242impl<'b, B: Backend + 'b> VecZnxDftReborrowBackendRef<B> for VecZnxDft<B::BufMut<'b>, B> {
243    fn reborrow_backend_ref(&self) -> VecZnxDftBackendRef<'_, B> {
244        vec_znx_dft_backend_ref_from_mut::<B>(self)
245    }
246}
247
248/// Mutably borrow a backend-owned `VecZnxDft` using the backend's native view type.
249pub trait VecZnxDftToBackendMut<B: Backend> {
250    fn to_backend_mut(&mut self) -> VecZnxDftBackendMut<'_, B>;
251}
252
253impl<B: Backend> VecZnxDftToBackendMut<B> for VecZnxDft<B::OwnedBuf, B> {
254    fn to_backend_mut(&mut self) -> VecZnxDftBackendMut<'_, B> {
255        VecZnxDft {
256            data: B::view_mut(&mut self.data),
257            shape: self.shape,
258            _phantom: std::marker::PhantomData,
259        }
260    }
261}
262
263impl<'b, B: Backend + 'b> VecZnxDftToBackendMut<B> for &mut VecZnxDft<B::BufMut<'b>, B> {
264    fn to_backend_mut(&mut self) -> VecZnxDftBackendMut<'_, B> {
265        vec_znx_dft_backend_mut_from_mut::<B>(self)
266    }
267}
268
269/// Reborrow an already backend-borrowed `VecZnxDft` as a mutable backend-native view.
270pub trait VecZnxDftReborrowBackendMut<B: Backend> {
271    fn reborrow_backend_mut(&mut self) -> VecZnxDftBackendMut<'_, B>;
272}
273
274impl<'b, B: Backend + 'b> VecZnxDftReborrowBackendMut<B> for VecZnxDft<B::BufMut<'b>, B> {
275    fn reborrow_backend_mut(&mut self) -> VecZnxDftBackendMut<'_, B> {
276        vec_znx_dft_backend_mut_from_mut::<B>(self)
277    }
278}
279
280impl<D: HostDataRef, B: Backend> fmt::Display for VecZnxDft<D, B> {
281    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
282        writeln!(f, "VecZnxDft(n={}, cols={}, size={})", self.n(), self.cols(), self.size())?;
283
284        for col in 0..self.cols() {
285            writeln!(f, "Column {col}:")?;
286            for size in 0..self.size() {
287                let coeffs = self.at(col, size);
288                write!(f, "  Size {size}: [")?;
289
290                let max_show = 100;
291                let show_count = coeffs.len().min(max_show);
292
293                for (i, &coeff) in coeffs.iter().take(show_count).enumerate() {
294                    if i > 0 {
295                        write!(f, ", ")?;
296                    }
297                    write!(f, "{coeff}")?;
298                }
299
300                if coeffs.len() > max_show {
301                    write!(f, ", ... ({} more)", coeffs.len() - max_show)?;
302                }
303
304                writeln!(f, "]")?;
305            }
306        }
307        Ok(())
308    }
309}