Skip to main content

alkahest_cas/matrix/
normal_form.rs

1//! V2-5 — Hermite and Smith normal forms for dense integer matrices (`IntegerMatrix`)
2//! and polynomial matrices over ℚ (`PolyMatrixQ` / `RatUniPoly`).
3//!
4//! Integer Hermite form uses FLINT `fmpz_mat_hnf_transform` (Storjohann-class implementations
5//! inside FLINT). Integer Smith form follows SymPy `smith_normal_decomp` (`U * M * V = S`).
6//! Polynomial Hermite / Smith use the same column-elimination pattern over the Euclidean
7//! domain `ℚ[x]`.
8
9#![allow(
10    clippy::needless_range_loop,
11    clippy::cmp_owned,
12    clippy::unnecessary_min_or_max
13)]
14
15use super::smith;
16use super::smith_poly;
17
18use crate::errors::AlkahestError;
19use crate::flint::integer::FlintInteger;
20use crate::flint::mat::FlintMat;
21use rug::{Integer, Rational};
22use std::fmt;
23use std::ops::Mul;
24
25// ---------------------------------------------------------------------------
26// Errors
27// ---------------------------------------------------------------------------
28
29/// Errors from constructing or combining normal-form matrices.
30#[derive(Debug, Clone, PartialEq, Eq)]
31pub enum NormalFormError {
32    /// A row in a nested initializer had the wrong length.
33    DimensionMismatch {
34        row: usize,
35        expected_cols: usize,
36        got: usize,
37    },
38    /// `A * B` was requested but `A.cols != B.rows`.
39    IncompatibleMultiply { left_cols: usize, right_rows: usize },
40}
41
42impl fmt::Display for NormalFormError {
43    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
44        match self {
45            NormalFormError::DimensionMismatch {
46                row,
47                expected_cols,
48                got,
49            } => write!(f, "row {row} has {got} columns, expected {expected_cols}",),
50            NormalFormError::IncompatibleMultiply {
51                left_cols,
52                right_rows,
53            } => write!(
54                f,
55                "cannot multiply {left_cols}-wide matrix by matrix with {right_rows} rows",
56            ),
57        }
58    }
59}
60
61impl std::error::Error for NormalFormError {}
62
63impl AlkahestError for NormalFormError {
64    fn code(&self) -> &'static str {
65        match self {
66            NormalFormError::DimensionMismatch { .. } => "E-NFM-001",
67            NormalFormError::IncompatibleMultiply { .. } => "E-NFM-002",
68        }
69    }
70
71    fn remediation(&self) -> Option<&'static str> {
72        match self {
73            NormalFormError::DimensionMismatch { .. } => {
74                Some("every row in `IntegerMatrix::from_nested` must have equal width")
75            }
76            NormalFormError::IncompatibleMultiply { .. } => {
77                Some("for `A * B`, use matrices where `A.cols == B.rows`")
78            }
79        }
80    }
81}
82
83// ---------------------------------------------------------------------------
84// Integer matrices
85// ---------------------------------------------------------------------------
86
87/// Dense `m × n` matrix over ℤ (row-major).
88#[derive(Clone, Debug, PartialEq, Eq)]
89pub struct IntegerMatrix {
90    pub rows: usize,
91    pub cols: usize,
92    data: Vec<Integer>,
93}
94
95impl IntegerMatrix {
96    /// Build from nested rows of `i64` (must be rectangular).
97    pub fn from_nested(rows: Vec<Vec<i64>>) -> Result<Self, NormalFormError> {
98        if rows.is_empty() {
99            return Ok(Self {
100                rows: 0,
101                cols: 0,
102                data: vec![],
103            });
104        }
105        let cols = rows[0].len();
106        let mut data = Vec::with_capacity(rows.len() * cols);
107        for (ri, r) in rows.iter().enumerate() {
108            if r.len() != cols {
109                return Err(NormalFormError::DimensionMismatch {
110                    row: ri,
111                    expected_cols: cols,
112                    got: r.len(),
113                });
114            }
115            for &x in r {
116                data.push(Integer::from(x));
117            }
118        }
119        Ok(Self {
120            rows: rows.len(),
121            cols,
122            data,
123        })
124    }
125
126    fn from_rug_rows(rows: Vec<Vec<Integer>>) -> Result<Self, NormalFormError> {
127        if rows.is_empty() {
128            return Ok(Self {
129                rows: 0,
130                cols: 0,
131                data: vec![],
132            });
133        }
134        let cols = rows[0].len();
135        let mut data = Vec::with_capacity(rows.len() * cols);
136        for (ri, r) in rows.iter().enumerate() {
137            if r.len() != cols {
138                return Err(NormalFormError::DimensionMismatch {
139                    row: ri,
140                    expected_cols: cols,
141                    got: r.len(),
142                });
143            }
144            for x in r {
145                data.push(x.clone());
146            }
147        }
148        Ok(Self {
149            rows: rows.len(),
150            cols,
151            data,
152        })
153    }
154
155    #[inline]
156    pub fn get(&self, r: usize, c: usize) -> &Integer {
157        &self.data[r * self.cols + c]
158    }
159
160    /// Matrix product `self * other`.
161    pub fn mul(&self, other: &IntegerMatrix) -> Result<Self, NormalFormError> {
162        if self.cols != other.rows {
163            return Err(NormalFormError::IncompatibleMultiply {
164                left_cols: self.cols,
165                right_rows: other.rows,
166            });
167        }
168        let m = self.rows;
169        let n = other.cols;
170        let k = self.cols;
171        let mut out = vec![Integer::from(0); m * n];
172        for i in 0..m {
173            for j in 0..n {
174                let mut acc = Integer::from(0);
175                for t in 0..k {
176                    acc += self.get(i, t) * other.get(t, j);
177                }
178                out[i * n + j] = acc;
179            }
180        }
181        Ok(IntegerMatrix {
182            rows: m,
183            cols: n,
184            data: out,
185        })
186    }
187
188    fn to_flint(&self) -> FlintMat {
189        let mut a = FlintMat::new(self.rows, self.cols);
190        for i in 0..self.rows {
191            for j in 0..self.cols {
192                let fi = FlintInteger::from_rug(self.get(i, j));
193                a.set_entry(i, j, &fi);
194            }
195        }
196        a
197    }
198
199    fn from_flint(m: &FlintMat) -> Self {
200        let rows = m.rows();
201        let cols = m.cols();
202        let mut data = Vec::with_capacity(rows * cols);
203        for i in 0..rows {
204            for j in 0..cols {
205                data.push(m.get_flint(i, j).to_rug());
206            }
207        }
208        Self { rows, cols, data }
209    }
210
211    fn to_nested_integer(&self) -> Vec<Vec<Integer>> {
212        (0..self.rows)
213            .map(|i| (0..self.cols).map(|j| self.get(i, j).clone()).collect())
214            .collect()
215    }
216}
217
218/// Hermite normal form: returns `(H, U)` with `U * M = H`, where `U` is unimodular.
219/// Uses FLINT `fmpz_mat_hnf_transform`.
220pub fn hermite_form(m: &IntegerMatrix) -> (IntegerMatrix, IntegerMatrix) {
221    if m.rows == 0 || m.cols == 0 {
222        return (
223            IntegerMatrix {
224                rows: m.rows,
225                cols: m.cols,
226                data: vec![],
227            },
228            IntegerMatrix::identity(m.rows),
229        );
230    }
231    let a = m.to_flint();
232    let mut h = FlintMat::new(m.rows, m.cols);
233    let mut u = FlintMat::new(m.rows, m.rows);
234    a.hnf_transform(&mut h, &mut u);
235    (IntegerMatrix::from_flint(&h), IntegerMatrix::from_flint(&u))
236}
237
238impl IntegerMatrix {
239    fn identity(n: usize) -> Self {
240        let mut data = vec![Integer::from(0); n * n];
241        for i in 0..n {
242            data[i * n + i] = Integer::from(1);
243        }
244        Self {
245            rows: n,
246            cols: n,
247            data,
248        }
249    }
250}
251
252/// Smith normal form: `(S, U, V)` with `S == U * M * V`, `S` rectangular-diagonal, invariant
253/// factors dividing along the diagonal.
254pub fn smith_form(
255    m: &IntegerMatrix,
256) -> Result<(IntegerMatrix, IntegerMatrix, IntegerMatrix), NormalFormError> {
257    if m.rows == 0 || m.cols == 0 {
258        return Ok((
259            IntegerMatrix {
260                rows: m.rows,
261                cols: m.cols,
262                data: vec![],
263            },
264            IntegerMatrix::identity(m.rows),
265            IntegerMatrix::identity(m.cols),
266        ));
267    }
268    let (s, u, v) = smith::smith_normal_decomp(m.to_nested_integer());
269    Ok((
270        IntegerMatrix::from_rug_rows(s)?,
271        IntegerMatrix::from_rug_rows(u)?,
272        IntegerMatrix::from_rug_rows(v)?,
273    ))
274}
275
276// ---------------------------------------------------------------------------
277// ℚ[x] polynomials (dense, ascending degree)
278// ---------------------------------------------------------------------------
279
280/// Univariate polynomial over ℚ, `∑ cᵢ xⁱ`.
281#[derive(Clone, Debug)]
282pub struct RatUniPoly {
283    /// Ascending coefficients; trailing zeros are stripped.
284    pub coeffs: Vec<Rational>,
285}
286
287impl PartialEq for RatUniPoly {
288    fn eq(&self, other: &Self) -> bool {
289        self.coeffs == other.coeffs
290    }
291}
292
293impl Eq for RatUniPoly {}
294
295impl RatUniPoly {
296    pub fn zero() -> Self {
297        Self { coeffs: vec![] }
298    }
299
300    pub fn one() -> Self {
301        Self {
302            coeffs: vec![Rational::from(1)],
303        }
304    }
305
306    pub fn constant(c: Rational) -> Self {
307        if c == Rational::from(0) {
308            Self::zero()
309        } else {
310            Self { coeffs: vec![c] }
311        }
312    }
313
314    /// The polynomial `x`.
315    pub fn x() -> Self {
316        Self {
317            coeffs: vec![Rational::from(0), Rational::from(1)],
318        }
319    }
320
321    pub(crate) fn trim(mut self) -> Self {
322        while self.coeffs.last() == Some(&Rational::from(0)) {
323            self.coeffs.pop();
324        }
325        self
326    }
327
328    pub fn degree(&self) -> i32 {
329        self.coeffs.len() as i32 - 1
330    }
331
332    pub fn is_zero(&self) -> bool {
333        self.coeffs.is_empty()
334    }
335
336    pub(crate) fn leading_coeff(&self) -> Rational {
337        self.coeffs
338            .last()
339            .cloned()
340            .unwrap_or_else(|| Rational::from(0))
341    }
342
343    /// Euclidean division: `a = q * b + r`, `deg r < deg b` (or `r = 0`).
344    pub fn div_rem(a: &Self, b: &Self) -> (Self, Self) {
345        assert!(!b.is_zero());
346        let mut a = a.clone();
347        let mut a_c = std::mem::take(&mut a.coeffs);
348        let b = b.clone().trim();
349        let b_c = &b.coeffs;
350        let db = b_c.len() as i32 - 1;
351        let lb = b_c[b_c.len() - 1].clone();
352
353        let mut q = vec![Rational::from(0); (a_c.len().saturating_sub(b_c.len()) + 1).max(0)];
354
355        while a_c.len() as i32 > db && a_c.last().map(|v| v != &Rational::from(0)).unwrap_or(false)
356        {
357            let da = a_c.len() as i32 - 1;
358            let la = a_c.last().unwrap().clone();
359            let shift = (da - db) as usize;
360            if shift >= q.len() {
361                q.resize(shift + 1, Rational::from(0));
362            }
363            let t = la / &lb;
364            q[shift] += &t;
365            for j in 0..b_c.len() {
366                let i = shift + j;
367                let prod = t.clone() * b_c[j].clone();
368                a_c[i] -= &prod;
369            }
370            while a_c.last() == Some(&Rational::from(0)) {
371                a_c.pop();
372            }
373        }
374
375        let q_poly = RatUniPoly { coeffs: q }.trim();
376        let r_poly = RatUniPoly { coeffs: a_c }.trim();
377        (q_poly, r_poly)
378    }
379
380    pub fn gcd(&self, other: &Self) -> Self {
381        let mut a = self.clone();
382        let mut b = other.clone();
383        if a.degree() < b.degree() {
384            std::mem::swap(&mut a, &mut b);
385        }
386        while !b.is_zero() {
387            let (_, r) = RatUniPoly::div_rem(&a, &b);
388            a = b;
389            b = r;
390        }
391        if a.is_zero() {
392            RatUniPoly::zero()
393        } else {
394            let mut g = a.trim();
395            let lc = g.leading_coeff();
396            for c in &mut g.coeffs {
397                *c /= lc.clone();
398            }
399            g.trim()
400        }
401    }
402
403    pub fn gcdex(a: &Self, b: &Self) -> (Self, Self, Self) {
404        if b.is_zero() {
405            if a.is_zero() {
406                return (Self::zero(), Self::one(), Self::zero());
407            }
408            let mut an = a.clone().trim();
409            let lc = an.leading_coeff();
410            let inv = Rational::from(1) / lc.clone();
411            for c in &mut an.coeffs {
412                *c *= inv.clone();
413            }
414            let an = an.trim();
415            return (Self::constant(inv), Self::zero(), an);
416        }
417        let (q, r) = Self::div_rem(a, b);
418        let (s1, t1, g) = Self::gcdex(b, &r);
419        let qt = &q * &t1;
420        let tt = &s1 - &qt;
421        (t1, tt.trim(), g)
422    }
423
424    pub(super) fn exquo(&self, g: &Self) -> Self {
425        let (q, r) = RatUniPoly::div_rem(self, g);
426        if !r.is_zero() {
427            panic!("RatUniPoly::exquo: not divisible");
428        }
429        q
430    }
431}
432
433impl std::ops::Add for &RatUniPoly {
434    type Output = RatUniPoly;
435    fn add(self, rhs: &RatUniPoly) -> RatUniPoly {
436        let n = self.coeffs.len().max(rhs.coeffs.len());
437        let mut c = vec![Rational::from(0); n];
438        for i in 0..n {
439            if i < self.coeffs.len() {
440                c[i] += self.coeffs[i].clone();
441            }
442            if i < rhs.coeffs.len() {
443                c[i] += rhs.coeffs[i].clone();
444            }
445        }
446        RatUniPoly { coeffs: c }.trim()
447    }
448}
449
450impl std::ops::Sub for &RatUniPoly {
451    type Output = RatUniPoly;
452    fn sub(self, rhs: &RatUniPoly) -> RatUniPoly {
453        let n = self.coeffs.len().max(rhs.coeffs.len());
454        let mut c = vec![Rational::from(0); n];
455        for i in 0..n {
456            if i < self.coeffs.len() {
457                c[i] += self.coeffs[i].clone();
458            }
459            if i < rhs.coeffs.len() {
460                c[i] -= rhs.coeffs[i].clone();
461            }
462        }
463        RatUniPoly { coeffs: c }.trim()
464    }
465}
466
467impl Mul for RatUniPoly {
468    type Output = Self;
469    fn mul(self, rhs: Self) -> Self {
470        (&self).mul(&rhs)
471    }
472}
473
474impl std::ops::Mul for &RatUniPoly {
475    type Output = RatUniPoly;
476    fn mul(self, rhs: &RatUniPoly) -> RatUniPoly {
477        if self.is_zero() || rhs.is_zero() {
478            return RatUniPoly::zero();
479        }
480        let mut c = vec![Rational::from(0); self.coeffs.len() + rhs.coeffs.len() - 1];
481        for (i, a) in self.coeffs.iter().enumerate() {
482            for (j, b) in rhs.coeffs.iter().enumerate() {
483                c[i + j] += a.clone() * b;
484            }
485        }
486        RatUniPoly { coeffs: c }.trim()
487    }
488}
489
490impl std::ops::Neg for &RatUniPoly {
491    type Output = RatUniPoly;
492    fn neg(self) -> RatUniPoly {
493        let coeffs = self.coeffs.iter().map(|c| -c.clone()).collect();
494        RatUniPoly { coeffs }.trim()
495    }
496}
497
498// ---------------------------------------------------------------------------
499// Polynomial matrices over ℚ[x]
500// ---------------------------------------------------------------------------
501
502/// Rectangular matrix whose entries are univariate polynomials over ℚ.
503#[derive(Clone, Debug, PartialEq, Eq)]
504pub struct PolyMatrixQ {
505    pub rows: usize,
506    pub cols: usize,
507    data: Vec<RatUniPoly>,
508}
509
510impl PolyMatrixQ {
511    pub(super) fn shell(rows: usize, cols: usize) -> Self {
512        Self {
513            rows,
514            cols,
515            data: vec![],
516        }
517    }
518
519    pub fn from_nested(rows: Vec<Vec<RatUniPoly>>) -> Result<Self, NormalFormError> {
520        if rows.is_empty() {
521            return Ok(Self {
522                rows: 0,
523                cols: 0,
524                data: vec![],
525            });
526        }
527        let cols = rows[0].len();
528        let mut data = Vec::with_capacity(rows.len() * cols);
529        for (ri, r) in rows.iter().enumerate() {
530            if r.len() != cols {
531                return Err(NormalFormError::DimensionMismatch {
532                    row: ri,
533                    expected_cols: cols,
534                    got: r.len(),
535                });
536            }
537            for p in r {
538                data.push(p.clone());
539            }
540        }
541        Ok(Self {
542            rows: rows.len(),
543            cols,
544            data,
545        })
546    }
547
548    #[inline]
549    pub fn get(&self, r: usize, c: usize) -> &RatUniPoly {
550        &self.data[r * self.cols + c]
551    }
552
553    pub fn mul(&self, other: &PolyMatrixQ) -> Result<Self, NormalFormError> {
554        if self.cols != other.rows {
555            return Err(NormalFormError::IncompatibleMultiply {
556                left_cols: self.cols,
557                right_rows: other.rows,
558            });
559        }
560        let m = self.rows;
561        let n = other.cols;
562        let k = self.cols;
563        let mut out = Vec::with_capacity(m * n);
564        for i in 0..m {
565            for j in 0..n {
566                let mut acc = RatUniPoly::zero();
567                for t in 0..k {
568                    let prod = self.get(i, t).clone() * other.get(t, j).clone();
569                    acc = (&acc + &prod).trim();
570                }
571                out.push(acc);
572            }
573        }
574        Ok(PolyMatrixQ {
575            rows: m,
576            cols: n,
577            data: out,
578        })
579    }
580
581    fn transpose(&self) -> PolyMatrixQ {
582        let mut data = Vec::with_capacity(self.rows * self.cols);
583        for j in 0..self.cols {
584            for i in 0..self.rows {
585                data.push(self.get(i, j).clone());
586            }
587        }
588        PolyMatrixQ {
589            rows: self.cols,
590            cols: self.rows,
591            data,
592        }
593    }
594}
595
596/// Hermite column-form on `Mᵀ`, then transpose — yields `(H, U)` with `U * M = H`
597/// for the row-style convention used by integer matrices.
598pub fn hermite_form_poly(m: &PolyMatrixQ) -> (PolyMatrixQ, PolyMatrixQ) {
599    let mt = m.transpose();
600    let (ht, v) = smith_poly::hermite_column_poly(&mt);
601    (ht.transpose(), v.transpose())
602}
603
604/// Smith normal form over `ℚ[x]`: `(S, U, V)` with `S == U * M * V`.
605pub fn smith_form_poly(m: &PolyMatrixQ) -> (PolyMatrixQ, PolyMatrixQ, PolyMatrixQ) {
606    smith_poly::smith_normal_poly(m)
607}
608
609// ---------------------------------------------------------------------------
610// Tests
611// ---------------------------------------------------------------------------
612
613#[cfg(test)]
614mod tests {
615    use super::*;
616    use rug::Complete;
617
618    #[test]
619    fn hnf_transform_matches_flint_and_um_equals_h() {
620        let m = IntegerMatrix::from_nested(vec![vec![12, 6, 4], vec![3, 9, 6], vec![2, 16, 14]])
621            .unwrap();
622        let (h, u) = hermite_form(&m);
623        let um = u.mul(&m).unwrap();
624        assert_eq!(um, h);
625        let fh = h.to_flint();
626        assert!(fh.is_in_hnf());
627    }
628
629    #[test]
630    fn snf_sympy_example_3x3() {
631        let m = IntegerMatrix::from_nested(vec![vec![12, 6, 4], vec![3, 9, 6], vec![2, 16, 14]])
632            .unwrap();
633        let (s, u, v) = smith_form(&m).unwrap();
634        let umv = u.mul(&m).unwrap().mul(&v).unwrap();
635        assert_eq!(umv, s);
636        assert!(s.to_flint().is_in_snf());
637        // invariant divisibility on diagonal
638        let d = m.rows.min(m.cols);
639        for i in 0..d.saturating_sub(1) {
640            let a = s.get(i, i).clone();
641            let b = s.get(i + 1, i + 1).clone();
642            if a != Integer::from(0) && b != Integer::from(0) {
643                let (_, r) = b.div_rem_floor_ref(&a).complete();
644                assert_eq!(r, Integer::from(0));
645            }
646        }
647    }
648
649    #[test]
650    fn snf_random_small_matches_flint_diagonal() {
651        use rug::rand::RandState;
652        let mut rand = RandState::new();
653        for _ in 0..30 {
654            let mut rows = vec![];
655            for _ in 0..4 {
656                let mut r = vec![];
657                for _ in 0..4 {
658                    let x: u32 = rand.bits(6);
659                    r.push(x as i64);
660                }
661                rows.push(r);
662            }
663            let m = IntegerMatrix::from_nested(rows).unwrap();
664            let (s, u, v) = smith_form(&m).unwrap();
665            let umv = u.mul(&m).unwrap().mul(&v).unwrap();
666            assert_eq!(umv, s);
667            let fa = m.to_flint();
668            let mut fs = FlintMat::new(m.rows, m.cols);
669            fa.snf_diagonal(&mut fs);
670            assert!(s.to_flint().equals(&fs));
671        }
672    }
673
674    #[test]
675    fn poly_hermite_and_smith_diag_x() {
676        let x = RatUniPoly::x();
677        let z = RatUniPoly::zero();
678        let m =
679            PolyMatrixQ::from_nested(vec![vec![x.clone(), z.clone()], vec![z.clone(), x.clone()]])
680                .unwrap();
681        let (h, u) = hermite_form_poly(&m);
682        let um = u.mul(&m).unwrap();
683        assert_eq!(um, h);
684
685        let (s, us, vs) = smith_form_poly(&m);
686        let prod = us.mul(&m).unwrap().mul(&vs).unwrap();
687        assert_eq!(prod, s);
688    }
689}