ldpc_toolbox/sparse/
bfs.rs

1use crate::sparse::{Node, SparseMatrix};
2use std::collections::VecDeque;
3
4#[derive(Debug, Clone, Eq, PartialEq)]
5struct PathHead {
6    node: Node,
7    parent: Option<Node>,
8    path_length: usize,
9}
10
11impl PathHead {
12    fn iter<'a>(&'a self, h: &'a SparseMatrix) -> impl Iterator<Item = PathHead> + 'a {
13        self.node
14            .iter(h)
15            .filter(move |&x| {
16                if let Some(parent) = self.parent {
17                    x != parent
18                } else {
19                    true
20                }
21            })
22            .map(move |x| PathHead {
23                node: x,
24                parent: Some(self.node),
25                path_length: self.path_length + 1,
26            })
27    }
28}
29
30/// Results for BFS algorithm
31///
32/// This gives the distances of each of the nodes of the graph from the node
33/// that was used as root for the BFS algorithm. Distances are represented
34/// as `Option<usize>`, with the value `None` for nodes that are not reachable
35/// from the root.
36#[derive(Debug, Clone, Eq, PartialEq)]
37pub struct BFSResults {
38    /// The vector of distances from each of the row nodes to the root
39    pub row_nodes_distance: Vec<Option<usize>>,
40    /// The vector of distances from each of the column nodes to the root
41    pub col_nodes_distance: Vec<Option<usize>>,
42}
43
44impl BFSResults {
45    fn get_node_mut(&mut self, node: Node) -> &mut Option<usize> {
46        match node {
47            Node::Row(n) => &mut self.row_nodes_distance[n],
48            Node::Col(n) => &mut self.col_nodes_distance[n],
49        }
50    }
51}
52
53pub struct BFSContext<'a> {
54    results: BFSResults,
55    to_visit: VecDeque<PathHead>,
56    h: &'a SparseMatrix,
57}
58
59impl BFSContext<'_> {
60    pub fn new(h: &SparseMatrix, node: Node) -> BFSContext<'_> {
61        let mut to_visit = VecDeque::new();
62        to_visit.push_back(PathHead {
63            node,
64            parent: None,
65            path_length: 0,
66        });
67        let mut results = BFSResults {
68            row_nodes_distance: vec![None; h.num_rows()],
69            col_nodes_distance: vec![None; h.num_cols()],
70        };
71        results.get_node_mut(node).replace(0);
72        BFSContext {
73            results,
74            to_visit,
75            h,
76        }
77    }
78
79    pub fn bfs(mut self) -> BFSResults {
80        while let Some(head) = self.to_visit.pop_front() {
81            for next_head in head.iter(self.h) {
82                let next_dist = self.results.get_node_mut(next_head.node);
83                if next_dist.is_none() {
84                    *next_dist = Some(next_head.path_length);
85                    self.to_visit.push_back(next_head);
86                }
87            }
88        }
89        self.results
90    }
91
92    pub fn local_girth(mut self, max: usize) -> Option<usize> {
93        while let Some(head) = self.to_visit.pop_front() {
94            for next_head in head.iter(self.h) {
95                let next_dist = self.results.get_node_mut(next_head.node);
96                if let Some(dist) = *next_dist {
97                    let total = dist + next_head.path_length;
98                    return if total <= max { Some(total) } else { None };
99                } else {
100                    *next_dist = Some(next_head.path_length);
101                    if next_head.path_length < max {
102                        self.to_visit.push_back(next_head);
103                    }
104                }
105            }
106        }
107        None
108    }
109}
110
111#[cfg(test)]
112mod tests {
113    use super::*;
114
115    #[test]
116    fn disconnected_2x2() {
117        let mut h = SparseMatrix::new(2, 2);
118        h.insert(0, 0);
119        h.insert(1, 1);
120        let r = h.bfs(Node::Col(0));
121        assert_eq!(r.row_nodes_distance[0], Some(1));
122        assert_eq!(r.row_nodes_distance[1], None);
123        assert_eq!(r.col_nodes_distance[0], Some(0));
124        assert_eq!(r.col_nodes_distance[1], None);
125    }
126
127    #[test]
128    fn complete_nxm() {
129        let n = 20;
130        let m = 10;
131        let mut h = SparseMatrix::new(n, m);
132        for i in 0..n {
133            for j in 0..m {
134                h.insert(i, j);
135            }
136        }
137        let r = h.bfs(Node::Row(0));
138        assert_eq!(r.row_nodes_distance[0], Some(0));
139        for i in 1..n {
140            assert_eq!(r.row_nodes_distance[i], Some(2));
141        }
142        for i in 0..m {
143            assert_eq!(r.col_nodes_distance[i], Some(1));
144        }
145    }
146
147    #[test]
148    fn circulant() {
149        let n = 20;
150        let mut h = SparseMatrix::new(n, n);
151        for j in 0..n {
152            h.insert(j, j);
153            h.insert(j, (j + 1) % n);
154        }
155        let r = h.bfs(Node::Row(0));
156        assert_eq!(r.row_nodes_distance[0], Some(0));
157        for j in 1..n {
158            let dist = std::cmp::min(2 * j, 2 * (n - j));
159            assert_eq!(r.row_nodes_distance[j], Some(dist));
160        }
161        for j in 1..n + 1 {
162            let dist = std::cmp::min(2 * j - 1, 2 * (n - j) + 1);
163            assert_eq!(r.col_nodes_distance[j % n], Some(dist));
164        }
165    }
166}