erg_compiler/context/
generalize.rs

1use std::mem;
2
3use erg_common::consts::DEBUG_MODE;
4use erg_common::set::Set;
5use erg_common::traits::{Locational, Stream};
6use erg_common::Str;
7use erg_common::{dict, fn_name, get_hash, set};
8#[allow(unused_imports)]
9use erg_common::{fmt_vec, log};
10
11use crate::hir::GuardClause;
12use crate::module::GeneralizationResult;
13use crate::ty::constructors::*;
14use crate::ty::free::{CanbeFree, Constraint, Free, HasLevel};
15use crate::ty::typaram::{TyParam, TyParamLambda};
16use crate::ty::value::ValueObj;
17use crate::ty::{HasType, Predicate, SharedFrees, SubrType, Type};
18
19use crate::context::{Context, Variance};
20use crate::error::{TyCheckError, TyCheckErrors, TyCheckResult};
21use crate::{feature_error, hir, mono_type_pattern, mono_value_pattern, unreachable_error};
22
23use Type::*;
24use Variance::*;
25
26use super::eval::{Substituter, UndoableLinkedList};
27
28pub struct Generalizer<'c> {
29    ctx: &'c Context,
30    variance: Variance,
31    qnames: Set<Str>,
32    structural_inner: bool,
33}
34
35impl<'c> Generalizer<'c> {
36    pub fn new(ctx: &'c Context) -> Self {
37        Self {
38            ctx,
39            variance: Covariant,
40            qnames: set! {},
41            structural_inner: false,
42        }
43    }
44
45    fn generalize_tp(&mut self, free: TyParam, uninit: bool) -> TyParam {
46        match free {
47            TyParam::Type(t) => TyParam::t(self.generalize_t(*t, uninit)),
48            TyParam::Value(val) => {
49                TyParam::Value(val.map_t(&mut |t| self.generalize_t(t, uninit)).map_tp(
50                    &mut |tp| self.generalize_tp(tp, uninit),
51                    &SharedFrees::new(),
52                ))
53            }
54            TyParam::FreeVar(fv) if fv.is_generalized() => TyParam::FreeVar(fv),
55            TyParam::FreeVar(fv) if fv.is_linked() => {
56                let tp = fv.crack().clone();
57                self.generalize_tp(tp, uninit)
58            }
59            // TODO: Polymorphic generalization
60            TyParam::FreeVar(fv) if fv.level() > Some(self.ctx.level) => {
61                let constr = self.generalize_constraint(&fv);
62                fv.update_constraint(constr, true);
63                fv.generalize();
64                TyParam::FreeVar(fv)
65            }
66            TyParam::List(tps) => TyParam::List(
67                tps.into_iter()
68                    .map(|tp| self.generalize_tp(tp, uninit))
69                    .collect(),
70            ),
71            TyParam::UnsizedList(tp) => {
72                TyParam::UnsizedList(Box::new(self.generalize_tp(*tp, uninit)))
73            }
74            TyParam::Tuple(tps) => TyParam::Tuple(
75                tps.into_iter()
76                    .map(|tp| self.generalize_tp(tp, uninit))
77                    .collect(),
78            ),
79            TyParam::Set(set) => TyParam::Set(
80                set.into_iter()
81                    .map(|tp| self.generalize_tp(tp, uninit))
82                    .collect(),
83            ),
84            TyParam::Dict(tps) => TyParam::Dict(
85                tps.into_iter()
86                    .map(|(k, v)| (self.generalize_tp(k, uninit), self.generalize_tp(v, uninit)))
87                    .collect(),
88            ),
89            TyParam::Record(rec) => TyParam::Record(
90                rec.into_iter()
91                    .map(|(field, tp)| (field, self.generalize_tp(tp, uninit)))
92                    .collect(),
93            ),
94            TyParam::DataClass { name, fields } => {
95                let fields = fields
96                    .into_iter()
97                    .map(|(field, tp)| (field, self.generalize_tp(tp, uninit)))
98                    .collect();
99                TyParam::DataClass { name, fields }
100            }
101            TyParam::Lambda(lambda) => {
102                let nd_params = lambda
103                    .nd_params
104                    .into_iter()
105                    .map(|pt| pt.map_type(&mut |t| self.generalize_t(t, uninit)))
106                    .collect::<Vec<_>>();
107                let var_params = lambda
108                    .var_params
109                    .map(|pt| pt.map_type(&mut |t| self.generalize_t(t, uninit)));
110                let d_params = lambda
111                    .d_params
112                    .into_iter()
113                    .map(|pt| pt.map_type(&mut |t| self.generalize_t(t, uninit)))
114                    .collect::<Vec<_>>();
115                let kw_var_params = lambda
116                    .kw_var_params
117                    .map(|pt| pt.map_type(&mut |t| self.generalize_t(t, uninit)));
118                let body = lambda
119                    .body
120                    .into_iter()
121                    .map(|tp| self.generalize_tp(tp, uninit))
122                    .collect();
123                TyParam::Lambda(TyParamLambda::new(
124                    lambda.const_,
125                    nd_params,
126                    var_params,
127                    d_params,
128                    kw_var_params,
129                    body,
130                ))
131            }
132            TyParam::FreeVar(_) => free,
133            TyParam::Proj { obj, attr } => {
134                let obj = self.generalize_tp(*obj, uninit);
135                TyParam::proj(obj, attr)
136            }
137            TyParam::ProjCall { obj, attr, args } => {
138                let obj = self.generalize_tp(*obj, uninit);
139                let args = args
140                    .into_iter()
141                    .map(|tp| self.generalize_tp(tp, uninit))
142                    .collect();
143                TyParam::proj_call(obj, attr, args)
144            }
145            TyParam::Erased(t) => TyParam::erased(self.generalize_t(*t, uninit)),
146            TyParam::App { name, args } => {
147                let args = args
148                    .into_iter()
149                    .map(|tp| self.generalize_tp(tp, uninit))
150                    .collect();
151                TyParam::App { name, args }
152            }
153            TyParam::BinOp { op, lhs, rhs } => {
154                let lhs = self.generalize_tp(*lhs, uninit);
155                let rhs = self.generalize_tp(*rhs, uninit);
156                TyParam::bin(op, lhs, rhs)
157            }
158            TyParam::UnaryOp { op, val } => {
159                let val = self.generalize_tp(*val, uninit);
160                TyParam::unary(op, val)
161            }
162            TyParam::Mono(_) | TyParam::Failure => free,
163        }
164    }
165
166    /// see doc/LANG/compiler/inference.md#一般化 for details
167    /// ```python
168    /// generalize_t(?T) == 'T: Type
169    /// generalize_t(?T(<: Nat) -> ?T) == |'T <: Nat| 'T -> 'T
170    /// generalize_t(?T(<: Add(?T(<: Eq(?T(<: ...)))) -> ?T) == |'T <: Add('T)| 'T -> 'T
171    /// generalize_t(?T(<: TraitX) -> Int) == TraitX -> Int // 戻り値に現れないなら量化しない
172    /// ```
173    fn generalize_t(&mut self, free_type: Type, uninit: bool) -> Type {
174        match free_type {
175            FreeVar(fv) if fv.is_linked() => self.generalize_t(fv.unwrap_linked(), uninit),
176            FreeVar(fv) if fv.is_generalized() => Type::FreeVar(fv),
177            // TODO: Polymorphic generalization
178            FreeVar(fv) if fv.level().unwrap() > self.ctx.level => {
179                fv.generalize();
180                if uninit {
181                    return Type::FreeVar(fv);
182                }
183                if let Some((sub, sup)) = fv.get_subsup() {
184                    // |Int <: T <: Int| T -> T ==> Int -> Int
185                    if sub == sup {
186                        let t = self.generalize_t(sub, uninit);
187                        let res = FreeVar(fv);
188                        res.set_level(1);
189                        res.destructive_link(&t);
190                        res.generalize();
191                        res
192                    } else if sup != Obj
193                        && self.variance == Contravariant
194                        && !self.qnames.contains(&fv.unbound_name().unwrap())
195                    {
196                        // |T <: Bool| T -> Int ==> Bool -> Int
197                        self.generalize_t(sup, uninit)
198                    } else if sub != Never
199                        && self.variance == Covariant
200                        && !self.qnames.contains(&fv.unbound_name().unwrap())
201                    {
202                        // |T :> Int| X -> T ==> X -> Int
203                        self.generalize_t(sub, uninit)
204                    } else {
205                        let constr = self.generalize_constraint(&fv);
206                        let ty = Type::FreeVar(fv);
207                        ty.update_constraint(constr, None, true);
208                        ty
209                    }
210                } else {
211                    // ?S(: Str) => 'S
212                    let constr = self.generalize_constraint(&fv);
213                    let ty = Type::FreeVar(fv);
214                    ty.update_constraint(constr, None, true);
215                    ty
216                }
217            }
218            FreeVar(_) => free_type,
219            Subr(mut subr) => {
220                self.variance = Contravariant;
221                let qnames = subr.essential_qnames();
222                self.qnames.extend(qnames.clone());
223                subr.non_default_params.iter_mut().for_each(|nd_param| {
224                    *nd_param.typ_mut() = self.generalize_t(mem::take(nd_param.typ_mut()), uninit);
225                });
226                if let Some(var_params) = &mut subr.var_params {
227                    *var_params.typ_mut() =
228                        self.generalize_t(mem::take(var_params.typ_mut()), uninit);
229                }
230                subr.default_params.iter_mut().for_each(|d_param| {
231                    *d_param.typ_mut() = self.generalize_t(mem::take(d_param.typ_mut()), uninit);
232                    if let Some(default) = d_param.default_typ_mut() {
233                        *default = self.generalize_t(mem::take(default), uninit);
234                    }
235                });
236                if let Some(kw_var_params) = &mut subr.kw_var_params {
237                    *kw_var_params.typ_mut() =
238                        self.generalize_t(mem::take(kw_var_params.typ_mut()), uninit);
239                    if let Some(default) = kw_var_params.default_typ_mut() {
240                        *default = self.generalize_t(mem::take(default), uninit);
241                    }
242                }
243                self.variance = Covariant;
244                let return_t = self.generalize_t(*subr.return_t, uninit);
245                self.qnames = self.qnames.difference(&qnames);
246                subr_t(
247                    subr.kind,
248                    subr.non_default_params,
249                    subr.var_params.map(|x| *x),
250                    subr.default_params,
251                    subr.kw_var_params.map(|x| *x),
252                    return_t,
253                )
254            }
255            Quantified(quant) => {
256                log!(err "{quant}");
257                quant.quantify()
258            }
259            Record(rec) => {
260                let fields = rec
261                    .into_iter()
262                    .map(|(name, t)| (name, self.generalize_t(t, uninit)))
263                    .collect();
264                Type::Record(fields)
265            }
266            NamedTuple(rec) => {
267                let fields = rec
268                    .into_iter()
269                    .map(|(name, t)| (name, self.generalize_t(t, uninit)))
270                    .collect();
271                Type::NamedTuple(fields)
272            }
273            Callable { param_ts, return_t } => {
274                let param_ts = param_ts
275                    .into_iter()
276                    .map(|t| self.generalize_t(t, uninit))
277                    .collect();
278                let return_t = self.generalize_t(*return_t, uninit);
279                callable(param_ts, return_t)
280            }
281            Ref(t) => ref_(self.generalize_t(*t, uninit)),
282            RefMut { before, after } => {
283                let after = after.map(|aft| self.generalize_t(*aft, uninit));
284                ref_mut(self.generalize_t(*before, uninit), after)
285            }
286            Refinement(refine) => {
287                let t = self.generalize_t(*refine.t, uninit);
288                let pred = self.generalize_pred(*refine.pred, uninit);
289                refinement(refine.var, t, pred)
290            }
291            Poly { name, mut params } => {
292                let params = params
293                    .iter_mut()
294                    .map(|p| self.generalize_tp(mem::take(p), uninit))
295                    .collect::<Vec<_>>();
296                poly(name, params)
297            }
298            Proj { lhs, rhs } => {
299                let lhs = self.generalize_t(*lhs, uninit);
300                proj(lhs, rhs)
301            }
302            ProjCall {
303                lhs,
304                attr_name,
305                mut args,
306            } => {
307                let lhs = self.generalize_tp(*lhs, uninit);
308                for arg in args.iter_mut() {
309                    *arg = self.generalize_tp(mem::take(arg), uninit);
310                }
311                proj_call(lhs, attr_name, args)
312            }
313            And(ands, idx) => {
314                // not `self.intersection` because types are generalized
315                let ands = ands
316                    .into_iter()
317                    .map(|t| self.generalize_t(t, uninit))
318                    .collect::<Vec<_>>();
319                let isec = ands
320                    .into_iter()
321                    .fold(Obj, |acc, t| self.ctx.intersection(&acc, &t));
322                if let Some(idx) = idx {
323                    isec.with_default_intersec_index(idx)
324                } else {
325                    isec
326                }
327            }
328            Or(ors) => {
329                // not `self.union` because types are generalized
330                let ors = ors
331                    .into_iter()
332                    .map(|t| self.generalize_t(t, uninit))
333                    .collect::<Set<_>>();
334                ors.into_iter()
335                    .fold(Never, |acc, t| self.ctx.union(&acc, &t))
336            }
337            Not(l) => not(self.generalize_t(*l, uninit)),
338            Structural(ty) => {
339                if self.structural_inner {
340                    ty.structuralize()
341                } else {
342                    if ty.is_recursive() {
343                        self.structural_inner = true;
344                    }
345                    let res = self.generalize_t(*ty, uninit).structuralize();
346                    self.structural_inner = false;
347                    res
348                }
349            }
350            Guard(grd) => {
351                let to = self.generalize_t(*grd.to, uninit);
352                guard(grd.namespace, *grd.target, to)
353            }
354            Bounded { sub, sup } => {
355                let sub = self.generalize_t(*sub, uninit);
356                let sup = self.generalize_t(*sup, uninit);
357                bounded(sub, sup)
358            }
359            Int | Nat | Float | Ratio | Complex | Bool | Str | Never | Obj | Type | Error
360            | Code | Frame | NoneType | Inf | NegInf | NotImplementedType | Ellipsis
361            | ClassType | TraitType | Patch | Failure | Uninited | Mono(_) => free_type,
362        }
363    }
364
365    fn generalize_constraint<T: CanbeFree + Send + Clone>(&mut self, fv: &Free<T>) -> Constraint {
366        if let Some((sub, sup)) = fv.get_subsup() {
367            let sub = self.generalize_t(sub, true);
368            let sup = self.generalize_t(sup, true);
369            Constraint::new_sandwiched(sub, sup)
370        } else if let Some(ty) = fv.get_type() {
371            let t = self.generalize_t(ty, true);
372            Constraint::new_type_of(t)
373        } else {
374            unreachable!()
375        }
376    }
377
378    fn generalize_pred(&mut self, pred: Predicate, uninit: bool) -> Predicate {
379        match pred {
380            Predicate::Const(_) | Predicate::Failure => pred,
381            Predicate::Value(val) => {
382                Predicate::Value(val.map_t(&mut |t| self.generalize_t(t, uninit)))
383            }
384            Predicate::Call {
385                receiver,
386                name,
387                args,
388            } => {
389                let receiver = self.generalize_tp(receiver, uninit);
390                let mut new_args = vec![];
391                for arg in args.into_iter() {
392                    new_args.push(self.generalize_tp(arg, uninit));
393                }
394                Predicate::call(receiver, name, new_args)
395            }
396            Predicate::Attr { receiver, name } => {
397                let receiver = self.generalize_tp(receiver, uninit);
398                Predicate::attr(receiver, name)
399            }
400            Predicate::GeneralEqual { lhs, rhs } => {
401                let lhs = self.generalize_pred(*lhs, uninit);
402                let rhs = self.generalize_pred(*rhs, uninit);
403                Predicate::general_eq(lhs, rhs)
404            }
405            Predicate::GeneralGreaterEqual { lhs, rhs } => {
406                let lhs = self.generalize_pred(*lhs, uninit);
407                let rhs = self.generalize_pred(*rhs, uninit);
408                Predicate::general_ge(lhs, rhs)
409            }
410            Predicate::GeneralLessEqual { lhs, rhs } => {
411                let lhs = self.generalize_pred(*lhs, uninit);
412                let rhs = self.generalize_pred(*rhs, uninit);
413                Predicate::general_le(lhs, rhs)
414            }
415            Predicate::GeneralNotEqual { lhs, rhs } => {
416                let lhs = self.generalize_pred(*lhs, uninit);
417                let rhs = self.generalize_pred(*rhs, uninit);
418                Predicate::general_ne(lhs, rhs)
419            }
420            Predicate::Equal { lhs, rhs } => {
421                let rhs = self.generalize_tp(rhs, uninit);
422                Predicate::eq(lhs, rhs)
423            }
424            Predicate::GreaterEqual { lhs, rhs } => {
425                let rhs = self.generalize_tp(rhs, uninit);
426                Predicate::ge(lhs, rhs)
427            }
428            Predicate::LessEqual { lhs, rhs } => {
429                let rhs = self.generalize_tp(rhs, uninit);
430                Predicate::le(lhs, rhs)
431            }
432            Predicate::NotEqual { lhs, rhs } => {
433                let rhs = self.generalize_tp(rhs, uninit);
434                Predicate::ne(lhs, rhs)
435            }
436            Predicate::And(lhs, rhs) => {
437                let lhs = self.generalize_pred(*lhs, uninit);
438                let rhs = self.generalize_pred(*rhs, uninit);
439                Predicate::and(lhs, rhs)
440            }
441            Predicate::Or(preds) => Predicate::Or(
442                preds
443                    .into_iter()
444                    .map(|pred| self.generalize_pred(pred, uninit))
445                    .collect(),
446            ),
447            Predicate::Not(pred) => {
448                let pred = self.generalize_pred(*pred, uninit);
449                !pred
450            }
451        }
452    }
453}
454
455pub struct Dereferencer<'c, 'q, 'l, L: Locational> {
456    ctx: &'c Context,
457    /// This is basically the same as `ctx.level`, but can be changed
458    level: usize,
459    coerce: bool,
460    variance_stack: Vec<Variance>,
461    qnames: &'q Set<Str>,
462    loc: &'l L,
463}
464
465impl<'c, 'q, 'l, L: Locational> Dereferencer<'c, 'q, 'l, L> {
466    pub fn new(
467        ctx: &'c Context,
468        variance: Variance,
469        coerce: bool,
470        qnames: &'q Set<Str>,
471        loc: &'l L,
472    ) -> Self {
473        Self {
474            ctx,
475            level: ctx.level,
476            coerce,
477            variance_stack: vec![Invariant, variance],
478            qnames,
479            loc,
480        }
481    }
482
483    pub fn simple(ctx: &'c Context, qnames: &'q Set<Str>, loc: &'l L) -> Self {
484        Self::new(ctx, Variance::Covariant, true, qnames, loc)
485    }
486
487    pub fn set_level(&mut self, level: usize) {
488        self.level = level;
489    }
490
491    fn push_variance(&mut self, variance: Variance) {
492        self.variance_stack.push(variance);
493    }
494
495    fn pop_variance(&mut self) {
496        self.variance_stack.pop();
497    }
498
499    fn current_variance(&self) -> Variance {
500        *self.variance_stack.last().unwrap()
501    }
502
503    fn deref_value(&mut self, val: ValueObj) -> TyCheckResult<ValueObj> {
504        match val {
505            ValueObj::Type(mut t) => {
506                t.try_map_t(&mut |t| self.deref_tyvar(t.clone()))?;
507                Ok(ValueObj::Type(t))
508            }
509            ValueObj::List(vs) => {
510                let mut new_vs = vec![];
511                for v in vs.iter() {
512                    new_vs.push(self.deref_value(v.clone())?);
513                }
514                Ok(ValueObj::List(new_vs.into()))
515            }
516            ValueObj::Tuple(vs) => {
517                let mut new_vs = vec![];
518                for v in vs.iter() {
519                    new_vs.push(self.deref_value(v.clone())?);
520                }
521                Ok(ValueObj::Tuple(new_vs.into()))
522            }
523            ValueObj::Dict(dic) => {
524                let mut new_dic = dict! {};
525                for (k, v) in dic.into_iter() {
526                    let k = self.deref_value(k)?;
527                    let v = self.deref_value(v)?;
528                    new_dic.insert(k, v);
529                }
530                Ok(ValueObj::Dict(new_dic))
531            }
532            ValueObj::Set(set) => {
533                let mut new_set = set! {};
534                for v in set.into_iter() {
535                    new_set.insert(self.deref_value(v)?);
536                }
537                Ok(ValueObj::Set(new_set))
538            }
539            ValueObj::Record(rec) => {
540                let mut new_rec = dict! {};
541                for (field, v) in rec.into_iter() {
542                    new_rec.insert(field, self.deref_value(v)?);
543                }
544                Ok(ValueObj::Record(new_rec))
545            }
546            ValueObj::DataClass { name, fields } => {
547                let mut new_fields = dict! {};
548                for (field, v) in fields.into_iter() {
549                    new_fields.insert(field, self.deref_value(v)?);
550                }
551                Ok(ValueObj::DataClass {
552                    name,
553                    fields: new_fields,
554                })
555            }
556            ValueObj::UnsizedList(v) => Ok(ValueObj::UnsizedList(Box::new(self.deref_value(*v)?))),
557            ValueObj::Subr(subr) => Ok(ValueObj::Subr(subr)),
558            mono_value_pattern!() => Ok(val),
559        }
560    }
561
562    pub(crate) fn deref_tp(&mut self, tp: TyParam) -> TyCheckResult<TyParam> {
563        match tp {
564            TyParam::FreeVar(fv) if fv.is_linked() => {
565                let inner = fv.unwrap_linked();
566                self.deref_tp(inner)
567            }
568            TyParam::FreeVar(fv)
569                if fv.is_generalized() && self.qnames.contains(&fv.unbound_name().unwrap()) =>
570            {
571                Ok(TyParam::FreeVar(fv))
572            }
573            // REVIEW:
574            TyParam::FreeVar(_) if self.level == 0 => {
575                let t = self.ctx.get_tp_t(&tp).unwrap_or(Type::Obj);
576                Ok(TyParam::erased(self.deref_tyvar(t)?))
577            }
578            TyParam::FreeVar(fv) if fv.get_type().is_some() => {
579                let t = self.deref_tyvar(fv.get_type().unwrap())?;
580                fv.update_type(t);
581                Ok(TyParam::FreeVar(fv))
582            }
583            TyParam::FreeVar(_) => Ok(tp),
584            TyParam::Type(t) => Ok(TyParam::t(self.deref_tyvar(*t)?)),
585            TyParam::Value(val) => self.deref_value(val).map(TyParam::Value),
586            TyParam::Erased(t) => Ok(TyParam::erased(self.deref_tyvar(*t)?)),
587            TyParam::App { name, mut args } => {
588                for param in args.iter_mut() {
589                    *param = self.deref_tp(mem::take(param))?;
590                }
591                Ok(TyParam::App { name, args })
592            }
593            TyParam::BinOp { op, lhs, rhs } => {
594                let lhs = self.deref_tp(*lhs)?;
595                let rhs = self.deref_tp(*rhs)?;
596                Ok(TyParam::BinOp {
597                    op,
598                    lhs: Box::new(lhs),
599                    rhs: Box::new(rhs),
600                })
601            }
602            TyParam::UnaryOp { op, val } => {
603                let val = self.deref_tp(*val)?;
604                Ok(TyParam::UnaryOp {
605                    op,
606                    val: Box::new(val),
607                })
608            }
609            TyParam::List(tps) => {
610                let mut new_tps = vec![];
611                for tp in tps {
612                    new_tps.push(self.deref_tp(tp)?);
613                }
614                Ok(TyParam::List(new_tps))
615            }
616            TyParam::UnsizedList(tp) => Ok(TyParam::UnsizedList(Box::new(self.deref_tp(*tp)?))),
617            TyParam::Tuple(tps) => {
618                let mut new_tps = vec![];
619                for tp in tps {
620                    new_tps.push(self.deref_tp(tp)?);
621                }
622                Ok(TyParam::Tuple(new_tps))
623            }
624            TyParam::Dict(dic) => {
625                let mut new_dic = dict! {};
626                for (k, v) in dic.into_iter() {
627                    let k = self.deref_tp(k)?;
628                    let v = self.deref_tp(v)?;
629                    new_dic
630                        .entry(k)
631                        .and_modify(|old_v| {
632                            if let Some(union) = self.ctx.union_tp(&mem::take(old_v), &v) {
633                                *old_v = union;
634                            }
635                        })
636                        .or_insert(v);
637                }
638                Ok(TyParam::Dict(new_dic))
639            }
640            TyParam::Set(set) => {
641                let mut new_set = set! {};
642                for v in set.into_iter() {
643                    new_set.insert(self.deref_tp(v)?);
644                }
645                Ok(TyParam::Set(new_set))
646            }
647            TyParam::Record(rec) => {
648                let mut new_rec = dict! {};
649                for (field, tp) in rec.into_iter() {
650                    new_rec.insert(field, self.deref_tp(tp)?);
651                }
652                Ok(TyParam::Record(new_rec))
653            }
654            TyParam::DataClass { name, fields } => {
655                let mut new_fields = dict! {};
656                for (field, tp) in fields.into_iter() {
657                    new_fields.insert(field, self.deref_tp(tp)?);
658                }
659                Ok(TyParam::DataClass {
660                    name,
661                    fields: new_fields,
662                })
663            }
664            TyParam::Lambda(lambda) => {
665                let nd_params = lambda
666                    .nd_params
667                    .into_iter()
668                    .map(|pt| pt.try_map_type(&mut |t| self.deref_tyvar(t)))
669                    .collect::<TyCheckResult<_>>()?;
670                let var_params = lambda
671                    .var_params
672                    .map(|pt| pt.try_map_type(&mut |t| self.deref_tyvar(t)))
673                    .transpose()?;
674                let d_params = lambda
675                    .d_params
676                    .into_iter()
677                    .map(|pt| pt.try_map_type(&mut |t| self.deref_tyvar(t)))
678                    .collect::<TyCheckResult<_>>()?;
679                let kw_var_params = lambda
680                    .kw_var_params
681                    .map(|pt| pt.try_map_type(&mut |t| self.deref_tyvar(t)))
682                    .transpose()?;
683                let body = lambda
684                    .body
685                    .into_iter()
686                    .map(|tp| self.deref_tp(tp))
687                    .collect::<TyCheckResult<Vec<_>>>()?;
688                Ok(TyParam::Lambda(TyParamLambda::new(
689                    lambda.const_,
690                    nd_params,
691                    var_params,
692                    d_params,
693                    kw_var_params,
694                    body,
695                )))
696            }
697            TyParam::Proj { obj, attr } => {
698                let obj = self.deref_tp(*obj)?;
699                Ok(TyParam::Proj {
700                    obj: Box::new(obj),
701                    attr,
702                })
703            }
704            TyParam::ProjCall { obj, attr, args } => {
705                let obj = self.deref_tp(*obj)?;
706                let mut new_args = vec![];
707                for arg in args.into_iter() {
708                    new_args.push(self.deref_tp(arg)?);
709                }
710                Ok(TyParam::ProjCall {
711                    obj: Box::new(obj),
712                    attr,
713                    args: new_args,
714                })
715            }
716            TyParam::Mono(_) | TyParam::Failure => Ok(tp),
717        }
718    }
719
720    fn deref_pred(&mut self, pred: Predicate) -> TyCheckResult<Predicate> {
721        match pred {
722            Predicate::Equal { lhs, rhs } => {
723                let rhs = self.deref_tp(rhs)?;
724                Ok(Predicate::eq(lhs, rhs))
725            }
726            Predicate::GreaterEqual { lhs, rhs } => {
727                let rhs = self.deref_tp(rhs)?;
728                Ok(Predicate::ge(lhs, rhs))
729            }
730            Predicate::LessEqual { lhs, rhs } => {
731                let rhs = self.deref_tp(rhs)?;
732                Ok(Predicate::le(lhs, rhs))
733            }
734            Predicate::NotEqual { lhs, rhs } => {
735                let rhs = self.deref_tp(rhs)?;
736                Ok(Predicate::ne(lhs, rhs))
737            }
738            Predicate::GeneralEqual { lhs, rhs } => {
739                let lhs = self.deref_pred(*lhs)?;
740                let rhs = self.deref_pred(*rhs)?;
741                match (lhs, rhs) {
742                    (Predicate::Value(lhs), Predicate::Value(rhs)) => {
743                        Ok(Predicate::Value(ValueObj::Bool(lhs == rhs)))
744                    }
745                    (lhs, rhs) => Ok(Predicate::general_eq(lhs, rhs)),
746                }
747            }
748            Predicate::GeneralNotEqual { lhs, rhs } => {
749                let lhs = self.deref_pred(*lhs)?;
750                let rhs = self.deref_pred(*rhs)?;
751                match (lhs, rhs) {
752                    (Predicate::Value(lhs), Predicate::Value(rhs)) => {
753                        Ok(Predicate::Value(ValueObj::Bool(lhs != rhs)))
754                    }
755                    (lhs, rhs) => Ok(Predicate::general_ne(lhs, rhs)),
756                }
757            }
758            Predicate::GeneralGreaterEqual { lhs, rhs } => {
759                let lhs = self.deref_pred(*lhs)?;
760                let rhs = self.deref_pred(*rhs)?;
761                match (lhs, rhs) {
762                    (Predicate::Value(lhs), Predicate::Value(rhs)) => {
763                        let Some(ValueObj::Bool(res)) = lhs.try_ge(rhs) else {
764                            // TODO:
765                            return Err(TyCheckErrors::from(TyCheckError::dummy_infer_error(
766                                self.ctx.cfg.input.clone(),
767                                fn_name!(),
768                                line!(),
769                            )));
770                        };
771                        Ok(Predicate::Value(ValueObj::Bool(res)))
772                    }
773                    (lhs, rhs) => Ok(Predicate::general_ge(lhs, rhs)),
774                }
775            }
776            Predicate::GeneralLessEqual { lhs, rhs } => {
777                let lhs = self.deref_pred(*lhs)?;
778                let rhs = self.deref_pred(*rhs)?;
779                match (lhs, rhs) {
780                    (Predicate::Value(lhs), Predicate::Value(rhs)) => {
781                        let Some(ValueObj::Bool(res)) = lhs.try_le(rhs) else {
782                            return Err(TyCheckErrors::from(TyCheckError::dummy_infer_error(
783                                self.ctx.cfg.input.clone(),
784                                fn_name!(),
785                                line!(),
786                            )));
787                        };
788                        Ok(Predicate::Value(ValueObj::Bool(res)))
789                    }
790                    (lhs, rhs) => Ok(Predicate::general_le(lhs, rhs)),
791                }
792            }
793            Predicate::Call {
794                receiver,
795                name,
796                args,
797            } => {
798                let Ok(receiver) = self.deref_tp(receiver.clone()) else {
799                    return Ok(Predicate::call(receiver, name, args));
800                };
801                let mut new_args = vec![];
802                for arg in args.into_iter() {
803                    let Ok(arg) = self.deref_tp(arg) else {
804                        return Ok(Predicate::call(receiver, name, new_args));
805                    };
806                    new_args.push(arg);
807                }
808                let evaled = if let Some(name) = &name {
809                    self.ctx
810                        .eval_proj_call(receiver.clone(), name.clone(), new_args.clone(), &())
811                } else {
812                    self.ctx.eval_call(receiver.clone(), new_args.clone(), &())
813                };
814                match evaled {
815                    Ok(TyParam::Value(value)) => Ok(Predicate::Value(value)),
816                    _ => Ok(Predicate::call(receiver, name, new_args)),
817                }
818            }
819            Predicate::And(lhs, rhs) => {
820                let lhs = self.deref_pred(*lhs)?;
821                let rhs = self.deref_pred(*rhs)?;
822                Ok(Predicate::and(lhs, rhs))
823            }
824            Predicate::Or(preds) => {
825                let mut new_preds = Set::with_capacity(preds.len());
826                for pred in preds.into_iter() {
827                    new_preds.insert(self.deref_pred(pred)?);
828                }
829                Ok(Predicate::Or(new_preds))
830            }
831            Predicate::Not(pred) => {
832                let pred = self.deref_pred(*pred)?;
833                Ok(!pred)
834            }
835            Predicate::Attr { receiver, name } => {
836                let receiver = self.deref_tp(receiver)?;
837                Ok(Predicate::attr(receiver, name))
838            }
839            Predicate::Value(v) => self.deref_value(v).map(Predicate::Value),
840            Predicate::Const(_) | Predicate::Failure => Ok(pred),
841        }
842    }
843
844    fn deref_constraint(&mut self, constraint: Constraint) -> TyCheckResult<Constraint> {
845        match constraint {
846            Constraint::Sandwiched { sub, sup } => Ok(Constraint::new_sandwiched(
847                self.deref_tyvar(sub)?,
848                self.deref_tyvar(sup)?,
849            )),
850            Constraint::TypeOf(t) => Ok(Constraint::new_type_of(self.deref_tyvar(t)?)),
851            _ => unreachable_error!(TyCheckErrors, TyCheckError, self.ctx),
852        }
853    }
854
855    /// e.g.
856    /// ```python
857    // ?T(:> Nat, <: Int)[n] ==> Nat (self.level <= n)
858    // ?T(:> Nat, <: Sub(?U(:> {1}))) ==> Nat
859    // ?T(:> Nat, <: Sub(?U(:> {1}))) -> ?U ==> |U: Type, T <: Sub(U)| T -> U
860    // ?T(:> Nat, <: Sub(Str)) ==> Error!
861    // ?T(:> {1, "a"}, <: Eq(?T(:> {1, "a"}, ...)) ==> Error!
862    // ```
863    pub(crate) fn deref_tyvar(&mut self, t: Type) -> TyCheckResult<Type> {
864        match t {
865            FreeVar(fv) if fv.is_linked() => {
866                let t = fv.unwrap_linked();
867                // (((((...))))) == Never
868                if t.is_recursive() {
869                    Ok(Type::Never)
870                } else {
871                    self.deref_tyvar(t)
872                }
873            }
874            FreeVar(mut fv)
875                if fv.is_generalized() && self.qnames.contains(&fv.unbound_name().unwrap()) =>
876            {
877                fv.update_init();
878                Ok(Type::FreeVar(fv))
879            }
880            // ?T(:> Nat, <: Int)[n] ==> Nat (self.level <= n)
881            // ?T(:> Nat, <: Sub ?U(:> {1}))[n] ==> Nat
882            // ?T(<: Int, :> Add(?T)) ==> Int
883            // ?T(:> Nat, <: Sub(Str)) ==> Error!
884            // ?T(:> {1, "a"}, <: Eq(?T(:> {1, "a"}, ...)) ==> Error!
885            FreeVar(fv) if fv.constraint_is_sandwiched() => {
886                let fv_hash = get_hash(&fv);
887                let (sub_t, super_t) = fv.get_subsup().unwrap();
888                if self.level <= fv.level().unwrap() {
889                    // we need to force linking to avoid infinite loop
890                    // e.g. fv == ?T(<: Int, :> Add(?T))
891                    //      fv == ?T(:> ?T.Output, <: Add(Int))
892                    let list = UndoableLinkedList::new();
893                    let fv_t = Type::FreeVar(fv.clone());
894                    let dummy = match (sub_t.contains_type(&fv_t), super_t.contains_type(&fv_t)) {
895                        // REVIEW: to prevent infinite recursion, but this may cause a nonsense error
896                        (true, true) => {
897                            fv.dummy_link();
898                            true
899                        }
900                        (true, false) => {
901                            fv_t.undoable_link(&super_t, &list);
902                            false
903                        }
904                        (false, true | false) => {
905                            fv_t.undoable_link(&sub_t, &list);
906                            false
907                        }
908                    };
909                    let res = self.validate_subsup(sub_t, super_t, fv_hash);
910                    if dummy {
911                        fv.undo();
912                    } else {
913                        drop(list);
914                    }
915                    match res {
916                        Ok(ty) => {
917                            // TODO: T(:> Nat <: Int) -> T(:> Nat, <: Int) ==> Int -> Nat
918                            // Type::FreeVar(fv).destructive_link(&ty);
919                            Ok(ty)
920                        }
921                        Err(errs) => {
922                            if !fv.is_generalized() {
923                                Type::FreeVar(fv).destructive_link(&Never);
924                            }
925                            Err(errs)
926                        }
927                    }
928                } else {
929                    // no dereference at this point
930                    Ok(Type::FreeVar(fv))
931                }
932            }
933            FreeVar(fv) if fv.get_type().is_some() => {
934                let ty = fv.get_type().unwrap();
935                if self.level <= fv.level().unwrap() {
936                    // T: {Int, Str} => Int or Str
937                    if let Some(tys) = ty.refinement_values() {
938                        let mut union = Never;
939                        for tp in tys {
940                            if let Ok(ty) = self.ctx.convert_tp_into_type(tp.clone()) {
941                                union = self.ctx.union(&union, &ty);
942                            }
943                        }
944                        return Ok(union);
945                    }
946                    Ok(Type::FreeVar(fv))
947                } else {
948                    Ok(Type::FreeVar(fv))
949                }
950            }
951            FreeVar(fv) if fv.is_unbound() => {
952                if self.level == 0 {
953                    match &*fv.crack_constraint() {
954                        Constraint::TypeOf(t) if !t.is_type() => {
955                            return Err(TyCheckErrors::from(TyCheckError::dummy_infer_error(
956                                self.ctx.cfg.input.clone(),
957                                fn_name!(),
958                                line!(),
959                            )));
960                        }
961                        _ => {}
962                    }
963                    Ok(Type::FreeVar(fv))
964                } else {
965                    let new_constraint = fv.crack_constraint().clone();
966                    let new_constraint = self.deref_constraint(new_constraint)?;
967                    let ty = Type::FreeVar(fv);
968                    ty.update_constraint(new_constraint, None, true);
969                    Ok(ty)
970                }
971            }
972            FreeVar(_) => Ok(t),
973            Poly { name, mut params } => {
974                let typ = poly(&name, params.clone());
975                let ctx = self.ctx.get_nominal_type_ctx(&typ).ok_or_else(|| {
976                    TyCheckError::type_not_found(
977                        self.ctx.cfg.input.clone(),
978                        line!() as usize,
979                        self.loc.loc(),
980                        self.ctx.caused_by(),
981                        &typ,
982                    )
983                })?;
984                let mut errs = TyCheckErrors::empty();
985                let variances = ctx.type_params_variance();
986                for (param, variance) in params
987                    .iter_mut()
988                    .zip(variances.into_iter().chain(std::iter::repeat(Invariant)))
989                {
990                    self.push_variance(variance);
991                    match self.deref_tp(mem::take(param)) {
992                        Ok(t) => *param = t,
993                        Err(es) => errs.extend(es),
994                    }
995                    self.pop_variance();
996                }
997                if errs.is_empty() {
998                    Ok(Type::Poly { name, params })
999                } else {
1000                    Err(errs)
1001                }
1002            }
1003            Subr(mut subr) => {
1004                let mut errs = TyCheckErrors::empty();
1005                for param in subr.non_default_params.iter_mut() {
1006                    self.push_variance(Contravariant);
1007                    match self.deref_tyvar(mem::take(param.typ_mut())) {
1008                        Ok(t) => *param.typ_mut() = t,
1009                        Err(es) => errs.extend(es),
1010                    }
1011                    self.pop_variance();
1012                }
1013                if let Some(var_params) = &mut subr.var_params {
1014                    self.push_variance(Contravariant);
1015                    match self.deref_tyvar(mem::take(var_params.typ_mut())) {
1016                        Ok(t) => *var_params.typ_mut() = t,
1017                        Err(es) => errs.extend(es),
1018                    }
1019                    self.pop_variance();
1020                }
1021                for d_param in subr.default_params.iter_mut() {
1022                    self.push_variance(Contravariant);
1023                    match self.deref_tyvar(mem::take(d_param.typ_mut())) {
1024                        Ok(t) => *d_param.typ_mut() = t,
1025                        Err(es) => errs.extend(es),
1026                    }
1027                    if let Some(default) = d_param.default_typ_mut() {
1028                        match self.deref_tyvar(mem::take(default)) {
1029                            Ok(t) => *default = t,
1030                            Err(es) => errs.extend(es),
1031                        }
1032                    }
1033                    self.pop_variance();
1034                }
1035                if let Some(kw_var_params) = &mut subr.kw_var_params {
1036                    self.push_variance(Contravariant);
1037                    match self.deref_tyvar(mem::take(kw_var_params.typ_mut())) {
1038                        Ok(t) => *kw_var_params.typ_mut() = t,
1039                        Err(es) => errs.extend(es),
1040                    }
1041                    if let Some(default) = kw_var_params.default_typ_mut() {
1042                        match self.deref_tyvar(mem::take(default)) {
1043                            Ok(t) => *default = t,
1044                            Err(es) => errs.extend(es),
1045                        }
1046                    }
1047                    self.pop_variance();
1048                }
1049                self.push_variance(Covariant);
1050                match self.deref_tyvar(mem::take(&mut subr.return_t)) {
1051                    Ok(t) => *subr.return_t = t,
1052                    Err(es) => errs.extend(es),
1053                }
1054                self.pop_variance();
1055                if errs.is_empty() {
1056                    Ok(Type::Subr(subr))
1057                } else {
1058                    Err(errs)
1059                }
1060            }
1061            Callable {
1062                mut param_ts,
1063                return_t,
1064            } => {
1065                for param_t in param_ts.iter_mut() {
1066                    *param_t = self.deref_tyvar(mem::take(param_t))?;
1067                }
1068                let return_t = self.deref_tyvar(*return_t)?;
1069                Ok(callable(param_ts, return_t))
1070            }
1071            Quantified(subr) => self.eliminate_needless_quant(*subr),
1072            Ref(t) => {
1073                let t = self.deref_tyvar(*t)?;
1074                Ok(ref_(t))
1075            }
1076            RefMut { before, after } => {
1077                let before = self.deref_tyvar(*before)?;
1078                let after = if let Some(after) = after {
1079                    Some(self.deref_tyvar(*after)?)
1080                } else {
1081                    None
1082                };
1083                Ok(ref_mut(before, after))
1084            }
1085            Record(mut rec) => {
1086                for (_, field) in rec.iter_mut() {
1087                    *field = self.deref_tyvar(mem::take(field))?;
1088                }
1089                Ok(Type::Record(rec))
1090            }
1091            NamedTuple(mut rec) => {
1092                for (_, t) in rec.iter_mut() {
1093                    *t = self.deref_tyvar(mem::take(t))?;
1094                }
1095                Ok(Type::NamedTuple(rec))
1096            }
1097            Refinement(refine) => {
1098                let t = self.deref_tyvar(*refine.t)?;
1099                let pred = self.deref_pred(*refine.pred)?;
1100                Ok(refinement(refine.var, t, pred))
1101            }
1102            And(ands, _) => {
1103                let mut new_ands = vec![];
1104                for t in ands.into_iter() {
1105                    new_ands.push(self.deref_tyvar(t)?);
1106                }
1107                Ok(new_ands
1108                    .into_iter()
1109                    .fold(Type::Obj, |acc, t| self.ctx.intersection(&acc, &t)))
1110            }
1111            Or(ors) => {
1112                let mut new_ors = vec![];
1113                for t in ors.into_iter() {
1114                    new_ors.push(self.deref_tyvar(t)?);
1115                }
1116                Ok(new_ors
1117                    .into_iter()
1118                    .fold(Type::Never, |acc, t| self.ctx.union(&acc, &t)))
1119            }
1120            Not(ty) => {
1121                let ty = self.deref_tyvar(*ty)?;
1122                Ok(self.ctx.complement(&ty))
1123            }
1124            Proj { lhs, rhs } => {
1125                let proj = self
1126                    .ctx
1127                    .eval_proj(*lhs.clone(), rhs.clone(), self.level, self.loc)
1128                    .or_else(|_| {
1129                        let lhs = self.deref_tyvar(*lhs)?;
1130                        self.ctx.eval_proj(lhs, rhs, self.level, self.loc)
1131                    })
1132                    .unwrap_or(Failure);
1133                Ok(proj)
1134            }
1135            ProjCall {
1136                lhs,
1137                attr_name,
1138                args,
1139            } => {
1140                let lhs = self.deref_tp(*lhs)?;
1141                let mut new_args = vec![];
1142                for arg in args.into_iter() {
1143                    new_args.push(self.deref_tp(arg)?);
1144                }
1145                let proj = self
1146                    .ctx
1147                    .eval_proj_call_t(lhs, attr_name, new_args, self.level, self.loc)
1148                    .unwrap_or(Failure);
1149                Ok(proj)
1150            }
1151            Structural(inner) => {
1152                let inner = self.deref_tyvar(*inner)?;
1153                Ok(inner.structuralize())
1154            }
1155            Guard(grd) => {
1156                let to = self.deref_tyvar(*grd.to)?;
1157                Ok(guard(grd.namespace, *grd.target, to))
1158            }
1159            Bounded { sub, sup } => {
1160                let sub = self.deref_tyvar(*sub)?;
1161                let sup = self.deref_tyvar(*sup)?;
1162                Ok(bounded(sub, sup))
1163            }
1164            mono_type_pattern!() => Ok(t),
1165        }
1166    }
1167
1168    fn validate_subsup(
1169        &mut self,
1170        sub_t: Type,
1171        super_t: Type,
1172        fv_hash: usize,
1173    ) -> TyCheckResult<Type> {
1174        // TODO: Subr, ...
1175        match (sub_t, super_t) {
1176            /*(sub_t @ Type::Refinement(_), super_t @ Type::Refinement(_)) => {
1177                self.validate_simple_subsup(sub_t, super_t)
1178            }
1179            (Type::Refinement(refine), super_t) => {
1180                self.validate_simple_subsup(*refine.t, super_t)
1181            }*/
1182            // See tests\should_err\subtyping.er:8~13
1183            (
1184                Poly {
1185                    name: ln,
1186                    params: lps,
1187                },
1188                Poly {
1189                    name: rn,
1190                    params: rps,
1191                },
1192            ) if ln == rn => {
1193                let typ = poly(ln, lps.clone());
1194                let ctx = self.ctx.get_nominal_type_ctx(&typ).ok_or_else(|| {
1195                    TyCheckError::type_not_found(
1196                        self.ctx.cfg.input.clone(),
1197                        line!() as usize,
1198                        self.loc.loc(),
1199                        self.ctx.caused_by(),
1200                        &typ,
1201                    )
1202                })?;
1203                let variances = ctx.type_params_variance();
1204                let mut tps = vec![];
1205                for ((lp, rp), variance) in lps
1206                    .into_iter()
1207                    .zip(rps.into_iter())
1208                    .zip(variances.into_iter().chain(std::iter::repeat(Invariant)))
1209                {
1210                    self.ctx
1211                        .sub_unify_tp(&lp, &rp, Some(variance), self.loc, false)?;
1212                    let param = if variance == Covariant { lp } else { rp };
1213                    tps.push(param);
1214                }
1215                Ok(poly(rn, tps))
1216            }
1217            (sub_t, super_t) => self.validate_simple_subsup(sub_t, super_t, fv_hash),
1218        }
1219    }
1220
1221    fn validate_simple_subsup(
1222        &mut self,
1223        sub_t: Type,
1224        super_t: Type,
1225        fv_hash: usize,
1226    ) -> TyCheckResult<Type> {
1227        let opt_res = self.ctx.shared().gen_cache.get(&fv_hash);
1228        if opt_res.is_none() && self.ctx.is_class(&sub_t) && self.ctx.is_trait(&super_t) {
1229            self.ctx
1230                .check_trait_impl(&sub_t, &super_t, self.qnames, self.loc)?;
1231        }
1232        let is_subtype = opt_res.map(|res| res.is_subtype).unwrap_or_else(|| {
1233            let is_subtype = self.ctx.subtype_of(&sub_t, &super_t); // PERF NOTE: bottleneck
1234            let res = GeneralizationResult {
1235                is_subtype,
1236                impl_trait: true,
1237            };
1238            self.ctx.shared().gen_cache.insert(fv_hash, res);
1239            is_subtype
1240        });
1241        let sub_t = self.deref_tyvar(sub_t)?;
1242        let super_t = self.deref_tyvar(super_t)?;
1243        if sub_t == super_t {
1244            Ok(sub_t)
1245        } else if is_subtype {
1246            match self.current_variance() {
1247                // ?T(<: Sup) --> Sup (Sup != Obj), because completion will not work if Never is selected.
1248                // ?T(:> Never, <: Obj) --> Never
1249                // ?T(:> Never, <: Int) --> Never..Int == Int
1250                Variance::Covariant if self.coerce => {
1251                    if sub_t != Never || super_t == Obj {
1252                        Ok(sub_t)
1253                    } else {
1254                        Ok(bounded(sub_t, super_t))
1255                    }
1256                }
1257                Variance::Contravariant if self.coerce => Ok(super_t),
1258                Variance::Covariant | Variance::Contravariant => Ok(bounded(sub_t, super_t)),
1259                Variance::Invariant => {
1260                    // need to check if sub_t == super_t (sub_t <: super_t is already checked)
1261                    if self.ctx.supertype_of(&sub_t, &super_t) {
1262                        Ok(sub_t)
1263                    } else {
1264                        Err(TyCheckErrors::from(TyCheckError::invariant_error(
1265                            self.ctx.cfg.input.clone(),
1266                            line!() as usize,
1267                            &sub_t,
1268                            &super_t,
1269                            self.loc.loc(),
1270                            self.ctx.caused_by(),
1271                        )))
1272                    }
1273                }
1274            }
1275        } else {
1276            Err(TyCheckErrors::from(TyCheckError::subtyping_error(
1277                self.ctx.cfg.input.clone(),
1278                line!() as usize,
1279                &sub_t,
1280                &super_t,
1281                self.loc.loc(),
1282                self.ctx.caused_by(),
1283            )))
1284        }
1285    }
1286
1287    /// here ?T can be eliminated
1288    /// ```erg
1289    /// ?T -> Int
1290    /// ?T, ?U -> K(?U)
1291    /// Int -> ?T
1292    /// ```
1293    /// here ?T cannot be eliminated
1294    /// ```erg
1295    /// ?T -> ?T
1296    /// ?T -> K(?T)
1297    /// ?T -> ?U(:> ?T)
1298    /// ```
1299    fn eliminate_needless_quant(&mut self, subr: Type) -> TyCheckResult<Type> {
1300        let Ok(mut subr) = SubrType::try_from(subr) else {
1301            unreachable!()
1302        };
1303        let essential_qnames = subr.essential_qnames();
1304        let mut _self = Dereferencer::new(
1305            self.ctx,
1306            self.current_variance(),
1307            self.coerce,
1308            &essential_qnames,
1309            self.loc,
1310        );
1311        for param in subr.non_default_params.iter_mut() {
1312            _self.push_variance(Contravariant);
1313            *param.typ_mut() = _self
1314                .deref_tyvar(mem::take(param.typ_mut()))
1315                .inspect_err(|_e| _self.pop_variance())?;
1316            _self.pop_variance();
1317        }
1318        if let Some(var_args) = &mut subr.var_params {
1319            _self.push_variance(Contravariant);
1320            *var_args.typ_mut() = _self
1321                .deref_tyvar(mem::take(var_args.typ_mut()))
1322                .inspect_err(|_e| _self.pop_variance())?;
1323            _self.pop_variance();
1324        }
1325        for d_param in subr.default_params.iter_mut() {
1326            _self.push_variance(Contravariant);
1327            *d_param.typ_mut() = _self
1328                .deref_tyvar(mem::take(d_param.typ_mut()))
1329                .inspect_err(|_e| {
1330                    _self.pop_variance();
1331                })?;
1332            if let Some(default) = d_param.default_typ_mut() {
1333                *default = _self
1334                    .deref_tyvar(mem::take(default))
1335                    .inspect_err(|_e| _self.pop_variance())?;
1336            }
1337            _self.pop_variance();
1338        }
1339        if let Some(kw_var_args) = &mut subr.kw_var_params {
1340            _self.push_variance(Contravariant);
1341            *kw_var_args.typ_mut() = _self
1342                .deref_tyvar(mem::take(kw_var_args.typ_mut()))
1343                .inspect_err(|_e| _self.pop_variance())?;
1344            if let Some(default) = kw_var_args.default_typ_mut() {
1345                *default = _self
1346                    .deref_tyvar(mem::take(default))
1347                    .inspect_err(|_e| _self.pop_variance())?;
1348            }
1349            _self.pop_variance();
1350        }
1351        _self.push_variance(Covariant);
1352        *subr.return_t = _self
1353            .deref_tyvar(mem::take(&mut subr.return_t))
1354            .inspect_err(|_e| _self.pop_variance())?;
1355        _self.pop_variance();
1356        let subr = Type::Subr(subr);
1357        if subr.has_qvar() {
1358            Ok(subr.quantify())
1359        } else {
1360            Ok(subr)
1361        }
1362    }
1363}
1364
1365impl Context {
1366    pub const TOP_LEVEL: usize = 1;
1367
1368    /// Quantification occurs only once in function types.
1369    /// Therefore, this method is called only once at the top level, and `generalize_t_inner` is called inside.
1370    pub(crate) fn generalize_t(&self, free_type: Type) -> Type {
1371        let mut generalizer = Generalizer::new(self);
1372        let maybe_unbound_t = generalizer.generalize_t(free_type, false);
1373        if maybe_unbound_t.is_subr() && maybe_unbound_t.has_qvar() {
1374            maybe_unbound_t.quantify()
1375        } else {
1376            maybe_unbound_t
1377        }
1378    }
1379
1380    pub fn readable_type(&self, t: Type) -> Type {
1381        let qnames = set! {};
1382        let mut dereferencer = Dereferencer::new(self, Covariant, false, &qnames, &());
1383        dereferencer.set_level(0);
1384        dereferencer.deref_tyvar(t.clone()).unwrap_or(t)
1385    }
1386
1387    /// Fix type variables at their lower bound
1388    /// ```erg
1389    /// i: ?T(:> Int)
1390    /// assert i.Real == 1 # ?T is coerced
1391    /// i: (Int)
1392    /// ```
1393    ///
1394    /// ```erg
1395    /// ?T(:> ?U(:> Int)).coerce(): ?T == ?U == Int
1396    /// ```
1397    pub(crate) fn coerce(&self, t: Type, t_loc: &impl Locational) -> TyCheckResult<Type> {
1398        let qnames = set! {};
1399        let mut dereferencer = Dereferencer::new(self, Covariant, true, &qnames, t_loc);
1400        dereferencer.deref_tyvar(t)
1401    }
1402
1403    pub(crate) fn coerce_tp(&self, tp: TyParam, t_loc: &impl Locational) -> TyCheckResult<TyParam> {
1404        let qnames = set! {};
1405        let mut dereferencer = Dereferencer::new(self, Covariant, true, &qnames, t_loc);
1406        dereferencer.deref_tp(tp)
1407    }
1408
1409    pub(crate) fn trait_impl_exists(&self, class: &Type, trait_: &Type) -> bool {
1410        // `Never` implements any trait
1411        if self.subtype_of(class, &Type::Never) {
1412            return true;
1413        }
1414        if class.is_monomorphic() {
1415            self.mono_class_trait_impl_exist(class, trait_)
1416        } else {
1417            self.poly_class_trait_impl_exists(class, trait_)
1418        }
1419    }
1420
1421    fn mono_class_trait_impl_exist(&self, class: &Type, trait_: &Type) -> bool {
1422        let mut super_exists = false;
1423        for imp in self.get_trait_impls(trait_).into_iter() {
1424            if self.supertype_of(&imp.sub_type, class) && self.supertype_of(&imp.sup_trait, trait_)
1425            {
1426                super_exists = true;
1427                break;
1428            }
1429        }
1430        super_exists
1431    }
1432
1433    /// Check if a trait implementation exists for a polymorphic class.
1434    /// This is needed because the trait implementation spec can contain projection types.
1435    /// e.g. `Tuple(Ts) <: Container(Ts.union())`
1436    fn poly_class_trait_impl_exists(&self, class: &Type, trait_: &Type) -> bool {
1437        for imp in self.get_trait_impls(trait_).into_iter() {
1438            let _sub_subs = Substituter::substitute_typarams(self, &imp.sub_type, class).ok();
1439            let _sup_subs = Substituter::substitute_typarams(self, &imp.sup_trait, trait_).ok();
1440            if self.supertype_of(&imp.sub_type, class) && self.supertype_of(&imp.sup_trait, trait_)
1441            {
1442                return true;
1443            }
1444        }
1445        false
1446    }
1447
1448    fn check_trait_impl(
1449        &self,
1450        class: &Type,
1451        trait_: &Type,
1452        qnames: &Set<Str>,
1453        loc: &impl Locational,
1454    ) -> TyCheckResult<()> {
1455        if !self.trait_impl_exists(class, trait_) {
1456            let mut dereferencer = Dereferencer::new(self, Variance::Covariant, false, qnames, loc);
1457            let class = if DEBUG_MODE {
1458                class.clone()
1459            } else {
1460                dereferencer.deref_tyvar(class.clone())?
1461            };
1462            let trait_ = if DEBUG_MODE {
1463                trait_.clone()
1464            } else {
1465                dereferencer.deref_tyvar(trait_.clone())?
1466            };
1467            Err(TyCheckErrors::from(TyCheckError::no_trait_impl_error(
1468                self.cfg.input.clone(),
1469                line!() as usize,
1470                &class,
1471                &trait_,
1472                loc.loc(),
1473                self.caused_by(),
1474                self.get_simple_type_mismatch_hint(&trait_, &class),
1475            )))
1476        } else {
1477            Ok(())
1478        }
1479    }
1480
1481    /// Check if all types are resolvable (if traits, check if an implementation exists)
1482    /// And replace them if resolvable
1483    pub(crate) fn resolve(
1484        &mut self,
1485        mut hir: hir::HIR,
1486    ) -> Result<hir::HIR, (hir::HIR, TyCheckErrors)> {
1487        self.level = 0;
1488        let mut errs = TyCheckErrors::empty();
1489        for chunk in hir.module.iter_mut() {
1490            if let Err(es) = self.resolve_expr_t(chunk, &set! {}) {
1491                errs.extend(es);
1492            }
1493        }
1494        self.resolve_ctx_vars();
1495        if errs.is_empty() {
1496            Ok(hir)
1497        } else {
1498            Err((hir, errs))
1499        }
1500    }
1501
1502    fn resolve_ctx_vars(&mut self) {
1503        let mut locals = mem::take(&mut self.locals);
1504        let mut params = mem::take(&mut self.params);
1505        let mut methods_list = mem::take(&mut self.methods_list);
1506        for (name, vi) in locals.iter_mut() {
1507            let qnames = set! {};
1508            let mut derferencer = Dereferencer::simple(self, &qnames, name);
1509            if let Ok(t) = derferencer.deref_tyvar(mem::take(&mut vi.t)) {
1510                vi.t = t;
1511            }
1512        }
1513        for (name, vi) in params.iter_mut() {
1514            let qnames = set! {};
1515            let mut derferencer = Dereferencer::simple(self, &qnames, name);
1516            if let Ok(t) = derferencer.deref_tyvar(mem::take(&mut vi.t)) {
1517                vi.t = t;
1518            }
1519        }
1520        for methods in methods_list.iter_mut() {
1521            methods.resolve_ctx_vars();
1522        }
1523        self.locals = locals;
1524        self.params = params;
1525        self.methods_list = methods_list;
1526    }
1527
1528    fn resolve_params_t(&self, params: &mut hir::Params, qnames: &Set<Str>) -> TyCheckResult<()> {
1529        for param in params.non_defaults.iter_mut() {
1530            // generalization should work properly for the subroutine type, but may not work for the parameters' own types
1531            // HACK: so generalize them manually
1532            param.vi.t.generalize();
1533            let t = mem::take(&mut param.vi.t);
1534            let mut dereferencer = Dereferencer::new(self, Contravariant, false, qnames, param);
1535            param.vi.t = dereferencer.deref_tyvar(t)?;
1536        }
1537        if let Some(var_params) = &mut params.var_params {
1538            var_params.vi.t.generalize();
1539            let t = mem::take(&mut var_params.vi.t);
1540            let mut dereferencer =
1541                Dereferencer::new(self, Contravariant, false, qnames, var_params.as_ref());
1542            var_params.vi.t = dereferencer.deref_tyvar(t)?;
1543        }
1544        for param in params.defaults.iter_mut() {
1545            param.sig.vi.t.generalize();
1546            let t = mem::take(&mut param.sig.vi.t);
1547            let mut dereferencer = Dereferencer::new(self, Contravariant, false, qnames, param);
1548            param.sig.vi.t = dereferencer.deref_tyvar(t)?;
1549            self.resolve_expr_t(&mut param.default_val, qnames)?;
1550        }
1551        if let Some(kw_var) = &mut params.kw_var_params {
1552            kw_var.vi.t.generalize();
1553            let t = mem::take(&mut kw_var.vi.t);
1554            let mut dereferencer =
1555                Dereferencer::new(self, Contravariant, false, qnames, kw_var.as_ref());
1556            kw_var.vi.t = dereferencer.deref_tyvar(t)?;
1557        }
1558        for guard in params.guards.iter_mut() {
1559            match guard {
1560                GuardClause::Bind(def) => {
1561                    self.resolve_def_t(def, qnames)?;
1562                }
1563                GuardClause::Condition(cond) => {
1564                    self.resolve_expr_t(cond, qnames)?;
1565                }
1566            }
1567        }
1568        Ok(())
1569    }
1570
1571    /// Resolution should start at a deeper level.
1572    /// For example, if it is a lambda function, the body should be checked before the signature.
1573    /// However, a binop call error, etc., is more important then binop operands.
1574    pub(crate) fn resolve_expr_t(
1575        &self,
1576        expr: &mut hir::Expr,
1577        qnames: &Set<Str>,
1578    ) -> TyCheckResult<()> {
1579        match expr {
1580            hir::Expr::Literal(_) => Ok(()),
1581            hir::Expr::Accessor(acc) => {
1582                let t = mem::take(acc.ref_mut_t().unwrap());
1583                let mut dereferencer = Dereferencer::simple(self, qnames, acc);
1584                *acc.ref_mut_t().unwrap() = dereferencer.deref_tyvar(t)?;
1585                if let hir::Accessor::Attr(attr) = acc {
1586                    self.resolve_expr_t(&mut attr.obj, qnames)?;
1587                }
1588                Ok(())
1589            }
1590            hir::Expr::List(list) => match list {
1591                hir::List::Normal(lis) => {
1592                    for elem in lis.elems.pos_args.iter_mut() {
1593                        self.resolve_expr_t(&mut elem.expr, qnames)?;
1594                    }
1595                    let t = mem::take(&mut lis.t);
1596                    let mut dereferencer = Dereferencer::simple(self, qnames, lis);
1597                    lis.t = dereferencer.deref_tyvar(t)?;
1598                    Ok(())
1599                }
1600                hir::List::WithLength(lis) => {
1601                    self.resolve_expr_t(&mut lis.elem, qnames)?;
1602                    if let Some(len) = &mut lis.len {
1603                        self.resolve_expr_t(len, qnames)?;
1604                    }
1605                    let t = mem::take(&mut lis.t);
1606                    let mut dereferencer = Dereferencer::simple(self, qnames, lis);
1607                    lis.t = dereferencer.deref_tyvar(t)?;
1608                    Ok(())
1609                }
1610                other => feature_error!(
1611                    TyCheckErrors,
1612                    TyCheckError,
1613                    self,
1614                    other.loc(),
1615                    "resolve types of array comprehension"
1616                ),
1617            },
1618            hir::Expr::Tuple(tuple) => match tuple {
1619                hir::Tuple::Normal(tup) => {
1620                    for elem in tup.elems.pos_args.iter_mut() {
1621                        self.resolve_expr_t(&mut elem.expr, qnames)?;
1622                    }
1623                    let t = mem::take(&mut tup.t);
1624                    let mut dereferencer = Dereferencer::simple(self, qnames, tup);
1625                    tup.t = dereferencer.deref_tyvar(t)?;
1626                    Ok(())
1627                }
1628            },
1629            hir::Expr::Set(set) => match set {
1630                hir::Set::Normal(st) => {
1631                    for elem in st.elems.pos_args.iter_mut() {
1632                        self.resolve_expr_t(&mut elem.expr, qnames)?;
1633                    }
1634                    let t = mem::take(&mut st.t);
1635                    let mut dereferencer = Dereferencer::simple(self, qnames, st);
1636                    st.t = dereferencer.deref_tyvar(t)?;
1637                    Ok(())
1638                }
1639                hir::Set::WithLength(st) => {
1640                    self.resolve_expr_t(&mut st.elem, qnames)?;
1641                    self.resolve_expr_t(&mut st.len, qnames)?;
1642                    let t = mem::take(&mut st.t);
1643                    let mut dereferencer = Dereferencer::simple(self, qnames, st);
1644                    st.t = dereferencer.deref_tyvar(t)?;
1645                    Ok(())
1646                }
1647            },
1648            hir::Expr::Dict(dict) => match dict {
1649                hir::Dict::Normal(dic) => {
1650                    for kv in dic.kvs.iter_mut() {
1651                        self.resolve_expr_t(&mut kv.key, qnames)?;
1652                        self.resolve_expr_t(&mut kv.value, qnames)?;
1653                    }
1654                    let t = mem::take(&mut dic.t);
1655                    let mut dereferencer = Dereferencer::simple(self, qnames, dic);
1656                    dic.t = dereferencer.deref_tyvar(t)?;
1657                    Ok(())
1658                }
1659                other => feature_error!(
1660                    TyCheckErrors,
1661                    TyCheckError,
1662                    self,
1663                    other.loc(),
1664                    "resolve types of dict comprehension"
1665                ),
1666            },
1667            hir::Expr::Record(record) => {
1668                for attr in record.attrs.iter_mut() {
1669                    let t = mem::take(attr.sig.ref_mut_t().unwrap());
1670                    let mut dereferencer = Dereferencer::simple(self, qnames, &attr.sig);
1671                    let t = dereferencer.deref_tyvar(t)?;
1672                    *attr.sig.ref_mut_t().unwrap() = t;
1673                    for chunk in attr.body.block.iter_mut() {
1674                        self.resolve_expr_t(chunk, qnames)?;
1675                    }
1676                }
1677                let t = mem::take(&mut record.t);
1678                let mut dereferencer = Dereferencer::simple(self, qnames, record);
1679                record.t = dereferencer.deref_tyvar(t)?;
1680                Ok(())
1681            }
1682            hir::Expr::BinOp(binop) => {
1683                let t = mem::take(binop.signature_mut_t().unwrap());
1684                let mut dereferencer = Dereferencer::simple(self, qnames, binop);
1685                *binop.signature_mut_t().unwrap() = dereferencer.deref_tyvar(t)?;
1686                self.resolve_expr_t(&mut binop.lhs, qnames)?;
1687                self.resolve_expr_t(&mut binop.rhs, qnames)?;
1688                Ok(())
1689            }
1690            hir::Expr::UnaryOp(unaryop) => {
1691                let t = mem::take(unaryop.signature_mut_t().unwrap());
1692                let mut dereferencer = Dereferencer::simple(self, qnames, unaryop);
1693                *unaryop.signature_mut_t().unwrap() = dereferencer.deref_tyvar(t)?;
1694                self.resolve_expr_t(&mut unaryop.expr, qnames)?;
1695                Ok(())
1696            }
1697            hir::Expr::Call(call) => {
1698                for arg in call.args.pos_args.iter_mut() {
1699                    self.resolve_expr_t(&mut arg.expr, qnames)?;
1700                }
1701                if let Some(var_args) = &mut call.args.var_args {
1702                    self.resolve_expr_t(&mut var_args.expr, qnames)?;
1703                }
1704                for arg in call.args.kw_args.iter_mut() {
1705                    self.resolve_expr_t(&mut arg.expr, qnames)?;
1706                }
1707                if let Some(kw_var) = &mut call.args.kw_var {
1708                    self.resolve_expr_t(&mut kw_var.expr, qnames)?;
1709                }
1710                self.resolve_expr_t(&mut call.obj, qnames)?;
1711                if let Some(t) = call.signature_mut_t() {
1712                    let t = mem::take(t);
1713                    let mut dereferencer = Dereferencer::simple(self, qnames, call);
1714                    *call.signature_mut_t().unwrap() = dereferencer.deref_tyvar(t)?;
1715                }
1716                Ok(())
1717            }
1718            hir::Expr::Def(def) => self.resolve_def_t(def, qnames),
1719            hir::Expr::Lambda(lambda) => {
1720                let qnames = if let Type::Quantified(quant) = lambda.ref_t() {
1721                    let Ok(subr) = <&SubrType>::try_from(quant.as_ref()) else {
1722                        unreachable!()
1723                    };
1724                    subr.essential_qnames()
1725                } else {
1726                    qnames.clone()
1727                };
1728                let mut errs = TyCheckErrors::empty();
1729                for chunk in lambda.body.iter_mut() {
1730                    if let Err(es) = self.resolve_expr_t(chunk, &qnames) {
1731                        errs.extend(es);
1732                    }
1733                }
1734                if let Err(es) = self.resolve_params_t(&mut lambda.params, &qnames) {
1735                    errs.extend(es);
1736                }
1737                let t = mem::take(&mut lambda.t);
1738                let mut dereferencer = Dereferencer::simple(self, &qnames, lambda);
1739                match dereferencer.deref_tyvar(t) {
1740                    Ok(t) => lambda.t = t,
1741                    Err(es) => errs.extend(es),
1742                }
1743                if !errs.is_empty() {
1744                    Err(errs)
1745                } else {
1746                    Ok(())
1747                }
1748            }
1749            hir::Expr::ClassDef(class_def) => {
1750                for def in class_def.all_methods_mut() {
1751                    self.resolve_expr_t(def, qnames)?;
1752                }
1753                Ok(())
1754            }
1755            hir::Expr::PatchDef(patch_def) => {
1756                for def in patch_def.methods.iter_mut() {
1757                    self.resolve_expr_t(def, qnames)?;
1758                }
1759                Ok(())
1760            }
1761            hir::Expr::ReDef(redef) => {
1762                // REVIEW: redef.attr is not dereferenced
1763                for chunk in redef.block.iter_mut() {
1764                    self.resolve_expr_t(chunk, qnames)?;
1765                }
1766                Ok(())
1767            }
1768            hir::Expr::TypeAsc(tasc) => self.resolve_expr_t(&mut tasc.expr, qnames),
1769            hir::Expr::Code(chunks) | hir::Expr::Compound(chunks) => {
1770                for chunk in chunks.iter_mut() {
1771                    self.resolve_expr_t(chunk, qnames)?;
1772                }
1773                Ok(())
1774            }
1775            hir::Expr::Dummy(chunks) => {
1776                for chunk in chunks.iter_mut() {
1777                    self.resolve_expr_t(chunk, qnames)?;
1778                }
1779                Ok(())
1780            }
1781            hir::Expr::Import(_) => unreachable_error!(TyCheckErrors, TyCheckError, self),
1782        }
1783    }
1784
1785    fn resolve_def_t(&self, def: &mut hir::Def, qnames: &Set<Str>) -> TyCheckResult<()> {
1786        let qnames = if let Type::Quantified(quant) = def.sig.ref_t() {
1787            // double quantification is not allowed
1788            let Ok(subr) = <&SubrType>::try_from(quant.as_ref()) else {
1789                unreachable!()
1790            };
1791            subr.essential_qnames()
1792        } else {
1793            qnames.clone()
1794        };
1795        let t = mem::take(def.sig.ref_mut_t().unwrap());
1796        let mut dereferencer = Dereferencer::simple(self, &qnames, &def.sig);
1797        *def.sig.ref_mut_t().unwrap() = dereferencer.deref_tyvar(t)?;
1798        if let Some(params) = def.sig.params_mut() {
1799            self.resolve_params_t(params, &qnames)?;
1800        }
1801        for chunk in def.body.block.iter_mut() {
1802            self.resolve_expr_t(chunk, &qnames)?;
1803        }
1804        Ok(())
1805    }
1806
1807    /// ```erg
1808    /// squash_tyvar(?1 or ?2) == ?1(== ?2)
1809    /// squash_tyvar(?T or ?U) == ?T or ?U
1810    /// squash_tyvar(?T or NoneType) == ?T or Nonetype
1811    /// ```
1812    pub(crate) fn squash_tyvar(&self, typ: Type) -> Type {
1813        match typ {
1814            Or(tys) => {
1815                let new_tys = tys
1816                    .into_iter()
1817                    .map(|t| self.squash_tyvar(t))
1818                    .collect::<Vec<_>>();
1819                let mut union = Never;
1820                // REVIEW:
1821                if new_tys.iter().all(|t| t.is_unnamed_unbound_var()) {
1822                    for ty in new_tys.iter() {
1823                        if union == Never {
1824                            union = ty.clone();
1825                            continue;
1826                        }
1827                        match (self.subtype_of(&union, ty), self.subtype_of(&union, ty)) {
1828                            (true, true) | (true, false) => {
1829                                let _ = self.sub_unify(&union, ty, &(), None);
1830                            }
1831                            (false, true) => {
1832                                let _ = self.sub_unify(ty, &union, &(), None);
1833                            }
1834                            _ => {}
1835                        }
1836                    }
1837                }
1838                new_tys
1839                    .into_iter()
1840                    .fold(Never, |acc, t| self.union(&acc, &t))
1841            }
1842            FreeVar(ref fv) if fv.constraint_is_sandwiched() => {
1843                let (sub_t, super_t) = fv.get_subsup().unwrap();
1844                let sub_t = self.squash_tyvar(sub_t);
1845                let super_t = self.squash_tyvar(super_t);
1846                typ.update_tyvar(sub_t, super_t, None, false);
1847                typ
1848            }
1849            other => other,
1850        }
1851    }
1852}