1use std::mem;
2
3use erg_common::consts::DEBUG_MODE;
4use erg_common::set::Set;
5use erg_common::traits::{Locational, Stream};
6use erg_common::Str;
7use erg_common::{dict, fn_name, get_hash, set};
8#[allow(unused_imports)]
9use erg_common::{fmt_vec, log};
10
11use crate::hir::GuardClause;
12use crate::module::GeneralizationResult;
13use crate::ty::constructors::*;
14use crate::ty::free::{CanbeFree, Constraint, Free, HasLevel};
15use crate::ty::typaram::{TyParam, TyParamLambda};
16use crate::ty::value::ValueObj;
17use crate::ty::{HasType, Predicate, SharedFrees, SubrType, Type};
18
19use crate::context::{Context, Variance};
20use crate::error::{TyCheckError, TyCheckErrors, TyCheckResult};
21use crate::{feature_error, hir, mono_type_pattern, mono_value_pattern, unreachable_error};
22
23use Type::*;
24use Variance::*;
25
26use super::eval::{Substituter, UndoableLinkedList};
27
28pub struct Generalizer<'c> {
29 ctx: &'c Context,
30 variance: Variance,
31 qnames: Set<Str>,
32 structural_inner: bool,
33}
34
35impl<'c> Generalizer<'c> {
36 pub fn new(ctx: &'c Context) -> Self {
37 Self {
38 ctx,
39 variance: Covariant,
40 qnames: set! {},
41 structural_inner: false,
42 }
43 }
44
45 fn generalize_tp(&mut self, free: TyParam, uninit: bool) -> TyParam {
46 match free {
47 TyParam::Type(t) => TyParam::t(self.generalize_t(*t, uninit)),
48 TyParam::Value(val) => {
49 TyParam::Value(val.map_t(&mut |t| self.generalize_t(t, uninit)).map_tp(
50 &mut |tp| self.generalize_tp(tp, uninit),
51 &SharedFrees::new(),
52 ))
53 }
54 TyParam::FreeVar(fv) if fv.is_generalized() => TyParam::FreeVar(fv),
55 TyParam::FreeVar(fv) if fv.is_linked() => {
56 let tp = fv.crack().clone();
57 self.generalize_tp(tp, uninit)
58 }
59 TyParam::FreeVar(fv) if fv.level() > Some(self.ctx.level) => {
61 let constr = self.generalize_constraint(&fv);
62 fv.update_constraint(constr, true);
63 fv.generalize();
64 TyParam::FreeVar(fv)
65 }
66 TyParam::List(tps) => TyParam::List(
67 tps.into_iter()
68 .map(|tp| self.generalize_tp(tp, uninit))
69 .collect(),
70 ),
71 TyParam::UnsizedList(tp) => {
72 TyParam::UnsizedList(Box::new(self.generalize_tp(*tp, uninit)))
73 }
74 TyParam::Tuple(tps) => TyParam::Tuple(
75 tps.into_iter()
76 .map(|tp| self.generalize_tp(tp, uninit))
77 .collect(),
78 ),
79 TyParam::Set(set) => TyParam::Set(
80 set.into_iter()
81 .map(|tp| self.generalize_tp(tp, uninit))
82 .collect(),
83 ),
84 TyParam::Dict(tps) => TyParam::Dict(
85 tps.into_iter()
86 .map(|(k, v)| (self.generalize_tp(k, uninit), self.generalize_tp(v, uninit)))
87 .collect(),
88 ),
89 TyParam::Record(rec) => TyParam::Record(
90 rec.into_iter()
91 .map(|(field, tp)| (field, self.generalize_tp(tp, uninit)))
92 .collect(),
93 ),
94 TyParam::DataClass { name, fields } => {
95 let fields = fields
96 .into_iter()
97 .map(|(field, tp)| (field, self.generalize_tp(tp, uninit)))
98 .collect();
99 TyParam::DataClass { name, fields }
100 }
101 TyParam::Lambda(lambda) => {
102 let nd_params = lambda
103 .nd_params
104 .into_iter()
105 .map(|pt| pt.map_type(&mut |t| self.generalize_t(t, uninit)))
106 .collect::<Vec<_>>();
107 let var_params = lambda
108 .var_params
109 .map(|pt| pt.map_type(&mut |t| self.generalize_t(t, uninit)));
110 let d_params = lambda
111 .d_params
112 .into_iter()
113 .map(|pt| pt.map_type(&mut |t| self.generalize_t(t, uninit)))
114 .collect::<Vec<_>>();
115 let kw_var_params = lambda
116 .kw_var_params
117 .map(|pt| pt.map_type(&mut |t| self.generalize_t(t, uninit)));
118 let body = lambda
119 .body
120 .into_iter()
121 .map(|tp| self.generalize_tp(tp, uninit))
122 .collect();
123 TyParam::Lambda(TyParamLambda::new(
124 lambda.const_,
125 nd_params,
126 var_params,
127 d_params,
128 kw_var_params,
129 body,
130 ))
131 }
132 TyParam::FreeVar(_) => free,
133 TyParam::Proj { obj, attr } => {
134 let obj = self.generalize_tp(*obj, uninit);
135 TyParam::proj(obj, attr)
136 }
137 TyParam::ProjCall { obj, attr, args } => {
138 let obj = self.generalize_tp(*obj, uninit);
139 let args = args
140 .into_iter()
141 .map(|tp| self.generalize_tp(tp, uninit))
142 .collect();
143 TyParam::proj_call(obj, attr, args)
144 }
145 TyParam::Erased(t) => TyParam::erased(self.generalize_t(*t, uninit)),
146 TyParam::App { name, args } => {
147 let args = args
148 .into_iter()
149 .map(|tp| self.generalize_tp(tp, uninit))
150 .collect();
151 TyParam::App { name, args }
152 }
153 TyParam::BinOp { op, lhs, rhs } => {
154 let lhs = self.generalize_tp(*lhs, uninit);
155 let rhs = self.generalize_tp(*rhs, uninit);
156 TyParam::bin(op, lhs, rhs)
157 }
158 TyParam::UnaryOp { op, val } => {
159 let val = self.generalize_tp(*val, uninit);
160 TyParam::unary(op, val)
161 }
162 TyParam::Mono(_) | TyParam::Failure => free,
163 }
164 }
165
166 fn generalize_t(&mut self, free_type: Type, uninit: bool) -> Type {
174 match free_type {
175 FreeVar(fv) if fv.is_linked() => self.generalize_t(fv.unwrap_linked(), uninit),
176 FreeVar(fv) if fv.is_generalized() => Type::FreeVar(fv),
177 FreeVar(fv) if fv.level().unwrap() > self.ctx.level => {
179 fv.generalize();
180 if uninit {
181 return Type::FreeVar(fv);
182 }
183 if let Some((sub, sup)) = fv.get_subsup() {
184 if sub == sup {
186 let t = self.generalize_t(sub, uninit);
187 let res = FreeVar(fv);
188 res.set_level(1);
189 res.destructive_link(&t);
190 res.generalize();
191 res
192 } else if sup != Obj
193 && self.variance == Contravariant
194 && !self.qnames.contains(&fv.unbound_name().unwrap())
195 {
196 self.generalize_t(sup, uninit)
198 } else if sub != Never
199 && self.variance == Covariant
200 && !self.qnames.contains(&fv.unbound_name().unwrap())
201 {
202 self.generalize_t(sub, uninit)
204 } else {
205 let constr = self.generalize_constraint(&fv);
206 let ty = Type::FreeVar(fv);
207 ty.update_constraint(constr, None, true);
208 ty
209 }
210 } else {
211 let constr = self.generalize_constraint(&fv);
213 let ty = Type::FreeVar(fv);
214 ty.update_constraint(constr, None, true);
215 ty
216 }
217 }
218 FreeVar(_) => free_type,
219 Subr(mut subr) => {
220 self.variance = Contravariant;
221 let qnames = subr.essential_qnames();
222 self.qnames.extend(qnames.clone());
223 subr.non_default_params.iter_mut().for_each(|nd_param| {
224 *nd_param.typ_mut() = self.generalize_t(mem::take(nd_param.typ_mut()), uninit);
225 });
226 if let Some(var_params) = &mut subr.var_params {
227 *var_params.typ_mut() =
228 self.generalize_t(mem::take(var_params.typ_mut()), uninit);
229 }
230 subr.default_params.iter_mut().for_each(|d_param| {
231 *d_param.typ_mut() = self.generalize_t(mem::take(d_param.typ_mut()), uninit);
232 if let Some(default) = d_param.default_typ_mut() {
233 *default = self.generalize_t(mem::take(default), uninit);
234 }
235 });
236 if let Some(kw_var_params) = &mut subr.kw_var_params {
237 *kw_var_params.typ_mut() =
238 self.generalize_t(mem::take(kw_var_params.typ_mut()), uninit);
239 if let Some(default) = kw_var_params.default_typ_mut() {
240 *default = self.generalize_t(mem::take(default), uninit);
241 }
242 }
243 self.variance = Covariant;
244 let return_t = self.generalize_t(*subr.return_t, uninit);
245 self.qnames = self.qnames.difference(&qnames);
246 subr_t(
247 subr.kind,
248 subr.non_default_params,
249 subr.var_params.map(|x| *x),
250 subr.default_params,
251 subr.kw_var_params.map(|x| *x),
252 return_t,
253 )
254 }
255 Quantified(quant) => {
256 log!(err "{quant}");
257 quant.quantify()
258 }
259 Record(rec) => {
260 let fields = rec
261 .into_iter()
262 .map(|(name, t)| (name, self.generalize_t(t, uninit)))
263 .collect();
264 Type::Record(fields)
265 }
266 NamedTuple(rec) => {
267 let fields = rec
268 .into_iter()
269 .map(|(name, t)| (name, self.generalize_t(t, uninit)))
270 .collect();
271 Type::NamedTuple(fields)
272 }
273 Callable { param_ts, return_t } => {
274 let param_ts = param_ts
275 .into_iter()
276 .map(|t| self.generalize_t(t, uninit))
277 .collect();
278 let return_t = self.generalize_t(*return_t, uninit);
279 callable(param_ts, return_t)
280 }
281 Ref(t) => ref_(self.generalize_t(*t, uninit)),
282 RefMut { before, after } => {
283 let after = after.map(|aft| self.generalize_t(*aft, uninit));
284 ref_mut(self.generalize_t(*before, uninit), after)
285 }
286 Refinement(refine) => {
287 let t = self.generalize_t(*refine.t, uninit);
288 let pred = self.generalize_pred(*refine.pred, uninit);
289 refinement(refine.var, t, pred)
290 }
291 Poly { name, mut params } => {
292 let params = params
293 .iter_mut()
294 .map(|p| self.generalize_tp(mem::take(p), uninit))
295 .collect::<Vec<_>>();
296 poly(name, params)
297 }
298 Proj { lhs, rhs } => {
299 let lhs = self.generalize_t(*lhs, uninit);
300 proj(lhs, rhs)
301 }
302 ProjCall {
303 lhs,
304 attr_name,
305 mut args,
306 } => {
307 let lhs = self.generalize_tp(*lhs, uninit);
308 for arg in args.iter_mut() {
309 *arg = self.generalize_tp(mem::take(arg), uninit);
310 }
311 proj_call(lhs, attr_name, args)
312 }
313 And(ands, idx) => {
314 let ands = ands
316 .into_iter()
317 .map(|t| self.generalize_t(t, uninit))
318 .collect::<Vec<_>>();
319 let isec = ands
320 .into_iter()
321 .fold(Obj, |acc, t| self.ctx.intersection(&acc, &t));
322 if let Some(idx) = idx {
323 isec.with_default_intersec_index(idx)
324 } else {
325 isec
326 }
327 }
328 Or(ors) => {
329 let ors = ors
331 .into_iter()
332 .map(|t| self.generalize_t(t, uninit))
333 .collect::<Set<_>>();
334 ors.into_iter()
335 .fold(Never, |acc, t| self.ctx.union(&acc, &t))
336 }
337 Not(l) => not(self.generalize_t(*l, uninit)),
338 Structural(ty) => {
339 if self.structural_inner {
340 ty.structuralize()
341 } else {
342 if ty.is_recursive() {
343 self.structural_inner = true;
344 }
345 let res = self.generalize_t(*ty, uninit).structuralize();
346 self.structural_inner = false;
347 res
348 }
349 }
350 Guard(grd) => {
351 let to = self.generalize_t(*grd.to, uninit);
352 guard(grd.namespace, *grd.target, to)
353 }
354 Bounded { sub, sup } => {
355 let sub = self.generalize_t(*sub, uninit);
356 let sup = self.generalize_t(*sup, uninit);
357 bounded(sub, sup)
358 }
359 Int | Nat | Float | Ratio | Complex | Bool | Str | Never | Obj | Type | Error
360 | Code | Frame | NoneType | Inf | NegInf | NotImplementedType | Ellipsis
361 | ClassType | TraitType | Patch | Failure | Uninited | Mono(_) => free_type,
362 }
363 }
364
365 fn generalize_constraint<T: CanbeFree + Send + Clone>(&mut self, fv: &Free<T>) -> Constraint {
366 if let Some((sub, sup)) = fv.get_subsup() {
367 let sub = self.generalize_t(sub, true);
368 let sup = self.generalize_t(sup, true);
369 Constraint::new_sandwiched(sub, sup)
370 } else if let Some(ty) = fv.get_type() {
371 let t = self.generalize_t(ty, true);
372 Constraint::new_type_of(t)
373 } else {
374 unreachable!()
375 }
376 }
377
378 fn generalize_pred(&mut self, pred: Predicate, uninit: bool) -> Predicate {
379 match pred {
380 Predicate::Const(_) | Predicate::Failure => pred,
381 Predicate::Value(val) => {
382 Predicate::Value(val.map_t(&mut |t| self.generalize_t(t, uninit)))
383 }
384 Predicate::Call {
385 receiver,
386 name,
387 args,
388 } => {
389 let receiver = self.generalize_tp(receiver, uninit);
390 let mut new_args = vec![];
391 for arg in args.into_iter() {
392 new_args.push(self.generalize_tp(arg, uninit));
393 }
394 Predicate::call(receiver, name, new_args)
395 }
396 Predicate::Attr { receiver, name } => {
397 let receiver = self.generalize_tp(receiver, uninit);
398 Predicate::attr(receiver, name)
399 }
400 Predicate::GeneralEqual { lhs, rhs } => {
401 let lhs = self.generalize_pred(*lhs, uninit);
402 let rhs = self.generalize_pred(*rhs, uninit);
403 Predicate::general_eq(lhs, rhs)
404 }
405 Predicate::GeneralGreaterEqual { lhs, rhs } => {
406 let lhs = self.generalize_pred(*lhs, uninit);
407 let rhs = self.generalize_pred(*rhs, uninit);
408 Predicate::general_ge(lhs, rhs)
409 }
410 Predicate::GeneralLessEqual { lhs, rhs } => {
411 let lhs = self.generalize_pred(*lhs, uninit);
412 let rhs = self.generalize_pred(*rhs, uninit);
413 Predicate::general_le(lhs, rhs)
414 }
415 Predicate::GeneralNotEqual { lhs, rhs } => {
416 let lhs = self.generalize_pred(*lhs, uninit);
417 let rhs = self.generalize_pred(*rhs, uninit);
418 Predicate::general_ne(lhs, rhs)
419 }
420 Predicate::Equal { lhs, rhs } => {
421 let rhs = self.generalize_tp(rhs, uninit);
422 Predicate::eq(lhs, rhs)
423 }
424 Predicate::GreaterEqual { lhs, rhs } => {
425 let rhs = self.generalize_tp(rhs, uninit);
426 Predicate::ge(lhs, rhs)
427 }
428 Predicate::LessEqual { lhs, rhs } => {
429 let rhs = self.generalize_tp(rhs, uninit);
430 Predicate::le(lhs, rhs)
431 }
432 Predicate::NotEqual { lhs, rhs } => {
433 let rhs = self.generalize_tp(rhs, uninit);
434 Predicate::ne(lhs, rhs)
435 }
436 Predicate::And(lhs, rhs) => {
437 let lhs = self.generalize_pred(*lhs, uninit);
438 let rhs = self.generalize_pred(*rhs, uninit);
439 Predicate::and(lhs, rhs)
440 }
441 Predicate::Or(preds) => Predicate::Or(
442 preds
443 .into_iter()
444 .map(|pred| self.generalize_pred(pred, uninit))
445 .collect(),
446 ),
447 Predicate::Not(pred) => {
448 let pred = self.generalize_pred(*pred, uninit);
449 !pred
450 }
451 }
452 }
453}
454
455pub struct Dereferencer<'c, 'q, 'l, L: Locational> {
456 ctx: &'c Context,
457 level: usize,
459 coerce: bool,
460 variance_stack: Vec<Variance>,
461 qnames: &'q Set<Str>,
462 loc: &'l L,
463}
464
465impl<'c, 'q, 'l, L: Locational> Dereferencer<'c, 'q, 'l, L> {
466 pub fn new(
467 ctx: &'c Context,
468 variance: Variance,
469 coerce: bool,
470 qnames: &'q Set<Str>,
471 loc: &'l L,
472 ) -> Self {
473 Self {
474 ctx,
475 level: ctx.level,
476 coerce,
477 variance_stack: vec![Invariant, variance],
478 qnames,
479 loc,
480 }
481 }
482
483 pub fn simple(ctx: &'c Context, qnames: &'q Set<Str>, loc: &'l L) -> Self {
484 Self::new(ctx, Variance::Covariant, true, qnames, loc)
485 }
486
487 pub fn set_level(&mut self, level: usize) {
488 self.level = level;
489 }
490
491 fn push_variance(&mut self, variance: Variance) {
492 self.variance_stack.push(variance);
493 }
494
495 fn pop_variance(&mut self) {
496 self.variance_stack.pop();
497 }
498
499 fn current_variance(&self) -> Variance {
500 *self.variance_stack.last().unwrap()
501 }
502
503 fn deref_value(&mut self, val: ValueObj) -> TyCheckResult<ValueObj> {
504 match val {
505 ValueObj::Type(mut t) => {
506 t.try_map_t(&mut |t| self.deref_tyvar(t.clone()))?;
507 Ok(ValueObj::Type(t))
508 }
509 ValueObj::List(vs) => {
510 let mut new_vs = vec![];
511 for v in vs.iter() {
512 new_vs.push(self.deref_value(v.clone())?);
513 }
514 Ok(ValueObj::List(new_vs.into()))
515 }
516 ValueObj::Tuple(vs) => {
517 let mut new_vs = vec![];
518 for v in vs.iter() {
519 new_vs.push(self.deref_value(v.clone())?);
520 }
521 Ok(ValueObj::Tuple(new_vs.into()))
522 }
523 ValueObj::Dict(dic) => {
524 let mut new_dic = dict! {};
525 for (k, v) in dic.into_iter() {
526 let k = self.deref_value(k)?;
527 let v = self.deref_value(v)?;
528 new_dic.insert(k, v);
529 }
530 Ok(ValueObj::Dict(new_dic))
531 }
532 ValueObj::Set(set) => {
533 let mut new_set = set! {};
534 for v in set.into_iter() {
535 new_set.insert(self.deref_value(v)?);
536 }
537 Ok(ValueObj::Set(new_set))
538 }
539 ValueObj::Record(rec) => {
540 let mut new_rec = dict! {};
541 for (field, v) in rec.into_iter() {
542 new_rec.insert(field, self.deref_value(v)?);
543 }
544 Ok(ValueObj::Record(new_rec))
545 }
546 ValueObj::DataClass { name, fields } => {
547 let mut new_fields = dict! {};
548 for (field, v) in fields.into_iter() {
549 new_fields.insert(field, self.deref_value(v)?);
550 }
551 Ok(ValueObj::DataClass {
552 name,
553 fields: new_fields,
554 })
555 }
556 ValueObj::UnsizedList(v) => Ok(ValueObj::UnsizedList(Box::new(self.deref_value(*v)?))),
557 ValueObj::Subr(subr) => Ok(ValueObj::Subr(subr)),
558 mono_value_pattern!() => Ok(val),
559 }
560 }
561
562 pub(crate) fn deref_tp(&mut self, tp: TyParam) -> TyCheckResult<TyParam> {
563 match tp {
564 TyParam::FreeVar(fv) if fv.is_linked() => {
565 let inner = fv.unwrap_linked();
566 self.deref_tp(inner)
567 }
568 TyParam::FreeVar(fv)
569 if fv.is_generalized() && self.qnames.contains(&fv.unbound_name().unwrap()) =>
570 {
571 Ok(TyParam::FreeVar(fv))
572 }
573 TyParam::FreeVar(_) if self.level == 0 => {
575 let t = self.ctx.get_tp_t(&tp).unwrap_or(Type::Obj);
576 Ok(TyParam::erased(self.deref_tyvar(t)?))
577 }
578 TyParam::FreeVar(fv) if fv.get_type().is_some() => {
579 let t = self.deref_tyvar(fv.get_type().unwrap())?;
580 fv.update_type(t);
581 Ok(TyParam::FreeVar(fv))
582 }
583 TyParam::FreeVar(_) => Ok(tp),
584 TyParam::Type(t) => Ok(TyParam::t(self.deref_tyvar(*t)?)),
585 TyParam::Value(val) => self.deref_value(val).map(TyParam::Value),
586 TyParam::Erased(t) => Ok(TyParam::erased(self.deref_tyvar(*t)?)),
587 TyParam::App { name, mut args } => {
588 for param in args.iter_mut() {
589 *param = self.deref_tp(mem::take(param))?;
590 }
591 Ok(TyParam::App { name, args })
592 }
593 TyParam::BinOp { op, lhs, rhs } => {
594 let lhs = self.deref_tp(*lhs)?;
595 let rhs = self.deref_tp(*rhs)?;
596 Ok(TyParam::BinOp {
597 op,
598 lhs: Box::new(lhs),
599 rhs: Box::new(rhs),
600 })
601 }
602 TyParam::UnaryOp { op, val } => {
603 let val = self.deref_tp(*val)?;
604 Ok(TyParam::UnaryOp {
605 op,
606 val: Box::new(val),
607 })
608 }
609 TyParam::List(tps) => {
610 let mut new_tps = vec![];
611 for tp in tps {
612 new_tps.push(self.deref_tp(tp)?);
613 }
614 Ok(TyParam::List(new_tps))
615 }
616 TyParam::UnsizedList(tp) => Ok(TyParam::UnsizedList(Box::new(self.deref_tp(*tp)?))),
617 TyParam::Tuple(tps) => {
618 let mut new_tps = vec![];
619 for tp in tps {
620 new_tps.push(self.deref_tp(tp)?);
621 }
622 Ok(TyParam::Tuple(new_tps))
623 }
624 TyParam::Dict(dic) => {
625 let mut new_dic = dict! {};
626 for (k, v) in dic.into_iter() {
627 let k = self.deref_tp(k)?;
628 let v = self.deref_tp(v)?;
629 new_dic
630 .entry(k)
631 .and_modify(|old_v| {
632 if let Some(union) = self.ctx.union_tp(&mem::take(old_v), &v) {
633 *old_v = union;
634 }
635 })
636 .or_insert(v);
637 }
638 Ok(TyParam::Dict(new_dic))
639 }
640 TyParam::Set(set) => {
641 let mut new_set = set! {};
642 for v in set.into_iter() {
643 new_set.insert(self.deref_tp(v)?);
644 }
645 Ok(TyParam::Set(new_set))
646 }
647 TyParam::Record(rec) => {
648 let mut new_rec = dict! {};
649 for (field, tp) in rec.into_iter() {
650 new_rec.insert(field, self.deref_tp(tp)?);
651 }
652 Ok(TyParam::Record(new_rec))
653 }
654 TyParam::DataClass { name, fields } => {
655 let mut new_fields = dict! {};
656 for (field, tp) in fields.into_iter() {
657 new_fields.insert(field, self.deref_tp(tp)?);
658 }
659 Ok(TyParam::DataClass {
660 name,
661 fields: new_fields,
662 })
663 }
664 TyParam::Lambda(lambda) => {
665 let nd_params = lambda
666 .nd_params
667 .into_iter()
668 .map(|pt| pt.try_map_type(&mut |t| self.deref_tyvar(t)))
669 .collect::<TyCheckResult<_>>()?;
670 let var_params = lambda
671 .var_params
672 .map(|pt| pt.try_map_type(&mut |t| self.deref_tyvar(t)))
673 .transpose()?;
674 let d_params = lambda
675 .d_params
676 .into_iter()
677 .map(|pt| pt.try_map_type(&mut |t| self.deref_tyvar(t)))
678 .collect::<TyCheckResult<_>>()?;
679 let kw_var_params = lambda
680 .kw_var_params
681 .map(|pt| pt.try_map_type(&mut |t| self.deref_tyvar(t)))
682 .transpose()?;
683 let body = lambda
684 .body
685 .into_iter()
686 .map(|tp| self.deref_tp(tp))
687 .collect::<TyCheckResult<Vec<_>>>()?;
688 Ok(TyParam::Lambda(TyParamLambda::new(
689 lambda.const_,
690 nd_params,
691 var_params,
692 d_params,
693 kw_var_params,
694 body,
695 )))
696 }
697 TyParam::Proj { obj, attr } => {
698 let obj = self.deref_tp(*obj)?;
699 Ok(TyParam::Proj {
700 obj: Box::new(obj),
701 attr,
702 })
703 }
704 TyParam::ProjCall { obj, attr, args } => {
705 let obj = self.deref_tp(*obj)?;
706 let mut new_args = vec![];
707 for arg in args.into_iter() {
708 new_args.push(self.deref_tp(arg)?);
709 }
710 Ok(TyParam::ProjCall {
711 obj: Box::new(obj),
712 attr,
713 args: new_args,
714 })
715 }
716 TyParam::Mono(_) | TyParam::Failure => Ok(tp),
717 }
718 }
719
720 fn deref_pred(&mut self, pred: Predicate) -> TyCheckResult<Predicate> {
721 match pred {
722 Predicate::Equal { lhs, rhs } => {
723 let rhs = self.deref_tp(rhs)?;
724 Ok(Predicate::eq(lhs, rhs))
725 }
726 Predicate::GreaterEqual { lhs, rhs } => {
727 let rhs = self.deref_tp(rhs)?;
728 Ok(Predicate::ge(lhs, rhs))
729 }
730 Predicate::LessEqual { lhs, rhs } => {
731 let rhs = self.deref_tp(rhs)?;
732 Ok(Predicate::le(lhs, rhs))
733 }
734 Predicate::NotEqual { lhs, rhs } => {
735 let rhs = self.deref_tp(rhs)?;
736 Ok(Predicate::ne(lhs, rhs))
737 }
738 Predicate::GeneralEqual { lhs, rhs } => {
739 let lhs = self.deref_pred(*lhs)?;
740 let rhs = self.deref_pred(*rhs)?;
741 match (lhs, rhs) {
742 (Predicate::Value(lhs), Predicate::Value(rhs)) => {
743 Ok(Predicate::Value(ValueObj::Bool(lhs == rhs)))
744 }
745 (lhs, rhs) => Ok(Predicate::general_eq(lhs, rhs)),
746 }
747 }
748 Predicate::GeneralNotEqual { lhs, rhs } => {
749 let lhs = self.deref_pred(*lhs)?;
750 let rhs = self.deref_pred(*rhs)?;
751 match (lhs, rhs) {
752 (Predicate::Value(lhs), Predicate::Value(rhs)) => {
753 Ok(Predicate::Value(ValueObj::Bool(lhs != rhs)))
754 }
755 (lhs, rhs) => Ok(Predicate::general_ne(lhs, rhs)),
756 }
757 }
758 Predicate::GeneralGreaterEqual { lhs, rhs } => {
759 let lhs = self.deref_pred(*lhs)?;
760 let rhs = self.deref_pred(*rhs)?;
761 match (lhs, rhs) {
762 (Predicate::Value(lhs), Predicate::Value(rhs)) => {
763 let Some(ValueObj::Bool(res)) = lhs.try_ge(rhs) else {
764 return Err(TyCheckErrors::from(TyCheckError::dummy_infer_error(
766 self.ctx.cfg.input.clone(),
767 fn_name!(),
768 line!(),
769 )));
770 };
771 Ok(Predicate::Value(ValueObj::Bool(res)))
772 }
773 (lhs, rhs) => Ok(Predicate::general_ge(lhs, rhs)),
774 }
775 }
776 Predicate::GeneralLessEqual { lhs, rhs } => {
777 let lhs = self.deref_pred(*lhs)?;
778 let rhs = self.deref_pred(*rhs)?;
779 match (lhs, rhs) {
780 (Predicate::Value(lhs), Predicate::Value(rhs)) => {
781 let Some(ValueObj::Bool(res)) = lhs.try_le(rhs) else {
782 return Err(TyCheckErrors::from(TyCheckError::dummy_infer_error(
783 self.ctx.cfg.input.clone(),
784 fn_name!(),
785 line!(),
786 )));
787 };
788 Ok(Predicate::Value(ValueObj::Bool(res)))
789 }
790 (lhs, rhs) => Ok(Predicate::general_le(lhs, rhs)),
791 }
792 }
793 Predicate::Call {
794 receiver,
795 name,
796 args,
797 } => {
798 let Ok(receiver) = self.deref_tp(receiver.clone()) else {
799 return Ok(Predicate::call(receiver, name, args));
800 };
801 let mut new_args = vec![];
802 for arg in args.into_iter() {
803 let Ok(arg) = self.deref_tp(arg) else {
804 return Ok(Predicate::call(receiver, name, new_args));
805 };
806 new_args.push(arg);
807 }
808 let evaled = if let Some(name) = &name {
809 self.ctx
810 .eval_proj_call(receiver.clone(), name.clone(), new_args.clone(), &())
811 } else {
812 self.ctx.eval_call(receiver.clone(), new_args.clone(), &())
813 };
814 match evaled {
815 Ok(TyParam::Value(value)) => Ok(Predicate::Value(value)),
816 _ => Ok(Predicate::call(receiver, name, new_args)),
817 }
818 }
819 Predicate::And(lhs, rhs) => {
820 let lhs = self.deref_pred(*lhs)?;
821 let rhs = self.deref_pred(*rhs)?;
822 Ok(Predicate::and(lhs, rhs))
823 }
824 Predicate::Or(preds) => {
825 let mut new_preds = Set::with_capacity(preds.len());
826 for pred in preds.into_iter() {
827 new_preds.insert(self.deref_pred(pred)?);
828 }
829 Ok(Predicate::Or(new_preds))
830 }
831 Predicate::Not(pred) => {
832 let pred = self.deref_pred(*pred)?;
833 Ok(!pred)
834 }
835 Predicate::Attr { receiver, name } => {
836 let receiver = self.deref_tp(receiver)?;
837 Ok(Predicate::attr(receiver, name))
838 }
839 Predicate::Value(v) => self.deref_value(v).map(Predicate::Value),
840 Predicate::Const(_) | Predicate::Failure => Ok(pred),
841 }
842 }
843
844 fn deref_constraint(&mut self, constraint: Constraint) -> TyCheckResult<Constraint> {
845 match constraint {
846 Constraint::Sandwiched { sub, sup } => Ok(Constraint::new_sandwiched(
847 self.deref_tyvar(sub)?,
848 self.deref_tyvar(sup)?,
849 )),
850 Constraint::TypeOf(t) => Ok(Constraint::new_type_of(self.deref_tyvar(t)?)),
851 _ => unreachable_error!(TyCheckErrors, TyCheckError, self.ctx),
852 }
853 }
854
855 pub(crate) fn deref_tyvar(&mut self, t: Type) -> TyCheckResult<Type> {
864 match t {
865 FreeVar(fv) if fv.is_linked() => {
866 let t = fv.unwrap_linked();
867 if t.is_recursive() {
869 Ok(Type::Never)
870 } else {
871 self.deref_tyvar(t)
872 }
873 }
874 FreeVar(mut fv)
875 if fv.is_generalized() && self.qnames.contains(&fv.unbound_name().unwrap()) =>
876 {
877 fv.update_init();
878 Ok(Type::FreeVar(fv))
879 }
880 FreeVar(fv) if fv.constraint_is_sandwiched() => {
886 let fv_hash = get_hash(&fv);
887 let (sub_t, super_t) = fv.get_subsup().unwrap();
888 if self.level <= fv.level().unwrap() {
889 let list = UndoableLinkedList::new();
893 let fv_t = Type::FreeVar(fv.clone());
894 let dummy = match (sub_t.contains_type(&fv_t), super_t.contains_type(&fv_t)) {
895 (true, true) => {
897 fv.dummy_link();
898 true
899 }
900 (true, false) => {
901 fv_t.undoable_link(&super_t, &list);
902 false
903 }
904 (false, true | false) => {
905 fv_t.undoable_link(&sub_t, &list);
906 false
907 }
908 };
909 let res = self.validate_subsup(sub_t, super_t, fv_hash);
910 if dummy {
911 fv.undo();
912 } else {
913 drop(list);
914 }
915 match res {
916 Ok(ty) => {
917 Ok(ty)
920 }
921 Err(errs) => {
922 if !fv.is_generalized() {
923 Type::FreeVar(fv).destructive_link(&Never);
924 }
925 Err(errs)
926 }
927 }
928 } else {
929 Ok(Type::FreeVar(fv))
931 }
932 }
933 FreeVar(fv) if fv.get_type().is_some() => {
934 let ty = fv.get_type().unwrap();
935 if self.level <= fv.level().unwrap() {
936 if let Some(tys) = ty.refinement_values() {
938 let mut union = Never;
939 for tp in tys {
940 if let Ok(ty) = self.ctx.convert_tp_into_type(tp.clone()) {
941 union = self.ctx.union(&union, &ty);
942 }
943 }
944 return Ok(union);
945 }
946 Ok(Type::FreeVar(fv))
947 } else {
948 Ok(Type::FreeVar(fv))
949 }
950 }
951 FreeVar(fv) if fv.is_unbound() => {
952 if self.level == 0 {
953 match &*fv.crack_constraint() {
954 Constraint::TypeOf(t) if !t.is_type() => {
955 return Err(TyCheckErrors::from(TyCheckError::dummy_infer_error(
956 self.ctx.cfg.input.clone(),
957 fn_name!(),
958 line!(),
959 )));
960 }
961 _ => {}
962 }
963 Ok(Type::FreeVar(fv))
964 } else {
965 let new_constraint = fv.crack_constraint().clone();
966 let new_constraint = self.deref_constraint(new_constraint)?;
967 let ty = Type::FreeVar(fv);
968 ty.update_constraint(new_constraint, None, true);
969 Ok(ty)
970 }
971 }
972 FreeVar(_) => Ok(t),
973 Poly { name, mut params } => {
974 let typ = poly(&name, params.clone());
975 let ctx = self.ctx.get_nominal_type_ctx(&typ).ok_or_else(|| {
976 TyCheckError::type_not_found(
977 self.ctx.cfg.input.clone(),
978 line!() as usize,
979 self.loc.loc(),
980 self.ctx.caused_by(),
981 &typ,
982 )
983 })?;
984 let mut errs = TyCheckErrors::empty();
985 let variances = ctx.type_params_variance();
986 for (param, variance) in params
987 .iter_mut()
988 .zip(variances.into_iter().chain(std::iter::repeat(Invariant)))
989 {
990 self.push_variance(variance);
991 match self.deref_tp(mem::take(param)) {
992 Ok(t) => *param = t,
993 Err(es) => errs.extend(es),
994 }
995 self.pop_variance();
996 }
997 if errs.is_empty() {
998 Ok(Type::Poly { name, params })
999 } else {
1000 Err(errs)
1001 }
1002 }
1003 Subr(mut subr) => {
1004 let mut errs = TyCheckErrors::empty();
1005 for param in subr.non_default_params.iter_mut() {
1006 self.push_variance(Contravariant);
1007 match self.deref_tyvar(mem::take(param.typ_mut())) {
1008 Ok(t) => *param.typ_mut() = t,
1009 Err(es) => errs.extend(es),
1010 }
1011 self.pop_variance();
1012 }
1013 if let Some(var_params) = &mut subr.var_params {
1014 self.push_variance(Contravariant);
1015 match self.deref_tyvar(mem::take(var_params.typ_mut())) {
1016 Ok(t) => *var_params.typ_mut() = t,
1017 Err(es) => errs.extend(es),
1018 }
1019 self.pop_variance();
1020 }
1021 for d_param in subr.default_params.iter_mut() {
1022 self.push_variance(Contravariant);
1023 match self.deref_tyvar(mem::take(d_param.typ_mut())) {
1024 Ok(t) => *d_param.typ_mut() = t,
1025 Err(es) => errs.extend(es),
1026 }
1027 if let Some(default) = d_param.default_typ_mut() {
1028 match self.deref_tyvar(mem::take(default)) {
1029 Ok(t) => *default = t,
1030 Err(es) => errs.extend(es),
1031 }
1032 }
1033 self.pop_variance();
1034 }
1035 if let Some(kw_var_params) = &mut subr.kw_var_params {
1036 self.push_variance(Contravariant);
1037 match self.deref_tyvar(mem::take(kw_var_params.typ_mut())) {
1038 Ok(t) => *kw_var_params.typ_mut() = t,
1039 Err(es) => errs.extend(es),
1040 }
1041 if let Some(default) = kw_var_params.default_typ_mut() {
1042 match self.deref_tyvar(mem::take(default)) {
1043 Ok(t) => *default = t,
1044 Err(es) => errs.extend(es),
1045 }
1046 }
1047 self.pop_variance();
1048 }
1049 self.push_variance(Covariant);
1050 match self.deref_tyvar(mem::take(&mut subr.return_t)) {
1051 Ok(t) => *subr.return_t = t,
1052 Err(es) => errs.extend(es),
1053 }
1054 self.pop_variance();
1055 if errs.is_empty() {
1056 Ok(Type::Subr(subr))
1057 } else {
1058 Err(errs)
1059 }
1060 }
1061 Callable {
1062 mut param_ts,
1063 return_t,
1064 } => {
1065 for param_t in param_ts.iter_mut() {
1066 *param_t = self.deref_tyvar(mem::take(param_t))?;
1067 }
1068 let return_t = self.deref_tyvar(*return_t)?;
1069 Ok(callable(param_ts, return_t))
1070 }
1071 Quantified(subr) => self.eliminate_needless_quant(*subr),
1072 Ref(t) => {
1073 let t = self.deref_tyvar(*t)?;
1074 Ok(ref_(t))
1075 }
1076 RefMut { before, after } => {
1077 let before = self.deref_tyvar(*before)?;
1078 let after = if let Some(after) = after {
1079 Some(self.deref_tyvar(*after)?)
1080 } else {
1081 None
1082 };
1083 Ok(ref_mut(before, after))
1084 }
1085 Record(mut rec) => {
1086 for (_, field) in rec.iter_mut() {
1087 *field = self.deref_tyvar(mem::take(field))?;
1088 }
1089 Ok(Type::Record(rec))
1090 }
1091 NamedTuple(mut rec) => {
1092 for (_, t) in rec.iter_mut() {
1093 *t = self.deref_tyvar(mem::take(t))?;
1094 }
1095 Ok(Type::NamedTuple(rec))
1096 }
1097 Refinement(refine) => {
1098 let t = self.deref_tyvar(*refine.t)?;
1099 let pred = self.deref_pred(*refine.pred)?;
1100 Ok(refinement(refine.var, t, pred))
1101 }
1102 And(ands, _) => {
1103 let mut new_ands = vec![];
1104 for t in ands.into_iter() {
1105 new_ands.push(self.deref_tyvar(t)?);
1106 }
1107 Ok(new_ands
1108 .into_iter()
1109 .fold(Type::Obj, |acc, t| self.ctx.intersection(&acc, &t)))
1110 }
1111 Or(ors) => {
1112 let mut new_ors = vec![];
1113 for t in ors.into_iter() {
1114 new_ors.push(self.deref_tyvar(t)?);
1115 }
1116 Ok(new_ors
1117 .into_iter()
1118 .fold(Type::Never, |acc, t| self.ctx.union(&acc, &t)))
1119 }
1120 Not(ty) => {
1121 let ty = self.deref_tyvar(*ty)?;
1122 Ok(self.ctx.complement(&ty))
1123 }
1124 Proj { lhs, rhs } => {
1125 let proj = self
1126 .ctx
1127 .eval_proj(*lhs.clone(), rhs.clone(), self.level, self.loc)
1128 .or_else(|_| {
1129 let lhs = self.deref_tyvar(*lhs)?;
1130 self.ctx.eval_proj(lhs, rhs, self.level, self.loc)
1131 })
1132 .unwrap_or(Failure);
1133 Ok(proj)
1134 }
1135 ProjCall {
1136 lhs,
1137 attr_name,
1138 args,
1139 } => {
1140 let lhs = self.deref_tp(*lhs)?;
1141 let mut new_args = vec![];
1142 for arg in args.into_iter() {
1143 new_args.push(self.deref_tp(arg)?);
1144 }
1145 let proj = self
1146 .ctx
1147 .eval_proj_call_t(lhs, attr_name, new_args, self.level, self.loc)
1148 .unwrap_or(Failure);
1149 Ok(proj)
1150 }
1151 Structural(inner) => {
1152 let inner = self.deref_tyvar(*inner)?;
1153 Ok(inner.structuralize())
1154 }
1155 Guard(grd) => {
1156 let to = self.deref_tyvar(*grd.to)?;
1157 Ok(guard(grd.namespace, *grd.target, to))
1158 }
1159 Bounded { sub, sup } => {
1160 let sub = self.deref_tyvar(*sub)?;
1161 let sup = self.deref_tyvar(*sup)?;
1162 Ok(bounded(sub, sup))
1163 }
1164 mono_type_pattern!() => Ok(t),
1165 }
1166 }
1167
1168 fn validate_subsup(
1169 &mut self,
1170 sub_t: Type,
1171 super_t: Type,
1172 fv_hash: usize,
1173 ) -> TyCheckResult<Type> {
1174 match (sub_t, super_t) {
1176 (
1184 Poly {
1185 name: ln,
1186 params: lps,
1187 },
1188 Poly {
1189 name: rn,
1190 params: rps,
1191 },
1192 ) if ln == rn => {
1193 let typ = poly(ln, lps.clone());
1194 let ctx = self.ctx.get_nominal_type_ctx(&typ).ok_or_else(|| {
1195 TyCheckError::type_not_found(
1196 self.ctx.cfg.input.clone(),
1197 line!() as usize,
1198 self.loc.loc(),
1199 self.ctx.caused_by(),
1200 &typ,
1201 )
1202 })?;
1203 let variances = ctx.type_params_variance();
1204 let mut tps = vec![];
1205 for ((lp, rp), variance) in lps
1206 .into_iter()
1207 .zip(rps.into_iter())
1208 .zip(variances.into_iter().chain(std::iter::repeat(Invariant)))
1209 {
1210 self.ctx
1211 .sub_unify_tp(&lp, &rp, Some(variance), self.loc, false)?;
1212 let param = if variance == Covariant { lp } else { rp };
1213 tps.push(param);
1214 }
1215 Ok(poly(rn, tps))
1216 }
1217 (sub_t, super_t) => self.validate_simple_subsup(sub_t, super_t, fv_hash),
1218 }
1219 }
1220
1221 fn validate_simple_subsup(
1222 &mut self,
1223 sub_t: Type,
1224 super_t: Type,
1225 fv_hash: usize,
1226 ) -> TyCheckResult<Type> {
1227 let opt_res = self.ctx.shared().gen_cache.get(&fv_hash);
1228 if opt_res.is_none() && self.ctx.is_class(&sub_t) && self.ctx.is_trait(&super_t) {
1229 self.ctx
1230 .check_trait_impl(&sub_t, &super_t, self.qnames, self.loc)?;
1231 }
1232 let is_subtype = opt_res.map(|res| res.is_subtype).unwrap_or_else(|| {
1233 let is_subtype = self.ctx.subtype_of(&sub_t, &super_t); let res = GeneralizationResult {
1235 is_subtype,
1236 impl_trait: true,
1237 };
1238 self.ctx.shared().gen_cache.insert(fv_hash, res);
1239 is_subtype
1240 });
1241 let sub_t = self.deref_tyvar(sub_t)?;
1242 let super_t = self.deref_tyvar(super_t)?;
1243 if sub_t == super_t {
1244 Ok(sub_t)
1245 } else if is_subtype {
1246 match self.current_variance() {
1247 Variance::Covariant if self.coerce => {
1251 if sub_t != Never || super_t == Obj {
1252 Ok(sub_t)
1253 } else {
1254 Ok(bounded(sub_t, super_t))
1255 }
1256 }
1257 Variance::Contravariant if self.coerce => Ok(super_t),
1258 Variance::Covariant | Variance::Contravariant => Ok(bounded(sub_t, super_t)),
1259 Variance::Invariant => {
1260 if self.ctx.supertype_of(&sub_t, &super_t) {
1262 Ok(sub_t)
1263 } else {
1264 Err(TyCheckErrors::from(TyCheckError::invariant_error(
1265 self.ctx.cfg.input.clone(),
1266 line!() as usize,
1267 &sub_t,
1268 &super_t,
1269 self.loc.loc(),
1270 self.ctx.caused_by(),
1271 )))
1272 }
1273 }
1274 }
1275 } else {
1276 Err(TyCheckErrors::from(TyCheckError::subtyping_error(
1277 self.ctx.cfg.input.clone(),
1278 line!() as usize,
1279 &sub_t,
1280 &super_t,
1281 self.loc.loc(),
1282 self.ctx.caused_by(),
1283 )))
1284 }
1285 }
1286
1287 fn eliminate_needless_quant(&mut self, subr: Type) -> TyCheckResult<Type> {
1300 let Ok(mut subr) = SubrType::try_from(subr) else {
1301 unreachable!()
1302 };
1303 let essential_qnames = subr.essential_qnames();
1304 let mut _self = Dereferencer::new(
1305 self.ctx,
1306 self.current_variance(),
1307 self.coerce,
1308 &essential_qnames,
1309 self.loc,
1310 );
1311 for param in subr.non_default_params.iter_mut() {
1312 _self.push_variance(Contravariant);
1313 *param.typ_mut() = _self
1314 .deref_tyvar(mem::take(param.typ_mut()))
1315 .inspect_err(|_e| _self.pop_variance())?;
1316 _self.pop_variance();
1317 }
1318 if let Some(var_args) = &mut subr.var_params {
1319 _self.push_variance(Contravariant);
1320 *var_args.typ_mut() = _self
1321 .deref_tyvar(mem::take(var_args.typ_mut()))
1322 .inspect_err(|_e| _self.pop_variance())?;
1323 _self.pop_variance();
1324 }
1325 for d_param in subr.default_params.iter_mut() {
1326 _self.push_variance(Contravariant);
1327 *d_param.typ_mut() = _self
1328 .deref_tyvar(mem::take(d_param.typ_mut()))
1329 .inspect_err(|_e| {
1330 _self.pop_variance();
1331 })?;
1332 if let Some(default) = d_param.default_typ_mut() {
1333 *default = _self
1334 .deref_tyvar(mem::take(default))
1335 .inspect_err(|_e| _self.pop_variance())?;
1336 }
1337 _self.pop_variance();
1338 }
1339 if let Some(kw_var_args) = &mut subr.kw_var_params {
1340 _self.push_variance(Contravariant);
1341 *kw_var_args.typ_mut() = _self
1342 .deref_tyvar(mem::take(kw_var_args.typ_mut()))
1343 .inspect_err(|_e| _self.pop_variance())?;
1344 if let Some(default) = kw_var_args.default_typ_mut() {
1345 *default = _self
1346 .deref_tyvar(mem::take(default))
1347 .inspect_err(|_e| _self.pop_variance())?;
1348 }
1349 _self.pop_variance();
1350 }
1351 _self.push_variance(Covariant);
1352 *subr.return_t = _self
1353 .deref_tyvar(mem::take(&mut subr.return_t))
1354 .inspect_err(|_e| _self.pop_variance())?;
1355 _self.pop_variance();
1356 let subr = Type::Subr(subr);
1357 if subr.has_qvar() {
1358 Ok(subr.quantify())
1359 } else {
1360 Ok(subr)
1361 }
1362 }
1363}
1364
1365impl Context {
1366 pub const TOP_LEVEL: usize = 1;
1367
1368 pub(crate) fn generalize_t(&self, free_type: Type) -> Type {
1371 let mut generalizer = Generalizer::new(self);
1372 let maybe_unbound_t = generalizer.generalize_t(free_type, false);
1373 if maybe_unbound_t.is_subr() && maybe_unbound_t.has_qvar() {
1374 maybe_unbound_t.quantify()
1375 } else {
1376 maybe_unbound_t
1377 }
1378 }
1379
1380 pub fn readable_type(&self, t: Type) -> Type {
1381 let qnames = set! {};
1382 let mut dereferencer = Dereferencer::new(self, Covariant, false, &qnames, &());
1383 dereferencer.set_level(0);
1384 dereferencer.deref_tyvar(t.clone()).unwrap_or(t)
1385 }
1386
1387 pub(crate) fn coerce(&self, t: Type, t_loc: &impl Locational) -> TyCheckResult<Type> {
1398 let qnames = set! {};
1399 let mut dereferencer = Dereferencer::new(self, Covariant, true, &qnames, t_loc);
1400 dereferencer.deref_tyvar(t)
1401 }
1402
1403 pub(crate) fn coerce_tp(&self, tp: TyParam, t_loc: &impl Locational) -> TyCheckResult<TyParam> {
1404 let qnames = set! {};
1405 let mut dereferencer = Dereferencer::new(self, Covariant, true, &qnames, t_loc);
1406 dereferencer.deref_tp(tp)
1407 }
1408
1409 pub(crate) fn trait_impl_exists(&self, class: &Type, trait_: &Type) -> bool {
1410 if self.subtype_of(class, &Type::Never) {
1412 return true;
1413 }
1414 if class.is_monomorphic() {
1415 self.mono_class_trait_impl_exist(class, trait_)
1416 } else {
1417 self.poly_class_trait_impl_exists(class, trait_)
1418 }
1419 }
1420
1421 fn mono_class_trait_impl_exist(&self, class: &Type, trait_: &Type) -> bool {
1422 let mut super_exists = false;
1423 for imp in self.get_trait_impls(trait_).into_iter() {
1424 if self.supertype_of(&imp.sub_type, class) && self.supertype_of(&imp.sup_trait, trait_)
1425 {
1426 super_exists = true;
1427 break;
1428 }
1429 }
1430 super_exists
1431 }
1432
1433 fn poly_class_trait_impl_exists(&self, class: &Type, trait_: &Type) -> bool {
1437 for imp in self.get_trait_impls(trait_).into_iter() {
1438 let _sub_subs = Substituter::substitute_typarams(self, &imp.sub_type, class).ok();
1439 let _sup_subs = Substituter::substitute_typarams(self, &imp.sup_trait, trait_).ok();
1440 if self.supertype_of(&imp.sub_type, class) && self.supertype_of(&imp.sup_trait, trait_)
1441 {
1442 return true;
1443 }
1444 }
1445 false
1446 }
1447
1448 fn check_trait_impl(
1449 &self,
1450 class: &Type,
1451 trait_: &Type,
1452 qnames: &Set<Str>,
1453 loc: &impl Locational,
1454 ) -> TyCheckResult<()> {
1455 if !self.trait_impl_exists(class, trait_) {
1456 let mut dereferencer = Dereferencer::new(self, Variance::Covariant, false, qnames, loc);
1457 let class = if DEBUG_MODE {
1458 class.clone()
1459 } else {
1460 dereferencer.deref_tyvar(class.clone())?
1461 };
1462 let trait_ = if DEBUG_MODE {
1463 trait_.clone()
1464 } else {
1465 dereferencer.deref_tyvar(trait_.clone())?
1466 };
1467 Err(TyCheckErrors::from(TyCheckError::no_trait_impl_error(
1468 self.cfg.input.clone(),
1469 line!() as usize,
1470 &class,
1471 &trait_,
1472 loc.loc(),
1473 self.caused_by(),
1474 self.get_simple_type_mismatch_hint(&trait_, &class),
1475 )))
1476 } else {
1477 Ok(())
1478 }
1479 }
1480
1481 pub(crate) fn resolve(
1484 &mut self,
1485 mut hir: hir::HIR,
1486 ) -> Result<hir::HIR, (hir::HIR, TyCheckErrors)> {
1487 self.level = 0;
1488 let mut errs = TyCheckErrors::empty();
1489 for chunk in hir.module.iter_mut() {
1490 if let Err(es) = self.resolve_expr_t(chunk, &set! {}) {
1491 errs.extend(es);
1492 }
1493 }
1494 self.resolve_ctx_vars();
1495 if errs.is_empty() {
1496 Ok(hir)
1497 } else {
1498 Err((hir, errs))
1499 }
1500 }
1501
1502 fn resolve_ctx_vars(&mut self) {
1503 let mut locals = mem::take(&mut self.locals);
1504 let mut params = mem::take(&mut self.params);
1505 let mut methods_list = mem::take(&mut self.methods_list);
1506 for (name, vi) in locals.iter_mut() {
1507 let qnames = set! {};
1508 let mut derferencer = Dereferencer::simple(self, &qnames, name);
1509 if let Ok(t) = derferencer.deref_tyvar(mem::take(&mut vi.t)) {
1510 vi.t = t;
1511 }
1512 }
1513 for (name, vi) in params.iter_mut() {
1514 let qnames = set! {};
1515 let mut derferencer = Dereferencer::simple(self, &qnames, name);
1516 if let Ok(t) = derferencer.deref_tyvar(mem::take(&mut vi.t)) {
1517 vi.t = t;
1518 }
1519 }
1520 for methods in methods_list.iter_mut() {
1521 methods.resolve_ctx_vars();
1522 }
1523 self.locals = locals;
1524 self.params = params;
1525 self.methods_list = methods_list;
1526 }
1527
1528 fn resolve_params_t(&self, params: &mut hir::Params, qnames: &Set<Str>) -> TyCheckResult<()> {
1529 for param in params.non_defaults.iter_mut() {
1530 param.vi.t.generalize();
1533 let t = mem::take(&mut param.vi.t);
1534 let mut dereferencer = Dereferencer::new(self, Contravariant, false, qnames, param);
1535 param.vi.t = dereferencer.deref_tyvar(t)?;
1536 }
1537 if let Some(var_params) = &mut params.var_params {
1538 var_params.vi.t.generalize();
1539 let t = mem::take(&mut var_params.vi.t);
1540 let mut dereferencer =
1541 Dereferencer::new(self, Contravariant, false, qnames, var_params.as_ref());
1542 var_params.vi.t = dereferencer.deref_tyvar(t)?;
1543 }
1544 for param in params.defaults.iter_mut() {
1545 param.sig.vi.t.generalize();
1546 let t = mem::take(&mut param.sig.vi.t);
1547 let mut dereferencer = Dereferencer::new(self, Contravariant, false, qnames, param);
1548 param.sig.vi.t = dereferencer.deref_tyvar(t)?;
1549 self.resolve_expr_t(&mut param.default_val, qnames)?;
1550 }
1551 if let Some(kw_var) = &mut params.kw_var_params {
1552 kw_var.vi.t.generalize();
1553 let t = mem::take(&mut kw_var.vi.t);
1554 let mut dereferencer =
1555 Dereferencer::new(self, Contravariant, false, qnames, kw_var.as_ref());
1556 kw_var.vi.t = dereferencer.deref_tyvar(t)?;
1557 }
1558 for guard in params.guards.iter_mut() {
1559 match guard {
1560 GuardClause::Bind(def) => {
1561 self.resolve_def_t(def, qnames)?;
1562 }
1563 GuardClause::Condition(cond) => {
1564 self.resolve_expr_t(cond, qnames)?;
1565 }
1566 }
1567 }
1568 Ok(())
1569 }
1570
1571 pub(crate) fn resolve_expr_t(
1575 &self,
1576 expr: &mut hir::Expr,
1577 qnames: &Set<Str>,
1578 ) -> TyCheckResult<()> {
1579 match expr {
1580 hir::Expr::Literal(_) => Ok(()),
1581 hir::Expr::Accessor(acc) => {
1582 let t = mem::take(acc.ref_mut_t().unwrap());
1583 let mut dereferencer = Dereferencer::simple(self, qnames, acc);
1584 *acc.ref_mut_t().unwrap() = dereferencer.deref_tyvar(t)?;
1585 if let hir::Accessor::Attr(attr) = acc {
1586 self.resolve_expr_t(&mut attr.obj, qnames)?;
1587 }
1588 Ok(())
1589 }
1590 hir::Expr::List(list) => match list {
1591 hir::List::Normal(lis) => {
1592 for elem in lis.elems.pos_args.iter_mut() {
1593 self.resolve_expr_t(&mut elem.expr, qnames)?;
1594 }
1595 let t = mem::take(&mut lis.t);
1596 let mut dereferencer = Dereferencer::simple(self, qnames, lis);
1597 lis.t = dereferencer.deref_tyvar(t)?;
1598 Ok(())
1599 }
1600 hir::List::WithLength(lis) => {
1601 self.resolve_expr_t(&mut lis.elem, qnames)?;
1602 if let Some(len) = &mut lis.len {
1603 self.resolve_expr_t(len, qnames)?;
1604 }
1605 let t = mem::take(&mut lis.t);
1606 let mut dereferencer = Dereferencer::simple(self, qnames, lis);
1607 lis.t = dereferencer.deref_tyvar(t)?;
1608 Ok(())
1609 }
1610 other => feature_error!(
1611 TyCheckErrors,
1612 TyCheckError,
1613 self,
1614 other.loc(),
1615 "resolve types of array comprehension"
1616 ),
1617 },
1618 hir::Expr::Tuple(tuple) => match tuple {
1619 hir::Tuple::Normal(tup) => {
1620 for elem in tup.elems.pos_args.iter_mut() {
1621 self.resolve_expr_t(&mut elem.expr, qnames)?;
1622 }
1623 let t = mem::take(&mut tup.t);
1624 let mut dereferencer = Dereferencer::simple(self, qnames, tup);
1625 tup.t = dereferencer.deref_tyvar(t)?;
1626 Ok(())
1627 }
1628 },
1629 hir::Expr::Set(set) => match set {
1630 hir::Set::Normal(st) => {
1631 for elem in st.elems.pos_args.iter_mut() {
1632 self.resolve_expr_t(&mut elem.expr, qnames)?;
1633 }
1634 let t = mem::take(&mut st.t);
1635 let mut dereferencer = Dereferencer::simple(self, qnames, st);
1636 st.t = dereferencer.deref_tyvar(t)?;
1637 Ok(())
1638 }
1639 hir::Set::WithLength(st) => {
1640 self.resolve_expr_t(&mut st.elem, qnames)?;
1641 self.resolve_expr_t(&mut st.len, qnames)?;
1642 let t = mem::take(&mut st.t);
1643 let mut dereferencer = Dereferencer::simple(self, qnames, st);
1644 st.t = dereferencer.deref_tyvar(t)?;
1645 Ok(())
1646 }
1647 },
1648 hir::Expr::Dict(dict) => match dict {
1649 hir::Dict::Normal(dic) => {
1650 for kv in dic.kvs.iter_mut() {
1651 self.resolve_expr_t(&mut kv.key, qnames)?;
1652 self.resolve_expr_t(&mut kv.value, qnames)?;
1653 }
1654 let t = mem::take(&mut dic.t);
1655 let mut dereferencer = Dereferencer::simple(self, qnames, dic);
1656 dic.t = dereferencer.deref_tyvar(t)?;
1657 Ok(())
1658 }
1659 other => feature_error!(
1660 TyCheckErrors,
1661 TyCheckError,
1662 self,
1663 other.loc(),
1664 "resolve types of dict comprehension"
1665 ),
1666 },
1667 hir::Expr::Record(record) => {
1668 for attr in record.attrs.iter_mut() {
1669 let t = mem::take(attr.sig.ref_mut_t().unwrap());
1670 let mut dereferencer = Dereferencer::simple(self, qnames, &attr.sig);
1671 let t = dereferencer.deref_tyvar(t)?;
1672 *attr.sig.ref_mut_t().unwrap() = t;
1673 for chunk in attr.body.block.iter_mut() {
1674 self.resolve_expr_t(chunk, qnames)?;
1675 }
1676 }
1677 let t = mem::take(&mut record.t);
1678 let mut dereferencer = Dereferencer::simple(self, qnames, record);
1679 record.t = dereferencer.deref_tyvar(t)?;
1680 Ok(())
1681 }
1682 hir::Expr::BinOp(binop) => {
1683 let t = mem::take(binop.signature_mut_t().unwrap());
1684 let mut dereferencer = Dereferencer::simple(self, qnames, binop);
1685 *binop.signature_mut_t().unwrap() = dereferencer.deref_tyvar(t)?;
1686 self.resolve_expr_t(&mut binop.lhs, qnames)?;
1687 self.resolve_expr_t(&mut binop.rhs, qnames)?;
1688 Ok(())
1689 }
1690 hir::Expr::UnaryOp(unaryop) => {
1691 let t = mem::take(unaryop.signature_mut_t().unwrap());
1692 let mut dereferencer = Dereferencer::simple(self, qnames, unaryop);
1693 *unaryop.signature_mut_t().unwrap() = dereferencer.deref_tyvar(t)?;
1694 self.resolve_expr_t(&mut unaryop.expr, qnames)?;
1695 Ok(())
1696 }
1697 hir::Expr::Call(call) => {
1698 for arg in call.args.pos_args.iter_mut() {
1699 self.resolve_expr_t(&mut arg.expr, qnames)?;
1700 }
1701 if let Some(var_args) = &mut call.args.var_args {
1702 self.resolve_expr_t(&mut var_args.expr, qnames)?;
1703 }
1704 for arg in call.args.kw_args.iter_mut() {
1705 self.resolve_expr_t(&mut arg.expr, qnames)?;
1706 }
1707 if let Some(kw_var) = &mut call.args.kw_var {
1708 self.resolve_expr_t(&mut kw_var.expr, qnames)?;
1709 }
1710 self.resolve_expr_t(&mut call.obj, qnames)?;
1711 if let Some(t) = call.signature_mut_t() {
1712 let t = mem::take(t);
1713 let mut dereferencer = Dereferencer::simple(self, qnames, call);
1714 *call.signature_mut_t().unwrap() = dereferencer.deref_tyvar(t)?;
1715 }
1716 Ok(())
1717 }
1718 hir::Expr::Def(def) => self.resolve_def_t(def, qnames),
1719 hir::Expr::Lambda(lambda) => {
1720 let qnames = if let Type::Quantified(quant) = lambda.ref_t() {
1721 let Ok(subr) = <&SubrType>::try_from(quant.as_ref()) else {
1722 unreachable!()
1723 };
1724 subr.essential_qnames()
1725 } else {
1726 qnames.clone()
1727 };
1728 let mut errs = TyCheckErrors::empty();
1729 for chunk in lambda.body.iter_mut() {
1730 if let Err(es) = self.resolve_expr_t(chunk, &qnames) {
1731 errs.extend(es);
1732 }
1733 }
1734 if let Err(es) = self.resolve_params_t(&mut lambda.params, &qnames) {
1735 errs.extend(es);
1736 }
1737 let t = mem::take(&mut lambda.t);
1738 let mut dereferencer = Dereferencer::simple(self, &qnames, lambda);
1739 match dereferencer.deref_tyvar(t) {
1740 Ok(t) => lambda.t = t,
1741 Err(es) => errs.extend(es),
1742 }
1743 if !errs.is_empty() {
1744 Err(errs)
1745 } else {
1746 Ok(())
1747 }
1748 }
1749 hir::Expr::ClassDef(class_def) => {
1750 for def in class_def.all_methods_mut() {
1751 self.resolve_expr_t(def, qnames)?;
1752 }
1753 Ok(())
1754 }
1755 hir::Expr::PatchDef(patch_def) => {
1756 for def in patch_def.methods.iter_mut() {
1757 self.resolve_expr_t(def, qnames)?;
1758 }
1759 Ok(())
1760 }
1761 hir::Expr::ReDef(redef) => {
1762 for chunk in redef.block.iter_mut() {
1764 self.resolve_expr_t(chunk, qnames)?;
1765 }
1766 Ok(())
1767 }
1768 hir::Expr::TypeAsc(tasc) => self.resolve_expr_t(&mut tasc.expr, qnames),
1769 hir::Expr::Code(chunks) | hir::Expr::Compound(chunks) => {
1770 for chunk in chunks.iter_mut() {
1771 self.resolve_expr_t(chunk, qnames)?;
1772 }
1773 Ok(())
1774 }
1775 hir::Expr::Dummy(chunks) => {
1776 for chunk in chunks.iter_mut() {
1777 self.resolve_expr_t(chunk, qnames)?;
1778 }
1779 Ok(())
1780 }
1781 hir::Expr::Import(_) => unreachable_error!(TyCheckErrors, TyCheckError, self),
1782 }
1783 }
1784
1785 fn resolve_def_t(&self, def: &mut hir::Def, qnames: &Set<Str>) -> TyCheckResult<()> {
1786 let qnames = if let Type::Quantified(quant) = def.sig.ref_t() {
1787 let Ok(subr) = <&SubrType>::try_from(quant.as_ref()) else {
1789 unreachable!()
1790 };
1791 subr.essential_qnames()
1792 } else {
1793 qnames.clone()
1794 };
1795 let t = mem::take(def.sig.ref_mut_t().unwrap());
1796 let mut dereferencer = Dereferencer::simple(self, &qnames, &def.sig);
1797 *def.sig.ref_mut_t().unwrap() = dereferencer.deref_tyvar(t)?;
1798 if let Some(params) = def.sig.params_mut() {
1799 self.resolve_params_t(params, &qnames)?;
1800 }
1801 for chunk in def.body.block.iter_mut() {
1802 self.resolve_expr_t(chunk, &qnames)?;
1803 }
1804 Ok(())
1805 }
1806
1807 pub(crate) fn squash_tyvar(&self, typ: Type) -> Type {
1813 match typ {
1814 Or(tys) => {
1815 let new_tys = tys
1816 .into_iter()
1817 .map(|t| self.squash_tyvar(t))
1818 .collect::<Vec<_>>();
1819 let mut union = Never;
1820 if new_tys.iter().all(|t| t.is_unnamed_unbound_var()) {
1822 for ty in new_tys.iter() {
1823 if union == Never {
1824 union = ty.clone();
1825 continue;
1826 }
1827 match (self.subtype_of(&union, ty), self.subtype_of(&union, ty)) {
1828 (true, true) | (true, false) => {
1829 let _ = self.sub_unify(&union, ty, &(), None);
1830 }
1831 (false, true) => {
1832 let _ = self.sub_unify(ty, &union, &(), None);
1833 }
1834 _ => {}
1835 }
1836 }
1837 }
1838 new_tys
1839 .into_iter()
1840 .fold(Never, |acc, t| self.union(&acc, &t))
1841 }
1842 FreeVar(ref fv) if fv.constraint_is_sandwiched() => {
1843 let (sub_t, super_t) = fv.get_subsup().unwrap();
1844 let sub_t = self.squash_tyvar(sub_t);
1845 let super_t = self.squash_tyvar(super_t);
1846 typ.update_tyvar(sub_t, super_t, None, false);
1847 typ
1848 }
1849 other => other,
1850 }
1851 }
1852}