#![warn(missing_docs)]
use std::fmt::Display;
use std::{collections::HashMap, error::Error, marker::PhantomData, str::FromStr};
use rayon::prelude::{IntoParallelIterator, IntoParallelRefIterator, ParallelIterator};
use serde::{Deserialize, Serialize};
use crate::{Matrix, MatrixElement, MatrixError, Shape};
macro_rules! at {
($row:ident, $col:ident, $ncols:expr) => {
($row * $ncols + $col) as usize
};
}
pub type SparseMatrixData<'a, T> = HashMap<Shape, T>;
#[derive(Clone, PartialEq, Debug, Serialize, Deserialize)]
pub struct SparseMatrix<'a, T>
where
T: MatrixElement,
<T as FromStr>::Err: Error + 'static,
Vec<T>: IntoParallelIterator,
Vec<&'a T>: IntoParallelRefIterator<'a>,
{
data: SparseMatrixData<'a, T>,
pub nrows: usize,
pub ncols: usize,
_lifetime: PhantomData<&'a T>,
}
impl<'a, T> Display for SparseMatrix<'a, T>
where
T: MatrixElement,
<T as FromStr>::Err: Error + 'static,
Vec<T>: IntoParallelIterator,
Vec<&'a T>: IntoParallelRefIterator<'a>,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
for i in 0..self.nrows {
for j in 0..self.ncols {
let elem = match self.data.get(&(i, j)) {
Some(&val) => val,
None => T::zero(),
};
write!(f, "{elem} ");
}
write!(f, "\n");
}
writeln!(f, "\ndtype = {}", std::any::type_name::<T>())
}
}
impl<'a, T> Default for SparseMatrix<'a, T>
where
T: MatrixElement,
<T as FromStr>::Err: Error + 'static,
Vec<T>: IntoParallelIterator,
Vec<&'a T>: IntoParallelRefIterator<'a>,
{
fn default() -> Self {
Self::eye(3)
}
}
impl<'a, T> SparseMatrix<'a, T>
where
T: MatrixElement,
<T as FromStr>::Err: Error + 'static,
Vec<T>: IntoParallelIterator,
Vec<&'a T>: IntoParallelRefIterator<'a>,
{
pub fn new(rows: usize, cols: usize) -> Self {
Self {
data: HashMap::new(),
nrows: rows,
ncols: cols,
_lifetime: PhantomData::default(),
}
}
pub fn init(data: SparseMatrixData<'a, T>, shape: Shape) -> Self {
Self {
data,
nrows: shape.0,
ncols: shape.1,
_lifetime: PhantomData::default(),
}
}
pub fn eye(size: usize) -> Self {
let data: SparseMatrixData<'a, T> = (0..size)
.into_par_iter()
.map(|i| ((i, i), T::one()))
.collect();
Self::init(data, (size, size))
}
pub fn identity(size: usize) -> Self {
Self::eye(size)
}
pub fn reshape(&mut self, nrows: usize, ncols: usize) {
self.nrows = nrows;
self.ncols = ncols;
}
pub fn from_dense(matrix: Matrix<'a, T>) -> Self {
let mut data: SparseMatrixData<'a, T> = HashMap::new();
for i in 0..matrix.nrows {
for j in 0..matrix.ncols {
let val = matrix.get(i, j).unwrap();
if val != T::zero() {
data.insert((i, j), val);
}
}
}
Self::init(data, matrix.shape())
}
pub fn get(&self, i: usize, j: usize) -> Option<T> {
let idx = at!(i, j, self.ncols);
if idx >= self.size() {
eprintln!("Error, index out of bounds. Not setting value");
return None;
}
match self.data.get(&(i, j)) {
None => Some(T::zero()),
val => val.copied(),
}
}
pub fn at(&self, i: usize, j: usize) -> T {
match self.data.get(&(i, j)) {
None => T::zero(),
Some(val) => val.clone(),
}
}
pub fn set(&mut self, i: usize, j: usize, value: T) {
let idx = at!(i, j, self.ncols);
if idx >= self.size() {
eprintln!("Error, index out of bounds. Not setting value");
return;
}
self.data
.entry((i, j))
.and_modify(|val| *val = value)
.or_insert(value);
}
pub fn print(&self, decimals: usize) {
self.data
.iter()
.for_each(|((i, j), val)| println!("{i} {j}: {:.decimals$}", val));
}
#[inline(always)]
pub fn size(&self) -> usize {
self.ncols * self.nrows
}
#[inline(always)]
pub fn sparcity(&self) -> f64 {
1.0 - self.data.par_iter().count() as f64 / self.size() as f64
}
pub fn shape(&self) -> Shape {
(self.nrows, self.ncols)
}
}
impl<'a, T> SparseMatrix<'a, T>
where
T: MatrixElement,
<T as FromStr>::Err: Error + 'static,
Vec<T>: IntoParallelIterator,
Vec<&'a T>: IntoParallelRefIterator<'a>,
{
pub fn add(&self, other: &Self) -> Result<Self, MatrixError> {
if self.shape() != other.shape() {
return Err(MatrixError::MatrixDimensionMismatchError.into());
}
let data: SparseMatrixData<T> = self
.data
.iter()
.zip(other.data.iter())
.filter_map(|((is, &val1), (js, &val2))| {
if is != js {
return None;
}
return Some((is.clone(), val1 + val2));
})
.collect();
Ok(Self::init(data, self.shape()))
}
fn matmul_sparse(&self, other: &Self) -> Self {
unimplemented!()
}
}