use std::ops::Index;
use std::ops::IndexMut;
use crate::*;
#[derive(Debug, Clone, PartialEq)]
pub struct Matrix<E: RingElement> {
width: usize,
height: usize,
entries: Vec<Vector<E>>,
}
impl<E: RingElement> Matrix<E> {
pub fn zero(height: usize, width: usize) -> Self {
Self {
width,
height,
entries: vec![Vector::new(width); height],
}
}
pub fn random<R: Rng>(height: usize, width: usize, rng: &mut R) -> Self {
let mut entries = Vec::with_capacity(height);
for _ in 0..height {
entries.push(Vector::sample_uniform(width, rng));
}
Self {
width,
height,
entries,
}
}
pub fn dimension(&self) -> (usize, usize) {
(self.height, self.width)
}
pub fn height(&self) -> usize {
self.dimension().0
}
pub fn width(&self) -> usize {
self.dimension().1
}
pub fn identity(size: usize) -> Self {
let mut matrix = Self::zero(size, size);
for i in 0..size {
matrix[i][i] = E::one();
}
matrix
}
pub fn compose_horizontal(mut self, other: Self) -> Self {
let (self_height, self_width) = self.dimension();
let (other_height, other_width) = other.dimension();
assert_eq!(
self_height, other_height,
"Matrix::compose_horizontal cannot compose matrices of unequal height"
);
self.width = self_width + other_width;
for (self_row, other_row) in self.entries.iter_mut().zip(other.entries.into_iter()) {
self_row.append(other_row);
}
self
}
pub fn compose_vertical(mut self, mut other: Self) -> Self {
let (self_height, self_width) = self.dimension();
let (other_height, other_width) = other.dimension();
assert_eq!(
self_width, other_width,
"Matrix::compose_vertical cannot compose matrices of unequal width"
);
self.height = self_height + other_height;
self.entries.append(&mut other.entries);
self
}
pub fn iter(&self) -> impl Iterator<Item = &Vector<E>> {
self.entries.iter()
}
}
impl<E: RingElement> FromIterator<Vector<E>> for Matrix<E> {
fn from_iter<T: IntoIterator<Item = Vector<E>>>(iter: T) -> Self {
let mut width = None;
let entries = iter
.into_iter()
.map(|row| {
if width.is_none() {
width = Some(row.len());
}
assert_eq!(
width.unwrap(),
row.len(),
"row length mismatch in Matrix contruction from iterator"
);
row
})
.collect::<Vec<_>>();
Self {
width: width.unwrap_or(0),
height: entries.len(),
entries,
}
}
}
impl<E: RingElement> Index<usize> for Matrix<E> {
type Output = Vector<E>;
fn index(&self, index: usize) -> &Self::Output {
&self.entries[index]
}
}
impl<E: RingElement> IndexMut<usize> for Matrix<E> {
fn index_mut(&mut self, index: usize) -> &mut Self::Output {
&mut self.entries[index]
}
}
impl<E: RingElement> AddAssign<&Self> for Matrix<E> {
fn add_assign(&mut self, rhs: &Self) {
assert_eq!(
self.width, rhs.width,
"cannot add matrices of different width"
);
assert_eq!(
self.height, rhs.height,
"cannot add matrices of different height"
);
for self_row in self.entries.iter_mut() {
for other_row in rhs.entries.iter() {
*self_row += other_row;
}
}
}
}
impl<E: RingElement> Add<&Self> for Matrix<E> {
type Output = Self;
fn add(mut self, rhs: &Self) -> Self::Output {
self += rhs;
self
}
}
impl<E: RingElement> MulAssign<&Self> for Matrix<E> {
fn mul_assign(&mut self, rhs: &Self) {
assert_eq!(
self.width, rhs.width,
"cannot mul matrices of different width"
);
assert_eq!(
self.height, rhs.height,
"cannot mul matrices of different height"
);
for self_row in self.entries.iter_mut() {
for other_row in rhs.entries.iter() {
*self_row *= other_row;
}
}
}
}
impl<E: RingElement> Mul<&Self> for Matrix<E> {
type Output = Self;
fn mul(mut self, rhs: &Self) -> Self::Output {
self *= rhs;
self
}
}
impl<E: RingElement> Mul<&Vector<E>> for Matrix<E> {
type Output = Vector<E>;
fn mul(self, rhs: &Vector<E>) -> Self::Output {
assert_eq!(self.width(), rhs.len());
#[cfg(not(feature = "rayon"))]
{
self.entries
.into_iter()
.map(|row| (row * rhs).into_sum())
.collect::<Vector<_>>()
}
#[cfg(feature = "rayon")]
{
use rayon::prelude::*;
self.entries
.par_iter()
.map(|row| (row.clone() * rhs).into_sum())
.collect::<Vec<E>>()
.into()
}
}
}
impl<E: RingElement> Mul<&Vector<E>> for &Matrix<E> {
type Output = Vector<E>;
fn mul(self, rhs: &Vector<E>) -> Self::Output {
assert_eq!(self.width(), rhs.len());
#[cfg(not(feature = "rayon"))]
{
self.entries
.iter()
.map(|row| {
let mut sum = E::zero();
for v in row.iter().zip(rhs.iter()) {
sum += *v.0 * *v.1;
}
sum
})
.collect::<Vector<_>>()
}
#[cfg(feature = "rayon")]
{
use rayon::prelude::*;
self.entries
.par_iter()
.map(|row| {
let mut sum = E::zero();
for v in row.iter().zip(rhs.iter()) {
sum += *v.0 * *v.1;
}
sum
})
.collect::<Vec<E>>()
.into()
}
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn matrix_dimension() {
type Field = MilliScalarMont;
let width = 50;
let height = 100;
let m = Matrix::<Field>::zero(height, width);
assert_eq!(m.dimension(), (height, width));
assert_eq!(m.height(), height);
assert_eq!(m.width(), width);
}
#[test]
fn matrix_identity() {
type Field = MilliScalarMont;
let mut rng = rand::rng();
for s in 1..100 {
let identity = Matrix::<Field>::identity(s);
let vec = Vector::sample_uniform(s, &mut rng);
assert_eq!(vec, &identity * &vec);
}
}
#[test]
fn matrix_compose_horizontal() {
type Field = MilliScalarMont;
let mut rng = rand::rng();
let width1 = 100;
let width2 = 50;
let m1 = Matrix::<Field>::random(200, width1, &mut rng);
let m2 = Matrix::<Field>::random(200, width2, &mut rng);
let m_composed = m1.compose_horizontal(m2.clone());
assert_eq!(m_composed.width(), width1 + width2);
for row in m_composed.iter() {
assert_eq!(row.len(), width1 + width2);
}
}
#[test]
fn matrix_compose_vertical() {
type Field = MilliScalarMont;
let mut rng = rand::rng();
let height1 = 100;
let height2 = 50;
let m1 = Matrix::<Field>::random(height1, 200, &mut rng);
let m2 = Matrix::<Field>::random(height2, 200, &mut rng);
let m_composed = m1.clone().compose_vertical(m2.clone());
assert_eq!(m_composed.height(), height1 + height2);
assert_eq!(m_composed.entries.len(), height1 + height2);
for (row_composed, row) in m_composed.iter().zip(m1.iter().chain(m2.iter())) {
assert_eq!(row_composed, row);
}
}
}