1use crate::{
2 expr::ModPath,
3 typ::{FnType, PrintFlag, Type, TypeRef, PRINT_FLAGS},
4};
5use ahash::{AHashMap, AHashSet};
6use anyhow::{bail, Result};
7use arcstr::ArcStr;
8use compact_str::format_compact;
9use parking_lot::{RwLock, RwLockReadGuard, RwLockWriteGuard};
10use std::{
11 cmp::{Eq, PartialEq},
12 collections::hash_map::Entry,
13 fmt::{self, Debug},
14 hash::Hash,
15 ops::Deref,
16};
17use triomphe::Arc;
18
19atomic_id!(TVarId);
20
21pub(super) fn would_cycle_inner(addr: usize, t: &Type) -> bool {
22 match t {
23 Type::Primitive(_) | Type::Any | Type::Bottom | Type::Ref(TypeRef { .. }) => {
24 false
25 }
26 Type::TVar(t) => {
27 Arc::as_ptr(&t.read().typ).addr() == addr
28 || match &*t.read().typ.read() {
29 None => false,
30 Some(t) => would_cycle_inner(addr, t),
31 }
32 }
33 Type::Abstract { id: _, params } => {
34 params.iter().any(|t| would_cycle_inner(addr, t))
35 }
36 Type::Error(t) => would_cycle_inner(addr, t),
37 Type::Array(a) => would_cycle_inner(addr, &**a),
38 Type::Map { key, value } => {
39 would_cycle_inner(addr, &**key) || would_cycle_inner(addr, &**value)
40 }
41 Type::ByRef(t) => would_cycle_inner(addr, t),
42 Type::Tuple(ts) => ts.iter().any(|t| would_cycle_inner(addr, t)),
43 Type::Variant(_, ts) => ts.iter().any(|t| would_cycle_inner(addr, t)),
44 Type::Struct(ts) => ts.iter().any(|(_, t)| would_cycle_inner(addr, t)),
45 Type::Set(s) => s.iter().any(|t| would_cycle_inner(addr, t)),
46 Type::Fn(f) => {
47 let FnType {
48 args,
49 vargs,
50 rtype,
51 constraints,
52 throws,
53 explicit_throws: _,
54 lambda_ids: _,
55 } = &**f;
56 args.iter().any(|t| would_cycle_inner(addr, &t.typ))
57 || match vargs {
58 None => false,
59 Some(t) => would_cycle_inner(addr, t),
60 }
61 || would_cycle_inner(addr, rtype)
62 || constraints.read().iter().any(|a| {
63 Arc::as_ptr(&a.0.read().typ).addr() == addr
64 || would_cycle_inner(addr, &a.1)
65 })
66 || would_cycle_inner(addr, &throws)
67 }
68 }
69}
70
71#[derive(Debug)]
72pub struct TVarInnerInner {
73 pub(crate) id: TVarId,
74 pub(crate) frozen: bool,
75 pub(crate) typ: Arc<RwLock<Option<Type>>>,
76}
77
78#[derive(Debug)]
79pub struct TVarInner {
80 pub name: ArcStr,
81 pub(crate) typ: RwLock<TVarInnerInner>,
82}
83
84#[derive(Debug, Clone)]
85pub struct TVar(Arc<TVarInner>);
86
87impl fmt::Display for TVar {
88 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
89 if !PRINT_FLAGS.get().contains(PrintFlag::DerefTVars) {
90 write!(f, "'{}", self.name)
91 } else {
92 write!(f, "'{}: ", self.name)?;
93 match &*self.read().typ.read() {
94 Some(t) => write!(f, "{t}"),
95 None => write!(f, "unbound"),
96 }
97 }
98 }
99}
100
101impl Default for TVar {
102 fn default() -> Self {
103 Self::empty_named(ArcStr::from(format_compact!("_{}", TVarId::new().0).as_str()))
104 }
105}
106
107impl Deref for TVar {
108 type Target = TVarInner;
109
110 fn deref(&self) -> &Self::Target {
111 &*self.0
112 }
113}
114
115impl PartialEq for TVar {
116 fn eq(&self, other: &Self) -> bool {
117 let t0 = self.read();
118 let t1 = other.read();
119 Arc::ptr_eq(&t0.typ, &t1.typ) || {
120 let t0 = t0.typ.read();
121 let t1 = t1.typ.read();
122 *t0 == *t1
123 }
124 }
125}
126
127impl Eq for TVar {}
128
129impl PartialOrd for TVar {
130 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
131 let t0 = self.read();
132 let t1 = other.read();
133 if Arc::ptr_eq(&t0.typ, &t1.typ) {
134 Some(std::cmp::Ordering::Equal)
135 } else {
136 let t0 = t0.typ.read();
137 let t1 = t1.typ.read();
138 t0.partial_cmp(&*t1)
139 }
140 }
141}
142
143impl Ord for TVar {
144 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
145 let t0 = self.read();
146 let t1 = other.read();
147 if Arc::ptr_eq(&t0.typ, &t1.typ) {
148 std::cmp::Ordering::Equal
149 } else {
150 let t0 = t0.typ.read();
151 let t1 = t1.typ.read();
152 t0.cmp(&*t1)
153 }
154 }
155}
156
157impl TVar {
158 pub fn scope_refs(&self, scope: &ModPath) -> Self {
159 match Type::TVar(self.clone()).scope_refs(scope) {
160 Type::TVar(tv) => tv,
161 _ => unreachable!(),
162 }
163 }
164
165 pub fn empty_named(name: ArcStr) -> Self {
166 Self(Arc::new(TVarInner {
167 name,
168 typ: RwLock::new(TVarInnerInner {
169 id: TVarId::new(),
170 frozen: false,
171 typ: Arc::new(RwLock::new(None)),
172 }),
173 }))
174 }
175
176 pub fn named(name: ArcStr, typ: Type) -> Self {
177 Self(Arc::new(TVarInner {
178 name,
179 typ: RwLock::new(TVarInnerInner {
180 id: TVarId::new(),
181 frozen: false,
182 typ: Arc::new(RwLock::new(Some(typ))),
183 }),
184 }))
185 }
186
187 pub fn read<'a>(&'a self) -> RwLockReadGuard<'a, TVarInnerInner> {
188 self.typ.read()
189 }
190
191 pub fn write<'a>(&'a self) -> RwLockWriteGuard<'a, TVarInnerInner> {
192 self.typ.write()
193 }
194
195 pub fn alias(&self, other: &Self) {
197 let mut s = self.write();
198 if !s.frozen {
199 s.frozen = true;
200 let o = other.read();
201 s.id = o.id;
202 s.typ = Arc::clone(&o.typ);
203 }
204 }
205
206 pub fn freeze(&self) {
207 self.write().frozen = true;
208 }
209
210 pub fn copy(&self, other: &Self) {
212 let s = self.read();
213 let o = other.read();
214 *s.typ.write() = o.typ.read().clone();
215 }
216
217 pub fn normalize(&self) -> Self {
218 match &mut *self.read().typ.write() {
219 None => (),
220 Some(t) => {
221 *t = t.normalize();
222 }
223 }
224 self.clone()
225 }
226
227 pub fn unbind(&self) {
228 *self.read().typ.write() = None
229 }
230
231 pub(super) fn would_cycle(&self, t: &Type) -> bool {
232 let addr = Arc::as_ptr(&self.read().typ).addr();
233 would_cycle_inner(addr, t)
234 }
235
236 pub(super) fn addr(&self) -> usize {
237 Arc::as_ptr(&self.0).addr()
238 }
239
240 pub(super) fn inner_addr(&self) -> usize {
241 Arc::as_ptr(&self.read().typ).addr()
242 }
243}
244
245impl Type {
246 pub fn unfreeze_tvars(&self) {
247 match self {
248 Type::Bottom | Type::Any | Type::Primitive(_) => (),
249 Type::Ref(TypeRef { params, .. }) => {
250 for t in params.iter() {
251 t.unfreeze_tvars();
252 }
253 }
254 Type::Error(t) => t.unfreeze_tvars(),
255 Type::Array(t) => t.unfreeze_tvars(),
256 Type::Map { key, value } => {
257 key.unfreeze_tvars();
258 value.unfreeze_tvars();
259 }
260 Type::ByRef(t) => t.unfreeze_tvars(),
261 Type::Tuple(ts) => {
262 for t in ts.iter() {
263 t.unfreeze_tvars()
264 }
265 }
266 Type::Struct(ts) => {
267 for (_, t) in ts.iter() {
268 t.unfreeze_tvars()
269 }
270 }
271 Type::Variant(_, ts) => {
272 for t in ts.iter() {
273 t.unfreeze_tvars()
274 }
275 }
276 Type::TVar(tv) => tv.write().frozen = false,
277 Type::Fn(ft) => ft.unfreeze_tvars(),
278 Type::Set(s) => {
279 for typ in s.iter() {
280 typ.unfreeze_tvars()
281 }
282 }
283 Type::Abstract { id: _, params } => {
284 for typ in params.iter() {
285 typ.unfreeze_tvars()
286 }
287 }
288 }
289 }
290
291 pub fn alias_tvars(&self, known: &mut AHashMap<ArcStr, TVar>) {
293 match self {
294 Type::Bottom | Type::Any | Type::Primitive(_) => (),
295 Type::Ref(TypeRef { params, .. }) => {
296 for t in params.iter() {
297 t.alias_tvars(known);
298 }
299 }
300 Type::Error(t) => t.alias_tvars(known),
301 Type::Array(t) => t.alias_tvars(known),
302 Type::Map { key, value } => {
303 key.alias_tvars(known);
304 value.alias_tvars(known);
305 }
306 Type::ByRef(t) => t.alias_tvars(known),
307 Type::Tuple(ts) => {
308 for t in ts.iter() {
309 t.alias_tvars(known)
310 }
311 }
312 Type::Struct(ts) => {
313 for (_, t) in ts.iter() {
314 t.alias_tvars(known)
315 }
316 }
317 Type::Variant(_, ts) => {
318 for t in ts.iter() {
319 t.alias_tvars(known)
320 }
321 }
322 Type::TVar(tv) => match known.entry(tv.name.clone()) {
323 Entry::Occupied(e) => {
324 let v = e.get();
325 v.freeze();
326 tv.alias(v);
327 }
328 Entry::Vacant(e) => {
329 e.insert(tv.clone());
330 ()
331 }
332 },
333 Type::Fn(ft) => ft.alias_tvars(known),
334 Type::Set(s) => {
335 for typ in s.iter() {
336 typ.alias_tvars(known)
337 }
338 }
339 Type::Abstract { id: _, params } => {
340 for typ in params.iter() {
341 typ.alias_tvars(known)
342 }
343 }
344 }
345 }
346
347 pub fn collect_tvars(&self, known: &mut AHashMap<ArcStr, TVar>) {
348 match self {
349 Type::Bottom | Type::Any | Type::Primitive(_) => (),
350 Type::Ref(TypeRef { params, .. }) => {
351 for t in params.iter() {
352 t.collect_tvars(known);
353 }
354 }
355 Type::Error(t) => t.collect_tvars(known),
356 Type::Array(t) => t.collect_tvars(known),
357 Type::Map { key, value } => {
358 key.collect_tvars(known);
359 value.collect_tvars(known);
360 }
361 Type::ByRef(t) => t.collect_tvars(known),
362 Type::Tuple(ts) => {
363 for t in ts.iter() {
364 t.collect_tvars(known)
365 }
366 }
367 Type::Struct(ts) => {
368 for (_, t) in ts.iter() {
369 t.collect_tvars(known)
370 }
371 }
372 Type::Variant(_, ts) => {
373 for t in ts.iter() {
374 t.collect_tvars(known)
375 }
376 }
377 Type::TVar(tv) => match known.entry(tv.name.clone()) {
378 Entry::Occupied(_) => (),
379 Entry::Vacant(e) => {
380 e.insert(tv.clone());
381 ()
382 }
383 },
384 Type::Fn(ft) => ft.collect_tvars(known),
385 Type::Set(s) => {
386 for typ in s.iter() {
387 typ.collect_tvars(known)
388 }
389 }
390 Type::Abstract { id: _, params } => {
391 for typ in params.iter() {
392 typ.collect_tvars(known)
393 }
394 }
395 }
396 }
397
398 pub fn check_tvars_declared(&self, declared: &AHashSet<ArcStr>) -> Result<()> {
399 match self {
400 Type::Bottom | Type::Any | Type::Primitive(_) => Ok(()),
401 Type::Ref(TypeRef { params, .. }) => {
402 params.iter().try_for_each(|t| t.check_tvars_declared(declared))
403 }
404 Type::Error(t) => t.check_tvars_declared(declared),
405 Type::Array(t) => t.check_tvars_declared(declared),
406 Type::Map { key, value } => {
407 key.check_tvars_declared(declared)?;
408 value.check_tvars_declared(declared)
409 }
410 Type::ByRef(t) => t.check_tvars_declared(declared),
411 Type::Tuple(ts) => {
412 ts.iter().try_for_each(|t| t.check_tvars_declared(declared))
413 }
414 Type::Struct(ts) => {
415 ts.iter().try_for_each(|(_, t)| t.check_tvars_declared(declared))
416 }
417 Type::Variant(_, ts) => {
418 ts.iter().try_for_each(|t| t.check_tvars_declared(declared))
419 }
420 Type::TVar(tv) => {
421 if !declared.contains(&tv.name) {
422 bail!("undeclared type variable '{}'", tv.name)
423 } else {
424 Ok(())
425 }
426 }
427 Type::Set(s) => s.iter().try_for_each(|t| t.check_tvars_declared(declared)),
428 Type::Abstract { id: _, params } => {
429 params.iter().try_for_each(|t| t.check_tvars_declared(declared))
430 }
431 Type::Fn(_) => Ok(()),
432 }
433 }
434
435 pub fn has_unbound(&self) -> bool {
436 match self {
437 Type::Bottom | Type::Any | Type::Primitive(_) => false,
438 Type::Ref(TypeRef { .. }) => false,
439 Type::Error(e) => e.has_unbound(),
440 Type::Array(t0) => t0.has_unbound(),
441 Type::Map { key, value } => key.has_unbound() || value.has_unbound(),
442 Type::ByRef(t0) => t0.has_unbound(),
443 Type::Tuple(ts) => ts.iter().any(|t| t.has_unbound()),
444 Type::Struct(ts) => ts.iter().any(|(_, t)| t.has_unbound()),
445 Type::Variant(_, ts) => ts.iter().any(|t| t.has_unbound()),
446 Type::TVar(tv) => tv.read().typ.read().is_none(),
447 Type::Set(s) => s.iter().any(|t| t.has_unbound()),
448 Type::Abstract { id: _, params } => params.iter().any(|t| t.has_unbound()),
449 Type::Fn(ft) => ft.has_unbound(),
450 }
451 }
452
453 pub fn bind_as(&self, t: &Self) {
455 match self {
456 Type::Bottom | Type::Any | Type::Primitive(_) => (),
457 Type::Ref(TypeRef { .. }) => (),
458 Type::Error(t0) => t0.bind_as(t),
459 Type::Array(t0) => t0.bind_as(t),
460 Type::Map { key, value } => {
461 key.bind_as(t);
462 value.bind_as(t);
463 }
464 Type::ByRef(t0) => t0.bind_as(t),
465 Type::Tuple(ts) => {
466 for elt in ts.iter() {
467 elt.bind_as(t)
468 }
469 }
470 Type::Struct(ts) => {
471 for (_, elt) in ts.iter() {
472 elt.bind_as(t)
473 }
474 }
475 Type::Variant(_, ts) => {
476 for elt in ts.iter() {
477 elt.bind_as(t)
478 }
479 }
480 Type::TVar(tv) => {
481 let tv = tv.read();
482 let mut tv = tv.typ.write();
483 if tv.is_none() {
484 *tv = Some(t.clone());
485 }
486 }
487 Type::Set(s) => {
488 for elt in s.iter() {
489 elt.bind_as(t)
490 }
491 }
492 Type::Fn(ft) => ft.bind_as(t),
493 Type::Abstract { id: _, params } => {
494 for typ in params.iter() {
495 typ.bind_as(t)
496 }
497 }
498 }
499 }
500
501 pub fn reset_tvars(&self) -> Type {
504 match self {
505 Type::Bottom => Type::Bottom,
506 Type::Any => Type::Any,
507 Type::Primitive(p) => Type::Primitive(*p),
508 Type::Ref(TypeRef { scope, name, params, .. }) => Type::Ref(TypeRef {
509 scope: scope.clone(),
510 name: name.clone(),
511 params: Arc::from_iter(params.iter().map(|t| t.reset_tvars())),
512 ..Default::default()
513 }),
514 Type::Error(t0) => Type::Error(Arc::new(t0.reset_tvars())),
515 Type::Array(t0) => Type::Array(Arc::new(t0.reset_tvars())),
516 Type::Map { key, value } => {
517 let key = Arc::new(key.reset_tvars());
518 let value = Arc::new(value.reset_tvars());
519 Type::Map { key, value }
520 }
521 Type::ByRef(t0) => Type::ByRef(Arc::new(t0.reset_tvars())),
522 Type::Tuple(ts) => {
523 Type::Tuple(Arc::from_iter(ts.iter().map(|t| t.reset_tvars())))
524 }
525 Type::Struct(ts) => Type::Struct(Arc::from_iter(
526 ts.iter().map(|(n, t)| (n.clone(), t.reset_tvars())),
527 )),
528 Type::Variant(tag, ts) => Type::Variant(
529 tag.clone(),
530 Arc::from_iter(ts.iter().map(|t| t.reset_tvars())),
531 ),
532 Type::TVar(tv) => Type::TVar(TVar::empty_named(tv.name.clone())),
533 Type::Set(s) => Type::Set(Arc::from_iter(s.iter().map(|t| t.reset_tvars()))),
534 Type::Fn(fntyp) => Type::Fn(Arc::new(fntyp.reset_tvars())),
535 Type::Abstract { id, params } => Type::Abstract {
536 id: *id,
537 params: Arc::from_iter(params.iter().map(|t| t.reset_tvars())),
538 },
539 }
540 }
541
542 pub fn replace_tvars(&self, known: &AHashMap<ArcStr, Self>) -> Type {
547 use poolshark::local::LPooled;
548 self.replace_tvars_int(known, &mut LPooled::take())
549 }
550
551 pub(super) fn replace_tvars_int(
552 &self,
553 known: &AHashMap<ArcStr, Self>,
554 renamed: &mut AHashMap<ArcStr, TVar>,
555 ) -> Type {
556 match self {
557 Type::TVar(tv) => match known.get(&tv.name) {
558 Some(t) => t.clone(),
559 None => {
560 let fresh =
561 renamed.entry(tv.name.clone()).or_insert_with(TVar::default);
562 Type::TVar(fresh.clone())
563 }
564 },
565 Type::Bottom => Type::Bottom,
566 Type::Any => Type::Any,
567 Type::Primitive(p) => Type::Primitive(*p),
568 Type::Ref(TypeRef { scope, name, params, .. }) => Type::Ref(TypeRef {
569 scope: scope.clone(),
570 name: name.clone(),
571 params: Arc::from_iter(
572 params.iter().map(|t| t.replace_tvars_int(known, renamed)),
573 ),
574 ..Default::default()
575 }),
576 Type::Error(t0) => {
577 Type::Error(Arc::new(t0.replace_tvars_int(known, renamed)))
578 }
579 Type::Array(t0) => {
580 Type::Array(Arc::new(t0.replace_tvars_int(known, renamed)))
581 }
582 Type::Map { key, value } => {
583 let key = Arc::new(key.replace_tvars_int(known, renamed));
584 let value = Arc::new(value.replace_tvars_int(known, renamed));
585 Type::Map { key, value }
586 }
587 Type::ByRef(t0) => {
588 Type::ByRef(Arc::new(t0.replace_tvars_int(known, renamed)))
589 }
590 Type::Tuple(ts) => Type::Tuple(Arc::from_iter(
591 ts.iter().map(|t| t.replace_tvars_int(known, renamed)),
592 )),
593 Type::Struct(ts) => Type::Struct(Arc::from_iter(
594 ts.iter().map(|(n, t)| (n.clone(), t.replace_tvars_int(known, renamed))),
595 )),
596 Type::Variant(tag, ts) => Type::Variant(
597 tag.clone(),
598 Arc::from_iter(ts.iter().map(|t| t.replace_tvars_int(known, renamed))),
599 ),
600 Type::Set(s) => Type::Set(Arc::from_iter(
601 s.iter().map(|t| t.replace_tvars_int(known, renamed)),
602 )),
603 Type::Fn(fntyp) => {
604 Type::Fn(Arc::new(fntyp.replace_tvars_int(known, renamed)))
605 }
606 Type::Abstract { id, params } => Type::Abstract {
607 id: *id,
608 params: Arc::from_iter(
609 params.iter().map(|t| t.replace_tvars_int(known, renamed)),
610 ),
611 },
612 }
613 }
614
615 pub(crate) fn unbind_tvars(&self) {
617 match self {
618 Type::Bottom | Type::Any | Type::Primitive(_) | Type::Ref(TypeRef { .. }) => {
619 ()
620 }
621 Type::Error(t0) => t0.unbind_tvars(),
622 Type::Array(t0) => t0.unbind_tvars(),
623 Type::Map { key, value } => {
624 key.unbind_tvars();
625 value.unbind_tvars();
626 }
627 Type::ByRef(t0) => t0.unbind_tvars(),
628 Type::Tuple(ts)
629 | Type::Variant(_, ts)
630 | Type::Set(ts)
631 | Type::Abstract { id: _, params: ts } => {
632 for t in ts.iter() {
633 t.unbind_tvars()
634 }
635 }
636 Type::Struct(ts) => {
637 for (_, t) in ts.iter() {
638 t.unbind_tvars()
639 }
640 }
641 Type::TVar(tv) => tv.unbind(),
642 Type::Fn(fntyp) => fntyp.unbind_tvars(),
643 }
644 }
645}