use std::ops::{Add, Neg, Sub};
use crate::{scalar::Scalar, CscMatrix, CsrMatrix, DokMatrix};
#[derive(Clone, Debug)]
pub struct CooMatrix<T: Scalar> {
nrows: usize,
ncols: usize,
entries: Vec<(usize, usize, T)>,
}
#[derive(Clone, Debug)]
pub struct Iter<'iter, T> {
iter: std::slice::Iter<'iter, (usize, usize, T)>,
}
#[derive(Debug)]
pub struct IterMut<'iter, T> {
iter: std::slice::IterMut<'iter, (usize, usize, T)>,
}
#[derive(Debug)]
pub struct IntoIter<T> {
iter: std::vec::IntoIter<(usize, usize, T)>,
}
impl<T: Scalar> CooMatrix<T> {
pub fn new(nrows: usize, ncols: usize) -> Self {
assert!(nrows > 0);
assert!(ncols > 0);
Self {
nrows,
ncols,
entries: Vec::new(),
}
}
pub fn eye(size: usize) -> Self {
assert!(size > 0);
Self {
nrows: size,
ncols: size,
entries: (0..size).map(|i| (i, i, T::one())).collect(),
}
}
pub fn with_capacity(nrows: usize, ncols: usize, capacity: usize) -> Self {
assert!(nrows > 0);
assert!(ncols > 0);
Self {
nrows,
ncols,
entries: Vec::with_capacity(capacity),
}
}
pub fn with_entries<I>(nrows: usize, ncols: usize, entries: I) -> Self
where
I: IntoIterator<Item = (usize, usize, T)>,
{
assert!(nrows > 0);
assert!(ncols > 0);
let entries: Vec<_> = entries.into_iter().collect();
for (row, col, _) in entries.iter() {
assert!(*row < nrows);
assert!(*col < ncols);
}
Self {
nrows,
ncols,
entries,
}
}
pub fn with_triplets<R, C, V>(
nrows: usize,
ncols: usize,
rowind: R,
colind: C,
values: V,
) -> Self
where
R: IntoIterator<Item = usize>,
C: IntoIterator<Item = usize>,
V: IntoIterator<Item = T>,
{
assert!(nrows > 0);
assert!(ncols > 0);
let rowind: Vec<_> = rowind.into_iter().collect();
let colind: Vec<_> = colind.into_iter().collect();
let values: Vec<_> = values.into_iter().collect();
assert!(rowind.len() == values.len());
assert!(colind.len() == values.len());
for row in rowind.iter() {
assert!(*row < nrows);
}
for col in colind.iter() {
assert!(*col < ncols);
}
let mut entries = Vec::with_capacity(values.len());
for (idx, value) in values.into_iter().enumerate() {
entries.push((rowind[idx], colind[idx], value))
}
Self {
nrows,
ncols,
entries,
}
}
pub fn nrows(&self) -> usize {
self.nrows
}
pub fn ncols(&self) -> usize {
self.ncols
}
pub fn shape(&self) -> (usize, usize) {
(self.nrows, self.ncols)
}
pub fn length(&self) -> usize {
self.entries.len()
}
pub fn capacity(&self) -> usize {
self.entries.capacity()
}
pub fn get(&self, index: usize) -> Option<(&usize, &usize, &T)> {
self.entries
.get(index)
.map(|(row, col, value)| (row, col, value))
}
pub fn get_mut(&mut self, index: usize) -> Option<(&usize, &usize, &mut T)> {
self.entries
.get_mut(index)
.map(|(row, col, value)| (&*row, &*col, value))
}
pub fn push(&mut self, row: usize, col: usize, value: T) {
assert!(row < self.nrows);
assert!(col < self.ncols);
self.entries.push((row, col, value))
}
pub fn pop(&mut self) -> Option<(usize, usize, T)> {
self.entries.pop()
}
pub fn clear(&mut self) {
self.entries.clear()
}
pub fn iter(&self) -> Iter<T> {
Iter {
iter: self.entries.iter(),
}
}
pub fn iter_mut(&mut self) -> IterMut<T> {
IterMut {
iter: self.entries.iter_mut(),
}
}
pub fn transpose(&self) -> Self {
let entries = self.entries.iter().map(|&(r, c, v)| (c, r, v)).collect();
Self {
nrows: self.ncols(),
ncols: self.nrows(),
entries,
}
}
}
impl<T: Scalar> Extend<(usize, usize, T)> for CooMatrix<T> {
fn extend<I: IntoIterator<Item = (usize, usize, T)>>(&mut self, iter: I) {
let entries: Vec<_> = iter.into_iter().collect();
for (row, col, _) in &entries {
assert!(*row < self.nrows);
assert!(*col < self.ncols);
}
self.entries.extend(entries)
}
}
impl<T: Scalar> IntoIterator for CooMatrix<T> {
type Item = (usize, usize, T);
type IntoIter = IntoIter<T>;
fn into_iter(self) -> Self::IntoIter {
IntoIter {
iter: self.entries.into_iter(),
}
}
}
impl<'iter, T> Iterator for Iter<'iter, T> {
type Item = (usize, usize, &'iter T);
fn next(&mut self) -> Option<Self::Item> {
self.iter.next().map(|(r, c, v)| (*r, *c, v))
}
}
impl<'iter, T> Iterator for IterMut<'iter, T> {
type Item = (usize, usize, &'iter mut T);
fn next(&mut self) -> Option<Self::Item> {
self.iter.next().map(|(r, c, v)| (*r, *c, v))
}
}
impl<T: Scalar> Iterator for IntoIter<T> {
type Item = (usize, usize, T);
fn next(&mut self) -> Option<Self::Item> {
self.iter.next()
}
}
impl<T: Scalar> From<CscMatrix<T>> for CooMatrix<T> {
fn from(csc: CscMatrix<T>) -> Self {
CooMatrix {
nrows: csc.nrows(),
ncols: csc.ncols(),
entries: csc.into_iter().collect(),
}
}
}
impl<T: Scalar> From<&CscMatrix<T>> for CooMatrix<T> {
fn from(csc: &CscMatrix<T>) -> Self {
CooMatrix {
nrows: csc.nrows(),
ncols: csc.ncols(),
entries: csc.iter().map(|(r, c, v)| (r, c, *v)).collect(),
}
}
}
impl<T: Scalar> From<CsrMatrix<T>> for CooMatrix<T> {
fn from(csr: CsrMatrix<T>) -> Self {
CooMatrix {
nrows: csr.nrows(),
ncols: csr.ncols(),
entries: csr.into_iter().collect(),
}
}
}
impl<T: Scalar> From<&CsrMatrix<T>> for CooMatrix<T> {
fn from(csr: &CsrMatrix<T>) -> Self {
CooMatrix {
nrows: csr.nrows(),
ncols: csr.ncols(),
entries: csr.iter().map(|(r, c, v)| (r, c, *v)).collect(),
}
}
}
impl<T: Scalar> From<DokMatrix<T>> for CooMatrix<T> {
fn from(dok: DokMatrix<T>) -> Self {
CooMatrix {
nrows: dok.nrows(),
ncols: dok.ncols(),
entries: dok.into_iter().collect(),
}
}
}
impl<T: Scalar> From<&DokMatrix<T>> for CooMatrix<T> {
fn from(dok: &DokMatrix<T>) -> Self {
CooMatrix {
nrows: dok.nrows(),
ncols: dok.ncols(),
entries: dok.iter().map(|(r, c, v)| (r, c, *v)).collect(),
}
}
}
impl<T: Scalar> Add for &CooMatrix<T> {
type Output = CooMatrix<T>;
fn add(self, rhs: Self) -> Self::Output {
assert_eq!(self.nrows(), rhs.nrows());
assert_eq!(self.ncols(), rhs.ncols());
let entries: Vec<_> = self
.entries
.iter()
.chain(rhs.entries.iter())
.copied()
.collect();
CooMatrix {
nrows: self.nrows(),
ncols: self.ncols(),
entries,
}
}
}
impl<T: Scalar> Sub for &CooMatrix<T> {
type Output = CooMatrix<T>;
fn sub(self, rhs: Self) -> Self::Output {
assert_eq!(self.nrows(), rhs.nrows());
assert_eq!(self.ncols(), rhs.ncols());
let entries: Vec<_> = self
.entries
.iter()
.copied()
.chain(rhs.entries.iter().map(|&(r, c, v)| (r, c, -v)))
.collect();
CooMatrix {
nrows: self.nrows(),
ncols: self.ncols(),
entries,
}
}
}
impl<T: Scalar> Neg for &CooMatrix<T> {
type Output = CooMatrix<T>;
fn neg(self) -> Self::Output {
let entries: Vec<_> = self.entries.iter().map(|&(r, c, v)| (r, c, -v)).collect();
CooMatrix {
nrows: self.nrows(),
ncols: self.ncols(),
entries,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn new() {
let matrix: CooMatrix<f64> = CooMatrix::new(1, 2);
assert_eq!(matrix.nrows(), 1);
assert_eq!(matrix.ncols(), 2);
assert_eq!(matrix.length(), 0);
assert_eq!(matrix.capacity(), 0);
}
#[test]
#[should_panic]
fn new_invalid_nrows() {
CooMatrix::<f64>::new(0, 1);
}
#[test]
#[should_panic]
fn new_invalid_ncols() {
CooMatrix::<f64>::new(1, 0);
}
#[test]
fn with_capacity() {
let matrix: CooMatrix<f64> = CooMatrix::with_capacity(1, 2, 4);
assert_eq!(matrix.nrows(), 1);
assert_eq!(matrix.ncols(), 2);
assert_eq!(matrix.length(), 0);
assert!(matrix.capacity() >= 4);
}
#[test]
#[should_panic]
fn with_capacity_invalid_nrows() {
CooMatrix::<f64>::with_capacity(0, 1, 1);
}
#[test]
#[should_panic]
fn with_capacity_invalid_ncols() {
CooMatrix::<f64>::with_capacity(0, 1, 1);
}
#[test]
fn with_entries() {
let entries = vec![(0, 0, 1.0), (1, 0, 2.0), (0, 2, 3.0)];
let matrix = CooMatrix::with_entries(2, 3, entries);
assert_eq!(matrix.length(), 3);
assert!(matrix.capacity() >= 3);
assert_eq!(matrix.get(0), Some((&0, &0, &1.0)));
assert_eq!(matrix.get(1), Some((&1, &0, &2.0)));
assert_eq!(matrix.get(2), Some((&0, &2, &3.0)));
assert!(matrix.get(3).is_none());
}
#[test]
#[should_panic]
fn with_entries_invalid_nrows() {
CooMatrix::<f64>::with_entries(0, 1, vec![]);
}
#[test]
#[should_panic]
fn with_entries_invalid_ncols() {
CooMatrix::<f64>::with_entries(1, 0, vec![]);
}
#[test]
#[should_panic]
fn with_entries_invalid_row() {
CooMatrix::<f64>::with_entries(1, 1, vec![(1, 0, 1.0)]);
}
#[test]
#[should_panic]
fn with_entries_invalid_col() {
CooMatrix::<f64>::with_entries(1, 1, vec![(0, 1, 1.0)]);
}
#[test]
fn with_triplets() {
let rowind = vec![0, 1];
let colind = vec![1, 0];
let values = vec![-1.0, 1.0];
let matrix = CooMatrix::with_triplets(2, 2, rowind, colind, values);
assert_eq!(matrix.length(), 2);
assert!(matrix.capacity() >= 2);
assert_eq!(matrix.get(0), Some((&0, &1, &-1.0)));
assert_eq!(matrix.get(1), Some((&1, &0, &1.0)));
assert!(matrix.get(2).is_none());
}
#[test]
#[should_panic]
fn with_triplets_invalid_nrows() {
CooMatrix::<f64>::with_triplets(0, 1, vec![], vec![], vec![]);
}
#[test]
#[should_panic]
fn with_triplets_invalid_ncols() {
CooMatrix::<f64>::with_triplets(1, 0, vec![], vec![], vec![]);
}
#[test]
#[should_panic]
fn with_triplets_invalid_triplets_rowind_length() {
CooMatrix::<f64>::with_triplets(1, 1, vec![], vec![0], vec![1.0]);
}
#[test]
#[should_panic]
fn with_triplets_invalid_triplets_colind_length() {
CooMatrix::<f64>::with_triplets(1, 1, vec![0], vec![], vec![1.0]);
}
#[test]
#[should_panic]
fn with_triplets_invalid_triplets_values_length() {
CooMatrix::<f64>::with_triplets(1, 1, vec![0], vec![0], vec![]);
}
#[test]
#[should_panic]
fn with_triplets_invalid_row() {
CooMatrix::<f64>::with_triplets(1, 1, vec![1], vec![0], vec![1.0]);
}
#[test]
#[should_panic]
fn with_triplets_invalid_col() {
CooMatrix::<f64>::with_triplets(1, 1, vec![0], vec![1], vec![1.0]);
}
#[test]
fn nrows() {
let matrix: CooMatrix<f64> = CooMatrix::new(1, 2);
assert_eq!(matrix.nrows(), 1);
}
#[test]
fn ncols() {
let matrix: CooMatrix<f64> = CooMatrix::new(1, 2);
assert_eq!(matrix.ncols(), 2);
}
#[test]
fn shape() {
let matrix: CooMatrix<f64> = CooMatrix::new(1, 2);
assert_eq!(matrix.shape(), (1, 2));
}
#[test]
fn length() {
let mut matrix: CooMatrix<f64> = CooMatrix::new(1, 1);
assert_eq!(matrix.length(), 0);
matrix.push(0, 0, 1.0);
assert_eq!(matrix.length(), 1);
}
#[test]
fn capacity() {
let mut matrix: CooMatrix<f64> = CooMatrix::new(1, 1);
assert_eq!(matrix.capacity(), 0);
matrix.push(0, 0, 1.0);
assert!(matrix.capacity() >= 1);
}
#[test]
fn get() {
let entries = vec![(0, 0, 1.0)];
let matrix = CooMatrix::with_entries(1, 1, entries);
assert_eq!(matrix.get(0), Some((&0, &0, &1.0)));
assert!(matrix.get(1).is_none());
}
#[test]
fn get_mut() {
let entries = vec![(0, 0, 1.0)];
let mut matrix = CooMatrix::with_entries(1, 1, entries);
assert_eq!(matrix.get_mut(0), Some((&0, &0, &mut 1.0)));
assert!(matrix.get_mut(1).is_none());
}
#[test]
fn push() {
let mut matrix: CooMatrix<f64> = CooMatrix::new(1, 1);
matrix.push(0, 0, 1.0);
assert_eq!(matrix.get(0), Some((&0, &0, &1.0)));
}
#[test]
#[should_panic]
fn push_invalid_row() {
let mut matrix: CooMatrix<f64> = CooMatrix::new(1, 1);
matrix.push(1, 0, 1.0);
}
#[test]
#[should_panic]
fn push_invalid_col() {
let mut matrix: CooMatrix<f64> = CooMatrix::new(1, 1);
matrix.push(0, 1, 1.0);
}
#[test]
fn pop() {
let entries = vec![(0, 0, 1.0)];
let mut matrix = CooMatrix::with_entries(1, 1, entries);
assert_eq!(matrix.pop(), Some((0, 0, 1.0)));
assert_eq!(matrix.length(), 0);
assert!(matrix.pop().is_none());
}
#[test]
fn clear() {
let entries = vec![(0, 0, 1.0)];
let mut matrix = CooMatrix::with_entries(1, 1, entries);
matrix.clear();
assert_eq!(matrix.length(), 0);
}
#[test]
fn iter() {
let entries = vec![(0, 0, 1.0), (1, 0, 2.0), (0, 2, 3.0)];
let matrix = CooMatrix::with_entries(2, 3, entries);
let mut iter = matrix.iter();
assert_eq!(iter.next(), Some((0, 0, &1.0)));
assert_eq!(iter.next(), Some((1, 0, &2.0)));
assert_eq!(iter.next(), Some((0, 2, &3.0)));
assert!(iter.next().is_none());
}
#[test]
fn iter_mut() {
let entries = vec![(0, 0, 1.0), (1, 0, 2.0), (0, 2, 3.0)];
let mut matrix = CooMatrix::with_entries(2, 3, entries);
let mut iter = matrix.iter_mut();
assert_eq!(iter.next(), Some((0, 0, &mut 1.0)));
assert_eq!(iter.next(), Some((1, 0, &mut 2.0)));
assert_eq!(iter.next(), Some((0, 2, &mut 3.0)));
assert!(iter.next().is_none());
}
#[test]
fn extend() {
let entries = vec![(0, 0, 1.0), (1, 0, 2.0), (0, 2, 3.0)];
let mut matrix = CooMatrix::new(2, 3);
matrix.extend(entries);
assert_eq!(matrix.length(), 3);
assert!(matrix.capacity() >= 3);
assert_eq!(matrix.get(0), Some((&0, &0, &1.0)));
assert_eq!(matrix.get(1), Some((&1, &0, &2.0)));
assert_eq!(matrix.get(2), Some((&0, &2, &3.0)));
assert!(matrix.get(3).is_none());
}
#[test]
fn into_iter() {
let entries = vec![(0, 0, 1.0), (1, 0, 2.0), (0, 2, 3.0)];
let matrix = CooMatrix::with_entries(2, 3, entries);
let mut iter = matrix.into_iter();
assert_eq!(iter.next(), Some((0, 0, 1.0)));
assert_eq!(iter.next(), Some((1, 0, 2.0)));
assert_eq!(iter.next(), Some((0, 2, 3.0)));
assert!(iter.next().is_none());
}
#[test]
fn add() {
let entries = vec![(0, 0, 1.0), (1, 0, 2.0), (0, 2, 3.0)];
let lhs = CooMatrix::with_entries(2, 3, entries);
let entries = vec![(0, 0, 2.0), (1, 1, 4.0), (1, 2, 6.0)];
let rhs = CooMatrix::with_entries(2, 3, entries);
let mat = &lhs + &rhs;
let mut iter = mat.into_iter();
assert_eq!(iter.next(), Some((0, 0, 1.0)));
assert_eq!(iter.next(), Some((1, 0, 2.0)));
assert_eq!(iter.next(), Some((0, 2, 3.0)));
assert_eq!(iter.next(), Some((0, 0, 2.0)));
assert_eq!(iter.next(), Some((1, 1, 4.0)));
assert_eq!(iter.next(), Some((1, 2, 6.0)));
assert!(iter.next().is_none());
}
#[test]
fn sub() {
let entries = vec![(0, 0, 1.0), (1, 0, 2.0), (0, 2, 3.0)];
let lhs = CooMatrix::with_entries(2, 3, entries);
let entries = vec![(0, 0, 2.0), (1, 1, 4.0), (1, 2, 6.0)];
let rhs = CooMatrix::with_entries(2, 3, entries);
let mat = &lhs - &rhs;
let mut iter = mat.into_iter();
assert_eq!(iter.next(), Some((0, 0, 1.0)));
assert_eq!(iter.next(), Some((1, 0, 2.0)));
assert_eq!(iter.next(), Some((0, 2, 3.0)));
assert_eq!(iter.next(), Some((0, 0, -2.0)));
assert_eq!(iter.next(), Some((1, 1, -4.0)));
assert_eq!(iter.next(), Some((1, 2, -6.0)));
assert!(iter.next().is_none());
}
#[test]
fn neg() {
let entries = vec![(0, 0, 1.0), (1, 0, 2.0), (0, 2, 3.0)];
let mat = -&CooMatrix::with_entries(2, 3, entries);
let mut iter = mat.into_iter();
assert_eq!(iter.next(), Some((0, 0, -1.0)));
assert_eq!(iter.next(), Some((1, 0, -2.0)));
assert_eq!(iter.next(), Some((0, 2, -3.0)));
assert!(iter.next().is_none());
}
}