#![warn(missing_docs)]
mod helper;
use helper::*;
use num_traits::Float;
use rand::Rng;
use itertools::Itertools;
use std::fmt::Display;
use std::fs;
use std::{collections::HashMap, error::Error, marker::PhantomData, str::FromStr};
use rayon::prelude::*;
use serde::{Deserialize, Serialize};
use crate::{at, LinAlgFloats, Matrix, MatrixElement, MatrixError, Operation, Shape};
pub type SparseMatrixData<'a, T> = HashMap<Shape, T>;
#[derive(Clone, PartialEq, Debug, Serialize, Deserialize)]
pub struct SparseMatrix<'a, T>
where
T: MatrixElement,
<T as FromStr>::Err: Error + 'static,
Vec<T>: IntoParallelIterator,
Vec<&'a T>: IntoParallelRefIterator<'a>,
{
pub data: SparseMatrixData<'a, T>,
pub nrows: usize,
pub ncols: usize,
_lifetime: PhantomData<&'a T>,
}
impl<'a, T> Error for SparseMatrix<'a, T>
where
T: MatrixElement,
<T as FromStr>::Err: Error + 'static,
Vec<T>: IntoParallelIterator,
Vec<&'a T>: IntoParallelRefIterator<'a>,
{
}
unsafe impl<'a, T> Send for SparseMatrix<'a, T>
where
T: MatrixElement,
<T as FromStr>::Err: Error + 'static,
Vec<T>: IntoParallelIterator,
Vec<&'a T>: IntoParallelRefIterator<'a>,
{
}
unsafe impl<'a, T> Sync for SparseMatrix<'a, T>
where
T: MatrixElement,
<T as FromStr>::Err: Error + 'static,
Vec<T>: IntoParallelIterator,
Vec<&'a T>: IntoParallelRefIterator<'a>,
{
}
impl<'a, T> FromStr for SparseMatrix<'a, T>
where
T: MatrixElement,
<T as FromStr>::Err: Error + 'static,
{
type Err = anyhow::Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let data = s
.trim()
.lines()
.skip(1)
.map(|l| {
let entry: Vec<&str> = l.split_whitespace().collect();
let row = entry[0].parse::<usize>().unwrap();
let col = entry[1].parse::<usize>().unwrap();
let val = entry[2].parse::<T>().unwrap();
((row, col), val)
})
.collect::<SparseMatrixData<T>>();
let dims = s
.trim()
.lines()
.nth(0)
.unwrap()
.split_whitespace()
.map(|e| e.parse::<usize>().unwrap())
.collect::<Vec<usize>>();
Ok(Self::new(data, (dims[0], dims[1])))
}
}
impl<'a, T> Display for SparseMatrix<'a, T>
where
T: MatrixElement,
<T as FromStr>::Err: Error + 'static,
Vec<T>: IntoParallelIterator,
Vec<&'a T>: IntoParallelRefIterator<'a>,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
for i in 0..self.nrows {
for j in 0..self.ncols {
let elem = match self.data.get(&(i, j)) {
Some(&val) => val,
None => T::zero(),
};
write!(f, "{elem} ");
}
writeln!(f);
}
writeln!(f, "\ndtype = {}", std::any::type_name::<T>())
}
}
impl<'a, T> Default for SparseMatrix<'a, T>
where
T: MatrixElement,
<T as FromStr>::Err: Error + 'static,
Vec<T>: IntoParallelIterator,
Vec<&'a T>: IntoParallelRefIterator<'a>,
{
fn default() -> Self {
Self {
data: HashMap::new(),
nrows: 0,
ncols: 0,
_lifetime: PhantomData::default(),
}
}
}
impl<'a, T> SparseMatrix<'a, T>
where
T: MatrixElement,
<T as FromStr>::Err: Error + 'static,
Vec<T>: IntoParallelIterator,
Vec<&'a T>: IntoParallelRefIterator<'a>,
{
pub fn new(data: SparseMatrixData<'a, T>, shape: Shape) -> Self {
Self {
data,
nrows: shape.0,
ncols: shape.1,
_lifetime: PhantomData::default(),
}
}
pub fn init(nrows: usize, ncols: usize) -> Self {
Self {
data: HashMap::new(),
nrows,
ncols,
_lifetime: PhantomData::default(),
}
}
pub fn eye(size: usize) -> Self {
let data: SparseMatrixData<'a, T> = (0..size)
.into_par_iter()
.map(|i| ((i, i), T::one()))
.collect();
Self::new(data, (size, size))
}
pub fn eye_like(matrix: &Self) -> Self {
Self::eye(matrix.nrows)
}
pub fn identity(size: usize) -> Self {
Self::eye(size)
}
pub fn ones(sparsity: f64, shape: Shape) -> Self {
Self::randomize_range(T::one(), T::one(), sparsity, shape)
}
pub fn reshape(&mut self, nrows: usize, ncols: usize) {
self.nrows = nrows;
self.ncols = ncols;
}
pub fn from_dense(matrix: Matrix<'a, T>) -> Self {
let mut data: SparseMatrixData<'a, T> = HashMap::new();
for i in 0..matrix.nrows {
for j in 0..matrix.ncols {
let val = matrix.get(i, j).unwrap();
if val != T::zero() {
data.insert((i, j), val);
}
}
}
Self::new(data, matrix.shape())
}
pub fn from_slices(
rows: &[usize],
cols: &[usize],
vals: &[T],
shape: Shape,
) -> Result<Self, MatrixError> {
if rows.len() != cols.len() && cols.len() != vals.len() {
return Err(MatrixError::MatrixDimensionMismatchError.into());
}
let data: SparseMatrixData<T> = rows
.iter()
.zip(cols.iter().zip(vals.iter()))
.map(|(&i, (&j, &val))| ((i, j), val))
.collect();
Ok(Self::new(data, shape))
}
pub fn from_file(path: &'static str) -> Result<Self, MatrixError> {
let data =
fs::read_to_string(path).map_err(|_| MatrixError::MatrixFileReadError(path).into())?;
data.parse::<Self>()
.map_err(|_| MatrixError::MatrixParseError.into())
}
pub fn get(&self, i: usize, j: usize) -> Option<T> {
let idx = at!(i, j, self.ncols);
if idx >= self.size() {
eprintln!("Error, index out of bounds. Not setting value");
return None;
}
match self.data.get(&(i, j)) {
None => Some(T::zero()),
val => val.copied(),
}
}
pub fn randomize_range(start: T, end: T, sparsity: f64, shape: Shape) -> Self {
let mut rng = rand::thread_rng();
let (rows, cols) = shape;
let mut matrix = Self::init(shape.0, shape.1);
while matrix.sparsity() > sparsity {
let value: T = rng.gen_range(start..=end);
let row: usize = rng.gen_range(0..rows);
let col: usize = rng.gen_range(0..cols);
match matrix.data.get(&(row, col)) {
Some(_) => {}
None => matrix.set(value, (row, col)),
}
}
matrix
}
pub fn randomize(sparcity: f64, shape: Shape) -> Self {
Self::randomize_range(T::zero(), T::one(), sparcity, shape)
}
pub fn randomize_range_like(start: T, end: T, matrix: &Self) -> Self {
Self::randomize_range(start, end, matrix.sparsity(), matrix.shape())
}
pub fn random_like(matrix: &Self) -> Self {
Self::randomize(matrix.sparsity(), matrix.shape())
}
#[inline(always)]
pub fn at(&self, i: usize, j: usize) -> T {
match self.data.get(&(i, j)) {
None => T::zero(),
Some(val) => val.clone(),
}
}
pub fn set(&mut self, value: T, idx: Shape) {
if value == T::zero() {
eprintln!("You are trying to insert a 0 value.");
return;
}
let i = at!(idx.0, idx.1, self.ncols);
if i >= self.size() {
eprintln!("Error, index out of bounds. Not setting value");
return;
}
self.data
.entry(idx)
.and_modify(|val| *val = value)
.or_insert(value);
}
pub fn insert(&mut self, i: usize, j: usize, value: T) {
self.set(value, (i, j));
}
pub fn print(&self, decimals: usize) {
self.data
.iter()
.for_each(|((i, j), val)| println!("{i} {j}: {:.decimals$}", val));
}
#[inline(always)]
pub fn size(&self) -> usize {
self.ncols * self.nrows
}
#[inline(always)]
pub fn get_zero_count(&self) -> usize {
self.size() - self.data.len()
}
#[inline(always)]
pub fn sparsity(&self) -> f64 {
1.0 - self.data.par_iter().count() as f64 / self.size() as f64
}
pub fn shape(&self) -> Shape {
(self.nrows, self.ncols)
}
pub fn transpose(&mut self) {
let mut new_data: SparseMatrixData<T> = HashMap::new();
for (&(i, j), &val) in self.data.iter() {
new_data.insert((j, i), val);
}
self.data = new_data;
swap(&mut self.nrows, &mut self.ncols);
}
pub fn t(&mut self) {
self.transpose();
}
pub fn transpose_new(&self) -> Self {
let mut res = self.clone();
res.transpose();
res
}
pub fn max(&self) -> T {
let elem = self
.data
.iter()
.max_by(|(_, v1), (_, v2)| v1.partial_cmp(v2).unwrap());
return match elem {
Some((_, &v)) => v,
None => T::zero(),
};
}
pub fn min(&self) -> T {
let elem = self
.data
.iter()
.max_by(|(_, v1), (_, v2)| v1.partial_cmp(v2).unwrap());
return match elem {
Some((_, &v)) => v,
None => T::zero(),
};
}
pub fn neg(&self) -> Self {
let data = self
.data
.par_iter()
.map(|((i, j), &e)| ((*i, *j), e.neg()))
.collect::<SparseMatrixData<T>>();
Self::new(data, self.shape())
}
pub fn avg(&self) -> T {
self.data.par_iter().map(|(_, &val)| val).sum::<T>()
/ self.size().to_string().parse::<T>().unwrap()
}
pub fn mean(&self) -> T {
self.avg()
}
pub fn median(&self) -> T {
if self.size() == 0 {
return T::zero();
}
if self.size() == 1 {
return self.at(0, 0);
}
if self.min() >= T::zero() && self.sparsity() >= 0.5 {
return T::zero();
}
let sorted_values: Vec<T> = self
.data
.values()
.copied()
.sorted_by(|a, b| a.partial_cmp(&b).unwrap())
.collect::<Vec<T>>();
match self.data.len() % 2 {
0 => {
let half: usize = self.data.len() / 2;
sorted_values
.iter()
.skip(half - 1)
.take(2)
.copied()
.sum::<T>()
/ (T::one() + T::one())
}
1 => {
let half: usize = self.data.len() / 2;
sorted_values.iter().nth(half).unwrap().to_owned()
}
_ => unreachable!(),
}
}
}
impl<'a, T> LinAlgFloats<'a, T> for SparseMatrix<'a, T>
where
T: MatrixElement + Float,
<T as FromStr>::Err: Error + 'static,
Vec<T>: IntoParallelIterator,
Vec<&'a T>: IntoParallelRefIterator<'a>,
{
fn ln(&self) -> Self {
let data = self.data.iter().map(|(&idx, &e)| (idx, e.ln())).collect();
Self::new(data, self.shape())
}
fn log(&self, base: T) -> Self {
let data = self
.data
.iter()
.map(|(&idx, &e)| (idx, e.log(base)))
.collect();
Self::new(data, self.shape())
}
fn sin(&self) -> Self {
let data = self.data.iter().map(|(&idx, &e)| (idx, e.sin())).collect();
Self::new(data, self.shape())
}
fn cos(&self) -> Self {
let data = self.data.iter().map(|(&idx, &e)| (idx, e.cos())).collect();
Self::new(data, self.shape())
}
fn tan(&self) -> Self {
let data = self.data.iter().map(|(&idx, &e)| (idx, e.tan())).collect();
Self::new(data, self.shape())
}
fn sqrt(&self) -> Self {
let data = self.data.iter().map(|(&idx, &e)| (idx, e.sqrt())).collect();
Self::new(data, self.shape())
}
fn sinh(&self) -> Self {
let data = self.data.iter().map(|(&idx, &e)| (idx, e.sinh())).collect();
Self::new(data, self.shape())
}
fn cosh(&self) -> Self {
let data = self.data.iter().map(|(&idx, &e)| (idx, e.cosh())).collect();
Self::new(data, self.shape())
}
fn tanh(&self) -> Self {
let data = self.data.iter().map(|(&idx, &e)| (idx, e.tanh())).collect();
Self::new(data, self.shape())
}
fn get_eigenvalues(&self) -> Option<Vec<T>> {
unimplemented!()
}
fn get_eigenvectors(&self) -> Option<Vec<T>> {
unimplemented!()
}
}
impl<'a, T> SparseMatrix<'a, T>
where
T: MatrixElement,
<T as FromStr>::Err: Error + 'static,
Vec<T>: IntoParallelIterator,
Vec<&'a T>: IntoParallelRefIterator<'a>,
{
pub fn add(&self, other: &Self) -> Result<Self, MatrixError> {
Self::sparse_helper(&self, other, Operation::ADD)
}
pub fn sub(&self, other: &Self) -> Result<Self, MatrixError> {
Self::sparse_helper(&self, other, Operation::SUB)
}
pub fn mul(&self, other: &Self) -> Result<Self, MatrixError> {
Self::sparse_helper(&self, other, Operation::MUL)
}
pub fn dot(&self, other: &Self) -> Result<Self, MatrixError> {
self.mul(other)
}
pub fn div(&self, other: &Self) -> Result<Self, MatrixError> {
Self::sparse_helper(&self, other, Operation::DIV)
}
pub fn add_self(&mut self, other: &Self) {
Self::sparse_helper_self(self, other, Operation::ADD);
}
pub fn sub_self(&mut self, other: &Self) {
Self::sparse_helper_self(self, other, Operation::SUB);
}
pub fn mul_self(&mut self, other: &Self) {
Self::sparse_helper_self(self, other, Operation::MUL);
}
pub fn div_self(&mut self, other: &Self) {
Self::sparse_helper_self(self, other, Operation::DIV);
}
pub fn add_val(&self, value: T) -> Self {
Self::sparse_helper_val(self, value, Operation::ADD)
}
pub fn sub_val(&self, value: T) -> Self {
Self::sparse_helper_val(self, value, Operation::SUB)
}
pub fn mul_val(&self, value: T) -> Self {
Self::sparse_helper_val(self, value, Operation::MUL)
}
pub fn div_val(&self, value: T) -> Self {
Self::sparse_helper_val(self, value, Operation::DIV)
}
pub fn add_val_self(&mut self, value: T) {
Self::sparse_helper_self_val(self, value, Operation::ADD)
}
pub fn sub_val_self(&mut self, value: T) {
Self::sparse_helper_self_val(self, value, Operation::SUB)
}
pub fn mul_val_self(&mut self, value: T) {
Self::sparse_helper_self_val(self, value, Operation::MUL)
}
pub fn div_val_self(&mut self, value: T) {
Self::sparse_helper_self_val(self, value, Operation::DIV)
}
pub fn matmul_sparse(&self, other: &Self) -> Result<Self, MatrixError> {
if self.ncols != other.nrows {
return Err(MatrixError::MatrixMultiplicationDimensionMismatchError.into());
}
if self.shape() == other.shape() {
return Ok(self.matmul_sparse_nn(other));
}
Ok(self.matmul_sparse_mnnp(other))
}
}
impl<'a, T> SparseMatrix<'a, T>
where
T: MatrixElement,
<T as FromStr>::Err: Error + 'static,
Vec<T>: IntoParallelIterator,
Vec<&'a T>: IntoParallelRefIterator<'a>,
{
pub fn all<F>(&self, pred: F) -> bool
where
F: Fn((Shape, T)) -> bool + Sync + Send,
{
self.data.clone().into_par_iter().all(pred)
}
pub fn any<F>(&self, pred: F) -> bool
where
F: Fn((Shape, T)) -> bool + Sync + Send,
{
self.data.clone().into_par_iter().any(pred)
}
pub fn count_where<F>(&'a self, pred: F) -> usize
where
F: Fn((&Shape, &T)) -> bool + Sync,
{
self.data.par_iter().filter(|&e| pred(e)).count()
}
pub fn sum_where<F>(&self, pred: F) -> T
where
F: Fn((&Shape, &T)) -> bool + Sync,
{
let mut res = T::zero();
for (idx, elem) in self.data.iter() {
if pred((idx, elem)) {
res += elem
}
}
res
}
pub fn set_where<F>(&mut self, mut pred: F)
where
F: FnMut((&Shape, &mut T)) + Sync + Send,
{
self.data.iter_mut().for_each(|e| pred(e));
}
pub fn find<F>(&self, pred: F) -> Option<T>
where
F: Fn((&Shape, &T)) -> bool + Sync,
{
for entry in &self.data {
if pred(entry) {
return Some(*entry.1);
}
}
None
}
fn find_all<F>(&self, pred: F) -> Option<Vec<T>>
where
F: Fn((&Shape, &T)) -> bool + Sync,
{
let mut idxs: Vec<T> = Vec::new();
for entry in &self.data {
if pred(entry) {
idxs.push(*entry.1);
}
}
if !idxs.is_empty() {
Some(idxs)
} else {
None
}
}
pub fn position<F>(&self, pred: F) -> Option<Shape>
where
F: Fn((&Shape, &T)) -> bool + Sync,
{
for entry in &self.data {
if pred(entry) {
return Some(*entry.0);
}
}
None
}
fn positions<F>(&self, pred: F) -> Option<Vec<Shape>>
where
F: Fn((&Shape, &T)) -> bool + Sync,
{
let mut idxs: Vec<Shape> = Vec::new();
for entry in &self.data {
if pred(entry) {
idxs.push(*entry.0);
}
}
if !idxs.is_empty() {
Some(idxs)
} else {
None
}
}
}