use std::ops::Add;
use num_traits::One;
use crate::error::{Error, Result};
#[derive(Debug, Clone)]
pub struct SMatrix<T> {
m: usize, n: usize, mlist: Vec<Vec<u16>>, nlist: Vec<Vec<u16>>, mvals: Vec<Vec<T>>, nvals: Vec<Vec<T>>, num_mlist: Vec<usize>, num_nlist: Vec<usize>, max_num_mlist: usize, max_num_nlist: usize, }
impl<T: Default + Copy + PartialEq + std::fmt::Display + One + Add<Output = T>> SMatrix<T> {
pub fn new(m: usize, n: usize) -> Result<Self> {
if m == 0 || n == 0 {
return Err(Error::Config("smatrix_create(), dimensions must be greater than zero".to_string()));
}
Ok(SMatrix {
m,
n,
mlist: vec![Vec::new(); m],
nlist: vec![Vec::new(); n],
mvals: vec![Vec::new(); m],
nvals: vec![Vec::new(); n],
num_mlist: vec![0; m],
num_nlist: vec![0; n],
max_num_mlist: 0,
max_num_nlist: 0,
})
}
pub fn from_array(v: &[T], m: usize, n: usize) -> Result<Self> {
let mut q = Self::new(m, n)?;
for i in 0..m {
for j in 0..n {
if v[i * n + j] != T::default() {
q.set(i, j, v[i * n + j]);
}
}
}
Ok(q)
}
pub fn print(&self) {
println!("dims : {} {}", self.m, self.n);
println!("max : {} {}", self.max_num_mlist, self.max_num_nlist);
print!("rows :");
for i in 0..self.m {
print!(" {}", self.num_mlist[i]);
}
println!();
print!("cols :");
for j in 0..self.n {
print!(" {}", self.num_nlist[j]);
}
println!();
println!("row indices:");
for i in 0..self.m {
if self.num_mlist[i] == 0 {
continue;
}
print!(" {:3} :", i);
for j in 0..self.num_mlist[i] {
print!(" {}", self.mlist[i][j]);
}
println!();
}
println!("column indices:");
for j in 0..self.n {
if self.num_nlist[j] == 0 {
continue;
}
print!(" {:3} :", j);
for i in 0..self.num_nlist[j] {
print!(" {}", self.nlist[j][i]);
}
println!();
}
println!("row values:");
for i in 0..self.m {
print!(" {:3} :", i);
for j in 0..self.num_mlist[i] {
print!(" {:6.2}", self.mvals[i][j]);
}
println!();
}
println!("column values:");
for j in 0..self.n {
print!(" {:3} :", j);
for i in 0..self.num_nlist[j] {
print!(" {:6.2}", self.nvals[j][i]);
}
println!();
}
}
pub fn print_expanded(&self) {
for i in 0..self.m {
let mut t = 0;
for j in 0..self.n {
if t == self.num_mlist[i] {
print!("{:6.2} ", T::default());
} else if self.mlist[i][t] == j as u16 {
print!("{:6.2} ", self.mvals[i][t]);
t += 1;
} else {
print!("{:6.2} ", T::default());
}
}
println!();
}
}
pub fn size(&self) -> (usize, usize) {
(self.m, self.n)
}
pub fn clear(&mut self) {
for i in 0..self.m {
for j in 0..self.num_mlist[i] {
self.mvals[i][j] = T::default();
}
}
for j in 0..self.n {
for i in 0..self.num_nlist[j] {
self.nvals[j][i] = T::default();
}
}
}
pub fn reset(&mut self) {
for i in 0..self.m {
self.num_mlist[i] = 0;
}
for j in 0..self.n {
self.num_nlist[j] = 0;
}
self.max_num_mlist = 0;
self.max_num_nlist = 0;
}
pub fn isset(&self, m: usize, n: usize) -> Result<bool> {
if m >= self.m || n >= self.n {
return Err(Error::Range(format!("smatrix_isset({},{}), index exceeds matrix dimension ({},{})", m, n, self.m, self.n)));
}
Ok(self.mlist[m].contains(&(n as u16)))
}
fn insert(&mut self, m: usize, n: usize, v: T) -> Result<()> {
if m >= self.m || n >= self.n {
return Err(Error::Range(format!("smatrix_insert({},{}), index exceeds matrix dimension ({},{})", m, n, self.m, self.n)));
}
if self.isset(m, n)? {
self.set(m, n, v);
return Ok(());
}
self.num_mlist[m] += 1;
self.num_nlist[n] += 1;
let mindex = Self::indexsearch(&self.mlist[m], self.num_mlist[m] - 1, n as u16);
let nindex = Self::indexsearch(&self.nlist[n], self.num_nlist[n] - 1, m as u16);
self.mlist[m].insert(mindex, n as u16);
self.nlist[n].insert(nindex, m as u16);
self.mvals[m].insert(mindex, v);
self.nvals[n].insert(nindex, v);
self.max_num_mlist = self.max_num_mlist.max(self.num_mlist[m]);
self.max_num_nlist = self.max_num_nlist.max(self.num_nlist[n]);
Ok(())
}
pub fn delete(&mut self, m: usize, n: usize) -> Result<()> {
if m > self.m || n > self.n {
return Err(Error::Range(format!("smatrix_delete({},{}), index exceeds matrix dimension ({},{})", m, n, self.m, self.n)));
}
if !self.isset(m, n)? {
return Ok(());
}
let mindex = self.mlist[m].iter().position(|&x| x == n as u16).unwrap();
self.mlist[m].remove(mindex);
let nindex = self.nlist[n].iter().position(|&x| x == m as u16).unwrap();
self.nlist[n].remove(nindex);
self.num_mlist[m] -= 1;
self.num_nlist[n] -= 1;
if self.max_num_mlist == self.num_mlist[m] + 1 {
self.reset_max_mlist();
}
if self.max_num_nlist == self.num_nlist[n] + 1 {
self.reset_max_nlist();
}
Ok(())
}
pub fn set(&mut self, m: usize, n: usize, v: T) {
if m >= self.m || n >= self.n {
panic!("smatrix_set({},{}), index exceeds matrix dimension ({},{})", m, n, self.m, self.n);
}
if !self.isset(m, n).unwrap() {
self.insert(m, n, v).unwrap();
return;
}
let mindex = self.mlist[m].iter().position(|&x| x == n as u16).unwrap();
self.mvals[m][mindex] = v;
let nindex = self.nlist[n].iter().position(|&x| x == m as u16).unwrap();
self.nvals[n][nindex] = v;
()
}
pub fn get(&self, m: usize, n: usize) -> T {
if m >= self.m || n >= self.n {
panic!("smatrix_get({},{}), index exceeds matrix dimension ({},{})", m, n, self.m, self.n);
}
if let Some(mindex) = self.mlist[m].iter().position(|&x| x == n as u16) {
self.mvals[m][mindex]
} else {
T::default()
}
}
pub fn eye(&mut self) {
self.reset();
let dmin = self.m.min(self.n);
for i in 0..dmin {
self.set(i, i, T::one());
}
}
pub fn mul(&self, b: &SMatrix<T>, c: &mut SMatrix<T>) -> Result<()> {
if c.m != self.m || c.n != b.n || self.n != b.m {
return Err(Error::Range("smatrix_mul(), invalid dimensions".to_string()));
}
c.clear();
for r in 0..c.m {
let nnz_a_row = self.num_mlist[r];
if nnz_a_row == 0 {
continue;
}
for col in 0..c.n {
let nnz_b_col = b.num_nlist[col];
let mut p = T::default();
let mut set_value = false;
let mut i = 0; let mut j = 0; while i < nnz_a_row && j < nnz_b_col {
let ca = self.mlist[r][i];
let rb = b.nlist[col][j];
if ca == rb {
p = p + self.mvals[r][i] * b.nvals[col][j];
set_value = true;
i += 1;
j += 1;
} else if ca < rb {
i += 1; } else {
j += 1; }
}
if set_value {
c.set(r, col, p);
}
}
}
Ok(())
}
pub fn vmul(&self, x: &[T], y: &mut [T]) {
for i in 0..self.m {
y[i] = T::default();
}
for i in 0..self.m {
let mut p = T::default();
for j in 0..self.num_mlist[i] {
let col_index = self.mlist[i][j] as usize;
p = p + self.mvals[i][j] * x[col_index];
}
y[i] = p;
}
}
fn reset_max_mlist(&mut self) {
self.max_num_mlist = self.num_mlist.iter().max().copied().unwrap_or(0);
}
fn reset_max_nlist(&mut self) {
self.max_num_nlist = self.num_nlist.iter().max().copied().unwrap_or(0);
}
fn indexsearch(v: &[u16], n: usize, x: u16) -> usize {
let mut i = 0;
while i < n && v[i] <= x {
i += 1;
}
i
}
}
impl SMatrix<u8> {
pub fn mul_f32(&self, x: &[f32], mx: usize, nx: usize, y: &mut [f32], my: usize, ny: usize) -> Result<()> {
if my != self.m || ny != nx || self.n != mx {
return Err(Error::Range("smatrix_mul(), invalid dimensions".to_string()));
}
y.fill(0.0);
for i in 0..self.m {
for &j in &self.mlist[i][..self.num_mlist[i]] {
for k in 0..ny {
y[i * ny + k] += x[j as usize * nx + k];
}
}
}
Ok(())
}
pub fn vmul_f32(&self, x: &[f32], y: &mut [f32]) {
for i in 0..self.m {
y[i] = 0.0;
for &j in &self.mlist[i][..self.num_mlist[i]] {
y[i] += x[j as usize];
}
}
}
pub fn wrap_bool(v: u8) -> u8 {
v % 2
}
pub fn wrap_bools(v: &mut [u8]) {
for i in 0..v.len() {
v[i] = v[i] % 2;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
use test_macro::autotest_annotate;
#[test]
#[autotest_annotate(autotest_smatrixf_vmul)]
fn test_smatrixf_vmul() {
let tol = 1e-6f32;
let mut a = SMatrix::<f32>::new(4, 5).unwrap();
a.set(0, 4, 4.0);
a.set(2, 3, 3.0);
a.set(3, 0, 2.0);
a.set(3, 4, 0.0);
a.set(3, 4, 1.0);
let x = vec![7.0, 1.0, 5.0, 2.0, 2.0];
let y_test = vec![8.0, 0.0, 6.0, 16.0];
let mut y = vec![0.0; 4];
a.vmul(&x, &mut y);
for i in 0..4 {
assert_relative_eq!(y[i], y_test[i], epsilon = tol);
}
}
#[test]
#[autotest_annotate(autotest_smatrixf_mul)]
fn test_smatrixf_mul() {
let tol = 1e-6f32;
let mut a = SMatrix::<f32>::new(4, 5).unwrap();
let mut b = SMatrix::<f32>::new(5, 3).unwrap();
let mut c = SMatrix::<f32>::new(4, 3).unwrap();
a.set(0, 4, 4.0);
a.set(2, 3, 3.0);
a.set(3, 0, 2.0);
a.set(3, 4, 0.0);
a.set(3, 4, 1.0);
b.set(0, 0, 7.0);
b.set(0, 1, 6.0);
b.set(3, 1, 5.0);
b.set(4, 0, 2.0);
a.mul(&b, &mut c).unwrap();
let c_test = vec![
8.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 15.0, 0.0,
16.0, 12.0, 0.0
];
for i in 0..4 {
for j in 0..3 {
assert_relative_eq!(c.get(i, j), c_test[i * 3 + j], epsilon = tol);
}
}
}
#[test]
#[autotest_annotate(autotest_smatrixb_vmul)]
fn test_smatrixb_vmul() {
let mut a = SMatrix::<u8>::new(8, 12).unwrap();
a.set(0, 0, 1);
a.set(2, 0, 1);
a.set(6, 0, 1);
a.set(3, 2, 1);
a.set(6, 2, 1);
a.set(7, 2, 1);
a.set(1, 3, 1);
a.set(7, 5, 1);
a.set(3, 6, 1);
a.set(5, 6, 1);
a.set(7, 6, 1);
a.set(3, 7, 1);
a.set(2, 8, 1);
a.set(5, 8, 1);
a.set(2, 9, 1);
a.set(5, 10, 1);
a.set(6, 10, 1);
a.set(6, 11, 1);
let x = vec![1, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0, 1];
let y_test = vec![1, 0, 1, 1, 0, 1, 0, 0];
let mut y = vec![0; 8];
a.vmul(&x, &mut y);
SMatrix::<u8>::wrap_bools(&mut y);
assert_eq!(y, y_test);
}
#[test]
#[autotest_annotate(autotest_smatrixb_mul)]
fn test_smatrixb_mul() {
let a_test: Vec<u8> = vec![
0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0,
0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0
];
let b_test: Vec<u8> = vec![
1, 1, 0, 0, 0,
0, 0, 0, 0, 1,
0, 0, 0, 0, 0,
0, 0, 0, 0, 0,
0, 0, 0, 0, 0,
0, 0, 0, 0, 1,
0, 0, 0, 1, 0,
0, 0, 0, 1, 0,
0, 0, 0, 0, 0,
0, 1, 0, 0, 1,
1, 0, 0, 1, 0,
0, 1, 0, 0, 0
];
let c_test: Vec<u8> = vec![
0, 0, 0, 0, 0,
0, 0, 0, 0, 0,
0, 0, 0, 1, 0,
0, 0, 0, 0, 0,
0, 0, 0, 0, 0,
0, 0, 0, 0, 0,
0, 0, 0, 0, 1,
0, 0, 0, 1, 0
];
let a = SMatrix::<u8>::from_array(&a_test, 8, 12).unwrap();
let b = SMatrix::<u8>::from_array(&b_test, 12, 5).unwrap();
let mut c = SMatrix::<u8>::new(8, 5).unwrap();
a.mul(&b, &mut c).unwrap();
for i in 0..8 {
for j in 0..5 {
assert_eq!(SMatrix::<u8>::wrap_bool(c.get(i, j)), c_test[i * 5 + j]);
}
}
}
#[test]
#[autotest_annotate(autotest_smatrixb_mulf)]
fn test_smatrixb_mulf() {
let tol = 1e-6f32;
let mut a = SMatrix::<u8>::new(8, 12).unwrap();
a.set(0, 0, 1);
a.set(2, 0, 1);
a.set(6, 0, 1);
a.set(3, 2, 1);
a.set(6, 2, 1);
a.set(7, 2, 1);
a.set(1, 3, 1);
a.set(7, 5, 1);
a.set(3, 6, 1);
a.set(5, 6, 1);
a.set(7, 6, 1);
a.set(3, 7, 1);
a.set(2, 8, 1);
a.set(5, 8, 1);
a.set(2, 9, 1);
a.set(5, 10, 1);
a.set(6, 10, 1);
a.set(6, 11, 1);
let x: Vec<f32> = vec![
-4.3, -0.7, 3.7,
-1.7, 2.8, 4.3,
2.0, 1.9, 0.6,
3.6, 1.0, -3.7,
4.3, 0.7, 2.1,
4.6, 0.5, 0.8,
1.6, -3.8, -0.8,
-1.9, -2.1, 2.8,
-1.5, 2.5, 0.8,
8.4, 1.5, -3.1,
-5.8, 0.0, 2.5,
-4.9, -2.1, -1.5
];
let y_test: Vec<f32> = vec![
-4.3, -0.7, 3.7,
3.6, 1.0, -3.7,
2.6, 3.3, 1.4,
1.7, -4.0, 2.6,
0.0, 0.0, 0.0,
-5.7, -1.3, 2.5,
-13.0, -0.9, 5.3,
8.2, -1.4, 0.6
];
let mut y = vec![0.0f32; 24];
a.mul_f32(&x, 12, 3, &mut y, 8, 3).unwrap();
for i in 0..24 {
assert_relative_eq!(y[i], y_test[i], epsilon = tol);
}
}
#[test]
#[autotest_annotate(autotest_smatrixb_vmulf)]
fn test_smatrixb_vmulf() {
let tol = 1e-6f32;
let mut a = SMatrix::<u8>::new(8, 12).unwrap();
a.set(0, 0, 1);
a.set(2, 0, 1);
a.set(6, 0, 1);
a.set(3, 2, 1);
a.set(6, 2, 1);
a.set(7, 2, 1);
a.set(1, 3, 1);
a.set(7, 5, 1);
a.set(3, 6, 1);
a.set(5, 6, 1);
a.set(7, 6, 1);
a.set(3, 7, 1);
a.set(2, 8, 1);
a.set(5, 8, 1);
a.set(2, 9, 1);
a.set(5, 10, 1);
a.set(6, 10, 1);
a.set(6, 11, 1);
let x: Vec<f32> = vec![
3.4, -5.7, 0.3, 2.3, 1.9, 3.9,
2.3, -4.0, -0.5, 1.5, -0.6, -1.0
];
let y_test: Vec<f32> = vec![
3.4, 2.3, 4.4, -1.4, 0.0, 1.2, 2.1, 6.5
];
let mut y = vec![0.0f32; 8];
a.vmul_f32(&x, &mut y);
for i in 0..8 {
assert_relative_eq!(y[i], y_test[i], epsilon = tol);
}
}
#[test]
#[autotest_annotate(autotest_smatrixi_vmul)]
fn test_smatrixi_vmul() {
let mut a = SMatrix::<i16>::new(4, 5).unwrap();
a.set(0, 4, 4);
a.set(2, 3, 3);
a.set(3, 0, 2);
a.set(3, 4, 0);
a.set(3, 4, 1);
let x = [7, 1, 5, 2, 2];
let y_test = [8, 0, 6, 16];
let mut y = [0; 4];
a.vmul(&x, &mut y);
assert_eq!(y[0], y_test[0]);
assert_eq!(y[1], y_test[1]);
assert_eq!(y[2], y_test[2]);
assert_eq!(y[3], y_test[3]);
}
#[test]
#[autotest_annotate(autotest_smatrixi_mul)]
fn test_smatrixi_mul() {
let mut a = SMatrix::<i16>::new(4, 5).unwrap();
let mut b = SMatrix::<i16>::new(5, 3).unwrap();
let mut c = SMatrix::<i16>::new(4, 3).unwrap();
a.set(0, 4, 4);
a.set(2, 3, 3);
a.set(3, 0, 2);
a.set(3, 4, 0);
a.set(3, 4, 1);
b.set(0, 0, 7);
b.set(0, 1, 6);
b.set(3, 1, 5);
b.set(4, 0, 2);
a.mul(&b, &mut c).unwrap();
let c_test = [
8, 0, 0,
0, 0, 0,
0, 15, 0,
16, 12, 0
];
for i in 0..4 {
for j in 0..3 {
assert_eq!(c.get(i, j), c_test[i * 3 + j]);
}
}
}
}