1use itertools::{concat, Itertools};
2use nonempty_collections::NEVec;
3use std::collections::HashSet;
4use std::fmt::{Display, Formatter};
5use thiserror::Error;
6use tracing::debug;
7
8use crate::algorithm_u::AlgorithmUError;
9use crate::algorithm_w::AlgorithmWError::UnknownVar;
10use crate::assumption::AssumptionSet;
11use crate::constraint::{eq_cst, explicit_cst, implicit_cst, Constraint, ConstraintsSolverError};
12use crate::language::{Language, Literal, Special, Var};
13use crate::r#type::PrimitiveType::Boolean;
14use crate::r#type::{fun1, MonomorphicType, PolymorphicType, PrimitiveType, TVar, Tuple};
15use crate::substitution::Substitutions;
16use crate::traits::{EnvironmentProvider, Substitutable};
17use crate::{language, InferState};
18
19#[derive(Debug, Error)]
20pub enum AlgorithmWError {
21 #[error(transparent)]
22 AlgorithmU(#[from] AlgorithmUError),
23 UnknownVar(NEVec<Var>),
24 #[error(transparent)]
25 FailedConstraints(#[from] ConstraintsSolverError),
26}
27
28#[derive(Debug, Error)]
29pub enum CompoundInferError<E> {
30 #[error(transparent)]
31 AlgoW(#[from] AlgorithmWError),
32 Both(AlgorithmWError, NEVec<E>),
33 Foreign(NEVec<E>)
34}
35
36struct AlgorithmW<'e> {
37 monomorphic_set: HashSet<TVar>,
38 env: &'e mut InferState,
39}
40
41type AWResult = Result<(AssumptionSet, Vec<Constraint>, MonomorphicType), AlgorithmWError>;
42
43impl<'e> AlgorithmW<'e> {
44 fn apply(&mut self, expression: &Language) -> AWResult {
45 match expression {
46 Language::Var(x) => self.apply_var(x),
47 Language::App(x) => self.apply_app(x),
48 Language::Lambda(x) => self.apply_lambda(x),
49 Language::Let(x) => self.apply_let(x),
50 Language::Special(x) => self.apply_special(x),
51 Language::Literal(x) => match x {
52 Literal::Integral(_) => Ok((
53 AssumptionSet::empty(),
54 vec![],
55 PrimitiveType::Integral.into(),
56 )),
57 Literal::Floating(_) => Ok((
58 AssumptionSet::empty(),
59 vec![],
60 PrimitiveType::Floating.into(),
61 )),
62 Literal::Tuple(vec) => self.apply_tuple(vec),
63 },
64 }
65 }
66
67 fn apply_var(&mut self, var: &Var) -> AWResult {
68 let fresh = self.env.new_var();
69 Ok((
70 AssumptionSet::single(var.clone(), fresh),
71 vec![],
72 fresh.into(),
73 ))
74 }
75
76 fn apply_app(&mut self, language::App { arg, func }: &language::App) -> AWResult {
77 let (as1, cs1, t1) = self.apply(func)?;
78 let (as2, cs2, t2) = self.apply(arg)?;
79 let tv = self.env.new_var();
80
81 Ok((
82 as1 + as2,
83 concat([cs1, cs2, vec![eq_cst(t1, fun1(t2, tv))]]),
84 tv.into(),
85 ))
86 }
87
88 fn apply_lambda(&mut self, language::Lambda { bind, expr }: &language::Lambda) -> AWResult {
89 let tv = self.env.new_var();
90 self.monomorphic_set.insert(tv);
91 let (as1, cs1, t1) = self.apply(expr)?;
92
93 let mut as_ = as1.clone();
94 as_.remove(&bind.var);
95 let eq_cs = as1
96 .get(&bind.var)
97 .iter()
98 .map(|it| eq_cst(tv, it.clone()))
99 .collect();
100 let bound = bind
101 .ty
102 .as_ref()
103 .map_or(vec![], |it| vec![eq_cst(tv, it.clone())]);
104
105 Ok((as_, concat([cs1, eq_cs, bound]), fun1(tv, t1)))
106 }
107
108 fn apply_let(
109 &mut self,
110 language::Let {
111 binder,
112 bind,
113 usage,
114 }: &language::Let,
115 ) -> AWResult {
116 let (as1, cs1, t1) = self.apply(binder)?;
117 let (as2, cs2, t2) = self.apply(usage)?;
118
119 let mut as_ = as1.clone() + &as2;
120 as_.remove(&bind.var);
121 let im_cs = as2
122 .get(&bind.var)
123 .iter()
124 .chain(as1.get(&bind.var).iter()) .map(|it| implicit_cst(it.clone(), self.monomorphic_set.clone(), t1.clone()))
126 .collect();
127 let bound = bind.ty.as_ref().map_or(vec![], |it| {
128 vec![implicit_cst(
129 it.clone(),
130 self.monomorphic_set.clone(),
131 t1.clone(),
132 )]
133 });
134
135 Ok((as_, concat([cs1, cs2, im_cs, bound]), t2))
136 }
137
138 fn apply_tuple(&mut self, tuple: &[Language]) -> AWResult {
139 let ctx: Vec<_> = tuple.iter().map(|it| self.apply(it)).try_collect()?;
140 let (a, c, t): (Vec<_>, Vec<_>, Vec<_>) = ctx.into_iter().multiunzip();
141 Ok((
142 AssumptionSet::merge_many(a),
143 c.into_iter().flatten().collect(),
144 Tuple(t).into(),
145 ))
146 }
147
148 fn apply_special(&mut self, special: &Special) -> AWResult {
149 match special {
150 Special::If {
151 condition,
152 body,
153 otherwise,
154 } => {
155 let (as1, cs1, t1) = self.apply(condition)?;
156 let (as2, cs2, t2) = self.apply(body)?;
157 let (as3, cs3, t3) = self.apply(otherwise)?;
158
159 Ok((
160 as1 + as2 + as3,
161 concat([
162 cs1,
163 cs2,
164 cs3,
165 vec![eq_cst(t1, Boolean), eq_cst(t2.clone(), t3)],
166 ]),
167 t2,
168 ))
169 }
170 }
171 }
172}
173
174impl Language {
175 fn infer_w<E>(
176 &self,
177 context: &mut AlgorithmW,
178 table: &impl EnvironmentProvider<Var, Error = E>,
179 ) -> Result<(Substitutions, MonomorphicType), CompoundInferError<E>> {
180 let (a, c, t) = context.apply(self)?;
181 let (errors, not_found, explicits) = a.keys().fold(
182 (Vec::new(), Vec::new(), Vec::new()),
183 |(mut errors, mut not_found, mut cst), next| {
184 match table.maybe_get(next) {
185 Ok(None) => not_found.push(next.clone()),
186 Ok(Some(s)) => cst.extend(
187 a.get(next)
188 .iter()
189 .map(|it| explicit_cst(it.clone(), s.clone().into_owned())),
190 ),
191 Err(e) => errors.push(e),
192 }
193 (errors, not_found, cst)
194 },
195 );
196
197 if let Some(not_found) = NEVec::from_vec(not_found) {
198 if let Some(errors) = NEVec::from_vec(errors) {
199 return Err(CompoundInferError::Both(UnknownVar(not_found), errors))
200 }
201 return Err(CompoundInferError::AlgoW(UnknownVar(not_found)))
202 } else if let Some(errors) = NEVec::from_vec(errors) {
203 return Err(CompoundInferError::Foreign(errors))
204 }
205
206 debug!("Inferred raw type and constraints: ");
207 debug!("{c:?} ++ {explicits:?}, {t}");
208 let substitutions = Constraint::solve(concat([c, explicits]), context.env)
209 .map_err(AlgorithmWError::FailedConstraints)?;
210 let t = t.substitute(&substitutions);
211 debug!("Inferred type and substitutions: ");
212 debug!("{}, {}", substitutions, t);
213 Ok((substitutions, t))
214 }
215
216 pub(crate) fn infer_with_env<E>(
217 &self,
218 context: &impl EnvironmentProvider<Var, Error = E>,
219 env: &mut InferState,
220 ) -> Result<PolymorphicType, CompoundInferError<E>> {
221 let mut ctx = AlgorithmW {
222 monomorphic_set: Default::default(),
223 env,
224 };
225 match self.infer_w(&mut ctx, context) {
226 Ok((s, t)) => Ok(t.substitute(&s).normalize()),
227 Err(e) => Err(e),
228 }
229 }
230
231 pub fn infer<E>(
232 &self,
233 table: &impl EnvironmentProvider<Var, Error = E>,
234 ) -> Result<PolymorphicType, CompoundInferError<E>> {
235 self.infer_with_env(table, &mut InferState::default())
236 }
237}
238
239impl Display for AlgorithmWError {
240 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
241 match self {
242 AlgorithmWError::AlgorithmU(x) => write!(f, "{x}"),
243 UnknownVar(vs) => write!(
244 f,
245 "Unknown references: [{}]",
246 vs.iter().into_iter().join(", ")
247 ),
248 AlgorithmWError::FailedConstraints(x) => write!(f, "{x}"),
249 }
250 }
251}