Skip to main content

graphix_compiler/typ/
tvar.rs

1use crate::{
2    expr::ModPath,
3    typ::{FnType, PrintFlag, Type, PRINT_FLAGS},
4};
5use anyhow::{bail, Result};
6use arcstr::ArcStr;
7use compact_str::format_compact;
8use fxhash::{FxHashMap, FxHashSet};
9use parking_lot::{RwLock, RwLockReadGuard, RwLockWriteGuard};
10use std::{
11    cmp::{Eq, PartialEq},
12    collections::hash_map::Entry,
13    fmt::{self, Debug},
14    hash::Hash,
15    ops::Deref,
16};
17use triomphe::Arc;
18
19atomic_id!(TVarId);
20
21pub(super) fn would_cycle_inner(addr: usize, t: &Type) -> bool {
22    match t {
23        Type::Primitive(_) | Type::Any | Type::Bottom | Type::Ref { .. } => false,
24        Type::TVar(t) => {
25            Arc::as_ptr(&t.read().typ).addr() == addr
26                || match &*t.read().typ.read() {
27                    None => false,
28                    Some(t) => would_cycle_inner(addr, t),
29                }
30        }
31        Type::Abstract { id: _, params } => {
32            params.iter().any(|t| would_cycle_inner(addr, t))
33        }
34        Type::Error(t) => would_cycle_inner(addr, t),
35        Type::Array(a) => would_cycle_inner(addr, &**a),
36        Type::Map { key, value } => {
37            would_cycle_inner(addr, &**key) || would_cycle_inner(addr, &**value)
38        }
39        Type::ByRef(t) => would_cycle_inner(addr, t),
40        Type::Tuple(ts) => ts.iter().any(|t| would_cycle_inner(addr, t)),
41        Type::Variant(_, ts) => ts.iter().any(|t| would_cycle_inner(addr, t)),
42        Type::Struct(ts) => ts.iter().any(|(_, t)| would_cycle_inner(addr, t)),
43        Type::Set(s) => s.iter().any(|t| would_cycle_inner(addr, t)),
44        Type::Fn(f) => {
45            let FnType { args, vargs, rtype, constraints, throws, explicit_throws: _ } =
46                &**f;
47            args.iter().any(|t| would_cycle_inner(addr, &t.typ))
48                || match vargs {
49                    None => false,
50                    Some(t) => would_cycle_inner(addr, t),
51                }
52                || would_cycle_inner(addr, rtype)
53                || constraints.read().iter().any(|a| {
54                    Arc::as_ptr(&a.0.read().typ).addr() == addr
55                        || would_cycle_inner(addr, &a.1)
56                })
57                || would_cycle_inner(addr, &throws)
58        }
59    }
60}
61
62#[derive(Debug)]
63pub struct TVarInnerInner {
64    pub(crate) id: TVarId,
65    pub(crate) frozen: bool,
66    pub(crate) typ: Arc<RwLock<Option<Type>>>,
67}
68
69#[derive(Debug)]
70pub struct TVarInner {
71    pub name: ArcStr,
72    pub(crate) typ: RwLock<TVarInnerInner>,
73}
74
75#[derive(Debug, Clone)]
76pub struct TVar(Arc<TVarInner>);
77
78impl fmt::Display for TVar {
79    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
80        if !PRINT_FLAGS.get().contains(PrintFlag::DerefTVars) {
81            write!(f, "'{}", self.name)
82        } else {
83            write!(f, "'{}: ", self.name)?;
84            match &*self.read().typ.read() {
85                Some(t) => write!(f, "{t}"),
86                None => write!(f, "unbound"),
87            }
88        }
89    }
90}
91
92impl Default for TVar {
93    fn default() -> Self {
94        Self::empty_named(ArcStr::from(format_compact!("_{}", TVarId::new().0).as_str()))
95    }
96}
97
98impl Deref for TVar {
99    type Target = TVarInner;
100
101    fn deref(&self) -> &Self::Target {
102        &*self.0
103    }
104}
105
106impl PartialEq for TVar {
107    fn eq(&self, other: &Self) -> bool {
108        let t0 = self.read();
109        let t1 = other.read();
110        t0.typ.as_ptr().addr() == t1.typ.as_ptr().addr() || {
111            let t0 = t0.typ.read();
112            let t1 = t1.typ.read();
113            *t0 == *t1
114        }
115    }
116}
117
118impl Eq for TVar {}
119
120impl PartialOrd for TVar {
121    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
122        let t0 = self.read();
123        let t1 = other.read();
124        if t0.typ.as_ptr().addr() == t1.typ.as_ptr().addr() {
125            Some(std::cmp::Ordering::Equal)
126        } else {
127            let t0 = t0.typ.read();
128            let t1 = t1.typ.read();
129            t0.partial_cmp(&*t1)
130        }
131    }
132}
133
134impl Ord for TVar {
135    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
136        let t0 = self.read();
137        let t1 = other.read();
138        if t0.typ.as_ptr().addr() == t1.typ.as_ptr().addr() {
139            std::cmp::Ordering::Equal
140        } else {
141            let t0 = t0.typ.read();
142            let t1 = t1.typ.read();
143            t0.cmp(&*t1)
144        }
145    }
146}
147
148impl TVar {
149    pub fn scope_refs(&self, scope: &ModPath) -> Self {
150        match Type::TVar(self.clone()).scope_refs(scope) {
151            Type::TVar(tv) => tv,
152            _ => unreachable!(),
153        }
154    }
155
156    pub fn empty_named(name: ArcStr) -> Self {
157        Self(Arc::new(TVarInner {
158            name,
159            typ: RwLock::new(TVarInnerInner {
160                id: TVarId::new(),
161                frozen: false,
162                typ: Arc::new(RwLock::new(None)),
163            }),
164        }))
165    }
166
167    pub fn named(name: ArcStr, typ: Type) -> Self {
168        Self(Arc::new(TVarInner {
169            name,
170            typ: RwLock::new(TVarInnerInner {
171                id: TVarId::new(),
172                frozen: false,
173                typ: Arc::new(RwLock::new(Some(typ))),
174            }),
175        }))
176    }
177
178    pub fn read<'a>(&'a self) -> RwLockReadGuard<'a, TVarInnerInner> {
179        self.typ.read()
180    }
181
182    pub fn write<'a>(&'a self) -> RwLockWriteGuard<'a, TVarInnerInner> {
183        self.typ.write()
184    }
185
186    /// make self an alias for other
187    pub fn alias(&self, other: &Self) {
188        let mut s = self.write();
189        if !s.frozen {
190            s.frozen = true;
191            let o = other.read();
192            s.id = o.id;
193            s.typ = Arc::clone(&o.typ);
194        }
195    }
196
197    pub fn freeze(&self) {
198        self.write().frozen = true;
199    }
200
201    /// copy self from other
202    pub fn copy(&self, other: &Self) {
203        let s = self.read();
204        let o = other.read();
205        *s.typ.write() = o.typ.read().clone();
206    }
207
208    pub fn normalize(&self) -> Self {
209        match &mut *self.read().typ.write() {
210            None => (),
211            Some(t) => {
212                *t = t.normalize();
213            }
214        }
215        self.clone()
216    }
217
218    pub fn unbind(&self) {
219        *self.read().typ.write() = None
220    }
221
222    pub(super) fn would_cycle(&self, t: &Type) -> bool {
223        let addr = Arc::as_ptr(&self.read().typ).addr();
224        would_cycle_inner(addr, t)
225    }
226
227    pub(super) fn addr(&self) -> usize {
228        Arc::as_ptr(&self.0).addr()
229    }
230
231    pub(super) fn inner_addr(&self) -> usize {
232        Arc::as_ptr(&self.read().typ).addr()
233    }
234}
235
236impl Type {
237    pub fn unfreeze_tvars(&self) {
238        match self {
239            Type::Bottom | Type::Any | Type::Primitive(_) => (),
240            Type::Ref { params, .. } => {
241                for t in params.iter() {
242                    t.unfreeze_tvars();
243                }
244            }
245            Type::Error(t) => t.unfreeze_tvars(),
246            Type::Array(t) => t.unfreeze_tvars(),
247            Type::Map { key, value } => {
248                key.unfreeze_tvars();
249                value.unfreeze_tvars();
250            }
251            Type::ByRef(t) => t.unfreeze_tvars(),
252            Type::Tuple(ts) => {
253                for t in ts.iter() {
254                    t.unfreeze_tvars()
255                }
256            }
257            Type::Struct(ts) => {
258                for (_, t) in ts.iter() {
259                    t.unfreeze_tvars()
260                }
261            }
262            Type::Variant(_, ts) => {
263                for t in ts.iter() {
264                    t.unfreeze_tvars()
265                }
266            }
267            Type::TVar(tv) => tv.write().frozen = false,
268            Type::Fn(ft) => ft.unfreeze_tvars(),
269            Type::Set(s) => {
270                for typ in s.iter() {
271                    typ.unfreeze_tvars()
272                }
273            }
274            Type::Abstract { id: _, params } => {
275                for typ in params.iter() {
276                    typ.unfreeze_tvars()
277                }
278            }
279        }
280    }
281
282    /// alias type variables with the same name to each other
283    pub fn alias_tvars(&self, known: &mut FxHashMap<ArcStr, TVar>) {
284        match self {
285            Type::Bottom | Type::Any | Type::Primitive(_) => (),
286            Type::Ref { params, .. } => {
287                for t in params.iter() {
288                    t.alias_tvars(known);
289                }
290            }
291            Type::Error(t) => t.alias_tvars(known),
292            Type::Array(t) => t.alias_tvars(known),
293            Type::Map { key, value } => {
294                key.alias_tvars(known);
295                value.alias_tvars(known);
296            }
297            Type::ByRef(t) => t.alias_tvars(known),
298            Type::Tuple(ts) => {
299                for t in ts.iter() {
300                    t.alias_tvars(known)
301                }
302            }
303            Type::Struct(ts) => {
304                for (_, t) in ts.iter() {
305                    t.alias_tvars(known)
306                }
307            }
308            Type::Variant(_, ts) => {
309                for t in ts.iter() {
310                    t.alias_tvars(known)
311                }
312            }
313            Type::TVar(tv) => match known.entry(tv.name.clone()) {
314                Entry::Occupied(e) => {
315                    let v = e.get();
316                    v.freeze();
317                    tv.alias(v);
318                }
319                Entry::Vacant(e) => {
320                    e.insert(tv.clone());
321                    ()
322                }
323            },
324            Type::Fn(ft) => ft.alias_tvars(known),
325            Type::Set(s) => {
326                for typ in s.iter() {
327                    typ.alias_tvars(known)
328                }
329            }
330            Type::Abstract { id: _, params } => {
331                for typ in params.iter() {
332                    typ.alias_tvars(known)
333                }
334            }
335        }
336    }
337
338    pub fn collect_tvars(&self, known: &mut FxHashMap<ArcStr, TVar>) {
339        match self {
340            Type::Bottom | Type::Any | Type::Primitive(_) => (),
341            Type::Ref { params, .. } => {
342                for t in params.iter() {
343                    t.collect_tvars(known);
344                }
345            }
346            Type::Error(t) => t.collect_tvars(known),
347            Type::Array(t) => t.collect_tvars(known),
348            Type::Map { key, value } => {
349                key.collect_tvars(known);
350                value.collect_tvars(known);
351            }
352            Type::ByRef(t) => t.collect_tvars(known),
353            Type::Tuple(ts) => {
354                for t in ts.iter() {
355                    t.collect_tvars(known)
356                }
357            }
358            Type::Struct(ts) => {
359                for (_, t) in ts.iter() {
360                    t.collect_tvars(known)
361                }
362            }
363            Type::Variant(_, ts) => {
364                for t in ts.iter() {
365                    t.collect_tvars(known)
366                }
367            }
368            Type::TVar(tv) => match known.entry(tv.name.clone()) {
369                Entry::Occupied(_) => (),
370                Entry::Vacant(e) => {
371                    e.insert(tv.clone());
372                    ()
373                }
374            },
375            Type::Fn(ft) => ft.collect_tvars(known),
376            Type::Set(s) => {
377                for typ in s.iter() {
378                    typ.collect_tvars(known)
379                }
380            }
381            Type::Abstract { id: _, params } => {
382                for typ in params.iter() {
383                    typ.collect_tvars(known)
384                }
385            }
386        }
387    }
388
389    pub fn check_tvars_declared(&self, declared: &FxHashSet<ArcStr>) -> Result<()> {
390        match self {
391            Type::Bottom | Type::Any | Type::Primitive(_) => Ok(()),
392            Type::Ref { params, .. } => {
393                params.iter().try_for_each(|t| t.check_tvars_declared(declared))
394            }
395            Type::Error(t) => t.check_tvars_declared(declared),
396            Type::Array(t) => t.check_tvars_declared(declared),
397            Type::Map { key, value } => {
398                key.check_tvars_declared(declared)?;
399                value.check_tvars_declared(declared)
400            }
401            Type::ByRef(t) => t.check_tvars_declared(declared),
402            Type::Tuple(ts) => {
403                ts.iter().try_for_each(|t| t.check_tvars_declared(declared))
404            }
405            Type::Struct(ts) => {
406                ts.iter().try_for_each(|(_, t)| t.check_tvars_declared(declared))
407            }
408            Type::Variant(_, ts) => {
409                ts.iter().try_for_each(|t| t.check_tvars_declared(declared))
410            }
411            Type::TVar(tv) => {
412                if !declared.contains(&tv.name) {
413                    bail!("undeclared type variable '{}'", tv.name)
414                } else {
415                    Ok(())
416                }
417            }
418            Type::Set(s) => s.iter().try_for_each(|t| t.check_tvars_declared(declared)),
419            Type::Abstract { id: _, params } => {
420                params.iter().try_for_each(|t| t.check_tvars_declared(declared))
421            }
422            Type::Fn(_) => Ok(()),
423        }
424    }
425
426    pub fn has_unbound(&self) -> bool {
427        match self {
428            Type::Bottom | Type::Any | Type::Primitive(_) => false,
429            Type::Ref { .. } => false,
430            Type::Error(e) => e.has_unbound(),
431            Type::Array(t0) => t0.has_unbound(),
432            Type::Map { key, value } => key.has_unbound() || value.has_unbound(),
433            Type::ByRef(t0) => t0.has_unbound(),
434            Type::Tuple(ts) => ts.iter().any(|t| t.has_unbound()),
435            Type::Struct(ts) => ts.iter().any(|(_, t)| t.has_unbound()),
436            Type::Variant(_, ts) => ts.iter().any(|t| t.has_unbound()),
437            Type::TVar(tv) => tv.read().typ.read().is_some(),
438            Type::Set(s) => s.iter().any(|t| t.has_unbound()),
439            Type::Abstract { id: _, params } => params.iter().any(|t| t.has_unbound()),
440            Type::Fn(ft) => ft.has_unbound(),
441        }
442    }
443
444    /// bind all unbound type variables to the specified type
445    pub fn bind_as(&self, t: &Self) {
446        match self {
447            Type::Bottom | Type::Any | Type::Primitive(_) => (),
448            Type::Ref { .. } => (),
449            Type::Error(t0) => t0.bind_as(t),
450            Type::Array(t0) => t0.bind_as(t),
451            Type::Map { key, value } => {
452                key.bind_as(t);
453                value.bind_as(t);
454            }
455            Type::ByRef(t0) => t0.bind_as(t),
456            Type::Tuple(ts) => {
457                for elt in ts.iter() {
458                    elt.bind_as(t)
459                }
460            }
461            Type::Struct(ts) => {
462                for (_, elt) in ts.iter() {
463                    elt.bind_as(t)
464                }
465            }
466            Type::Variant(_, ts) => {
467                for elt in ts.iter() {
468                    elt.bind_as(t)
469                }
470            }
471            Type::TVar(tv) => {
472                let tv = tv.read();
473                let mut tv = tv.typ.write();
474                if tv.is_none() {
475                    *tv = Some(t.clone());
476                }
477            }
478            Type::Set(s) => {
479                for elt in s.iter() {
480                    elt.bind_as(t)
481                }
482            }
483            Type::Fn(ft) => ft.bind_as(t),
484            Type::Abstract { id: _, params } => {
485                for typ in params.iter() {
486                    typ.bind_as(t)
487                }
488            }
489        }
490    }
491
492    /// return a copy of self with all type variables unbound and
493    /// unaliased. self will not be modified
494    pub fn reset_tvars(&self) -> Type {
495        match self {
496            Type::Bottom => Type::Bottom,
497            Type::Any => Type::Any,
498            Type::Primitive(p) => Type::Primitive(*p),
499            Type::Ref { scope, name, params } => Type::Ref {
500                scope: scope.clone(),
501                name: name.clone(),
502                params: Arc::from_iter(params.iter().map(|t| t.reset_tvars())),
503            },
504            Type::Error(t0) => Type::Error(Arc::new(t0.reset_tvars())),
505            Type::Array(t0) => Type::Array(Arc::new(t0.reset_tvars())),
506            Type::Map { key, value } => {
507                let key = Arc::new(key.reset_tvars());
508                let value = Arc::new(value.reset_tvars());
509                Type::Map { key, value }
510            }
511            Type::ByRef(t0) => Type::ByRef(Arc::new(t0.reset_tvars())),
512            Type::Tuple(ts) => {
513                Type::Tuple(Arc::from_iter(ts.iter().map(|t| t.reset_tvars())))
514            }
515            Type::Struct(ts) => Type::Struct(Arc::from_iter(
516                ts.iter().map(|(n, t)| (n.clone(), t.reset_tvars())),
517            )),
518            Type::Variant(tag, ts) => Type::Variant(
519                tag.clone(),
520                Arc::from_iter(ts.iter().map(|t| t.reset_tvars())),
521            ),
522            Type::TVar(tv) => Type::TVar(TVar::empty_named(tv.name.clone())),
523            Type::Set(s) => Type::Set(Arc::from_iter(s.iter().map(|t| t.reset_tvars()))),
524            Type::Fn(fntyp) => Type::Fn(Arc::new(fntyp.reset_tvars())),
525            Type::Abstract { id, params } => Type::Abstract {
526                id: *id,
527                params: Arc::from_iter(params.iter().map(|t| t.reset_tvars())),
528            },
529        }
530    }
531
532    /// return a copy of self with every TVar named in known replaced
533    /// with the corresponding type
534    pub fn replace_tvars(&self, known: &FxHashMap<ArcStr, Self>) -> Type {
535        match self {
536            Type::TVar(tv) => match known.get(&tv.name) {
537                Some(t) => t.clone(),
538                None => Type::TVar(tv.clone()),
539            },
540            Type::Bottom => Type::Bottom,
541            Type::Any => Type::Any,
542            Type::Primitive(p) => Type::Primitive(*p),
543            Type::Ref { scope, name, params } => Type::Ref {
544                scope: scope.clone(),
545                name: name.clone(),
546                params: Arc::from_iter(params.iter().map(|t| t.replace_tvars(known))),
547            },
548            Type::Error(t0) => Type::Error(Arc::new(t0.replace_tvars(known))),
549            Type::Array(t0) => Type::Array(Arc::new(t0.replace_tvars(known))),
550            Type::Map { key, value } => {
551                let key = Arc::new(key.replace_tvars(known));
552                let value = Arc::new(value.replace_tvars(known));
553                Type::Map { key, value }
554            }
555            Type::ByRef(t0) => Type::ByRef(Arc::new(t0.replace_tvars(known))),
556            Type::Tuple(ts) => {
557                Type::Tuple(Arc::from_iter(ts.iter().map(|t| t.replace_tvars(known))))
558            }
559            Type::Struct(ts) => Type::Struct(Arc::from_iter(
560                ts.iter().map(|(n, t)| (n.clone(), t.replace_tvars(known))),
561            )),
562            Type::Variant(tag, ts) => Type::Variant(
563                tag.clone(),
564                Arc::from_iter(ts.iter().map(|t| t.replace_tvars(known))),
565            ),
566            Type::Set(s) => {
567                Type::Set(Arc::from_iter(s.iter().map(|t| t.replace_tvars(known))))
568            }
569            Type::Fn(fntyp) => Type::Fn(Arc::new(fntyp.replace_tvars(known))),
570            Type::Abstract { id, params } => Type::Abstract {
571                id: *id,
572                params: Arc::from_iter(params.iter().map(|t| t.replace_tvars(known))),
573            },
574        }
575    }
576
577    /// Unbind any bound tvars, but do not unalias them.
578    pub(crate) fn unbind_tvars(&self) {
579        match self {
580            Type::Bottom | Type::Any | Type::Primitive(_) | Type::Ref { .. } => (),
581            Type::Error(t0) => t0.unbind_tvars(),
582            Type::Array(t0) => t0.unbind_tvars(),
583            Type::Map { key, value } => {
584                key.unbind_tvars();
585                value.unbind_tvars();
586            }
587            Type::ByRef(t0) => t0.unbind_tvars(),
588            Type::Tuple(ts)
589            | Type::Variant(_, ts)
590            | Type::Set(ts)
591            | Type::Abstract { id: _, params: ts } => {
592                for t in ts.iter() {
593                    t.unbind_tvars()
594                }
595            }
596            Type::Struct(ts) => {
597                for (_, t) in ts.iter() {
598                    t.unbind_tvars()
599                }
600            }
601            Type::TVar(tv) => tv.unbind(),
602            Type::Fn(fntyp) => fntyp.unbind_tvars(),
603        }
604    }
605}