p3_matrix/
lib.rs

1//! Matrix library.
2
3#![no_std]
4
5extern crate alloc;
6
7use alloc::vec;
8use alloc::vec::Vec;
9use core::fmt::{Debug, Display, Formatter};
10use core::ops::Deref;
11
12use itertools::{izip, Itertools};
13use p3_field::{ExtensionField, Field, PackedValue};
14use p3_maybe_rayon::prelude::*;
15use serde::{Deserialize, Serialize};
16use strided::{VerticallyStridedMatrixView, VerticallyStridedRowIndexMap};
17
18use crate::dense::RowMajorMatrix;
19
20pub mod bitrev;
21pub mod dense;
22pub mod extension;
23pub mod mul;
24pub mod row_index_mapped;
25pub mod sparse;
26pub mod stack;
27pub mod strided;
28pub mod util;
29
30#[derive(Clone, Copy, Serialize, Deserialize)]
31pub struct Dimensions {
32    pub width: usize,
33    pub height: usize,
34}
35
36impl Debug for Dimensions {
37    fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
38        write!(f, "{}x{}", self.width, self.height)
39    }
40}
41
42impl Display for Dimensions {
43    fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
44        write!(f, "{}x{}", self.width, self.height)
45    }
46}
47
48pub trait Matrix<T: Send + Sync>: Send + Sync {
49    fn width(&self) -> usize;
50    fn height(&self) -> usize;
51
52    fn dimensions(&self) -> Dimensions {
53        Dimensions {
54            width: self.width(),
55            height: self.height(),
56        }
57    }
58
59    fn get(&self, r: usize, c: usize) -> T {
60        self.row(r).nth(c).unwrap()
61    }
62
63    type Row<'a>: Iterator<Item = T> + Send + Sync
64    where
65        Self: 'a;
66
67    fn row(&self, r: usize) -> Self::Row<'_>;
68
69    fn rows(&self) -> impl Iterator<Item = Self::Row<'_>> {
70        (0..self.height()).map(move |r| self.row(r))
71    }
72
73    fn par_rows(&self) -> impl IndexedParallelIterator<Item = Self::Row<'_>> {
74        (0..self.height()).into_par_iter().map(move |r| self.row(r))
75    }
76
77    // Opaque return type implicitly captures &'_ self
78    fn row_slice(&self, r: usize) -> impl Deref<Target = [T]> {
79        self.row(r).collect_vec()
80    }
81
82    fn first_row(&self) -> Self::Row<'_> {
83        self.row(0)
84    }
85
86    fn last_row(&self) -> Self::Row<'_> {
87        self.row(self.height() - 1)
88    }
89
90    fn to_row_major_matrix(self) -> RowMajorMatrix<T>
91    where
92        Self: Sized,
93        T: Clone,
94    {
95        RowMajorMatrix::new(
96            (0..self.height()).flat_map(|r| self.row(r)).collect(),
97            self.width(),
98        )
99    }
100
101    fn horizontally_packed_row<'a, P>(
102        &'a self,
103        r: usize,
104    ) -> (impl Iterator<Item = P>, impl Iterator<Item = T>)
105    where
106        P: PackedValue<Value = T>,
107        T: Clone + 'a,
108    {
109        let num_packed = self.width() / P::WIDTH;
110        let packed = (0..num_packed).map(move |c| P::from_fn(|i| self.get(r, P::WIDTH * c + i)));
111        let sfx = (num_packed * P::WIDTH..self.width()).map(move |c| self.get(r, c));
112        (packed, sfx)
113    }
114
115    /// Wraps at the end.
116    fn vertically_packed_row<P>(&self, r: usize) -> impl Iterator<Item = P>
117    where
118        P: PackedValue<Value = T>,
119    {
120        (0..self.width()).map(move |c| P::from_fn(|i| self.get((r + i) % self.height(), c)))
121    }
122
123    fn vertically_strided(self, stride: usize, offset: usize) -> VerticallyStridedMatrixView<Self>
124    where
125        Self: Sized,
126    {
127        VerticallyStridedRowIndexMap::new_view(self, stride, offset)
128    }
129
130    /// Compute Mᵀv, aka premultiply this matrix by the given vector,
131    /// aka scale each row by the corresponding entry in `v` and take the row-wise sum.
132    /// `v` can be a vector of extension elements.
133    fn columnwise_dot_product<EF>(&self, v: &[EF]) -> Vec<EF>
134    where
135        T: Field,
136        EF: ExtensionField<T>,
137    {
138        self.par_rows().zip(v).par_fold_reduce(
139            || vec![EF::zero(); self.width()],
140            |mut acc, (row, &scale)| {
141                izip!(&mut acc, row).for_each(|(a, x)| *a += scale * x);
142                acc
143            },
144            |mut acc_l, acc_r| {
145                izip!(&mut acc_l, acc_r).for_each(|(l, r)| *l += r);
146                acc_l
147            },
148        )
149    }
150}