use ff::Field;
use num_traits::Zero;
use std::ops::{Add, AddAssign, Index, IndexMut, Mul, Sub, SubAssign};
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct Matrix<T: Copy> {
data: Vec<T>,
pub nrows: usize,
pub ncols: usize,
}
impl<T: Copy> Matrix<T> {
pub fn new(size: (usize, usize), item: T) -> Self {
Matrix {
data: vec![item; size.0 * size.1],
nrows: size.0,
ncols: size.1,
}
}
pub fn new_from_iter<U: Iterator<Item = T>>(size: (usize, usize), iterator: U) -> Self {
let data: Vec<T> = iterator.collect();
assert_eq!(
data.len(),
size.0 * size.1,
"iterator of size {} for matrix of size {}x{}",
data.len(),
size.0,
size.1
);
Matrix {
data,
nrows: size.0,
ncols: size.1,
}
}
pub fn new_from_column_major_iter<U: Iterator<Item = T>>(
size: (usize, usize),
iterator: U,
) -> Self {
let column_major_data = iterator.collect::<Vec<T>>();
assert_eq!(
column_major_data.len(),
size.0 * size.1,
"iterator of size {} for matrix of size {}x{}",
column_major_data.len(),
size.0,
size.1
);
let row_major_data = (0..size.0)
.flat_map(|i| {
(0..size.1)
.map(|j| column_major_data[i + j * size.0])
.collect::<Vec<T>>()
})
.collect();
Matrix {
data: row_major_data,
nrows: size.0,
ncols: size.1,
}
}
fn index(&self, x: usize, y: usize) -> usize {
x * self.ncols + y
}
pub fn get(&self, location: (usize, usize)) -> Option<&T> {
let (x, y) = location;
let index = self.index(x, y);
self.data.get(index)
}
pub fn get_mut(&mut self, location: (usize, usize)) -> Option<&mut T> {
let (x, y) = location;
let index = self.index(x, y);
self.data.get_mut(index)
}
pub fn col(&self, index: usize) -> Self {
if index >= self.ncols {
panic!(
"index for column extraction must be less than {} (found {})",
self.ncols, index
);
}
let data = (index..self.nrows * self.ncols)
.step_by(self.ncols)
.map(|i| self.data[i])
.collect();
Self {
data,
nrows: self.nrows,
ncols: 1,
}
}
pub fn map_mut(&mut self, mut f: impl FnMut(T) -> T) {
for x in 0..self.nrows {
for y in 0..self.ncols {
self[(x, y)] = f(self[(x, y)]);
}
}
}
pub fn mat_mul<U: Copy + Add<Output = U> + Mul<T, Output = U> + Zero>(
&self,
rhs: &Matrix<U>,
) -> Matrix<U> {
assert_eq!(self.ncols, rhs.nrows);
let mut mat: Matrix<U> = Matrix::new((self.nrows, rhs.ncols), U::zero());
for i in 0..self.nrows {
for j in 0..rhs.ncols {
let acc = mat.get_mut((i, j)).unwrap();
for k in 0..self.ncols {
*acc = *acc + rhs[(k, j)] * self[(i, k)];
}
}
}
mat
}
pub fn det(&self) -> T
where
T: Field,
{
assert!(self.nrows == self.ncols && self.ncols != 0);
let n = self.ncols;
let mut det = T::ONE;
let mut rows;
rows = self
.data
.chunks(n)
.map(|c| c.to_vec())
.collect::<Vec<_>>()
.clone();
for _ in 0..n {
let (lz_rows_vec, nlz_rows_vec): (Vec<_>, Vec<_>) =
rows.iter().partition(|row| row.starts_with(&[T::ZERO]));
let (lz_rows, mut nlz_rows) = (lz_rows_vec.iter(), nlz_rows_vec.iter());
let Some(pivot) = nlz_rows.next() else {
return T::ZERO;
};
det *= pivot[0];
let pivot_inverse = pivot[0].invert().unwrap();
let normalized_pivot: Vec<_> = pivot.iter().map(|f| *f * pivot_inverse).collect();
let processed_nlz_rows = nlz_rows.map(|row| {
let lead = row[0];
let row: Vec<_> = row
.iter()
.zip(&normalized_pivot)
.map(move |(f, p)| *f - lead * p)
.collect();
row
});
rows = processed_nlz_rows
.chain(lz_rows.map(|c| c.to_vec()))
.map(|mut v| v.drain(1..).collect::<Vec<_>>())
.collect::<Vec<_>>();
}
det
}
pub fn convert<U: From<T> + Copy>(&self) -> Matrix<U> {
Matrix::new_from_iter(
(self.nrows, self.ncols),
self.into_iter().map(|c| U::from(c)),
)
}
}
impl<T: Copy> Index<(usize, usize)> for Matrix<T> {
type Output = T;
fn index(&self, index: (usize, usize)) -> &Self::Output {
self.get(index).unwrap()
}
}
impl<T: Copy> IndexMut<(usize, usize)> for Matrix<T> {
fn index_mut(&mut self, index: (usize, usize)) -> &mut Self::Output {
self.get_mut(index).unwrap()
}
}
impl<T: Copy> IntoIterator for Matrix<T> {
type Item = T;
type IntoIter = std::vec::IntoIter<Self::Item>;
fn into_iter(self) -> Self::IntoIter {
self.data.into_iter()
}
}
impl<T: Copy> IntoIterator for &Matrix<T> {
type Item = T;
type IntoIter = std::vec::IntoIter<Self::Item>;
fn into_iter(self) -> Self::IntoIter {
self.data.clone().into_iter()
}
}
impl<'a, T: Copy + Add<Output = T>> AddAssign<&'a Matrix<T>> for Matrix<T> {
fn add_assign(&mut self, rhs: &'a Matrix<T>) {
assert_eq!(self.nrows, rhs.nrows);
assert_eq!(self.ncols, rhs.ncols);
for i in 0..self.nrows {
for j in 0..self.ncols {
self[(i, j)] = self[(i, j)] + rhs[(i, j)];
}
}
}
}
impl<'a, T: Copy + Sub<Output = T>> SubAssign<&'a Matrix<T>> for Matrix<T> {
fn sub_assign(&mut self, rhs: &'a Matrix<T>) {
assert_eq!(self.nrows, rhs.nrows);
assert_eq!(self.ncols, rhs.ncols);
for i in 0..self.nrows {
for j in 0..self.ncols {
self[(i, j)] = self[(i, j)] - rhs[(i, j)];
}
}
}
}
impl<T: Copy + Add<Output = T>> Add for Matrix<T> {
type Output = Matrix<T>;
fn add(mut self, rhs: Self) -> Self::Output {
self += &rhs;
self
}
}
impl<T: Copy + Sub<Output = T>> Sub for Matrix<T> {
type Output = Matrix<T>;
fn sub(mut self, rhs: Self) -> Self::Output {
self -= &rhs;
self
}
}
impl<T: Copy> From<Vec<T>> for Matrix<T> {
fn from(v: Vec<T>) -> Self {
let nrows = v.len();
Self::new_from_iter((nrows, 1), v.into_iter())
}
}
impl<'a, T: Copy> From<&'a [T]> for Matrix<T> {
fn from(v: &'a [T]) -> Self {
let nrows = v.len();
Self::new_from_iter((nrows, 1), v.iter().copied())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::utils::field::ScalarField;
use ff::Field;
type F = ScalarField;
#[test]
fn test_det_dim3() {
let data = vec![
F::from(4),
F::from(2),
F::from(4),
F::ZERO,
F::ZERO,
F::from(3),
F::from(5),
F::from(7),
F::from(7),
];
let mat = Matrix::new_from_iter((3, 3), data.into_iter());
let det = mat.det();
assert_eq!(F::from(54), det);
}
#[test]
fn test_det_dim4() {
let data = vec![
F::from(6),
F::from(4),
F::from(7),
F::from(8),
F::from(9),
F::from(3),
F::from(9),
F::from(8),
F::from(8),
F::from(3),
F::from(4),
F::from(9),
F::from(5),
F::from(4),
F::from(1),
F::from(3),
];
let mat = Matrix::new_from_iter((4, 4), data.into_iter());
let det = mat.det();
assert_eq!(F::from(-476), det);
}
}