use std::ops::{Add, BitAnd, BitOr, BitXor, Div, Index, IndexMut, Mul, Not, Sub};
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Matrix<T> {
rows: usize,
cols: usize,
data: Vec<T>,
}
impl<T: Clone> Matrix<T> {
pub fn from_cols(cols_data: Vec<Vec<T>>) -> Self {
let cols = cols_data.len();
assert!(cols > 0, "need at least one column");
let rows = cols_data.get(0).map_or(0, |c| c.len());
assert!(
rows > 0 || cols == 0,
"need at least one row if columns exist"
);
for (i, col) in cols_data.iter().enumerate() {
assert!(
col.len() == rows,
"col {} has len {}, expected {}",
i,
col.len(),
rows
);
}
let data = cols_data.into_iter().flatten().collect();
Matrix { rows, cols, data }
}
pub fn from_vec(data: Vec<T>, rows: usize, cols: usize) -> Self {
assert!(
rows > 0 || cols == 0,
"need at least one row if columns exist"
);
assert!(
cols > 0 || rows == 0,
"need at least one column if rows exist"
);
if rows * cols != 0 {
assert_eq!(
data.len(),
rows * cols,
"data length mismatch: expected {}, got {}",
rows * cols,
data.len()
);
} else {
assert!(data.is_empty(), "data must be empty for 0-sized matrix");
}
Matrix { rows, cols, data }
}
pub fn from_rows_vec(data: Vec<T>, rows: usize, cols: usize) -> Self {
let mut new_vec = Vec::with_capacity(rows * cols);
for c in 0..cols {
for r in 0..rows {
new_vec.push(data[r * cols + c].clone());
}
}
Matrix::from_vec(new_vec, rows, cols)
}
pub fn data(&self) -> &[T] {
&self.data
}
pub fn data_mut(&mut self) -> &mut [T] {
&mut self.data
}
pub fn into_vec(self) -> Vec<T> {
self.data
}
pub fn to_vec(&self) -> Vec<T> {
self.data.clone()
}
pub fn rows(&self) -> usize {
self.rows
}
pub fn cols(&self) -> usize {
self.cols
}
pub fn shape(&self) -> (usize, usize) {
(self.rows, self.cols)
}
pub fn get(&self, r: usize, c: usize) -> &T {
&self[(r, c)]
}
pub fn get_mut(&mut self, r: usize, c: usize) -> &mut T {
&mut self[(r, c)]
}
#[inline]
pub fn column(&self, c: usize) -> &[T] {
assert!(
c < self.cols,
"column index {} out of bounds for {} columns",
c,
self.cols
);
let start = c * self.rows;
&self.data[start..start + self.rows]
}
#[inline]
pub fn column_mut(&mut self, c: usize) -> &mut [T] {
assert!(
c < self.cols,
"column index {} out of bounds for {} columns",
c,
self.cols
);
let start = c * self.rows;
&mut self.data[start..start + self.rows]
}
pub fn iter_columns(&self) -> impl Iterator<Item = &[T]> {
(0..self.cols).map(move |c| self.column(c))
}
pub fn iter_rows(&self) -> impl Iterator<Item = MatrixRow<'_, T>> {
(0..self.rows).map(move |r| MatrixRow {
matrix: self,
row: r,
})
}
pub fn swap_columns(&mut self, c1: usize, c2: usize) {
assert!(
c1 < self.cols,
"column index c1={} out of bounds for {} columns",
c1,
self.cols
);
assert!(
c2 < self.cols,
"column index c2={} out of bounds for {} columns",
c2,
self.cols
);
if c1 == c2 || self.rows == 0 || self.cols == 0 {
return;
}
let (start1, end1) = (c1 * self.rows, (c1 + 1) * self.rows);
let (start2, end2) = (c2 * self.rows, (c2 + 1) * self.rows);
if (start1 < start2 && end1 > start2) || (start2 < start1 && end2 > start1) {
panic!("Cannot swap overlapping columns");
}
for r in 0..self.rows {
self.data.swap(start1 + r, start2 + r);
}
}
pub fn delete_column(&mut self, col: usize) {
assert!(
col < self.cols,
"column index {} out of bounds for {} columns",
col,
self.cols
);
let start = col * self.rows;
self.data.drain(start..start + self.rows); self.cols -= 1;
}
#[inline]
pub fn row(&self, r: usize) -> Vec<T> {
assert!(
r < self.rows,
"row index {} out of bounds for {} rows",
r,
self.rows
);
let mut row_data = Vec::with_capacity(self.cols);
for c in 0..self.cols {
row_data.push(self[(r, c)].clone()); }
row_data
}
pub fn row_copy_from_slice(&mut self, r: usize, values: &[T]) {
assert!(
r < self.rows,
"row index {} out of bounds for {} rows",
r,
self.rows
);
assert!(
values.len() == self.cols,
"input slice length {} does not match number of columns {}",
values.len(),
self.cols
);
for (c, value) in values.iter().enumerate() {
let idx = r + c * self.rows; self.data[idx] = value.clone();
}
}
pub fn delete_row(&mut self, row: usize) {
assert!(
row < self.rows,
"row index {} out of bounds for {} rows",
row,
self.rows
);
if self.rows == 0 {
return;
}
let old_rows = self.rows;
let new_rows = self.rows - 1;
let mut new_data = Vec::with_capacity(new_rows * self.cols);
for c in 0..self.cols {
let col_start_old = c * old_rows;
for r in 0..old_rows {
if r != row {
new_data.push(self.data[col_start_old + r].clone());
}
}
}
self.data = new_data;
self.rows = new_rows;
}
pub fn transpose(&self) -> Matrix<T> {
let (m, n) = (self.rows, self.cols);
let mut transposed_data = Vec::with_capacity(m * n);
for j in 0..m {
for i in 0..n {
transposed_data.push(self[(j, i)].clone()); }
}
Matrix::from_vec(transposed_data, n, m) }
}
impl<T: Clone> Matrix<T> {
pub fn add_column(&mut self, index: usize, column: Vec<T>) {
assert!(
index <= self.cols,
"add_column index {} out of bounds for {} columns",
index,
self.cols
);
assert_eq!(
column.len(),
self.rows,
"column length mismatch: expected {}, got {}",
self.rows,
column.len()
);
if self.rows == 0 && self.cols == 0 {
assert!(index == 0, "index must be 0 for adding first column");
self.data = column;
self.cols = 1;
} else {
let insert_pos = index * self.rows;
self.data.splice(insert_pos..insert_pos, column); self.cols += 1;
}
}
pub fn add_row(&mut self, index: usize, row: Vec<T>) {
assert!(
index <= self.rows,
"add_row index {} out of bounds for {} rows",
index,
self.rows
);
assert_eq!(
row.len(),
self.cols,
"row length mismatch: expected {} (cols), got {}",
self.cols,
row.len()
);
if self.cols == 0 && self.rows == 0 {
assert!(index == 0, "index must be 0 for adding first row");
assert!(
self.cols > 0 || row.is_empty(),
"cannot add non-empty row to matrix with 0 columns"
);
if row.is_empty() {
return;
} }
let old_rows = self.rows;
let new_rows = self.rows + 1;
let mut new_data = Vec::with_capacity(new_rows * self.cols);
let mut row_iter = row.into_iter();
for c in 0..self.cols {
let old_col_start = c * old_rows;
for r in 0..new_rows {
if r == index {
new_data.push(row_iter.next().expect("Row iterator exhausted prematurely - should have been caught by length assert"));
} else {
let old_r = if r < index { r } else { r - 1 };
new_data.push(self.data[old_col_start + old_r].clone());
}
}
}
self.data = new_data;
self.rows = new_rows;
}
pub fn repeat_rows(&self, n: usize) -> Matrix<T>
where
T: Clone,
{
let mut data = Vec::with_capacity(n * self.cols());
let zeroth_row = self.row(0);
for value in &zeroth_row {
for _ in 0..n {
data.push(value.clone()); }
}
Matrix::from_vec(data, n, self.cols)
}
pub fn filled(rows: usize, cols: usize, value: T) -> Self {
Matrix {
rows,
cols,
data: vec![value; rows * cols], }
}
pub fn broadcast_row_to_target_shape(
&self,
target_rows: usize,
target_cols: usize,
) -> Matrix<T> {
assert_eq!(
self.rows(),
1,
"broadcast_row_to_target_shape can only be called on a 1-row matrix."
);
assert_eq!(
self.cols(),
target_cols,
"Column count mismatch for broadcasting: source has {} columns, target has {} columns.",
self.cols(),
target_cols
);
let mut data = Vec::with_capacity(target_rows * target_cols);
let original_row_data = self.row(0);
for _ in 0..target_rows {
for value in &original_row_data {
data.push(value.clone());
}
}
Matrix::from_rows_vec(data, target_rows, target_cols)
}
}
impl Matrix<f64> {
pub fn zeros(rows: usize, cols: usize) -> Self {
Matrix::filled(rows, cols, 0.0)
}
pub fn ones(rows: usize, cols: usize) -> Self {
Matrix::filled(rows, cols, 1.0)
}
pub fn nan(rows: usize, cols: usize) -> Matrix<f64> {
Matrix::filled(rows, cols, f64::NAN)
}
}
impl<T> Index<(usize, usize)> for Matrix<T> {
type Output = T;
#[inline]
fn index(&self, (r, c): (usize, usize)) -> &T {
assert!(
r < self.rows && c < self.cols,
"index out of bounds: ({}, {}) vs {}x{}",
r,
c,
self.rows,
self.cols
);
&self.data[c * self.rows + r]
}
}
impl<T> IndexMut<(usize, usize)> for Matrix<T> {
#[inline]
fn index_mut(&mut self, (r, c): (usize, usize)) -> &mut T {
assert!(
r < self.rows && c < self.cols,
"index out of bounds: ({}, {}) vs {}x{}",
r,
c,
self.rows,
self.cols
);
&mut self.data[c * self.rows + r]
}
}
pub struct MatrixRow<'a, T> {
matrix: &'a Matrix<T>,
row: usize,
}
impl<'a, T> MatrixRow<'a, T> {
pub fn get(&self, c: usize) -> &T {
&self.matrix[(self.row, c)]
}
pub fn iter(&self) -> impl Iterator<Item = &T> {
(0..self.matrix.cols).map(move |c| self.get(c))
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum Axis {
Col,
Row,
}
pub trait Broadcastable<T> {
fn to_vec(&self, rows: usize, cols: usize) -> Vec<T>;
}
impl<T: Clone> Broadcastable<T> for T {
fn to_vec(&self, rows: usize, cols: usize) -> Vec<T> {
vec![self.clone(); rows * cols]
}
}
impl<T: Clone> Broadcastable<T> for Matrix<T> {
fn to_vec(&self, rows: usize, cols: usize) -> Vec<T> {
assert_eq!(self.rows, rows, "row count mismatch in broadcast");
assert_eq!(self.cols, cols, "col count mismatch in broadcast");
self.data.clone() }
}
macro_rules! impl_elementwise_cmp {
(
$( $method:ident => $op:tt ),* $(,)?
) => {
impl<T: PartialOrd + Clone> Matrix<T> {
$(
#[doc = concat!("Element-wise comparison `self ", stringify!($op), " rhs`,\n\
where `rhs` may be a `Matrix<T>` or a scalar T.\n\
Returns a `BoolMatrix`.")]
pub fn $method<Rhs>(&self, rhs: Rhs) -> BoolMatrix
where
Rhs: Broadcastable<T>,
{
let rhs_data = rhs.to_vec(self.rows, self.cols);
let data = self
.data
.iter() .zip(rhs_data.iter()) .map(|(a, b)| a $op b) .collect();
Matrix::<bool>::from_vec(data, self.rows, self.cols)
}
)*
}
};
}
impl_elementwise_cmp! {
eq_elem => ==,
ne_elem => !=,
lt_elem => <,
le_elem => <=,
gt_elem => >,
ge_elem => >=,
}
fn check_matrix_dims_for_ops<T>(lhs: &Matrix<T>, rhs: &Matrix<T>) {
assert!(
lhs.rows == rhs.rows,
"Row count mismatch: left has {} rows, right has {} rows",
lhs.rows,
rhs.rows
);
assert!(
lhs.cols == rhs.cols,
"Column count mismatch: left has {} columns, right has {} columns",
lhs.cols,
rhs.cols
);
}
macro_rules! impl_elementwise_op_matrix_all {
($OpTrait:ident, $method:ident, $op:tt) => {
impl<'a, 'b, T> $OpTrait<&'b Matrix<T>> for &'a Matrix<T>
where T: Clone + $OpTrait<Output = T> {
type Output = Matrix<T>;
fn $method(self, rhs: &'b Matrix<T>) -> Matrix<T> {
check_matrix_dims_for_ops(self, rhs);
let data = self.data.iter().cloned().zip(rhs.data.iter().cloned()).map(|(a, b)| a $op b).collect();
Matrix { rows: self.rows, cols: self.cols, data }
}
}
impl<'b, T> $OpTrait<&'b Matrix<T>> for Matrix<T>
where T: Clone + $OpTrait<Output = T> {
type Output = Matrix<T>;
fn $method(mut self, rhs: &'b Matrix<T>) -> Matrix<T> { check_matrix_dims_for_ops(&self, rhs);
for (a, b) in self.data.iter_mut().zip(rhs.data.iter().cloned()) {
*a = a.clone() $op b; }
self
}
}
impl<'a, T> $OpTrait<Matrix<T>> for &'a Matrix<T>
where T: Clone + $OpTrait<Output = T> {
type Output = Matrix<T>;
fn $method(self, mut rhs: Matrix<T>) -> Matrix<T> { check_matrix_dims_for_ops(self, &rhs);
for (a, b) in self.data.iter().cloned().zip(rhs.data.iter_mut()) {
*b = a $op b.clone(); }
rhs
}
}
impl<T> $OpTrait<Matrix<T>> for Matrix<T>
where T: Clone + $OpTrait<Output = T> {
type Output = Matrix<T>;
fn $method(mut self, rhs: Matrix<T>) -> Matrix<T> { check_matrix_dims_for_ops(&self, &rhs);
for (a, b) in self.data.iter_mut().zip(rhs.data.into_iter()) {
*a = a.clone() $op b; }
self
}
}
};
}
macro_rules! impl_elementwise_op_scalar_all {
($OpTrait:ident, $method:ident, $op:tt) => {
impl<'a, T> $OpTrait<T> for &'a Matrix<T>
where T: Clone + $OpTrait<Output = T> {
type Output = Matrix<T>;
fn $method(self, rhs: T) -> Matrix<T> {
let data = self.data.iter().cloned().map(|a| a $op rhs.clone()).collect();
Matrix { rows: self.rows, cols: self.cols, data }
}
}
impl<T> $OpTrait<T> for Matrix<T>
where T: Clone + $OpTrait<Output = T> {
type Output = Matrix<T>;
fn $method(mut self, rhs: T) -> Matrix<T> { for a in self.data.iter_mut() {
*a = a.clone() $op rhs.clone(); }
self
}
}
};
}
impl_elementwise_op_matrix_all!(Add, add, +);
impl_elementwise_op_matrix_all!(Sub, sub, -);
impl_elementwise_op_matrix_all!(Mul, mul, *); impl_elementwise_op_matrix_all!(Div, div, /);
impl_elementwise_op_scalar_all!(Add, add, +);
impl_elementwise_op_scalar_all!(Sub, sub, -);
impl_elementwise_op_scalar_all!(Mul, mul, *);
impl_elementwise_op_scalar_all!(Div, div, /);
macro_rules! impl_bitwise_op_all {
($OpTrait:ident, $method:ident, $op:tt) => {
impl<'a, 'b> $OpTrait<&'b Matrix<bool>> for &'a Matrix<bool> {
type Output = Matrix<bool>;
fn $method(self, rhs: &'b Matrix<bool>) -> Matrix<bool> {
check_matrix_dims_for_ops(self, rhs);
let data = self.data.iter().cloned().zip(rhs.data.iter().cloned()).map(|(a, b)| a $op b).collect();
Matrix { rows: self.rows, cols: self.cols, data }
}
}
impl<'b> $OpTrait<&'b Matrix<bool>> for Matrix<bool> {
type Output = Matrix<bool>;
fn $method(mut self, rhs: &'b Matrix<bool>) -> Matrix<bool> {
check_matrix_dims_for_ops(&self, rhs);
for (a, b) in self.data.iter_mut().zip(rhs.data.iter()) { *a = *a $op *b; } self
}
}
impl<'a> $OpTrait<Matrix<bool>> for &'a Matrix<bool> {
type Output = Matrix<bool>;
fn $method(self, mut rhs: Matrix<bool>) -> Matrix<bool> {
check_matrix_dims_for_ops(self, &rhs);
for (a, b) in self.data.iter().zip(rhs.data.iter_mut()) { *b = *a $op *b; } rhs
}
}
impl $OpTrait<Matrix<bool>> for Matrix<bool> {
type Output = Matrix<bool>;
fn $method(mut self, rhs: Matrix<bool>) -> Matrix<bool> {
check_matrix_dims_for_ops(&self, &rhs);
for (a, b) in self.data.iter_mut().zip(rhs.data.iter()) { *a = *a $op *b; } self
}
}
};
}
impl_bitwise_op_all!(BitAnd, bitand, &);
impl_bitwise_op_all!(BitOr, bitor, |);
impl_bitwise_op_all!(BitXor, bitxor, ^);
impl Not for Matrix<bool> {
type Output = Matrix<bool>;
fn not(mut self) -> Matrix<bool> {
for val in self.data.iter_mut() {
*val = !*val; }
self }
}
impl Not for &Matrix<bool> {
type Output = Matrix<bool>;
fn not(self) -> Matrix<bool> {
let data = self.data.iter().map(|&v| !v).collect(); Matrix {
rows: self.rows,
cols: self.cols,
data,
}
}
}
pub type FloatMatrix = Matrix<f64>;
pub type BoolMatrix = Matrix<bool>;
pub type IntMatrix = Matrix<i32>;
pub type StringMatrix = Matrix<String>;
#[cfg(test)]
mod tests {
use crate::matrix::BoolOps;
use super::*;
fn make_f64_matrix(a: f64, b: f64, c: f64, d: f64) -> FloatMatrix {
Matrix::from_cols(vec![vec![a, c], vec![b, d]])
}
fn make_bool_matrix(a: bool, b: bool, c: bool, d: bool) -> BoolMatrix {
Matrix::from_cols(vec![vec![a, c], vec![b, d]])
}
#[test]
fn test_add_f64() {
let m1 = make_f64_matrix(1.0, 2.0, 3.0, 4.0);
let m2 = make_f64_matrix(5.0, 6.0, 7.0, 8.0);
let expected = make_f64_matrix(6.0, 8.0, 10.0, 12.0);
assert_eq!(m1.clone() + m2.clone(), expected, "M + M");
assert_eq!(m1.clone() + &m2, expected, "M + &M");
assert_eq!(&m1 + m2.clone(), expected, "&M + M");
assert_eq!(&m1 + &m2, expected, "&M + &M");
}
#[test]
fn test_add_scalar_f64() {
let m1 = make_f64_matrix(1.0, 2.0, 3.0, 4.0);
let scalar = 10.0;
let expected = make_f64_matrix(11.0, 12.0, 13.0, 14.0);
assert_eq!(m1.clone() + scalar, expected, "M + S");
assert_eq!(&m1 + scalar, expected, "&M + S");
}
#[test]
fn test_sub_f64() {
let m1 = make_f64_matrix(10.0, 20.0, 30.0, 40.0);
let m2 = make_f64_matrix(1.0, 2.0, 3.0, 4.0);
let expected = make_f64_matrix(9.0, 18.0, 27.0, 36.0);
assert_eq!(m1.clone() - m2.clone(), expected, "M - M");
assert_eq!(m1.clone() - &m2, expected, "M - &M");
assert_eq!(&m1 - m2.clone(), expected, "&M - M");
assert_eq!(&m1 - &m2, expected, "&M - &M");
}
#[test]
fn test_sub_scalar_f64() {
let m1 = make_f64_matrix(11.0, 12.0, 13.0, 14.0);
let scalar = 10.0;
let expected = make_f64_matrix(1.0, 2.0, 3.0, 4.0);
assert_eq!(m1.clone() - scalar, expected, "M - S");
assert_eq!(&m1 - scalar, expected, "&M - S");
}
#[test]
fn test_mul_f64() {
let m1 = make_f64_matrix(1.0, 2.0, 3.0, 4.0);
let m2 = make_f64_matrix(5.0, 6.0, 7.0, 8.0);
let expected = make_f64_matrix(5.0, 12.0, 21.0, 32.0);
assert_eq!(m1.clone() * m2.clone(), expected, "M * M");
assert_eq!(m1.clone() * &m2, expected, "M * &M");
assert_eq!(&m1 * m2.clone(), expected, "&M * M");
assert_eq!(&m1 * &m2, expected, "&M * &M");
}
#[test]
fn test_mul_scalar_f64() {
let m1 = make_f64_matrix(1.0, 2.0, 3.0, 4.0);
let scalar = 3.0;
let expected = make_f64_matrix(3.0, 6.0, 9.0, 12.0);
assert_eq!(m1.clone() * scalar, expected, "M * S");
assert_eq!(&m1 * scalar, expected, "&M * S");
}
#[test]
fn test_div_f64() {
let m1 = make_f64_matrix(10.0, 20.0, 30.0, 40.0);
let m2 = make_f64_matrix(2.0, 5.0, 6.0, 8.0);
let expected = make_f64_matrix(5.0, 4.0, 5.0, 5.0);
assert_eq!(m1.clone() / m2.clone(), expected, "M / M");
assert_eq!(m1.clone() / &m2, expected, "M / &M");
assert_eq!(&m1 / m2.clone(), expected, "&M / M");
assert_eq!(&m1 / &m2, expected, "&M / &M");
}
#[test]
fn test_div_scalar_f64() {
let m1 = make_f64_matrix(10.0, 20.0, 30.0, 40.0);
let scalar = 10.0;
let expected = make_f64_matrix(1.0, 2.0, 3.0, 4.0);
assert_eq!(m1.clone() / scalar, expected, "M / S");
assert_eq!(&m1 / scalar, expected, "&M / S");
}
#[test]
fn test_chained_ops_f64() {
let m = make_f64_matrix(1.0, 2.0, 3.0, 4.0);
let result = (((m.clone() + 1.0) * 2.0) - 4.0) / 2.0;
let expected = make_f64_matrix(0.0, 1.0, 2.0, 3.0);
assert_eq!(result, expected);
}
#[test]
fn test_bitand_bool() {
let m1 = make_bool_matrix(true, false, true, false);
let m2 = make_bool_matrix(true, true, false, false);
let expected = make_bool_matrix(true, false, false, false);
assert_eq!(m1.clone() & m2.clone(), expected, "M & M");
assert_eq!(m1.clone() & &m2, expected, "M & &M");
assert_eq!(&m1 & m2.clone(), expected, "&M & M");
assert_eq!(&m1 & &m2, expected, "&M & &M");
}
#[test]
fn test_bitor_bool() {
let m1 = make_bool_matrix(true, false, true, false);
let m2 = make_bool_matrix(true, true, false, false);
let expected = make_bool_matrix(true, true, true, false);
assert_eq!(m1.clone() | m2.clone(), expected, "M | M");
assert_eq!(m1.clone() | &m2, expected, "M | &M");
assert_eq!(&m1 | m2.clone(), expected, "&M | M");
assert_eq!(&m1 | &m2, expected, "&M | &M");
}
#[test]
fn test_bitxor_bool() {
let m1 = make_bool_matrix(true, false, true, false);
let m2 = make_bool_matrix(true, true, false, false);
let expected = make_bool_matrix(false, true, true, false);
assert_eq!(m1.clone() ^ m2.clone(), expected, "M ^ M");
assert_eq!(m1.clone() ^ &m2, expected, "M ^ &M");
assert_eq!(&m1 ^ m2.clone(), expected, "&M ^ M");
assert_eq!(&m1 ^ &m2, expected, "&M ^ &M");
}
#[test]
fn test_not_bool() {
let m = make_bool_matrix(true, false, true, false);
let expected = make_bool_matrix(false, true, false, true);
assert_eq!(!m.clone(), expected, "!M (consuming)");
assert_eq!(!&m, expected, "!&M (borrowing)");
let original = make_bool_matrix(true, false, true, false);
let _negated_ref = !&original;
assert_eq!(original, make_bool_matrix(true, false, true, false));
}
#[test]
fn test_comparison_eq_elem() {
let m1 = make_f64_matrix(1.0, 2.0, 3.0, 4.0);
let m2 = make_f64_matrix(1.0, 0.0, 3.0, 5.0);
let s = 3.0;
let expected_m = make_bool_matrix(true, false, true, false);
let expected_s = make_bool_matrix(false, false, true, false);
assert_eq!(m1.eq_elem(m2), expected_m, "eq_elem matrix");
assert_eq!(m1.eq_elem(s), expected_s, "eq_elem scalar");
}
#[test]
fn test_comparison_gt_elem() {
let m1 = make_f64_matrix(1.0, 2.0, 3.0, 4.0);
let m2 = make_f64_matrix(0.0, 3.0, 3.0, 5.0);
let s = 2.5;
let expected_m = make_bool_matrix(true, false, false, false);
let expected_s = make_bool_matrix(false, false, true, true);
assert_eq!(m1.gt_elem(m2), expected_m, "gt_elem matrix");
assert_eq!(m1.gt_elem(s), expected_s, "gt_elem scalar");
}
#[test]
fn test_indexing() {
let m = make_f64_matrix(1.0, 2.0, 3.0, 4.0);
assert_eq!(m[(0, 0)], 1.0);
assert_eq!(m[(0, 1)], 2.0);
assert_eq!(m[(1, 0)], 3.0);
assert_eq!(m[(1, 1)], 4.0);
assert_eq!(*m.get(1, 0), 3.0); }
#[test]
#[should_panic]
fn test_index_out_of_bounds_row() {
let m = make_f64_matrix(1.0, 2.0, 3.0, 4.0);
let _ = m[(2, 0)]; }
#[test]
#[should_panic]
fn test_index_out_of_bounds_col() {
let m = make_f64_matrix(1.0, 2.0, 3.0, 4.0);
let _ = m[(0, 2)]; }
#[test]
fn test_dimensions() {
let m = make_f64_matrix(1.0, 2.0, 3.0, 4.0);
assert_eq!(m.rows(), 2);
assert_eq!(m.cols(), 2);
}
#[test]
fn test_from_vec() {
let data = vec![1.0, 3.0, 2.0, 4.0]; let m = Matrix::from_vec(data, 2, 2);
let expected = make_f64_matrix(1.0, 2.0, 3.0, 4.0);
assert_eq!(m, expected);
assert_eq!(m.to_vec(), vec![1.0, 3.0, 2.0, 4.0]);
}
#[test]
fn test_from_rows_vec() {
let rows_data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let matrix = Matrix::from_rows_vec(rows_data, 2, 3);
let data = vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0]; let expected = Matrix::from_vec(data, 2, 3);
assert_eq!(matrix, expected);
}
fn static_test_matrix() -> Matrix<i32> {
let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9];
Matrix::from_vec(data, 3, 3)
}
fn static_test_matrix_2x4() -> Matrix<i32> {
let data = vec![1, 2, 3, 4, 5, 6, 7, 8];
Matrix::from_vec(data, 2, 4)
}
#[test]
fn test_from_vec_basic() {
let data = vec![1, 2, 3, 4, 5, 6]; let matrix = Matrix::from_vec(data, 2, 3);
assert_eq!(matrix.rows(), 2);
assert_eq!(matrix.cols(), 3);
assert_eq!(matrix.data(), &[1, 2, 3, 4, 5, 6]);
assert_eq!(matrix[(0, 0)], 1); assert_eq!(matrix[(1, 0)], 2); assert_eq!(matrix[(0, 1)], 3); assert_eq!(matrix[(1, 2)], 6); }
#[test]
fn test_transpose() {
let matrix = static_test_matrix();
let transposed = matrix.transpose();
let round_triped = transposed.transpose();
assert_eq!(
round_triped, matrix,
"Transposing twice should return original matrix"
);
for r in 0..matrix.rows() {
for c in 0..matrix.cols() {
assert_eq!(matrix[(r, c)], transposed[(c, r)]);
}
}
}
#[test]
fn test_transpose_big() {
let data: Vec<i32> = (1..=20000).collect(); let matrix = Matrix::from_vec(data, 100, 200);
let transposed = matrix.transpose();
assert_eq!(transposed.rows(), 200);
assert_eq!(transposed.cols(), 100);
assert_eq!(transposed.data().len(), 20000);
assert_eq!(transposed[(0, 0)], 1);
let round_trip = transposed.transpose();
assert_eq!(
round_trip, matrix,
"Transposing back should return original matrix"
);
}
#[test]
#[should_panic(expected = "data length mismatch")]
fn test_from_vec_wrong_length() {
let data = vec![1, 2, 3, 4, 5]; Matrix::from_vec(data, 2, 3);
}
#[test]
#[should_panic(expected = "need at least one row")]
fn test_from_vec_zero_rows() {
let data = vec![1, 2, 3];
Matrix::from_vec(data, 0, 3);
}
#[test]
#[should_panic(expected = "need at least one column")]
fn test_from_vec_zero_cols() {
let data = vec![1, 2, 3];
Matrix::from_vec(data, 3, 0);
}
#[test]
fn test_from_cols_basic() {
let cols_data = vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]];
let matrix = Matrix::from_cols(cols_data);
assert_eq!(matrix.rows(), 3);
assert_eq!(matrix.cols(), 3);
assert_eq!(matrix.data(), &[1, 2, 3, 4, 5, 6, 7, 8, 9]);
assert_eq!(matrix[(0, 0)], 1);
assert_eq!(matrix[(2, 0)], 3);
assert_eq!(matrix[(1, 1)], 5);
assert_eq!(matrix[(0, 2)], 7);
}
#[test]
fn test_from_cols_1x1() {
let cols_data = vec![vec![42]];
let matrix = Matrix::from_cols(cols_data);
assert_eq!(matrix.rows(), 1);
assert_eq!(matrix.cols(), 1);
assert_eq!(matrix.data(), &[42]);
assert_eq!(matrix[(0, 0)], 42);
}
#[test]
#[should_panic(expected = "need at least one column")]
fn test_from_cols_empty_cols() {
let empty_cols: Vec<Vec<i32>> = vec![];
Matrix::from_cols(empty_cols);
}
#[test]
#[should_panic(expected = "need at least one row")]
fn test_from_cols_empty_rows() {
let empty_row: Vec<Vec<String>> = vec![vec![], vec![]];
Matrix::from_cols(empty_row);
}
#[test]
#[should_panic(expected = "col 1 has len 2, expected 3")]
fn test_from_cols_mismatched_lengths() {
let cols_data = vec![vec![1, 2, 3], vec![4, 5], vec![6, 7, 8]];
Matrix::from_cols(cols_data);
}
#[test]
fn test_getters() {
let matrix = static_test_matrix();
assert_eq!(matrix.rows(), 3);
assert_eq!(matrix.cols(), 3);
assert_eq!(matrix.data(), &[1, 2, 3, 4, 5, 6, 7, 8, 9]);
}
#[test]
fn test_index_and_get() {
let matrix = static_test_matrix();
assert_eq!(matrix[(0, 0)], 1);
assert_eq!(matrix[(1, 1)], 5);
assert_eq!(matrix[(2, 2)], 9);
assert_eq!(*matrix.get(0, 0), 1);
assert_eq!(*matrix.get(1, 1), 5);
assert_eq!(*matrix.get(2, 2), 9);
}
#[test]
#[should_panic(expected = "index out of bounds")]
fn test_index_out_of_bounds_row_alt() {
let matrix = static_test_matrix();
let _ = matrix[(3, 0)];
}
#[test]
#[should_panic(expected = "index out of bounds")]
fn test_index_out_of_bounds_col_alt() {
let matrix = static_test_matrix();
let _ = matrix[(0, 3)];
}
#[test]
fn test_index_mut_and_get_mut() {
let mut matrix = static_test_matrix();
matrix[(0, 0)] = 10;
matrix[(1, 1)] = 20;
matrix[(2, 2)] = 30;
assert_eq!(matrix[(0, 0)], 10);
assert_eq!(matrix[(1, 1)], 20);
assert_eq!(matrix[(2, 2)], 30);
*matrix.get_mut(0, 1) = 15;
*matrix.get_mut(2, 1) = 25;
assert_eq!(matrix[(0, 1)], 15);
assert_eq!(matrix[(2, 1)], 25);
assert_eq!(matrix.data(), &[10, 2, 3, 15, 20, 25, 7, 8, 30]);
}
#[test]
#[should_panic(expected = "index out of bounds")]
fn test_index_mut_out_of_bounds_row() {
let mut matrix = static_test_matrix();
matrix[(3, 0)] = 99;
}
#[test]
#[should_panic(expected = "index out of bounds")]
fn test_index_mut_out_of_bounds_col() {
let mut matrix = static_test_matrix();
matrix[(0, 3)] = 99;
}
#[test]
fn test_row() {
let ma = static_test_matrix();
assert_eq!(ma.row(0), &[1, 4, 7]);
assert_eq!(ma.row(1), &[2, 5, 8]);
assert_eq!(ma.row(2), &[3, 6, 9]);
}
#[test]
fn test_row_copy_from_slice() {
let mut ma = static_test_matrix();
let new_row = vec![10, 20, 30];
ma.row_copy_from_slice(1, &new_row);
assert_eq!(ma.row(1), &[10, 20, 30]);
}
#[test]
#[should_panic(expected = "row index 4 out of bounds for 3 rows")]
fn test_row_copy_from_slice_out_of_bounds() {
let mut ma = static_test_matrix();
let new_row = vec![10, 20, 30];
ma.row_copy_from_slice(4, &new_row);
}
#[test]
#[should_panic(expected = "row index 3 out of bounds for 3 rows")]
fn test_row_out_of_bounds_index() {
let ma = static_test_matrix();
ma.row(3);
}
#[test]
#[should_panic(expected = "input slice length 2 does not match number of columns 3")]
fn test_row_copy_from_slice_wrong_length() {
let mut ma = static_test_matrix();
let new_row = vec![10, 20]; ma.row_copy_from_slice(1, &new_row);
}
#[test]
fn test_shape() {
let ma = static_test_matrix_2x4();
assert_eq!(ma.shape(), (2, 4));
assert_eq!(ma.rows(), 2);
assert_eq!(ma.cols(), 4);
}
#[test]
fn test_repeat_rows() {
let ma = static_test_matrix();
let repeated = ma.repeat_rows(3);
for r in 0..repeated.rows() {
assert_eq!(repeated.row(r), ma.row(0));
}
}
#[test]
#[should_panic(expected = "row index 3 out of bounds for 3 rows")]
fn test_row_out_of_bounds() {
let ma = static_test_matrix();
ma.row(3);
}
#[test]
fn test_column() {
let matrix = static_test_matrix_2x4();
assert_eq!(matrix.column(0), &[1, 2]);
assert_eq!(matrix.column(1), &[3, 4]);
assert_eq!(matrix.column(2), &[5, 6]);
assert_eq!(matrix.column(3), &[7, 8]);
}
#[test]
#[should_panic(expected = "column index 4 out of bounds for 4 columns")]
fn test_column_out_of_bounds() {
let matrix = static_test_matrix_2x4();
matrix.column(4);
}
#[test]
fn test_column_mut() {
let mut matrix = static_test_matrix_2x4();
let col1_mut = matrix.column_mut(1);
col1_mut[0] = 30;
col1_mut[1] = 40;
let col3_mut = matrix.column_mut(3);
col3_mut[0] = 70;
assert_eq!(matrix[(0, 1)], 30);
assert_eq!(matrix[(1, 1)], 40);
assert_eq!(matrix[(0, 3)], 70);
assert_eq!(matrix[(1, 3)], 8);
assert_eq!(matrix.data(), &[1, 2, 30, 40, 5, 6, 70, 8]);
}
#[test]
#[should_panic(expected = "column index 4 out of bounds for 4 columns")]
fn test_column_mut_out_of_bounds() {
let mut matrix = static_test_matrix_2x4();
matrix.column_mut(4);
}
#[test]
fn test_iter_columns() {
let matrix = static_test_matrix_2x4();
let cols: Vec<&[i32]> = matrix.iter_columns().collect();
assert_eq!(cols.len(), 4);
assert_eq!(cols[0], &[1, 2]);
assert_eq!(cols[1], &[3, 4]);
assert_eq!(cols[2], &[5, 6]);
assert_eq!(cols[3], &[7, 8]);
}
#[test]
fn test_iter_rows() {
let matrix = static_test_matrix_2x4();
let rows: Vec<Vec<i32>> = matrix
.iter_rows()
.map(|row| row.iter().cloned().collect())
.collect();
assert_eq!(rows.len(), 2);
assert_eq!(rows[0], vec![1, 3, 5, 7]);
assert_eq!(rows[1], vec![2, 4, 6, 8]);
}
#[test]
fn test_data_mut() {
let mut matrix = static_test_matrix();
let data_mut = matrix.data_mut();
data_mut[0] = 10;
data_mut[1] = 20;
assert_eq!(matrix[(0, 0)], 10);
assert_eq!(matrix[(1, 0)], 20);
}
#[test]
fn test_matrix_row_get_and_iter() {
let matrix = static_test_matrix_2x4();
let row0 = matrix.iter_rows().next().unwrap();
assert_eq!(*row0.get(0), 1);
assert_eq!(*row0.get(1), 3);
assert_eq!(*row0.get(3), 7);
let row0_vec: Vec<i32> = row0.iter().cloned().collect();
assert_eq!(row0_vec, vec![1, 3, 5, 7]);
let row1 = matrix.iter_rows().nth(1).unwrap();
assert_eq!(*row1.get(0), 2);
assert_eq!(*row1.get(2), 6);
let row1_vec: Vec<i32> = row1.iter().cloned().collect();
assert_eq!(row1_vec, vec![2, 4, 6, 8]);
}
#[test]
fn test_swap_columns() {
let mut matrix = static_test_matrix();
matrix.swap_columns(0, 2);
assert_eq!(matrix.rows(), 3);
assert_eq!(matrix.cols(), 3);
assert_eq!(matrix[(0, 0)], 7);
assert_eq!(matrix[(1, 0)], 8);
assert_eq!(matrix[(2, 0)], 9);
assert_eq!(matrix[(0, 1)], 4); assert_eq!(matrix[(1, 1)], 5);
assert_eq!(matrix[(2, 1)], 6);
assert_eq!(matrix[(0, 2)], 1);
assert_eq!(matrix[(1, 2)], 2);
assert_eq!(matrix[(2, 2)], 3);
let original_data = matrix.data().to_vec();
matrix.swap_columns(1, 1);
assert_eq!(matrix.data(), &original_data);
assert_eq!(matrix.data(), &[7, 8, 9, 4, 5, 6, 1, 2, 3]);
}
#[test]
#[should_panic(expected = "column index c2=3 out of bounds for 3 columns")]
fn test_swap_columns_out_of_bounds() {
let mut matrix = static_test_matrix();
matrix.swap_columns(0, 3);
}
#[test]
fn test_delete_column() {
let mut matrix = static_test_matrix_2x4();
matrix.delete_column(1);
assert_eq!(matrix.rows(), 2);
assert_eq!(matrix.cols(), 3);
assert_eq!(matrix[(0, 0)], 1);
assert_eq!(matrix[(1, 0)], 2);
assert_eq!(matrix[(0, 1)], 5);
assert_eq!(matrix[(1, 1)], 6);
assert_eq!(matrix[(0, 2)], 7);
assert_eq!(matrix[(1, 2)], 8);
assert_eq!(matrix.data(), &[1, 2, 5, 6, 7, 8]);
matrix.delete_column(0);
assert_eq!(matrix.rows(), 2);
assert_eq!(matrix.cols(), 2);
assert_eq!(matrix.data(), &[5, 6, 7, 8]);
matrix.delete_column(1);
assert_eq!(matrix.rows(), 2);
assert_eq!(matrix.cols(), 1);
assert_eq!(matrix.data(), &[5, 6]);
matrix.delete_column(0);
assert_eq!(matrix.rows(), 2); assert_eq!(matrix.cols(), 0); assert_eq!(matrix.data(), &[]);
}
#[test]
#[should_panic(expected = "column index 4 out of bounds for 4 columns")]
fn test_delete_column_out_of_bounds() {
let mut matrix = static_test_matrix_2x4();
matrix.delete_column(4);
}
#[test]
fn test_delete_row() {
let mut matrix = static_test_matrix();
matrix.delete_row(1);
assert_eq!(matrix.rows(), 2);
assert_eq!(matrix.cols(), 3);
assert_eq!(matrix[(0, 0)], 1);
assert_eq!(matrix[(1, 0)], 3);
assert_eq!(matrix[(0, 1)], 4);
assert_eq!(matrix[(1, 1)], 6);
assert_eq!(matrix[(0, 2)], 7);
assert_eq!(matrix[(1, 2)], 9);
assert_eq!(matrix.data(), &[1, 3, 4, 6, 7, 9]);
matrix.delete_row(0);
assert_eq!(matrix.rows(), 1);
assert_eq!(matrix.cols(), 3);
assert_eq!(matrix.data(), &[3, 6, 9]);
matrix.delete_row(0);
assert_eq!(matrix.rows(), 0); assert_eq!(matrix.cols(), 3); assert_eq!(matrix.data(), &[]);
}
#[test]
#[should_panic(expected = "row index 3 out of bounds for 3 rows")]
fn test_delete_row_out_of_bounds() {
let mut matrix = static_test_matrix();
matrix.delete_row(3);
}
#[test]
fn test_add_column() {
let mut matrix = static_test_matrix_2x4();
let new_col = vec![9, 10];
matrix.add_column(2, new_col);
assert_eq!(matrix.rows(), 2);
assert_eq!(matrix.cols(), 5);
assert_eq!(matrix[(0, 0)], 1);
assert_eq!(matrix[(1, 0)], 2);
assert_eq!(matrix[(0, 1)], 3);
assert_eq!(matrix[(1, 1)], 4);
assert_eq!(matrix[(0, 2)], 9);
assert_eq!(matrix[(1, 2)], 10);
assert_eq!(matrix[(0, 3)], 5); assert_eq!(matrix[(1, 3)], 6);
assert_eq!(matrix[(0, 4)], 7); assert_eq!(matrix[(1, 4)], 8);
assert_eq!(matrix.data(), &[1, 2, 3, 4, 9, 10, 5, 6, 7, 8]);
let new_col_start = vec![11, 12];
matrix.add_column(0, new_col_start);
assert_eq!(matrix.rows(), 2);
assert_eq!(matrix.cols(), 6);
assert_eq!(matrix[(0, 0)], 11);
assert_eq!(matrix[(1, 0)], 12);
assert_eq!(matrix.data(), &[11, 12, 1, 2, 3, 4, 9, 10, 5, 6, 7, 8]);
let new_col_end = vec![13, 14];
matrix.add_column(6, new_col_end);
assert_eq!(matrix.rows(), 2);
assert_eq!(matrix.cols(), 7);
assert_eq!(matrix[(0, 6)], 13);
assert_eq!(matrix[(1, 6)], 14);
assert_eq!(
matrix.data(),
&[11, 12, 1, 2, 3, 4, 9, 10, 5, 6, 7, 8, 13, 14]
);
}
#[test]
#[should_panic(expected = "add_column index 5 out of bounds for 4 columns")]
fn test_add_column_out_of_bounds() {
let mut matrix = static_test_matrix_2x4();
let new_col = vec![9, 10];
matrix.add_column(5, new_col); }
#[test]
#[should_panic(expected = "column length mismatch")]
fn test_add_column_length_mismatch() {
let mut matrix = static_test_matrix_2x4();
let new_col = vec![9, 10, 11]; matrix.add_column(0, new_col);
}
#[test]
fn test_add_row() {
let mut matrix = static_test_matrix_2x4();
let new_row = vec![9, 10, 11, 12];
matrix.add_row(1, new_row);
assert_eq!(matrix.rows(), 3);
assert_eq!(matrix.cols(), 4);
assert_eq!(matrix[(0, 0)], 1);
assert_eq!(matrix[(0, 1)], 3);
assert_eq!(matrix[(0, 2)], 5);
assert_eq!(matrix[(0, 3)], 7);
assert_eq!(matrix[(1, 0)], 9);
assert_eq!(matrix[(1, 1)], 10);
assert_eq!(matrix[(1, 2)], 11);
assert_eq!(matrix[(1, 3)], 12);
assert_eq!(matrix[(2, 0)], 2);
assert_eq!(matrix[(2, 1)], 4);
assert_eq!(matrix[(2, 2)], 6);
assert_eq!(matrix[(2, 3)], 8);
assert_eq!(matrix.data(), &[1, 9, 2, 3, 10, 4, 5, 11, 6, 7, 12, 8]);
let new_row_start = vec![13, 14, 15, 16];
matrix.add_row(0, new_row_start);
assert_eq!(matrix.rows(), 4);
assert_eq!(matrix.cols(), 4);
assert_eq!(matrix[(0, 0)], 13);
assert_eq!(matrix[(0, 1)], 14);
assert_eq!(matrix[(0, 2)], 15);
assert_eq!(matrix[(0, 3)], 16);
assert_eq!(matrix[(1, 0)], 1);
assert_eq!(matrix[(2, 1)], 10);
assert_eq!(matrix[(3, 3)], 8);
let new_row_end = vec![17, 18, 19, 20];
matrix.add_row(4, new_row_end);
assert_eq!(matrix.rows(), 5);
assert_eq!(matrix.cols(), 4);
assert_eq!(matrix[(4, 0)], 17);
assert_eq!(matrix[(4, 3)], 20);
}
#[test]
#[should_panic(expected = "add_row index 3 out of bounds for 2 rows")]
fn test_add_row_out_of_bounds() {
let mut matrix = static_test_matrix_2x4();
let new_row = vec![9, 10, 11, 12];
matrix.add_row(3, new_row); }
#[test]
#[should_panic(expected = "row length mismatch")]
fn test_add_row_length_mismatch() {
let mut matrix = static_test_matrix_2x4();
let new_row = vec![9, 10, 11]; matrix.add_row(0, new_row);
}
#[test]
fn test_elementwise_add() {
let matrix1 = static_test_matrix();
let matrix2 = Matrix::from_vec(vec![9, 8, 7, 6, 5, 4, 3, 2, 1], 3, 3);
let result = &matrix1 + &matrix2;
assert_eq!(result.rows(), 3);
assert_eq!(result.cols(), 3);
assert_eq!(result.data(), &[10, 10, 10, 10, 10, 10, 10, 10, 10]);
assert_eq!(result[(0, 0)], 10);
assert_eq!(result[(1, 1)], 10);
assert_eq!(result[(2, 2)], 10);
}
#[test]
fn test_elementwise_sub() {
let matrix1 = static_test_matrix();
let matrix2 = Matrix::from_vec(vec![1, 1, 1, 2, 2, 2, 3, 3, 3], 3, 3);
let result = &matrix1 - &matrix2;
assert_eq!(result.rows(), 3);
assert_eq!(result.cols(), 3);
assert_eq!(result.data(), &[0, 1, 2, 2, 3, 4, 4, 5, 6]);
assert_eq!(result[(0, 0)], 0);
assert_eq!(result[(1, 1)], 3);
assert_eq!(result[(2, 2)], 6);
}
#[test]
fn test_elementwise_mul() {
let matrix1 = static_test_matrix();
let matrix2 = Matrix::from_vec(vec![1, 2, 3, 1, 2, 3, 1, 2, 3], 3, 3);
let result = &matrix1 * &matrix2;
assert_eq!(result.rows(), 3);
assert_eq!(result.cols(), 3);
assert_eq!(result.data(), &[1, 4, 9, 4, 10, 18, 7, 16, 27]);
assert_eq!(result[(0, 0)], 1);
assert_eq!(result[(1, 1)], 10);
assert_eq!(result[(2, 2)], 27);
}
#[test]
fn test_elementwise_div() {
let matrix1 = static_test_matrix();
let matrix2 = Matrix::from_vec(vec![1, 1, 1, 2, 2, 2, 7, 8, 9], 3, 3);
let result = &matrix1 / &matrix2;
assert_eq!(result.rows(), 3);
assert_eq!(result.cols(), 3);
assert_eq!(result.data(), &[1, 2, 3, 2, 2, 3, 1, 1, 1]);
assert_eq!(result[(0, 0)], 1);
assert_eq!(result[(1, 1)], 2);
assert_eq!(result[(2, 2)], 1);
}
#[test]
#[should_panic(expected = "Row count mismatch: left has 3 rows, right has 2 rows")]
fn test_elementwise_op_row_mismatch() {
let matrix1 = static_test_matrix();
let matrix2 = static_test_matrix_2x4();
let _ = &matrix1 + &matrix2; }
#[test]
#[should_panic(expected = "Row count mismatch: left has 3 rows, right has 2 ro")]
fn test_elementwise_op_col_mismatch() {
let matrix1 = static_test_matrix();
let matrix2 = static_test_matrix_2x4();
let _ = &matrix1 * &matrix2; }
#[test]
fn test_bitwise_and() {
let data1 = vec![true, false, true, false, true, false];
let data2 = vec![true, true, false, false, true, true];
let matrix1 = BoolMatrix::from_vec(data1, 2, 3);
let matrix2 = BoolMatrix::from_vec(data2, 2, 3);
let expected_data = vec![true, false, false, false, true, false];
let expected_matrix = BoolMatrix::from_vec(expected_data, 2, 3);
let result = &matrix1 & &matrix2;
assert_eq!(result, expected_matrix);
}
#[test]
fn test_bitwise_or() {
let data1 = vec![true, false, true, false, true, false];
let data2 = vec![true, true, false, false, true, true];
let matrix1 = BoolMatrix::from_vec(data1, 2, 3);
let matrix2 = BoolMatrix::from_vec(data2, 2, 3);
let expected_data = vec![true, true, true, false, true, true];
let expected_matrix = BoolMatrix::from_vec(expected_data, 2, 3);
let result = &matrix1 | &matrix2;
assert_eq!(result, expected_matrix);
}
#[test]
fn test_bitwise_xor() {
let data1 = vec![true, false, true, false, true, false];
let data2 = vec![true, true, false, false, true, true];
let matrix1 = BoolMatrix::from_vec(data1, 2, 3);
let matrix2 = BoolMatrix::from_vec(data2, 2, 3);
let expected_data = vec![false, true, true, false, false, true];
let expected_matrix = BoolMatrix::from_vec(expected_data, 2, 3);
let result = &matrix1 ^ &matrix2;
assert_eq!(result, expected_matrix);
}
#[test]
fn test_bitwise_not() {
let data = vec![true, false, true, false, true, false];
let matrix = BoolMatrix::from_vec(data, 2, 3);
let expected_data = vec![false, true, false, true, false, true];
let expected_matrix = BoolMatrix::from_vec(expected_data, 2, 3);
let result = !matrix; assert_eq!(result, expected_matrix);
}
#[test]
#[should_panic(expected = "Column count mismatch: left has 2 columns, right has 3 columns")]
fn test_bitwise_op_row_mismatch() {
let data1 = vec![true, false, true, false];
let data2 = vec![true, true, false, false, true, true];
let matrix1 = BoolMatrix::from_vec(data1, 2, 2);
let matrix2 = BoolMatrix::from_vec(data2, 2, 3);
let _ = &matrix1 & &matrix2; }
#[test]
#[should_panic(expected = "Column count mismatch: left has 2 columns, right has 3 columns")]
fn test_bitwise_op_col_mismatch() {
let data1 = vec![true, false, true, false];
let data2 = vec![true, true, false, false, true, true];
let matrix1 = BoolMatrix::from_vec(data1, 2, 2);
let matrix2 = BoolMatrix::from_vec(data2, 2, 3);
let _ = &matrix1 | &matrix2; }
#[test]
fn test_string_matrix() {
let data = vec![
"a".to_string(),
"b".to_string(),
"c".to_string(),
"d".to_string(),
];
let matrix = StringMatrix::from_vec(data.clone(), 2, 2);
assert_eq!(matrix[(0, 0)], "a".to_string());
assert_eq!(matrix[(1, 0)], "b".to_string());
assert_eq!(matrix[(0, 1)], "c".to_string());
assert_eq!(matrix[(1, 1)], "d".to_string());
let mut matrix = matrix;
matrix[(0, 0)] = "hello".to_string();
assert_eq!(matrix[(0, 0)], "hello".to_string());
let new_col = vec!["e".to_string(), "f".to_string()];
matrix.add_column(1, new_col);
assert_eq!(matrix.rows(), 2);
assert_eq!(matrix.cols(), 3);
assert_eq!(matrix[(0, 0)], "hello".to_string());
assert_eq!(matrix[(1, 0)], "b".to_string());
assert_eq!(matrix[(0, 1)], "e".to_string()); assert_eq!(matrix[(1, 1)], "f".to_string()); assert_eq!(matrix[(0, 2)], "c".to_string()); assert_eq!(matrix[(1, 2)], "d".to_string());
let new_row = vec!["g".to_string(), "h".to_string(), "i".to_string()];
matrix.add_row(0, new_row);
assert_eq!(matrix.rows(), 3);
assert_eq!(matrix.cols(), 3);
assert_eq!(matrix[(0, 0)], "g".to_string());
assert_eq!(matrix[(0, 1)], "h".to_string());
assert_eq!(matrix[(0, 2)], "i".to_string());
assert_eq!(matrix[(1, 0)], "hello".to_string()); assert_eq!(matrix[(2, 2)], "d".to_string()); }
#[test]
fn test_float_matrix_ops() {
let data1 = vec![1.0, 2.0, 3.0, 4.0];
let data2 = vec![0.5, 1.5, 2.5, 3.5];
let matrix1 = FloatMatrix::from_vec(data1, 2, 2);
let matrix2 = FloatMatrix::from_vec(data2, 2, 2);
let sum = &matrix1 + &matrix2;
let diff = &matrix1 - &matrix2;
let prod = &matrix1 * &matrix2;
let div = &matrix1 / &matrix2;
assert_eq!(sum.data(), &[1.5, 3.5, 5.5, 7.5]);
assert_eq!(diff.data(), &[0.5, 0.5, 0.5, 0.5]);
assert_eq!(prod.data(), &[0.5, 3.0, 7.5, 14.0]);
assert_eq!(div.rows(), 2);
assert_eq!(div.cols(), 2);
assert!((div[(0, 0)] - 1.0 / 0.5).abs() < 1e-9); assert!((div[(1, 0)] - 2.0 / 1.5).abs() < 1e-9); assert!((div[(0, 1)] - 3.0 / 2.5).abs() < 1e-9); assert!((div[(1, 1)] - 4.0 / 3.5).abs() < 1e-9); }
fn create_test_matrix_i32() -> Matrix<i32> {
Matrix::from_cols(vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]])
}
#[test]
fn test_matrix_swap_columns_directly() {
let mut matrix = create_test_matrix_i32();
let initial_col0_data = matrix.column(0).to_vec(); let initial_col1_data = matrix.column(1).to_vec(); let initial_col2_data = matrix.column(2).to_vec();
matrix.swap_columns(0, 2);
assert_eq!(matrix.rows(), 3, "Matrix rows should remain unchanged");
assert_eq!(matrix.cols(), 3, "Matrix cols should remain unchanged");
assert_eq!(
matrix.column(1),
initial_col1_data.as_slice(), "Column 1 data should be unchanged"
);
assert_eq!(
matrix.column(2),
initial_col0_data.as_slice(),
"Column 2 should now contain the original data from column 0"
);
assert_eq!(
matrix.column(0),
initial_col2_data.as_slice(),
"Column 0 should now contain the original data from column 2"
);
assert_eq!(
matrix.data(),
&[7, 8, 9, 4, 5, 6, 1, 2, 3],
"Underlying data vector is incorrect after swap"
);
let state_before_self_swap = matrix.clone();
matrix.swap_columns(1, 1);
assert_eq!(
matrix, state_before_self_swap,
"Swapping a column with itself should not change the matrix"
);
let mut matrix2 = create_test_matrix_i32();
let initial_col0_data_m2 = matrix2.column(0).to_vec();
let initial_col1_data_m2 = matrix2.column(1).to_vec();
matrix2.swap_columns(0, 1);
assert_eq!(matrix2.column(0), initial_col1_data_m2.as_slice());
assert_eq!(matrix2.column(1), initial_col0_data_m2.as_slice());
assert_eq!(matrix2.data(), &[4, 5, 6, 1, 2, 3, 7, 8, 9]);
}
#[test]
fn test_comparision_broadcast() {
let matrix = static_test_matrix();
let result = matrix.gt_elem(0).into_vec();
let expected = vec![true; result.len()];
assert_eq!(result, expected);
let ma = static_test_matrix();
let mb = static_test_matrix();
let result = ma.eq_elem(mb);
assert!(result.all());
let result = matrix.lt_elem(1e10 as i32).all();
assert!(result);
for i in 0..matrix.rows() {
for j in 0..matrix.cols() {
let vx = matrix[(i, j)];
let c = &(matrix.le_elem(vx)) & &(matrix.ge_elem(vx));
assert_eq!(c.count(), 1);
}
}
}
#[test]
fn test_arithmetic_broadcast() {
let matrix = static_test_matrix();
let result = &matrix + 1;
for i in 0..matrix.rows() {
for j in 0..matrix.cols() {
assert_eq!(result[(i, j)], matrix[(i, j)] + 1);
}
}
let result = &matrix * 2;
let result2 = &matrix / 2;
for i in 0..matrix.rows() {
for j in 0..matrix.cols() {
assert_eq!(result[(i, j)], matrix[(i, j)] * 2);
assert_eq!(result2[(i, j)], matrix[(i, j)] / 2);
}
}
}
#[test]
fn test_matrix_zeros_ones_filled() {
let m = Matrix::<f64>::zeros(2, 3);
assert_eq!(m.rows(), 2);
assert_eq!(m.cols(), 3);
assert_eq!(m.data(), &[0.0, 0.0, 0.0, 0.0, 0.0, 0.0]);
let m = Matrix::<f64>::ones(3, 2);
assert_eq!(m.rows(), 3);
assert_eq!(m.cols(), 2);
assert_eq!(m.data(), &[1.0, 1.0, 1.0, 1.0, 1.0, 1.0]);
let m = Matrix::<f64>::filled(2, 2, 42.5);
assert_eq!(m.rows(), 2);
assert_eq!(m.cols(), 2);
assert_eq!(m.data(), &[42.5, 42.5, 42.5, 42.5]);
let m = Matrix::<i32>::filled(2, 3, 7);
assert_eq!(m.rows(), 2);
assert_eq!(m.cols(), 3);
assert_eq!(m.data(), &[7, 7, 7, 7, 7, 7]);
let m = Matrix::nan(3, 3);
assert_eq!(m.rows(), 3);
assert_eq!(m.cols(), 3);
for &value in m.data() {
assert!(value.is_nan(), "Expected NaN, got {}", value);
}
}
#[test]
fn test_broadcast_row_to_target_shape_basic() {
let single_row_matrix = Matrix::from_rows_vec(vec![1.0, 2.0, 3.0], 1, 3);
let target_rows = 5;
let target_cols = 3;
let broadcasted = single_row_matrix.broadcast_row_to_target_shape(target_rows, target_cols);
assert_eq!(broadcasted.rows(), target_rows);
assert_eq!(broadcasted.cols(), target_cols);
for r in 0..target_rows {
assert_eq!(broadcasted.row(r), vec![1.0, 2.0, 3.0]);
}
}
#[test]
fn test_broadcast_row_to_target_shape_single_row() {
let single_row_matrix = Matrix::from_rows_vec(vec![10.0, 20.0], 1, 2);
let target_rows = 1;
let target_cols = 2;
let broadcasted = single_row_matrix.broadcast_row_to_target_shape(target_rows, target_cols);
assert_eq!(broadcasted.rows(), target_rows);
assert_eq!(broadcasted.cols(), target_cols);
assert_eq!(broadcasted.row(0), vec![10.0, 20.0]);
}
#[test]
#[should_panic(
expected = "broadcast_row_to_target_shape can only be called on a 1-row matrix."
)]
fn test_broadcast_row_to_target_shape_panic_not_1_row() {
let multi_row_matrix = Matrix::from_rows_vec(vec![1.0, 2.0, 3.0, 4.0], 2, 2);
multi_row_matrix.broadcast_row_to_target_shape(3, 2);
}
#[test]
#[should_panic(
expected = "Column count mismatch for broadcasting: source has 3 columns, target has 4 columns."
)]
fn test_broadcast_row_to_target_shape_panic_col_mismatch() {
let single_row_matrix = Matrix::from_rows_vec(vec![1.0, 2.0, 3.0], 1, 3);
single_row_matrix.broadcast_row_to_target_shape(5, 4);
}
}