use crate::array::Array;
use crate::error::{NumRs2Error, Result};
use num_traits::{One, Zero};
use std::convert::TryFrom;
use std::fmt;
use std::ops::{Add, Div, Mul, Sub};
#[derive(Clone)]
pub struct Matrix<T> {
data: Array<T>,
}
impl<T> Matrix<T>
where
T: Clone + Zero + One + PartialEq + Default + PartialOrd,
{
pub fn new(array: Array<T>) -> Result<Self> {
if array.ndim() != 2 {
return Err(NumRs2Error::DimensionMismatch(format!(
"Matrix must be 2-dimensional, got {}-dimensional array",
array.ndim()
)));
}
Ok(Self { data: array })
}
pub fn from_vec(vec: Vec<T>) -> Self {
let n = vec.len();
let array = Array::from_vec(vec).reshape(&[n, 1]);
Self { data: array }
}
pub fn from_nested_vec(nested_vec: Vec<Vec<T>>) -> Result<Self> {
if nested_vec.is_empty() {
return Err(NumRs2Error::InvalidOperation(
"Cannot create matrix from empty vector".to_string(),
));
}
let first_row_len = nested_vec[0].len();
if !nested_vec.iter().all(|row| row.len() == first_row_len) {
return Err(NumRs2Error::InvalidOperation(
"Inconsistent row lengths in nested vector".to_string(),
));
}
let mut flat_vec = Vec::with_capacity(nested_vec.len() * first_row_len);
let rows = nested_vec.len();
for row in nested_vec {
flat_vec.extend(row);
}
let array = Array::from_vec(flat_vec).reshape(&[rows, first_row_len]);
Ok(Self { data: array })
}
pub fn zeros(rows: usize, cols: usize) -> Self
where
T: Default + Clone,
{
let array = Array::zeros(&[rows, cols]);
Self { data: array }
}
pub fn ones(rows: usize, cols: usize) -> Self
where
T: From<u8> + Clone,
{
let array = Array::ones(&[rows, cols]);
Self { data: array }
}
pub fn eye(n: usize) -> Self
where
T: From<u8> + Clone + Default,
{
let mut array = Array::zeros(&[n, n]);
for i in 0..n {
let value = T::from(1u8);
array
.set(&[i, i], value.clone())
.expect("eye: diagonal index should always be valid");
}
Self { data: array }
}
pub fn nrows(&self) -> usize {
self.data.shape()[0]
}
pub fn ncols(&self) -> usize {
self.data.shape()[1]
}
pub fn shape(&self) -> (usize, usize) {
let shape = self.data.shape();
(shape[0], shape[1])
}
pub fn size(&self) -> usize {
self.data.size()
}
pub fn array(&self) -> &Array<T> {
&self.data
}
pub fn to_array(&self) -> Array<T> {
self.data.clone()
}
pub fn to_nested_vec(&self) -> Vec<Vec<T>> {
let (rows, cols) = self.shape();
let mut result = Vec::with_capacity(rows);
for i in 0..rows {
let mut row = Vec::with_capacity(cols);
for j in 0..cols {
row.push(
self.data
.get(&[i, j])
.expect("to_nested_vec: index within matrix bounds should be valid")
.clone(),
);
}
result.push(row);
}
result
}
pub fn get(&self, i: usize, j: usize) -> Result<T> {
if i >= self.nrows() || j >= self.ncols() {
return Err(NumRs2Error::IndexOutOfBounds(format!(
"Index ({}, {}) out of bounds for matrix with shape ({}, {})",
i,
j,
self.nrows(),
self.ncols()
)));
}
Ok(self.data.get(&[i, j])?.clone())
}
pub fn set(&mut self, i: usize, j: usize, value: T) -> Result<()> {
if i >= self.nrows() || j >= self.ncols() {
return Err(NumRs2Error::IndexOutOfBounds(format!(
"Index ({}, {}) out of bounds for matrix with shape ({}, {})",
i,
j,
self.nrows(),
self.ncols()
)));
}
self.data.set(&[i, j], value)
}
pub fn transpose(&self) -> Self {
let transposed_array = self.data.transpose();
Self {
data: transposed_array,
}
}
pub fn row(&self, i: usize) -> Result<Self> {
if i >= self.nrows() {
return Err(NumRs2Error::IndexOutOfBounds(format!(
"Row index {} out of bounds for matrix with {} rows",
i,
self.nrows()
)));
}
let cols = self.ncols();
let mut row_data = Vec::with_capacity(cols);
for j in 0..cols {
row_data.push(self.data.get(&[i, j])?.clone());
}
let row_array = Array::from_vec(row_data).reshape(&[1, cols]);
Ok(Self { data: row_array })
}
pub fn column(&self, j: usize) -> Result<Self> {
if j >= self.ncols() {
return Err(NumRs2Error::IndexOutOfBounds(format!(
"Column index {} out of bounds for matrix with {} columns",
j,
self.ncols()
)));
}
let rows = self.nrows();
let mut col_data = Vec::with_capacity(rows);
for i in 0..rows {
col_data.push(self.data.get(&[i, j])?.clone());
}
let col_array = Array::from_vec(col_data).reshape(&[rows, 1]);
Ok(Self { data: col_array })
}
pub fn diagonal(&self) -> Self {
let (rows, cols) = self.shape();
let diag_len = rows.min(cols);
let mut diag_data = Vec::with_capacity(diag_len);
for i in 0..diag_len {
diag_data.push(
self.data
.get(&[i, i])
.expect("diagonal: index within min(rows, cols) should be valid")
.clone(),
);
}
let diag_array = Array::from_vec(diag_data).reshape(&[diag_len, 1]);
Self { data: diag_array }
}
pub fn is_square(&self) -> bool {
self.nrows() == self.ncols()
}
pub fn is_symmetric(&self) -> bool
where
T: PartialEq,
{
if !self.is_square() {
return false;
}
let n = self.nrows();
for i in 0..n {
for j in (i + 1)..n {
let a_ij = self
.data
.get(&[i, j])
.expect("is_symmetric: index within square matrix should be valid");
let a_ji = self
.data
.get(&[j, i])
.expect("is_symmetric: index within square matrix should be valid");
if a_ij != a_ji {
return false;
}
}
}
true
}
}
impl<T> Matrix<T>
where
T: Clone + Default + Add<Output = T> + Mul<Output = T> + Zero + One + PartialEq + PartialOrd,
{
pub fn dot(&self, other: &Matrix<T>) -> Result<Matrix<T>> {
if self.ncols() != other.nrows() {
return Err(NumRs2Error::ShapeMismatch {
expected: vec![self.nrows(), other.ncols()],
actual: vec![self.nrows(), self.ncols()],
});
}
let m = self.nrows();
let p = self.ncols(); let n = other.ncols();
let mut result = Matrix::zeros(m, n);
for i in 0..m {
for j in 0..n {
let mut sum = T::default();
for k in 0..p {
let a_ik = self
.data
.get(&[i, k])
.expect("dot: index within self matrix bounds should be valid")
.clone();
let b_kj = other
.data
.get(&[k, j])
.expect("dot: index within other matrix bounds should be valid")
.clone();
sum = sum + (a_ik * b_kj);
}
result.set(i, j, sum)?;
}
}
Ok(result)
}
}
impl<T> TryFrom<Array<T>> for Matrix<T>
where
T: Clone + Zero + One + PartialEq + Default + PartialOrd,
{
type Error = NumRs2Error;
fn try_from(array: Array<T>) -> Result<Self> {
Matrix::new(array)
}
}
impl<T> Add for &Matrix<T>
where
T: Clone + Add<Output = T> + Zero + One + PartialEq + Default + PartialOrd,
{
type Output = Matrix<T>;
fn add(self, other: &Matrix<T>) -> Matrix<T> {
let result_array = self
.data
.add_broadcast(&other.data)
.expect("Matrix addition: shapes must be broadcastable");
Matrix { data: result_array }
}
}
impl<T> Sub for &Matrix<T>
where
T: Clone + Sub<Output = T> + Zero + One + PartialEq + Default + PartialOrd,
{
type Output = Matrix<T>;
fn sub(self, other: &Matrix<T>) -> Matrix<T> {
let result_array = self
.data
.subtract_broadcast(&other.data)
.expect("Matrix subtraction: shapes must be broadcastable");
Matrix { data: result_array }
}
}
impl<
T: Clone + Default + Add<Output = T> + Mul<Output = T> + Zero + One + PartialEq + PartialOrd,
> Mul for &Matrix<T>
{
type Output = Matrix<T>;
fn mul(self, other: &Matrix<T>) -> Matrix<T> {
self.dot(other)
.expect("Matrix multiplication: inner dimensions must match")
}
}
impl<T> Div<T> for &Matrix<T>
where
T: Clone + Div<Output = T> + Zero + One + PartialEq + Default + PartialOrd,
{
type Output = Matrix<T>;
fn div(self, scalar: T) -> Matrix<T> {
let result_array = self.data.map(|x| x.clone() / scalar.clone());
Matrix { data: result_array }
}
}
impl<T> fmt::Display for Matrix<T>
where
T: Clone + fmt::Display + Zero + One + PartialEq + Default + PartialOrd,
{
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let (rows, cols) = self.shape();
writeln!(f, "Matrix({}, {})", rows, cols)?;
for i in 0..rows {
write!(f, "[")?;
for j in 0..cols {
let value = self
.data
.get(&[i, j])
.expect("Display: index within matrix bounds should be valid")
.clone();
if j > 0 {
write!(f, ", ")?;
}
write!(f, "{}", value)?;
}
writeln!(f, "]")?;
}
Ok(())
}
}
pub fn matrix<T>(data: Array<T>) -> Result<Matrix<T>>
where
T: Clone + Zero + One + PartialEq + Default + PartialOrd,
{
match data.ndim() {
0 => {
let scalar_value = data
.get(&[])
.expect("matrix: 0-dimensional array should have a single element")
.clone();
let scalar_array = Array::from_vec(vec![scalar_value]).reshape(&[1, 1]);
Matrix::new(scalar_array)
}
1 => {
let shape = data.shape();
let row_vector = data.reshape(&[1, shape[0]]);
Matrix::new(row_vector)
}
2 => {
Matrix::new(data)
}
_ => {
let total_size = data.size();
let flattened = data.flatten(None);
let matrix_2d = flattened.reshape(&[1, total_size]);
Matrix::new(matrix_2d)
}
}
}
pub fn matrix_from_nested<T>(nested_vec: Vec<Vec<T>>) -> Result<Matrix<T>>
where
T: Clone + Zero + One + PartialEq + Default + PartialOrd,
{
Matrix::from_nested_vec(nested_vec)
}
pub fn matrix_from_scalar<T>(scalar: T) -> Matrix<T>
where
T: Clone + Zero + One + PartialEq + Default + PartialOrd,
{
let scalar_array = Array::from_vec(vec![scalar]).reshape(&[1, 1]);
Matrix::new(scalar_array).expect("matrix_from_scalar: 1x1 array is always a valid matrix")
}
pub fn asmatrix<T>(data: Array<T>) -> Result<Matrix<T>>
where
T: Clone + Zero + One + PartialEq + Default + PartialOrd,
{
matrix(data)
}
pub fn asmatrix_from_nested<T>(nested_vec: Vec<Vec<T>>) -> Result<Matrix<T>>
where
T: Clone + Zero + One + PartialEq + Default + PartialOrd,
{
matrix_from_nested(nested_vec)
}