use super::Scalar;
use super::{
AddDiagonalVectorInPlace, GeneralRankOneUpdate, GramMatrix, LinearSolveError, LinearSolveSpd,
MatDiagonal, MatTransposeVec, MatVec, MatrixFromDiagonal, MatrixIdentity, RankOneUpdate,
ScaleInPlace, SymmetricEigen, SymmetricEigenError,
};
#[derive(Clone, Debug, PartialEq)]
pub struct DenseMatrix<F = f64> {
data: Vec<F>,
rows: usize,
cols: usize,
}
impl<F: Scalar> DenseMatrix<F> {
pub fn from_row_slice(rows: usize, cols: usize, data: &[F]) -> Self {
assert_eq!(
data.len(),
rows * cols,
"DenseMatrix::from_row_slice: expected {} entries for a {}×{} matrix, got {}",
rows * cols,
rows,
cols,
data.len()
);
Self {
data: data.to_vec(),
rows,
cols,
}
}
pub fn from_fn<G: FnMut(usize, usize) -> F>(rows: usize, cols: usize, mut f: G) -> Self {
let mut data = Vec::with_capacity(rows * cols);
for i in 0..rows {
for j in 0..cols {
data.push(f(i, j));
}
}
Self { data, rows, cols }
}
pub fn nrows(&self) -> usize {
self.rows
}
pub fn ncols(&self) -> usize {
self.cols
}
pub fn get(&self, i: usize, j: usize) -> F {
assert!(
i < self.rows && j < self.cols,
"DenseMatrix::get: index ({i}, {j}) out of bounds for a {}×{} matrix",
self.rows,
self.cols
);
self.data[i * self.cols + j]
}
}
impl<F: Scalar> MatVec<Vec<F>> for DenseMatrix<F> {
fn matvec(&self, x: &Vec<F>) -> Vec<F> {
assert_eq!(
x.len(),
self.cols,
"matvec: x has length {} but the matrix has {} columns",
x.len(),
self.cols
);
let mut y = vec![F::zero(); self.rows];
for (i, yi) in y.iter_mut().enumerate() {
let row = &self.data[i * self.cols..(i + 1) * self.cols];
*yi = row.iter().zip(x.iter()).map(|(a, xj)| *a * *xj).sum();
}
y
}
}
impl<F: Scalar> MatTransposeVec<Vec<F>> for DenseMatrix<F> {
fn mat_transpose_vec(&self, x: &Vec<F>) -> Vec<F> {
assert_eq!(
x.len(),
self.rows,
"mat_transpose_vec: x has length {} but the matrix has {} rows",
x.len(),
self.rows
);
let mut y = vec![F::zero(); self.cols];
for (i, &xi) in x.iter().enumerate() {
let row = &self.data[i * self.cols..(i + 1) * self.cols];
for (yj, a) in y.iter_mut().zip(row.iter()) {
*yj = *yj + *a * xi;
}
}
y
}
}
impl<F: Scalar> MatrixIdentity for DenseMatrix<F> {
fn identity(n: usize) -> Self {
Self::from_fn(n, n, |i, j| if i == j { F::one() } else { F::zero() })
}
}
impl<F: Scalar> ScaleInPlace<F> for DenseMatrix<F> {
fn scale_in_place(&mut self, scalar: F) {
for entry in &mut self.data {
*entry = *entry * scalar;
}
}
}
impl<F: Scalar> GeneralRankOneUpdate<Vec<F>, F> for DenseMatrix<F> {
fn general_rank_one_update(&mut self, alpha: F, u: &Vec<F>, v: &Vec<F>) {
assert_eq!(
self.rows, self.cols,
"general_rank_one_update: matrix must be square, got {}x{}",
self.rows, self.cols
);
assert_eq!(
self.rows,
u.len(),
"general_rank_one_update: matrix is {}x{} but u has length {}",
self.rows,
self.cols,
u.len()
);
assert_eq!(
self.cols,
v.len(),
"general_rank_one_update: matrix is {}x{} but v has length {}",
self.rows,
self.cols,
v.len()
);
for (i, &ui) in u.iter().enumerate() {
let au = alpha * ui;
let row = &mut self.data[i * self.cols..(i + 1) * self.cols];
for (entry, &vj) in row.iter_mut().zip(v.iter()) {
*entry = *entry + au * vj;
}
}
}
}
impl<F: Scalar> RankOneUpdate<Vec<F>, F> for DenseMatrix<F> {
fn rank_one_update(&mut self, alpha: F, v: &Vec<F>) {
self.general_rank_one_update(alpha, v, v);
}
}
impl<F: Scalar> MatrixFromDiagonal<Vec<F>> for DenseMatrix<F> {
fn from_diagonal(diag: &Vec<F>) -> Self {
let n = diag.len();
Self::from_fn(n, n, |i, j| if i == j { diag[i] } else { F::zero() })
}
}
impl<F: Scalar> MatDiagonal<Vec<F>> for DenseMatrix<F> {
fn diagonal(&self) -> Vec<F> {
assert_eq!(
self.rows, self.cols,
"diagonal: matrix must be square, got {}x{}",
self.rows, self.cols
);
(0..self.rows)
.map(|i| self.data[i * self.cols + i])
.collect()
}
}
impl<F: Scalar> SymmetricEigen<Vec<F>> for DenseMatrix<F> {
fn try_eigh(&self) -> Result<(Self, Vec<F>), SymmetricEigenError> {
assert_eq!(
self.rows, self.cols,
"try_eigh: matrix must be square, got {}x{}",
self.rows, self.cols
);
let n = self.rows;
let (eigenvalues, eigenvectors) =
super::dense_eig::jacobi_eigen(&self.data, n).ok_or(SymmetricEigenError::Failed)?;
let b = Self {
data: eigenvectors,
rows: n,
cols: n,
};
Ok((b, eigenvalues))
}
}
impl<F: Scalar> GramMatrix for DenseMatrix<F> {
fn gram(&self) -> Self {
let n = self.cols;
let mut data = vec![F::zero(); n * n];
for col in self.data.chunks_exact(self.cols) {
for (i, &ci) in col.iter().enumerate() {
let grow = &mut data[i * n..(i + 1) * n];
for (gij, &cj) in grow.iter_mut().zip(col.iter()) {
*gij = *gij + ci * cj;
}
}
}
Self {
data,
rows: n,
cols: n,
}
}
}
impl<F: Scalar> AddDiagonalVectorInPlace<Vec<F>> for DenseMatrix<F> {
fn add_diagonal_vector_in_place(&mut self, diag: &Vec<F>) {
assert_eq!(
self.rows, self.cols,
"add_diagonal_vector_in_place: matrix must be square, got {}x{}",
self.rows, self.cols
);
assert_eq!(
self.rows,
diag.len(),
"add_diagonal_vector_in_place: matrix is {}x{} but diag has length {}",
self.rows,
self.cols,
diag.len()
);
for (i, &di) in diag.iter().enumerate() {
let entry = &mut self.data[i * self.cols + i];
*entry = *entry + di;
}
}
}
impl<F: Scalar> LinearSolveSpd<Vec<F>> for DenseMatrix<F> {
fn solve_spd(&self, b: &Vec<F>) -> Result<Vec<F>, LinearSolveError> {
assert_eq!(
self.rows, self.cols,
"solve_spd: matrix must be square, got {}x{}",
self.rows, self.cols
);
assert_eq!(
self.rows,
b.len(),
"solve_spd: matrix is {}x{} but rhs has length {}",
self.rows,
self.cols,
b.len()
);
super::dense_chol::cholesky_solve_spd(&self.data, self.rows, b)
.ok_or(LinearSolveError::NotPositiveDefinite)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn fixture() -> DenseMatrix {
DenseMatrix::from_row_slice(2, 3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
}
#[test]
fn shape_and_entry_access() {
let a = fixture();
assert_eq!(a.nrows(), 2);
assert_eq!(a.ncols(), 3);
assert_eq!(a.get(0, 2), 3.0);
assert_eq!(a.get(1, 0), 4.0);
}
#[test]
fn from_fn_matches_from_row_slice() {
let by_fn = DenseMatrix::from_fn(2, 3, |i, j| (i * 3 + j + 1) as f64);
assert_eq!(by_fn, fixture());
}
#[test]
fn matvec_computes_a_times_x() {
let a = fixture();
let y = a.matvec(&vec![1.0, 1.0, 1.0]);
assert_eq!(y, vec![6.0, 15.0]);
}
#[test]
fn mat_transpose_vec_computes_a_transpose_times_x() {
let a = fixture();
let y = a.mat_transpose_vec(&vec![1.0, 1.0]);
assert_eq!(y, vec![5.0, 7.0, 9.0]);
}
#[test]
fn matvec_and_transpose_are_consistent() {
let a = fixture();
let x = vec![0.5, -1.0, 2.0];
let v = vec![3.0, -2.0];
let ax = a.matvec(&x);
let atv = a.mat_transpose_vec(&v);
let lhs: f64 = ax.iter().zip(&v).map(|(p, q)| p * q).sum();
let rhs: f64 = x.iter().zip(&atv).map(|(p, q)| p * q).sum();
assert!((lhs - rhs).abs() < 1e-12, "lhs={lhs}, rhs={rhs}");
}
#[test]
#[should_panic(expected = "from_row_slice")]
fn from_row_slice_rejects_wrong_length() {
let _ = DenseMatrix::from_row_slice(2, 2, &[1.0, 2.0, 3.0]);
}
#[test]
#[should_panic(expected = "matvec")]
fn matvec_rejects_length_mismatch() {
let a = fixture();
let _ = a.matvec(&vec![1.0, 1.0]); }
#[test]
#[should_panic(expected = "mat_transpose_vec")]
fn mat_transpose_vec_rejects_length_mismatch() {
let a = fixture();
let _ = a.mat_transpose_vec(&vec![1.0, 1.0, 1.0]); }
#[test]
fn identity_is_square_with_unit_diagonal() {
let id: DenseMatrix = MatrixIdentity::identity(3);
assert_eq!(id.nrows(), 3);
assert_eq!(id.ncols(), 3);
for i in 0..3 {
for j in 0..3 {
assert_eq!(id.get(i, j), if i == j { 1.0 } else { 0.0 });
}
}
}
#[test]
fn scale_in_place_multiplies_every_entry() {
let mut a = fixture();
a.scale_in_place(2.0);
assert_eq!(
a,
DenseMatrix::from_row_slice(2, 3, &[2.0, 4.0, 6.0, 8.0, 10.0, 12.0])
);
}
#[test]
fn general_rank_one_update_symmetric_case() {
let mut a: DenseMatrix = MatrixIdentity::identity(2);
let v = vec![1.0, 2.0];
a.general_rank_one_update(1.0, &v, &v);
assert_eq!(a, DenseMatrix::from_row_slice(2, 2, &[2.0, 2.0, 2.0, 5.0]));
}
#[test]
fn general_rank_one_update_asymmetric_case() {
let mut a = DenseMatrix::from_row_slice(2, 2, &[0.0, 0.0, 0.0, 0.0]);
a.general_rank_one_update(2.0, &vec![1.0, 0.0], &vec![3.0, 4.0]);
assert_eq!(a, DenseMatrix::from_row_slice(2, 2, &[6.0, 8.0, 0.0, 0.0]));
}
#[test]
#[should_panic(expected = "general_rank_one_update")]
fn general_rank_one_update_rejects_non_square() {
let mut a = fixture(); a.general_rank_one_update(1.0, &vec![1.0, 1.0], &vec![1.0, 1.0, 1.0]);
}
#[test]
fn gram_computes_a_transpose_a() {
let g = fixture().gram();
assert_eq!(g.nrows(), 3);
assert_eq!(g.ncols(), 3);
assert_eq!(
g,
DenseMatrix::from_row_slice(
3,
3,
&[17.0, 22.0, 27.0, 22.0, 29.0, 36.0, 27.0, 36.0, 45.0]
)
);
}
#[test]
fn gram_is_symmetric_positive_semidefinite_diagonal() {
let g = fixture().gram();
assert_eq!(g.get(0, 0), 17.0);
assert_eq!(g.get(2, 2), 45.0);
assert_eq!(g.get(0, 2), g.get(2, 0));
}
#[test]
fn add_diagonal_vector_adds_to_diagonal_only() {
let mut a = DenseMatrix::from_row_slice(2, 2, &[1.0, 2.0, 3.0, 4.0]);
a.add_diagonal_vector_in_place(&vec![10.0, 20.0]);
assert_eq!(
a,
DenseMatrix::from_row_slice(2, 2, &[11.0, 2.0, 3.0, 24.0])
);
}
#[test]
#[should_panic(expected = "add_diagonal_vector_in_place")]
fn add_diagonal_vector_rejects_non_square() {
let mut a = fixture(); a.add_diagonal_vector_in_place(&vec![1.0, 1.0]);
}
#[test]
fn solve_spd_round_trips() {
let a: DenseMatrix = DenseMatrix::from_row_slice(2, 2, &[4.0, 1.0, 1.0, 3.0]);
let x = a.solve_spd(&vec![1.0, 2.0]).unwrap();
assert!((x[0] - 1.0 / 11.0).abs() < 1e-12, "x[0] = {}", x[0]);
assert!((x[1] - 7.0 / 11.0).abs() < 1e-12, "x[1] = {}", x[1]);
}
#[test]
fn solve_spd_rejects_non_positive_definite() {
let a: DenseMatrix = DenseMatrix::from_row_slice(2, 2, &[1.0, 2.0, 2.0, 1.0]);
assert_eq!(
a.solve_spd(&vec![1.0, 1.0]),
Err(LinearSolveError::NotPositiveDefinite)
);
}
#[test]
fn gram_damp_solve_pipeline() {
let j: DenseMatrix = DenseMatrix::from_row_slice(3, 2, &[1.0, 0.0, 1.0, 1.0, 0.0, 1.0]);
let mut g = j.gram(); assert_eq!(g, DenseMatrix::from_row_slice(2, 2, &[2.0, 1.0, 1.0, 2.0]));
g.add_diagonal_vector_in_place(&vec![1.0, 1.0]); let x = g.solve_spd(&vec![1.0, 0.0]).unwrap();
assert!((x[0] - 3.0 / 8.0).abs() < 1e-12, "x[0] = {}", x[0]);
assert!((x[1] + 1.0 / 8.0).abs() < 1e-12, "x[1] = {}", x[1]);
}
}