Skip to main content

poulpy_hal/layouts/
znx_base.rs

1use std::fmt::{Debug, Display};
2
3use crate::{
4    layouts::{Data, HostDataMut, HostDataRef},
5    source::Source,
6};
7use bytemuck::Pod;
8use rand_distr::num_traits::Zero;
9
10/// Metadata trait providing the shape of a polynomial container.
11///
12/// Every layout type in this crate implements `ZnxInfos` to expose its
13/// ring degree, row/column counts, and limb count.
14pub trait ZnxInfos {
15    /// Returns the ring degree `N` of the polynomials in `Z[X]/(X^N + 1)`.
16    fn n(&self) -> usize;
17
18    /// Returns the base two logarithm of the ring dimension of the polynomials.
19    fn log_n(&self) -> usize {
20        (usize::BITS - (self.n() - 1).leading_zeros()) as _
21    }
22
23    /// Returns the number of rows.
24    fn rows(&self) -> usize;
25
26    /// Returns the number of polynomials in each row.
27    fn cols(&self) -> usize;
28
29    /// Returns the number of limbs per polynomial.
30    fn size(&self) -> usize;
31
32    /// Returns the total number of small polynomials.
33    fn poly_count(&self) -> usize {
34        self.rows() * self.cols() * self.size()
35    }
36}
37
38/// Read-only access to the underlying data container of a layout type.
39pub trait DataView {
40    type D: Data;
41    fn data(&self) -> &Self::D;
42}
43
44/// Mutable access to the underlying data container of a layout type.
45pub trait DataViewMut: DataView {
46    fn data_mut(&mut self) -> &mut Self::D;
47}
48
49/// Read-only view into a polynomial container's coefficient data.
50///
51/// Coefficients are stored in a **limb-major, column-minor** layout.
52/// For a container with `cols` columns and `size` limbs, limb `j` of
53/// column `i` starts at scalar offset `n * (j * cols + i)`.
54///
55/// The associated `Scalar` type is `i64` for coefficient-domain types
56/// and a backend-specific type for DFT/big representations.
57pub trait ZnxView: ZnxInfos + DataView<D: HostDataRef> {
58    type Scalar: Copy + Zero + Display + Debug + Pod;
59
60    /// Returns a non-mutable pointer to the underlying coefficients array.
61    fn as_ptr(&self) -> *const Self::Scalar {
62        self.data().as_ref().as_ptr() as *const Self::Scalar
63    }
64
65    /// Returns a non-mutable reference to the entire underlying coefficient array.
66    fn raw(&self) -> &[Self::Scalar] {
67        unsafe { std::slice::from_raw_parts(self.as_ptr(), self.n() * self.poly_count()) }
68    }
69
70    /// Returns a non-mutable pointer starting at the j-th small polynomial of the i-th column.
71    fn at_ptr(&self, i: usize, j: usize) -> *const Self::Scalar {
72        assert!(i < self.cols(), "cols: {} >= self.cols(): {}", i, self.cols());
73        assert!(j < self.size(), "size: {} >= self.size(): {}", j, self.size());
74        let offset: usize = self.n() * (j * self.cols() + i);
75        unsafe { self.as_ptr().add(offset) }
76    }
77
78    /// Returns non-mutable reference to the (i, j)-th small polynomial.
79    fn at(&self, i: usize, j: usize) -> &[Self::Scalar] {
80        unsafe { std::slice::from_raw_parts(self.at_ptr(i, j), self.n()) }
81    }
82}
83
84/// Mutable view into a polynomial container's coefficient data.
85///
86/// Extends [`ZnxView`] with mutable pointer and slice accessors.
87pub trait ZnxViewMut: ZnxView + DataViewMut<D: HostDataMut> {
88    /// Returns a mutable pointer to the underlying coefficients array.
89    fn as_mut_ptr(&mut self) -> *mut Self::Scalar {
90        self.data_mut().as_mut().as_mut_ptr() as *mut Self::Scalar
91    }
92
93    /// Returns a mutable reference to the entire underlying coefficient array.
94    fn raw_mut(&mut self) -> &mut [Self::Scalar] {
95        unsafe { std::slice::from_raw_parts_mut(self.as_mut_ptr(), self.n() * self.poly_count()) }
96    }
97
98    /// Returns a mutable pointer starting at the j-th small polynomial of the i-th column.
99    fn at_mut_ptr(&mut self, i: usize, j: usize) -> *mut Self::Scalar {
100        assert!(i < self.cols(), "cols: {} >= self.cols(): {}", i, self.cols());
101        assert!(j < self.size(), "size: {} >= self.size(): {}", j, self.size());
102        let offset: usize = self.n() * (j * self.cols() + i);
103        unsafe { self.as_mut_ptr().add(offset) }
104    }
105
106    /// Returns mutable reference to the (i, j)-th small polynomial.
107    fn at_mut(&mut self, i: usize, j: usize) -> &mut [Self::Scalar] {
108        unsafe { std::slice::from_raw_parts_mut(self.at_mut_ptr(i, j), self.n()) }
109    }
110}
111
112// Note: Cannot provide blanket impl of ZnxView because Scalar is not known.
113impl<T> ZnxViewMut for T where T: ZnxView + DataViewMut<D: HostDataMut> {}
114
115/// Zero-fill operations for polynomial containers.
116pub trait ZnxZero
117where
118    Self: Sized,
119{
120    /// Sets all coefficients across all columns and limbs to zero.
121    fn zero(&mut self);
122    /// Sets all coefficients of limb `j` of column `i` to zero.
123    fn zero_at(&mut self, i: usize, j: usize);
124}
125
126/// Fill a polynomial container with uniformly distributed random coefficients.
127pub trait FillUniform {
128    /// Fills all coefficients with values drawn uniformly from
129    /// `[-2^(log_bound-1), 2^(log_bound-1))`.
130    ///
131    /// When `log_bound == 64`, all 64 bits are used (full `i64` range).
132    ///
133    /// # Panics
134    ///
135    /// Panics if `log_bound == 0`.
136    fn fill_uniform(&mut self, log_bound: usize, source: &mut Source);
137}