1use std::collections::BTreeSet;
2use std::rc::Rc;
3
4pub const EMPTY_RETURN_STATE: usize = usize::MAX;
5
6#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
7pub enum PredictionContext {
8 Empty,
9 Singleton {
10 parent: Rc<Self>,
11 return_state: usize,
12 },
13 Array {
14 parents: Vec<Rc<Self>>,
15 return_states: Vec<usize>,
16 },
17}
18
19impl PredictionContext {
20 pub fn empty() -> Rc<Self> {
21 Rc::new(Self::Empty)
22 }
23
24 pub fn singleton(parent: Rc<Self>, return_state: usize) -> Rc<Self> {
25 if return_state == EMPTY_RETURN_STATE {
26 Self::empty()
27 } else {
28 Rc::new(Self::Singleton {
29 parent,
30 return_state,
31 })
32 }
33 }
34
35 pub const fn len(&self) -> usize {
36 match self {
37 Self::Empty => 1,
38 Self::Singleton { .. } => 1,
39 Self::Array { return_states, .. } => return_states.len(),
40 }
41 }
42
43 pub const fn is_empty(&self) -> bool {
44 matches!(self, Self::Empty)
45 }
46
47 pub fn return_state(&self, index: usize) -> Option<usize> {
48 match self {
49 Self::Empty if index == 0 => Some(EMPTY_RETURN_STATE),
50 Self::Singleton { return_state, .. } if index == 0 => Some(*return_state),
51 Self::Array { return_states, .. } => return_states.get(index).copied(),
52 Self::Empty => None,
53 Self::Singleton { .. } => None,
54 }
55 }
56
57 pub fn parent(&self, index: usize) -> Option<Rc<Self>> {
58 match self {
59 Self::Empty => None,
60 Self::Singleton { parent, .. } if index == 0 => Some(Rc::clone(parent)),
61 Self::Array { parents, .. } => parents.get(index).cloned(),
62 Self::Singleton { .. } => None,
63 }
64 }
65
66 pub fn merge(left: Rc<Self>, right: Rc<Self>) -> Rc<Self> {
73 if left == right {
74 return left;
75 }
76 let mut entries = Vec::new();
77 collect_entries(&left, &mut entries);
78 collect_entries(&right, &mut entries);
79 drop((left, right));
80 entries.sort_by(|(left_parent, left_return), (right_parent, right_return)| {
81 left_return
82 .cmp(right_return)
83 .then_with(|| left_parent.cmp(right_parent))
84 });
85 entries.dedup_by(|a, b| a.1 == b.1 && a.0 == b.0);
86 Rc::new(Self::Array {
87 parents: entries
88 .iter()
89 .map(|(parent, _)| Rc::clone(parent))
90 .collect(),
91 return_states: entries
92 .iter()
93 .map(|(_, return_state)| *return_state)
94 .collect(),
95 })
96 }
97}
98
99fn collect_entries(
100 context: &Rc<PredictionContext>,
101 entries: &mut Vec<(Rc<PredictionContext>, usize)>,
102) {
103 match context.as_ref() {
104 PredictionContext::Empty => entries.push((Rc::clone(context), EMPTY_RETURN_STATE)),
105 PredictionContext::Singleton {
106 parent,
107 return_state,
108 } => entries.push((Rc::clone(parent), *return_state)),
109 PredictionContext::Array {
110 parents,
111 return_states,
112 } => {
113 for (parent, return_state) in parents.iter().zip(return_states) {
114 entries.push((Rc::clone(parent), *return_state));
115 }
116 }
117 }
118}
119
120#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
121pub struct AtnConfig {
122 pub state: usize,
123 pub alt: usize,
124 pub context: Rc<PredictionContext>,
125 pub reaches_into_outer_context: usize,
126}
127
128impl AtnConfig {
129 pub const fn new(state: usize, alt: usize, context: Rc<PredictionContext>) -> Self {
130 Self {
131 state,
132 alt,
133 context,
134 reaches_into_outer_context: 0,
135 }
136 }
137}
138
139#[derive(Clone, Debug, Default, Eq, Ord, PartialEq, PartialOrd)]
140pub struct AtnConfigSet {
141 configs: Vec<AtnConfig>,
142 config_index: BTreeSet<AtnConfig>,
143 has_semantic_context: bool,
144 dips_into_outer_context: bool,
145 readonly: bool,
146}
147
148impl AtnConfigSet {
149 pub fn new() -> Self {
150 Self::default()
151 }
152
153 pub fn add(&mut self, config: AtnConfig) -> bool {
156 assert!(!self.readonly, "cannot mutate readonly ATN config set");
157 if self.config_index.insert(config.clone()) {
158 if config.reaches_into_outer_context > 0 {
159 self.dips_into_outer_context = true;
160 }
161 self.configs.push(config);
162 true
163 } else {
164 false
165 }
166 }
167
168 pub fn configs(&self) -> &[AtnConfig] {
169 &self.configs
170 }
171
172 pub const fn is_empty(&self) -> bool {
173 self.configs.is_empty()
174 }
175
176 pub const fn len(&self) -> usize {
177 self.configs.len()
178 }
179
180 pub const fn set_readonly(&mut self, readonly: bool) {
181 self.readonly = readonly;
182 }
183
184 pub const fn has_semantic_context(&self) -> bool {
185 self.has_semantic_context
186 }
187
188 pub const fn set_has_semantic_context(&mut self, value: bool) {
189 self.has_semantic_context = value;
190 }
191
192 pub const fn dips_into_outer_context(&self) -> bool {
193 self.dips_into_outer_context
194 }
195}
196
197#[cfg(test)]
198mod tests {
199 use super::*;
200
201 #[test]
202 fn config_set_deduplicates_configs() {
203 let empty = PredictionContext::empty();
204 let mut set = AtnConfigSet::new();
205 assert!(set.add(AtnConfig::new(1, 1, Rc::clone(&empty))));
206 assert!(!set.add(AtnConfig::new(1, 1, Rc::clone(&empty))));
207 assert_eq!(set.len(), 1);
208 }
209
210 #[test]
211 fn singleton_context_reports_parent_and_return_state() {
212 let empty = PredictionContext::empty();
213 let context = PredictionContext::singleton(Rc::clone(&empty), 42);
214 assert_eq!(context.return_state(0), Some(42));
215 assert_eq!(context.parent(0), Some(empty));
216 }
217
218 #[test]
219 fn merge_with_empty_preserves_non_empty_return_state() {
220 let empty = PredictionContext::empty();
221 let singleton = PredictionContext::singleton(Rc::clone(&empty), 42);
222
223 let merged = PredictionContext::merge(Rc::clone(&singleton), Rc::clone(&empty));
224
225 assert_eq!(merged.len(), 2);
226 assert_eq!(merged.return_state(0), Some(42));
227 assert_eq!(merged.parent(0), Some(empty.clone()));
228 assert_eq!(merged.return_state(1), Some(EMPTY_RETURN_STATE));
229 assert_eq!(merged.parent(1), Some(empty));
230 }
231
232 #[test]
233 fn merge_deduplicates_entries_with_same_parent_and_return_state() {
234 let empty = PredictionContext::empty();
235 let parent_one = PredictionContext::singleton(Rc::clone(&empty), 1);
236 let parent_two = PredictionContext::singleton(Rc::clone(&empty), 2);
237 let left = Rc::new(PredictionContext::Array {
238 parents: vec![Rc::clone(&parent_one), parent_two],
239 return_states: vec![42, 42],
240 });
241 let right = PredictionContext::singleton(Rc::clone(&parent_one), 42);
242
243 let merged = PredictionContext::merge(left, right);
244
245 assert_eq!(merged.len(), 2);
246 }
247}