use rustc_hash::FxHashMap;
use std::cmp::min;
use std::fmt;
#[derive(PartialEq, Eq, Clone, Debug)]
pub struct Mat2 {
d: Vec<Vec<u8>>,
}
pub trait RowOps {
fn row_add(&mut self, r0: usize, r1: usize);
fn row_swap(&mut self, r0: usize, r1: usize);
}
pub trait ColOps {
fn col_add(&mut self, c0: usize, c1: usize);
fn col_swap(&mut self, c0: usize, c1: usize);
}
impl RowOps for () {
fn row_add(&mut self, _: usize, _: usize) {}
fn row_swap(&mut self, _: usize, _: usize) {}
}
impl Mat2 {
pub fn new(d: Vec<Vec<u8>>) -> Mat2 {
Mat2 { d }
}
pub fn build<F>(rows: usize, cols: usize, f: F) -> Mat2
where
F: Fn(usize, usize) -> bool,
{
Mat2 {
d: (0..rows)
.map(|x| (0..cols).map(|y| if f(x, y) { 1 } else { 0 }).collect())
.collect(),
}
}
pub fn zeros(rows: usize, cols: usize) -> Mat2 {
Mat2::build(rows, cols, |_, _| false)
}
pub fn ones(rows: usize, cols: usize) -> Mat2 {
Mat2::build(rows, cols, |_, _| true)
}
pub fn id(dim: usize) -> Mat2 {
Mat2::build(dim, dim, |x, y| x == y)
}
pub fn unit_vector(dim: usize, i: usize) -> Mat2 {
Mat2::build(dim, 1, |x, _| x == i)
}
pub fn num_rows(&self) -> usize {
self.d.len()
}
pub fn num_cols(&self) -> usize {
if !self.d.is_empty() {
self.d[0].len()
} else {
0
}
}
pub fn transpose(&self) -> Mat2 {
Mat2::build(self.num_cols(), self.num_rows(), |i, j| self[j][i] == 1)
}
fn gauss_helper<T: RowOps>(
&mut self,
full_reduce: bool,
blocksize: usize,
x: &mut T,
pivot_cols: &mut Vec<usize>,
) -> usize {
let rows = self.num_rows();
let cols = self.num_cols();
let mut pivot_row = 0;
let num_blocks = if cols % blocksize == 0 {
cols / blocksize
} else {
(cols / blocksize) + 1
};
for sec in 0..num_blocks {
let i0 = sec * blocksize;
let i1 = min(cols, (sec + 1) * blocksize);
let mut chunks: FxHashMap<Vec<u8>, usize> = FxHashMap::default();
for r in pivot_row..rows {
let ch = self.d[r][i0..i1].to_vec();
if ch.iter().all(|&x| x == 0) {
continue;
}
if let Some(&r1) = chunks.get(&ch) {
self.row_add(r1, r);
x.row_add(r1, r);
} else {
chunks.insert(ch, r);
}
}
for p in i0..i1 {
for r0 in pivot_row..rows {
if self.d[r0][p] != 0 {
if r0 != pivot_row {
self.row_add(r0, pivot_row);
x.row_add(r0, pivot_row);
}
for r1 in pivot_row + 1..rows {
if self.d[r1][p] != 0 {
self.row_add(pivot_row, r1);
x.row_add(pivot_row, r1);
}
}
pivot_cols.push(p);
pivot_row += 1;
break;
}
}
}
}
let rank = pivot_row;
if full_reduce && rank != 0 {
pivot_row -= 1;
let mut pivot_cols1 = pivot_cols.clone();
let mut sec = num_blocks;
while sec != 0 {
sec -= 1;
let i0 = sec * blocksize;
let i1 = min(cols, (sec + 1) * blocksize);
let mut chunks: FxHashMap<Vec<u8>, usize> = FxHashMap::default();
let mut r = pivot_row + 1;
while r != 0 {
r -= 1;
let ch = self.d[r][i0..i1].to_vec();
if ch.iter().all(|&x| x == 0) {
continue;
}
if let Some(&r1) = chunks.get(&ch) {
self.row_add(r1, r);
x.row_add(r1, r);
} else {
chunks.insert(ch, r);
}
}
while let Some(&pcol) = pivot_cols1.last() {
if i0 > pcol || pcol >= i1 {
break;
}
pivot_cols1.pop();
for r in 0..pivot_row {
if self.d[r][pcol] != 0 {
self.row_add(pivot_row, r);
x.row_add(pivot_row, r);
}
}
pivot_row = pivot_row.saturating_sub(1);
}
}
}
rank
}
pub fn gauss(&mut self, full_reduce: bool) -> usize {
self.gauss_helper(full_reduce, 3, &mut (), &mut vec![])
}
pub fn gauss_x(&mut self, full_reduce: bool, blocksize: usize, x: &mut impl RowOps) -> usize {
self.gauss_helper(full_reduce, blocksize, x, &mut vec![])
}
pub fn rank(&self) -> usize {
let mut m = self.clone();
m.gauss(false)
}
pub fn inverse(&self) -> Option<Mat2> {
if self.num_rows() != self.num_cols() {
return None;
}
let mut m = self.clone();
let mut inv = Mat2::id(self.num_rows());
let rank = m.gauss_helper(true, 3, &mut inv, &mut vec![]);
if rank < self.num_rows() {
None
} else {
Some(inv)
}
}
pub fn row_weight(&self, i: usize) -> u8 {
self.d[i].iter().sum::<u8>()
}
pub fn weight(&self) -> u8 {
self.d.iter().map(|r| r.iter().sum::<u8>()).sum::<u8>()
}
pub fn unit_rows(&self) -> Vec<usize> {
self.d
.iter()
.enumerate()
.filter_map(|(i, r)| {
if r.iter().sum::<u8>() == 1 {
Some(i)
} else {
None
}
})
.collect()
}
pub fn nullspace(&self) -> Vec<Self> {
let mut mat = self.clone();
let rank = mat.gauss(true);
let n = self.num_cols();
if rank == n {
return Vec::new();
}
let mut pivot_cols = Vec::with_capacity(rank);
let mut current_rank = 0;
for col in 0..n {
if current_rank < rank && mat[current_rank][col] == 1 {
pivot_cols.push(col);
current_rank += 1;
if current_rank == rank {
break;
}
}
}
let mut free_vars = Vec::with_capacity(n - rank);
let mut pivot_iter = pivot_cols.iter().peekable();
for col in 0..n {
if let Some(&&pivot) = pivot_iter.peek() {
if pivot == col {
pivot_iter.next();
continue;
}
}
free_vars.push(col);
}
let mut basis = Vec::with_capacity(free_vars.len());
for &free_var in &free_vars {
let mut vec = Self::zeros(1, n);
vec[0][free_var] = 1;
for (row, &pivot_col) in pivot_cols.iter().enumerate().rev() {
if free_var > pivot_col && mat[row][free_var] == 1 {
vec[0][pivot_col] = 1;
}
}
basis.push(vec);
}
basis
}
pub fn vstack(&self, other: &Self) -> Self {
assert_eq!(
self.num_cols(),
other.num_cols(),
"Matrices must have the same number of columns for vertical stacking"
);
let mut result = self.d.clone();
for row in &other.d {
result.push(row.clone());
}
Mat2 { d: result }
}
pub fn hstack(&self, other: &Self) -> Self {
assert_eq!(
self.num_rows(),
other.num_rows(),
"Matrices must have the same number of rows for horizontal stacking"
);
let mut result = self.clone();
for (i, row) in other.d.iter().enumerate() {
result.d[i].extend(row);
}
result
}
}
impl RowOps for Mat2 {
fn row_add(&mut self, r0: usize, r1: usize) {
for i in 0..self.num_cols() {
self.d[r1][i] ^= self.d[r0][i];
}
}
fn row_swap(&mut self, r0: usize, r1: usize) {
self.d.swap(r0, r1);
}
}
impl ColOps for Mat2 {
fn col_add(&mut self, c0: usize, c1: usize) {
for i in 0..self.num_rows() {
self.d[i][c1] ^= self.d[i][c0];
}
}
fn col_swap(&mut self, c0: usize, c1: usize) {
for i in 0..self.num_rows() {
self.d[i].swap(c0, c1);
}
}
}
impl fmt::Display for Mat2 {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
for row in &self.d {
write!(f, "[ ")?;
for x in row {
write!(f, "{x} ")?;
}
writeln!(f, "]")?;
}
Ok(())
}
}
impl std::ops::Index<(usize, usize)> for Mat2 {
type Output = u8;
fn index(&self, idx: (usize, usize)) -> &Self::Output {
&self.d[idx.0][idx.1]
}
}
impl std::ops::IndexMut<(usize, usize)> for Mat2 {
fn index_mut(&mut self, idx: (usize, usize)) -> &mut Self::Output {
&mut self.d[idx.0][idx.1]
}
}
impl std::ops::Index<usize> for Mat2 {
type Output = Vec<u8>;
fn index(&self, idx: usize) -> &Self::Output {
&self.d[idx]
}
}
impl std::ops::IndexMut<usize> for Mat2 {
fn index_mut(&mut self, idx: usize) -> &mut Self::Output {
&mut self.d[idx]
}
}
impl std::ops::Mul<&Mat2> for &Mat2 {
type Output = Mat2;
#[allow(clippy::suspicious_arithmetic_impl)]
fn mul(self, rhs: &Mat2) -> Self::Output {
if self.num_cols() != rhs.num_rows() {
panic!("Cannot multiply matrices with mismatched dimensions.");
}
let k = self.num_cols();
Mat2::build(self.num_rows(), rhs.num_cols(), |x, y| {
let mut b = 0;
for i in 0..k {
b ^= self.d[x][i] & rhs.d[i][y];
}
b == 1
})
}
}
impl std::ops::Mul<Mat2> for &Mat2 {
type Output = Mat2;
fn mul(self, rhs: Mat2) -> Self::Output {
self * &rhs
}
}
impl std::ops::Mul<&Mat2> for Mat2 {
type Output = Mat2;
fn mul(self, rhs: &Mat2) -> Self::Output {
&self * rhs
}
}
impl std::ops::Mul<Mat2> for Mat2 {
type Output = Mat2;
fn mul(self, rhs: Mat2) -> Self::Output {
&self * &rhs
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn mat_mul() {
let v = Mat2::new(vec![vec![1, 0, 1, 0], vec![1, 1, 1, 1], vec![0, 0, 1, 1]]);
let w = Mat2::new(vec![vec![1, 1], vec![1, 0], vec![0, 0], vec![0, 1]]);
let u = Mat2::new(vec![vec![1, 1], vec![0, 0], vec![0, 1]]);
assert_eq!(&v * &w, u);
}
#[test]
fn transpose() {
let v = Mat2::new(vec![vec![1, 0, 1, 0], vec![1, 1, 1, 1], vec![0, 0, 1, 1]]);
let vt = Mat2::new(vec![
vec![1, 1, 0],
vec![0, 1, 0],
vec![1, 1, 1],
vec![0, 1, 1],
]);
assert_eq!(v.transpose(), vt);
}
#[test]
fn unit_vecs() {
let v = Mat2::new(vec![vec![1, 0, 1, 0], vec![1, 1, 1, 1], vec![0, 0, 1, 1]]);
let c0 = Mat2::new(vec![vec![1, 1, 0]]).transpose();
let c1 = Mat2::new(vec![vec![0, 1, 0]]).transpose();
let c2 = Mat2::new(vec![vec![1, 1, 1]]).transpose();
let c3 = Mat2::new(vec![vec![0, 1, 1]]).transpose();
assert_eq!(&v * Mat2::unit_vector(4, 0), c0);
assert_eq!(&v * Mat2::unit_vector(4, 1), c1);
assert_eq!(&v * Mat2::unit_vector(4, 2), c2);
assert_eq!(&v * Mat2::unit_vector(4, 3), c3);
}
#[test]
fn row_ops() {
let mut v = Mat2::new(vec![vec![1, 0, 1, 0], vec![1, 1, 1, 1], vec![0, 0, 1, 1]]);
let w1 = Mat2::new(vec![vec![1, 0, 1, 0], vec![1, 1, 1, 1], vec![1, 1, 0, 0]]);
let w2 = Mat2::new(vec![vec![1, 1, 1, 1], vec![1, 0, 1, 0], vec![1, 1, 0, 0]]);
v.row_add(1, 2);
assert_eq!(v, w1);
v.row_swap(0, 1);
assert_eq!(v, w2);
}
#[test]
fn gauss() {
let v = Mat2::id(4);
let mut w = v.clone();
w.gauss(true);
assert_eq!(v, w);
}
#[test]
fn ranks() {
let v = Mat2::new(vec![vec![1, 0, 1, 0], vec![1, 1, 1, 1], vec![0, 0, 1, 1]]);
assert_eq!(v.rank(), 3);
let v = Mat2::new(vec![vec![1, 0, 1, 0], vec![1, 1, 1, 1], vec![0, 1, 0, 1]]);
assert_eq!(v.rank(), 2);
}
#[test]
fn inv() {
let v = Mat2::new(vec![vec![1, 1, 1], vec![0, 1, 1], vec![0, 0, 1]]);
assert_eq!(v.rank(), 3);
let vi = v.inverse().expect("v should be invertible");
assert_eq!(&v * &vi, Mat2::id(3));
assert_eq!(&vi * &v, Mat2::id(3));
let vi_exp = Mat2::new(vec![vec![1, 1, 0], vec![0, 1, 1], vec![0, 0, 1]]);
assert_eq!(vi_exp, vi);
}
#[test]
fn test_nullspace() {
let mat = Mat2::new(vec![vec![1, 0, 1], vec![0, 1, 1]]);
let nullspace = mat.nullspace();
assert_eq!(nullspace.len(), 1);
assert_eq!(nullspace[0].d, vec![vec![1, 1, 1]]);
println!("Matrix is \n{}", mat)
}
#[test]
#[should_panic(expected = "Matrices must have the same number of rows for horizontal stacking")]
fn test_hstack_panic() {
let a = Mat2::new(vec![vec![1, 0]]);
let b = Mat2::new(vec![vec![1], vec![0]]);
a.hstack(&b); }
#[test]
fn test_hstack() {
let a = Mat2::new(vec![vec![1], vec![0]]);
let b = Mat2::new(vec![vec![1], vec![0]]);
let c = a.hstack(&b);
assert_eq!(c, Mat2::new(vec![vec![1, 1], vec![0, 0]]));
}
#[test]
fn test_vstack() {
let a = Mat2::new(vec![vec![1, 0]]);
let b = Mat2::new(vec![vec![1, 0]]);
let c = a.vstack(&b);
assert_eq!(c, Mat2::new(vec![vec![1, 0], vec![1, 0]]));
}
}