use std::slice::Iter;
use crate::Dot;
use crate::vector::Vector;
use crate::Init;
use impl_ops::*;
use std::ops;
pub struct Matrix {
array : Vec<Vec<f64>>,
m : usize, n : usize, }
impl Matrix {
pub fn from(matrix: Vec<Vec<f64>>) -> Matrix {
Matrix {
m : matrix.len(),
n : matrix[0].len(),
array : matrix,
}
}
#[inline]
pub fn n(&self) -> usize {
self.n
}
#[inline]
pub fn m(&self) -> usize {
self.m
}
pub fn is_square(&self) -> bool {
self.n == self.m
}
pub fn transpose(&self) -> Matrix {
let mut tr = Self::zeros(
(self.n, self.m)
);
for i in 0..self.m {
for j in 0..self.n {
(&mut tr.array[j])[i] = (& self.array[i])[j];
}
}
return tr;
}
pub fn iter(&self) -> Iter<'_, Vec<f64>> {
self.array.iter()
}
pub fn inv(&self) -> Matrix {
assert!(self.is_square());
let n = self.n;
let mut original = self.clone();
let mut inv = Matrix::diagonal(
Vector::ones(self.m)
);
for k in 0..n {
let a = 1.0/original.array[k][k];
for j in k..n {
original.array[k][j] *= a;
}
for j in 0..n {
inv.array[k][j] *= a;
}
for i in k+1..n {
let b = original.array[i][k];
for j in k..n {
original.array[i][j] -= original.array[k][j]*b;
}
for j in 0..n {
inv.array[i][j] -= inv.array[k][j]*b;
}
}
}
for k in (0..n).rev() {
for i in 0..k {
let a = original.array[i][k];
for j in k..n {
original.array[i][j] -= original.array[k][j]*a;
}
for j in 0..n {
inv.array[i][j] -= inv.array[k][j]*a;
}
}
}
inv
}
pub fn line(&self, i: usize) -> Vector {
Vector::from(self.array[i].clone())
}
pub fn diagonal(vec: Vector) -> Matrix {
let mut matrix = Matrix::zeros((vec.len(), vec.len()));
for i in 0..vec.len() {
matrix[(i,i)] = vec[i];
}
matrix
}
}
impl Init for Matrix {
type Output = Matrix;
type Size = (usize, usize);
fn init(value: f64, size: Self::Size) -> Self::Output {
Matrix {
array : vec![vec![value; size.1]; size.0],
m : size.0,
n : size.1
}
}
fn init_func<F>(func: F, size: Self::Size) -> Self::Output where F: Fn(Self::Size) -> f64 {
Matrix {
array : (0..size.0).map(|i| {
(0..size.1).map(|j| func((i,j))).collect()
}).collect(),
n : size.0,
m : size.1,
}
}
}
impl std::clone::Clone for Matrix {
fn clone(&self) -> Self {
Matrix {
array : self.array.clone(),
m : self.m,
n : self.n,
}
}
}
impl Dot<Vector> for Matrix {
type Output = Vector;
fn dot(&self, rhs: &Vector) -> Self::Output {
assert_eq!(self.n, rhs.len());
let mut result = Vector::zeros(self.n);
for i in 0..result.len() {
let mut s = 0f64;
for k in 0..self.m {
s += self.array[i][k]*rhs[k]
}
result[i] = s;
}
result
}
}
impl Dot<Matrix> for Matrix {
type Output = Matrix;
fn dot(&self, rhs: &Matrix) -> Self::Output {
assert_eq!(self.n, rhs.m);
let mut result = Matrix::zeros((self.m, rhs.n));
for i in 0..result.n {
for j in 0..result.m {
let mut s = 0f64;
for k in 0..self.m {
s += self.array[i][k]*rhs.array[k][j]
}
result.array[i][j] = s;
}
}
result
}
}
impl std::ops::Index<usize> for Matrix {
type Output = Vec<f64>;
#[inline]
fn index(&self, index: usize) -> &Self::Output {
&self.array[index]
}
}
impl std::ops::Index<(usize, usize)> for Matrix {
type Output = f64;
#[inline]
fn index(&self, index: (usize, usize)) -> &Self::Output {
&self[index.0][index.1]
}
}
impl std::ops::IndexMut<usize> for Matrix {
#[inline]
fn index_mut(&mut self, index: usize) -> &mut Self::Output {
&mut self.array[index]
}
}
impl std::ops::IndexMut<(usize, usize)> for Matrix {
#[inline]
fn index_mut(&mut self, index: (usize, usize)) -> &mut Self::Output {
&mut self[index.0][index.1]
}
}
impl std::ops::Neg for Matrix {
type Output = Matrix;
fn neg(self) -> Self::Output {
Matrix::from(self.array.iter().map(|line|
line.iter().map(|x| -x).collect()
).collect())
}
}
impl std::ops::Neg for &Matrix {
type Output = Matrix;
fn neg(self) -> Self::Output {
Matrix::from(self.array.iter().map(|line|
line.iter().map(|x| -x).collect()
).collect())
}
}
macro_rules! impl_operator {
( $($op:tt),* ) => {
$(
impl_op_ex!($op |a: &Matrix, b: &Matrix| -> Matrix {
Matrix::from(a.iter()
.zip(b.iter())
.map(|(line_a,line_b)| {
line_a.iter().zip(line_b.iter())
.map(|(x,y)| x $op y)
.collect()
}).collect()
)
});
impl_op_ex!($op |a: &Matrix, k: &f64| -> Matrix {
Matrix::from(a.iter()
.map(|line| line.iter()
.map(|x| x $op k).collect()
).collect()
)
});
)*
}
}
impl_operator![+, -, *, /];