linalg_rs/sparse/
mod.rs

1//!  Module for defining sparse matrices.
2//!
3//! # What are sparse matrices
4//!
5//! Generally speaking, matrices with a lot of 0s
6//!
7//! # How are they represented
8//!
9//! Since storing large sparse matrices in memory is expensive
10//!
11//!
12//! # What data structure does sukker use?
13//!
14//! For now, a hash map where the keys are indeces in the matrix
15//! and tha value is the value at that 2d index
16#![warn(missing_docs)]
17
18mod helper;
19
20use helper::*;
21use num_traits::Float;
22use rand::Rng;
23
24use itertools::Itertools;
25use std::fmt::Display;
26use std::fs;
27use std::{collections::HashMap, error::Error, marker::PhantomData, str::FromStr};
28
29use rayon::prelude::*;
30use serde::{Deserialize, Serialize};
31
32use crate::{at, LinAlgFloats, Matrix, MatrixElement, MatrixError, Operation, Shape};
33
34/// SparseMatrixData represents the datatype used to store information
35/// about non-zero values in a general matrix.
36///
37/// The keys are the index to the position in data,
38/// while the value is the value to be stored inside the matrix
39pub type SparseMatrixData<'a, T> = HashMap<Shape, T>;
40
41#[derive(Clone, PartialEq, Debug, Serialize, Deserialize)]
42/// Represents a sparse matrix and its data
43pub struct SparseMatrix<'a, T>
44where
45    T: MatrixElement,
46    <T as FromStr>::Err: Error + 'static,
47    Vec<T>: IntoParallelIterator,
48    Vec<&'a T>: IntoParallelRefIterator<'a>,
49{
50    /// Vector containing all data
51    pub data: SparseMatrixData<'a, T>,
52    /// Number of rows
53    pub nrows: usize,
54    /// Number of columns
55    pub ncols: usize,
56    _lifetime: PhantomData<&'a T>,
57}
58
59impl<'a, T> Error for SparseMatrix<'a, T>
60where
61    T: MatrixElement,
62    <T as FromStr>::Err: Error + 'static,
63    Vec<T>: IntoParallelIterator,
64    Vec<&'a T>: IntoParallelRefIterator<'a>,
65{
66}
67
68unsafe impl<'a, T> Send for SparseMatrix<'a, T>
69where
70    T: MatrixElement,
71    <T as FromStr>::Err: Error + 'static,
72    Vec<T>: IntoParallelIterator,
73    Vec<&'a T>: IntoParallelRefIterator<'a>,
74{
75}
76
77unsafe impl<'a, T> Sync for SparseMatrix<'a, T>
78where
79    T: MatrixElement,
80    <T as FromStr>::Err: Error + 'static,
81    Vec<T>: IntoParallelIterator,
82    Vec<&'a T>: IntoParallelRefIterator<'a>,
83{
84}
85
86impl<'a, T> FromStr for SparseMatrix<'a, T>
87where
88    T: MatrixElement,
89    <T as FromStr>::Err: Error + 'static,
90{
91    type Err = anyhow::Error;
92    fn from_str(s: &str) -> Result<Self, Self::Err> {
93        // Parse the input string and construct the matrix dynamically
94        let data = s
95            .trim()
96            .lines()
97            .skip(1)
98            .map(|l| {
99                let entry: Vec<&str> = l.split_whitespace().collect();
100
101                let row = entry[0].parse::<usize>().unwrap();
102                let col = entry[1].parse::<usize>().unwrap();
103                let val = entry[2].parse::<T>().unwrap();
104
105                ((row, col), val)
106            })
107            .collect::<SparseMatrixData<T>>();
108
109        let dims = s
110            .trim()
111            .lines()
112            .nth(0)
113            .unwrap()
114            .split_whitespace()
115            .map(|e| e.parse::<usize>().unwrap())
116            .collect::<Vec<usize>>();
117
118        Ok(Self::new(data, (dims[0], dims[1])))
119    }
120}
121
122impl<'a, T> Display for SparseMatrix<'a, T>
123where
124    T: MatrixElement,
125    <T as FromStr>::Err: Error + 'static,
126    Vec<T>: IntoParallelIterator,
127    Vec<&'a T>: IntoParallelRefIterator<'a>,
128{
129    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
130        for i in 0..self.nrows {
131            for j in 0..self.ncols {
132                let elem = match self.data.get(&(i, j)) {
133                    Some(&val) => val,
134                    None => T::zero(),
135                };
136
137                write!(f, "{elem} ");
138            }
139            writeln!(f);
140        }
141        writeln!(f, "\ndtype = {}", std::any::type_name::<T>())
142    }
143}
144
145impl<'a, T> Default for SparseMatrix<'a, T>
146where
147    T: MatrixElement,
148    <T as FromStr>::Err: Error + 'static,
149    Vec<T>: IntoParallelIterator,
150    Vec<&'a T>: IntoParallelRefIterator<'a>,
151{
152    /// Returns a sparse 3x3 identity matrix
153    fn default() -> Self {
154        Self {
155            data: HashMap::new(),
156            nrows: 0,
157            ncols: 0,
158            _lifetime: PhantomData::default(),
159        }
160    }
161}
162
163impl<'a, T> SparseMatrix<'a, T>
164where
165    T: MatrixElement,
166    <T as FromStr>::Err: Error + 'static,
167    Vec<T>: IntoParallelIterator,
168    Vec<&'a T>: IntoParallelRefIterator<'a>,
169{
170    /// Constructs a new sparse matrix based on a hashmap
171    /// containing the indices where value is not 0
172    ///
173    /// This function does not check whether or not the
174    /// indices are valid and according to shape. Use `reshape`
175    /// to fix this issue.
176    ///
177    /// Examples
178    ///
179    /// ```
180    /// use std::collections::HashMap;
181    /// use sukker::{smd, SparseMatrix, SparseMatrixData};
182    ///
183    /// // Here we can use the smd! macro
184    /// // to easily be able to set up a new hashmap
185    /// let indexes: SparseMatrixData<f64> = smd![
186    ///     ( (0, 0), 2.0),
187    ///     ( (0, 3), 4.0),
188    ///     ( (4, 5), 6.0),
189    ///     ( (2, 7), 8.0)
190    /// ];
191    ///
192    /// let sparse = SparseMatrix::<f64>::new(indexes, (3,3));
193    ///
194    /// assert_eq!(sparse.shape(), (3,3));
195    /// assert_eq!(sparse.get(4,5), None);
196    /// assert_eq!(sparse.get(0,1), Some(0.0));
197    /// ```
198    pub fn new(data: SparseMatrixData<'a, T>, shape: Shape) -> Self {
199        Self {
200            data,
201            nrows: shape.0,
202            ncols: shape.1,
203            _lifetime: PhantomData::default(),
204        }
205    }
206
207    /// Inits an empty matrix based on shape
208    pub fn init(nrows: usize, ncols: usize) -> Self {
209        Self {
210            data: HashMap::new(),
211            nrows,
212            ncols,
213            _lifetime: PhantomData::default(),
214        }
215    }
216
217    /// Returns a sparse eye matrix
218    ///
219    /// Examples
220    ///
221    /// ```
222    /// use sukker::SparseMatrix;
223    ///
224    /// let sparse = SparseMatrix::<i32>::eye(3);
225    ///
226    /// assert_eq!(sparse.ncols, 3);
227    /// assert_eq!(sparse.nrows, 3);
228    /// ```
229    pub fn eye(size: usize) -> Self {
230        let data: SparseMatrixData<'a, T> = (0..size)
231            .into_par_iter()
232            .map(|i| ((i, i), T::one()))
233            .collect();
234
235        Self::new(data, (size, size))
236    }
237
238    /// Produces an eye with the same shape as another
239    /// sparse matrix
240    pub fn eye_like(matrix: &Self) -> Self {
241        Self::eye(matrix.nrows)
242    }
243
244    /// Same as eye
245    ///
246    /// Examples
247    ///
248    /// ```
249    /// use sukker::SparseMatrix;
250    ///
251    /// let sparse = SparseMatrix::<f64>::identity(3);
252    ///
253    /// assert_eq!(sparse.ncols, 3);
254    /// assert_eq!(sparse.nrows, 3);
255    /// ```
256    pub fn identity(size: usize) -> Self {
257        Self::eye(size)
258    }
259
260    /// Creates a matrix with only one values at random
261    /// locations
262    ///
263    /// Same as `random_like` but with range from 1.0..=1.0
264    pub fn ones(sparsity: f64, shape: Shape) -> Self {
265        Self::randomize_range(T::one(), T::one(), sparsity, shape)
266    }
267
268    /// Reshapes a sparse matrix
269    ///
270    /// Examples
271    ///
272    /// ```
273    /// use sukker::SparseMatrix;
274    ///
275    /// let mut sparse = SparseMatrix::<f64>::identity(3);
276    ///
277    /// sparse.reshape(5,5);
278    ///
279    /// assert_eq!(sparse.ncols, 5);
280    /// assert_eq!(sparse.nrows, 5);
281    /// ```
282    pub fn reshape(&mut self, nrows: usize, ncols: usize) {
283        self.nrows = nrows;
284        self.ncols = ncols;
285    }
286
287    /// Creates a sparse matrix from a already existent
288    /// dense one.
289    ///
290    /// Examples:
291    ///
292    /// ```
293    /// use sukker::{SparseMatrix, Matrix};
294    ///
295    /// let dense = Matrix::<i32>::eye(4);
296    ///
297    /// let sparse = SparseMatrix::from_dense(dense);
298    ///
299    /// assert_eq!(sparse.get(0,0), Some(1));
300    /// assert_eq!(sparse.get(1,0), Some(0));
301    /// assert_eq!(sparse.shape(), (4,4));
302    /// ```
303    pub fn from_dense(matrix: Matrix<'a, T>) -> Self {
304        let mut data: SparseMatrixData<'a, T> = HashMap::new();
305
306        for i in 0..matrix.nrows {
307            for j in 0..matrix.ncols {
308                let val = matrix.get(i, j).unwrap();
309                if val != T::zero() {
310                    data.insert((i, j), val);
311                }
312            }
313        }
314
315        Self::new(data, matrix.shape())
316    }
317
318    /// Constructs a sparse matrix from 3 slices.
319    /// One for the rows, one for the cols, and one for the value.
320    /// A combination of values fromt the same index corresponds to
321    /// an entry in the hasmap.
322    ///
323    /// Csc in numpy uses 3 lists of same size
324    ///
325    /// Examples:
326    ///
327    /// ```
328    /// use sukker::SparseMatrix;
329    ///
330    /// let rows = vec![0,1,2,3];
331    /// let cols = vec![1,2,3,4];
332    /// let vals= vec![0.0,1.3,0.05,4.53];
333    ///
334    /// let shape = (6,7);
335    ///
336    /// let sparse = SparseMatrix::from_slices(&rows, &cols, &vals, shape).unwrap();
337    ///
338    /// assert_eq!(sparse.shape(), (6,7));
339    /// assert_eq!(sparse.at(1,2), 1.3);
340    /// assert_eq!(sparse.at(0,1), 0.0);
341    /// ```
342    pub fn from_slices(
343        rows: &[usize],
344        cols: &[usize],
345        vals: &[T],
346        shape: Shape,
347    ) -> Result<Self, MatrixError> {
348        if rows.len() != cols.len() && cols.len() != vals.len() {
349            return Err(MatrixError::MatrixDimensionMismatchError.into());
350        }
351
352        let data: SparseMatrixData<T> = rows
353            .iter()
354            .zip(cols.iter().zip(vals.iter()))
355            .map(|(&i, (&j, &val))| ((i, j), val))
356            .collect();
357
358        Ok(Self::new(data, shape))
359    }
360
361    /// Parses from file, but will return a default sparse matrix if nothing is given
362    ///
363    /// # Examples
364    ///
365    /// ```
366    /// use sukker::SparseMatrix;
367    ///
368    /// // let m: SparseMatrix<f32> = Matrix::from_file("../../test.txt").unwrap();
369    ///
370    /// // m.print(4);
371    /// ```
372    pub fn from_file(path: &'static str) -> Result<Self, MatrixError> {
373        let data =
374            fs::read_to_string(path).map_err(|_| MatrixError::MatrixFileReadError(path).into())?;
375
376        data.parse::<Self>()
377            .map_err(|_| MatrixError::MatrixParseError.into())
378    }
379
380    /// Gets an element from the sparse matrix.
381    ///
382    /// Returns None if index is out of bounds.
383    ///
384    /// Examples
385    ///
386    /// ```
387    /// use sukker::SparseMatrix;
388    ///
389    /// let sparse = SparseMatrix::<i32>::eye(3);
390    ///
391    /// assert_eq!(sparse.get(0,0), Some(1));
392    /// assert_eq!(sparse.get(1,0), Some(0));
393    /// assert_eq!(sparse.get(4,0), None);
394    /// ```
395    pub fn get(&self, i: usize, j: usize) -> Option<T> {
396        let idx = at!(i, j, self.ncols);
397
398        if idx >= self.size() {
399            eprintln!("Error, index out of bounds. Not setting value");
400            return None;
401        }
402
403        match self.data.get(&(i, j)) {
404            None => Some(T::zero()),
405            val => val.copied(),
406        }
407    }
408
409    /// Gets the size of the sparse matrix
410    ///
411    /// Examples
412    ///
413    /// ```
414    /// use sukker::SparseMatrix;
415    ///
416    /// let sparse = SparseMatrix::<f32>::randomize_range(1.0,2.0, 0.75, (4,4));
417    ///
418    /// assert_eq!(sparse.shape(), (4,4));
419    /// assert_eq!(sparse.sparsity(), 0.75);
420    /// assert_eq!(sparse.all(|(_, val)| val >= 1.0 && val <= 2.0), true);
421    /// assert_eq!(sparse.size(), 16);
422    /// ```
423    pub fn randomize_range(start: T, end: T, sparsity: f64, shape: Shape) -> Self {
424        let mut rng = rand::thread_rng();
425
426        let (rows, cols) = shape;
427
428        // If we insert in a position that's already filled up,
429        // we ahve to get a new one
430        let mut matrix = Self::init(shape.0, shape.1);
431
432        while matrix.sparsity() > sparsity {
433            let value: T = rng.gen_range(start..=end);
434
435            let row: usize = rng.gen_range(0..rows);
436            let col: usize = rng.gen_range(0..cols);
437
438            match matrix.data.get(&(row, col)) {
439                Some(_) => {}
440                None => matrix.set(value, (row, col)),
441            }
442        }
443
444        matrix
445    }
446
447    /// Randomizes a sparse matrix with values between 0 and 1.
448    ///
449    /// Examples
450    ///
451    /// ```
452    /// use sukker::SparseMatrix;
453    ///
454    /// let sparse = SparseMatrix::<f32>::randomize(0.75, (4,4));
455    ///
456    /// assert_eq!(sparse.shape(), (4,4));
457    /// assert_eq!(sparse.sparsity(), 0.75);
458    /// assert_eq!(sparse.all(|(_, val)| val >= 0.0 && val <= 1.0), true);
459    /// assert_eq!(sparse.size(), 16);
460    /// ```
461    pub fn randomize(sparcity: f64, shape: Shape) -> Self {
462        Self::randomize_range(T::zero(), T::one(), sparcity, shape)
463    }
464
465    /// Randomizes a sparse matrix  to have same shape and sparcity as another one
466    /// You can however set the range
467    ///
468    /// Examples
469    ///
470    /// ```
471    /// use sukker::SparseMatrix;
472    ///
473    /// let sparse = SparseMatrix::<f32>::randomize_range(1.0,2.0, 0.75, (4,4));
474    ///
475    /// let copy = SparseMatrix::randomize_range_like(2.0, 4.0, &sparse);
476    ///
477    /// assert_eq!(copy.shape(), (4,4));
478    /// assert_eq!(copy.sparsity(), 0.75);
479    /// assert_eq!(copy.all(|(_, val)| val >= 2.0 && val <= 4.0), true);
480    /// assert_eq!(copy.size(), 16);
481    /// ```
482    pub fn randomize_range_like(start: T, end: T, matrix: &Self) -> Self {
483        Self::randomize_range(start, end, matrix.sparsity(), matrix.shape())
484    }
485
486    /// Randomizes a sparse matrix  to have same shape and sparcity as another one
487    /// The values here are set to be between 0 and 1, no matter the value range
488    /// of the matrix whos shape is being copied.
489    ///
490    /// Examples
491    ///
492    /// ```
493    /// use sukker::SparseMatrix;
494    ///
495    /// let sparse = SparseMatrix::<f32>::randomize_range(2.0, 4.0, 0.75, (4,4));
496    ///
497    /// let copy = SparseMatrix::random_like(&sparse);
498    ///
499    /// assert_eq!(copy.shape(), (4,4));
500    /// assert_eq!(copy.sparsity(), 0.75);
501    /// assert_eq!(copy.all(|(_, val)| val >= 0.0 && val <= 1.0), true);
502    /// assert_eq!(copy.size(), 16);
503    /// ```
504    pub fn random_like(matrix: &Self) -> Self {
505        Self::randomize(matrix.sparsity(), matrix.shape())
506    }
507
508    /// Same as `get`, but will panic if indexes are out of bounds
509    ///
510    /// Examples:
511    ///
512    /// ```
513    /// use sukker::SparseMatrix;
514    ///
515    /// let sparse = SparseMatrix::<i32>::eye(3);
516    ///
517    /// assert_eq!(sparse.at(0,0), 1);
518    /// assert_eq!(sparse.at(1,0), 0);
519    /// ```
520    #[inline(always)]
521    pub fn at(&self, i: usize, j: usize) -> T {
522        match self.data.get(&(i, j)) {
523            None => T::zero(),
524            Some(val) => val.clone(),
525        }
526    }
527
528    /// Sets an element
529    ///
530    /// If you're trying to insert a zero-value, this function
531    /// does nothing
532    ///
533    /// Mutates or inserts a value based on indeces given
534    pub fn set(&mut self, value: T, idx: Shape) {
535        if value == T::zero() {
536            eprintln!("You are trying to insert a 0 value.");
537            return;
538        }
539
540        let i = at!(idx.0, idx.1, self.ncols);
541
542        if i >= self.size() {
543            eprintln!("Error, index out of bounds. Not setting value");
544            return;
545        }
546
547        self.data
548            .entry(idx)
549            .and_modify(|val| *val = value)
550            .or_insert(value);
551    }
552
553    /// A way of inserting with individual row and col
554    pub fn insert(&mut self, i: usize, j: usize, value: T) {
555        self.set(value, (i, j));
556    }
557
558    /// Prints out the sparse matrix data
559    ///
560    /// Only prints out the hashmap with a set amount of decimals
561    pub fn print(&self, decimals: usize) {
562        self.data
563            .iter()
564            .for_each(|((i, j), val)| println!("{i} {j}: {:.decimals$}", val));
565    }
566
567    /// Gets the size of the sparse matrix
568    ///
569    /// Examples:
570    ///
571    /// ```
572    /// use sukker::SparseMatrix;
573    ///
574    /// let sparse = SparseMatrix::<i32>::eye(4);
575    ///
576    /// assert_eq!(sparse.size(), 16);
577    #[inline(always)]
578    pub fn size(&self) -> usize {
579        self.ncols * self.nrows
580    }
581
582    /// Get's amount of 0s in the matrix
583    ///
584    /// Examples:
585    ///
586    /// ```
587    /// use sukker::SparseMatrix;
588    ///
589    /// let sparse = SparseMatrix::<i32>::eye(4);
590    ///
591    /// assert_eq!(sparse.get_zero_count(), 12);
592    #[inline(always)]
593    pub fn get_zero_count(&self) -> usize {
594        self.size() - self.data.len()
595    }
596
597    /// Calcualtes sparcity for the given matrix
598    /// Sparity is defined as the percantage of the matrix
599    /// filled with 0 values
600    ///
601    /// Examples:
602    ///
603    /// ```
604    /// use sukker::SparseMatrix;
605    ///
606    /// let sparse = SparseMatrix::<i32>::eye(4);
607    ///
608    /// assert_eq!(sparse.sparsity(), 0.75);
609    /// ```
610    #[inline(always)]
611    pub fn sparsity(&self) -> f64 {
612        1.0 - self.data.par_iter().count() as f64 / self.size() as f64
613    }
614
615    /// Shape of the matrix outputted as a tuple
616    ///
617    /// Examples:
618    ///
619    /// ```
620    /// use sukker::SparseMatrix;
621    ///
622    /// let sparse = SparseMatrix::<i32>::eye(3);
623    ///
624    /// assert_eq!(sparse.shape(), (3,3));
625    /// ```
626    pub fn shape(&self) -> Shape {
627        (self.nrows, self.ncols)
628    }
629
630    /// Transpose the matrix
631    ///
632    /// Examples:
633    ///
634    /// ```
635    /// use sukker::SparseMatrix;
636    ///
637    /// let mut sparse = SparseMatrix::<i32>::init(4,4);
638    ///
639    /// sparse.set(1, (2,0));
640    /// sparse.set(2, (3,0));
641    /// sparse.set(3, (0,1));
642    /// sparse.set(4, (0,2));
643    ///
644    /// sparse.transpose();
645    ///
646    /// assert_eq!(sparse.at(0,2), 1);
647    /// assert_eq!(sparse.at(0,3), 2);
648    /// assert_eq!(sparse.at(1,0), 3);
649    /// assert_eq!(sparse.at(2,0), 4);
650    ///
651    /// // Old value is now gone
652    /// assert_eq!(sparse.get(3,0), Some(0));
653    /// ```
654    pub fn transpose(&mut self) {
655        let mut new_data: SparseMatrixData<T> = HashMap::new();
656
657        for (&(i, j), &val) in self.data.iter() {
658            new_data.insert((j, i), val);
659        }
660
661        self.data = new_data;
662
663        swap(&mut self.nrows, &mut self.ncols);
664    }
665
666    /// Shorthand for `transpose`
667    pub fn t(&mut self) {
668        self.transpose();
669    }
670
671    /// Tranpose the matrix into a new copy
672    ///
673    /// Examples:
674    ///
675    /// ```
676    /// use sukker::SparseMatrix;
677    ///
678    /// let mut mat = SparseMatrix::<i32>::init(4,4);
679    ///
680    /// mat.set(1, (2,0));
681    /// mat.set(2, (3,0));
682    /// mat.set(3, (0,1));
683    /// mat.set(4, (0,2));
684    ///
685    /// let sparse = mat.transpose_new();
686    ///
687    /// assert_eq!(sparse.at(0,2), 1);
688    /// assert_eq!(sparse.at(0,3), 2);
689    /// assert_eq!(sparse.at(1,0), 3);
690    /// assert_eq!(sparse.at(2,0), 4);
691    ///
692    /// assert_eq!(sparse.get(3,0), Some(0));
693    /// ```
694    pub fn transpose_new(&self) -> Self {
695        let mut res = self.clone();
696        res.transpose();
697        res
698    }
699
700    /// Finds max element of a sparse matrix
701    /// Will return 0 if matrix is empty
702    pub fn max(&self) -> T {
703        let elem = self
704            .data
705            .iter()
706            .max_by(|(_, v1), (_, v2)| v1.partial_cmp(v2).unwrap());
707
708        return match elem {
709            Some((_, &v)) => v,
710            None => T::zero(),
711        };
712    }
713
714    /// Finds minimum element of a sparse matrix
715    /// Will return 0 if matrix is empty
716    pub fn min(&self) -> T {
717        let elem = self
718            .data
719            .iter()
720            .max_by(|(_, v1), (_, v2)| v1.partial_cmp(v2).unwrap());
721
722        return match elem {
723            Some((_, &v)) => v,
724            None => T::zero(),
725        };
726    }
727
728    /// Negates all items
729    pub fn neg(&self) -> Self {
730        let data = self
731            .data
732            .par_iter()
733            .map(|((i, j), &e)| ((*i, *j), e.neg()))
734            .collect::<SparseMatrixData<T>>();
735
736        Self::new(data, self.shape())
737    }
738
739    /// Finds average value of a matrix
740    ///
741    /// Returns 0 if matrix is empty
742    pub fn avg(&self) -> T {
743        self.data.par_iter().map(|(_, &val)| val).sum::<T>()
744            / self.size().to_string().parse::<T>().unwrap()
745    }
746
747    /// Same as `avg`
748    pub fn mean(&self) -> T {
749        self.avg()
750    }
751
752    /// Finds the median value of a matrix
753    ///
754    /// If the matrix is empty, 0 is returned
755    pub fn median(&self) -> T {
756        if self.size() == 0 {
757            return T::zero();
758        }
759
760        if self.size() == 1 {
761            return self.at(0, 0);
762        }
763
764        // If more than half the values are 0 and we only have
765        // values > 0, 0 is returned
766        if self.min() >= T::zero() && self.sparsity() >= 0.5 {
767            return T::zero();
768        }
769
770        let sorted_values: Vec<T> = self
771            .data
772            .values()
773            .copied()
774            .sorted_by(|a, b| a.partial_cmp(&b).unwrap())
775            .collect::<Vec<T>>();
776
777        match self.data.len() % 2 {
778            0 => {
779                let half: usize = self.data.len() / 2;
780
781                sorted_values
782                    .iter()
783                    .skip(half - 1)
784                    .take(2)
785                    .copied()
786                    .sum::<T>()
787                    / (T::one() + T::one())
788            }
789            1 => {
790                let half: usize = self.data.len() / 2;
791
792                sorted_values.iter().nth(half).unwrap().to_owned()
793            }
794            _ => unreachable!(),
795        }
796    }
797}
798
799/// Linear algebra on sparse matrices
800impl<'a, T> LinAlgFloats<'a, T> for SparseMatrix<'a, T>
801where
802    T: MatrixElement + Float,
803    <T as FromStr>::Err: Error + 'static,
804    Vec<T>: IntoParallelIterator,
805    Vec<&'a T>: IntoParallelRefIterator<'a>,
806{
807    fn ln(&self) -> Self {
808        let data = self.data.iter().map(|(&idx, &e)| (idx, e.ln())).collect();
809
810        Self::new(data, self.shape())
811    }
812
813    fn log(&self, base: T) -> Self {
814        let data = self
815            .data
816            .iter()
817            .map(|(&idx, &e)| (idx, e.log(base)))
818            .collect();
819
820        Self::new(data, self.shape())
821    }
822
823    fn sin(&self) -> Self {
824        let data = self.data.iter().map(|(&idx, &e)| (idx, e.sin())).collect();
825
826        Self::new(data, self.shape())
827    }
828
829    fn cos(&self) -> Self {
830        let data = self.data.iter().map(|(&idx, &e)| (idx, e.cos())).collect();
831
832        Self::new(data, self.shape())
833    }
834
835    fn tan(&self) -> Self {
836        let data = self.data.iter().map(|(&idx, &e)| (idx, e.tan())).collect();
837
838        Self::new(data, self.shape())
839    }
840
841    fn sqrt(&self) -> Self {
842        let data = self.data.iter().map(|(&idx, &e)| (idx, e.sqrt())).collect();
843
844        Self::new(data, self.shape())
845    }
846
847    fn sinh(&self) -> Self {
848        let data = self.data.iter().map(|(&idx, &e)| (idx, e.sinh())).collect();
849
850        Self::new(data, self.shape())
851    }
852
853    fn cosh(&self) -> Self {
854        let data = self.data.iter().map(|(&idx, &e)| (idx, e.cosh())).collect();
855
856        Self::new(data, self.shape())
857    }
858
859    fn tanh(&self) -> Self {
860        let data = self.data.iter().map(|(&idx, &e)| (idx, e.tanh())).collect();
861
862        Self::new(data, self.shape())
863    }
864
865    fn get_eigenvalues(&self) -> Option<Vec<T>> {
866        unimplemented!()
867    }
868
869    fn get_eigenvectors(&self) -> Option<Vec<T>> {
870        unimplemented!()
871    }
872}
873
874/// Operations on sparse matrices
875impl<'a, T> SparseMatrix<'a, T>
876where
877    T: MatrixElement,
878    <T as FromStr>::Err: Error + 'static,
879    Vec<T>: IntoParallelIterator,
880    Vec<&'a T>: IntoParallelRefIterator<'a>,
881{
882    /// Adds two sparse matrices together
883    /// and return a new one
884    ///
885    /// Examples:
886    ///
887    /// ```
888    /// use sukker::SparseMatrix;
889    ///
890    /// let sparse1 = SparseMatrix::<i32>::eye(3);
891    /// let sparse2 = SparseMatrix::<i32>::eye(3);
892    ///
893    /// let res = sparse1.add(&sparse2).unwrap();
894    ///
895    /// assert_eq!(res.shape(), (3,3));
896    /// assert_eq!(res.get(0,0).unwrap(), 2);
897    /// ```
898    pub fn add(&self, other: &Self) -> Result<Self, MatrixError> {
899        Self::sparse_helper(&self, other, Operation::ADD)
900    }
901
902    /// Subtracts two sparse matrices
903    /// and return a new one
904    ///
905    /// Examples:
906    ///
907    /// ```
908    /// use sukker::SparseMatrix;
909    ///
910    /// let sparse1 = SparseMatrix::<i32>::eye(3);
911    /// let sparse2 = SparseMatrix::<i32>::eye(3);
912    ///
913    /// let res = sparse1.sub(&sparse2).unwrap();
914    ///
915    /// assert_eq!(res.shape(), (3,3));
916    /// assert_eq!(res.get(0,0).unwrap(), 2);
917    /// ```
918    pub fn sub(&self, other: &Self) -> Result<Self, MatrixError> {
919        Self::sparse_helper(&self, other, Operation::SUB)
920    }
921    /// Multiplies two sparse matrices together
922    /// and return a new one
923    ///
924    /// Examples:
925    ///
926    /// ```
927    /// use sukker::SparseMatrix;
928    ///
929    /// let sparse1 = SparseMatrix::<i32>::eye(3);
930    /// let sparse2 = SparseMatrix::<i32>::eye(3);
931    ///
932    /// let res = sparse1.mul(&sparse2).unwrap();
933    ///
934    /// assert_eq!(res.shape(), (3,3));
935    /// assert_eq!(res.get(0,0).unwrap(), 2);
936    /// ```
937    pub fn mul(&self, other: &Self) -> Result<Self, MatrixError> {
938        Self::sparse_helper(&self, other, Operation::MUL)
939    }
940
941    /// Same as `mul`. This kind of matrix multiplication is called
942    /// a dot product
943    pub fn dot(&self, other: &Self) -> Result<Self, MatrixError> {
944        self.mul(other)
945    }
946
947    /// Divides two sparse matrices
948    /// and return a new one
949    ///
950    /// Examples:
951    ///
952    /// ```
953    /// use sukker::SparseMatrix;
954    ///
955    /// let sparse1 = SparseMatrix::<i32>::eye(3);
956    /// let sparse2 = SparseMatrix::<i32>::eye(3);
957    ///
958    /// let res = sparse1.div(&sparse2).unwrap();
959    ///
960    /// assert_eq!(res.shape(), (3,3));
961    /// assert_eq!(res.get(0,0).unwrap(), 2);
962    /// ```
963    pub fn div(&self, other: &Self) -> Result<Self, MatrixError> {
964        Self::sparse_helper(&self, other, Operation::DIV)
965    }
966
967    // =============================================================
968    //
969    //    Matrix operations modifying the lhs
970    //
971    // =============================================================
972
973    /// Adds rhs matrix on to lhs matrix.
974    /// All elements from rhs gets inserted into lhs
975    ///
976    /// Examples:
977    ///
978    /// ```
979    /// use sukker::SparseMatrix;
980    ///
981    /// let mut sparse1 = SparseMatrix::<i32>::eye(3);
982    /// let sparse2 = SparseMatrix::<i32>::eye(3);
983    ///
984    /// sparse1.add_self(&sparse2);
985    ///
986    /// assert_eq!(sparse1.shape(), (3,3));
987    /// assert_eq!(sparse1.get(0,0).unwrap(), 2);
988    /// ```
989    pub fn add_self(&mut self, other: &Self) {
990        Self::sparse_helper_self(self, other, Operation::ADD);
991    }
992
993    /// Subs rhs matrix on to lhs matrix.
994    /// All elements from rhs gets inserted into lhs
995    ///
996    /// Examples:
997    ///
998    /// ```
999    /// use sukker::SparseMatrix;
1000    ///
1001    /// let mut sparse1 = SparseMatrix::<i32>::eye(3);
1002    /// let sparse2 = SparseMatrix::<i32>::eye(3);
1003    ///
1004    /// sparse1.sub_self(&sparse2);
1005    ///
1006    /// assert_eq!(sparse1.shape(), (3,3));
1007    /// assert_eq!(sparse1.get(0,0).unwrap(), 0);
1008    /// ```
1009    pub fn sub_self(&mut self, other: &Self) {
1010        Self::sparse_helper_self(self, other, Operation::SUB);
1011    }
1012
1013    /// Multiplies  rhs matrix on to lhs matrix.
1014    /// All elements from rhs gets inserted into lhs
1015    ///
1016    /// Examples:
1017    ///
1018    /// ```
1019    /// use sukker::SparseMatrix;
1020    ///
1021    /// let mut sparse1 = SparseMatrix::<i32>::eye(3);
1022    /// let sparse2 = SparseMatrix::<i32>::eye(3);
1023    ///
1024    /// sparse1.mul_self(&sparse2);
1025    ///
1026    /// assert_eq!(sparse1.shape(), (3,3));
1027    /// assert_eq!(sparse1.get(0,0).unwrap(), 1);
1028    /// ```
1029    pub fn mul_self(&mut self, other: &Self) {
1030        Self::sparse_helper_self(self, other, Operation::MUL);
1031    }
1032
1033    /// Divides rhs matrix on to lhs matrix.
1034    /// All elements from rhs gets inserted into lhs
1035    ///
1036    /// Examples:
1037    ///
1038    /// ```
1039    /// use sukker::SparseMatrix;
1040    ///
1041    /// let mut sparse1 = SparseMatrix::<i32>::eye(3);
1042    /// let sparse2 = SparseMatrix::<i32>::eye(3);
1043    ///
1044    /// sparse1.div_self(&sparse2);
1045    ///
1046    /// assert_eq!(sparse1.shape(), (3,3));
1047    /// assert_eq!(sparse1.get(0,0).unwrap(), 1);
1048    /// ```
1049    pub fn div_self(&mut self, other: &Self) {
1050        Self::sparse_helper_self(self, other, Operation::DIV);
1051    }
1052
1053    // =============================================================
1054    //
1055    //    Matrix operations  with a value
1056    //
1057    // =============================================================
1058
1059    /// Adds value to all non zero values in the matrix
1060    /// and return a new matrix
1061    ///
1062    /// Examples:
1063    ///
1064    /// ```
1065    /// use sukker::SparseMatrix;
1066    ///
1067    /// let sparse = SparseMatrix::<f32>::eye(3);
1068    /// let val: f32 = 4.5;
1069    ///
1070    /// let res = sparse.add_val(val);
1071    ///
1072    /// assert_eq!(res.get(0,0).unwrap(), 5.5);
1073    /// ```
1074    pub fn add_val(&self, value: T) -> Self {
1075        Self::sparse_helper_val(self, value, Operation::ADD)
1076    }
1077
1078    /// Subs value to all non zero values in the matrix
1079    /// and return a new matrix
1080    ///
1081    /// Examples:
1082    ///
1083    /// ```
1084    /// use sukker::SparseMatrix;
1085    ///
1086    /// let sparse = SparseMatrix::<f32>::eye(3);
1087    /// let val: f32 = 4.5;
1088    ///
1089    /// let res = sparse.sub_val(val);
1090    ///
1091    /// assert_eq!(res.get(0,0).unwrap(), -3.5);
1092    /// ```
1093    pub fn sub_val(&self, value: T) -> Self {
1094        Self::sparse_helper_val(self, value, Operation::SUB)
1095    }
1096
1097    /// Multiplies value to all non zero values in the matrix
1098    /// and return a new matrix
1099    ///
1100    /// Examples:
1101    ///
1102    /// ```
1103    /// use sukker::SparseMatrix;
1104    ///
1105    /// let sparse = SparseMatrix::<f32>::eye(3);
1106    /// let val: f32 = 4.5;
1107    ///
1108    /// let res = sparse.mul_val(val);
1109    ///
1110    /// assert_eq!(res.get(0,0).unwrap(), 4.5);
1111    /// ```
1112    pub fn mul_val(&self, value: T) -> Self {
1113        Self::sparse_helper_val(self, value, Operation::MUL)
1114    }
1115
1116    /// Divides value to all non zero values in the matrix
1117    /// and return a new matrix.
1118    ///
1119    /// Will panic if you choose to divide by zero
1120    ///
1121    /// Examples:
1122    ///
1123    /// ```
1124    /// use sukker::SparseMatrix;
1125    ///
1126    /// let sparse = SparseMatrix::<f32>::eye(3);
1127    /// let val: f32 = 4.0;
1128    ///
1129    /// let res = sparse.div_val(val);
1130    ///
1131    /// assert_eq!(res.get(0,0).unwrap(), 0.25);
1132    /// ```
1133    pub fn div_val(&self, value: T) -> Self {
1134        Self::sparse_helper_val(self, value, Operation::DIV)
1135    }
1136
1137    // =============================================================
1138    //
1139    //    Matrix operations modyfing lhs  with a value
1140    //
1141    // =============================================================
1142
1143    /// Adds value to all non zero elements in matrix
1144    ///
1145    /// Examples:
1146    ///
1147    /// ```
1148    /// use sukker::SparseMatrix;
1149    ///
1150    /// let mut sparse = SparseMatrix::<f64>::eye(3);
1151    /// let val = 10.0;
1152    ///
1153    /// sparse.add_val_self(val);
1154    ///
1155    /// assert_eq!(sparse.get(0,0).unwrap(), 11.0);
1156    /// ```
1157    pub fn add_val_self(&mut self, value: T) {
1158        Self::sparse_helper_self_val(self, value, Operation::ADD)
1159    }
1160
1161    /// Subtracts value to all non zero elements in matrix
1162    ///
1163    /// Examples:
1164    ///
1165    /// ```
1166    /// use sukker::SparseMatrix;
1167    ///
1168    /// let mut sparse = SparseMatrix::<f64>::eye(3);
1169    /// let val = 10.0;
1170    ///
1171    /// sparse.sub_val_self(val);
1172    ///
1173    /// assert_eq!(sparse.get(0,0).unwrap(), -9.0);
1174    /// ```
1175    pub fn sub_val_self(&mut self, value: T) {
1176        Self::sparse_helper_self_val(self, value, Operation::SUB)
1177    }
1178
1179    /// Multiplies value to all non zero elements in matrix
1180    ///
1181    /// Examples:
1182    ///
1183    /// ```
1184    /// use sukker::SparseMatrix;
1185    ///
1186    /// let mut sparse = SparseMatrix::<f64>::eye(3);
1187    /// let val = 10.0;
1188    ///
1189    /// sparse.mul_val_self(val);
1190    ///
1191    /// assert_eq!(sparse.get(0,0).unwrap(), 10.0);
1192    /// ```
1193    pub fn mul_val_self(&mut self, value: T) {
1194        Self::sparse_helper_self_val(self, value, Operation::MUL)
1195    }
1196
1197    /// Divides all non zero elemnts in matrix by value in-place
1198    ///
1199    /// Will panic if you choose to divide by zero
1200    ///
1201    /// Examples:
1202    ///
1203    /// ```
1204    /// use sukker::SparseMatrix;
1205    ///
1206    /// let mut sparse = SparseMatrix::<f64>::eye(3);
1207    /// let val = 10.0;
1208    ///
1209    /// sparse.div_val_self(val);
1210    ///
1211    /// assert_eq!(sparse.get(0,0).unwrap(), 0.1);
1212    /// ```
1213    pub fn div_val_self(&mut self, value: T) {
1214        Self::sparse_helper_self_val(self, value, Operation::DIV)
1215    }
1216
1217    /// Sparse matrix multiplication
1218    ///
1219    /// For two n x n matrices, we use this algorithm:
1220    /// https://theory.stanford.edu/~virgi/cs367/papers/sparsemult.pdf
1221    ///
1222    /// Else, we use this:
1223    /// link..
1224    ///
1225    /// In this example we have these two matrices:
1226    ///
1227    /// A:
1228    ///
1229    /// 0.0 2.0 0.0
1230    /// 4.0 6.0 0.0
1231    /// 0.0 0.0 8.0
1232    ///
1233    /// B:
1234    ///
1235    /// 2.0 0.0 0.0
1236    /// 4.0 8.0 0.0
1237    /// 8.0 6.0 0.0
1238    ///
1239    /// 0.0 24.0 0
1240    /// 8.0 72.0 0
1241    /// 0.0 0.0  48.0
1242    ///
1243    /// Examples
1244    ///
1245    /// ```
1246    /// use std::collections::HashMap;
1247    /// use sukker::{SparseMatrix, SparseMatrixData};
1248    ///
1249    /// let mut indexes: SparseMatrixData<f64> = HashMap::new();
1250    ///
1251    /// indexes.insert((0, 0), 2.0);
1252    /// indexes.insert((0, 1), 2.0);
1253    /// indexes.insert((1, 0), 2.0);
1254    /// indexes.insert((1, 1), 2.0);
1255    ///
1256    /// let sparse = SparseMatrix::<f64>::new(indexes, (2, 2));
1257    ///
1258    /// let mut indexes2: SparseMatrixData<f64> = HashMap::new();
1259    ///
1260    /// indexes2.insert((0, 0), 2.0);
1261    /// indexes2.insert((0, 1), 2.0);
1262    /// indexes2.insert((1, 0), 2.0);
1263    /// indexes2.insert((1, 1), 2.0);
1264    ///
1265    /// let sparse2 = SparseMatrix::<f64>::new(indexes2, (2, 2));
1266    ///
1267    /// let res = sparse.matmul_sparse(&sparse2).unwrap();
1268    ///
1269    /// assert_eq!(res.at(0, 0), 8.0);
1270    /// assert_eq!(res.at(0, 1), 8.0);
1271    /// assert_eq!(res.at(1, 0), 8.0);
1272    /// assert_eq!(res.at(1, 1), 8.0);
1273    /// ```
1274    pub fn matmul_sparse(&self, other: &Self) -> Result<Self, MatrixError> {
1275        if self.ncols != other.nrows {
1276            return Err(MatrixError::MatrixMultiplicationDimensionMismatchError.into());
1277        }
1278
1279        if self.shape() == other.shape() {
1280            return Ok(self.matmul_sparse_nn(other));
1281        }
1282
1283        Ok(self.matmul_sparse_mnnp(other))
1284    }
1285}
1286
1287/// Predicates for sparse matrices
1288impl<'a, T> SparseMatrix<'a, T>
1289where
1290    T: MatrixElement,
1291    <T as FromStr>::Err: Error + 'static,
1292    Vec<T>: IntoParallelIterator,
1293    Vec<&'a T>: IntoParallelRefIterator<'a>,
1294{
1295    /// Returns whether or not predicate holds for all values
1296    ///
1297    /// # Examples
1298    ///
1299    /// ```
1300    /// use sukker::SparseMatrix;
1301    ///
1302    /// let sparse = SparseMatrix::<i32>::eye(3);
1303    ///
1304    /// assert_eq!(sparse.shape(), (3,3));
1305    /// assert_eq!(sparse.all(|(idx, val)| val >= 0), true);
1306    /// ```
1307    pub fn all<F>(&self, pred: F) -> bool
1308    where
1309        F: Fn((Shape, T)) -> bool + Sync + Send,
1310    {
1311        self.data.clone().into_par_iter().all(pred)
1312    }
1313
1314    /// Returns whether or not predicate holds for any
1315    ///
1316    /// # Examples
1317    ///
1318    /// ```
1319    /// use sukker::SparseMatrix;
1320    ///
1321    /// let sparse = SparseMatrix::<i32>::eye(3);
1322    ///
1323    /// assert_eq!(sparse.shape(), (3,3));
1324    /// assert_eq!(sparse.any(|(_, val)| val == 1), true);
1325    /// ```
1326    pub fn any<F>(&self, pred: F) -> bool
1327    where
1328        F: Fn((Shape, T)) -> bool + Sync + Send,
1329    {
1330        self.data.clone().into_par_iter().any(pred)
1331    }
1332
1333    /// Counts all occurances where predicate holds
1334    ///
1335    /// # Examples
1336    ///
1337    /// ```
1338    /// use sukker::SparseMatrix;
1339    ///
1340    /// let sparse = SparseMatrix::<i32>::eye(3);
1341    ///
1342    /// assert_eq!(sparse.count_where(|(_, &val)| val == 1), 3);
1343    /// ```
1344    pub fn count_where<F>(&'a self, pred: F) -> usize
1345    where
1346        F: Fn((&Shape, &T)) -> bool + Sync,
1347    {
1348        self.data.par_iter().filter(|&e| pred(e)).count()
1349    }
1350
1351    /// Sums all occurances where predicate holds
1352    ///
1353    /// # Examples
1354    ///
1355    /// ```
1356    /// use sukker::SparseMatrix;
1357    ///
1358    /// let sparse = SparseMatrix::<f32>::eye(3);
1359    ///
1360    /// assert_eq!(sparse.sum_where(|(&(i, j), &val)| val == 1.0 && i > 0), 2.0);
1361    /// ```
1362    pub fn sum_where<F>(&self, pred: F) -> T
1363    where
1364        F: Fn((&Shape, &T)) -> bool + Sync,
1365    {
1366        let mut res = T::zero();
1367        for (idx, elem) in self.data.iter() {
1368            if pred((idx, elem)) {
1369                res += elem
1370            }
1371        }
1372
1373        res
1374    }
1375
1376    /// Sets all elements where predicate holds true.
1377    /// The new value is to be set inside the predicate as well
1378    ///
1379    /// # Examples
1380    ///
1381    /// ```
1382    /// ```
1383    pub fn set_where<F>(&mut self, mut pred: F)
1384    where
1385        F: FnMut((&Shape, &mut T)) + Sync + Send,
1386    {
1387        self.data.iter_mut().for_each(|e| pred(e));
1388    }
1389
1390    /// Finds value of first occurance where predicate holds true
1391    ///
1392    /// # Examples
1393    ///
1394    /// ```
1395    /// ```
1396    pub fn find<F>(&self, pred: F) -> Option<T>
1397    where
1398        F: Fn((&Shape, &T)) -> bool + Sync,
1399    {
1400        for entry in &self.data {
1401            if pred(entry) {
1402                return Some(*entry.1);
1403            }
1404        }
1405
1406        None
1407    }
1408
1409    /// Finds all values where predicates holds if possible
1410    ///
1411    /// # Examples
1412    ///
1413    /// ```
1414    /// ```
1415    fn find_all<F>(&self, pred: F) -> Option<Vec<T>>
1416    where
1417        F: Fn((&Shape, &T)) -> bool + Sync,
1418    {
1419        let mut idxs: Vec<T> = Vec::new();
1420        for entry in &self.data {
1421            if pred(entry) {
1422                idxs.push(*entry.1);
1423            }
1424        }
1425
1426        if !idxs.is_empty() {
1427            Some(idxs)
1428        } else {
1429            None
1430        }
1431    }
1432
1433    /// Finds indices of first occurance where predicate holds true
1434    ///
1435    /// # Examples
1436    ///
1437    /// ```
1438    /// ```
1439    pub fn position<F>(&self, pred: F) -> Option<Shape>
1440    where
1441        F: Fn((&Shape, &T)) -> bool + Sync,
1442    {
1443        for entry in &self.data {
1444            if pred(entry) {
1445                return Some(*entry.0);
1446            }
1447        }
1448
1449        None
1450    }
1451
1452    /// Finds all positions  where predicates holds if possible
1453    ///
1454    /// # Examples
1455    ///
1456    /// ```
1457    /// ```
1458    fn positions<F>(&self, pred: F) -> Option<Vec<Shape>>
1459    where
1460        F: Fn((&Shape, &T)) -> bool + Sync,
1461    {
1462        let mut idxs: Vec<Shape> = Vec::new();
1463        for entry in &self.data {
1464            if pred(entry) {
1465                idxs.push(*entry.0);
1466            }
1467        }
1468
1469        if !idxs.is_empty() {
1470            Some(idxs)
1471        } else {
1472            None
1473        }
1474    }
1475}