Skip to main content

mlsub/biunify/
mod.rs

1#[cfg(test)]
2mod reference;
3#[cfg(test)]
4mod tests;
5
6use std::convert::Infallible;
7use std::fmt::{self, Debug};
8use std::iter::once;
9
10use crate::auto::{Automaton, StateId};
11use crate::{Constructor, Label, Polarity};
12
13pub type Result<C> = std::result::Result<(), Error<C>>;
14
15#[derive(Debug)]
16pub struct Error<C: Constructor> {
17    pub stack: Vec<(C::Label, C, C)>,
18    pub constraint: (C, C),
19}
20
21pub(crate) enum CacheEntry<C: Constructor> {
22    Root,
23    RequiredBy {
24        label: C::Label,
25        pos: (StateId, C),
26        neg: (StateId, C),
27    },
28}
29
30impl<C: Constructor> Automaton<C> {
31    /// Solves a set of constraints t⁺ ≤ t⁻ where t⁺ and t⁻ are represented by the states `qp` and `qn`.
32    pub fn biunify(&mut self, qp: StateId, qn: StateId) -> Result<C> {
33        self.biunify_all(once((qp, qn)))
34    }
35
36    /// Solves a set of constraints t⁺ ≤ t⁻ where t⁺ and t⁻ are represented by the states `qp` and `qn`.
37    pub fn biunify_all<I>(&mut self, constraints: I) -> Result<C>
38    where
39        I: IntoIterator<Item = (StateId, StateId)>,
40    {
41        let mut stack = Vec::with_capacity(20);
42        stack.extend(constraints.into_iter().filter(|&constraint| {
43            self.biunify_cache
44                .insert(constraint, CacheEntry::Root)
45                .is_none()
46        }));
47        while let Some(constraint) = stack.pop() {
48            self.biunify_impl(&mut stack, constraint)?;
49        }
50        Ok(())
51    }
52
53    fn biunify_impl(
54        &mut self,
55        stack: &mut Vec<(StateId, StateId)>,
56        (qp, qn): (StateId, StateId),
57    ) -> Result<C> {
58        #[cfg(debug_assertions)]
59        debug_assert_eq!(self[qp].pol, Polarity::Pos);
60        #[cfg(debug_assertions)]
61        debug_assert_eq!(self[qn].pol, Polarity::Neg);
62        debug_assert!(self.biunify_cache.contains_key(&(qp, qn)));
63
64        for (cp, cn) in product(self[qp].cons.iter(), self[qn].cons.iter()) {
65            if !(cp <= cn) {
66                return Err(self.make_error((qp, cp.clone()), (qn, cn.clone())));
67            }
68        }
69        for to in self[qn].flow.iter() {
70            self.merge(Polarity::Pos, to, qp);
71        }
72        for from in self[qp].flow.iter() {
73            self.merge(Polarity::Neg, from, qn);
74        }
75
76        let states = &self.states;
77        let biunify_cache = &mut self.biunify_cache;
78        let cps = &states[qp.as_u32() as usize].cons;
79        let cns = &states[qn.as_u32() as usize].cons;
80        for (cp, cn) in cps.intersection(cns) {
81            cp.visit_params_intersection::<_, Infallible>(&cn, |label, l, r| {
82                let (ps, ns) = label.polarity().flip(l, r);
83                stack.extend(product(ps, ns).filter(|&constraint| {
84                    biunify_cache
85                        .insert(
86                            constraint,
87                            CacheEntry::RequiredBy {
88                                label: label.clone(),
89                                pos: (qp, cp.clone()),
90                                neg: (qn, cn.clone()),
91                            },
92                        )
93                        .is_none()
94                }));
95                Ok(())
96            })
97            .unwrap();
98        }
99        Ok(())
100    }
101
102    fn make_error(&self, pos: (StateId, C), neg: (StateId, C)) -> Error<C> {
103        let mut stack = Vec::new();
104
105        let mut key = (pos.0, neg.0);
106        while let CacheEntry::RequiredBy { label, pos, neg } = &self.biunify_cache[&key] {
107            stack.push((label.clone(), pos.1.clone(), neg.1.clone()));
108            key = (pos.0, neg.0);
109        }
110
111        Error {
112            stack,
113            constraint: (pos.1, neg.1),
114        }
115    }
116}
117
118fn product<I, J>(lhs: I, rhs: J) -> impl Iterator<Item = (I::Item, J::Item)>
119where
120    I: IntoIterator,
121    I::Item: Clone + Copy,
122    J: IntoIterator,
123    J: Clone,
124{
125    lhs.into_iter()
126        .flat_map(move |l| rhs.clone().into_iter().map(move |r| (l.clone(), r)))
127}
128
129impl<C> Debug for CacheEntry<C>
130where
131    C: Constructor + Debug,
132    C::Label: Debug,
133{
134    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
135        match self {
136            CacheEntry::Root => f.debug_struct("Root").finish(),
137            CacheEntry::RequiredBy { label, pos, neg } => f
138                .debug_struct("RequiredBy")
139                .field("label", label)
140                .field("pos", pos)
141                .field("neg", neg)
142                .finish(),
143        }
144    }
145}