1use crate::{
2 alloc_aligned,
3 layouts::{
4 Backend, Data, DataView, DataViewMut, DigestU64, FillUniform, HostDataMut, HostDataRef, ReaderFrom, ToOwnedDeep, VecZnx,
5 WriterTo, ZnxInfos, ZnxView, ZnxViewMut, ZnxZero,
6 },
7 source::Source,
8};
9use std::{
10 fmt,
11 hash::{DefaultHasher, Hasher},
12};
13
14use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
15use rand::Rng;
16
17#[repr(C)]
18#[derive(PartialEq, Eq, Clone, Copy, Hash, Debug, Default)]
19pub struct MatZnxShape {
20 n: usize,
21 size: usize,
22 rows: usize,
23 cols_in: usize,
24 cols_out: usize,
25}
26
27impl MatZnxShape {
28 pub const fn new(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> Self {
29 Self {
30 n,
31 size,
32 rows,
33 cols_in,
34 cols_out,
35 }
36 }
37
38 pub const fn n(self) -> usize {
39 self.n
40 }
41
42 pub const fn size(self) -> usize {
43 self.size
44 }
45
46 pub const fn rows(self) -> usize {
47 self.rows
48 }
49
50 pub const fn cols_in(self) -> usize {
51 self.cols_in
52 }
53
54 pub const fn cols_out(self) -> usize {
55 self.cols_out
56 }
57}
58
59#[repr(C)]
69#[derive(PartialEq, Eq, Clone, Hash)]
70pub struct MatZnx<D: Data> {
71 data: D,
72 shape: MatZnxShape,
73}
74
75impl<D: HostDataRef> DigestU64 for MatZnx<D> {
76 fn digest_u64(&self) -> u64 {
77 let mut h: DefaultHasher = DefaultHasher::new();
78 h.write(self.data.as_ref());
79 h.write_usize(self.n());
80 h.write_usize(self.size());
81 h.write_usize(self.rows());
82 h.write_usize(self.cols_in());
83 h.write_usize(self.cols_out());
84 h.finish()
85 }
86}
87
88impl<D: HostDataRef> ToOwnedDeep for MatZnx<D> {
89 type Owned = MatZnx<Vec<u8>>;
90 fn to_owned_deep(&self) -> Self::Owned {
91 MatZnx {
92 data: self.data.as_ref().to_vec(),
93 shape: self.shape,
94 }
95 }
96}
97
98impl<D: HostDataRef> fmt::Debug for MatZnx<D> {
99 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
100 write!(f, "{self}")
101 }
102}
103
104impl<D: Data> ZnxInfos for MatZnx<D> {
105 fn cols(&self) -> usize {
106 self.shape.cols_in()
107 }
108
109 fn rows(&self) -> usize {
110 self.shape.rows()
111 }
112
113 fn n(&self) -> usize {
114 self.shape.n()
115 }
116
117 fn size(&self) -> usize {
118 self.shape.size()
119 }
120
121 fn poly_count(&self) -> usize {
122 self.rows() * self.cols_in() * self.cols_out() * self.size()
123 }
124}
125
126impl<D: Data> DataView for MatZnx<D> {
127 type D = D;
128 fn data(&self) -> &Self::D {
129 &self.data
130 }
131}
132
133impl<D: Data> DataViewMut for MatZnx<D> {
134 fn data_mut(&mut self) -> &mut Self::D {
135 &mut self.data
136 }
137}
138
139impl<D: HostDataRef> ZnxView for MatZnx<D> {
140 type Scalar = i64;
141}
142
143impl<D: Data> MatZnx<D> {
144 pub fn shape(&self) -> MatZnxShape {
145 self.shape
146 }
147
148 pub fn n(&self) -> usize {
149 self.shape.n()
150 }
151
152 pub fn rows(&self) -> usize {
153 self.shape.rows()
154 }
155
156 pub fn size(&self) -> usize {
157 self.shape.size()
158 }
159
160 pub fn cols_in(&self) -> usize {
162 self.shape.cols_in()
163 }
164
165 pub fn cols_out(&self) -> usize {
167 self.shape.cols_out()
168 }
169
170 pub fn into_data(self) -> D {
172 self.data
173 }
174}
175
176impl MatZnx<Vec<u8>> {
177 pub fn bytes_of(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
179 rows * cols_in * VecZnx::<Vec<u8>>::bytes_of(n, cols_out, size)
180 }
181
182 pub(crate) fn alloc(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> Self {
184 let data: Vec<u8> = alloc_aligned(Self::bytes_of(n, rows, cols_in, cols_out, size));
185 Self {
186 data,
187 shape: MatZnxShape::new(n, rows, cols_in, cols_out, size),
188 }
189 }
190
191 pub fn from_bytes(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize, bytes: impl Into<Vec<u8>>) -> Self {
192 let data: Vec<u8> = bytes.into();
193 assert!(data.len() == Self::bytes_of(n, rows, cols_in, cols_out, size));
194 crate::assert_alignment(data.as_ptr());
195 Self {
196 data,
197 shape: MatZnxShape::new(n, rows, cols_in, cols_out, size),
198 }
199 }
200}
201
202impl<D: HostDataRef> MatZnx<D> {
203 pub fn at(&self, row: usize, col: usize) -> VecZnx<&[u8]> {
209 #[cfg(debug_assertions)]
210 {
211 assert!(row < self.rows(), "rows: {} >= {}", row, self.rows());
212 assert!(col < self.cols_in(), "cols: {} >= {}", col, self.cols_in());
213 }
214
215 let self_ref = MatZnx {
216 data: self.data.as_ref(),
217 shape: self.shape,
218 };
219 let nb_bytes: usize = VecZnx::<Vec<u8>>::bytes_of(self.n(), self.cols_out(), self.size());
220 let start: usize = nb_bytes * self.cols() * row + col * nb_bytes;
221 let end: usize = start + nb_bytes;
222
223 VecZnx::from_data(&self_ref.data[start..end], self.n(), self.cols_out(), self.size())
224 }
225}
226
227impl<D: HostDataMut> MatZnx<D> {
228 pub fn at_mut(&mut self, row: usize, col: usize) -> VecZnx<&mut [u8]> {
234 #[cfg(debug_assertions)]
235 {
236 assert!(row < self.rows(), "rows: {} >= {}", row, self.rows());
237 assert!(col < self.cols_in(), "cols: {} >= {}", col, self.cols_in());
238 }
239
240 let n: usize = self.n();
241 let rows: usize = self.rows();
242 let cols_out: usize = self.cols_out();
243 let cols_in: usize = self.cols_in();
244 let size: usize = self.size();
245
246 let self_ref = MatZnx {
247 data: self.data.as_mut(),
248 shape: MatZnxShape::new(n, rows, cols_in, cols_out, size),
249 };
250 let nb_bytes: usize = VecZnx::<Vec<u8>>::bytes_of(n, cols_out, size);
251 let start: usize = nb_bytes * cols_in * row + col * nb_bytes;
252 let end: usize = start + nb_bytes;
253
254 VecZnx::from_data(&mut self_ref.data[start..end], n, cols_out, size)
255 }
256}
257
258pub trait MatZnxAtBackendRef<B: Backend> {
260 fn at_backend(&self, row: usize, col: usize) -> VecZnx<B::BufRef<'_>>;
261}
262
263impl<B: Backend> MatZnxAtBackendRef<B> for MatZnx<B::OwnedBuf> {
264 fn at_backend(&self, row: usize, col: usize) -> VecZnx<B::BufRef<'_>> {
265 #[cfg(debug_assertions)]
266 {
267 assert!(row < self.rows(), "rows: {} >= {}", row, self.rows());
268 assert!(col < self.cols_in(), "cols: {} >= {}", col, self.cols_in());
269 }
270
271 let nb_bytes: usize = VecZnx::<Vec<u8>>::bytes_of(self.n(), self.cols_out(), self.size());
272 let start: usize = nb_bytes * self.cols() * row + col * nb_bytes;
273 let end: usize = start + nb_bytes;
274
275 VecZnx::from_data(
276 B::region(&self.data, start, end - start),
277 self.n(),
278 self.cols_out(),
279 self.size(),
280 )
281 }
282}
283
284pub fn mat_znx_at_backend_ref_from_ref<'a, 'b, B: Backend + 'b>(
285 mat: &'a MatZnx<B::BufRef<'b>>,
286 row: usize,
287 col: usize,
288) -> VecZnx<B::BufRef<'a>> {
289 #[cfg(debug_assertions)]
290 {
291 assert!(row < mat.rows(), "rows: {} >= {}", row, mat.rows());
292 assert!(col < mat.cols_in(), "cols: {} >= {}", col, mat.cols_in());
293 }
294
295 let nb_bytes: usize = VecZnx::<Vec<u8>>::bytes_of(mat.n(), mat.cols_out(), mat.size());
296 let start: usize = nb_bytes * mat.cols() * row + col * nb_bytes;
297 let end: usize = start + nb_bytes;
298
299 VecZnx::from_data(
300 B::region_ref(&mat.data, start, end - start),
301 mat.n(),
302 mat.cols_out(),
303 mat.size(),
304 )
305}
306
307pub fn mat_znx_at_backend_ref_from_mut<'a, 'b, B: Backend + 'b>(
308 mat: &'a MatZnx<B::BufMut<'b>>,
309 row: usize,
310 col: usize,
311) -> VecZnx<B::BufRef<'a>> {
312 #[cfg(debug_assertions)]
313 {
314 assert!(row < mat.rows(), "rows: {} >= {}", row, mat.rows());
315 assert!(col < mat.cols_in(), "cols: {} >= {}", col, mat.cols_in());
316 }
317
318 let nb_bytes: usize = VecZnx::<Vec<u8>>::bytes_of(mat.n(), mat.cols_out(), mat.size());
319 let start: usize = nb_bytes * mat.cols() * row + col * nb_bytes;
320 let end: usize = start + nb_bytes;
321
322 VecZnx::from_data(
323 B::region_ref_mut(&mat.data, start, end - start),
324 mat.n(),
325 mat.cols_out(),
326 mat.size(),
327 )
328}
329
330pub trait MatZnxAtBackendMut<B: Backend> {
332 fn at_backend_mut(&mut self, row: usize, col: usize) -> VecZnx<B::BufMut<'_>>;
333}
334
335impl<B: Backend> MatZnxAtBackendMut<B> for MatZnx<B::OwnedBuf> {
336 fn at_backend_mut(&mut self, row: usize, col: usize) -> VecZnx<B::BufMut<'_>> {
337 #[cfg(debug_assertions)]
338 {
339 assert!(row < self.rows(), "rows: {} >= {}", row, self.rows());
340 assert!(col < self.cols_in(), "cols: {} >= {}", col, self.cols_in());
341 }
342
343 let n: usize = self.n();
344 let cols_out: usize = self.cols_out();
345 let cols_in: usize = self.cols_in();
346 let size: usize = self.size();
347 let nb_bytes: usize = VecZnx::<Vec<u8>>::bytes_of(n, cols_out, size);
348 let start: usize = nb_bytes * cols_in * row + col * nb_bytes;
349 let end: usize = start + nb_bytes;
350
351 VecZnx::from_data(B::region_mut(&mut self.data, start, end - start), n, cols_out, size)
352 }
353}
354
355pub fn mat_znx_at_backend_mut_from_mut<'a, 'b, B: Backend + 'b>(
356 mat: &'a mut MatZnx<B::BufMut<'b>>,
357 row: usize,
358 col: usize,
359) -> VecZnx<B::BufMut<'a>> {
360 #[cfg(debug_assertions)]
361 {
362 assert!(row < mat.rows(), "rows: {} >= {}", row, mat.rows());
363 assert!(col < mat.cols_in(), "cols: {} >= {}", col, mat.cols_in());
364 }
365
366 let n: usize = mat.n();
367 let cols_out: usize = mat.cols_out();
368 let cols_in: usize = mat.cols_in();
369 let size: usize = mat.size();
370 let nb_bytes: usize = VecZnx::<Vec<u8>>::bytes_of(n, cols_out, size);
371 let start: usize = nb_bytes * cols_in * row + col * nb_bytes;
372 let end: usize = start + nb_bytes;
373
374 VecZnx::from_data(B::region_mut_ref(&mut mat.data, start, end - start), n, cols_out, size)
375}
376
377impl<D: HostDataMut> FillUniform for MatZnx<D> {
378 fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) {
379 match log_bound {
380 64 => source.fill_bytes(self.data.as_mut()),
381 0 => panic!("invalid log_bound, cannot be zero"),
382 _ => {
383 let mask: u64 = (1u64 << log_bound) - 1;
384 for x in self.raw_mut().iter_mut() {
385 let r = source.next_u64() & mask;
386 *x = ((r << (64 - log_bound)) as i64) >> (64 - log_bound);
387 }
388 }
389 }
390 }
391}
392
393pub type MatZnxOwned = MatZnx<Vec<u8>>;
395pub type MatZnxMut<'a> = MatZnx<&'a mut [u8]>;
397pub type MatZnxRef<'a> = MatZnx<&'a [u8]>;
399pub type MatZnxBackendRef<'a, B> = MatZnx<<B as Backend>::BufRef<'a>>;
401pub type MatZnxBackendMut<'a, B> = MatZnx<<B as Backend>::BufMut<'a>>;
403
404pub trait MatZnxToBackendRef<B: Backend> {
406 fn to_backend_ref(&self) -> MatZnxBackendRef<'_, B>;
407}
408
409impl<B: Backend> MatZnxToBackendRef<B> for MatZnx<B::OwnedBuf> {
410 fn to_backend_ref(&self) -> MatZnxBackendRef<'_, B> {
411 MatZnx {
412 data: B::view(&self.data),
413 shape: self.shape,
414 }
415 }
416}
417
418impl<'b, B: Backend + 'b> MatZnxToBackendRef<B> for &MatZnx<B::BufRef<'b>> {
419 fn to_backend_ref(&self) -> MatZnxBackendRef<'_, B> {
420 mat_znx_backend_ref_from_ref::<B>(self)
421 }
422}
423
424impl<'b, B: Backend + 'b> MatZnxToBackendRef<B> for &mut MatZnx<B::BufMut<'b>> {
425 fn to_backend_ref(&self) -> MatZnxBackendRef<'_, B> {
426 mat_znx_backend_ref_from_mut::<B>(self)
427 }
428}
429
430pub fn mat_znx_backend_ref_from_ref<'a, 'b, B: Backend + 'b>(mat: &'a MatZnx<B::BufRef<'b>>) -> MatZnxBackendRef<'a, B> {
431 MatZnx {
432 data: B::view_ref(&mat.data),
433 shape: mat.shape,
434 }
435}
436
437pub fn mat_znx_backend_ref_from_mut<'a, 'b, B: Backend + 'b>(mat: &'a MatZnx<B::BufMut<'b>>) -> MatZnxBackendRef<'a, B> {
438 MatZnx {
439 data: B::view_ref_mut(&mat.data),
440 shape: mat.shape,
441 }
442}
443
444pub trait MatZnxToBackendMut<B: Backend> {
446 fn to_backend_mut(&mut self) -> MatZnxBackendMut<'_, B>;
447}
448
449impl<B: Backend> MatZnxToBackendMut<B> for MatZnx<B::OwnedBuf> {
450 fn to_backend_mut(&mut self) -> MatZnxBackendMut<'_, B> {
451 MatZnx {
452 data: B::view_mut(&mut self.data),
453 shape: self.shape,
454 }
455 }
456}
457
458impl<'b, B: Backend + 'b> MatZnxToBackendMut<B> for &mut MatZnx<B::BufMut<'b>> {
459 fn to_backend_mut(&mut self) -> MatZnxBackendMut<'_, B> {
460 mat_znx_backend_mut_from_mut::<B>(self)
461 }
462}
463
464pub fn mat_znx_backend_mut_from_mut<'a, 'b, B: Backend + 'b>(mat: &'a mut MatZnx<B::BufMut<'b>>) -> MatZnxBackendMut<'a, B> {
465 MatZnx {
466 data: B::view_mut_ref(&mut mat.data),
467 shape: mat.shape,
468 }
469}
470
471impl<D: Data> MatZnx<D> {
472 pub fn from_data(data: D, n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> Self {
473 Self {
474 data,
475 shape: MatZnxShape::new(n, rows, cols_in, cols_out, size),
476 }
477 }
478}
479
480impl<D: HostDataMut> ReaderFrom for MatZnx<D> {
481 fn read_from<R: std::io::Read>(&mut self, reader: &mut R) -> std::io::Result<()> {
482 let new_n: usize = reader.read_u64::<LittleEndian>()? as usize;
483 let new_size: usize = reader.read_u64::<LittleEndian>()? as usize;
484 let new_rows: usize = reader.read_u64::<LittleEndian>()? as usize;
485 let new_cols_in: usize = reader.read_u64::<LittleEndian>()? as usize;
486 let new_cols_out: usize = reader.read_u64::<LittleEndian>()? as usize;
487 let len: usize = reader.read_u64::<LittleEndian>()? as usize;
488
489 let expected_len: usize = new_rows * new_cols_in * new_n * new_cols_out * new_size * size_of::<i64>();
490 if expected_len != len {
491 return Err(std::io::Error::new(
492 std::io::ErrorKind::InvalidData,
493 format!(
494 "MatZnx metadata inconsistent: rows={new_rows} * cols_in={new_cols_in} * n={new_n} * cols_out={new_cols_out} * size={new_size} * 8 = {expected_len} != data len={len}"
495 ),
496 ));
497 }
498
499 let buf: &mut [u8] = self.data.as_mut();
500 if buf.len() < len {
501 return Err(std::io::Error::new(
502 std::io::ErrorKind::InvalidData,
503 format!("MatZnx buffer too small: self.data.len()={} < read len={len}", buf.len()),
504 ));
505 }
506 reader.read_exact(&mut buf[..len])?;
507
508 self.shape = MatZnxShape::new(new_n, new_rows, new_cols_in, new_cols_out, new_size);
509 Ok(())
510 }
511}
512
513impl<D: HostDataRef> WriterTo for MatZnx<D> {
514 fn write_to<W: std::io::Write>(&self, writer: &mut W) -> std::io::Result<()> {
515 writer.write_u64::<LittleEndian>(self.n() as u64)?;
516 writer.write_u64::<LittleEndian>(self.size() as u64)?;
517 writer.write_u64::<LittleEndian>(self.rows() as u64)?;
518 writer.write_u64::<LittleEndian>(self.cols_in() as u64)?;
519 writer.write_u64::<LittleEndian>(self.cols_out() as u64)?;
520 let logical_len: usize = MatZnx::<Vec<u8>>::bytes_of(self.n(), self.rows(), self.cols_in(), self.cols_out(), self.size());
521 let buf: &[u8] = self.data.as_ref();
522 if buf.len() < logical_len {
523 return Err(std::io::Error::new(
524 std::io::ErrorKind::InvalidData,
525 format!(
526 "MatZnx buffer too small: self.data.len()={} < logical_len={logical_len}",
527 buf.len()
528 ),
529 ));
530 }
531 writer.write_u64::<LittleEndian>(logical_len as u64)?;
532 writer.write_all(&buf[..logical_len])?;
533 Ok(())
534 }
535}
536
537impl<D: HostDataRef> fmt::Display for MatZnx<D> {
538 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
539 writeln!(
540 f,
541 "MatZnx(n={}, rows={}, cols_in={}, cols_out={}, size={})",
542 self.n(),
543 self.rows(),
544 self.cols_in(),
545 self.cols_out(),
546 self.size()
547 )?;
548
549 for row_i in 0..self.rows() {
550 writeln!(f, "Row {row_i}:")?;
551 for col_i in 0..self.cols_in() {
552 writeln!(f, "cols_in {col_i}:")?;
553 writeln!(f, "{}:", self.at(row_i, col_i))?;
554 }
555 }
556 Ok(())
557 }
558}
559
560impl<D: HostDataMut> ZnxZero for MatZnx<D> {
561 fn zero(&mut self) {
562 self.raw_mut().fill(0)
563 }
564
565 fn zero_at(&mut self, i: usize, j: usize) {
566 self.at_mut(i, j).zero();
567 }
568}