1use super::{
51 small_set::{SmallSet, SmallSet256},
52 ArrayTransitionLabel, SimpleSlice,
53};
54use rsonpath_syntax::num::JsonUInt;
55use std::collections::HashMap;
56
57#[derive(Debug)]
58pub(super) struct ArrayTransitionSet {
59 transitions: HashMap<LinearSet, LinearSetTransition>,
60}
61
62#[derive(Debug)]
63struct LinearSetTransition {
64 priority: usize,
65 target: SmallSet256,
66}
67
68#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
69enum LinearSet {
70 Singleton(JsonUInt),
71 BoundedSlice(JsonUInt, JsonUInt, JsonUInt),
72 OpenEndedSlice(JsonUInt, JsonUInt),
73}
74
75pub(super) struct ArrayTransitionSetIterator {
76 transitions: std::vec::IntoIter<(LinearSet, LinearSetTransition)>,
77}
78
79impl ArrayTransitionSet {
80 pub(super) fn new() -> Self {
81 Self {
82 transitions: HashMap::new(),
83 }
84 }
85
86 pub(super) fn add_transition(&mut self, label: ArrayTransitionLabel, target: SmallSet256) {
87 use std::collections::hash_map::Entry;
88 let Some(label) = LinearSet::from_label(label) else {
89 return;
90 };
91 let overlaps: Vec<_> = self
92 .transitions
93 .iter()
94 .filter_map(|(other, trans)| {
95 let overlap = other.overlap_with(&label)?;
96 let priority = trans.priority + 1;
97 let mut overlap_target = target;
98 overlap_target.union(&trans.target);
99
100 Some((overlap, LinearSetTransition { priority, target }))
101 })
102 .collect();
103
104 for (label, trans) in overlaps {
105 match self.transitions.entry(label) {
106 Entry::Occupied(mut entry) => {
107 let entry = entry.get_mut();
108 entry.priority = std::cmp::max(entry.priority, trans.priority);
109 entry.target.union(&trans.target);
110 }
111 Entry::Vacant(entry) => {
112 entry.insert(trans);
113 }
114 }
115 }
116
117 match self.transitions.entry(label) {
118 Entry::Occupied(_) => (),
120 Entry::Vacant(entry) => {
121 entry.insert(LinearSetTransition { priority: 1, target });
122 }
123 }
124 }
125
126 pub(super) fn states_mut(&mut self) -> impl Iterator<Item = &mut SmallSet256> {
127 self.transitions.iter_mut().map(|(_, trans)| &mut trans.target)
128 }
129}
130
131impl ArrayTransitionSetIterator {
132 fn new(mut transitions: Vec<(LinearSet, LinearSetTransition)>) -> Self {
133 transitions.sort_by(|(_, x), (_, y)| x.priority.cmp(&y.priority).reverse());
134 Self {
135 transitions: transitions.into_iter(),
136 }
137 }
138}
139
140impl IntoIterator for ArrayTransitionSet {
141 type Item = (ArrayTransitionLabel, SmallSet256);
142
143 type IntoIter = ArrayTransitionSetIterator;
144
145 fn into_iter(self) -> Self::IntoIter {
146 ArrayTransitionSetIterator::new(self.transitions.into_iter().collect())
147 }
148}
149
150impl Iterator for ArrayTransitionSetIterator {
151 type Item = (ArrayTransitionLabel, SmallSet256);
152
153 fn next(&mut self) -> Option<Self::Item> {
154 let (label, transition) = self.transitions.next()?;
155 Some(match label {
156 LinearSet::Singleton(idx) => (ArrayTransitionLabel::Index(idx), transition.target),
157 LinearSet::BoundedSlice(start, end, step) => (
158 ArrayTransitionLabel::Slice(SimpleSlice::new(start, Some(end), step)),
159 transition.target,
160 ),
161 LinearSet::OpenEndedSlice(start, step) => (
162 ArrayTransitionLabel::Slice(SimpleSlice::new(start, None, step)),
163 transition.target,
164 ),
165 })
166 }
167}
168
169impl LinearSet {
170 fn from_label(label: ArrayTransitionLabel) -> Option<Self> {
171 match label {
172 ArrayTransitionLabel::Index(idx) => Some(Self::Singleton(idx)),
173 ArrayTransitionLabel::Slice(slice) => {
174 if slice.step == JsonUInt::ZERO {
175 None
176 } else if let Some(end) = slice.end {
177 if slice.start >= end {
178 None
179 } else if slice.start.as_u64().saturating_add(slice.step.as_u64()) >= end.as_u64() {
180 Some(Self::Singleton(slice.start))
182 } else {
183 debug_assert!(end > JsonUInt::ZERO);
184 Some(Self::BoundedSlice(slice.start, end, slice.step))
185 }
186 } else {
187 Some(Self::OpenEndedSlice(slice.start, slice.step))
188 }
189 }
190 }
191 }
192
193 fn from_slice(start: JsonUInt, end: Option<JsonUInt>, step: JsonUInt) -> Option<Self> {
194 if step == JsonUInt::ZERO {
195 None
196 } else if let Some(end) = end {
197 if start >= end {
198 None
199 } else if start.as_u64().saturating_add(step.as_u64()) >= end.as_u64() {
200 Some(Self::Singleton(start))
202 } else {
203 debug_assert!(end > JsonUInt::ZERO);
204 Some(Self::BoundedSlice(start, end, step))
205 }
206 } else {
207 Some(Self::OpenEndedSlice(start, step))
208 }
209 }
210
211 fn overlap_with(&self, other: &Self) -> Option<Self> {
212 if self.start() > other.start() {
214 return other.overlap_with(self);
215 }
216 assert_ne!(self.step().as_u64(), 0);
218 assert_ne!(other.step().as_u64(), 0);
219
220 let (first_element, gcd) = find_first_element(
228 self.start().into(),
229 self.step().into(),
230 other.start().into(),
231 other.step().into(),
232 )?;
233 let end = match (self.end_exclusive(), other.end_exclusive()) {
235 (None, Some(x)) | (Some(x), None) => Some(x),
236 (None, None) => None,
237 (Some(x), Some(y)) => Some(std::cmp::min(x, y)),
238 };
239 let common_step = (self.step().as_u64() / gcd).saturating_mul(other.step().as_u64());
243
244 let start = JsonUInt::try_from(first_element).ok()?;
245
246 return match JsonUInt::try_from(common_step).ok() {
247 Some(step) => Self::from_slice(start, end, step),
248 None if end.map_or(false, |end| end <= start) => None,
249 None => Some(Self::Singleton(start)),
250 };
251
252 fn find_first_element(a: i64, k: i64, b: i64, l: i64) -> Option<(i64, u64)> {
253 let c = umod(k - (b - a), k);
267 let (jumps, gcd) = solve_linear_congruence(l, c, k)?;
268 Some((jumps.checked_mul(l)?.checked_add(b)?, gcd))
269 }
270 }
271
272 fn start(&self) -> JsonUInt {
273 match self {
274 Self::Singleton(i) | Self::BoundedSlice(i, _, _) | Self::OpenEndedSlice(i, _) => *i,
275 }
276 }
277
278 fn end_exclusive(&self) -> Option<JsonUInt> {
279 match self {
280 Self::Singleton(i) => JsonUInt::try_from(i.as_u64() + 1).ok(),
281 Self::BoundedSlice(_, i, _) => Some(*i),
282 Self::OpenEndedSlice(_, _) => None,
283 }
284 }
285
286 fn step(&self) -> JsonUInt {
287 match self {
288 Self::Singleton(_) => JsonUInt::ONE,
289 Self::BoundedSlice(_, _, s) | Self::OpenEndedSlice(_, s) => *s,
290 }
291 }
292}
293
294fn umod(x: i64, m: i64) -> i64 {
299 assert!(m > 0);
300 let k = x % m;
301 if k < 0 {
302 m + k
303 } else {
304 k
305 }
306}
307
308fn solve_linear_congruence(a: i64, b: i64, m: i64) -> Option<(i64, u64)> {
311 let b = umod(b, m);
318 let (x, gcd) = extended_euclid(a, m);
319
320 if b % gcd != 0 {
321 None
322 } else {
323 Some((
324 umod(x.checked_mul(b / gcd)?, m / gcd),
325 u64::try_from(gcd).expect("negative gcd"),
326 ))
327 }
328}
329
330fn extended_euclid(a: i64, b: i64) -> (i64, i64) {
332 let (mut old_r, mut r) = (a, b);
333 let (mut old_x, mut x) = (1, 0);
334
335 while r != 0 {
336 let quotient = old_r / r;
337 (old_r, r) = (r, old_r - quotient * r);
338 (old_x, x) = (x, old_x - quotient * x);
339 }
340
341 (old_x, old_r)
342}
343
344#[cfg(test)]
345mod tests {
346 use test_case::test_case;
347
348 use super::LinearSet;
349
350 #[test_case(1, 1 => (0, 1))]
351 #[test_case(4, 10 => (-2, 2))]
352 #[test_case(7, 10 => (3, 1))]
353 #[test_case(8, 10 => (-1, 2))]
354 #[test_case(161, 28 => (-1, 7))]
355 fn extended_euclid_tests(a: i64, b: i64) -> (i64, i64) {
356 super::extended_euclid(a, b)
357 }
358
359 #[test_case(7, 3, 10 => Some((9, 1)))]
360 #[test_case(7, 8, 10 => Some((4, 1)))]
361 #[test_case(8, 3, 10 => None)]
362 #[test_case(8, 2, 10 => Some((4, 2)))]
363 #[test_case(94_253_004_627_829, 666_084_837_845, 888_777_666_555_119 => Some((2_412_193, 121_216_531)))]
364 #[test_case(6_253_004_621, 2_156_208_490, 27_815_089_521 => Some((116, 215_620_849)))]
365 fn linear_congruence_tests(a: i64, b: i64, m: i64) -> Option<(i64, u64)> {
366 super::solve_linear_congruence(a, b, m)
367 }
368
369 #[test_case(LinearSet::Singleton(1.into()), LinearSet::Singleton(1.into()) => Some(LinearSet::Singleton(1.into())))]
370 #[test_case(LinearSet::Singleton(1.into()), LinearSet::Singleton(2.into()) => None)]
371 #[test_case(
372 LinearSet::Singleton(3.into()),
373 LinearSet::BoundedSlice(3.into(), 15.into(), 2.into())
374 => Some(LinearSet::Singleton(3.into())))]
375 #[test_case(
376 LinearSet::Singleton(5.into()),
377 LinearSet::BoundedSlice(3.into(), 15.into(), 2.into())
378 => Some(LinearSet::Singleton(5.into())))]
379 #[test_case(
380 LinearSet::Singleton(15.into()),
381 LinearSet::BoundedSlice(3.into(), 15.into(), 2.into())
382 => None)]
383 #[test_case(
384 LinearSet::BoundedSlice(3.into(), 15.into(), 2.into()),
385 LinearSet::BoundedSlice(3.into(), 15.into(), 2.into())
386 => Some(LinearSet::BoundedSlice(3.into(), 15.into(), 2.into())))]
387 #[test_case(
388 LinearSet::BoundedSlice(5.into(), 1024.into(), 7.into()),
389 LinearSet::BoundedSlice(3.into(), 911.into(), 10.into())
390 => Some(LinearSet::BoundedSlice(33.into(), 911.into(), 70.into())))]
391 #[test_case(
392 LinearSet::OpenEndedSlice(5.into(), 7.into()),
393 LinearSet::OpenEndedSlice(3.into(), 10.into())
394 => Some(LinearSet::OpenEndedSlice(33.into(), 70.into())))]
395 #[test_case(
396 LinearSet::OpenEndedSlice(5.into(), 8.into()),
397 LinearSet::OpenEndedSlice(3.into(), 10.into())
398 => Some(LinearSet::OpenEndedSlice(13.into(), 40.into())))]
399 #[test_case(
400 LinearSet::OpenEndedSlice(156_208_490.try_into().unwrap(), 6_253_004_621_u64.try_into().unwrap()),
401 LinearSet::OpenEndedSlice(4_253_004_621_u64.try_into().unwrap(), 27_815_089_521_u64.try_into().unwrap())
402 => Some(LinearSet::OpenEndedSlice(87_698_273_184_u64.try_into().unwrap(), 806_637_596_109_u64.try_into().unwrap())))]
403 #[test_case(
404 LinearSet::OpenEndedSlice(666_123_456_789_u64.try_into().unwrap(), 888_777_666_555_119_u64.try_into().unwrap()),
405 LinearSet::OpenEndedSlice(888_777_705_174_063_u64.try_into().unwrap(), 94_253_004_627_829_u64.try_into().unwrap())
406 => None)]
407 fn overlap_tests(a: LinearSet, b: LinearSet) -> Option<LinearSet> {
408 a.overlap_with(&b)
409 }
410}