pub mod aliases;
mod block;
mod linalg;
mod mixed_ops;
mod norm;
mod ops;
mod slice;
mod square;
mod util;
mod vector;
pub use aliases::*;
pub use linalg::{DynCholesky, DynLu, DynQr, DynQrPivot, DynSchur, DynSvd, DynSymmetricEigen};
pub use vector::DynVector;
use alloc::vec;
use alloc::vec::Vec;
use core::ops::{Index, IndexMut};
use crate::traits::{MatrixMut, MatrixRef, Scalar};
use crate::Matrix;
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct DimensionMismatch {
pub expected: (usize, usize),
pub got: (usize, usize),
}
impl core::fmt::Display for DimensionMismatch {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(
f,
"dimension mismatch: expected {}x{}, got {}x{}",
self.expected.0, self.expected.1, self.got.0, self.got.1
)
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct DynMatrix<T> {
data: Vec<T>,
nrows: usize,
ncols: usize,
}
impl<T: Scalar> DynMatrix<T> {
pub fn zeros(nrows: usize, ncols: usize) -> Self {
Self {
data: vec![T::zero(); nrows * ncols],
nrows,
ncols,
}
}
pub fn fill(nrows: usize, ncols: usize, value: T) -> Self {
Self {
data: vec![value; nrows * ncols],
nrows,
ncols,
}
}
pub fn eye(n: usize) -> Self {
let mut m = Self::zeros(n, n);
for i in 0..n {
m[(i, i)] = T::one();
}
m
}
pub fn from_slice(nrows: usize, ncols: usize, slice: &[T]) -> Self {
assert_eq!(
slice.len(),
nrows * ncols,
"slice length {} does not match {}x{} matrix",
slice.len(),
nrows,
ncols,
);
Self {
data: slice.to_vec(),
nrows,
ncols,
}
}
pub fn from_rows(nrows: usize, ncols: usize, row_major: &[T]) -> Self {
assert_eq!(
row_major.len(),
nrows * ncols,
"slice length {} does not match {}x{} matrix",
row_major.len(),
nrows,
ncols,
);
let mut data = vec![T::zero(); nrows * ncols];
for i in 0..nrows {
for j in 0..ncols {
data[j * nrows + i] = row_major[i * ncols + j];
}
}
Self { data, nrows, ncols }
}
pub fn from_vec(nrows: usize, ncols: usize, data: Vec<T>) -> Self {
assert_eq!(
data.len(),
nrows * ncols,
"vec length {} does not match {}x{} matrix",
data.len(),
nrows,
ncols,
);
Self { data, nrows, ncols }
}
}
impl<T> DynMatrix<T> {
#[inline]
pub fn nrows(&self) -> usize {
self.nrows
}
#[inline]
pub fn ncols(&self) -> usize {
self.ncols
}
#[inline]
pub fn is_square(&self) -> bool {
self.nrows == self.ncols
}
pub fn from_fn(nrows: usize, ncols: usize, f: impl Fn(usize, usize) -> T) -> Self {
let mut data = Vec::with_capacity(nrows * ncols);
for j in 0..ncols {
for i in 0..nrows {
data.push(f(i, j));
}
}
Self { data, nrows, ncols }
}
}
impl<T> MatrixRef<T> for DynMatrix<T> {
#[inline]
fn nrows(&self) -> usize {
self.nrows
}
#[inline]
fn ncols(&self) -> usize {
self.ncols
}
#[inline]
fn get(&self, row: usize, col: usize) -> &T {
&self.data[col * self.nrows + row]
}
#[inline]
fn col_as_slice(&self, col: usize, row_start: usize) -> &[T] {
let start = col * self.nrows + row_start;
let end = col * self.nrows + self.nrows;
&self.data[start..end]
}
}
impl<T> MatrixMut<T> for DynMatrix<T> {
#[inline]
fn get_mut(&mut self, row: usize, col: usize) -> &mut T {
&mut self.data[col * self.nrows + row]
}
#[inline]
fn col_as_mut_slice(&mut self, col: usize, row_start: usize) -> &mut [T] {
let start = col * self.nrows + row_start;
let end = col * self.nrows + self.nrows;
&mut self.data[start..end]
}
}
impl<T> Index<(usize, usize)> for DynMatrix<T> {
type Output = T;
#[inline]
fn index(&self, (row, col): (usize, usize)) -> &T {
&self.data[col * self.nrows + row]
}
}
impl<T> IndexMut<(usize, usize)> for DynMatrix<T> {
#[inline]
fn index_mut(&mut self, (row, col): (usize, usize)) -> &mut T {
&mut self.data[col * self.nrows + row]
}
}
impl<T: Scalar, const M: usize, const N: usize> From<Matrix<T, M, N>> for DynMatrix<T> {
fn from(m: Matrix<T, M, N>) -> Self {
Self {
data: m.as_slice().to_vec(),
nrows: M,
ncols: N,
}
}
}
impl<T: Scalar, const M: usize, const N: usize> From<&Matrix<T, M, N>> for DynMatrix<T> {
fn from(m: &Matrix<T, M, N>) -> Self {
Self {
data: m.as_slice().to_vec(),
nrows: M,
ncols: N,
}
}
}
impl<T: Scalar, const M: usize, const N: usize> TryFrom<&DynMatrix<T>> for Matrix<T, M, N> {
type Error = DimensionMismatch;
fn try_from(d: &DynMatrix<T>) -> Result<Self, Self::Error> {
if d.nrows != M || d.ncols != N {
return Err(DimensionMismatch {
expected: (M, N),
got: (d.nrows, d.ncols),
});
}
Ok(Matrix::from_slice(d.data.as_slice()))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn zeros() {
let m = DynMatrix::<f64>::zeros(3, 4);
assert_eq!(m.nrows(), 3);
assert_eq!(m.ncols(), 4);
for i in 0..3 {
for j in 0..4 {
assert_eq!(m[(i, j)], 0.0);
}
}
}
#[test]
fn fill() {
let m = DynMatrix::fill(2, 3, 7.0_f64);
for i in 0..2 {
for j in 0..3 {
assert_eq!(m[(i, j)], 7.0);
}
}
}
#[test]
fn eye() {
let m = DynMatrix::<f64>::eye(3);
for i in 0..3 {
for j in 0..3 {
let expected = if i == j { 1.0 } else { 0.0 };
assert_eq!(m[(i, j)], expected);
}
}
}
#[test]
fn from_rows() {
let m = DynMatrix::from_rows(2, 3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
assert_eq!(m[(0, 0)], 1.0);
assert_eq!(m[(0, 2)], 3.0);
assert_eq!(m[(1, 0)], 4.0);
assert_eq!(m[(1, 2)], 6.0);
}
#[test]
#[should_panic(expected = "slice length")]
fn from_rows_wrong_length() {
let _ = DynMatrix::from_rows(2, 2, &[1.0, 2.0, 3.0]);
}
#[test]
fn from_vec() {
let m = DynMatrix::from_vec(2, 2, vec![1.0, 3.0, 2.0, 4.0]);
assert_eq!(m[(0, 0)], 1.0);
assert_eq!(m[(1, 1)], 4.0);
}
#[test]
fn from_fn() {
let m = DynMatrix::from_fn(3, 3, |i, j| (i * 3 + j) as f64);
assert_eq!(m[(0, 0)], 0.0);
assert_eq!(m[(1, 1)], 4.0);
assert_eq!(m[(2, 2)], 8.0);
}
#[test]
fn index_mut() {
let mut m = DynMatrix::<f64>::zeros(2, 2);
m[(0, 1)] = 5.0;
assert_eq!(m[(0, 1)], 5.0);
}
#[test]
fn matrix_ref_trait() {
let m = DynMatrix::from_rows(2, 2, &[1.0, 2.0, 3.0, 4.0]);
fn trace<T: Scalar>(m: &impl MatrixRef<T>) -> T {
let mut sum = T::zero();
let n = m.nrows().min(m.ncols());
for i in 0..n {
sum = sum + *m.get(i, i);
}
sum
}
assert_eq!(trace(&m), 5.0);
}
#[test]
fn matrix_mut_trait() {
let mut m = DynMatrix::<f64>::zeros(2, 2);
fn set_diag<T: Scalar>(m: &mut impl MatrixMut<T>, val: T) {
let n = m.nrows().min(m.ncols());
for i in 0..n {
*m.get_mut(i, i) = val;
}
}
set_diag(&mut m, 7.0);
assert_eq!(m[(0, 0)], 7.0);
assert_eq!(m[(1, 1)], 7.0);
assert_eq!(m[(0, 1)], 0.0);
}
#[test]
fn from_matrix() {
let m = Matrix::new([[1.0, 2.0], [3.0, 4.0]]);
let d: DynMatrix<f64> = m.into();
assert_eq!(d.nrows(), 2);
assert_eq!(d.ncols(), 2);
assert_eq!(d[(0, 0)], 1.0);
assert_eq!(d[(1, 1)], 4.0);
}
#[test]
fn from_matrix_ref() {
let m = Matrix::new([[1.0, 2.0], [3.0, 4.0]]);
let d: DynMatrix<f64> = (&m).into();
assert_eq!(d[(0, 0)], 1.0);
}
#[test]
fn try_into_matrix() {
let d = DynMatrix::from_rows(2, 2, &[1.0, 2.0, 3.0, 4.0]);
let m: Matrix<f64, 2, 2> = (&d).try_into().unwrap();
assert_eq!(m[(0, 0)], 1.0);
assert_eq!(m[(1, 1)], 4.0);
}
#[test]
fn try_into_matrix_wrong_dims() {
let d = DynMatrix::from_rows(2, 3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
let result: Result<Matrix<f64, 2, 2>, _> = (&d).try_into();
assert!(result.is_err());
let err = result.unwrap_err();
assert_eq!(err.expected, (2, 2));
assert_eq!(err.got, (2, 3));
}
#[test]
fn is_square() {
let sq = DynMatrix::<f64>::zeros(3, 3);
assert!(sq.is_square());
let rect = DynMatrix::<f64>::zeros(2, 3);
assert!(!rect.is_square());
}
#[test]
fn clone_eq() {
let a = DynMatrix::from_rows(2, 2, &[1.0, 2.0, 3.0, 4.0]);
let b = a.clone();
assert_eq!(a, b);
}
}