poulpy_hal/layouts/znx_base.rs
1use std::fmt::{Debug, Display};
2
3use crate::{
4 layouts::{Data, HostDataMut, HostDataRef},
5 source::Source,
6};
7use bytemuck::Pod;
8use rand_distr::num_traits::Zero;
9
10/// Metadata trait providing the shape of a polynomial container.
11///
12/// Every layout type in this crate implements `ZnxInfos` to expose its
13/// ring degree, row/column counts, and limb count.
14pub trait ZnxInfos {
15 /// Returns the ring degree `N` of the polynomials in `Z[X]/(X^N + 1)`.
16 fn n(&self) -> usize;
17
18 /// Returns the base two logarithm of the ring dimension of the polynomials.
19 fn log_n(&self) -> usize {
20 (usize::BITS - (self.n() - 1).leading_zeros()) as _
21 }
22
23 /// Returns the number of rows.
24 fn rows(&self) -> usize;
25
26 /// Returns the number of polynomials in each row.
27 fn cols(&self) -> usize;
28
29 /// Returns the number of limbs per polynomial.
30 fn size(&self) -> usize;
31
32 /// Returns the total number of small polynomials.
33 fn poly_count(&self) -> usize {
34 self.rows() * self.cols() * self.size()
35 }
36}
37
38/// Read-only access to the underlying data container of a layout type.
39pub trait DataView {
40 type D: Data;
41 fn data(&self) -> &Self::D;
42}
43
44/// Mutable access to the underlying data container of a layout type.
45pub trait DataViewMut: DataView {
46 fn data_mut(&mut self) -> &mut Self::D;
47}
48
49/// Read-only view into a polynomial container's coefficient data.
50///
51/// Coefficients are stored in a **limb-major, column-minor** layout.
52/// For a container with `cols` columns and `size` limbs, limb `j` of
53/// column `i` starts at scalar offset `n * (j * cols + i)`.
54///
55/// The associated `Scalar` type is `i64` for coefficient-domain types
56/// and a backend-specific type for DFT/big representations.
57pub trait ZnxView: ZnxInfos + DataView<D: HostDataRef> {
58 type Scalar: Copy + Zero + Display + Debug + Pod;
59
60 /// Returns a non-mutable pointer to the underlying coefficients array.
61 fn as_ptr(&self) -> *const Self::Scalar {
62 self.data().as_ref().as_ptr() as *const Self::Scalar
63 }
64
65 /// Returns a non-mutable reference to the entire underlying coefficient array.
66 fn raw(&self) -> &[Self::Scalar] {
67 unsafe { std::slice::from_raw_parts(self.as_ptr(), self.n() * self.poly_count()) }
68 }
69
70 /// Returns a non-mutable pointer starting at the j-th small polynomial of the i-th column.
71 fn at_ptr(&self, i: usize, j: usize) -> *const Self::Scalar {
72 assert!(i < self.cols(), "cols: {} >= self.cols(): {}", i, self.cols());
73 assert!(j < self.size(), "size: {} >= self.size(): {}", j, self.size());
74 let offset: usize = self.n() * (j * self.cols() + i);
75 unsafe { self.as_ptr().add(offset) }
76 }
77
78 /// Returns non-mutable reference to the (i, j)-th small polynomial.
79 fn at(&self, i: usize, j: usize) -> &[Self::Scalar] {
80 unsafe { std::slice::from_raw_parts(self.at_ptr(i, j), self.n()) }
81 }
82}
83
84/// Mutable view into a polynomial container's coefficient data.
85///
86/// Extends [`ZnxView`] with mutable pointer and slice accessors.
87pub trait ZnxViewMut: ZnxView + DataViewMut<D: HostDataMut> {
88 /// Returns a mutable pointer to the underlying coefficients array.
89 fn as_mut_ptr(&mut self) -> *mut Self::Scalar {
90 self.data_mut().as_mut().as_mut_ptr() as *mut Self::Scalar
91 }
92
93 /// Returns a mutable reference to the entire underlying coefficient array.
94 fn raw_mut(&mut self) -> &mut [Self::Scalar] {
95 unsafe { std::slice::from_raw_parts_mut(self.as_mut_ptr(), self.n() * self.poly_count()) }
96 }
97
98 /// Returns a mutable pointer starting at the j-th small polynomial of the i-th column.
99 fn at_mut_ptr(&mut self, i: usize, j: usize) -> *mut Self::Scalar {
100 assert!(i < self.cols(), "cols: {} >= self.cols(): {}", i, self.cols());
101 assert!(j < self.size(), "size: {} >= self.size(): {}", j, self.size());
102 let offset: usize = self.n() * (j * self.cols() + i);
103 unsafe { self.as_mut_ptr().add(offset) }
104 }
105
106 /// Returns mutable reference to the (i, j)-th small polynomial.
107 fn at_mut(&mut self, i: usize, j: usize) -> &mut [Self::Scalar] {
108 unsafe { std::slice::from_raw_parts_mut(self.at_mut_ptr(i, j), self.n()) }
109 }
110}
111
112// Note: Cannot provide blanket impl of ZnxView because Scalar is not known.
113impl<T> ZnxViewMut for T where T: ZnxView + DataViewMut<D: HostDataMut> {}
114
115/// Zero-fill operations for polynomial containers.
116pub trait ZnxZero
117where
118 Self: Sized,
119{
120 /// Sets all coefficients across all columns and limbs to zero.
121 fn zero(&mut self);
122 /// Sets all coefficients of limb `j` of column `i` to zero.
123 fn zero_at(&mut self, i: usize, j: usize);
124}
125
126/// Fill a polynomial container with uniformly distributed random coefficients.
127pub trait FillUniform {
128 /// Fills all coefficients with values drawn uniformly from
129 /// `[-2^(log_bound-1), 2^(log_bound-1))`.
130 ///
131 /// When `log_bound == 64`, all 64 bits are used (full `i64` range).
132 ///
133 /// # Panics
134 ///
135 /// Panics if `log_bound == 0`.
136 fn fill_uniform(&mut self, log_bound: usize, source: &mut Source);
137}