roaring_graphs/
strictly_upper_triangular_logical_matrix.rs

1use roaring::{MultiOps, RoaringBitmap};
2
3#[inline]
4pub fn strictly_upper_triangular_matrix_capacity(n: u16) -> u32 {
5    let n = u32::from(n);
6    (n * n - n) / 2
7}
8
9pub struct RowColumnIterator {
10    size: u16,
11    i: u16,
12    j: u16,
13}
14
15impl RowColumnIterator {
16    pub fn new(size: u16) -> Self {
17        Self {
18            size,
19            i: 0,
20            j: 1,
21        }
22    }
23}
24
25impl<'a> Iterator for RowColumnIterator {
26    type Item = (u16, u16);
27
28    fn next(&mut self) -> Option<Self::Item> {
29        if self.size == 0 {
30            return None;
31        }
32        let result = (self.i, self.j);
33        if self.j < self.size - 1 {
34            self.j += 1;
35            return Some(result);
36        }
37        if self.i < self.size - 1 {
38            self.i += 1;
39            self.j = self.i + 1;
40            return Some(result);
41        }
42        None
43    }
44}
45
46/// A zero-indexed [row-major
47/// packed](https://www.intel.com/content/www/us/en/develop/documentation/onemkl-developer-reference-c/top/lapack-routines/matrix-storage-schemes-for-lapack-routines.html)
48/// matrix of booleans.
49#[derive(Clone, Debug)]
50pub struct StrictlyUpperTriangularLogicalMatrix {
51    size: u16,
52    matrix: RoaringBitmap,
53}
54
55impl Eq for StrictlyUpperTriangularLogicalMatrix {}
56
57impl PartialEq for StrictlyUpperTriangularLogicalMatrix {
58    fn eq(&self, other: &Self) -> bool {
59        if self.size != other.size {
60            return false;
61        }
62        self.iter_ones().eq(other.iter_ones())
63    }
64}
65
66// Reference: https://www.intel.com/content/www/us/en/develop/documentation/onemkl-developer-reference-c/top/lapack-routines/matrix-storage-schemes-for-lapack-routines.html
67// Formulas adjusted for indexing from zero.
68#[inline]
69pub(crate) fn index_from_row_column(row: u16, column: u16, size: u16) -> u32 {
70    u32::from(row) * u32::from(size) + u32::from(column)
71}
72
73#[inline]
74pub(crate) fn row_from_index(index: u32, size: u16) -> u16 {
75    u16::try_from(index / u32::from(size)).unwrap()
76}
77
78#[inline]
79pub(crate) fn column_from_index(index: u32, size: u16) -> u16 {
80    u16::try_from(index % u32::from(size)).unwrap()
81}
82
83#[inline]
84pub(crate) fn row_column_from_index(index: u32, size: u16) -> (u16, u16) {
85    let row = u16::try_from(index / u32::from(size)).unwrap();
86    let column = u16::try_from(index % u32::from(size)).unwrap();
87    (row, column)
88}
89
90impl StrictlyUpperTriangularLogicalMatrix {
91    pub fn zeroed(size: u16) -> Self {
92        Self {
93            size,
94            matrix: RoaringBitmap::new(),
95        }
96    }
97
98    pub fn from_bitset(size: u16, bitset: RoaringBitmap) -> Self {
99        Self {
100            size,
101            matrix: bitset,
102        }
103    }
104
105    pub fn from_iter<I: Iterator<Item = (u16, u16)>>(size: u16, iter: I) -> Self {
106        let mut bitmap = RoaringBitmap::new();
107        for (i, j) in iter {
108            let index = index_from_row_column(i, j, size);
109            bitmap.insert(index);
110        }
111        Self::from_bitset(size, bitmap)
112    }
113
114    #[inline]
115    pub fn size(&self) -> u16 {
116        self.size
117    }
118
119    pub fn get(&self, i: u16, j: u16) -> bool {
120        let index = index_from_row_column(i, j, self.size);
121        self.matrix.contains(index)
122    }
123
124    /// Returns the previous value.
125    pub fn set_to(&mut self, i: u16, j: u16, value: bool) -> bool {
126        let index = index_from_row_column(i, j, self.size);
127        let current = self.matrix.contains(index);
128        if value {
129            self.matrix.insert(index);
130        } else {
131            self.matrix.remove(index);
132        }
133        current
134    }
135
136    /// Returns the previous value.
137    pub fn set(&mut self, i: u16, j: u16) {
138        let index = index_from_row_column(i, j, self.size);
139        self.matrix.insert(index);
140    }
141
142    pub fn clear(&mut self, i: u16, j: u16) {
143        let index = index_from_row_column(i, j, self.size);
144        self.matrix.remove(index);
145    }
146
147    pub fn iter_ones(&self) -> impl Iterator<Item = (u16, u16)> + '_ {
148        self.matrix.iter().map(|index| row_column_from_index(index, self.size))
149    }
150
151    pub fn iter_ones_at_row(&self, i: u16) -> impl Iterator<Item = u16> + '_ {
152        assert!(i < self.size());
153        ((i*self.size+i+1)..((i+1)*self.size)).into_iter().filter(|index| self.matrix.contains((*index).into())).map(|index| column_from_index(index.into(), self.size))
154    }
155
156    pub fn iter_ones_at_column(&self, j: u16) -> impl Iterator<Item = u16> + '_ {
157        assert!(j < self.size());
158        let mask = RoaringBitmap::from_iter((0..j).map(|k| u32::from(k * self.size + j)));
159        let ones_indexes = [&self.matrix, &mask].intersection();
160        ones_indexes.into_iter().map(|index| row_from_index(index, self.size))
161    }
162
163    pub fn into_bitset(self) -> RoaringBitmap {
164        self.matrix
165    }
166}
167
168#[cfg(test)]
169mod tests {
170    use crate::strictly_upper_triangular_logical_matrix::*;
171
172    #[test]
173    fn positive_test_3x3_matrix() {
174        let mut matrix = StrictlyUpperTriangularLogicalMatrix::zeroed(3);
175        assert_eq!(matrix.get(0, 1), false);
176        let ones: Vec<(u16, u16)> = matrix.iter_ones().collect();
177        assert_eq!(ones, vec![]);
178
179        matrix.set_to(0, 1, true);
180        let ones: Vec<(u16, u16)> = matrix.iter_ones().collect();
181        assert_eq!(ones, vec![(0, 1)]);
182    }
183
184    #[test]
185    fn ones_at_row() {
186        let mut matrix = StrictlyUpperTriangularLogicalMatrix::zeroed(3);
187        matrix.set(0, 1);
188        assert_eq!(Vec::from_iter(matrix.iter_ones_at_row(0)), vec![1]);
189        assert_eq!(Vec::from_iter(matrix.iter_ones_at_row(1)), vec![]);
190        assert_eq!(Vec::from_iter(matrix.iter_ones_at_row(2)), vec![]);
191    }
192
193    #[test]
194    fn ones_at_column_bug1() {
195        let mut matrix = StrictlyUpperTriangularLogicalMatrix::zeroed(3);
196        matrix.set(0, 1);
197        assert_eq!(Vec::from_iter(matrix.iter_ones_at_column(0)), vec![]);
198        assert_eq!(Vec::from_iter(matrix.iter_ones_at_column(1)), vec![0]);
199        assert_eq!(Vec::from_iter(matrix.iter_ones_at_column(2)), vec![]);
200    }
201
202    #[test]
203    fn ones_at_column_bug2() {
204        let matrix = StrictlyUpperTriangularLogicalMatrix::from_iter(5, vec![(1, 2)].into_iter());
205        dbg!(&matrix);
206        assert_eq!(Vec::from_iter(matrix.iter_ones_at_column(0)), vec![]);
207        assert_eq!(Vec::from_iter(matrix.iter_ones_at_column(1)), vec![]);
208        assert_eq!(Vec::from_iter(matrix.iter_ones_at_column(2)), vec![1]);
209        assert_eq!(Vec::from_iter(matrix.iter_ones_at_column(3)), vec![]);
210        assert_eq!(Vec::from_iter(matrix.iter_ones_at_column(4)), vec![]);
211    }
212}