pub trait Matrix {
fn flat_slice(&self) -> &[f64];
fn flat_mut_slice(&mut self) -> &mut [f64];
fn zeros(rows: usize, cols: usize) -> Self;
fn dims(&self) -> (usize, usize);
}
#[cfg(feature = "ndarray")]
impl Matrix for ndarray::Array2<f64> {
fn flat_slice(&self) -> &[f64] {
self.as_slice().unwrap()
}
fn flat_mut_slice(&mut self) -> &mut [f64] {
self.as_slice_mut().unwrap()
}
fn zeros(rows: usize, cols: usize) -> Self {
ndarray::Array2::zeros((rows, cols))
}
fn dims(&self) -> (usize, usize) {
(self.nrows(), self.ncols())
}
}
#[cfg(feature = "nalgebra")]
impl Matrix for nalgebra::DMatrix<f64> {
fn flat_slice(&self) -> &[f64] {
self.as_slice()
}
fn flat_mut_slice(&mut self) -> &mut [f64] {
self.as_mut_slice()
}
fn zeros(rows: usize, cols: usize) -> Self {
nalgebra::DMatrix::zeros(rows, cols)
}
fn dims(&self) -> (usize, usize) {
(self.nrows(), self.ncols())
}
}
#[cfg(test)]
mod tests {
use super::*;
use nalgebra::DMatrix;
use ndarray::Array2;
#[test]
fn test_ndarray_matrix_operations() {
let mat = Array2::<f64>::zeros((2, 3));
assert_eq!(mat.dims(), (2, 3));
assert_eq!(mat.flat_slice().len(), 6);
assert!(mat.flat_slice().iter().all(|&x| x == 0.0));
let mut mat = Array2::<f64>::zeros((2, 2));
{
let slice = mat.flat_mut_slice();
slice[0] = 1.0;
slice[3] = 4.0;
}
assert_eq!(mat.flat_slice(), &[1.0, 0.0, 0.0, 4.0]);
}
#[test]
fn test_nalgebra_matrix_operations() {
let mat = DMatrix::<f64>::zeros(2, 3);
assert_eq!(mat.dims(), (2, 3));
assert_eq!(mat.flat_slice().len(), 6);
assert!(mat.flat_slice().iter().all(|&x| x == 0.0));
let mut mat = DMatrix::<f64>::zeros(2, 2);
{
let slice = mat.flat_mut_slice();
slice[0] = 1.0;
slice[3] = 4.0;
}
assert_eq!(mat.flat_slice(), &[1.0, 0.0, 0.0, 4.0]);
}
#[test]
fn test_matrix_layout() {
let mut ndarray_mat = Array2::<f64>::zeros((2, 3));
let mut nalgebra_mat = DMatrix::<f64>::zeros(2, 3);
for i in 0..6 {
ndarray_mat.flat_mut_slice()[i] = i as f64;
nalgebra_mat.flat_mut_slice()[i] = i as f64;
}
assert_eq!(ndarray_mat.flat_slice(), &[0.0, 1.0, 2.0, 3.0, 4.0, 5.0]);
assert_eq!(nalgebra_mat.flat_slice(), &[0.0, 1.0, 2.0, 3.0, 4.0, 5.0]);
}
}