1use std::borrow::Cow;
2
3use crate::atn::{Atn, AtnState, AtnStateKind, AtnType, IntervalSet, LexerAction, Transition};
4use crate::errors::AntlrError;
5use crate::token::TOKEN_EOF;
6
7pub const SERIALIZED_VERSION: i32 = 4;
8
9#[derive(Clone, Debug)]
15pub struct SerializedAtn<'a> {
16 values: Cow<'a, [i32]>,
17}
18
19impl<'a> SerializedAtn<'a> {
20 pub const fn from_i32(values: &'a [i32]) -> Self {
22 Self {
23 values: Cow::Borrowed(values),
24 }
25 }
26
27 pub fn from_chars(chars: impl IntoIterator<Item = char>) -> SerializedAtn<'static> {
34 SerializedAtn {
35 values: Cow::Owned(chars.into_iter().map(|ch| ch as i32).collect()),
36 }
37 }
38
39 pub fn values(&self) -> &[i32] {
40 &self.values
41 }
42}
43
44#[derive(Debug)]
46pub struct AtnDeserializer<'a> {
47 values: &'a [i32],
48 cursor: usize,
49}
50
51impl<'a> AtnDeserializer<'a> {
52 pub fn new(serialized: &'a SerializedAtn<'_>) -> Self {
54 Self {
55 values: serialized.values(),
56 cursor: 0,
57 }
58 }
59
60 pub fn deserialize(mut self) -> Result<Atn, AntlrError> {
69 let version = self.read("version")?;
70 if version != SERIALIZED_VERSION {
71 return Err(AntlrError::Unsupported(format!(
72 "serialized ATN version {version}; expected {SERIALIZED_VERSION}"
73 )));
74 }
75
76 let grammar_type = match self.read("grammar type")? {
77 0 => AtnType::Lexer,
78 1 => AtnType::Parser,
79 other => {
80 return Err(AntlrError::Unsupported(format!(
81 "serialized ATN grammar type {other}"
82 )));
83 }
84 };
85 let max_token_type = self.read("max token type")?;
86 let mut atn = Atn::new(grammar_type, max_token_type);
87
88 self.deserialize_states(&mut atn)?;
89 self.deserialize_non_greedy_states(&mut atn)?;
90 self.deserialize_precedence_states(&mut atn)?;
91 self.deserialize_rules(&mut atn)?;
92 self.deserialize_modes(&mut atn)?;
93 let sets = self.deserialize_sets()?;
94 self.deserialize_edges(&mut atn, &sets)?;
95 self.deserialize_decisions(&mut atn)?;
96 if grammar_type == AtnType::Lexer {
97 self.deserialize_lexer_actions(&mut atn)?;
98 }
99 mark_precedence_decisions(&mut atn);
100
101 Ok(atn)
102 }
103
104 fn deserialize_states(&mut self, atn: &mut Atn) -> Result<(), AntlrError> {
107 let state_count = self.read_usize("state count")?;
108 for state_number in 0..state_count {
109 let kind = decode_state_kind(self.read("state type")?)?;
110 if kind == AtnStateKind::Invalid {
111 atn.add_state(AtnState::new(state_number, kind));
112 continue;
113 }
114
115 let rule_index = self.read("rule index")?;
116 let mut state = AtnState::new(state_number, kind);
117 if rule_index >= 0 {
118 let rule_index = usize::try_from(rule_index).map_err(|_| {
119 AntlrError::Unsupported(format!("rule index cannot be negative: {rule_index}"))
120 })?;
121 state = state.with_rule_index(rule_index);
122 }
123
124 match kind {
125 AtnStateKind::LoopEnd => {
126 state.loop_back_state = Some(self.read_usize("loop back state")?);
127 }
128 AtnStateKind::BlockStart
129 | AtnStateKind::PlusBlockStart
130 | AtnStateKind::StarBlockStart => {
131 state.end_state = Some(self.read_usize("block end state")?);
132 }
133 _ => {}
134 }
135
136 atn.add_state(state);
137 }
138 Ok(())
139 }
140
141 fn deserialize_non_greedy_states(&mut self, atn: &mut Atn) -> Result<(), AntlrError> {
144 let count = self.read_usize("non-greedy state count")?;
145 for _ in 0..count {
146 let state_number = self.read_usize("non-greedy state")?;
147 let Some(state) = atn.state_mut(state_number) else {
148 return Err(AntlrError::Unsupported(format!(
149 "non-greedy state {state_number} outside state list"
150 )));
151 };
152 state.non_greedy = true;
153 }
154 Ok(())
155 }
156
157 fn deserialize_precedence_states(&mut self, atn: &mut Atn) -> Result<(), AntlrError> {
160 let count = self.read_usize("precedence state count")?;
161 for _ in 0..count {
162 let state_number = self.read_usize("precedence state")?;
163 let Some(state) = atn.state_mut(state_number) else {
164 return Err(AntlrError::Unsupported(format!(
165 "precedence state {state_number} outside state list"
166 )));
167 };
168 state.left_recursive_rule = true;
169 }
170 Ok(())
171 }
172
173 fn deserialize_rules(&mut self, atn: &mut Atn) -> Result<(), AntlrError> {
176 let rule_count = self.read_usize("rule count")?;
177 let mut starts = Vec::with_capacity(rule_count);
178 let mut token_types = Vec::new();
179 for _ in 0..rule_count {
180 starts.push(self.read_usize("rule start state")?);
181 if atn.grammar_type() == AtnType::Lexer {
182 token_types.push(self.read("rule token type")?);
183 }
184 }
185
186 let mut stops = vec![usize::MAX; rule_count];
187 for state in atn.states() {
188 if state.kind == AtnStateKind::RuleStop {
189 let Some(rule_index) = state.rule_index else {
190 continue;
191 };
192 if let Some(stop) = stops.get_mut(rule_index) {
193 *stop = state.state_number;
194 }
195 }
196 }
197
198 atn.set_rule_to_start_state(starts);
199 atn.set_rule_to_stop_state(stops);
200 atn.set_rule_to_token_type(token_types);
201 Ok(())
202 }
203
204 fn deserialize_modes(&mut self, atn: &mut Atn) -> Result<(), AntlrError> {
206 let mode_count = self.read_usize("mode count")?;
207 for _ in 0..mode_count {
208 atn.add_mode_start_state(self.read_usize("mode start state")?);
209 }
210 Ok(())
211 }
212
213 fn deserialize_sets(&mut self) -> Result<Vec<IntervalSet>, AntlrError> {
216 let set_count = self.read_usize("set count")?;
217 let mut sets = Vec::with_capacity(set_count);
218 for _ in 0..set_count {
219 let interval_count = self.read_usize("interval count")?;
220 let mut set = IntervalSet::new();
221 let contains_eof = self.read("set contains EOF")? != 0;
222 if contains_eof {
223 set.add(TOKEN_EOF);
224 }
225 for _ in 0..interval_count {
226 let start = self.read("interval start")?;
227 let stop = self.read("interval stop")?;
228 set.add_range(start, stop);
229 }
230 sets.push(set);
231 }
232 Ok(sets)
233 }
234
235 fn deserialize_edges(&mut self, atn: &mut Atn, sets: &[IntervalSet]) -> Result<(), AntlrError> {
237 let transition_count = self.read_usize("transition count")?;
238 for _ in 0..transition_count {
239 let src = self.read_usize("transition source")?;
240 let target = self.read_usize("transition target")?;
241 let kind = self.read("transition type")?;
242 let a = self.read("transition arg 1")?;
243 let b = self.read("transition arg 2")?;
244 let c = self.read("transition arg 3")?;
245 let transition = decode_transition(target, kind, a, b, c, sets)?;
246 let Some(state) = atn.state_mut(src) else {
247 return Err(AntlrError::Unsupported(format!(
248 "transition source {src} outside state list"
249 )));
250 };
251 state.add_transition(transition);
252 }
253
254 let mut return_edges = Vec::new();
255 for state in atn.states() {
256 for transition in &state.transitions {
257 let Transition::Rule {
258 target,
259 follow_state,
260 ..
261 } = transition
262 else {
263 continue;
264 };
265 let Some(rule_index) = atn.state(*target).and_then(|state| state.rule_index) else {
266 continue;
267 };
268 let Some(stop_state) = atn.rule_to_stop_state().get(rule_index).copied() else {
269 continue;
270 };
271 if stop_state != usize::MAX {
272 return_edges.push((stop_state, *follow_state));
273 }
274 }
275 }
276 for (stop_state, follow_state) in return_edges {
277 if let Some(state) = atn.state_mut(stop_state) {
278 state.add_transition(Transition::Epsilon {
279 target: follow_state,
280 });
281 }
282 }
283
284 Ok(())
285 }
286
287 fn deserialize_decisions(&mut self, atn: &mut Atn) -> Result<(), AntlrError> {
289 let decision_count = self.read_usize("decision count")?;
290 for _ in 0..decision_count {
291 atn.add_decision_state(self.read_usize("decision state")?);
292 }
293 Ok(())
294 }
295
296 fn deserialize_lexer_actions(&mut self, atn: &mut Atn) -> Result<(), AntlrError> {
299 let action_count = self.read_usize("lexer action count")?;
300 let mut actions = Vec::with_capacity(action_count);
301 for _ in 0..action_count {
302 let action_type = self.read("lexer action type")?;
303 let data1 = self.read("lexer action data 1")?;
304 let data2 = self.read("lexer action data 2")?;
305 actions.push(decode_lexer_action(action_type, data1, data2)?);
306 }
307 atn.set_lexer_actions(actions);
308 Ok(())
309 }
310
311 fn read(&mut self, label: &str) -> Result<i32, AntlrError> {
314 let value = self.values.get(self.cursor).copied().ok_or_else(|| {
315 AntlrError::Unsupported(format!("serialized ATN ended while reading {label}"))
316 })?;
317 self.cursor += 1;
318 Ok(value)
319 }
320
321 fn read_usize(&mut self, label: &str) -> Result<usize, AntlrError> {
323 let value = self.read(label)?;
324 usize::try_from(value)
325 .map_err(|_| AntlrError::Unsupported(format!("{label} cannot be negative: {value}")))
326 }
327}
328
329fn decode_state_kind(value: i32) -> Result<AtnStateKind, AntlrError> {
331 let kind = match value {
332 0 => AtnStateKind::Invalid,
333 1 => AtnStateKind::Basic,
334 2 => AtnStateKind::RuleStart,
335 3 => AtnStateKind::BlockStart,
336 4 => AtnStateKind::PlusBlockStart,
337 5 => AtnStateKind::StarBlockStart,
338 6 => AtnStateKind::TokenStart,
339 7 => AtnStateKind::RuleStop,
340 8 => AtnStateKind::BlockEnd,
341 9 => AtnStateKind::StarLoopBack,
342 10 => AtnStateKind::StarLoopEntry,
343 11 => AtnStateKind::PlusLoopBack,
344 12 => AtnStateKind::LoopEnd,
345 other => return Err(AntlrError::Unsupported(format!("ATN state type {other}"))),
346 };
347 Ok(kind)
348}
349
350fn decode_transition(
352 target: usize,
353 kind: i32,
354 a: i32,
355 b: i32,
356 c: i32,
357 sets: &[IntervalSet],
358) -> Result<Transition, AntlrError> {
359 let transition = match kind {
360 1 => Transition::Epsilon { target },
361 2 => Transition::Range {
362 target,
363 start: if c != 0 { TOKEN_EOF } else { a },
364 stop: b,
365 },
366 3 => Transition::Rule {
367 target: read_index(a, "rule transition target")?,
368 rule_index: read_index(b, "rule transition rule index")?,
369 follow_state: target,
370 precedence: c,
371 },
372 4 => Transition::Predicate {
373 target,
374 rule_index: read_index(a, "predicate rule index")?,
375 pred_index: read_index(b, "predicate index")?,
376 context_dependent: c != 0,
377 },
378 5 => Transition::Atom {
379 target,
380 label: if c != 0 { TOKEN_EOF } else { a },
381 },
382 6 => Transition::Action {
383 target,
384 rule_index: read_index(a, "action rule index")?,
385 action_index: usize::try_from(b).ok(),
386 context_dependent: c != 0,
387 },
388 7 => Transition::Set {
389 target,
390 set: sets
391 .get(read_index(a, "set transition set index")?)
392 .cloned()
393 .ok_or_else(|| {
394 AntlrError::Unsupported(format!("set index {a} outside set list"))
395 })?,
396 },
397 8 => Transition::NotSet {
398 target,
399 set: sets
400 .get(read_index(a, "not-set transition set index")?)
401 .cloned()
402 .ok_or_else(|| {
403 AntlrError::Unsupported(format!("set index {a} outside set list"))
404 })?,
405 },
406 9 => Transition::Wildcard { target },
407 10 => Transition::Precedence {
408 target,
409 precedence: a,
410 },
411 other => {
412 return Err(AntlrError::Unsupported(format!(
413 "ATN transition type {other}"
414 )));
415 }
416 };
417 Ok(transition)
418}
419
420fn decode_lexer_action(
423 action_type: i32,
424 data1: i32,
425 data2: i32,
426) -> Result<LexerAction, AntlrError> {
427 let action = match action_type {
428 0 => LexerAction::Channel(data1),
429 1 => LexerAction::Custom {
430 rule_index: data1,
431 action_index: data2,
432 },
433 2 => LexerAction::Mode(data1),
434 3 => LexerAction::More,
435 4 => LexerAction::PopMode,
436 5 => LexerAction::PushMode(data1),
437 6 => LexerAction::Skip,
438 7 => LexerAction::Type(data1),
439 other => {
440 return Err(AntlrError::Unsupported(format!(
441 "lexer action type {other}"
442 )));
443 }
444 };
445 Ok(action)
446}
447
448fn mark_precedence_decisions(atn: &mut Atn) {
450 let mut decisions = Vec::new();
451 for state in atn.states() {
452 if state.kind != AtnStateKind::StarLoopEntry {
453 continue;
454 }
455 let Some(rule_index) = state.rule_index else {
456 continue;
457 };
458 let Some(rule_start) = atn
459 .rule_to_start_state()
460 .get(rule_index)
461 .and_then(|state_number| atn.state(*state_number))
462 else {
463 continue;
464 };
465 if !rule_start.left_recursive_rule {
466 continue;
467 }
468 let Some(loop_end_state) = state
469 .transitions
470 .last()
471 .and_then(|transition| atn.state(transition.target()))
472 else {
473 continue;
474 };
475 if loop_end_state.kind != AtnStateKind::LoopEnd {
476 continue;
477 }
478 let Some(target) = loop_end_state
479 .transitions
480 .first()
481 .and_then(|transition| atn.state(transition.target()))
482 else {
483 continue;
484 };
485 if target.kind == AtnStateKind::RuleStop {
486 decisions.push(state.state_number);
487 }
488 }
489
490 for state_number in decisions {
491 if let Some(state) = atn.state_mut(state_number) {
492 state.precedence_rule_decision = true;
493 }
494 }
495}
496
497fn read_index(value: i32, label: &str) -> Result<usize, AntlrError> {
500 usize::try_from(value)
501 .map_err(|_| AntlrError::Unsupported(format!("{label} cannot be negative: {value}")))
502}
503
504#[cfg(test)]
505mod tests {
506 use super::*;
507
508 #[test]
509 fn reads_small_parser_atn() {
510 let serialized = SerializedAtn::from_i32(&[
511 4, 1, 9, 2, 2, 0, 7, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 5, 42, 0, 0, 1, 0,
525 ]);
526 let atn = AtnDeserializer::new(&serialized)
527 .deserialize()
528 .expect("artificial parser ATN should deserialize");
529 assert_eq!(atn.grammar_type(), AtnType::Parser);
530 assert_eq!(atn.max_token_type(), 9);
531 assert_eq!(atn.states().len(), 2);
532 assert_eq!(atn.rule_to_start_state(), &[0]);
533 assert_eq!(atn.rule_to_stop_state(), &[1]);
534 assert_eq!(atn.decision_to_state(), &[0]);
535 }
536}