Skip to main content

poulpy_hal/layouts/
vec_znx.rs

1use std::{
2    fmt,
3    hash::{DefaultHasher, Hasher},
4};
5
6use crate::{
7    alloc_aligned,
8    layouts::{
9        Backend, Data, DataView, DataViewMut, DigestU64, FillUniform, HostDataMut, HostDataRef, ReaderFrom, ScalarZnx,
10        ToOwnedDeep, WriterTo, ZnxInfos, ZnxView, ZnxViewMut, ZnxZero,
11    },
12    source::Source,
13};
14
15use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
16use rand::Rng;
17
18/// A vector of polynomials in `Z[X]/(X^N + 1)` with limb-decomposed
19/// (base-2^k) representation.
20///
21/// This is the central data type of the crate. Each `VecZnx` contains
22/// `cols` independent polynomial columns, each decomposed into `size`
23/// limbs of `N` coefficients. Coefficients are `i64` values.
24///
25/// **Memory layout:** limb-major, column-minor. Limb `j` of column `i`
26/// starts at scalar offset `N * (j * cols + i)`.
27///
28/// The type parameter `D` controls ownership: `Vec<u8>` for owned,
29/// `&[u8]` for shared borrows, `&mut [u8]` for mutable borrows.
30///
31/// **Invariant:** `size <= max_size`. The `max_size` field records the
32/// allocated capacity; `size` can be reduced without reallocation.
33#[repr(C)]
34#[derive(PartialEq, Eq, Clone, Copy, Hash, Debug, Default)]
35pub struct VecZnxShape {
36    n: usize,
37    cols: usize,
38    size: usize,
39    max_size: usize,
40}
41
42impl VecZnxShape {
43    pub const fn new(n: usize, cols: usize, size: usize, max_size: usize) -> Self {
44        Self { n, cols, size, max_size }
45    }
46
47    pub const fn n(self) -> usize {
48        self.n
49    }
50
51    pub const fn cols(self) -> usize {
52        self.cols
53    }
54
55    pub const fn size(self) -> usize {
56        self.size
57    }
58
59    pub const fn max_size(self) -> usize {
60        self.max_size
61    }
62
63    pub const fn with_size(self, size: usize) -> Self {
64        assert!(size <= self.max_size);
65        Self { size, ..self }
66    }
67}
68
69#[repr(C)]
70#[derive(PartialEq, Eq, Clone, Copy, Hash)]
71pub struct VecZnx<D: Data> {
72    pub data: D,
73    shape: VecZnxShape,
74}
75
76impl<D: HostDataRef> VecZnx<D> {
77    /// Returns a read-only [`ScalarZnx`] view of a single limb of a single column.
78    pub fn as_scalar_znx_ref(&self, col: usize, limb: usize) -> ScalarZnx<&[u8]> {
79        ScalarZnx::from_data(bytemuck::cast_slice(self.at(col, limb)), self.n(), 1)
80    }
81}
82
83impl<D: HostDataMut> VecZnx<D> {
84    /// Returns a mutable [`ScalarZnx`] view of a single limb of a single column.
85    pub fn as_scalar_znx_mut(&mut self, col: usize, limb: usize) -> ScalarZnx<&mut [u8]> {
86        let n = self.n();
87        ScalarZnx::from_data(bytemuck::cast_slice_mut(self.at_mut(col, limb)), n, 1)
88    }
89}
90
91impl<D: Data + Default> Default for VecZnx<D> {
92    fn default() -> Self {
93        Self {
94            data: D::default(),
95            shape: VecZnxShape::default(),
96        }
97    }
98}
99
100impl<D: HostDataRef> DigestU64 for VecZnx<D> {
101    fn digest_u64(&self) -> u64 {
102        let mut h: DefaultHasher = DefaultHasher::new();
103        h.write(self.data.as_ref());
104        h.write_usize(self.n());
105        h.write_usize(self.cols());
106        h.write_usize(self.size());
107        h.write_usize(self.max_size());
108        h.finish()
109    }
110}
111
112impl<D: HostDataRef> ToOwnedDeep for VecZnx<D> {
113    type Owned = VecZnx<Vec<u8>>;
114    fn to_owned_deep(&self) -> Self::Owned {
115        VecZnx {
116            data: self.data.as_ref().to_vec(),
117            shape: self.shape,
118        }
119    }
120}
121
122impl<D: Data> VecZnx<D> {
123    /// Rebuilds this backend-owned vector as a host-owned [`VecZnx<Vec<u8>>`].
124    pub fn to_host_owned<BE>(&self) -> VecZnx<Vec<u8>>
125    where
126        BE: Backend<OwnedBuf = D>,
127    {
128        let shape = self.shape();
129        VecZnx::from_data_with_max_size(
130            crate::layouts::HostBytesBackend::from_bytes(BE::to_host_bytes(&self.data)),
131            shape.n(),
132            shape.cols(),
133            shape.size(),
134            shape.max_size(),
135        )
136    }
137
138    /// Formats this backend-owned vector through the existing host [`fmt::Display`] implementation.
139    pub fn display_host<BE>(&self) -> String
140    where
141        BE: Backend<OwnedBuf = D>,
142    {
143        self.to_host_owned::<BE>().to_string()
144    }
145}
146
147impl<D: HostDataRef> fmt::Debug for VecZnx<D> {
148    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
149        write!(f, "{self}")
150    }
151}
152
153impl<D: Data> ZnxInfos for VecZnx<D> {
154    fn cols(&self) -> usize {
155        self.shape.cols()
156    }
157
158    fn rows(&self) -> usize {
159        1
160    }
161
162    fn n(&self) -> usize {
163        self.shape.n()
164    }
165
166    fn size(&self) -> usize {
167        self.shape.size()
168    }
169}
170
171impl<D: Data> DataView for VecZnx<D> {
172    type D = D;
173    fn data(&self) -> &Self::D {
174        &self.data
175    }
176}
177
178impl<D: Data> DataViewMut for VecZnx<D> {
179    fn data_mut(&mut self) -> &mut Self::D {
180        &mut self.data
181    }
182}
183
184impl<D: HostDataRef> ZnxView for VecZnx<D> {
185    type Scalar = i64;
186}
187
188impl<D: Data> VecZnx<D> {
189    pub fn n(&self) -> usize {
190        self.shape.n()
191    }
192
193    pub fn cols(&self) -> usize {
194        self.shape.cols()
195    }
196
197    pub fn size(&self) -> usize {
198        self.shape.size()
199    }
200
201    pub fn shape(&self) -> VecZnxShape {
202        self.shape
203    }
204
205    pub fn with_size(mut self, size: usize) -> Self {
206        assert!(size <= self.max_size());
207        self.shape = self.shape.with_size(size);
208        self
209    }
210
211    /// Returns the allocated limb capacity.
212    pub fn max_size(&self) -> usize {
213        self.shape.max_size()
214    }
215}
216
217impl<D: Data> VecZnx<D> {
218    /// Sets the active limb count.
219    ///
220    /// # Panics
221    ///
222    /// Panics if `size > max_size`.
223    pub fn set_size(&mut self, size: usize) {
224        self.shape = self.shape.with_size(size);
225    }
226}
227
228impl VecZnx<Vec<u8>> {
229    /// Returns the scratch space (in bytes) required by right-shift operations.
230    pub fn rsh_tmp_bytes(n: usize) -> usize {
231        n * size_of::<i64>()
232    }
233
234    /// Reallocates the backing buffer so capacity matches the `new_size` limb count.
235    pub fn reallocate_limbs(&mut self, new_size: usize) {
236        if self.size() == new_size {
237            return;
238        }
239
240        let mut compact: Self = Self::alloc(self.n(), self.cols(), new_size);
241        let copy_len = compact.raw().len().min(self.raw().len());
242        compact.raw_mut()[..copy_len].copy_from_slice(&self.raw()[..copy_len]);
243        *self = compact;
244    }
245}
246
247impl<D: HostDataMut> ZnxZero for VecZnx<D> {
248    fn zero(&mut self) {
249        self.raw_mut().fill(0)
250    }
251    fn zero_at(&mut self, i: usize, j: usize) {
252        self.at_mut(i, j).fill(0);
253    }
254}
255
256impl VecZnx<Vec<u8>> {
257    /// Returns the number of bytes required: `n * cols * size * 8`.
258    pub fn bytes_of(n: usize, cols: usize, size: usize) -> usize {
259        n * cols * size * size_of::<i64>()
260    }
261
262    /// Allocates a zero-initialized `VecZnx` aligned to [`DEFAULTALIGN`](crate::DEFAULTALIGN).
263    /// Sets `max_size = size`.
264    pub(crate) fn alloc(n: usize, cols: usize, size: usize) -> Self {
265        let data: Vec<u8> = alloc_aligned::<u8>(Self::bytes_of(n, cols, size));
266        Self {
267            data,
268            shape: VecZnxShape::new(n, cols, size, size),
269        }
270    }
271
272    /// Wraps an existing byte buffer into a `VecZnx`.
273    ///
274    /// # Panics
275    ///
276    /// Panics if the buffer length does not equal `bytes_of(n, cols, size)` or
277    /// the buffer is not aligned to [`DEFAULTALIGN`](crate::DEFAULTALIGN).
278    pub fn from_bytes(n: usize, cols: usize, size: usize, bytes: impl Into<Vec<u8>>) -> Self {
279        let data: Vec<u8> = bytes.into();
280        assert!(
281            data.len() == Self::bytes_of(n, cols, size),
282            "from_bytes: data.len()={} != bytes_of({}, {}, {})={}",
283            data.len(),
284            n,
285            cols,
286            size,
287            Self::bytes_of(n, cols, size)
288        );
289        crate::assert_alignment(data.as_ptr());
290        Self {
291            data,
292            shape: VecZnxShape::new(n, cols, size, size),
293        }
294    }
295}
296
297impl<D: Data> VecZnx<D> {
298    /// Constructs a `VecZnx` from raw parts without validation.
299    /// Sets `max_size = size`.
300    pub fn from_data(data: D, n: usize, cols: usize, size: usize) -> Self {
301        Self {
302            data,
303            shape: VecZnxShape::new(n, cols, size, size),
304        }
305    }
306
307    /// Constructs a `VecZnx` from raw parts, preserving both `size` and `max_size`.
308    ///
309    /// Used by cross-backend transfer to rebuild a layout over a fresh
310    /// buffer without shrinking its capacity.
311    pub fn from_data_with_max_size(data: D, n: usize, cols: usize, size: usize, max_size: usize) -> Self {
312        Self {
313            data,
314            shape: VecZnxShape::new(n, cols, size, max_size),
315        }
316    }
317}
318
319impl<D: HostDataRef> fmt::Display for VecZnx<D> {
320    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
321        writeln!(f, "VecZnx(n={}, cols={}, size={})", self.n(), self.cols(), self.size())?;
322
323        for col in 0..self.cols() {
324            writeln!(f, "Column {col}:")?;
325            for size in 0..self.size() {
326                let coeffs = self.at(col, size);
327                write!(f, "  Size {size}: [")?;
328
329                let max_show = 16;
330                let show_count = coeffs.len().min(max_show);
331
332                for (i, &coeff) in coeffs.iter().take(show_count).enumerate() {
333                    if i > 0 {
334                        write!(f, ", ")?;
335                    }
336                    write!(f, "{coeff}")?;
337                }
338
339                if coeffs.len() > max_show {
340                    write!(f, ", ... ({} more)", coeffs.len() - max_show)?;
341                }
342
343                writeln!(f, "]")?;
344            }
345        }
346        Ok(())
347    }
348}
349
350impl<D: HostDataMut> FillUniform for VecZnx<D> {
351    fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) {
352        match log_bound {
353            64 => source.fill_bytes(self.data.as_mut()),
354            0 => panic!("invalid log_bound, cannot be zero"),
355            _ => {
356                let mask: u64 = (1u64 << log_bound) - 1;
357                for x in self.raw_mut().iter_mut() {
358                    let r = source.next_u64() & mask;
359                    *x = ((r << (64 - log_bound)) as i64) >> (64 - log_bound);
360                }
361            }
362        }
363    }
364}
365
366/// Owned `VecZnx` backed by a `Vec<u8>`.
367pub type VecZnxOwned = VecZnx<Vec<u8>>;
368/// Mutably borrowed `VecZnx`.
369pub type VecZnxMut<'a> = VecZnx<&'a mut [u8]>;
370/// Immutably borrowed `VecZnx`.
371pub type VecZnxRef<'a> = VecZnx<&'a [u8]>;
372/// Shared backend-native borrow of a `VecZnx`.
373pub type VecZnxBackendRef<'a, B> = VecZnx<<B as Backend>::BufRef<'a>>;
374/// Mutable backend-native borrow of a `VecZnx`.
375pub type VecZnxBackendMut<'a, B> = VecZnx<<B as Backend>::BufMut<'a>>;
376
377/// Returns a shared backend-native scalar view into a backend-owned `VecZnx`.
378pub trait VecZnxAsScalarBackendRef<B: Backend> {
379    fn as_scalar_znx_backend_ref(&self, col: usize, limb: usize) -> ScalarZnx<B::BufRef<'_>>;
380}
381
382impl<B: Backend> VecZnxAsScalarBackendRef<B> for VecZnx<B::OwnedBuf> {
383    fn as_scalar_znx_backend_ref(&self, col: usize, limb: usize) -> ScalarZnx<B::BufRef<'_>> {
384        #[cfg(debug_assertions)]
385        {
386            assert!(limb < self.size(), "size: {limb} >= {}", self.size());
387            assert!(col < self.cols(), "cols: {col} >= {}", self.cols());
388        }
389        let start: usize = (limb * self.cols() + col) * self.n() * size_of::<i64>();
390        let len: usize = self.n() * size_of::<i64>();
391        ScalarZnx::from_data(B::region(&self.data, start, len), self.n(), 1)
392    }
393}
394
395/// Returns a mutable backend-native scalar view into a backend-owned `VecZnx`.
396pub trait VecZnxAsScalarBackendMut<B: Backend> {
397    fn as_scalar_znx_backend_mut(&mut self, col: usize, limb: usize) -> ScalarZnx<B::BufMut<'_>>;
398}
399
400impl<B: Backend> VecZnxAsScalarBackendMut<B> for VecZnx<B::OwnedBuf> {
401    fn as_scalar_znx_backend_mut(&mut self, col: usize, limb: usize) -> ScalarZnx<B::BufMut<'_>> {
402        #[cfg(debug_assertions)]
403        {
404            assert!(limb < self.size(), "size: {limb} >= {}", self.size());
405            assert!(col < self.cols(), "cols: {col} >= {}", self.cols());
406        }
407        let n = self.n();
408        let start: usize = (limb * self.cols() + col) * n * size_of::<i64>();
409        let len: usize = n * size_of::<i64>();
410        ScalarZnx::from_data(B::region_mut(&mut self.data, start, len), n, 1)
411    }
412}
413
414/// Borrow a backend-owned `VecZnx` using the backend's native view type.
415pub trait VecZnxToBackendRef<B: Backend = crate::layouts::HostBytesBackend> {
416    fn to_backend_ref(&self) -> VecZnxBackendRef<'_, B>;
417}
418
419impl<B: Backend> VecZnxToBackendRef<B> for VecZnx<B::OwnedBuf> {
420    fn to_backend_ref(&self) -> VecZnxBackendRef<'_, B> {
421        VecZnx {
422            data: B::view(&self.data),
423            shape: self.shape,
424        }
425    }
426}
427
428impl<'b, B: Backend + 'b> VecZnxToBackendRef<B> for &VecZnx<B::BufRef<'b>> {
429    fn to_backend_ref(&self) -> VecZnxBackendRef<'_, B> {
430        vec_znx_backend_ref_from_ref::<B>(self)
431    }
432}
433
434impl VecZnxToBackendRef<crate::layouts::HostBytesBackend> for VecZnx<&mut [u8]> {
435    fn to_backend_ref(&self) -> VecZnxBackendRef<'_, crate::layouts::HostBytesBackend> {
436        VecZnx {
437            data: self.data,
438            shape: self.shape,
439        }
440    }
441}
442
443impl VecZnxToBackendRef<crate::layouts::HostBytesBackend> for VecZnx<&[u8]> {
444    fn to_backend_ref(&self) -> VecZnxBackendRef<'_, crate::layouts::HostBytesBackend> {
445        VecZnx {
446            data: self.data,
447            shape: self.shape,
448        }
449    }
450}
451
452/// Reborrow an already backend-borrowed `VecZnx` as a shared backend-native view.
453pub trait VecZnxReborrowBackendRef<B: Backend = crate::layouts::HostBytesBackend> {
454    fn reborrow_backend_ref(&self) -> VecZnxBackendRef<'_, B>;
455}
456
457pub fn vec_znx_backend_ref_from_ref<'a, 'b, B: Backend + 'b>(vec: &'a VecZnx<B::BufRef<'b>>) -> VecZnxBackendRef<'a, B> {
458    VecZnx {
459        data: B::view_ref(&vec.data),
460        shape: vec.shape,
461    }
462}
463
464pub fn vec_znx_backend_ref_from_mut<'a, 'b, B: Backend + 'b>(vec: &'a VecZnx<B::BufMut<'b>>) -> VecZnxBackendRef<'a, B> {
465    VecZnx {
466        data: B::view_ref_mut(&vec.data),
467        shape: vec.shape,
468    }
469}
470
471impl<'b, B: Backend + 'b> VecZnxReborrowBackendRef<B> for VecZnx<B::BufMut<'b>> {
472    fn reborrow_backend_ref(&self) -> VecZnxBackendRef<'_, B> {
473        vec_znx_backend_ref_from_mut::<B>(self)
474    }
475}
476
477/// Mutably borrow a backend-owned `VecZnx` using the backend's native view type.
478pub trait VecZnxToBackendMut<B: Backend = crate::layouts::HostBytesBackend> {
479    fn to_backend_mut(&mut self) -> VecZnxBackendMut<'_, B>;
480}
481
482impl<B: Backend> VecZnxToBackendMut<B> for VecZnx<B::OwnedBuf> {
483    fn to_backend_mut(&mut self) -> VecZnxBackendMut<'_, B> {
484        VecZnx {
485            data: B::view_mut(&mut self.data),
486            shape: self.shape,
487        }
488    }
489}
490
491impl<'b, B: Backend + 'b> VecZnxToBackendMut<B> for &mut VecZnx<B::BufMut<'b>> {
492    fn to_backend_mut(&mut self) -> VecZnxBackendMut<'_, B> {
493        vec_znx_backend_mut_from_mut::<B>(self)
494    }
495}
496
497impl VecZnxToBackendMut<crate::layouts::HostBytesBackend> for VecZnx<&mut [u8]> {
498    fn to_backend_mut(&mut self) -> VecZnxBackendMut<'_, crate::layouts::HostBytesBackend> {
499        VecZnx {
500            data: self.data,
501            shape: self.shape,
502        }
503    }
504}
505
506/// Reborrow an already backend-borrowed `VecZnx` as a mutable backend-native view.
507pub trait VecZnxReborrowBackendMut<B: Backend = crate::layouts::HostBytesBackend> {
508    fn reborrow_backend_mut(&mut self) -> VecZnxBackendMut<'_, B>;
509}
510
511pub fn vec_znx_host_backend_ref<D: HostDataRef>(vec: &VecZnx<D>) -> VecZnxBackendRef<'_, crate::layouts::HostBytesBackend> {
512    VecZnx {
513        data: vec.data.as_ref(),
514        shape: vec.shape,
515    }
516}
517
518pub fn vec_znx_host_backend_mut<D: HostDataMut>(vec: &mut VecZnx<D>) -> VecZnxBackendMut<'_, crate::layouts::HostBytesBackend> {
519    VecZnx {
520        data: vec.data.as_mut(),
521        shape: vec.shape,
522    }
523}
524
525pub fn vec_znx_backend_mut_from_mut<'a, 'b, B: Backend + 'b>(vec: &'a mut VecZnx<B::BufMut<'b>>) -> VecZnxBackendMut<'a, B> {
526    VecZnx {
527        data: B::view_mut_ref(&mut vec.data),
528        shape: vec.shape,
529    }
530}
531
532impl<'b, B: Backend + 'b> VecZnxReborrowBackendMut<B> for VecZnx<B::BufMut<'b>> {
533    fn reborrow_backend_mut(&mut self) -> VecZnxBackendMut<'_, B> {
534        vec_znx_backend_mut_from_mut::<B>(self)
535    }
536}
537
538impl<D: HostDataMut> ReaderFrom for VecZnx<D> {
539    fn read_from<R: std::io::Read>(&mut self, reader: &mut R) -> std::io::Result<()> {
540        // Read into temporaries first to avoid leaving self in an inconsistent state on error.
541        let new_n: usize = reader.read_u64::<LittleEndian>()? as usize;
542        let new_cols: usize = reader.read_u64::<LittleEndian>()? as usize;
543        let new_size: usize = reader.read_u64::<LittleEndian>()? as usize;
544        let new_max_size: usize = reader.read_u64::<LittleEndian>()? as usize;
545        let len: usize = reader.read_u64::<LittleEndian>()? as usize;
546
547        // Validate metadata consistency: n * cols * size * sizeof(i64) must match data length.
548        let expected_len: usize = new_n * new_cols * new_size * size_of::<i64>();
549        if expected_len != len {
550            return Err(std::io::Error::new(
551                std::io::ErrorKind::InvalidData,
552                format!(
553                    "VecZnx metadata inconsistent: n={new_n} * cols={new_cols} * size={new_size} * 8 = {expected_len} != data len={len}"
554                ),
555            ));
556        }
557
558        let buf: &mut [u8] = self.data.as_mut();
559        if buf.len() < len {
560            return Err(std::io::Error::new(
561                std::io::ErrorKind::InvalidData,
562                format!("VecZnx buffer too small: self.data.len()={} < read len={len}", buf.len()),
563            ));
564        }
565        reader.read_exact(&mut buf[..len])?;
566
567        // Only commit metadata after successful read.
568        self.shape = VecZnxShape::new(new_n, new_cols, new_size, new_max_size);
569        Ok(())
570    }
571}
572
573impl<D: HostDataRef> WriterTo for VecZnx<D> {
574    fn write_to<W: std::io::Write>(&self, writer: &mut W) -> std::io::Result<()> {
575        writer.write_u64::<LittleEndian>(self.n() as u64)?;
576        writer.write_u64::<LittleEndian>(self.cols() as u64)?;
577        writer.write_u64::<LittleEndian>(self.size() as u64)?;
578        writer.write_u64::<LittleEndian>(self.max_size() as u64)?;
579        let coeff_bytes: usize = self.n() * self.cols() * self.size() * size_of::<i64>();
580        let buf: &[u8] = self.data.as_ref();
581        if buf.len() < coeff_bytes {
582            return Err(std::io::Error::new(
583                std::io::ErrorKind::InvalidData,
584                format!(
585                    "VecZnx buffer too small: self.data.len()={} < coeff_bytes={coeff_bytes}",
586                    buf.len()
587                ),
588            ));
589        }
590        writer.write_u64::<LittleEndian>(coeff_bytes as u64)?;
591        writer.write_all(&buf[..coeff_bytes])?;
592        Ok(())
593    }
594}