compiler_course_helper/grammar/
nullable_first_follow.rs

1use std::collections::HashSet;
2
3use super::{grammar::Symbol, Grammar, END_MARK};
4
5impl Grammar {
6    pub fn calculate_nullable_first_follow(&mut self) {
7        if let Some(start_idx) = self.start_symbol {
8            self.symbols[start_idx]
9                .mut_non_terminal()
10                .unwrap()
11                .follow
12                .insert(self.symbol_table[END_MARK]);
13            self.calculate_nullable();
14            self.calculate_first();
15            self.calculate_follow();
16
17            self.validate_nullable_first_follow();
18        }
19    }
20
21    pub fn reset_nullable_first_follow(&mut self) {
22        for nt in self.non_terminal_iter_mut() {
23            nt.nullable = false;
24            nt.first = HashSet::new();
25            nt.follow = HashSet::new();
26        }
27    }
28
29    fn calculate_nullable(&mut self) {
30        let mut changed = true;
31        while changed {
32            changed = false;
33            for i in 0..self.symbols.len() {
34                let nullable: bool = match &self.symbols[i] {
35                    Symbol::Terminal(_) => continue,
36                    Symbol::NonTerminal(nt) => {
37                        if nt.nullable {
38                            continue;
39                        }
40                        nt.productions.iter().any(|production| {
41                            production.iter().all(|s| match &self.symbols[*s] {
42                                Symbol::Terminal(_) => false,
43                                Symbol::NonTerminal(e) => e.nullable,
44                            })
45                        })
46                    }
47                };
48
49                if nullable {
50                    self.symbols[i].mut_non_terminal().unwrap().nullable = true;
51                    changed = true;
52                }
53            }
54        }
55    }
56
57    pub fn calculate_first_for_production(&self, production: &[usize]) -> HashSet<usize> {
58        let mut first: HashSet<usize> = HashSet::new();
59        for (idx, symbol) in production.iter().map(|i| (*i, &self.symbols[*i])) {
60            match symbol {
61                Symbol::Terminal(_) => {
62                    first.insert(idx);
63                    break;
64                }
65                Symbol::NonTerminal(nt) => {
66                    first.extend(nt.first.iter().cloned());
67                    if !nt.nullable {
68                        break;
69                    }
70                }
71            }
72        }
73        first
74    }
75
76    fn calculate_first(&mut self) {
77        let mut changed = true;
78        while changed {
79            changed = false;
80            for i in 0..self.symbols.len() {
81                let first: HashSet<usize> = match &self.symbols[i] {
82                    Symbol::Terminal(_) => continue,
83                    Symbol::NonTerminal(nt) => {
84                        nt.productions
85                            .iter()
86                            .fold(HashSet::new(), |mut first, production| {
87                                first.extend(
88                                    self.calculate_first_for_production(production).into_iter(),
89                                );
90                                first
91                            })
92                    }
93                };
94
95                let nt = self.symbols[i].mut_non_terminal().unwrap();
96                if nt.first.len() != first.len() {
97                    changed = true;
98                    nt.first = first;
99                }
100            }
101        }
102    }
103
104    pub fn calculate_follow_for_production(&self, production: &Vec<usize>) -> HashSet<usize> {
105        let mut follow = HashSet::new();
106        for idx in production.iter().rev() {
107            match &self.symbols[*idx] {
108                Symbol::Terminal(_) => {
109                    follow.insert(*idx);
110                    break;
111                }
112                Symbol::NonTerminal(nt) => {
113                    follow.extend(nt.follow.iter().cloned());
114                    if !nt.nullable {
115                        break;
116                    }
117                }
118            }
119        }
120        follow
121    }
122
123    fn calculate_follow(&mut self) {
124        let mut changed = true;
125        while changed {
126            changed = false;
127            for i in 0..self.symbols.len() {
128                if let Symbol::Terminal(_) = self.symbols[i] {
129                    continue;
130                }
131
132                let productions = self.symbols[i].non_terminal().unwrap().productions.clone();
133                for production in productions {
134                    let mut first: HashSet<usize> = HashSet::new();
135                    let mut left_follow =
136                        Some(self.symbols[i].non_terminal().unwrap().follow.clone());
137
138                    for i in (0..production.len()).rev() {
139                        match &mut self.symbols[production[i]] {
140                            Symbol::Terminal(_) => {
141                                first = HashSet::new();
142                                first.insert(production[i]);
143                                left_follow = None;
144                            }
145                            Symbol::NonTerminal(nt) => {
146                                let len = nt.follow.len();
147
148                                if let Some(left_follow) = &left_follow {
149                                    nt.follow.extend(left_follow.iter().cloned());
150                                }
151                                nt.follow.extend(first.iter().cloned());
152                                changed |= len != nt.follow.len();
153
154                                if !nt.nullable {
155                                    first = nt.first.clone();
156                                    left_follow = None;
157                                } else {
158                                    first.extend(nt.first.iter().cloned());
159                                }
160                            }
161                        }
162                    }
163                }
164            }
165        }
166    }
167}