1use crate::atn::{Atn, AtnState, AtnStateKind, AtnType, IntervalSet, LexerAction, Transition};
2use crate::errors::AntlrError;
3use crate::token::TOKEN_EOF;
4
5pub const SERIALIZED_VERSION: i32 = 4;
6
7#[derive(Clone, Debug)]
13pub struct SerializedAtn {
14 values: Vec<i32>,
15}
16
17impl SerializedAtn {
18 pub fn from_i32(values: impl Into<Vec<i32>>) -> Self {
20 Self {
21 values: values.into(),
22 }
23 }
24
25 pub fn from_chars(chars: impl IntoIterator<Item = char>) -> Self {
32 Self {
33 values: chars.into_iter().map(|ch| ch as i32).collect(),
34 }
35 }
36
37 pub fn values(&self) -> &[i32] {
38 &self.values
39 }
40}
41
42#[derive(Debug)]
44pub struct AtnDeserializer<'a> {
45 values: &'a [i32],
46 cursor: usize,
47}
48
49impl<'a> AtnDeserializer<'a> {
50 pub fn new(serialized: &'a SerializedAtn) -> Self {
52 Self {
53 values: serialized.values(),
54 cursor: 0,
55 }
56 }
57
58 pub fn deserialize(mut self) -> Result<Atn, AntlrError> {
67 let version = self.read("version")?;
68 if version != SERIALIZED_VERSION {
69 return Err(AntlrError::Unsupported(format!(
70 "serialized ATN version {version}; expected {SERIALIZED_VERSION}"
71 )));
72 }
73
74 let grammar_type = match self.read("grammar type")? {
75 0 => AtnType::Lexer,
76 1 => AtnType::Parser,
77 other => {
78 return Err(AntlrError::Unsupported(format!(
79 "serialized ATN grammar type {other}"
80 )));
81 }
82 };
83 let max_token_type = self.read("max token type")?;
84 let mut atn = Atn::new(grammar_type, max_token_type);
85
86 self.deserialize_states(&mut atn)?;
87 self.deserialize_non_greedy_states(&mut atn)?;
88 self.deserialize_precedence_states(&mut atn)?;
89 self.deserialize_rules(&mut atn)?;
90 self.deserialize_modes(&mut atn)?;
91 let sets = self.deserialize_sets()?;
92 self.deserialize_edges(&mut atn, &sets)?;
93 self.deserialize_decisions(&mut atn)?;
94 if grammar_type == AtnType::Lexer {
95 self.deserialize_lexer_actions(&mut atn)?;
96 }
97 mark_precedence_decisions(&mut atn);
98
99 Ok(atn)
100 }
101
102 fn deserialize_states(&mut self, atn: &mut Atn) -> Result<(), AntlrError> {
105 let state_count = self.read_usize("state count")?;
106 for state_number in 0..state_count {
107 let kind = decode_state_kind(self.read("state type")?)?;
108 if kind == AtnStateKind::Invalid {
109 atn.add_state(AtnState::new(state_number, kind));
110 continue;
111 }
112
113 let rule_index = self.read("rule index")?;
114 let mut state = AtnState::new(state_number, kind);
115 if rule_index >= 0 {
116 let rule_index = usize::try_from(rule_index).map_err(|_| {
117 AntlrError::Unsupported(format!("rule index cannot be negative: {rule_index}"))
118 })?;
119 state = state.with_rule_index(rule_index);
120 }
121
122 match kind {
123 AtnStateKind::LoopEnd => {
124 state.loop_back_state = Some(self.read_usize("loop back state")?);
125 }
126 AtnStateKind::BlockStart
127 | AtnStateKind::PlusBlockStart
128 | AtnStateKind::StarBlockStart => {
129 state.end_state = Some(self.read_usize("block end state")?);
130 }
131 _ => {}
132 }
133
134 atn.add_state(state);
135 }
136 Ok(())
137 }
138
139 fn deserialize_non_greedy_states(&mut self, atn: &mut Atn) -> Result<(), AntlrError> {
142 let count = self.read_usize("non-greedy state count")?;
143 for _ in 0..count {
144 let state_number = self.read_usize("non-greedy state")?;
145 let Some(state) = atn.state_mut(state_number) else {
146 return Err(AntlrError::Unsupported(format!(
147 "non-greedy state {state_number} outside state list"
148 )));
149 };
150 state.non_greedy = true;
151 }
152 Ok(())
153 }
154
155 fn deserialize_precedence_states(&mut self, atn: &mut Atn) -> Result<(), AntlrError> {
158 let count = self.read_usize("precedence state count")?;
159 for _ in 0..count {
160 let state_number = self.read_usize("precedence state")?;
161 let Some(state) = atn.state_mut(state_number) else {
162 return Err(AntlrError::Unsupported(format!(
163 "precedence state {state_number} outside state list"
164 )));
165 };
166 state.left_recursive_rule = true;
167 }
168 Ok(())
169 }
170
171 fn deserialize_rules(&mut self, atn: &mut Atn) -> Result<(), AntlrError> {
174 let rule_count = self.read_usize("rule count")?;
175 let mut starts = Vec::with_capacity(rule_count);
176 let mut token_types = Vec::new();
177 for _ in 0..rule_count {
178 starts.push(self.read_usize("rule start state")?);
179 if atn.grammar_type() == AtnType::Lexer {
180 token_types.push(self.read("rule token type")?);
181 }
182 }
183
184 let mut stops = vec![usize::MAX; rule_count];
185 for state in atn.states() {
186 if state.kind == AtnStateKind::RuleStop {
187 let Some(rule_index) = state.rule_index else {
188 continue;
189 };
190 if let Some(stop) = stops.get_mut(rule_index) {
191 *stop = state.state_number;
192 }
193 }
194 }
195
196 atn.set_rule_to_start_state(starts);
197 atn.set_rule_to_stop_state(stops);
198 atn.set_rule_to_token_type(token_types);
199 Ok(())
200 }
201
202 fn deserialize_modes(&mut self, atn: &mut Atn) -> Result<(), AntlrError> {
204 let mode_count = self.read_usize("mode count")?;
205 for _ in 0..mode_count {
206 atn.add_mode_start_state(self.read_usize("mode start state")?);
207 }
208 Ok(())
209 }
210
211 fn deserialize_sets(&mut self) -> Result<Vec<IntervalSet>, AntlrError> {
214 let set_count = self.read_usize("set count")?;
215 let mut sets = Vec::with_capacity(set_count);
216 for _ in 0..set_count {
217 let interval_count = self.read_usize("interval count")?;
218 let mut set = IntervalSet::new();
219 let contains_eof = self.read("set contains EOF")? != 0;
220 if contains_eof {
221 set.add(TOKEN_EOF);
222 }
223 for _ in 0..interval_count {
224 let start = self.read("interval start")?;
225 let stop = self.read("interval stop")?;
226 set.add_range(start, stop);
227 }
228 sets.push(set);
229 }
230 Ok(sets)
231 }
232
233 fn deserialize_edges(&mut self, atn: &mut Atn, sets: &[IntervalSet]) -> Result<(), AntlrError> {
235 let transition_count = self.read_usize("transition count")?;
236 for _ in 0..transition_count {
237 let src = self.read_usize("transition source")?;
238 let target = self.read_usize("transition target")?;
239 let kind = self.read("transition type")?;
240 let a = self.read("transition arg 1")?;
241 let b = self.read("transition arg 2")?;
242 let c = self.read("transition arg 3")?;
243 let transition = decode_transition(target, kind, a, b, c, sets)?;
244 let Some(state) = atn.state_mut(src) else {
245 return Err(AntlrError::Unsupported(format!(
246 "transition source {src} outside state list"
247 )));
248 };
249 state.add_transition(transition);
250 }
251
252 let mut return_edges = Vec::new();
253 for state in atn.states() {
254 for transition in &state.transitions {
255 let Transition::Rule {
256 target,
257 follow_state,
258 ..
259 } = transition
260 else {
261 continue;
262 };
263 let Some(rule_index) = atn.state(*target).and_then(|state| state.rule_index) else {
264 continue;
265 };
266 let Some(stop_state) = atn.rule_to_stop_state().get(rule_index).copied() else {
267 continue;
268 };
269 if stop_state != usize::MAX {
270 return_edges.push((stop_state, *follow_state));
271 }
272 }
273 }
274 for (stop_state, follow_state) in return_edges {
275 if let Some(state) = atn.state_mut(stop_state) {
276 state.add_transition(Transition::Epsilon {
277 target: follow_state,
278 });
279 }
280 }
281
282 Ok(())
283 }
284
285 fn deserialize_decisions(&mut self, atn: &mut Atn) -> Result<(), AntlrError> {
287 let decision_count = self.read_usize("decision count")?;
288 for _ in 0..decision_count {
289 atn.add_decision_state(self.read_usize("decision state")?);
290 }
291 Ok(())
292 }
293
294 fn deserialize_lexer_actions(&mut self, atn: &mut Atn) -> Result<(), AntlrError> {
297 let action_count = self.read_usize("lexer action count")?;
298 let mut actions = Vec::with_capacity(action_count);
299 for _ in 0..action_count {
300 let action_type = self.read("lexer action type")?;
301 let data1 = self.read("lexer action data 1")?;
302 let data2 = self.read("lexer action data 2")?;
303 actions.push(decode_lexer_action(action_type, data1, data2)?);
304 }
305 atn.set_lexer_actions(actions);
306 Ok(())
307 }
308
309 fn read(&mut self, label: &str) -> Result<i32, AntlrError> {
312 let value = self.values.get(self.cursor).copied().ok_or_else(|| {
313 AntlrError::Unsupported(format!("serialized ATN ended while reading {label}"))
314 })?;
315 self.cursor += 1;
316 Ok(value)
317 }
318
319 fn read_usize(&mut self, label: &str) -> Result<usize, AntlrError> {
321 let value = self.read(label)?;
322 usize::try_from(value)
323 .map_err(|_| AntlrError::Unsupported(format!("{label} cannot be negative: {value}")))
324 }
325}
326
327fn decode_state_kind(value: i32) -> Result<AtnStateKind, AntlrError> {
329 let kind = match value {
330 0 => AtnStateKind::Invalid,
331 1 => AtnStateKind::Basic,
332 2 => AtnStateKind::RuleStart,
333 3 => AtnStateKind::BlockStart,
334 4 => AtnStateKind::PlusBlockStart,
335 5 => AtnStateKind::StarBlockStart,
336 6 => AtnStateKind::TokenStart,
337 7 => AtnStateKind::RuleStop,
338 8 => AtnStateKind::BlockEnd,
339 9 => AtnStateKind::StarLoopBack,
340 10 => AtnStateKind::StarLoopEntry,
341 11 => AtnStateKind::PlusLoopBack,
342 12 => AtnStateKind::LoopEnd,
343 other => return Err(AntlrError::Unsupported(format!("ATN state type {other}"))),
344 };
345 Ok(kind)
346}
347
348fn decode_transition(
350 target: usize,
351 kind: i32,
352 a: i32,
353 b: i32,
354 c: i32,
355 sets: &[IntervalSet],
356) -> Result<Transition, AntlrError> {
357 let transition = match kind {
358 1 => Transition::Epsilon { target },
359 2 => Transition::Range {
360 target,
361 start: if c != 0 { TOKEN_EOF } else { a },
362 stop: b,
363 },
364 3 => Transition::Rule {
365 target: read_index(a, "rule transition target")?,
366 rule_index: read_index(b, "rule transition rule index")?,
367 follow_state: target,
368 precedence: c,
369 },
370 4 => Transition::Predicate {
371 target,
372 rule_index: read_index(a, "predicate rule index")?,
373 pred_index: read_index(b, "predicate index")?,
374 context_dependent: c != 0,
375 },
376 5 => Transition::Atom {
377 target,
378 label: if c != 0 { TOKEN_EOF } else { a },
379 },
380 6 => Transition::Action {
381 target,
382 rule_index: read_index(a, "action rule index")?,
383 action_index: usize::try_from(b).ok(),
384 context_dependent: c != 0,
385 },
386 7 => Transition::Set {
387 target,
388 set: sets
389 .get(read_index(a, "set transition set index")?)
390 .cloned()
391 .ok_or_else(|| {
392 AntlrError::Unsupported(format!("set index {a} outside set list"))
393 })?,
394 },
395 8 => Transition::NotSet {
396 target,
397 set: sets
398 .get(read_index(a, "not-set transition set index")?)
399 .cloned()
400 .ok_or_else(|| {
401 AntlrError::Unsupported(format!("set index {a} outside set list"))
402 })?,
403 },
404 9 => Transition::Wildcard { target },
405 10 => Transition::Precedence {
406 target,
407 precedence: a,
408 },
409 other => {
410 return Err(AntlrError::Unsupported(format!(
411 "ATN transition type {other}"
412 )));
413 }
414 };
415 Ok(transition)
416}
417
418fn decode_lexer_action(
421 action_type: i32,
422 data1: i32,
423 data2: i32,
424) -> Result<LexerAction, AntlrError> {
425 let action = match action_type {
426 0 => LexerAction::Channel(data1),
427 1 => LexerAction::Custom {
428 rule_index: data1,
429 action_index: data2,
430 },
431 2 => LexerAction::Mode(data1),
432 3 => LexerAction::More,
433 4 => LexerAction::PopMode,
434 5 => LexerAction::PushMode(data1),
435 6 => LexerAction::Skip,
436 7 => LexerAction::Type(data1),
437 other => {
438 return Err(AntlrError::Unsupported(format!(
439 "lexer action type {other}"
440 )));
441 }
442 };
443 Ok(action)
444}
445
446fn mark_precedence_decisions(atn: &mut Atn) {
448 let mut decisions = Vec::new();
449 for state in atn.states() {
450 if state.kind != AtnStateKind::StarLoopEntry {
451 continue;
452 }
453 let Some(rule_index) = state.rule_index else {
454 continue;
455 };
456 let Some(rule_start) = atn
457 .rule_to_start_state()
458 .get(rule_index)
459 .and_then(|state_number| atn.state(*state_number))
460 else {
461 continue;
462 };
463 if !rule_start.left_recursive_rule {
464 continue;
465 }
466 let Some(loop_end_state) = state
467 .transitions
468 .last()
469 .and_then(|transition| atn.state(transition.target()))
470 else {
471 continue;
472 };
473 if loop_end_state.kind != AtnStateKind::LoopEnd {
474 continue;
475 }
476 let Some(target) = loop_end_state
477 .transitions
478 .first()
479 .and_then(|transition| atn.state(transition.target()))
480 else {
481 continue;
482 };
483 if target.kind == AtnStateKind::RuleStop {
484 decisions.push(state.state_number);
485 }
486 }
487
488 for state_number in decisions {
489 if let Some(state) = atn.state_mut(state_number) {
490 state.precedence_rule_decision = true;
491 }
492 }
493}
494
495fn read_index(value: i32, label: &str) -> Result<usize, AntlrError> {
498 usize::try_from(value)
499 .map_err(|_| AntlrError::Unsupported(format!("{label} cannot be negative: {value}")))
500}
501
502#[cfg(test)]
503mod tests {
504 use super::*;
505
506 #[test]
507 fn reads_small_parser_atn() {
508 let serialized = SerializedAtn::from_i32([
509 4, 1, 9, 2, 2, 0, 7, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 5, 42, 0, 0, 1, 0,
523 ]);
524 let atn = AtnDeserializer::new(&serialized)
525 .deserialize()
526 .expect("artificial parser ATN should deserialize");
527 assert_eq!(atn.grammar_type(), AtnType::Parser);
528 assert_eq!(atn.max_token_type(), 9);
529 assert_eq!(atn.states().len(), 2);
530 assert_eq!(atn.rule_to_start_state(), &[0]);
531 assert_eq!(atn.rule_to_stop_state(), &[1]);
532 assert_eq!(atn.decision_to_state(), &[0]);
533 }
534}