use serde::{Deserialize, Serialize};
use std::cmp::Ordering;
use std::fmt::{self, Debug, Display};
use std::iter::Sum;
use std::ops::{Add, AddAssign, Div, Mul, Neg, Sub, SubAssign};
use std::str::FromStr;
pub trait FloatData<T>:
Mul<Output = T>
+ Display
+ Add<Output = T>
+ Div<Output = T>
+ Neg<Output = T>
+ Copy
+ Debug
+ PartialEq
+ PartialOrd
+ AddAssign
+ Sub<Output = T>
+ SubAssign
+ Sum
+ std::marker::Send
+ std::marker::Sync
{
const ZERO: T;
const ONE: T;
const MIN: T;
const MAX: T;
const NAN: T;
const INFINITY: T;
fn from_usize(v: usize) -> T;
fn from_u16(v: u16) -> T;
fn is_nan(self) -> bool;
fn ln(self) -> T;
fn exp(self) -> T;
fn total_cmp(&self, other: &T) -> Ordering;
}
impl FloatData<f64> for f64 {
const ZERO: f64 = 0.0;
const ONE: f64 = 1.0;
const MIN: f64 = f64::MIN;
const MAX: f64 = f64::MAX;
const NAN: f64 = f64::NAN;
const INFINITY: f64 = f64::INFINITY;
fn from_usize(v: usize) -> f64 {
v as f64
}
fn from_u16(v: u16) -> f64 {
f64::from(v)
}
fn is_nan(self) -> bool {
self.is_nan()
}
fn ln(self) -> f64 {
self.ln()
}
fn exp(self) -> f64 {
self.exp()
}
fn total_cmp(&self, other: &f64) -> Ordering {
f64::total_cmp(self, other)
}
}
impl FloatData<f32> for f32 {
const ZERO: f32 = 0.0;
const ONE: f32 = 1.0;
const MIN: f32 = f32::MIN;
const MAX: f32 = f32::MAX;
const NAN: f32 = f32::NAN;
const INFINITY: f32 = f32::INFINITY;
fn from_usize(v: usize) -> f32 {
v as f32
}
fn from_u16(v: u16) -> f32 {
f32::from(v)
}
fn is_nan(self) -> bool {
self.is_nan()
}
fn ln(self) -> f32 {
self.ln()
}
fn exp(self) -> f32 {
self.exp()
}
fn total_cmp(&self, other: &f32) -> Ordering {
f32::total_cmp(self, other)
}
}
pub struct Matrix<'a, T> {
pub data: &'a [T],
pub index: Vec<usize>,
pub rows: usize,
pub cols: usize,
stride1: usize,
stride2: usize,
}
impl<'a, T> Matrix<'a, T> {
pub fn new(data: &'a [T], rows: usize, cols: usize) -> Self {
Matrix {
data,
index: (0..rows).collect(),
rows,
cols,
stride1: rows,
stride2: 1,
}
}
pub fn get(&self, i: usize, j: usize) -> &T {
&self.data[self.item_index(i, j)]
}
fn item_index(&self, i: usize, j: usize) -> usize {
let mut idx = self.stride2 * i;
idx += j * self.stride1;
idx
}
pub fn get_row_iter(&self, row: usize) -> std::iter::StepBy<std::iter::Skip<std::slice::Iter<'a, T>>> {
self.data.iter().skip(row).step_by(self.rows)
}
pub fn get_col_slice(&self, col: usize, start_row: usize, end_row: usize) -> &[T] {
let i = self.item_index(start_row, col);
let j = self.item_index(end_row, col);
&self.data[i..j]
}
pub fn get_col(&self, col: usize) -> &[T] {
self.get_col_slice(col, 0, self.rows)
}
}
impl<'a, T> Matrix<'a, T>
where
T: Copy,
{
pub fn get_row(&self, row: usize) -> Vec<T> {
self.get_row_iter(row).copied().collect()
}
}
pub struct ColumnarMatrix<'a, T> {
pub columns: Vec<&'a [T]>,
pub masks: Option<Vec<Option<&'a [u8]>>>,
pub index: Vec<usize>,
pub rows: usize,
pub cols: usize,
}
impl<'a, T> ColumnarMatrix<'a, T> {
pub fn new(columns: Vec<&'a [T]>, masks: Option<Vec<Option<&'a [u8]>>>, rows: usize) -> Self {
let cols = columns.len();
if let Some(ref m) = masks {
assert_eq!(m.len(), cols, "Number of masks must match number of columns");
}
ColumnarMatrix {
columns,
masks,
index: (0..rows).collect(),
rows,
cols,
}
}
pub fn get(&self, i: usize, j: usize) -> &T {
&self.columns[j][i]
}
pub fn get_col(&self, col: usize) -> &[T] {
self.columns[col]
}
pub fn get_col_slice(&self, col: usize, start_row: usize, end_row: usize) -> &[T] {
&self.columns[col][start_row..end_row]
}
pub fn is_valid(&self, row: usize, col: usize) -> bool {
if let Some(mask) = self.masks.as_ref().and_then(|m| m[col]) {
let byte_idx = row / 8;
let bit_idx = row % 8;
if byte_idx < mask.len() {
return (mask[byte_idx] >> bit_idx) & 1 != 0;
}
return false;
}
true
}
}
impl<'a, T> ColumnarMatrix<'a, T>
where
T: Copy,
{
pub fn get_row(&self, row: usize) -> Vec<T> {
self.columns.iter().map(|col| col[row]).collect()
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct RowMajorMatrix<T> {
pub data: Vec<T>,
pub rows: usize,
pub cols: usize,
stride1: usize,
stride2: usize,
}
impl<T> RowMajorMatrix<T> {
pub fn new(data: Vec<T>, rows: usize, cols: usize) -> Self {
RowMajorMatrix {
data,
rows,
cols,
stride1: 1,
stride2: cols,
}
}
pub fn get(&self, i: usize, j: usize) -> &T {
&self.data[self.item_index(i, j)]
}
fn item_index(&self, i: usize, j: usize) -> usize {
let mut idx = self.stride2 * i;
idx += j * self.stride1;
idx
}
pub fn append_row(&mut self, items: Vec<T>) {
assert!(items.len().is_multiple_of(self.cols));
let new_rows = items.len() / self.cols;
self.rows += new_rows;
self.data.extend(items);
}
}
impl<'a, T> fmt::Display for Matrix<'a, T>
where
T: FromStr + std::fmt::Display,
<T as FromStr>::Err: 'static + std::error::Error,
{
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let mut val = String::new();
for i in 0..self.rows {
for j in 0..self.cols {
val.push_str(self.get(i, j).to_string().as_str());
if j == (self.cols - 1) {
val.push('\n');
} else {
val.push(' ');
}
}
}
write!(f, "{}", val)
}
}
#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct JaggedMatrix<T> {
pub data: Vec<T>,
pub ends: Vec<usize>,
pub cols: usize,
pub n_records: usize,
}
impl<T> JaggedMatrix<T>
where
T: Copy,
{
pub fn from_vecs(vecs: &[Vec<T>]) -> Self {
let mut data = Vec::new();
let mut ends = Vec::new();
let mut e = 0;
let mut n_records = 0;
for vec in vecs {
for v in vec {
data.push(*v);
}
e += vec.len();
ends.push(e);
n_records += e;
}
let cols = vecs.len();
JaggedMatrix {
data,
ends,
cols,
n_records,
}
}
}
impl<T> JaggedMatrix<T> {
pub fn new() -> Self {
JaggedMatrix {
data: Vec::new(),
ends: Vec::new(),
cols: 0,
n_records: 0,
}
}
pub fn get_col(&self, col: usize) -> &[T] {
assert!(col < self.ends.len());
let (i, j) = if col == 0 {
(0, self.ends[col])
} else {
(self.ends[col - 1], self.ends[col])
};
&self.data[i..j]
}
pub fn get_col_mut(&mut self, col: usize) -> &mut [T] {
assert!(col < self.ends.len());
let (i, j) = if col == 0 {
(0, self.ends[col])
} else {
(self.ends[col - 1], self.ends[col])
};
&mut self.data[i..j]
}
}
impl<T> Default for JaggedMatrix<T> {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rowmatrix_get() {
let v = vec![1, 2, 3, 5, 6, 7];
let m = RowMajorMatrix::new(v, 2, 3);
println!("{:?}", m);
assert_eq!(m.get(0, 0), &1);
assert_eq!(m.get(1, 0), &5);
assert_eq!(m.get(0, 2), &3);
assert_eq!(m.get(1, 1), &6);
}
#[test]
fn test_rowmatrix_append() {
let v = vec![1, 2, 3, 5, 6, 7];
let mut m = RowMajorMatrix::new(v, 2, 3);
m.append_row(vec![-1, -2, -3]);
assert_eq!(m.get(2, 1), &-2);
}
#[test]
fn test_matrix_get() {
let v = vec![1, 2, 3, 5, 6, 7];
let m = Matrix::new(&v, 2, 3);
println!("{}", m);
assert_eq!(m.get(0, 0), &1);
assert_eq!(m.get(1, 0), &2);
}
#[test]
fn test_matrix_get_col_slice() {
let v = vec![1, 2, 3, 5, 6, 7];
let m = Matrix::new(&v, 3, 2);
assert_eq!(m.get_col_slice(0, 0, 3), &vec![1, 2, 3]);
assert_eq!(m.get_col_slice(1, 0, 2), &vec![5, 6]);
assert_eq!(m.get_col_slice(1, 1, 3), &vec![6, 7]);
assert_eq!(m.get_col_slice(0, 1, 2), &vec![2]);
}
#[test]
fn test_matrix_get_col() {
let v = vec![1, 2, 3, 5, 6, 7];
let m = Matrix::new(&v, 3, 2);
assert_eq!(m.get_col(1), &vec![5, 6, 7]);
}
#[test]
fn test_matrix_row() {
let v = vec![1, 2, 3, 5, 6, 7];
let m = Matrix::new(&v, 3, 2);
assert_eq!(m.get_row(2), vec![3, 7]);
assert_eq!(m.get_row(0), vec![1, 5]);
assert_eq!(m.get_row(1), vec![2, 6]);
}
#[test]
fn test_jaggedmatrix_get_col() {
let vecs = vec![vec![0], vec![5, 4, 3, 2], vec![4, 5]];
let jmatrix = JaggedMatrix::from_vecs(&vecs);
assert_eq!(jmatrix.get_col(1), vec![5, 4, 3, 2]);
assert_eq!(jmatrix.get_col(0), vec![0]);
assert_eq!(jmatrix.get_col(2), vec![4, 5]);
}
#[test]
fn test_columnar_matrix() {
let col0 = vec![1.0, 2.0, 3.0];
let col1 = vec![4.0, 5.0, 6.0];
let cm = ColumnarMatrix::new(vec![&col0, &col1], None, 3);
assert_eq!(cm.rows, 3);
assert_eq!(cm.cols, 2);
assert_eq!(*cm.get(0, 0), 1.0);
assert_eq!(*cm.get(2, 1), 6.0);
assert_eq!(cm.get_col(0), &[1.0, 2.0, 3.0]);
assert_eq!(cm.get_col_slice(1, 0, 2), &[4.0, 5.0]);
assert_eq!(cm.get_row(1), vec![2.0, 5.0]);
}
#[test]
fn test_columnar_matrix_is_valid_no_mask() {
let col0 = vec![1.0, 2.0];
let cm = ColumnarMatrix::new(vec![&col0], None, 2);
assert!(cm.is_valid(0, 0));
assert!(cm.is_valid(1, 0));
}
#[test]
fn test_columnar_matrix_is_valid_with_mask() {
let col0 = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
let mask: Vec<u8> = vec![0b10101010, 0b00000001]; let cm = ColumnarMatrix::new(vec![&col0], Some(vec![Some(&mask)]), 9);
assert!(!cm.is_valid(0, 0));
assert!(cm.is_valid(1, 0));
assert!(cm.is_valid(7, 0));
assert!(cm.is_valid(8, 0));
}
#[test]
fn test_columnar_matrix_is_valid_none_mask() {
let col0 = vec![1.0, 2.0];
let cm = ColumnarMatrix::new(vec![&col0], Some(vec![None]), 2);
assert!(cm.is_valid(0, 0));
assert!(cm.is_valid(1, 0));
}
#[test]
fn test_float_data_trait() {
assert_eq!(f64::ZERO, 0.0);
assert_eq!(f64::ONE, 1.0);
assert_eq!(f64::from_usize(5), 5.0);
assert_eq!(f64::from_u16(10), 10.0);
assert!(f64::NAN.is_nan());
assert!((f64::ONE.ln() - 0.0_f64).abs() < 1e-10);
assert!((f64::ONE.exp() - std::f64::consts::E).abs() < 1e-10);
assert_eq!(f32::ZERO, 0.0_f32);
assert_eq!(f32::from_usize(5), 5.0_f32);
assert_eq!(f32::from_u16(10), 10.0_f32);
assert!(f32::NAN.is_nan());
}
#[test]
fn test_jagged_matrix_new_and_default() {
let jm: JaggedMatrix<f64> = JaggedMatrix::new();
assert_eq!(jm.cols, 0);
assert_eq!(jm.n_records, 0);
let jm2: JaggedMatrix<f64> = JaggedMatrix::default();
assert_eq!(jm2.cols, 0);
}
#[test]
fn test_jagged_matrix_get_col_mut() {
let vecs = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
let mut jm = JaggedMatrix::from_vecs(&vecs);
let col = jm.get_col_mut(0);
col[0] = 99.0;
assert_eq!(jm.data[0], 99.0);
let col1 = jm.get_col_mut(1);
col1[1] = 88.0;
assert_eq!(jm.data[3], 88.0);
}
#[test]
fn test_matrix_display() {
let v = vec![1, 2, 3, 4];
let m = Matrix::new(&v, 2, 2);
let s = format!("{}", m);
assert!(s.contains("1"));
assert!(s.contains("4"));
}
}