graphix_compiler/typ/
mod.rs

1use crate::{
2    env::Env, expr::ModPath, format_with_flags, PrintFlag, Rt, UserEvent, PRINT_FLAGS,
3};
4use anyhow::{anyhow, bail, Result};
5use arcstr::ArcStr;
6use enumflags2::BitFlags;
7use fxhash::{FxHashMap, FxHashSet};
8use netidx::{
9    publisher::{Typ, Value},
10    utils::Either,
11};
12use netidx_value::ValArray;
13use parking_lot::RwLock;
14use smallvec::{smallvec, SmallVec};
15use std::{
16    cell::RefCell,
17    cmp::{Eq, PartialEq},
18    collections::{hash_map::Entry, HashMap, HashSet},
19    fmt::{self, Debug},
20    iter,
21};
22use triomphe::Arc;
23
24mod fntyp;
25mod tval;
26mod tvar;
27
28pub use fntyp::{FnArgType, FnType};
29pub use tval::TVal;
30use tvar::would_cycle_inner;
31pub use tvar::TVar;
32
33struct AndAc(bool);
34
35impl FromIterator<bool> for AndAc {
36    fn from_iter<T: IntoIterator<Item = bool>>(iter: T) -> Self {
37        AndAc(iter.into_iter().all(|b| b))
38    }
39}
40
41#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
42pub enum Type {
43    Bottom,
44    Any,
45    Primitive(BitFlags<Typ>),
46    Ref { scope: ModPath, name: ModPath, params: Arc<[Type]> },
47    Fn(Arc<FnType>),
48    Set(Arc<[Type]>),
49    TVar(TVar),
50    Array(Arc<Type>),
51    ByRef(Arc<Type>),
52    Tuple(Arc<[Type]>),
53    Struct(Arc<[(ArcStr, Type)]>),
54    Variant(ArcStr, Arc<[Type]>),
55}
56
57impl Default for Type {
58    fn default() -> Self {
59        Self::Bottom
60    }
61}
62
63impl Type {
64    pub fn empty_tvar() -> Self {
65        Type::TVar(TVar::default())
66    }
67
68    fn iter_prims(&self) -> impl Iterator<Item = Self> {
69        match self {
70            Self::Primitive(p) => {
71                Either::Left(p.iter().map(|t| Type::Primitive(t.into())))
72            }
73            t => Either::Right(iter::once(t.clone())),
74        }
75    }
76
77    pub fn is_defined(&self) -> bool {
78        match self {
79            Self::Bottom
80            | Self::Any
81            | Self::Primitive(_)
82            | Self::Fn(_)
83            | Self::Set(_)
84            | Self::Array(_)
85            | Self::ByRef(_)
86            | Self::Tuple(_)
87            | Self::Struct(_)
88            | Self::Variant(_, _) => true,
89            Self::TVar(tv) => tv.read().typ.read().is_some(),
90            Self::Ref { .. } => true,
91        }
92    }
93
94    pub fn lookup_ref<'a, R: Rt, E: UserEvent>(
95        &'a self,
96        env: &'a Env<R, E>,
97    ) -> Result<&'a Type> {
98        match self {
99            Self::Ref { scope, name, params } => {
100                let def = env
101                    .lookup_typedef(scope, name)
102                    .ok_or_else(|| anyhow!("undefined type {name} in {scope}"))?;
103                if def.params.len() != params.len() {
104                    bail!("{} expects {} type parameters", name, def.params.len());
105                }
106                def.typ.unbind_tvars();
107                for ((tv, ct), arg) in def.params.iter().zip(params.iter()) {
108                    if let Some(ct) = ct {
109                        ct.check_contains(env, arg)?;
110                    }
111                    if !tv.would_cycle(arg) {
112                        *tv.read().typ.write() = Some(arg.clone());
113                    }
114                }
115                Ok(&def.typ)
116            }
117            t => Ok(t),
118        }
119    }
120
121    pub fn check_contains<R: Rt, E: UserEvent>(
122        &self,
123        env: &Env<R, E>,
124        t: &Self,
125    ) -> Result<()> {
126        if self.contains(env, t)? {
127            Ok(())
128        } else {
129            format_with_flags(PrintFlag::DerefTVars | PrintFlag::ReplacePrims, || {
130                bail!("type mismatch {self} does not contain {t}")
131            })
132        }
133    }
134
135    fn contains_int<R: Rt, E: UserEvent>(
136        &self,
137        env: &Env<R, E>,
138        hist: &mut FxHashMap<(usize, usize), bool>,
139        t: &Self,
140    ) -> Result<bool> {
141        if (self as *const Type) == (t as *const Type) {
142            return Ok(true);
143        }
144        match (self, t) {
145            (
146                Self::Ref { scope: s0, name: n0, .. },
147                Self::Ref { scope: s1, name: n1, .. },
148            ) if s0 == s1 && n0 == n1 => Ok(true),
149            (t0 @ Self::Ref { .. }, t1) | (t0, t1 @ Self::Ref { .. }) => {
150                let t0 = t0.lookup_ref(env)?;
151                let t1 = t1.lookup_ref(env)?;
152                let t0_addr = (t0 as *const Type).addr();
153                let t1_addr = (t1 as *const Type).addr();
154                match hist.get(&(t0_addr, t1_addr)) {
155                    Some(r) => Ok(*r),
156                    None => {
157                        hist.insert((t0_addr, t1_addr), true);
158                        match t0.contains_int(env, hist, t1) {
159                            Ok(r) => {
160                                hist.insert((t0_addr, t1_addr), r);
161                                Ok(r)
162                            }
163                            Err(e) => {
164                                hist.remove(&(t0_addr, t1_addr));
165                                Err(e)
166                            }
167                        }
168                    }
169                }
170            }
171            (Self::TVar(t0), Self::Bottom) => {
172                if let Some(_) = &*t0.read().typ.read() {
173                    return Ok(true);
174                }
175                *t0.read().typ.write() = Some(Self::Bottom);
176                Ok(true)
177            }
178            (Self::TVar(t0), Self::Any) => {
179                if let Some(t0) = &*t0.read().typ.read() {
180                    return t0.contains_int(env, hist, t);
181                }
182                *t0.read().typ.write() = Some(Self::Any);
183                Ok(true)
184            }
185            (Self::Any, _) => Ok(true),
186            (Self::Bottom, _) | (_, Self::Bottom) => Ok(true),
187            (Self::Primitive(p0), Self::Primitive(p1)) => Ok(p0.contains(*p1)),
188            (
189                Self::Primitive(p),
190                Self::Array(_) | Self::Tuple(_) | Self::Struct(_) | Self::Variant(_, _),
191            ) => Ok(p.contains(Typ::Array)),
192            (Self::Array(t0), Self::Array(t1)) => t0.contains_int(env, hist, t1),
193            (Self::Array(t0), Self::Primitive(p)) if *p == BitFlags::from(Typ::Array) => {
194                t0.contains_int(env, hist, &Type::Primitive(BitFlags::all()))
195            }
196            (Self::Tuple(t0), Self::Tuple(t1))
197                if t0.as_ptr().addr() == t1.as_ptr().addr() =>
198            {
199                Ok(true)
200            }
201            (Self::Tuple(t0), Self::Tuple(t1)) => Ok(t0.len() == t1.len()
202                && t0
203                    .iter()
204                    .zip(t1.iter())
205                    .map(|(t0, t1)| t0.contains_int(env, hist, t1))
206                    .collect::<Result<AndAc>>()?
207                    .0),
208            (Self::Struct(t0), Self::Struct(t1))
209                if t0.as_ptr().addr() == t1.as_ptr().addr() =>
210            {
211                Ok(true)
212            }
213            (Self::Struct(t0), Self::Struct(t1)) => {
214                Ok(t0.len() == t1.len() && {
215                    // struct types are always sorted by field name
216                    t0.iter()
217                        .zip(t1.iter())
218                        .map(|((n0, t0), (n1, t1))| {
219                            Ok(n0 == n1 && t0.contains_int(env, hist, t1)?)
220                        })
221                        .collect::<Result<AndAc>>()?
222                        .0
223                })
224            }
225            (Self::Variant(tg0, t0), Self::Variant(tg1, t1))
226                if tg0.as_ptr() == tg1.as_ptr()
227                    && t0.as_ptr().addr() == t1.as_ptr().addr() =>
228            {
229                Ok(true)
230            }
231            (Self::Variant(tg0, t0), Self::Variant(tg1, t1)) => Ok(tg0 == tg1
232                && t0.len() == t1.len()
233                && t0
234                    .iter()
235                    .zip(t1.iter())
236                    .map(|(t0, t1)| t0.contains_int(env, hist, t1))
237                    .collect::<Result<AndAc>>()?
238                    .0),
239            (Self::ByRef(t0), Self::ByRef(t1)) => t0.contains_int(env, hist, t1),
240            (Self::Tuple(_), Self::Array(_))
241            | (Self::Tuple(_), Self::Primitive(_))
242            | (Self::Tuple(_), Self::Struct(_))
243            | (Self::Tuple(_), Self::Variant(_, _))
244            | (Self::Array(_), Self::Primitive(_))
245            | (Self::Array(_), Self::Tuple(_))
246            | (Self::Array(_), Self::Struct(_))
247            | (Self::Array(_), Self::Variant(_, _))
248            | (Self::Struct(_), Self::Primitive(_))
249            | (Self::Struct(_), Self::Array(_))
250            | (Self::Struct(_), Self::Tuple(_))
251            | (Self::Struct(_), Self::Variant(_, _))
252            | (Self::Variant(_, _), Self::Array(_))
253            | (Self::Variant(_, _), Self::Struct(_))
254            | (Self::Variant(_, _), Self::Primitive(_))
255            | (Self::Variant(_, _), Self::Tuple(_)) => Ok(false),
256            (Self::TVar(t0), Self::TVar(t1)) if t0.addr() == t1.addr() => Ok(true),
257            (Self::TVar(t0), tt1 @ Self::TVar(t1)) => {
258                #[derive(Debug)]
259                enum Act {
260                    RightCopy,
261                    RightAlias,
262                    LeftAlias,
263                    LeftCopy,
264                }
265                let act = {
266                    let t0 = t0.read();
267                    let t1 = t1.read();
268                    let addr = Arc::as_ptr(&t0.typ).addr();
269                    if addr == Arc::as_ptr(&t1.typ).addr() {
270                        return Ok(true);
271                    }
272                    let t0i = t0.typ.read();
273                    let t1i = t1.typ.read();
274                    match (&*t0i, &*t1i) {
275                        (Some(t0), Some(t1)) => return t0.contains_int(env, hist, &*t1),
276                        (None, None) => {
277                            if would_cycle_inner(addr, tt1) {
278                                return Ok(true);
279                            }
280                            if t0.frozen && t1.frozen {
281                                return Ok(true);
282                            }
283                            if t0.frozen {
284                                Act::RightAlias
285                            } else {
286                                Act::LeftAlias
287                            }
288                        }
289                        (Some(_), None) => {
290                            if would_cycle_inner(addr, tt1) {
291                                return Ok(true);
292                            }
293                            Act::RightCopy
294                        }
295                        (None, Some(_)) => {
296                            if would_cycle_inner(addr, tt1) {
297                                return Ok(true);
298                            }
299                            Act::LeftCopy
300                        }
301                    }
302                };
303                match act {
304                    Act::RightCopy => t1.copy(t0),
305                    Act::RightAlias => t1.alias(t0),
306                    Act::LeftAlias => t0.alias(t1),
307                    Act::LeftCopy => t0.copy(t1),
308                }
309                Ok(true)
310            }
311            (Self::TVar(t0), t1) if !t0.would_cycle(t1) => {
312                if let Some(t0) = &*t0.read().typ.read() {
313                    return t0.contains_int(env, hist, t1);
314                }
315                *t0.read().typ.write() = Some(t1.clone());
316                Ok(true)
317            }
318            (t0, Self::TVar(t1)) if !t1.would_cycle(t0) => {
319                if let Some(t1) = &*t1.read().typ.read() {
320                    return t0.contains_int(env, hist, t1);
321                }
322                *t1.read().typ.write() = Some(t0.clone());
323                Ok(true)
324            }
325            (Self::Set(s0), Self::Set(s1))
326                if s0.as_ptr().addr() == s1.as_ptr().addr() =>
327            {
328                Ok(true)
329            }
330            (t0, Self::Set(s)) => Ok(s
331                .iter()
332                .map(|t1| t0.contains_int(env, hist, t1))
333                .collect::<Result<AndAc>>()?
334                .0),
335            (Self::Set(s), t) => Ok(s
336                .iter()
337                .fold(Ok::<_, anyhow::Error>(false), |acc, t0| {
338                    Ok(acc? || t0.contains_int(env, hist, t)?)
339                })?
340                || t.iter_prims().fold(Ok::<_, anyhow::Error>(true), |acc, t1| {
341                    Ok(acc?
342                        && s.iter().fold(Ok::<_, anyhow::Error>(false), |acc, t0| {
343                            Ok(acc? || t0.contains_int(env, hist, &t1)?)
344                        })?)
345                })?),
346            (Self::Fn(f0), Self::Fn(f1)) => {
347                Ok(f0.as_ptr() == f1.as_ptr() || f0.contains_int(env, hist, f1)?)
348            }
349            (_, Self::Any)
350            | (_, Self::TVar(_))
351            | (Self::TVar(_), _)
352            | (Self::Fn(_), _)
353            | (Self::ByRef(_), _)
354            | (_, Self::ByRef(_))
355            | (_, Self::Fn(_)) => Ok(false),
356        }
357    }
358
359    pub fn contains<R: Rt, E: UserEvent>(
360        &self,
361        env: &Env<R, E>,
362        t: &Self,
363    ) -> Result<bool> {
364        thread_local! {
365            static HIST: RefCell<FxHashMap<(usize, usize), bool>> = RefCell::new(HashMap::default());
366        }
367        HIST.with_borrow_mut(|hist| {
368            hist.clear();
369            self.contains_int(env, hist, t)
370        })
371    }
372
373    fn could_match_int<R: Rt, E: UserEvent>(
374        &self,
375        env: &Env<R, E>,
376        hist: &mut FxHashMap<(usize, usize), bool>,
377        t: &Self,
378    ) -> Result<bool> {
379        match (self, t) {
380            (
381                Self::Ref { scope: s0, name: n0, .. },
382                Self::Ref { scope: s1, name: n1, .. },
383            ) if s0 == s1 && n0 == n1 => Ok(true),
384            (t0 @ Self::Ref { .. }, t1) | (t0, t1 @ Self::Ref { .. }) => {
385                let t0 = t0.lookup_ref(env)?;
386                let t1 = t1.lookup_ref(env)?;
387                let t0_addr = (t0 as *const Type).addr();
388                let t1_addr = (t1 as *const Type).addr();
389                match hist.get(&(t0_addr, t1_addr)) {
390                    Some(r) => Ok(*r),
391                    None => {
392                        hist.insert((t0_addr, t1_addr), true);
393                        match t0.could_match_int(env, hist, t1) {
394                            Ok(r) => {
395                                hist.insert((t0_addr, t1_addr), r);
396                                Ok(r)
397                            }
398                            Err(e) => {
399                                hist.remove(&(t0_addr, t1_addr));
400                                Err(e)
401                            }
402                        }
403                    }
404                }
405            }
406            (t0, Self::Primitive(s)) => {
407                for t1 in s.iter() {
408                    if t0.contains_int(env, hist, &Type::Primitive(t1.into()))? {
409                        return Ok(true);
410                    }
411                }
412                Ok(false)
413            }
414            (t0, Self::Set(ts)) => {
415                for t1 in ts.iter() {
416                    if t0.contains_int(env, hist, t1)? {
417                        return Ok(true);
418                    }
419                }
420                Ok(false)
421            }
422            (Type::TVar(t0), t1) => match &*t0.read().typ.read() {
423                Some(t0) => t0.could_match_int(env, hist, t1),
424                None => Ok(false),
425            },
426            (t0, Type::TVar(t1)) => match &*t1.read().typ.read() {
427                Some(t1) => t0.could_match_int(env, hist, t1),
428                None => Ok(false),
429            },
430            (t0, t1) => t0.contains_int(env, hist, t1),
431        }
432    }
433
434    pub fn could_match<R: Rt, E: UserEvent>(
435        &self,
436        env: &Env<R, E>,
437        t: &Self,
438    ) -> Result<bool> {
439        thread_local! {
440            static HIST: RefCell<FxHashMap<(usize, usize), bool>> = RefCell::new(HashMap::default());
441        }
442        HIST.with_borrow_mut(|hist| {
443            hist.clear();
444            self.could_match_int(env, hist, t)
445        })
446    }
447
448    fn union_int<R: Rt, E: UserEvent>(
449        &self,
450        env: &Env<R, E>,
451        hist: &mut FxHashMap<(usize, usize), Type>,
452        t: &Self,
453    ) -> Result<Self> {
454        match (self, t) {
455            (
456                Type::Ref { name: n0, scope: s0, .. },
457                Type::Ref { scope: s1, name: n1, .. },
458            ) if n0 == n1 && s0 == s1 => Ok(self.clone()),
459            (tr @ Type::Ref { .. }, t) => {
460                let t0 = tr.lookup_ref(env)?;
461                let t0_addr = (t0 as *const Type).addr();
462                let t_addr = (t as *const Type).addr();
463                match hist.get(&(t0_addr, t_addr)) {
464                    Some(t) => Ok(t.clone()),
465                    None => {
466                        hist.insert((t0_addr, t_addr), tr.clone());
467                        let r = t0.union_int(env, hist, t)?;
468                        hist.insert((t0_addr, t_addr), r.clone());
469                        Ok(r)
470                    }
471                }
472            }
473            (t, tr @ Type::Ref { .. }) => {
474                let t1 = tr.lookup_ref(env)?;
475                let t1_addr = (t1 as *const Type).addr();
476                let t_addr = (t as *const Type).addr();
477                match hist.get(&(t_addr, t1_addr)) {
478                    Some(t) => Ok(t.clone()),
479                    None => {
480                        hist.insert((t_addr, t1_addr), tr.clone());
481                        let r = t.union_int(env, hist, t1)?;
482                        hist.insert((t_addr, t1_addr), r.clone());
483                        Ok(r)
484                    }
485                }
486            }
487            (Type::Bottom, t) | (t, Type::Bottom) => Ok(t.clone()),
488            (Type::Any, _) | (_, Type::Any) => Ok(Type::Any),
489            (Type::Primitive(p), t) | (t, Type::Primitive(p)) if p.is_empty() => {
490                Ok(t.clone())
491            }
492            (Type::Primitive(s0), Type::Primitive(s1)) => {
493                let mut s = *s0;
494                s.insert(*s1);
495                Ok(Type::Primitive(s))
496            }
497            (
498                Type::Primitive(p),
499                Type::Array(_) | Type::Struct(_) | Type::Tuple(_) | Type::Variant(_, _),
500            )
501            | (
502                Type::Array(_) | Type::Struct(_) | Type::Tuple(_) | Type::Variant(_, _),
503                Type::Primitive(p),
504            ) if p.contains(Typ::Array) => Ok(Type::Primitive(*p)),
505            (Type::Primitive(p), Type::Array(t))
506            | (Type::Array(t), Type::Primitive(p)) => Ok(Type::Set(Arc::from_iter([
507                Type::Primitive(*p),
508                Type::Array(t.clone()),
509            ]))),
510            (t @ Type::Array(t0), u @ Type::Array(t1)) => {
511                if t0 == t1 {
512                    Ok(Type::Array(t0.clone()))
513                } else {
514                    Ok(Type::Set(Arc::from_iter([u.clone(), t.clone()])))
515                }
516            }
517            (t @ Type::ByRef(t0), u @ Type::ByRef(t1)) => {
518                if t0 == t1 {
519                    Ok(Type::ByRef(t0.clone()))
520                } else {
521                    Ok(Type::Set(Arc::from_iter([u.clone(), t.clone()])))
522                }
523            }
524            (Type::Set(s0), Type::Set(s1)) => Ok(Type::Set(Arc::from_iter(
525                s0.iter().cloned().chain(s1.iter().cloned()),
526            ))),
527            (Type::Set(s), t) | (t, Type::Set(s)) => Ok(Type::Set(Arc::from_iter(
528                s.iter().cloned().chain(iter::once(t.clone())),
529            ))),
530            (u @ Type::Struct(t0), t @ Type::Struct(t1)) => {
531                if t0.len() == t1.len() && t0 == t1 {
532                    Ok(u.clone())
533                } else {
534                    Ok(Type::Set(Arc::from_iter([u.clone(), t.clone()])))
535                }
536            }
537            (u @ Type::Struct(_), t) | (t, u @ Type::Struct(_)) => {
538                Ok(Type::Set(Arc::from_iter([u.clone(), t.clone()])))
539            }
540            (u @ Type::Tuple(t0), t @ Type::Tuple(t1)) => {
541                if t0 == t1 {
542                    Ok(u.clone())
543                } else {
544                    Ok(Type::Set(Arc::from_iter([u.clone(), t.clone()])))
545                }
546            }
547            (u @ Type::Tuple(_), t) | (t, u @ Type::Tuple(_)) => {
548                Ok(Type::Set(Arc::from_iter([u.clone(), t.clone()])))
549            }
550            (u @ Type::Variant(tg0, t0), t @ Type::Variant(tg1, t1)) => {
551                if tg0 == tg1 && t0.len() == t1.len() {
552                    let typs = t0
553                        .iter()
554                        .zip(t1.iter())
555                        .map(|(t0, t1)| t0.union_int(env, hist, t1))
556                        .collect::<Result<SmallVec<[_; 8]>>>()?;
557                    Ok(Type::Variant(tg0.clone(), Arc::from_iter(typs.into_iter())))
558                } else {
559                    Ok(Type::Set(Arc::from_iter([u.clone(), t.clone()])))
560                }
561            }
562            (u @ Type::Variant(_, _), t) | (t, u @ Type::Variant(_, _)) => {
563                Ok(Type::Set(Arc::from_iter([u.clone(), t.clone()])))
564            }
565            (Type::Fn(f0), Type::Fn(f1)) => {
566                if f0 == f1 {
567                    Ok(Type::Fn(f0.clone()))
568                } else {
569                    Ok(Type::Set(Arc::from_iter([
570                        Type::Fn(f0.clone()),
571                        Type::Fn(f1.clone()),
572                    ])))
573                }
574            }
575            (f @ Type::Fn(_), t) | (t, f @ Type::Fn(_)) => {
576                Ok(Type::Set(Arc::from_iter([f.clone(), t.clone()])))
577            }
578            (t0 @ Type::TVar(_), t1 @ Type::TVar(_)) => {
579                if t0 == t1 {
580                    Ok(t0.clone())
581                } else {
582                    Ok(Type::Set(Arc::from_iter([t0.clone(), t1.clone()])))
583                }
584            }
585            (t0 @ Type::TVar(_), t1) | (t1, t0 @ Type::TVar(_)) => {
586                Ok(Type::Set(Arc::from_iter([t0.clone(), t1.clone()])))
587            }
588            (t @ Type::ByRef(_), u) | (u, t @ Type::ByRef(_)) => {
589                Ok(Type::Set(Arc::from_iter([t.clone(), u.clone()])))
590            }
591        }
592    }
593
594    pub fn union<R: Rt, E: UserEvent>(&self, env: &Env<R, E>, t: &Self) -> Result<Self> {
595        thread_local! {
596            static HIST: RefCell<FxHashMap<(usize, usize), Type>> = RefCell::new(HashMap::default());
597        }
598        HIST.with_borrow_mut(|hist| {
599            hist.clear();
600            Ok(self.union_int(env, hist, t)?.normalize())
601        })
602    }
603
604    fn diff_int<R: Rt, E: UserEvent>(
605        &self,
606        env: &Env<R, E>,
607        hist: &mut FxHashMap<(usize, usize), Type>,
608        t: &Self,
609    ) -> Result<Self> {
610        match (self, t) {
611            (Type::Any, _) => Ok(Type::Any),
612            (_, Type::Any) => Ok(Type::Primitive(BitFlags::empty())),
613            (
614                Type::Ref { scope: s0, name: n0, .. },
615                Type::Ref { scope: s1, name: n1, .. },
616            ) if s0 == s1 && n0 == n1 => Ok(Type::Primitive(BitFlags::empty())),
617            (t0 @ Type::Ref { .. }, t1) | (t0, t1 @ Type::Ref { .. }) => {
618                let t0 = t0.lookup_ref(env)?;
619                let t1 = t1.lookup_ref(env)?;
620                let t0_addr = (t0 as *const Type).addr();
621                let t1_addr = (t1 as *const Type).addr();
622                match hist.get(&(t0_addr, t1_addr)) {
623                    Some(r) => Ok(r.clone()),
624                    None => {
625                        let r = Type::Primitive(BitFlags::empty());
626                        hist.insert((t0_addr, t1_addr), r);
627                        match t0.diff_int(env, hist, &t1) {
628                            Ok(r) => {
629                                hist.insert((t0_addr, t1_addr), r.clone());
630                                Ok(r)
631                            }
632                            Err(e) => {
633                                hist.remove(&(t0_addr, t1_addr));
634                                Err(e)
635                            }
636                        }
637                    }
638                }
639            }
640            (Type::Bottom, t) | (t, Type::Bottom) => Ok(t.clone()),
641            (Type::Primitive(s0), Type::Primitive(s1)) => {
642                let mut s = *s0;
643                s.remove(*s1);
644                Ok(Type::Primitive(s))
645            }
646            (
647                Type::Primitive(p),
648                Type::Array(_) | Type::Struct(_) | Type::Tuple(_) | Type::Variant(_, _),
649            ) => {
650                // CR estokes: is this correct? It's a bit odd.
651                let mut s = *p;
652                s.remove(Typ::Array);
653                Ok(Type::Primitive(s))
654            }
655            (
656                Type::Array(_) | Type::Struct(_) | Type::Tuple(_) | Type::Variant(_, _),
657                Type::Primitive(p),
658            ) => {
659                if p.contains(Typ::Array) {
660                    Ok(Type::Primitive(BitFlags::empty()))
661                } else {
662                    Ok(self.clone())
663                }
664            }
665            (Type::Array(t0), Type::Array(t1)) => {
666                Ok(Type::Array(Arc::new(t0.diff_int(env, hist, t1)?)))
667            }
668            (Type::ByRef(t0), Type::ByRef(t1)) => {
669                Ok(Type::ByRef(Arc::new(t0.diff_int(env, hist, t1)?)))
670            }
671            (Type::Set(s0), Type::Set(s1)) => {
672                let mut s: SmallVec<[Type; 4]> = smallvec![];
673                for i in 0..s0.len() {
674                    s.push(s0[i].clone());
675                    for j in 0..s1.len() {
676                        s[i] = s[i].diff_int(env, hist, &s1[j])?
677                    }
678                }
679                Ok(Self::flatten_set(s.into_iter()))
680            }
681            (Type::Set(s), t) => Ok(Self::flatten_set(
682                s.iter()
683                    .map(|s| s.diff_int(env, hist, t))
684                    .collect::<Result<SmallVec<[_; 8]>>>()?,
685            )),
686            (t, Type::Set(s)) => {
687                let mut t = t.clone();
688                for st in s.iter() {
689                    t = t.diff_int(env, hist, st)?;
690                }
691                Ok(t)
692            }
693            (Type::Tuple(t0), Type::Tuple(t1)) => {
694                if t0 == t1 {
695                    Ok(Type::Primitive(BitFlags::empty()))
696                } else {
697                    Ok(self.clone())
698                }
699            }
700            (Type::Tuple(_), _) | (_, Type::Tuple(_)) => Ok(self.clone()),
701            (Type::Struct(t0), Type::Struct(t1)) => {
702                if t0.len() == t1.len() && t0 == t1 {
703                    Ok(Type::Primitive(BitFlags::empty()))
704                } else {
705                    Ok(self.clone())
706                }
707            }
708            (Type::Struct(_), _) | (_, Type::Struct(_)) => Ok(self.clone()),
709            (Type::ByRef(_), _) | (_, Type::ByRef(_)) => Ok(self.clone()),
710            (Type::Variant(tg0, t0), Type::Variant(tg1, t1)) => {
711                if tg0 == tg1 && t0.len() == t1.len() && t0 == t1 {
712                    Ok(Type::Primitive(BitFlags::empty()))
713                } else {
714                    Ok(self.clone())
715                }
716            }
717            (Type::Variant(_, _), _) | (_, Type::Variant(_, _)) => Ok(self.clone()),
718            (Type::Fn(f0), Type::Fn(f1)) => {
719                if f0 == f1 {
720                    Ok(Type::Primitive(BitFlags::empty()))
721                } else {
722                    Ok(Type::Fn(f0.clone()))
723                }
724            }
725            (f @ Type::Fn(_), _) => Ok(f.clone()),
726            (t, Type::Fn(_)) => Ok(t.clone()),
727            (Type::TVar(tv0), Type::TVar(tv1)) => {
728                if tv0.read().typ.as_ptr() == tv1.read().typ.as_ptr() {
729                    return Ok(Type::Primitive(BitFlags::empty()));
730                }
731                Ok(match (&*tv0.read().typ.read(), &*tv1.read().typ.read()) {
732                    (None, _) | (_, None) => Type::TVar(tv0.clone()),
733                    (Some(t0), Some(t1)) => t0.diff_int(env, hist, t1)?,
734                })
735            }
736            (Type::TVar(tv), t1) => match &*tv.read().typ.read() {
737                None => Ok(Type::TVar(tv.clone())),
738                Some(t0) => t0.diff_int(env, hist, t1),
739            },
740            (t0, Type::TVar(tv)) => match &*tv.read().typ.read() {
741                None => Ok(t0.clone()),
742                Some(t1) => t0.diff_int(env, hist, t1),
743            },
744        }
745    }
746
747    pub fn diff<R: Rt, E: UserEvent>(&self, env: &Env<R, E>, t: &Self) -> Result<Self> {
748        thread_local! {
749            static HIST: RefCell<FxHashMap<(usize, usize), Type>> = RefCell::new(HashMap::default());
750        }
751        HIST.with_borrow_mut(|hist| {
752            hist.clear();
753            Ok(self.diff_int(env, hist, t)?.normalize())
754        })
755    }
756
757    pub fn any() -> Self {
758        Self::Any
759    }
760
761    pub fn boolean() -> Self {
762        Self::Primitive(Typ::Bool.into())
763    }
764
765    pub fn number() -> Self {
766        Self::Primitive(Typ::number())
767    }
768
769    pub fn int() -> Self {
770        Self::Primitive(Typ::integer())
771    }
772
773    pub fn uint() -> Self {
774        Self::Primitive(Typ::unsigned_integer())
775    }
776
777    /// alias type variables with the same name to each other
778    pub fn alias_tvars(&self, known: &mut FxHashMap<ArcStr, TVar>) {
779        match self {
780            Type::Bottom | Type::Any | Type::Primitive(_) => (),
781            Type::Ref { params, .. } => {
782                for t in params.iter() {
783                    t.alias_tvars(known);
784                }
785            }
786            Type::Array(t) => t.alias_tvars(known),
787            Type::ByRef(t) => t.alias_tvars(known),
788            Type::Tuple(ts) => {
789                for t in ts.iter() {
790                    t.alias_tvars(known)
791                }
792            }
793            Type::Struct(ts) => {
794                for (_, t) in ts.iter() {
795                    t.alias_tvars(known)
796                }
797            }
798            Type::Variant(_, ts) => {
799                for t in ts.iter() {
800                    t.alias_tvars(known)
801                }
802            }
803            Type::TVar(tv) => match known.entry(tv.name.clone()) {
804                Entry::Occupied(e) => {
805                    let v = e.get();
806                    v.freeze();
807                    tv.alias(v);
808                }
809                Entry::Vacant(e) => {
810                    e.insert(tv.clone());
811                    ()
812                }
813            },
814            Type::Fn(ft) => ft.alias_tvars(known),
815            Type::Set(s) => {
816                for typ in s.iter() {
817                    typ.alias_tvars(known)
818                }
819            }
820        }
821    }
822
823    pub fn collect_tvars(&self, known: &mut FxHashMap<ArcStr, TVar>) {
824        match self {
825            Type::Bottom | Type::Any | Type::Primitive(_) => (),
826            Type::Ref { params, .. } => {
827                for t in params.iter() {
828                    t.collect_tvars(known);
829                }
830            }
831            Type::Array(t) => t.collect_tvars(known),
832            Type::ByRef(t) => t.collect_tvars(known),
833            Type::Tuple(ts) => {
834                for t in ts.iter() {
835                    t.collect_tvars(known)
836                }
837            }
838            Type::Struct(ts) => {
839                for (_, t) in ts.iter() {
840                    t.collect_tvars(known)
841                }
842            }
843            Type::Variant(_, ts) => {
844                for t in ts.iter() {
845                    t.collect_tvars(known)
846                }
847            }
848            Type::TVar(tv) => match known.entry(tv.name.clone()) {
849                Entry::Occupied(_) => (),
850                Entry::Vacant(e) => {
851                    e.insert(tv.clone());
852                    ()
853                }
854            },
855            Type::Fn(ft) => ft.collect_tvars(known),
856            Type::Set(s) => {
857                for typ in s.iter() {
858                    typ.collect_tvars(known)
859                }
860            }
861        }
862    }
863
864    pub fn check_tvars_declared(&self, declared: &FxHashSet<ArcStr>) -> Result<()> {
865        match self {
866            Type::Bottom | Type::Any | Type::Primitive(_) => Ok(()),
867            Type::Ref { params, .. } => {
868                params.iter().try_for_each(|t| t.check_tvars_declared(declared))
869            }
870            Type::Array(t) => t.check_tvars_declared(declared),
871            Type::ByRef(t) => t.check_tvars_declared(declared),
872            Type::Tuple(ts) => {
873                ts.iter().try_for_each(|t| t.check_tvars_declared(declared))
874            }
875            Type::Struct(ts) => {
876                ts.iter().try_for_each(|(_, t)| t.check_tvars_declared(declared))
877            }
878            Type::Variant(_, ts) => {
879                ts.iter().try_for_each(|t| t.check_tvars_declared(declared))
880            }
881            Type::TVar(tv) => {
882                if !declared.contains(&tv.name) {
883                    bail!("undeclared type variable '{}'", tv.name)
884                } else {
885                    Ok(())
886                }
887            }
888            Type::Set(s) => s.iter().try_for_each(|t| t.check_tvars_declared(declared)),
889            Type::Fn(_) => Ok(()),
890        }
891    }
892
893    pub fn has_unbound(&self) -> bool {
894        match self {
895            Type::Bottom | Type::Any | Type::Primitive(_) => false,
896            Type::Ref { .. } => false,
897            Type::Array(t0) => t0.has_unbound(),
898            Type::ByRef(t0) => t0.has_unbound(),
899            Type::Tuple(ts) => ts.iter().any(|t| t.has_unbound()),
900            Type::Struct(ts) => ts.iter().any(|(_, t)| t.has_unbound()),
901            Type::Variant(_, ts) => ts.iter().any(|t| t.has_unbound()),
902            Type::TVar(tv) => tv.read().typ.read().is_some(),
903            Type::Set(s) => s.iter().any(|t| t.has_unbound()),
904            Type::Fn(ft) => ft.has_unbound(),
905        }
906    }
907
908    /// bind all unbound type variables to the specified type
909    pub fn bind_as(&self, t: &Self) {
910        match self {
911            Type::Bottom | Type::Any | Type::Primitive(_) => (),
912            Type::Ref { .. } => (),
913            Type::Array(t0) => t0.bind_as(t),
914            Type::ByRef(t0) => t0.bind_as(t),
915            Type::Tuple(ts) => {
916                for elt in ts.iter() {
917                    elt.bind_as(t)
918                }
919            }
920            Type::Struct(ts) => {
921                for (_, elt) in ts.iter() {
922                    elt.bind_as(t)
923                }
924            }
925            Type::Variant(_, ts) => {
926                for elt in ts.iter() {
927                    elt.bind_as(t)
928                }
929            }
930            Type::TVar(tv) => {
931                let tv = tv.read();
932                let mut tv = tv.typ.write();
933                if tv.is_none() {
934                    *tv = Some(t.clone());
935                }
936            }
937            Type::Set(s) => {
938                for elt in s.iter() {
939                    elt.bind_as(t)
940                }
941            }
942            Type::Fn(ft) => ft.bind_as(t),
943        }
944    }
945
946    /// return a copy of self with all type variables unbound and
947    /// unaliased. self will not be modified
948    pub fn reset_tvars(&self) -> Type {
949        match self {
950            Type::Bottom => Type::Bottom,
951            Type::Any => Type::Any,
952            Type::Primitive(p) => Type::Primitive(*p),
953            Type::Ref { scope, name, params } => Type::Ref {
954                scope: scope.clone(),
955                name: name.clone(),
956                params: Arc::from_iter(params.iter().map(|t| t.reset_tvars())),
957            },
958            Type::Array(t0) => Type::Array(Arc::new(t0.reset_tvars())),
959            Type::ByRef(t0) => Type::ByRef(Arc::new(t0.reset_tvars())),
960            Type::Tuple(ts) => {
961                Type::Tuple(Arc::from_iter(ts.iter().map(|t| t.reset_tvars())))
962            }
963            Type::Struct(ts) => Type::Struct(Arc::from_iter(
964                ts.iter().map(|(n, t)| (n.clone(), t.reset_tvars())),
965            )),
966            Type::Variant(tag, ts) => Type::Variant(
967                tag.clone(),
968                Arc::from_iter(ts.iter().map(|t| t.reset_tvars())),
969            ),
970            Type::TVar(tv) => Type::TVar(TVar::empty_named(tv.name.clone())),
971            Type::Set(s) => Type::Set(Arc::from_iter(s.iter().map(|t| t.reset_tvars()))),
972            Type::Fn(fntyp) => Type::Fn(Arc::new(fntyp.reset_tvars())),
973        }
974    }
975
976    /// return a copy of self with every TVar named in known replaced
977    /// with the corresponding type
978    pub fn replace_tvars(&self, known: &FxHashMap<ArcStr, Self>) -> Type {
979        match self {
980            Type::TVar(tv) => match known.get(&tv.name) {
981                Some(t) => t.clone(),
982                None => Type::TVar(tv.clone()),
983            },
984            Type::Bottom => Type::Bottom,
985            Type::Any => Type::Any,
986            Type::Primitive(p) => Type::Primitive(*p),
987            Type::Ref { scope, name, params } => Type::Ref {
988                scope: scope.clone(),
989                name: name.clone(),
990                params: Arc::from_iter(params.iter().map(|t| t.replace_tvars(known))),
991            },
992            Type::Array(t0) => Type::Array(Arc::new(t0.replace_tvars(known))),
993            Type::ByRef(t0) => Type::ByRef(Arc::new(t0.replace_tvars(known))),
994            Type::Tuple(ts) => {
995                Type::Tuple(Arc::from_iter(ts.iter().map(|t| t.replace_tvars(known))))
996            }
997            Type::Struct(ts) => Type::Struct(Arc::from_iter(
998                ts.iter().map(|(n, t)| (n.clone(), t.replace_tvars(known))),
999            )),
1000            Type::Variant(tag, ts) => Type::Variant(
1001                tag.clone(),
1002                Arc::from_iter(ts.iter().map(|t| t.replace_tvars(known))),
1003            ),
1004            Type::Set(s) => {
1005                Type::Set(Arc::from_iter(s.iter().map(|t| t.replace_tvars(known))))
1006            }
1007            Type::Fn(fntyp) => Type::Fn(Arc::new(fntyp.replace_tvars(known))),
1008        }
1009    }
1010
1011    /// Unbind any bound tvars, but do not unalias them.
1012    pub(crate) fn unbind_tvars(&self) {
1013        match self {
1014            Type::Bottom | Type::Any | Type::Primitive(_) | Type::Ref { .. } => (),
1015            Type::Array(t0) => t0.unbind_tvars(),
1016            Type::ByRef(t0) => t0.unbind_tvars(),
1017            Type::Tuple(ts) | Type::Variant(_, ts) | Type::Set(ts) => {
1018                for t in ts.iter() {
1019                    t.unbind_tvars()
1020                }
1021            }
1022            Type::Struct(ts) => {
1023                for (_, t) in ts.iter() {
1024                    t.unbind_tvars()
1025                }
1026            }
1027            Type::TVar(tv) => tv.unbind(),
1028            Type::Fn(fntyp) => fntyp.unbind_tvars(),
1029        }
1030    }
1031
1032    fn first_prim_int<R: Rt, E: UserEvent>(
1033        &self,
1034        env: &Env<R, E>,
1035        hist: &mut FxHashSet<usize>,
1036    ) -> Option<Typ> {
1037        match self {
1038            Type::Primitive(p) => p.iter().next(),
1039            Type::Set(s) => s.iter().find_map(|t| t.first_prim_int(env, hist)),
1040            Type::TVar(tv) => {
1041                tv.read().typ.read().as_ref().and_then(|t| t.first_prim_int(env, hist))
1042            }
1043            // array, tuple, and struct casting are handled directly
1044            Type::Bottom
1045            | Type::Any
1046            | Type::Fn(_)
1047            | Type::Array(_)
1048            | Type::Tuple(_)
1049            | Type::Struct(_)
1050            | Type::Variant(_, _)
1051            | Type::ByRef(_) => None,
1052            Type::Ref { .. } => {
1053                let t = self.lookup_ref(env).ok()?;
1054                let t_addr = (t as *const Type).addr();
1055                if hist.contains(&t_addr) {
1056                    None
1057                } else {
1058                    hist.insert(t_addr);
1059                    t.first_prim_int(env, hist)
1060                }
1061            }
1062        }
1063    }
1064
1065    fn first_prim<R: Rt, E: UserEvent>(&self, env: &Env<R, E>) -> Option<Typ> {
1066        thread_local! {
1067            static HIST: RefCell<FxHashSet<usize>> = RefCell::new(HashSet::default());
1068        }
1069        HIST.with_borrow_mut(|hist| {
1070            hist.clear();
1071            self.first_prim_int(env, hist)
1072        })
1073    }
1074
1075    fn check_cast_int<R: Rt, E: UserEvent>(
1076        &self,
1077        env: &Env<R, E>,
1078        hist: &mut FxHashSet<usize>,
1079    ) -> Result<()> {
1080        match self {
1081            Type::Primitive(_) | Type::Any => Ok(()),
1082            Type::Fn(_) => bail!("can't cast a value to a function"),
1083            Type::Bottom => bail!("can't cast a value to bottom"),
1084            Type::Set(s) => Ok(for t in s.iter() {
1085                t.check_cast_int(env, hist)?
1086            }),
1087            Type::TVar(tv) => match &*tv.read().typ.read() {
1088                Some(t) => t.check_cast_int(env, hist),
1089                None => bail!("can't cast a value to a free type variable"),
1090            },
1091            Type::Array(et) => et.check_cast_int(env, hist),
1092            Type::ByRef(_) => bail!("can't cast a reference"),
1093            Type::Tuple(ts) => Ok(for t in ts.iter() {
1094                t.check_cast_int(env, hist)?
1095            }),
1096            Type::Struct(ts) => Ok(for (_, t) in ts.iter() {
1097                t.check_cast_int(env, hist)?
1098            }),
1099            Type::Variant(_, ts) => Ok(for t in ts.iter() {
1100                t.check_cast_int(env, hist)?
1101            }),
1102            Type::Ref { .. } => {
1103                let t = self.lookup_ref(env)?;
1104                let t_addr = (t as *const Type).addr();
1105                if hist.contains(&t_addr) {
1106                    Ok(())
1107                } else {
1108                    hist.insert(t_addr);
1109                    t.check_cast_int(env, hist)
1110                }
1111            }
1112        }
1113    }
1114
1115    pub fn check_cast<R: Rt, E: UserEvent>(&self, env: &Env<R, E>) -> Result<()> {
1116        thread_local! {
1117            static HIST: RefCell<FxHashSet<usize>> = RefCell::new(FxHashSet::default());
1118        }
1119        HIST.with_borrow_mut(|hist| {
1120            hist.clear();
1121            self.check_cast_int(env, hist)
1122        })
1123    }
1124
1125    fn cast_value_int<R: Rt, E: UserEvent>(
1126        &self,
1127        env: &Env<R, E>,
1128        hist: &mut FxHashSet<usize>,
1129        v: Value,
1130    ) -> Result<Value> {
1131        if self.is_a_int(env, hist, &v) {
1132            return Ok(v);
1133        }
1134        match self {
1135            Type::Array(et) => match v {
1136                Value::Array(elts) => {
1137                    let va = elts
1138                        .iter()
1139                        .map(|el| et.cast_value_int(env, hist, el.clone()))
1140                        .collect::<Result<SmallVec<[Value; 8]>>>()?;
1141                    Ok(Value::Array(ValArray::from_iter_exact(va.into_iter())))
1142                }
1143                v => Ok(Value::Array([et.cast_value_int(env, hist, v)?].into())),
1144            },
1145            Type::Tuple(ts) => match v {
1146                Value::Array(elts) => {
1147                    if elts.len() != ts.len() {
1148                        bail!("tuple size mismatch {self} with {}", Value::Array(elts))
1149                    }
1150                    let a = ts
1151                        .iter()
1152                        .zip(elts.iter())
1153                        .map(|(t, el)| t.cast_value_int(env, hist, el.clone()))
1154                        .collect::<Result<SmallVec<[Value; 8]>>>()?;
1155                    Ok(Value::Array(ValArray::from_iter_exact(a.into_iter())))
1156                }
1157                v => bail!("can't cast {v} to {self}"),
1158            },
1159            Type::Struct(ts) => match v {
1160                Value::Array(elts) => {
1161                    if elts.len() != ts.len() {
1162                        bail!("struct size mismatch {self} with {}", Value::Array(elts))
1163                    }
1164                    let is_pairs = elts.iter().all(|v| match v {
1165                        Value::Array(a) if a.len() == 2 => match &a[0] {
1166                            Value::String(_) => true,
1167                            _ => false,
1168                        },
1169                        _ => false,
1170                    });
1171                    if !is_pairs {
1172                        bail!("expected array of pairs, got {}", Value::Array(elts))
1173                    }
1174                    let mut elts_s: SmallVec<[&Value; 16]> = elts.iter().collect();
1175                    elts_s.sort_by_key(|v| match v {
1176                        Value::Array(a) => match &a[0] {
1177                            Value::String(s) => s,
1178                            _ => unreachable!(),
1179                        },
1180                        _ => unreachable!(),
1181                    });
1182                    let (keys_ok, ok) = ts.iter().zip(elts_s.iter()).fold(
1183                        Ok((true, true)),
1184                        |acc: Result<_>, ((fname, t), v)| {
1185                            let (kok, ok) = acc?;
1186                            let (name, v) = match v {
1187                                Value::Array(a) => match (&a[0], &a[1]) {
1188                                    (Value::String(n), v) => (n, v),
1189                                    _ => unreachable!(),
1190                                },
1191                                _ => unreachable!(),
1192                            };
1193                            Ok((
1194                                kok && name == fname,
1195                                ok && kok
1196                                    && t.contains(
1197                                        env,
1198                                        &Type::Primitive(Typ::get(v).into()),
1199                                    )?,
1200                            ))
1201                        },
1202                    )?;
1203                    if ok {
1204                        drop(elts_s);
1205                        return Ok(Value::Array(elts));
1206                    } else if keys_ok {
1207                        let elts = ts
1208                            .iter()
1209                            .zip(elts_s.iter())
1210                            .map(|((n, t), v)| match v {
1211                                Value::Array(a) => {
1212                                    let a = [
1213                                        Value::String(n.clone()),
1214                                        t.cast_value_int(env, hist, a[1].clone())?,
1215                                    ];
1216                                    Ok(Value::Array(ValArray::from_iter_exact(
1217                                        a.into_iter(),
1218                                    )))
1219                                }
1220                                _ => unreachable!(),
1221                            })
1222                            .collect::<Result<SmallVec<[Value; 8]>>>()?;
1223                        Ok(Value::Array(ValArray::from_iter_exact(elts.into_iter())))
1224                    } else {
1225                        drop(elts_s);
1226                        bail!("struct fields mismatch {self}, {}", Value::Array(elts))
1227                    }
1228                }
1229                v => bail!("can't cast {v} to {self}"),
1230            },
1231            Type::Variant(tag, ts) if ts.len() == 0 => match &v {
1232                Value::String(s) if s == tag => Ok(v),
1233                _ => bail!("variant tag mismatch expected {tag} got {v}"),
1234            },
1235            Type::Variant(tag, ts) => match &v {
1236                Value::Array(elts) => {
1237                    if ts.len() + 1 == elts.len() {
1238                        match &elts[0] {
1239                            Value::String(s) if s == tag => (),
1240                            v => bail!("variant tag mismatch expected {tag} got {v}"),
1241                        }
1242                        let a = iter::once(&Type::Primitive(Typ::String.into()))
1243                            .chain(ts.iter())
1244                            .zip(elts.iter())
1245                            .map(|(t, v)| t.cast_value_int(env, hist, v.clone()))
1246                            .collect::<Result<SmallVec<[Value; 8]>>>()?;
1247                        Ok(Value::Array(ValArray::from_iter_exact(a.into_iter())))
1248                    } else if ts.len() == elts.len() {
1249                        let mut a = ts
1250                            .iter()
1251                            .zip(elts.iter())
1252                            .map(|(t, v)| t.cast_value_int(env, hist, v.clone()))
1253                            .collect::<Result<SmallVec<[Value; 8]>>>()?;
1254                        a.insert(0, Value::String(tag.clone()));
1255                        Ok(Value::Array(ValArray::from_iter_exact(a.into_iter())))
1256                    } else {
1257                        bail!("variant length mismatch")
1258                    }
1259                }
1260                v => bail!("can't cast {v} to {self}"),
1261            },
1262            Type::Ref { .. } => self.lookup_ref(env)?.cast_value_int(env, hist, v),
1263            t => match t.first_prim(env) {
1264                None => bail!("empty or non primitive cast"),
1265                Some(t) => Ok(v
1266                    .clone()
1267                    .cast(t)
1268                    .ok_or_else(|| anyhow!("can't cast {v} to {t}"))?),
1269            },
1270        }
1271    }
1272
1273    pub fn cast_value<R: Rt, E: UserEvent>(&self, env: &Env<R, E>, v: Value) -> Value {
1274        thread_local! {
1275            static HIST: RefCell<FxHashSet<usize>> = RefCell::new(HashSet::default());
1276        }
1277        HIST.with_borrow_mut(|hist| {
1278            hist.clear();
1279            match self.cast_value_int(env, hist, v) {
1280                Ok(v) => v,
1281                Err(e) => Value::Error(e.to_string().into()),
1282            }
1283        })
1284    }
1285
1286    fn is_a_int<R: Rt, E: UserEvent>(
1287        &self,
1288        env: &Env<R, E>,
1289        hist: &mut FxHashSet<usize>,
1290        v: &Value,
1291    ) -> bool {
1292        match self {
1293            Type::Ref { .. } => match self.lookup_ref(env) {
1294                Err(_) => false,
1295                Ok(t) => {
1296                    let t_addr = (t as *const Type).addr();
1297                    !hist.contains(&t_addr) && {
1298                        hist.insert(t_addr);
1299                        t.is_a_int(env, hist, v)
1300                    }
1301                }
1302            },
1303            Type::Primitive(t) => t.contains(Typ::get(&v)),
1304            Type::Any => true,
1305            Type::Array(et) => match v {
1306                Value::Array(a) => a.iter().all(|v| et.is_a_int(env, hist, v)),
1307                _ => false,
1308            },
1309            Type::ByRef(_) => matches!(v, Value::U64(_) | Value::V64(_)),
1310            Type::Tuple(ts) => match v {
1311                Value::Array(elts) => {
1312                    elts.len() == ts.len()
1313                        && ts
1314                            .iter()
1315                            .zip(elts.iter())
1316                            .all(|(t, v)| t.is_a_int(env, hist, v))
1317                }
1318                _ => false,
1319            },
1320            Type::Struct(ts) => match v {
1321                Value::Array(elts) => {
1322                    elts.len() == ts.len()
1323                        && ts.iter().zip(elts.iter()).all(|((n, t), v)| match v {
1324                            Value::Array(a) if a.len() == 2 => match &a[..] {
1325                                [Value::String(key), v] => {
1326                                    n == key && t.is_a_int(env, hist, v)
1327                                }
1328                                _ => false,
1329                            },
1330                            _ => false,
1331                        })
1332                }
1333                _ => false,
1334            },
1335            Type::Variant(tag, ts) if ts.len() == 0 => match &v {
1336                Value::String(s) => s == tag,
1337                _ => false,
1338            },
1339            Type::Variant(tag, ts) => match &v {
1340                Value::Array(elts) => {
1341                    ts.len() + 1 == elts.len()
1342                        && match &elts[0] {
1343                            Value::String(s) => s == tag,
1344                            _ => false,
1345                        }
1346                        && ts
1347                            .iter()
1348                            .zip(elts[1..].iter())
1349                            .all(|(t, v)| t.is_a_int(env, hist, v))
1350                }
1351                _ => false,
1352            },
1353            Type::TVar(tv) => match &*tv.read().typ.read() {
1354                None => true,
1355                Some(t) => t.is_a_int(env, hist, v),
1356            },
1357            Type::Fn(_) => match v {
1358                Value::U64(_) => true,
1359                _ => false,
1360            },
1361            Type::Bottom => true,
1362            Type::Set(ts) => ts.iter().any(|t| t.is_a_int(env, hist, v)),
1363        }
1364    }
1365
1366    /// return true if v is structurally compatible with the type
1367    pub fn is_a<R: Rt, E: UserEvent>(&self, env: &Env<R, E>, v: &Value) -> bool {
1368        thread_local! {
1369            static HIST: RefCell<FxHashSet<usize>> = RefCell::new(HashSet::default());
1370        }
1371        HIST.with_borrow_mut(|hist| {
1372            hist.clear();
1373            self.is_a_int(env, hist, v)
1374        })
1375    }
1376
1377    pub fn is_bot(&self) -> bool {
1378        match self {
1379            Type::Bottom => true,
1380            Type::Any
1381            | Type::TVar(_)
1382            | Type::Primitive(_)
1383            | Type::Ref { .. }
1384            | Type::Fn(_)
1385            | Type::Array(_)
1386            | Type::ByRef(_)
1387            | Type::Tuple(_)
1388            | Type::Struct(_)
1389            | Type::Variant(_, _)
1390            | Type::Set(_) => false,
1391        }
1392    }
1393
1394    pub fn with_deref<R, F: FnOnce(Option<&Self>) -> R>(&self, f: F) -> R {
1395        match self {
1396            Self::Bottom
1397            | Self::Any
1398            | Self::Primitive(_)
1399            | Self::Fn(_)
1400            | Self::Set(_)
1401            | Self::Array(_)
1402            | Self::ByRef(_)
1403            | Self::Tuple(_)
1404            | Self::Struct(_)
1405            | Self::Variant(_, _)
1406            | Self::Ref { .. } => f(Some(self)),
1407            Self::TVar(tv) => f(tv.read().typ.read().as_ref()),
1408        }
1409    }
1410
1411    pub(crate) fn flatten_set(set: impl IntoIterator<Item = Self>) -> Self {
1412        let init: Box<dyn Iterator<Item = Self>> = Box::new(set.into_iter());
1413        let mut iters: SmallVec<[Box<dyn Iterator<Item = Self>>; 16]> = smallvec![init];
1414        let mut acc: SmallVec<[Self; 16]> = smallvec![];
1415        loop {
1416            match iters.last_mut() {
1417                None => break,
1418                Some(iter) => match iter.next() {
1419                    None => {
1420                        iters.pop();
1421                    }
1422                    Some(Type::Set(s)) => {
1423                        let v: SmallVec<[Self; 16]> =
1424                            s.iter().map(|t| t.clone()).collect();
1425                        iters.push(Box::new(v.into_iter()))
1426                    }
1427                    Some(Type::Any) => return Type::Any,
1428                    Some(t) => {
1429                        acc.push(t);
1430                        let mut i = 0;
1431                        let mut j;
1432                        while i < acc.len() {
1433                            j = i + 1;
1434                            while j < acc.len() {
1435                                if let Some(t) = acc[i].merge(&acc[j]) {
1436                                    acc[i] = t;
1437                                    acc.remove(j);
1438                                } else {
1439                                    j += 1;
1440                                }
1441                            }
1442                            i += 1;
1443                        }
1444                    }
1445                },
1446            }
1447        }
1448        acc.sort();
1449        match &*acc {
1450            [] => Type::Primitive(BitFlags::empty()),
1451            [t] => t.clone(),
1452            _ => Type::Set(Arc::from_iter(acc)),
1453        }
1454    }
1455
1456    pub(crate) fn normalize(&self) -> Self {
1457        match self {
1458            Type::Bottom | Type::Any | Type::Primitive(_) => self.clone(),
1459            Type::Ref { scope, name, params } => {
1460                let params = Arc::from_iter(params.iter().map(|t| t.normalize()));
1461                Type::Ref { scope: scope.clone(), name: name.clone(), params }
1462            }
1463            Type::TVar(tv) => Type::TVar(tv.normalize()),
1464            Type::Set(s) => Self::flatten_set(s.iter().map(|t| t.normalize())),
1465            Type::Array(t) => Type::Array(Arc::new(t.normalize())),
1466            Type::ByRef(t) => Type::ByRef(Arc::new(t.normalize())),
1467            Type::Tuple(t) => {
1468                Type::Tuple(Arc::from_iter(t.iter().map(|t| t.normalize())))
1469            }
1470            Type::Struct(t) => Type::Struct(Arc::from_iter(
1471                t.iter().map(|(n, t)| (n.clone(), t.normalize())),
1472            )),
1473            Type::Variant(tag, t) => Type::Variant(
1474                tag.clone(),
1475                Arc::from_iter(t.iter().map(|t| t.normalize())),
1476            ),
1477            Type::Fn(ft) => Type::Fn(Arc::new(ft.normalize())),
1478        }
1479    }
1480
1481    fn merge(&self, t: &Self) -> Option<Self> {
1482        match (self, t) {
1483            (
1484                Type::Ref { scope: s0, name: r0, params: a0 },
1485                Type::Ref { scope: s1, name: r1, params: a1 },
1486            ) => {
1487                if s0 == s1 && r0 == r1 && a0 == a1 {
1488                    Some(Type::Ref {
1489                        scope: s0.clone(),
1490                        name: r0.clone(),
1491                        params: a0.clone(),
1492                    })
1493                } else {
1494                    None
1495                }
1496            }
1497            (Type::Ref { .. }, _) | (_, Type::Ref { .. }) => None,
1498            (Type::Bottom, t) | (t, Type::Bottom) => Some(t.clone()),
1499            (Type::Any, _) | (_, Type::Any) => Some(Type::Any),
1500            (Type::Primitive(p), t) | (t, Type::Primitive(p)) if p.is_empty() => {
1501                Some(t.clone())
1502            }
1503            (Type::Primitive(s0), Type::Primitive(s1)) => {
1504                let mut s = *s0;
1505                s.insert(*s1);
1506                Some(Type::Primitive(s))
1507            }
1508            (Type::Fn(f0), Type::Fn(f1)) => {
1509                if f0 == f1 {
1510                    Some(Type::Fn(f0.clone()))
1511                } else {
1512                    None
1513                }
1514            }
1515            (Type::Array(t0), Type::Array(t1)) => {
1516                let t0f = match &**t0 {
1517                    Type::Set(et) => Self::flatten_set(et.iter().cloned()),
1518                    t => t.clone(),
1519                };
1520                let t1f = match &**t1 {
1521                    Type::Set(et) => Self::flatten_set(et.iter().cloned()),
1522                    t => t.clone(),
1523                };
1524                if t0f == t1f {
1525                    Some(Type::Array(t0.clone()))
1526                } else {
1527                    None
1528                }
1529            }
1530            (Type::ByRef(t0), Type::ByRef(t1)) => {
1531                t0.merge(t1).map(|t| Type::ByRef(Arc::new(t)))
1532            }
1533            (Type::ByRef(_), _) | (_, Type::ByRef(_)) => None,
1534            (Type::Array(_), _) | (_, Type::Array(_)) => None,
1535            (Type::Set(s0), Type::Set(s1)) => {
1536                Some(Self::flatten_set(s0.iter().cloned().chain(s1.iter().cloned())))
1537            }
1538            (Type::Set(s), Type::Primitive(p)) | (Type::Primitive(p), Type::Set(s))
1539                if p.is_empty() =>
1540            {
1541                Some(Type::Set(s.clone()))
1542            }
1543            (Type::Set(s), t) | (t, Type::Set(s)) => {
1544                Some(Self::flatten_set(s.iter().cloned().chain(iter::once(t.clone()))))
1545            }
1546            (Type::Tuple(t0), Type::Tuple(t1)) => {
1547                if t0.len() == t1.len() {
1548                    let t = t0
1549                        .iter()
1550                        .zip(t1.iter())
1551                        .map(|(t0, t1)| t0.merge(t1))
1552                        .collect::<Option<SmallVec<[Type; 8]>>>()?;
1553                    Some(Type::Tuple(Arc::from_iter(t)))
1554                } else {
1555                    None
1556                }
1557            }
1558            (Type::Variant(tag0, t0), Type::Variant(tag1, t1)) => {
1559                if tag0 == tag1 && t0.len() == t1.len() {
1560                    let t = t0
1561                        .iter()
1562                        .zip(t1.iter())
1563                        .map(|(t0, t1)| t0.merge(t1))
1564                        .collect::<Option<SmallVec<[Type; 8]>>>()?;
1565                    Some(Type::Variant(tag0.clone(), Arc::from_iter(t)))
1566                } else {
1567                    None
1568                }
1569            }
1570            (Type::Struct(t0), Type::Struct(t1)) => {
1571                if t0.len() == t1.len() {
1572                    let t = t0
1573                        .iter()
1574                        .zip(t1.iter())
1575                        .map(|((n0, t0), (n1, t1))| {
1576                            if n0 != n1 {
1577                                None
1578                            } else {
1579                                t0.merge(t1).map(|t| (n0.clone(), t))
1580                            }
1581                        })
1582                        .collect::<Option<SmallVec<[(ArcStr, Type); 8]>>>()?;
1583                    Some(Type::Struct(Arc::from_iter(t)))
1584                } else {
1585                    None
1586                }
1587            }
1588            (Type::TVar(tv0), Type::TVar(tv1)) if tv0.name == tv1.name && tv0 == tv1 => {
1589                Some(Type::TVar(tv0.clone()))
1590            }
1591            (Type::TVar(tv), t) => {
1592                tv.read().typ.read().as_ref().and_then(|tv| tv.merge(t))
1593            }
1594            (t, Type::TVar(tv)) => {
1595                tv.read().typ.read().as_ref().and_then(|tv| t.merge(tv))
1596            }
1597            (Type::Tuple(_), _)
1598            | (_, Type::Tuple(_))
1599            | (Type::Struct(_), _)
1600            | (_, Type::Struct(_))
1601            | (Type::Variant(_, _), _)
1602            | (_, Type::Variant(_, _))
1603            | (_, Type::Fn(_))
1604            | (Type::Fn(_), _) => None,
1605        }
1606    }
1607
1608    pub fn scope_refs(&self, scope: &ModPath) -> Type {
1609        match self {
1610            Type::Bottom => Type::Bottom,
1611            Type::Any => Type::Any,
1612            Type::Primitive(s) => Type::Primitive(*s),
1613            Type::Array(t0) => Type::Array(Arc::new(t0.scope_refs(scope))),
1614            Type::ByRef(t) => Type::ByRef(Arc::new(t.scope_refs(scope))),
1615            Type::Tuple(ts) => {
1616                let i = ts.iter().map(|t| t.scope_refs(scope));
1617                Type::Tuple(Arc::from_iter(i))
1618            }
1619            Type::Variant(tag, ts) => {
1620                let i = ts.iter().map(|t| t.scope_refs(scope));
1621                Type::Variant(tag.clone(), Arc::from_iter(i))
1622            }
1623            Type::Struct(ts) => {
1624                let i = ts.iter().map(|(n, t)| (n.clone(), t.scope_refs(scope)));
1625                Type::Struct(Arc::from_iter(i))
1626            }
1627            Type::TVar(tv) => match tv.read().typ.read().as_ref() {
1628                None => Type::TVar(TVar::empty_named(tv.name.clone())),
1629                Some(typ) => {
1630                    let typ = typ.scope_refs(scope);
1631                    Type::TVar(TVar::named(tv.name.clone(), typ))
1632                }
1633            },
1634            Type::Ref { scope: _, name, params } => {
1635                let params = Arc::from_iter(params.iter().map(|t| t.scope_refs(scope)));
1636                Type::Ref { scope: scope.clone(), name: name.clone(), params }
1637            }
1638            Type::Set(ts) => {
1639                Type::Set(Arc::from_iter(ts.iter().map(|t| t.scope_refs(scope))))
1640            }
1641            Type::Fn(f) => {
1642                let vargs = f.vargs.as_ref().map(|t| t.scope_refs(scope));
1643                let rtype = f.rtype.scope_refs(scope);
1644                let args = Arc::from_iter(f.args.iter().map(|a| FnArgType {
1645                    label: a.label.clone(),
1646                    typ: a.typ.scope_refs(scope),
1647                }));
1648                let mut cres: SmallVec<[(TVar, Type); 4]> = smallvec![];
1649                for (tv, tc) in f.constraints.read().iter() {
1650                    let tv = tv.scope_refs(scope);
1651                    let tc = tc.scope_refs(scope);
1652                    cres.push((tv, tc));
1653                }
1654                Type::Fn(Arc::new(FnType {
1655                    args,
1656                    rtype,
1657                    constraints: Arc::new(RwLock::new(cres.into_iter().collect())),
1658                    vargs,
1659                }))
1660            }
1661        }
1662    }
1663}
1664
1665impl fmt::Display for Type {
1666    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1667        match self {
1668            Self::Bottom => write!(f, "_"),
1669            Self::Any => write!(f, "Any"),
1670            Self::Ref { scope: _, name, params } => {
1671                write!(f, "{name}")?;
1672                if !params.is_empty() {
1673                    write!(f, "<")?;
1674                    for (i, t) in params.iter().enumerate() {
1675                        write!(f, "{t}")?;
1676                        if i < params.len() - 1 {
1677                            write!(f, ", ")?;
1678                        }
1679                    }
1680                    write!(f, ">")?;
1681                }
1682                Ok(())
1683            }
1684            Self::TVar(tv) => write!(f, "{tv}"),
1685            Self::Fn(t) => write!(f, "{t}"),
1686            Self::Array(t) => write!(f, "Array<{t}>"),
1687            Self::ByRef(t) => write!(f, "&{t}"),
1688            Self::Tuple(ts) => {
1689                write!(f, "(")?;
1690                for (i, t) in ts.iter().enumerate() {
1691                    write!(f, "{t}")?;
1692                    if i < ts.len() - 1 {
1693                        write!(f, ", ")?;
1694                    }
1695                }
1696                write!(f, ")")
1697            }
1698            Self::Variant(tag, ts) if ts.len() == 0 => {
1699                write!(f, "`{tag}")
1700            }
1701            Self::Variant(tag, ts) => {
1702                write!(f, "`{tag}(")?;
1703                for (i, t) in ts.iter().enumerate() {
1704                    write!(f, "{t}")?;
1705                    if i < ts.len() - 1 {
1706                        write!(f, ", ")?
1707                    }
1708                }
1709                write!(f, ")")
1710            }
1711            Self::Struct(ts) => {
1712                write!(f, "{{")?;
1713                for (i, (n, t)) in ts.iter().enumerate() {
1714                    write!(f, "{n}: {t}")?;
1715                    if i < ts.len() - 1 {
1716                        write!(f, ", ")?
1717                    }
1718                }
1719                write!(f, "}}")
1720            }
1721            Self::Set(s) => {
1722                write!(f, "[")?;
1723                for (i, t) in s.iter().enumerate() {
1724                    write!(f, "{t}")?;
1725                    if i < s.len() - 1 {
1726                        write!(f, ", ")?;
1727                    }
1728                }
1729                write!(f, "]")
1730            }
1731            Self::Primitive(s) => {
1732                let replace = PRINT_FLAGS.get().contains(PrintFlag::ReplacePrims);
1733                if replace && *s == Typ::number() {
1734                    write!(f, "Number")
1735                } else if replace && *s == Typ::float() {
1736                    write!(f, "Float")
1737                } else if replace && *s == Typ::real() {
1738                    write!(f, "Real")
1739                } else if replace && *s == Typ::integer() {
1740                    write!(f, "Int")
1741                } else if replace && *s == Typ::unsigned_integer() {
1742                    write!(f, "Uint")
1743                } else if replace && *s == Typ::signed_integer() {
1744                    write!(f, "Sint")
1745                } else if s.len() == 0 {
1746                    write!(f, "[]")
1747                } else if s.len() == 1 {
1748                    write!(f, "{}", s.iter().next().unwrap())
1749                } else {
1750                    let mut s = *s;
1751                    macro_rules! builtin {
1752                        ($set:expr, $name:literal) => {
1753                            if replace && s.contains($set) {
1754                                s.remove($set);
1755                                write!(f, $name)?;
1756                                if !s.is_empty() {
1757                                    write!(f, ", ")?
1758                                }
1759                            }
1760                        };
1761                    }
1762                    write!(f, "[")?;
1763                    builtin!(Typ::number(), "Number");
1764                    builtin!(Typ::real(), "Real");
1765                    builtin!(Typ::float(), "Float");
1766                    builtin!(Typ::integer(), "Int");
1767                    builtin!(Typ::unsigned_integer(), "Uint");
1768                    builtin!(Typ::signed_integer(), "Sint");
1769                    for (i, t) in s.iter().enumerate() {
1770                        write!(f, "{t}")?;
1771                        if i < s.len() - 1 {
1772                            write!(f, ", ")?;
1773                        }
1774                    }
1775                    write!(f, "]")
1776                }
1777            }
1778        }
1779    }
1780}