mathhook-core 0.2.0

Core mathematical engine for MathHook - expressions, algebra, and solving
Documentation
use super::NumericMatrix;
use crate::error::MathError;

const PIVOT_THRESHOLD: f64 = 1e-10;

#[cfg(test)]
const EPSILON: f64 = 1e-10;

#[derive(Debug, Clone, PartialEq)]
pub struct LUResult {
    pub l: NumericMatrix,
    pub u: NumericMatrix,
    pub p: Vec<usize>,
    pub num_swaps: usize,
}

impl NumericMatrix {
    pub fn lu_decomposition(&self) -> Result<LUResult, MathError> {
        if !self.is_square() {
            return Err(MathError::DomainError {
                operation: "LU decomposition".to_string(),
                value: crate::Expression::integer(self.dimensions().0 as i64),
                reason: "LU decomposition requires square matrix".to_string(),
            });
        }

        let n = self.rows;
        let mut l = NumericMatrix::identity(n)?;
        let mut u = self.clone();
        let mut p: Vec<usize> = (0..n).collect();
        let mut num_swaps = 0;

        for k in 0..n {
            let mut max_val = 0.0;
            let mut pivot_row = k;

            for i in k..n {
                let val = u.data[i * n + k].abs();
                if val > max_val {
                    max_val = val;
                    pivot_row = i;
                }
            }

            if max_val < PIVOT_THRESHOLD {
                return Err(MathError::DomainError {
                    operation: "LU decomposition".to_string(),
                    value: crate::Expression::float(max_val),
                    reason: format!(
                        "Matrix is singular or nearly singular (pivot {} < {})",
                        max_val, PIVOT_THRESHOLD
                    ),
                });
            }

            if pivot_row != k {
                for j in 0..n {
                    u.data.swap(k * n + j, pivot_row * n + j);
                }
                for j in 0..k {
                    l.data.swap(k * n + j, pivot_row * n + j);
                }
                p.swap(k, pivot_row);
                num_swaps += 1;
            }

            for i in (k + 1)..n {
                let factor = u.data[i * n + k] / u.data[k * n + k];
                l.data[i * n + k] = factor;

                for j in k..n {
                    u.data[i * n + j] -= factor * u.data[k * n + j];
                }
            }
        }

        Ok(LUResult { l, u, p, num_swaps })
    }

