ldpc_toolbox/
sparse.rs

1//! # Sparse binary matrix representation and functions
2//!
3//! This module implements a representation for sparse binary matrices based on
4//! the alist format used to handle LDPC parity check matrices.
5
6use std::borrow::Borrow;
7use std::slice::Iter;
8
9mod bfs;
10mod girth;
11
12pub use bfs::BFSResults;
13
14/// A [`String`] with an description of the error.
15pub type Error = String;
16/// A [`Result`] type containing an error [`String`].
17pub type Result<T> = std::result::Result<T, Error>;
18
19/// A sparse binary matrix
20///
21/// The internal representation for this matrix is based on the alist format.
22#[derive(Eq, Debug, Clone)]
23pub struct SparseMatrix {
24    rows: Vec<Vec<usize>>,
25    cols: Vec<Vec<usize>>,
26}
27
28impl PartialEq for SparseMatrix {
29    fn eq(&self, other: &SparseMatrix) -> bool {
30        if self.num_rows() != other.num_rows() {
31            return false;
32        }
33        if self.num_cols() != other.num_cols() {
34            return false;
35        }
36        for (r1, r2) in self.rows.iter().zip(other.rows.iter()) {
37            let mut r1 = r1.clone();
38            let mut r2 = r2.clone();
39            r1.sort();
40            r2.sort();
41            if r1 != r2 {
42                return false;
43            }
44        }
45        true
46    }
47}
48
49impl SparseMatrix {
50    /// Create a new sparse matrix of a given size
51    ///
52    /// The matrix is inizialized to the zero matrix.
53    ///
54    /// # Examples
55    /// ```
56    /// # use ldpc_toolbox::sparse::SparseMatrix;
57    /// let h = SparseMatrix::new(10, 30);
58    /// assert_eq!(h.num_rows(), 10);
59    /// assert_eq!(h.num_cols(), 30);
60    /// ```
61    pub fn new(nrows: usize, ncols: usize) -> SparseMatrix {
62        use std::iter::repeat_with;
63        let rows = repeat_with(Vec::new).take(nrows).collect();
64        let cols = repeat_with(Vec::new).take(ncols).collect();
65        SparseMatrix { rows, cols }
66    }
67
68    /// Returns the number of rows of the matrix
69    pub fn num_rows(&self) -> usize {
70        self.rows.len()
71    }
72
73    /// Returns the number of columns of the matrix
74    pub fn num_cols(&self) -> usize {
75        self.cols.len()
76    }
77
78    /// Returns the row weight of `row`
79    ///
80    /// The row weight is defined as the number of entries equal to
81    /// one in a particular row. Rows are indexed starting by zero.
82    pub fn row_weight(&self, row: usize) -> usize {
83        self.rows[row].len()
84    }
85
86    /// Returns the column weight of `column`
87    ///
88    /// The column weight is defined as the number of entries equal to
89    /// one in a particular column. Columns are indexed starting by zero.
90    pub fn col_weight(&self, col: usize) -> usize {
91        self.cols[col].len()
92    }
93
94    /// Returns `true` if the entry corresponding to a particular
95    /// row and column is a one
96    pub fn contains(&self, row: usize, col: usize) -> bool {
97        // typically columns are shorter, so we search in the column
98        self.cols[col].contains(&row)
99    }
100
101    /// Inserts a one in a particular row and column.
102    ///
103    /// If there is already a one in this row and column, this function does
104    /// nothing.
105    ///
106    /// # Examples
107    /// ```
108    /// # use ldpc_toolbox::sparse::SparseMatrix;
109    /// let mut h = SparseMatrix::new(10, 30);
110    /// assert!(!h.contains(3, 7));
111    /// h.insert(3, 7);
112    /// assert!(h.contains(3, 7));
113    /// ```
114    pub fn insert(&mut self, row: usize, col: usize) {
115        if !self.contains(row, col) {
116            self.rows[row].push(col);
117            self.cols[col].push(row);
118        }
119    }
120
121    /// Removes a one in a particular row and column.
122    ///
123    /// If there is no one in this row and column, this function does nothing.
124    ///
125    /// # Examples
126    /// ```
127    /// # use ldpc_toolbox::sparse::SparseMatrix;
128    /// let mut h = SparseMatrix::new(10, 30);
129    /// h.insert(3, 7);
130    /// assert!(h.contains(3, 7));
131    /// h.remove(3, 7);
132    /// assert!(!h.contains(3, 7));
133    /// ```
134    pub fn remove(&mut self, row: usize, col: usize) {
135        self.rows[row].retain(|&c| c != col);
136        self.cols[col].retain(|&r| r != row);
137    }
138
139    /// Toggles the 0/1 in a particular row and column.
140    ///
141    /// If the row and column contains a zero, this function sets a one, and
142    /// vice versa. This is useful to implement addition modulo 2.
143    pub fn toggle(&mut self, row: usize, col: usize) {
144        match self.contains(row, col) {
145            true => self.remove(row, col),
146            false => self.insert(row, col),
147        }
148    }
149
150    /// Inserts ones in particular columns of a row
151    ///
152    /// This effect is as calling `insert()` on each of the elements
153    /// of the iterator `cols`.
154    ///
155    /// # Examples
156    /// ```
157    /// # use ldpc_toolbox::sparse::SparseMatrix;
158    /// let mut h1 = SparseMatrix::new(10, 30);
159    /// let mut h2 = SparseMatrix::new(10, 30);
160    /// let c = vec![3, 7, 9];
161    /// h1.insert_row(0, c.iter());
162    /// for a in &c {
163    ///     h2.insert(0, *a);
164    /// }
165    /// assert_eq!(h1, h2);
166    /// ```
167    pub fn insert_row<T, S>(&mut self, row: usize, cols: T)
168    where
169        T: Iterator<Item = S>,
170        S: Borrow<usize>,
171    {
172        for col in cols {
173            self.insert(row, *col.borrow());
174        }
175    }
176
177    /// Inserts ones in a particular rows of a column
178    ///
179    /// This works like `insert_row()`.
180    pub fn insert_col<T, S>(&mut self, col: usize, rows: T)
181    where
182        T: Iterator<Item = S>,
183        S: Borrow<usize>,
184    {
185        for row in rows {
186            self.insert(*row.borrow(), col);
187        }
188    }
189
190    /// Remove all the ones in a particular row
191    pub fn clear_row(&mut self, row: usize) {
192        for &col in &self.rows[row] {
193            self.cols[col].retain(|r| *r != row);
194        }
195        self.rows[row].clear();
196    }
197
198    /// Remove all the ones in a particular column
199    pub fn clear_col(&mut self, col: usize) {
200        for &row in &self.cols[col] {
201            self.rows[row].retain(|c| *c != col);
202        }
203        self.cols[col].clear();
204    }
205
206    /// Set the elements that are equal to one in a row
207    ///
208    /// The effect of this is like calling `clear_row()` followed
209    /// by `insert_row()`.
210    pub fn set_row<T, S>(&mut self, row: usize, cols: T)
211    where
212        T: Iterator<Item = S>,
213        S: Borrow<usize>,
214    {
215        self.clear_row(row);
216        self.insert_row(row, cols);
217    }
218
219    /// Set the elements that are equal to one in a column
220    pub fn set_col<T, S>(&mut self, col: usize, rows: T)
221    where
222        T: Iterator<Item = S>,
223        S: Borrow<usize>,
224    {
225        self.clear_col(col);
226        self.insert_col(col, rows);
227    }
228
229    /// Returns an [Iterator] over the indices entries equal to one in all the
230    /// matrix.
231    pub fn iter_all(&self) -> impl Iterator<Item = (usize, usize)> + '_ {
232        self.rows
233            .iter()
234            .enumerate()
235            .flat_map(|(j, r)| r.iter().map(move |&k| (j, k)))
236    }
237
238    /// Returns an [Iterator] over the entries equal to one
239    /// in a particular row
240    pub fn iter_row(&self, row: usize) -> Iter<'_, usize> {
241        self.rows[row].iter()
242    }
243
244    /// Returns an [Iterator] over the entries equal to one
245    /// in a particular column
246    pub fn iter_col(&self, col: usize) -> Iter<'_, usize> {
247        self.cols[col].iter()
248    }
249
250    fn write_alist_maybe_padding<W: std::fmt::Write>(
251        &self,
252        w: &mut W,
253        use_padding: bool,
254    ) -> std::fmt::Result {
255        writeln!(w, "{} {}", self.num_cols(), self.num_rows())?;
256        let directions = [&self.cols, &self.rows];
257        let mut direction_lengths = [0, 0];
258        for (dir, len) in directions.iter().zip(direction_lengths.iter_mut()) {
259            *len = dir.iter().map(|el| el.len()).max().unwrap_or(0);
260        }
261        writeln!(w, "{} {}", direction_lengths[0], direction_lengths[1])?;
262        for dir in directions.iter() {
263            let mut lengths = dir.iter().map(|el| el.len());
264            if let Some(len) = lengths.next() {
265                write!(w, "{}", len)?;
266            }
267            for len in lengths {
268                write!(w, " {}", len)?;
269            }
270            writeln!(w)?;
271        }
272        for (dir, &dirlen) in directions.iter().zip(direction_lengths.iter()) {
273            for el in *dir {
274                let mut v = el.clone();
275                v.sort_unstable();
276                let vlen = v.len();
277                let mut v = v.iter().map(|x| x + 1);
278                if let Some(x) = v.next() {
279                    write!(w, "{}", x)?;
280                }
281                for x in v {
282                    write!(w, " {}", x)?;
283                }
284                if use_padding {
285                    if vlen == 0 {
286                        write!(w, "0")?;
287                    }
288                    // .max(1) because we've added one padding element if vlen
289                    // was zero
290                    let num_padding = dirlen - vlen.max(1);
291                    for _ in 0..num_padding {
292                        write!(w, " 0")?;
293                    }
294                }
295                writeln!(w)?;
296            }
297        }
298        Ok(())
299    }
300
301    /// Writes the matrix in alist format to a writer.
302    ///
303    /// This function includes zeros as padding for irregular codes, as
304    /// originally defined by MacKay.
305    ///
306    /// # Errors
307    /// If a call to `write!()` returns an error, this function returns
308    /// such an error.
309    pub fn write_alist<W: std::fmt::Write>(&self, w: &mut W) -> std::fmt::Result {
310        self.write_alist_maybe_padding(w, true)
311    }
312
313    /// Writes the matrix in alist format to a writer.
314    ///
315    /// This function does not include zeros as padding for irregular codes.
316    ///
317    /// # Errors
318    /// If a call to `write!()` returns an error, this function returns
319    /// such an error.
320    pub fn write_alist_no_padding<W: std::fmt::Write>(&self, w: &mut W) -> std::fmt::Result {
321        self.write_alist_maybe_padding(w, false)
322    }
323
324    /// Returns a [`String`] with the alist representation of the matrix.
325    ///
326    /// This function includes zeros as padding for irregular codes, as
327    /// originally defined by MacKay.
328    pub fn alist(&self) -> String {
329        let mut s = String::new();
330        self.write_alist(&mut s).unwrap();
331        s
332    }
333
334    /// Returns a [`String`] with the alist representation of the matrix.
335    ///
336    /// This function does not include zeros as padding for irregular codes.
337    pub fn alist_no_padding(&self) -> String {
338        let mut s = String::new();
339        self.write_alist_no_padding(&mut s).unwrap();
340        s
341    }
342
343    /// Constructs and returns a sparse matrix from its alist representation.
344    ///
345    /// This function is able to read alists that use zeros for padding in the
346    /// case of an irregular code (as was defined originally by MacKay), as well
347    /// as alists that omit these zeros.
348    ///
349    /// # Errors
350    /// `alist` should hold a valid alist representation. If an error is found
351    /// while parsing `alist`, a `String` describing the error will be returned.
352    pub fn from_alist(alist: &str) -> Result<SparseMatrix> {
353        let mut alist = alist.split('\n');
354        let sizes = alist
355            .next()
356            .ok_or_else(|| String::from("alist first line not found"))?;
357        let mut sizes = sizes.split_whitespace();
358        let ncols = sizes
359            .next()
360            .ok_or_else(|| String::from("alist first line does not contain enough elements"))?
361            .parse()
362            .map_err(|_| String::from("ncols is not a number"))?;
363        let nrows = sizes
364            .next()
365            .ok_or_else(|| String::from("alist first line does not contain enough elements"))?
366            .parse()
367            .map_err(|_| String::from("nrows is not a number"))?;
368        let mut h = SparseMatrix::new(nrows, ncols);
369        alist.next(); // skip max weights
370        alist.next();
371        alist.next(); // skip weights
372        for col in 0..ncols {
373            let col_data = alist
374                .next()
375                .ok_or_else(|| String::from("alist does not contain expected number of lines"))?;
376            let col_data = col_data.split_whitespace();
377            for row in col_data {
378                let row: usize = row
379                    .parse()
380                    .map_err(|_| String::from("row value is not a number"))?;
381                // row == 0 is used for padding in irregular codes
382                if row != 0 {
383                    h.insert(row - 1, col);
384                }
385            }
386        }
387        // we do not need to process the rows of the alist
388        Ok(h)
389    }
390
391    /// Returns the girth of the bipartite graph defined by the matrix
392    ///
393    /// The girth is the length of the shortest cycle. If there are no
394    /// cycles, `None` is returned.
395    ///
396    /// # Examples
397    /// The following shows that a 2 x 2 matrix whose entries are all
398    /// equal to one has a girth of 4, which is the smallest girth that
399    /// a bipartite graph can have.
400    /// ```
401    /// # use ldpc_toolbox::sparse::SparseMatrix;
402    /// let mut h = SparseMatrix::new(2, 2);
403    /// for j in 0..2 {
404    ///     for k in 0..2 {
405    ///         h.insert(j, k);
406    ///     }
407    /// }
408    /// assert_eq!(h.girth(), Some(4));
409    /// ```
410    pub fn girth(&self) -> Option<usize> {
411        self.girth_with_max(usize::MAX)
412    }
413
414    /// Returns the girth of the bipartite graph defined by the matrix
415    /// as long as it is smaller than a maximum
416    ///
417    /// By imposing a maximum value in the girth search algorithm,
418    /// the execution time is reduced, since paths longer than the
419    /// maximum do not need to be explored.
420    ///
421    /// Often it is only necessary to check that a graph has at least
422    /// some minimum girth, so it is possible to use `girth_with_max()`.
423    ///
424    /// If there are no cycles with length smaller or equal to `max`, then
425    /// `None` is returned.
426    pub fn girth_with_max(&self, max: usize) -> Option<usize> {
427        (0..self.num_cols())
428            .filter_map(|c| self.girth_at_node_with_max(Node::Col(c), max))
429            .min()
430    }
431
432    /// Returns the local girth at a particular node
433    ///
434    /// The local girth at a node of a graph is defined as the minimum
435    /// length of the cycles containing that node.
436    ///
437    /// This function returns the local girth of at the node correponding
438    /// to a column or row of the matrix, or `None` if there are no cycles containing
439    /// that node.
440    pub fn girth_at_node(&self, node: Node) -> Option<usize> {
441        self.girth_at_node_with_max(node, usize::MAX)
442    }
443
444    /// Returns the girth at a particular node with a maximum
445    ///
446    /// This function works like `girth_at_node()` but imposes a maximum in the
447    /// length of the cycles considered. `None` is returned if there are no
448    /// cycles containing the node with length smaller or equal than `max`.
449    pub fn girth_at_node_with_max(&self, node: Node, max: usize) -> Option<usize> {
450        bfs::BFSContext::new(self, node).local_girth(max)
451    }
452
453    /// Run the BFS algorithm
454    ///
455    /// This uses a node of the graph associated to the matrix as the root
456    /// for the BFS algorithm and finds the distances from each of the nodes
457    /// of the graph to that root using breadth-first search.
458    /// # Examples
459    /// Run BFS on a matrix that has two connected components.
460    /// ```
461    /// # use ldpc_toolbox::sparse::SparseMatrix;
462    /// # use ldpc_toolbox::sparse::Node;
463    /// let mut h = SparseMatrix::new(4, 4);
464    /// for j in 0..4 {
465    ///     for k in 0..4 {
466    ///         if (j % 2) == (k % 2) {
467    ///             h.insert(j, k);
468    ///         }
469    ///     }
470    /// }
471    /// println!("{:?}", h.bfs(Node::Col(0)));
472    /// ```
473    pub fn bfs(&self, node: Node) -> BFSResults {
474        bfs::BFSContext::new(self, node).bfs()
475    }
476}
477
478/// A node in the graph associated to a sparse matrix
479///
480/// A node can represent a row or a column of the graph.
481#[derive(Debug, Copy, Clone, Eq, PartialEq)]
482pub enum Node {
483    /// Node representing row number `n`
484    Row(usize),
485    /// Node representing column number `n`
486    Col(usize),
487}
488
489impl Node {
490    fn iter(self, h: &SparseMatrix) -> impl Iterator<Item = Node> + '_ {
491        match self {
492            Node::Row(n) => h.iter_row(n),
493            Node::Col(n) => h.iter_col(n),
494        }
495        .map(move |&x| match self {
496            Node::Row(_) => Node::Col(x),
497            Node::Col(_) => Node::Row(x),
498        })
499    }
500}
501
502#[cfg(test)]
503mod tests {
504    use super::*;
505
506    #[test]
507    fn test_insert() {
508        let mut h = SparseMatrix::new(100, 300);
509        assert!(!h.contains(27, 154));
510        h.insert(27, 154);
511        assert!(h.contains(27, 154));
512        assert!(!h.contains(28, 154));
513    }
514
515    #[test]
516    fn test_insert_twice() {
517        let mut h = SparseMatrix::new(100, 300);
518        h.insert(27, 154);
519        h.insert(43, 28);
520        h.insert(53, 135);
521        let h2 = h.clone();
522        h.insert(43, 28);
523        assert_eq!(h, h2);
524    }
525
526    #[test]
527    fn iter_all() {
528        use std::collections::HashSet;
529
530        let mut h = SparseMatrix::new(10, 20);
531        let entries = [
532            (7, 8),
533            (5, 14),
534            (6, 6),
535            (6, 7),
536            (8, 10),
537            (0, 4),
538            (0, 0),
539            (0, 15),
540        ];
541        for entry in &entries {
542            h.insert(entry.0, entry.1);
543        }
544        let result = h.iter_all().collect::<HashSet<_>>();
545        assert_eq!(result, HashSet::from(entries));
546    }
547
548    #[test]
549    fn test_alist() {
550        let mut h = SparseMatrix::new(4, 12);
551        for j in 0..4 {
552            h.insert(j, j);
553            h.insert(j, j + 4);
554            h.insert(j, j + 8);
555        }
556        let expected = "12 4
5571 3
5581 1 1 1 1 1 1 1 1 1 1 1
5593 3 3 3
5601
5612
5623
5634
5641
5652
5663
5674
5681
5692
5703
5714
5721 5 9
5732 6 10
5743 7 11
5754 8 12
576";
577        assert_eq!(h.alist(), expected);
578
579        let h2 = SparseMatrix::from_alist(expected).unwrap();
580        assert_eq!(h2.alist(), expected);
581    }
582
583    #[test]
584    fn test_alist_irregular() {
585        let mut h = SparseMatrix::new(4, 12);
586        for j in 0..4 {
587            h.insert(j, j);
588            h.insert(j, j + 4);
589            if j < 2 {
590                h.insert(j, j + 8);
591            }
592        }
593
594        // with zero padding
595
596        let expected = "12 4
5971 3
5981 1 1 1 1 1 1 1 1 1 0 0
5993 3 2 2
6001
6012
6023
6034
6041
6052
6063
6074
6081
6092
6100
6110
6121 5 9
6132 6 10
6143 7 0
6154 8 0
616";
617        let expected_no_padding = "12 4
6181 3
6191 1 1 1 1 1 1 1 1 1 0 0
6203 3 2 2
6211
6222
6233
6244
6251
6262
6273
6284
6291
6302
631
632
6331 5 9
6342 6 10
6353 7
6364 8
637";
638
639        assert_eq!(h.alist(), expected);
640        assert_eq!(h.alist_no_padding(), expected_no_padding);
641        let h2 = SparseMatrix::from_alist(expected).unwrap();
642        assert_eq!(h2.alist(), expected);
643        assert_eq!(h2.alist_no_padding(), expected_no_padding);
644        let h3 = SparseMatrix::from_alist(expected_no_padding).unwrap();
645        assert_eq!(h3.alist(), expected);
646        assert_eq!(h3.alist_no_padding(), expected_no_padding);
647    }
648}