1use std::{
2 fmt,
3 hash::{DefaultHasher, Hasher},
4};
5
6use crate::{
7 alloc_aligned,
8 layouts::{
9 Backend, Data, DataView, DataViewMut, DigestU64, FillUniform, HostDataMut, HostDataRef, ReaderFrom, ScalarZnx,
10 ToOwnedDeep, WriterTo, ZnxInfos, ZnxView, ZnxViewMut, ZnxZero,
11 },
12 source::Source,
13};
14
15use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
16use rand::Rng;
17
18#[repr(C)]
34#[derive(PartialEq, Eq, Clone, Copy, Hash, Debug, Default)]
35pub struct VecZnxShape {
36 n: usize,
37 cols: usize,
38 size: usize,
39 max_size: usize,
40}
41
42impl VecZnxShape {
43 pub const fn new(n: usize, cols: usize, size: usize, max_size: usize) -> Self {
44 Self { n, cols, size, max_size }
45 }
46
47 pub const fn n(self) -> usize {
48 self.n
49 }
50
51 pub const fn cols(self) -> usize {
52 self.cols
53 }
54
55 pub const fn size(self) -> usize {
56 self.size
57 }
58
59 pub const fn max_size(self) -> usize {
60 self.max_size
61 }
62
63 pub const fn with_size(self, size: usize) -> Self {
64 assert!(size <= self.max_size);
65 Self { size, ..self }
66 }
67}
68
69#[repr(C)]
70#[derive(PartialEq, Eq, Clone, Copy, Hash)]
71pub struct VecZnx<D: Data> {
72 pub data: D,
73 shape: VecZnxShape,
74}
75
76impl<D: HostDataRef> VecZnx<D> {
77 pub fn as_scalar_znx_ref(&self, col: usize, limb: usize) -> ScalarZnx<&[u8]> {
79 ScalarZnx::from_data(bytemuck::cast_slice(self.at(col, limb)), self.n(), 1)
80 }
81}
82
83impl<D: HostDataMut> VecZnx<D> {
84 pub fn as_scalar_znx_mut(&mut self, col: usize, limb: usize) -> ScalarZnx<&mut [u8]> {
86 let n = self.n();
87 ScalarZnx::from_data(bytemuck::cast_slice_mut(self.at_mut(col, limb)), n, 1)
88 }
89}
90
91impl<D: Data + Default> Default for VecZnx<D> {
92 fn default() -> Self {
93 Self {
94 data: D::default(),
95 shape: VecZnxShape::default(),
96 }
97 }
98}
99
100impl<D: HostDataRef> DigestU64 for VecZnx<D> {
101 fn digest_u64(&self) -> u64 {
102 let mut h: DefaultHasher = DefaultHasher::new();
103 h.write(self.data.as_ref());
104 h.write_usize(self.n());
105 h.write_usize(self.cols());
106 h.write_usize(self.size());
107 h.write_usize(self.max_size());
108 h.finish()
109 }
110}
111
112impl<D: HostDataRef> ToOwnedDeep for VecZnx<D> {
113 type Owned = VecZnx<Vec<u8>>;
114 fn to_owned_deep(&self) -> Self::Owned {
115 VecZnx {
116 data: self.data.as_ref().to_vec(),
117 shape: self.shape,
118 }
119 }
120}
121
122impl<D: Data> VecZnx<D> {
123 pub fn to_host_owned<BE>(&self) -> VecZnx<Vec<u8>>
125 where
126 BE: Backend<OwnedBuf = D>,
127 {
128 let shape = self.shape();
129 VecZnx::from_data_with_max_size(
130 crate::layouts::HostBytesBackend::from_bytes(BE::to_host_bytes(&self.data)),
131 shape.n(),
132 shape.cols(),
133 shape.size(),
134 shape.max_size(),
135 )
136 }
137
138 pub fn display_host<BE>(&self) -> String
140 where
141 BE: Backend<OwnedBuf = D>,
142 {
143 self.to_host_owned::<BE>().to_string()
144 }
145}
146
147impl<D: HostDataRef> fmt::Debug for VecZnx<D> {
148 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
149 write!(f, "{self}")
150 }
151}
152
153impl<D: Data> ZnxInfos for VecZnx<D> {
154 fn cols(&self) -> usize {
155 self.shape.cols()
156 }
157
158 fn rows(&self) -> usize {
159 1
160 }
161
162 fn n(&self) -> usize {
163 self.shape.n()
164 }
165
166 fn size(&self) -> usize {
167 self.shape.size()
168 }
169}
170
171impl<D: Data> DataView for VecZnx<D> {
172 type D = D;
173 fn data(&self) -> &Self::D {
174 &self.data
175 }
176}
177
178impl<D: Data> DataViewMut for VecZnx<D> {
179 fn data_mut(&mut self) -> &mut Self::D {
180 &mut self.data
181 }
182}
183
184impl<D: HostDataRef> ZnxView for VecZnx<D> {
185 type Scalar = i64;
186}
187
188impl<D: Data> VecZnx<D> {
189 pub fn n(&self) -> usize {
190 self.shape.n()
191 }
192
193 pub fn cols(&self) -> usize {
194 self.shape.cols()
195 }
196
197 pub fn size(&self) -> usize {
198 self.shape.size()
199 }
200
201 pub fn shape(&self) -> VecZnxShape {
202 self.shape
203 }
204
205 pub fn with_size(mut self, size: usize) -> Self {
206 assert!(size <= self.max_size());
207 self.shape = self.shape.with_size(size);
208 self
209 }
210
211 pub fn max_size(&self) -> usize {
213 self.shape.max_size()
214 }
215}
216
217impl<D: Data> VecZnx<D> {
218 pub fn set_size(&mut self, size: usize) {
224 self.shape = self.shape.with_size(size);
225 }
226}
227
228impl VecZnx<Vec<u8>> {
229 pub fn rsh_tmp_bytes(n: usize) -> usize {
231 n * size_of::<i64>()
232 }
233
234 pub fn reallocate_limbs(&mut self, new_size: usize) {
236 if self.size() == new_size {
237 return;
238 }
239
240 let mut compact: Self = Self::alloc(self.n(), self.cols(), new_size);
241 let copy_len = compact.raw().len().min(self.raw().len());
242 compact.raw_mut()[..copy_len].copy_from_slice(&self.raw()[..copy_len]);
243 *self = compact;
244 }
245}
246
247impl<D: HostDataMut> ZnxZero for VecZnx<D> {
248 fn zero(&mut self) {
249 self.raw_mut().fill(0)
250 }
251 fn zero_at(&mut self, i: usize, j: usize) {
252 self.at_mut(i, j).fill(0);
253 }
254}
255
256impl VecZnx<Vec<u8>> {
257 pub fn bytes_of(n: usize, cols: usize, size: usize) -> usize {
259 n * cols * size * size_of::<i64>()
260 }
261
262 pub(crate) fn alloc(n: usize, cols: usize, size: usize) -> Self {
265 let data: Vec<u8> = alloc_aligned::<u8>(Self::bytes_of(n, cols, size));
266 Self {
267 data,
268 shape: VecZnxShape::new(n, cols, size, size),
269 }
270 }
271
272 pub fn from_bytes(n: usize, cols: usize, size: usize, bytes: impl Into<Vec<u8>>) -> Self {
279 let data: Vec<u8> = bytes.into();
280 assert!(
281 data.len() == Self::bytes_of(n, cols, size),
282 "from_bytes: data.len()={} != bytes_of({}, {}, {})={}",
283 data.len(),
284 n,
285 cols,
286 size,
287 Self::bytes_of(n, cols, size)
288 );
289 crate::assert_alignment(data.as_ptr());
290 Self {
291 data,
292 shape: VecZnxShape::new(n, cols, size, size),
293 }
294 }
295}
296
297impl<D: Data> VecZnx<D> {
298 pub fn from_data(data: D, n: usize, cols: usize, size: usize) -> Self {
301 Self {
302 data,
303 shape: VecZnxShape::new(n, cols, size, size),
304 }
305 }
306
307 pub fn from_data_with_max_size(data: D, n: usize, cols: usize, size: usize, max_size: usize) -> Self {
312 Self {
313 data,
314 shape: VecZnxShape::new(n, cols, size, max_size),
315 }
316 }
317}
318
319impl<D: HostDataRef> fmt::Display for VecZnx<D> {
320 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
321 writeln!(f, "VecZnx(n={}, cols={}, size={})", self.n(), self.cols(), self.size())?;
322
323 for col in 0..self.cols() {
324 writeln!(f, "Column {col}:")?;
325 for size in 0..self.size() {
326 let coeffs = self.at(col, size);
327 write!(f, " Size {size}: [")?;
328
329 let max_show = 16;
330 let show_count = coeffs.len().min(max_show);
331
332 for (i, &coeff) in coeffs.iter().take(show_count).enumerate() {
333 if i > 0 {
334 write!(f, ", ")?;
335 }
336 write!(f, "{coeff}")?;
337 }
338
339 if coeffs.len() > max_show {
340 write!(f, ", ... ({} more)", coeffs.len() - max_show)?;
341 }
342
343 writeln!(f, "]")?;
344 }
345 }
346 Ok(())
347 }
348}
349
350impl<D: HostDataMut> FillUniform for VecZnx<D> {
351 fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) {
352 match log_bound {
353 64 => source.fill_bytes(self.data.as_mut()),
354 0 => panic!("invalid log_bound, cannot be zero"),
355 _ => {
356 let mask: u64 = (1u64 << log_bound) - 1;
357 for x in self.raw_mut().iter_mut() {
358 let r = source.next_u64() & mask;
359 *x = ((r << (64 - log_bound)) as i64) >> (64 - log_bound);
360 }
361 }
362 }
363 }
364}
365
366pub type VecZnxOwned = VecZnx<Vec<u8>>;
368pub type VecZnxMut<'a> = VecZnx<&'a mut [u8]>;
370pub type VecZnxRef<'a> = VecZnx<&'a [u8]>;
372pub type VecZnxBackendRef<'a, B> = VecZnx<<B as Backend>::BufRef<'a>>;
374pub type VecZnxBackendMut<'a, B> = VecZnx<<B as Backend>::BufMut<'a>>;
376
377pub trait VecZnxAsScalarBackendRef<B: Backend> {
379 fn as_scalar_znx_backend_ref(&self, col: usize, limb: usize) -> ScalarZnx<B::BufRef<'_>>;
380}
381
382impl<B: Backend> VecZnxAsScalarBackendRef<B> for VecZnx<B::OwnedBuf> {
383 fn as_scalar_znx_backend_ref(&self, col: usize, limb: usize) -> ScalarZnx<B::BufRef<'_>> {
384 #[cfg(debug_assertions)]
385 {
386 assert!(limb < self.size(), "size: {limb} >= {}", self.size());
387 assert!(col < self.cols(), "cols: {col} >= {}", self.cols());
388 }
389 let start: usize = (limb * self.cols() + col) * self.n() * size_of::<i64>();
390 let len: usize = self.n() * size_of::<i64>();
391 ScalarZnx::from_data(B::region(&self.data, start, len), self.n(), 1)
392 }
393}
394
395pub trait VecZnxAsScalarBackendMut<B: Backend> {
397 fn as_scalar_znx_backend_mut(&mut self, col: usize, limb: usize) -> ScalarZnx<B::BufMut<'_>>;
398}
399
400impl<B: Backend> VecZnxAsScalarBackendMut<B> for VecZnx<B::OwnedBuf> {
401 fn as_scalar_znx_backend_mut(&mut self, col: usize, limb: usize) -> ScalarZnx<B::BufMut<'_>> {
402 #[cfg(debug_assertions)]
403 {
404 assert!(limb < self.size(), "size: {limb} >= {}", self.size());
405 assert!(col < self.cols(), "cols: {col} >= {}", self.cols());
406 }
407 let n = self.n();
408 let start: usize = (limb * self.cols() + col) * n * size_of::<i64>();
409 let len: usize = n * size_of::<i64>();
410 ScalarZnx::from_data(B::region_mut(&mut self.data, start, len), n, 1)
411 }
412}
413
414pub trait VecZnxToBackendRef<B: Backend = crate::layouts::HostBytesBackend> {
416 fn to_backend_ref(&self) -> VecZnxBackendRef<'_, B>;
417}
418
419impl<B: Backend> VecZnxToBackendRef<B> for VecZnx<B::OwnedBuf> {
420 fn to_backend_ref(&self) -> VecZnxBackendRef<'_, B> {
421 VecZnx {
422 data: B::view(&self.data),
423 shape: self.shape,
424 }
425 }
426}
427
428impl<'b, B: Backend + 'b> VecZnxToBackendRef<B> for &VecZnx<B::BufRef<'b>> {
429 fn to_backend_ref(&self) -> VecZnxBackendRef<'_, B> {
430 vec_znx_backend_ref_from_ref::<B>(self)
431 }
432}
433
434impl VecZnxToBackendRef<crate::layouts::HostBytesBackend> for VecZnx<&mut [u8]> {
435 fn to_backend_ref(&self) -> VecZnxBackendRef<'_, crate::layouts::HostBytesBackend> {
436 VecZnx {
437 data: self.data,
438 shape: self.shape,
439 }
440 }
441}
442
443impl VecZnxToBackendRef<crate::layouts::HostBytesBackend> for VecZnx<&[u8]> {
444 fn to_backend_ref(&self) -> VecZnxBackendRef<'_, crate::layouts::HostBytesBackend> {
445 VecZnx {
446 data: self.data,
447 shape: self.shape,
448 }
449 }
450}
451
452pub trait VecZnxReborrowBackendRef<B: Backend = crate::layouts::HostBytesBackend> {
454 fn reborrow_backend_ref(&self) -> VecZnxBackendRef<'_, B>;
455}
456
457pub fn vec_znx_backend_ref_from_ref<'a, 'b, B: Backend + 'b>(vec: &'a VecZnx<B::BufRef<'b>>) -> VecZnxBackendRef<'a, B> {
458 VecZnx {
459 data: B::view_ref(&vec.data),
460 shape: vec.shape,
461 }
462}
463
464pub fn vec_znx_backend_ref_from_mut<'a, 'b, B: Backend + 'b>(vec: &'a VecZnx<B::BufMut<'b>>) -> VecZnxBackendRef<'a, B> {
465 VecZnx {
466 data: B::view_ref_mut(&vec.data),
467 shape: vec.shape,
468 }
469}
470
471impl<'b, B: Backend + 'b> VecZnxReborrowBackendRef<B> for VecZnx<B::BufMut<'b>> {
472 fn reborrow_backend_ref(&self) -> VecZnxBackendRef<'_, B> {
473 vec_znx_backend_ref_from_mut::<B>(self)
474 }
475}
476
477pub trait VecZnxToBackendMut<B: Backend = crate::layouts::HostBytesBackend> {
479 fn to_backend_mut(&mut self) -> VecZnxBackendMut<'_, B>;
480}
481
482impl<B: Backend> VecZnxToBackendMut<B> for VecZnx<B::OwnedBuf> {
483 fn to_backend_mut(&mut self) -> VecZnxBackendMut<'_, B> {
484 VecZnx {
485 data: B::view_mut(&mut self.data),
486 shape: self.shape,
487 }
488 }
489}
490
491impl<'b, B: Backend + 'b> VecZnxToBackendMut<B> for &mut VecZnx<B::BufMut<'b>> {
492 fn to_backend_mut(&mut self) -> VecZnxBackendMut<'_, B> {
493 vec_znx_backend_mut_from_mut::<B>(self)
494 }
495}
496
497impl VecZnxToBackendMut<crate::layouts::HostBytesBackend> for VecZnx<&mut [u8]> {
498 fn to_backend_mut(&mut self) -> VecZnxBackendMut<'_, crate::layouts::HostBytesBackend> {
499 VecZnx {
500 data: self.data,
501 shape: self.shape,
502 }
503 }
504}
505
506pub trait VecZnxReborrowBackendMut<B: Backend = crate::layouts::HostBytesBackend> {
508 fn reborrow_backend_mut(&mut self) -> VecZnxBackendMut<'_, B>;
509}
510
511pub fn vec_znx_host_backend_ref<D: HostDataRef>(vec: &VecZnx<D>) -> VecZnxBackendRef<'_, crate::layouts::HostBytesBackend> {
512 VecZnx {
513 data: vec.data.as_ref(),
514 shape: vec.shape,
515 }
516}
517
518pub fn vec_znx_host_backend_mut<D: HostDataMut>(vec: &mut VecZnx<D>) -> VecZnxBackendMut<'_, crate::layouts::HostBytesBackend> {
519 VecZnx {
520 data: vec.data.as_mut(),
521 shape: vec.shape,
522 }
523}
524
525pub fn vec_znx_backend_mut_from_mut<'a, 'b, B: Backend + 'b>(vec: &'a mut VecZnx<B::BufMut<'b>>) -> VecZnxBackendMut<'a, B> {
526 VecZnx {
527 data: B::view_mut_ref(&mut vec.data),
528 shape: vec.shape,
529 }
530}
531
532impl<'b, B: Backend + 'b> VecZnxReborrowBackendMut<B> for VecZnx<B::BufMut<'b>> {
533 fn reborrow_backend_mut(&mut self) -> VecZnxBackendMut<'_, B> {
534 vec_znx_backend_mut_from_mut::<B>(self)
535 }
536}
537
538impl<D: HostDataMut> ReaderFrom for VecZnx<D> {
539 fn read_from<R: std::io::Read>(&mut self, reader: &mut R) -> std::io::Result<()> {
540 let new_n: usize = reader.read_u64::<LittleEndian>()? as usize;
542 let new_cols: usize = reader.read_u64::<LittleEndian>()? as usize;
543 let new_size: usize = reader.read_u64::<LittleEndian>()? as usize;
544 let new_max_size: usize = reader.read_u64::<LittleEndian>()? as usize;
545 let len: usize = reader.read_u64::<LittleEndian>()? as usize;
546
547 let expected_len: usize = new_n * new_cols * new_size * size_of::<i64>();
549 if expected_len != len {
550 return Err(std::io::Error::new(
551 std::io::ErrorKind::InvalidData,
552 format!(
553 "VecZnx metadata inconsistent: n={new_n} * cols={new_cols} * size={new_size} * 8 = {expected_len} != data len={len}"
554 ),
555 ));
556 }
557
558 let buf: &mut [u8] = self.data.as_mut();
559 if buf.len() < len {
560 return Err(std::io::Error::new(
561 std::io::ErrorKind::InvalidData,
562 format!("VecZnx buffer too small: self.data.len()={} < read len={len}", buf.len()),
563 ));
564 }
565 reader.read_exact(&mut buf[..len])?;
566
567 self.shape = VecZnxShape::new(new_n, new_cols, new_size, new_max_size);
569 Ok(())
570 }
571}
572
573impl<D: HostDataRef> WriterTo for VecZnx<D> {
574 fn write_to<W: std::io::Write>(&self, writer: &mut W) -> std::io::Result<()> {
575 writer.write_u64::<LittleEndian>(self.n() as u64)?;
576 writer.write_u64::<LittleEndian>(self.cols() as u64)?;
577 writer.write_u64::<LittleEndian>(self.size() as u64)?;
578 writer.write_u64::<LittleEndian>(self.max_size() as u64)?;
579 let coeff_bytes: usize = self.n() * self.cols() * self.size() * size_of::<i64>();
580 let buf: &[u8] = self.data.as_ref();
581 if buf.len() < coeff_bytes {
582 return Err(std::io::Error::new(
583 std::io::ErrorKind::InvalidData,
584 format!(
585 "VecZnx buffer too small: self.data.len()={} < coeff_bytes={coeff_bytes}",
586 buf.len()
587 ),
588 ));
589 }
590 writer.write_u64::<LittleEndian>(coeff_bytes as u64)?;
591 writer.write_all(&buf[..coeff_bytes])?;
592 Ok(())
593 }
594}