    pub fn determinant(&self) -> Result<f64, MathError> {
        if !self.is_square() {
            return Err(MathError::DomainError {
                operation: "determinant".to_string(),
                value: crate::Expression::integer(self.dimensions().0 as i64),
                reason: "Determinant requires square matrix".to_string(),
            });
        }

        if self.rows == 1 {
            return Ok(self.data[0]);
        }

        if self.rows == 2 {
            return Ok(self.data[0] * self.data[3] - self.data[1] * self.data[2]);
        }

        let lu = self.lu_decomposition()?;

        let mut det = 1.0;
        for i in 0..self.rows {
            det *= lu.u.data[i * self.rows + i];
        }

        if lu.num_swaps % 2 == 1 {
            det = -det;
        }

        Ok(det)
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    fn approx_eq(a: f64, b: f64) -> bool {
        (a - b).abs() < EPSILON
    }

    fn matrix_approx_eq(a: &NumericMatrix, b: &NumericMatrix) -> bool {
        if a.dimensions() != b.dimensions() {
            return false;
        }
        a.data
            .iter()
            .zip(b.data.iter())
            .all(|(x, y)| approx_eq(*x, *y))
    }

    #[test]
    fn test_lu_2x2() {
        let a = NumericMatrix::from_flat(2, 2, vec![2.0, 1.0, 4.0, 3.0]).unwrap();
        let lu = a.lu_decomposition().unwrap();

        assert_eq!(lu.l.dimensions(), (2, 2));
        assert_eq!(lu.u.dimensions(), (2, 2));

        assert!(approx_eq(lu.l.get(0, 0).unwrap(), 1.0));
        assert!(approx_eq(lu.l.get(1, 1).unwrap(), 1.0));
        assert!(approx_eq(lu.l.get(0, 1).unwrap(), 0.0));

        let mut pa_data = vec![0.0; 4];
        for i in 0..2 {
            for j in 0..2 {
                pa_data[i * 2 + j] = a.get(lu.p[i], j).unwrap();
            }
        }
        let pa = NumericMatrix::from_flat(2, 2, pa_data).unwrap();

        let l_times_u = lu.l.multiply(&lu.u).unwrap();

        assert!(matrix_approx_eq(&pa, &l_times_u));
    }

    #[test]
    fn test_lu_3x3() {
        let a = NumericMatrix::from_flat(3, 3, vec![2.0, 1.0, 1.0, 4.0, 3.0, 3.0, 8.0, 7.0, 9.0])
            .unwrap();

        let lu = a.lu_decomposition().unwrap();

        assert_eq!(lu.l.dimensions(), (3, 3));
        assert_eq!(lu.u.dimensions(), (3, 3));

        for i in 0..3 {
            assert!(approx_eq(lu.l.get(i, i).unwrap(), 1.0));
        }

        for i in 0..3 {
            for j in (i + 1)..3 {
                assert!(approx_eq(lu.l.get(i, j).unwrap(), 0.0));
                assert!(approx_eq(lu.u.get(j, i).unwrap(), 0.0));
            }
        }

        let mut pa_data = vec![0.0; 9];
        for i in 0..3 {
            for j in 0..3 {
                pa_data[i * 3 + j] = a.get(lu.p[i], j).unwrap();
            }
        }
        let pa = NumericMatrix::from_flat(3, 3, pa_data).unwrap();

        let l_times_u = lu.l.multiply(&lu.u).unwrap();

        assert!(matrix_approx_eq(&pa, &l_times_u));
    }

    #[test]
    fn test_lu_identity() {
        let a = NumericMatrix::identity(3).unwrap();
        let lu = a.lu_decomposition().unwrap();

        assert!(matrix_approx_eq(
            &lu.l,
            &NumericMatrix::identity(3).unwrap()
        ));
        assert!(matrix_approx_eq(
            &lu.u,
            &NumericMatrix::identity(3).unwrap()
        ));
        assert_eq!(lu.num_swaps, 0);
    }

    #[test]
    fn test_lu_singular() {
        let a = NumericMatrix::from_flat(2, 2, vec![1.0, 2.0, 2.0, 4.0]).unwrap();
        assert!(a.lu_decomposition().is_err());
    }

    #[test]
    fn test_lu_non_square() {
        let a = NumericMatrix::from_flat(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
        assert!(a.lu_decomposition().is_err());
    }

    #[test]
    fn test_determinant_1x1() {
        let a = NumericMatrix::from_flat(1, 1, vec![5.0]).unwrap();
        assert!(approx_eq(a.determinant().unwrap(), 5.0));
    }

    #[test]
    fn test_determinant_2x2() {
        let a = NumericMatrix::from_flat(2, 2, vec![3.0, 8.0, 4.0, 6.0]).unwrap();
        let det = a.determinant().unwrap();
        assert!(approx_eq(det, 3.0 * 6.0 - 8.0 * 4.0));
    }

    #[test]
    fn test_determinant_3x3() {
        let a = NumericMatrix::from_flat(3, 3, vec![2.0, 1.0, 1.0, 4.0, 3.0, 3.0, 8.0, 7.0, 9.0])
            .unwrap();

        let det = a.determinant().unwrap();
        assert!(approx_eq(det, 4.0));
    }

    #[test]
    fn test_determinant_identity() {
        let a = NumericMatrix::identity(4).unwrap();
        assert!(approx_eq(a.determinant().unwrap(), 1.0));
    }

    #[test]
    fn test_determinant_singular() {
        let a = NumericMatrix::from_flat(2, 2, vec![1.0, 2.0, 2.0, 4.0]).unwrap();
        let det = a.determinant().unwrap();
        assert!(approx_eq(det, 0.0));
    }

    #[test]
    fn test_determinant_non_square() {
        let a = NumericMatrix::from_flat(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
        assert!(a.determinant().is_err());
    }
}