Skip to main content

alkahest_cas/matrix/
eigen.rs

1//! V2-17 — Eigenvalues, eigenvectors, and diagonalization for dense symbolic matrices
2//! whose characteristic polynomial splits over ℚ into linear and quadratic factors.
3//!
4//! The characteristic polynomial is `det(λI − M)` in a fresh λ symbol. Entries may
5//! be rational; the determinant is read as a ℚ-polynomial in λ and cleared to a
6//! ℤ-polynomial for factorization (same roots).
7
8#![allow(clippy::needless_range_loop)]
9
10use crate::kernel::{Domain, ExprData, ExprId, ExprPool};
11use crate::matrix::Matrix;
12use crate::poly::error::ConversionError;
13use crate::poly::unipoly::UniPoly;
14use crate::poly::{factor_univariate_z, FactorError};
15use crate::simplify::engine::{simplify, simplify_expanded};
16use rug::Rational;
17use std::fmt;
18use std::sync::atomic::{AtomicUsize, Ordering};
19
20/// Errors from eigen-decomposition helpers.
21#[derive(Debug, Clone, PartialEq, Eq)]
22pub enum EigenError {
23    /// `eigenvals` requires a square matrix.
24    NonSquare,
25    /// The determinant polynomial could not be read as ℤ\[λ\].
26    CharPolyConversion(ConversionError),
27    /// FLINT factorization failed.
28    Factorization(FactorError),
29    /// The characteristic polynomial has an irreducible factor of degree greater than two.
30    UnsupportedIrreducibleDegree { degree: usize },
31    /// Algebraic and geometric multiplicity disagree (Jordan block situation).
32    NonDiagonalizable,
33    /// Gaussian elimination failed to produce a numerical field representation.
34    KernelComputationFailed,
35    /// `P` in `diagonalize` is singular / not invertible.
36    SingularModalMatrix,
37}
38
39impl fmt::Display for EigenError {
40    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
41        match self {
42            EigenError::NonSquare => write!(f, "eigen decomposition requires a square matrix"),
43            EigenError::CharPolyConversion(e) => write!(f, "characteristic polynomial: {e}"),
44            EigenError::Factorization(e) => write!(f, "factorization failed: {e}"),
45            EigenError::UnsupportedIrreducibleDegree { degree } => write!(
46                f,
47                "irreducible characteristic factor of degree {degree}; only degrees 1–2 are supported"
48            ),
49            EigenError::NonDiagonalizable => {
50                write!(f, "matrix is not diagonalizable over the computed eigenbasis")
51            }
52            EigenError::KernelComputationFailed => write!(
53                f,
54                "could not compute eigenvectors (nullspace) for this coefficient field"
55            ),
56            EigenError::SingularModalMatrix => {
57                write!(f, "eigenvector matrix is singular — no diagonal decomposition")
58            }
59        }
60    }
61}
62
63impl std::error::Error for EigenError {}
64
65impl crate::errors::AlkahestError for EigenError {
66    fn code(&self) -> &'static str {
67        match self {
68            EigenError::NonSquare => "E-EIGEN-001",
69            EigenError::CharPolyConversion(_) => "E-EIGEN-002",
70            EigenError::Factorization(_) => "E-EIGEN-003",
71            EigenError::UnsupportedIrreducibleDegree { .. } => "E-EIGEN-004",
72            EigenError::NonDiagonalizable => "E-EIGEN-005",
73            EigenError::KernelComputationFailed => "E-EIGEN-006",
74            EigenError::SingularModalMatrix => "E-EIGEN-007",
75        }
76    }
77
78    fn remediation(&self) -> Option<&'static str> {
79        match self {
80            EigenError::NonSquare => Some("pass a square n×n matrix"),
81            EigenError::CharPolyConversion(_) => Some(
82                "entries must simplify to rationals/constants so det(λI−M) is a polynomial in λ",
83            ),
84            EigenError::Factorization(_) => None,
85            EigenError::UnsupportedIrreducibleDegree { .. } => {
86                Some("higher-degree irreducible characteristic factors require a CAS / algebraic-numbers extension")
87            }
88            EigenError::NonDiagonalizable => {
89                Some("use Jordan-form tooling or restrict to diagonalizable matrices")
90            }
91            EigenError::KernelComputationFailed => Some(
92                "try a purely rational spectrum or a matrix with quadratic splitting only over ℚ",
93            ),
94            EigenError::SingularModalMatrix => Some(
95                "the computed eigenvectors are linearly dependent; check multiplicities or input",
96            ),
97        }
98    }
99}
100
101static EIGEN_LAM_SEQ: AtomicUsize = AtomicUsize::new(0);
102
103fn fresh_lambda(pool: &ExprPool) -> ExprId {
104    let n = EIGEN_LAM_SEQ.fetch_add(1, Ordering::Relaxed);
105    pool.symbol(format!("__eigen_lambda_{n}"), Domain::Complex)
106}
107
108/// √(−1), used internally for quadratic splitting roots and ℚ(\`i\`) nullspaces.
109pub(crate) fn imag_unit(pool: &ExprPool) -> ExprId {
110    let neg_one = pool.integer(-1_i32);
111    pool.func("sqrt", vec![neg_one])
112}
113
114/// `(det(λ I − M), λ)` — the determinant is simplified before return.
115pub fn characteristic_polynomial_lambda_minus_m(
116    m: &Matrix,
117    pool: &ExprPool,
118) -> Result<(ExprId, ExprId), EigenError> {
119    if m.rows != m.cols {
120        return Err(EigenError::NonSquare);
121    }
122    let lam = fresh_lambda(pool);
123    let lm = lambda_identity_minus_m(m, lam, pool);
124    let det = lm.det(pool).map_err(|_| EigenError::NonSquare)?;
125    Ok((simplify(det, pool).value, lam))
126}
127
128/// multiset of eigenvalue Expr → algebraic multiplicity  
129pub fn eigenvalues(m: &Matrix, pool: &ExprPool) -> Result<Vec<(ExprId, usize)>, EigenError> {
130    let (poly_e, lam) = characteristic_polynomial_lambda_minus_m(m, pool)?;
131    eigenvalues_from_char_poly(poly_e, lam, pool)
132}
133
134/// `(value, multiplicity, column eigenvectors)`
135pub fn eigenvectors(
136    m: &Matrix,
137    pool: &ExprPool,
138) -> Result<Vec<(ExprId, usize, Vec<Matrix>)>, EigenError> {
139    let vals = eigenvalues(m, pool)?;
140    let mut out = Vec::with_capacity(vals.len());
141    for (lambda, mult) in vals {
142        let b = m_minus_lambda_scaled(m, lambda, pool);
143        let vecs =
144            kernel_column_basis(&b, pool).map_err(|_| EigenError::KernelComputationFailed)?;
145        out.push((lambda, mult, vecs));
146    }
147    Ok(out)
148}
149
150/// Returns `(P, D)` with `M·P == P·D` (same convention as SymPy: columns of `P` are eigenvectors).
151pub fn diagonalize(m: &Matrix, pool: &ExprPool) -> Result<(Matrix, Matrix), EigenError> {
152    let evecs = eigenvectors(m, pool)?;
153    for (_lambda, alg_m, vecs) in &evecs {
154        if vecs.len() != *alg_m {
155            return Err(EigenError::NonDiagonalizable);
156        }
157    }
158    let n = m.rows;
159    let mut cols: Vec<Matrix> = Vec::with_capacity(n);
160    let mut diag_entries: Vec<ExprId> = Vec::with_capacity(n);
161    for (lambda, _alg_m, vecs) in evecs {
162        for v in vecs {
163            cols.push(v);
164            diag_entries.push(lambda);
165        }
166    }
167    if cols.len() != n {
168        return Err(EigenError::NonDiagonalizable);
169    }
170    let p_mat =
171        concatenate_columns(&cols, pool).map_err(|_| EigenError::KernelComputationFailed)?;
172    // Verify full rank geometrically via det / invertibility later
173    let d_mat = diagonal_from_entries(&diag_entries, pool);
174    if !columns_match_eigen_relation(m, &p_mat, pool, &diag_entries) {
175        return Err(EigenError::NonDiagonalizable);
176    }
177    Ok((p_mat, d_mat))
178}
179
180// ---------------------------------------------------------------------------
181// Characteristic polynomial → eigenvalues
182// ---------------------------------------------------------------------------
183
184fn eigenvalues_from_char_poly(
185    poly_e: ExprId,
186    lam: ExprId,
187    pool: &ExprPool,
188) -> Result<Vec<(ExprId, usize)>, EigenError> {
189    let uni = UniPoly::from_symbolic_clear_denoms(poly_e, lam, pool)
190        .map_err(EigenError::CharPolyConversion)?;
191    let fac = factor_univariate_z(&uni).map_err(EigenError::Factorization)?;
192    let mut pairs: Vec<(ExprId, usize)> = Vec::new();
193    for (base, exp) in fac.factors {
194        let d = base.degree() as usize;
195        match d {
196            0 => continue,
197            1 => {
198                let root = linear_root(&base, pool)
199                    .ok_or(EigenError::UnsupportedIrreducibleDegree { degree: d })?;
200                pairs.push((root, exp as usize));
201            }
202            2 => {
203                let (r1, r2) = quadratic_roots(&base, pool)?;
204                pairs.push((r1, exp as usize));
205                pairs.push((r2, exp as usize));
206            }
207            _ => return Err(EigenError::UnsupportedIrreducibleDegree { degree: d }),
208        }
209    }
210    sort_eigenpairs(&pairs, pool)
211}
212
213fn linear_root(p: &UniPoly, pool: &ExprPool) -> Option<ExprId> {
214    let c = p.coefficients();
215    if c.len() != 2 {
216        return None;
217    }
218    // c0 + c1 λ
219    let numer = -&c[0];
220    let denom = c[1].clone();
221    if denom == 0 {
222        None
223    } else {
224        Some(pool.rational(numer, denom))
225    }
226}
227
228fn quadratic_roots(p: &UniPoly, pool: &ExprPool) -> Result<(ExprId, ExprId), EigenError> {
229    let c = p.coefficients();
230    if c.len() != 3 {
231        return Err(EigenError::UnsupportedIrreducibleDegree {
232            degree: p.degree().max(0) as usize,
233        });
234    }
235    let c0 = c[0].clone();
236    let c1 = c[1].clone();
237    let c2 = c[2].clone();
238    if c2 == 0 {
239        return Err(EigenError::UnsupportedIrreducibleDegree { degree: 1 });
240    }
241    // Prefer ±√(-c₀/c₂) when c₁ = 0 (e.g. λ² + 1) so roots use √(−1) and match the ℚ(i) path.
242    if c1 == 0 {
243        let r_rat = Rational::from((rug::Integer::from(0) - &c0, c2.clone()));
244        let inner_sqrt = if r_rat.denom().clone() == 1 {
245            pool.integer(r_rat.numer().clone())
246        } else {
247            pool.rational(r_rat.numer().clone(), r_rat.denom().clone())
248        };
249        let sd = simplify(pool.func("sqrt", vec![inner_sqrt]), pool).value;
250        let r_minus = simplify(pool.mul(vec![pool.integer(-1_i32), sd]), pool).value;
251        let (x, y) = order_two_roots(sd, r_minus, pool);
252        return Ok((x, y));
253    }
254    let mut disc = c1.clone() * c1.clone();
255    disc -= rug::Integer::from(4) * c0 * c2.clone();
256    let sqrt_d = simplify(
257        pool.func("sqrt", vec![expr_from_integer(&disc, pool)]),
258        pool,
259    )
260    .value;
261    let neg_c1 = rug::Integer::from(0) - &c1;
262    let minus_c1 = expr_from_integer(&neg_c1, pool);
263    let two_a = c2.clone() * rug::Integer::from(2);
264    let inv_2a = pool.rational(rug::Integer::from(1), two_a.clone());
265    let r_plus = simplify(
266        pool.mul(vec![
267            inv_2a,
268            simplify(pool.add(vec![minus_c1, sqrt_d]), pool).value,
269        ]),
270        pool,
271    )
272    .value;
273    let neg_sqrt = simplify(pool.mul(vec![pool.integer(-1_i32), sqrt_d]), pool).value;
274    let r_minus = simplify(
275        pool.mul(vec![
276            inv_2a,
277            simplify(pool.add(vec![minus_c1, neg_sqrt]), pool).value,
278        ]),
279        pool,
280    )
281    .value;
282    let (x, y) = order_two_roots(r_plus, r_minus, pool);
283    Ok((x, y))
284}
285
286fn expr_from_integer(n: &rug::Integer, pool: &ExprPool) -> ExprId {
287    pool.integer(n.clone())
288}
289
290/// Lexicographic sort on `pool.display` for stable tests.
291fn sort_eigenpairs(
292    pairs: &[(ExprId, usize)],
293    pool: &ExprPool,
294) -> Result<Vec<(ExprId, usize)>, EigenError> {
295    let mut v: Vec<(ExprId, usize)> = pairs.to_vec();
296    v.sort_by(|a, b| {
297        pool.display(a.0)
298            .to_string()
299            .cmp(&pool.display(b.0).to_string())
300    });
301    Ok(v)
302}
303
304fn order_two_roots(a: ExprId, b: ExprId, pool: &ExprPool) -> (ExprId, ExprId) {
305    let sa = pool.display(a).to_string();
306    let sb = pool.display(b).to_string();
307    if sa <= sb {
308        (a, b)
309    } else {
310        (b, a)
311    }
312}
313
314// ---------------------------------------------------------------------------
315// λ I − M and M − λ I
316// ---------------------------------------------------------------------------
317
318fn lambda_identity_minus_m(m: &Matrix, lam: ExprId, pool: &ExprPool) -> Matrix {
319    let n = m.rows;
320    let neg_one = pool.integer(-1_i32);
321    let mut data = Vec::with_capacity(n * n);
322    for r in 0..n {
323        for c in 0..n {
324            let e = if r == c {
325                let term = pool.mul(vec![neg_one, m.get(r, c)]);
326                pool.add(vec![lam, term])
327            } else {
328                pool.mul(vec![neg_one, m.get(r, c)])
329            };
330            data.push(e);
331        }
332    }
333    Matrix {
334        data,
335        rows: n,
336        cols: n,
337    }
338}
339
340fn m_minus_lambda_scaled(m: &Matrix, lambda: ExprId, pool: &ExprPool) -> Matrix {
341    let n = m.rows;
342    let mut data = Vec::with_capacity(n * n);
343    for r in 0..n {
344        for c in 0..n {
345            let e = if r == c {
346                let neg_l = pool.mul(vec![pool.integer(-1_i32), lambda]);
347                pool.add(vec![m.get(r, c), neg_l])
348            } else {
349                m.get(r, c)
350            };
351            data.push(e);
352        }
353    }
354    Matrix {
355        data,
356        rows: n,
357        cols: n,
358    }
359}
360
361// ---------------------------------------------------------------------------
362// Nullspace
363// ---------------------------------------------------------------------------
364
365fn kernel_2x2_column_basis(m: &Matrix, pool: &ExprPool) -> Option<Vec<Matrix>> {
366    let a00 = simplify(m.get(0, 0), pool).value;
367    let b01 = simplify(m.get(0, 1), pool).value;
368    let c10 = simplify(m.get(1, 0), pool).value;
369    let d11 = simplify(m.get(1, 1), pool).value;
370    let neg_one = pool.integer(-1_i32);
371    let (a, b) = if expr_is_exactly_zero(pool, a00) && expr_is_exactly_zero(pool, b01) {
372        if expr_is_exactly_zero(pool, c10) && expr_is_exactly_zero(pool, d11) {
373            return None;
374        }
375        (c10, d11)
376    } else {
377        (a00, b01)
378    };
379    let neg_b = simplify(pool.mul(vec![neg_one, b]), pool).value;
380    let col = Matrix::new(vec![vec![neg_b], vec![a]]).ok()?;
381    Some(vec![col])
382}
383
384fn kernel_column_basis(m: &Matrix, pool: &ExprPool) -> Result<Vec<Matrix>, ()> {
385    if m.rows == 2 && m.cols == 2 {
386        if let Some(bas) = kernel_2x2_column_basis(m, pool) {
387            return Ok(bas);
388        }
389    }
390    let cols = m.cols;
391    let n = m.rows;
392    if let Some(rm) = matrix_to_rational_grid(m, pool) {
393        let bas = rational_nullspace_basis(&rm, n, cols);
394        return Ok(bas
395            .into_iter()
396            .map(|col| col_vector_from_rationals(&col, pool))
397            .collect());
398    }
399    let iu = imag_unit(pool);
400    if let Some(qm) = matrix_to_qi_grid(m, iu, pool) {
401        let bas = qi_nullspace_basis(&qm, n, cols);
402        return Ok(bas
403            .into_iter()
404            .map(|col| col_vector_from_qi(&col, iu, pool))
405            .collect());
406    }
407    let bas = gauss_nullspace_expr(m, pool)?;
408    Ok(bas
409        .into_iter()
410        .map(|col| col_vector_from_expr_slice(&col, pool))
411        .collect())
412}
413
414fn col_vector_from_rationals(v: &[Rational], pool: &ExprPool) -> Matrix {
415    let rows: Vec<Vec<ExprId>> = v
416        .iter()
417        .map(|r| vec![pool.rational(r.numer().clone(), r.denom().clone())])
418        .collect();
419    Matrix::new(rows).expect("cols")
420}
421
422fn col_vector_from_qi(v: &[(Rational, Rational)], imag: ExprId, pool: &ExprPool) -> Matrix {
423    let rows: Vec<Vec<ExprId>> = v
424        .iter()
425        .map(|(re, im)| {
426            let re_e = pool.rational(re.numer().clone(), re.denom().clone());
427            if im == &Rational::from(0) {
428                vec![re_e]
429            } else {
430                let im_e = pool.rational(im.numer().clone(), im.denom().clone());
431                let im_term = simplify(pool.mul(vec![im_e, imag]), pool).value;
432                vec![simplify(pool.add(vec![re_e, im_term]), pool).value]
433            }
434        })
435        .collect();
436    Matrix::new(rows).expect("qi col")
437}
438
439fn col_vector_from_expr_slice(v: &[ExprId], _pool: &ExprPool) -> Matrix {
440    let rows: Vec<Vec<ExprId>> = v.iter().copied().map(|e| vec![e]).collect();
441    Matrix::new(rows).expect("expr col")
442}
443
444fn is_sqrt_of_negative_one(pool: &ExprPool, e: ExprId) -> bool {
445    match pool.get(e) {
446        ExprData::Func { name, args } if name == "sqrt" && args.len() == 1 => {
447            args[0] == pool.integer(-1_i32)
448        }
449        _ => false,
450    }
451}
452
453fn squash_sqrt_minus_one_squared(e: ExprId, pool: &ExprPool) -> ExprId {
454    fn rec(e: ExprId, pool: &ExprPool) -> ExprId {
455        match pool.get(e) {
456            ExprData::Pow { base, exp } => {
457                if let ExprData::Integer(n) = pool.get(exp) {
458                    if n.0 == 2 && is_sqrt_of_negative_one(pool, base) {
459                        return pool.integer(-1_i32);
460                    }
461                }
462                let nb = rec(base, pool);
463                let ne = rec(exp, pool);
464                pool.pow(nb, ne)
465            }
466            ExprData::Add(args) => {
467                let v: Vec<ExprId> = args.iter().map(|&a| rec(a, pool)).collect();
468                pool.add(v)
469            }
470            ExprData::Mul(args) => {
471                let v: Vec<ExprId> = args.iter().map(|&a| rec(a, pool)).collect();
472                pool.mul(v)
473            }
474            _ => e,
475        }
476    }
477    rec(e, pool)
478}
479
480fn deep_normalize_for_compare(expr: ExprId, pool: &ExprPool, rounds: usize) -> ExprId {
481    let mut cur = squash_sqrt_minus_one_squared(expr, pool);
482    for _ in 0..rounds {
483        let n = simplify_expanded(cur, pool).value;
484        let n2 = simplify(n, pool).value;
485        if n2 == cur {
486            break;
487        }
488        cur = n2;
489    }
490    cur
491}
492
493#[allow(dead_code)]
494fn matrix_eq_simplified(a: &Matrix, b: &Matrix, pool: &ExprPool) -> bool {
495    if a.rows != b.rows || a.cols != b.cols {
496        return false;
497    }
498    for i in 0..a.rows * a.cols {
499        let ea = deep_normalize_for_compare(a.entries()[i], pool, 12);
500        let eb = deep_normalize_for_compare(b.entries()[i], pool, 12);
501        if ea != eb {
502            return false;
503        }
504    }
505    true
506}
507
508/// `(M P)_{r,j} == λ_j P_{r,j}` for all `r`, `j`.
509fn columns_match_eigen_relation(
510    m: &Matrix,
511    p: &Matrix,
512    pool: &ExprPool,
513    lambdas: &[ExprId],
514) -> bool {
515    let n = m.rows;
516    if m.cols != n || p.rows != n || p.cols != n || lambdas.len() != n {
517        return false;
518    }
519    for j in 0..n {
520        let lam = lambdas[j];
521        for r in 0..n {
522            let mut terms: Vec<ExprId> = Vec::with_capacity(n);
523            for k in 0..n {
524                terms.push(pool.mul(vec![m.get(r, k), p.get(k, j)]));
525            }
526            let lhs = simplify(pool.add(terms), pool).value;
527            let rhs = simplify(pool.mul(vec![lam, p.get(r, j)]), pool).value;
528            let lhs3 = deep_normalize_for_compare(lhs, pool, 12);
529            let rhs3 = deep_normalize_for_compare(rhs, pool, 12);
530            if lhs3 != rhs3 {
531                return false;
532            }
533        }
534    }
535    true
536}
537
538// --- ℚ grid ---
539
540fn matrix_to_rational_grid(m: &Matrix, pool: &ExprPool) -> Option<Vec<Vec<Rational>>> {
541    let mut g = Vec::with_capacity(m.rows);
542    for r in 0..m.rows {
543        let mut row = Vec::with_capacity(m.cols);
544        for c in 0..m.cols {
545            row.push(expr_to_rational_strict(m.get(r, c), pool)?);
546        }
547        g.push(row);
548    }
549    Some(g)
550}
551
552fn expr_to_rational_strict(e: ExprId, pool: &ExprPool) -> Option<Rational> {
553    match pool.get(e) {
554        ExprData::Integer(ref n) => Some(Rational::from((n.0.clone(), rug::Integer::from(1)))),
555        ExprData::Rational(ref r) => Some(r.0.clone()),
556        ExprData::Add(ref args) => {
557            let mut acc = Rational::from(0);
558            for &a in args {
559                acc += expr_to_rational_strict(a, pool)?;
560            }
561            Some(acc)
562        }
563        ExprData::Mul(ref args) => {
564            let mut acc = Rational::from(1);
565            for &a in args {
566                acc *= expr_to_rational_strict(a, pool)?;
567            }
568            Some(acc)
569        }
570        ExprData::Pow { base, exp } => match pool.get(exp) {
571            ExprData::Integer(n) => {
572                let ei = n.0.to_i32()?;
573                if ei < 0 {
574                    None
575                } else {
576                    let b = expr_to_rational_strict(base, pool)?;
577                    Some(if ei == 0 {
578                        Rational::from(1)
579                    } else {
580                        let mut acc = Rational::from(1);
581                        for _ in 0..ei {
582                            acc *= b.clone();
583                        }
584                        acc
585                    })
586                }
587            }
588            _ => None,
589        },
590        _ => None,
591    }
592}
593
594fn rational_nullspace_basis(mat: &[Vec<Rational>], rows: usize, cols: usize) -> Vec<Vec<Rational>> {
595    let mut a = mat.to_vec();
596    let mut pivot_cols: Vec<usize> = Vec::new();
597    let mut r = 0usize;
598    for c in 0..cols {
599        if r >= rows {
600            break;
601        }
602        let mut piv = None;
603        for rr in r..rows {
604            if a[rr][c] != 0 {
605                piv = Some(rr);
606                break;
607            }
608        }
609        let Some(pr) = piv else { continue };
610        if pr != r {
611            a.swap(pr, r);
612        }
613        let inv = Rational::from(1) / a[r][c].clone();
614        for cc in 0..cols {
615            a[r][cc] *= inv.clone();
616        }
617        for rr in 0..rows {
618            if rr == r {
619                continue;
620            }
621            let f = a[rr][c].clone();
622            if f == 0 {
623                continue;
624            }
625            for cc in 0..cols {
626                let sub = f.clone() * a[r][cc].clone();
627                a[rr][cc] -= sub;
628            }
629        }
630        pivot_cols.push(c);
631        r += 1;
632    }
633    let mut is_pivot = vec![false; cols];
634    for &pc in &pivot_cols {
635        is_pivot[pc] = true;
636    }
637    let mut bases: Vec<Vec<Rational>> = Vec::new();
638    for fc in 0..cols {
639        if is_pivot[fc] {
640            continue;
641        }
642        let mut v = vec![Rational::from(0); cols];
643        v[fc] = Rational::from(1);
644        for (i, &pc) in pivot_cols.iter().enumerate() {
645            if i >= r {
646                break;
647            }
648            v[pc] = -a[i][fc].clone();
649        }
650        bases.push(v);
651    }
652    bases
653}
654
655// --- ℚ(i) when entries are `(re) + (im)*sqrt(-1)` with rational re, im ---
656
657fn matrix_to_qi_grid(
658    m: &Matrix,
659    imag: ExprId,
660    pool: &ExprPool,
661) -> Option<Vec<Vec<(Rational, Rational)>>> {
662    let mut g = Vec::with_capacity(m.rows);
663    for r in 0..m.rows {
664        let mut row = Vec::with_capacity(m.cols);
665        for c in 0..m.cols {
666            row.push(split_qi_linear(m.get(r, c), imag, pool)?);
667        }
668        g.push(row);
669    }
670    Some(g)
671}
672
673fn split_qi_linear(e: ExprId, imag: ExprId, pool: &ExprPool) -> Option<(Rational, Rational)> {
674    if let Some(r) = expr_to_rational_strict(e, pool) {
675        return Some((r, Rational::from(0)));
676    }
677    if e == imag {
678        return Some((Rational::from(0), Rational::from(1)));
679    }
680    match pool.get(e) {
681        ExprData::Mul(ref args) if args.contains(&imag) => {
682            let rest: Vec<ExprId> = args.iter().copied().filter(|&x| x != imag).collect();
683            let prod = if rest.is_empty() {
684                pool.integer(1_i32)
685            } else if rest.len() == 1 {
686                rest[0]
687            } else {
688                pool.mul(rest)
689            };
690            Some((Rational::from(0), expr_to_rational_strict(prod, pool)?))
691        }
692        ExprData::Add(ref args) => {
693            let mut re = Rational::from(0);
694            let mut im = Rational::from(0);
695            for &a in args {
696                if a == imag {
697                    im += Rational::from(1);
698                } else if let ExprData::Mul(ms) = pool.get(a) {
699                    if ms.contains(&imag) {
700                        let rest: Vec<ExprId> = ms.iter().copied().filter(|&x| x != imag).collect();
701                        let prod = if rest.is_empty() {
702                            pool.integer(1_i32)
703                        } else if rest.len() == 1 {
704                            rest[0]
705                        } else {
706                            pool.mul(rest)
707                        };
708                        im += expr_to_rational_strict(prod, pool)?;
709                    } else {
710                        re += expr_to_rational_strict(a, pool)?;
711                    }
712                } else {
713                    re += expr_to_rational_strict(a, pool)?;
714                }
715            }
716            Some((re, im))
717        }
718        _ => None,
719    }
720}
721
722fn qi_mul(a: (Rational, Rational), b: (Rational, Rational)) -> (Rational, Rational) {
723    let (ar, ai) = a;
724    let (br, bi) = b;
725    (
726        ar.clone() * br.clone() - ai.clone() * bi.clone(),
727        ar * bi + ai * br,
728    )
729}
730
731fn qi_add(a: (Rational, Rational), b: (Rational, Rational)) -> (Rational, Rational) {
732    (a.0 + b.0, a.1 + b.1)
733}
734
735fn qi_neg(a: (Rational, Rational)) -> (Rational, Rational) {
736    (-a.0, -a.1)
737}
738
739fn qi_is_zero(q: &(Rational, Rational)) -> bool {
740    q.0.is_zero() && q.1.is_zero()
741}
742
743fn qi_inv(a: (Rational, Rational)) -> Option<(Rational, Rational)> {
744    let norm = a.0.clone() * a.0.clone() + a.1.clone() * a.1.clone();
745    if norm.is_zero() {
746        None
747    } else {
748        Some((a.0.clone() / norm.clone(), (-a.1.clone()) / norm.clone()))
749    }
750}
751
752fn qi_nullspace_basis(
753    mat: &[Vec<(Rational, Rational)>],
754    rows: usize,
755    cols: usize,
756) -> Vec<Vec<(Rational, Rational)>> {
757    let mut a = mat.to_vec();
758    let mut pivot_cols: Vec<usize> = Vec::new();
759    let mut r = 0usize;
760    for c in 0..cols {
761        if r >= rows {
762            break;
763        }
764        let mut piv = None;
765        for rr in r..rows {
766            if !qi_is_zero(&a[rr][c]) {
767                piv = Some(rr);
768                break;
769            }
770        }
771        let Some(pr) = piv else { continue };
772        if pr != r {
773            a.swap(pr, r);
774        }
775        let inv = qi_inv(a[r][c].clone()).unwrap();
776        for cc in 0..cols {
777            a[r][cc] = qi_mul(inv.clone(), a[r][cc].clone());
778        }
779        for rr in 0..rows {
780            if rr == r {
781                continue;
782            }
783            let f = a[rr][c].clone();
784            if qi_is_zero(&f) {
785                continue;
786            }
787            for cc in 0..cols {
788                let sub = qi_mul(f.clone(), a[r][cc].clone());
789                a[rr][cc] = qi_add(a[rr][cc].clone(), qi_neg(sub));
790            }
791        }
792        pivot_cols.push(c);
793        r += 1;
794    }
795    let mut is_pivot = vec![false; cols];
796    for &pc in &pivot_cols {
797        is_pivot[pc] = true;
798    }
799    let mut bases: Vec<Vec<(Rational, Rational)>> = Vec::new();
800    for fc in 0..cols {
801        if is_pivot[fc] {
802            continue;
803        }
804        let mut v = vec![(Rational::from(0), Rational::from(0)); cols];
805        v[fc] = (Rational::from(1), Rational::from(0));
806        for (i, &pc) in pivot_cols.iter().enumerate() {
807            if i >= r {
808                break;
809            }
810            v[pc] = qi_neg(a[i][fc].clone());
811        }
812        bases.push(v);
813    }
814    bases
815}
816
817// --- Expr Gaussian fallback ---
818
819fn gauss_nullspace_expr(m: &Matrix, pool: &ExprPool) -> Result<Vec<Vec<ExprId>>, ()> {
820    let rows = m.rows;
821    let cols = m.cols;
822    let mut a: Vec<Vec<ExprId>> = (0..rows)
823        .map(|r| {
824            (0..cols)
825                .map(|c| simplify(m.get(r, c), pool).value)
826                .collect()
827        })
828        .collect();
829    let neg_one = pool.integer(-1_i32);
830    let mut pivot_cols: Vec<usize> = Vec::new();
831    let mut r_at = 0usize;
832    for c in 0..cols {
833        if r_at >= rows {
834            break;
835        }
836        let mut prow = None;
837        for rr in r_at..rows {
838            let e = simplify(a[rr][c], pool).value;
839            if !expr_is_exactly_zero(pool, e) {
840                prow = Some((rr, e));
841                break;
842            }
843        }
844        let Some((pr, piv)) = prow else { continue };
845        if pr != r_at {
846            a.swap(pr, r_at);
847        }
848        let inv_p = simplify(pool.pow(piv, pool.integer(-1_i32)), pool).value;
849        for cc in 0..cols {
850            a[r_at][cc] = simplify(pool.mul(vec![inv_p, a[r_at][cc]]), pool).value;
851        }
852        for rr in 0..rows {
853            if rr == r_at {
854                continue;
855            }
856            let f = simplify(a[rr][c], pool).value;
857            if expr_is_exactly_zero(pool, f) {
858                continue;
859            }
860            for cc in 0..cols {
861                let term = simplify(pool.mul(vec![f, a[r_at][cc]]), pool).value;
862                let neg_term = simplify(pool.mul(vec![neg_one, term]), pool).value;
863                a[rr][cc] = simplify(pool.add(vec![a[rr][cc], neg_term]), pool).value;
864            }
865        }
866        pivot_cols.push(c);
867        r_at += 1;
868    }
869    let mut is_pivot = vec![false; cols];
870    for &pc in &pivot_cols {
871        is_pivot[pc] = true;
872    }
873    let mut bases: Vec<Vec<ExprId>> = Vec::new();
874    for fc in 0..cols {
875        if is_pivot[fc] {
876            continue;
877        }
878        let mut v = vec![pool.integer(0_i32); cols];
879        v[fc] = pool.integer(1_i32);
880        for (i, &pc) in pivot_cols.iter().enumerate() {
881            if i >= r_at {
882                break;
883            }
884            let neg_entry = simplify(pool.mul(vec![neg_one, a[i][fc]]), pool).value;
885            v[pc] = neg_entry;
886        }
887        bases.push(v);
888    }
889    Ok(bases)
890}
891
892fn expr_is_exactly_zero(pool: &ExprPool, e: ExprId) -> bool {
893    match pool.get(e) {
894        ExprData::Integer(n) => n.0 == 0,
895        ExprData::Rational(r) => r.0 == 0,
896        _ => false,
897    }
898}
899
900// ---------------------------------------------------------------------------
901// Column concat + diagonal
902// ---------------------------------------------------------------------------
903
904fn concatenate_columns(cols: &[Matrix], _pool: &ExprPool) -> Result<Matrix, ()> {
905    if cols.is_empty() {
906        return Err(());
907    }
908    let n = cols[0].rows;
909    for c in cols {
910        if c.rows != n || c.cols != 1 {
911            return Err(());
912        }
913    }
914    let mut data = Vec::with_capacity(n * cols.len());
915    for r in 0..n {
916        for col in cols {
917            data.push(col.get(r, 0));
918        }
919    }
920    Ok(Matrix {
921        data,
922        rows: n,
923        cols: cols.len(),
924    })
925}
926
927fn diagonal_from_entries(d: &[ExprId], pool: &ExprPool) -> Matrix {
928    let n = d.len();
929    let mut mat = Matrix::zeros(n, n, pool);
930    for i in 0..n {
931        mat.set(i, i, d[i]);
932    }
933    mat
934}
935
936// ---------------------------------------------------------------------------
937// Tests
938// ---------------------------------------------------------------------------
939
940#[cfg(test)]
941mod tests {
942    use super::*;
943
944    fn pool() -> ExprPool {
945        ExprPool::new()
946    }
947
948    #[test]
949    fn jordan_block_eigenspace_one() {
950        let p = pool();
951        let two = p.integer(2_i32);
952        let one = p.integer(1_i32);
953        let z = p.integer(0_i32);
954        let m = Matrix::new(vec![vec![two, one], vec![z, two]]).unwrap();
955        let vals = eigenvalues(&m, &p).unwrap();
956        assert_eq!(vals.len(), 1);
957        assert_eq!(vals[0].1, 2);
958        let vecs = eigenvectors(&m, &p).unwrap();
959        assert_eq!(vecs.len(), 1);
960        assert_eq!(vecs[0].2.len(), 1);
961        assert!(diagonalize(&m, &p).is_err());
962    }
963
964    #[test]
965    fn similar_rational_three_by_three_eigenvals() {
966        // Same shape as SymPy random seed-17 trial — rational similar to an integer diagonal.
967        let p = pool();
968        let r = |a: i64, b: i64| p.rational(a, b);
969        let m = Matrix::new(vec![
970            vec![p.integer(2), r(-18, 7), r(-6, 7)],
971            vec![p.integer(0), r(12, 7), r(32, 7)],
972            vec![p.integer(-2), r(6, 7), r(-26, 7)],
973        ])
974        .unwrap();
975        eigenvalues(&m, &p).unwrap();
976    }
977
978    #[test]
979    fn rotation_imag_roots() {
980        let p = pool();
981        let z = p.integer(0_i32);
982        let one = p.integer(1_i32);
983        let neg_one = p.integer(-1_i32);
984        let m = Matrix::new(vec![vec![z, neg_one], vec![one, z]]).unwrap();
985        let vals = eigenvalues(&m, &p).unwrap();
986        assert_eq!(vals.len(), 2);
987        let vecs = eigenvectors(&m, &p).unwrap();
988        assert_eq!(vecs.len(), 2);
989        assert_eq!(vecs[0].2.len(), 1);
990        assert_eq!(vecs[1].2.len(), 1);
991        let (pp, dd) = diagonalize(&m, &p).unwrap();
992        let mp = m.mul(&pp, &p).unwrap().simplify_entries(&p);
993        let pdd = pp.mul(&dd, &p).unwrap().simplify_entries(&p);
994        assert!(matrix_eq_simplified(&mp, &pdd, &p));
995    }
996}