proof_engine/symbolic/
matrix.rs1use super::expr::Expr;
4
5#[derive(Debug, Clone)]
7pub struct SymMatrix {
8 pub rows: usize,
9 pub cols: usize,
10 pub data: Vec<Vec<Expr>>,
11}
12
13impl SymMatrix {
14 pub fn new(rows: usize, cols: usize) -> Self {
15 Self { rows, cols, data: vec![vec![Expr::zero(); cols]; rows] }
16 }
17
18 pub fn identity(n: usize) -> Self {
19 let mut m = Self::new(n, n);
20 for i in 0..n { m.data[i][i] = Expr::one(); }
21 m
22 }
23
24 pub fn from_f64(data: &[&[f64]]) -> Self {
25 let rows = data.len();
26 let cols = if rows > 0 { data[0].len() } else { 0 };
27 let mut m = Self::new(rows, cols);
28 for i in 0..rows {
29 for j in 0..cols {
30 m.data[i][j] = Expr::c(data[i][j]);
31 }
32 }
33 m
34 }
35
36 pub fn get(&self, r: usize, c: usize) -> &Expr { &self.data[r][c] }
37 pub fn set(&mut self, r: usize, c: usize, val: Expr) { self.data[r][c] = val; }
38
39 pub fn mul(&self, other: &SymMatrix) -> SymMatrix {
41 assert_eq!(self.cols, other.rows);
42 let mut result = SymMatrix::new(self.rows, other.cols);
43 for i in 0..self.rows {
44 for j in 0..other.cols {
45 let mut sum = Expr::zero();
46 for k in 0..self.cols {
47 sum = sum.add(self.data[i][k].clone().mul(other.data[k][j].clone()));
48 }
49 result.data[i][j] = sum;
50 }
51 }
52 result
53 }
54
55 pub fn transpose(&self) -> SymMatrix {
57 let mut result = SymMatrix::new(self.cols, self.rows);
58 for i in 0..self.rows {
59 for j in 0..self.cols {
60 result.data[j][i] = self.data[i][j].clone();
61 }
62 }
63 result
64 }
65
66 pub fn determinant(&self) -> Expr {
68 assert_eq!(self.rows, self.cols);
69 let n = self.rows;
70 if n == 1 { return self.data[0][0].clone(); }
71 if n == 2 {
72 let a = self.data[0][0].clone().mul(self.data[1][1].clone());
73 let b = self.data[0][1].clone().mul(self.data[1][0].clone());
74 return a.sub(b);
75 }
76 let mut det = Expr::zero();
77 for j in 0..n {
78 let cofactor = self.cofactor(0, j);
79 let term = self.data[0][j].clone().mul(cofactor);
80 if j % 2 == 0 { det = det.add(term); }
81 else { det = det.sub(term); }
82 }
83 det
84 }
85
86 pub fn minor(&self, row: usize, col: usize) -> Expr {
88 let sub = self.submatrix(row, col);
89 sub.determinant()
90 }
91
92 pub fn cofactor(&self, row: usize, col: usize) -> Expr {
94 let m = self.minor(row, col);
95 if (row + col) % 2 == 0 { m } else { m.neg() }
96 }
97
98 pub fn submatrix(&self, row: usize, col: usize) -> SymMatrix {
100 let mut result = SymMatrix::new(self.rows - 1, self.cols - 1);
101 let mut ri = 0;
102 for i in 0..self.rows {
103 if i == row { continue; }
104 let mut ci = 0;
105 for j in 0..self.cols {
106 if j == col { continue; }
107 result.data[ri][ci] = self.data[i][j].clone();
108 ci += 1;
109 }
110 ri += 1;
111 }
112 result
113 }
114
115 pub fn trace(&self) -> Expr {
117 let mut sum = Expr::zero();
118 for i in 0..self.rows.min(self.cols) {
119 sum = sum.add(self.data[i][i].clone());
120 }
121 sum
122 }
123
124 pub fn eigenvalues_2x2(&self) -> Option<(f64, f64)> {
126 if self.rows != 2 || self.cols != 2 { return None; }
127 let vars = std::collections::HashMap::new();
128 let a = self.data[0][0].eval(&vars);
129 let b = self.data[0][1].eval(&vars);
130 let c = self.data[1][0].eval(&vars);
131 let d = self.data[1][1].eval(&vars);
132
133 let trace = a + d;
134 let det = a * d - b * c;
135 let disc = trace * trace - 4.0 * det;
136 if disc < 0.0 { return None; }
137 let sqrt_disc = disc.sqrt();
138 Some(((trace + sqrt_disc) / 2.0, (trace - sqrt_disc) / 2.0))
139 }
140}
141
142#[cfg(test)]
143mod tests {
144 use super::*;
145 use std::collections::HashMap;
146
147 #[test]
148 fn det_2x2() {
149 let m = SymMatrix::from_f64(&[&[1.0, 2.0], &[3.0, 4.0]]);
150 let det = m.determinant();
151 let val = det.eval(&HashMap::new());
152 assert!((val - (-2.0)).abs() < 1e-10);
153 }
154
155 #[test]
156 fn det_3x3() {
157 let m = SymMatrix::from_f64(&[&[1.0, 2.0, 3.0], &[4.0, 5.0, 6.0], &[7.0, 8.0, 0.0]]);
158 let det = m.determinant();
159 let val = det.eval(&HashMap::new());
160 assert!((val - 27.0).abs() < 1e-8);
161 }
162
163 #[test]
164 fn identity_det_is_one() {
165 let m = SymMatrix::identity(3);
166 let det = m.determinant();
167 let val = det.eval(&HashMap::new());
168 assert!((val - 1.0).abs() < 1e-10);
169 }
170
171 #[test]
172 fn eigenvalues_diagonal() {
173 let m = SymMatrix::from_f64(&[&[3.0, 0.0], &[0.0, 5.0]]);
174 let (e1, e2) = m.eigenvalues_2x2().unwrap();
175 assert!((e1 - 5.0).abs() < 1e-10);
176 assert!((e2 - 3.0).abs() < 1e-10);
177 }
178
179 #[test]
180 fn transpose() {
181 let m = SymMatrix::from_f64(&[&[1.0, 2.0], &[3.0, 4.0]]);
182 let t = m.transpose();
183 let val = t.data[1][0].eval(&HashMap::new());
184 assert!((val - 2.0).abs() < 1e-10);
185 }
186}