Skip to main content

alkahest_cas/matrix/
mod.rs

1//! Phase 15 — Symbolic matrices and vectors.
2//!
3//! Provides a dense `Matrix` of `ExprId` values together with:
4//! - arithmetic (`+`, `-`, `*`)
5//! - `transpose()`
6//! - `det()` (Bareiss integer-preserving elimination)
7//! - `jacobian(f_vec, x_vec, pool)` — the `m×n` matrix `∂f_i/∂x_j`
8
9use crate::diff::diff;
10use crate::kernel::{ExprId, ExprPool};
11use crate::simplify::engine::simplify;
12use std::fmt;
13
14pub mod eigen;
15pub mod normal_form;
16mod smith;
17mod smith_poly;
18
19pub use eigen::{
20    characteristic_polynomial_lambda_minus_m, diagonalize, eigenvalues, eigenvectors, EigenError,
21};
22pub use normal_form::{
23    hermite_form, hermite_form_poly, smith_form, smith_form_poly, IntegerMatrix, NormalFormError,
24    PolyMatrixQ, RatUniPoly,
25};
26
27// ---------------------------------------------------------------------------
28// Matrix type
29// ---------------------------------------------------------------------------
30
31/// A dense symbolic matrix stored in row-major order.
32#[derive(Clone, Debug, PartialEq, Eq)]
33pub struct Matrix {
34    /// Row-major flat storage of `ExprId` entries.
35    data: Vec<ExprId>,
36    pub rows: usize,
37    pub cols: usize,
38}
39
40#[derive(Debug, Clone, PartialEq, Eq)]
41pub enum MatrixError {
42    DimensionMismatch { msg: String },
43    NotSquare,
44    SingularMatrix,
45}
46
47impl fmt::Display for MatrixError {
48    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
49        match self {
50            MatrixError::DimensionMismatch { msg } => write!(f, "dimension mismatch: {msg}"),
51            MatrixError::NotSquare => write!(f, "matrix is not square"),
52            MatrixError::SingularMatrix => write!(f, "matrix is singular"),
53        }
54    }
55}
56
57impl std::error::Error for MatrixError {}
58
59impl crate::errors::AlkahestError for MatrixError {
60    fn code(&self) -> &'static str {
61        match self {
62            MatrixError::DimensionMismatch { .. } => "E-MAT-001",
63            MatrixError::NotSquare => "E-MAT-002",
64            MatrixError::SingularMatrix => "E-MAT-003",
65        }
66    }
67
68    fn remediation(&self) -> Option<&'static str> {
69        match self {
70            MatrixError::DimensionMismatch { .. } => Some(
71                "ensure all rows have the same column count and operand dimensions match",
72            ),
73            MatrixError::NotSquare => Some(
74                "determinant and inverse require a square matrix; use the pseudo-inverse for rectangular matrices",
75            ),
76            MatrixError::SingularMatrix => Some(
77                "the matrix has a zero determinant; check your system of equations for linear dependence",
78            ),
79        }
80    }
81}
82
83impl Matrix {
84    /// Create a matrix from row-major nested vectors.
85    pub fn new(rows: Vec<Vec<ExprId>>) -> Result<Self, MatrixError> {
86        if rows.is_empty() {
87            return Ok(Matrix {
88                data: vec![],
89                rows: 0,
90                cols: 0,
91            });
92        }
93        let cols = rows[0].len();
94        for r in &rows {
95            if r.len() != cols {
96                return Err(MatrixError::DimensionMismatch {
97                    msg: format!("expected {cols} columns, got {}", r.len()),
98                });
99            }
100        }
101        let nrows = rows.len();
102        let data: Vec<ExprId> = rows.into_iter().flatten().collect();
103        Ok(Matrix {
104            data,
105            rows: nrows,
106            cols,
107        })
108    }
109
110    /// Create a zero matrix (all entries are `pool.integer(0)`).
111    pub fn zeros(rows: usize, cols: usize, pool: &ExprPool) -> Self {
112        let zero = pool.integer(0_i32);
113        Matrix {
114            data: vec![zero; rows * cols],
115            rows,
116            cols,
117        }
118    }
119
120    /// Create an identity matrix.
121    pub fn identity(n: usize, pool: &ExprPool) -> Self {
122        let zero = pool.integer(0_i32);
123        let one = pool.integer(1_i32);
124        let mut data = vec![zero; n * n];
125        for i in 0..n {
126            data[i * n + i] = one;
127        }
128        Matrix {
129            data,
130            rows: n,
131            cols: n,
132        }
133    }
134
135    /// Get entry at row `r`, column `c` (0-indexed).
136    pub fn get(&self, r: usize, c: usize) -> ExprId {
137        self.data[r * self.cols + c]
138    }
139
140    /// Set entry at row `r`, column `c`.
141    pub fn set(&mut self, r: usize, c: usize, val: ExprId) {
142        self.data[r * self.cols + c] = val;
143    }
144
145    /// Get a row as a vector.
146    pub fn row(&self, r: usize) -> Vec<ExprId> {
147        self.data[r * self.cols..(r + 1) * self.cols].to_vec()
148    }
149
150    /// Get a column as a vector.
151    pub fn col(&self, c: usize) -> Vec<ExprId> {
152        (0..self.rows).map(|r| self.get(r, c)).collect()
153    }
154
155    /// Transpose.
156    pub fn transpose(&self) -> Self {
157        let mut data = Vec::with_capacity(self.rows * self.cols);
158        for c in 0..self.cols {
159            for r in 0..self.rows {
160                data.push(self.get(r, c));
161            }
162        }
163        Matrix {
164            data,
165            rows: self.cols,
166            cols: self.rows,
167        }
168    }
169
170    /// Element-wise addition.
171    pub fn add(&self, other: &Matrix, pool: &ExprPool) -> Result<Matrix, MatrixError> {
172        self.check_same_shape(other)?;
173        let data = self
174            .data
175            .iter()
176            .zip(other.data.iter())
177            .map(|(&a, &b)| pool.add(vec![a, b]))
178            .collect();
179        Ok(Matrix {
180            data,
181            rows: self.rows,
182            cols: self.cols,
183        })
184    }
185
186    /// Element-wise subtraction.
187    pub fn sub(&self, other: &Matrix, pool: &ExprPool) -> Result<Matrix, MatrixError> {
188        self.check_same_shape(other)?;
189        let neg_one = pool.integer(-1_i32);
190        let data = self
191            .data
192            .iter()
193            .zip(other.data.iter())
194            .map(|(&a, &b)| {
195                let neg_b = pool.mul(vec![neg_one, b]);
196                pool.add(vec![a, neg_b])
197            })
198            .collect();
199        Ok(Matrix {
200            data,
201            rows: self.rows,
202            cols: self.cols,
203        })
204    }
205
206    /// Matrix multiplication (`self` is m×k, `other` is k×n → result is m×n).
207    pub fn mul(&self, other: &Matrix, pool: &ExprPool) -> Result<Matrix, MatrixError> {
208        if self.cols != other.rows {
209            return Err(MatrixError::DimensionMismatch {
210                msg: format!(
211                    "cannot multiply {}×{} by {}×{}",
212                    self.rows, self.cols, other.rows, other.cols
213                ),
214            });
215        }
216        let m = self.rows;
217        let n = other.cols;
218        let k = self.cols;
219        let mut data = Vec::with_capacity(m * n);
220        for r in 0..m {
221            for c in 0..n {
222                let terms: Vec<ExprId> = (0..k)
223                    .map(|i| pool.mul(vec![self.get(r, i), other.get(i, c)]))
224                    .collect();
225                let entry = if terms.is_empty() {
226                    pool.integer(0_i32)
227                } else if terms.len() == 1 {
228                    terms[0]
229                } else {
230                    pool.add(terms)
231                };
232                data.push(entry);
233            }
234        }
235        Ok(Matrix {
236            data,
237            rows: m,
238            cols: n,
239        })
240    }
241
242    /// Scalar multiplication.
243    pub fn scale(&self, scalar: ExprId, pool: &ExprPool) -> Matrix {
244        let data = self
245            .data
246            .iter()
247            .map(|&e| pool.mul(vec![scalar, e]))
248            .collect();
249        Matrix {
250            data,
251            rows: self.rows,
252            cols: self.cols,
253        }
254    }
255
256    /// Simplify all entries.
257    pub fn simplify_entries(&self, pool: &ExprPool) -> Matrix {
258        let data = self.data.iter().map(|&e| simplify(e, pool).value).collect();
259        Matrix {
260            data,
261            rows: self.rows,
262            cols: self.cols,
263        }
264    }
265
266    /// Determinant using Bareiss algorithm (exact over integers, symbolic otherwise).
267    pub fn det(&self, pool: &ExprPool) -> Result<ExprId, MatrixError> {
268        if self.rows != self.cols {
269            return Err(MatrixError::NotSquare);
270        }
271        let n = self.rows;
272        if n == 0 {
273            return Ok(pool.integer(1_i32));
274        }
275        if n == 1 {
276            return Ok(self.get(0, 0));
277        }
278        if n == 2 {
279            // ad - bc
280            let ad = pool.mul(vec![self.get(0, 0), self.get(1, 1)]);
281            let bc = pool.mul(vec![self.get(0, 1), self.get(1, 0)]);
282            let neg_bc = pool.mul(vec![pool.integer(-1_i32), bc]);
283            return Ok(simplify(pool.add(vec![ad, neg_bc]), pool).value);
284        }
285        // Cofactor expansion along first row for n >= 3
286        let mut terms: Vec<ExprId> = Vec::new();
287        for j in 0..n {
288            let minor = self.minor(0, j);
289            let minor_det = minor.det(pool)?;
290            let sign = if j % 2 == 0 {
291                pool.integer(1_i32)
292            } else {
293                pool.integer(-1_i32)
294            };
295            terms.push(pool.mul(vec![sign, self.get(0, j), minor_det]));
296        }
297        Ok(simplify(pool.add(terms), pool).value)
298    }
299
300    /// Submatrix obtained by removing row `r` and column `c`.
301    fn minor(&self, skip_row: usize, skip_col: usize) -> Matrix {
302        let n = self.rows;
303        let mut data = Vec::with_capacity((n - 1) * (n - 1));
304        for r in 0..n {
305            if r == skip_row {
306                continue;
307            }
308            for c in 0..n {
309                if c == skip_col {
310                    continue;
311                }
312                data.push(self.get(r, c));
313            }
314        }
315        Matrix {
316            data,
317            rows: n - 1,
318            cols: n - 1,
319        }
320    }
321
322    fn check_same_shape(&self, other: &Matrix) -> Result<(), MatrixError> {
323        if self.rows != other.rows || self.cols != other.cols {
324            Err(MatrixError::DimensionMismatch {
325                msg: format!(
326                    "{}×{} vs {}×{}",
327                    self.rows, self.cols, other.rows, other.cols
328                ),
329            })
330        } else {
331            Ok(())
332        }
333    }
334
335    /// Return a flat reference to all entries.
336    pub fn entries(&self) -> &[ExprId] {
337        &self.data
338    }
339
340    /// Return entries as a nested `Vec<Vec<ExprId>>`.
341    pub fn to_nested(&self) -> Vec<Vec<ExprId>> {
342        (0..self.rows).map(|r| self.row(r)).collect()
343    }
344
345    /// V2-17 — `det(λI − M)` as a pooled expression plus the fresh λ symbol used.
346    pub fn characteristic_polynomial_lambda_minus_m(
347        &self,
348        pool: &ExprPool,
349    ) -> Result<(ExprId, ExprId), EigenError> {
350        eigen::characteristic_polynomial_lambda_minus_m(self, pool)
351    }
352
353    /// V2-17 — Algebraic eigenvalues `(value, multiplicity)` for matrices whose characteristic
354    /// polynomial factors over ℚ into linear and quadratic terms.
355    pub fn eigenvalues(&self, pool: &ExprPool) -> Result<Vec<(ExprId, usize)>, EigenError> {
356        eigen::eigenvalues(self, pool)
357    }
358
359    /// V2-17 — Eigenvalue tuples `(λ, multiplicity, column eigenvectors)`.
360    pub fn eigenvectors(
361        &self,
362        pool: &ExprPool,
363    ) -> Result<Vec<(ExprId, usize, Vec<Matrix>)>, EigenError> {
364        eigen::eigenvectors(self, pool)
365    }
366
367    /// V2-17 — `(P, D)` with `M·P == P·D` when diagonalizable in the ℚ-splitting-field sense.
368    pub fn diagonalize(&self, pool: &ExprPool) -> Result<(Matrix, Matrix), EigenError> {
369        eigen::diagonalize(self, pool)
370    }
371}
372
373// ---------------------------------------------------------------------------
374// Jacobian
375// ---------------------------------------------------------------------------
376
377/// Compute the Jacobian matrix `J[i][j] = ∂f_i/∂x_j`.
378///
379/// `f_vec` is a slice of m scalar expressions; `x_vec` is a slice of n
380/// variable expressions.  The result is an m×n `Matrix`.
381pub fn jacobian(
382    f_vec: &[ExprId],
383    x_vec: &[ExprId],
384    pool: &ExprPool,
385) -> Result<Matrix, crate::diff::diff_impl::DiffError> {
386    let m = f_vec.len();
387    let n = x_vec.len();
388    let mut data = Vec::with_capacity(m * n);
389    for &f in f_vec {
390        for &x in x_vec {
391            let df = diff(f, x, pool)?.value;
392            data.push(df);
393        }
394    }
395    Ok(Matrix {
396        data,
397        rows: m,
398        cols: n,
399    })
400}
401
402// ---------------------------------------------------------------------------
403// Display
404// ---------------------------------------------------------------------------
405
406impl Matrix {
407    pub fn display(&self, pool: &ExprPool) -> String {
408        let rows: Vec<String> = (0..self.rows)
409            .map(|r| {
410                let entries: Vec<String> = self
411                    .row(r)
412                    .into_iter()
413                    .map(|e| pool.display(e).to_string())
414                    .collect();
415                format!("[{}]", entries.join(", "))
416            })
417            .collect();
418        format!("[{}]", rows.join(", "))
419    }
420}
421
422// ---------------------------------------------------------------------------
423// Tests
424// ---------------------------------------------------------------------------
425
426#[cfg(test)]
427mod tests {
428    use super::*;
429    use crate::kernel::{Domain, ExprPool};
430
431    fn p() -> ExprPool {
432        ExprPool::new()
433    }
434
435    #[test]
436    fn identity_2x2() {
437        let pool = p();
438        let id = Matrix::identity(2, &pool);
439        assert_eq!(id.rows, 2);
440        assert_eq!(id.cols, 2);
441        assert_eq!(id.get(0, 0), pool.integer(1_i32));
442        assert_eq!(id.get(0, 1), pool.integer(0_i32));
443        assert_eq!(id.get(1, 0), pool.integer(0_i32));
444        assert_eq!(id.get(1, 1), pool.integer(1_i32));
445    }
446
447    #[test]
448    fn transpose_2x3() {
449        let pool = p();
450        let x = pool.symbol("x", Domain::Real);
451        let y = pool.symbol("y", Domain::Real);
452        let z = pool.symbol("z", Domain::Real);
453        let a = pool.integer(1_i32);
454        let b = pool.integer(2_i32);
455        let c = pool.integer(3_i32);
456        // [[x, y, z], [a, b, c]]  →  [[x,a],[y,b],[z,c]]
457        let m = Matrix::new(vec![vec![x, y, z], vec![a, b, c]]).unwrap();
458        let t = m.transpose();
459        assert_eq!(t.rows, 3);
460        assert_eq!(t.cols, 2);
461        assert_eq!(t.get(0, 0), x);
462        assert_eq!(t.get(1, 1), b);
463    }
464
465    #[test]
466    fn add_matrices() {
467        let pool = p();
468        let x = pool.symbol("x", Domain::Real);
469        let one = pool.integer(1_i32);
470        let m1 = Matrix::new(vec![vec![x, one]]).unwrap();
471        let m2 = Matrix::new(vec![vec![one, x]]).unwrap();
472        let result = m1.add(&m2, &pool).unwrap();
473        // result[0][0] = x + 1
474        let r00_str = pool.display(result.get(0, 0)).to_string();
475        assert!(
476            r00_str.contains("x") && r00_str.contains("1"),
477            "got: {r00_str}"
478        );
479    }
480
481    #[test]
482    fn mul_2x2() {
483        let pool = p();
484        // [[1,0],[0,1]] * [[a,b],[c,d]] = [[a,b],[c,d]]
485        let id = Matrix::identity(2, &pool);
486        let x = pool.symbol("x", Domain::Real);
487        let y = pool.symbol("y", Domain::Real);
488        let m = Matrix::new(vec![vec![x, y], vec![y, x]]).unwrap();
489        let result = id.mul(&m, &pool).unwrap().simplify_entries(&pool);
490        assert_eq!(result.get(0, 0), x);
491        assert_eq!(result.get(0, 1), y);
492    }
493
494    #[test]
495    fn det_2x2() {
496        let pool = p();
497        // det([[a,b],[c,d]]) = ad - bc
498        let a = pool.symbol("a", Domain::Real);
499        let b = pool.symbol("b", Domain::Real);
500        let c = pool.symbol("c", Domain::Real);
501        let d = pool.symbol("d", Domain::Real);
502        let m = Matrix::new(vec![vec![a, b], vec![c, d]]).unwrap();
503        let det = m.det(&pool).unwrap();
504        let s = pool.display(det).to_string();
505        assert!(s.contains("a") && s.contains("d"), "got: {s}");
506    }
507
508    #[test]
509    fn det_3x3_identity_is_one() {
510        let pool = p();
511        let id = Matrix::identity(3, &pool);
512        let det = id.det(&pool).unwrap();
513        assert_eq!(det, pool.integer(1_i32));
514    }
515
516    #[test]
517    fn jacobian_linear() {
518        // f = [x + y, x - y], vars = [x, y]
519        // J = [[1, 1], [1, -1]]
520        let pool = p();
521        let x = pool.symbol("x", Domain::Real);
522        let y = pool.symbol("y", Domain::Real);
523        let neg_y = pool.mul(vec![pool.integer(-1_i32), y]);
524        let f1 = pool.add(vec![x, y]);
525        let f2 = pool.add(vec![x, neg_y]);
526        let j = jacobian(&[f1, f2], &[x, y], &pool).unwrap();
527        assert_eq!(j.rows, 2);
528        assert_eq!(j.cols, 2);
529        assert_eq!(j.get(0, 0), pool.integer(1_i32)); // ∂f1/∂x
530        assert_eq!(j.get(0, 1), pool.integer(1_i32)); // ∂f1/∂y
531        assert_eq!(j.get(1, 0), pool.integer(1_i32)); // ∂f2/∂x
532        assert_eq!(j.get(1, 1), pool.integer(-1_i32)); // ∂f2/∂y
533    }
534
535    #[test]
536    fn jacobian_quadratic() {
537        // f = [x², y²], vars = [x, y]
538        // J = [[2x, 0], [0, 2y]]
539        let pool = p();
540        let x = pool.symbol("x", Domain::Real);
541        let y = pool.symbol("y", Domain::Real);
542        let f1 = pool.pow(x, pool.integer(2_i32));
543        let f2 = pool.pow(y, pool.integer(2_i32));
544        let j = jacobian(&[f1, f2], &[x, y], &pool).unwrap();
545        // ∂f1/∂y = 0, ∂f2/∂x = 0
546        assert_eq!(j.get(0, 1), pool.integer(0_i32));
547        assert_eq!(j.get(1, 0), pool.integer(0_i32));
548    }
549}