1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
use nalgebra::base::allocator::Allocator;
use nalgebra::{DefaultAllocator, Dim, OMatrix};

use super::elementary::{
    adding_column_matrix, changing_column_sign_matrix, swapping_column_matrix,
};

/// Hermite normal form of (M, N) matrix such that h = basis * r
#[derive(Debug)]
#[allow(clippy::upper_case_acronyms)]
pub struct HNF<M: Dim, N: Dim>
where
    DefaultAllocator: Allocator<i32, M, N> + Allocator<i32, N, N>,
{
    pub h: OMatrix<i32, M, N>,
    pub r: OMatrix<i32, N, N>,
}

impl<M: Dim, N: Dim> HNF<M, N>
where
    DefaultAllocator: Allocator<i32, M, N> + Allocator<i32, N, N>,
{
    /// Return column-wise Hermite norm form
    pub fn new(basis: &OMatrix<i32, M, N>) -> Self {
        let (m, n) = basis.shape_generic();
        let mut h = basis.clone();
        let mut r = OMatrix::identity_generic(n, n);

        // Process the `s`th row
        for s in 0..m.value() {
            loop {
                if (s..n.value()).all(|j| h[(s, j)] == 0) {
                    break;
                }

                // Choose pivot column with the smallest absolute value
                let pivot = (s..n.value())
                    .filter(|&j| h[(s, j)] != 0)
                    .min_by_key(|&j| h[(s, j)].abs())
                    .unwrap();
                h.swap_columns(s, pivot);
                r *= swapping_column_matrix(n, s, pivot);

                // Guarantee that h[(s, s)] is positive
                if h[(s, s)] < 0 {
                    for i in 0..m.value() {
                        h[(i, s)] *= -1;
                    }
                    r *= changing_column_sign_matrix(n, s);
                }
                assert_ne!(h[(s, s)], 0);

                // Add the `s`th column to the other columns
                let mut update = false;
                for j in 0..n.value() {
                    if j == s {
                        continue;
                    }
                    let k = h[(s, j)].div_euclid(h[(s, s)]);

                    if k != 0 {
                        update = true;
                        // h[(:, j)] -= k * h[(:, s)]
                        for i in 0..m.value() {
                            h[(i, j)] -= k * h[(i, s)];
                        }
                        r *= adding_column_matrix(n, s, j, -k);
                    }
                }

                // Continue until updating
                if !update {
                    break;
                }
            }
        }
        assert_eq!(h, basis * r.clone());
        Self { h, r }
    }
}

#[cfg(test)]
mod tests {
    use nalgebra::{matrix, SMatrix};
    use rand::prelude::*;
    use rand::rngs::StdRng;
    use rand::SeedableRng;

    use super::HNF;

    #[test]
    fn test_hnf_small() {
        {
            let m = matrix![
                -1, 0, 0;
                1, 2, 2;
                0, -1, -2;
            ];
            let hnf = HNF::new(&m);
            let expect = matrix![
                1, 0, 0;
                1, 2, 0;
                0, 0, 1;
            ];
            assert_eq!(hnf.h, expect);
        }
        {
            let m = matrix![
                20, -6;
                -2, 1;
            ];
            let hnf = HNF::new(&m);
            assert_eq!(hnf.h, matrix![2, 0; 1, 4]);
        }
        {
            let m = matrix![
                2, 3, 6, 2;
                5, 6, 1, 6;
                8, 3, 1, 1;
            ];
            let hnf = HNF::new(&m);
            let expect = matrix![
                1, 0, 0, 0;
                0, 1, 0, 0;
                0, 0, 1, 0;
            ];
            assert_eq!(hnf.h, expect);
        }
    }

    #[test]
    fn test_hnf_random() {
        let mut rng: StdRng = SeedableRng::from_seed([0; 32]);

        for _ in 0..256 {
            let m = SMatrix::<i32, 3, 3>::from_fn(|_, _| rng.gen_range(-4..4));
            let _ = HNF::new(&m);

            let m = SMatrix::<i32, 5, 7>::from_fn(|_, _| rng.gen_range(-4..4));
            let _ = HNF::new(&m);

            let m = SMatrix::<i32, 7, 5>::from_fn(|_, _| rng.gen_range(-4..4));
            let _ = HNF::new(&m);
        }
    }
}