1use alloc::vec;
2use alloc::vec::Vec;
3use core::borrow::{Borrow, BorrowMut};
4use core::marker::PhantomData;
5use core::ops::Deref;
6use core::{iter, slice};
7
8use p3_field::{ExtensionField, Field, PackedValue};
9use p3_maybe_rayon::prelude::*;
10use rand::distributions::{Distribution, Standard};
11use rand::Rng;
12use serde::{Deserialize, Serialize};
13
14use crate::Matrix;
15
16const TRANSPOSE_BLOCK_SIZE: usize = 64;
18
19#[derive(Copy, Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
21pub struct DenseMatrix<T, V = Vec<T>> {
22 pub values: V,
23 pub width: usize,
24 _phantom: PhantomData<T>,
25}
26
27pub type RowMajorMatrix<T> = DenseMatrix<T, Vec<T>>;
28pub type RowMajorMatrixView<'a, T> = DenseMatrix<T, &'a [T]>;
29pub type RowMajorMatrixViewMut<'a, T> = DenseMatrix<T, &'a mut [T]>;
30
31pub trait DenseStorage<T>: Borrow<[T]> + Into<Vec<T>> + Send + Sync {}
32impl<T, S: Borrow<[T]> + Into<Vec<T>> + Send + Sync> DenseStorage<T> for S {}
33
34impl<T: Clone + Send + Sync + Default> DenseMatrix<T> {
35 #[must_use]
38 pub fn default(width: usize, height: usize) -> Self {
39 Self::new(vec![T::default(); width * height], width)
40 }
41}
42
43impl<T: Clone + Send + Sync, S: DenseStorage<T>> DenseMatrix<T, S> {
44 #[must_use]
45 pub fn new(values: S, width: usize) -> Self {
46 debug_assert!(width == 0 || values.borrow().len() % width == 0);
47 Self {
48 values,
49 width,
50 _phantom: PhantomData,
51 }
52 }
53
54 #[must_use]
55 pub fn new_row(values: S) -> Self {
56 let width = values.borrow().len();
57 Self::new(values, width)
58 }
59
60 #[must_use]
61 pub fn new_col(values: S) -> Self {
62 Self::new(values, 1)
63 }
64
65 pub fn as_view(&self) -> RowMajorMatrixView<'_, T> {
66 RowMajorMatrixView::new(self.values.borrow(), self.width)
67 }
68
69 pub fn as_view_mut(&mut self) -> RowMajorMatrixViewMut<'_, T>
70 where
71 S: BorrowMut<[T]>,
72 {
73 RowMajorMatrixViewMut::new(self.values.borrow_mut(), self.width)
74 }
75
76 pub fn flatten_to_base<F: Field>(&self) -> RowMajorMatrix<F>
77 where
78 T: ExtensionField<F>,
79 {
80 let width = self.width * T::D;
81 let values = self
82 .values
83 .borrow()
84 .iter()
85 .flat_map(|x| x.as_base_slice().iter().copied())
86 .collect();
87 RowMajorMatrix::new(values, width)
88 }
89
90 pub fn par_row_slices(&self) -> impl IndexedParallelIterator<Item = &[T]>
91 where
92 T: Sync,
93 {
94 self.values.borrow().par_chunks_exact(self.width)
95 }
96
97 pub fn row_mut(&mut self, r: usize) -> &mut [T]
98 where
99 S: BorrowMut<[T]>,
100 {
101 &mut self.values.borrow_mut()[r * self.width..(r + 1) * self.width]
102 }
103
104 pub fn rows_mut(&mut self) -> impl Iterator<Item = &mut [T]>
105 where
106 S: BorrowMut<[T]>,
107 {
108 self.values.borrow_mut().chunks_exact_mut(self.width)
109 }
110
111 pub fn par_rows_mut<'a>(&'a mut self) -> impl IndexedParallelIterator<Item = &'a mut [T]>
112 where
113 T: 'a + Send,
114 S: BorrowMut<[T]>,
115 {
116 self.values.borrow_mut().par_chunks_exact_mut(self.width)
117 }
118
119 pub fn horizontally_packed_row_mut<P>(&mut self, r: usize) -> (&mut [P], &mut [T])
120 where
121 P: PackedValue<Value = T>,
122 S: BorrowMut<[T]>,
123 {
124 P::pack_slice_with_suffix_mut(self.row_mut(r))
125 }
126
127 pub fn scale_row(&mut self, r: usize, scale: T)
128 where
129 T: Field,
130 S: BorrowMut<[T]>,
131 {
132 let (packed, sfx) = self.horizontally_packed_row_mut::<T::Packing>(r);
133 let packed_scale: T::Packing = scale.into();
134 packed.iter_mut().for_each(|x| *x *= packed_scale);
135 sfx.iter_mut().for_each(|x| *x *= scale);
136 }
137
138 pub fn scale(&mut self, scale: T)
139 where
140 T: Field,
141 S: BorrowMut<[T]>,
142 {
143 let (packed, sfx) = T::Packing::pack_slice_with_suffix_mut(self.values.borrow_mut());
144 let packed_scale: T::Packing = scale.into();
145 packed.iter_mut().for_each(|x| *x *= packed_scale);
146 sfx.iter_mut().for_each(|x| *x *= scale);
147 }
148
149 pub fn split_rows(&self, r: usize) -> (RowMajorMatrixView<T>, RowMajorMatrixView<T>) {
150 let (lo, hi) = self.values.borrow().split_at(r * self.width);
151 (
152 DenseMatrix::new(lo, self.width),
153 DenseMatrix::new(hi, self.width),
154 )
155 }
156
157 pub fn split_rows_mut(
158 &mut self,
159 r: usize,
160 ) -> (RowMajorMatrixViewMut<T>, RowMajorMatrixViewMut<T>)
161 where
162 S: BorrowMut<[T]>,
163 {
164 let (lo, hi) = self.values.borrow_mut().split_at_mut(r * self.width);
165 (
166 DenseMatrix::new(lo, self.width),
167 DenseMatrix::new(hi, self.width),
168 )
169 }
170
171 pub fn par_row_chunks_mut(
172 &mut self,
173 chunk_rows: usize,
174 ) -> impl IndexedParallelIterator<Item = RowMajorMatrixViewMut<T>>
175 where
176 T: Send,
177 S: BorrowMut<[T]>,
178 {
179 self.values
180 .borrow_mut()
181 .par_chunks_mut(self.width * chunk_rows)
182 .map(|slice| RowMajorMatrixViewMut::new(slice, self.width))
183 }
184
185 pub fn par_row_chunks_exact_mut(
186 &mut self,
187 chunk_rows: usize,
188 ) -> impl IndexedParallelIterator<Item = RowMajorMatrixViewMut<T>>
189 where
190 T: Send,
191 S: BorrowMut<[T]>,
192 {
193 self.values
194 .borrow_mut()
195 .par_chunks_exact_mut(self.width * chunk_rows)
196 .map(|slice| RowMajorMatrixViewMut::new(slice, self.width))
197 }
198
199 pub fn row_pair_mut(&mut self, row_1: usize, row_2: usize) -> (&mut [T], &mut [T])
200 where
201 S: BorrowMut<[T]>,
202 {
203 debug_assert_ne!(row_1, row_2);
204 let start_1 = row_1 * self.width;
205 let start_2 = row_2 * self.width;
206 let (lo, hi) = self.values.borrow_mut().split_at_mut(start_2);
207 (&mut lo[start_1..][..self.width], &mut hi[..self.width])
208 }
209
210 #[allow(clippy::type_complexity)]
211 pub fn packed_row_pair_mut<P>(
212 &mut self,
213 row_1: usize,
214 row_2: usize,
215 ) -> ((&mut [P], &mut [T]), (&mut [P], &mut [T]))
216 where
217 S: BorrowMut<[T]>,
218 P: PackedValue<Value = T>,
219 {
220 let (slice_1, slice_2) = self.row_pair_mut(row_1, row_2);
221 (
222 P::pack_slice_with_suffix_mut(slice_1),
223 P::pack_slice_with_suffix_mut(slice_2),
224 )
225 }
226
227 pub fn bit_reversed_zero_pad(self, added_bits: usize) -> RowMajorMatrix<T>
228 where
229 T: Copy + Default + Send + Sync,
230 {
231 if added_bits == 0 {
232 return self.to_row_major_matrix();
233 }
234
235 let w = self.width;
245 let mut padded = RowMajorMatrix::new(
246 vec![T::default(); self.values.borrow().len() << added_bits],
247 w,
248 );
249 padded
250 .par_row_chunks_exact_mut(1 << added_bits)
251 .zip(self.par_row_slices())
252 .for_each(|(mut ch, r)| ch.row_mut(0).copy_from_slice(r));
253
254 padded
255 }
256}
257
258impl<T: Clone + Send + Sync, S: DenseStorage<T>> Matrix<T> for DenseMatrix<T, S> {
259 fn width(&self) -> usize {
260 self.width
261 }
262 fn height(&self) -> usize {
263 if self.width == 0 {
264 0
265 } else {
266 self.values.borrow().len() / self.width
267 }
268 }
269 fn get(&self, r: usize, c: usize) -> T {
270 self.values.borrow()[r * self.width + c].clone()
271 }
272 type Row<'a>
273 = iter::Cloned<slice::Iter<'a, T>>
274 where
275 Self: 'a;
276 fn row(&self, r: usize) -> Self::Row<'_> {
277 self.values.borrow()[r * self.width..(r + 1) * self.width]
278 .iter()
279 .cloned()
280 }
281 fn row_slice(&self, r: usize) -> impl Deref<Target = [T]> {
282 &self.values.borrow()[r * self.width..(r + 1) * self.width]
283 }
284 fn to_row_major_matrix(self) -> RowMajorMatrix<T>
285 where
286 Self: Sized,
287 T: Clone,
288 {
289 RowMajorMatrix::new(self.values.into(), self.width)
290 }
291 fn horizontally_packed_row<'a, P>(
292 &'a self,
293 r: usize,
294 ) -> (impl Iterator<Item = P>, impl Iterator<Item = T>)
295 where
296 P: PackedValue<Value = T>,
297 T: Clone + 'a,
298 {
299 let buf = &self.values.borrow()[r * self.width..(r + 1) * self.width];
300 let (packed, sfx) = P::pack_slice_with_suffix(buf);
301 (packed.iter().cloned(), sfx.iter().cloned())
302 }
303}
304
305impl<T: Clone + Default + Send + Sync> DenseMatrix<T, Vec<T>> {
306 pub fn rand<R: Rng>(rng: &mut R, rows: usize, cols: usize) -> Self
307 where
308 Standard: Distribution<T>,
309 {
310 let values = rng.sample_iter(Standard).take(rows * cols).collect();
311 Self::new(values, cols)
312 }
313
314 pub fn rand_nonzero<R: Rng>(rng: &mut R, rows: usize, cols: usize) -> Self
315 where
316 T: Field,
317 Standard: Distribution<T>,
318 {
319 let values = rng
320 .sample_iter(Standard)
321 .filter(|x| !x.is_zero())
322 .take(rows * cols)
323 .collect();
324 Self::new(values, cols)
325 }
326
327 pub fn transpose(self) -> Self {
328 let block_size = TRANSPOSE_BLOCK_SIZE;
329 let height = self.height();
330 let width = self.width();
331
332 let transposed_values: Vec<T> = vec![T::default(); width * height];
333 let mut transposed = Self::new(transposed_values, height);
334
335 transposed
336 .values
337 .par_chunks_mut(height)
338 .enumerate()
339 .for_each(|(row_ind, row)| {
340 row.par_chunks_mut(block_size)
341 .enumerate()
342 .for_each(|(block_num, row_block)| {
343 let row_block_len = row_block.len();
344 (0..row_block_len).for_each(|col_ind| {
345 let original_mat_row_ind = block_size * block_num + col_ind;
346 let original_mat_col_ind = row_ind;
347 let original_values_index =
348 original_mat_row_ind * width + original_mat_col_ind;
349
350 row_block[col_ind] = self.values[original_values_index].clone();
351 });
352 });
353 });
354
355 transposed
356 }
357}
358
359#[cfg(test)]
360mod tests {
361 use super::*;
362
363 #[test]
364 fn test_transpose_square_matrix() {
365 const START_INDEX: usize = 1;
366 const VALUE_LEN: usize = 9;
367 const WIDTH: usize = 3;
368 const HEIGHT: usize = 3;
369
370 let matrix_values = (START_INDEX..=VALUE_LEN).collect::<Vec<_>>();
371 let matrix = RowMajorMatrix::new(matrix_values, WIDTH);
372 let transposed = matrix.transpose();
373 let should_be_transposed_values = vec![1, 4, 7, 2, 5, 8, 3, 6, 9];
374 let should_be_transposed = RowMajorMatrix::new(should_be_transposed_values, HEIGHT);
375 assert_eq!(transposed, should_be_transposed);
376 }
377
378 #[test]
379 fn test_transpose_row_matrix() {
380 const START_INDEX: usize = 1;
381 const VALUE_LEN: usize = 30;
382 const WIDTH: usize = 1;
383 const HEIGHT: usize = 30;
384
385 let matrix_values = (START_INDEX..=VALUE_LEN).collect::<Vec<_>>();
386 let matrix = RowMajorMatrix::new(matrix_values.clone(), WIDTH);
387 let transposed = matrix.transpose();
388 let should_be_transposed = RowMajorMatrix::new(matrix_values, HEIGHT);
389 assert_eq!(transposed, should_be_transposed);
390 }
391
392 #[test]
393 fn test_transpose_rectangular_matrix() {
394 const START_INDEX: usize = 1;
395 const VALUE_LEN: usize = 30;
396 const WIDTH: usize = 5;
397 const HEIGHT: usize = 6;
398
399 let matrix_values = (START_INDEX..=VALUE_LEN).collect::<Vec<_>>();
400 let matrix = RowMajorMatrix::new(matrix_values, WIDTH);
401 let transposed = matrix.transpose();
402 let should_be_transposed_values = vec![
403 1, 6, 11, 16, 21, 26, 2, 7, 12, 17, 22, 27, 3, 8, 13, 18, 23, 28, 4, 9, 14, 19, 24, 29,
404 5, 10, 15, 20, 25, 30,
405 ];
406 let should_be_transposed = RowMajorMatrix::new(should_be_transposed_values, HEIGHT);
407 assert_eq!(transposed, should_be_transposed);
408 }
409
410 #[test]
411 fn test_transpose_larger_rectangular_matrix() {
412 const START_INDEX: usize = 1;
413 const VALUE_LEN: usize = 131072; const WIDTH: usize = 256;
415 const HEIGHT: usize = 512;
416
417 let matrix_values = (START_INDEX..=VALUE_LEN).collect::<Vec<_>>();
418 let matrix = RowMajorMatrix::new(matrix_values, WIDTH);
419 let transposed = matrix.clone().transpose();
420
421 assert_eq!(transposed.width(), HEIGHT);
422 assert_eq!(transposed.height(), WIDTH);
423
424 for col_index in 0..WIDTH {
425 for row_index in 0..HEIGHT {
426 assert_eq!(
427 matrix.values[row_index * WIDTH + col_index],
428 transposed.values[col_index * HEIGHT + row_index]
429 );
430 }
431 }
432 }
433
434 #[test]
435 fn test_transpose_very_large_rectangular_matrix() {
436 const START_INDEX: usize = 1;
437 const VALUE_LEN: usize = 1048576; const WIDTH: usize = 1024;
439 const HEIGHT: usize = 1024;
440
441 let matrix_values = (START_INDEX..=VALUE_LEN).collect::<Vec<_>>();
442 let matrix = RowMajorMatrix::new(matrix_values, WIDTH);
443 let transposed = matrix.clone().transpose();
444
445 assert_eq!(transposed.width(), HEIGHT);
446 assert_eq!(transposed.height(), WIDTH);
447
448 for col_index in 0..WIDTH {
449 for row_index in 0..HEIGHT {
450 assert_eq!(
451 matrix.values[row_index * WIDTH + col_index],
452 transposed.values[col_index * HEIGHT + row_index]
453 );
454 }
455 }
456 }
457}