Skip to main content

proof_engine/symbolic/
matrix.rs

1//! Symbolic matrix operations — determinant, inverse, eigenvalues.
2
3use super::expr::Expr;
4
5/// A symbolic matrix of expressions.
6#[derive(Debug, Clone)]
7pub struct SymMatrix {
8    pub rows: usize,
9    pub cols: usize,
10    pub data: Vec<Vec<Expr>>,
11}
12
13impl SymMatrix {
14    pub fn new(rows: usize, cols: usize) -> Self {
15        Self { rows, cols, data: vec![vec![Expr::zero(); cols]; rows] }
16    }
17
18    pub fn identity(n: usize) -> Self {
19        let mut m = Self::new(n, n);
20        for i in 0..n { m.data[i][i] = Expr::one(); }
21        m
22    }
23
24    pub fn from_f64(data: &[&[f64]]) -> Self {
25        let rows = data.len();
26        let cols = if rows > 0 { data[0].len() } else { 0 };
27        let mut m = Self::new(rows, cols);
28        for i in 0..rows {
29            for j in 0..cols {
30                m.data[i][j] = Expr::c(data[i][j]);
31            }
32        }
33        m
34    }
35
36    pub fn get(&self, r: usize, c: usize) -> &Expr { &self.data[r][c] }
37    pub fn set(&mut self, r: usize, c: usize, val: Expr) { self.data[r][c] = val; }
38
39    /// Matrix multiplication.
40    pub fn mul(&self, other: &SymMatrix) -> SymMatrix {
41        assert_eq!(self.cols, other.rows);
42        let mut result = SymMatrix::new(self.rows, other.cols);
43        for i in 0..self.rows {
44            for j in 0..other.cols {
45                let mut sum = Expr::zero();
46                for k in 0..self.cols {
47                    sum = sum.add(self.data[i][k].clone().mul(other.data[k][j].clone()));
48                }
49                result.data[i][j] = sum;
50            }
51        }
52        result
53    }
54
55    /// Transpose.
56    pub fn transpose(&self) -> SymMatrix {
57        let mut result = SymMatrix::new(self.cols, self.rows);
58        for i in 0..self.rows {
59            for j in 0..self.cols {
60                result.data[j][i] = self.data[i][j].clone();
61            }
62        }
63        result
64    }
65
66    /// Determinant (recursive cofactor expansion).
67    pub fn determinant(&self) -> Expr {
68        assert_eq!(self.rows, self.cols);
69        let n = self.rows;
70        if n == 1 { return self.data[0][0].clone(); }
71        if n == 2 {
72            let a = self.data[0][0].clone().mul(self.data[1][1].clone());
73            let b = self.data[0][1].clone().mul(self.data[1][0].clone());
74            return a.sub(b);
75        }
76        let mut det = Expr::zero();
77        for j in 0..n {
78            let cofactor = self.cofactor(0, j);
79            let term = self.data[0][j].clone().mul(cofactor);
80            if j % 2 == 0 { det = det.add(term); }
81            else { det = det.sub(term); }
82        }
83        det
84    }
85
86    /// Minor: determinant of the submatrix with row i and col j removed.
87    pub fn minor(&self, row: usize, col: usize) -> Expr {
88        let sub = self.submatrix(row, col);
89        sub.determinant()
90    }
91
92    /// Cofactor: (-1)^(i+j) * minor(i,j).
93    pub fn cofactor(&self, row: usize, col: usize) -> Expr {
94        let m = self.minor(row, col);
95        if (row + col) % 2 == 0 { m } else { m.neg() }
96    }
97
98    /// Remove row i and column j.
99    pub fn submatrix(&self, row: usize, col: usize) -> SymMatrix {
100        let mut result = SymMatrix::new(self.rows - 1, self.cols - 1);
101        let mut ri = 0;
102        for i in 0..self.rows {
103            if i == row { continue; }
104            let mut ci = 0;
105            for j in 0..self.cols {
106                if j == col { continue; }
107                result.data[ri][ci] = self.data[i][j].clone();
108                ci += 1;
109            }
110            ri += 1;
111        }
112        result
113    }
114
115    /// Trace: sum of diagonal elements.
116    pub fn trace(&self) -> Expr {
117        let mut sum = Expr::zero();
118        for i in 0..self.rows.min(self.cols) {
119            sum = sum.add(self.data[i][i].clone());
120        }
121        sum
122    }
123
124    /// Numerical eigenvalues for a 2x2 matrix.
125    pub fn eigenvalues_2x2(&self) -> Option<(f64, f64)> {
126        if self.rows != 2 || self.cols != 2 { return None; }
127        let vars = std::collections::HashMap::new();
128        let a = self.data[0][0].eval(&vars);
129        let b = self.data[0][1].eval(&vars);
130        let c = self.data[1][0].eval(&vars);
131        let d = self.data[1][1].eval(&vars);
132
133        let trace = a + d;
134        let det = a * d - b * c;
135        let disc = trace * trace - 4.0 * det;
136        if disc < 0.0 { return None; }
137        let sqrt_disc = disc.sqrt();
138        Some(((trace + sqrt_disc) / 2.0, (trace - sqrt_disc) / 2.0))
139    }
140}
141
142#[cfg(test)]
143mod tests {
144    use super::*;
145    use std::collections::HashMap;
146
147    #[test]
148    fn det_2x2() {
149        let m = SymMatrix::from_f64(&[&[1.0, 2.0], &[3.0, 4.0]]);
150        let det = m.determinant();
151        let val = det.eval(&HashMap::new());
152        assert!((val - (-2.0)).abs() < 1e-10);
153    }
154
155    #[test]
156    fn det_3x3() {
157        let m = SymMatrix::from_f64(&[&[1.0, 2.0, 3.0], &[4.0, 5.0, 6.0], &[7.0, 8.0, 0.0]]);
158        let det = m.determinant();
159        let val = det.eval(&HashMap::new());
160        assert!((val - 27.0).abs() < 1e-8);
161    }
162
163    #[test]
164    fn identity_det_is_one() {
165        let m = SymMatrix::identity(3);
166        let det = m.determinant();
167        let val = det.eval(&HashMap::new());
168        assert!((val - 1.0).abs() < 1e-10);
169    }
170
171    #[test]
172    fn eigenvalues_diagonal() {
173        let m = SymMatrix::from_f64(&[&[3.0, 0.0], &[0.0, 5.0]]);
174        let (e1, e2) = m.eigenvalues_2x2().unwrap();
175        assert!((e1 - 5.0).abs() < 1e-10);
176        assert!((e2 - 3.0).abs() < 1e-10);
177    }
178
179    #[test]
180    fn transpose() {
181        let m = SymMatrix::from_f64(&[&[1.0, 2.0], &[3.0, 4.0]]);
182        let t = m.transpose();
183        let val = t.data[1][0].eval(&HashMap::new());
184        assert!((val - 2.0).abs() < 1e-10);
185    }
186}