1use std::{
2 hash::{DefaultHasher, Hasher},
3 marker::PhantomData,
4};
5
6use rand_distr::num_traits::Zero;
7use std::fmt;
8
9use crate::layouts::{
10 Backend, Data, DataView, DataViewMut, DigestU64, HostDataMut, HostDataRef, VecZnxShape, ZnxInfos, ZnxView, ZnxViewMut,
11 ZnxZero,
12};
13
14#[repr(C)]
24#[derive(PartialEq, Eq, Hash)]
25pub struct VecZnxBig<D: Data, B: Backend> {
26 pub data: D,
27 shape: VecZnxShape,
28 pub _phantom: PhantomData<B>,
29}
30
31impl<D: HostDataRef, B: Backend> DigestU64 for VecZnxBig<D, B> {
32 fn digest_u64(&self) -> u64 {
33 let mut h: DefaultHasher = DefaultHasher::new();
34 h.write(self.data.as_ref());
35 h.write_usize(self.n());
36 h.write_usize(self.cols());
37 h.write_usize(self.size());
38 h.write_usize(self.max_size());
39 h.finish()
40 }
41}
42
43impl<D: HostDataRef, B: Backend> ZnxView for VecZnxBig<D, B> {
44 type Scalar = B::ScalarBig;
45}
46
47impl<D: Data, B: Backend> ZnxInfos for VecZnxBig<D, B> {
48 fn cols(&self) -> usize {
49 self.shape.cols()
50 }
51
52 fn rows(&self) -> usize {
53 1
54 }
55
56 fn n(&self) -> usize {
57 self.shape.n()
58 }
59
60 fn size(&self) -> usize {
61 self.shape.size()
62 }
63}
64
65impl<D: Data, B: Backend> DataView for VecZnxBig<D, B> {
66 type D = D;
67 fn data(&self) -> &Self::D {
68 &self.data
69 }
70}
71
72impl<D: Data, B: Backend> DataViewMut for VecZnxBig<D, B> {
73 fn data_mut(&mut self) -> &mut Self::D {
74 &mut self.data
75 }
76}
77
78impl<D: Data, B: Backend> VecZnxBig<D, B> {
79 pub fn n(&self) -> usize {
80 self.shape.n()
81 }
82
83 pub fn cols(&self) -> usize {
84 self.shape.cols()
85 }
86
87 pub fn size(&self) -> usize {
88 self.shape.size()
89 }
90
91 pub fn shape(&self) -> VecZnxShape {
92 self.shape
93 }
94
95 pub fn max_size(&self) -> usize {
96 self.shape.max_size()
97 }
98
99 pub fn with_size(mut self, size: usize) -> Self {
100 assert!(size <= self.max_size());
101 self.shape = self.shape.with_size(size);
102 self
103 }
104
105 pub fn set_size(&mut self, size: usize) {
106 self.shape = self.shape.with_size(size);
107 }
108}
109
110impl<D: HostDataMut, B: Backend> ZnxZero for VecZnxBig<D, B>
111where
112 Self: ZnxViewMut,
113 <Self as ZnxView>::Scalar: Zero + Copy,
114{
115 fn zero(&mut self) {
116 self.raw_mut().fill(<Self as ZnxView>::Scalar::zero())
117 }
118 fn zero_at(&mut self, i: usize, j: usize) {
119 self.at_mut(i, j).fill(<Self as ZnxView>::Scalar::zero());
120 }
121}
122
123impl<B: Backend> VecZnxBig<<B as Backend>::OwnedBuf, B> {
124 pub(crate) fn alloc(n: usize, cols: usize, size: usize) -> Self {
125 let data: <B as Backend>::OwnedBuf = B::alloc_zeroed_bytes(B::bytes_of_vec_znx_big(n, cols, size));
126 Self {
127 data,
128 shape: VecZnxShape::new(n, cols, size, size),
129 _phantom: PhantomData,
130 }
131 }
132
133 pub fn from_bytes(n: usize, cols: usize, size: usize, bytes: impl Into<Vec<u8>>) -> Self {
134 let data: Vec<u8> = bytes.into();
135 assert!(data.len() == B::bytes_of_vec_znx_big(n, cols, size));
136 let data: <B as Backend>::OwnedBuf = B::from_host_bytes(&data);
137 Self {
138 data,
139 shape: VecZnxShape::new(n, cols, size, size),
140 _phantom: PhantomData,
141 }
142 }
143}
144
145impl<D: Data, B: Backend> VecZnxBig<D, B> {
146 pub fn from_data(data: D, n: usize, cols: usize, size: usize) -> Self {
147 Self {
148 data,
149 shape: VecZnxShape::new(n, cols, size, size),
150 _phantom: PhantomData,
151 }
152 }
153
154 pub fn from_data_with_max_size(data: D, n: usize, cols: usize, size: usize, max_size: usize) -> Self {
155 Self {
156 data,
157 shape: VecZnxShape::new(n, cols, size, max_size),
158 _phantom: PhantomData,
159 }
160 }
161}
162
163pub type VecZnxBigOwned<B> = VecZnxBig<<B as Backend>::OwnedBuf, B>;
165pub type VecZnxBigBackendRef<'a, B> = VecZnxBig<<B as Backend>::BufRef<'a>, B>;
167pub type VecZnxBigBackendMut<'a, B> = VecZnxBig<<B as Backend>::BufMut<'a>, B>;
169
170pub fn vec_znx_big_backend_ref_from_mut<'a, 'b, B: Backend + 'b>(
172 vec: &'a VecZnxBigBackendMut<'b, B>,
173) -> VecZnxBigBackendRef<'a, B> {
174 VecZnxBig {
175 data: B::view_ref_mut(&vec.data),
176 shape: vec.shape,
177 _phantom: PhantomData,
178 }
179}
180
181pub trait VecZnxBigToBackendRef<B: Backend> {
183 fn to_backend_ref(&self) -> VecZnxBigBackendRef<'_, B>;
184}
185
186impl<B: Backend> VecZnxBigToBackendRef<B> for VecZnxBig<B::OwnedBuf, B> {
187 fn to_backend_ref(&self) -> VecZnxBigBackendRef<'_, B> {
188 VecZnxBig {
189 data: B::view(&self.data),
190 shape: self.shape,
191 _phantom: std::marker::PhantomData,
192 }
193 }
194}
195
196impl<'b, B: Backend + 'b> VecZnxBigToBackendRef<B> for &VecZnxBig<B::BufRef<'b>, B> {
197 fn to_backend_ref(&self) -> VecZnxBigBackendRef<'_, B> {
198 VecZnxBig {
199 data: B::view_ref(&self.data),
200 shape: self.shape,
201 _phantom: std::marker::PhantomData,
202 }
203 }
204}
205
206pub trait VecZnxBigReborrowBackendRef<B: Backend> {
208 fn reborrow_backend_ref(&self) -> VecZnxBigBackendRef<'_, B>;
209}
210
211impl<'b, B: Backend + 'b> VecZnxBigReborrowBackendRef<B> for VecZnxBig<B::BufMut<'b>, B> {
212 fn reborrow_backend_ref(&self) -> VecZnxBigBackendRef<'_, B> {
213 vec_znx_big_backend_ref_from_mut::<B>(self)
214 }
215}
216
217pub trait VecZnxBigToBackendMut<B: Backend> {
219 fn to_backend_mut(&mut self) -> VecZnxBigBackendMut<'_, B>;
220}
221
222impl<B: Backend> VecZnxBigToBackendMut<B> for VecZnxBig<B::OwnedBuf, B> {
223 fn to_backend_mut(&mut self) -> VecZnxBigBackendMut<'_, B> {
224 VecZnxBig {
225 data: B::view_mut(&mut self.data),
226 shape: self.shape,
227 _phantom: std::marker::PhantomData,
228 }
229 }
230}
231
232impl<'b, B: Backend + 'b> VecZnxBigToBackendMut<B> for &mut VecZnxBig<B::BufMut<'b>, B> {
233 fn to_backend_mut(&mut self) -> VecZnxBigBackendMut<'_, B> {
234 VecZnxBig {
235 data: B::view_mut_ref(&mut self.data),
236 shape: self.shape,
237 _phantom: std::marker::PhantomData,
238 }
239 }
240}
241
242pub trait VecZnxBigReborrowBackendMut<B: Backend> {
244 fn reborrow_backend_mut(&mut self) -> VecZnxBigBackendMut<'_, B>;
245}
246
247impl<'b, B: Backend + 'b> VecZnxBigReborrowBackendMut<B> for VecZnxBig<B::BufMut<'b>, B> {
248 fn reborrow_backend_mut(&mut self) -> VecZnxBigBackendMut<'_, B> {
249 VecZnxBig {
250 data: B::view_mut_ref(&mut self.data),
251 shape: self.shape,
252 _phantom: std::marker::PhantomData,
253 }
254 }
255}
256
257impl<D: HostDataRef, B: Backend> fmt::Display for VecZnxBig<D, B> {
258 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
259 writeln!(f, "VecZnxBig(n={}, cols={}, size={})", self.n(), self.cols(), self.size())?;
260
261 for col in 0..self.cols() {
262 writeln!(f, "Column {col}:")?;
263 for size in 0..self.size() {
264 let coeffs = self.at(col, size);
265 write!(f, " Size {size}: [")?;
266
267 let max_show = 100;
268 let show_count = coeffs.len().min(max_show);
269
270 for (i, &coeff) in coeffs.iter().take(show_count).enumerate() {
271 if i > 0 {
272 write!(f, ", ")?;
273 }
274 write!(f, "{coeff}")?;
275 }
276
277 if coeffs.len() > max_show {
278 write!(f, ", ... ({} more)", coeffs.len() - max_show)?;
279 }
280
281 writeln!(f, "]")?;
282 }
283 }
284 Ok(())
285 }
286}