Skip to main content

amari_core/gf2/
matrix.rs

1//! GF(2) matrices with Gaussian elimination, rank, null space, and linear system solving.
2
3use super::scalar::GF2;
4use super::vector::GF2Vector;
5use crate::error::{CoreError, CoreResult};
6use alloc::vec;
7use alloc::vec::Vec;
8use core::fmt;
9
10/// A matrix over GF(2), stored as row vectors.
11///
12/// Supports Gaussian elimination, rank computation, null space extraction,
13/// and matrix-vector multiplication — all via bitwise operations.
14#[derive(Debug, Clone, PartialEq, Eq)]
15pub struct GF2Matrix {
16    rows: Vec<GF2Vector>,
17    nrows: usize,
18    ncols: usize,
19}
20
21impl GF2Matrix {
22    /// Create a zero matrix.
23    #[must_use]
24    pub fn zero(nrows: usize, ncols: usize) -> Self {
25        let rows = (0..nrows).map(|_| GF2Vector::zero(ncols)).collect();
26        Self { rows, nrows, ncols }
27    }
28
29    /// Create an identity matrix.
30    #[must_use]
31    pub fn identity(n: usize) -> Self {
32        let mut m = Self::zero(n, n);
33        for i in 0..n {
34            m.set(i, i, GF2::ONE);
35        }
36        m
37    }
38
39    /// Create from row vectors. All rows must have the same dimension.
40    #[must_use]
41    pub fn from_rows(rows: Vec<GF2Vector>) -> Self {
42        let nrows = rows.len();
43        let ncols = if nrows > 0 { rows[0].dim() } else { 0 };
44        debug_assert!(rows.iter().all(|r| r.dim() == ncols));
45        Self { rows, nrows, ncols }
46    }
47
48    /// Number of rows.
49    #[inline]
50    #[must_use]
51    pub fn nrows(&self) -> usize {
52        self.nrows
53    }
54
55    /// Number of columns.
56    #[inline]
57    #[must_use]
58    pub fn ncols(&self) -> usize {
59        self.ncols
60    }
61
62    /// Get element at (row, col).
63    #[inline]
64    #[must_use]
65    pub fn get(&self, row: usize, col: usize) -> GF2 {
66        self.rows[row].get(col)
67    }
68
69    /// Set element at (row, col).
70    #[inline]
71    pub fn set(&mut self, row: usize, col: usize, value: GF2) {
72        self.rows[row].set(col, value);
73    }
74
75    /// Get a reference to row i.
76    #[must_use]
77    pub fn row(&self, i: usize) -> &GF2Vector {
78        &self.rows[i]
79    }
80
81    /// Matrix-vector product over GF(2).
82    #[must_use]
83    pub fn mul_vec(&self, v: &GF2Vector) -> GF2Vector {
84        assert_eq!(self.ncols, v.dim(), "dimension mismatch");
85        let bits: Vec<u8> = self.rows.iter().map(|row| row.dot(v).value()).collect();
86        GF2Vector::from_bits(&bits)
87    }
88
89    /// Matrix-matrix product over GF(2).
90    #[must_use]
91    pub fn mul_mat(&self, other: &Self) -> Self {
92        assert_eq!(self.ncols, other.nrows, "dimension mismatch");
93        let other_t = other.transpose();
94        let rows: Vec<GF2Vector> = self
95            .rows
96            .iter()
97            .map(|row| {
98                let bits: Vec<u8> = other_t
99                    .rows
100                    .iter()
101                    .map(|col| row.dot(col).value())
102                    .collect();
103                GF2Vector::from_bits(&bits)
104            })
105            .collect();
106        Self::from_rows(rows)
107    }
108
109    /// Transpose.
110    #[must_use]
111    pub fn transpose(&self) -> Self {
112        let mut t = Self::zero(self.ncols, self.nrows);
113        for i in 0..self.nrows {
114            for j in 0..self.ncols {
115                t.set(j, i, self.get(i, j));
116            }
117        }
118        t
119    }
120
121    /// Reduced row echelon form (in-place). Returns pivot column indices.
122    pub fn reduced_row_echelon(&mut self) -> Vec<usize> {
123        let mut pivots = Vec::new();
124        let mut pivot_row = 0;
125
126        for col in 0..self.ncols {
127            // Find a row with a 1 in this column at or below pivot_row.
128            let found = (pivot_row..self.nrows).find(|&r| self.get(r, col).is_one());
129
130            if let Some(swap_row) = found {
131                self.rows.swap(pivot_row, swap_row);
132
133                // Eliminate all other rows with a 1 in this column.
134                for r in 0..self.nrows {
135                    if r != pivot_row && self.get(r, col).is_one() {
136                        let pivot = self.rows[pivot_row].clone();
137                        self.rows[r] = self.rows[r].add(&pivot);
138                    }
139                }
140
141                pivots.push(col);
142                pivot_row += 1;
143            }
144        }
145        pivots
146    }
147
148    /// Row echelon form (in-place). Returns pivot column indices.
149    ///
150    /// Over GF(2), this produces the same result as `reduced_row_echelon` since
151    /// the elimination above and below is equivalent when the only nonzero scalar is 1.
152    pub fn row_echelon(&mut self) -> Vec<usize> {
153        self.reduced_row_echelon()
154    }
155
156    /// Rank = number of pivots.
157    #[must_use]
158    pub fn rank(&self) -> usize {
159        let mut copy = self.clone();
160        copy.reduced_row_echelon().len()
161    }
162
163    /// Null space basis vectors (kernel of the matrix).
164    #[must_use]
165    pub fn null_space(&self) -> Vec<GF2Vector> {
166        let mut rref = self.clone();
167        let pivots = rref.reduced_row_echelon();
168
169        let pivot_set: Vec<bool> = (0..self.ncols).map(|c| pivots.contains(&c)).collect();
170
171        // Map pivot columns to their row index.
172        let mut pivot_row_for_col = vec![usize::MAX; self.ncols];
173        for (row, &col) in pivots.iter().enumerate() {
174            pivot_row_for_col[col] = row;
175        }
176
177        let free_cols: Vec<usize> = (0..self.ncols).filter(|c| !pivot_set[*c]).collect();
178
179        let mut basis = Vec::new();
180        for &fc in &free_cols {
181            let mut v = GF2Vector::zero(self.ncols);
182            v.set(fc, GF2::ONE);
183            // For each pivot column, read the entry in the RREF at (pivot_row, fc).
184            for &pc in &pivots {
185                let pr = pivot_row_for_col[pc];
186                v.set(pc, rref.get(pr, fc));
187            }
188            basis.push(v);
189        }
190        basis
191    }
192
193    /// Determinant (only for square matrices).
194    pub fn determinant(&self) -> CoreResult<GF2> {
195        if self.nrows != self.ncols {
196            return Err(CoreError::GF2NotSquare {
197                rows: self.nrows,
198                cols: self.ncols,
199            });
200        }
201        let r = self.rank();
202        Ok(if r == self.nrows { GF2::ONE } else { GF2::ZERO })
203    }
204
205    /// Column space basis vectors (image of the matrix).
206    #[must_use]
207    pub fn column_space(&self) -> Vec<GF2Vector> {
208        let t = self.transpose();
209        let mut rref = t.clone();
210        let pivots = rref.reduced_row_echelon();
211        pivots.iter().map(|&c| t.row(c).clone()).collect()
212    }
213
214    /// Check if a vector is in the column space.
215    #[must_use]
216    pub fn in_column_space(&self, v: &GF2Vector) -> bool {
217        self.solve(v).is_some()
218    }
219
220    /// Solve Ax = b over GF(2). Returns None if no solution exists.
221    #[must_use]
222    pub fn solve(&self, b: &GF2Vector) -> Option<GF2Vector> {
223        assert_eq!(self.nrows, b.dim(), "dimension mismatch");
224        let mut aug = self.augment(b);
225        let pivots = aug.reduced_row_echelon();
226
227        // Check for inconsistency: pivot in the augmented column (last column).
228        let aug_col = self.ncols;
229        if pivots.contains(&aug_col) {
230            return None;
231        }
232
233        // Extract solution: for each pivot column, read the value from the augmented column.
234        let mut x = GF2Vector::zero(self.ncols);
235        for (row, &col) in pivots.iter().enumerate() {
236            x.set(col, aug.get(row, aug_col));
237        }
238        Some(x)
239    }
240
241    /// Augmented matrix [A | b].
242    #[must_use]
243    pub fn augment(&self, b: &GF2Vector) -> Self {
244        assert_eq!(self.nrows, b.dim(), "dimension mismatch");
245        let new_ncols = self.ncols + 1;
246        let rows: Vec<GF2Vector> = self
247            .rows
248            .iter()
249            .enumerate()
250            .map(|(i, row)| {
251                let mut new_row = GF2Vector::zero(new_ncols);
252                for j in 0..self.ncols {
253                    new_row.set(j, row.get(j));
254                }
255                new_row.set(self.ncols, b.get(i));
256                new_row
257            })
258            .collect();
259        Self {
260            rows,
261            nrows: self.nrows,
262            ncols: new_ncols,
263        }
264    }
265
266    /// Horizontal concatenation [A | B].
267    #[must_use]
268    pub fn hcat(&self, other: &Self) -> Self {
269        assert_eq!(self.nrows, other.nrows, "row count mismatch");
270        let new_ncols = self.ncols + other.ncols;
271        let rows: Vec<GF2Vector> = self
272            .rows
273            .iter()
274            .zip(other.rows.iter())
275            .map(|(a, b)| {
276                let mut new_row = GF2Vector::zero(new_ncols);
277                for j in 0..self.ncols {
278                    new_row.set(j, a.get(j));
279                }
280                for j in 0..other.ncols {
281                    new_row.set(self.ncols + j, b.get(j));
282                }
283                new_row
284            })
285            .collect();
286        Self {
287            rows,
288            nrows: self.nrows,
289            ncols: new_ncols,
290        }
291    }
292}
293
294impl fmt::Display for GF2Matrix {
295    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
296        for (i, row) in self.rows.iter().enumerate() {
297            if i > 0 {
298                writeln!(f)?;
299            }
300            write!(f, "{}", row)?;
301        }
302        Ok(())
303    }
304}
305
306#[cfg(test)]
307mod tests {
308    use super::*;
309
310    #[test]
311    fn test_identity_properties() {
312        let id = GF2Matrix::identity(3);
313        assert_eq!(id.rank(), 3);
314        assert_eq!(id.determinant().unwrap(), GF2::ONE);
315
316        let v = GF2Vector::from_bits(&[1, 0, 1]);
317        assert_eq!(id.mul_vec(&v), v);
318    }
319
320    #[test]
321    fn test_matrix_vector_product() {
322        // [[1,0,1],[0,1,1]] * [1,1,0] = [1, 1]
323        let m = GF2Matrix::from_rows(vec![
324            GF2Vector::from_bits(&[1, 0, 1]),
325            GF2Vector::from_bits(&[0, 1, 1]),
326        ]);
327        let v = GF2Vector::from_bits(&[1, 1, 0]);
328        let result = m.mul_vec(&v);
329        assert_eq!(result, GF2Vector::from_bits(&[1, 1]));
330    }
331
332    #[test]
333    fn test_row_echelon_and_rank() {
334        let mut m = GF2Matrix::from_rows(vec![
335            GF2Vector::from_bits(&[1, 0, 1, 0]),
336            GF2Vector::from_bits(&[0, 1, 1, 0]),
337            GF2Vector::from_bits(&[1, 1, 0, 0]),
338        ]);
339        let pivots = m.reduced_row_echelon();
340        assert_eq!(pivots.len(), 2); // rank 2 (third row is sum of first two over GF(2))
341    }
342
343    #[test]
344    fn test_full_rank() {
345        let m = GF2Matrix::from_rows(vec![
346            GF2Vector::from_bits(&[1, 0, 0]),
347            GF2Vector::from_bits(&[0, 1, 0]),
348            GF2Vector::from_bits(&[0, 0, 1]),
349        ]);
350        assert_eq!(m.rank(), 3);
351        assert_eq!(m.determinant().unwrap(), GF2::ONE);
352    }
353
354    #[test]
355    fn test_rank_deficient() {
356        let m = GF2Matrix::from_rows(vec![
357            GF2Vector::from_bits(&[1, 1, 0]),
358            GF2Vector::from_bits(&[0, 0, 1]),
359            GF2Vector::from_bits(&[1, 1, 1]),
360        ]);
361        assert_eq!(m.rank(), 2);
362        assert_eq!(m.determinant().unwrap(), GF2::ZERO);
363    }
364
365    #[test]
366    fn test_null_space() {
367        // [[1,0,1],[0,1,1]] — null space should be [1,1,1]
368        let m = GF2Matrix::from_rows(vec![
369            GF2Vector::from_bits(&[1, 0, 1]),
370            GF2Vector::from_bits(&[0, 1, 1]),
371        ]);
372        let ns = m.null_space();
373        assert_eq!(ns.len(), 1);
374        // Verify Ax = 0 for each null space vector.
375        for v in &ns {
376            let product = m.mul_vec(v);
377            assert!(product.is_zero(), "null space vector not in kernel");
378        }
379    }
380
381    #[test]
382    fn test_determinant_non_square() {
383        let m = GF2Matrix::zero(2, 3);
384        assert!(m.determinant().is_err());
385    }
386
387    #[test]
388    fn test_solve() {
389        // A = [[1,0],[0,1]], b = [1,1] => x = [1,1]
390        let a = GF2Matrix::identity(2);
391        let b = GF2Vector::from_bits(&[1, 1]);
392        let x = a.solve(&b).unwrap();
393        assert_eq!(a.mul_vec(&x), b);
394    }
395
396    #[test]
397    fn test_solve_inconsistent() {
398        // A = [[1,0],[1,0]], b = [1,0] => inconsistent for b=[0,1]
399        let a = GF2Matrix::from_rows(vec![
400            GF2Vector::from_bits(&[1, 0]),
401            GF2Vector::from_bits(&[1, 0]),
402        ]);
403        let b = GF2Vector::from_bits(&[0, 1]);
404        assert!(a.solve(&b).is_none());
405    }
406
407    #[test]
408    fn test_transpose_roundtrip() {
409        let m = GF2Matrix::from_rows(vec![
410            GF2Vector::from_bits(&[1, 0, 1]),
411            GF2Vector::from_bits(&[0, 1, 0]),
412        ]);
413        let tt = m.transpose().transpose();
414        assert_eq!(m, tt);
415    }
416
417    #[test]
418    fn test_matrix_product() {
419        let a = GF2Matrix::identity(3);
420        let b = GF2Matrix::from_rows(vec![
421            GF2Vector::from_bits(&[1, 1, 0]),
422            GF2Vector::from_bits(&[0, 1, 1]),
423            GF2Vector::from_bits(&[1, 0, 1]),
424        ]);
425        assert_eq!(a.mul_mat(&b), b);
426    }
427}