p3_matrix/
lib.rs

1//! Matrix library.
2
3#![no_std]
4
5extern crate alloc;
6
7use alloc::vec;
8use alloc::vec::Vec;
9use core::fmt::{Debug, Display, Formatter};
10use core::ops::Deref;
11
12use itertools::{izip, Itertools};
13use p3_field::{dot_product, AbstractExtensionField, ExtensionField, Field, PackedValue};
14use p3_maybe_rayon::prelude::*;
15use strided::{VerticallyStridedMatrixView, VerticallyStridedRowIndexMap};
16
17use crate::dense::RowMajorMatrix;
18
19pub mod bitrev;
20pub mod dense;
21pub mod extension;
22pub mod mul;
23pub mod row_index_mapped;
24pub mod sparse;
25pub mod stack;
26pub mod strided;
27pub mod util;
28
29#[derive(Clone, Copy)]
30pub struct Dimensions {
31    pub width: usize,
32    pub height: usize,
33}
34
35impl Debug for Dimensions {
36    fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
37        write!(f, "{}x{}", self.width, self.height)
38    }
39}
40
41impl Display for Dimensions {
42    fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
43        write!(f, "{}x{}", self.width, self.height)
44    }
45}
46
47pub trait Matrix<T: Send + Sync>: Send + Sync {
48    fn width(&self) -> usize;
49    fn height(&self) -> usize;
50
51    fn dimensions(&self) -> Dimensions {
52        Dimensions {
53            width: self.width(),
54            height: self.height(),
55        }
56    }
57
58    fn get(&self, r: usize, c: usize) -> T {
59        self.row(r).nth(c).unwrap()
60    }
61
62    type Row<'a>: Iterator<Item = T> + Send + Sync
63    where
64        Self: 'a;
65
66    fn row(&self, r: usize) -> Self::Row<'_>;
67
68    fn rows(&self) -> impl Iterator<Item = Self::Row<'_>> {
69        (0..self.height()).map(move |r| self.row(r))
70    }
71
72    fn par_rows(&self) -> impl IndexedParallelIterator<Item = Self::Row<'_>> {
73        (0..self.height()).into_par_iter().map(move |r| self.row(r))
74    }
75
76    // Opaque return type implicitly captures &'_ self
77    fn row_slice(&self, r: usize) -> impl Deref<Target = [T]> {
78        self.row(r).collect_vec()
79    }
80
81    fn first_row(&self) -> Self::Row<'_> {
82        self.row(0)
83    }
84
85    fn last_row(&self) -> Self::Row<'_> {
86        self.row(self.height() - 1)
87    }
88
89    fn to_row_major_matrix(self) -> RowMajorMatrix<T>
90    where
91        Self: Sized,
92        T: Clone,
93    {
94        RowMajorMatrix::new(
95            (0..self.height()).flat_map(|r| self.row(r)).collect(),
96            self.width(),
97        )
98    }
99
100    fn horizontally_packed_row<'a, P>(
101        &'a self,
102        r: usize,
103    ) -> (
104        impl Iterator<Item = P> + Send + Sync,
105        impl Iterator<Item = T> + Send + Sync,
106    )
107    where
108        P: PackedValue<Value = T>,
109        T: Clone + 'a,
110    {
111        let num_packed = self.width() / P::WIDTH;
112        let packed = (0..num_packed).map(move |c| P::from_fn(|i| self.get(r, P::WIDTH * c + i)));
113        let sfx = (num_packed * P::WIDTH..self.width()).map(move |c| self.get(r, c));
114        (packed, sfx)
115    }
116
117    /// Zero padded.
118    fn padded_horizontally_packed_row<'a, P>(
119        &'a self,
120        r: usize,
121    ) -> impl Iterator<Item = P> + Send + Sync
122    where
123        P: PackedValue<Value = T>,
124        T: Clone + Default + 'a,
125    {
126        let mut row_iter = self.row(r);
127        let num_elems = self.width().next_multiple_of(P::WIDTH);
128        // array::from_fn currently always calls in order, but it's not clear whether that's guaranteed.
129        (0..num_elems).map(move |_| P::from_fn(|_| row_iter.next().unwrap_or_default()))
130    }
131
132    fn par_horizontally_packed_rows<'a, P>(
133        &'a self,
134    ) -> impl IndexedParallelIterator<
135        Item = (
136            impl Iterator<Item = P> + Send + Sync,
137            impl Iterator<Item = T> + Send + Sync,
138        ),
139    >
140    where
141        P: PackedValue<Value = T>,
142        T: Clone + 'a,
143    {
144        (0..self.height())
145            .into_par_iter()
146            .map(|r| self.horizontally_packed_row(r))
147    }
148
149    fn par_padded_horizontally_packed_rows<'a, P>(
150        &'a self,
151    ) -> impl IndexedParallelIterator<Item = impl Iterator<Item = P> + Send + Sync>
152    where
153        P: PackedValue<Value = T>,
154        T: Clone + Default + 'a,
155    {
156        (0..self.height())
157            .into_par_iter()
158            .map(|r| self.padded_horizontally_packed_row(r))
159    }
160
161    /// Wraps at the end.
162    fn vertically_packed_row<P>(&self, r: usize) -> impl Iterator<Item = P>
163    where
164        P: PackedValue<Value = T>,
165    {
166        (0..self.width()).map(move |c| P::from_fn(|i| self.get((r + i) % self.height(), c)))
167    }
168
169    fn vertically_strided(self, stride: usize, offset: usize) -> VerticallyStridedMatrixView<Self>
170    where
171        Self: Sized,
172    {
173        VerticallyStridedRowIndexMap::new_view(self, stride, offset)
174    }
175
176    /// Compute Mᵀv, aka premultiply this matrix by the given vector,
177    /// aka scale each row by the corresponding entry in `v` and take the row-wise sum.
178    /// `v` can be a vector of extension elements.
179    fn columnwise_dot_product<EF>(&self, v: &[EF]) -> Vec<EF>
180    where
181        T: Field,
182        EF: ExtensionField<T>,
183    {
184        self.par_rows().zip(v).par_fold_reduce(
185            || vec![EF::zero(); self.width()],
186            |mut acc, (row, &scale)| {
187                izip!(&mut acc, row).for_each(|(a, x)| *a += scale * x);
188                acc
189            },
190            |mut acc_l, acc_r| {
191                izip!(&mut acc_l, acc_r).for_each(|(l, r)| *l += r);
192                acc_l
193            },
194        )
195    }
196
197    /// Multiply this matrix by the vector of powers of `base`, which is an extension element.
198    fn dot_ext_powers<EF>(&self, base: EF) -> impl IndexedParallelIterator<Item = EF>
199    where
200        T: Field,
201        EF: ExtensionField<T>,
202    {
203        let powers_packed = base
204            .ext_powers_packed()
205            .take(self.width().next_multiple_of(T::Packing::WIDTH))
206            .collect_vec();
207        self.par_padded_horizontally_packed_rows::<T::Packing>()
208            .map(move |row_packed| {
209                let packed_sum_of_packed: EF::ExtensionPacking =
210                    dot_product(powers_packed.iter().copied(), row_packed);
211                let sum_of_packed: EF = EF::from_base_fn(|i| {
212                    packed_sum_of_packed.as_base_slice()[i]
213                        .as_slice()
214                        .iter()
215                        .copied()
216                        .sum()
217                });
218                sum_of_packed
219            })
220    }
221}