1use super::scalar::GF2;
4use super::vector::GF2Vector;
5use crate::error::{CoreError, CoreResult};
6use alloc::vec;
7use alloc::vec::Vec;
8use core::fmt;
9
10#[derive(Debug, Clone, PartialEq, Eq)]
15pub struct GF2Matrix {
16 rows: Vec<GF2Vector>,
17 nrows: usize,
18 ncols: usize,
19}
20
21impl GF2Matrix {
22 #[must_use]
24 pub fn zero(nrows: usize, ncols: usize) -> Self {
25 let rows = (0..nrows).map(|_| GF2Vector::zero(ncols)).collect();
26 Self { rows, nrows, ncols }
27 }
28
29 #[must_use]
31 pub fn identity(n: usize) -> Self {
32 let mut m = Self::zero(n, n);
33 for i in 0..n {
34 m.set(i, i, GF2::ONE);
35 }
36 m
37 }
38
39 #[must_use]
41 pub fn from_rows(rows: Vec<GF2Vector>) -> Self {
42 let nrows = rows.len();
43 let ncols = if nrows > 0 { rows[0].dim() } else { 0 };
44 debug_assert!(rows.iter().all(|r| r.dim() == ncols));
45 Self { rows, nrows, ncols }
46 }
47
48 #[inline]
50 #[must_use]
51 pub fn nrows(&self) -> usize {
52 self.nrows
53 }
54
55 #[inline]
57 #[must_use]
58 pub fn ncols(&self) -> usize {
59 self.ncols
60 }
61
62 #[inline]
64 #[must_use]
65 pub fn get(&self, row: usize, col: usize) -> GF2 {
66 self.rows[row].get(col)
67 }
68
69 #[inline]
71 pub fn set(&mut self, row: usize, col: usize, value: GF2) {
72 self.rows[row].set(col, value);
73 }
74
75 #[must_use]
77 pub fn row(&self, i: usize) -> &GF2Vector {
78 &self.rows[i]
79 }
80
81 #[must_use]
83 pub fn mul_vec(&self, v: &GF2Vector) -> GF2Vector {
84 assert_eq!(self.ncols, v.dim(), "dimension mismatch");
85 let bits: Vec<u8> = self.rows.iter().map(|row| row.dot(v).value()).collect();
86 GF2Vector::from_bits(&bits)
87 }
88
89 #[must_use]
91 pub fn mul_mat(&self, other: &Self) -> Self {
92 assert_eq!(self.ncols, other.nrows, "dimension mismatch");
93 let other_t = other.transpose();
94 let rows: Vec<GF2Vector> = self
95 .rows
96 .iter()
97 .map(|row| {
98 let bits: Vec<u8> = other_t
99 .rows
100 .iter()
101 .map(|col| row.dot(col).value())
102 .collect();
103 GF2Vector::from_bits(&bits)
104 })
105 .collect();
106 Self::from_rows(rows)
107 }
108
109 #[must_use]
111 pub fn transpose(&self) -> Self {
112 let mut t = Self::zero(self.ncols, self.nrows);
113 for i in 0..self.nrows {
114 for j in 0..self.ncols {
115 t.set(j, i, self.get(i, j));
116 }
117 }
118 t
119 }
120
121 pub fn reduced_row_echelon(&mut self) -> Vec<usize> {
123 let mut pivots = Vec::new();
124 let mut pivot_row = 0;
125
126 for col in 0..self.ncols {
127 let found = (pivot_row..self.nrows).find(|&r| self.get(r, col).is_one());
129
130 if let Some(swap_row) = found {
131 self.rows.swap(pivot_row, swap_row);
132
133 for r in 0..self.nrows {
135 if r != pivot_row && self.get(r, col).is_one() {
136 let pivot = self.rows[pivot_row].clone();
137 self.rows[r] = self.rows[r].add(&pivot);
138 }
139 }
140
141 pivots.push(col);
142 pivot_row += 1;
143 }
144 }
145 pivots
146 }
147
148 pub fn row_echelon(&mut self) -> Vec<usize> {
153 self.reduced_row_echelon()
154 }
155
156 #[must_use]
158 pub fn rank(&self) -> usize {
159 let mut copy = self.clone();
160 copy.reduced_row_echelon().len()
161 }
162
163 #[must_use]
165 pub fn null_space(&self) -> Vec<GF2Vector> {
166 let mut rref = self.clone();
167 let pivots = rref.reduced_row_echelon();
168
169 let pivot_set: Vec<bool> = (0..self.ncols).map(|c| pivots.contains(&c)).collect();
170
171 let mut pivot_row_for_col = vec![usize::MAX; self.ncols];
173 for (row, &col) in pivots.iter().enumerate() {
174 pivot_row_for_col[col] = row;
175 }
176
177 let free_cols: Vec<usize> = (0..self.ncols).filter(|c| !pivot_set[*c]).collect();
178
179 let mut basis = Vec::new();
180 for &fc in &free_cols {
181 let mut v = GF2Vector::zero(self.ncols);
182 v.set(fc, GF2::ONE);
183 for &pc in &pivots {
185 let pr = pivot_row_for_col[pc];
186 v.set(pc, rref.get(pr, fc));
187 }
188 basis.push(v);
189 }
190 basis
191 }
192
193 pub fn determinant(&self) -> CoreResult<GF2> {
195 if self.nrows != self.ncols {
196 return Err(CoreError::GF2NotSquare {
197 rows: self.nrows,
198 cols: self.ncols,
199 });
200 }
201 let r = self.rank();
202 Ok(if r == self.nrows { GF2::ONE } else { GF2::ZERO })
203 }
204
205 #[must_use]
207 pub fn column_space(&self) -> Vec<GF2Vector> {
208 let t = self.transpose();
209 let mut rref = t.clone();
210 let pivots = rref.reduced_row_echelon();
211 pivots.iter().map(|&c| t.row(c).clone()).collect()
212 }
213
214 #[must_use]
216 pub fn in_column_space(&self, v: &GF2Vector) -> bool {
217 self.solve(v).is_some()
218 }
219
220 #[must_use]
222 pub fn solve(&self, b: &GF2Vector) -> Option<GF2Vector> {
223 assert_eq!(self.nrows, b.dim(), "dimension mismatch");
224 let mut aug = self.augment(b);
225 let pivots = aug.reduced_row_echelon();
226
227 let aug_col = self.ncols;
229 if pivots.contains(&aug_col) {
230 return None;
231 }
232
233 let mut x = GF2Vector::zero(self.ncols);
235 for (row, &col) in pivots.iter().enumerate() {
236 x.set(col, aug.get(row, aug_col));
237 }
238 Some(x)
239 }
240
241 #[must_use]
243 pub fn augment(&self, b: &GF2Vector) -> Self {
244 assert_eq!(self.nrows, b.dim(), "dimension mismatch");
245 let new_ncols = self.ncols + 1;
246 let rows: Vec<GF2Vector> = self
247 .rows
248 .iter()
249 .enumerate()
250 .map(|(i, row)| {
251 let mut new_row = GF2Vector::zero(new_ncols);
252 for j in 0..self.ncols {
253 new_row.set(j, row.get(j));
254 }
255 new_row.set(self.ncols, b.get(i));
256 new_row
257 })
258 .collect();
259 Self {
260 rows,
261 nrows: self.nrows,
262 ncols: new_ncols,
263 }
264 }
265
266 #[must_use]
268 pub fn hcat(&self, other: &Self) -> Self {
269 assert_eq!(self.nrows, other.nrows, "row count mismatch");
270 let new_ncols = self.ncols + other.ncols;
271 let rows: Vec<GF2Vector> = self
272 .rows
273 .iter()
274 .zip(other.rows.iter())
275 .map(|(a, b)| {
276 let mut new_row = GF2Vector::zero(new_ncols);
277 for j in 0..self.ncols {
278 new_row.set(j, a.get(j));
279 }
280 for j in 0..other.ncols {
281 new_row.set(self.ncols + j, b.get(j));
282 }
283 new_row
284 })
285 .collect();
286 Self {
287 rows,
288 nrows: self.nrows,
289 ncols: new_ncols,
290 }
291 }
292}
293
294impl fmt::Display for GF2Matrix {
295 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
296 for (i, row) in self.rows.iter().enumerate() {
297 if i > 0 {
298 writeln!(f)?;
299 }
300 write!(f, "{}", row)?;
301 }
302 Ok(())
303 }
304}
305
306#[cfg(test)]
307mod tests {
308 use super::*;
309
310 #[test]
311 fn test_identity_properties() {
312 let id = GF2Matrix::identity(3);
313 assert_eq!(id.rank(), 3);
314 assert_eq!(id.determinant().unwrap(), GF2::ONE);
315
316 let v = GF2Vector::from_bits(&[1, 0, 1]);
317 assert_eq!(id.mul_vec(&v), v);
318 }
319
320 #[test]
321 fn test_matrix_vector_product() {
322 let m = GF2Matrix::from_rows(vec![
324 GF2Vector::from_bits(&[1, 0, 1]),
325 GF2Vector::from_bits(&[0, 1, 1]),
326 ]);
327 let v = GF2Vector::from_bits(&[1, 1, 0]);
328 let result = m.mul_vec(&v);
329 assert_eq!(result, GF2Vector::from_bits(&[1, 1]));
330 }
331
332 #[test]
333 fn test_row_echelon_and_rank() {
334 let mut m = GF2Matrix::from_rows(vec![
335 GF2Vector::from_bits(&[1, 0, 1, 0]),
336 GF2Vector::from_bits(&[0, 1, 1, 0]),
337 GF2Vector::from_bits(&[1, 1, 0, 0]),
338 ]);
339 let pivots = m.reduced_row_echelon();
340 assert_eq!(pivots.len(), 2); }
342
343 #[test]
344 fn test_full_rank() {
345 let m = GF2Matrix::from_rows(vec![
346 GF2Vector::from_bits(&[1, 0, 0]),
347 GF2Vector::from_bits(&[0, 1, 0]),
348 GF2Vector::from_bits(&[0, 0, 1]),
349 ]);
350 assert_eq!(m.rank(), 3);
351 assert_eq!(m.determinant().unwrap(), GF2::ONE);
352 }
353
354 #[test]
355 fn test_rank_deficient() {
356 let m = GF2Matrix::from_rows(vec![
357 GF2Vector::from_bits(&[1, 1, 0]),
358 GF2Vector::from_bits(&[0, 0, 1]),
359 GF2Vector::from_bits(&[1, 1, 1]),
360 ]);
361 assert_eq!(m.rank(), 2);
362 assert_eq!(m.determinant().unwrap(), GF2::ZERO);
363 }
364
365 #[test]
366 fn test_null_space() {
367 let m = GF2Matrix::from_rows(vec![
369 GF2Vector::from_bits(&[1, 0, 1]),
370 GF2Vector::from_bits(&[0, 1, 1]),
371 ]);
372 let ns = m.null_space();
373 assert_eq!(ns.len(), 1);
374 for v in &ns {
376 let product = m.mul_vec(v);
377 assert!(product.is_zero(), "null space vector not in kernel");
378 }
379 }
380
381 #[test]
382 fn test_determinant_non_square() {
383 let m = GF2Matrix::zero(2, 3);
384 assert!(m.determinant().is_err());
385 }
386
387 #[test]
388 fn test_solve() {
389 let a = GF2Matrix::identity(2);
391 let b = GF2Vector::from_bits(&[1, 1]);
392 let x = a.solve(&b).unwrap();
393 assert_eq!(a.mul_vec(&x), b);
394 }
395
396 #[test]
397 fn test_solve_inconsistent() {
398 let a = GF2Matrix::from_rows(vec![
400 GF2Vector::from_bits(&[1, 0]),
401 GF2Vector::from_bits(&[1, 0]),
402 ]);
403 let b = GF2Vector::from_bits(&[0, 1]);
404 assert!(a.solve(&b).is_none());
405 }
406
407 #[test]
408 fn test_transpose_roundtrip() {
409 let m = GF2Matrix::from_rows(vec![
410 GF2Vector::from_bits(&[1, 0, 1]),
411 GF2Vector::from_bits(&[0, 1, 0]),
412 ]);
413 let tt = m.transpose().transpose();
414 assert_eq!(m, tt);
415 }
416
417 #[test]
418 fn test_matrix_product() {
419 let a = GF2Matrix::identity(3);
420 let b = GF2Matrix::from_rows(vec![
421 GF2Vector::from_bits(&[1, 1, 0]),
422 GF2Vector::from_bits(&[0, 1, 1]),
423 GF2Vector::from_bits(&[1, 0, 1]),
424 ]);
425 assert_eq!(a.mul_mat(&b), b);
426 }
427}