kodept_inference/
constraint.rs

1use std::collections::{HashSet, VecDeque};
2use std::fmt::{Debug, Display, Formatter};
3
4use derive_more::Display;
5use itertools::Either::{Left, Right};
6use itertools::{Either, Itertools};
7use thiserror::Error;
8
9use Constraint::{ExplicitInstance, ImplicitInstance};
10use ConstraintsSolverError::AlgorithmU;
11
12use crate::algorithm_u::AlgorithmUError;
13use crate::constraint::Constraint::Eq;
14use crate::constraint::ConstraintsSolverError::Ambiguous;
15use crate::r#type::{MonomorphicType, PolymorphicType, TVar};
16use crate::substitution::Substitutions;
17use crate::traits::{ActiveTVars, FreeTypeVars, Substitutable};
18use crate::InferState;
19
20#[derive(Debug, Error)]
21pub enum ConstraintsSolverError {
22    #[error(transparent)]
23    AlgorithmU(#[from] AlgorithmUError),
24    Ambiguous(Vec<Constraint>),
25}
26
27#[derive(Debug, PartialEq, Clone, Display)]
28#[display(fmt = "{t1} ≡ {t2}")]
29pub struct EqConstraint {
30    pub t1: MonomorphicType,
31    pub t2: MonomorphicType,
32}
33
34/// Types of constraints used in algorithm W
35#[derive(PartialEq, Clone)]
36pub enum Constraint {
37    /// t1 should be unified with t2
38    Eq(EqConstraint),
39    /// t should be an instance of s
40    ExplicitInstance {
41        t: MonomorphicType,
42        s: PolymorphicType,
43    },
44    /// t1 should be an instance of generalize(t2, ctx)
45    ImplicitInstance {
46        t1: MonomorphicType,
47        ctx: HashSet<TVar>,
48        t2: MonomorphicType,
49    },
50}
51
52impl Display for ConstraintsSolverError {
53    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
54        match self {
55            AlgorithmU(x) => write!(f, "{x}")?,
56            Ambiguous(x) => {
57                for item in x {
58                    match item {
59                        Eq(EqConstraint { t1, t2 }) => {
60                            write!(
61                                f,
62                                "Cannot match expected type `{t1}` with actual type `{t2}`"
63                            )?;
64                        }
65                        ExplicitInstance { t, s } => {
66                            write!(f, "Cannot match instance `{t}` of type `{s}`")?;
67                        }
68                        ImplicitInstance { t1, ctx, t2 } => {
69                            write!(f, "Cannot match expected type `{t1}` with generalization of type `{t2}` in context {{{}}}", ctx.iter().join(", "))?;
70                        }
71                    }
72                }
73            }
74        }
75        Ok(())
76    }
77}
78
79impl Constraint {
80    fn solvable(c: &Constraint, cs: &VecDeque<Constraint>) -> bool {
81        match c {
82            Eq(EqConstraint { .. }) => true,
83            ExplicitInstance { .. } => true,
84            ImplicitInstance { ctx, t2, .. } => {
85                let v1 = &t2.free_types() - ctx;
86                let active = cs.active_vars();
87                (&v1 & &active).is_empty()
88            }
89        }
90    }
91
92    fn solve_pair(
93        c: Constraint,
94        env: &mut InferState,
95    ) -> Result<Either<Substitutions, Constraint>, ConstraintsSolverError> {
96        match c {
97            Eq(EqConstraint { t1, t2 }) => Ok(Left(t1.unify(&t2)?)),
98            ExplicitInstance { t, s } => {
99                let t2 = s.instantiate(env);
100                Ok(Right(Eq(EqConstraint { t1: t, t2 })))
101            }
102            ImplicitInstance { t1, ctx, t2 } => {
103                let s = t2.generalize(&ctx);
104                Ok(Right(ExplicitInstance { t: t1, s }))
105            }
106        }
107    }
108
109    pub(crate) fn solve(
110        constraints: Vec<Constraint>,
111        env: &mut InferState,
112    ) -> Result<Substitutions, ConstraintsSolverError> {
113        let mut cs = VecDeque::from(constraints);
114        let mut s0 = Substitutions::empty();
115
116        // solver should always find suitable constraint to solve
117        while let Some(c) = cs.pop_back() {
118            if Self::solvable(&c, &cs) {
119                match Self::solve_pair(c, env)? {
120                    Left(s) => {
121                        cs = cs.make_contiguous().substitute(&s).into();
122                        s0 = s0 + s;
123                    }
124                    Right(c) => cs.push_front(c),
125                }
126            } else {
127                cs.push_front(c)
128            }
129        }
130        Ok(s0)
131    }
132}
133
134impl Display for Constraint {
135    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
136        match self {
137            Eq(x) => write!(f, "{x}"),
138            ExplicitInstance { t, s } => write!(f, "{t} ≼ ({s})"),
139            ImplicitInstance { t1, ctx, t2 } => {
140                write!(f, "{t1} ≤{{{}}} {t2}", ctx.iter().join(", "))
141            }
142        }
143    }
144}
145
146impl Debug for Constraint {
147    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
148        write!(f, "{}", self)
149    }
150}
151
152pub fn eq_cst(t1: impl Into<MonomorphicType>, t2: impl Into<MonomorphicType>) -> Constraint {
153    Eq(EqConstraint {
154        t1: t1.into(),
155        t2: t2.into(),
156    })
157}
158
159pub fn implicit_cst(
160    t1: impl Into<MonomorphicType>,
161    ctx: impl Into<HashSet<TVar>>,
162    t2: impl Into<MonomorphicType>,
163) -> Constraint {
164    ImplicitInstance {
165        t1: t1.into(),
166        ctx: ctx.into(),
167        t2: t2.into(),
168    }
169}
170
171pub fn explicit_cst(t: impl Into<MonomorphicType>, s: impl Into<PolymorphicType>) -> Constraint {
172    ExplicitInstance {
173        t: t.into(),
174        s: s.into(),
175    }
176}
177
178#[cfg(test)]
179#[allow(clippy::unwrap_used)]
180mod tests {
181    use crate::constraint::{eq_cst, implicit_cst, Constraint};
182    use crate::r#type::MonomorphicType::Var;
183    use crate::r#type::PrimitiveType::Boolean;
184    use crate::r#type::{fun1, TVar};
185    use crate::substitution::Substitutions;
186    use crate::InferState;
187
188    #[test]
189    fn test_1() {
190        let mut env = InferState::default();
191        let [t1, t2, t3, t4, t5] = [1, 2, 3, 4, 5].map(TVar);
192        env.variable_index = 6;
193        let cs = vec![
194            eq_cst(t2, fun1(Boolean, t3)),
195            implicit_cst(t4, [t5], t3),
196            implicit_cst(t2, [t5], t1),
197            eq_cst(t5, t1),
198        ];
199
200        let result = Constraint::solve(cs, &mut env).unwrap();
201        assert_eq!(
202            result,
203            Substitutions::from_iter([
204                (t4, Var(t3)),
205                (t1, fun1(Boolean, t3)),
206                (t5, fun1(Boolean, t3)),
207                (t2, fun1(Boolean, t3))
208            ])
209        )
210    }
211    
212    #[test]
213    fn test_2() {
214        let mut env = InferState::default();
215        let [t0, t1, t2, t3, t4] = [0, 1, 2, 3, 4].map(TVar);
216        env.variable_index = 5;
217        
218        let cs = vec![
219            eq_cst(t1, fun1(t2, t3)),
220            implicit_cst(t4, [t0], t3),
221            implicit_cst(t2, [t0], t3),
222            eq_cst(t0, t1)
223        ];
224        
225        let result = Constraint::solve(cs, &mut env).unwrap();
226        assert_eq!(result, Substitutions::from_iter([
227            (t0, fun1(t3, t3)),
228            (t2, t3.into()),
229            (t4, t3.into()),
230            (t1, fun1(t3, t3))
231        ]));
232    }
233}