kodept_inference/
algorithm_w.rs

1use itertools::{concat, Itertools};
2use nonempty_collections::NEVec;
3use std::collections::HashSet;
4use std::fmt::{Display, Formatter};
5use thiserror::Error;
6use tracing::debug;
7
8use crate::algorithm_u::AlgorithmUError;
9use crate::algorithm_w::AlgorithmWError::UnknownVar;
10use crate::assumption::AssumptionSet;
11use crate::constraint::{eq_cst, explicit_cst, implicit_cst, Constraint, ConstraintsSolverError};
12use crate::language::{Language, Literal, Special, Var};
13use crate::r#type::PrimitiveType::Boolean;
14use crate::r#type::{fun1, MonomorphicType, PolymorphicType, PrimitiveType, TVar, Tuple};
15use crate::substitution::Substitutions;
16use crate::traits::{EnvironmentProvider, Substitutable};
17use crate::{language, InferState};
18
19#[derive(Debug, Error)]
20pub enum AlgorithmWError {
21    #[error(transparent)]
22    AlgorithmU(#[from] AlgorithmUError),
23    UnknownVar(NEVec<Var>),
24    #[error(transparent)]
25    FailedConstraints(#[from] ConstraintsSolverError),
26}
27
28#[derive(Debug, Error)]
29pub enum CompoundInferError<E> {
30    #[error(transparent)]
31    AlgoW(#[from] AlgorithmWError),
32    Both(AlgorithmWError, NEVec<E>),
33    Foreign(NEVec<E>)
34}
35
36struct AlgorithmW<'e> {
37    monomorphic_set: HashSet<TVar>,
38    env: &'e mut InferState,
39}
40
41type AWResult = Result<(AssumptionSet, Vec<Constraint>, MonomorphicType), AlgorithmWError>;
42
43impl<'e> AlgorithmW<'e> {
44    fn apply(&mut self, expression: &Language) -> AWResult {
45        match expression {
46            Language::Var(x) => self.apply_var(x),
47            Language::App(x) => self.apply_app(x),
48            Language::Lambda(x) => self.apply_lambda(x),
49            Language::Let(x) => self.apply_let(x),
50            Language::Special(x) => self.apply_special(x),
51            Language::Literal(x) => match x {
52                Literal::Integral(_) => Ok((
53                    AssumptionSet::empty(),
54                    vec![],
55                    PrimitiveType::Integral.into(),
56                )),
57                Literal::Floating(_) => Ok((
58                    AssumptionSet::empty(),
59                    vec![],
60                    PrimitiveType::Floating.into(),
61                )),
62                Literal::Tuple(vec) => self.apply_tuple(vec),
63            },
64        }
65    }
66
67    fn apply_var(&mut self, var: &Var) -> AWResult {
68        let fresh = self.env.new_var();
69        Ok((
70            AssumptionSet::single(var.clone(), fresh),
71            vec![],
72            fresh.into(),
73        ))
74    }
75
76    fn apply_app(&mut self, language::App { arg, func }: &language::App) -> AWResult {
77        let (as1, cs1, t1) = self.apply(func)?;
78        let (as2, cs2, t2) = self.apply(arg)?;
79        let tv = self.env.new_var();
80
81        Ok((
82            as1 + as2,
83            concat([cs1, cs2, vec![eq_cst(t1, fun1(t2, tv))]]),
84            tv.into(),
85        ))
86    }
87
88    fn apply_lambda(&mut self, language::Lambda { bind, expr }: &language::Lambda) -> AWResult {
89        let tv = self.env.new_var();
90        self.monomorphic_set.insert(tv);
91        let (as1, cs1, t1) = self.apply(expr)?;
92
93        let mut as_ = as1.clone();
94        as_.remove(&bind.var);
95        let eq_cs = as1
96            .get(&bind.var)
97            .iter()
98            .map(|it| eq_cst(tv, it.clone()))
99            .collect();
100        let bound = bind
101            .ty
102            .as_ref()
103            .map_or(vec![], |it| vec![eq_cst(tv, it.clone())]);
104
105        Ok((as_, concat([cs1, eq_cs, bound]), fun1(tv, t1)))
106    }
107
108    fn apply_let(
109        &mut self,
110        language::Let {
111            binder,
112            bind,
113            usage,
114        }: &language::Let,
115    ) -> AWResult {
116        let (as1, cs1, t1) = self.apply(binder)?;
117        let (as2, cs2, t2) = self.apply(usage)?;
118
119        let mut as_ = as1.clone() + &as2;
120        as_.remove(&bind.var);
121        let im_cs = as2
122            .get(&bind.var)
123            .iter()
124            .chain(as1.get(&bind.var).iter()) // support for fix
125            .map(|it| implicit_cst(it.clone(), self.monomorphic_set.clone(), t1.clone()))
126            .collect();
127        let bound = bind.ty.as_ref().map_or(vec![], |it| {
128            vec![implicit_cst(
129                it.clone(),
130                self.monomorphic_set.clone(),
131                t1.clone(),
132            )]
133        });
134
135        Ok((as_, concat([cs1, cs2, im_cs, bound]), t2))
136    }
137
138    fn apply_tuple(&mut self, tuple: &[Language]) -> AWResult {
139        let ctx: Vec<_> = tuple.iter().map(|it| self.apply(it)).try_collect()?;
140        let (a, c, t): (Vec<_>, Vec<_>, Vec<_>) = ctx.into_iter().multiunzip();
141        Ok((
142            AssumptionSet::merge_many(a),
143            c.into_iter().flatten().collect(),
144            Tuple(t).into(),
145        ))
146    }
147
148    fn apply_special(&mut self, special: &Special) -> AWResult {
149        match special {
150            Special::If {
151                condition,
152                body,
153                otherwise,
154            } => {
155                let (as1, cs1, t1) = self.apply(condition)?;
156                let (as2, cs2, t2) = self.apply(body)?;
157                let (as3, cs3, t3) = self.apply(otherwise)?;
158
159                Ok((
160                    as1 + as2 + as3,
161                    concat([
162                        cs1,
163                        cs2,
164                        cs3,
165                        vec![eq_cst(t1, Boolean), eq_cst(t2.clone(), t3)],
166                    ]),
167                    t2,
168                ))
169            }
170        }
171    }
172}
173
174impl Language {
175    fn infer_w<E>(
176        &self,
177        context: &mut AlgorithmW,
178        table: &impl EnvironmentProvider<Var, Error = E>,
179    ) -> Result<(Substitutions, MonomorphicType), CompoundInferError<E>> {
180        let (a, c, t) = context.apply(self)?;
181        let (errors, not_found, explicits) = a.keys().fold(
182            (Vec::new(), Vec::new(), Vec::new()),
183            |(mut errors, mut not_found, mut cst), next| {
184                match table.maybe_get(next) {
185                    Ok(None) => not_found.push(next.clone()),
186                    Ok(Some(s)) => cst.extend(
187                        a.get(next)
188                            .iter()
189                            .map(|it| explicit_cst(it.clone(), s.clone().into_owned())),
190                    ),
191                    Err(e) => errors.push(e),
192                }
193                (errors, not_found, cst)
194            },
195        );
196
197        if let Some(not_found) = NEVec::from_vec(not_found) {
198            if let Some(errors) = NEVec::from_vec(errors) {
199                return Err(CompoundInferError::Both(UnknownVar(not_found), errors))
200            }
201            return Err(CompoundInferError::AlgoW(UnknownVar(not_found)))
202        } else if let Some(errors) = NEVec::from_vec(errors) {
203            return Err(CompoundInferError::Foreign(errors))
204        }
205
206        debug!("Inferred raw type and constraints: ");
207        debug!("{c:?} ++ {explicits:?}, {t}");
208        let substitutions = Constraint::solve(concat([c, explicits]), context.env)
209            .map_err(AlgorithmWError::FailedConstraints)?;
210        let t = t.substitute(&substitutions);
211        debug!("Inferred type and substitutions: ");
212        debug!("{}, {}", substitutions, t);
213        Ok((substitutions, t))
214    }
215
216    pub(crate) fn infer_with_env<E>(
217        &self,
218        context: &impl EnvironmentProvider<Var, Error = E>,
219        env: &mut InferState,
220    ) -> Result<PolymorphicType, CompoundInferError<E>> {
221        let mut ctx = AlgorithmW {
222            monomorphic_set: Default::default(),
223            env,
224        };
225        match self.infer_w(&mut ctx, context) {
226            Ok((s, t)) => Ok(t.substitute(&s).normalize()),
227            Err(e) => Err(e),
228        }
229    }
230
231    pub fn infer<E>(
232        &self,
233        table: &impl EnvironmentProvider<Var, Error = E>,
234    ) -> Result<PolymorphicType, CompoundInferError<E>> {
235        self.infer_with_env(table, &mut InferState::default())
236    }
237}
238
239impl Display for AlgorithmWError {
240    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
241        match self {
242            AlgorithmWError::AlgorithmU(x) => write!(f, "{x}"),
243            UnknownVar(vs) => write!(
244                f,
245                "Unknown references: [{}]",
246                vs.iter().into_iter().join(", ")
247            ),
248            AlgorithmWError::FailedConstraints(x) => write!(f, "{x}"),
249        }
250    }
251}