use alloc::vec;
use core::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign};
use crate::traits::Scalar;
use super::DynMatrix;
impl<T: Scalar> Add for DynMatrix<T> {
type Output = Self;
fn add(self, rhs: Self) -> Self {
&self + &rhs
}
}
impl<T: Scalar> Add<&DynMatrix<T>> for DynMatrix<T> {
type Output = DynMatrix<T>;
fn add(self, rhs: &DynMatrix<T>) -> DynMatrix<T> {
&self + rhs
}
}
impl<T: Scalar> Add<DynMatrix<T>> for &DynMatrix<T> {
type Output = DynMatrix<T>;
fn add(self, rhs: DynMatrix<T>) -> DynMatrix<T> {
self + &rhs
}
}
impl<T: Scalar> Add<&DynMatrix<T>> for &DynMatrix<T> {
type Output = DynMatrix<T>;
fn add(self, rhs: &DynMatrix<T>) -> DynMatrix<T> {
assert_eq!(
(self.nrows, self.ncols),
(rhs.nrows, rhs.ncols),
"dimension mismatch: {}x{} + {}x{}",
self.nrows, self.ncols, rhs.nrows, rhs.ncols,
);
let mut data = vec![T::zero(); self.data.len()];
crate::simd::add_slices_dispatch(&self.data, &rhs.data, &mut data);
DynMatrix {
data,
nrows: self.nrows,
ncols: self.ncols,
}
}
}
impl<T: Scalar> AddAssign for DynMatrix<T> {
fn add_assign(&mut self, rhs: Self) {
self.add_assign(&rhs);
}
}
impl<T: Scalar> AddAssign<&DynMatrix<T>> for DynMatrix<T> {
fn add_assign(&mut self, rhs: &DynMatrix<T>) {
assert_eq!(
(self.nrows, self.ncols),
(rhs.nrows, rhs.ncols),
"dimension mismatch: {}x{} += {}x{}",
self.nrows, self.ncols, rhs.nrows, rhs.ncols,
);
crate::simd::scalar::add_assign_slices(&mut self.data, &rhs.data);
}
}
impl<T: Scalar> Sub for DynMatrix<T> {
type Output = Self;
fn sub(self, rhs: Self) -> Self {
&self - &rhs
}
}
impl<T: Scalar> Sub<&DynMatrix<T>> for DynMatrix<T> {
type Output = DynMatrix<T>;
fn sub(self, rhs: &DynMatrix<T>) -> DynMatrix<T> {
&self - rhs
}
}
impl<T: Scalar> Sub<DynMatrix<T>> for &DynMatrix<T> {
type Output = DynMatrix<T>;
fn sub(self, rhs: DynMatrix<T>) -> DynMatrix<T> {
self - &rhs
}
}
impl<T: Scalar> Sub<&DynMatrix<T>> for &DynMatrix<T> {
type Output = DynMatrix<T>;
fn sub(self, rhs: &DynMatrix<T>) -> DynMatrix<T> {
assert_eq!(
(self.nrows, self.ncols),
(rhs.nrows, rhs.ncols),
"dimension mismatch: {}x{} - {}x{}",
self.nrows, self.ncols, rhs.nrows, rhs.ncols,
);
let mut data = vec![T::zero(); self.data.len()];
crate::simd::sub_slices_dispatch(&self.data, &rhs.data, &mut data);
DynMatrix {
data,
nrows: self.nrows,
ncols: self.ncols,
}
}
}
impl<T: Scalar> SubAssign for DynMatrix<T> {
fn sub_assign(&mut self, rhs: Self) {
self.sub_assign(&rhs);
}
}
impl<T: Scalar> SubAssign<&DynMatrix<T>> for DynMatrix<T> {
fn sub_assign(&mut self, rhs: &DynMatrix<T>) {
assert_eq!(
(self.nrows, self.ncols),
(rhs.nrows, rhs.ncols),
"dimension mismatch",
);
crate::simd::scalar::sub_assign_slices(&mut self.data, &rhs.data);
}
}
impl<T: Scalar> Neg for DynMatrix<T> {
type Output = Self;
fn neg(self) -> Self {
let data = self.data.iter().map(|&x| T::zero() - x).collect();
DynMatrix {
data,
nrows: self.nrows,
ncols: self.ncols,
}
}
}
impl<T: Scalar> Neg for &DynMatrix<T> {
type Output = DynMatrix<T>;
fn neg(self) -> DynMatrix<T> {
let data = self.data.iter().map(|&x| T::zero() - x).collect();
DynMatrix {
data,
nrows: self.nrows,
ncols: self.ncols,
}
}
}
impl<T: Scalar> Mul for DynMatrix<T> {
type Output = Self;
fn mul(self, rhs: Self) -> Self {
&self * &rhs
}
}
impl<T: Scalar> Mul<&DynMatrix<T>> for DynMatrix<T> {
type Output = DynMatrix<T>;
fn mul(self, rhs: &DynMatrix<T>) -> DynMatrix<T> {
&self * rhs
}
}
impl<T: Scalar> Mul<DynMatrix<T>> for &DynMatrix<T> {
type Output = DynMatrix<T>;
fn mul(self, rhs: DynMatrix<T>) -> DynMatrix<T> {
self * &rhs
}
}
impl<T: Scalar> Mul<&DynMatrix<T>> for &DynMatrix<T> {
type Output = DynMatrix<T>;
fn mul(self, rhs: &DynMatrix<T>) -> DynMatrix<T> {
assert_eq!(
self.ncols, rhs.nrows,
"dimension mismatch: {}x{} * {}x{}",
self.nrows, self.ncols, rhs.nrows, rhs.ncols,
);
let m = self.nrows;
let n = self.ncols;
let p = rhs.ncols;
let mut data = vec![T::zero(); m * p];
crate::simd::matmul_dispatch(&self.data, &rhs.data, &mut data, m, n, p);
DynMatrix {
data,
nrows: m,
ncols: p,
}
}
}
impl<T: Scalar> Mul<T> for DynMatrix<T> {
type Output = Self;
fn mul(self, rhs: T) -> Self {
&self * rhs
}
}
impl<T: Scalar> Mul<T> for &DynMatrix<T> {
type Output = DynMatrix<T>;
fn mul(self, rhs: T) -> DynMatrix<T> {
let mut data = vec![T::zero(); self.data.len()];
crate::simd::scale_slices_dispatch(&self.data, rhs, &mut data);
DynMatrix {
data,
nrows: self.nrows,
ncols: self.ncols,
}
}
}
impl<T: Scalar> MulAssign<T> for DynMatrix<T> {
fn mul_assign(&mut self, rhs: T) {
crate::simd::scalar::scale_assign_slices(&mut self.data, rhs);
}
}
macro_rules! impl_scalar_mul_dyn {
($($t:ty),*) => {
$(
impl Mul<DynMatrix<$t>> for $t {
type Output = DynMatrix<$t>;
fn mul(self, rhs: DynMatrix<$t>) -> DynMatrix<$t> {
rhs * self
}
}
impl Mul<&DynMatrix<$t>> for $t {
type Output = DynMatrix<$t>;
fn mul(self, rhs: &DynMatrix<$t>) -> DynMatrix<$t> {
rhs * self
}
}
)*
};
}
impl_scalar_mul_dyn!(f32, f64, i8, i16, i32, i64, i128, u8, u16, u32, u64, u128);
impl<T: Scalar> Div<T> for DynMatrix<T> {
type Output = Self;
fn div(self, rhs: T) -> Self {
let data = self.data.iter().map(|&x| x / rhs).collect();
DynMatrix {
data,
nrows: self.nrows,
ncols: self.ncols,
}
}
}
impl<T: Scalar> Div<T> for &DynMatrix<T> {
type Output = DynMatrix<T>;
fn div(self, rhs: T) -> DynMatrix<T> {
let data = self.data.iter().map(|&x| x / rhs).collect();
DynMatrix {
data,
nrows: self.nrows,
ncols: self.ncols,
}
}
}
impl<T: Scalar> DivAssign<T> for DynMatrix<T> {
fn div_assign(&mut self, rhs: T) {
for x in self.data.iter_mut() {
*x = *x / rhs;
}
}
}
impl<T: Scalar> DynMatrix<T> {
pub fn element_mul(&self, rhs: &Self) -> Self {
assert_eq!(
(self.nrows, self.ncols),
(rhs.nrows, rhs.ncols),
"dimension mismatch",
);
let data = self
.data
.iter()
.zip(rhs.data.iter())
.map(|(&a, &b)| a * b)
.collect();
DynMatrix {
data,
nrows: self.nrows,
ncols: self.ncols,
}
}
pub fn element_div(&self, rhs: &Self) -> Self {
assert_eq!(
(self.nrows, self.ncols),
(rhs.nrows, rhs.ncols),
"dimension mismatch",
);
let data = self
.data
.iter()
.zip(rhs.data.iter())
.map(|(&a, &b)| a / b)
.collect();
DynMatrix {
data,
nrows: self.nrows,
ncols: self.ncols,
}
}
pub fn kronecker(&self, rhs: &Self) -> Self {
let m = self.nrows;
let n = self.ncols;
let p = rhs.nrows;
let q = rhs.ncols;
let out_rows = m * p;
let out_cols = n * q;
let mut data = vec![T::zero(); out_rows * out_cols];
for j in 0..n {
for l in 0..q {
let out_col = j * q + l;
let self_col = &self.data[j * m..(j + 1) * m];
let rhs_col = &rhs.data[l * p..(l + 1) * p];
let out_slice = &mut data[out_col * out_rows..(out_col + 1) * out_rows];
for i in 0..m {
let a_ij = self_col[i];
let base = i * p;
for k in 0..p {
out_slice[base + k] = a_ij * rhs_col[k];
}
}
}
}
DynMatrix {
data,
nrows: out_rows,
ncols: out_cols,
}
}
pub fn transpose(&self) -> Self
where
T: Copy,
{
let m = self.nrows;
let n = self.ncols;
DynMatrix::from_fn(n, m, |i, j| self[(j, i)])
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn add_sub() {
let a = DynMatrix::from_rows(2, 2, &[1.0, 2.0, 3.0, 4.0]);
let b = DynMatrix::from_rows(2, 2, &[5.0, 6.0, 7.0, 8.0]);
let c = &a + &b;
assert_eq!(c[(0, 0)], 6.0);
assert_eq!(c[(1, 1)], 12.0);
let d = &b - &a;
assert_eq!(d[(0, 0)], 4.0);
assert_eq!(d[(1, 1)], 4.0);
}
#[test]
fn add_assign() {
let mut a = DynMatrix::from_rows(2, 2, &[1.0, 2.0, 3.0, 4.0]);
let b = DynMatrix::from_rows(2, 2, &[5.0, 6.0, 7.0, 8.0]);
a += &b;
assert_eq!(a[(0, 0)], 6.0);
a -= &b;
assert_eq!(a[(0, 0)], 1.0);
}
#[test]
fn neg() {
let a = DynMatrix::from_rows(2, 2, &[1.0, -2.0, 3.0, -4.0]);
let b = -a;
assert_eq!(b[(0, 0)], -1.0);
assert_eq!(b[(0, 1)], 2.0);
}
#[test]
fn matrix_multiply() {
let a = DynMatrix::from_rows(2, 2, &[1.0, 2.0, 3.0, 4.0]);
let b = DynMatrix::from_rows(2, 2, &[5.0, 6.0, 7.0, 8.0]);
let c = &a * &b;
assert_eq!(c[(0, 0)], 19.0);
assert_eq!(c[(0, 1)], 22.0);
assert_eq!(c[(1, 0)], 43.0);
assert_eq!(c[(1, 1)], 50.0);
}
#[test]
fn matrix_multiply_non_square() {
let a = DynMatrix::from_rows(2, 3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
let b = DynMatrix::from_rows(3, 2, &[7.0, 8.0, 9.0, 10.0, 11.0, 12.0]);
let c = &a * &b;
assert_eq!(c.nrows(), 2);
assert_eq!(c.ncols(), 2);
assert_eq!(c[(0, 0)], 58.0);
assert_eq!(c[(0, 1)], 64.0);
}
#[test]
#[should_panic(expected = "dimension mismatch")]
fn multiply_dim_mismatch() {
let a = DynMatrix::from_rows(2, 3, &[0.0; 6]);
let b = DynMatrix::from_rows(2, 2, &[0.0; 4]);
let _ = &a * &b;
}
#[test]
fn scalar_multiply() {
let a = DynMatrix::from_rows(2, 2, &[1.0, 2.0, 3.0, 4.0]);
let b = &a * 3.0;
assert_eq!(b[(0, 0)], 3.0);
assert_eq!(b[(1, 1)], 12.0);
let c = 3.0 * &a;
assert_eq!(c, b);
}
#[test]
fn scalar_divide() {
let a = DynMatrix::from_rows(2, 2, &[2.0, 4.0, 6.0, 8.0]);
let b = &a / 2.0;
assert_eq!(b[(0, 0)], 1.0);
assert_eq!(b[(1, 1)], 4.0);
}
#[test]
fn mul_div_assign() {
let mut a = DynMatrix::from_rows(2, 2, &[1.0, 2.0, 3.0, 4.0]);
a *= 2.0;
assert_eq!(a[(0, 0)], 2.0);
a /= 2.0;
assert_eq!(a[(0, 0)], 1.0);
}
#[test]
fn element_mul() {
let a = DynMatrix::from_rows(2, 2, &[1.0, 2.0, 3.0, 4.0]);
let b = DynMatrix::from_rows(2, 2, &[5.0, 6.0, 7.0, 8.0]);
let c = a.element_mul(&b);
assert_eq!(c[(0, 0)], 5.0);
assert_eq!(c[(1, 1)], 32.0);
}
#[test]
fn element_div() {
let a = DynMatrix::from_rows(2, 2, &[10.0, 12.0, 21.0, 32.0]);
let b = DynMatrix::from_rows(2, 2, &[5.0, 6.0, 7.0, 8.0]);
let c = a.element_div(&b);
assert_eq!(c[(0, 0)], 2.0);
assert_eq!(c[(1, 1)], 4.0);
}
#[test]
fn transpose() {
let a = DynMatrix::from_rows(2, 3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
let t = a.transpose();
assert_eq!(t.nrows(), 3);
assert_eq!(t.ncols(), 2);
assert_eq!(t[(0, 0)], 1.0);
assert_eq!(t[(1, 0)], 2.0);
assert_eq!(t[(2, 1)], 6.0);
}
#[test]
fn ref_variants() {
let a = DynMatrix::from_rows(2, 2, &[1.0, 2.0, 3.0, 4.0]);
let b = DynMatrix::from_rows(2, 2, &[5.0, 6.0, 7.0, 8.0]);
let sum1 = &a + &b;
let sum2 = a.clone() + &b;
let sum3 = &a + b.clone();
let sum4 = a.clone() + b.clone();
assert_eq!(sum1, sum2);
assert_eq!(sum1, sum3);
assert_eq!(sum1, sum4);
}
#[test]
fn identity_multiply() {
let a = DynMatrix::from_rows(2, 2, &[1.0, 2.0, 3.0, 4.0]);
let id = DynMatrix::<f64>::eye(2);
assert_eq!(&a * &id, a);
assert_eq!(&id * &a, a);
}
#[test]
fn kronecker_identity() {
let i2 = DynMatrix::<f64>::eye(2);
let i4 = i2.kronecker(&i2);
assert_eq!(i4.nrows(), 4);
assert_eq!(i4.ncols(), 4);
let expected = DynMatrix::<f64>::eye(4);
assert_eq!(i4, expected);
}
#[test]
fn kronecker_2x2() {
let a = DynMatrix::from_rows(2, 2, &[1.0, 2.0, 3.0, 4.0]);
let b = DynMatrix::from_rows(2, 2, &[0.0, 5.0, 6.0, 7.0]);
let c = a.kronecker(&b);
assert_eq!(c.nrows(), 4);
assert_eq!(c.ncols(), 4);
let expected = DynMatrix::from_rows(
4,
4,
&[
0.0, 5.0, 0.0, 10.0, 6.0, 7.0, 12.0, 14.0, 0.0, 15.0, 0.0, 20.0, 18.0, 21.0,
24.0, 28.0,
],
);
assert_eq!(c, expected);
}
#[test]
fn kronecker_rectangular() {
let a = DynMatrix::from_rows(2, 3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
let b = DynMatrix::from_rows(2, 2, &[1.0, 0.0, 0.0, 1.0]);
let c = a.kronecker(&b);
assert_eq!(c.nrows(), 4);
assert_eq!(c.ncols(), 6);
assert_eq!(c[(0, 0)], 1.0);
assert_eq!(c[(0, 1)], 0.0);
assert_eq!(c[(0, 2)], 2.0);
assert_eq!(c[(1, 1)], 1.0);
assert_eq!(c[(1, 3)], 2.0);
assert_eq!(c[(2, 2)], 5.0);
assert_eq!(c[(2, 4)], 6.0);
assert_eq!(c[(3, 3)], 5.0);
assert_eq!(c[(3, 5)], 6.0);
}
#[test]
fn kronecker_commute_trace() {
let a = DynMatrix::<f64>::from_rows(3, 3, &[2.0, 1.0, 0.0, 0.0, 3.0, 1.0, 1.0, 0.0, 5.0]);
let b = DynMatrix::<f64>::from_rows(2, 2, &[4.0, 7.0, 2.0, 6.0]);
let ab = a.kronecker(&b);
let tr_ab: f64 = ab.trace();
let tr_a: f64 = a.trace();
let tr_b: f64 = b.trace();
assert!((tr_ab - tr_a * tr_b).abs() < 1e-12);
}
}