use serde::{Deserialize, Serialize};
use std::{
error::Error,
fmt::{Debug, Display},
fs,
marker::PhantomData,
ops::Div,
str::FromStr,
};
use itertools::{iproduct, Itertools};
use num_traits::{
pow,
sign::{abs, Signed},
Num, NumAssign, NumAssignOps, NumAssignRef, NumOps, One, Zero,
};
use rand::{distributions::uniform::SampleUniform, Rng};
use rayon::prelude::*;
use std::iter::{Product, Sum};
pub type Shape = (usize, usize);
fn swap(lhs: &mut usize, rhs: &mut usize) {
let temp = *lhs;
*lhs = *rhs;
*rhs = temp;
}
macro_rules! at {
($row:ident, $col:ident, $ncols:expr) => {
$row * $ncols + $col
};
}
#[derive(Clone, PartialEq, PartialOrd, Debug, Serialize, Deserialize)]
pub struct Matrix<'a, T>
where
T: MatrixElement,
<T as FromStr>::Err: Error + 'static,
Vec<T>: IntoParallelIterator,
Vec<&'a T>: IntoParallelRefIterator<'a>,
{
pub data: Vec<T>,
pub shape: Shape,
pub nrows: usize,
pub ncols: usize,
_lifetime: PhantomData<&'a T>,
}
pub trait MatrixElement:
Copy
+ Clone
+ PartialOrd
+ Signed
+ Sum
+ Product
+ Display
+ Debug
+ FromStr
+ Default
+ One
+ PartialEq
+ Zero
+ Send
+ Sync
+ Sized
+ Num
+ NumOps
+ NumAssignOps
+ NumAssignRef
+ NumAssign
+ SampleUniform
{
}
unsafe impl<'a, T> Send for Matrix<'a, T>
where
T: MatrixElement,
<T as FromStr>::Err: Error + 'static,
Vec<T>: IntoParallelIterator,
Vec<&'a T>: IntoParallelRefIterator<'a>,
{
}
unsafe impl<'a, T> Sync for Matrix<'a, T>
where
T: MatrixElement,
<T as FromStr>::Err: Error + 'static,
Vec<T>: IntoParallelIterator,
Vec<&'a T>: IntoParallelRefIterator<'a>,
{
}
impl MatrixElement for i8 {}
impl MatrixElement for i16 {}
impl MatrixElement for i32 {}
impl MatrixElement for i64 {}
impl MatrixElement for i128 {}
impl MatrixElement for f32 {}
impl MatrixElement for f64 {}
impl<'a, T> FromStr for Matrix<'a, T>
where
T: MatrixElement,
<T as FromStr>::Err: Error + 'static,
Vec<T>: IntoParallelIterator,
Vec<&'a T>: IntoParallelRefIterator<'a>,
{
type Err = anyhow::Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let v: Vec<T> = s
.trim()
.lines()
.map(|l| {
l.split_whitespace()
.map(|num| num.parse::<T>().unwrap())
.collect::<Vec<T>>()
})
.collect::<Vec<Vec<T>>>()
.into_iter()
.flatten()
.collect();
let rows = s.trim().lines().count();
let cols = s.trim().lines().nth(0).unwrap().split_whitespace().count();
Ok(Self::new(v, (rows, cols)))
}
}
impl<'a, T> Display for Matrix<'a, T>
where
T: MatrixElement,
<T as FromStr>::Err: Error + 'static,
Vec<T>: IntoParallelIterator,
Vec<&'a T>: IntoParallelRefIterator<'a>,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "[");
if self.nrows > 10 || self.ncols > 10 {
write!(f, "...");
}
for i in 0..self.nrows {
for j in 0..self.ncols {
if i == 0 {
write!(f, "{:.4} ", self.get(i, j));
} else {
write!(f, " {:.4}", self.get(i, j));
}
}
if i == self.shape.0 - 1 {
break;
}
write!(f, "\n");
}
writeln!(f, "], dtype={}", std::any::type_name::<T>())
}
}
impl<'a, T> Default for Matrix<'a, T>
where
T: MatrixElement,
<T as FromStr>::Err: Error + 'static,
Vec<T>: IntoParallelIterator,
Vec<&'a T>: IntoParallelRefIterator<'a>,
{
fn default() -> Self {
Self {
data: vec![T::one(); 9],
shape: (3, 3),
nrows: 3,
ncols: 3,
_lifetime: PhantomData::default(),
}
}
}
impl<'a, T> Matrix<'a, T>
where
T: MatrixElement,
<T as FromStr>::Err: Error + 'static,
Vec<T>: IntoParallelIterator,
Vec<&'a T>: IntoParallelRefIterator<'a>,
{
pub fn print(&self, decimals: usize) {
print!("[");
if self.nrows > 10 || self.ncols > 10 {
print!("...");
}
for i in 0..self.nrows {
for j in 0..self.ncols {
if i == 0 {
print!("{val:.dec$} ", dec = decimals, val = self.get(i, j));
} else {
print!(" {val:.dec$}", dec = decimals, val = self.get(i, j));
}
}
if i == self.shape.0 - 1 {
break;
}
print!("\n");
}
println!("], dtype={}", std::any::type_name::<T>());
}
}
impl<'a, T> Matrix<'a, T>
where
T: MatrixElement,
<T as FromStr>::Err: Error,
Vec<T>: IntoParallelIterator,
Vec<&'a T>: IntoParallelRefIterator<'a>,
{
pub fn new(data: Vec<T>, shape: Shape) -> Self {
if shape.0 * shape.1 != data.len() {
return Self::default();
}
Self {
data,
shape,
nrows: shape.0,
ncols: shape.1,
_lifetime: PhantomData::default(),
}
}
pub fn init(value: T, shape: Shape) -> Self {
Self::from_shape(value, shape)
}
pub fn eye(size: usize) -> Self {
let mut data: Vec<T> = vec![T::zero(); size * size];
(0..size).for_each(|i| data[i * size + i] = T::one());
Self::new(data, (size, size))
}
pub fn identity(size: usize) -> Self {
Self::eye(size)
}
pub fn from_slice(arr: &[T], shape: Shape) -> Option<Self> {
if shape.0 * shape.1 != arr.len() {
return None;
}
Some(Self::new(arr.to_owned(), shape))
}
pub fn zeros(shape: Shape) -> Self {
Self::from_shape(T::zero(), shape)
}
pub fn ones(shape: Shape) -> Self {
Self::from_shape(T::one(), shape)
}
pub fn zeros_like(other: &Self) -> Self {
Self::from_shape(T::zero(), other.shape)
}
pub fn ones_like(other: &Self) -> Self {
Self::from_shape(T::one(), other.shape)
}
pub fn random_like(matrix: &Self) -> Self {
Self::randomize_range(T::zero(), T::one(), matrix.shape)
}
pub fn randomize_range(start: T, end: T, shape: Shape) -> Self {
let mut rng = rand::thread_rng();
let (rows, cols) = shape;
let len: usize = rows * cols;
let data: Vec<T> = (0..len).map(|_| rng.gen_range(start..=end)).collect();
Self::new(data, shape)
}
pub fn randomize(shape: Shape) -> Self {
Self::randomize_range(T::zero(), T::one(), shape)
}
pub fn from_file(path: &'static str) -> Self {
let data =
fs::read_to_string(path).unwrap_or_else(|_| panic!("Failed to read file: {}", path));
data.parse::<Self>().unwrap_or_else(|_| Self::default())
}
fn from_shape(value: T, shape: Shape) -> Self {
let (rows, cols) = shape;
let len: usize = rows * cols;
let data = vec![value; len];
Self::new(data, shape)
}
}
pub enum Dimension {
Row = 0,
Col = 1,
}
impl<'a, T> Matrix<'a, T>
where
T: MatrixElement + Div<Output = T> + Sum<T>,
<T as FromStr>::Err: Error + 'static,
Vec<T>: IntoParallelIterator,
Vec<&'a T>: IntoParallelRefIterator<'a>,
{
pub fn to_tensor(&self) {
todo!()
}
pub fn reshape(&mut self, new_shape: Shape) {
if new_shape.0 * new_shape.1 != self.size() {
println!("Can not reshape.. Keeping old dimensions for now");
}
self.shape = new_shape;
}
pub fn size(&self) -> usize {
self.nrows * self.ncols
}
pub fn get(&self, i: usize, j: usize) -> T {
self.data[at!(i, j, self.ncols)]
}
pub fn get_vec_slice(&self, start_idx: Shape, size: Shape) -> Vec<T> {
let (start_row, start_col) = start_idx;
let (dx, dy) = size;
iproduct!(start_row..start_row + dy, start_col..start_col + dx)
.map(|(i, j)| self.get(i, j))
.collect()
}
pub fn get_sub_matrix(&self, start_idx: Shape, size: Shape) -> Self {
let (start_row, start_col) = start_idx;
let (dx, dy) = size;
let data = iproduct!(start_row..start_row + dy, start_col..start_col + dx)
.map(|(i, j)| self.get(i, j))
.collect();
Self::new(data, size)
}
pub fn set(&mut self, i: usize, j: usize, value: T) {
self.data[at!(i, j, self.ncols)] = value;
}
pub fn inverse_at(&self, idx: usize) -> Shape {
let row = idx / self.ncols;
let col = idx % self.ncols;
(row, col)
}
pub fn max(&self) -> T {
*self
.data
.par_iter()
.max_by(|a, b| a.partial_cmp(b).unwrap())
.unwrap()
}
pub fn min(&self) -> T {
*self
.data
.par_iter()
.min_by(|a, b| a.partial_cmp(b).unwrap())
.unwrap()
}
pub fn argmax(&self, rowcol: usize, dimension: Dimension) -> Option<T> {
match dimension {
Dimension::Row => {
if rowcol >= self.nrows - 1 {
return None;
}
self.data
.par_iter()
.skip(rowcol * self.ncols)
.take(self.ncols)
.max_by(|a, b| a.partial_cmp(b).unwrap())
.copied()
}
Dimension::Col => {
if rowcol >= self.ncols - 1 {
return None;
}
self.data
.par_iter()
.skip(rowcol)
.step_by(self.ncols)
.max_by(|a, b| a.partial_cmp(b).unwrap())
.copied()
}
}
}
pub fn argmin(&self, rowcol: usize, dimension: Dimension) -> Option<T> {
match dimension {
Dimension::Row => {
if rowcol >= self.nrows - 1 {
return None;
}
self.data
.par_iter()
.skip(rowcol * self.ncols)
.take(self.ncols)
.min_by(|a, b| a.partial_cmp(b).unwrap())
.copied()
}
Dimension::Col => {
if rowcol >= self.ncols - 1 {
return None;
}
self.data
.par_iter()
.skip(rowcol)
.step_by(self.ncols)
.min_by(|a, b| a.partial_cmp(b).unwrap())
.copied()
}
}
}
pub fn cumsum(&self) -> T {
self.data.par_iter().copied().sum()
}
pub fn cumprod(&self) -> T {
self.data.par_iter().copied().product()
}
pub fn avg(&self) -> T {
let mut size: T = T::zero();
self.data.iter().for_each(|_| size += T::one());
let tot: T = self.data.par_iter().copied().sum::<T>();
tot / size
}
pub fn mean(&self) -> T {
self.avg()
}
pub fn median(&self) -> T {
match self.data.len() % 2 {
0 => {
let half: usize = self.data.len() / 2;
self.data
.iter()
.sorted_by(|a, b| a.partial_cmp(&b).unwrap())
.skip(half - 1)
.take(2)
.copied()
.sum::<T>()
/ (T::one() + T::one())
}
1 => {
let half: usize = self.data.len() / 2;
self.data
.iter()
.sorted_by(|a, b| a.partial_cmp(&b).unwrap())
.nth(half)
.copied()
.unwrap()
}
_ => unreachable!(),
}
}
pub fn sum(&self, rowcol: usize, dimension: Dimension) -> T {
match dimension {
Dimension::Row => self
.data
.par_iter()
.skip(rowcol * self.ncols)
.take(self.ncols)
.copied()
.sum(),
Dimension::Col => self
.data
.par_iter()
.skip(rowcol)
.step_by(self.ncols)
.copied()
.sum(),
}
}
pub fn prod(&self, rowcol: usize, dimension: Dimension) -> T {
match dimension {
Dimension::Row => self
.data
.par_iter()
.skip(rowcol * self.ncols)
.take(self.ncols)
.copied()
.product(),
Dimension::Col => self
.data
.par_iter()
.skip(rowcol)
.step_by(self.ncols)
.copied()
.product(),
}
}
}
pub trait MatrixLinAlg<'a, T>
where
T: MatrixElement,
{
fn add(&self, other: &Self) -> Self;
fn sub(&self, other: &Self) -> Self;
fn sub_abs(&self, other: &Self) -> Self;
fn mul(&self, other: &Self) -> Self;
fn div(&self, other: &Self) -> Self;
fn add_val(&self, val: T) -> Self;
fn sub_val(&self, val: T) -> Self;
fn mul_val(&self, val: T) -> Self;
fn div_val(&self, val: T) -> Self;
fn log(&self, base: T) -> Self;
fn ln(&self) -> Self;
fn tanh(&self) -> Self;
fn pow(&self, val: usize) -> Self;
fn abs(&self) -> Self;
fn add_self(&mut self, other: &Self);
fn sub_self(&mut self, other: &Self);
fn mul_self(&mut self, other: &Self);
fn div_self(&mut self, other: &Self);
fn abs_self(&mut self);
fn add_val_self(&mut self, val: T);
fn sub_val_self(&mut self, val: T);
fn mul_val_self(&mut self, val: T);
fn div_val_self(&mut self, val: T);
fn matmul(&self, other: &Self) -> Self;
fn transpose(&mut self);
fn t(&mut self);
fn transpose_copy(&self) -> Self;
fn eigenvalue(&self) -> T;
}
impl<'a, T> MatrixLinAlg<'a, T> for Matrix<'a, T>
where
T: MatrixElement,
<T as FromStr>::Err: Error + 'static,
Vec<T>: IntoParallelIterator,
Vec<&'a T>: IntoParallelRefIterator<'a>,
{
fn add(&self, other: &Self) -> Self {
if self.nrows != other.nrows || self.ncols != other.ncols {
panic!("NOOO!");
}
let data = (0..self.nrows)
.flat_map(|i| (0..self.ncols).map(move |j| self.get(i, j) + other.get(i, j)))
.collect_vec();
Self::new(data, self.shape)
}
fn sub(&self, other: &Self) -> Self {
if self.nrows != other.nrows || self.ncols != other.ncols {
panic!("NOOO!");
}
let data = (0..self.nrows)
.flat_map(|i| (0..self.ncols).map(move |j| self.get(i, j) - other.get(i, j)))
.collect_vec();
Self::new(data, self.shape)
}
fn sub_abs(&self, other: &Self) -> Self {
if self.nrows != other.nrows || self.ncols != other.ncols {
panic!("NOOO!");
}
let data = (0..self.nrows)
.flat_map(|i| {
(0..self.ncols).map(move |j| {
let a = self.get(i, j);
let b = other.get(i, j);
if a > b {
a - b
} else {
b - a
}
})
})
.collect_vec();
Self::new(data, self.shape)
}
fn mul(&self, other: &Self) -> Self {
if self.nrows != other.nrows || self.ncols != other.ncols {
panic!("NOOO!");
}
let data = (0..self.nrows)
.flat_map(|i| (0..self.ncols).map(move |j| self.get(i, j) * other.get(i, j)))
.collect_vec();
Self::new(data, self.shape)
}
fn div(&self, other: &Self) -> Self {
if self.nrows != other.nrows || self.ncols != other.ncols {
panic!("NOOO!");
}
if other.any(|e| e == &T::zero()) {
panic!("NOOOOO")
}
let data = (0..self.nrows)
.flat_map(|i| (0..self.ncols).map(move |j| self.get(i, j) / other.get(i, j)))
.collect_vec();
Self::new(data, self.shape)
}
fn add_val(&self, val: T) -> Self {
let data: Vec<T> = self.data.par_iter().map(|&e| e + val).collect();
Self::new(data, self.shape)
}
fn sub_val(&self, val: T) -> Self {
let data: Vec<T> = self.data.par_iter().map(|&e| e - val).collect();
Self::new(data, self.shape)
}
fn mul_val(&self, val: T) -> Self {
let data: Vec<T> = self.data.par_iter().map(|&e| e * val).collect();
Self::new(data, self.shape)
}
fn div_val(&self, val: T) -> Self {
let data: Vec<T> = self.data.par_iter().map(|&e| e / val).collect();
Self::new(data, self.shape)
}
fn log(&self, base: T) -> Self {
unimplemented!()
}
fn ln(&self) -> Self {
unimplemented!()
}
fn tanh(&self) -> Self {
unimplemented!()
}
fn pow(&self, val: usize) -> Self {
let data: Vec<T> = self.data.par_iter().map(|&e| pow(e, val)).collect();
Self::new(data, self.shape)
}
fn abs(&self) -> Self {
let data: Vec<T> = self.data.par_iter().map(|&e| abs(e)).collect();
Self::new(data, self.shape)
}
fn add_self(&mut self, other: &Self) {
self.data
.par_iter_mut()
.zip(&other.data)
.for_each(|(a, b)| *a += *b);
}
fn sub_self(&mut self, other: &Self) {
self.data
.par_iter_mut()
.zip(&other.data)
.for_each(|(a, b)| *a -= *b);
}
fn mul_self(&mut self, other: &Self) {
self.data
.par_iter_mut()
.zip(&other.data)
.for_each(|(a, b)| *a *= *b);
}
fn div_self(&mut self, other: &Self) {
self.data
.par_iter_mut()
.zip(&other.data)
.for_each(|(a, b)| *a /= *b);
}
fn abs_self(&mut self) {
self.data.par_iter_mut().for_each(|e| *e = abs(*e))
}
fn add_val_self(&mut self, val: T) {
self.data.par_iter_mut().for_each(|e| *e += val);
}
fn sub_val_self(&mut self, val: T) {
self.data.par_iter_mut().for_each(|e| *e -= val);
}
fn mul_val_self(&mut self, val: T) {
self.data.par_iter_mut().for_each(|e| *e *= val);
}
fn div_val_self(&mut self, val: T) {
self.data.par_iter_mut().for_each(|e| *e /= val);
}
fn matmul(&self, other: &Self) -> Self {
if self.ncols != other.nrows {
panic!("Oops, dimensions do not match");
}
let r1 = self.nrows;
let c1 = self.ncols;
let c2 = other.ncols;
let mut data = vec![T::zero(); c2 * r1];
let t_other = other.transpose_copy();
for i in 0..r1 {
for j in 0..c2 {
data[at!(i, j, c2)] = (0..c1)
.into_par_iter()
.map(|k| self.data[at!(i, k, c1)] * t_other.data[at!(j, k, t_other.ncols)])
.sum();
}
}
Self::new(data, (c2, r1))
}
fn transpose(&mut self) {
for i in 0..self.nrows {
for j in (i + 1)..self.ncols {
let lhs = at!(i, j, self.ncols);
let rhs = at!(j, i, self.nrows);
self.data.swap(lhs, rhs);
}
}
swap(&mut self.shape.0, &mut self.shape.1);
swap(&mut self.nrows, &mut self.ncols);
}
fn t(&mut self) {
self.transpose()
}
fn transpose_copy(&self) -> Self {
let mut res = self.clone();
res.transpose();
res
}
fn eigenvalue(&self) -> T {
todo!()
}
}
impl<'a, T> Matrix<'a, T>
where
T: MatrixElement,
<T as FromStr>::Err: Error + 'static,
Vec<T>: IntoParallelIterator,
Vec<&'a T>: IntoParallelRefIterator<'a>,
{
pub fn count_where<F>(&'a self, pred: F) -> usize
where
F: Fn(&T) -> bool + Sync,
{
self.data.par_iter().filter(|&e| pred(e)).count()
}
pub fn sum_where<F>(&self, pred: F) -> T
where
F: Fn(&T) -> bool + Sync,
{
self.data
.par_iter()
.filter(|&e| pred(e))
.copied()
.sum::<T>()
}
pub fn any<F>(&self, pred: F) -> bool
where
F: Fn(&T) -> bool + Sync + Send,
{
self.data.par_iter().any(pred)
}
pub fn all<F>(&self, pred: F) -> bool
where
F: Fn(&T) -> bool + Sync + Send,
{
self.data.par_iter().all(pred)
}
pub fn find<F>(&self, pred: F) -> Option<Shape>
where
F: Fn(&T) -> bool + Sync,
{
if let Some((idx, _)) = self.data.iter().find_position(|&e| pred(e)) {
return Some(self.inverse_at(idx));
}
None
}
pub fn find_all<F>(&self, pred: F) -> Option<Vec<Shape>>
where
F: Fn(&T) -> bool + Sync,
{
let data: Vec<Shape> = self
.data
.par_iter()
.enumerate()
.filter_map(|(idx, elem)| {
if pred(elem) {
Some(self.inverse_at(idx))
} else {
None
}
})
.collect();
if data.is_empty() {
None
} else {
Some(data)
}
}
}