use super::DVec;
use alloc::vec::Vec;
#[cfg(all(feature = "accelerate", target_os = "macos"))]
use core::any::TypeId;
use core::ops::{Add, Index, IndexMut, Mul, Neg, Sub};
use crate::Scalar;
#[derive(Clone, Debug, PartialEq)]
pub struct DMat<S> {
data: Vec<S>,
nrows: usize,
ncols: usize,
}
impl<S: Scalar> DMat<S> {
pub fn from_raw(nrows: usize, ncols: usize, data: Vec<S>) -> Self {
assert_eq!(data.len(), nrows * ncols, "DMat: data length mismatch");
Self { data, nrows, ncols }
}
pub fn from_fn(nrows: usize, ncols: usize, f: impl Fn(usize, usize) -> S) -> Self {
let mut data = Vec::with_capacity(nrows * ncols);
for j in 0..ncols {
for i in 0..nrows {
data.push(f(i, j));
}
}
Self { data, nrows, ncols }
}
pub fn zeros(nrows: usize, ncols: usize) -> Self {
Self {
data: alloc::vec![S::ZERO; nrows * ncols],
nrows,
ncols,
}
}
pub fn identity(n: usize) -> Self {
Self::from_fn(n, n, |i, j| if i == j { S::ONE } else { S::ZERO })
}
pub fn from_iterator(nrows: usize, ncols: usize, iter: impl IntoIterator<Item = S>) -> Self {
let data: Vec<S> = iter.into_iter().take(nrows * ncols).collect();
assert_eq!(
data.len(),
nrows * ncols,
"DMat::from_iterator: iterator yielded fewer than expected"
);
Self::from_raw(nrows, ncols, data)
}
pub fn from_row_slice(nrows: usize, ncols: usize, data: &[S]) -> Self {
assert_eq!(data.len(), nrows * ncols, "DMat: data length mismatch");
Self::from_fn(nrows, ncols, |i, j| data[i * ncols + j])
}
pub fn from_diagonal(diag: &DVec<S>) -> Self {
let n = diag.len();
Self::from_fn(n, n, |i, j| if i == j { diag[i] } else { S::ZERO })
}
#[inline]
pub fn nrows(&self) -> usize {
self.nrows
}
#[inline]
pub fn ncols(&self) -> usize {
self.ncols
}
#[inline]
pub fn get(&self, row: usize, col: usize) -> S {
self.data[col * self.nrows + row]
}
#[inline]
pub fn get_mut(&mut self, row: usize, col: usize) -> &mut S {
&mut self.data[col * self.nrows + row]
}
#[inline]
pub fn set(&mut self, row: usize, col: usize, val: S) {
self.data[col * self.nrows + row] = val;
}
#[inline]
pub fn as_slice(&self) -> &[S] {
&self.data
}
#[inline]
pub fn as_mut_slice(&mut self) -> &mut [S] {
&mut self.data
}
#[inline]
pub fn col_mut(&mut self, j: usize) -> &mut [S] {
let start = j * self.nrows;
&mut self.data[start..start + self.nrows]
}
#[inline]
pub fn data_mut(&mut self) -> &mut Vec<S> {
&mut self.data
}
pub fn col(&self, j: usize) -> &[S] {
let start = j * self.nrows;
&self.data[start..start + self.nrows]
}
pub fn col_vec(&self, j: usize) -> DVec<S> {
DVec::from_slice(self.col(j))
}
pub fn row_vec(&self, i: usize) -> DVec<S> {
DVec::from_fn(self.ncols, |j| self.get(i, j))
}
pub fn diagonal(&self) -> DVec<S> {
let n = self.nrows.min(self.ncols);
DVec::from_fn(n, |i| self.get(i, i))
}
pub fn transpose(&self) -> Self {
Self::from_fn(self.ncols, self.nrows, |i, j| self.get(j, i))
}
pub fn mul_vec(&self, x: &DVec<S>) -> DVec<S> {
assert_eq!(self.ncols, x.len(), "DMat mul_vec: dimension mismatch");
#[cfg(all(feature = "accelerate", target_os = "macos"))]
{
let m = self.nrows as i32;
let n = self.ncols as i32;
if TypeId::of::<S>() == TypeId::of::<f32>() {
let mut y = DVec::zeros(self.nrows);
unsafe {
super::blas::sgemv(
self.as_slice().as_ptr() as *const f32,
x.as_slice().as_ptr() as *const f32,
y.as_mut_slice().as_mut_ptr() as *mut f32,
m, n,
);
}
return y;
}
if TypeId::of::<S>() == TypeId::of::<f64>() {
let mut y = DVec::zeros(self.nrows);
unsafe {
super::blas::dgemv(
self.as_slice().as_ptr() as *const f64,
x.as_slice().as_ptr() as *const f64,
y.as_mut_slice().as_mut_ptr() as *mut f64,
m, n,
);
}
return y;
}
}
let mut y = DVec::zeros(self.nrows);
for j in 0..self.ncols {
let xj = x[j];
for i in 0..self.nrows {
y[i] += self.get(i, j) * xj;
}
}
y
}
pub fn mul_mat(&self, rhs: &DMat<S>) -> DMat<S> {
assert_eq!(self.ncols, rhs.nrows, "DMat mul_mat: dimension mismatch");
let m = self.nrows;
let n = rhs.ncols;
let p = self.ncols;
#[cfg(all(feature = "accelerate", target_os = "macos"))]
{
let mi = m as i32;
let ni = n as i32;
let ki = p as i32;
if TypeId::of::<S>() == TypeId::of::<f32>() {
let mut c = DMat::zeros(m, n);
unsafe {
super::blas::sgemm(
self.as_slice().as_ptr() as *const f32,
rhs.as_slice().as_ptr() as *const f32,
c.as_mut_slice().as_mut_ptr() as *mut f32,
mi, ni, ki,
);
}
return c;
}
if TypeId::of::<S>() == TypeId::of::<f64>() {
let mut c = DMat::zeros(m, n);
unsafe {
super::blas::dgemm(
self.as_slice().as_ptr() as *const f64,
rhs.as_slice().as_ptr() as *const f64,
c.as_mut_slice().as_mut_ptr() as *mut f64,
mi, ni, ki,
);
}
return c;
}
}
let mut c = DMat::zeros(m, n);
let a = self.as_slice();
let b = rhs.as_slice();
let c_data = c.as_mut_slice();
for j in 0..n {
let c_col = j * m;
for k in 0..p {
let b_kj = b[j * rhs.nrows + k];
let a_col = k * m;
for i in 0..m {
c_data[c_col + i] = c_data[c_col + i] + a[a_col + i] * b_kj;
}
}
}
c
}
pub fn norm_sq(&self) -> S {
let mut s = S::ZERO;
for &x in &self.data {
s += x * x;
}
s
}
pub fn norm(&self) -> S {
self.norm_sq().sqrt()
}
pub fn trace(&self) -> S {
let n = self.nrows.min(self.ncols);
let mut s = S::ZERO;
for i in 0..n {
s += self.get(i, i);
}
s
}
pub fn scale(&self, s: S) -> Self {
Self::from_fn(self.nrows, self.ncols, |i, j| self.get(i, j) * s)
}
#[inline]
pub fn is_square(&self) -> bool {
self.nrows == self.ncols
}
pub fn swap_rows(&mut self, a: usize, b: usize) {
if a == b {
return;
}
for j in 0..self.ncols {
let va = self.get(a, j);
let vb = self.get(b, j);
self.set(a, j, vb);
self.set(b, j, va);
}
}
pub fn submatrix(
&self,
row_start: usize,
col_start: usize,
nrows: usize,
ncols: usize,
) -> Self {
Self::from_fn(nrows, ncols, |i, j| self.get(row_start + i, col_start + j))
}
pub fn column(&self, j: usize) -> DVec<S> {
self.col_vec(j)
}
pub fn symmetric_eigen(&self) -> super::SymmetricEigen<S> {
super::SymmetricEigen::new(self)
}
pub fn svd(&self, _compute_u: bool, _compute_v: bool) -> super::Svd<S> {
super::Svd::new(self)
}
pub fn try_inverse(&self) -> Option<Self> {
assert!(self.is_square(), "DMat::try_inverse: not square");
let n = self.nrows;
let lu = super::Lu::new(self)?;
let mut inv = DMat::zeros(n, n);
for j in 0..n {
let mut e = DVec::zeros(n);
e[j] = S::ONE;
let col = lu.solve(&e);
for i in 0..n {
inv.set(i, j, col[i]);
}
}
Some(inv)
}
pub fn lu(self) -> DMatLu<S> {
DMatLu(self)
}
}
pub struct DMatLu<S>(DMat<S>);
impl<S: Scalar> DMatLu<S> {
pub fn solve(&self, b: &DVec<S>) -> Option<DVec<S>> {
super::Lu::new(&self.0).map(|lu| lu.solve(b))
}
}
impl<S: Scalar> Index<(usize, usize)> for DMat<S> {
type Output = S;
#[inline]
fn index(&self, (row, col): (usize, usize)) -> &S {
&self.data[col * self.nrows + row]
}
}
impl<S: Scalar> IndexMut<(usize, usize)> for DMat<S> {
#[inline]
fn index_mut(&mut self, (row, col): (usize, usize)) -> &mut S {
&mut self.data[col * self.nrows + row]
}
}
impl<S: Scalar> Add for &DMat<S> {
type Output = DMat<S>;
fn add(self, rhs: &DMat<S>) -> DMat<S> {
assert_eq!(self.nrows, rhs.nrows);
assert_eq!(self.ncols, rhs.ncols);
DMat::from_fn(self.nrows, self.ncols, |i, j| {
self.get(i, j) + rhs.get(i, j)
})
}
}
impl<S: Scalar> Sub for &DMat<S> {
type Output = DMat<S>;
fn sub(self, rhs: &DMat<S>) -> DMat<S> {
assert_eq!(self.nrows, rhs.nrows);
assert_eq!(self.ncols, rhs.ncols);
DMat::from_fn(self.nrows, self.ncols, |i, j| {
self.get(i, j) - rhs.get(i, j)
})
}
}
impl<S: Scalar> Neg for &DMat<S> {
type Output = DMat<S>;
fn neg(self) -> DMat<S> {
DMat::from_fn(self.nrows, self.ncols, |i, j| -self.get(i, j))
}
}
impl<S: Scalar> Mul<&DVec<S>> for &DMat<S> {
type Output = DVec<S>;
fn mul(self, rhs: &DVec<S>) -> DVec<S> {
self.mul_vec(rhs)
}
}
impl<S: Scalar> Mul for &DMat<S> {
type Output = DMat<S>;
fn mul(self, rhs: &DMat<S>) -> DMat<S> {
self.mul_mat(rhs)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn identity_mul() {
let i = DMat::<f64>::identity(3);
let x = DVec::from_slice(&[1.0, 2.0, 3.0]);
let y = i.mul_vec(&x);
assert_eq!(y[0], 1.0);
assert_eq!(y[1], 2.0);
assert_eq!(y[2], 3.0);
}
#[test]
fn mat_mul() {
let a = DMat::from_fn(2, 3, |i, j| (i * 3 + j + 1) as f64);
let b = DMat::from_fn(3, 2, |i, j| (i * 2 + j + 1) as f64);
let c = a.mul_mat(&b);
assert_eq!(c.nrows(), 2);
assert_eq!(c.ncols(), 2);
assert_eq!(c.get(0, 0), 22.0);
assert_eq!(c.get(0, 1), 28.0);
assert_eq!(c.get(1, 0), 49.0);
assert_eq!(c.get(1, 1), 64.0);
}
#[test]
fn transpose() {
let m = DMat::from_fn(2, 3, |i, j| (i * 3 + j) as f64);
let mt = m.transpose();
assert_eq!(mt.nrows(), 3);
assert_eq!(mt.ncols(), 2);
assert_eq!(mt.get(0, 0), 0.0);
assert_eq!(mt.get(0, 1), 3.0);
assert_eq!(mt.get(1, 0), 1.0);
}
#[test]
fn trace() {
let m = DMat::from_fn(3, 3, |i, j| if i == j { (i + 1) as f64 } else { 0.0 });
assert_eq!(m.trace(), 6.0);
}
#[test]
fn diagonal() {
let m = DMat::from_fn(3, 3, |i, j| (i * 3 + j) as f64);
let d = m.diagonal();
assert_eq!(d[0], 0.0);
assert_eq!(d[1], 4.0);
assert_eq!(d[2], 8.0);
}
#[cfg(all(feature = "accelerate", target_os = "macos"))]
mod blas_tests {
use super::*;
#[test]
fn f32_matmul_small() {
let a = DMat::from_fn(2, 3, |i, j| (i * 3 + j + 1) as f32);
let b = DMat::from_fn(3, 2, |i, j| (i * 2 + j + 1) as f32);
let c = a.mul_mat(&b);
assert_eq!(c.get(0, 0), 22.0f32);
assert_eq!(c.get(0, 1), 28.0f32);
assert_eq!(c.get(1, 0), 49.0f32);
assert_eq!(c.get(1, 1), 64.0f32);
}
#[test]
fn f64_matmul_small() {
let a = DMat::from_fn(2, 3, |i, j| (i * 3 + j + 1) as f64);
let b = DMat::from_fn(3, 2, |i, j| (i * 2 + j + 1) as f64);
let c = a.mul_mat(&b);
assert_eq!(c.get(0, 0), 22.0);
assert_eq!(c.get(0, 1), 28.0);
assert_eq!(c.get(1, 0), 49.0);
assert_eq!(c.get(1, 1), 64.0);
}
#[test]
fn f32_matmul_large() {
let n = 64;
let a = DMat::from_fn(n, n, |i, j| ((i + j) % 7) as f32);
let b = DMat::from_fn(n, n, |i, j| ((i * 3 + j) % 11) as f32);
let c = a.mul_mat(&b);
let mut expected = 0.0f32;
for k in 0..n {
expected += a.get(0, k) * b.get(k, 0);
}
assert!((c.get(0, 0) - expected).abs() < 1e-3);
}
#[test]
fn f32_matvec() {
let a = DMat::from_fn(2, 3, |i, j| (i * 3 + j + 1) as f32);
let x = DVec::from_slice(&[1.0f32, 2.0, 3.0]);
let y = a.mul_vec(&x);
assert_eq!(y[0], 14.0f32);
assert_eq!(y[1], 32.0f32);
}
#[test]
fn f64_matvec() {
let a = DMat::from_fn(2, 3, |i, j| (i * 3 + j + 1) as f64);
let x = DVec::from_slice(&[1.0, 2.0, 3.0]);
let y = a.mul_vec(&x);
assert_eq!(y[0], 14.0);
assert_eq!(y[1], 32.0);
}
#[test]
fn f32_rectangular() {
let a = DMat::from_fn(5, 3, |i, j| (i + j) as f32);
let b = DMat::from_fn(3, 4, |i, j| (i * j + 1) as f32);
let c = a.mul_mat(&b);
assert_eq!(c.nrows(), 5);
assert_eq!(c.ncols(), 4);
assert_eq!(c.get(0, 0), 3.0f32);
}
#[test]
fn f64_identity_matvec() {
let eye = DMat::<f64>::identity(4);
let x = DVec::from_slice(&[10.0, 20.0, 30.0, 40.0]);
let y = eye.mul_vec(&x);
for i in 0..4 {
assert_eq!(y[i], x[i]);
}
}
}
}