#![warn(missing_docs)]
mod helper;
use helper::*;
use std::fmt::Display;
use std::{collections::HashMap, error::Error, marker::PhantomData, str::FromStr};
use rayon::prelude::{IntoParallelIterator, IntoParallelRefIterator, ParallelIterator};
use serde::{Deserialize, Serialize};
use crate::{Matrix, MatrixElement, MatrixError, Shape};
macro_rules! at {
($row:expr, $col:expr, $ncols:expr) => {
($row * $ncols + $col) as usize
};
}
pub type SparseMatrixData<'a, T> = HashMap<Shape, T>;
#[derive(Clone, PartialEq, Debug, Serialize, Deserialize)]
pub struct SparseMatrix<'a, T>
where
T: MatrixElement,
<T as FromStr>::Err: Error + 'static,
Vec<T>: IntoParallelIterator,
Vec<&'a T>: IntoParallelRefIterator<'a>,
{
data: SparseMatrixData<'a, T>,
pub nrows: usize,
pub ncols: usize,
_lifetime: PhantomData<&'a T>,
}
impl<'a, T> Display for SparseMatrix<'a, T>
where
T: MatrixElement,
<T as FromStr>::Err: Error + 'static,
Vec<T>: IntoParallelIterator,
Vec<&'a T>: IntoParallelRefIterator<'a>,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
for i in 0..self.nrows {
for j in 0..self.ncols {
let elem = match self.data.get(&(i, j)) {
Some(&val) => val,
None => T::zero(),
};
write!(f, "{elem} ");
}
write!(f, "\n");
}
writeln!(f, "\ndtype = {}", std::any::type_name::<T>())
}
}
impl<'a, T> Default for SparseMatrix<'a, T>
where
T: MatrixElement,
<T as FromStr>::Err: Error + 'static,
Vec<T>: IntoParallelIterator,
Vec<&'a T>: IntoParallelRefIterator<'a>,
{
fn default() -> Self {
Self::eye(3)
}
}
impl<'a, T> SparseMatrix<'a, T>
where
T: MatrixElement,
<T as FromStr>::Err: Error + 'static,
Vec<T>: IntoParallelIterator,
Vec<&'a T>: IntoParallelRefIterator<'a>,
{
pub fn new(rows: usize, cols: usize) -> Self {
Self {
data: HashMap::new(),
nrows: rows,
ncols: cols,
_lifetime: PhantomData::default(),
}
}
pub fn init(data: SparseMatrixData<'a, T>, shape: Shape) -> Self {
Self {
data,
nrows: shape.0,
ncols: shape.1,
_lifetime: PhantomData::default(),
}
}
pub fn eye(size: usize) -> Self {
let data: SparseMatrixData<'a, T> = (0..size)
.into_par_iter()
.map(|i| ((i, i), T::one()))
.collect();
Self::init(data, (size, size))
}
pub fn identity(size: usize) -> Self {
Self::eye(size)
}
pub fn reshape(&mut self, nrows: usize, ncols: usize) {
self.nrows = nrows;
self.ncols = ncols;
}
pub fn from_dense(matrix: Matrix<'a, T>) -> Self {
let mut data: SparseMatrixData<'a, T> = HashMap::new();
for i in 0..matrix.nrows {
for j in 0..matrix.ncols {
let val = matrix.get(i, j).unwrap();
if val != T::zero() {
data.insert((i, j), val);
}
}
}
Self::init(data, matrix.shape())
}
pub fn get(&self, i: usize, j: usize) -> Option<T> {
let idx = at!(i, j, self.ncols);
if idx >= self.size() {
eprintln!("Error, index out of bounds. Not setting value");
return None;
}
match self.data.get(&(i, j)) {
None => Some(T::zero()),
val => val.copied(),
}
}
pub fn at(&self, i: usize, j: usize) -> T {
match self.data.get(&(i, j)) {
None => T::zero(),
Some(val) => val.clone(),
}
}
pub fn set(&mut self, idx: Shape, value: T) {
let i = at!(idx.0, idx.1, self.ncols);
if i >= self.size() {
eprintln!("Error, index out of bounds. Not setting value");
return;
}
self.data
.entry(idx)
.and_modify(|val| *val = value)
.or_insert(value);
}
pub fn print(&self, decimals: usize) {
self.data
.iter()
.for_each(|((i, j), val)| println!("{i} {j}: {:.decimals$}", val));
}
#[inline(always)]
pub fn size(&self) -> usize {
self.ncols * self.nrows
}
#[inline(always)]
pub fn get_zero_count(&self) -> usize {
self.size() - self.data.len()
}
#[inline(always)]
pub fn sparcity(&self) -> f64 {
1.0 - self.data.par_iter().count() as f64 / self.size() as f64
}
pub fn shape(&self) -> Shape {
(self.nrows, self.ncols)
}
pub fn transpose(&mut self) {
let mut new_data: SparseMatrixData<T> = HashMap::new();
for (&(i, j), &val) in self.data.iter() {
new_data.insert((j, i), val);
}
self.data = new_data;
swap(&mut self.nrows, &mut self.ncols);
}
pub fn t(&mut self) {
self.transpose();
}
pub fn transpose_new(&self) -> Self {
let mut res = self.clone();
res.transpose();
res
}
}
impl<'a, T> SparseMatrix<'a, T>
where
T: MatrixElement,
<T as FromStr>::Err: Error + 'static,
Vec<T>: IntoParallelIterator,
Vec<&'a T>: IntoParallelRefIterator<'a>,
{
pub fn add(&self, other: &Self) -> Result<Self, MatrixError> {
Self::sparse_helper(&self, other, Operation::ADD)
}
pub fn sub(&self, other: &Self) -> Result<Self, MatrixError> {
Self::sparse_helper(&self, other, Operation::SUB)
}
pub fn mul(&self, other: &Self) -> Result<Self, MatrixError> {
Self::sparse_helper(&self, other, Operation::MUL)
}
pub fn div(&self, other: &Self) -> Result<Self, MatrixError> {
Self::sparse_helper(&self, other, Operation::DIV)
}
pub fn add_self(&mut self, other: &Self) {
Self::sparse_helper_self(self, other, Operation::ADD);
}
pub fn sub_self(&mut self, other: &Self) {
Self::sparse_helper_self(self, other, Operation::SUB);
}
pub fn mul_self(&mut self, other: &Self) {
Self::sparse_helper_self(self, other, Operation::MUL);
}
pub fn div_self(&mut self, other: &Self) {
Self::sparse_helper_self(self, other, Operation::DIV);
}
pub fn add_val(&self, value: T) -> Self {
Self::sparse_helper_val(self, value, Operation::ADD)
}
pub fn sub_val(&self, value: T) -> Self {
Self::sparse_helper_val(self, value, Operation::SUB)
}
pub fn mul_val(&self, value: T) -> Self {
Self::sparse_helper_val(self, value, Operation::MUL)
}
pub fn div_val(&self, value: T) -> Self {
Self::sparse_helper_val(self, value, Operation::DIV)
}
pub fn add_val_self(&mut self, value: T) {
Self::sparse_helper_self_val(self, value, Operation::ADD)
}
pub fn sub_val_self(&mut self, value: T) {
Self::sparse_helper_self_val(self, value, Operation::SUB)
}
pub fn mul_val_self(&mut self, value: T) {
Self::sparse_helper_self_val(self, value, Operation::MUL)
}
pub fn div_val_self(&mut self, value: T) {
Self::sparse_helper_self_val(self, value, Operation::DIV)
}
fn matmul_sparse(&self, other: &Self) -> Self {
unimplemented!()
}
}
impl<'a, T> SparseMatrix<'a, T>
where
T: MatrixElement,
<T as FromStr>::Err: Error + 'static,
Vec<T>: IntoParallelIterator,
Vec<&'a T>: IntoParallelRefIterator<'a>,
{
pub fn all<F>(&self, pred: F) -> bool
where
F: Fn((Shape, T)) -> bool + Sync + Send,
{
self.data.clone().into_par_iter().all(pred)
}
pub fn any<F>(&self, pred: F) -> bool
where
F: Fn((Shape, T)) -> bool + Sync + Send,
{
self.data.clone().into_par_iter().any(pred)
}
pub fn count_where<F>(&'a self, pred: F) -> usize
where
F: Fn((&Shape, &T)) -> bool + Sync,
{
self.data.par_iter().filter(|&e| pred(e)).count()
}
pub fn sum_where<F>(&self, pred: F) -> T
where
F: Fn((&Shape, &T)) -> bool + Sync,
{
let mut res = T::zero();
for (idx, elem) in self.data.iter() {
if pred((idx, elem)) {
res += elem
}
}
res
}
pub fn set_where<F>(&mut self, mut pred: F)
where
F: FnMut((&Shape, &mut T)) + Sync + Send,
{
self.data.iter_mut().for_each(|e| pred(e));
}
fn find<F>(&self, pred: F) -> Option<Shape>
where
F: Fn(&T) -> bool + Sync,
{
unimplemented!()
}
fn find_all<F>(&self, pred: F) -> Option<Vec<Shape>>
where
F: Fn(&T) -> bool + Sync,
{
unimplemented!()
}
}