layout/topo/
optimizer.rs

1//! This module contains optimization passes that transform the graphs in different
2//! phases of the program. Here you can find things like optimizations for
3//! sinking or hoisting nodes to reduce the number of live edges, and
4//! optimizations that move nodes within a row to reduce edge crossing.
5
6use crate::adt::dag::NodeHandle;
7use crate::adt::dag::DAG;
8use crate::core::base::Direction;
9
10/// This optimizations changes the order of nodes within a rank (ordering along
11/// the x-axis). The transformation tries to reduce the number of edges that
12/// cross each other.
13#[derive(Debug)]
14pub struct EdgeCrossOptimizer<'a> {
15    dag: &'a mut DAG,
16}
17impl<'a> EdgeCrossOptimizer<'a> {
18    pub fn new(dag: &'a mut DAG) -> Self {
19        Self { dag }
20    }
21
22    /// Given two nodes that may have connections in \p row, check how many of
23    /// these edges intersect. Check both successors and predecessors.
24    ///               A   B
25    ///             /   \/ \
26    ///            /    /\  \
27    ///  Row: [][][][][][][][][][]
28    fn num_crossing(
29        &self,
30        a: NodeHandle,
31        b: NodeHandle,
32        row: &[NodeHandle],
33    ) -> usize {
34        let mut sum = 0;
35        // Record the number of edges that previously connected with node B.
36        let mut num_b = 0;
37
38        let a_edges1 = self.dag.successors(a);
39        let a_edges2 = self.dag.predecessors(a);
40        let b_edges1 = self.dag.successors(b);
41        let b_edges2 = self.dag.predecessors(b);
42
43        for node in row {
44            let is_a1 = a_edges1.iter().any(|x| x == node);
45            let is_a2 = a_edges2.iter().any(|x| x == node);
46            let is_b1 = b_edges1.iter().any(|x| x == node);
47            let is_b2 = b_edges2.iter().any(|x| x == node);
48            if is_a1 || is_a2 {
49                sum += num_b;
50            }
51            if is_b1 || is_b2 {
52                num_b += 1;
53            }
54        }
55        sum
56    }
57
58    // Shuffle the nodes in all of the ranks.
59    pub fn perturb_rank(&mut self) {
60        for i in 0..self.dag.num_levels() {
61            let row = self.dag.row_mut(i);
62            let len = row.len();
63            for j in 0..len {
64                row.swap((j * 17) % len, j);
65            }
66        }
67    }
68
69    // Move the elements in the rank to the left, to perturb the graph.
70    pub fn rotate_rank(&mut self) {
71        for i in 0..self.dag.num_levels() {
72            let row = self.dag.row_mut(i);
73            row.rotate_left(1);
74        }
75    }
76
77    pub fn optimize(&mut self) {
78        self.dag.verify();
79        #[cfg(feature = "log")]
80        log::info!("Optimizing edge crossing.");
81        let mut best_rank = self.dag.ranks().clone();
82        let mut best_cnt = self.count_crossed_edges();
83        #[cfg(feature = "log")]
84        log::info!("Starting with {} crossings.", best_cnt);
85        for i in 0..50 {
86            let dir = match i % 4 {
87                0 => Direction::Both,
88                1 => Direction::Up,
89                _ => Direction::Down,
90            };
91            self.swap_crossed_edges(dir);
92            let new_cnt = self.count_crossed_edges();
93            if new_cnt < best_cnt {
94                #[cfg(feature = "log")]
95                log::info!("Found a rank with {} crossings.", new_cnt);
96                best_rank = self.dag.ranks().clone();
97                best_cnt = new_cnt;
98            }
99            self.rotate_rank();
100            if i % 10 == 0 {
101                self.perturb_rank();
102            }
103        }
104        *self.dag.ranks_mut() = best_rank;
105    }
106
107    fn count_crossed_edges(&self) -> usize {
108        let mut sum = 0;
109        // Compare each row to the row afterwards.
110        for row_idx in 0..self.dag.num_levels() - 1 {
111            let first_row = self.dag.row(row_idx);
112            let second_row = self.dag.row(row_idx + 1);
113            sum += self.count_crossing_in_rows(first_row, second_row);
114        }
115        sum
116    }
117
118    fn count_crossing_in_rows(
119        &self,
120        first: &[NodeHandle],
121        second: &[NodeHandle],
122    ) -> usize {
123        if first.len() < 2 {
124            return 0;
125        }
126        let mut sum = 0;
127        // Check for each pair of nodes a,b where b comes after a.
128        for i in 0..first.len() {
129            for j in i + 1..first.len() {
130                let a = first[i];
131                let b = first[j];
132                sum += self.num_crossing(a, b, second);
133            }
134        }
135        sum
136    }
137
138    /// Scan all of the node pairs in the module and count the number of crossed
139    /// edges. If \p allow_swap is set then swap the edges if it reduces the
140    /// number of crossing.
141    fn swap_crossed_edges(&mut self, dir: Direction) {
142        let mut changed = true;
143        while changed {
144            changed = false;
145            if dir.is_down() {
146                for i in 0..self.dag.num_levels() {
147                    changed |= self.swap_crossed_edges_on_row(i, dir);
148                }
149            }
150            if dir.is_up() {
151                for i in (0..self.dag.num_levels()).rev() {
152                    changed |= self.swap_crossed_edges_on_row(i, dir);
153                }
154            }
155        }
156    }
157
158    /// See swap_crossed_edges.
159    fn swap_crossed_edges_on_row(
160        &mut self,
161        row_idx: usize,
162        dir: Direction,
163    ) -> bool {
164        let mut changed = false;
165
166        let num_rows = self.dag.num_levels();
167
168        let prev_row = if row_idx > 0 && dir.is_up() {
169            self.dag.row(row_idx - 1).clone()
170        } else {
171            Vec::new()
172        };
173        let next_row = if row_idx + 1 < num_rows && dir.is_down() {
174            self.dag.row(row_idx + 1).clone()
175        } else {
176            Vec::new()
177        };
178
179        let mut row = self.dag.row(row_idx).clone();
180
181        if row.len() < 2 {
182            return false;
183        }
184
185        // For each two consecutive elements in the row:
186        for i in 0..row.len() - 1 {
187            let a = row[i];
188            let b = row[i + 1];
189
190            let mut ab = 0;
191            let mut ba = 0;
192            // Figure out if A crosses the edges of B, and vice versa, on both
193            // the edges pointing up and down.
194            ab += self.num_crossing(a, b, &prev_row);
195            ba += self.num_crossing(b, a, &prev_row);
196            ab += self.num_crossing(a, b, &next_row);
197            ba += self.num_crossing(b, a, &next_row);
198
199            // Swap the edges.
200            if ab > ba {
201                row[i] = b;
202                row[i + 1] = a;
203                changed = true;
204            }
205        }
206
207        if changed {
208            *self.dag.row_mut(row_idx) = row;
209        }
210        changed
211    }
212}
213
214/// This optimization sinks nodes in an attempt to shorten the length of edges
215/// that run through the graph.
216#[derive(Debug)]
217pub struct RankOptimizer<'a> {
218    dag: &'a mut DAG,
219}
220
221impl<'a> RankOptimizer<'a> {
222    pub fn new(dag: &'a mut DAG) -> Self {
223        Self { dag }
224    }
225
226    pub fn try_to_sink_node(&mut self, node: NodeHandle) -> bool {
227        let backs = self.dag.predecessors(node);
228        let fwds = self.dag.successors(node);
229
230        // Don't try to sink if we increase the number of live edges,
231        // or if there are no forward edges.
232        if backs.len() > fwds.len() || backs.len() + fwds.len() == 0 {
233            return false;
234        }
235
236        let curr_rank = self.dag.level(node);
237        let mut highest_next = self.dag.len();
238        for elem in fwds {
239            let next_rank = self.dag.level(*elem);
240            highest_next = highest_next.min(next_rank);
241        }
242
243        // We found an opportunity to sink a node.
244        if highest_next > curr_rank + 1 {
245            self.dag
246                .update_node_rank_level(node, highest_next - 1, None);
247            return true;
248        }
249        false
250    }
251
252    // Try to sink nodes to shorten the length of edges.
253    pub fn optimize(&mut self) {
254        self.dag.verify();
255
256        #[cfg(feature = "log")]
257        log::info!("Optimizing the ranks.");
258        #[cfg(feature = "log")]
259        let mut cnt = 0;
260        #[cfg(feature = "log")]
261        let mut iter = 0;
262
263        loop {
264            let mut c = 0;
265            for node in self.dag.iter() {
266                if self.try_to_sink_node(node) {
267                    c += 1;
268                }
269            }
270            #[cfg(feature = "log")]
271            {
272                cnt += c;
273                iter += 1;
274            }
275            if c == 0 {
276                break;
277            }
278        }
279
280        #[cfg(feature = "log")]
281        log::info!("Sank {} nodes in {} iteration.", cnt, iter);
282    }
283}