use std::fmt;
#[derive(Clone, PartialEq, Eq, Hash)]
pub struct BitVec {
words: Vec<u64>,
len: usize,
}
impl BitVec {
#[must_use]
pub fn zeros(len: usize) -> Self {
let word_count = len.div_ceil(64);
Self {
words: vec![0; word_count],
len,
}
}
#[must_use]
pub fn singleton(len: usize, bit: usize) -> Self {
let mut v = Self::zeros(len);
v.set(bit);
v
}
#[must_use]
pub const fn len(&self) -> usize {
self.len
}
#[must_use]
pub const fn is_empty(&self) -> bool {
self.len == 0
}
#[must_use]
pub fn is_zero(&self) -> bool {
self.words.iter().all(|&w| w == 0)
}
#[must_use]
pub fn get(&self, i: usize) -> bool {
assert!(
i < self.len,
"bit index {i} out of range (len={})",
self.len
);
let word = i / 64;
let bit = i % 64;
(self.words[word] >> bit) & 1 == 1
}
pub fn set(&mut self, i: usize) {
assert!(
i < self.len,
"bit index {i} out of range (len={})",
self.len
);
let word = i / 64;
let bit = i % 64;
self.words[word] |= 1u64 << bit;
}
pub fn clear(&mut self, i: usize) {
assert!(
i < self.len,
"bit index {i} out of range (len={})",
self.len
);
let word = i / 64;
let bit = i % 64;
self.words[word] &= !(1u64 << bit);
}
pub fn flip(&mut self, i: usize) {
assert!(
i < self.len,
"bit index {i} out of range (len={})",
self.len
);
let word = i / 64;
let bit = i % 64;
self.words[word] ^= 1u64 << bit;
}
pub fn xor_assign(&mut self, other: &Self) {
assert_eq!(
self.len, other.len,
"xor_assign: length mismatch ({} vs {})",
self.len, other.len
);
for (a, b) in self.words.iter_mut().zip(other.words.iter()) {
*a ^= b;
}
}
#[must_use]
pub fn xor(&self, other: &Self) -> Self {
let mut result = self.clone();
result.xor_assign(other);
result
}
#[must_use]
pub fn pivot(&self) -> Option<usize> {
for (word_idx, &word) in self.words.iter().enumerate() {
if word != 0 {
let bit = word.trailing_zeros() as usize;
let index = word_idx * 64 + bit;
if index < self.len {
return Some(index);
}
}
}
None
}
#[must_use]
pub fn highest_bit(&self) -> Option<usize> {
for (word_idx, &word) in self.words.iter().enumerate().rev() {
let mut w = word;
if word_idx == self.words.len() - 1 {
let rem = self.len % 64;
if rem != 0 {
w &= (1u64 << rem) - 1;
}
}
if w != 0 {
let bit = 63 - w.leading_zeros() as usize;
return Some(word_idx * 64 + bit);
}
}
None
}
#[must_use]
pub fn count_ones(&self) -> usize {
self.words.iter().map(|w| w.count_ones() as usize).sum()
}
pub fn ones(&self) -> impl Iterator<Item = usize> + '_ {
self.words
.iter()
.enumerate()
.flat_map(move |(word_idx, &word)| {
let base = word_idx * 64;
BitIter {
word,
base,
len: self.len,
}
})
}
}
impl fmt::Debug for BitVec {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "BitVec({}, [", self.len)?;
let mut first = true;
for i in self.ones() {
if !first {
write!(f, ", ")?;
}
write!(f, "{i}")?;
first = false;
}
write!(f, "])")
}
}
struct BitIter {
word: u64,
base: usize,
len: usize,
}
impl Iterator for BitIter {
type Item = usize;
fn next(&mut self) -> Option<usize> {
if self.word == 0 {
return None;
}
let bit = self.word.trailing_zeros() as usize;
let index = self.base + bit;
if index >= self.len {
self.word = 0;
return None;
}
self.word &= self.word - 1;
Some(index)
}
}
#[derive(Clone)]
pub struct BoundaryMatrix {
rows: usize,
columns: Vec<BitVec>,
}
impl BoundaryMatrix {
#[must_use]
pub fn zeros(rows: usize, cols: usize) -> Self {
Self {
rows,
columns: (0..cols).map(|_| BitVec::zeros(rows)).collect(),
}
}
#[must_use]
pub fn from_columns(rows: usize, columns: Vec<BitVec>) -> Self {
for (i, col) in columns.iter().enumerate() {
assert_eq!(
col.len(),
rows,
"column {i} has length {}, expected {rows}",
col.len()
);
}
Self { rows, columns }
}
#[must_use]
pub const fn rows(&self) -> usize {
self.rows
}
#[must_use]
pub fn cols(&self) -> usize {
self.columns.len()
}
#[must_use]
pub fn column(&self, j: usize) -> &BitVec {
&self.columns[j]
}
pub fn column_mut(&mut self, j: usize) -> &mut BitVec {
&mut self.columns[j]
}
pub fn set(&mut self, i: usize, j: usize) {
self.columns[j].set(i);
}
#[must_use]
pub fn get(&self, i: usize, j: usize) -> bool {
self.columns[j].get(i)
}
pub fn xor_columns(&mut self, dst: usize, src: usize) {
assert_ne!(dst, src, "xor_columns: src and dst must differ");
let src_col = self.columns[src].clone();
self.columns[dst].xor_assign(&src_col);
}
#[must_use]
pub fn column_pivot(&self, j: usize) -> Option<usize> {
self.columns[j].highest_bit()
}
#[must_use]
pub fn reduce(&self) -> ReducedMatrix {
let mut reduced = self.clone();
let mut pivot_map: Vec<Option<usize>> = vec![None; self.rows];
for j in 0..reduced.cols() {
while let Some(pivot) = reduced.column_pivot(j) {
let Some(existing_col) = pivot_map[pivot] else {
pivot_map[pivot] = Some(j);
break;
};
let existing = reduced.columns[existing_col].clone();
reduced.columns[j].xor_assign(&existing);
}
}
ReducedMatrix {
matrix: reduced,
pivot_map,
}
}
}
impl fmt::Debug for BoundaryMatrix {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(f, "BoundaryMatrix({}x{}):", self.rows, self.cols())?;
for i in 0..self.rows {
write!(f, " ")?;
for j in 0..self.cols() {
write!(f, "{}", u8::from(self.get(i, j)))?;
}
writeln!(f)?;
}
Ok(())
}
}
#[derive(Clone)]
pub struct ReducedMatrix {
pub matrix: BoundaryMatrix,
pub pivot_map: Vec<Option<usize>>,
}
impl ReducedMatrix {
#[must_use]
pub fn persistence_pairs(&self) -> PersistencePairs {
let mut pairs = Vec::new();
let mut unpaired = Vec::new();
for j in 0..self.matrix.cols() {
if let Some(pivot) = self.matrix.column_pivot(j) {
pairs.push((pivot, j));
}
}
let mut is_death = vec![false; self.matrix.cols()];
let mut is_birth = vec![false; self.matrix.rows()];
for &(birth, death) in &pairs {
is_birth[birth] = true;
is_death[death] = true;
}
for j in 0..self.matrix.cols() {
let j_is_birth = j < self.matrix.rows() && is_birth[j];
if !is_death[j] && !j_is_birth && self.matrix.column(j).is_zero() {
unpaired.push(j);
}
}
PersistencePairs { pairs, unpaired }
}
}
#[derive(Debug, Clone)]
pub struct PersistencePairs {
pub pairs: Vec<(usize, usize)>,
pub unpaired: Vec<usize>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn bitvec_zeros() {
let v = BitVec::zeros(100);
assert_eq!(v.len(), 100);
assert!(v.is_zero());
assert_eq!(v.pivot(), None);
assert_eq!(v.count_ones(), 0);
}
#[test]
fn bitvec_singleton() {
let v = BitVec::singleton(128, 65);
assert!(v.get(65));
assert!(!v.get(0));
assert!(!v.get(64));
assert_eq!(v.pivot(), Some(65));
assert_eq!(v.highest_bit(), Some(65));
assert_eq!(v.count_ones(), 1);
}
#[test]
fn bitvec_set_clear_flip() {
let mut v = BitVec::zeros(10);
v.set(3);
v.set(7);
assert!(v.get(3));
assert!(v.get(7));
assert_eq!(v.count_ones(), 2);
v.clear(3);
assert!(!v.get(3));
assert_eq!(v.count_ones(), 1);
v.flip(7);
assert!(!v.get(7));
assert!(v.is_zero());
v.flip(0);
assert!(v.get(0));
}
#[test]
fn bitvec_xor() {
let mut a = BitVec::zeros(8);
a.set(0);
a.set(2);
a.set(4);
let mut b = BitVec::zeros(8);
b.set(2);
b.set(3);
b.set(4);
let c = a.xor(&b); assert!(c.get(0));
assert!(!c.get(2));
assert!(c.get(3));
assert!(!c.get(4));
assert_eq!(c.count_ones(), 2);
}
#[test]
fn bitvec_pivot_and_highest() {
let mut v = BitVec::zeros(200);
v.set(5);
v.set(100);
v.set(150);
assert_eq!(v.pivot(), Some(5));
assert_eq!(v.highest_bit(), Some(150));
}
#[test]
fn bitvec_ones_iterator() {
let mut v = BitVec::zeros(130);
v.set(0);
v.set(63);
v.set(64);
v.set(127);
v.set(129);
let ones: Vec<usize> = v.ones().collect();
assert_eq!(ones, vec![0, 63, 64, 127, 129]);
}
#[test]
fn bitvec_large_xor() {
let mut a = BitVec::zeros(1000);
let mut b = BitVec::zeros(1000);
for i in (0..1000).step_by(3) {
a.set(i);
}
for i in (0..1000).step_by(5) {
b.set(i);
}
let c = a.xor(&b);
for i in 0..1000 {
let in_a = i % 3 == 0;
let in_b = i % 5 == 0;
assert_eq!(c.get(i), in_a ^ in_b, "mismatch at bit {i}");
}
}
#[test]
fn matrix_basic() {
let mut m = BoundaryMatrix::zeros(3, 3);
m.set(0, 1);
m.set(1, 1);
m.set(1, 2);
m.set(2, 2);
assert!(m.get(0, 1));
assert!(m.get(1, 1));
assert!(m.get(1, 2));
assert!(m.get(2, 2));
assert!(!m.get(0, 0));
}
#[test]
fn matrix_xor_columns() {
let mut m = BoundaryMatrix::zeros(4, 3);
m.set(0, 0);
m.set(1, 0);
m.set(1, 1);
m.set(2, 1);
m.set(0, 2);
m.set(2, 2);
m.xor_columns(2, 0);
assert!(!m.get(0, 2));
assert!(m.get(1, 2));
assert!(m.get(2, 2));
}
#[test]
fn matrix_reduce_triangle() {
let mut d1 = BoundaryMatrix::zeros(3, 3);
d1.set(0, 0);
d1.set(1, 0);
d1.set(0, 1);
d1.set(2, 1);
d1.set(1, 2);
d1.set(2, 2);
let reduced = d1.reduce();
let p0 = reduced.matrix.column_pivot(0);
let p1 = reduced.matrix.column_pivot(1);
let p2 = reduced.matrix.column_pivot(2);
assert_eq!(p0, Some(1));
assert_eq!(p1, Some(2));
assert_eq!(p2, None);
let pairs = reduced.persistence_pairs();
assert!(!pairs.pairs.is_empty(), "should have persistence pairs");
}
#[test]
fn reduce_determinism() {
let mut d = BoundaryMatrix::zeros(5, 5);
d.set(0, 1);
d.set(1, 1);
d.set(0, 2);
d.set(2, 2);
d.set(1, 3);
d.set(2, 3);
d.set(3, 4);
d.set(4, 4);
let r1 = d.reduce();
let r2 = d.reduce();
for j in 0..5 {
assert_eq!(
r1.matrix.column_pivot(j),
r2.matrix.column_pivot(j),
"pivot mismatch at column {j}"
);
}
}
#[test]
fn reduce_keeps_zero_column_zero() {
let mut d = BoundaryMatrix::zeros(4, 3);
d.set(0, 0);
d.set(1, 0);
d.set(2, 2);
let reduced = d.reduce();
assert_eq!(reduced.matrix.column_pivot(1), None);
}
#[test]
fn bitvec_empty() {
let v = BitVec::zeros(0);
assert_eq!(v.len(), 0);
assert!(v.is_zero());
assert_eq!(v.pivot(), None);
assert_eq!(v.count_ones(), 0);
assert_eq!(v.ones().count(), 0);
}
#[test]
fn bitvec_word_boundary() {
let mut v = BitVec::zeros(128);
v.set(63);
v.set(64);
assert_eq!(v.pivot(), Some(63));
assert_eq!(v.highest_bit(), Some(64));
assert_eq!(v.count_ones(), 2);
v.clear(63);
assert_eq!(v.pivot(), Some(64));
}
#[test]
fn persistence_pairs_simple() {
let mut d = BoundaryMatrix::zeros(2, 1);
d.set(0, 0);
d.set(1, 0);
let reduced = d.reduce();
let pairs = reduced.persistence_pairs();
assert_eq!(pairs.pairs.len(), 1);
assert_eq!(pairs.pairs[0], (1, 0));
}
#[test]
fn persistence_pairs_non_square_more_rows_than_cols() {
let mut d = BoundaryMatrix::zeros(6, 2);
d.set(0, 0);
d.set(1, 0);
d.set(2, 0);
d.set(2, 1);
d.set(3, 1);
d.set(4, 1);
let reduced = d.reduce();
let pairs = reduced.persistence_pairs();
assert_eq!(pairs.pairs.len() + pairs.unpaired.len(), 2);
}
#[test]
fn persistence_pairs_non_square_more_cols_than_rows() {
let mut d = BoundaryMatrix::zeros(3, 5);
d.set(0, 0);
d.set(1, 0); d.set(1, 1);
d.set(2, 1); d.set(0, 2);
d.set(2, 2); d.set(0, 3);
d.set(1, 3);
d.set(0, 4);
d.set(1, 4);
let reduced = d.reduce();
let pairs = reduced.persistence_pairs();
assert!(pairs.pairs.len() + pairs.unpaired.len() <= 5);
}
#[test]
fn bitvec_clone_eq_hash() {
use std::collections::HashSet;
let mut a = BitVec::zeros(128);
a.set(0);
a.set(64);
let b = a.clone();
assert_eq!(a, b);
let mut set = HashSet::new();
set.insert(a.clone());
assert!(set.contains(&b));
}
#[test]
fn persistence_pairs_debug_clone() {
let pp = PersistencePairs {
pairs: vec![(0, 1), (2, 3)],
unpaired: vec![4],
};
let cloned = pp.clone();
assert_eq!(cloned.pairs.len(), 2);
assert_eq!(cloned.unpaired, vec![4]);
let dbg = format!("{pp:?}");
assert!(dbg.contains("PersistencePairs"));
}
}