use nalgebra::DMatrix;
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct FdMatrix {
data: Vec<f64>,
nrows: usize,
ncols: usize,
}
impl FdMatrix {
pub fn from_column_major(
data: Vec<f64>,
nrows: usize,
ncols: usize,
) -> Result<Self, crate::FdarError> {
if data.len() != nrows * ncols {
return Err(crate::FdarError::InvalidDimension {
parameter: "data",
expected: format!("{}", nrows * ncols),
actual: format!("{}", data.len()),
});
}
Ok(Self { data, nrows, ncols })
}
pub fn from_slice(data: &[f64], nrows: usize, ncols: usize) -> Result<Self, crate::FdarError> {
if data.len() != nrows * ncols {
return Err(crate::FdarError::InvalidDimension {
parameter: "data",
expected: format!("{}", nrows * ncols),
actual: format!("{}", data.len()),
});
}
Ok(Self {
data: data.to_vec(),
nrows,
ncols,
})
}
pub fn zeros(nrows: usize, ncols: usize) -> Self {
Self {
data: vec![0.0; nrows * ncols],
nrows,
ncols,
}
}
#[inline]
pub fn nrows(&self) -> usize {
self.nrows
}
#[inline]
pub fn ncols(&self) -> usize {
self.ncols
}
#[inline]
pub fn shape(&self) -> (usize, usize) {
(self.nrows, self.ncols)
}
#[inline]
pub fn len(&self) -> usize {
self.data.len()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.data.is_empty()
}
#[inline]
pub fn column(&self, col: usize) -> &[f64] {
let start = col * self.nrows;
&self.data[start..start + self.nrows]
}
#[inline]
pub fn column_mut(&mut self, col: usize) -> &mut [f64] {
let start = col * self.nrows;
&mut self.data[start..start + self.nrows]
}
pub fn row(&self, row: usize) -> Vec<f64> {
(0..self.ncols)
.map(|j| self.data[row + j * self.nrows])
.collect()
}
#[inline]
pub fn row_to_buf(&self, row: usize, buf: &mut [f64]) {
debug_assert!(
row < self.nrows,
"row {row} out of bounds for {} rows",
self.nrows
);
debug_assert!(
buf.len() >= self.ncols,
"buffer len {} < ncols {}",
buf.len(),
self.ncols
);
let n = self.nrows;
for j in 0..self.ncols {
buf[j] = self.data[row + j * n];
}
}
#[inline]
pub fn row_dot(&self, row_a: usize, other: &FdMatrix, row_b: usize) -> f64 {
debug_assert_eq!(self.ncols, other.ncols, "ncols mismatch in row_dot");
let na = self.nrows;
let nb = other.nrows;
let mut sum = 0.0;
for j in 0..self.ncols {
sum += self.data[row_a + j * na] * other.data[row_b + j * nb];
}
sum
}
#[inline]
pub fn row_l2_sq(&self, row_a: usize, other: &FdMatrix, row_b: usize) -> f64 {
debug_assert_eq!(self.ncols, other.ncols, "ncols mismatch in row_l2_sq");
let na = self.nrows;
let nb = other.nrows;
let mut sum = 0.0;
for j in 0..self.ncols {
let d = self.data[row_a + j * na] - other.data[row_b + j * nb];
sum += d * d;
}
sum
}
pub fn iter_rows(&self) -> impl Iterator<Item = Vec<f64>> + '_ {
(0..self.nrows).map(move |i| self.row(i))
}
pub fn iter_columns(&self) -> impl Iterator<Item = &[f64]> {
(0..self.ncols).map(move |j| self.column(j))
}
pub fn rows(&self) -> Vec<Vec<f64>> {
(0..self.nrows).map(|i| self.row(i)).collect()
}
pub fn to_row_major(&self) -> Vec<f64> {
let mut buf = vec![0.0; self.nrows * self.ncols];
for i in 0..self.nrows {
for j in 0..self.ncols {
buf[i * self.ncols + j] = self.data[i + j * self.nrows];
}
}
buf
}
#[inline]
pub fn as_slice(&self) -> &[f64] {
&self.data
}
#[inline]
pub fn as_mut_slice(&mut self) -> &mut [f64] {
&mut self.data
}
pub fn into_vec(self) -> Vec<f64> {
self.data
}
pub fn to_dmatrix(&self) -> DMatrix<f64> {
DMatrix::from_column_slice(self.nrows, self.ncols, &self.data)
}
pub fn from_dmatrix(mat: &DMatrix<f64>) -> Self {
let (nrows, ncols) = mat.shape();
Self {
data: mat.as_slice().to_vec(),
nrows,
ncols,
}
}
#[inline]
pub fn get(&self, row: usize, col: usize) -> Option<f64> {
if row < self.nrows && col < self.ncols {
Some(self.data[row + col * self.nrows])
} else {
None
}
}
#[inline]
pub fn set(&mut self, row: usize, col: usize, value: f64) -> bool {
if row < self.nrows && col < self.ncols {
self.data[row + col * self.nrows] = value;
true
} else {
false
}
}
}
impl std::ops::Index<(usize, usize)> for FdMatrix {
type Output = f64;
#[inline]
fn index(&self, (row, col): (usize, usize)) -> &f64 {
debug_assert!(
row < self.nrows && col < self.ncols,
"FdMatrix index ({}, {}) out of bounds for {}x{} matrix",
row,
col,
self.nrows,
self.ncols
);
&self.data[row + col * self.nrows]
}
}
impl std::ops::IndexMut<(usize, usize)> for FdMatrix {
#[inline]
fn index_mut(&mut self, (row, col): (usize, usize)) -> &mut f64 {
debug_assert!(
row < self.nrows && col < self.ncols,
"FdMatrix index ({}, {}) out of bounds for {}x{} matrix",
row,
col,
self.nrows,
self.ncols
);
&mut self.data[row + col * self.nrows]
}
}
impl std::fmt::Display for FdMatrix {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "FdMatrix({}x{})", self.nrows, self.ncols)
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct FdCurveSet {
pub dims: Vec<FdMatrix>,
}
impl FdCurveSet {
pub fn ndim(&self) -> usize {
self.dims.len()
}
pub fn ncurves(&self) -> usize {
if self.dims.is_empty() {
0
} else {
self.dims[0].nrows()
}
}
pub fn npoints(&self) -> usize {
if self.dims.is_empty() {
0
} else {
self.dims[0].ncols()
}
}
pub fn from_1d(data: FdMatrix) -> Self {
Self { dims: vec![data] }
}
pub fn from_dims(dims: Vec<FdMatrix>) -> Result<Self, crate::FdarError> {
if dims.is_empty() {
return Err(crate::FdarError::InvalidDimension {
parameter: "dims",
expected: "non-empty".to_string(),
actual: "empty".to_string(),
});
}
let (n, m) = dims[0].shape();
if dims.iter().any(|d| d.shape() != (n, m)) {
return Err(crate::FdarError::InvalidDimension {
parameter: "dims",
expected: format!("all ({n}, {m})"),
actual: "inconsistent shapes".to_string(),
});
}
Ok(Self { dims })
}
pub fn point(&self, curve: usize, time_idx: usize) -> Vec<f64> {
self.dims.iter().map(|d| d[(curve, time_idx)]).collect()
}
}
impl std::fmt::Display for FdCurveSet {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"FdCurveSet(d={}, n={}, m={})",
self.ndim(),
self.ncurves(),
self.npoints()
)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn sample_3x4() -> FdMatrix {
let data = vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, ];
FdMatrix::from_column_major(data, 3, 4).unwrap()
}
#[test]
fn test_from_column_major_valid() {
let mat = sample_3x4();
assert_eq!(mat.nrows(), 3);
assert_eq!(mat.ncols(), 4);
assert_eq!(mat.shape(), (3, 4));
assert_eq!(mat.len(), 12);
assert!(!mat.is_empty());
}
#[test]
fn test_from_column_major_invalid() {
assert!(FdMatrix::from_column_major(vec![1.0, 2.0], 3, 4).is_err());
}
#[test]
fn test_from_slice() {
let data = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let mat = FdMatrix::from_slice(&data, 2, 3).unwrap();
assert_eq!(mat[(0, 0)], 1.0);
assert_eq!(mat[(1, 0)], 2.0);
assert_eq!(mat[(0, 1)], 3.0);
}
#[test]
fn test_from_slice_invalid() {
assert!(FdMatrix::from_slice(&[1.0, 2.0], 3, 3).is_err());
}
#[test]
fn test_zeros() {
let mat = FdMatrix::zeros(2, 3);
assert_eq!(mat.nrows(), 2);
assert_eq!(mat.ncols(), 3);
for j in 0..3 {
for i in 0..2 {
assert_eq!(mat[(i, j)], 0.0);
}
}
}
#[test]
fn test_index() {
let mat = sample_3x4();
assert_eq!(mat[(0, 0)], 1.0);
assert_eq!(mat[(1, 0)], 2.0);
assert_eq!(mat[(2, 0)], 3.0);
assert_eq!(mat[(0, 1)], 4.0);
assert_eq!(mat[(1, 1)], 5.0);
assert_eq!(mat[(2, 3)], 12.0);
}
#[test]
fn test_index_mut() {
let mut mat = sample_3x4();
mat[(1, 2)] = 99.0;
assert_eq!(mat[(1, 2)], 99.0);
}
#[test]
fn test_column() {
let mat = sample_3x4();
assert_eq!(mat.column(0), &[1.0, 2.0, 3.0]);
assert_eq!(mat.column(1), &[4.0, 5.0, 6.0]);
assert_eq!(mat.column(3), &[10.0, 11.0, 12.0]);
}
#[test]
fn test_column_mut() {
let mut mat = sample_3x4();
mat.column_mut(1)[0] = 99.0;
assert_eq!(mat[(0, 1)], 99.0);
}
#[test]
fn test_row() {
let mat = sample_3x4();
assert_eq!(mat.row(0), vec![1.0, 4.0, 7.0, 10.0]);
assert_eq!(mat.row(1), vec![2.0, 5.0, 8.0, 11.0]);
assert_eq!(mat.row(2), vec![3.0, 6.0, 9.0, 12.0]);
}
#[test]
fn test_rows() {
let mat = sample_3x4();
let rows = mat.rows();
assert_eq!(rows.len(), 3);
assert_eq!(rows[0], vec![1.0, 4.0, 7.0, 10.0]);
assert_eq!(rows[2], vec![3.0, 6.0, 9.0, 12.0]);
}
#[test]
fn test_as_slice() {
let mat = sample_3x4();
let expected = vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
];
assert_eq!(mat.as_slice(), expected.as_slice());
}
#[test]
fn test_into_vec() {
let mat = sample_3x4();
let v = mat.into_vec();
assert_eq!(v.len(), 12);
assert_eq!(v[0], 1.0);
}
#[test]
fn test_get_bounds_check() {
let mat = sample_3x4();
assert_eq!(mat.get(0, 0), Some(1.0));
assert_eq!(mat.get(2, 3), Some(12.0));
assert_eq!(mat.get(3, 0), None); assert_eq!(mat.get(0, 4), None); }
#[test]
fn test_set_bounds_check() {
let mut mat = sample_3x4();
assert!(mat.set(1, 1, 99.0));
assert_eq!(mat[(1, 1)], 99.0);
assert!(!mat.set(5, 0, 99.0)); }
#[test]
fn test_nalgebra_roundtrip() {
let mat = sample_3x4();
let dmat = mat.to_dmatrix();
assert_eq!(dmat.nrows(), 3);
assert_eq!(dmat.ncols(), 4);
assert_eq!(dmat[(0, 0)], 1.0);
assert_eq!(dmat[(1, 2)], 8.0);
let back = FdMatrix::from_dmatrix(&dmat);
assert_eq!(mat, back);
}
#[test]
fn test_empty() {
let mat = FdMatrix::zeros(0, 0);
assert!(mat.is_empty());
assert_eq!(mat.len(), 0);
}
#[test]
fn test_single_element() {
let mat = FdMatrix::from_column_major(vec![42.0], 1, 1).unwrap();
assert_eq!(mat[(0, 0)], 42.0);
assert_eq!(mat.column(0), &[42.0]);
assert_eq!(mat.row(0), vec![42.0]);
}
#[test]
fn test_display() {
let mat = sample_3x4();
assert_eq!(format!("{}", mat), "FdMatrix(3x4)");
}
#[test]
fn test_clone() {
let mat = sample_3x4();
let cloned = mat.clone();
assert_eq!(mat, cloned);
}
#[test]
fn test_as_mut_slice() {
let mut mat = FdMatrix::zeros(2, 2);
let s = mat.as_mut_slice();
s[0] = 1.0;
s[1] = 2.0;
s[2] = 3.0;
s[3] = 4.0;
assert_eq!(mat[(0, 0)], 1.0);
assert_eq!(mat[(1, 0)], 2.0);
assert_eq!(mat[(0, 1)], 3.0);
assert_eq!(mat[(1, 1)], 4.0);
}
#[test]
fn test_fd_curve_set_empty() {
assert!(FdCurveSet::from_dims(vec![]).is_err());
let cs = FdCurveSet::from_dims(vec![]).unwrap_or(FdCurveSet { dims: vec![] });
assert_eq!(cs.ndim(), 0);
assert_eq!(cs.ncurves(), 0);
assert_eq!(cs.npoints(), 0);
assert_eq!(format!("{}", cs), "FdCurveSet(d=0, n=0, m=0)");
}
#[test]
fn test_fd_curve_set_from_1d() {
let mat = sample_3x4();
let cs = FdCurveSet::from_1d(mat.clone());
assert_eq!(cs.ndim(), 1);
assert_eq!(cs.ncurves(), 3);
assert_eq!(cs.npoints(), 4);
assert_eq!(cs.point(0, 0), vec![1.0]);
assert_eq!(cs.point(1, 2), vec![8.0]);
}
#[test]
fn test_fd_curve_set_from_dims_consistent() {
let m1 = FdMatrix::from_column_major(vec![1.0, 2.0, 3.0, 4.0], 2, 2).unwrap();
let m2 = FdMatrix::from_column_major(vec![5.0, 6.0, 7.0, 8.0], 2, 2).unwrap();
let cs = FdCurveSet::from_dims(vec![m1, m2]).unwrap();
assert_eq!(cs.ndim(), 2);
assert_eq!(cs.point(0, 0), vec![1.0, 5.0]);
assert_eq!(cs.point(1, 1), vec![4.0, 8.0]);
assert_eq!(format!("{}", cs), "FdCurveSet(d=2, n=2, m=2)");
}
#[test]
fn test_fd_curve_set_from_dims_inconsistent() {
let m1 = FdMatrix::from_column_major(vec![1.0, 2.0], 2, 1).unwrap();
let m2 = FdMatrix::from_column_major(vec![1.0, 2.0, 3.0], 3, 1).unwrap();
assert!(FdCurveSet::from_dims(vec![m1, m2]).is_err());
}
#[test]
fn test_to_row_major() {
let mat = sample_3x4();
let rm = mat.to_row_major();
assert_eq!(
rm,
vec![1.0, 4.0, 7.0, 10.0, 2.0, 5.0, 8.0, 11.0, 3.0, 6.0, 9.0, 12.0]
);
}
#[test]
fn test_row_to_buf() {
let mat = sample_3x4();
let mut buf = vec![0.0; 4];
mat.row_to_buf(0, &mut buf);
assert_eq!(buf, vec![1.0, 4.0, 7.0, 10.0]);
mat.row_to_buf(1, &mut buf);
assert_eq!(buf, vec![2.0, 5.0, 8.0, 11.0]);
mat.row_to_buf(2, &mut buf);
assert_eq!(buf, vec![3.0, 6.0, 9.0, 12.0]);
}
#[test]
fn test_row_to_buf_larger_buffer() {
let mat = sample_3x4();
let mut buf = vec![99.0; 6]; mat.row_to_buf(0, &mut buf);
assert_eq!(&buf[..4], &[1.0, 4.0, 7.0, 10.0]);
assert_eq!(buf[4], 99.0);
}
#[test]
fn test_row_dot_same_matrix() {
let mat = sample_3x4();
assert_eq!(mat.row_dot(0, &mat, 1), 188.0);
assert_eq!(mat.row_dot(0, &mat, 0), 166.0);
}
#[test]
fn test_row_dot_different_matrices() {
let mat1 = sample_3x4();
let data2 = vec![
10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0, 110.0, 120.0, ];
let mat2 = FdMatrix::from_column_major(data2, 3, 4).unwrap();
assert_eq!(mat1.row_dot(0, &mat2, 0), 1660.0);
}
#[test]
fn test_row_l2_sq_identical() {
let mat = sample_3x4();
assert_eq!(mat.row_l2_sq(0, &mat, 0), 0.0);
assert_eq!(mat.row_l2_sq(1, &mat, 1), 0.0);
}
#[test]
fn test_row_l2_sq_different() {
let mat = sample_3x4();
assert_eq!(mat.row_l2_sq(0, &mat, 1), 4.0);
}
#[test]
fn test_row_l2_sq_cross_matrix() {
let mat1 = FdMatrix::from_column_major(vec![0.0, 0.0], 1, 2).unwrap();
let mat2 = FdMatrix::from_column_major(vec![3.0, 4.0], 1, 2).unwrap();
assert_eq!(mat1.row_l2_sq(0, &mat2, 0), 25.0);
}
#[test]
fn test_column_major_layout_matches_manual() {
let n = 5;
let m = 7;
let data: Vec<f64> = (0..n * m).map(|x| x as f64).collect();
let mat = FdMatrix::from_column_major(data.clone(), n, m).unwrap();
for j in 0..m {
for i in 0..n {
assert_eq!(mat[(i, j)], data[i + j * n]);
}
}
}
#[test]
fn test_iter_rows() {
let mat = sample_3x4();
let rows: Vec<Vec<f64>> = mat.iter_rows().collect();
assert_eq!(rows.len(), 3);
assert_eq!(rows[0], vec![1.0, 4.0, 7.0, 10.0]);
assert_eq!(rows[1], vec![2.0, 5.0, 8.0, 11.0]);
assert_eq!(rows[2], vec![3.0, 6.0, 9.0, 12.0]);
}
#[test]
fn test_iter_rows_matches_rows() {
let mat = sample_3x4();
let from_iter: Vec<Vec<f64>> = mat.iter_rows().collect();
let from_rows = mat.rows();
assert_eq!(from_iter, from_rows);
}
#[test]
fn test_iter_rows_partial() {
let mat = sample_3x4();
let first_two: Vec<Vec<f64>> = mat.iter_rows().take(2).collect();
assert_eq!(first_two.len(), 2);
assert_eq!(first_two[0], vec![1.0, 4.0, 7.0, 10.0]);
assert_eq!(first_two[1], vec![2.0, 5.0, 8.0, 11.0]);
}
#[test]
fn test_iter_rows_empty() {
let mat = FdMatrix::zeros(0, 0);
let rows: Vec<Vec<f64>> = mat.iter_rows().collect();
assert!(rows.is_empty());
}
#[test]
fn test_iter_rows_single_row() {
let mat = FdMatrix::from_column_major(vec![1.0, 2.0, 3.0], 1, 3).unwrap();
let rows: Vec<Vec<f64>> = mat.iter_rows().collect();
assert_eq!(rows, vec![vec![1.0, 2.0, 3.0]]);
}
#[test]
fn test_iter_rows_single_column() {
let mat = FdMatrix::from_column_major(vec![1.0, 2.0, 3.0], 3, 1).unwrap();
let rows: Vec<Vec<f64>> = mat.iter_rows().collect();
assert_eq!(rows, vec![vec![1.0], vec![2.0], vec![3.0]]);
}
#[test]
fn test_iter_columns() {
let mat = sample_3x4();
let cols: Vec<&[f64]> = mat.iter_columns().collect();
assert_eq!(cols.len(), 4);
assert_eq!(cols[0], &[1.0, 2.0, 3.0]);
assert_eq!(cols[1], &[4.0, 5.0, 6.0]);
assert_eq!(cols[2], &[7.0, 8.0, 9.0]);
assert_eq!(cols[3], &[10.0, 11.0, 12.0]);
}
#[test]
fn test_iter_columns_partial() {
let mat = sample_3x4();
let first_two: Vec<&[f64]> = mat.iter_columns().take(2).collect();
assert_eq!(first_two.len(), 2);
assert_eq!(first_two[0], &[1.0, 2.0, 3.0]);
assert_eq!(first_two[1], &[4.0, 5.0, 6.0]);
}
#[test]
fn test_iter_columns_empty() {
let mat = FdMatrix::zeros(0, 0);
let cols: Vec<&[f64]> = mat.iter_columns().collect();
assert!(cols.is_empty());
}
#[test]
fn test_iter_columns_single_column() {
let mat = FdMatrix::from_column_major(vec![1.0, 2.0, 3.0], 3, 1).unwrap();
let cols: Vec<&[f64]> = mat.iter_columns().collect();
assert_eq!(cols, vec![&[1.0, 2.0, 3.0]]);
}
#[test]
fn test_iter_columns_single_row() {
let mat = FdMatrix::from_column_major(vec![1.0, 2.0, 3.0], 1, 3).unwrap();
let cols: Vec<&[f64]> = mat.iter_columns().collect();
assert_eq!(cols, vec![&[1.0_f64] as &[f64], &[2.0], &[3.0]]);
}
#[test]
fn test_iter_rows_enumerate() {
let mat = sample_3x4();
for (i, row) in mat.iter_rows().enumerate() {
assert_eq!(row, mat.row(i));
}
}
#[test]
fn test_iter_columns_enumerate() {
let mat = sample_3x4();
for (j, col) in mat.iter_columns().enumerate() {
assert_eq!(col, mat.column(j));
}
}
}