Skip to main content

poulpy_hal/layouts/
vec_znx_big.rs

1use std::{
2    hash::{DefaultHasher, Hasher},
3    marker::PhantomData,
4};
5
6use rand_distr::num_traits::Zero;
7use std::fmt;
8
9use crate::layouts::{
10    Backend, Data, DataView, DataViewMut, DigestU64, HostDataMut, HostDataRef, VecZnxShape, ZnxInfos, ZnxView, ZnxViewMut,
11    ZnxZero,
12};
13
14/// Extended-precision polynomial vector used as a result accumulator.
15///
16/// `VecZnxBig` has the same structural shape as [`VecZnx`](crate::layouts::VecZnx)
17/// (`cols` columns, `size` limbs, ring degree `N`) but uses
18/// [`Backend::ScalarBig`] as its coefficient type instead of `i64`.
19/// The wider scalar type allows lossless accumulation of intermediate
20/// products before normalization back to `i64` limbs.
21///
22/// The exact scalar width and memory layout are backend-specific.
23#[repr(C)]
24#[derive(PartialEq, Eq, Hash)]
25pub struct VecZnxBig<D: Data, B: Backend> {
26    pub data: D,
27    shape: VecZnxShape,
28    pub _phantom: PhantomData<B>,
29}
30
31impl<D: HostDataRef, B: Backend> DigestU64 for VecZnxBig<D, B> {
32    fn digest_u64(&self) -> u64 {
33        let mut h: DefaultHasher = DefaultHasher::new();
34        h.write(self.data.as_ref());
35        h.write_usize(self.n());
36        h.write_usize(self.cols());
37        h.write_usize(self.size());
38        h.write_usize(self.max_size());
39        h.finish()
40    }
41}
42
43impl<D: HostDataRef, B: Backend> ZnxView for VecZnxBig<D, B> {
44    type Scalar = B::ScalarBig;
45}
46
47impl<D: Data, B: Backend> ZnxInfos for VecZnxBig<D, B> {
48    fn cols(&self) -> usize {
49        self.shape.cols()
50    }
51
52    fn rows(&self) -> usize {
53        1
54    }
55
56    fn n(&self) -> usize {
57        self.shape.n()
58    }
59
60    fn size(&self) -> usize {
61        self.shape.size()
62    }
63}
64
65impl<D: Data, B: Backend> DataView for VecZnxBig<D, B> {
66    type D = D;
67    fn data(&self) -> &Self::D {
68        &self.data
69    }
70}
71
72impl<D: Data, B: Backend> DataViewMut for VecZnxBig<D, B> {
73    fn data_mut(&mut self) -> &mut Self::D {
74        &mut self.data
75    }
76}
77
78impl<D: Data, B: Backend> VecZnxBig<D, B> {
79    pub fn n(&self) -> usize {
80        self.shape.n()
81    }
82
83    pub fn cols(&self) -> usize {
84        self.shape.cols()
85    }
86
87    pub fn size(&self) -> usize {
88        self.shape.size()
89    }
90
91    pub fn shape(&self) -> VecZnxShape {
92        self.shape
93    }
94
95    pub fn max_size(&self) -> usize {
96        self.shape.max_size()
97    }
98
99    pub fn with_size(mut self, size: usize) -> Self {
100        assert!(size <= self.max_size());
101        self.shape = self.shape.with_size(size);
102        self
103    }
104
105    pub fn set_size(&mut self, size: usize) {
106        self.shape = self.shape.with_size(size);
107    }
108}
109
110impl<D: HostDataMut, B: Backend> ZnxZero for VecZnxBig<D, B>
111where
112    Self: ZnxViewMut,
113    <Self as ZnxView>::Scalar: Zero + Copy,
114{
115    fn zero(&mut self) {
116        self.raw_mut().fill(<Self as ZnxView>::Scalar::zero())
117    }
118    fn zero_at(&mut self, i: usize, j: usize) {
119        self.at_mut(i, j).fill(<Self as ZnxView>::Scalar::zero());
120    }
121}
122
123impl<B: Backend> VecZnxBig<<B as Backend>::OwnedBuf, B> {
124    pub(crate) fn alloc(n: usize, cols: usize, size: usize) -> Self {
125        let data: <B as Backend>::OwnedBuf = B::alloc_zeroed_bytes(B::bytes_of_vec_znx_big(n, cols, size));
126        Self {
127            data,
128            shape: VecZnxShape::new(n, cols, size, size),
129            _phantom: PhantomData,
130        }
131    }
132
133    pub fn from_bytes(n: usize, cols: usize, size: usize, bytes: impl Into<Vec<u8>>) -> Self {
134        let data: Vec<u8> = bytes.into();
135        assert!(data.len() == B::bytes_of_vec_znx_big(n, cols, size));
136        let data: <B as Backend>::OwnedBuf = B::from_host_bytes(&data);
137        Self {
138            data,
139            shape: VecZnxShape::new(n, cols, size, size),
140            _phantom: PhantomData,
141        }
142    }
143}
144
145impl<D: Data, B: Backend> VecZnxBig<D, B> {
146    pub fn from_data(data: D, n: usize, cols: usize, size: usize) -> Self {
147        Self {
148            data,
149            shape: VecZnxShape::new(n, cols, size, size),
150            _phantom: PhantomData,
151        }
152    }
153
154    pub fn from_data_with_max_size(data: D, n: usize, cols: usize, size: usize, max_size: usize) -> Self {
155        Self {
156            data,
157            shape: VecZnxShape::new(n, cols, size, max_size),
158            _phantom: PhantomData,
159        }
160    }
161}
162
163/// Owned `VecZnxBig` backed by a backend-owned buffer.
164pub type VecZnxBigOwned<B> = VecZnxBig<<B as Backend>::OwnedBuf, B>;
165/// Shared backend-native borrow of a `VecZnxBig`.
166pub type VecZnxBigBackendRef<'a, B> = VecZnxBig<<B as Backend>::BufRef<'a>, B>;
167/// Mutable backend-native borrow of a `VecZnxBig`.
168pub type VecZnxBigBackendMut<'a, B> = VecZnxBig<<B as Backend>::BufMut<'a>, B>;
169
170/// Reborrow a mutable backend-native `VecZnxBig` view as a shared backend-native view.
171pub fn vec_znx_big_backend_ref_from_mut<'a, 'b, B: Backend + 'b>(
172    vec: &'a VecZnxBigBackendMut<'b, B>,
173) -> VecZnxBigBackendRef<'a, B> {
174    VecZnxBig {
175        data: B::view_ref_mut(&vec.data),
176        shape: vec.shape,
177        _phantom: PhantomData,
178    }
179}
180
181/// Borrow a backend-owned `VecZnxBig` using the backend's native view type.
182pub trait VecZnxBigToBackendRef<B: Backend> {
183    fn to_backend_ref(&self) -> VecZnxBigBackendRef<'_, B>;
184}
185
186impl<B: Backend> VecZnxBigToBackendRef<B> for VecZnxBig<B::OwnedBuf, B> {
187    fn to_backend_ref(&self) -> VecZnxBigBackendRef<'_, B> {
188        VecZnxBig {
189            data: B::view(&self.data),
190            shape: self.shape,
191            _phantom: std::marker::PhantomData,
192        }
193    }
194}
195
196impl<'b, B: Backend + 'b> VecZnxBigToBackendRef<B> for &VecZnxBig<B::BufRef<'b>, B> {
197    fn to_backend_ref(&self) -> VecZnxBigBackendRef<'_, B> {
198        VecZnxBig {
199            data: B::view_ref(&self.data),
200            shape: self.shape,
201            _phantom: std::marker::PhantomData,
202        }
203    }
204}
205
206/// Reborrow an already backend-borrowed `VecZnxBig` as a shared backend-native view.
207pub trait VecZnxBigReborrowBackendRef<B: Backend> {
208    fn reborrow_backend_ref(&self) -> VecZnxBigBackendRef<'_, B>;
209}
210
211impl<'b, B: Backend + 'b> VecZnxBigReborrowBackendRef<B> for VecZnxBig<B::BufMut<'b>, B> {
212    fn reborrow_backend_ref(&self) -> VecZnxBigBackendRef<'_, B> {
213        vec_znx_big_backend_ref_from_mut::<B>(self)
214    }
215}
216
217/// Mutably borrow a backend-owned `VecZnxBig` using the backend's native view type.
218pub trait VecZnxBigToBackendMut<B: Backend> {
219    fn to_backend_mut(&mut self) -> VecZnxBigBackendMut<'_, B>;
220}
221
222impl<B: Backend> VecZnxBigToBackendMut<B> for VecZnxBig<B::OwnedBuf, B> {
223    fn to_backend_mut(&mut self) -> VecZnxBigBackendMut<'_, B> {
224        VecZnxBig {
225            data: B::view_mut(&mut self.data),
226            shape: self.shape,
227            _phantom: std::marker::PhantomData,
228        }
229    }
230}
231
232impl<'b, B: Backend + 'b> VecZnxBigToBackendMut<B> for &mut VecZnxBig<B::BufMut<'b>, B> {
233    fn to_backend_mut(&mut self) -> VecZnxBigBackendMut<'_, B> {
234        VecZnxBig {
235            data: B::view_mut_ref(&mut self.data),
236            shape: self.shape,
237            _phantom: std::marker::PhantomData,
238        }
239    }
240}
241
242/// Reborrow an already backend-borrowed `VecZnxBig` as a mutable backend-native view.
243pub trait VecZnxBigReborrowBackendMut<B: Backend> {
244    fn reborrow_backend_mut(&mut self) -> VecZnxBigBackendMut<'_, B>;
245}
246
247impl<'b, B: Backend + 'b> VecZnxBigReborrowBackendMut<B> for VecZnxBig<B::BufMut<'b>, B> {
248    fn reborrow_backend_mut(&mut self) -> VecZnxBigBackendMut<'_, B> {
249        VecZnxBig {
250            data: B::view_mut_ref(&mut self.data),
251            shape: self.shape,
252            _phantom: std::marker::PhantomData,
253        }
254    }
255}
256
257impl<D: HostDataRef, B: Backend> fmt::Display for VecZnxBig<D, B> {
258    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
259        writeln!(f, "VecZnxBig(n={}, cols={}, size={})", self.n(), self.cols(), self.size())?;
260
261        for col in 0..self.cols() {
262            writeln!(f, "Column {col}:")?;
263            for size in 0..self.size() {
264                let coeffs = self.at(col, size);
265                write!(f, "  Size {size}: [")?;
266
267                let max_show = 100;
268                let show_count = coeffs.len().min(max_show);
269
270                for (i, &coeff) in coeffs.iter().take(show_count).enumerate() {
271                    if i > 0 {
272                        write!(f, ", ")?;
273                    }
274                    write!(f, "{coeff}")?;
275                }
276
277                if coeffs.len() > max_show {
278                    write!(f, ", ... ({} more)", coeffs.len() - max_show)?;
279                }
280
281                writeln!(f, "]")?;
282            }
283        }
284        Ok(())
285    }
286}