compiler_course_helper/grammar/
eliminate_left_recursion.rs

1use std::collections::{HashMap, HashSet};
2
3use super::{grammar::NonTerminal, Grammar, EPSILON};
4
5impl Grammar {
6    pub fn eliminate_left_recursion(&mut self) {
7        if !self.is_nullable_first_follow_valid() {
8            self.calculate_nullable_first_follow();
9        }
10
11        let epsilon_idx = self.get_symbol_index(EPSILON).unwrap();
12        let offset = self.symbols.len();
13
14        let mut non_terminals = self.non_terminal_iter_mut().collect::<Vec<_>>();
15        let map: HashMap<usize, usize> =
16            non_terminals
17                .iter()
18                .enumerate()
19                .fold(HashMap::new(), |mut map, (i, nt)| {
20                    map.insert(nt.index, i);
21                    map
22                });
23
24        let mut new_non_terminals: Vec<NonTerminal> = Vec::new();
25
26        for i in 0..non_terminals.len() {
27            let (replace, b) = non_terminals.split_at_mut(i);
28            let (nt, _) = b.split_first_mut().unwrap();
29            let replace = &replace[..];
30
31            let old_productions = std::mem::replace(&mut nt.productions, Vec::new());
32            let mut recursive_productions: Vec<Vec<usize>> = Vec::new();
33            for mut production in old_productions {
34                if let Some(idx) = production.first() {
35                    if let Some(&arr_idx) = map.get(idx) {
36                        match arr_idx.cmp(&i) {
37                            std::cmp::Ordering::Less => {
38                                for prefix in &replace[arr_idx].productions {
39                                    let new_production =
40                                        prefix.iter().chain(production.iter().skip(1)).cloned();
41
42                                    if Some(&nt.index) == prefix.first() {
43                                        recursive_productions.push(new_production.skip(1).collect())
44                                    } else {
45                                        nt.productions.push(new_production.collect())
46                                    }
47                                }
48                            }
49                            std::cmp::Ordering::Equal => {
50                                production.remove(0);
51                                recursive_productions.push(production);
52                            }
53                            std::cmp::Ordering::Greater => {
54                                nt.productions.push(production);
55                            }
56                        };
57                    } else {
58                        nt.productions.push(production);
59                    }
60                }
61            }
62
63            if recursive_productions.len() > 0 {
64                let nt_prime_idx = offset + new_non_terminals.len();
65                for production in &mut nt.productions {
66                    production.push(nt_prime_idx);
67                }
68                for production in &mut recursive_productions {
69                    production.push(nt_prime_idx);
70                }
71                recursive_productions.push(vec![epsilon_idx]);
72                new_non_terminals.push(NonTerminal {
73                    index: nt_prime_idx,
74                    nullable: false,
75                    name: nt.name.clone(),
76                    first: HashSet::new(),
77                    follow: HashSet::new(),
78                    productions: recursive_productions,
79                });
80            }
81        }
82
83        for mut nt in new_non_terminals {
84            nt.name = self.get_symbol_prime_name(nt.name);
85            self.symbol_table.insert(nt.name.clone(), nt.index);
86            self.symbols.push(super::grammar::Symbol::NonTerminal(nt));
87        }
88
89        self.invalidate_nullable_first_follow();
90        self.calculate_nullable_first_follow();
91    }
92}