linalg_rs/sparse/mod.rs
1//! Module for defining sparse matrices.
2//!
3//! # What are sparse matrices
4//!
5//! Generally speaking, matrices with a lot of 0s
6//!
7//! # How are they represented
8//!
9//! Since storing large sparse matrices in memory is expensive
10//!
11//!
12//! # What data structure does sukker use?
13//!
14//! For now, a hash map where the keys are indeces in the matrix
15//! and tha value is the value at that 2d index
16#![warn(missing_docs)]
17
18mod helper;
19
20use helper::*;
21use num_traits::Float;
22use rand::Rng;
23
24use itertools::Itertools;
25use std::fmt::Display;
26use std::fs;
27use std::{collections::HashMap, error::Error, marker::PhantomData, str::FromStr};
28
29use rayon::prelude::*;
30use serde::{Deserialize, Serialize};
31
32use crate::{at, LinAlgFloats, Matrix, MatrixElement, MatrixError, Operation, Shape};
33
34/// SparseMatrixData represents the datatype used to store information
35/// about non-zero values in a general matrix.
36///
37/// The keys are the index to the position in data,
38/// while the value is the value to be stored inside the matrix
39pub type SparseMatrixData<'a, T> = HashMap<Shape, T>;
40
41#[derive(Clone, PartialEq, Debug, Serialize, Deserialize)]
42/// Represents a sparse matrix and its data
43pub struct SparseMatrix<'a, T>
44where
45 T: MatrixElement,
46 <T as FromStr>::Err: Error + 'static,
47 Vec<T>: IntoParallelIterator,
48 Vec<&'a T>: IntoParallelRefIterator<'a>,
49{
50 /// Vector containing all data
51 pub data: SparseMatrixData<'a, T>,
52 /// Number of rows
53 pub nrows: usize,
54 /// Number of columns
55 pub ncols: usize,
56 _lifetime: PhantomData<&'a T>,
57}
58
59impl<'a, T> Error for SparseMatrix<'a, T>
60where
61 T: MatrixElement,
62 <T as FromStr>::Err: Error + 'static,
63 Vec<T>: IntoParallelIterator,
64 Vec<&'a T>: IntoParallelRefIterator<'a>,
65{
66}
67
68unsafe impl<'a, T> Send for SparseMatrix<'a, T>
69where
70 T: MatrixElement,
71 <T as FromStr>::Err: Error + 'static,
72 Vec<T>: IntoParallelIterator,
73 Vec<&'a T>: IntoParallelRefIterator<'a>,
74{
75}
76
77unsafe impl<'a, T> Sync for SparseMatrix<'a, T>
78where
79 T: MatrixElement,
80 <T as FromStr>::Err: Error + 'static,
81 Vec<T>: IntoParallelIterator,
82 Vec<&'a T>: IntoParallelRefIterator<'a>,
83{
84}
85
86impl<'a, T> FromStr for SparseMatrix<'a, T>
87where
88 T: MatrixElement,
89 <T as FromStr>::Err: Error + 'static,
90{
91 type Err = anyhow::Error;
92 fn from_str(s: &str) -> Result<Self, Self::Err> {
93 // Parse the input string and construct the matrix dynamically
94 let data = s
95 .trim()
96 .lines()
97 .skip(1)
98 .map(|l| {
99 let entry: Vec<&str> = l.split_whitespace().collect();
100
101 let row = entry[0].parse::<usize>().unwrap();
102 let col = entry[1].parse::<usize>().unwrap();
103 let val = entry[2].parse::<T>().unwrap();
104
105 ((row, col), val)
106 })
107 .collect::<SparseMatrixData<T>>();
108
109 let dims = s
110 .trim()
111 .lines()
112 .nth(0)
113 .unwrap()
114 .split_whitespace()
115 .map(|e| e.parse::<usize>().unwrap())
116 .collect::<Vec<usize>>();
117
118 Ok(Self::new(data, (dims[0], dims[1])))
119 }
120}
121
122impl<'a, T> Display for SparseMatrix<'a, T>
123where
124 T: MatrixElement,
125 <T as FromStr>::Err: Error + 'static,
126 Vec<T>: IntoParallelIterator,
127 Vec<&'a T>: IntoParallelRefIterator<'a>,
128{
129 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
130 for i in 0..self.nrows {
131 for j in 0..self.ncols {
132 let elem = match self.data.get(&(i, j)) {
133 Some(&val) => val,
134 None => T::zero(),
135 };
136
137 write!(f, "{elem} ");
138 }
139 writeln!(f);
140 }
141 writeln!(f, "\ndtype = {}", std::any::type_name::<T>())
142 }
143}
144
145impl<'a, T> Default for SparseMatrix<'a, T>
146where
147 T: MatrixElement,
148 <T as FromStr>::Err: Error + 'static,
149 Vec<T>: IntoParallelIterator,
150 Vec<&'a T>: IntoParallelRefIterator<'a>,
151{
152 /// Returns a sparse 3x3 identity matrix
153 fn default() -> Self {
154 Self {
155 data: HashMap::new(),
156 nrows: 0,
157 ncols: 0,
158 _lifetime: PhantomData::default(),
159 }
160 }
161}
162
163impl<'a, T> SparseMatrix<'a, T>
164where
165 T: MatrixElement,
166 <T as FromStr>::Err: Error + 'static,
167 Vec<T>: IntoParallelIterator,
168 Vec<&'a T>: IntoParallelRefIterator<'a>,
169{
170 /// Constructs a new sparse matrix based on a hashmap
171 /// containing the indices where value is not 0
172 ///
173 /// This function does not check whether or not the
174 /// indices are valid and according to shape. Use `reshape`
175 /// to fix this issue.
176 ///
177 /// Examples
178 ///
179 /// ```
180 /// use std::collections::HashMap;
181 /// use sukker::{smd, SparseMatrix, SparseMatrixData};
182 ///
183 /// // Here we can use the smd! macro
184 /// // to easily be able to set up a new hashmap
185 /// let indexes: SparseMatrixData<f64> = smd![
186 /// ( (0, 0), 2.0),
187 /// ( (0, 3), 4.0),
188 /// ( (4, 5), 6.0),
189 /// ( (2, 7), 8.0)
190 /// ];
191 ///
192 /// let sparse = SparseMatrix::<f64>::new(indexes, (3,3));
193 ///
194 /// assert_eq!(sparse.shape(), (3,3));
195 /// assert_eq!(sparse.get(4,5), None);
196 /// assert_eq!(sparse.get(0,1), Some(0.0));
197 /// ```
198 pub fn new(data: SparseMatrixData<'a, T>, shape: Shape) -> Self {
199 Self {
200 data,
201 nrows: shape.0,
202 ncols: shape.1,
203 _lifetime: PhantomData::default(),
204 }
205 }
206
207 /// Inits an empty matrix based on shape
208 pub fn init(nrows: usize, ncols: usize) -> Self {
209 Self {
210 data: HashMap::new(),
211 nrows,
212 ncols,
213 _lifetime: PhantomData::default(),
214 }
215 }
216
217 /// Returns a sparse eye matrix
218 ///
219 /// Examples
220 ///
221 /// ```
222 /// use sukker::SparseMatrix;
223 ///
224 /// let sparse = SparseMatrix::<i32>::eye(3);
225 ///
226 /// assert_eq!(sparse.ncols, 3);
227 /// assert_eq!(sparse.nrows, 3);
228 /// ```
229 pub fn eye(size: usize) -> Self {
230 let data: SparseMatrixData<'a, T> = (0..size)
231 .into_par_iter()
232 .map(|i| ((i, i), T::one()))
233 .collect();
234
235 Self::new(data, (size, size))
236 }
237
238 /// Produces an eye with the same shape as another
239 /// sparse matrix
240 pub fn eye_like(matrix: &Self) -> Self {
241 Self::eye(matrix.nrows)
242 }
243
244 /// Same as eye
245 ///
246 /// Examples
247 ///
248 /// ```
249 /// use sukker::SparseMatrix;
250 ///
251 /// let sparse = SparseMatrix::<f64>::identity(3);
252 ///
253 /// assert_eq!(sparse.ncols, 3);
254 /// assert_eq!(sparse.nrows, 3);
255 /// ```
256 pub fn identity(size: usize) -> Self {
257 Self::eye(size)
258 }
259
260 /// Creates a matrix with only one values at random
261 /// locations
262 ///
263 /// Same as `random_like` but with range from 1.0..=1.0
264 pub fn ones(sparsity: f64, shape: Shape) -> Self {
265 Self::randomize_range(T::one(), T::one(), sparsity, shape)
266 }
267
268 /// Reshapes a sparse matrix
269 ///
270 /// Examples
271 ///
272 /// ```
273 /// use sukker::SparseMatrix;
274 ///
275 /// let mut sparse = SparseMatrix::<f64>::identity(3);
276 ///
277 /// sparse.reshape(5,5);
278 ///
279 /// assert_eq!(sparse.ncols, 5);
280 /// assert_eq!(sparse.nrows, 5);
281 /// ```
282 pub fn reshape(&mut self, nrows: usize, ncols: usize) {
283 self.nrows = nrows;
284 self.ncols = ncols;
285 }
286
287 /// Creates a sparse matrix from a already existent
288 /// dense one.
289 ///
290 /// Examples:
291 ///
292 /// ```
293 /// use sukker::{SparseMatrix, Matrix};
294 ///
295 /// let dense = Matrix::<i32>::eye(4);
296 ///
297 /// let sparse = SparseMatrix::from_dense(dense);
298 ///
299 /// assert_eq!(sparse.get(0,0), Some(1));
300 /// assert_eq!(sparse.get(1,0), Some(0));
301 /// assert_eq!(sparse.shape(), (4,4));
302 /// ```
303 pub fn from_dense(matrix: Matrix<'a, T>) -> Self {
304 let mut data: SparseMatrixData<'a, T> = HashMap::new();
305
306 for i in 0..matrix.nrows {
307 for j in 0..matrix.ncols {
308 let val = matrix.get(i, j).unwrap();
309 if val != T::zero() {
310 data.insert((i, j), val);
311 }
312 }
313 }
314
315 Self::new(data, matrix.shape())
316 }
317
318 /// Constructs a sparse matrix from 3 slices.
319 /// One for the rows, one for the cols, and one for the value.
320 /// A combination of values fromt the same index corresponds to
321 /// an entry in the hasmap.
322 ///
323 /// Csc in numpy uses 3 lists of same size
324 ///
325 /// Examples:
326 ///
327 /// ```
328 /// use sukker::SparseMatrix;
329 ///
330 /// let rows = vec![0,1,2,3];
331 /// let cols = vec![1,2,3,4];
332 /// let vals= vec![0.0,1.3,0.05,4.53];
333 ///
334 /// let shape = (6,7);
335 ///
336 /// let sparse = SparseMatrix::from_slices(&rows, &cols, &vals, shape).unwrap();
337 ///
338 /// assert_eq!(sparse.shape(), (6,7));
339 /// assert_eq!(sparse.at(1,2), 1.3);
340 /// assert_eq!(sparse.at(0,1), 0.0);
341 /// ```
342 pub fn from_slices(
343 rows: &[usize],
344 cols: &[usize],
345 vals: &[T],
346 shape: Shape,
347 ) -> Result<Self, MatrixError> {
348 if rows.len() != cols.len() && cols.len() != vals.len() {
349 return Err(MatrixError::MatrixDimensionMismatchError.into());
350 }
351
352 let data: SparseMatrixData<T> = rows
353 .iter()
354 .zip(cols.iter().zip(vals.iter()))
355 .map(|(&i, (&j, &val))| ((i, j), val))
356 .collect();
357
358 Ok(Self::new(data, shape))
359 }
360
361 /// Parses from file, but will return a default sparse matrix if nothing is given
362 ///
363 /// # Examples
364 ///
365 /// ```
366 /// use sukker::SparseMatrix;
367 ///
368 /// // let m: SparseMatrix<f32> = Matrix::from_file("../../test.txt").unwrap();
369 ///
370 /// // m.print(4);
371 /// ```
372 pub fn from_file(path: &'static str) -> Result<Self, MatrixError> {
373 let data =
374 fs::read_to_string(path).map_err(|_| MatrixError::MatrixFileReadError(path).into())?;
375
376 data.parse::<Self>()
377 .map_err(|_| MatrixError::MatrixParseError.into())
378 }
379
380 /// Gets an element from the sparse matrix.
381 ///
382 /// Returns None if index is out of bounds.
383 ///
384 /// Examples
385 ///
386 /// ```
387 /// use sukker::SparseMatrix;
388 ///
389 /// let sparse = SparseMatrix::<i32>::eye(3);
390 ///
391 /// assert_eq!(sparse.get(0,0), Some(1));
392 /// assert_eq!(sparse.get(1,0), Some(0));
393 /// assert_eq!(sparse.get(4,0), None);
394 /// ```
395 pub fn get(&self, i: usize, j: usize) -> Option<T> {
396 let idx = at!(i, j, self.ncols);
397
398 if idx >= self.size() {
399 eprintln!("Error, index out of bounds. Not setting value");
400 return None;
401 }
402
403 match self.data.get(&(i, j)) {
404 None => Some(T::zero()),
405 val => val.copied(),
406 }
407 }
408
409 /// Gets the size of the sparse matrix
410 ///
411 /// Examples
412 ///
413 /// ```
414 /// use sukker::SparseMatrix;
415 ///
416 /// let sparse = SparseMatrix::<f32>::randomize_range(1.0,2.0, 0.75, (4,4));
417 ///
418 /// assert_eq!(sparse.shape(), (4,4));
419 /// assert_eq!(sparse.sparsity(), 0.75);
420 /// assert_eq!(sparse.all(|(_, val)| val >= 1.0 && val <= 2.0), true);
421 /// assert_eq!(sparse.size(), 16);
422 /// ```
423 pub fn randomize_range(start: T, end: T, sparsity: f64, shape: Shape) -> Self {
424 let mut rng = rand::thread_rng();
425
426 let (rows, cols) = shape;
427
428 // If we insert in a position that's already filled up,
429 // we ahve to get a new one
430 let mut matrix = Self::init(shape.0, shape.1);
431
432 while matrix.sparsity() > sparsity {
433 let value: T = rng.gen_range(start..=end);
434
435 let row: usize = rng.gen_range(0..rows);
436 let col: usize = rng.gen_range(0..cols);
437
438 match matrix.data.get(&(row, col)) {
439 Some(_) => {}
440 None => matrix.set(value, (row, col)),
441 }
442 }
443
444 matrix
445 }
446
447 /// Randomizes a sparse matrix with values between 0 and 1.
448 ///
449 /// Examples
450 ///
451 /// ```
452 /// use sukker::SparseMatrix;
453 ///
454 /// let sparse = SparseMatrix::<f32>::randomize(0.75, (4,4));
455 ///
456 /// assert_eq!(sparse.shape(), (4,4));
457 /// assert_eq!(sparse.sparsity(), 0.75);
458 /// assert_eq!(sparse.all(|(_, val)| val >= 0.0 && val <= 1.0), true);
459 /// assert_eq!(sparse.size(), 16);
460 /// ```
461 pub fn randomize(sparcity: f64, shape: Shape) -> Self {
462 Self::randomize_range(T::zero(), T::one(), sparcity, shape)
463 }
464
465 /// Randomizes a sparse matrix to have same shape and sparcity as another one
466 /// You can however set the range
467 ///
468 /// Examples
469 ///
470 /// ```
471 /// use sukker::SparseMatrix;
472 ///
473 /// let sparse = SparseMatrix::<f32>::randomize_range(1.0,2.0, 0.75, (4,4));
474 ///
475 /// let copy = SparseMatrix::randomize_range_like(2.0, 4.0, &sparse);
476 ///
477 /// assert_eq!(copy.shape(), (4,4));
478 /// assert_eq!(copy.sparsity(), 0.75);
479 /// assert_eq!(copy.all(|(_, val)| val >= 2.0 && val <= 4.0), true);
480 /// assert_eq!(copy.size(), 16);
481 /// ```
482 pub fn randomize_range_like(start: T, end: T, matrix: &Self) -> Self {
483 Self::randomize_range(start, end, matrix.sparsity(), matrix.shape())
484 }
485
486 /// Randomizes a sparse matrix to have same shape and sparcity as another one
487 /// The values here are set to be between 0 and 1, no matter the value range
488 /// of the matrix whos shape is being copied.
489 ///
490 /// Examples
491 ///
492 /// ```
493 /// use sukker::SparseMatrix;
494 ///
495 /// let sparse = SparseMatrix::<f32>::randomize_range(2.0, 4.0, 0.75, (4,4));
496 ///
497 /// let copy = SparseMatrix::random_like(&sparse);
498 ///
499 /// assert_eq!(copy.shape(), (4,4));
500 /// assert_eq!(copy.sparsity(), 0.75);
501 /// assert_eq!(copy.all(|(_, val)| val >= 0.0 && val <= 1.0), true);
502 /// assert_eq!(copy.size(), 16);
503 /// ```
504 pub fn random_like(matrix: &Self) -> Self {
505 Self::randomize(matrix.sparsity(), matrix.shape())
506 }
507
508 /// Same as `get`, but will panic if indexes are out of bounds
509 ///
510 /// Examples:
511 ///
512 /// ```
513 /// use sukker::SparseMatrix;
514 ///
515 /// let sparse = SparseMatrix::<i32>::eye(3);
516 ///
517 /// assert_eq!(sparse.at(0,0), 1);
518 /// assert_eq!(sparse.at(1,0), 0);
519 /// ```
520 #[inline(always)]
521 pub fn at(&self, i: usize, j: usize) -> T {
522 match self.data.get(&(i, j)) {
523 None => T::zero(),
524 Some(val) => val.clone(),
525 }
526 }
527
528 /// Sets an element
529 ///
530 /// If you're trying to insert a zero-value, this function
531 /// does nothing
532 ///
533 /// Mutates or inserts a value based on indeces given
534 pub fn set(&mut self, value: T, idx: Shape) {
535 if value == T::zero() {
536 eprintln!("You are trying to insert a 0 value.");
537 return;
538 }
539
540 let i = at!(idx.0, idx.1, self.ncols);
541
542 if i >= self.size() {
543 eprintln!("Error, index out of bounds. Not setting value");
544 return;
545 }
546
547 self.data
548 .entry(idx)
549 .and_modify(|val| *val = value)
550 .or_insert(value);
551 }
552
553 /// A way of inserting with individual row and col
554 pub fn insert(&mut self, i: usize, j: usize, value: T) {
555 self.set(value, (i, j));
556 }
557
558 /// Prints out the sparse matrix data
559 ///
560 /// Only prints out the hashmap with a set amount of decimals
561 pub fn print(&self, decimals: usize) {
562 self.data
563 .iter()
564 .for_each(|((i, j), val)| println!("{i} {j}: {:.decimals$}", val));
565 }
566
567 /// Gets the size of the sparse matrix
568 ///
569 /// Examples:
570 ///
571 /// ```
572 /// use sukker::SparseMatrix;
573 ///
574 /// let sparse = SparseMatrix::<i32>::eye(4);
575 ///
576 /// assert_eq!(sparse.size(), 16);
577 #[inline(always)]
578 pub fn size(&self) -> usize {
579 self.ncols * self.nrows
580 }
581
582 /// Get's amount of 0s in the matrix
583 ///
584 /// Examples:
585 ///
586 /// ```
587 /// use sukker::SparseMatrix;
588 ///
589 /// let sparse = SparseMatrix::<i32>::eye(4);
590 ///
591 /// assert_eq!(sparse.get_zero_count(), 12);
592 #[inline(always)]
593 pub fn get_zero_count(&self) -> usize {
594 self.size() - self.data.len()
595 }
596
597 /// Calcualtes sparcity for the given matrix
598 /// Sparity is defined as the percantage of the matrix
599 /// filled with 0 values
600 ///
601 /// Examples:
602 ///
603 /// ```
604 /// use sukker::SparseMatrix;
605 ///
606 /// let sparse = SparseMatrix::<i32>::eye(4);
607 ///
608 /// assert_eq!(sparse.sparsity(), 0.75);
609 /// ```
610 #[inline(always)]
611 pub fn sparsity(&self) -> f64 {
612 1.0 - self.data.par_iter().count() as f64 / self.size() as f64
613 }
614
615 /// Shape of the matrix outputted as a tuple
616 ///
617 /// Examples:
618 ///
619 /// ```
620 /// use sukker::SparseMatrix;
621 ///
622 /// let sparse = SparseMatrix::<i32>::eye(3);
623 ///
624 /// assert_eq!(sparse.shape(), (3,3));
625 /// ```
626 pub fn shape(&self) -> Shape {
627 (self.nrows, self.ncols)
628 }
629
630 /// Transpose the matrix
631 ///
632 /// Examples:
633 ///
634 /// ```
635 /// use sukker::SparseMatrix;
636 ///
637 /// let mut sparse = SparseMatrix::<i32>::init(4,4);
638 ///
639 /// sparse.set(1, (2,0));
640 /// sparse.set(2, (3,0));
641 /// sparse.set(3, (0,1));
642 /// sparse.set(4, (0,2));
643 ///
644 /// sparse.transpose();
645 ///
646 /// assert_eq!(sparse.at(0,2), 1);
647 /// assert_eq!(sparse.at(0,3), 2);
648 /// assert_eq!(sparse.at(1,0), 3);
649 /// assert_eq!(sparse.at(2,0), 4);
650 ///
651 /// // Old value is now gone
652 /// assert_eq!(sparse.get(3,0), Some(0));
653 /// ```
654 pub fn transpose(&mut self) {
655 let mut new_data: SparseMatrixData<T> = HashMap::new();
656
657 for (&(i, j), &val) in self.data.iter() {
658 new_data.insert((j, i), val);
659 }
660
661 self.data = new_data;
662
663 swap(&mut self.nrows, &mut self.ncols);
664 }
665
666 /// Shorthand for `transpose`
667 pub fn t(&mut self) {
668 self.transpose();
669 }
670
671 /// Tranpose the matrix into a new copy
672 ///
673 /// Examples:
674 ///
675 /// ```
676 /// use sukker::SparseMatrix;
677 ///
678 /// let mut mat = SparseMatrix::<i32>::init(4,4);
679 ///
680 /// mat.set(1, (2,0));
681 /// mat.set(2, (3,0));
682 /// mat.set(3, (0,1));
683 /// mat.set(4, (0,2));
684 ///
685 /// let sparse = mat.transpose_new();
686 ///
687 /// assert_eq!(sparse.at(0,2), 1);
688 /// assert_eq!(sparse.at(0,3), 2);
689 /// assert_eq!(sparse.at(1,0), 3);
690 /// assert_eq!(sparse.at(2,0), 4);
691 ///
692 /// assert_eq!(sparse.get(3,0), Some(0));
693 /// ```
694 pub fn transpose_new(&self) -> Self {
695 let mut res = self.clone();
696 res.transpose();
697 res
698 }
699
700 /// Finds max element of a sparse matrix
701 /// Will return 0 if matrix is empty
702 pub fn max(&self) -> T {
703 let elem = self
704 .data
705 .iter()
706 .max_by(|(_, v1), (_, v2)| v1.partial_cmp(v2).unwrap());
707
708 return match elem {
709 Some((_, &v)) => v,
710 None => T::zero(),
711 };
712 }
713
714 /// Finds minimum element of a sparse matrix
715 /// Will return 0 if matrix is empty
716 pub fn min(&self) -> T {
717 let elem = self
718 .data
719 .iter()
720 .max_by(|(_, v1), (_, v2)| v1.partial_cmp(v2).unwrap());
721
722 return match elem {
723 Some((_, &v)) => v,
724 None => T::zero(),
725 };
726 }
727
728 /// Negates all items
729 pub fn neg(&self) -> Self {
730 let data = self
731 .data
732 .par_iter()
733 .map(|((i, j), &e)| ((*i, *j), e.neg()))
734 .collect::<SparseMatrixData<T>>();
735
736 Self::new(data, self.shape())
737 }
738
739 /// Finds average value of a matrix
740 ///
741 /// Returns 0 if matrix is empty
742 pub fn avg(&self) -> T {
743 self.data.par_iter().map(|(_, &val)| val).sum::<T>()
744 / self.size().to_string().parse::<T>().unwrap()
745 }
746
747 /// Same as `avg`
748 pub fn mean(&self) -> T {
749 self.avg()
750 }
751
752 /// Finds the median value of a matrix
753 ///
754 /// If the matrix is empty, 0 is returned
755 pub fn median(&self) -> T {
756 if self.size() == 0 {
757 return T::zero();
758 }
759
760 if self.size() == 1 {
761 return self.at(0, 0);
762 }
763
764 // If more than half the values are 0 and we only have
765 // values > 0, 0 is returned
766 if self.min() >= T::zero() && self.sparsity() >= 0.5 {
767 return T::zero();
768 }
769
770 let sorted_values: Vec<T> = self
771 .data
772 .values()
773 .copied()
774 .sorted_by(|a, b| a.partial_cmp(&b).unwrap())
775 .collect::<Vec<T>>();
776
777 match self.data.len() % 2 {
778 0 => {
779 let half: usize = self.data.len() / 2;
780
781 sorted_values
782 .iter()
783 .skip(half - 1)
784 .take(2)
785 .copied()
786 .sum::<T>()
787 / (T::one() + T::one())
788 }
789 1 => {
790 let half: usize = self.data.len() / 2;
791
792 sorted_values.iter().nth(half).unwrap().to_owned()
793 }
794 _ => unreachable!(),
795 }
796 }
797}
798
799/// Linear algebra on sparse matrices
800impl<'a, T> LinAlgFloats<'a, T> for SparseMatrix<'a, T>
801where
802 T: MatrixElement + Float,
803 <T as FromStr>::Err: Error + 'static,
804 Vec<T>: IntoParallelIterator,
805 Vec<&'a T>: IntoParallelRefIterator<'a>,
806{
807 fn ln(&self) -> Self {
808 let data = self.data.iter().map(|(&idx, &e)| (idx, e.ln())).collect();
809
810 Self::new(data, self.shape())
811 }
812
813 fn log(&self, base: T) -> Self {
814 let data = self
815 .data
816 .iter()
817 .map(|(&idx, &e)| (idx, e.log(base)))
818 .collect();
819
820 Self::new(data, self.shape())
821 }
822
823 fn sin(&self) -> Self {
824 let data = self.data.iter().map(|(&idx, &e)| (idx, e.sin())).collect();
825
826 Self::new(data, self.shape())
827 }
828
829 fn cos(&self) -> Self {
830 let data = self.data.iter().map(|(&idx, &e)| (idx, e.cos())).collect();
831
832 Self::new(data, self.shape())
833 }
834
835 fn tan(&self) -> Self {
836 let data = self.data.iter().map(|(&idx, &e)| (idx, e.tan())).collect();
837
838 Self::new(data, self.shape())
839 }
840
841 fn sqrt(&self) -> Self {
842 let data = self.data.iter().map(|(&idx, &e)| (idx, e.sqrt())).collect();
843
844 Self::new(data, self.shape())
845 }
846
847 fn sinh(&self) -> Self {
848 let data = self.data.iter().map(|(&idx, &e)| (idx, e.sinh())).collect();
849
850 Self::new(data, self.shape())
851 }
852
853 fn cosh(&self) -> Self {
854 let data = self.data.iter().map(|(&idx, &e)| (idx, e.cosh())).collect();
855
856 Self::new(data, self.shape())
857 }
858
859 fn tanh(&self) -> Self {
860 let data = self.data.iter().map(|(&idx, &e)| (idx, e.tanh())).collect();
861
862 Self::new(data, self.shape())
863 }
864
865 fn get_eigenvalues(&self) -> Option<Vec<T>> {
866 unimplemented!()
867 }
868
869 fn get_eigenvectors(&self) -> Option<Vec<T>> {
870 unimplemented!()
871 }
872}
873
874/// Operations on sparse matrices
875impl<'a, T> SparseMatrix<'a, T>
876where
877 T: MatrixElement,
878 <T as FromStr>::Err: Error + 'static,
879 Vec<T>: IntoParallelIterator,
880 Vec<&'a T>: IntoParallelRefIterator<'a>,
881{
882 /// Adds two sparse matrices together
883 /// and return a new one
884 ///
885 /// Examples:
886 ///
887 /// ```
888 /// use sukker::SparseMatrix;
889 ///
890 /// let sparse1 = SparseMatrix::<i32>::eye(3);
891 /// let sparse2 = SparseMatrix::<i32>::eye(3);
892 ///
893 /// let res = sparse1.add(&sparse2).unwrap();
894 ///
895 /// assert_eq!(res.shape(), (3,3));
896 /// assert_eq!(res.get(0,0).unwrap(), 2);
897 /// ```
898 pub fn add(&self, other: &Self) -> Result<Self, MatrixError> {
899 Self::sparse_helper(&self, other, Operation::ADD)
900 }
901
902 /// Subtracts two sparse matrices
903 /// and return a new one
904 ///
905 /// Examples:
906 ///
907 /// ```
908 /// use sukker::SparseMatrix;
909 ///
910 /// let sparse1 = SparseMatrix::<i32>::eye(3);
911 /// let sparse2 = SparseMatrix::<i32>::eye(3);
912 ///
913 /// let res = sparse1.sub(&sparse2).unwrap();
914 ///
915 /// assert_eq!(res.shape(), (3,3));
916 /// assert_eq!(res.get(0,0).unwrap(), 2);
917 /// ```
918 pub fn sub(&self, other: &Self) -> Result<Self, MatrixError> {
919 Self::sparse_helper(&self, other, Operation::SUB)
920 }
921 /// Multiplies two sparse matrices together
922 /// and return a new one
923 ///
924 /// Examples:
925 ///
926 /// ```
927 /// use sukker::SparseMatrix;
928 ///
929 /// let sparse1 = SparseMatrix::<i32>::eye(3);
930 /// let sparse2 = SparseMatrix::<i32>::eye(3);
931 ///
932 /// let res = sparse1.mul(&sparse2).unwrap();
933 ///
934 /// assert_eq!(res.shape(), (3,3));
935 /// assert_eq!(res.get(0,0).unwrap(), 2);
936 /// ```
937 pub fn mul(&self, other: &Self) -> Result<Self, MatrixError> {
938 Self::sparse_helper(&self, other, Operation::MUL)
939 }
940
941 /// Same as `mul`. This kind of matrix multiplication is called
942 /// a dot product
943 pub fn dot(&self, other: &Self) -> Result<Self, MatrixError> {
944 self.mul(other)
945 }
946
947 /// Divides two sparse matrices
948 /// and return a new one
949 ///
950 /// Examples:
951 ///
952 /// ```
953 /// use sukker::SparseMatrix;
954 ///
955 /// let sparse1 = SparseMatrix::<i32>::eye(3);
956 /// let sparse2 = SparseMatrix::<i32>::eye(3);
957 ///
958 /// let res = sparse1.div(&sparse2).unwrap();
959 ///
960 /// assert_eq!(res.shape(), (3,3));
961 /// assert_eq!(res.get(0,0).unwrap(), 2);
962 /// ```
963 pub fn div(&self, other: &Self) -> Result<Self, MatrixError> {
964 Self::sparse_helper(&self, other, Operation::DIV)
965 }
966
967 // =============================================================
968 //
969 // Matrix operations modifying the lhs
970 //
971 // =============================================================
972
973 /// Adds rhs matrix on to lhs matrix.
974 /// All elements from rhs gets inserted into lhs
975 ///
976 /// Examples:
977 ///
978 /// ```
979 /// use sukker::SparseMatrix;
980 ///
981 /// let mut sparse1 = SparseMatrix::<i32>::eye(3);
982 /// let sparse2 = SparseMatrix::<i32>::eye(3);
983 ///
984 /// sparse1.add_self(&sparse2);
985 ///
986 /// assert_eq!(sparse1.shape(), (3,3));
987 /// assert_eq!(sparse1.get(0,0).unwrap(), 2);
988 /// ```
989 pub fn add_self(&mut self, other: &Self) {
990 Self::sparse_helper_self(self, other, Operation::ADD);
991 }
992
993 /// Subs rhs matrix on to lhs matrix.
994 /// All elements from rhs gets inserted into lhs
995 ///
996 /// Examples:
997 ///
998 /// ```
999 /// use sukker::SparseMatrix;
1000 ///
1001 /// let mut sparse1 = SparseMatrix::<i32>::eye(3);
1002 /// let sparse2 = SparseMatrix::<i32>::eye(3);
1003 ///
1004 /// sparse1.sub_self(&sparse2);
1005 ///
1006 /// assert_eq!(sparse1.shape(), (3,3));
1007 /// assert_eq!(sparse1.get(0,0).unwrap(), 0);
1008 /// ```
1009 pub fn sub_self(&mut self, other: &Self) {
1010 Self::sparse_helper_self(self, other, Operation::SUB);
1011 }
1012
1013 /// Multiplies rhs matrix on to lhs matrix.
1014 /// All elements from rhs gets inserted into lhs
1015 ///
1016 /// Examples:
1017 ///
1018 /// ```
1019 /// use sukker::SparseMatrix;
1020 ///
1021 /// let mut sparse1 = SparseMatrix::<i32>::eye(3);
1022 /// let sparse2 = SparseMatrix::<i32>::eye(3);
1023 ///
1024 /// sparse1.mul_self(&sparse2);
1025 ///
1026 /// assert_eq!(sparse1.shape(), (3,3));
1027 /// assert_eq!(sparse1.get(0,0).unwrap(), 1);
1028 /// ```
1029 pub fn mul_self(&mut self, other: &Self) {
1030 Self::sparse_helper_self(self, other, Operation::MUL);
1031 }
1032
1033 /// Divides rhs matrix on to lhs matrix.
1034 /// All elements from rhs gets inserted into lhs
1035 ///
1036 /// Examples:
1037 ///
1038 /// ```
1039 /// use sukker::SparseMatrix;
1040 ///
1041 /// let mut sparse1 = SparseMatrix::<i32>::eye(3);
1042 /// let sparse2 = SparseMatrix::<i32>::eye(3);
1043 ///
1044 /// sparse1.div_self(&sparse2);
1045 ///
1046 /// assert_eq!(sparse1.shape(), (3,3));
1047 /// assert_eq!(sparse1.get(0,0).unwrap(), 1);
1048 /// ```
1049 pub fn div_self(&mut self, other: &Self) {
1050 Self::sparse_helper_self(self, other, Operation::DIV);
1051 }
1052
1053 // =============================================================
1054 //
1055 // Matrix operations with a value
1056 //
1057 // =============================================================
1058
1059 /// Adds value to all non zero values in the matrix
1060 /// and return a new matrix
1061 ///
1062 /// Examples:
1063 ///
1064 /// ```
1065 /// use sukker::SparseMatrix;
1066 ///
1067 /// let sparse = SparseMatrix::<f32>::eye(3);
1068 /// let val: f32 = 4.5;
1069 ///
1070 /// let res = sparse.add_val(val);
1071 ///
1072 /// assert_eq!(res.get(0,0).unwrap(), 5.5);
1073 /// ```
1074 pub fn add_val(&self, value: T) -> Self {
1075 Self::sparse_helper_val(self, value, Operation::ADD)
1076 }
1077
1078 /// Subs value to all non zero values in the matrix
1079 /// and return a new matrix
1080 ///
1081 /// Examples:
1082 ///
1083 /// ```
1084 /// use sukker::SparseMatrix;
1085 ///
1086 /// let sparse = SparseMatrix::<f32>::eye(3);
1087 /// let val: f32 = 4.5;
1088 ///
1089 /// let res = sparse.sub_val(val);
1090 ///
1091 /// assert_eq!(res.get(0,0).unwrap(), -3.5);
1092 /// ```
1093 pub fn sub_val(&self, value: T) -> Self {
1094 Self::sparse_helper_val(self, value, Operation::SUB)
1095 }
1096
1097 /// Multiplies value to all non zero values in the matrix
1098 /// and return a new matrix
1099 ///
1100 /// Examples:
1101 ///
1102 /// ```
1103 /// use sukker::SparseMatrix;
1104 ///
1105 /// let sparse = SparseMatrix::<f32>::eye(3);
1106 /// let val: f32 = 4.5;
1107 ///
1108 /// let res = sparse.mul_val(val);
1109 ///
1110 /// assert_eq!(res.get(0,0).unwrap(), 4.5);
1111 /// ```
1112 pub fn mul_val(&self, value: T) -> Self {
1113 Self::sparse_helper_val(self, value, Operation::MUL)
1114 }
1115
1116 /// Divides value to all non zero values in the matrix
1117 /// and return a new matrix.
1118 ///
1119 /// Will panic if you choose to divide by zero
1120 ///
1121 /// Examples:
1122 ///
1123 /// ```
1124 /// use sukker::SparseMatrix;
1125 ///
1126 /// let sparse = SparseMatrix::<f32>::eye(3);
1127 /// let val: f32 = 4.0;
1128 ///
1129 /// let res = sparse.div_val(val);
1130 ///
1131 /// assert_eq!(res.get(0,0).unwrap(), 0.25);
1132 /// ```
1133 pub fn div_val(&self, value: T) -> Self {
1134 Self::sparse_helper_val(self, value, Operation::DIV)
1135 }
1136
1137 // =============================================================
1138 //
1139 // Matrix operations modyfing lhs with a value
1140 //
1141 // =============================================================
1142
1143 /// Adds value to all non zero elements in matrix
1144 ///
1145 /// Examples:
1146 ///
1147 /// ```
1148 /// use sukker::SparseMatrix;
1149 ///
1150 /// let mut sparse = SparseMatrix::<f64>::eye(3);
1151 /// let val = 10.0;
1152 ///
1153 /// sparse.add_val_self(val);
1154 ///
1155 /// assert_eq!(sparse.get(0,0).unwrap(), 11.0);
1156 /// ```
1157 pub fn add_val_self(&mut self, value: T) {
1158 Self::sparse_helper_self_val(self, value, Operation::ADD)
1159 }
1160
1161 /// Subtracts value to all non zero elements in matrix
1162 ///
1163 /// Examples:
1164 ///
1165 /// ```
1166 /// use sukker::SparseMatrix;
1167 ///
1168 /// let mut sparse = SparseMatrix::<f64>::eye(3);
1169 /// let val = 10.0;
1170 ///
1171 /// sparse.sub_val_self(val);
1172 ///
1173 /// assert_eq!(sparse.get(0,0).unwrap(), -9.0);
1174 /// ```
1175 pub fn sub_val_self(&mut self, value: T) {
1176 Self::sparse_helper_self_val(self, value, Operation::SUB)
1177 }
1178
1179 /// Multiplies value to all non zero elements in matrix
1180 ///
1181 /// Examples:
1182 ///
1183 /// ```
1184 /// use sukker::SparseMatrix;
1185 ///
1186 /// let mut sparse = SparseMatrix::<f64>::eye(3);
1187 /// let val = 10.0;
1188 ///
1189 /// sparse.mul_val_self(val);
1190 ///
1191 /// assert_eq!(sparse.get(0,0).unwrap(), 10.0);
1192 /// ```
1193 pub fn mul_val_self(&mut self, value: T) {
1194 Self::sparse_helper_self_val(self, value, Operation::MUL)
1195 }
1196
1197 /// Divides all non zero elemnts in matrix by value in-place
1198 ///
1199 /// Will panic if you choose to divide by zero
1200 ///
1201 /// Examples:
1202 ///
1203 /// ```
1204 /// use sukker::SparseMatrix;
1205 ///
1206 /// let mut sparse = SparseMatrix::<f64>::eye(3);
1207 /// let val = 10.0;
1208 ///
1209 /// sparse.div_val_self(val);
1210 ///
1211 /// assert_eq!(sparse.get(0,0).unwrap(), 0.1);
1212 /// ```
1213 pub fn div_val_self(&mut self, value: T) {
1214 Self::sparse_helper_self_val(self, value, Operation::DIV)
1215 }
1216
1217 /// Sparse matrix multiplication
1218 ///
1219 /// For two n x n matrices, we use this algorithm:
1220 /// https://theory.stanford.edu/~virgi/cs367/papers/sparsemult.pdf
1221 ///
1222 /// Else, we use this:
1223 /// link..
1224 ///
1225 /// In this example we have these two matrices:
1226 ///
1227 /// A:
1228 ///
1229 /// 0.0 2.0 0.0
1230 /// 4.0 6.0 0.0
1231 /// 0.0 0.0 8.0
1232 ///
1233 /// B:
1234 ///
1235 /// 2.0 0.0 0.0
1236 /// 4.0 8.0 0.0
1237 /// 8.0 6.0 0.0
1238 ///
1239 /// 0.0 24.0 0
1240 /// 8.0 72.0 0
1241 /// 0.0 0.0 48.0
1242 ///
1243 /// Examples
1244 ///
1245 /// ```
1246 /// use std::collections::HashMap;
1247 /// use sukker::{SparseMatrix, SparseMatrixData};
1248 ///
1249 /// let mut indexes: SparseMatrixData<f64> = HashMap::new();
1250 ///
1251 /// indexes.insert((0, 0), 2.0);
1252 /// indexes.insert((0, 1), 2.0);
1253 /// indexes.insert((1, 0), 2.0);
1254 /// indexes.insert((1, 1), 2.0);
1255 ///
1256 /// let sparse = SparseMatrix::<f64>::new(indexes, (2, 2));
1257 ///
1258 /// let mut indexes2: SparseMatrixData<f64> = HashMap::new();
1259 ///
1260 /// indexes2.insert((0, 0), 2.0);
1261 /// indexes2.insert((0, 1), 2.0);
1262 /// indexes2.insert((1, 0), 2.0);
1263 /// indexes2.insert((1, 1), 2.0);
1264 ///
1265 /// let sparse2 = SparseMatrix::<f64>::new(indexes2, (2, 2));
1266 ///
1267 /// let res = sparse.matmul_sparse(&sparse2).unwrap();
1268 ///
1269 /// assert_eq!(res.at(0, 0), 8.0);
1270 /// assert_eq!(res.at(0, 1), 8.0);
1271 /// assert_eq!(res.at(1, 0), 8.0);
1272 /// assert_eq!(res.at(1, 1), 8.0);
1273 /// ```
1274 pub fn matmul_sparse(&self, other: &Self) -> Result<Self, MatrixError> {
1275 if self.ncols != other.nrows {
1276 return Err(MatrixError::MatrixMultiplicationDimensionMismatchError.into());
1277 }
1278
1279 if self.shape() == other.shape() {
1280 return Ok(self.matmul_sparse_nn(other));
1281 }
1282
1283 Ok(self.matmul_sparse_mnnp(other))
1284 }
1285}
1286
1287/// Predicates for sparse matrices
1288impl<'a, T> SparseMatrix<'a, T>
1289where
1290 T: MatrixElement,
1291 <T as FromStr>::Err: Error + 'static,
1292 Vec<T>: IntoParallelIterator,
1293 Vec<&'a T>: IntoParallelRefIterator<'a>,
1294{
1295 /// Returns whether or not predicate holds for all values
1296 ///
1297 /// # Examples
1298 ///
1299 /// ```
1300 /// use sukker::SparseMatrix;
1301 ///
1302 /// let sparse = SparseMatrix::<i32>::eye(3);
1303 ///
1304 /// assert_eq!(sparse.shape(), (3,3));
1305 /// assert_eq!(sparse.all(|(idx, val)| val >= 0), true);
1306 /// ```
1307 pub fn all<F>(&self, pred: F) -> bool
1308 where
1309 F: Fn((Shape, T)) -> bool + Sync + Send,
1310 {
1311 self.data.clone().into_par_iter().all(pred)
1312 }
1313
1314 /// Returns whether or not predicate holds for any
1315 ///
1316 /// # Examples
1317 ///
1318 /// ```
1319 /// use sukker::SparseMatrix;
1320 ///
1321 /// let sparse = SparseMatrix::<i32>::eye(3);
1322 ///
1323 /// assert_eq!(sparse.shape(), (3,3));
1324 /// assert_eq!(sparse.any(|(_, val)| val == 1), true);
1325 /// ```
1326 pub fn any<F>(&self, pred: F) -> bool
1327 where
1328 F: Fn((Shape, T)) -> bool + Sync + Send,
1329 {
1330 self.data.clone().into_par_iter().any(pred)
1331 }
1332
1333 /// Counts all occurances where predicate holds
1334 ///
1335 /// # Examples
1336 ///
1337 /// ```
1338 /// use sukker::SparseMatrix;
1339 ///
1340 /// let sparse = SparseMatrix::<i32>::eye(3);
1341 ///
1342 /// assert_eq!(sparse.count_where(|(_, &val)| val == 1), 3);
1343 /// ```
1344 pub fn count_where<F>(&'a self, pred: F) -> usize
1345 where
1346 F: Fn((&Shape, &T)) -> bool + Sync,
1347 {
1348 self.data.par_iter().filter(|&e| pred(e)).count()
1349 }
1350
1351 /// Sums all occurances where predicate holds
1352 ///
1353 /// # Examples
1354 ///
1355 /// ```
1356 /// use sukker::SparseMatrix;
1357 ///
1358 /// let sparse = SparseMatrix::<f32>::eye(3);
1359 ///
1360 /// assert_eq!(sparse.sum_where(|(&(i, j), &val)| val == 1.0 && i > 0), 2.0);
1361 /// ```
1362 pub fn sum_where<F>(&self, pred: F) -> T
1363 where
1364 F: Fn((&Shape, &T)) -> bool + Sync,
1365 {
1366 let mut res = T::zero();
1367 for (idx, elem) in self.data.iter() {
1368 if pred((idx, elem)) {
1369 res += elem
1370 }
1371 }
1372
1373 res
1374 }
1375
1376 /// Sets all elements where predicate holds true.
1377 /// The new value is to be set inside the predicate as well
1378 ///
1379 /// # Examples
1380 ///
1381 /// ```
1382 /// ```
1383 pub fn set_where<F>(&mut self, mut pred: F)
1384 where
1385 F: FnMut((&Shape, &mut T)) + Sync + Send,
1386 {
1387 self.data.iter_mut().for_each(|e| pred(e));
1388 }
1389
1390 /// Finds value of first occurance where predicate holds true
1391 ///
1392 /// # Examples
1393 ///
1394 /// ```
1395 /// ```
1396 pub fn find<F>(&self, pred: F) -> Option<T>
1397 where
1398 F: Fn((&Shape, &T)) -> bool + Sync,
1399 {
1400 for entry in &self.data {
1401 if pred(entry) {
1402 return Some(*entry.1);
1403 }
1404 }
1405
1406 None
1407 }
1408
1409 /// Finds all values where predicates holds if possible
1410 ///
1411 /// # Examples
1412 ///
1413 /// ```
1414 /// ```
1415 fn find_all<F>(&self, pred: F) -> Option<Vec<T>>
1416 where
1417 F: Fn((&Shape, &T)) -> bool + Sync,
1418 {
1419 let mut idxs: Vec<T> = Vec::new();
1420 for entry in &self.data {
1421 if pred(entry) {
1422 idxs.push(*entry.1);
1423 }
1424 }
1425
1426 if !idxs.is_empty() {
1427 Some(idxs)
1428 } else {
1429 None
1430 }
1431 }
1432
1433 /// Finds indices of first occurance where predicate holds true
1434 ///
1435 /// # Examples
1436 ///
1437 /// ```
1438 /// ```
1439 pub fn position<F>(&self, pred: F) -> Option<Shape>
1440 where
1441 F: Fn((&Shape, &T)) -> bool + Sync,
1442 {
1443 for entry in &self.data {
1444 if pred(entry) {
1445 return Some(*entry.0);
1446 }
1447 }
1448
1449 None
1450 }
1451
1452 /// Finds all positions where predicates holds if possible
1453 ///
1454 /// # Examples
1455 ///
1456 /// ```
1457 /// ```
1458 fn positions<F>(&self, pred: F) -> Option<Vec<Shape>>
1459 where
1460 F: Fn((&Shape, &T)) -> bool + Sync,
1461 {
1462 let mut idxs: Vec<Shape> = Vec::new();
1463 for entry in &self.data {
1464 if pred(entry) {
1465 idxs.push(*entry.0);
1466 }
1467 }
1468
1469 if !idxs.is_empty() {
1470 Some(idxs)
1471 } else {
1472 None
1473 }
1474 }
1475}