Skip to main content

graphix_compiler/typ/
tvar.rs

1use crate::{
2    expr::ModPath,
3    typ::{FnType, PrintFlag, Type, TypeRef, PRINT_FLAGS},
4};
5use ahash::{AHashMap, AHashSet};
6use anyhow::{bail, Result};
7use arcstr::ArcStr;
8use compact_str::format_compact;
9use parking_lot::{RwLock, RwLockReadGuard, RwLockWriteGuard};
10use std::{
11    cmp::{Eq, PartialEq},
12    collections::hash_map::Entry,
13    fmt::{self, Debug},
14    hash::Hash,
15    ops::Deref,
16};
17use triomphe::Arc;
18
19atomic_id!(TVarId);
20
21pub(super) fn would_cycle_inner(addr: usize, t: &Type) -> bool {
22    match t {
23        Type::Primitive(_) | Type::Any | Type::Bottom | Type::Ref(TypeRef { .. }) => {
24            false
25        }
26        Type::TVar(t) => {
27            Arc::as_ptr(&t.read().typ).addr() == addr
28                || match &*t.read().typ.read() {
29                    None => false,
30                    Some(t) => would_cycle_inner(addr, t),
31                }
32        }
33        Type::Abstract { id: _, params } => {
34            params.iter().any(|t| would_cycle_inner(addr, t))
35        }
36        Type::Error(t) => would_cycle_inner(addr, t),
37        Type::Array(a) => would_cycle_inner(addr, &**a),
38        Type::Map { key, value } => {
39            would_cycle_inner(addr, &**key) || would_cycle_inner(addr, &**value)
40        }
41        Type::ByRef(t) => would_cycle_inner(addr, t),
42        Type::Tuple(ts) => ts.iter().any(|t| would_cycle_inner(addr, t)),
43        Type::Variant(_, ts) => ts.iter().any(|t| would_cycle_inner(addr, t)),
44        Type::Struct(ts) => ts.iter().any(|(_, t)| would_cycle_inner(addr, t)),
45        Type::Set(s) => s.iter().any(|t| would_cycle_inner(addr, t)),
46        Type::Fn(f) => {
47            let FnType {
48                args,
49                vargs,
50                rtype,
51                constraints,
52                throws,
53                explicit_throws: _,
54                lambda_ids: _,
55            } = &**f;
56            args.iter().any(|t| would_cycle_inner(addr, &t.typ))
57                || match vargs {
58                    None => false,
59                    Some(t) => would_cycle_inner(addr, t),
60                }
61                || would_cycle_inner(addr, rtype)
62                || constraints.read().iter().any(|a| {
63                    Arc::as_ptr(&a.0.read().typ).addr() == addr
64                        || would_cycle_inner(addr, &a.1)
65                })
66                || would_cycle_inner(addr, &throws)
67        }
68    }
69}
70
71#[derive(Debug)]
72pub struct TVarInnerInner {
73    pub(crate) id: TVarId,
74    pub(crate) frozen: bool,
75    pub(crate) typ: Arc<RwLock<Option<Type>>>,
76}
77
78#[derive(Debug)]
79pub struct TVarInner {
80    pub name: ArcStr,
81    pub(crate) typ: RwLock<TVarInnerInner>,
82}
83
84#[derive(Debug, Clone)]
85pub struct TVar(Arc<TVarInner>);
86
87impl fmt::Display for TVar {
88    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
89        if !PRINT_FLAGS.get().contains(PrintFlag::DerefTVars) {
90            write!(f, "'{}", self.name)
91        } else {
92            write!(f, "'{}: ", self.name)?;
93            match &*self.read().typ.read() {
94                Some(t) => write!(f, "{t}"),
95                None => write!(f, "unbound"),
96            }
97        }
98    }
99}
100
101impl Default for TVar {
102    fn default() -> Self {
103        Self::empty_named(ArcStr::from(format_compact!("_{}", TVarId::new().0).as_str()))
104    }
105}
106
107impl Deref for TVar {
108    type Target = TVarInner;
109
110    fn deref(&self) -> &Self::Target {
111        &*self.0
112    }
113}
114
115impl PartialEq for TVar {
116    fn eq(&self, other: &Self) -> bool {
117        let t0 = self.read();
118        let t1 = other.read();
119        Arc::ptr_eq(&t0.typ, &t1.typ) || {
120            let t0 = t0.typ.read();
121            let t1 = t1.typ.read();
122            *t0 == *t1
123        }
124    }
125}
126
127impl Eq for TVar {}
128
129impl PartialOrd for TVar {
130    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
131        let t0 = self.read();
132        let t1 = other.read();
133        if Arc::ptr_eq(&t0.typ, &t1.typ) {
134            Some(std::cmp::Ordering::Equal)
135        } else {
136            let t0 = t0.typ.read();
137            let t1 = t1.typ.read();
138            t0.partial_cmp(&*t1)
139        }
140    }
141}
142
143impl Ord for TVar {
144    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
145        let t0 = self.read();
146        let t1 = other.read();
147        if Arc::ptr_eq(&t0.typ, &t1.typ) {
148            std::cmp::Ordering::Equal
149        } else {
150            let t0 = t0.typ.read();
151            let t1 = t1.typ.read();
152            t0.cmp(&*t1)
153        }
154    }
155}
156
157impl TVar {
158    pub fn scope_refs(&self, scope: &ModPath) -> Self {
159        match Type::TVar(self.clone()).scope_refs(scope) {
160            Type::TVar(tv) => tv,
161            _ => unreachable!(),
162        }
163    }
164
165    pub fn empty_named(name: ArcStr) -> Self {
166        Self(Arc::new(TVarInner {
167            name,
168            typ: RwLock::new(TVarInnerInner {
169                id: TVarId::new(),
170                frozen: false,
171                typ: Arc::new(RwLock::new(None)),
172            }),
173        }))
174    }
175
176    pub fn named(name: ArcStr, typ: Type) -> Self {
177        Self(Arc::new(TVarInner {
178            name,
179            typ: RwLock::new(TVarInnerInner {
180                id: TVarId::new(),
181                frozen: false,
182                typ: Arc::new(RwLock::new(Some(typ))),
183            }),
184        }))
185    }
186
187    pub fn read<'a>(&'a self) -> RwLockReadGuard<'a, TVarInnerInner> {
188        self.typ.read()
189    }
190
191    pub fn write<'a>(&'a self) -> RwLockWriteGuard<'a, TVarInnerInner> {
192        self.typ.write()
193    }
194
195    /// make self an alias for other
196    pub fn alias(&self, other: &Self) {
197        let mut s = self.write();
198        if !s.frozen {
199            s.frozen = true;
200            let o = other.read();
201            s.id = o.id;
202            s.typ = Arc::clone(&o.typ);
203        }
204    }
205
206    pub fn freeze(&self) {
207        self.write().frozen = true;
208    }
209
210    /// copy self from other
211    pub fn copy(&self, other: &Self) {
212        let s = self.read();
213        let o = other.read();
214        *s.typ.write() = o.typ.read().clone();
215    }
216
217    pub fn normalize(&self) -> Self {
218        match &mut *self.read().typ.write() {
219            None => (),
220            Some(t) => {
221                *t = t.normalize();
222            }
223        }
224        self.clone()
225    }
226
227    pub fn unbind(&self) {
228        *self.read().typ.write() = None
229    }
230
231    pub(super) fn would_cycle(&self, t: &Type) -> bool {
232        let addr = Arc::as_ptr(&self.read().typ).addr();
233        would_cycle_inner(addr, t)
234    }
235
236    pub(super) fn addr(&self) -> usize {
237        Arc::as_ptr(&self.0).addr()
238    }
239
240    pub(super) fn inner_addr(&self) -> usize {
241        Arc::as_ptr(&self.read().typ).addr()
242    }
243}
244
245impl Type {
246    pub fn unfreeze_tvars(&self) {
247        match self {
248            Type::Bottom | Type::Any | Type::Primitive(_) => (),
249            Type::Ref(TypeRef { params, .. }) => {
250                for t in params.iter() {
251                    t.unfreeze_tvars();
252                }
253            }
254            Type::Error(t) => t.unfreeze_tvars(),
255            Type::Array(t) => t.unfreeze_tvars(),
256            Type::Map { key, value } => {
257                key.unfreeze_tvars();
258                value.unfreeze_tvars();
259            }
260            Type::ByRef(t) => t.unfreeze_tvars(),
261            Type::Tuple(ts) => {
262                for t in ts.iter() {
263                    t.unfreeze_tvars()
264                }
265            }
266            Type::Struct(ts) => {
267                for (_, t) in ts.iter() {
268                    t.unfreeze_tvars()
269                }
270            }
271            Type::Variant(_, ts) => {
272                for t in ts.iter() {
273                    t.unfreeze_tvars()
274                }
275            }
276            Type::TVar(tv) => tv.write().frozen = false,
277            Type::Fn(ft) => ft.unfreeze_tvars(),
278            Type::Set(s) => {
279                for typ in s.iter() {
280                    typ.unfreeze_tvars()
281                }
282            }
283            Type::Abstract { id: _, params } => {
284                for typ in params.iter() {
285                    typ.unfreeze_tvars()
286                }
287            }
288        }
289    }
290
291    /// alias type variables with the same name to each other
292    pub fn alias_tvars(&self, known: &mut AHashMap<ArcStr, TVar>) {
293        match self {
294            Type::Bottom | Type::Any | Type::Primitive(_) => (),
295            Type::Ref(TypeRef { params, .. }) => {
296                for t in params.iter() {
297                    t.alias_tvars(known);
298                }
299            }
300            Type::Error(t) => t.alias_tvars(known),
301            Type::Array(t) => t.alias_tvars(known),
302            Type::Map { key, value } => {
303                key.alias_tvars(known);
304                value.alias_tvars(known);
305            }
306            Type::ByRef(t) => t.alias_tvars(known),
307            Type::Tuple(ts) => {
308                for t in ts.iter() {
309                    t.alias_tvars(known)
310                }
311            }
312            Type::Struct(ts) => {
313                for (_, t) in ts.iter() {
314                    t.alias_tvars(known)
315                }
316            }
317            Type::Variant(_, ts) => {
318                for t in ts.iter() {
319                    t.alias_tvars(known)
320                }
321            }
322            Type::TVar(tv) => match known.entry(tv.name.clone()) {
323                Entry::Occupied(e) => {
324                    let v = e.get();
325                    v.freeze();
326                    tv.alias(v);
327                }
328                Entry::Vacant(e) => {
329                    e.insert(tv.clone());
330                    ()
331                }
332            },
333            Type::Fn(ft) => ft.alias_tvars(known),
334            Type::Set(s) => {
335                for typ in s.iter() {
336                    typ.alias_tvars(known)
337                }
338            }
339            Type::Abstract { id: _, params } => {
340                for typ in params.iter() {
341                    typ.alias_tvars(known)
342                }
343            }
344        }
345    }
346
347    pub fn collect_tvars(&self, known: &mut AHashMap<ArcStr, TVar>) {
348        match self {
349            Type::Bottom | Type::Any | Type::Primitive(_) => (),
350            Type::Ref(TypeRef { params, .. }) => {
351                for t in params.iter() {
352                    t.collect_tvars(known);
353                }
354            }
355            Type::Error(t) => t.collect_tvars(known),
356            Type::Array(t) => t.collect_tvars(known),
357            Type::Map { key, value } => {
358                key.collect_tvars(known);
359                value.collect_tvars(known);
360            }
361            Type::ByRef(t) => t.collect_tvars(known),
362            Type::Tuple(ts) => {
363                for t in ts.iter() {
364                    t.collect_tvars(known)
365                }
366            }
367            Type::Struct(ts) => {
368                for (_, t) in ts.iter() {
369                    t.collect_tvars(known)
370                }
371            }
372            Type::Variant(_, ts) => {
373                for t in ts.iter() {
374                    t.collect_tvars(known)
375                }
376            }
377            Type::TVar(tv) => match known.entry(tv.name.clone()) {
378                Entry::Occupied(_) => (),
379                Entry::Vacant(e) => {
380                    e.insert(tv.clone());
381                    ()
382                }
383            },
384            Type::Fn(ft) => ft.collect_tvars(known),
385            Type::Set(s) => {
386                for typ in s.iter() {
387                    typ.collect_tvars(known)
388                }
389            }
390            Type::Abstract { id: _, params } => {
391                for typ in params.iter() {
392                    typ.collect_tvars(known)
393                }
394            }
395        }
396    }
397
398    pub fn check_tvars_declared(&self, declared: &AHashSet<ArcStr>) -> Result<()> {
399        match self {
400            Type::Bottom | Type::Any | Type::Primitive(_) => Ok(()),
401            Type::Ref(TypeRef { params, .. }) => {
402                params.iter().try_for_each(|t| t.check_tvars_declared(declared))
403            }
404            Type::Error(t) => t.check_tvars_declared(declared),
405            Type::Array(t) => t.check_tvars_declared(declared),
406            Type::Map { key, value } => {
407                key.check_tvars_declared(declared)?;
408                value.check_tvars_declared(declared)
409            }
410            Type::ByRef(t) => t.check_tvars_declared(declared),
411            Type::Tuple(ts) => {
412                ts.iter().try_for_each(|t| t.check_tvars_declared(declared))
413            }
414            Type::Struct(ts) => {
415                ts.iter().try_for_each(|(_, t)| t.check_tvars_declared(declared))
416            }
417            Type::Variant(_, ts) => {
418                ts.iter().try_for_each(|t| t.check_tvars_declared(declared))
419            }
420            Type::TVar(tv) => {
421                if !declared.contains(&tv.name) {
422                    bail!("undeclared type variable '{}'", tv.name)
423                } else {
424                    Ok(())
425                }
426            }
427            Type::Set(s) => s.iter().try_for_each(|t| t.check_tvars_declared(declared)),
428            Type::Abstract { id: _, params } => {
429                params.iter().try_for_each(|t| t.check_tvars_declared(declared))
430            }
431            Type::Fn(_) => Ok(()),
432        }
433    }
434
435    pub fn has_unbound(&self) -> bool {
436        match self {
437            Type::Bottom | Type::Any | Type::Primitive(_) => false,
438            Type::Ref(TypeRef { .. }) => false,
439            Type::Error(e) => e.has_unbound(),
440            Type::Array(t0) => t0.has_unbound(),
441            Type::Map { key, value } => key.has_unbound() || value.has_unbound(),
442            Type::ByRef(t0) => t0.has_unbound(),
443            Type::Tuple(ts) => ts.iter().any(|t| t.has_unbound()),
444            Type::Struct(ts) => ts.iter().any(|(_, t)| t.has_unbound()),
445            Type::Variant(_, ts) => ts.iter().any(|t| t.has_unbound()),
446            Type::TVar(tv) => tv.read().typ.read().is_none(),
447            Type::Set(s) => s.iter().any(|t| t.has_unbound()),
448            Type::Abstract { id: _, params } => params.iter().any(|t| t.has_unbound()),
449            Type::Fn(ft) => ft.has_unbound(),
450        }
451    }
452
453    /// bind all unbound type variables to the specified type
454    pub fn bind_as(&self, t: &Self) {
455        match self {
456            Type::Bottom | Type::Any | Type::Primitive(_) => (),
457            Type::Ref(TypeRef { .. }) => (),
458            Type::Error(t0) => t0.bind_as(t),
459            Type::Array(t0) => t0.bind_as(t),
460            Type::Map { key, value } => {
461                key.bind_as(t);
462                value.bind_as(t);
463            }
464            Type::ByRef(t0) => t0.bind_as(t),
465            Type::Tuple(ts) => {
466                for elt in ts.iter() {
467                    elt.bind_as(t)
468                }
469            }
470            Type::Struct(ts) => {
471                for (_, elt) in ts.iter() {
472                    elt.bind_as(t)
473                }
474            }
475            Type::Variant(_, ts) => {
476                for elt in ts.iter() {
477                    elt.bind_as(t)
478                }
479            }
480            Type::TVar(tv) => {
481                let tv = tv.read();
482                let mut tv = tv.typ.write();
483                if tv.is_none() {
484                    *tv = Some(t.clone());
485                }
486            }
487            Type::Set(s) => {
488                for elt in s.iter() {
489                    elt.bind_as(t)
490                }
491            }
492            Type::Fn(ft) => ft.bind_as(t),
493            Type::Abstract { id: _, params } => {
494                for typ in params.iter() {
495                    typ.bind_as(t)
496                }
497            }
498        }
499    }
500
501    /// return a copy of self with all type variables unbound and
502    /// unaliased. self will not be modified
503    pub fn reset_tvars(&self) -> Type {
504        match self {
505            Type::Bottom => Type::Bottom,
506            Type::Any => Type::Any,
507            Type::Primitive(p) => Type::Primitive(*p),
508            Type::Ref(TypeRef { scope, name, params, .. }) => Type::Ref(TypeRef {
509                scope: scope.clone(),
510                name: name.clone(),
511                params: Arc::from_iter(params.iter().map(|t| t.reset_tvars())),
512                ..Default::default()
513            }),
514            Type::Error(t0) => Type::Error(Arc::new(t0.reset_tvars())),
515            Type::Array(t0) => Type::Array(Arc::new(t0.reset_tvars())),
516            Type::Map { key, value } => {
517                let key = Arc::new(key.reset_tvars());
518                let value = Arc::new(value.reset_tvars());
519                Type::Map { key, value }
520            }
521            Type::ByRef(t0) => Type::ByRef(Arc::new(t0.reset_tvars())),
522            Type::Tuple(ts) => {
523                Type::Tuple(Arc::from_iter(ts.iter().map(|t| t.reset_tvars())))
524            }
525            Type::Struct(ts) => Type::Struct(Arc::from_iter(
526                ts.iter().map(|(n, t)| (n.clone(), t.reset_tvars())),
527            )),
528            Type::Variant(tag, ts) => Type::Variant(
529                tag.clone(),
530                Arc::from_iter(ts.iter().map(|t| t.reset_tvars())),
531            ),
532            Type::TVar(tv) => Type::TVar(TVar::empty_named(tv.name.clone())),
533            Type::Set(s) => Type::Set(Arc::from_iter(s.iter().map(|t| t.reset_tvars()))),
534            Type::Fn(fntyp) => Type::Fn(Arc::new(fntyp.reset_tvars())),
535            Type::Abstract { id, params } => Type::Abstract {
536                id: *id,
537                params: Arc::from_iter(params.iter().map(|t| t.reset_tvars())),
538            },
539        }
540    }
541
542    /// return a copy of self with every TVar named in known replaced
543    /// with the corresponding type. TVars not in known are replaced with
544    /// fresh TVars using unique names to avoid entanglement with the caller's
545    /// TVars that happen to share the same name.
546    pub fn replace_tvars(&self, known: &AHashMap<ArcStr, Self>) -> Type {
547        use poolshark::local::LPooled;
548        self.replace_tvars_int(known, &mut LPooled::take())
549    }
550
551    pub(super) fn replace_tvars_int(
552        &self,
553        known: &AHashMap<ArcStr, Self>,
554        renamed: &mut AHashMap<ArcStr, TVar>,
555    ) -> Type {
556        match self {
557            Type::TVar(tv) => match known.get(&tv.name) {
558                Some(t) => t.clone(),
559                None => {
560                    let fresh =
561                        renamed.entry(tv.name.clone()).or_insert_with(TVar::default);
562                    Type::TVar(fresh.clone())
563                }
564            },
565            Type::Bottom => Type::Bottom,
566            Type::Any => Type::Any,
567            Type::Primitive(p) => Type::Primitive(*p),
568            Type::Ref(TypeRef { scope, name, params, .. }) => Type::Ref(TypeRef {
569                scope: scope.clone(),
570                name: name.clone(),
571                params: Arc::from_iter(
572                    params.iter().map(|t| t.replace_tvars_int(known, renamed)),
573                ),
574                ..Default::default()
575            }),
576            Type::Error(t0) => {
577                Type::Error(Arc::new(t0.replace_tvars_int(known, renamed)))
578            }
579            Type::Array(t0) => {
580                Type::Array(Arc::new(t0.replace_tvars_int(known, renamed)))
581            }
582            Type::Map { key, value } => {
583                let key = Arc::new(key.replace_tvars_int(known, renamed));
584                let value = Arc::new(value.replace_tvars_int(known, renamed));
585                Type::Map { key, value }
586            }
587            Type::ByRef(t0) => {
588                Type::ByRef(Arc::new(t0.replace_tvars_int(known, renamed)))
589            }
590            Type::Tuple(ts) => Type::Tuple(Arc::from_iter(
591                ts.iter().map(|t| t.replace_tvars_int(known, renamed)),
592            )),
593            Type::Struct(ts) => Type::Struct(Arc::from_iter(
594                ts.iter().map(|(n, t)| (n.clone(), t.replace_tvars_int(known, renamed))),
595            )),
596            Type::Variant(tag, ts) => Type::Variant(
597                tag.clone(),
598                Arc::from_iter(ts.iter().map(|t| t.replace_tvars_int(known, renamed))),
599            ),
600            Type::Set(s) => Type::Set(Arc::from_iter(
601                s.iter().map(|t| t.replace_tvars_int(known, renamed)),
602            )),
603            Type::Fn(fntyp) => {
604                Type::Fn(Arc::new(fntyp.replace_tvars_int(known, renamed)))
605            }
606            Type::Abstract { id, params } => Type::Abstract {
607                id: *id,
608                params: Arc::from_iter(
609                    params.iter().map(|t| t.replace_tvars_int(known, renamed)),
610                ),
611            },
612        }
613    }
614
615    /// Unbind any bound tvars, but do not unalias them.
616    pub(crate) fn unbind_tvars(&self) {
617        match self {
618            Type::Bottom | Type::Any | Type::Primitive(_) | Type::Ref(TypeRef { .. }) => {
619                ()
620            }
621            Type::Error(t0) => t0.unbind_tvars(),
622            Type::Array(t0) => t0.unbind_tvars(),
623            Type::Map { key, value } => {
624                key.unbind_tvars();
625                value.unbind_tvars();
626            }
627            Type::ByRef(t0) => t0.unbind_tvars(),
628            Type::Tuple(ts)
629            | Type::Variant(_, ts)
630            | Type::Set(ts)
631            | Type::Abstract { id: _, params: ts } => {
632                for t in ts.iter() {
633                    t.unbind_tvars()
634                }
635            }
636            Type::Struct(ts) => {
637                for (_, t) in ts.iter() {
638                    t.unbind_tvars()
639                }
640            }
641            Type::TVar(tv) => tv.unbind(),
642            Type::Fn(fntyp) => fntyp.unbind_tvars(),
643        }
644    }
645}