compiler_course_helper/grammar/
lr_fsm.rs

1use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet, VecDeque};
2
3use serde::{Deserialize, Serialize};
4
5use crate::Grammar;
6
7use super::{grammar::Symbol, END_MARK, EPSILON};
8
9#[derive(PartialEq, Eq, Hash, Debug, PartialOrd, Ord, Clone, Serialize)]
10pub struct DotProduction {
11    pub left: String,
12    pub production: Vec<String>,
13    pub position: usize,
14    pub lookahead: Option<Vec<String>>,
15}
16
17impl DotProduction {
18    pub fn new(left: String, production: Vec<String>, lookahead: Option<Vec<String>>) -> Self {
19        let mut i = 0;
20        while i < production.len() && production[i] == EPSILON {
21            i += 1;
22        }
23        Self {
24            left,
25            production,
26            position: i,
27            lookahead,
28        }
29    }
30
31    pub fn generate_next(&self) -> Self {
32        let mut i = self.position + 1;
33        while i < self.production.len() && self.production[i] == EPSILON {
34            i += 1;
35        }
36
37        Self {
38            left: self.left.clone(),
39            production: self.production.clone(),
40            position: i,
41            lookahead: self.lookahead.clone(),
42        }
43    }
44}
45
46#[derive(PartialEq, Eq, Debug, Clone, Serialize)]
47pub struct LRItem {
48    pub kernel: Vec<DotProduction>,
49    pub extend: Vec<DotProduction>,
50    pub edges: BTreeMap<String, usize>,
51}
52
53impl LRItem {
54    fn calculate_extend(&mut self, g: &Grammar) {
55        let is_lr1 = self.kernel[0].lookahead.is_some();
56        let mut extend: HashMap<usize, Option<HashSet<usize>>> = HashMap::new();
57        let mut q: VecDeque<usize> = VecDeque::new();
58
59        let calculate_first = |production: &[String]| -> Vec<usize> {
60            g.calculate_first_for_production(
61                &production
62                    .iter()
63                    .map(|s| g.get_symbol_index(s).unwrap())
64                    .collect::<Vec<_>>(),
65            )
66            .into_iter()
67            .collect()
68        };
69
70        // use self.kernel to initialize self.extend
71        for c in &self.kernel {
72            if let Some(symbol) = c.production.get(c.position) {
73                if let Symbol::NonTerminal(nt) = g.get_symbol_by_name(symbol.as_str()) {
74                    if !extend.contains_key(&nt.index) {
75                        extend.insert(nt.index, if is_lr1 { Some(HashSet::new()) } else { None });
76                        q.push_back(nt.index);
77                    }
78
79                    if is_lr1 {
80                        let lookahead = if c.position + 1 < c.production.len() {
81                            calculate_first(&c.production[c.position + 1..])
82                        } else {
83                            c.lookahead
84                                .as_ref()
85                                .unwrap()
86                                .iter()
87                                .map(|s| g.get_symbol_index(s).unwrap())
88                                .collect()
89                        };
90                        extend
91                            .get_mut(&nt.index)
92                            .unwrap()
93                            .as_mut()
94                            .unwrap()
95                            .extend(lookahead.into_iter());
96                    }
97                }
98            }
99        }
100
101        // iteratively calculate self.extend
102        while let Some(s_idx) = q.pop_front() {
103            for production in &g.symbols[s_idx].non_terminal().unwrap().productions {
104                if let Symbol::NonTerminal(nt) = &g.symbols[production[0]] {
105                    if !extend.contains_key(&nt.index) {
106                        extend.insert(nt.index, if is_lr1 { Some(HashSet::new()) } else { None });
107                        q.push_back(nt.index);
108                    }
109
110                    if is_lr1 {
111                        let lookahead = if production.len() > 1 {
112                            g.calculate_first_for_production(&production[1..])
113                        } else {
114                            extend[&s_idx].as_ref().unwrap().clone()
115                        };
116                        extend
117                            .get_mut(&nt.index)
118                            .unwrap()
119                            .as_mut()
120                            .unwrap()
121                            .extend(lookahead);
122                    }
123                }
124            }
125        }
126
127        for (nt_idx, lookahead) in extend {
128            let nt = g.symbols[nt_idx].non_terminal().unwrap();
129
130            let lookahead: Option<Vec<String>> = lookahead.and_then(|lookahead| {
131                let mut lookahead = lookahead
132                    .iter()
133                    .map(|&i| g.get_symbol_name(i).to_string())
134                    .collect::<Vec<_>>();
135                lookahead.sort();
136                Some(lookahead)
137            });
138
139            for production in &nt.productions {
140                self.extend.push(DotProduction::new(
141                    nt.name.clone(),
142                    g.production_to_vec_str(production)
143                        .iter()
144                        .map(|s| s.to_string())
145                        .collect(),
146                    lookahead.clone(),
147                ));
148            }
149
150            self.extend.sort();
151        }
152    }
153}
154
155impl LRItem {
156    fn new(mut kernel: Vec<DotProduction>) -> Self {
157        kernel.sort();
158        Self {
159            kernel,
160            extend: Vec::new(),
161            edges: BTreeMap::new(),
162        }
163    }
164
165    fn core_eq(&self, rhs: &LRItem) -> bool {
166        if self.kernel.len() != rhs.kernel.len() || self.extend.len() != rhs.extend.len() {
167            return false;
168        }
169        let a = self.kernel.iter().chain(self.extend.iter());
170        let b = rhs.kernel.iter().chain(rhs.extend.iter());
171        a.zip(b).all(|(x, y)| {
172            x.left == y.left && x.production == y.production && x.position == y.position
173        })
174    }
175}
176
177#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
178pub enum LRFSMType {
179    LR0,
180    LR1,
181    LALR,
182}
183
184#[derive(Debug, Serialize)]
185pub struct LRFSM {
186    pub t: LRFSMType,
187    pub(super) terminals: Vec<String>,
188    pub(super) non_terminals: Vec<String>,
189
190    pub states: Vec<LRItem>,
191    pub start: usize,
192    pub end: usize,
193    pub follow: Option<HashMap<String, Vec<String>>>,
194}
195
196impl Grammar {
197    pub fn to_lr_fsm(&mut self, t: LRFSMType) -> Result<LRFSM, String> {
198        if self.start_symbol.is_none() {
199            return Err("start symbol is not set".to_string());
200        }
201
202        if t == LRFSMType::LR0 && !self.is_nullable_first_follow_valid() {
203            self.calculate_nullable_first_follow();
204        }
205
206        let real_start = self.get_symbol_name(self.start_symbol.unwrap()).to_string();
207        let dummy_start = self.get_symbol_prime_name(real_start.clone());
208        let mut start_state = LRItem::new(vec![DotProduction::new(
209            dummy_start.clone(),
210            vec![real_start],
211            if t == LRFSMType::LR1 || t == LRFSMType::LALR {
212                Some(vec![END_MARK.to_string()])
213            } else {
214                None
215            },
216        )]);
217        start_state.calculate_extend(self);
218        let mut states = vec![start_state];
219        let mut q: VecDeque<usize> = VecDeque::new();
220        q.push_back(0);
221
222        let mut end: usize = 0;
223
224        while let Some(u) = q.pop_front() {
225            let mut edges: BTreeMap<String, BTreeSet<DotProduction>> = BTreeMap::new();
226
227            let productions = states[u].kernel.iter().chain(states[u].extend.iter());
228            for production in productions {
229                if production.production.len() == 1
230                    && production.position == 1
231                    && production.left == dummy_start
232                {
233                    end = u;
234                }
235
236                if production.position < production.production.len() {
237                    let e = production.production[production.position].clone();
238                    let item = edges.entry(e).or_insert(BTreeSet::new());
239                    item.insert(production.generate_next());
240                }
241            }
242
243            for (e, kernel) in edges {
244                let mut s = LRItem::new(kernel.into_iter().collect());
245                s.calculate_extend(self);
246
247                let mut entry_or_insert = |s: LRItem| {
248                    for (i, state) in states.iter().enumerate() {
249                        if state.kernel == s.kernel && state.extend == s.extend {
250                            return i;
251                        }
252                    }
253                    states.push(s);
254                    q.push_back(states.len() - 1);
255                    states.len() - 1
256                };
257
258                let v_idx = entry_or_insert(s);
259                states[u].edges.insert(e.clone(), v_idx);
260            }
261        }
262
263        if t == LRFSMType::LALR {
264            let mut new_id: Vec<Option<usize>> = vec![None; states.len()];
265            let mut cnt: usize = 0;
266            for i in 0..states.len() {
267                if new_id[i].is_some() {
268                    continue;
269                }
270                let id = cnt;
271                cnt += 1;
272                new_id[i] = Some(id);
273                for j in i + 1..states.len() {
274                    if states[i].core_eq(&states[j]) {
275                        assert_eq!(new_id[j], None);
276                        new_id[j] = Some(id);
277                    }
278                }
279            }
280
281            let mut new_states: Vec<Vec<LRItem>> = vec![Vec::new(); cnt];
282            for (i, s) in states.into_iter().enumerate() {
283                new_states[new_id[i].unwrap()].push(s);
284            }
285
286            states = new_states
287                .into_iter()
288                .map(|mut arr| {
289                    for (_, v) in arr[0].edges.iter_mut() {
290                        *v = new_id[*v].unwrap();
291                    }
292
293                    arr.into_iter()
294                        .reduce(|mut accum, s| {
295                            for (x, y) in accum
296                                .kernel
297                                .iter_mut()
298                                .chain(accum.extend.iter_mut())
299                                .zip(s.kernel.iter().chain(s.extend.iter()))
300                            {
301                                x.lookahead
302                                    .as_mut()
303                                    .unwrap()
304                                    .extend(y.lookahead.as_ref().unwrap().iter().cloned());
305                                x.lookahead.as_mut().unwrap().sort();
306                                x.lookahead.as_mut().unwrap().dedup();
307                            }
308
309                            for (e, v) in s.edges {
310                                let to = accum.edges.entry(e).or_insert(new_id[v].unwrap());
311                                assert_eq!(*to, new_id[v].unwrap());
312                            }
313
314                            accum
315                        })
316                        .unwrap()
317                })
318                .collect();
319        }
320
321        Ok(LRFSM {
322            t,
323            terminals: self.terminal_iter().cloned().collect(),
324            non_terminals: self.non_terminal_iter().map(|nt| nt.name.clone()).collect(),
325            states,
326            start: 0,
327            end,
328            follow: if t == LRFSMType::LR0 {
329                let mut r: HashMap<String, Vec<String>> = HashMap::new();
330                r.insert(dummy_start, vec![END_MARK.to_string()]);
331                for nt in self.non_terminal_iter() {
332                    r.insert(
333                        nt.name.clone(),
334                        nt.follow
335                            .iter()
336                            .map(|i| self.get_symbol_name(*i).to_string())
337                            .collect(),
338                    );
339                }
340                Some(r)
341            } else {
342                None
343            },
344        })
345    }
346}
347
348#[derive(Debug, Clone, Serialize)]
349pub enum LRParsingTableAction {
350    Shift(usize),
351    Reduce((String, Vec<String>)),
352    Accept,
353}
354
355#[derive(Serialize)]
356pub struct LRParsingTable {
357    pub t: LRFSMType,
358    pub terminals: Vec<String>,
359    pub non_terminals: Vec<String>,
360    pub action: Vec<Vec<Vec<LRParsingTableAction>>>,
361    pub goto: Vec<Vec<Option<usize>>>,
362}
363
364impl LRFSM {
365    pub fn to_parsing_table(&self) -> LRParsingTable {
366        let dummy_start = &self.states[0].kernel[0].left;
367
368        let mut terminal_idx_map: HashMap<&str, usize> = HashMap::new();
369        for (i, s) in self.terminals.iter().enumerate() {
370            terminal_idx_map.insert(s, i);
371        }
372
373        let mut non_terminal_idx_map: HashMap<&str, usize> = HashMap::new();
374        for (i, s) in self.non_terminals.iter().enumerate() {
375            non_terminal_idx_map.insert(s, i);
376        }
377
378        let mut table = LRParsingTable {
379            t: self.t,
380            terminals: self.terminals.clone(),
381            non_terminals: self.non_terminals.clone(),
382            action: Vec::new(),
383            goto: Vec::new(),
384        };
385
386        for state in &self.states {
387            let mut action_row: Vec<Vec<LRParsingTableAction>> =
388                vec![Vec::new(); self.terminals.len()];
389            let mut goto_row: Vec<Option<usize>> = vec![None; self.non_terminals.len()];
390            for prodcution in state.kernel.iter().chain(state.extend.iter()) {
391                if prodcution.production.len() == prodcution.position {
392                    if &prodcution.left == dummy_start {
393                        action_row[terminal_idx_map[END_MARK]].push(LRParsingTableAction::Accept);
394                        continue;
395                    }
396
397                    let lookahead = if let Some(lookahead) = &prodcution.lookahead {
398                        lookahead
399                    } else {
400                        &self.follow.as_ref().unwrap()[&prodcution.left]
401                    };
402                    for terminal in lookahead {
403                        action_row[terminal_idx_map[terminal.as_str()]].push(
404                            LRParsingTableAction::Reduce((
405                                prodcution.left.clone(),
406                                prodcution.production.clone(),
407                            )),
408                        );
409                    }
410                }
411            }
412            for (e, v) in &state.edges {
413                if let Some(idx) = terminal_idx_map.get(e.as_str()) {
414                    action_row[*idx].push(LRParsingTableAction::Shift(*v));
415                }
416                if let Some(idx) = non_terminal_idx_map.get(e.as_str()) {
417                    goto_row[*idx] = Some(*v);
418                }
419            }
420            table.action.push(action_row);
421            table.goto.push(goto_row);
422        }
423
424        table
425    }
426}