p3_matrix/
dense.rs

1use alloc::vec;
2use alloc::vec::Vec;
3use core::borrow::{Borrow, BorrowMut};
4use core::marker::PhantomData;
5use core::ops::Deref;
6use core::{iter, slice};
7
8use p3_field::{ExtensionField, Field, PackedValue};
9use p3_maybe_rayon::prelude::*;
10use rand::distributions::{Distribution, Standard};
11use rand::Rng;
12use serde::{Deserialize, Serialize};
13
14use crate::Matrix;
15
16/// A default constant for block size matrix transposition. The value was chosen with 32-byte type, in mind.
17const TRANSPOSE_BLOCK_SIZE: usize = 64;
18
19/// A dense matrix stored in row-major form.
20#[derive(Copy, Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
21pub struct DenseMatrix<T, V = Vec<T>> {
22    pub values: V,
23    pub width: usize,
24    _phantom: PhantomData<T>,
25}
26
27pub type RowMajorMatrix<T> = DenseMatrix<T, Vec<T>>;
28pub type RowMajorMatrixView<'a, T> = DenseMatrix<T, &'a [T]>;
29pub type RowMajorMatrixViewMut<'a, T> = DenseMatrix<T, &'a mut [T]>;
30
31pub trait DenseStorage<T>: Borrow<[T]> + Into<Vec<T>> + Send + Sync {}
32impl<T, S: Borrow<[T]> + Into<Vec<T>> + Send + Sync> DenseStorage<T> for S {}
33
34impl<T: Clone + Send + Sync + Default> DenseMatrix<T> {
35    /// Create a new dense matrix of the given dimensions, backed by a `Vec`, and filled with
36    /// default values.
37    #[must_use]
38    pub fn default(width: usize, height: usize) -> Self {
39        Self::new(vec![T::default(); width * height], width)
40    }
41}
42
43impl<T: Clone + Send + Sync, S: DenseStorage<T>> DenseMatrix<T, S> {
44    #[must_use]
45    pub fn new(values: S, width: usize) -> Self {
46        debug_assert!(width == 0 || values.borrow().len() % width == 0);
47        Self {
48            values,
49            width,
50            _phantom: PhantomData,
51        }
52    }
53
54    #[must_use]
55    pub fn new_row(values: S) -> Self {
56        let width = values.borrow().len();
57        Self::new(values, width)
58    }
59
60    #[must_use]
61    pub fn new_col(values: S) -> Self {
62        Self::new(values, 1)
63    }
64
65    pub fn as_view(&self) -> RowMajorMatrixView<'_, T> {
66        RowMajorMatrixView::new(self.values.borrow(), self.width)
67    }
68
69    pub fn as_view_mut(&mut self) -> RowMajorMatrixViewMut<'_, T>
70    where
71        S: BorrowMut<[T]>,
72    {
73        RowMajorMatrixViewMut::new(self.values.borrow_mut(), self.width)
74    }
75
76    pub fn flatten_to_base<F: Field>(&self) -> RowMajorMatrix<F>
77    where
78        T: ExtensionField<F>,
79    {
80        let width = self.width * T::D;
81        let values = self
82            .values
83            .borrow()
84            .iter()
85            .flat_map(|x| x.as_base_slice().iter().copied())
86            .collect();
87        RowMajorMatrix::new(values, width)
88    }
89
90    pub fn par_row_slices(&self) -> impl IndexedParallelIterator<Item = &[T]>
91    where
92        T: Sync,
93    {
94        self.values.borrow().par_chunks_exact(self.width)
95    }
96
97    pub fn row_mut(&mut self, r: usize) -> &mut [T]
98    where
99        S: BorrowMut<[T]>,
100    {
101        &mut self.values.borrow_mut()[r * self.width..(r + 1) * self.width]
102    }
103
104    pub fn rows_mut(&mut self) -> impl Iterator<Item = &mut [T]>
105    where
106        S: BorrowMut<[T]>,
107    {
108        self.values.borrow_mut().chunks_exact_mut(self.width)
109    }
110
111    pub fn par_rows_mut<'a>(&'a mut self) -> impl IndexedParallelIterator<Item = &'a mut [T]>
112    where
113        T: 'a + Send,
114        S: BorrowMut<[T]>,
115    {
116        self.values.borrow_mut().par_chunks_exact_mut(self.width)
117    }
118
119    pub fn horizontally_packed_row_mut<P>(&mut self, r: usize) -> (&mut [P], &mut [T])
120    where
121        P: PackedValue<Value = T>,
122        S: BorrowMut<[T]>,
123    {
124        P::pack_slice_with_suffix_mut(self.row_mut(r))
125    }
126
127    pub fn scale_row(&mut self, r: usize, scale: T)
128    where
129        T: Field,
130        S: BorrowMut<[T]>,
131    {
132        let (packed, sfx) = self.horizontally_packed_row_mut::<T::Packing>(r);
133        let packed_scale: T::Packing = scale.into();
134        packed.iter_mut().for_each(|x| *x *= packed_scale);
135        sfx.iter_mut().for_each(|x| *x *= scale);
136    }
137
138    pub fn scale(&mut self, scale: T)
139    where
140        T: Field,
141        S: BorrowMut<[T]>,
142    {
143        let (packed, sfx) = T::Packing::pack_slice_with_suffix_mut(self.values.borrow_mut());
144        let packed_scale: T::Packing = scale.into();
145        packed.iter_mut().for_each(|x| *x *= packed_scale);
146        sfx.iter_mut().for_each(|x| *x *= scale);
147    }
148
149    pub fn split_rows(&self, r: usize) -> (RowMajorMatrixView<T>, RowMajorMatrixView<T>) {
150        let (lo, hi) = self.values.borrow().split_at(r * self.width);
151        (
152            DenseMatrix::new(lo, self.width),
153            DenseMatrix::new(hi, self.width),
154        )
155    }
156
157    pub fn split_rows_mut(
158        &mut self,
159        r: usize,
160    ) -> (RowMajorMatrixViewMut<T>, RowMajorMatrixViewMut<T>)
161    where
162        S: BorrowMut<[T]>,
163    {
164        let (lo, hi) = self.values.borrow_mut().split_at_mut(r * self.width);
165        (
166            DenseMatrix::new(lo, self.width),
167            DenseMatrix::new(hi, self.width),
168        )
169    }
170
171    pub fn par_row_chunks_mut(
172        &mut self,
173        chunk_rows: usize,
174    ) -> impl IndexedParallelIterator<Item = RowMajorMatrixViewMut<T>>
175    where
176        T: Send,
177        S: BorrowMut<[T]>,
178    {
179        self.values
180            .borrow_mut()
181            .par_chunks_mut(self.width * chunk_rows)
182            .map(|slice| RowMajorMatrixViewMut::new(slice, self.width))
183    }
184
185    pub fn par_row_chunks_exact_mut(
186        &mut self,
187        chunk_rows: usize,
188    ) -> impl IndexedParallelIterator<Item = RowMajorMatrixViewMut<T>>
189    where
190        T: Send,
191        S: BorrowMut<[T]>,
192    {
193        self.values
194            .borrow_mut()
195            .par_chunks_exact_mut(self.width * chunk_rows)
196            .map(|slice| RowMajorMatrixViewMut::new(slice, self.width))
197    }
198
199    pub fn row_pair_mut(&mut self, row_1: usize, row_2: usize) -> (&mut [T], &mut [T])
200    where
201        S: BorrowMut<[T]>,
202    {
203        debug_assert_ne!(row_1, row_2);
204        let start_1 = row_1 * self.width;
205        let start_2 = row_2 * self.width;
206        let (lo, hi) = self.values.borrow_mut().split_at_mut(start_2);
207        (&mut lo[start_1..][..self.width], &mut hi[..self.width])
208    }
209
210    #[allow(clippy::type_complexity)]
211    pub fn packed_row_pair_mut<P>(
212        &mut self,
213        row_1: usize,
214        row_2: usize,
215    ) -> ((&mut [P], &mut [T]), (&mut [P], &mut [T]))
216    where
217        S: BorrowMut<[T]>,
218        P: PackedValue<Value = T>,
219    {
220        let (slice_1, slice_2) = self.row_pair_mut(row_1, row_2);
221        (
222            P::pack_slice_with_suffix_mut(slice_1),
223            P::pack_slice_with_suffix_mut(slice_2),
224        )
225    }
226
227    pub fn bit_reversed_zero_pad(self, added_bits: usize) -> RowMajorMatrix<T>
228    where
229        T: Copy + Default + Send + Sync,
230    {
231        if added_bits == 0 {
232            return self.to_row_major_matrix();
233        }
234
235        // This is equivalent to:
236        //     reverse_matrix_index_bits(mat);
237        //     mat
238        //         .values
239        //         .resize(mat.values.len() << added_bits, F::zero());
240        //     reverse_matrix_index_bits(mat);
241        // But rather than implement it with bit reversals, we directly construct the resulting matrix,
242        // whose rows are zero except for rows whose low `added_bits` bits are zero.
243
244        let w = self.width;
245        let mut padded = RowMajorMatrix::new(
246            vec![T::default(); self.values.borrow().len() << added_bits],
247            w,
248        );
249        padded
250            .par_row_chunks_exact_mut(1 << added_bits)
251            .zip(self.par_row_slices())
252            .for_each(|(mut ch, r)| ch.row_mut(0).copy_from_slice(r));
253
254        padded
255    }
256}
257
258impl<T: Clone + Send + Sync, S: DenseStorage<T>> Matrix<T> for DenseMatrix<T, S> {
259    fn width(&self) -> usize {
260        self.width
261    }
262    fn height(&self) -> usize {
263        if self.width == 0 {
264            0
265        } else {
266            self.values.borrow().len() / self.width
267        }
268    }
269    fn get(&self, r: usize, c: usize) -> T {
270        self.values.borrow()[r * self.width + c].clone()
271    }
272    type Row<'a>
273        = iter::Cloned<slice::Iter<'a, T>>
274    where
275        Self: 'a;
276    fn row(&self, r: usize) -> Self::Row<'_> {
277        self.values.borrow()[r * self.width..(r + 1) * self.width]
278            .iter()
279            .cloned()
280    }
281    fn row_slice(&self, r: usize) -> impl Deref<Target = [T]> {
282        &self.values.borrow()[r * self.width..(r + 1) * self.width]
283    }
284    fn to_row_major_matrix(self) -> RowMajorMatrix<T>
285    where
286        Self: Sized,
287        T: Clone,
288    {
289        RowMajorMatrix::new(self.values.into(), self.width)
290    }
291    fn horizontally_packed_row<'a, P>(
292        &'a self,
293        r: usize,
294    ) -> (impl Iterator<Item = P>, impl Iterator<Item = T>)
295    where
296        P: PackedValue<Value = T>,
297        T: Clone + 'a,
298    {
299        let buf = &self.values.borrow()[r * self.width..(r + 1) * self.width];
300        let (packed, sfx) = P::pack_slice_with_suffix(buf);
301        (packed.iter().cloned(), sfx.iter().cloned())
302    }
303}
304
305impl<T: Clone + Default + Send + Sync> DenseMatrix<T, Vec<T>> {
306    pub fn rand<R: Rng>(rng: &mut R, rows: usize, cols: usize) -> Self
307    where
308        Standard: Distribution<T>,
309    {
310        let values = rng.sample_iter(Standard).take(rows * cols).collect();
311        Self::new(values, cols)
312    }
313
314    pub fn rand_nonzero<R: Rng>(rng: &mut R, rows: usize, cols: usize) -> Self
315    where
316        T: Field,
317        Standard: Distribution<T>,
318    {
319        let values = rng
320            .sample_iter(Standard)
321            .filter(|x| !x.is_zero())
322            .take(rows * cols)
323            .collect();
324        Self::new(values, cols)
325    }
326
327    pub fn transpose(self) -> Self {
328        let block_size = TRANSPOSE_BLOCK_SIZE;
329        let height = self.height();
330        let width = self.width();
331
332        let transposed_values: Vec<T> = vec![T::default(); width * height];
333        let mut transposed = Self::new(transposed_values, height);
334
335        transposed
336            .values
337            .par_chunks_mut(height)
338            .enumerate()
339            .for_each(|(row_ind, row)| {
340                row.par_chunks_mut(block_size)
341                    .enumerate()
342                    .for_each(|(block_num, row_block)| {
343                        let row_block_len = row_block.len();
344                        (0..row_block_len).for_each(|col_ind| {
345                            let original_mat_row_ind = block_size * block_num + col_ind;
346                            let original_mat_col_ind = row_ind;
347                            let original_values_index =
348                                original_mat_row_ind * width + original_mat_col_ind;
349
350                            row_block[col_ind] = self.values[original_values_index].clone();
351                        });
352                    });
353            });
354
355        transposed
356    }
357}
358
359#[cfg(test)]
360mod tests {
361    use super::*;
362
363    #[test]
364    fn test_transpose_square_matrix() {
365        const START_INDEX: usize = 1;
366        const VALUE_LEN: usize = 9;
367        const WIDTH: usize = 3;
368        const HEIGHT: usize = 3;
369
370        let matrix_values = (START_INDEX..=VALUE_LEN).collect::<Vec<_>>();
371        let matrix = RowMajorMatrix::new(matrix_values, WIDTH);
372        let transposed = matrix.transpose();
373        let should_be_transposed_values = vec![1, 4, 7, 2, 5, 8, 3, 6, 9];
374        let should_be_transposed = RowMajorMatrix::new(should_be_transposed_values, HEIGHT);
375        assert_eq!(transposed, should_be_transposed);
376    }
377
378    #[test]
379    fn test_transpose_row_matrix() {
380        const START_INDEX: usize = 1;
381        const VALUE_LEN: usize = 30;
382        const WIDTH: usize = 1;
383        const HEIGHT: usize = 30;
384
385        let matrix_values = (START_INDEX..=VALUE_LEN).collect::<Vec<_>>();
386        let matrix = RowMajorMatrix::new(matrix_values.clone(), WIDTH);
387        let transposed = matrix.transpose();
388        let should_be_transposed = RowMajorMatrix::new(matrix_values, HEIGHT);
389        assert_eq!(transposed, should_be_transposed);
390    }
391
392    #[test]
393    fn test_transpose_rectangular_matrix() {
394        const START_INDEX: usize = 1;
395        const VALUE_LEN: usize = 30;
396        const WIDTH: usize = 5;
397        const HEIGHT: usize = 6;
398
399        let matrix_values = (START_INDEX..=VALUE_LEN).collect::<Vec<_>>();
400        let matrix = RowMajorMatrix::new(matrix_values, WIDTH);
401        let transposed = matrix.transpose();
402        let should_be_transposed_values = vec![
403            1, 6, 11, 16, 21, 26, 2, 7, 12, 17, 22, 27, 3, 8, 13, 18, 23, 28, 4, 9, 14, 19, 24, 29,
404            5, 10, 15, 20, 25, 30,
405        ];
406        let should_be_transposed = RowMajorMatrix::new(should_be_transposed_values, HEIGHT);
407        assert_eq!(transposed, should_be_transposed);
408    }
409
410    #[test]
411    fn test_transpose_larger_rectangular_matrix() {
412        const START_INDEX: usize = 1;
413        const VALUE_LEN: usize = 131072; // 512 * 256
414        const WIDTH: usize = 256;
415        const HEIGHT: usize = 512;
416
417        let matrix_values = (START_INDEX..=VALUE_LEN).collect::<Vec<_>>();
418        let matrix = RowMajorMatrix::new(matrix_values, WIDTH);
419        let transposed = matrix.clone().transpose();
420
421        assert_eq!(transposed.width(), HEIGHT);
422        assert_eq!(transposed.height(), WIDTH);
423
424        for col_index in 0..WIDTH {
425            for row_index in 0..HEIGHT {
426                assert_eq!(
427                    matrix.values[row_index * WIDTH + col_index],
428                    transposed.values[col_index * HEIGHT + row_index]
429                );
430            }
431        }
432    }
433
434    #[test]
435    fn test_transpose_very_large_rectangular_matrix() {
436        const START_INDEX: usize = 1;
437        const VALUE_LEN: usize = 1048576; // 512 * 256
438        const WIDTH: usize = 1024;
439        const HEIGHT: usize = 1024;
440
441        let matrix_values = (START_INDEX..=VALUE_LEN).collect::<Vec<_>>();
442        let matrix = RowMajorMatrix::new(matrix_values, WIDTH);
443        let transposed = matrix.clone().transpose();
444
445        assert_eq!(transposed.width(), HEIGHT);
446        assert_eq!(transposed.height(), WIDTH);
447
448        for col_index in 0..WIDTH {
449            for row_index in 0..HEIGHT {
450                assert_eq!(
451                    matrix.values[row_index * WIDTH + col_index],
452                    transposed.values[col_index * HEIGHT + row_index]
453                );
454            }
455        }
456    }
457}