1#![allow(
10 clippy::needless_range_loop,
11 clippy::cmp_owned,
12 clippy::unnecessary_min_or_max
13)]
14
15use super::smith;
16use super::smith_poly;
17
18use crate::errors::AlkahestError;
19use crate::flint::integer::FlintInteger;
20use crate::flint::mat::FlintMat;
21use rug::{Integer, Rational};
22use std::fmt;
23use std::ops::Mul;
24
25#[derive(Debug, Clone, PartialEq, Eq)]
31pub enum NormalFormError {
32 DimensionMismatch {
34 row: usize,
35 expected_cols: usize,
36 got: usize,
37 },
38 IncompatibleMultiply { left_cols: usize, right_rows: usize },
40}
41
42impl fmt::Display for NormalFormError {
43 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
44 match self {
45 NormalFormError::DimensionMismatch {
46 row,
47 expected_cols,
48 got,
49 } => write!(f, "row {row} has {got} columns, expected {expected_cols}",),
50 NormalFormError::IncompatibleMultiply {
51 left_cols,
52 right_rows,
53 } => write!(
54 f,
55 "cannot multiply {left_cols}-wide matrix by matrix with {right_rows} rows",
56 ),
57 }
58 }
59}
60
61impl std::error::Error for NormalFormError {}
62
63impl AlkahestError for NormalFormError {
64 fn code(&self) -> &'static str {
65 match self {
66 NormalFormError::DimensionMismatch { .. } => "E-NFM-001",
67 NormalFormError::IncompatibleMultiply { .. } => "E-NFM-002",
68 }
69 }
70
71 fn remediation(&self) -> Option<&'static str> {
72 match self {
73 NormalFormError::DimensionMismatch { .. } => {
74 Some("every row in `IntegerMatrix::from_nested` must have equal width")
75 }
76 NormalFormError::IncompatibleMultiply { .. } => {
77 Some("for `A * B`, use matrices where `A.cols == B.rows`")
78 }
79 }
80 }
81}
82
83#[derive(Clone, Debug, PartialEq, Eq)]
89pub struct IntegerMatrix {
90 pub rows: usize,
91 pub cols: usize,
92 data: Vec<Integer>,
93}
94
95impl IntegerMatrix {
96 pub fn from_nested(rows: Vec<Vec<i64>>) -> Result<Self, NormalFormError> {
98 if rows.is_empty() {
99 return Ok(Self {
100 rows: 0,
101 cols: 0,
102 data: vec![],
103 });
104 }
105 let cols = rows[0].len();
106 let mut data = Vec::with_capacity(rows.len() * cols);
107 for (ri, r) in rows.iter().enumerate() {
108 if r.len() != cols {
109 return Err(NormalFormError::DimensionMismatch {
110 row: ri,
111 expected_cols: cols,
112 got: r.len(),
113 });
114 }
115 for &x in r {
116 data.push(Integer::from(x));
117 }
118 }
119 Ok(Self {
120 rows: rows.len(),
121 cols,
122 data,
123 })
124 }
125
126 fn from_rug_rows(rows: Vec<Vec<Integer>>) -> Result<Self, NormalFormError> {
127 if rows.is_empty() {
128 return Ok(Self {
129 rows: 0,
130 cols: 0,
131 data: vec![],
132 });
133 }
134 let cols = rows[0].len();
135 let mut data = Vec::with_capacity(rows.len() * cols);
136 for (ri, r) in rows.iter().enumerate() {
137 if r.len() != cols {
138 return Err(NormalFormError::DimensionMismatch {
139 row: ri,
140 expected_cols: cols,
141 got: r.len(),
142 });
143 }
144 for x in r {
145 data.push(x.clone());
146 }
147 }
148 Ok(Self {
149 rows: rows.len(),
150 cols,
151 data,
152 })
153 }
154
155 #[inline]
156 pub fn get(&self, r: usize, c: usize) -> &Integer {
157 &self.data[r * self.cols + c]
158 }
159
160 pub fn mul(&self, other: &IntegerMatrix) -> Result<Self, NormalFormError> {
162 if self.cols != other.rows {
163 return Err(NormalFormError::IncompatibleMultiply {
164 left_cols: self.cols,
165 right_rows: other.rows,
166 });
167 }
168 let m = self.rows;
169 let n = other.cols;
170 let k = self.cols;
171 let mut out = vec![Integer::from(0); m * n];
172 for i in 0..m {
173 for j in 0..n {
174 let mut acc = Integer::from(0);
175 for t in 0..k {
176 acc += self.get(i, t) * other.get(t, j);
177 }
178 out[i * n + j] = acc;
179 }
180 }
181 Ok(IntegerMatrix {
182 rows: m,
183 cols: n,
184 data: out,
185 })
186 }
187
188 fn to_flint(&self) -> FlintMat {
189 let mut a = FlintMat::new(self.rows, self.cols);
190 for i in 0..self.rows {
191 for j in 0..self.cols {
192 let fi = FlintInteger::from_rug(self.get(i, j));
193 a.set_entry(i, j, &fi);
194 }
195 }
196 a
197 }
198
199 fn from_flint(m: &FlintMat) -> Self {
200 let rows = m.rows();
201 let cols = m.cols();
202 let mut data = Vec::with_capacity(rows * cols);
203 for i in 0..rows {
204 for j in 0..cols {
205 data.push(m.get_flint(i, j).to_rug());
206 }
207 }
208 Self { rows, cols, data }
209 }
210
211 fn to_nested_integer(&self) -> Vec<Vec<Integer>> {
212 (0..self.rows)
213 .map(|i| (0..self.cols).map(|j| self.get(i, j).clone()).collect())
214 .collect()
215 }
216}
217
218pub fn hermite_form(m: &IntegerMatrix) -> (IntegerMatrix, IntegerMatrix) {
221 if m.rows == 0 || m.cols == 0 {
222 return (
223 IntegerMatrix {
224 rows: m.rows,
225 cols: m.cols,
226 data: vec![],
227 },
228 IntegerMatrix::identity(m.rows),
229 );
230 }
231 let a = m.to_flint();
232 let mut h = FlintMat::new(m.rows, m.cols);
233 let mut u = FlintMat::new(m.rows, m.rows);
234 a.hnf_transform(&mut h, &mut u);
235 (IntegerMatrix::from_flint(&h), IntegerMatrix::from_flint(&u))
236}
237
238impl IntegerMatrix {
239 fn identity(n: usize) -> Self {
240 let mut data = vec![Integer::from(0); n * n];
241 for i in 0..n {
242 data[i * n + i] = Integer::from(1);
243 }
244 Self {
245 rows: n,
246 cols: n,
247 data,
248 }
249 }
250}
251
252pub fn smith_form(
255 m: &IntegerMatrix,
256) -> Result<(IntegerMatrix, IntegerMatrix, IntegerMatrix), NormalFormError> {
257 if m.rows == 0 || m.cols == 0 {
258 return Ok((
259 IntegerMatrix {
260 rows: m.rows,
261 cols: m.cols,
262 data: vec![],
263 },
264 IntegerMatrix::identity(m.rows),
265 IntegerMatrix::identity(m.cols),
266 ));
267 }
268 let (s, u, v) = smith::smith_normal_decomp(m.to_nested_integer());
269 Ok((
270 IntegerMatrix::from_rug_rows(s)?,
271 IntegerMatrix::from_rug_rows(u)?,
272 IntegerMatrix::from_rug_rows(v)?,
273 ))
274}
275
276#[derive(Clone, Debug)]
282pub struct RatUniPoly {
283 pub coeffs: Vec<Rational>,
285}
286
287impl PartialEq for RatUniPoly {
288 fn eq(&self, other: &Self) -> bool {
289 self.coeffs == other.coeffs
290 }
291}
292
293impl Eq for RatUniPoly {}
294
295impl RatUniPoly {
296 pub fn zero() -> Self {
297 Self { coeffs: vec![] }
298 }
299
300 pub fn one() -> Self {
301 Self {
302 coeffs: vec![Rational::from(1)],
303 }
304 }
305
306 pub fn constant(c: Rational) -> Self {
307 if c == Rational::from(0) {
308 Self::zero()
309 } else {
310 Self { coeffs: vec![c] }
311 }
312 }
313
314 pub fn x() -> Self {
316 Self {
317 coeffs: vec![Rational::from(0), Rational::from(1)],
318 }
319 }
320
321 pub(crate) fn trim(mut self) -> Self {
322 while self.coeffs.last() == Some(&Rational::from(0)) {
323 self.coeffs.pop();
324 }
325 self
326 }
327
328 pub fn degree(&self) -> i32 {
329 self.coeffs.len() as i32 - 1
330 }
331
332 pub fn is_zero(&self) -> bool {
333 self.coeffs.is_empty()
334 }
335
336 pub(crate) fn leading_coeff(&self) -> Rational {
337 self.coeffs
338 .last()
339 .cloned()
340 .unwrap_or_else(|| Rational::from(0))
341 }
342
343 pub fn div_rem(a: &Self, b: &Self) -> (Self, Self) {
345 assert!(!b.is_zero());
346 let mut a = a.clone();
347 let mut a_c = std::mem::take(&mut a.coeffs);
348 let b = b.clone().trim();
349 let b_c = &b.coeffs;
350 let db = b_c.len() as i32 - 1;
351 let lb = b_c[b_c.len() - 1].clone();
352
353 let mut q = vec![Rational::from(0); (a_c.len().saturating_sub(b_c.len()) + 1).max(0)];
354
355 while a_c.len() as i32 > db && a_c.last().map(|v| v != &Rational::from(0)).unwrap_or(false)
356 {
357 let da = a_c.len() as i32 - 1;
358 let la = a_c.last().unwrap().clone();
359 let shift = (da - db) as usize;
360 if shift >= q.len() {
361 q.resize(shift + 1, Rational::from(0));
362 }
363 let t = la / &lb;
364 q[shift] += &t;
365 for j in 0..b_c.len() {
366 let i = shift + j;
367 let prod = t.clone() * b_c[j].clone();
368 a_c[i] -= ∏
369 }
370 while a_c.last() == Some(&Rational::from(0)) {
371 a_c.pop();
372 }
373 }
374
375 let q_poly = RatUniPoly { coeffs: q }.trim();
376 let r_poly = RatUniPoly { coeffs: a_c }.trim();
377 (q_poly, r_poly)
378 }
379
380 pub fn gcd(&self, other: &Self) -> Self {
381 let mut a = self.clone();
382 let mut b = other.clone();
383 if a.degree() < b.degree() {
384 std::mem::swap(&mut a, &mut b);
385 }
386 while !b.is_zero() {
387 let (_, r) = RatUniPoly::div_rem(&a, &b);
388 a = b;
389 b = r;
390 }
391 if a.is_zero() {
392 RatUniPoly::zero()
393 } else {
394 let mut g = a.trim();
395 let lc = g.leading_coeff();
396 for c in &mut g.coeffs {
397 *c /= lc.clone();
398 }
399 g.trim()
400 }
401 }
402
403 pub fn gcdex(a: &Self, b: &Self) -> (Self, Self, Self) {
404 if b.is_zero() {
405 if a.is_zero() {
406 return (Self::zero(), Self::one(), Self::zero());
407 }
408 let mut an = a.clone().trim();
409 let lc = an.leading_coeff();
410 let inv = Rational::from(1) / lc.clone();
411 for c in &mut an.coeffs {
412 *c *= inv.clone();
413 }
414 let an = an.trim();
415 return (Self::constant(inv), Self::zero(), an);
416 }
417 let (q, r) = Self::div_rem(a, b);
418 let (s1, t1, g) = Self::gcdex(b, &r);
419 let qt = &q * &t1;
420 let tt = &s1 - &qt;
421 (t1, tt.trim(), g)
422 }
423
424 pub(super) fn exquo(&self, g: &Self) -> Self {
425 let (q, r) = RatUniPoly::div_rem(self, g);
426 if !r.is_zero() {
427 panic!("RatUniPoly::exquo: not divisible");
428 }
429 q
430 }
431}
432
433impl std::ops::Add for &RatUniPoly {
434 type Output = RatUniPoly;
435 fn add(self, rhs: &RatUniPoly) -> RatUniPoly {
436 let n = self.coeffs.len().max(rhs.coeffs.len());
437 let mut c = vec![Rational::from(0); n];
438 for i in 0..n {
439 if i < self.coeffs.len() {
440 c[i] += self.coeffs[i].clone();
441 }
442 if i < rhs.coeffs.len() {
443 c[i] += rhs.coeffs[i].clone();
444 }
445 }
446 RatUniPoly { coeffs: c }.trim()
447 }
448}
449
450impl std::ops::Sub for &RatUniPoly {
451 type Output = RatUniPoly;
452 fn sub(self, rhs: &RatUniPoly) -> RatUniPoly {
453 let n = self.coeffs.len().max(rhs.coeffs.len());
454 let mut c = vec![Rational::from(0); n];
455 for i in 0..n {
456 if i < self.coeffs.len() {
457 c[i] += self.coeffs[i].clone();
458 }
459 if i < rhs.coeffs.len() {
460 c[i] -= rhs.coeffs[i].clone();
461 }
462 }
463 RatUniPoly { coeffs: c }.trim()
464 }
465}
466
467impl Mul for RatUniPoly {
468 type Output = Self;
469 fn mul(self, rhs: Self) -> Self {
470 (&self).mul(&rhs)
471 }
472}
473
474impl std::ops::Mul for &RatUniPoly {
475 type Output = RatUniPoly;
476 fn mul(self, rhs: &RatUniPoly) -> RatUniPoly {
477 if self.is_zero() || rhs.is_zero() {
478 return RatUniPoly::zero();
479 }
480 let mut c = vec![Rational::from(0); self.coeffs.len() + rhs.coeffs.len() - 1];
481 for (i, a) in self.coeffs.iter().enumerate() {
482 for (j, b) in rhs.coeffs.iter().enumerate() {
483 c[i + j] += a.clone() * b;
484 }
485 }
486 RatUniPoly { coeffs: c }.trim()
487 }
488}
489
490impl std::ops::Neg for &RatUniPoly {
491 type Output = RatUniPoly;
492 fn neg(self) -> RatUniPoly {
493 let coeffs = self.coeffs.iter().map(|c| -c.clone()).collect();
494 RatUniPoly { coeffs }.trim()
495 }
496}
497
498#[derive(Clone, Debug, PartialEq, Eq)]
504pub struct PolyMatrixQ {
505 pub rows: usize,
506 pub cols: usize,
507 data: Vec<RatUniPoly>,
508}
509
510impl PolyMatrixQ {
511 pub(super) fn shell(rows: usize, cols: usize) -> Self {
512 Self {
513 rows,
514 cols,
515 data: vec![],
516 }
517 }
518
519 pub fn from_nested(rows: Vec<Vec<RatUniPoly>>) -> Result<Self, NormalFormError> {
520 if rows.is_empty() {
521 return Ok(Self {
522 rows: 0,
523 cols: 0,
524 data: vec![],
525 });
526 }
527 let cols = rows[0].len();
528 let mut data = Vec::with_capacity(rows.len() * cols);
529 for (ri, r) in rows.iter().enumerate() {
530 if r.len() != cols {
531 return Err(NormalFormError::DimensionMismatch {
532 row: ri,
533 expected_cols: cols,
534 got: r.len(),
535 });
536 }
537 for p in r {
538 data.push(p.clone());
539 }
540 }
541 Ok(Self {
542 rows: rows.len(),
543 cols,
544 data,
545 })
546 }
547
548 #[inline]
549 pub fn get(&self, r: usize, c: usize) -> &RatUniPoly {
550 &self.data[r * self.cols + c]
551 }
552
553 pub fn mul(&self, other: &PolyMatrixQ) -> Result<Self, NormalFormError> {
554 if self.cols != other.rows {
555 return Err(NormalFormError::IncompatibleMultiply {
556 left_cols: self.cols,
557 right_rows: other.rows,
558 });
559 }
560 let m = self.rows;
561 let n = other.cols;
562 let k = self.cols;
563 let mut out = Vec::with_capacity(m * n);
564 for i in 0..m {
565 for j in 0..n {
566 let mut acc = RatUniPoly::zero();
567 for t in 0..k {
568 let prod = self.get(i, t).clone() * other.get(t, j).clone();
569 acc = (&acc + &prod).trim();
570 }
571 out.push(acc);
572 }
573 }
574 Ok(PolyMatrixQ {
575 rows: m,
576 cols: n,
577 data: out,
578 })
579 }
580
581 fn transpose(&self) -> PolyMatrixQ {
582 let mut data = Vec::with_capacity(self.rows * self.cols);
583 for j in 0..self.cols {
584 for i in 0..self.rows {
585 data.push(self.get(i, j).clone());
586 }
587 }
588 PolyMatrixQ {
589 rows: self.cols,
590 cols: self.rows,
591 data,
592 }
593 }
594}
595
596pub fn hermite_form_poly(m: &PolyMatrixQ) -> (PolyMatrixQ, PolyMatrixQ) {
599 let mt = m.transpose();
600 let (ht, v) = smith_poly::hermite_column_poly(&mt);
601 (ht.transpose(), v.transpose())
602}
603
604pub fn smith_form_poly(m: &PolyMatrixQ) -> (PolyMatrixQ, PolyMatrixQ, PolyMatrixQ) {
606 smith_poly::smith_normal_poly(m)
607}
608
609#[cfg(test)]
614mod tests {
615 use super::*;
616 use rug::Complete;
617
618 #[test]
619 fn hnf_transform_matches_flint_and_um_equals_h() {
620 let m = IntegerMatrix::from_nested(vec![vec![12, 6, 4], vec![3, 9, 6], vec![2, 16, 14]])
621 .unwrap();
622 let (h, u) = hermite_form(&m);
623 let um = u.mul(&m).unwrap();
624 assert_eq!(um, h);
625 let fh = h.to_flint();
626 assert!(fh.is_in_hnf());
627 }
628
629 #[test]
630 fn snf_sympy_example_3x3() {
631 let m = IntegerMatrix::from_nested(vec![vec![12, 6, 4], vec![3, 9, 6], vec![2, 16, 14]])
632 .unwrap();
633 let (s, u, v) = smith_form(&m).unwrap();
634 let umv = u.mul(&m).unwrap().mul(&v).unwrap();
635 assert_eq!(umv, s);
636 assert!(s.to_flint().is_in_snf());
637 let d = m.rows.min(m.cols);
639 for i in 0..d.saturating_sub(1) {
640 let a = s.get(i, i).clone();
641 let b = s.get(i + 1, i + 1).clone();
642 if a != Integer::from(0) && b != Integer::from(0) {
643 let (_, r) = b.div_rem_floor_ref(&a).complete();
644 assert_eq!(r, Integer::from(0));
645 }
646 }
647 }
648
649 #[test]
650 fn snf_random_small_matches_flint_diagonal() {
651 use rug::rand::RandState;
652 let mut rand = RandState::new();
653 for _ in 0..30 {
654 let mut rows = vec![];
655 for _ in 0..4 {
656 let mut r = vec![];
657 for _ in 0..4 {
658 let x: u32 = rand.bits(6);
659 r.push(x as i64);
660 }
661 rows.push(r);
662 }
663 let m = IntegerMatrix::from_nested(rows).unwrap();
664 let (s, u, v) = smith_form(&m).unwrap();
665 let umv = u.mul(&m).unwrap().mul(&v).unwrap();
666 assert_eq!(umv, s);
667 let fa = m.to_flint();
668 let mut fs = FlintMat::new(m.rows, m.cols);
669 fa.snf_diagonal(&mut fs);
670 assert!(s.to_flint().equals(&fs));
671 }
672 }
673
674 #[test]
675 fn poly_hermite_and_smith_diag_x() {
676 let x = RatUniPoly::x();
677 let z = RatUniPoly::zero();
678 let m =
679 PolyMatrixQ::from_nested(vec![vec![x.clone(), z.clone()], vec![z.clone(), x.clone()]])
680 .unwrap();
681 let (h, u) = hermite_form_poly(&m);
682 let um = u.mul(&m).unwrap();
683 assert_eq!(um, h);
684
685 let (s, us, vs) = smith_form_poly(&m);
686 let prod = us.mul(&m).unwrap().mul(&vs).unwrap();
687 assert_eq!(prod, s);
688 }
689}