use super::NumericMatrix;
use crate::error::MathError;
const PIVOT_THRESHOLD: f64 = 1e-10;
#[cfg(test)]
const EPSILON: f64 = 1e-10;
#[derive(Debug, Clone, PartialEq)]
pub struct LUResult {
pub l: NumericMatrix,
pub u: NumericMatrix,
pub p: Vec<usize>,
pub num_swaps: usize,
}
impl NumericMatrix {
pub fn lu_decomposition(&self) -> Result<LUResult, MathError> {
if !self.is_square() {
return Err(MathError::DomainError {
operation: "LU decomposition".to_string(),
value: crate::Expression::integer(self.dimensions().0 as i64),
reason: "LU decomposition requires square matrix".to_string(),
});
}
let n = self.rows;
let mut l = NumericMatrix::identity(n)?;
let mut u = self.clone();
let mut p: Vec<usize> = (0..n).collect();
let mut num_swaps = 0;
for k in 0..n {
let mut max_val = 0.0;
let mut pivot_row = k;
for i in k..n {
let val = u.data[i * n + k].abs();
if val > max_val {
max_val = val;
pivot_row = i;
}
}
if max_val < PIVOT_THRESHOLD {
return Err(MathError::DomainError {
operation: "LU decomposition".to_string(),
value: crate::Expression::float(max_val),
reason: format!(
"Matrix is singular or nearly singular (pivot {} < {})",
max_val, PIVOT_THRESHOLD
),
});
}
if pivot_row != k {
for j in 0..n {
u.data.swap(k * n + j, pivot_row * n + j);
}
for j in 0..k {
l.data.swap(k * n + j, pivot_row * n + j);
}
p.swap(k, pivot_row);
num_swaps += 1;
}
for i in (k + 1)..n {
let factor = u.data[i * n + k] / u.data[k * n + k];
l.data[i * n + k] = factor;
for j in k..n {
u.data[i * n + j] -= factor * u.data[k * n + j];
}
}
}
Ok(LUResult { l, u, p, num_swaps })
}
pub fn determinant(&self) -> Result<f64, MathError> {
if !self.is_square() {
return Err(MathError::DomainError {
operation: "determinant".to_string(),
value: crate::Expression::integer(self.dimensions().0 as i64),
reason: "Determinant requires square matrix".to_string(),
});
}
if self.rows == 1 {
return Ok(self.data[0]);
}
if self.rows == 2 {
return Ok(self.data[0] * self.data[3] - self.data[1] * self.data[2]);
}
let lu = self.lu_decomposition()?;
let mut det = 1.0;
for i in 0..self.rows {
det *= lu.u.data[i * self.rows + i];
}
if lu.num_swaps % 2 == 1 {
det = -det;
}
Ok(det)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn approx_eq(a: f64, b: f64) -> bool {
(a - b).abs() < EPSILON
}
fn matrix_approx_eq(a: &NumericMatrix, b: &NumericMatrix) -> bool {
if a.dimensions() != b.dimensions() {
return false;
}
a.data
.iter()
.zip(b.data.iter())
.all(|(x, y)| approx_eq(*x, *y))
}
#[test]
fn test_lu_2x2() {
let a = NumericMatrix::from_flat(2, 2, vec![2.0, 1.0, 4.0, 3.0]).unwrap();
let lu = a.lu_decomposition().unwrap();
assert_eq!(lu.l.dimensions(), (2, 2));
assert_eq!(lu.u.dimensions(), (2, 2));
assert!(approx_eq(lu.l.get(0, 0).unwrap(), 1.0));
assert!(approx_eq(lu.l.get(1, 1).unwrap(), 1.0));
assert!(approx_eq(lu.l.get(0, 1).unwrap(), 0.0));
let mut pa_data = vec![0.0; 4];
for i in 0..2 {
for j in 0..2 {
pa_data[i * 2 + j] = a.get(lu.p[i], j).unwrap();
}
}
let pa = NumericMatrix::from_flat(2, 2, pa_data).unwrap();
let l_times_u = lu.l.multiply(&lu.u).unwrap();
assert!(matrix_approx_eq(&pa, &l_times_u));
}
#[test]
fn test_lu_3x3() {
let a = NumericMatrix::from_flat(3, 3, vec![2.0, 1.0, 1.0, 4.0, 3.0, 3.0, 8.0, 7.0, 9.0])
.unwrap();
let lu = a.lu_decomposition().unwrap();
assert_eq!(lu.l.dimensions(), (3, 3));
assert_eq!(lu.u.dimensions(), (3, 3));
for i in 0..3 {
assert!(approx_eq(lu.l.get(i, i).unwrap(), 1.0));
}
for i in 0..3 {
for j in (i + 1)..3 {
assert!(approx_eq(lu.l.get(i, j).unwrap(), 0.0));
assert!(approx_eq(lu.u.get(j, i).unwrap(), 0.0));
}
}
let mut pa_data = vec![0.0; 9];
for i in 0..3 {
for j in 0..3 {
pa_data[i * 3 + j] = a.get(lu.p[i], j).unwrap();
}
}
let pa = NumericMatrix::from_flat(3, 3, pa_data).unwrap();
let l_times_u = lu.l.multiply(&lu.u).unwrap();
assert!(matrix_approx_eq(&pa, &l_times_u));
}
#[test]
fn test_lu_identity() {
let a = NumericMatrix::identity(3).unwrap();
let lu = a.lu_decomposition().unwrap();
assert!(matrix_approx_eq(
&lu.l,
&NumericMatrix::identity(3).unwrap()
));
assert!(matrix_approx_eq(
&lu.u,
&NumericMatrix::identity(3).unwrap()
));
assert_eq!(lu.num_swaps, 0);
}
#[test]
fn test_lu_singular() {
let a = NumericMatrix::from_flat(2, 2, vec![1.0, 2.0, 2.0, 4.0]).unwrap();
assert!(a.lu_decomposition().is_err());
}
#[test]
fn test_lu_non_square() {
let a = NumericMatrix::from_flat(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
assert!(a.lu_decomposition().is_err());
}
#[test]
fn test_determinant_1x1() {
let a = NumericMatrix::from_flat(1, 1, vec![5.0]).unwrap();
assert!(approx_eq(a.determinant().unwrap(), 5.0));
}
#[test]
fn test_determinant_2x2() {
let a = NumericMatrix::from_flat(2, 2, vec![3.0, 8.0, 4.0, 6.0]).unwrap();
let det = a.determinant().unwrap();
assert!(approx_eq(det, 3.0 * 6.0 - 8.0 * 4.0));
}
#[test]
fn test_determinant_3x3() {
let a = NumericMatrix::from_flat(3, 3, vec![2.0, 1.0, 1.0, 4.0, 3.0, 3.0, 8.0, 7.0, 9.0])
.unwrap();
let det = a.determinant().unwrap();
assert!(approx_eq(det, 4.0));
}
#[test]
fn test_determinant_identity() {
let a = NumericMatrix::identity(4).unwrap();
assert!(approx_eq(a.determinant().unwrap(), 1.0));
}
#[test]
fn test_determinant_singular() {
let a = NumericMatrix::from_flat(2, 2, vec![1.0, 2.0, 2.0, 4.0]).unwrap();
let det = a.determinant().unwrap();
assert!(approx_eq(det, 0.0));
}
#[test]
fn test_determinant_non_square() {
let a = NumericMatrix::from_flat(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
assert!(a.determinant().is_err());
}
}