Skip to main content

graphix_compiler/typ/
tvar.rs

1use crate::{
2    expr::ModPath,
3    typ::{FnType, PrintFlag, Type, PRINT_FLAGS, TypeRef},
4};
5use anyhow::{bail, Result};
6use arcstr::ArcStr;
7use compact_str::format_compact;
8use fxhash::{FxHashMap, FxHashSet};
9use parking_lot::{RwLock, RwLockReadGuard, RwLockWriteGuard};
10use std::{
11    cmp::{Eq, PartialEq},
12    collections::hash_map::Entry,
13    fmt::{self, Debug},
14    hash::Hash,
15    ops::Deref,
16};
17use triomphe::Arc;
18
19atomic_id!(TVarId);
20
21pub(super) fn would_cycle_inner(addr: usize, t: &Type) -> bool {
22    match t {
23        Type::Primitive(_) | Type::Any | Type::Bottom | Type::Ref (TypeRef { .. }) => false,
24        Type::TVar(t) => {
25            Arc::as_ptr(&t.read().typ).addr() == addr
26                || match &*t.read().typ.read() {
27                    None => false,
28                    Some(t) => would_cycle_inner(addr, t),
29                }
30        }
31        Type::Abstract { id: _, params } => {
32            params.iter().any(|t| would_cycle_inner(addr, t))
33        }
34        Type::Error(t) => would_cycle_inner(addr, t),
35        Type::Array(a) => would_cycle_inner(addr, &**a),
36        Type::Map { key, value } => {
37            would_cycle_inner(addr, &**key) || would_cycle_inner(addr, &**value)
38        }
39        Type::ByRef(t) => would_cycle_inner(addr, t),
40        Type::Tuple(ts) => ts.iter().any(|t| would_cycle_inner(addr, t)),
41        Type::Variant(_, ts) => ts.iter().any(|t| would_cycle_inner(addr, t)),
42        Type::Struct(ts) => ts.iter().any(|(_, t)| would_cycle_inner(addr, t)),
43        Type::Set(s) => s.iter().any(|t| would_cycle_inner(addr, t)),
44        Type::Fn(f) => {
45            let FnType {
46                args,
47                vargs,
48                rtype,
49                constraints,
50                throws,
51                explicit_throws: _,
52                lambda_ids: _,
53            } = &**f;
54            args.iter().any(|t| would_cycle_inner(addr, &t.typ))
55                || match vargs {
56                    None => false,
57                    Some(t) => would_cycle_inner(addr, t),
58                }
59                || would_cycle_inner(addr, rtype)
60                || constraints.read().iter().any(|a| {
61                    Arc::as_ptr(&a.0.read().typ).addr() == addr
62                        || would_cycle_inner(addr, &a.1)
63                })
64                || would_cycle_inner(addr, &throws)
65        }
66    }
67}
68
69#[derive(Debug)]
70pub struct TVarInnerInner {
71    pub(crate) id: TVarId,
72    pub(crate) frozen: bool,
73    pub(crate) typ: Arc<RwLock<Option<Type>>>,
74}
75
76#[derive(Debug)]
77pub struct TVarInner {
78    pub name: ArcStr,
79    pub(crate) typ: RwLock<TVarInnerInner>,
80}
81
82#[derive(Debug, Clone)]
83pub struct TVar(Arc<TVarInner>);
84
85impl fmt::Display for TVar {
86    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
87        if !PRINT_FLAGS.get().contains(PrintFlag::DerefTVars) {
88            write!(f, "'{}", self.name)
89        } else {
90            write!(f, "'{}: ", self.name)?;
91            match &*self.read().typ.read() {
92                Some(t) => write!(f, "{t}"),
93                None => write!(f, "unbound"),
94            }
95        }
96    }
97}
98
99impl Default for TVar {
100    fn default() -> Self {
101        Self::empty_named(ArcStr::from(format_compact!("_{}", TVarId::new().0).as_str()))
102    }
103}
104
105impl Deref for TVar {
106    type Target = TVarInner;
107
108    fn deref(&self) -> &Self::Target {
109        &*self.0
110    }
111}
112
113impl PartialEq for TVar {
114    fn eq(&self, other: &Self) -> bool {
115        let t0 = self.read();
116        let t1 = other.read();
117        Arc::ptr_eq(&t0.typ, &t1.typ) || {
118            let t0 = t0.typ.read();
119            let t1 = t1.typ.read();
120            *t0 == *t1
121        }
122    }
123}
124
125impl Eq for TVar {}
126
127impl PartialOrd for TVar {
128    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
129        let t0 = self.read();
130        let t1 = other.read();
131        if Arc::ptr_eq(&t0.typ, &t1.typ) {
132            Some(std::cmp::Ordering::Equal)
133        } else {
134            let t0 = t0.typ.read();
135            let t1 = t1.typ.read();
136            t0.partial_cmp(&*t1)
137        }
138    }
139}
140
141impl Ord for TVar {
142    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
143        let t0 = self.read();
144        let t1 = other.read();
145        if Arc::ptr_eq(&t0.typ, &t1.typ) {
146            std::cmp::Ordering::Equal
147        } else {
148            let t0 = t0.typ.read();
149            let t1 = t1.typ.read();
150            t0.cmp(&*t1)
151        }
152    }
153}
154
155impl TVar {
156    pub fn scope_refs(&self, scope: &ModPath) -> Self {
157        match Type::TVar(self.clone()).scope_refs(scope) {
158            Type::TVar(tv) => tv,
159            _ => unreachable!(),
160        }
161    }
162
163    pub fn empty_named(name: ArcStr) -> Self {
164        Self(Arc::new(TVarInner {
165            name,
166            typ: RwLock::new(TVarInnerInner {
167                id: TVarId::new(),
168                frozen: false,
169                typ: Arc::new(RwLock::new(None)),
170            }),
171        }))
172    }
173
174    pub fn named(name: ArcStr, typ: Type) -> Self {
175        Self(Arc::new(TVarInner {
176            name,
177            typ: RwLock::new(TVarInnerInner {
178                id: TVarId::new(),
179                frozen: false,
180                typ: Arc::new(RwLock::new(Some(typ))),
181            }),
182        }))
183    }
184
185    pub fn read<'a>(&'a self) -> RwLockReadGuard<'a, TVarInnerInner> {
186        self.typ.read()
187    }
188
189    pub fn write<'a>(&'a self) -> RwLockWriteGuard<'a, TVarInnerInner> {
190        self.typ.write()
191    }
192
193    /// make self an alias for other
194    pub fn alias(&self, other: &Self) {
195        let mut s = self.write();
196        if !s.frozen {
197            s.frozen = true;
198            let o = other.read();
199            s.id = o.id;
200            s.typ = Arc::clone(&o.typ);
201        }
202    }
203
204    pub fn freeze(&self) {
205        self.write().frozen = true;
206    }
207
208    /// copy self from other
209    pub fn copy(&self, other: &Self) {
210        let s = self.read();
211        let o = other.read();
212        *s.typ.write() = o.typ.read().clone();
213    }
214
215    pub fn normalize(&self) -> Self {
216        match &mut *self.read().typ.write() {
217            None => (),
218            Some(t) => {
219                *t = t.normalize();
220            }
221        }
222        self.clone()
223    }
224
225    pub fn unbind(&self) {
226        *self.read().typ.write() = None
227    }
228
229    pub(super) fn would_cycle(&self, t: &Type) -> bool {
230        let addr = Arc::as_ptr(&self.read().typ).addr();
231        would_cycle_inner(addr, t)
232    }
233
234    pub(super) fn addr(&self) -> usize {
235        Arc::as_ptr(&self.0).addr()
236    }
237
238    pub(super) fn inner_addr(&self) -> usize {
239        Arc::as_ptr(&self.read().typ).addr()
240    }
241}
242
243impl Type {
244    pub fn unfreeze_tvars(&self) {
245        match self {
246            Type::Bottom | Type::Any | Type::Primitive(_) => (),
247            Type::Ref (TypeRef { params, .. }) => {
248                for t in params.iter() {
249                    t.unfreeze_tvars();
250                }
251            }
252            Type::Error(t) => t.unfreeze_tvars(),
253            Type::Array(t) => t.unfreeze_tvars(),
254            Type::Map { key, value } => {
255                key.unfreeze_tvars();
256                value.unfreeze_tvars();
257            }
258            Type::ByRef(t) => t.unfreeze_tvars(),
259            Type::Tuple(ts) => {
260                for t in ts.iter() {
261                    t.unfreeze_tvars()
262                }
263            }
264            Type::Struct(ts) => {
265                for (_, t) in ts.iter() {
266                    t.unfreeze_tvars()
267                }
268            }
269            Type::Variant(_, ts) => {
270                for t in ts.iter() {
271                    t.unfreeze_tvars()
272                }
273            }
274            Type::TVar(tv) => tv.write().frozen = false,
275            Type::Fn(ft) => ft.unfreeze_tvars(),
276            Type::Set(s) => {
277                for typ in s.iter() {
278                    typ.unfreeze_tvars()
279                }
280            }
281            Type::Abstract { id: _, params } => {
282                for typ in params.iter() {
283                    typ.unfreeze_tvars()
284                }
285            }
286        }
287    }
288
289    /// alias type variables with the same name to each other
290    pub fn alias_tvars(&self, known: &mut FxHashMap<ArcStr, TVar>) {
291        match self {
292            Type::Bottom | Type::Any | Type::Primitive(_) => (),
293            Type::Ref (TypeRef { params, .. }) => {
294                for t in params.iter() {
295                    t.alias_tvars(known);
296                }
297            }
298            Type::Error(t) => t.alias_tvars(known),
299            Type::Array(t) => t.alias_tvars(known),
300            Type::Map { key, value } => {
301                key.alias_tvars(known);
302                value.alias_tvars(known);
303            }
304            Type::ByRef(t) => t.alias_tvars(known),
305            Type::Tuple(ts) => {
306                for t in ts.iter() {
307                    t.alias_tvars(known)
308                }
309            }
310            Type::Struct(ts) => {
311                for (_, t) in ts.iter() {
312                    t.alias_tvars(known)
313                }
314            }
315            Type::Variant(_, ts) => {
316                for t in ts.iter() {
317                    t.alias_tvars(known)
318                }
319            }
320            Type::TVar(tv) => match known.entry(tv.name.clone()) {
321                Entry::Occupied(e) => {
322                    let v = e.get();
323                    v.freeze();
324                    tv.alias(v);
325                }
326                Entry::Vacant(e) => {
327                    e.insert(tv.clone());
328                    ()
329                }
330            },
331            Type::Fn(ft) => ft.alias_tvars(known),
332            Type::Set(s) => {
333                for typ in s.iter() {
334                    typ.alias_tvars(known)
335                }
336            }
337            Type::Abstract { id: _, params } => {
338                for typ in params.iter() {
339                    typ.alias_tvars(known)
340                }
341            }
342        }
343    }
344
345    pub fn collect_tvars(&self, known: &mut FxHashMap<ArcStr, TVar>) {
346        match self {
347            Type::Bottom | Type::Any | Type::Primitive(_) => (),
348            Type::Ref (TypeRef { params, .. }) => {
349                for t in params.iter() {
350                    t.collect_tvars(known);
351                }
352            }
353            Type::Error(t) => t.collect_tvars(known),
354            Type::Array(t) => t.collect_tvars(known),
355            Type::Map { key, value } => {
356                key.collect_tvars(known);
357                value.collect_tvars(known);
358            }
359            Type::ByRef(t) => t.collect_tvars(known),
360            Type::Tuple(ts) => {
361                for t in ts.iter() {
362                    t.collect_tvars(known)
363                }
364            }
365            Type::Struct(ts) => {
366                for (_, t) in ts.iter() {
367                    t.collect_tvars(known)
368                }
369            }
370            Type::Variant(_, ts) => {
371                for t in ts.iter() {
372                    t.collect_tvars(known)
373                }
374            }
375            Type::TVar(tv) => match known.entry(tv.name.clone()) {
376                Entry::Occupied(_) => (),
377                Entry::Vacant(e) => {
378                    e.insert(tv.clone());
379                    ()
380                }
381            },
382            Type::Fn(ft) => ft.collect_tvars(known),
383            Type::Set(s) => {
384                for typ in s.iter() {
385                    typ.collect_tvars(known)
386                }
387            }
388            Type::Abstract { id: _, params } => {
389                for typ in params.iter() {
390                    typ.collect_tvars(known)
391                }
392            }
393        }
394    }
395
396    pub fn check_tvars_declared(&self, declared: &FxHashSet<ArcStr>) -> Result<()> {
397        match self {
398            Type::Bottom | Type::Any | Type::Primitive(_) => Ok(()),
399            Type::Ref (TypeRef { params, .. }) => {
400                params.iter().try_for_each(|t| t.check_tvars_declared(declared))
401            }
402            Type::Error(t) => t.check_tvars_declared(declared),
403            Type::Array(t) => t.check_tvars_declared(declared),
404            Type::Map { key, value } => {
405                key.check_tvars_declared(declared)?;
406                value.check_tvars_declared(declared)
407            }
408            Type::ByRef(t) => t.check_tvars_declared(declared),
409            Type::Tuple(ts) => {
410                ts.iter().try_for_each(|t| t.check_tvars_declared(declared))
411            }
412            Type::Struct(ts) => {
413                ts.iter().try_for_each(|(_, t)| t.check_tvars_declared(declared))
414            }
415            Type::Variant(_, ts) => {
416                ts.iter().try_for_each(|t| t.check_tvars_declared(declared))
417            }
418            Type::TVar(tv) => {
419                if !declared.contains(&tv.name) {
420                    bail!("undeclared type variable '{}'", tv.name)
421                } else {
422                    Ok(())
423                }
424            }
425            Type::Set(s) => s.iter().try_for_each(|t| t.check_tvars_declared(declared)),
426            Type::Abstract { id: _, params } => {
427                params.iter().try_for_each(|t| t.check_tvars_declared(declared))
428            }
429            Type::Fn(_) => Ok(()),
430        }
431    }
432
433    pub fn has_unbound(&self) -> bool {
434        match self {
435            Type::Bottom | Type::Any | Type::Primitive(_) => false,
436            Type::Ref (TypeRef { .. }) => false,
437            Type::Error(e) => e.has_unbound(),
438            Type::Array(t0) => t0.has_unbound(),
439            Type::Map { key, value } => key.has_unbound() || value.has_unbound(),
440            Type::ByRef(t0) => t0.has_unbound(),
441            Type::Tuple(ts) => ts.iter().any(|t| t.has_unbound()),
442            Type::Struct(ts) => ts.iter().any(|(_, t)| t.has_unbound()),
443            Type::Variant(_, ts) => ts.iter().any(|t| t.has_unbound()),
444            Type::TVar(tv) => tv.read().typ.read().is_none(),
445            Type::Set(s) => s.iter().any(|t| t.has_unbound()),
446            Type::Abstract { id: _, params } => params.iter().any(|t| t.has_unbound()),
447            Type::Fn(ft) => ft.has_unbound(),
448        }
449    }
450
451    /// bind all unbound type variables to the specified type
452    pub fn bind_as(&self, t: &Self) {
453        match self {
454            Type::Bottom | Type::Any | Type::Primitive(_) => (),
455            Type::Ref (TypeRef { .. }) => (),
456            Type::Error(t0) => t0.bind_as(t),
457            Type::Array(t0) => t0.bind_as(t),
458            Type::Map { key, value } => {
459                key.bind_as(t);
460                value.bind_as(t);
461            }
462            Type::ByRef(t0) => t0.bind_as(t),
463            Type::Tuple(ts) => {
464                for elt in ts.iter() {
465                    elt.bind_as(t)
466                }
467            }
468            Type::Struct(ts) => {
469                for (_, elt) in ts.iter() {
470                    elt.bind_as(t)
471                }
472            }
473            Type::Variant(_, ts) => {
474                for elt in ts.iter() {
475                    elt.bind_as(t)
476                }
477            }
478            Type::TVar(tv) => {
479                let tv = tv.read();
480                let mut tv = tv.typ.write();
481                if tv.is_none() {
482                    *tv = Some(t.clone());
483                }
484            }
485            Type::Set(s) => {
486                for elt in s.iter() {
487                    elt.bind_as(t)
488                }
489            }
490            Type::Fn(ft) => ft.bind_as(t),
491            Type::Abstract { id: _, params } => {
492                for typ in params.iter() {
493                    typ.bind_as(t)
494                }
495            }
496        }
497    }
498
499    /// return a copy of self with all type variables unbound and
500    /// unaliased. self will not be modified
501    pub fn reset_tvars(&self) -> Type {
502        match self {
503            Type::Bottom => Type::Bottom,
504            Type::Any => Type::Any,
505            Type::Primitive(p) => Type::Primitive(*p),
506            Type::Ref (TypeRef { scope, name, params, .. }) => Type::Ref (TypeRef {
507                scope: scope.clone(),
508                name: name.clone(),
509                params: Arc::from_iter(params.iter().map(|t| t.reset_tvars())),
510             ..Default::default()}),
511            Type::Error(t0) => Type::Error(Arc::new(t0.reset_tvars())),
512            Type::Array(t0) => Type::Array(Arc::new(t0.reset_tvars())),
513            Type::Map { key, value } => {
514                let key = Arc::new(key.reset_tvars());
515                let value = Arc::new(value.reset_tvars());
516                Type::Map { key, value }
517            }
518            Type::ByRef(t0) => Type::ByRef(Arc::new(t0.reset_tvars())),
519            Type::Tuple(ts) => {
520                Type::Tuple(Arc::from_iter(ts.iter().map(|t| t.reset_tvars())))
521            }
522            Type::Struct(ts) => Type::Struct(Arc::from_iter(
523                ts.iter().map(|(n, t)| (n.clone(), t.reset_tvars())),
524            )),
525            Type::Variant(tag, ts) => Type::Variant(
526                tag.clone(),
527                Arc::from_iter(ts.iter().map(|t| t.reset_tvars())),
528            ),
529            Type::TVar(tv) => Type::TVar(TVar::empty_named(tv.name.clone())),
530            Type::Set(s) => Type::Set(Arc::from_iter(s.iter().map(|t| t.reset_tvars()))),
531            Type::Fn(fntyp) => Type::Fn(Arc::new(fntyp.reset_tvars())),
532            Type::Abstract { id, params } => Type::Abstract {
533                id: *id,
534                params: Arc::from_iter(params.iter().map(|t| t.reset_tvars())),
535            },
536        }
537    }
538
539    /// return a copy of self with every TVar named in known replaced
540    /// with the corresponding type. TVars not in known are replaced with
541    /// fresh TVars using unique names to avoid entanglement with the caller's
542    /// TVars that happen to share the same name.
543    pub fn replace_tvars(&self, known: &FxHashMap<ArcStr, Self>) -> Type {
544        use poolshark::local::LPooled;
545        self.replace_tvars_int(known, &mut LPooled::take())
546    }
547
548    pub(super) fn replace_tvars_int(
549        &self,
550        known: &FxHashMap<ArcStr, Self>,
551        renamed: &mut FxHashMap<ArcStr, TVar>,
552    ) -> Type {
553        match self {
554            Type::TVar(tv) => match known.get(&tv.name) {
555                Some(t) => t.clone(),
556                None => {
557                    let fresh =
558                        renamed.entry(tv.name.clone()).or_insert_with(TVar::default);
559                    Type::TVar(fresh.clone())
560                }
561            },
562            Type::Bottom => Type::Bottom,
563            Type::Any => Type::Any,
564            Type::Primitive(p) => Type::Primitive(*p),
565            Type::Ref (TypeRef { scope, name, params, .. }) => Type::Ref (TypeRef {
566                scope: scope.clone(),
567                name: name.clone(),
568                params: Arc::from_iter(
569                    params.iter().map(|t| t.replace_tvars_int(known, renamed)),
570                ),
571             ..Default::default()}),
572            Type::Error(t0) => {
573                Type::Error(Arc::new(t0.replace_tvars_int(known, renamed)))
574            }
575            Type::Array(t0) => {
576                Type::Array(Arc::new(t0.replace_tvars_int(known, renamed)))
577            }
578            Type::Map { key, value } => {
579                let key = Arc::new(key.replace_tvars_int(known, renamed));
580                let value = Arc::new(value.replace_tvars_int(known, renamed));
581                Type::Map { key, value }
582            }
583            Type::ByRef(t0) => {
584                Type::ByRef(Arc::new(t0.replace_tvars_int(known, renamed)))
585            }
586            Type::Tuple(ts) => Type::Tuple(Arc::from_iter(
587                ts.iter().map(|t| t.replace_tvars_int(known, renamed)),
588            )),
589            Type::Struct(ts) => Type::Struct(Arc::from_iter(
590                ts.iter().map(|(n, t)| (n.clone(), t.replace_tvars_int(known, renamed))),
591            )),
592            Type::Variant(tag, ts) => Type::Variant(
593                tag.clone(),
594                Arc::from_iter(ts.iter().map(|t| t.replace_tvars_int(known, renamed))),
595            ),
596            Type::Set(s) => Type::Set(Arc::from_iter(
597                s.iter().map(|t| t.replace_tvars_int(known, renamed)),
598            )),
599            Type::Fn(fntyp) => {
600                Type::Fn(Arc::new(fntyp.replace_tvars_int(known, renamed)))
601            }
602            Type::Abstract { id, params } => Type::Abstract {
603                id: *id,
604                params: Arc::from_iter(
605                    params.iter().map(|t| t.replace_tvars_int(known, renamed)),
606                ),
607            },
608        }
609    }
610
611    /// Unbind any bound tvars, but do not unalias them.
612    pub(crate) fn unbind_tvars(&self) {
613        match self {
614            Type::Bottom | Type::Any | Type::Primitive(_) | Type::Ref (TypeRef { .. }) => (),
615            Type::Error(t0) => t0.unbind_tvars(),
616            Type::Array(t0) => t0.unbind_tvars(),
617            Type::Map { key, value } => {
618                key.unbind_tvars();
619                value.unbind_tvars();
620            }
621            Type::ByRef(t0) => t0.unbind_tvars(),
622            Type::Tuple(ts)
623            | Type::Variant(_, ts)
624            | Type::Set(ts)
625            | Type::Abstract { id: _, params: ts } => {
626                for t in ts.iter() {
627                    t.unbind_tvars()
628                }
629            }
630            Type::Struct(ts) => {
631                for (_, t) in ts.iter() {
632                    t.unbind_tvars()
633                }
634            }
635            Type::TVar(tv) => tv.unbind(),
636            Type::Fn(fntyp) => fntyp.unbind_tvars(),
637        }
638    }
639}