compiler_course_helper/grammar/
eliminate_left_recursion.rs1use std::collections::{HashMap, HashSet};
2
3use super::{grammar::NonTerminal, Grammar, EPSILON};
4
5impl Grammar {
6 pub fn eliminate_left_recursion(&mut self) {
7 if !self.is_nullable_first_follow_valid() {
8 self.calculate_nullable_first_follow();
9 }
10
11 let epsilon_idx = self.get_symbol_index(EPSILON).unwrap();
12 let offset = self.symbols.len();
13
14 let mut non_terminals = self.non_terminal_iter_mut().collect::<Vec<_>>();
15 let map: HashMap<usize, usize> =
16 non_terminals
17 .iter()
18 .enumerate()
19 .fold(HashMap::new(), |mut map, (i, nt)| {
20 map.insert(nt.index, i);
21 map
22 });
23
24 let mut new_non_terminals: Vec<NonTerminal> = Vec::new();
25
26 for i in 0..non_terminals.len() {
27 let (replace, b) = non_terminals.split_at_mut(i);
28 let (nt, _) = b.split_first_mut().unwrap();
29 let replace = &replace[..];
30
31 let old_productions = std::mem::replace(&mut nt.productions, Vec::new());
32 let mut recursive_productions: Vec<Vec<usize>> = Vec::new();
33 for mut production in old_productions {
34 if let Some(idx) = production.first() {
35 if let Some(&arr_idx) = map.get(idx) {
36 match arr_idx.cmp(&i) {
37 std::cmp::Ordering::Less => {
38 for prefix in &replace[arr_idx].productions {
39 let new_production =
40 prefix.iter().chain(production.iter().skip(1)).cloned();
41
42 if Some(&nt.index) == prefix.first() {
43 recursive_productions.push(new_production.skip(1).collect())
44 } else {
45 nt.productions.push(new_production.collect())
46 }
47 }
48 }
49 std::cmp::Ordering::Equal => {
50 production.remove(0);
51 recursive_productions.push(production);
52 }
53 std::cmp::Ordering::Greater => {
54 nt.productions.push(production);
55 }
56 };
57 } else {
58 nt.productions.push(production);
59 }
60 }
61 }
62
63 if recursive_productions.len() > 0 {
64 let nt_prime_idx = offset + new_non_terminals.len();
65 for production in &mut nt.productions {
66 production.push(nt_prime_idx);
67 }
68 for production in &mut recursive_productions {
69 production.push(nt_prime_idx);
70 }
71 recursive_productions.push(vec![epsilon_idx]);
72 new_non_terminals.push(NonTerminal {
73 index: nt_prime_idx,
74 nullable: false,
75 name: nt.name.clone(),
76 first: HashSet::new(),
77 follow: HashSet::new(),
78 productions: recursive_productions,
79 });
80 }
81 }
82
83 for mut nt in new_non_terminals {
84 nt.name = self.get_symbol_prime_name(nt.name);
85 self.symbol_table.insert(nt.name.clone(), nt.index);
86 self.symbols.push(super::grammar::Symbol::NonTerminal(nt));
87 }
88
89 self.invalidate_nullable_first_follow();
90 self.calculate_nullable_first_follow();
91 }
92}