use crate::structure::matrix::matrix_trait::MatrixDataTrait;
use crate::structure::matrix::MatrixVis;
use std::collections::HashMap;
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct Matrix<T, const ROWS: usize, const COLS: usize = ROWS> {
pub(crate) data: [[T; COLS]; ROWS],
}
impl<T: MatrixDataTrait, const ROWS: usize, const COLS: usize> Matrix<T, ROWS, COLS> {
pub fn new() -> Self {
Self {
data: [[T::zero(); COLS]; ROWS],
}
}
pub fn from_array(data: [[T; COLS]; ROWS]) -> Self {
Self { data }
}
pub fn draw(&self, w: char, b: char) -> MatrixVis<ROWS, COLS> {
MatrixVis::<ROWS, COLS>::from_mat(self, w, b)
}
pub fn row(&self, index: usize) -> [T; COLS] {
self.data[index]
}
pub fn col(&self, index: usize) -> [T; ROWS] {
let mut result = [T::zero(); ROWS];
for i in 0..ROWS {
result[i] = self[(i, index)];
}
result
}
pub fn is_square(&self) -> bool {
ROWS == COLS
}
pub fn is_identity(&self) -> bool {
if self.is_square() == false {
return false;
}
for i in 0..ROWS {
for j in 0..COLS {
if i == j && self[(i, j)] != T::one() {
return false;
} else if i != j && self[(i, j)] != T::zero() {
return false;
}
}
}
true
}
pub fn is_diagonal(&self) -> bool {
if self.is_square() == false {
return false;
}
for i in 0..ROWS {
for j in 0..COLS {
if i == j && self[(i, j)] == T::zero() {
return false;
} else if i != j && self[(i, j)] != T::zero() {
return false;
}
}
}
true
}
pub fn is_symmetric(&self) -> bool {
if self.is_square() == false {
return false;
}
for i in 0..ROWS {
for j in 0..COLS {
if self[(i, j)] != self[(j, i)] {
return false;
}
}
}
true
}
pub fn is_upper_triangular(&self) -> bool {
if self.is_square() == false {
return false;
}
for i in 0..ROWS {
for j in 0..i {
if self[(i, j)] != T::zero() {
return false;
}
}
}
true
}
pub fn is_lower_triangular(&self) -> bool {
if self.is_square() == false {
return false;
}
for i in 0..ROWS {
for j in i + 1..COLS {
if self[(i, j)] != T::zero() {
return false;
}
}
}
true
}
}
impl<T: MatrixDataTrait, const ROWS: usize, const COLS: usize> std::fmt::Display
for Matrix<T, ROWS, COLS>
{
fn fmt(&self, f: &mut std::fmt::Formatter) -> core::fmt::Result {
let mut lengths = HashMap::<usize, usize>::new();
for j in 0..COLS {
for i in 0..ROWS {
let length = self[(i, j)].to_string().len();
lengths
.entry(j)
.and_modify(|e| *e = (*e).max(length))
.or_insert(length);
}
}
writeln!(f).expect("Failed to write to formatter");
for i in 0..ROWS {
for j in 0..COLS {
write!(f, "{:width$}", self[(i, j)], width = lengths[&j] + 1)
.expect("Failed to write to formatter");
}
writeln!(f).expect("Failed to write to formatter");
}
Ok(())
}
}