Skip to main content

poulpy_hal/layouts/
mat_znx.rs

1use crate::{
2    alloc_aligned,
3    layouts::{
4        Backend, Data, DataView, DataViewMut, DigestU64, FillUniform, HostDataMut, HostDataRef, ReaderFrom, ToOwnedDeep, VecZnx,
5        WriterTo, ZnxInfos, ZnxView, ZnxViewMut, ZnxZero,
6    },
7    source::Source,
8};
9use std::{
10    fmt,
11    hash::{DefaultHasher, Hasher},
12};
13
14use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
15use rand::Rng;
16
17#[repr(C)]
18#[derive(PartialEq, Eq, Clone, Copy, Hash, Debug, Default)]
19pub struct MatZnxShape {
20    n: usize,
21    size: usize,
22    rows: usize,
23    cols_in: usize,
24    cols_out: usize,
25}
26
27impl MatZnxShape {
28    pub const fn new(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> Self {
29        Self {
30            n,
31            size,
32            rows,
33            cols_in,
34            cols_out,
35        }
36    }
37
38    pub const fn n(self) -> usize {
39        self.n
40    }
41
42    pub const fn size(self) -> usize {
43        self.size
44    }
45
46    pub const fn rows(self) -> usize {
47        self.rows
48    }
49
50    pub const fn cols_in(self) -> usize {
51        self.cols_in
52    }
53
54    pub const fn cols_out(self) -> usize {
55        self.cols_out
56    }
57}
58
59/// Matrix of polynomials in `Z[X]/(X^N + 1)`.
60///
61/// A `MatZnx` has `rows` rows, each containing `cols_in` entries.
62/// Each entry is itself a [`VecZnx`] with `cols_out` columns and `size` limbs.
63/// This gives a total of `rows * cols_in * cols_out * size` small polynomials.
64///
65/// Used primarily as the plaintext input to [`VmpPrepare`](crate::api::VmpPrepare),
66/// which converts a `MatZnx` into a prepared [`VmpPMat`](crate::layouts::VmpPMat)
67/// for vector-matrix products.
68#[repr(C)]
69#[derive(PartialEq, Eq, Clone, Hash)]
70pub struct MatZnx<D: Data> {
71    data: D,
72    shape: MatZnxShape,
73}
74
75impl<D: HostDataRef> DigestU64 for MatZnx<D> {
76    fn digest_u64(&self) -> u64 {
77        let mut h: DefaultHasher = DefaultHasher::new();
78        h.write(self.data.as_ref());
79        h.write_usize(self.n());
80        h.write_usize(self.size());
81        h.write_usize(self.rows());
82        h.write_usize(self.cols_in());
83        h.write_usize(self.cols_out());
84        h.finish()
85    }
86}
87
88impl<D: HostDataRef> ToOwnedDeep for MatZnx<D> {
89    type Owned = MatZnx<Vec<u8>>;
90    fn to_owned_deep(&self) -> Self::Owned {
91        MatZnx {
92            data: self.data.as_ref().to_vec(),
93            shape: self.shape,
94        }
95    }
96}
97
98impl<D: HostDataRef> fmt::Debug for MatZnx<D> {
99    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
100        write!(f, "{self}")
101    }
102}
103
104impl<D: Data> ZnxInfos for MatZnx<D> {
105    fn cols(&self) -> usize {
106        self.shape.cols_in()
107    }
108
109    fn rows(&self) -> usize {
110        self.shape.rows()
111    }
112
113    fn n(&self) -> usize {
114        self.shape.n()
115    }
116
117    fn size(&self) -> usize {
118        self.shape.size()
119    }
120
121    fn poly_count(&self) -> usize {
122        self.rows() * self.cols_in() * self.cols_out() * self.size()
123    }
124}
125
126impl<D: Data> DataView for MatZnx<D> {
127    type D = D;
128    fn data(&self) -> &Self::D {
129        &self.data
130    }
131}
132
133impl<D: Data> DataViewMut for MatZnx<D> {
134    fn data_mut(&mut self) -> &mut Self::D {
135        &mut self.data
136    }
137}
138
139impl<D: HostDataRef> ZnxView for MatZnx<D> {
140    type Scalar = i64;
141}
142
143impl<D: Data> MatZnx<D> {
144    pub fn shape(&self) -> MatZnxShape {
145        self.shape
146    }
147
148    pub fn n(&self) -> usize {
149        self.shape.n()
150    }
151
152    pub fn rows(&self) -> usize {
153        self.shape.rows()
154    }
155
156    pub fn size(&self) -> usize {
157        self.shape.size()
158    }
159
160    /// Returns the number of input columns (first matrix dimension after rows).
161    pub fn cols_in(&self) -> usize {
162        self.shape.cols_in()
163    }
164
165    /// Returns the number of output columns (the column count of each inner [`VecZnx`]).
166    pub fn cols_out(&self) -> usize {
167        self.shape.cols_out()
168    }
169
170    /// Consumes the `MatZnx` and returns its backing data.
171    pub fn into_data(self) -> D {
172        self.data
173    }
174}
175
176impl MatZnx<Vec<u8>> {
177    /// Returns the number of bytes required to store the matrix.
178    pub fn bytes_of(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
179        rows * cols_in * VecZnx::<Vec<u8>>::bytes_of(n, cols_out, size)
180    }
181
182    /// Allocates a zero-initialized `MatZnx` aligned to [`DEFAULTALIGN`](crate::DEFAULTALIGN).
183    pub(crate) fn alloc(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> Self {
184        let data: Vec<u8> = alloc_aligned(Self::bytes_of(n, rows, cols_in, cols_out, size));
185        Self {
186            data,
187            shape: MatZnxShape::new(n, rows, cols_in, cols_out, size),
188        }
189    }
190
191    pub fn from_bytes(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize, bytes: impl Into<Vec<u8>>) -> Self {
192        let data: Vec<u8> = bytes.into();
193        assert!(data.len() == Self::bytes_of(n, rows, cols_in, cols_out, size));
194        crate::assert_alignment(data.as_ptr());
195        Self {
196            data,
197            shape: MatZnxShape::new(n, rows, cols_in, cols_out, size),
198        }
199    }
200}
201
202impl<D: HostDataRef> MatZnx<D> {
203    /// Returns a shared [`VecZnx`] view of the entry at `(row, col)`.
204    ///
205    /// # Panics (debug)
206    ///
207    /// Debug-asserts that `row < rows` and `col < cols_in`.
208    pub fn at(&self, row: usize, col: usize) -> VecZnx<&[u8]> {
209        #[cfg(debug_assertions)]
210        {
211            assert!(row < self.rows(), "rows: {} >= {}", row, self.rows());
212            assert!(col < self.cols_in(), "cols: {} >= {}", col, self.cols_in());
213        }
214
215        let self_ref = MatZnx {
216            data: self.data.as_ref(),
217            shape: self.shape,
218        };
219        let nb_bytes: usize = VecZnx::<Vec<u8>>::bytes_of(self.n(), self.cols_out(), self.size());
220        let start: usize = nb_bytes * self.cols() * row + col * nb_bytes;
221        let end: usize = start + nb_bytes;
222
223        VecZnx::from_data(&self_ref.data[start..end], self.n(), self.cols_out(), self.size())
224    }
225}
226
227impl<D: HostDataMut> MatZnx<D> {
228    /// Returns a mutable [`VecZnx`] view of the entry at `(row, col)`.
229    ///
230    /// # Panics (debug)
231    ///
232    /// Debug-asserts that `row < rows` and `col < cols_in`.
233    pub fn at_mut(&mut self, row: usize, col: usize) -> VecZnx<&mut [u8]> {
234        #[cfg(debug_assertions)]
235        {
236            assert!(row < self.rows(), "rows: {} >= {}", row, self.rows());
237            assert!(col < self.cols_in(), "cols: {} >= {}", col, self.cols_in());
238        }
239
240        let n: usize = self.n();
241        let rows: usize = self.rows();
242        let cols_out: usize = self.cols_out();
243        let cols_in: usize = self.cols_in();
244        let size: usize = self.size();
245
246        let self_ref = MatZnx {
247            data: self.data.as_mut(),
248            shape: MatZnxShape::new(n, rows, cols_in, cols_out, size),
249        };
250        let nb_bytes: usize = VecZnx::<Vec<u8>>::bytes_of(n, cols_out, size);
251        let start: usize = nb_bytes * cols_in * row + col * nb_bytes;
252        let end: usize = start + nb_bytes;
253
254        VecZnx::from_data(&mut self_ref.data[start..end], n, cols_out, size)
255    }
256}
257
258/// Returns a shared backend-native entry view of a backend-owned `MatZnx`.
259pub trait MatZnxAtBackendRef<B: Backend> {
260    fn at_backend(&self, row: usize, col: usize) -> VecZnx<B::BufRef<'_>>;
261}
262
263impl<B: Backend> MatZnxAtBackendRef<B> for MatZnx<B::OwnedBuf> {
264    fn at_backend(&self, row: usize, col: usize) -> VecZnx<B::BufRef<'_>> {
265        #[cfg(debug_assertions)]
266        {
267            assert!(row < self.rows(), "rows: {} >= {}", row, self.rows());
268            assert!(col < self.cols_in(), "cols: {} >= {}", col, self.cols_in());
269        }
270
271        let nb_bytes: usize = VecZnx::<Vec<u8>>::bytes_of(self.n(), self.cols_out(), self.size());
272        let start: usize = nb_bytes * self.cols() * row + col * nb_bytes;
273        let end: usize = start + nb_bytes;
274
275        VecZnx::from_data(
276            B::region(&self.data, start, end - start),
277            self.n(),
278            self.cols_out(),
279            self.size(),
280        )
281    }
282}
283
284pub fn mat_znx_at_backend_ref_from_ref<'a, 'b, B: Backend + 'b>(
285    mat: &'a MatZnx<B::BufRef<'b>>,
286    row: usize,
287    col: usize,
288) -> VecZnx<B::BufRef<'a>> {
289    #[cfg(debug_assertions)]
290    {
291        assert!(row < mat.rows(), "rows: {} >= {}", row, mat.rows());
292        assert!(col < mat.cols_in(), "cols: {} >= {}", col, mat.cols_in());
293    }
294
295    let nb_bytes: usize = VecZnx::<Vec<u8>>::bytes_of(mat.n(), mat.cols_out(), mat.size());
296    let start: usize = nb_bytes * mat.cols() * row + col * nb_bytes;
297    let end: usize = start + nb_bytes;
298
299    VecZnx::from_data(
300        B::region_ref(&mat.data, start, end - start),
301        mat.n(),
302        mat.cols_out(),
303        mat.size(),
304    )
305}
306
307pub fn mat_znx_at_backend_ref_from_mut<'a, 'b, B: Backend + 'b>(
308    mat: &'a MatZnx<B::BufMut<'b>>,
309    row: usize,
310    col: usize,
311) -> VecZnx<B::BufRef<'a>> {
312    #[cfg(debug_assertions)]
313    {
314        assert!(row < mat.rows(), "rows: {} >= {}", row, mat.rows());
315        assert!(col < mat.cols_in(), "cols: {} >= {}", col, mat.cols_in());
316    }
317
318    let nb_bytes: usize = VecZnx::<Vec<u8>>::bytes_of(mat.n(), mat.cols_out(), mat.size());
319    let start: usize = nb_bytes * mat.cols() * row + col * nb_bytes;
320    let end: usize = start + nb_bytes;
321
322    VecZnx::from_data(
323        B::region_ref_mut(&mat.data, start, end - start),
324        mat.n(),
325        mat.cols_out(),
326        mat.size(),
327    )
328}
329
330/// Returns a mutable backend-native entry view of a backend-owned `MatZnx`.
331pub trait MatZnxAtBackendMut<B: Backend> {
332    fn at_backend_mut(&mut self, row: usize, col: usize) -> VecZnx<B::BufMut<'_>>;
333}
334
335impl<B: Backend> MatZnxAtBackendMut<B> for MatZnx<B::OwnedBuf> {
336    fn at_backend_mut(&mut self, row: usize, col: usize) -> VecZnx<B::BufMut<'_>> {
337        #[cfg(debug_assertions)]
338        {
339            assert!(row < self.rows(), "rows: {} >= {}", row, self.rows());
340            assert!(col < self.cols_in(), "cols: {} >= {}", col, self.cols_in());
341        }
342
343        let n: usize = self.n();
344        let cols_out: usize = self.cols_out();
345        let cols_in: usize = self.cols_in();
346        let size: usize = self.size();
347        let nb_bytes: usize = VecZnx::<Vec<u8>>::bytes_of(n, cols_out, size);
348        let start: usize = nb_bytes * cols_in * row + col * nb_bytes;
349        let end: usize = start + nb_bytes;
350
351        VecZnx::from_data(B::region_mut(&mut self.data, start, end - start), n, cols_out, size)
352    }
353}
354
355pub fn mat_znx_at_backend_mut_from_mut<'a, 'b, B: Backend + 'b>(
356    mat: &'a mut MatZnx<B::BufMut<'b>>,
357    row: usize,
358    col: usize,
359) -> VecZnx<B::BufMut<'a>> {
360    #[cfg(debug_assertions)]
361    {
362        assert!(row < mat.rows(), "rows: {} >= {}", row, mat.rows());
363        assert!(col < mat.cols_in(), "cols: {} >= {}", col, mat.cols_in());
364    }
365
366    let n: usize = mat.n();
367    let cols_out: usize = mat.cols_out();
368    let cols_in: usize = mat.cols_in();
369    let size: usize = mat.size();
370    let nb_bytes: usize = VecZnx::<Vec<u8>>::bytes_of(n, cols_out, size);
371    let start: usize = nb_bytes * cols_in * row + col * nb_bytes;
372    let end: usize = start + nb_bytes;
373
374    VecZnx::from_data(B::region_mut_ref(&mut mat.data, start, end - start), n, cols_out, size)
375}
376
377impl<D: HostDataMut> FillUniform for MatZnx<D> {
378    fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) {
379        match log_bound {
380            64 => source.fill_bytes(self.data.as_mut()),
381            0 => panic!("invalid log_bound, cannot be zero"),
382            _ => {
383                let mask: u64 = (1u64 << log_bound) - 1;
384                for x in self.raw_mut().iter_mut() {
385                    let r = source.next_u64() & mask;
386                    *x = ((r << (64 - log_bound)) as i64) >> (64 - log_bound);
387                }
388            }
389        }
390    }
391}
392
393/// Owned `MatZnx` backed by a `Vec<u8>`.
394pub type MatZnxOwned = MatZnx<Vec<u8>>;
395/// Mutably borrowed `MatZnx`.
396pub type MatZnxMut<'a> = MatZnx<&'a mut [u8]>;
397/// Immutably borrowed `MatZnx`.
398pub type MatZnxRef<'a> = MatZnx<&'a [u8]>;
399/// Shared backend-native borrow of a `MatZnx`.
400pub type MatZnxBackendRef<'a, B> = MatZnx<<B as Backend>::BufRef<'a>>;
401/// Mutable backend-native borrow of a `MatZnx`.
402pub type MatZnxBackendMut<'a, B> = MatZnx<<B as Backend>::BufMut<'a>>;
403
404/// Borrow a backend-owned `MatZnx` using the backend's native view type.
405pub trait MatZnxToBackendRef<B: Backend> {
406    fn to_backend_ref(&self) -> MatZnxBackendRef<'_, B>;
407}
408
409impl<B: Backend> MatZnxToBackendRef<B> for MatZnx<B::OwnedBuf> {
410    fn to_backend_ref(&self) -> MatZnxBackendRef<'_, B> {
411        MatZnx {
412            data: B::view(&self.data),
413            shape: self.shape,
414        }
415    }
416}
417
418impl<'b, B: Backend + 'b> MatZnxToBackendRef<B> for &MatZnx<B::BufRef<'b>> {
419    fn to_backend_ref(&self) -> MatZnxBackendRef<'_, B> {
420        mat_znx_backend_ref_from_ref::<B>(self)
421    }
422}
423
424impl<'b, B: Backend + 'b> MatZnxToBackendRef<B> for &mut MatZnx<B::BufMut<'b>> {
425    fn to_backend_ref(&self) -> MatZnxBackendRef<'_, B> {
426        mat_znx_backend_ref_from_mut::<B>(self)
427    }
428}
429
430pub fn mat_znx_backend_ref_from_ref<'a, 'b, B: Backend + 'b>(mat: &'a MatZnx<B::BufRef<'b>>) -> MatZnxBackendRef<'a, B> {
431    MatZnx {
432        data: B::view_ref(&mat.data),
433        shape: mat.shape,
434    }
435}
436
437pub fn mat_znx_backend_ref_from_mut<'a, 'b, B: Backend + 'b>(mat: &'a MatZnx<B::BufMut<'b>>) -> MatZnxBackendRef<'a, B> {
438    MatZnx {
439        data: B::view_ref_mut(&mat.data),
440        shape: mat.shape,
441    }
442}
443
444/// Mutably borrow a backend-owned `MatZnx` using the backend's native view type.
445pub trait MatZnxToBackendMut<B: Backend> {
446    fn to_backend_mut(&mut self) -> MatZnxBackendMut<'_, B>;
447}
448
449impl<B: Backend> MatZnxToBackendMut<B> for MatZnx<B::OwnedBuf> {
450    fn to_backend_mut(&mut self) -> MatZnxBackendMut<'_, B> {
451        MatZnx {
452            data: B::view_mut(&mut self.data),
453            shape: self.shape,
454        }
455    }
456}
457
458impl<'b, B: Backend + 'b> MatZnxToBackendMut<B> for &mut MatZnx<B::BufMut<'b>> {
459    fn to_backend_mut(&mut self) -> MatZnxBackendMut<'_, B> {
460        mat_znx_backend_mut_from_mut::<B>(self)
461    }
462}
463
464pub fn mat_znx_backend_mut_from_mut<'a, 'b, B: Backend + 'b>(mat: &'a mut MatZnx<B::BufMut<'b>>) -> MatZnxBackendMut<'a, B> {
465    MatZnx {
466        data: B::view_mut_ref(&mut mat.data),
467        shape: mat.shape,
468    }
469}
470
471impl<D: Data> MatZnx<D> {
472    pub fn from_data(data: D, n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> Self {
473        Self {
474            data,
475            shape: MatZnxShape::new(n, rows, cols_in, cols_out, size),
476        }
477    }
478}
479
480impl<D: HostDataMut> ReaderFrom for MatZnx<D> {
481    fn read_from<R: std::io::Read>(&mut self, reader: &mut R) -> std::io::Result<()> {
482        let new_n: usize = reader.read_u64::<LittleEndian>()? as usize;
483        let new_size: usize = reader.read_u64::<LittleEndian>()? as usize;
484        let new_rows: usize = reader.read_u64::<LittleEndian>()? as usize;
485        let new_cols_in: usize = reader.read_u64::<LittleEndian>()? as usize;
486        let new_cols_out: usize = reader.read_u64::<LittleEndian>()? as usize;
487        let len: usize = reader.read_u64::<LittleEndian>()? as usize;
488
489        let expected_len: usize = new_rows * new_cols_in * new_n * new_cols_out * new_size * size_of::<i64>();
490        if expected_len != len {
491            return Err(std::io::Error::new(
492                std::io::ErrorKind::InvalidData,
493                format!(
494                    "MatZnx metadata inconsistent: rows={new_rows} * cols_in={new_cols_in} * n={new_n} * cols_out={new_cols_out} * size={new_size} * 8 = {expected_len} != data len={len}"
495                ),
496            ));
497        }
498
499        let buf: &mut [u8] = self.data.as_mut();
500        if buf.len() < len {
501            return Err(std::io::Error::new(
502                std::io::ErrorKind::InvalidData,
503                format!("MatZnx buffer too small: self.data.len()={} < read len={len}", buf.len()),
504            ));
505        }
506        reader.read_exact(&mut buf[..len])?;
507
508        self.shape = MatZnxShape::new(new_n, new_rows, new_cols_in, new_cols_out, new_size);
509        Ok(())
510    }
511}
512
513impl<D: HostDataRef> WriterTo for MatZnx<D> {
514    fn write_to<W: std::io::Write>(&self, writer: &mut W) -> std::io::Result<()> {
515        writer.write_u64::<LittleEndian>(self.n() as u64)?;
516        writer.write_u64::<LittleEndian>(self.size() as u64)?;
517        writer.write_u64::<LittleEndian>(self.rows() as u64)?;
518        writer.write_u64::<LittleEndian>(self.cols_in() as u64)?;
519        writer.write_u64::<LittleEndian>(self.cols_out() as u64)?;
520        let logical_len: usize = MatZnx::<Vec<u8>>::bytes_of(self.n(), self.rows(), self.cols_in(), self.cols_out(), self.size());
521        let buf: &[u8] = self.data.as_ref();
522        if buf.len() < logical_len {
523            return Err(std::io::Error::new(
524                std::io::ErrorKind::InvalidData,
525                format!(
526                    "MatZnx buffer too small: self.data.len()={} < logical_len={logical_len}",
527                    buf.len()
528                ),
529            ));
530        }
531        writer.write_u64::<LittleEndian>(logical_len as u64)?;
532        writer.write_all(&buf[..logical_len])?;
533        Ok(())
534    }
535}
536
537impl<D: HostDataRef> fmt::Display for MatZnx<D> {
538    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
539        writeln!(
540            f,
541            "MatZnx(n={}, rows={}, cols_in={}, cols_out={}, size={})",
542            self.n(),
543            self.rows(),
544            self.cols_in(),
545            self.cols_out(),
546            self.size()
547        )?;
548
549        for row_i in 0..self.rows() {
550            writeln!(f, "Row {row_i}:")?;
551            for col_i in 0..self.cols_in() {
552                writeln!(f, "cols_in {col_i}:")?;
553                writeln!(f, "{}:", self.at(row_i, col_i))?;
554            }
555        }
556        Ok(())
557    }
558}
559
560impl<D: HostDataMut> ZnxZero for MatZnx<D> {
561    fn zero(&mut self) {
562        self.raw_mut().fill(0)
563    }
564
565    fn zero_at(&mut self, i: usize, j: usize) {
566        self.at_mut(i, j).zero();
567    }
568}