1use crate::{
2 expr::ModPath,
3 typ::{FnType, PrintFlag, Type, PRINT_FLAGS},
4};
5use anyhow::{bail, Result};
6use arcstr::ArcStr;
7use compact_str::format_compact;
8use fxhash::{FxHashMap, FxHashSet};
9use parking_lot::{RwLock, RwLockReadGuard, RwLockWriteGuard};
10use std::{
11 cmp::{Eq, PartialEq},
12 collections::hash_map::Entry,
13 fmt::{self, Debug},
14 hash::Hash,
15 ops::Deref,
16};
17use triomphe::Arc;
18
19atomic_id!(TVarId);
20
21pub(super) fn would_cycle_inner(addr: usize, t: &Type) -> bool {
22 match t {
23 Type::Primitive(_) | Type::Any | Type::Bottom | Type::Ref { .. } => false,
24 Type::TVar(t) => {
25 Arc::as_ptr(&t.read().typ).addr() == addr
26 || match &*t.read().typ.read() {
27 None => false,
28 Some(t) => would_cycle_inner(addr, t),
29 }
30 }
31 Type::Abstract { id: _, params } => {
32 params.iter().any(|t| would_cycle_inner(addr, t))
33 }
34 Type::Error(t) => would_cycle_inner(addr, t),
35 Type::Array(a) => would_cycle_inner(addr, &**a),
36 Type::Map { key, value } => {
37 would_cycle_inner(addr, &**key) || would_cycle_inner(addr, &**value)
38 }
39 Type::ByRef(t) => would_cycle_inner(addr, t),
40 Type::Tuple(ts) => ts.iter().any(|t| would_cycle_inner(addr, t)),
41 Type::Variant(_, ts) => ts.iter().any(|t| would_cycle_inner(addr, t)),
42 Type::Struct(ts) => ts.iter().any(|(_, t)| would_cycle_inner(addr, t)),
43 Type::Set(s) => s.iter().any(|t| would_cycle_inner(addr, t)),
44 Type::Fn(f) => {
45 let FnType {
46 args,
47 vargs,
48 rtype,
49 constraints,
50 throws,
51 explicit_throws: _,
52 lambda_ids: _,
53 } = &**f;
54 args.iter().any(|t| would_cycle_inner(addr, &t.typ))
55 || match vargs {
56 None => false,
57 Some(t) => would_cycle_inner(addr, t),
58 }
59 || would_cycle_inner(addr, rtype)
60 || constraints.read().iter().any(|a| {
61 Arc::as_ptr(&a.0.read().typ).addr() == addr
62 || would_cycle_inner(addr, &a.1)
63 })
64 || would_cycle_inner(addr, &throws)
65 }
66 }
67}
68
69#[derive(Debug)]
70pub struct TVarInnerInner {
71 pub(crate) id: TVarId,
72 pub(crate) frozen: bool,
73 pub(crate) typ: Arc<RwLock<Option<Type>>>,
74}
75
76#[derive(Debug)]
77pub struct TVarInner {
78 pub name: ArcStr,
79 pub(crate) typ: RwLock<TVarInnerInner>,
80}
81
82#[derive(Debug, Clone)]
83pub struct TVar(Arc<TVarInner>);
84
85impl fmt::Display for TVar {
86 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
87 if !PRINT_FLAGS.get().contains(PrintFlag::DerefTVars) {
88 write!(f, "'{}", self.name)
89 } else {
90 write!(f, "'{}: ", self.name)?;
91 match &*self.read().typ.read() {
92 Some(t) => write!(f, "{t}"),
93 None => write!(f, "unbound"),
94 }
95 }
96 }
97}
98
99impl Default for TVar {
100 fn default() -> Self {
101 Self::empty_named(ArcStr::from(format_compact!("_{}", TVarId::new().0).as_str()))
102 }
103}
104
105impl Deref for TVar {
106 type Target = TVarInner;
107
108 fn deref(&self) -> &Self::Target {
109 &*self.0
110 }
111}
112
113impl PartialEq for TVar {
114 fn eq(&self, other: &Self) -> bool {
115 let t0 = self.read();
116 let t1 = other.read();
117 Arc::ptr_eq(&t0.typ, &t1.typ) || {
118 let t0 = t0.typ.read();
119 let t1 = t1.typ.read();
120 *t0 == *t1
121 }
122 }
123}
124
125impl Eq for TVar {}
126
127impl PartialOrd for TVar {
128 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
129 let t0 = self.read();
130 let t1 = other.read();
131 if Arc::ptr_eq(&t0.typ, &t1.typ) {
132 Some(std::cmp::Ordering::Equal)
133 } else {
134 let t0 = t0.typ.read();
135 let t1 = t1.typ.read();
136 t0.partial_cmp(&*t1)
137 }
138 }
139}
140
141impl Ord for TVar {
142 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
143 let t0 = self.read();
144 let t1 = other.read();
145 if Arc::ptr_eq(&t0.typ, &t1.typ) {
146 std::cmp::Ordering::Equal
147 } else {
148 let t0 = t0.typ.read();
149 let t1 = t1.typ.read();
150 t0.cmp(&*t1)
151 }
152 }
153}
154
155impl TVar {
156 pub fn scope_refs(&self, scope: &ModPath) -> Self {
157 match Type::TVar(self.clone()).scope_refs(scope) {
158 Type::TVar(tv) => tv,
159 _ => unreachable!(),
160 }
161 }
162
163 pub fn empty_named(name: ArcStr) -> Self {
164 Self(Arc::new(TVarInner {
165 name,
166 typ: RwLock::new(TVarInnerInner {
167 id: TVarId::new(),
168 frozen: false,
169 typ: Arc::new(RwLock::new(None)),
170 }),
171 }))
172 }
173
174 pub fn named(name: ArcStr, typ: Type) -> Self {
175 Self(Arc::new(TVarInner {
176 name,
177 typ: RwLock::new(TVarInnerInner {
178 id: TVarId::new(),
179 frozen: false,
180 typ: Arc::new(RwLock::new(Some(typ))),
181 }),
182 }))
183 }
184
185 pub fn read<'a>(&'a self) -> RwLockReadGuard<'a, TVarInnerInner> {
186 self.typ.read()
187 }
188
189 pub fn write<'a>(&'a self) -> RwLockWriteGuard<'a, TVarInnerInner> {
190 self.typ.write()
191 }
192
193 pub fn alias(&self, other: &Self) {
195 let mut s = self.write();
196 if !s.frozen {
197 s.frozen = true;
198 let o = other.read();
199 s.id = o.id;
200 s.typ = Arc::clone(&o.typ);
201 }
202 }
203
204 pub fn freeze(&self) {
205 self.write().frozen = true;
206 }
207
208 pub fn copy(&self, other: &Self) {
210 let s = self.read();
211 let o = other.read();
212 *s.typ.write() = o.typ.read().clone();
213 }
214
215 pub fn normalize(&self) -> Self {
216 match &mut *self.read().typ.write() {
217 None => (),
218 Some(t) => {
219 *t = t.normalize();
220 }
221 }
222 self.clone()
223 }
224
225 pub fn unbind(&self) {
226 *self.read().typ.write() = None
227 }
228
229 pub(super) fn would_cycle(&self, t: &Type) -> bool {
230 let addr = Arc::as_ptr(&self.read().typ).addr();
231 would_cycle_inner(addr, t)
232 }
233
234 pub(super) fn addr(&self) -> usize {
235 Arc::as_ptr(&self.0).addr()
236 }
237
238 pub(super) fn inner_addr(&self) -> usize {
239 Arc::as_ptr(&self.read().typ).addr()
240 }
241}
242
243impl Type {
244 pub fn unfreeze_tvars(&self) {
245 match self {
246 Type::Bottom | Type::Any | Type::Primitive(_) => (),
247 Type::Ref { params, .. } => {
248 for t in params.iter() {
249 t.unfreeze_tvars();
250 }
251 }
252 Type::Error(t) => t.unfreeze_tvars(),
253 Type::Array(t) => t.unfreeze_tvars(),
254 Type::Map { key, value } => {
255 key.unfreeze_tvars();
256 value.unfreeze_tvars();
257 }
258 Type::ByRef(t) => t.unfreeze_tvars(),
259 Type::Tuple(ts) => {
260 for t in ts.iter() {
261 t.unfreeze_tvars()
262 }
263 }
264 Type::Struct(ts) => {
265 for (_, t) in ts.iter() {
266 t.unfreeze_tvars()
267 }
268 }
269 Type::Variant(_, ts) => {
270 for t in ts.iter() {
271 t.unfreeze_tvars()
272 }
273 }
274 Type::TVar(tv) => tv.write().frozen = false,
275 Type::Fn(ft) => ft.unfreeze_tvars(),
276 Type::Set(s) => {
277 for typ in s.iter() {
278 typ.unfreeze_tvars()
279 }
280 }
281 Type::Abstract { id: _, params } => {
282 for typ in params.iter() {
283 typ.unfreeze_tvars()
284 }
285 }
286 }
287 }
288
289 pub fn alias_tvars(&self, known: &mut FxHashMap<ArcStr, TVar>) {
291 match self {
292 Type::Bottom | Type::Any | Type::Primitive(_) => (),
293 Type::Ref { params, .. } => {
294 for t in params.iter() {
295 t.alias_tvars(known);
296 }
297 }
298 Type::Error(t) => t.alias_tvars(known),
299 Type::Array(t) => t.alias_tvars(known),
300 Type::Map { key, value } => {
301 key.alias_tvars(known);
302 value.alias_tvars(known);
303 }
304 Type::ByRef(t) => t.alias_tvars(known),
305 Type::Tuple(ts) => {
306 for t in ts.iter() {
307 t.alias_tvars(known)
308 }
309 }
310 Type::Struct(ts) => {
311 for (_, t) in ts.iter() {
312 t.alias_tvars(known)
313 }
314 }
315 Type::Variant(_, ts) => {
316 for t in ts.iter() {
317 t.alias_tvars(known)
318 }
319 }
320 Type::TVar(tv) => match known.entry(tv.name.clone()) {
321 Entry::Occupied(e) => {
322 let v = e.get();
323 v.freeze();
324 tv.alias(v);
325 }
326 Entry::Vacant(e) => {
327 e.insert(tv.clone());
328 ()
329 }
330 },
331 Type::Fn(ft) => ft.alias_tvars(known),
332 Type::Set(s) => {
333 for typ in s.iter() {
334 typ.alias_tvars(known)
335 }
336 }
337 Type::Abstract { id: _, params } => {
338 for typ in params.iter() {
339 typ.alias_tvars(known)
340 }
341 }
342 }
343 }
344
345 pub fn collect_tvars(&self, known: &mut FxHashMap<ArcStr, TVar>) {
346 match self {
347 Type::Bottom | Type::Any | Type::Primitive(_) => (),
348 Type::Ref { params, .. } => {
349 for t in params.iter() {
350 t.collect_tvars(known);
351 }
352 }
353 Type::Error(t) => t.collect_tvars(known),
354 Type::Array(t) => t.collect_tvars(known),
355 Type::Map { key, value } => {
356 key.collect_tvars(known);
357 value.collect_tvars(known);
358 }
359 Type::ByRef(t) => t.collect_tvars(known),
360 Type::Tuple(ts) => {
361 for t in ts.iter() {
362 t.collect_tvars(known)
363 }
364 }
365 Type::Struct(ts) => {
366 for (_, t) in ts.iter() {
367 t.collect_tvars(known)
368 }
369 }
370 Type::Variant(_, ts) => {
371 for t in ts.iter() {
372 t.collect_tvars(known)
373 }
374 }
375 Type::TVar(tv) => match known.entry(tv.name.clone()) {
376 Entry::Occupied(_) => (),
377 Entry::Vacant(e) => {
378 e.insert(tv.clone());
379 ()
380 }
381 },
382 Type::Fn(ft) => ft.collect_tvars(known),
383 Type::Set(s) => {
384 for typ in s.iter() {
385 typ.collect_tvars(known)
386 }
387 }
388 Type::Abstract { id: _, params } => {
389 for typ in params.iter() {
390 typ.collect_tvars(known)
391 }
392 }
393 }
394 }
395
396 pub fn check_tvars_declared(&self, declared: &FxHashSet<ArcStr>) -> Result<()> {
397 match self {
398 Type::Bottom | Type::Any | Type::Primitive(_) => Ok(()),
399 Type::Ref { params, .. } => {
400 params.iter().try_for_each(|t| t.check_tvars_declared(declared))
401 }
402 Type::Error(t) => t.check_tvars_declared(declared),
403 Type::Array(t) => t.check_tvars_declared(declared),
404 Type::Map { key, value } => {
405 key.check_tvars_declared(declared)?;
406 value.check_tvars_declared(declared)
407 }
408 Type::ByRef(t) => t.check_tvars_declared(declared),
409 Type::Tuple(ts) => {
410 ts.iter().try_for_each(|t| t.check_tvars_declared(declared))
411 }
412 Type::Struct(ts) => {
413 ts.iter().try_for_each(|(_, t)| t.check_tvars_declared(declared))
414 }
415 Type::Variant(_, ts) => {
416 ts.iter().try_for_each(|t| t.check_tvars_declared(declared))
417 }
418 Type::TVar(tv) => {
419 if !declared.contains(&tv.name) {
420 bail!("undeclared type variable '{}'", tv.name)
421 } else {
422 Ok(())
423 }
424 }
425 Type::Set(s) => s.iter().try_for_each(|t| t.check_tvars_declared(declared)),
426 Type::Abstract { id: _, params } => {
427 params.iter().try_for_each(|t| t.check_tvars_declared(declared))
428 }
429 Type::Fn(_) => Ok(()),
430 }
431 }
432
433 pub fn has_unbound(&self) -> bool {
434 match self {
435 Type::Bottom | Type::Any | Type::Primitive(_) => false,
436 Type::Ref { .. } => false,
437 Type::Error(e) => e.has_unbound(),
438 Type::Array(t0) => t0.has_unbound(),
439 Type::Map { key, value } => key.has_unbound() || value.has_unbound(),
440 Type::ByRef(t0) => t0.has_unbound(),
441 Type::Tuple(ts) => ts.iter().any(|t| t.has_unbound()),
442 Type::Struct(ts) => ts.iter().any(|(_, t)| t.has_unbound()),
443 Type::Variant(_, ts) => ts.iter().any(|t| t.has_unbound()),
444 Type::TVar(tv) => tv.read().typ.read().is_none(),
445 Type::Set(s) => s.iter().any(|t| t.has_unbound()),
446 Type::Abstract { id: _, params } => params.iter().any(|t| t.has_unbound()),
447 Type::Fn(ft) => ft.has_unbound(),
448 }
449 }
450
451 pub fn bind_as(&self, t: &Self) {
453 match self {
454 Type::Bottom | Type::Any | Type::Primitive(_) => (),
455 Type::Ref { .. } => (),
456 Type::Error(t0) => t0.bind_as(t),
457 Type::Array(t0) => t0.bind_as(t),
458 Type::Map { key, value } => {
459 key.bind_as(t);
460 value.bind_as(t);
461 }
462 Type::ByRef(t0) => t0.bind_as(t),
463 Type::Tuple(ts) => {
464 for elt in ts.iter() {
465 elt.bind_as(t)
466 }
467 }
468 Type::Struct(ts) => {
469 for (_, elt) in ts.iter() {
470 elt.bind_as(t)
471 }
472 }
473 Type::Variant(_, ts) => {
474 for elt in ts.iter() {
475 elt.bind_as(t)
476 }
477 }
478 Type::TVar(tv) => {
479 let tv = tv.read();
480 let mut tv = tv.typ.write();
481 if tv.is_none() {
482 *tv = Some(t.clone());
483 }
484 }
485 Type::Set(s) => {
486 for elt in s.iter() {
487 elt.bind_as(t)
488 }
489 }
490 Type::Fn(ft) => ft.bind_as(t),
491 Type::Abstract { id: _, params } => {
492 for typ in params.iter() {
493 typ.bind_as(t)
494 }
495 }
496 }
497 }
498
499 pub fn reset_tvars(&self) -> Type {
502 match self {
503 Type::Bottom => Type::Bottom,
504 Type::Any => Type::Any,
505 Type::Primitive(p) => Type::Primitive(*p),
506 Type::Ref { scope, name, params } => Type::Ref {
507 scope: scope.clone(),
508 name: name.clone(),
509 params: Arc::from_iter(params.iter().map(|t| t.reset_tvars())),
510 },
511 Type::Error(t0) => Type::Error(Arc::new(t0.reset_tvars())),
512 Type::Array(t0) => Type::Array(Arc::new(t0.reset_tvars())),
513 Type::Map { key, value } => {
514 let key = Arc::new(key.reset_tvars());
515 let value = Arc::new(value.reset_tvars());
516 Type::Map { key, value }
517 }
518 Type::ByRef(t0) => Type::ByRef(Arc::new(t0.reset_tvars())),
519 Type::Tuple(ts) => {
520 Type::Tuple(Arc::from_iter(ts.iter().map(|t| t.reset_tvars())))
521 }
522 Type::Struct(ts) => Type::Struct(Arc::from_iter(
523 ts.iter().map(|(n, t)| (n.clone(), t.reset_tvars())),
524 )),
525 Type::Variant(tag, ts) => Type::Variant(
526 tag.clone(),
527 Arc::from_iter(ts.iter().map(|t| t.reset_tvars())),
528 ),
529 Type::TVar(tv) => Type::TVar(TVar::empty_named(tv.name.clone())),
530 Type::Set(s) => Type::Set(Arc::from_iter(s.iter().map(|t| t.reset_tvars()))),
531 Type::Fn(fntyp) => Type::Fn(Arc::new(fntyp.reset_tvars())),
532 Type::Abstract { id, params } => Type::Abstract {
533 id: *id,
534 params: Arc::from_iter(params.iter().map(|t| t.reset_tvars())),
535 },
536 }
537 }
538
539 pub fn replace_tvars(&self, known: &FxHashMap<ArcStr, Self>) -> Type {
544 use poolshark::local::LPooled;
545 self.replace_tvars_int(known, &mut LPooled::take())
546 }
547
548 pub(super) fn replace_tvars_int(
549 &self,
550 known: &FxHashMap<ArcStr, Self>,
551 renamed: &mut FxHashMap<ArcStr, TVar>,
552 ) -> Type {
553 match self {
554 Type::TVar(tv) => match known.get(&tv.name) {
555 Some(t) => t.clone(),
556 None => {
557 let fresh =
558 renamed.entry(tv.name.clone()).or_insert_with(TVar::default);
559 Type::TVar(fresh.clone())
560 }
561 },
562 Type::Bottom => Type::Bottom,
563 Type::Any => Type::Any,
564 Type::Primitive(p) => Type::Primitive(*p),
565 Type::Ref { scope, name, params } => Type::Ref {
566 scope: scope.clone(),
567 name: name.clone(),
568 params: Arc::from_iter(
569 params.iter().map(|t| t.replace_tvars_int(known, renamed)),
570 ),
571 },
572 Type::Error(t0) => {
573 Type::Error(Arc::new(t0.replace_tvars_int(known, renamed)))
574 }
575 Type::Array(t0) => {
576 Type::Array(Arc::new(t0.replace_tvars_int(known, renamed)))
577 }
578 Type::Map { key, value } => {
579 let key = Arc::new(key.replace_tvars_int(known, renamed));
580 let value = Arc::new(value.replace_tvars_int(known, renamed));
581 Type::Map { key, value }
582 }
583 Type::ByRef(t0) => {
584 Type::ByRef(Arc::new(t0.replace_tvars_int(known, renamed)))
585 }
586 Type::Tuple(ts) => Type::Tuple(Arc::from_iter(
587 ts.iter().map(|t| t.replace_tvars_int(known, renamed)),
588 )),
589 Type::Struct(ts) => Type::Struct(Arc::from_iter(
590 ts.iter().map(|(n, t)| (n.clone(), t.replace_tvars_int(known, renamed))),
591 )),
592 Type::Variant(tag, ts) => Type::Variant(
593 tag.clone(),
594 Arc::from_iter(ts.iter().map(|t| t.replace_tvars_int(known, renamed))),
595 ),
596 Type::Set(s) => Type::Set(Arc::from_iter(
597 s.iter().map(|t| t.replace_tvars_int(known, renamed)),
598 )),
599 Type::Fn(fntyp) => {
600 Type::Fn(Arc::new(fntyp.replace_tvars_int(known, renamed)))
601 }
602 Type::Abstract { id, params } => Type::Abstract {
603 id: *id,
604 params: Arc::from_iter(
605 params.iter().map(|t| t.replace_tvars_int(known, renamed)),
606 ),
607 },
608 }
609 }
610
611 pub(crate) fn unbind_tvars(&self) {
613 match self {
614 Type::Bottom | Type::Any | Type::Primitive(_) | Type::Ref { .. } => (),
615 Type::Error(t0) => t0.unbind_tvars(),
616 Type::Array(t0) => t0.unbind_tvars(),
617 Type::Map { key, value } => {
618 key.unbind_tvars();
619 value.unbind_tvars();
620 }
621 Type::ByRef(t0) => t0.unbind_tvars(),
622 Type::Tuple(ts)
623 | Type::Variant(_, ts)
624 | Type::Set(ts)
625 | Type::Abstract { id: _, params: ts } => {
626 for t in ts.iter() {
627 t.unbind_tvars()
628 }
629 }
630 Type::Struct(ts) => {
631 for (_, t) in ts.iter() {
632 t.unbind_tvars()
633 }
634 }
635 Type::TVar(tv) => tv.unbind(),
636 Type::Fn(fntyp) => fntyp.unbind_tvars(),
637 }
638 }
639}