use super::Vector;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct Matrix<T> {
data: Vec<T>,
rows: usize,
cols: usize,
}
impl<T: Copy> Matrix<T> {
pub fn from_vec(rows: usize, cols: usize, data: Vec<T>) -> Result<Self, &'static str> {
if data.len() != rows * cols {
return Err("Data length must equal rows * cols");
}
Ok(Self { data, rows, cols })
}
#[must_use]
pub fn shape(&self) -> (usize, usize) {
(self.rows, self.cols)
}
#[must_use]
pub fn n_rows(&self) -> usize {
self.rows
}
#[must_use]
pub fn n_cols(&self) -> usize {
self.cols
}
#[must_use]
pub fn get(&self, row: usize, col: usize) -> T {
self.data[row * self.cols + col]
}
pub fn set(&mut self, row: usize, col: usize, value: T) {
self.data[row * self.cols + col] = value;
}
#[must_use]
pub fn row(&self, row_idx: usize) -> Vector<T> {
let start = row_idx * self.cols;
let end = start + self.cols;
Vector::from_slice(&self.data[start..end])
}
#[must_use]
pub fn column(&self, col_idx: usize) -> Vector<T> {
let data: Vec<T> = (0..self.rows)
.map(|row| self.data[row * self.cols + col_idx])
.collect();
Vector::from_vec(data)
}
#[must_use]
pub fn as_slice(&self) -> &[T] {
&self.data
}
}
impl Matrix<f32> {
#[must_use]
pub fn zeros(rows: usize, cols: usize) -> Self {
Self {
data: vec![0.0; rows * cols],
rows,
cols,
}
}
#[must_use]
pub fn ones(rows: usize, cols: usize) -> Self {
Self {
data: vec![1.0; rows * cols],
rows,
cols,
}
}
#[must_use]
pub fn eye(n: usize) -> Self {
let mut data = vec![0.0; n * n];
for i in 0..n {
data[i * n + i] = 1.0;
}
Self {
data,
rows: n,
cols: n,
}
}
#[must_use]
pub fn transpose(&self) -> Self {
let mut data = vec![0.0; self.rows * self.cols];
const TILE: usize = 32;
for i0 in (0..self.rows).step_by(TILE) {
let i_end = (i0 + TILE).min(self.rows);
for j0 in (0..self.cols).step_by(TILE) {
let j_end = (j0 + TILE).min(self.cols);
for i in i0..i_end {
let src_base = i * self.cols;
for j in j0..j_end {
data[j * self.rows + i] = self.data[src_base + j];
}
}
}
}
Self {
data,
rows: self.cols,
cols: self.rows,
}
}
pub fn matmul(&self, other: &Self) -> Result<Self, &'static str> {
if self.cols != other.rows {
return Err("Matrix dimensions don't match for multiplication");
}
let mut result = vec![0.0; self.rows * other.cols];
for i in 0..self.rows {
for j in 0..other.cols {
let mut sum = 0.0;
for k in 0..self.cols {
sum += self.get(i, k) * other.get(k, j);
}
result[i * other.cols + j] = sum;
}
}
Ok(Self {
data: result,
rows: self.rows,
cols: other.cols,
})
}
pub fn matvec(&self, vec: &Vector<f32>) -> Result<Vector<f32>, &'static str> {
if self.cols != vec.len() {
return Err("Matrix columns must match vector length");
}
let result: Vec<f32> = (0..self.rows)
.map(|i| {
let row = self.row(i);
row.dot(vec)
})
.collect();
Ok(Vector::from_vec(result))
}
pub fn add(&self, other: &Self) -> Result<Self, &'static str> {
if self.rows != other.rows || self.cols != other.cols {
return Err("Matrix dimensions must match for addition");
}
let data: Vec<f32> = self
.data
.iter()
.zip(other.data.iter())
.map(|(a, b)| a + b)
.collect();
Ok(Self {
data,
rows: self.rows,
cols: self.cols,
})
}
pub fn sub(&self, other: &Self) -> Result<Self, &'static str> {
if self.rows != other.rows || self.cols != other.cols {
return Err("Matrix dimensions must match for subtraction");
}
let data: Vec<f32> = self
.data
.iter()
.zip(other.data.iter())
.map(|(a, b)| a - b)
.collect();
Ok(Self {
data,
rows: self.rows,
cols: self.cols,
})
}
#[must_use]
pub fn mul_scalar(&self, scalar: f32) -> Self {
Self {
data: self.data.iter().map(|x| x * scalar).collect(),
rows: self.rows,
cols: self.cols,
}
}
pub fn cholesky_solve(&self, b: &Vector<f32>) -> Result<Vector<f32>, &'static str> {
if self.rows != self.cols {
return Err("Matrix must be square for Cholesky decomposition");
}
if self.rows != b.len() {
return Err("Matrix rows must match vector length");
}
let n = self.rows;
let l = Self::cholesky_factor(self, n)?;
let y = Self::forward_substitute(&l, b, n);
let x = Self::backward_substitute(&l, &y, n);
Ok(Vector::from_vec(x))
}
fn cholesky_factor(&self, n: usize) -> Result<Vec<f32>, &'static str> {
let mut l = vec![0.0; n * n];
for i in 0..n {
for j in 0..=i {
let mut sum = 0.0;
if i == j {
for k in 0..j {
sum += l[j * n + k] * l[j * n + k];
}
let diag = self.get(j, j) - sum;
if diag <= 0.0 {
return Err("Matrix is not positive definite");
}
l[j * n + j] = diag.sqrt();
} else {
for k in 0..j {
sum += l[i * n + k] * l[j * n + k];
}
l[i * n + j] = (self.get(i, j) - sum) / l[j * n + j];
}
}
}
Ok(l)
}
fn forward_substitute(l: &[f32], b: &Vector<f32>, n: usize) -> Vec<f32> {
let mut y = vec![0.0; n];
for i in 0..n {
let mut sum = 0.0;
for j in 0..i {
sum += l[i * n + j] * y[j];
}
y[i] = (b[i] - sum) / l[i * n + i];
}
y
}
fn backward_substitute(l: &[f32], y: &[f32], n: usize) -> Vec<f32> {
let mut x = vec![0.0; n];
for i in (0..n).rev() {
let mut sum = 0.0;
for j in (i + 1)..n {
sum += l[j * n + i] * x[j];
}
x[i] = (y[i] - sum) / l[i * n + i];
}
x
}
}
#[cfg(test)]
#[path = "matrix_tests.rs"]
mod tests;
#[cfg(test)]
#[path = "tests_matrix_contract.rs"]
mod tests_matrix_contract;