1use crate::infer::{
12 InferenceCandidate, InferenceContext, InferenceError, InferenceInfo, InferenceVar,
13 MAX_CONSTRAINT_ITERATIONS, MAX_TYPE_RECURSION_DEPTH,
14};
15use crate::instantiate::TypeSubstitution;
16use crate::types::{InferencePriority, TemplateSpan, TypeData, TypeId};
17use crate::widening;
18use rustc_hash::FxHashSet;
19use tsz_common::interner::Atom;
20
21struct VarianceState<'a> {
22 target_param: Atom,
23 covariant: &'a mut u32,
24 contravariant: &'a mut u32,
25}
26
27impl<'a> InferenceContext<'a> {
28 pub fn resolve_with_constraints(
40 &mut self,
41 var: InferenceVar,
42 ) -> Result<TypeId, InferenceError> {
43 if let Some(ty) = self.probe(var) {
45 return Ok(ty);
46 }
47
48 let (root, result, upper_bounds, upper_bounds_only) = self.compute_constraint_result(var);
49
50 if !upper_bounds_only {
52 let filtered_upper_bounds = Self::filter_relevant_upper_bounds(&upper_bounds);
53 if let Some(upper) =
54 self.first_failed_upper_bound(result, &filtered_upper_bounds, |a, b| {
55 self.is_subtype(a, b)
56 })
57 {
58 return Err(InferenceError::BoundsViolation {
59 var,
60 lower: result,
61 upper,
62 });
63 }
64 }
65
66 if self.occurs_in(root, result) {
67 return Err(InferenceError::OccursCheck {
68 var: root,
69 ty: result,
70 });
71 }
72
73 self.table.union_value(
75 root,
76 InferenceInfo {
77 resolved: Some(result),
78 ..InferenceInfo::default()
79 },
80 );
81
82 Ok(result)
83 }
84
85 pub fn resolve_with_constraints_by<F>(
88 &mut self,
89 var: InferenceVar,
90 is_subtype: F,
91 ) -> Result<TypeId, InferenceError>
92 where
93 F: FnMut(TypeId, TypeId) -> bool,
94 {
95 if let Some(ty) = self.probe(var) {
97 return Ok(ty);
98 }
99
100 let (root, result, upper_bounds, upper_bounds_only) = self.compute_constraint_result(var);
101
102 if !upper_bounds_only {
103 let filtered_upper_bounds = Self::filter_relevant_upper_bounds(&upper_bounds);
104 if let Some(upper) =
105 self.first_failed_upper_bound(result, &filtered_upper_bounds, is_subtype)
106 {
107 return Err(InferenceError::BoundsViolation {
108 var,
109 lower: result,
110 upper,
111 });
112 }
113 }
114
115 if self.occurs_in(root, result) {
116 return Err(InferenceError::OccursCheck {
117 var: root,
118 ty: result,
119 });
120 }
121
122 self.table.union_value(
123 root,
124 InferenceInfo {
125 resolved: Some(result),
126 ..InferenceInfo::default()
127 },
128 );
129
130 Ok(result)
131 }
132
133 fn filter_relevant_upper_bounds(upper_bounds: &[TypeId]) -> Vec<TypeId> {
134 upper_bounds
135 .iter()
136 .copied()
137 .filter(|&upper| !matches!(upper, TypeId::ANY | TypeId::UNKNOWN | TypeId::ERROR))
138 .collect()
139 }
140
141 fn first_failed_upper_bound<F>(
142 &self,
143 result: TypeId,
144 filtered_upper_bounds: &[TypeId],
145 mut is_subtype: F,
146 ) -> Option<TypeId>
147 where
148 F: FnMut(TypeId, TypeId) -> bool,
149 {
150 match filtered_upper_bounds {
151 [] => None,
152 [single] => (!is_subtype(result, *single)).then_some(*single),
153 many => {
154 if many.len() <= Self::UPPER_BOUND_INTERSECTION_FAST_PATH_LIMIT {
158 let intersection = self.interner.intersection(many.to_vec());
159 if is_subtype(result, intersection) {
160 return None;
161 }
162 }
163 if many.len() >= Self::UPPER_BOUND_INTERSECTION_LARGE_SET_THRESHOLD
167 && self.should_try_large_upper_bound_intersection(result, many)
168 {
169 let intersection = self.interner.intersection(many.to_vec());
170 if is_subtype(result, intersection) {
171 return None;
172 }
173 }
174 many.iter()
175 .copied()
176 .find(|&upper| !is_subtype(result, upper))
177 }
178 }
179 }
180
181 fn should_try_large_upper_bound_intersection(&self, result: TypeId, bounds: &[TypeId]) -> bool {
182 self.is_object_like_upper_bound(result)
183 && bounds
184 .iter()
185 .copied()
186 .all(|bound| self.is_object_like_upper_bound(bound))
187 }
188
189 fn is_object_like_upper_bound(&self, ty: TypeId) -> bool {
190 match self.interner.lookup(ty) {
191 Some(
192 TypeData::Object(_)
193 | TypeData::ObjectWithIndex(_)
194 | TypeData::Lazy(_)
195 | TypeData::Intersection(_),
196 ) => true,
197 Some(TypeData::TypeParameter(info)) => info
198 .constraint
199 .is_some_and(|constraint| self.is_object_like_upper_bound(constraint)),
200 _ => false,
201 }
202 }
203
204 fn compute_constraint_result(
205 &mut self,
206 var: InferenceVar,
207 ) -> (InferenceVar, TypeId, Vec<TypeId>, bool) {
208 let root = self.table.find(var);
209 let info = self.table.probe_value(root);
210 let target_names = self.type_param_names_for_root(root);
211 let mut upper_bounds = Vec::new();
212 let mut seen_upper_bounds = FxHashSet::default();
213 let mut candidates = info.candidates;
214 for bound in info.upper_bounds {
215 if self.occurs_in(root, bound) {
216 continue;
217 }
218 if !target_names.is_empty() && self.upper_bound_cycles_param(bound, &target_names) {
219 self.expand_cyclic_upper_bound(
220 root,
221 bound,
222 &target_names,
223 &mut candidates,
224 &mut upper_bounds,
225 );
226 continue;
227 }
228 if seen_upper_bounds.insert(bound) {
229 upper_bounds.push(bound);
230 }
231 }
232
233 if !upper_bounds.is_empty() {
234 candidates.retain(|candidate| {
235 !matches!(
236 candidate.type_id,
237 TypeId::ANY | TypeId::UNKNOWN | TypeId::ERROR
238 )
239 });
240 }
241
242 let is_const = self.is_var_const(root);
244
245 let upper_bounds_only = candidates.is_empty() && !upper_bounds.is_empty();
246
247 let result = if !candidates.is_empty() {
248 self.resolve_from_candidates(&candidates, is_const, &upper_bounds)
249 } else if !upper_bounds.is_empty() {
250 if upper_bounds.len() == 1 {
254 upper_bounds[0]
255 } else {
256 self.interner.intersection(upper_bounds.clone())
257 }
258 } else {
259 TypeId::UNKNOWN
261 };
262
263 (root, result, upper_bounds, upper_bounds_only)
264 }
265
266 pub fn resolve_all_with_constraints(&mut self) -> Result<Vec<(Atom, TypeId)>, InferenceError> {
268 self.strengthen_constraints()?;
273
274 let type_params: Vec<_> = self.type_params.clone();
275 let mut results = Vec::new();
276
277 for (name, var, _) in type_params {
278 let ty = self.resolve_with_constraints(var)?;
279 results.push((name, ty));
280 }
281
282 Ok(results)
283 }
284
285 fn resolve_from_candidates(
286 &self,
287 candidates: &[InferenceCandidate],
288 is_const: bool,
289 upper_bounds: &[TypeId],
290 ) -> TypeId {
291 let filtered = self.filter_candidates_by_priority(candidates);
292 if filtered.is_empty() {
293 return TypeId::UNKNOWN;
294 }
295 let filtered_no_never: Vec<_> = filtered
296 .iter()
297 .filter(|c| c.type_id != TypeId::NEVER)
298 .cloned()
299 .collect();
300 if filtered_no_never.is_empty() {
301 return TypeId::NEVER;
302 }
303 let preserve_literals = is_const || self.constraint_implies_literals(upper_bounds);
306 let widened = if preserve_literals {
307 if is_const {
308 filtered_no_never
309 .iter()
310 .map(|c| widening::apply_const_assertion(self.interner, c.type_id))
311 .collect()
312 } else {
313 filtered_no_never.iter().map(|c| c.type_id).collect()
314 }
315 } else {
316 self.widen_candidate_types(&filtered_no_never)
317 };
318 self.best_common_type(&widened)
319 }
320
321 fn constraint_implies_literals(&self, upper_bounds: &[TypeId]) -> bool {
323 upper_bounds
324 .iter()
325 .any(|&bound| self.type_implies_literals(bound))
326 }
327
328 fn type_implies_literals(&self, type_id: TypeId) -> bool {
330 match self.interner.lookup(type_id) {
331 Some(TypeData::Literal(_)) => true,
332 Some(TypeData::Union(list_id)) => {
333 let members = self.interner.type_list(list_id);
334 members.iter().any(|&m| self.type_implies_literals(m))
335 }
336 Some(TypeData::Intersection(list_id)) => {
337 let members = self.interner.type_list(list_id);
338 members.iter().any(|&m| self.type_implies_literals(m))
339 }
340 _ => false,
341 }
342 }
343
344 fn filter_candidates_by_priority(
352 &self,
353 candidates: &[InferenceCandidate],
354 ) -> Vec<InferenceCandidate> {
355 let Some(best_priority) = candidates.iter().map(|c| c.priority).min() else {
356 return Vec::new();
357 };
358 candidates
359 .iter()
360 .filter(|candidate| candidate.priority == best_priority)
361 .cloned()
362 .collect()
363 }
364
365 fn widen_candidate_types(&self, candidates: &[InferenceCandidate]) -> Vec<TypeId> {
366 candidates
367 .iter()
368 .map(|candidate| {
369 if candidate.is_fresh_literal {
375 self.get_base_type(candidate.type_id)
376 .unwrap_or(candidate.type_id)
377 } else {
378 candidate.type_id
379 }
380 })
381 .collect()
382 }
383
384 pub fn infer_from_conditional(
392 &mut self,
393 var: InferenceVar,
394 check_type: TypeId,
395 extends_type: TypeId,
396 true_type: TypeId,
397 false_type: TypeId,
398 ) {
399 if let Some(TypeData::TypeParameter(info)) = self.interner.lookup(check_type)
401 && let Some(check_var) = self.find_type_param(info.name)
402 && check_var == self.table.find(var)
403 {
404 self.add_upper_bound(var, extends_type);
407 }
408
409 self.infer_from_type(var, true_type);
411 self.infer_from_type(var, false_type);
412 }
413
414 fn infer_from_type(&mut self, var: InferenceVar, ty: TypeId) {
416 let root = self.table.find(var);
417
418 if !self.contains_inference_var(ty, root) {
420 return;
421 }
422
423 match self.interner.lookup(ty) {
424 Some(TypeData::TypeParameter(info)) => {
425 if let Some(param_var) = self.find_type_param(info.name)
426 && self.table.find(param_var) == root
427 {
428 if let Some(constraint) = info.constraint {
431 self.add_upper_bound(var, constraint);
432 }
433 }
434 }
435 Some(TypeData::Array(elem)) => {
436 self.infer_from_type(var, elem);
437 }
438 Some(TypeData::Tuple(elements)) => {
439 let elements = self.interner.tuple_list(elements);
440 for elem in elements.iter() {
441 self.infer_from_type(var, elem.type_id);
442 }
443 }
444 Some(TypeData::Union(members) | TypeData::Intersection(members)) => {
445 let members = self.interner.type_list(members);
446 for &member in members.iter() {
447 self.infer_from_type(var, member);
448 }
449 }
450 Some(TypeData::Object(shape_id)) => {
451 let shape = self.interner.object_shape(shape_id);
452 for prop in &shape.properties {
453 self.infer_from_type(var, prop.type_id);
454 }
455 }
456 Some(TypeData::ObjectWithIndex(shape_id)) => {
457 let shape = self.interner.object_shape(shape_id);
458 for prop in &shape.properties {
459 self.infer_from_type(var, prop.type_id);
460 }
461 if let Some(index) = shape.string_index.as_ref() {
462 self.infer_from_type(var, index.key_type);
463 self.infer_from_type(var, index.value_type);
464 }
465 if let Some(index) = shape.number_index.as_ref() {
466 self.infer_from_type(var, index.key_type);
467 self.infer_from_type(var, index.value_type);
468 }
469 }
470 Some(TypeData::Application(app_id)) => {
471 let app = self.interner.type_application(app_id);
472 self.infer_from_type(var, app.base);
473 for &arg in &app.args {
474 self.infer_from_type(var, arg);
475 }
476 }
477 Some(TypeData::Function(shape_id)) => {
478 let shape = self.interner.function_shape(shape_id);
479 for param in &shape.params {
480 self.infer_from_type(var, param.type_id);
481 }
482 if let Some(this_type) = shape.this_type {
483 self.infer_from_type(var, this_type);
484 }
485 self.infer_from_type(var, shape.return_type);
486 }
487 Some(TypeData::Conditional(cond_id)) => {
488 let cond = self.interner.conditional_type(cond_id);
489 self.infer_from_conditional(
490 var,
491 cond.check_type,
492 cond.extends_type,
493 cond.true_type,
494 cond.false_type,
495 );
496 }
497 Some(TypeData::TemplateLiteral(spans)) => {
498 let spans = self.interner.template_list(spans);
500 for span in spans.iter() {
501 if let TemplateSpan::Type(inner) = span {
502 self.infer_from_type(var, *inner);
503 }
504 }
505 }
506 _ => {}
507 }
508 }
509
510 pub(crate) fn contains_inference_var(&mut self, ty: TypeId, var: InferenceVar) -> bool {
512 let mut visited = FxHashSet::default();
513 self.contains_inference_var_inner(ty, var, &mut visited, 0)
514 }
515
516 fn contains_inference_var_inner(
517 &mut self,
518 ty: TypeId,
519 var: InferenceVar,
520 visited: &mut FxHashSet<TypeId>,
521 depth: usize,
522 ) -> bool {
523 if depth > MAX_TYPE_RECURSION_DEPTH {
525 return false;
526 }
527 if !visited.insert(ty) {
529 return false;
530 }
531
532 let root = self.table.find(var);
533
534 match self.interner.lookup(ty) {
535 Some(TypeData::TypeParameter(info) | TypeData::Infer(info)) => {
536 if let Some(param_var) = self.find_type_param(info.name) {
537 self.table.find(param_var) == root
538 } else {
539 false
540 }
541 }
542 Some(TypeData::Array(elem)) => {
543 self.contains_inference_var_inner(elem, var, visited, depth + 1)
544 }
545 Some(TypeData::Tuple(elements)) => {
546 let elements = self.interner.tuple_list(elements);
547 elements
548 .iter()
549 .any(|e| self.contains_inference_var_inner(e.type_id, var, visited, depth + 1))
550 }
551 Some(TypeData::Union(members) | TypeData::Intersection(members)) => {
552 let members = self.interner.type_list(members);
553 members
554 .iter()
555 .any(|&m| self.contains_inference_var_inner(m, var, visited, depth + 1))
556 }
557 Some(TypeData::Object(shape_id)) => {
558 let shape = self.interner.object_shape(shape_id);
559 shape
560 .properties
561 .iter()
562 .any(|p| self.contains_inference_var_inner(p.type_id, var, visited, depth + 1))
563 }
564 Some(TypeData::ObjectWithIndex(shape_id)) => {
565 let shape = self.interner.object_shape(shape_id);
566 shape
567 .properties
568 .iter()
569 .any(|p| self.contains_inference_var_inner(p.type_id, var, visited, depth + 1))
570 || shape.string_index.as_ref().is_some_and(|idx| {
571 self.contains_inference_var_inner(idx.key_type, var, visited, depth + 1)
572 || self.contains_inference_var_inner(
573 idx.value_type,
574 var,
575 visited,
576 depth + 1,
577 )
578 })
579 || shape.number_index.as_ref().is_some_and(|idx| {
580 self.contains_inference_var_inner(idx.key_type, var, visited, depth + 1)
581 || self.contains_inference_var_inner(
582 idx.value_type,
583 var,
584 visited,
585 depth + 1,
586 )
587 })
588 }
589 Some(TypeData::Application(app_id)) => {
590 let app = self.interner.type_application(app_id);
591 self.contains_inference_var_inner(app.base, var, visited, depth + 1)
592 || app
593 .args
594 .iter()
595 .any(|&arg| self.contains_inference_var_inner(arg, var, visited, depth + 1))
596 }
597 Some(TypeData::Function(shape_id)) => {
598 let shape = self.interner.function_shape(shape_id);
599 shape
600 .params
601 .iter()
602 .any(|p| self.contains_inference_var_inner(p.type_id, var, visited, depth + 1))
603 || shape.this_type.is_some_and(|t| {
604 self.contains_inference_var_inner(t, var, visited, depth + 1)
605 })
606 || self.contains_inference_var_inner(shape.return_type, var, visited, depth + 1)
607 }
608 Some(TypeData::Conditional(cond_id)) => {
609 let cond = self.interner.conditional_type(cond_id);
610 self.contains_inference_var_inner(cond.check_type, var, visited, depth + 1)
611 || self.contains_inference_var_inner(cond.extends_type, var, visited, depth + 1)
612 || self.contains_inference_var_inner(cond.true_type, var, visited, depth + 1)
613 || self.contains_inference_var_inner(cond.false_type, var, visited, depth + 1)
614 }
615 Some(TypeData::TemplateLiteral(spans)) => {
616 let spans = self.interner.template_list(spans);
617 spans.iter().any(|span| match span {
618 TemplateSpan::Text(_) => false,
619 TemplateSpan::Type(inner) => {
620 self.contains_inference_var_inner(*inner, var, visited, depth + 1)
621 }
622 })
623 }
624 _ => false,
625 }
626 }
627
628 pub fn compute_variance(&self, ty: TypeId, target_param: Atom) -> (u32, u32, u32, u32) {
635 let mut covariant = 0u32;
636 let mut contravariant = 0u32;
637 let invariant = 0u32;
638 let bivariant = 0u32;
639 let mut state = VarianceState {
640 target_param,
641 covariant: &mut covariant,
642 contravariant: &mut contravariant,
643 };
644
645 self.compute_variance_helper(ty, true, &mut state);
646
647 (covariant, contravariant, invariant, bivariant)
648 }
649
650 fn compute_variance_helper(
651 &self,
652 ty: TypeId,
653 polarity: bool, state: &mut VarianceState<'_>,
655 ) {
656 match self.interner.lookup(ty) {
657 Some(TypeData::TypeParameter(info)) if info.name == state.target_param => {
658 if polarity {
659 *state.covariant += 1;
660 } else {
661 *state.contravariant += 1;
662 }
663 }
664 Some(TypeData::Array(elem)) => {
665 self.compute_variance_helper(elem, polarity, state);
666 }
667 Some(TypeData::Tuple(elements)) => {
668 let elements = self.interner.tuple_list(elements);
669 for elem in elements.iter() {
670 self.compute_variance_helper(elem.type_id, polarity, state);
671 }
672 }
673 Some(TypeData::Union(members) | TypeData::Intersection(members)) => {
674 let members = self.interner.type_list(members);
675 for &member in members.iter() {
676 self.compute_variance_helper(member, polarity, state);
677 }
678 }
679 Some(TypeData::Object(shape_id)) => {
680 let shape = self.interner.object_shape(shape_id);
681 for prop in &shape.properties {
682 self.compute_variance_helper(prop.type_id, polarity, state);
684 if prop.write_type != prop.type_id && !prop.readonly {
686 self.compute_variance_helper(prop.write_type, !polarity, state);
687 }
688 }
689 }
690 Some(TypeData::ObjectWithIndex(shape_id)) => {
691 let shape = self.interner.object_shape(shape_id);
692 for prop in &shape.properties {
693 self.compute_variance_helper(prop.type_id, polarity, state);
694 if prop.write_type != prop.type_id && !prop.readonly {
695 self.compute_variance_helper(prop.write_type, !polarity, state);
696 }
697 }
698 if let Some(index) = shape.string_index.as_ref() {
699 self.compute_variance_helper(index.value_type, polarity, state);
700 }
701 if let Some(index) = shape.number_index.as_ref() {
702 self.compute_variance_helper(index.value_type, polarity, state);
703 }
704 }
705 Some(TypeData::Application(app_id)) => {
706 let app = self.interner.type_application(app_id);
707 for &arg in &app.args {
710 self.compute_variance_helper(arg, polarity, state);
711 }
712 }
713 Some(TypeData::Function(shape_id)) => {
714 let shape = self.interner.function_shape(shape_id);
715 for param in &shape.params {
717 self.compute_variance_helper(param.type_id, !polarity, state);
718 }
719 self.compute_variance_helper(shape.return_type, polarity, state);
721 }
722 Some(TypeData::Conditional(cond_id)) => {
723 let cond = self.interner.conditional_type(cond_id);
724 self.compute_variance_helper(cond.check_type, false, state);
726 self.compute_variance_helper(cond.extends_type, false, state);
727 self.compute_variance_helper(cond.true_type, polarity, state);
729 self.compute_variance_helper(cond.false_type, polarity, state);
730 }
731 _ => {}
732 }
733 }
734
735 pub fn is_invariant_position(&self, ty: TypeId, target_param: Atom) -> bool {
737 let (_, _, invariant, _) = self.compute_variance(ty, target_param);
738 invariant > 0
739 }
740
741 pub fn is_bivariant_position(&self, ty: TypeId, target_param: Atom) -> bool {
743 let (_, _, _, bivariant) = self.compute_variance(ty, target_param);
744 bivariant > 0
745 }
746
747 pub fn get_variance(&self, ty: TypeId, target_param: Atom) -> &'static str {
749 let (covariant, contravariant, invariant, bivariant) =
750 self.compute_variance(ty, target_param);
751
752 if invariant > 0 {
753 "invariant"
754 } else if bivariant > 0 {
755 "bivariant"
756 } else if covariant > 0 && contravariant > 0 {
757 "invariant" } else if covariant > 0 {
759 "covariant"
760 } else if contravariant > 0 {
761 "contravariant"
762 } else {
763 "unused"
764 }
765 }
766
767 pub fn infer_from_context(
775 &mut self,
776 var: InferenceVar,
777 context_type: TypeId,
778 ) -> Result<(), InferenceError> {
779 self.add_upper_bound(var, context_type);
781
782 let root = self.table.find(var);
785 if self.contains_inference_var(context_type, root) {
786 return Err(InferenceError::OccursCheck {
789 var: root,
790 ty: context_type,
791 });
792 }
793
794 Ok(())
795 }
796
797 fn unify_circular_constraints(&mut self) -> Result<(), InferenceError> {
801 use rustc_hash::{FxHashMap, FxHashSet};
802
803 let type_params: Vec<_> = self.type_params.clone();
804
805 let mut graph: FxHashMap<InferenceVar, FxHashSet<InferenceVar>> = FxHashMap::default();
807 let mut var_for_param: FxHashMap<Atom, InferenceVar> = FxHashMap::default();
808
809 for (name, var, _) in &type_params {
810 let root = self.table.find(*var);
811 var_for_param.insert(*name, root);
812 graph.entry(root).or_default();
813 }
814
815 for (_name, var, _) in &type_params {
817 let root = self.table.find(*var);
818 let info = self.table.probe_value(root);
819
820 for &upper in &info.upper_bounds {
821 if let Some(TypeData::TypeParameter(param_info)) = self.interner.lookup(upper)
823 && let Some(&upper_var) = var_for_param.get(¶m_info.name)
824 {
825 let upper_root = self.table.find(upper_var);
826 graph.entry(root).or_default().insert(upper_root);
828 }
829 }
830 }
831
832 let mut index_counter = 0;
834 let mut indices: FxHashMap<InferenceVar, usize> = FxHashMap::default();
835 let mut lowlink: FxHashMap<InferenceVar, usize> = FxHashMap::default();
836 let mut stack: Vec<InferenceVar> = Vec::new();
837 let mut on_stack: FxHashSet<InferenceVar> = FxHashSet::default();
838 let mut sccs: Vec<Vec<InferenceVar>> = Vec::new();
839
840 struct TarjanState<'a> {
841 graph: &'a FxHashMap<InferenceVar, FxHashSet<InferenceVar>>,
842 index_counter: &'a mut usize,
843 indices: &'a mut FxHashMap<InferenceVar, usize>,
844 lowlink: &'a mut FxHashMap<InferenceVar, usize>,
845 stack: &'a mut Vec<InferenceVar>,
846 on_stack: &'a mut FxHashSet<InferenceVar>,
847 sccs: &'a mut Vec<Vec<InferenceVar>>,
848 }
849
850 fn strongconnect(var: InferenceVar, state: &mut TarjanState) {
851 state.indices.insert(var, *state.index_counter);
852 state.lowlink.insert(var, *state.index_counter);
853 *state.index_counter += 1;
854 state.stack.push(var);
855 state.on_stack.insert(var);
856
857 if let Some(neighbors) = state.graph.get(&var) {
858 for &neighbor in neighbors {
859 if !state.indices.contains_key(&neighbor) {
860 strongconnect(neighbor, state);
861 let neighbor_low = *state.lowlink.get(&neighbor).unwrap_or(&0);
862 let var_low = state.lowlink.get_mut(&var).unwrap();
863 *var_low = (*var_low).min(neighbor_low);
864 } else if state.on_stack.contains(&neighbor) {
865 let neighbor_idx = *state.indices.get(&neighbor).unwrap_or(&0);
866 let var_low = state.lowlink.get_mut(&var).unwrap();
867 *var_low = (*var_low).min(neighbor_idx);
868 }
869 }
870 }
871
872 if *state.lowlink.get(&var).unwrap_or(&0) == *state.indices.get(&var).unwrap_or(&0) {
873 let mut scc = Vec::new();
874 loop {
875 let w = state.stack.pop().unwrap();
876 state.on_stack.remove(&w);
877 scc.push(w);
878 if w == var {
879 break;
880 }
881 }
882 state.sccs.push(scc);
883 }
884 }
885
886 for &var in graph.keys() {
888 if !indices.contains_key(&var) {
889 let mut state = TarjanState {
890 graph: &graph,
891 index_counter: &mut index_counter,
892 indices: &mut indices,
893 lowlink: &mut lowlink,
894 stack: &mut stack,
895 on_stack: &mut on_stack,
896 sccs: &mut sccs,
897 };
898 strongconnect(var, &mut state);
899 }
900 }
901
902 for scc in sccs {
904 if scc.len() > 1 {
905 let first = scc[0];
907 for &other in &scc[1..] {
908 self.unify_vars(first, other)?;
909 }
910 }
911 }
912
913 Ok(())
914 }
915
916 pub fn strengthen_constraints(&mut self) -> Result<(), InferenceError> {
919 self.unify_circular_constraints()?;
923
924 let type_params: Vec<_> = self.type_params.clone();
925 let mut changed = true;
926 let mut iterations = 0;
927
928 while changed && iterations < MAX_CONSTRAINT_ITERATIONS {
931 changed = false;
932 iterations += 1;
933
934 for (name, var, _) in &type_params {
935 let root = self.table.find(*var);
936
937 let info = self.table.probe_value(root).clone();
940
941 for &upper in &info.upper_bounds {
944 if self.propagate_candidates_to_upper(root, upper, *name)? {
945 changed = true;
946 }
947 }
948 }
949 }
950 Ok(())
951 }
952
953 fn propagate_candidates_to_upper(
956 &mut self,
957 var_root: InferenceVar,
958 upper: TypeId,
959 exclude_param: Atom,
960 ) -> Result<bool, InferenceError> {
961 if let Some(TypeData::TypeParameter(info)) = self.interner.lookup(upper)
963 && info.name != exclude_param
964 && let Some(upper_var) = self.find_type_param(info.name)
965 {
966 let upper_root = self.table.find(upper_var);
967
968 if var_root == upper_root {
970 return Ok(false);
971 }
972
973 let var_candidates = self.table.probe_value(var_root).candidates;
975
976 let mut changed = false;
978 for candidate in var_candidates {
979 if self.add_candidate_if_new(
981 upper_root,
982 candidate.type_id,
983 InferencePriority::Circular,
984 ) {
985 changed = true;
986 }
987 }
988 return Ok(changed);
989 }
990 Ok(false)
991 }
992
993 fn add_candidate_if_new(
995 &mut self,
996 var: InferenceVar,
997 ty: TypeId,
998 priority: InferencePriority,
999 ) -> bool {
1000 let root = self.table.find(var);
1001 let info = self.table.probe_value(root);
1002
1003 if info.candidates.iter().any(|c| c.type_id == ty) {
1005 return false;
1006 }
1007
1008 self.add_candidate(var, ty, priority);
1009 true
1010 }
1011
1012 pub fn validate_variance(&mut self) -> Result<(), InferenceError> {
1014 let type_params: Vec<_> = self.type_params.clone();
1015 for (_name, var, _) in &type_params {
1016 let resolved = match self.probe(*var) {
1017 Some(ty) => ty,
1018 None => continue,
1019 };
1020
1021 if self.occurs_in(*var, resolved) {
1024 let root = self.table.find(*var);
1025 return Err(InferenceError::OccursCheck {
1027 var: root,
1028 ty: resolved,
1029 });
1030 }
1031
1032 }
1036
1037 Ok(())
1038 }
1039
1040 pub fn fix_current_variables(&mut self) -> Result<(), InferenceError> {
1053 let type_params: Vec<_> = self.type_params.clone();
1054
1055 for (_name, var, _is_const) in &type_params {
1056 let root = self.table.find(*var);
1057 let info = self.table.probe_value(root);
1058
1059 if info.resolved.is_some() {
1061 continue;
1062 }
1063
1064 if info.candidates.is_empty() {
1066 continue;
1067 }
1068
1069 let is_const = self.is_var_const(root);
1073 let result =
1074 self.resolve_from_candidates(&info.candidates, is_const, &info.upper_bounds);
1075
1076 if self.occurs_in(root, result) {
1078 continue;
1080 }
1081
1082 self.table.union_value(
1085 root,
1086 InferenceInfo {
1087 resolved: Some(result),
1088 candidates: info.candidates,
1090 upper_bounds: info.upper_bounds,
1091 },
1092 );
1093 }
1094
1095 Ok(())
1096 }
1097
1098 pub fn get_current_substitution(&mut self) -> TypeSubstitution {
1104 let mut subst = TypeSubstitution::new();
1105 let type_params: Vec<_> = self.type_params.clone();
1106
1107 for (name, var, _) in &type_params {
1108 let ty = match self.probe(*var) {
1109 Some(resolved) => {
1110 tracing::trace!(
1111 ?name,
1112 ?var,
1113 ?resolved,
1114 "get_current_substitution: already resolved"
1115 );
1116 resolved
1117 }
1118 None => {
1119 let root = self.table.find(*var);
1121 let info = self.table.probe_value(root);
1122 tracing::trace!(
1123 ?name, ?var,
1124 candidates_count = info.candidates.len(),
1125 upper_bounds_count = info.upper_bounds.len(),
1126 upper_bounds = ?info.upper_bounds,
1127 "get_current_substitution: not resolved"
1128 );
1129
1130 if !info.candidates.is_empty() {
1131 let is_const = self.is_var_const(root);
1132 self.resolve_from_candidates(&info.candidates, is_const, &info.upper_bounds)
1133 } else if !info.upper_bounds.is_empty() {
1134 if info.upper_bounds.len() == 1 {
1140 info.upper_bounds[0]
1141 } else {
1142 self.interner.intersection(info.upper_bounds.to_vec())
1143 }
1144 } else {
1145 TypeId::UNKNOWN
1147 }
1148 }
1149 };
1150
1151 subst.insert(*name, ty);
1152 }
1153
1154 subst
1155 }
1156}