1use crate::diff::diff;
10use crate::kernel::{ExprId, ExprPool};
11use crate::simplify::engine::simplify;
12use std::fmt;
13
14pub mod eigen;
15pub mod normal_form;
16mod smith;
17mod smith_poly;
18
19pub use eigen::{
20 characteristic_polynomial_lambda_minus_m, diagonalize, eigenvalues, eigenvectors, EigenError,
21};
22pub use normal_form::{
23 hermite_form, hermite_form_poly, smith_form, smith_form_poly, IntegerMatrix, NormalFormError,
24 PolyMatrixQ, RatUniPoly,
25};
26
27#[derive(Clone, Debug, PartialEq, Eq)]
33pub struct Matrix {
34 data: Vec<ExprId>,
36 pub rows: usize,
37 pub cols: usize,
38}
39
40#[derive(Debug, Clone, PartialEq, Eq)]
41pub enum MatrixError {
42 DimensionMismatch { msg: String },
43 NotSquare,
44 SingularMatrix,
45}
46
47impl fmt::Display for MatrixError {
48 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
49 match self {
50 MatrixError::DimensionMismatch { msg } => write!(f, "dimension mismatch: {msg}"),
51 MatrixError::NotSquare => write!(f, "matrix is not square"),
52 MatrixError::SingularMatrix => write!(f, "matrix is singular"),
53 }
54 }
55}
56
57impl std::error::Error for MatrixError {}
58
59impl crate::errors::AlkahestError for MatrixError {
60 fn code(&self) -> &'static str {
61 match self {
62 MatrixError::DimensionMismatch { .. } => "E-MAT-001",
63 MatrixError::NotSquare => "E-MAT-002",
64 MatrixError::SingularMatrix => "E-MAT-003",
65 }
66 }
67
68 fn remediation(&self) -> Option<&'static str> {
69 match self {
70 MatrixError::DimensionMismatch { .. } => Some(
71 "ensure all rows have the same column count and operand dimensions match",
72 ),
73 MatrixError::NotSquare => Some(
74 "determinant and inverse require a square matrix; use the pseudo-inverse for rectangular matrices",
75 ),
76 MatrixError::SingularMatrix => Some(
77 "the matrix has a zero determinant; check your system of equations for linear dependence",
78 ),
79 }
80 }
81}
82
83impl Matrix {
84 pub fn new(rows: Vec<Vec<ExprId>>) -> Result<Self, MatrixError> {
86 if rows.is_empty() {
87 return Ok(Matrix {
88 data: vec![],
89 rows: 0,
90 cols: 0,
91 });
92 }
93 let cols = rows[0].len();
94 for r in &rows {
95 if r.len() != cols {
96 return Err(MatrixError::DimensionMismatch {
97 msg: format!("expected {cols} columns, got {}", r.len()),
98 });
99 }
100 }
101 let nrows = rows.len();
102 let data: Vec<ExprId> = rows.into_iter().flatten().collect();
103 Ok(Matrix {
104 data,
105 rows: nrows,
106 cols,
107 })
108 }
109
110 pub fn zeros(rows: usize, cols: usize, pool: &ExprPool) -> Self {
112 let zero = pool.integer(0_i32);
113 Matrix {
114 data: vec![zero; rows * cols],
115 rows,
116 cols,
117 }
118 }
119
120 pub fn identity(n: usize, pool: &ExprPool) -> Self {
122 let zero = pool.integer(0_i32);
123 let one = pool.integer(1_i32);
124 let mut data = vec![zero; n * n];
125 for i in 0..n {
126 data[i * n + i] = one;
127 }
128 Matrix {
129 data,
130 rows: n,
131 cols: n,
132 }
133 }
134
135 pub fn get(&self, r: usize, c: usize) -> ExprId {
137 self.data[r * self.cols + c]
138 }
139
140 pub fn set(&mut self, r: usize, c: usize, val: ExprId) {
142 self.data[r * self.cols + c] = val;
143 }
144
145 pub fn row(&self, r: usize) -> Vec<ExprId> {
147 self.data[r * self.cols..(r + 1) * self.cols].to_vec()
148 }
149
150 pub fn col(&self, c: usize) -> Vec<ExprId> {
152 (0..self.rows).map(|r| self.get(r, c)).collect()
153 }
154
155 pub fn transpose(&self) -> Self {
157 let mut data = Vec::with_capacity(self.rows * self.cols);
158 for c in 0..self.cols {
159 for r in 0..self.rows {
160 data.push(self.get(r, c));
161 }
162 }
163 Matrix {
164 data,
165 rows: self.cols,
166 cols: self.rows,
167 }
168 }
169
170 pub fn add(&self, other: &Matrix, pool: &ExprPool) -> Result<Matrix, MatrixError> {
172 self.check_same_shape(other)?;
173 let data = self
174 .data
175 .iter()
176 .zip(other.data.iter())
177 .map(|(&a, &b)| pool.add(vec![a, b]))
178 .collect();
179 Ok(Matrix {
180 data,
181 rows: self.rows,
182 cols: self.cols,
183 })
184 }
185
186 pub fn sub(&self, other: &Matrix, pool: &ExprPool) -> Result<Matrix, MatrixError> {
188 self.check_same_shape(other)?;
189 let neg_one = pool.integer(-1_i32);
190 let data = self
191 .data
192 .iter()
193 .zip(other.data.iter())
194 .map(|(&a, &b)| {
195 let neg_b = pool.mul(vec![neg_one, b]);
196 pool.add(vec![a, neg_b])
197 })
198 .collect();
199 Ok(Matrix {
200 data,
201 rows: self.rows,
202 cols: self.cols,
203 })
204 }
205
206 pub fn mul(&self, other: &Matrix, pool: &ExprPool) -> Result<Matrix, MatrixError> {
208 if self.cols != other.rows {
209 return Err(MatrixError::DimensionMismatch {
210 msg: format!(
211 "cannot multiply {}×{} by {}×{}",
212 self.rows, self.cols, other.rows, other.cols
213 ),
214 });
215 }
216 let m = self.rows;
217 let n = other.cols;
218 let k = self.cols;
219 let mut data = Vec::with_capacity(m * n);
220 for r in 0..m {
221 for c in 0..n {
222 let terms: Vec<ExprId> = (0..k)
223 .map(|i| pool.mul(vec![self.get(r, i), other.get(i, c)]))
224 .collect();
225 let entry = if terms.is_empty() {
226 pool.integer(0_i32)
227 } else if terms.len() == 1 {
228 terms[0]
229 } else {
230 pool.add(terms)
231 };
232 data.push(entry);
233 }
234 }
235 Ok(Matrix {
236 data,
237 rows: m,
238 cols: n,
239 })
240 }
241
242 pub fn scale(&self, scalar: ExprId, pool: &ExprPool) -> Matrix {
244 let data = self
245 .data
246 .iter()
247 .map(|&e| pool.mul(vec![scalar, e]))
248 .collect();
249 Matrix {
250 data,
251 rows: self.rows,
252 cols: self.cols,
253 }
254 }
255
256 pub fn simplify_entries(&self, pool: &ExprPool) -> Matrix {
258 let data = self.data.iter().map(|&e| simplify(e, pool).value).collect();
259 Matrix {
260 data,
261 rows: self.rows,
262 cols: self.cols,
263 }
264 }
265
266 pub fn det(&self, pool: &ExprPool) -> Result<ExprId, MatrixError> {
268 if self.rows != self.cols {
269 return Err(MatrixError::NotSquare);
270 }
271 let n = self.rows;
272 if n == 0 {
273 return Ok(pool.integer(1_i32));
274 }
275 if n == 1 {
276 return Ok(self.get(0, 0));
277 }
278 if n == 2 {
279 let ad = pool.mul(vec![self.get(0, 0), self.get(1, 1)]);
281 let bc = pool.mul(vec![self.get(0, 1), self.get(1, 0)]);
282 let neg_bc = pool.mul(vec![pool.integer(-1_i32), bc]);
283 return Ok(simplify(pool.add(vec![ad, neg_bc]), pool).value);
284 }
285 let mut terms: Vec<ExprId> = Vec::new();
287 for j in 0..n {
288 let minor = self.minor(0, j);
289 let minor_det = minor.det(pool)?;
290 let sign = if j % 2 == 0 {
291 pool.integer(1_i32)
292 } else {
293 pool.integer(-1_i32)
294 };
295 terms.push(pool.mul(vec![sign, self.get(0, j), minor_det]));
296 }
297 Ok(simplify(pool.add(terms), pool).value)
298 }
299
300 fn minor(&self, skip_row: usize, skip_col: usize) -> Matrix {
302 let n = self.rows;
303 let mut data = Vec::with_capacity((n - 1) * (n - 1));
304 for r in 0..n {
305 if r == skip_row {
306 continue;
307 }
308 for c in 0..n {
309 if c == skip_col {
310 continue;
311 }
312 data.push(self.get(r, c));
313 }
314 }
315 Matrix {
316 data,
317 rows: n - 1,
318 cols: n - 1,
319 }
320 }
321
322 fn check_same_shape(&self, other: &Matrix) -> Result<(), MatrixError> {
323 if self.rows != other.rows || self.cols != other.cols {
324 Err(MatrixError::DimensionMismatch {
325 msg: format!(
326 "{}×{} vs {}×{}",
327 self.rows, self.cols, other.rows, other.cols
328 ),
329 })
330 } else {
331 Ok(())
332 }
333 }
334
335 pub fn entries(&self) -> &[ExprId] {
337 &self.data
338 }
339
340 pub fn to_nested(&self) -> Vec<Vec<ExprId>> {
342 (0..self.rows).map(|r| self.row(r)).collect()
343 }
344
345 pub fn characteristic_polynomial_lambda_minus_m(
347 &self,
348 pool: &ExprPool,
349 ) -> Result<(ExprId, ExprId), EigenError> {
350 eigen::characteristic_polynomial_lambda_minus_m(self, pool)
351 }
352
353 pub fn eigenvalues(&self, pool: &ExprPool) -> Result<Vec<(ExprId, usize)>, EigenError> {
356 eigen::eigenvalues(self, pool)
357 }
358
359 pub fn eigenvectors(
361 &self,
362 pool: &ExprPool,
363 ) -> Result<Vec<(ExprId, usize, Vec<Matrix>)>, EigenError> {
364 eigen::eigenvectors(self, pool)
365 }
366
367 pub fn diagonalize(&self, pool: &ExprPool) -> Result<(Matrix, Matrix), EigenError> {
369 eigen::diagonalize(self, pool)
370 }
371}
372
373pub fn jacobian(
382 f_vec: &[ExprId],
383 x_vec: &[ExprId],
384 pool: &ExprPool,
385) -> Result<Matrix, crate::diff::diff_impl::DiffError> {
386 let m = f_vec.len();
387 let n = x_vec.len();
388 let mut data = Vec::with_capacity(m * n);
389 for &f in f_vec {
390 for &x in x_vec {
391 let df = diff(f, x, pool)?.value;
392 data.push(df);
393 }
394 }
395 Ok(Matrix {
396 data,
397 rows: m,
398 cols: n,
399 })
400}
401
402impl Matrix {
407 pub fn display(&self, pool: &ExprPool) -> String {
408 let rows: Vec<String> = (0..self.rows)
409 .map(|r| {
410 let entries: Vec<String> = self
411 .row(r)
412 .into_iter()
413 .map(|e| pool.display(e).to_string())
414 .collect();
415 format!("[{}]", entries.join(", "))
416 })
417 .collect();
418 format!("[{}]", rows.join(", "))
419 }
420}
421
422#[cfg(test)]
427mod tests {
428 use super::*;
429 use crate::kernel::{Domain, ExprPool};
430
431 fn p() -> ExprPool {
432 ExprPool::new()
433 }
434
435 #[test]
436 fn identity_2x2() {
437 let pool = p();
438 let id = Matrix::identity(2, &pool);
439 assert_eq!(id.rows, 2);
440 assert_eq!(id.cols, 2);
441 assert_eq!(id.get(0, 0), pool.integer(1_i32));
442 assert_eq!(id.get(0, 1), pool.integer(0_i32));
443 assert_eq!(id.get(1, 0), pool.integer(0_i32));
444 assert_eq!(id.get(1, 1), pool.integer(1_i32));
445 }
446
447 #[test]
448 fn transpose_2x3() {
449 let pool = p();
450 let x = pool.symbol("x", Domain::Real);
451 let y = pool.symbol("y", Domain::Real);
452 let z = pool.symbol("z", Domain::Real);
453 let a = pool.integer(1_i32);
454 let b = pool.integer(2_i32);
455 let c = pool.integer(3_i32);
456 let m = Matrix::new(vec![vec![x, y, z], vec![a, b, c]]).unwrap();
458 let t = m.transpose();
459 assert_eq!(t.rows, 3);
460 assert_eq!(t.cols, 2);
461 assert_eq!(t.get(0, 0), x);
462 assert_eq!(t.get(1, 1), b);
463 }
464
465 #[test]
466 fn add_matrices() {
467 let pool = p();
468 let x = pool.symbol("x", Domain::Real);
469 let one = pool.integer(1_i32);
470 let m1 = Matrix::new(vec![vec![x, one]]).unwrap();
471 let m2 = Matrix::new(vec![vec![one, x]]).unwrap();
472 let result = m1.add(&m2, &pool).unwrap();
473 let r00_str = pool.display(result.get(0, 0)).to_string();
475 assert!(
476 r00_str.contains("x") && r00_str.contains("1"),
477 "got: {r00_str}"
478 );
479 }
480
481 #[test]
482 fn mul_2x2() {
483 let pool = p();
484 let id = Matrix::identity(2, &pool);
486 let x = pool.symbol("x", Domain::Real);
487 let y = pool.symbol("y", Domain::Real);
488 let m = Matrix::new(vec![vec![x, y], vec![y, x]]).unwrap();
489 let result = id.mul(&m, &pool).unwrap().simplify_entries(&pool);
490 assert_eq!(result.get(0, 0), x);
491 assert_eq!(result.get(0, 1), y);
492 }
493
494 #[test]
495 fn det_2x2() {
496 let pool = p();
497 let a = pool.symbol("a", Domain::Real);
499 let b = pool.symbol("b", Domain::Real);
500 let c = pool.symbol("c", Domain::Real);
501 let d = pool.symbol("d", Domain::Real);
502 let m = Matrix::new(vec![vec![a, b], vec![c, d]]).unwrap();
503 let det = m.det(&pool).unwrap();
504 let s = pool.display(det).to_string();
505 assert!(s.contains("a") && s.contains("d"), "got: {s}");
506 }
507
508 #[test]
509 fn det_3x3_identity_is_one() {
510 let pool = p();
511 let id = Matrix::identity(3, &pool);
512 let det = id.det(&pool).unwrap();
513 assert_eq!(det, pool.integer(1_i32));
514 }
515
516 #[test]
517 fn jacobian_linear() {
518 let pool = p();
521 let x = pool.symbol("x", Domain::Real);
522 let y = pool.symbol("y", Domain::Real);
523 let neg_y = pool.mul(vec![pool.integer(-1_i32), y]);
524 let f1 = pool.add(vec![x, y]);
525 let f2 = pool.add(vec![x, neg_y]);
526 let j = jacobian(&[f1, f2], &[x, y], &pool).unwrap();
527 assert_eq!(j.rows, 2);
528 assert_eq!(j.cols, 2);
529 assert_eq!(j.get(0, 0), pool.integer(1_i32)); assert_eq!(j.get(0, 1), pool.integer(1_i32)); assert_eq!(j.get(1, 0), pool.integer(1_i32)); assert_eq!(j.get(1, 1), pool.integer(-1_i32)); }
534
535 #[test]
536 fn jacobian_quadratic() {
537 let pool = p();
540 let x = pool.symbol("x", Domain::Real);
541 let y = pool.symbol("y", Domain::Real);
542 let f1 = pool.pow(x, pool.integer(2_i32));
543 let f2 = pool.pow(y, pool.integer(2_i32));
544 let j = jacobian(&[f1, f2], &[x, y], &pool).unwrap();
545 assert_eq!(j.get(0, 1), pool.integer(0_i32));
547 assert_eq!(j.get(1, 0), pool.integer(0_i32));
548 }
549}