use serde::{Deserialize, Serialize};
use std::{
error::Error,
fmt::{Debug, Display},
fs,
marker::PhantomData,
ops::Div,
str::FromStr,
};
use itertools::{iproduct, Itertools};
use num_traits::{
pow,
sign::{abs, Signed},
Num, NumAssign, NumAssignOps, NumAssignRef, NumOps, One, Zero,
};
use rand::{distributions::uniform::SampleUniform, Rng};
use rayon::prelude::*;
use std::iter::{Product, Sum};
use crate::MatrixError;
pub type Shape = (usize, usize);
fn swap(lhs: &mut usize, rhs: &mut usize) {
let temp = *lhs;
*lhs = *rhs;
*rhs = temp;
}
macro_rules! at {
($row:expr, $col:expr, $ncols:expr) => {
$row * $ncols + $col
};
}
#[derive(Clone, PartialEq, PartialOrd, Debug, Serialize, Deserialize)]
pub struct Matrix<'a, T>
where
T: MatrixElement,
<T as FromStr>::Err: Error + 'static,
Vec<T>: IntoParallelIterator,
Vec<&'a T>: IntoParallelRefIterator<'a>,
{
data: Vec<T>,
pub nrows: usize,
pub ncols: usize,
_lifetime: PhantomData<&'a T>,
}
pub trait MatrixElement:
Copy
+ Clone
+ PartialOrd
+ Signed
+ Sum
+ Product
+ Display
+ Debug
+ FromStr
+ Default
+ One
+ PartialEq
+ Zero
+ Send
+ Sync
+ Sized
+ Num
+ NumOps
+ NumAssignOps
+ NumAssignRef
+ NumAssign
+ SampleUniform
{
}
impl<'a, T> Error for Matrix<'a, T>
where
T: MatrixElement,
<T as FromStr>::Err: Error + 'static,
Vec<T>: IntoParallelIterator,
Vec<&'a T>: IntoParallelRefIterator<'a>,
{
}
unsafe impl<'a, T> Send for Matrix<'a, T>
where
T: MatrixElement,
<T as FromStr>::Err: Error + 'static,
Vec<T>: IntoParallelIterator,
Vec<&'a T>: IntoParallelRefIterator<'a>,
{
}
unsafe impl<'a, T> Sync for Matrix<'a, T>
where
T: MatrixElement,
<T as FromStr>::Err: Error + 'static,
Vec<T>: IntoParallelIterator,
Vec<&'a T>: IntoParallelRefIterator<'a>,
{
}
impl MatrixElement for i8 {}
impl MatrixElement for i16 {}
impl MatrixElement for i32 {}
impl MatrixElement for i64 {}
impl MatrixElement for i128 {}
impl MatrixElement for f32 {}
impl MatrixElement for f64 {}
impl<'a, T> FromStr for Matrix<'a, T>
where
T: MatrixElement,
<T as FromStr>::Err: Error + 'static,
Vec<T>: IntoParallelIterator,
Vec<&'a T>: IntoParallelRefIterator<'a>,
{
type Err = anyhow::Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let v: Vec<T> = s
.trim()
.lines()
.map(|l| {
l.split_whitespace()
.map(|num| num.parse::<T>().unwrap())
.collect::<Vec<T>>()
})
.collect::<Vec<Vec<T>>>()
.into_iter()
.flatten()
.collect();
let rows = s.trim().lines().count();
let cols = s.trim().lines().nth(0).unwrap().split_whitespace().count();
Ok(Self::new(v, (rows, cols)).unwrap())
}
}
impl<'a, T> Display for Matrix<'a, T>
where
T: MatrixElement,
<T as FromStr>::Err: Error + 'static,
Vec<T>: IntoParallelIterator,
Vec<&'a T>: IntoParallelRefIterator<'a>,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "[");
if self.nrows > 10 || self.ncols > 10 {
write!(f, "...");
}
for i in 0..self.nrows {
for j in 0..self.ncols {
if i == 0 {
write!(f, "{:.4} ", self.get(i, j).unwrap());
} else {
write!(f, " {:.4}", self.get(i, j).unwrap());
}
}
if i == self.nrows - 1 {
break;
}
write!(f, "\n");
}
writeln!(f, "], dtype={}", std::any::type_name::<T>())
}
}
impl<'a, T> Default for Matrix<'a, T>
where
T: MatrixElement,
<T as FromStr>::Err: Error + 'static,
Vec<T>: IntoParallelIterator,
Vec<&'a T>: IntoParallelRefIterator<'a>,
{
fn default() -> Self {
Self {
data: vec![T::one(); 9],
nrows: 3,
ncols: 3,
_lifetime: PhantomData::default(),
}
}
}
impl<'a, T> Matrix<'a, T>
where
T: MatrixElement,
<T as FromStr>::Err: Error + 'static,
Vec<T>: IntoParallelIterator,
Vec<&'a T>: IntoParallelRefIterator<'a>,
{
pub fn print(&self, decimals: usize) {
print!("[");
if self.nrows > 10 || self.ncols > 10 {
print!("...");
}
for i in 0..self.nrows {
for j in 0..self.ncols {
if i == 0 {
print!(
"{val:.dec$} ",
dec = decimals,
val = self.get(i, j).unwrap()
);
} else {
print!(
" {val:.dec$}",
dec = decimals,
val = self.get(i, j).unwrap()
);
}
}
if i == self.nrows - 1 {
break;
}
print!("\n");
}
println!("], dtype={}", std::any::type_name::<T>());
}
#[inline(always)]
pub fn sparcity(&'a self) -> f64 {
self.count_where(|&e| e == T::zero()) as f64 / self.size() as f64
}
pub fn shape(&self) -> Shape {
(self.nrows, self.ncols)
}
}
impl<'a, T> Matrix<'a, T>
where
T: MatrixElement,
<T as FromStr>::Err: Error,
Vec<T>: IntoParallelIterator,
Vec<&'a T>: IntoParallelRefIterator<'a>,
{
pub fn new(data: Vec<T>, shape: Shape) -> Result<Self, MatrixError> {
if shape.0 * shape.1 != data.len() {
return Err(MatrixError::MatrixCreationError.into());
}
Ok(Self {
data,
nrows: shape.0,
ncols: shape.1,
_lifetime: PhantomData::default(),
})
}
pub fn init(value: T, shape: Shape) -> Self {
Self::from_shape(value, shape)
}
pub fn eye(size: usize) -> Self {
let mut data: Vec<T> = vec![T::zero(); size * size];
(0..size).for_each(|i| data[i * size + i] = T::one());
Self::new(data, (size, size)).unwrap()
}
pub fn identity(size: usize) -> Self {
Self::eye(size)
}
pub fn from_slice(arr: &[T], shape: Shape) -> Result<Self, MatrixError> {
if shape.0 * shape.1 != arr.len() {
return Err(MatrixError::MatrixCreationError.into());
}
Ok(Self::new(arr.to_owned(), shape).unwrap())
}
pub fn zeros(shape: Shape) -> Self {
Self::from_shape(T::zero(), shape)
}
pub fn ones(shape: Shape) -> Self {
Self::from_shape(T::one(), shape)
}
pub fn zeros_like(other: &Self) -> Self {
Self::from_shape(T::zero(), other.shape())
}
pub fn ones_like(other: &Self) -> Self {
Self::from_shape(T::one(), other.shape())
}
pub fn random_like(matrix: &Self) -> Self {
Self::randomize_range(T::zero(), T::one(), matrix.shape())
}
pub fn randomize_range(start: T, end: T, shape: Shape) -> Self {
let mut rng = rand::thread_rng();
let (rows, cols) = shape;
let len: usize = rows * cols;
let data: Vec<T> = (0..len).map(|_| rng.gen_range(start..=end)).collect();
Self::new(data, shape).unwrap()
}
pub fn randomize(shape: Shape) -> Self {
Self::randomize_range(T::zero(), T::one(), shape)
}
pub fn from_file(path: &'static str) -> Result<Self, MatrixError> {
let data =
fs::read_to_string(path).map_err(|_| MatrixError::MatrixFileReadError(path).into())?;
data.parse::<Self>()
.map_err(|_| MatrixError::MatrixParseError.into())
}
fn from_shape(value: T, shape: Shape) -> Self {
let (rows, cols) = shape;
let len: usize = rows * cols;
let data = vec![value; len];
Self::new(data, shape).unwrap()
}
}
pub enum Dimension {
Row = 0,
Col = 1,
}
impl<'a, T> Matrix<'a, T>
where
T: MatrixElement + Div<Output = T> + Sum<T>,
<T as FromStr>::Err: Error + 'static,
Vec<T>: IntoParallelIterator,
Vec<&'a T>: IntoParallelRefIterator<'a>,
{
pub fn reshape(&mut self, new_shape: Shape) {
if new_shape.0 * new_shape.1 != self.size() {
eprintln!("Err: Can not reshape.. Keeping old dimensions for now");
return;
}
self.nrows = new_shape.0;
self.ncols = new_shape.1;
}
pub fn size(&self) -> usize {
self.nrows * self.ncols
}
pub fn get(&self, i: usize, j: usize) -> Option<T> {
let idx = at!(i, j, self.ncols);
if idx >= self.size() {
return None;
}
Some(self.at(i, j))
}
pub fn at(&self, i: usize, j: usize) -> T {
self.data[at!(i, j, self.ncols)]
}
pub fn get_vec_slice(&self, start_idx: Shape, size: Shape) -> Vec<T> {
let (start_row, start_col) = start_idx;
let (dx, dy) = size;
iproduct!(start_row..start_row + dy, start_col..start_col + dx)
.filter_map(|(i, j)| self.get(i, j))
.collect()
}
pub fn get_vec(&self) -> Vec<T> {
self.data.clone()
}
pub fn get_sub_matrix(&self, start_idx: Shape, size: Shape) -> Result<Self, MatrixError> {
let (start_row, start_col) = start_idx;
let (dx, dy) = size;
let data = iproduct!(start_row..start_row + dy, start_col..start_col + dx)
.filter_map(|(i, j)| self.get(i, j))
.collect();
return match Self::new(data, size) {
Ok(a) => Ok(a),
Err(_) => Err(MatrixError::MatrixIndexOutOfBoundsError.into()),
};
}
pub fn concat(&self, other: &Self, dim: Dimension) -> Result<Self, MatrixError> {
match dim {
Dimension::Row => {
if self.ncols != other.ncols {
return Err(MatrixError::MatrixConcatinationError.into());
}
let mut new_data = self.data.clone();
new_data.extend(other.data.iter());
let nrows = self.nrows + other.nrows;
let shape = (nrows, self.ncols);
return Ok(Self::new(new_data, shape).unwrap());
}
Dimension::Col => {
if self.nrows != other.nrows {
return Err(MatrixError::MatrixConcatinationError.into());
}
let mut new_data: Vec<T> = Vec::new();
let take_self = self.ncols;
let take_other = other.ncols;
for (idx, _) in self.data.iter().step_by(take_self).enumerate() {
let row = (idx / take_self) * take_self;
new_data.extend(self.data.iter().skip(row).take(take_self));
new_data.extend(other.data.iter().skip(row).take(take_other));
}
let ncols = self.ncols + other.ncols;
let shape = (self.nrows, ncols);
return Ok(Self::new(new_data, shape).unwrap());
}
};
}
pub fn extend(&mut self, other: &Self, dim: Dimension) {
match dim {
Dimension::Row => {
if self.ncols != other.ncols {
eprintln!("Error: Dimension mismatch");
return;
}
self.data.extend(other.data.iter());
self.nrows += other.nrows;
}
Dimension::Col => {
if self.nrows != other.nrows {
eprintln!("Error: Dimension mismatch");
return;
}
let mut new_data: Vec<T> = Vec::new();
let take_self = self.ncols;
let take_other = other.ncols;
for (idx, _) in self.data.iter().step_by(take_self).enumerate() {
let row = (idx / take_self) * take_self;
new_data.extend(self.data.iter().skip(row).take(take_self));
new_data.extend(other.data.iter().skip(row).take(take_other));
}
self.ncols += other.ncols;
}
};
}
pub fn set(&mut self, value: T, idx: Shape) {
let idx = at!(idx.0, idx.1, self.ncols);
if idx >= self.size() {
eprintln!("Error: Index out of bounds. Not setting value.");
return;
}
self.data[idx] = value;
}
pub fn set_many(&mut self, idx_list: Vec<Shape>, value: T) {
idx_list.iter().for_each(|&idx| self.set(value, idx));
}
pub fn set_range(&mut self, start: usize, stop: usize, value: T) {
(start..=stop).for_each(|i| self.data[i] = value);
}
pub fn one_to_2d_idx(&self, idx: usize) -> Shape {
let row = idx / self.ncols;
let col = idx % self.ncols;
(row, col)
}
pub fn max(&self) -> T {
*self
.data
.par_iter()
.max_by(|a, b| a.partial_cmp(b).unwrap())
.unwrap()
}
pub fn min(&self) -> T {
*self
.data
.par_iter()
.min_by(|a, b| a.partial_cmp(b).unwrap())
.unwrap()
}
fn argmax(&self, rowcol: usize, dimension: Dimension) -> Option<Shape> {
match dimension {
Dimension::Row => {
if rowcol >= self.nrows - 1 {
return None;
}
let mut highest: T = T::one();
let mut i = 0;
for (idx, elem) in self
.data
.iter()
.enumerate()
.skip(rowcol * self.ncols)
.take(self.ncols)
{
if *elem >= highest {
i = idx;
}
}
Some(self.one_to_2d_idx(i))
}
Dimension::Col => {
if rowcol >= self.ncols - 1 {
return None;
}
let mut highest: T = T::one();
let mut i = 0;
for (idx, elem) in self
.data
.iter()
.enumerate()
.skip(rowcol)
.step_by(self.ncols)
{
if *elem >= highest {
i = idx;
}
}
Some(self.one_to_2d_idx(i))
}
}
}
fn argmin(&self, rowcol: usize, dimension: Dimension) -> Option<Shape> {
match dimension {
Dimension::Row => {
if rowcol >= self.nrows - 1 {
return None;
}
let mut lowest: T = T::zero();
let mut i = 0;
for (idx, elem) in self
.data
.iter()
.enumerate()
.skip(rowcol * self.ncols)
.take(self.ncols)
{
if *elem < lowest {
i = idx;
}
}
Some(self.one_to_2d_idx(i))
}
Dimension::Col => {
if rowcol >= self.ncols - 1 {
return None;
}
let mut lowest: T = T::zero();
let mut i = 0;
for (idx, elem) in self
.data
.iter()
.enumerate()
.skip(rowcol)
.step_by(self.ncols)
{
if *elem <= lowest {
i = idx;
}
}
Some(self.one_to_2d_idx(i))
}
}
}
pub fn cumsum(&self) -> T {
if self.size() == 0 {
return T::zero();
}
self.data.par_iter().copied().sum()
}
pub fn cumprod(&self) -> T {
if self.size() == 0 {
return T::zero();
}
self.data.par_iter().copied().product()
}
pub fn avg(&self) -> T {
if self.size() == 0 {
return T::zero();
}
let mut size: T = T::zero();
self.data.iter().for_each(|_| size += T::one());
let tot: T = self.data.par_iter().copied().sum::<T>();
tot / size
}
pub fn mean(&self) -> T {
self.avg()
}
pub fn median(&self) -> T {
if self.size() == 1 {
return self.at(0, 0);
}
match self.data.len() % 2 {
0 => {
let half: usize = self.data.len() / 2;
self.data
.iter()
.sorted_by(|a, b| a.partial_cmp(&b).unwrap())
.skip(half - 1)
.take(2)
.copied()
.sum::<T>()
/ (T::one() + T::one())
}
1 => {
let half: usize = self.data.len() / 2;
self.data
.iter()
.sorted_by(|a, b| a.partial_cmp(&b).unwrap())
.nth(half)
.copied()
.unwrap()
}
_ => unreachable!(),
}
}
pub fn sum(&self, rowcol: usize, dimension: Dimension) -> T {
if self.size() == 1 {
return self.at(0, 0);
}
match dimension {
Dimension::Row => self
.data
.par_iter()
.skip(rowcol * self.ncols)
.take(self.ncols)
.copied()
.sum(),
Dimension::Col => self
.data
.par_iter()
.skip(rowcol)
.step_by(self.ncols)
.copied()
.sum(),
}
}
pub fn prod(&self, rowcol: usize, dimension: Dimension) -> T {
match dimension {
Dimension::Row => self
.data
.par_iter()
.skip(rowcol * self.ncols)
.take(self.ncols)
.copied()
.product(),
Dimension::Col => self
.data
.par_iter()
.skip(rowcol)
.step_by(self.ncols)
.copied()
.product(),
}
}
}
impl<'a, T> Matrix<'a, T>
where
T: MatrixElement,
<T as FromStr>::Err: Error + 'static,
Vec<T>: IntoParallelIterator,
Vec<&'a T>: IntoParallelRefIterator<'a>,
{
pub fn add(&self, other: &Self) -> Result<Self, MatrixError> {
if self.shape() != other.shape() {
return Err(MatrixError::MatrixDimensionMismatchError.into());
}
let data: Vec<T> = self
.data
.iter()
.zip(other.data.iter())
.map(|(&x, &y)| x + y)
.collect();
Ok(Self::new(data, self.shape()).unwrap())
}
pub fn sub(&self, other: &Self) -> Result<Self, MatrixError> {
if self.shape() != other.shape() {
return Err(MatrixError::MatrixDimensionMismatchError.into());
}
let data: Vec<T> = self
.data
.iter()
.zip(other.data.iter())
.map(|(&x, &y)| x - y)
.collect();
Ok(Self::new(data, self.shape()).unwrap())
}
pub fn sub_abs(&self, other: &Self) -> Result<Self, MatrixError> {
if self.shape() != other.shape() {
return Err(MatrixError::MatrixDimensionMismatchError.into());
}
let data = self
.data
.iter()
.zip(other.data.iter())
.map(|(&x, &y)| if x > y { x - y } else { y - x })
.collect_vec();
Ok(Self::new(data, self.shape()).unwrap())
}
pub fn mul(&self, other: &Self) -> Result<Self, MatrixError> {
if self.shape() != other.shape() {
return Err(MatrixError::MatrixDimensionMismatchError.into());
}
let data = self
.data
.iter()
.zip(other.data.iter())
.map(|(&x, &y)| x * y)
.collect_vec();
Ok(Self::new(data, self.shape()).unwrap())
}
pub fn dot(&self, other: &Self) -> Result<Self, MatrixError> {
self.mul(other)
}
pub fn div(&self, other: &Self) -> Result<Self, MatrixError> {
if self.shape() != other.shape() {
return Err(MatrixError::MatrixDimensionMismatchError.into());
}
if other.any(|e| e == &T::zero()) {
return Err(MatrixError::MatrixDivideByZeroError.into());
}
let data = self
.data
.iter()
.zip(other.data.iter())
.map(|(&x, &y)| x / y)
.collect_vec();
Ok(Self::new(data, self.shape()).unwrap())
}
pub fn add_val(&self, val: T) -> Self {
let data: Vec<T> = self.data.par_iter().map(|&e| e + val).collect();
Self::new(data, self.shape()).unwrap()
}
pub fn sub_val(&self, val: T) -> Self {
let data: Vec<T> = self.data.par_iter().map(|&e| e - val).collect();
Self::new(data, self.shape()).unwrap()
}
pub fn mul_val(&self, val: T) -> Self {
let data: Vec<T> = self.data.par_iter().map(|&e| e * val).collect();
Self::new(data, self.shape()).unwrap()
}
pub fn div_val(&self, val: T) -> Self {
let data: Vec<T> = self.data.par_iter().map(|&e| e / val).collect();
Self::new(data, self.shape()).unwrap()
}
fn log(&self, base: T) -> Self {
unimplemented!()
}
fn ln(&self) -> Self {
unimplemented!()
}
fn tanh(&self) -> Self {
unimplemented!()
}
fn sinh(&self) -> Self {
unimplemented!()
}
fn cosh(&self) -> Self {
unimplemented!()
}
pub fn pow(&self, val: usize) -> Self {
let data: Vec<T> = self.data.par_iter().map(|&e| pow(e, val)).collect();
Self::new(data, self.shape()).unwrap()
}
pub fn abs(&self) -> Self {
let data: Vec<T> = self.data.par_iter().map(|&e| abs(e)).collect();
Self::new(data, self.shape()).unwrap()
}
pub fn add_self(&mut self, other: &Self) {
self.data
.par_iter_mut()
.zip(&other.data)
.for_each(|(a, b)| *a += *b);
}
pub fn sub_self(&mut self, other: &Self) {
self.data
.par_iter_mut()
.zip(&other.data)
.for_each(|(a, b)| *a -= *b);
}
pub fn mul_self(&mut self, other: &Self) {
self.data
.par_iter_mut()
.zip(&other.data)
.for_each(|(a, b)| *a *= *b);
}
pub fn div_self(&mut self, other: &Self) {
self.data
.par_iter_mut()
.zip(&other.data)
.for_each(|(a, b)| *a /= *b);
}
pub fn abs_self(&mut self) {
self.data.par_iter_mut().for_each(|e| *e = abs(*e))
}
pub fn add_val_self(&mut self, val: T) {
self.data.par_iter_mut().for_each(|e| *e += val);
}
pub fn sub_val_self(&mut self, val: T) {
self.data.par_iter_mut().for_each(|e| *e -= val);
}
pub fn mul_val_self(&mut self, val: T) {
self.data.par_iter_mut().for_each(|e| *e *= val);
}
pub fn div_val_self(&mut self, val: T) {
self.data.par_iter_mut().for_each(|e| *e /= val);
}
pub fn matmul(&self, other: &Self) -> Result<Self, MatrixError> {
if self.ncols != other.nrows {
return Err(MatrixError::MatrixDimensionMismatchError.into());
}
let r1 = self.nrows;
let c1 = self.ncols;
let c2 = other.ncols;
let mut data = vec![T::zero(); c2 * r1];
let t_other = other.transpose_copy();
for i in 0..r1 {
for j in 0..c2 {
data[at!(i, j, c2)] = (0..c1)
.into_par_iter()
.map(|k| self.data[at!(i, k, c1)] * t_other.data[at!(j, k, t_other.ncols)])
.sum();
}
}
Ok(Self::new(data, (c2, r1)).unwrap())
}
pub fn transpose(&mut self) {
for i in 0..self.nrows {
for j in (i + 1)..self.ncols {
let lhs = at!(i, j, self.ncols);
let rhs = at!(j, i, self.nrows);
self.data.swap(lhs, rhs);
}
}
swap(&mut self.nrows, &mut self.ncols);
}
pub fn t(&mut self) {
self.transpose()
}
pub fn transpose_copy(&self) -> Self {
let mut res = self.clone();
res.transpose();
res
}
fn eigenvalue(&self) -> T {
todo!()
}
}
impl<'a, T> Matrix<'a, T>
where
T: MatrixElement,
<T as FromStr>::Err: Error + 'static,
Vec<T>: IntoParallelIterator,
Vec<&'a T>: IntoParallelRefIterator<'a>,
{
pub fn count_where<F>(&'a self, pred: F) -> usize
where
F: Fn(&T) -> bool + Sync,
{
self.data.par_iter().filter(|&e| pred(e)).count()
}
pub fn sum_where<F>(&self, pred: F) -> T
where
F: Fn(&T) -> bool + Sync,
{
self.data
.par_iter()
.filter(|&e| pred(e))
.copied()
.sum::<T>()
}
pub fn set_where<P>(&mut self, mut pred: P)
where
P: FnMut(&mut T) + Sync + Send,
{
self.data.iter_mut().for_each(|e| pred(e));
}
pub fn any<F>(&self, pred: F) -> bool
where
F: Fn(&T) -> bool + Sync + Send,
{
self.data.par_iter().any(pred)
}
pub fn all<F>(&self, pred: F) -> bool
where
F: Fn(&T) -> bool + Sync + Send,
{
self.data.par_iter().all(pred)
}
pub fn find<F>(&self, pred: F) -> Option<Shape>
where
F: Fn(&T) -> bool + Sync,
{
if let Some((idx, _)) = self.data.iter().find_position(|&e| pred(e)) {
return Some(self.one_to_2d_idx(idx));
}
None
}
pub fn find_all<F>(&self, pred: F) -> Option<Vec<Shape>>
where
F: Fn(&T) -> bool + Sync,
{
let data: Vec<Shape> = self
.data
.par_iter()
.enumerate()
.filter_map(|(idx, elem)| {
if pred(elem) {
Some(self.one_to_2d_idx(idx))
} else {
None
}
})
.collect();
if data.is_empty() {
None
} else {
Some(data)
}
}
}