use std::collections::BTreeSet;
use crate::bits::Bits;
#[derive(Debug)]
pub struct SparseBoolMatrix {
ones_by_row: Vec<BTreeSet<usize>>,
column_count: usize,
}
impl SparseBoolMatrix {
pub fn new(column_count: usize) -> Self {
Self {
ones_by_row: vec![],
column_count,
}
}
pub fn push_row(&mut self, one_indices: BTreeSet<usize>) {
self.ones_by_row.push(one_indices);
}
pub fn identity(row_count: usize) -> Self {
Self {
ones_by_row: (0..row_count).map(|i| BTreeSet::from([i])).collect(),
column_count: row_count,
}
}
pub fn solve<V: Bits>(mut self, mut targets: Vec<V>) -> Vec<Option<V>> {
let row_count = self.ones_by_row.len();
assert_eq!(row_count, targets.len());
let mut pivot_row = 0;
let mut pivot_column = 0;
while pivot_row <= row_count && pivot_column <= self.column_count {
let mut next_pivot_row = None;
for i in pivot_row..row_count {
if self.ones_by_row[i].contains(&pivot_column) {
next_pivot_row = Some(i);
break;
}
}
match next_pivot_row {
Some(i) => {
self.ones_by_row.swap(pivot_row, i);
targets.swap(pivot_row, i);
}
None => {
pivot_column += 1;
continue;
}
}
for i in (pivot_row + 1)..row_count {
self.ones_by_row[i].remove(&pivot_column);
if !self.ones_by_row[i].contains(&pivot_column) {
continue;
}
for j in (pivot_column + 1)..self.column_count {
if self.ones_by_row[i].contains(&j) {
self.ones_by_row[i].remove(&j);
} else {
self.ones_by_row[i].insert(j);
}
}
}
pivot_row += 1;
pivot_column += 1;
}
let mut assignment: Vec<Option<V>> = (0..self.column_count).map(|_| None).collect();
for (k, ones) in self.ones_by_row.into_iter().enumerate().rev() {
let mut last = None;
for i in ones {
match assignment[i] {
Some(value) => targets[k] ^= value,
None => {
last = Some(i);
let value = V::random();
targets[k] ^= value;
assignment[i] = Some(value);
}
}
}
assignment[last.unwrap()] = Some(assignment[last.unwrap()].unwrap() ^ targets[k]);
}
assignment
}
}
#[cfg(test)]
mod tests {
use super::SparseBoolMatrix;
#[test]
fn test_identity_5x5() {
let identity = SparseBoolMatrix::identity(5);
let targets = vec![10, 20, 30, 40, 50];
let solution = identity.solve(targets.clone());
assert_eq!(
solution,
vec![Some(10), Some(20), Some(30), Some(40), Some(50)]
);
}
#[test]
fn test_identity_4x5() {
let mut identity = SparseBoolMatrix::identity(5);
identity.ones_by_row.pop();
let solution = identity.solve(vec![10, 20, 30, 40]);
assert_eq!(solution[0..4], [Some(10), Some(20), Some(30), Some(40)]);
assert!(solution[4].is_none());
}
#[test]
fn test_identity_4x5_shifted() {
let mut identity = SparseBoolMatrix::identity(5);
identity.ones_by_row.remove(0);
let solution = identity.solve(vec![10, 20, 30, 40]);
assert!(solution[0].is_none());
assert_eq!(solution[1..5], [Some(10), Some(20), Some(30), Some(40)]);
}
}