gluon_vm/core/
dead_code.rs

1use std::iter::FromIterator;
2
3use petgraph::visit::Walker;
4
5use base::{
6    fnv::FnvSet,
7    merge::merge,
8    scoped_map::ScopedMap,
9    symbol::{Symbol, SymbolRef},
10};
11
12use crate::core::{
13    self,
14    optimize::{walk_closures, walk_expr, walk_expr_alloc, SameLifetime, Visitor},
15    Allocator, CExpr, Expr, LetBinding, Pattern,
16};
17
18pub fn dead_code_elimination<'a>(
19    used_bindings: &FnvSet<&'a SymbolRef>,
20    allocator: &'a Allocator<'a>,
21    expr: CExpr<'a>,
22) -> CExpr<'a> {
23    trace!("dead_code_elimination: {}", expr);
24    struct DeadCodeEliminator<'a, 'b> {
25        allocator: &'a Allocator<'a>,
26        used_bindings: &'b FnvSet<&'a SymbolRef>,
27    }
28    impl DeadCodeEliminator<'_, '_> {
29        fn is_used(&self, s: &Symbol) -> bool {
30            self.used_bindings.contains(&**s)
31        }
32    }
33
34    impl<'e> Visitor<'e, 'e> for DeadCodeEliminator<'e, '_> {
35        type Producer = SameLifetime<'e>;
36
37        fn visit_expr(&mut self, expr: CExpr<'e>) -> Option<CExpr<'e>> {
38            match expr {
39                Expr::Let(bind, body) => {
40                    let new_body = self.visit_expr(body);
41                    let new_named = match &bind.expr {
42                        core::Named::Recursive(closures) => {
43                            let used_closures: Vec<_> = closures
44                                .iter()
45                                .filter(|closure| self.is_used(&closure.name.name))
46                                .cloned()
47                                .collect();
48
49                            walk_closures(self, &used_closures)
50                                .or_else(|| {
51                                    if used_closures.len() == closures.len() {
52                                        None
53                                    } else {
54                                        Some(used_closures)
55                                    }
56                                })
57                                .map(core::Named::Recursive)
58                        }
59
60                        core::Named::Expr(bind_expr) => {
61                            if self.is_used(&bind.name.name) {
62                                let new_bind_expr = self.visit_expr(bind_expr);
63                                new_bind_expr.map(core::Named::Expr)
64                            } else {
65                                return Some(new_body.unwrap_or(body));
66                            }
67                        }
68                    };
69                    let new_bind = new_named.map(|expr| {
70                        &*self.allocator.let_binding_arena.alloc(LetBinding {
71                            name: bind.name.clone(),
72                            expr,
73                            span_start: bind.span_start,
74                        })
75                    });
76                    merge(bind, new_bind, body, new_body, |bind, body| {
77                        match &bind.expr {
78                            core::Named::Recursive(closures) if closures.is_empty() => body,
79                            _ => &*self.allocator.arena.alloc(Expr::Let(bind, body)),
80                        }
81                    })
82                }
83
84                Expr::Match(_, alts) if alts.len() == 1 => match &alts[0].pattern {
85                    Pattern::Record { fields, .. } => {
86                        if fields
87                            .iter()
88                            .map(|(x, y)| y.as_ref().unwrap_or(&x.name))
89                            .any(|field_bind| self.is_used(&field_bind))
90                        {
91                            walk_expr_alloc(self, expr)
92                        } else {
93                            Some(
94                                self.visit_expr(alts[0].expr)
95                                    .unwrap_or_else(|| alts[0].expr),
96                            )
97                        }
98                    }
99                    _ => walk_expr_alloc(self, expr),
100                },
101
102                _ => walk_expr_alloc(self, expr),
103            }
104        }
105        fn detach_allocator(&self) -> Option<&'e Allocator<'e>> {
106            Some(self.allocator)
107        }
108    }
109
110    let mut free_vars = DeadCodeEliminator {
111        allocator,
112        used_bindings,
113    };
114    free_vars.visit_expr(expr).unwrap_or(expr)
115}
116
117#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
118enum Scope<'a> {
119    Symbol(&'a SymbolRef),
120    Match(usize),
121}
122
123#[derive(Default)]
124pub struct DepGraph<'a> {
125    graph: petgraph::Graph<Scope<'a>, ()>,
126    symbol_map: ScopedMap<Scope<'a>, petgraph::graph::NodeIndex>,
127    currents: Vec<(BindType, petgraph::graph::NodeIndex)>,
128    match_id: usize,
129}
130
131static TOP: &str = "<top>";
132
133#[derive(PartialEq)]
134enum BindType {
135    Closure,
136    Expr,
137}
138
139impl<'a> DepGraph<'a> {
140    fn scope(&mut self, id: &'a SymbolRef, bind_type: BindType, f: impl FnOnce(&mut Self)) {
141        let current_idx = self.add_node(Scope::Symbol(id));
142        self.scope_idx(current_idx, bind_type, f)
143    }
144
145    fn scope_idx(
146        &mut self,
147        idx: petgraph::graph::NodeIndex,
148        bind_type: BindType,
149        f: impl FnOnce(&mut Self),
150    ) {
151        self.currents.push((bind_type, idx));
152
153        f(self);
154
155        self.currents.pop();
156    }
157
158    fn add_node(&mut self, scope: Scope<'a>) -> petgraph::graph::NodeIndex {
159        let Self {
160            symbol_map, graph, ..
161        } = self;
162        *symbol_map
163            .entry(scope)
164            .or_insert_with(|| graph.add_node(scope))
165    }
166
167    fn bind_pattern(&mut self, pattern: &'a Pattern, scrutinee_id: petgraph::graph::NodeIndex) {
168        match pattern {
169            Pattern::Ident(ref id) => {
170                let id_id = self.add_node(Scope::Symbol(&id.name));
171                self.graph.add_edge(id_id, scrutinee_id, ());
172            }
173            Pattern::Record { fields, .. } => {
174                for field in fields {
175                    let name = field.1.as_ref().unwrap_or(&field.0.name);
176                    let id_id = self.add_node(Scope::Symbol(name));
177                    self.graph.add_edge(id_id, scrutinee_id, ());
178                }
179            }
180            Pattern::Constructor(_, fields) => {
181                for field in fields {
182                    let id_id = self.add_node(Scope::Symbol(&field.name));
183                    self.graph.add_edge(id_id, scrutinee_id, ());
184                }
185            }
186            Pattern::Literal(_) => (),
187        }
188    }
189
190    pub fn cycles<'s>(
191        &'s self,
192    ) -> impl Iterator<Item = impl Iterator<Item = &'a SymbolRef> + 's> + 's {
193        petgraph::algo::tarjan_scc(&self.graph)
194            .into_iter()
195            .filter_map(move |cycle| {
196                if cycle.len() == 1 {
197                    let node = cycle[0];
198                    if self.graph.find_edge(node, node).is_some() {
199                        Some(itertools::Either::Left(
200                            self.graph.node_weight(node).cloned().into_iter(),
201                        ))
202                    } else {
203                        None
204                    }
205                } else {
206                    Some(itertools::Either::Right(
207                        cycle
208                            .into_iter()
209                            .flat_map(move |node| self.graph.node_weight(node).cloned()),
210                    ))
211                }
212            })
213            .map(|iter| {
214                iter.filter_map(|scope| match scope {
215                    Scope::Symbol(s) => Some(s),
216                    Scope::Match(_) => None,
217                })
218            })
219    }
220
221    pub fn used_bindings<F>(&mut self, expr: CExpr<'a>) -> F
222    where
223        F: FromIterator<&'a SymbolRef>,
224    {
225        let top_symbol = SymbolRef::new(TOP);
226        let top = self.graph.add_node(Scope::Symbol(top_symbol));
227        self.symbol_map.insert(Scope::Symbol(top_symbol), top);
228
229        self.scope_idx(top, BindType::Expr, |dep_graph| {
230            dep_graph.visit_expr(expr);
231        });
232
233        trace!("DepGraph: {:?}", petgraph::dot::Dot::new(&self.graph));
234
235        let graph = &self.graph;
236        petgraph::visit::Dfs::new(graph, top)
237            .iter(graph)
238            .flat_map(|idx| {
239                graph.node_weight(idx).and_then(|scope| match *scope {
240                    Scope::Symbol(s) => Some(s),
241                    Scope::Match(_) => None,
242                })
243            })
244            .collect()
245    }
246}
247
248impl<'e> Visitor<'e, 'e> for DepGraph<'e> {
249    type Producer = SameLifetime<'e>;
250
251    fn visit_expr(&mut self, expr: CExpr<'e>) -> Option<CExpr<'e>> {
252        match expr {
253            Expr::Ident(id, ..) => {
254                let current = self.currents.last().unwrap().1;
255                let used_id = self.add_node(Scope::Symbol(&id.name));
256                self.graph.add_edge(current, used_id, ());
257
258                None
259            }
260
261            Expr::Call(Expr::Ident(id, ..), ..) if !id.name.as_str().starts_with('#') => {
262                for window in self
263                    .currents
264                    .windows(2)
265                    .rev()
266                    .take_while(|t| t[1].0 == BindType::Expr)
267                {
268                    self.graph.add_edge(window[0].1, window[1].1, ());
269                }
270                walk_expr(self, expr);
271                None
272            }
273
274            Expr::Let(bind, body) => {
275                self.symbol_map.enter_scope();
276
277                match &bind.expr {
278                    core::Named::Recursive(closures) => {
279                        for closure in closures {
280                            let id = &closure.name.name;
281                            self.add_node(Scope::Symbol(id));
282                        }
283
284                        for closure in closures {
285                            self.scope(&closure.name.name, BindType::Closure, |self_| {
286                                self_.visit_expr(closure.expr);
287                            });
288                        }
289                    }
290                    core::Named::Expr(bind_expr) => {
291                        self.scope(&bind.name.name, BindType::Expr, |self_| {
292                            self_.visit_expr(bind_expr);
293                        });
294                    }
295                }
296
297                self.visit_expr(body);
298
299                self.symbol_map.exit_scope();
300
301                None
302            }
303
304            Expr::Match(scrutinee, alts) => {
305                let id = Scope::Match(self.match_id);
306                self.match_id += 1;
307                let scrutinee_id = self.graph.add_node(id);
308
309                for alt in *alts {
310                    self.bind_pattern(&alt.pattern, scrutinee_id);
311                }
312
313                self.scope_idx(scrutinee_id, BindType::Expr, |self_| {
314                    self_.visit_expr(scrutinee);
315                });
316
317                if alts.iter().any(|alt| match alt.pattern {
318                    Pattern::Constructor(..) | Pattern::Literal(..) => true,
319                    _ => false,
320                }) {
321                    let current = self.currents.last().unwrap().1;
322                    self.graph.add_edge(current, scrutinee_id, ());
323                }
324
325                for alt in *alts {
326                    self.visit_expr(&alt.expr);
327                }
328
329                None
330            }
331
332            _ => {
333                walk_expr(self, expr);
334                None
335            }
336        }
337    }
338    fn detach_allocator(&self) -> Option<&'e Allocator<'e>> {
339        None
340    }
341}
342
343#[cfg(test)]
344mod tests {
345    use super::*;
346
347    use base::symbol::Symbols;
348
349    use crate::core::optimize::tests::check_optimization;
350
351    fn dead_code_elimination<'a>(allocator: &'a Allocator<'a>, expr: CExpr<'a>) -> CExpr<'a> {
352        let mut dep_graph = crate::core::dead_code::DepGraph::default();
353        let used_bindings = dep_graph.used_bindings(expr);
354
355        super::dead_code_elimination(&used_bindings, allocator, expr)
356    }
357
358    #[test]
359    fn basic() {
360        let initial_str = r#"
361            let x = 1
362            in
363            2
364            "#;
365        let expected_str = r#"
366            2
367            "#;
368        check_optimization(initial_str, expected_str, dead_code_elimination);
369    }
370
371    #[test]
372    fn recursive_basic() {
373        let initial_str = r#"
374            rec let f x = x
375            in
376            2
377            "#;
378        let expected_str = r#"
379            2
380            "#;
381        check_optimization(initial_str, expected_str, dead_code_elimination);
382    }
383
384    #[test]
385    fn eliminate_inner() {
386        let initial_str = r#"
387            let x =
388                let y = ""
389                in
390                1
391            in
392            x
393            "#;
394        let expected_str = r#"
395            let x = 1
396            in
397            x
398            "#;
399        check_optimization(initial_str, expected_str, dead_code_elimination);
400    }
401
402    #[test]
403    fn eliminate_redundant_match() {
404        let initial_str = r#"
405            match { x = 1 } with
406            | { x } -> 1
407            end
408            "#;
409        let expected_str = r#"
410            1
411            "#;
412        check_optimization(initial_str, expected_str, dead_code_elimination);
413    }
414
415    #[test]
416    fn eliminate_let_used_in_redundant_match() {
417        let initial_str = r#"
418            let a = 1 in
419            match { x = a } with
420            | { x } -> 1
421            end
422            "#;
423        let expected_str = r#"
424            1
425            "#;
426        check_optimization(initial_str, expected_str, dead_code_elimination);
427    }
428
429    #[test]
430    fn dont_eliminate_used_match() {
431        let initial_str = r#"
432            rec let f y = y
433            in
434            let x = f 123
435            in
436            match { x } with
437            | { x } -> x
438            end
439            "#;
440        let expected_str = r#"
441            rec let f y = y
442            in
443            let x = f 123
444            in
445            match { x } with
446            | { x } -> x
447            end
448            "#;
449        check_optimization(initial_str, expected_str, dead_code_elimination);
450    }
451
452    #[test]
453    fn dont_eliminate_let_in_constructor_match() {
454        let initial_str = r#"
455            let y = 1 in
456            match y with
457            | LT -> 1
458            end
459            "#;
460        let expected_str = r#"
461            let y = 1 in
462            match y with
463            | LT -> 1
464            end
465            "#;
466        check_optimization(initial_str, expected_str, dead_code_elimination);
467    }
468
469    #[test]
470    fn dont_eliminate_constructor_with_used_binding() {
471        let initial_str = r#"
472            rec let f y =
473                match y with
474                | Test a ->
475                    match a with
476                    | 1 -> 1
477                    | _ -> 2
478                    end
479                end
480            in f
481            "#;
482        let expected_str = r#"
483            rec let f y =
484                match y with
485                | Test a ->
486                    match a with
487                    | 1 -> 1
488                    | _ -> 2
489                    end
490                end
491            in f
492           "#;
493        check_optimization(initial_str, expected_str, dead_code_elimination);
494    }
495
496    #[test]
497    fn cycles() {
498        let expr_str = r#"
499            let z = 1
500            in
501
502            rec let f y1 =
503                let a = z in
504                g y1 z
505            rec let g y2 = f y2
506            in
507
508            rec let h y3 = h y3
509            in
510
511            let a = g y
512            in 1
513            "#;
514
515        let mut symbols = Symbols::new();
516        let allocator = Allocator::new();
517        let expr = crate::core::interpreter::tests::parse_expr(&mut symbols, &allocator, expr_str);
518
519        let mut dep_graph = DepGraph::default();
520        dep_graph.visit_expr(expr);
521        assert_eq!(
522            dep_graph
523                .cycles()
524                .map(|group| group.map(|s| s.to_string()).collect::<Vec<_>>())
525                .collect::<Vec<_>>(),
526            vec![
527                vec!["g".to_string(), "f".to_string()],
528                vec!["h".to_string()]
529            ],
530            "{:?}",
531            petgraph::dot::Dot::new(&dep_graph.graph)
532        );
533    }
534}