Skip to main content

oxihuman_core/
bit_matrix.rs

1// Copyright (C) 2026 COOLJAPAN OU (Team KitaSan)
2// SPDX-License-Identifier: Apache-2.0
3
4//! Compact boolean matrix (bitset rows with u64 words).
5
6#![allow(dead_code)]
7
8/// A compact boolean matrix stored as rows of u64 bit-words.
9#[allow(dead_code)]
10#[derive(Debug, Clone)]
11pub struct BitMatrix {
12    pub rows: usize,
13    pub cols: usize,
14    words_per_row: usize,
15    data: Vec<u64>,
16}
17
18#[allow(dead_code)]
19impl BitMatrix {
20    pub fn new(rows: usize, cols: usize) -> Self {
21        let words_per_row = cols.div_ceil(64);
22        let data = vec![0u64; rows * words_per_row];
23        Self {
24            rows,
25            cols,
26            words_per_row,
27            data,
28        }
29    }
30
31    fn index(&self, row: usize, col: usize) -> (usize, usize) {
32        let word = row * self.words_per_row + col / 64;
33        let bit = col % 64;
34        (word, bit)
35    }
36
37    /// Set bit at (row, col) to true.
38    pub fn set(&mut self, row: usize, col: usize) {
39        assert!(row < self.rows && col < self.cols);
40        let (w, b) = self.index(row, col);
41        self.data[w] |= 1u64 << b;
42    }
43
44    /// Clear bit at (row, col).
45    pub fn clear_bit(&mut self, row: usize, col: usize) {
46        assert!(row < self.rows && col < self.cols);
47        let (w, b) = self.index(row, col);
48        self.data[w] &= !(1u64 << b);
49    }
50
51    /// Toggle bit at (row, col).
52    pub fn toggle(&mut self, row: usize, col: usize) {
53        assert!(row < self.rows && col < self.cols);
54        let (w, b) = self.index(row, col);
55        self.data[w] ^= 1u64 << b;
56    }
57
58    /// Get bit at (row, col).
59    pub fn get(&self, row: usize, col: usize) -> bool {
60        assert!(row < self.rows && col < self.cols);
61        let (w, b) = self.index(row, col);
62        (self.data[w] >> b) & 1 == 1
63    }
64
65    /// Count set bits in a row.
66    pub fn row_popcount(&self, row: usize) -> u32 {
67        assert!(row < self.rows);
68        let start = row * self.words_per_row;
69        let end = start + self.words_per_row;
70        self.data[start..end].iter().map(|w| w.count_ones()).sum()
71    }
72
73    /// Count all set bits.
74    pub fn total_popcount(&self) -> u32 {
75        self.data.iter().map(|w| w.count_ones()).sum()
76    }
77
78    /// Zero all bits.
79    pub fn clear_all(&mut self) {
80        self.data.fill(0);
81    }
82
83    /// Fill all bits.
84    pub fn fill_all(&mut self) {
85        self.data.fill(u64::MAX);
86        // Zero out padding bits in last word of each row
87        if !self.cols.is_multiple_of(64) {
88            let mask = (1u64 << (self.cols % 64)) - 1;
89            for row in 0..self.rows {
90                let last = (row + 1) * self.words_per_row - 1;
91                self.data[last] &= mask;
92            }
93        }
94    }
95
96    /// Bitwise AND row r1 with r2, result in r1.
97    pub fn row_and(&mut self, r1: usize, r2: usize) {
98        assert!(r1 < self.rows && r2 < self.rows);
99        for w in 0..self.words_per_row {
100            self.data[r1 * self.words_per_row + w] &= self.data[r2 * self.words_per_row + w];
101        }
102    }
103
104    /// Bitwise OR row r1 with r2, result in r1.
105    pub fn row_or(&mut self, r1: usize, r2: usize) {
106        assert!(r1 < self.rows && r2 < self.rows);
107        for w in 0..self.words_per_row {
108            self.data[r1 * self.words_per_row + w] |= self.data[r2 * self.words_per_row + w];
109        }
110    }
111}
112
113#[cfg(test)]
114mod tests {
115    use super::*;
116
117    #[test]
118    fn new_all_clear() {
119        let m = BitMatrix::new(4, 64);
120        assert_eq!(m.total_popcount(), 0);
121    }
122
123    #[test]
124    fn set_and_get() {
125        let mut m = BitMatrix::new(4, 64);
126        m.set(2, 33);
127        assert!(m.get(2, 33));
128        assert!(!m.get(2, 32));
129    }
130
131    #[test]
132    fn clear_bit() {
133        let mut m = BitMatrix::new(4, 64);
134        m.set(1, 10);
135        m.clear_bit(1, 10);
136        assert!(!m.get(1, 10));
137    }
138
139    #[test]
140    fn toggle() {
141        let mut m = BitMatrix::new(4, 64);
142        m.toggle(0, 0);
143        assert!(m.get(0, 0));
144        m.toggle(0, 0);
145        assert!(!m.get(0, 0));
146    }
147
148    #[test]
149    fn row_popcount() {
150        let mut m = BitMatrix::new(3, 64);
151        m.set(1, 0);
152        m.set(1, 10);
153        m.set(1, 63);
154        assert_eq!(m.row_popcount(1), 3);
155    }
156
157    #[test]
158    fn total_popcount() {
159        let mut m = BitMatrix::new(3, 64);
160        m.set(0, 0);
161        m.set(1, 1);
162        m.set(2, 2);
163        assert_eq!(m.total_popcount(), 3);
164    }
165
166    #[test]
167    fn fill_all_popcount() {
168        let mut m = BitMatrix::new(2, 10);
169        m.fill_all();
170        assert_eq!(m.total_popcount(), 20);
171    }
172
173    #[test]
174    fn row_and() {
175        let mut m = BitMatrix::new(2, 64);
176        m.set(0, 5);
177        m.set(0, 10);
178        m.set(1, 5);
179        m.row_and(0, 1);
180        assert!(m.get(0, 5));
181        assert!(!m.get(0, 10));
182    }
183
184    #[test]
185    fn row_or() {
186        let mut m = BitMatrix::new(2, 64);
187        m.set(0, 5);
188        m.set(1, 10);
189        m.row_or(0, 1);
190        assert!(m.get(0, 5));
191        assert!(m.get(0, 10));
192    }
193
194    #[test]
195    fn wide_matrix() {
196        let mut m = BitMatrix::new(2, 200);
197        m.set(0, 199);
198        m.set(1, 65);
199        assert!(m.get(0, 199));
200        assert!(m.get(1, 65));
201        assert!(!m.get(0, 65));
202    }
203}