use super::{
GeneralRankOneUpdate, MatDiagonal, MatTransposeVec, MatVec, MatrixFromDiagonal, MatrixIdentity,
RankOneUpdate, ScaleInPlace, SymmetricEigen, SymmetricEigenError,
};
#[derive(Clone, Debug, PartialEq)]
pub struct DenseMatrix {
data: Vec<f64>,
rows: usize,
cols: usize,
}
impl DenseMatrix {
pub fn from_row_slice(rows: usize, cols: usize, data: &[f64]) -> Self {
assert_eq!(
data.len(),
rows * cols,
"DenseMatrix::from_row_slice: expected {} entries for a {}×{} matrix, got {}",
rows * cols,
rows,
cols,
data.len()
);
Self {
data: data.to_vec(),
rows,
cols,
}
}
pub fn from_fn<F: FnMut(usize, usize) -> f64>(rows: usize, cols: usize, mut f: F) -> Self {
let mut data = Vec::with_capacity(rows * cols);
for i in 0..rows {
for j in 0..cols {
data.push(f(i, j));
}
}
Self { data, rows, cols }
}
pub fn nrows(&self) -> usize {
self.rows
}
pub fn ncols(&self) -> usize {
self.cols
}
pub fn get(&self, i: usize, j: usize) -> f64 {
assert!(
i < self.rows && j < self.cols,
"DenseMatrix::get: index ({i}, {j}) out of bounds for a {}×{} matrix",
self.rows,
self.cols
);
self.data[i * self.cols + j]
}
}
impl MatVec<Vec<f64>> for DenseMatrix {
fn matvec(&self, x: &Vec<f64>) -> Vec<f64> {
assert_eq!(
x.len(),
self.cols,
"matvec: x has length {} but the matrix has {} columns",
x.len(),
self.cols
);
let mut y = vec![0.0; self.rows];
for (i, yi) in y.iter_mut().enumerate() {
let row = &self.data[i * self.cols..(i + 1) * self.cols];
*yi = row.iter().zip(x.iter()).map(|(a, xj)| a * xj).sum();
}
y
}
}
impl MatTransposeVec<Vec<f64>> for DenseMatrix {
fn mat_transpose_vec(&self, x: &Vec<f64>) -> Vec<f64> {
assert_eq!(
x.len(),
self.rows,
"mat_transpose_vec: x has length {} but the matrix has {} rows",
x.len(),
self.rows
);
let mut y = vec![0.0; self.cols];
for (i, &xi) in x.iter().enumerate() {
let row = &self.data[i * self.cols..(i + 1) * self.cols];
for (yj, a) in y.iter_mut().zip(row.iter()) {
*yj += a * xi;
}
}
y
}
}
impl MatrixIdentity for DenseMatrix {
fn identity(n: usize) -> Self {
Self::from_fn(n, n, |i, j| if i == j { 1.0 } else { 0.0 })
}
}
impl ScaleInPlace for DenseMatrix {
fn scale_in_place(&mut self, scalar: f64) {
for entry in &mut self.data {
*entry *= scalar;
}
}
}
impl GeneralRankOneUpdate<Vec<f64>> for DenseMatrix {
fn general_rank_one_update(&mut self, alpha: f64, u: &Vec<f64>, v: &Vec<f64>) {
assert_eq!(
self.rows, self.cols,
"general_rank_one_update: matrix must be square, got {}x{}",
self.rows, self.cols
);
assert_eq!(
self.rows,
u.len(),
"general_rank_one_update: matrix is {}x{} but u has length {}",
self.rows,
self.cols,
u.len()
);
assert_eq!(
self.cols,
v.len(),
"general_rank_one_update: matrix is {}x{} but v has length {}",
self.rows,
self.cols,
v.len()
);
for (i, &ui) in u.iter().enumerate() {
let au = alpha * ui;
let row = &mut self.data[i * self.cols..(i + 1) * self.cols];
for (entry, &vj) in row.iter_mut().zip(v.iter()) {
*entry += au * vj;
}
}
}
}
impl RankOneUpdate<Vec<f64>> for DenseMatrix {
fn rank_one_update(&mut self, alpha: f64, v: &Vec<f64>) {
self.general_rank_one_update(alpha, v, v);
}
}
impl MatrixFromDiagonal<Vec<f64>> for DenseMatrix {
fn from_diagonal(diag: &Vec<f64>) -> Self {
let n = diag.len();
Self::from_fn(n, n, |i, j| if i == j { diag[i] } else { 0.0 })
}
}
impl MatDiagonal<Vec<f64>> for DenseMatrix {
fn diagonal(&self) -> Vec<f64> {
assert_eq!(
self.rows, self.cols,
"diagonal: matrix must be square, got {}x{}",
self.rows, self.cols
);
(0..self.rows)
.map(|i| self.data[i * self.cols + i])
.collect()
}
}
impl SymmetricEigen<Vec<f64>> for DenseMatrix {
fn try_eigh(&self) -> Result<(Self, Vec<f64>), SymmetricEigenError> {
assert_eq!(
self.rows, self.cols,
"try_eigh: matrix must be square, got {}x{}",
self.rows, self.cols
);
let n = self.rows;
let (eigenvalues, eigenvectors) =
super::dense_eig::jacobi_eigen(&self.data, n).ok_or(SymmetricEigenError::Failed)?;
let b = Self {
data: eigenvectors,
rows: n,
cols: n,
};
Ok((b, eigenvalues))
}
}
#[cfg(test)]
mod tests {
use super::*;
fn fixture() -> DenseMatrix {
DenseMatrix::from_row_slice(2, 3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
}
#[test]
fn shape_and_entry_access() {
let a = fixture();
assert_eq!(a.nrows(), 2);
assert_eq!(a.ncols(), 3);
assert_eq!(a.get(0, 2), 3.0);
assert_eq!(a.get(1, 0), 4.0);
}
#[test]
fn from_fn_matches_from_row_slice() {
let by_fn = DenseMatrix::from_fn(2, 3, |i, j| (i * 3 + j + 1) as f64);
assert_eq!(by_fn, fixture());
}
#[test]
fn matvec_computes_a_times_x() {
let a = fixture();
let y = a.matvec(&vec![1.0, 1.0, 1.0]);
assert_eq!(y, vec![6.0, 15.0]);
}
#[test]
fn mat_transpose_vec_computes_a_transpose_times_x() {
let a = fixture();
let y = a.mat_transpose_vec(&vec![1.0, 1.0]);
assert_eq!(y, vec![5.0, 7.0, 9.0]);
}
#[test]
fn matvec_and_transpose_are_consistent() {
let a = fixture();
let x = vec![0.5, -1.0, 2.0];
let v = vec![3.0, -2.0];
let ax = a.matvec(&x);
let atv = a.mat_transpose_vec(&v);
let lhs: f64 = ax.iter().zip(&v).map(|(p, q)| p * q).sum();
let rhs: f64 = x.iter().zip(&atv).map(|(p, q)| p * q).sum();
assert!((lhs - rhs).abs() < 1e-12, "lhs={lhs}, rhs={rhs}");
}
#[test]
#[should_panic(expected = "from_row_slice")]
fn from_row_slice_rejects_wrong_length() {
let _ = DenseMatrix::from_row_slice(2, 2, &[1.0, 2.0, 3.0]);
}
#[test]
#[should_panic(expected = "matvec")]
fn matvec_rejects_length_mismatch() {
let a = fixture();
let _ = a.matvec(&vec![1.0, 1.0]); }
#[test]
#[should_panic(expected = "mat_transpose_vec")]
fn mat_transpose_vec_rejects_length_mismatch() {
let a = fixture();
let _ = a.mat_transpose_vec(&vec![1.0, 1.0, 1.0]); }
#[test]
fn identity_is_square_with_unit_diagonal() {
let id = DenseMatrix::identity(3);
assert_eq!(id.nrows(), 3);
assert_eq!(id.ncols(), 3);
for i in 0..3 {
for j in 0..3 {
assert_eq!(id.get(i, j), if i == j { 1.0 } else { 0.0 });
}
}
}
#[test]
fn scale_in_place_multiplies_every_entry() {
let mut a = fixture();
a.scale_in_place(2.0);
assert_eq!(
a,
DenseMatrix::from_row_slice(2, 3, &[2.0, 4.0, 6.0, 8.0, 10.0, 12.0])
);
}
#[test]
fn general_rank_one_update_symmetric_case() {
let mut a = DenseMatrix::identity(2);
let v = vec![1.0, 2.0];
a.general_rank_one_update(1.0, &v, &v);
assert_eq!(a, DenseMatrix::from_row_slice(2, 2, &[2.0, 2.0, 2.0, 5.0]));
}
#[test]
fn general_rank_one_update_asymmetric_case() {
let mut a = DenseMatrix::from_row_slice(2, 2, &[0.0, 0.0, 0.0, 0.0]);
a.general_rank_one_update(2.0, &vec![1.0, 0.0], &vec![3.0, 4.0]);
assert_eq!(a, DenseMatrix::from_row_slice(2, 2, &[6.0, 8.0, 0.0, 0.0]));
}
#[test]
#[should_panic(expected = "general_rank_one_update")]
fn general_rank_one_update_rejects_non_square() {
let mut a = fixture(); a.general_rank_one_update(1.0, &vec![1.0, 1.0], &vec![1.0, 1.0, 1.0]);
}
}