1#[cfg(test)]
2mod reference;
3#[cfg(test)]
4mod tests;
5
6use std::convert::Infallible;
7use std::fmt::{self, Debug};
8use std::iter::once;
9
10use crate::auto::{Automaton, StateId};
11use crate::{Constructor, Label, Polarity};
12
13pub type Result<C> = std::result::Result<(), Error<C>>;
14
15#[derive(Debug)]
16pub struct Error<C: Constructor> {
17 pub stack: Vec<(C::Label, C, C)>,
18 pub constraint: (C, C),
19}
20
21pub(crate) enum CacheEntry<C: Constructor> {
22 Root,
23 RequiredBy {
24 label: C::Label,
25 pos: (StateId, C),
26 neg: (StateId, C),
27 },
28}
29
30impl<C: Constructor> Automaton<C> {
31 pub fn biunify(&mut self, qp: StateId, qn: StateId) -> Result<C> {
33 self.biunify_all(once((qp, qn)))
34 }
35
36 pub fn biunify_all<I>(&mut self, constraints: I) -> Result<C>
38 where
39 I: IntoIterator<Item = (StateId, StateId)>,
40 {
41 let mut stack = Vec::with_capacity(20);
42 stack.extend(constraints.into_iter().filter(|&constraint| {
43 self.biunify_cache
44 .insert(constraint, CacheEntry::Root)
45 .is_none()
46 }));
47 while let Some(constraint) = stack.pop() {
48 self.biunify_impl(&mut stack, constraint)?;
49 }
50 Ok(())
51 }
52
53 fn biunify_impl(
54 &mut self,
55 stack: &mut Vec<(StateId, StateId)>,
56 (qp, qn): (StateId, StateId),
57 ) -> Result<C> {
58 #[cfg(debug_assertions)]
59 debug_assert_eq!(self[qp].pol, Polarity::Pos);
60 #[cfg(debug_assertions)]
61 debug_assert_eq!(self[qn].pol, Polarity::Neg);
62 debug_assert!(self.biunify_cache.contains_key(&(qp, qn)));
63
64 for (cp, cn) in product(self[qp].cons.iter(), self[qn].cons.iter()) {
65 if !(cp <= cn) {
66 return Err(self.make_error((qp, cp.clone()), (qn, cn.clone())));
67 }
68 }
69 for to in self[qn].flow.iter() {
70 self.merge(Polarity::Pos, to, qp);
71 }
72 for from in self[qp].flow.iter() {
73 self.merge(Polarity::Neg, from, qn);
74 }
75
76 let states = &self.states;
77 let biunify_cache = &mut self.biunify_cache;
78 let cps = &states[qp.as_u32() as usize].cons;
79 let cns = &states[qn.as_u32() as usize].cons;
80 for (cp, cn) in cps.intersection(cns) {
81 cp.visit_params_intersection::<_, Infallible>(&cn, |label, l, r| {
82 let (ps, ns) = label.polarity().flip(l, r);
83 stack.extend(product(ps, ns).filter(|&constraint| {
84 biunify_cache
85 .insert(
86 constraint,
87 CacheEntry::RequiredBy {
88 label: label.clone(),
89 pos: (qp, cp.clone()),
90 neg: (qn, cn.clone()),
91 },
92 )
93 .is_none()
94 }));
95 Ok(())
96 })
97 .unwrap();
98 }
99 Ok(())
100 }
101
102 fn make_error(&self, pos: (StateId, C), neg: (StateId, C)) -> Error<C> {
103 let mut stack = Vec::new();
104
105 let mut key = (pos.0, neg.0);
106 while let CacheEntry::RequiredBy { label, pos, neg } = &self.biunify_cache[&key] {
107 stack.push((label.clone(), pos.1.clone(), neg.1.clone()));
108 key = (pos.0, neg.0);
109 }
110
111 Error {
112 stack,
113 constraint: (pos.1, neg.1),
114 }
115 }
116}
117
118fn product<I, J>(lhs: I, rhs: J) -> impl Iterator<Item = (I::Item, J::Item)>
119where
120 I: IntoIterator,
121 I::Item: Clone + Copy,
122 J: IntoIterator,
123 J: Clone,
124{
125 lhs.into_iter()
126 .flat_map(move |l| rhs.clone().into_iter().map(move |r| (l.clone(), r)))
127}
128
129impl<C> Debug for CacheEntry<C>
130where
131 C: Constructor + Debug,
132 C::Label: Debug,
133{
134 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
135 match self {
136 CacheEntry::Root => f.debug_struct("Root").finish(),
137 CacheEntry::RequiredBy { label, pos, neg } => f
138 .debug_struct("RequiredBy")
139 .field("label", label)
140 .field("pos", pos)
141 .field("neg", neg)
142 .finish(),
143 }
144 }
145}