1use std::collections::{HashSet, VecDeque};
2use std::fmt::{Debug, Display, Formatter};
3
4use derive_more::Display;
5use itertools::Either::{Left, Right};
6use itertools::{Either, Itertools};
7use thiserror::Error;
8
9use Constraint::{ExplicitInstance, ImplicitInstance};
10use ConstraintsSolverError::AlgorithmU;
11
12use crate::algorithm_u::AlgorithmUError;
13use crate::constraint::Constraint::Eq;
14use crate::constraint::ConstraintsSolverError::Ambiguous;
15use crate::r#type::{MonomorphicType, PolymorphicType, TVar};
16use crate::substitution::Substitutions;
17use crate::traits::{ActiveTVars, FreeTypeVars, Substitutable};
18use crate::InferState;
19
20#[derive(Debug, Error)]
21pub enum ConstraintsSolverError {
22 #[error(transparent)]
23 AlgorithmU(#[from] AlgorithmUError),
24 Ambiguous(Vec<Constraint>),
25}
26
27#[derive(Debug, PartialEq, Clone, Display)]
28#[display(fmt = "{t1} ≡ {t2}")]
29pub struct EqConstraint {
30 pub t1: MonomorphicType,
31 pub t2: MonomorphicType,
32}
33
34#[derive(PartialEq, Clone)]
36pub enum Constraint {
37 Eq(EqConstraint),
39 ExplicitInstance {
41 t: MonomorphicType,
42 s: PolymorphicType,
43 },
44 ImplicitInstance {
46 t1: MonomorphicType,
47 ctx: HashSet<TVar>,
48 t2: MonomorphicType,
49 },
50}
51
52impl Display for ConstraintsSolverError {
53 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
54 match self {
55 AlgorithmU(x) => write!(f, "{x}")?,
56 Ambiguous(x) => {
57 for item in x {
58 match item {
59 Eq(EqConstraint { t1, t2 }) => {
60 write!(
61 f,
62 "Cannot match expected type `{t1}` with actual type `{t2}`"
63 )?;
64 }
65 ExplicitInstance { t, s } => {
66 write!(f, "Cannot match instance `{t}` of type `{s}`")?;
67 }
68 ImplicitInstance { t1, ctx, t2 } => {
69 write!(f, "Cannot match expected type `{t1}` with generalization of type `{t2}` in context {{{}}}", ctx.iter().join(", "))?;
70 }
71 }
72 }
73 }
74 }
75 Ok(())
76 }
77}
78
79impl Constraint {
80 fn solvable(c: &Constraint, cs: &VecDeque<Constraint>) -> bool {
81 match c {
82 Eq(EqConstraint { .. }) => true,
83 ExplicitInstance { .. } => true,
84 ImplicitInstance { ctx, t2, .. } => {
85 let v1 = &t2.free_types() - ctx;
86 let active = cs.active_vars();
87 (&v1 & &active).is_empty()
88 }
89 }
90 }
91
92 fn solve_pair(
93 c: Constraint,
94 env: &mut InferState,
95 ) -> Result<Either<Substitutions, Constraint>, ConstraintsSolverError> {
96 match c {
97 Eq(EqConstraint { t1, t2 }) => Ok(Left(t1.unify(&t2)?)),
98 ExplicitInstance { t, s } => {
99 let t2 = s.instantiate(env);
100 Ok(Right(Eq(EqConstraint { t1: t, t2 })))
101 }
102 ImplicitInstance { t1, ctx, t2 } => {
103 let s = t2.generalize(&ctx);
104 Ok(Right(ExplicitInstance { t: t1, s }))
105 }
106 }
107 }
108
109 pub(crate) fn solve(
110 constraints: Vec<Constraint>,
111 env: &mut InferState,
112 ) -> Result<Substitutions, ConstraintsSolverError> {
113 let mut cs = VecDeque::from(constraints);
114 let mut s0 = Substitutions::empty();
115
116 while let Some(c) = cs.pop_back() {
118 if Self::solvable(&c, &cs) {
119 match Self::solve_pair(c, env)? {
120 Left(s) => {
121 cs = cs.make_contiguous().substitute(&s).into();
122 s0 = s0 + s;
123 }
124 Right(c) => cs.push_front(c),
125 }
126 } else {
127 cs.push_front(c)
128 }
129 }
130 Ok(s0)
131 }
132}
133
134impl Display for Constraint {
135 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
136 match self {
137 Eq(x) => write!(f, "{x}"),
138 ExplicitInstance { t, s } => write!(f, "{t} ≼ ({s})"),
139 ImplicitInstance { t1, ctx, t2 } => {
140 write!(f, "{t1} ≤{{{}}} {t2}", ctx.iter().join(", "))
141 }
142 }
143 }
144}
145
146impl Debug for Constraint {
147 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
148 write!(f, "{}", self)
149 }
150}
151
152pub fn eq_cst(t1: impl Into<MonomorphicType>, t2: impl Into<MonomorphicType>) -> Constraint {
153 Eq(EqConstraint {
154 t1: t1.into(),
155 t2: t2.into(),
156 })
157}
158
159pub fn implicit_cst(
160 t1: impl Into<MonomorphicType>,
161 ctx: impl Into<HashSet<TVar>>,
162 t2: impl Into<MonomorphicType>,
163) -> Constraint {
164 ImplicitInstance {
165 t1: t1.into(),
166 ctx: ctx.into(),
167 t2: t2.into(),
168 }
169}
170
171pub fn explicit_cst(t: impl Into<MonomorphicType>, s: impl Into<PolymorphicType>) -> Constraint {
172 ExplicitInstance {
173 t: t.into(),
174 s: s.into(),
175 }
176}
177
178#[cfg(test)]
179#[allow(clippy::unwrap_used)]
180mod tests {
181 use crate::constraint::{eq_cst, implicit_cst, Constraint};
182 use crate::r#type::MonomorphicType::Var;
183 use crate::r#type::PrimitiveType::Boolean;
184 use crate::r#type::{fun1, TVar};
185 use crate::substitution::Substitutions;
186 use crate::InferState;
187
188 #[test]
189 fn test_1() {
190 let mut env = InferState::default();
191 let [t1, t2, t3, t4, t5] = [1, 2, 3, 4, 5].map(TVar);
192 env.variable_index = 6;
193 let cs = vec![
194 eq_cst(t2, fun1(Boolean, t3)),
195 implicit_cst(t4, [t5], t3),
196 implicit_cst(t2, [t5], t1),
197 eq_cst(t5, t1),
198 ];
199
200 let result = Constraint::solve(cs, &mut env).unwrap();
201 assert_eq!(
202 result,
203 Substitutions::from_iter([
204 (t4, Var(t3)),
205 (t1, fun1(Boolean, t3)),
206 (t5, fun1(Boolean, t3)),
207 (t2, fun1(Boolean, t3))
208 ])
209 )
210 }
211
212 #[test]
213 fn test_2() {
214 let mut env = InferState::default();
215 let [t0, t1, t2, t3, t4] = [0, 1, 2, 3, 4].map(TVar);
216 env.variable_index = 5;
217
218 let cs = vec![
219 eq_cst(t1, fun1(t2, t3)),
220 implicit_cst(t4, [t0], t3),
221 implicit_cst(t2, [t0], t3),
222 eq_cst(t0, t1)
223 ];
224
225 let result = Constraint::solve(cs, &mut env).unwrap();
226 assert_eq!(result, Substitutions::from_iter([
227 (t0, fun1(t3, t3)),
228 (t2, t3.into()),
229 (t4, t3.into()),
230 (t1, fun1(t3, t3))
231 ]));
232 }
233}