1use crate::{
2 expr::ModPath,
3 typ::{FnType, PrintFlag, Type, PRINT_FLAGS},
4};
5use anyhow::{bail, Result};
6use arcstr::ArcStr;
7use compact_str::format_compact;
8use fxhash::{FxHashMap, FxHashSet};
9use parking_lot::{RwLock, RwLockReadGuard, RwLockWriteGuard};
10use std::{
11 cmp::{Eq, PartialEq},
12 collections::hash_map::Entry,
13 fmt::{self, Debug},
14 hash::Hash,
15 ops::Deref,
16};
17use triomphe::Arc;
18
19atomic_id!(TVarId);
20
21pub(super) fn would_cycle_inner(addr: usize, t: &Type) -> bool {
22 match t {
23 Type::Primitive(_) | Type::Any | Type::Bottom | Type::Ref { .. } => false,
24 Type::TVar(t) => {
25 Arc::as_ptr(&t.read().typ).addr() == addr
26 || match &*t.read().typ.read() {
27 None => false,
28 Some(t) => would_cycle_inner(addr, t),
29 }
30 }
31 Type::Abstract { id: _, params } => {
32 params.iter().any(|t| would_cycle_inner(addr, t))
33 }
34 Type::Error(t) => would_cycle_inner(addr, t),
35 Type::Array(a) => would_cycle_inner(addr, &**a),
36 Type::Map { key, value } => {
37 would_cycle_inner(addr, &**key) || would_cycle_inner(addr, &**value)
38 }
39 Type::ByRef(t) => would_cycle_inner(addr, t),
40 Type::Tuple(ts) => ts.iter().any(|t| would_cycle_inner(addr, t)),
41 Type::Variant(_, ts) => ts.iter().any(|t| would_cycle_inner(addr, t)),
42 Type::Struct(ts) => ts.iter().any(|(_, t)| would_cycle_inner(addr, t)),
43 Type::Set(s) => s.iter().any(|t| would_cycle_inner(addr, t)),
44 Type::Fn(f) => {
45 let FnType { args, vargs, rtype, constraints, throws, explicit_throws: _ } =
46 &**f;
47 args.iter().any(|t| would_cycle_inner(addr, &t.typ))
48 || match vargs {
49 None => false,
50 Some(t) => would_cycle_inner(addr, t),
51 }
52 || would_cycle_inner(addr, rtype)
53 || constraints.read().iter().any(|a| {
54 Arc::as_ptr(&a.0.read().typ).addr() == addr
55 || would_cycle_inner(addr, &a.1)
56 })
57 || would_cycle_inner(addr, &throws)
58 }
59 }
60}
61
62#[derive(Debug)]
63pub struct TVarInnerInner {
64 pub(crate) id: TVarId,
65 pub(crate) frozen: bool,
66 pub(crate) typ: Arc<RwLock<Option<Type>>>,
67}
68
69#[derive(Debug)]
70pub struct TVarInner {
71 pub name: ArcStr,
72 pub(crate) typ: RwLock<TVarInnerInner>,
73}
74
75#[derive(Debug, Clone)]
76pub struct TVar(Arc<TVarInner>);
77
78impl fmt::Display for TVar {
79 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
80 if !PRINT_FLAGS.get().contains(PrintFlag::DerefTVars) {
81 write!(f, "'{}", self.name)
82 } else {
83 write!(f, "'{}: ", self.name)?;
84 match &*self.read().typ.read() {
85 Some(t) => write!(f, "{t}"),
86 None => write!(f, "unbound"),
87 }
88 }
89 }
90}
91
92impl Default for TVar {
93 fn default() -> Self {
94 Self::empty_named(ArcStr::from(format_compact!("_{}", TVarId::new().0).as_str()))
95 }
96}
97
98impl Deref for TVar {
99 type Target = TVarInner;
100
101 fn deref(&self) -> &Self::Target {
102 &*self.0
103 }
104}
105
106impl PartialEq for TVar {
107 fn eq(&self, other: &Self) -> bool {
108 let t0 = self.read();
109 let t1 = other.read();
110 t0.typ.as_ptr().addr() == t1.typ.as_ptr().addr() || {
111 let t0 = t0.typ.read();
112 let t1 = t1.typ.read();
113 *t0 == *t1
114 }
115 }
116}
117
118impl Eq for TVar {}
119
120impl PartialOrd for TVar {
121 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
122 let t0 = self.read();
123 let t1 = other.read();
124 if t0.typ.as_ptr().addr() == t1.typ.as_ptr().addr() {
125 Some(std::cmp::Ordering::Equal)
126 } else {
127 let t0 = t0.typ.read();
128 let t1 = t1.typ.read();
129 t0.partial_cmp(&*t1)
130 }
131 }
132}
133
134impl Ord for TVar {
135 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
136 let t0 = self.read();
137 let t1 = other.read();
138 if t0.typ.as_ptr().addr() == t1.typ.as_ptr().addr() {
139 std::cmp::Ordering::Equal
140 } else {
141 let t0 = t0.typ.read();
142 let t1 = t1.typ.read();
143 t0.cmp(&*t1)
144 }
145 }
146}
147
148impl TVar {
149 pub fn scope_refs(&self, scope: &ModPath) -> Self {
150 match Type::TVar(self.clone()).scope_refs(scope) {
151 Type::TVar(tv) => tv,
152 _ => unreachable!(),
153 }
154 }
155
156 pub fn empty_named(name: ArcStr) -> Self {
157 Self(Arc::new(TVarInner {
158 name,
159 typ: RwLock::new(TVarInnerInner {
160 id: TVarId::new(),
161 frozen: false,
162 typ: Arc::new(RwLock::new(None)),
163 }),
164 }))
165 }
166
167 pub fn named(name: ArcStr, typ: Type) -> Self {
168 Self(Arc::new(TVarInner {
169 name,
170 typ: RwLock::new(TVarInnerInner {
171 id: TVarId::new(),
172 frozen: false,
173 typ: Arc::new(RwLock::new(Some(typ))),
174 }),
175 }))
176 }
177
178 pub fn read<'a>(&'a self) -> RwLockReadGuard<'a, TVarInnerInner> {
179 self.typ.read()
180 }
181
182 pub fn write<'a>(&'a self) -> RwLockWriteGuard<'a, TVarInnerInner> {
183 self.typ.write()
184 }
185
186 pub fn alias(&self, other: &Self) {
188 let mut s = self.write();
189 if !s.frozen {
190 s.frozen = true;
191 let o = other.read();
192 s.id = o.id;
193 s.typ = Arc::clone(&o.typ);
194 }
195 }
196
197 pub fn freeze(&self) {
198 self.write().frozen = true;
199 }
200
201 pub fn copy(&self, other: &Self) {
203 let s = self.read();
204 let o = other.read();
205 *s.typ.write() = o.typ.read().clone();
206 }
207
208 pub fn normalize(&self) -> Self {
209 match &mut *self.read().typ.write() {
210 None => (),
211 Some(t) => {
212 *t = t.normalize();
213 }
214 }
215 self.clone()
216 }
217
218 pub fn unbind(&self) {
219 *self.read().typ.write() = None
220 }
221
222 pub(super) fn would_cycle(&self, t: &Type) -> bool {
223 let addr = Arc::as_ptr(&self.read().typ).addr();
224 would_cycle_inner(addr, t)
225 }
226
227 pub(super) fn addr(&self) -> usize {
228 Arc::as_ptr(&self.0).addr()
229 }
230
231 pub(super) fn inner_addr(&self) -> usize {
232 Arc::as_ptr(&self.read().typ).addr()
233 }
234}
235
236impl Type {
237 pub fn unfreeze_tvars(&self) {
238 match self {
239 Type::Bottom | Type::Any | Type::Primitive(_) => (),
240 Type::Ref { params, .. } => {
241 for t in params.iter() {
242 t.unfreeze_tvars();
243 }
244 }
245 Type::Error(t) => t.unfreeze_tvars(),
246 Type::Array(t) => t.unfreeze_tvars(),
247 Type::Map { key, value } => {
248 key.unfreeze_tvars();
249 value.unfreeze_tvars();
250 }
251 Type::ByRef(t) => t.unfreeze_tvars(),
252 Type::Tuple(ts) => {
253 for t in ts.iter() {
254 t.unfreeze_tvars()
255 }
256 }
257 Type::Struct(ts) => {
258 for (_, t) in ts.iter() {
259 t.unfreeze_tvars()
260 }
261 }
262 Type::Variant(_, ts) => {
263 for t in ts.iter() {
264 t.unfreeze_tvars()
265 }
266 }
267 Type::TVar(tv) => tv.write().frozen = false,
268 Type::Fn(ft) => ft.unfreeze_tvars(),
269 Type::Set(s) => {
270 for typ in s.iter() {
271 typ.unfreeze_tvars()
272 }
273 }
274 Type::Abstract { id: _, params } => {
275 for typ in params.iter() {
276 typ.unfreeze_tvars()
277 }
278 }
279 }
280 }
281
282 pub fn alias_tvars(&self, known: &mut FxHashMap<ArcStr, TVar>) {
284 match self {
285 Type::Bottom | Type::Any | Type::Primitive(_) => (),
286 Type::Ref { params, .. } => {
287 for t in params.iter() {
288 t.alias_tvars(known);
289 }
290 }
291 Type::Error(t) => t.alias_tvars(known),
292 Type::Array(t) => t.alias_tvars(known),
293 Type::Map { key, value } => {
294 key.alias_tvars(known);
295 value.alias_tvars(known);
296 }
297 Type::ByRef(t) => t.alias_tvars(known),
298 Type::Tuple(ts) => {
299 for t in ts.iter() {
300 t.alias_tvars(known)
301 }
302 }
303 Type::Struct(ts) => {
304 for (_, t) in ts.iter() {
305 t.alias_tvars(known)
306 }
307 }
308 Type::Variant(_, ts) => {
309 for t in ts.iter() {
310 t.alias_tvars(known)
311 }
312 }
313 Type::TVar(tv) => match known.entry(tv.name.clone()) {
314 Entry::Occupied(e) => {
315 let v = e.get();
316 v.freeze();
317 tv.alias(v);
318 }
319 Entry::Vacant(e) => {
320 e.insert(tv.clone());
321 ()
322 }
323 },
324 Type::Fn(ft) => ft.alias_tvars(known),
325 Type::Set(s) => {
326 for typ in s.iter() {
327 typ.alias_tvars(known)
328 }
329 }
330 Type::Abstract { id: _, params } => {
331 for typ in params.iter() {
332 typ.alias_tvars(known)
333 }
334 }
335 }
336 }
337
338 pub fn collect_tvars(&self, known: &mut FxHashMap<ArcStr, TVar>) {
339 match self {
340 Type::Bottom | Type::Any | Type::Primitive(_) => (),
341 Type::Ref { params, .. } => {
342 for t in params.iter() {
343 t.collect_tvars(known);
344 }
345 }
346 Type::Error(t) => t.collect_tvars(known),
347 Type::Array(t) => t.collect_tvars(known),
348 Type::Map { key, value } => {
349 key.collect_tvars(known);
350 value.collect_tvars(known);
351 }
352 Type::ByRef(t) => t.collect_tvars(known),
353 Type::Tuple(ts) => {
354 for t in ts.iter() {
355 t.collect_tvars(known)
356 }
357 }
358 Type::Struct(ts) => {
359 for (_, t) in ts.iter() {
360 t.collect_tvars(known)
361 }
362 }
363 Type::Variant(_, ts) => {
364 for t in ts.iter() {
365 t.collect_tvars(known)
366 }
367 }
368 Type::TVar(tv) => match known.entry(tv.name.clone()) {
369 Entry::Occupied(_) => (),
370 Entry::Vacant(e) => {
371 e.insert(tv.clone());
372 ()
373 }
374 },
375 Type::Fn(ft) => ft.collect_tvars(known),
376 Type::Set(s) => {
377 for typ in s.iter() {
378 typ.collect_tvars(known)
379 }
380 }
381 Type::Abstract { id: _, params } => {
382 for typ in params.iter() {
383 typ.collect_tvars(known)
384 }
385 }
386 }
387 }
388
389 pub fn check_tvars_declared(&self, declared: &FxHashSet<ArcStr>) -> Result<()> {
390 match self {
391 Type::Bottom | Type::Any | Type::Primitive(_) => Ok(()),
392 Type::Ref { params, .. } => {
393 params.iter().try_for_each(|t| t.check_tvars_declared(declared))
394 }
395 Type::Error(t) => t.check_tvars_declared(declared),
396 Type::Array(t) => t.check_tvars_declared(declared),
397 Type::Map { key, value } => {
398 key.check_tvars_declared(declared)?;
399 value.check_tvars_declared(declared)
400 }
401 Type::ByRef(t) => t.check_tvars_declared(declared),
402 Type::Tuple(ts) => {
403 ts.iter().try_for_each(|t| t.check_tvars_declared(declared))
404 }
405 Type::Struct(ts) => {
406 ts.iter().try_for_each(|(_, t)| t.check_tvars_declared(declared))
407 }
408 Type::Variant(_, ts) => {
409 ts.iter().try_for_each(|t| t.check_tvars_declared(declared))
410 }
411 Type::TVar(tv) => {
412 if !declared.contains(&tv.name) {
413 bail!("undeclared type variable '{}'", tv.name)
414 } else {
415 Ok(())
416 }
417 }
418 Type::Set(s) => s.iter().try_for_each(|t| t.check_tvars_declared(declared)),
419 Type::Abstract { id: _, params } => {
420 params.iter().try_for_each(|t| t.check_tvars_declared(declared))
421 }
422 Type::Fn(_) => Ok(()),
423 }
424 }
425
426 pub fn has_unbound(&self) -> bool {
427 match self {
428 Type::Bottom | Type::Any | Type::Primitive(_) => false,
429 Type::Ref { .. } => false,
430 Type::Error(e) => e.has_unbound(),
431 Type::Array(t0) => t0.has_unbound(),
432 Type::Map { key, value } => key.has_unbound() || value.has_unbound(),
433 Type::ByRef(t0) => t0.has_unbound(),
434 Type::Tuple(ts) => ts.iter().any(|t| t.has_unbound()),
435 Type::Struct(ts) => ts.iter().any(|(_, t)| t.has_unbound()),
436 Type::Variant(_, ts) => ts.iter().any(|t| t.has_unbound()),
437 Type::TVar(tv) => tv.read().typ.read().is_some(),
438 Type::Set(s) => s.iter().any(|t| t.has_unbound()),
439 Type::Abstract { id: _, params } => params.iter().any(|t| t.has_unbound()),
440 Type::Fn(ft) => ft.has_unbound(),
441 }
442 }
443
444 pub fn bind_as(&self, t: &Self) {
446 match self {
447 Type::Bottom | Type::Any | Type::Primitive(_) => (),
448 Type::Ref { .. } => (),
449 Type::Error(t0) => t0.bind_as(t),
450 Type::Array(t0) => t0.bind_as(t),
451 Type::Map { key, value } => {
452 key.bind_as(t);
453 value.bind_as(t);
454 }
455 Type::ByRef(t0) => t0.bind_as(t),
456 Type::Tuple(ts) => {
457 for elt in ts.iter() {
458 elt.bind_as(t)
459 }
460 }
461 Type::Struct(ts) => {
462 for (_, elt) in ts.iter() {
463 elt.bind_as(t)
464 }
465 }
466 Type::Variant(_, ts) => {
467 for elt in ts.iter() {
468 elt.bind_as(t)
469 }
470 }
471 Type::TVar(tv) => {
472 let tv = tv.read();
473 let mut tv = tv.typ.write();
474 if tv.is_none() {
475 *tv = Some(t.clone());
476 }
477 }
478 Type::Set(s) => {
479 for elt in s.iter() {
480 elt.bind_as(t)
481 }
482 }
483 Type::Fn(ft) => ft.bind_as(t),
484 Type::Abstract { id: _, params } => {
485 for typ in params.iter() {
486 typ.bind_as(t)
487 }
488 }
489 }
490 }
491
492 pub fn reset_tvars(&self) -> Type {
495 match self {
496 Type::Bottom => Type::Bottom,
497 Type::Any => Type::Any,
498 Type::Primitive(p) => Type::Primitive(*p),
499 Type::Ref { scope, name, params } => Type::Ref {
500 scope: scope.clone(),
501 name: name.clone(),
502 params: Arc::from_iter(params.iter().map(|t| t.reset_tvars())),
503 },
504 Type::Error(t0) => Type::Error(Arc::new(t0.reset_tvars())),
505 Type::Array(t0) => Type::Array(Arc::new(t0.reset_tvars())),
506 Type::Map { key, value } => {
507 let key = Arc::new(key.reset_tvars());
508 let value = Arc::new(value.reset_tvars());
509 Type::Map { key, value }
510 }
511 Type::ByRef(t0) => Type::ByRef(Arc::new(t0.reset_tvars())),
512 Type::Tuple(ts) => {
513 Type::Tuple(Arc::from_iter(ts.iter().map(|t| t.reset_tvars())))
514 }
515 Type::Struct(ts) => Type::Struct(Arc::from_iter(
516 ts.iter().map(|(n, t)| (n.clone(), t.reset_tvars())),
517 )),
518 Type::Variant(tag, ts) => Type::Variant(
519 tag.clone(),
520 Arc::from_iter(ts.iter().map(|t| t.reset_tvars())),
521 ),
522 Type::TVar(tv) => Type::TVar(TVar::empty_named(tv.name.clone())),
523 Type::Set(s) => Type::Set(Arc::from_iter(s.iter().map(|t| t.reset_tvars()))),
524 Type::Fn(fntyp) => Type::Fn(Arc::new(fntyp.reset_tvars())),
525 Type::Abstract { id, params } => Type::Abstract {
526 id: *id,
527 params: Arc::from_iter(params.iter().map(|t| t.reset_tvars())),
528 },
529 }
530 }
531
532 pub fn replace_tvars(&self, known: &FxHashMap<ArcStr, Self>) -> Type {
535 match self {
536 Type::TVar(tv) => match known.get(&tv.name) {
537 Some(t) => t.clone(),
538 None => Type::TVar(tv.clone()),
539 },
540 Type::Bottom => Type::Bottom,
541 Type::Any => Type::Any,
542 Type::Primitive(p) => Type::Primitive(*p),
543 Type::Ref { scope, name, params } => Type::Ref {
544 scope: scope.clone(),
545 name: name.clone(),
546 params: Arc::from_iter(params.iter().map(|t| t.replace_tvars(known))),
547 },
548 Type::Error(t0) => Type::Error(Arc::new(t0.replace_tvars(known))),
549 Type::Array(t0) => Type::Array(Arc::new(t0.replace_tvars(known))),
550 Type::Map { key, value } => {
551 let key = Arc::new(key.replace_tvars(known));
552 let value = Arc::new(value.replace_tvars(known));
553 Type::Map { key, value }
554 }
555 Type::ByRef(t0) => Type::ByRef(Arc::new(t0.replace_tvars(known))),
556 Type::Tuple(ts) => {
557 Type::Tuple(Arc::from_iter(ts.iter().map(|t| t.replace_tvars(known))))
558 }
559 Type::Struct(ts) => Type::Struct(Arc::from_iter(
560 ts.iter().map(|(n, t)| (n.clone(), t.replace_tvars(known))),
561 )),
562 Type::Variant(tag, ts) => Type::Variant(
563 tag.clone(),
564 Arc::from_iter(ts.iter().map(|t| t.replace_tvars(known))),
565 ),
566 Type::Set(s) => {
567 Type::Set(Arc::from_iter(s.iter().map(|t| t.replace_tvars(known))))
568 }
569 Type::Fn(fntyp) => Type::Fn(Arc::new(fntyp.replace_tvars(known))),
570 Type::Abstract { id, params } => Type::Abstract {
571 id: *id,
572 params: Arc::from_iter(params.iter().map(|t| t.replace_tvars(known))),
573 },
574 }
575 }
576
577 pub(crate) fn unbind_tvars(&self) {
579 match self {
580 Type::Bottom | Type::Any | Type::Primitive(_) | Type::Ref { .. } => (),
581 Type::Error(t0) => t0.unbind_tvars(),
582 Type::Array(t0) => t0.unbind_tvars(),
583 Type::Map { key, value } => {
584 key.unbind_tvars();
585 value.unbind_tvars();
586 }
587 Type::ByRef(t0) => t0.unbind_tvars(),
588 Type::Tuple(ts)
589 | Type::Variant(_, ts)
590 | Type::Set(ts)
591 | Type::Abstract { id: _, params: ts } => {
592 for t in ts.iter() {
593 t.unbind_tvars()
594 }
595 }
596 Type::Struct(ts) => {
597 for (_, t) in ts.iter() {
598 t.unbind_tvars()
599 }
600 }
601 Type::TVar(tv) => tv.unbind(),
602 Type::Fn(fntyp) => fntyp.unbind_tvars(),
603 }
604 }
605}