graphix_compiler/typ/
mod.rs

1use crate::{
2    env::Env,
3    errf,
4    expr::{
5        print::{PrettyBuf, PrettyDisplay},
6        ModPath,
7    },
8    format_with_flags, PrintFlag, Rt, UserEvent, CAST_ERR_TAG, PRINT_FLAGS,
9};
10use anyhow::{anyhow, bail, Result};
11use arcstr::ArcStr;
12use enumflags2::bitflags;
13use enumflags2::BitFlags;
14use fxhash::{FxHashMap, FxHashSet};
15use immutable_chunkmap::map::Map;
16use netidx::{
17    publisher::{Typ, Value},
18    utils::Either,
19};
20use netidx_value::ValArray;
21use poolshark::local::LPooled;
22use smallvec::{smallvec, SmallVec};
23use std::{
24    cmp::{Eq, PartialEq},
25    collections::hash_map::Entry,
26    fmt::{self, Debug, Write},
27    iter,
28};
29use triomphe::Arc;
30
31mod fntyp;
32mod tval;
33mod tvar;
34
35pub use fntyp::{FnArgType, FnType};
36pub use tval::TVal;
37use tvar::would_cycle_inner;
38pub use tvar::TVar;
39
40#[derive(Debug, Clone, Copy)]
41#[bitflags]
42#[repr(u8)]
43pub enum ContainsFlags {
44    AliasTVars,
45    InitTVars,
46}
47
48struct AndAc(bool);
49
50impl FromIterator<bool> for AndAc {
51    fn from_iter<T: IntoIterator<Item = bool>>(iter: T) -> Self {
52        AndAc(iter.into_iter().all(|b| b))
53    }
54}
55
56#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
57pub enum Type {
58    Bottom,
59    Any,
60    Primitive(BitFlags<Typ>),
61    Ref { scope: ModPath, name: ModPath, params: Arc<[Type]> },
62    Fn(Arc<FnType>),
63    Set(Arc<[Type]>),
64    TVar(TVar),
65    Error(Arc<Type>),
66    Array(Arc<Type>),
67    ByRef(Arc<Type>),
68    Tuple(Arc<[Type]>),
69    Struct(Arc<[(ArcStr, Type)]>),
70    Variant(ArcStr, Arc<[Type]>),
71    Map { key: Arc<Type>, value: Arc<Type> },
72}
73
74impl Default for Type {
75    fn default() -> Self {
76        Self::Bottom
77    }
78}
79
80impl Type {
81    pub fn empty_tvar() -> Self {
82        Type::TVar(TVar::default())
83    }
84
85    fn iter_prims(&self) -> impl Iterator<Item = Self> {
86        match self {
87            Self::Primitive(p) => {
88                Either::Left(p.iter().map(|t| Type::Primitive(t.into())))
89            }
90            t => Either::Right(iter::once(t.clone())),
91        }
92    }
93
94    pub fn is_defined(&self) -> bool {
95        match self {
96            Self::Bottom
97            | Self::Any
98            | Self::Primitive(_)
99            | Self::Fn(_)
100            | Self::Set(_)
101            | Self::Error(_)
102            | Self::Array(_)
103            | Self::ByRef(_)
104            | Self::Tuple(_)
105            | Self::Struct(_)
106            | Self::Variant(_, _)
107            | Self::Ref { .. }
108            | Self::Map { .. } => true,
109            Self::TVar(tv) => tv.read().typ.read().is_some(),
110        }
111    }
112
113    pub fn lookup_ref<'a, R: Rt, E: UserEvent>(
114        &'a self,
115        env: &'a Env<R, E>,
116    ) -> Result<&'a Type> {
117        match self {
118            Self::Ref { scope, name, params } => {
119                let def = env
120                    .lookup_typedef(scope, name)
121                    .ok_or_else(|| anyhow!("undefined type {name} in {scope}"))?;
122                if def.params.len() != params.len() {
123                    bail!("{} expects {} type parameters", name, def.params.len());
124                }
125                def.typ.unbind_tvars();
126                for ((tv, ct), arg) in def.params.iter().zip(params.iter()) {
127                    if let Some(ct) = ct {
128                        ct.check_contains(env, arg)?;
129                    }
130                    if !tv.would_cycle(arg) {
131                        match arg {
132                            Type::TVar(arg_tv) => match &*arg_tv.read().typ.read() {
133                                None => *tv.read().typ.write() = Some(arg.clone()),
134                                Some(t) => *tv.read().typ.write() = Some(t.clone()),
135                            },
136                            _ => {
137                                *tv.read().typ.write() = Some(arg.clone());
138                            }
139                        }
140                    }
141                }
142                Ok(&def.typ)
143            }
144            t => Ok(t),
145        }
146    }
147
148    pub fn check_contains<R: Rt, E: UserEvent>(
149        &self,
150        env: &Env<R, E>,
151        t: &Self,
152    ) -> Result<()> {
153        if self.contains(env, t)? {
154            Ok(())
155        } else {
156            format_with_flags(PrintFlag::DerefTVars | PrintFlag::ReplacePrims, || {
157                bail!("type mismatch {self} does not contain {t}")
158            })
159        }
160    }
161
162    fn contains_int<R: Rt, E: UserEvent>(
163        &self,
164        flags: BitFlags<ContainsFlags>,
165        env: &Env<R, E>,
166        hist: &mut FxHashMap<(usize, usize), bool>,
167        t: &Self,
168    ) -> Result<bool> {
169        if (self as *const Type) == (t as *const Type) {
170            return Ok(true);
171        }
172        match (self, t) {
173            (
174                Self::Ref { scope: s0, name: n0, params: p0 },
175                Self::Ref { scope: s1, name: n1, params: p1 },
176            ) if s0 == s1 && n0 == n1 => Ok(p0.len() == p1.len()
177                && p0
178                    .iter()
179                    .zip(p1.iter())
180                    .map(|(t0, t1)| t0.contains_int(flags, env, hist, t1))
181                    .collect::<Result<AndAc>>()?
182                    .0),
183            (t0 @ Self::Ref { .. }, t1) | (t0, t1 @ Self::Ref { .. }) => {
184                let t0 = t0.lookup_ref(env)?;
185                let t1 = t1.lookup_ref(env)?;
186                let t0_addr = (t0 as *const Type).addr();
187                let t1_addr = (t1 as *const Type).addr();
188                match hist.get(&(t0_addr, t1_addr)) {
189                    Some(r) => Ok(*r),
190                    None => {
191                        hist.insert((t0_addr, t1_addr), true);
192                        match t0.contains_int(flags, env, hist, t1) {
193                            Ok(r) => {
194                                hist.insert((t0_addr, t1_addr), r);
195                                Ok(r)
196                            }
197                            Err(e) => {
198                                hist.remove(&(t0_addr, t1_addr));
199                                Err(e)
200                            }
201                        }
202                    }
203                }
204            }
205            (Self::TVar(t0), Self::Bottom) => {
206                if let Some(_) = &*t0.read().typ.read() {
207                    return Ok(true);
208                }
209                if flags.contains(ContainsFlags::InitTVars) {
210                    *t0.read().typ.write() = Some(Self::Bottom);
211                }
212                Ok(true)
213            }
214            (Self::Bottom, Self::TVar(t0)) => {
215                if let Some(Type::Bottom) = &*t0.read().typ.read() {
216                    return Ok(true);
217                }
218                if flags.contains(ContainsFlags::InitTVars) {
219                    *t0.read().typ.write() = Some(Self::Bottom);
220                    return Ok(true);
221                }
222                Ok(false)
223            }
224            (Self::Bottom, Self::Bottom) => Ok(true),
225            (Self::Bottom, _) => Ok(false),
226            (_, Self::Bottom) => Ok(true),
227            (Self::TVar(t0), Self::Any) => {
228                if let Some(t0) = &*t0.read().typ.read() {
229                    return t0.contains_int(flags, env, hist, t);
230                }
231                if flags.contains(ContainsFlags::InitTVars) {
232                    *t0.read().typ.write() = Some(Self::Any);
233                }
234                Ok(true)
235            }
236            (Self::Any, _) => Ok(true),
237            (Self::Primitive(p0), Self::Primitive(p1)) => Ok(p0.contains(*p1)),
238            (
239                Self::Primitive(p),
240                Self::Array(_) | Self::Tuple(_) | Self::Struct(_) | Self::Variant(_, _),
241            ) => Ok(p.contains(Typ::Array)),
242            (Self::Array(t0), Self::Array(t1)) => t0.contains_int(flags, env, hist, t1),
243            (Self::Array(t0), Self::Primitive(p)) if *p == BitFlags::from(Typ::Array) => {
244                t0.contains_int(flags, env, hist, &Type::Any)
245            }
246            (Self::Map { key: k0, value: v0 }, Self::Map { key: k1, value: v1 }) => {
247                Ok(k0.contains_int(flags, env, hist, k1)?
248                    && v0.contains_int(flags, env, hist, v1)?)
249            }
250            (Self::Primitive(p), Self::Map { .. }) => Ok(p.contains(Typ::Map)),
251            (Self::Map { key, value }, Self::Primitive(p))
252                if *p == BitFlags::from(Typ::Map) =>
253            {
254                Ok(key.contains_int(flags, env, hist, &Type::Any)?
255                    && value.contains_int(flags, env, hist, &Type::Any)?)
256            }
257            (Self::Primitive(p0), Self::Error(_)) => Ok(p0.contains(Typ::Error)),
258            (Self::Error(e), Self::Primitive(p)) if *p == BitFlags::from(Typ::Error) => {
259                e.contains_int(flags, env, hist, &Type::Any)
260            }
261            (Self::Error(e0), Self::Error(e1)) => e0.contains_int(flags, env, hist, e1),
262            (Self::Tuple(t0), Self::Tuple(t1))
263                if t0.as_ptr().addr() == t1.as_ptr().addr() =>
264            {
265                Ok(true)
266            }
267            (Self::Tuple(t0), Self::Tuple(t1)) => Ok(t0.len() == t1.len()
268                && t0
269                    .iter()
270                    .zip(t1.iter())
271                    .map(|(t0, t1)| t0.contains_int(flags, env, hist, t1))
272                    .collect::<Result<AndAc>>()?
273                    .0),
274            (Self::Struct(t0), Self::Struct(t1))
275                if t0.as_ptr().addr() == t1.as_ptr().addr() =>
276            {
277                Ok(true)
278            }
279            (Self::Struct(t0), Self::Struct(t1)) => {
280                Ok(t0.len() == t1.len() && {
281                    // struct types are always sorted by field name
282                    t0.iter()
283                        .zip(t1.iter())
284                        .map(|((n0, t0), (n1, t1))| {
285                            Ok(n0 == n1 && t0.contains_int(flags, env, hist, t1)?)
286                        })
287                        .collect::<Result<AndAc>>()?
288                        .0
289                })
290            }
291            (Self::Variant(tg0, t0), Self::Variant(tg1, t1))
292                if tg0.as_ptr() == tg1.as_ptr()
293                    && t0.as_ptr().addr() == t1.as_ptr().addr() =>
294            {
295                Ok(true)
296            }
297            (Self::Variant(tg0, t0), Self::Variant(tg1, t1)) => Ok(tg0 == tg1
298                && t0.len() == t1.len()
299                && t0
300                    .iter()
301                    .zip(t1.iter())
302                    .map(|(t0, t1)| t0.contains_int(flags, env, hist, t1))
303                    .collect::<Result<AndAc>>()?
304                    .0),
305            (Self::ByRef(t0), Self::ByRef(t1)) => t0.contains_int(flags, env, hist, t1),
306            (Self::TVar(t0), Self::TVar(t1))
307                if t0.addr() == t1.addr() || t0.read().id == t1.read().id =>
308            {
309                Ok(true)
310            }
311            (tt0 @ Self::TVar(t0), tt1 @ Self::TVar(t1)) => {
312                #[derive(Debug)]
313                enum Act {
314                    RightCopy,
315                    RightAlias,
316                    LeftAlias,
317                    LeftCopy,
318                }
319                if t0.would_cycle(tt1) || t1.would_cycle(tt0) {
320                    return Ok(true);
321                }
322                let act = {
323                    let t0 = t0.read();
324                    let t1 = t1.read();
325                    let addr0 = Arc::as_ptr(&t0.typ).addr();
326                    let addr1 = Arc::as_ptr(&t1.typ).addr();
327                    if addr0 == addr1 {
328                        return Ok(true);
329                    }
330                    if would_cycle_inner(addr0, tt1) || would_cycle_inner(addr1, tt0) {
331                        return Ok(true);
332                    }
333                    let t0i = t0.typ.read();
334                    let t1i = t1.typ.read();
335                    match (&*t0i, &*t1i) {
336                        (Some(t0), Some(t1)) => {
337                            return t0.contains_int(flags, env, hist, &*t1)
338                        }
339                        (None, None) => {
340                            if t0.frozen && t1.frozen {
341                                return Ok(true);
342                            }
343                            if t0.frozen {
344                                Act::RightAlias
345                            } else {
346                                Act::LeftAlias
347                            }
348                        }
349                        (Some(_), None) => Act::RightCopy,
350                        (None, Some(_)) => Act::LeftCopy,
351                    }
352                };
353                match act {
354                    Act::RightCopy if flags.contains(ContainsFlags::InitTVars) => {
355                        t1.copy(t0)
356                    }
357                    Act::RightAlias if flags.contains(ContainsFlags::AliasTVars) => {
358                        t1.alias(t0)
359                    }
360                    Act::LeftAlias if flags.contains(ContainsFlags::AliasTVars) => {
361                        t0.alias(t1)
362                    }
363                    Act::LeftCopy if flags.contains(ContainsFlags::InitTVars) => {
364                        t0.copy(t1)
365                    }
366                    Act::RightCopy | Act::RightAlias | Act::LeftAlias | Act::LeftCopy => {
367                        ()
368                    }
369                }
370                Ok(true)
371            }
372            (Self::TVar(t0), t1) if !t0.would_cycle(t1) => {
373                if let Some(t0) = &*t0.read().typ.read() {
374                    return t0.contains_int(flags, env, hist, t1);
375                }
376                if flags.contains(ContainsFlags::InitTVars) {
377                    *t0.read().typ.write() = Some(t1.clone());
378                }
379                Ok(true)
380            }
381            (t0, Self::TVar(t1)) if !t1.would_cycle(t0) => {
382                if let Some(t1) = &*t1.read().typ.read() {
383                    return t0.contains_int(flags, env, hist, t1);
384                }
385                if flags.contains(ContainsFlags::InitTVars) {
386                    *t1.read().typ.write() = Some(t0.clone());
387                }
388                Ok(true)
389            }
390            (Self::Set(s0), Self::Set(s1))
391                if s0.as_ptr().addr() == s1.as_ptr().addr() =>
392            {
393                Ok(true)
394            }
395            (t0, Self::Set(s)) => Ok(s
396                .iter()
397                .map(|t1| t0.contains_int(flags, env, hist, t1))
398                .collect::<Result<AndAc>>()?
399                .0),
400            (Self::Set(s), t) => Ok(s
401                .iter()
402                .fold(Ok::<_, anyhow::Error>(false), |acc, t0| {
403                    Ok(acc? || t0.contains_int(flags, env, hist, t)?)
404                })?
405                || t.iter_prims().fold(Ok::<_, anyhow::Error>(true), |acc, t1| {
406                    Ok(acc?
407                        && s.iter().fold(Ok::<_, anyhow::Error>(false), |acc, t0| {
408                            Ok(acc? || t0.contains_int(flags, env, hist, &t1)?)
409                        })?)
410                })?),
411            (Self::Fn(f0), Self::Fn(f1)) => {
412                Ok(f0.as_ptr() == f1.as_ptr() || f0.contains_int(flags, env, hist, f1)?)
413            }
414            (_, Self::Any)
415            | (_, Self::TVar(_))
416            | (Self::TVar(_), _)
417            | (Self::Fn(_), _)
418            | (Self::ByRef(_), _)
419            | (_, Self::ByRef(_))
420            | (_, Self::Fn(_))
421            | (Self::Tuple(_), Self::Array(_))
422            | (Self::Tuple(_), Self::Primitive(_))
423            | (Self::Tuple(_), Self::Struct(_))
424            | (Self::Tuple(_), Self::Variant(_, _))
425            | (Self::Tuple(_), Self::Error(_))
426            | (Self::Tuple(_), Self::Map { .. })
427            | (Self::Array(_), Self::Primitive(_))
428            | (Self::Array(_), Self::Tuple(_))
429            | (Self::Array(_), Self::Struct(_))
430            | (Self::Array(_), Self::Variant(_, _))
431            | (Self::Array(_), Self::Error(_))
432            | (Self::Array(_), Self::Map { .. })
433            | (Self::Struct(_), Self::Array(_))
434            | (Self::Struct(_), Self::Primitive(_))
435            | (Self::Struct(_), Self::Tuple(_))
436            | (Self::Struct(_), Self::Variant(_, _))
437            | (Self::Struct(_), Self::Error(_))
438            | (Self::Struct(_), Self::Map { .. })
439            | (Self::Variant(_, _), Self::Array(_))
440            | (Self::Variant(_, _), Self::Struct(_))
441            | (Self::Variant(_, _), Self::Primitive(_))
442            | (Self::Variant(_, _), Self::Tuple(_))
443            | (Self::Variant(_, _), Self::Error(_))
444            | (Self::Variant(_, _), Self::Map { .. })
445            | (Self::Error(_), Self::Array(_))
446            | (Self::Error(_), Self::Primitive(_))
447            | (Self::Error(_), Self::Struct(_))
448            | (Self::Error(_), Self::Variant(_, _))
449            | (Self::Error(_), Self::Tuple(_))
450            | (Self::Error(_), Self::Map { .. })
451            | (Self::Map { .. }, Self::Array(_))
452            | (Self::Map { .. }, Self::Primitive(_))
453            | (Self::Map { .. }, Self::Struct(_))
454            | (Self::Map { .. }, Self::Variant(_, _))
455            | (Self::Map { .. }, Self::Tuple(_))
456            | (Self::Map { .. }, Self::Error(_)) => Ok(false),
457        }
458    }
459
460    pub fn contains<R: Rt, E: UserEvent>(
461        &self,
462        env: &Env<R, E>,
463        t: &Self,
464    ) -> Result<bool> {
465        self.contains_int(
466            ContainsFlags::AliasTVars | ContainsFlags::InitTVars,
467            env,
468            &mut LPooled::take(),
469            t,
470        )
471    }
472
473    pub fn contains_with_flags<R: Rt, E: UserEvent>(
474        &self,
475        flags: BitFlags<ContainsFlags>,
476        env: &Env<R, E>,
477        t: &Self,
478    ) -> Result<bool> {
479        self.contains_int(flags, env, &mut LPooled::take(), t)
480    }
481
482    fn could_match_int<R: Rt, E: UserEvent>(
483        &self,
484        env: &Env<R, E>,
485        hist: &mut FxHashMap<(usize, usize), bool>,
486        t: &Self,
487    ) -> Result<bool> {
488        let fl = BitFlags::empty();
489        match (self, t) {
490            (
491                Self::Ref { scope: s0, name: n0, params: p0 },
492                Self::Ref { scope: s1, name: n1, params: p1 },
493            ) if s0 == s1 && n0 == n1 => Ok(p0.len() == p1.len()
494                && p0
495                    .iter()
496                    .zip(p1.iter())
497                    .map(|(t0, t1)| t0.could_match_int(env, hist, t1))
498                    .collect::<Result<AndAc>>()?
499                    .0),
500            (t0 @ Self::Ref { .. }, t1) | (t0, t1 @ Self::Ref { .. }) => {
501                let t0 = t0.lookup_ref(env)?;
502                let t1 = t1.lookup_ref(env)?;
503                let t0_addr = (t0 as *const Type).addr();
504                let t1_addr = (t1 as *const Type).addr();
505                match hist.get(&(t0_addr, t1_addr)) {
506                    Some(r) => Ok(*r),
507                    None => {
508                        hist.insert((t0_addr, t1_addr), true);
509                        match t0.could_match_int(env, hist, t1) {
510                            Ok(r) => {
511                                hist.insert((t0_addr, t1_addr), r);
512                                Ok(r)
513                            }
514                            Err(e) => {
515                                hist.remove(&(t0_addr, t1_addr));
516                                Err(e)
517                            }
518                        }
519                    }
520                }
521            }
522            (t0, Self::Primitive(s)) => {
523                for t1 in s.iter() {
524                    if t0.contains_int(fl, env, hist, &Type::Primitive(t1.into()))? {
525                        return Ok(true);
526                    }
527                }
528                Ok(false)
529            }
530            (Type::Primitive(p), Type::Error(_)) => Ok(p.contains(Typ::Error)),
531            (Type::Error(t0), Type::Error(t1)) => t0.could_match_int(env, hist, t1),
532            (Type::Array(t0), Type::Array(t1)) => t0.could_match_int(env, hist, t1),
533            (Type::Primitive(p), Type::Array(_)) => Ok(p.contains(Typ::Array)),
534            (Type::Map { key: k0, value: v0 }, Type::Map { key: k1, value: v1 }) => {
535                Ok(k0.could_match_int(env, hist, k1)?
536                    && v0.could_match_int(env, hist, v1)?)
537            }
538            (Type::Primitive(p), Type::Map { .. }) => Ok(p.contains(Typ::Map)),
539            (Type::Tuple(ts0), Type::Tuple(ts1)) => Ok(ts0.len() == ts1.len()
540                && ts0
541                    .iter()
542                    .zip(ts1.iter())
543                    .map(|(t0, t1)| t0.could_match_int(env, hist, t1))
544                    .collect::<Result<AndAc>>()?
545                    .0),
546            (Type::Struct(ts0), Type::Struct(ts1)) => Ok(ts0.len() == ts1.len()
547                && ts0
548                    .iter()
549                    .zip(ts1.iter())
550                    .map(|((n0, t0), (n1, t1))| {
551                        Ok(n0 == n1 && t0.could_match_int(env, hist, t1)?)
552                    })
553                    .collect::<Result<AndAc>>()?
554                    .0),
555            (Type::Variant(n0, ts0), Type::Variant(n1, ts1)) => Ok(ts0.len()
556                == ts1.len()
557                && n0 == n1
558                && ts0
559                    .iter()
560                    .zip(ts1.iter())
561                    .map(|(t0, t1)| t0.could_match_int(env, hist, t1))
562                    .collect::<Result<AndAc>>()?
563                    .0),
564            (Type::ByRef(t0), Type::ByRef(t1)) => t0.could_match_int(env, hist, t1),
565            (t0, Self::Set(ts)) => {
566                for t1 in ts.iter() {
567                    if t0.could_match_int(env, hist, t1)? {
568                        return Ok(true);
569                    }
570                }
571                Ok(false)
572            }
573            (Type::Set(ts), t1) => {
574                for t0 in ts.iter() {
575                    if t0.could_match_int(env, hist, t1)? {
576                        return Ok(true);
577                    }
578                }
579                Ok(false)
580            }
581            (Type::TVar(t0), t1) => match &*t0.read().typ.read() {
582                Some(t0) => t0.could_match_int(env, hist, t1),
583                None => Ok(true),
584            },
585            (t0, Type::TVar(t1)) => match &*t1.read().typ.read() {
586                Some(t1) => t0.could_match_int(env, hist, t1),
587                None => Ok(true),
588            },
589            (_, Type::Bottom) => Ok(true),
590            (Type::Bottom, _) => Ok(false),
591            (Type::Any, _) | (_, Type::Any) => Ok(true),
592            (Type::Fn(_), _)
593            | (_, Type::Fn(_))
594            | (Type::Tuple(_), _)
595            | (_, Type::Tuple(_))
596            | (Type::Struct(_), _)
597            | (_, Type::Struct(_))
598            | (Type::Variant(_, _), _)
599            | (_, Type::Variant(_, _))
600            | (Type::ByRef(_), _)
601            | (_, Type::ByRef(_))
602            | (Type::Array(_), _)
603            | (_, Type::Array(_))
604            | (_, Type::Map { .. })
605            | (Type::Map { .. }, _) => Ok(false),
606        }
607    }
608
609    pub fn could_match<R: Rt, E: UserEvent>(
610        &self,
611        env: &Env<R, E>,
612        t: &Self,
613    ) -> Result<bool> {
614        self.could_match_int(env, &mut LPooled::take(), t)
615    }
616
617    pub fn sig_matches<R: Rt, E: UserEvent>(
618        &self,
619        env: &Env<R, E>,
620        impl_type: &Self,
621    ) -> Result<()> {
622        self.sig_matches_int(env, impl_type, &mut LPooled::take(), &mut LPooled::take())
623    }
624
625    pub(crate) fn sig_matches_int<R: Rt, E: UserEvent>(
626        &self,
627        env: &Env<R, E>,
628        impl_type: &Self,
629        tvar_map: &mut FxHashMap<usize, Type>,
630        hist: &mut FxHashSet<(usize, usize)>,
631    ) -> Result<()> {
632        if (self as *const Type) == (impl_type as *const Type) {
633            return Ok(());
634        }
635        match (self, impl_type) {
636            (Self::Bottom, Self::Bottom) => Ok(()),
637            (Self::Any, Self::Any) => Ok(()),
638            (Self::Primitive(p0), Self::Primitive(p1)) if p0 == p1 => Ok(()),
639            (
640                Self::Ref { scope: s0, name: n0, params: p0 },
641                Self::Ref { scope: s1, name: n1, params: p1 },
642            ) if s0 == s1 && n0 == n1 && p0.len() == p1.len() => {
643                for (t0, t1) in p0.iter().zip(p1.iter()) {
644                    t0.sig_matches_int(env, t1, tvar_map, hist)?;
645                }
646                Ok(())
647            }
648            (t0 @ Self::Ref { .. }, t1) | (t0, t1 @ Self::Ref { .. }) => {
649                let t0 = t0.lookup_ref(env)?;
650                let t1 = t1.lookup_ref(env)?;
651                let t0_addr = (t0 as *const Type).addr();
652                let t1_addr = (t1 as *const Type).addr();
653                if hist.contains(&(t0_addr, t1_addr)) {
654                    Ok(())
655                } else {
656                    hist.insert((t0_addr, t1_addr));
657                    t0.sig_matches_int(env, t1, tvar_map, hist)
658                }
659            }
660            (Self::Fn(f0), Self::Fn(f1)) => f0.sig_matches_int(env, f1, tvar_map, hist),
661            (Self::Set(s0), Self::Set(s1)) if s0.len() == s1.len() => {
662                for (t0, t1) in s0.iter().zip(s1.iter()) {
663                    t0.sig_matches_int(env, t1, tvar_map, hist)?;
664                }
665                Ok(())
666            }
667            (Self::Error(e0), Self::Error(e1)) => {
668                e0.sig_matches_int(env, e1, tvar_map, hist)
669            }
670            (Self::Array(a0), Self::Array(a1)) => {
671                a0.sig_matches_int(env, a1, tvar_map, hist)
672            }
673            (Self::ByRef(b0), Self::ByRef(b1)) => {
674                b0.sig_matches_int(env, b1, tvar_map, hist)
675            }
676            (Self::Tuple(t0), Self::Tuple(t1)) if t0.len() == t1.len() => {
677                for (t0, t1) in t0.iter().zip(t1.iter()) {
678                    t0.sig_matches_int(env, t1, tvar_map, hist)?;
679                }
680                Ok(())
681            }
682            (Self::Struct(s0), Self::Struct(s1)) if s0.len() == s1.len() => {
683                for ((n0, t0), (n1, t1)) in s0.iter().zip(s1.iter()) {
684                    if n0 != n1 {
685                        format_with_flags(PrintFlag::DerefTVars, || {
686                            bail!("struct field name mismatch: {n0} vs {n1}")
687                        })?
688                    }
689                    t0.sig_matches_int(env, t1, tvar_map, hist)?;
690                }
691                Ok(())
692            }
693            (Self::Variant(tag0, t0), Self::Variant(tag1, t1))
694                if tag0 == tag1 && t0.len() == t1.len() =>
695            {
696                for (t0, t1) in t0.iter().zip(t1.iter()) {
697                    t0.sig_matches_int(env, t1, tvar_map, hist)?;
698                }
699                Ok(())
700            }
701            (Self::Map { key: k0, value: v0 }, Self::Map { key: k1, value: v1 }) => {
702                k0.sig_matches_int(env, k1, tvar_map, hist)?;
703                v0.sig_matches_int(env, v1, tvar_map, hist)
704            }
705            (Self::TVar(sig_tv), Self::TVar(impl_tv)) if sig_tv != impl_tv => {
706                format_with_flags(PrintFlag::DerefTVars, || {
707                    bail!("signature type variable {sig_tv} does not match implementation {impl_tv}")
708                })
709            }
710            (sig_type, Self::TVar(impl_tv)) => {
711                let impl_tv_addr = impl_tv.inner_addr();
712                match tvar_map.get(&impl_tv_addr) {
713                    Some(prev_sig_type) => {
714                        let matches = match (sig_type, prev_sig_type) {
715                            (Type::TVar(tv0), Type::TVar(tv1)) => {
716                                tv0.inner_addr() == tv1.inner_addr()
717                            }
718                            _ => sig_type == prev_sig_type,
719                        };
720                        if matches {
721                            Ok(())
722                        } else {
723                            format_with_flags(PrintFlag::DerefTVars, || {
724                                bail!(
725                                    "type variable usage mismatch: expected {prev_sig_type}, got {sig_type}"
726                                )
727                            })
728                        }
729                    }
730                    None => {
731                        tvar_map.insert(impl_tv_addr, sig_type.clone());
732                        Ok(())
733                    }
734                }
735            }
736            (Self::TVar(sig_tv), impl_type) => {
737                format_with_flags(PrintFlag::DerefTVars, || {
738                    bail!("signature has type variable '{sig_tv} where implementation has {impl_type}")
739                })
740            }
741            (sig_type, impl_type) => format_with_flags(PrintFlag::DerefTVars, || {
742                bail!("type mismatch: signature has {sig_type}, implementation has {impl_type}")
743            }),
744        }
745    }
746
747    fn union_int<R: Rt, E: UserEvent>(
748        &self,
749        env: &Env<R, E>,
750        hist: &mut FxHashMap<(usize, usize), Type>,
751        t: &Self,
752    ) -> Result<Self> {
753        match (self, t) {
754            (
755                Type::Ref { scope: s0, name: n0, params: p0 },
756                Type::Ref { scope: s1, name: n1, params: p1 },
757            ) if n0 == n1 && s0 == s1 && p0.len() == p1.len() => {
758                let mut params = p0
759                    .iter()
760                    .zip(p1.iter())
761                    .map(|(p0, p1)| p0.union_int(env, hist, p1))
762                    .collect::<Result<LPooled<Vec<_>>>>()?;
763                let params = Arc::from_iter(params.drain(..));
764                Ok(Self::Ref { scope: s0.clone(), name: n0.clone(), params })
765            }
766            (tr @ Type::Ref { .. }, t) => {
767                let t0 = tr.lookup_ref(env)?;
768                let t0_addr = (t0 as *const Type).addr();
769                let t_addr = (t as *const Type).addr();
770                match hist.get(&(t0_addr, t_addr)) {
771                    Some(t) => Ok(t.clone()),
772                    None => {
773                        hist.insert((t0_addr, t_addr), tr.clone());
774                        let r = t0.union_int(env, hist, t)?;
775                        hist.insert((t0_addr, t_addr), r.clone());
776                        Ok(r)
777                    }
778                }
779            }
780            (t, tr @ Type::Ref { .. }) => {
781                let t1 = tr.lookup_ref(env)?;
782                let t1_addr = (t1 as *const Type).addr();
783                let t_addr = (t as *const Type).addr();
784                match hist.get(&(t_addr, t1_addr)) {
785                    Some(t) => Ok(t.clone()),
786                    None => {
787                        hist.insert((t_addr, t1_addr), tr.clone());
788                        let r = t.union_int(env, hist, t1)?;
789                        hist.insert((t_addr, t1_addr), r.clone());
790                        Ok(r)
791                    }
792                }
793            }
794            (Type::Bottom, t) | (t, Type::Bottom) => Ok(t.clone()),
795            (Type::Any, _) | (_, Type::Any) => Ok(Type::Any),
796            (Type::Primitive(p), t) | (t, Type::Primitive(p)) if p.is_empty() => {
797                Ok(t.clone())
798            }
799            (Type::Primitive(s0), Type::Primitive(s1)) => {
800                let mut s = *s0;
801                s.insert(*s1);
802                Ok(Type::Primitive(s))
803            }
804            (
805                Type::Primitive(p),
806                Type::Array(_) | Type::Struct(_) | Type::Tuple(_) | Type::Variant(_, _),
807            )
808            | (
809                Type::Array(_) | Type::Struct(_) | Type::Tuple(_) | Type::Variant(_, _),
810                Type::Primitive(p),
811            ) if p.contains(Typ::Array) => Ok(Type::Primitive(*p)),
812            (Type::Primitive(p), Type::Array(t))
813            | (Type::Array(t), Type::Primitive(p)) => Ok(Type::Set(Arc::from_iter([
814                Type::Primitive(*p),
815                Type::Array(t.clone()),
816            ]))),
817            (t @ Type::Array(t0), u @ Type::Array(t1)) => {
818                if t0 == t1 {
819                    Ok(Type::Array(t0.clone()))
820                } else {
821                    Ok(Type::Set(Arc::from_iter([t.clone(), u.clone()])))
822                }
823            }
824            (Type::Primitive(p), Type::Map { .. })
825            | (Type::Map { .. }, Type::Primitive(p))
826                if p.contains(Typ::Map) =>
827            {
828                Ok(Type::Primitive(*p))
829            }
830            (Type::Primitive(p), Type::Map { key, value })
831            | (Type::Map { key, value }, Type::Primitive(p)) => {
832                Ok(Type::Set(Arc::from_iter([
833                    Type::Primitive(*p),
834                    Type::Map { key: key.clone(), value: value.clone() },
835                ])))
836            }
837            (
838                t @ Type::Map { key: k0, value: v0 },
839                u @ Type::Map { key: k1, value: v1 },
840            ) => {
841                if k0 == k1 && v0 == v1 {
842                    Ok(Type::Map { key: k0.clone(), value: v0.clone() })
843                } else {
844                    Ok(Type::Set(Arc::from_iter([t.clone(), u.clone()])))
845                }
846            }
847            (t @ Type::Map { .. }, u) | (u, t @ Type::Map { .. }) => {
848                Ok(Type::Set(Arc::from_iter([t.clone(), u.clone()])))
849            }
850            (Type::Primitive(p), Type::Error(_))
851            | (Type::Error(_), Type::Primitive(p))
852                if p.contains(Typ::Error) =>
853            {
854                Ok(Type::Primitive(*p))
855            }
856            (Type::Error(e0), Type::Error(e1)) => {
857                Ok(Type::Error(Arc::new(e0.union_int(env, hist, e1)?)))
858            }
859            (e @ Type::Error(_), t) | (t, e @ Type::Error(_)) => {
860                Ok(Type::Set(Arc::from_iter([e.clone(), t.clone()])))
861            }
862            (t @ Type::ByRef(t0), u @ Type::ByRef(t1)) => {
863                if t0 == t1 {
864                    Ok(Type::ByRef(t0.clone()))
865                } else {
866                    Ok(Type::Set(Arc::from_iter([u.clone(), t.clone()])))
867                }
868            }
869            (Type::Set(s0), Type::Set(s1)) => Ok(Type::Set(Arc::from_iter(
870                s0.iter().cloned().chain(s1.iter().cloned()),
871            ))),
872            (Type::Set(s), t) | (t, Type::Set(s)) => Ok(Type::Set(Arc::from_iter(
873                s.iter().cloned().chain(iter::once(t.clone())),
874            ))),
875            (u @ Type::Struct(t0), t @ Type::Struct(t1)) => {
876                if t0.len() == t1.len() && t0 == t1 {
877                    Ok(u.clone())
878                } else {
879                    Ok(Type::Set(Arc::from_iter([u.clone(), t.clone()])))
880                }
881            }
882            (u @ Type::Struct(_), t) | (t, u @ Type::Struct(_)) => {
883                Ok(Type::Set(Arc::from_iter([u.clone(), t.clone()])))
884            }
885            (u @ Type::Tuple(t0), t @ Type::Tuple(t1)) => {
886                if t0 == t1 {
887                    Ok(u.clone())
888                } else {
889                    Ok(Type::Set(Arc::from_iter([u.clone(), t.clone()])))
890                }
891            }
892            (u @ Type::Tuple(_), t) | (t, u @ Type::Tuple(_)) => {
893                Ok(Type::Set(Arc::from_iter([u.clone(), t.clone()])))
894            }
895            (u @ Type::Variant(tg0, t0), t @ Type::Variant(tg1, t1)) => {
896                if tg0 == tg1 && t0.len() == t1.len() {
897                    let typs = t0
898                        .iter()
899                        .zip(t1.iter())
900                        .map(|(t0, t1)| t0.union_int(env, hist, t1))
901                        .collect::<Result<SmallVec<[_; 8]>>>()?;
902                    Ok(Type::Variant(tg0.clone(), Arc::from_iter(typs.into_iter())))
903                } else {
904                    Ok(Type::Set(Arc::from_iter([u.clone(), t.clone()])))
905                }
906            }
907            (u @ Type::Variant(_, _), t) | (t, u @ Type::Variant(_, _)) => {
908                Ok(Type::Set(Arc::from_iter([u.clone(), t.clone()])))
909            }
910            (Type::Fn(f0), Type::Fn(f1)) => {
911                if f0 == f1 {
912                    Ok(Type::Fn(f0.clone()))
913                } else {
914                    Ok(Type::Set(Arc::from_iter([
915                        Type::Fn(f0.clone()),
916                        Type::Fn(f1.clone()),
917                    ])))
918                }
919            }
920            (f @ Type::Fn(_), t) | (t, f @ Type::Fn(_)) => {
921                Ok(Type::Set(Arc::from_iter([f.clone(), t.clone()])))
922            }
923            (t0 @ Type::TVar(_), t1 @ Type::TVar(_)) => {
924                if t0 == t1 {
925                    Ok(t0.clone())
926                } else {
927                    Ok(Type::Set(Arc::from_iter([t0.clone(), t1.clone()])))
928                }
929            }
930            (t0 @ Type::TVar(_), t1) | (t1, t0 @ Type::TVar(_)) => {
931                Ok(Type::Set(Arc::from_iter([t0.clone(), t1.clone()])))
932            }
933            (t @ Type::ByRef(_), u) | (u, t @ Type::ByRef(_)) => {
934                Ok(Type::Set(Arc::from_iter([t.clone(), u.clone()])))
935            }
936        }
937    }
938
939    pub fn union<R: Rt, E: UserEvent>(&self, env: &Env<R, E>, t: &Self) -> Result<Self> {
940        Ok(self.union_int(env, &mut LPooled::take(), t)?.normalize())
941    }
942
943    fn diff_int<R: Rt, E: UserEvent>(
944        &self,
945        env: &Env<R, E>,
946        hist: &mut FxHashMap<(usize, usize), Type>,
947        t: &Self,
948    ) -> Result<Self> {
949        match (self, t) {
950            (
951                Type::Ref { scope: s0, name: n0, .. },
952                Type::Ref { scope: s1, name: n1, .. },
953            ) if s0 == s1 && n0 == n1 => Ok(Type::Primitive(BitFlags::empty())),
954            (t0 @ Type::Ref { .. }, t1) | (t0, t1 @ Type::Ref { .. }) => {
955                let t0 = t0.lookup_ref(env)?;
956                let t1 = t1.lookup_ref(env)?;
957                let t0_addr = (t0 as *const Type).addr();
958                let t1_addr = (t1 as *const Type).addr();
959                match hist.get(&(t0_addr, t1_addr)) {
960                    Some(r) => Ok(r.clone()),
961                    None => {
962                        let r = Type::Primitive(BitFlags::empty());
963                        hist.insert((t0_addr, t1_addr), r);
964                        match t0.diff_int(env, hist, &t1) {
965                            Ok(r) => {
966                                hist.insert((t0_addr, t1_addr), r.clone());
967                                Ok(r)
968                            }
969                            Err(e) => {
970                                hist.remove(&(t0_addr, t1_addr));
971                                Err(e)
972                            }
973                        }
974                    }
975                }
976            }
977            (Type::Set(s0), Type::Set(s1)) => {
978                let mut s: SmallVec<[Type; 4]> = smallvec![];
979                for i in 0..s0.len() {
980                    s.push(s0[i].clone());
981                    for j in 0..s1.len() {
982                        s[i] = s[i].diff_int(env, hist, &s1[j])?
983                    }
984                }
985                Ok(Self::flatten_set(s.into_iter()))
986            }
987            (Type::Set(s), t) => Ok(Self::flatten_set(
988                s.iter()
989                    .map(|s| s.diff_int(env, hist, t))
990                    .collect::<Result<SmallVec<[_; 8]>>>()?,
991            )),
992            (t, Type::Set(s)) => {
993                let mut t = t.clone();
994                for st in s.iter() {
995                    t = t.diff_int(env, hist, st)?;
996                }
997                Ok(t)
998            }
999            (Type::Tuple(t0), Type::Tuple(t1)) => {
1000                if t0 == t1 {
1001                    Ok(Type::Primitive(BitFlags::empty()))
1002                } else {
1003                    Ok(self.clone())
1004                }
1005            }
1006            (Type::Struct(t0), Type::Struct(t1)) => {
1007                if t0.len() == t1.len() && t0 == t1 {
1008                    Ok(Type::Primitive(BitFlags::empty()))
1009                } else {
1010                    Ok(self.clone())
1011                }
1012            }
1013            (Type::Variant(tg0, t0), Type::Variant(tg1, t1)) => {
1014                if tg0 == tg1 && t0.len() == t1.len() && t0 == t1 {
1015                    Ok(Type::Primitive(BitFlags::empty()))
1016                } else {
1017                    Ok(self.clone())
1018                }
1019            }
1020            (Type::Map { key: k0, value: v0 }, Type::Map { key: k1, value: v1 }) => {
1021                if k0 == k1 && v0 == v1 {
1022                    Ok(Type::Primitive(BitFlags::empty()))
1023                } else {
1024                    Ok(self.clone())
1025                }
1026            }
1027            (Type::Map { .. }, Type::Primitive(p)) => {
1028                if p.contains(Typ::Map) {
1029                    Ok(Type::Primitive(BitFlags::empty()))
1030                } else {
1031                    Ok(self.clone())
1032                }
1033            }
1034            (Type::Primitive(p), Type::Map { key, value }) => {
1035                if **key == Type::Any && **value == Type::Any {
1036                    let mut p = *p;
1037                    p.remove(Typ::Map);
1038                    Ok(Type::Primitive(p))
1039                } else {
1040                    Ok(Type::Primitive(*p))
1041                }
1042            }
1043            (Type::Fn(f0), Type::Fn(f1)) => {
1044                if f0 == f1 {
1045                    Ok(Type::Primitive(BitFlags::empty()))
1046                } else {
1047                    Ok(Type::Fn(f0.clone()))
1048                }
1049            }
1050            (Type::TVar(tv0), Type::TVar(tv1)) => {
1051                if tv0.read().typ.as_ptr() == tv1.read().typ.as_ptr() {
1052                    return Ok(Type::Primitive(BitFlags::empty()));
1053                }
1054                Ok(match (&*tv0.read().typ.read(), &*tv1.read().typ.read()) {
1055                    (None, _) | (_, None) => Type::TVar(tv0.clone()),
1056                    (Some(t0), Some(t1)) => t0.diff_int(env, hist, t1)?,
1057                })
1058            }
1059            (Type::TVar(tv), t) => Ok(match &*tv.read().typ.read() {
1060                Some(tv) => tv.diff_int(env, hist, t)?,
1061                None => self.clone(),
1062            }),
1063            (t, Type::TVar(tv)) => Ok(match &*tv.read().typ.read() {
1064                Some(tv) => t.diff_int(env, hist, tv)?,
1065                None => self.clone(),
1066            }),
1067            (Type::Array(t0), Type::Array(t1)) => {
1068                if t0 == t1 {
1069                    Ok(Type::Primitive(BitFlags::empty()))
1070                } else {
1071                    Ok(Type::Array(Arc::new(t0.diff_int(env, hist, t1)?)))
1072                }
1073            }
1074            (Type::Primitive(p), Type::Array(t)) => {
1075                if &**t == &Type::Any {
1076                    let mut s = *p;
1077                    s.remove(Typ::Array);
1078                    Ok(Type::Primitive(s))
1079                } else {
1080                    Ok(Type::Primitive(*p))
1081                }
1082            }
1083            (
1084                Type::Array(_) | Type::Struct(_) | Type::Tuple(_) | Type::Variant(_, _),
1085                Type::Primitive(p),
1086            ) => {
1087                if p.contains(Typ::Array) {
1088                    Ok(Type::Primitive(BitFlags::empty()))
1089                } else {
1090                    Ok(self.clone())
1091                }
1092            }
1093            (_, Type::Any) => Ok(Type::Primitive(BitFlags::empty())),
1094            (Type::Any, _) => Ok(Type::Any),
1095            (Type::Primitive(s0), Type::Primitive(s1)) => {
1096                let mut s = *s0;
1097                s.remove(*s1);
1098                Ok(Type::Primitive(s))
1099            }
1100            (Type::Primitive(p), Type::Error(e)) => {
1101                if &**e == &Type::Any {
1102                    let mut s = *p;
1103                    s.remove(Typ::Error);
1104                    Ok(Type::Primitive(s))
1105                } else {
1106                    Ok(Type::Primitive(*p))
1107                }
1108            }
1109            (Type::Error(_), Type::Primitive(p)) => {
1110                if p.contains(Typ::Error) {
1111                    Ok(Type::Primitive(BitFlags::empty()))
1112                } else {
1113                    Ok(self.clone())
1114                }
1115            }
1116            (Type::Error(e0), Type::Error(e1)) => {
1117                if e0 == e1 {
1118                    Ok(Type::Primitive(BitFlags::empty()))
1119                } else {
1120                    Ok(Type::Error(Arc::new(e0.diff_int(env, hist, e1)?)))
1121                }
1122            }
1123            (Type::ByRef(t0), Type::ByRef(t1)) => {
1124                Ok(Type::ByRef(Arc::new(t0.diff_int(env, hist, t1)?)))
1125            }
1126            (Type::Fn(_), _)
1127            | (_, Type::Fn(_))
1128            | (Type::Array(_), _)
1129            | (_, Type::Array(_))
1130            | (Type::Tuple(_), _)
1131            | (_, Type::Tuple(_))
1132            | (Type::Struct(_), _)
1133            | (_, Type::Struct(_))
1134            | (Type::Variant(_, _), _)
1135            | (_, Type::Variant(_, _))
1136            | (Type::ByRef(_), _)
1137            | (_, Type::ByRef(_))
1138            | (Type::Error(_), _)
1139            | (_, Type::Error(_))
1140            | (Type::Primitive(_), _)
1141            | (_, Type::Primitive(_))
1142            | (Type::Bottom, _)
1143            | (Type::Map { .. }, _) => Ok(self.clone()),
1144        }
1145    }
1146
1147    pub fn diff<R: Rt, E: UserEvent>(&self, env: &Env<R, E>, t: &Self) -> Result<Self> {
1148        Ok(self.diff_int(env, &mut LPooled::take(), t)?.normalize())
1149    }
1150
1151    pub fn any() -> Self {
1152        Self::Any
1153    }
1154
1155    pub fn boolean() -> Self {
1156        Self::Primitive(Typ::Bool.into())
1157    }
1158
1159    pub fn number() -> Self {
1160        Self::Primitive(Typ::number())
1161    }
1162
1163    pub fn int() -> Self {
1164        Self::Primitive(Typ::integer())
1165    }
1166
1167    pub fn uint() -> Self {
1168        Self::Primitive(Typ::unsigned_integer())
1169    }
1170
1171    /// alias type variables with the same name to each other
1172    pub fn alias_tvars(&self, known: &mut FxHashMap<ArcStr, TVar>) {
1173        match self {
1174            Type::Bottom | Type::Any | Type::Primitive(_) => (),
1175            Type::Ref { params, .. } => {
1176                for t in params.iter() {
1177                    t.alias_tvars(known);
1178                }
1179            }
1180            Type::Error(t) => t.alias_tvars(known),
1181            Type::Array(t) => t.alias_tvars(known),
1182            Type::Map { key, value } => {
1183                key.alias_tvars(known);
1184                value.alias_tvars(known);
1185            }
1186            Type::ByRef(t) => t.alias_tvars(known),
1187            Type::Tuple(ts) => {
1188                for t in ts.iter() {
1189                    t.alias_tvars(known)
1190                }
1191            }
1192            Type::Struct(ts) => {
1193                for (_, t) in ts.iter() {
1194                    t.alias_tvars(known)
1195                }
1196            }
1197            Type::Variant(_, ts) => {
1198                for t in ts.iter() {
1199                    t.alias_tvars(known)
1200                }
1201            }
1202            Type::TVar(tv) => match known.entry(tv.name.clone()) {
1203                Entry::Occupied(e) => {
1204                    let v = e.get();
1205                    v.freeze();
1206                    tv.alias(v);
1207                }
1208                Entry::Vacant(e) => {
1209                    e.insert(tv.clone());
1210                    ()
1211                }
1212            },
1213            Type::Fn(ft) => ft.alias_tvars(known),
1214            Type::Set(s) => {
1215                for typ in s.iter() {
1216                    typ.alias_tvars(known)
1217                }
1218            }
1219        }
1220    }
1221
1222    pub fn collect_tvars(&self, known: &mut FxHashMap<ArcStr, TVar>) {
1223        match self {
1224            Type::Bottom | Type::Any | Type::Primitive(_) => (),
1225            Type::Ref { params, .. } => {
1226                for t in params.iter() {
1227                    t.collect_tvars(known);
1228                }
1229            }
1230            Type::Error(t) => t.collect_tvars(known),
1231            Type::Array(t) => t.collect_tvars(known),
1232            Type::Map { key, value } => {
1233                key.collect_tvars(known);
1234                value.collect_tvars(known);
1235            }
1236            Type::ByRef(t) => t.collect_tvars(known),
1237            Type::Tuple(ts) => {
1238                for t in ts.iter() {
1239                    t.collect_tvars(known)
1240                }
1241            }
1242            Type::Struct(ts) => {
1243                for (_, t) in ts.iter() {
1244                    t.collect_tvars(known)
1245                }
1246            }
1247            Type::Variant(_, ts) => {
1248                for t in ts.iter() {
1249                    t.collect_tvars(known)
1250                }
1251            }
1252            Type::TVar(tv) => match known.entry(tv.name.clone()) {
1253                Entry::Occupied(_) => (),
1254                Entry::Vacant(e) => {
1255                    e.insert(tv.clone());
1256                    ()
1257                }
1258            },
1259            Type::Fn(ft) => ft.collect_tvars(known),
1260            Type::Set(s) => {
1261                for typ in s.iter() {
1262                    typ.collect_tvars(known)
1263                }
1264            }
1265        }
1266    }
1267
1268    pub fn check_tvars_declared(&self, declared: &FxHashSet<ArcStr>) -> Result<()> {
1269        match self {
1270            Type::Bottom | Type::Any | Type::Primitive(_) => Ok(()),
1271            Type::Ref { params, .. } => {
1272                params.iter().try_for_each(|t| t.check_tvars_declared(declared))
1273            }
1274            Type::Error(t) => t.check_tvars_declared(declared),
1275            Type::Array(t) => t.check_tvars_declared(declared),
1276            Type::Map { key, value } => {
1277                key.check_tvars_declared(declared)?;
1278                value.check_tvars_declared(declared)
1279            }
1280            Type::ByRef(t) => t.check_tvars_declared(declared),
1281            Type::Tuple(ts) => {
1282                ts.iter().try_for_each(|t| t.check_tvars_declared(declared))
1283            }
1284            Type::Struct(ts) => {
1285                ts.iter().try_for_each(|(_, t)| t.check_tvars_declared(declared))
1286            }
1287            Type::Variant(_, ts) => {
1288                ts.iter().try_for_each(|t| t.check_tvars_declared(declared))
1289            }
1290            Type::TVar(tv) => {
1291                if !declared.contains(&tv.name) {
1292                    bail!("undeclared type variable '{}'", tv.name)
1293                } else {
1294                    Ok(())
1295                }
1296            }
1297            Type::Set(s) => s.iter().try_for_each(|t| t.check_tvars_declared(declared)),
1298            Type::Fn(_) => Ok(()),
1299        }
1300    }
1301
1302    pub fn has_unbound(&self) -> bool {
1303        match self {
1304            Type::Bottom | Type::Any | Type::Primitive(_) => false,
1305            Type::Ref { .. } => false,
1306            Type::Error(e) => e.has_unbound(),
1307            Type::Array(t0) => t0.has_unbound(),
1308            Type::Map { key, value } => key.has_unbound() || value.has_unbound(),
1309            Type::ByRef(t0) => t0.has_unbound(),
1310            Type::Tuple(ts) => ts.iter().any(|t| t.has_unbound()),
1311            Type::Struct(ts) => ts.iter().any(|(_, t)| t.has_unbound()),
1312            Type::Variant(_, ts) => ts.iter().any(|t| t.has_unbound()),
1313            Type::TVar(tv) => tv.read().typ.read().is_some(),
1314            Type::Set(s) => s.iter().any(|t| t.has_unbound()),
1315            Type::Fn(ft) => ft.has_unbound(),
1316        }
1317    }
1318
1319    /// bind all unbound type variables to the specified type
1320    pub fn bind_as(&self, t: &Self) {
1321        match self {
1322            Type::Bottom | Type::Any | Type::Primitive(_) => (),
1323            Type::Ref { .. } => (),
1324            Type::Error(t0) => t0.bind_as(t),
1325            Type::Array(t0) => t0.bind_as(t),
1326            Type::Map { key, value } => {
1327                key.bind_as(t);
1328                value.bind_as(t);
1329            }
1330            Type::ByRef(t0) => t0.bind_as(t),
1331            Type::Tuple(ts) => {
1332                for elt in ts.iter() {
1333                    elt.bind_as(t)
1334                }
1335            }
1336            Type::Struct(ts) => {
1337                for (_, elt) in ts.iter() {
1338                    elt.bind_as(t)
1339                }
1340            }
1341            Type::Variant(_, ts) => {
1342                for elt in ts.iter() {
1343                    elt.bind_as(t)
1344                }
1345            }
1346            Type::TVar(tv) => {
1347                let tv = tv.read();
1348                let mut tv = tv.typ.write();
1349                if tv.is_none() {
1350                    *tv = Some(t.clone());
1351                }
1352            }
1353            Type::Set(s) => {
1354                for elt in s.iter() {
1355                    elt.bind_as(t)
1356                }
1357            }
1358            Type::Fn(ft) => ft.bind_as(t),
1359        }
1360    }
1361
1362    /// return a copy of self with all type variables unbound and
1363    /// unaliased. self will not be modified
1364    pub fn reset_tvars(&self) -> Type {
1365        match self {
1366            Type::Bottom => Type::Bottom,
1367            Type::Any => Type::Any,
1368            Type::Primitive(p) => Type::Primitive(*p),
1369            Type::Ref { scope, name, params } => Type::Ref {
1370                scope: scope.clone(),
1371                name: name.clone(),
1372                params: Arc::from_iter(params.iter().map(|t| t.reset_tvars())),
1373            },
1374            Type::Error(t0) => Type::Error(Arc::new(t0.reset_tvars())),
1375            Type::Array(t0) => Type::Array(Arc::new(t0.reset_tvars())),
1376            Type::Map { key, value } => {
1377                let key = Arc::new(key.reset_tvars());
1378                let value = Arc::new(value.reset_tvars());
1379                Type::Map { key, value }
1380            }
1381            Type::ByRef(t0) => Type::ByRef(Arc::new(t0.reset_tvars())),
1382            Type::Tuple(ts) => {
1383                Type::Tuple(Arc::from_iter(ts.iter().map(|t| t.reset_tvars())))
1384            }
1385            Type::Struct(ts) => Type::Struct(Arc::from_iter(
1386                ts.iter().map(|(n, t)| (n.clone(), t.reset_tvars())),
1387            )),
1388            Type::Variant(tag, ts) => Type::Variant(
1389                tag.clone(),
1390                Arc::from_iter(ts.iter().map(|t| t.reset_tvars())),
1391            ),
1392            Type::TVar(tv) => Type::TVar(TVar::empty_named(tv.name.clone())),
1393            Type::Set(s) => Type::Set(Arc::from_iter(s.iter().map(|t| t.reset_tvars()))),
1394            Type::Fn(fntyp) => Type::Fn(Arc::new(fntyp.reset_tvars())),
1395        }
1396    }
1397
1398    /// return a copy of self with every TVar named in known replaced
1399    /// with the corresponding type
1400    pub fn replace_tvars(&self, known: &FxHashMap<ArcStr, Self>) -> Type {
1401        match self {
1402            Type::TVar(tv) => match known.get(&tv.name) {
1403                Some(t) => t.clone(),
1404                None => Type::TVar(tv.clone()),
1405            },
1406            Type::Bottom => Type::Bottom,
1407            Type::Any => Type::Any,
1408            Type::Primitive(p) => Type::Primitive(*p),
1409            Type::Ref { scope, name, params } => Type::Ref {
1410                scope: scope.clone(),
1411                name: name.clone(),
1412                params: Arc::from_iter(params.iter().map(|t| t.replace_tvars(known))),
1413            },
1414            Type::Error(t0) => Type::Error(Arc::new(t0.replace_tvars(known))),
1415            Type::Array(t0) => Type::Array(Arc::new(t0.replace_tvars(known))),
1416            Type::Map { key, value } => {
1417                let key = Arc::new(key.replace_tvars(known));
1418                let value = Arc::new(value.replace_tvars(known));
1419                Type::Map { key, value }
1420            }
1421            Type::ByRef(t0) => Type::ByRef(Arc::new(t0.replace_tvars(known))),
1422            Type::Tuple(ts) => {
1423                Type::Tuple(Arc::from_iter(ts.iter().map(|t| t.replace_tvars(known))))
1424            }
1425            Type::Struct(ts) => Type::Struct(Arc::from_iter(
1426                ts.iter().map(|(n, t)| (n.clone(), t.replace_tvars(known))),
1427            )),
1428            Type::Variant(tag, ts) => Type::Variant(
1429                tag.clone(),
1430                Arc::from_iter(ts.iter().map(|t| t.replace_tvars(known))),
1431            ),
1432            Type::Set(s) => {
1433                Type::Set(Arc::from_iter(s.iter().map(|t| t.replace_tvars(known))))
1434            }
1435            Type::Fn(fntyp) => Type::Fn(Arc::new(fntyp.replace_tvars(known))),
1436        }
1437    }
1438
1439    fn strip_error_int<R: Rt, E: UserEvent>(
1440        &self,
1441        env: &Env<R, E>,
1442        hist: &mut FxHashSet<usize>,
1443    ) -> Option<Type> {
1444        match self {
1445            Type::Error(t) => match t.strip_error_int(env, hist) {
1446                Some(t) => Some(t),
1447                None => Some((**t).clone()),
1448            },
1449            Type::TVar(tv) => {
1450                tv.read().typ.read().as_ref().and_then(|t| t.strip_error_int(env, hist))
1451            }
1452            Type::Primitive(p) => {
1453                if *p == BitFlags::from(Typ::Error) {
1454                    Some(Type::Any)
1455                } else {
1456                    None
1457                }
1458            }
1459            Type::Ref { .. } => {
1460                let t = self.lookup_ref(env).ok()?;
1461                let addr = t as *const Type as usize;
1462                if hist.insert(addr) {
1463                    t.strip_error_int(env, hist)
1464                } else {
1465                    None
1466                }
1467            }
1468            Type::Set(s) => {
1469                let r = Self::flatten_set(
1470                    s.iter().filter_map(|t| t.strip_error_int(env, hist)),
1471                );
1472                match r {
1473                    Type::Primitive(p) if p.is_empty() => None,
1474                    t => Some(t),
1475                }
1476            }
1477            Type::Array(_)
1478            | Type::Map { .. }
1479            | Type::ByRef(_)
1480            | Type::Tuple(_)
1481            | Type::Struct(_)
1482            | Type::Variant(_, _)
1483            | Type::Fn(_)
1484            | Type::Any
1485            | Type::Bottom => None,
1486        }
1487    }
1488
1489    /// remove the outer error type and return the inner payload, fail if self
1490    /// isn't an error or contains non error types
1491    pub fn strip_error<R: Rt, E: UserEvent>(&self, env: &Env<R, E>) -> Option<Self> {
1492        self.strip_error_int(env, &mut LPooled::take())
1493    }
1494
1495    /// Unbind any bound tvars, but do not unalias them.
1496    pub(crate) fn unbind_tvars(&self) {
1497        match self {
1498            Type::Bottom | Type::Any | Type::Primitive(_) | Type::Ref { .. } => (),
1499            Type::Error(t0) => t0.unbind_tvars(),
1500            Type::Array(t0) => t0.unbind_tvars(),
1501            Type::Map { key, value } => {
1502                key.unbind_tvars();
1503                value.unbind_tvars();
1504            }
1505            Type::ByRef(t0) => t0.unbind_tvars(),
1506            Type::Tuple(ts) | Type::Variant(_, ts) | Type::Set(ts) => {
1507                for t in ts.iter() {
1508                    t.unbind_tvars()
1509                }
1510            }
1511            Type::Struct(ts) => {
1512                for (_, t) in ts.iter() {
1513                    t.unbind_tvars()
1514                }
1515            }
1516            Type::TVar(tv) => tv.unbind(),
1517            Type::Fn(fntyp) => fntyp.unbind_tvars(),
1518        }
1519    }
1520
1521    fn check_cast_int<R: Rt, E: UserEvent>(
1522        &self,
1523        env: &Env<R, E>,
1524        hist: &mut FxHashSet<usize>,
1525    ) -> Result<()> {
1526        match self {
1527            Type::Primitive(_) | Type::Any => Ok(()),
1528            Type::Fn(_) => bail!("can't cast a value to a function"),
1529            Type::Bottom => bail!("can't cast a value to bottom"),
1530            Type::Set(s) => Ok(for t in s.iter() {
1531                t.check_cast_int(env, hist)?
1532            }),
1533            Type::TVar(tv) => match &*tv.read().typ.read() {
1534                Some(t) => t.check_cast_int(env, hist),
1535                None => bail!("can't cast a value to a free type variable"),
1536            },
1537            Type::Error(e) => e.check_cast_int(env, hist),
1538            Type::Array(et) => et.check_cast_int(env, hist),
1539            Type::Map { key, value } => {
1540                key.check_cast_int(env, hist)?;
1541                value.check_cast_int(env, hist)
1542            }
1543            Type::ByRef(_) => bail!("can't cast a reference"),
1544            Type::Tuple(ts) => Ok(for t in ts.iter() {
1545                t.check_cast_int(env, hist)?
1546            }),
1547            Type::Struct(ts) => Ok(for (_, t) in ts.iter() {
1548                t.check_cast_int(env, hist)?
1549            }),
1550            Type::Variant(_, ts) => Ok(for t in ts.iter() {
1551                t.check_cast_int(env, hist)?
1552            }),
1553            Type::Ref { .. } => {
1554                let t = self.lookup_ref(env)?;
1555                let t_addr = (t as *const Type).addr();
1556                if hist.contains(&t_addr) {
1557                    Ok(())
1558                } else {
1559                    hist.insert(t_addr);
1560                    t.check_cast_int(env, hist)
1561                }
1562            }
1563        }
1564    }
1565
1566    pub fn check_cast<R: Rt, E: UserEvent>(&self, env: &Env<R, E>) -> Result<()> {
1567        self.check_cast_int(env, &mut LPooled::take())
1568    }
1569
1570    fn cast_value_int<R: Rt, E: UserEvent>(
1571        &self,
1572        env: &Env<R, E>,
1573        hist: &mut FxHashSet<(usize, usize)>,
1574        v: Value,
1575    ) -> Result<Value> {
1576        if self.is_a_int(env, hist, &v) {
1577            return Ok(v);
1578        }
1579        match self {
1580            Type::Bottom => bail!("can't cast {v} to Bottom"),
1581            Type::Fn(_) => bail!("can't cast {v} to a function"),
1582            Type::ByRef(_) => bail!("can't cast {v} to a reference"),
1583            Type::Primitive(s) => s
1584                .iter()
1585                .find_map(|t| v.clone().cast(t))
1586                .ok_or_else(|| anyhow!("can't cast {v} to {self}")),
1587            Type::Any => Ok(v),
1588            Type::Error(e) => {
1589                let v = match v {
1590                    Value::Error(v) => (*v).clone(),
1591                    v => v,
1592                };
1593                Ok(Value::Error(Arc::new(e.cast_value_int(env, hist, v)?)))
1594            }
1595            Type::Array(et) => match v {
1596                Value::Array(elts) => {
1597                    let mut va = elts
1598                        .iter()
1599                        .map(|el| et.cast_value_int(env, hist, el.clone()))
1600                        .collect::<Result<LPooled<Vec<Value>>>>()?;
1601                    Ok(Value::Array(ValArray::from_iter_exact(va.drain(..))))
1602                }
1603                v => Ok(Value::Array([et.cast_value_int(env, hist, v)?].into())),
1604            },
1605            Type::Map { key, value } => match v {
1606                Value::Map(m) => {
1607                    let mut m = m
1608                        .into_iter()
1609                        .map(|(k, v)| {
1610                            Ok((
1611                                key.cast_value_int(env, hist, k.clone())?,
1612                                value.cast_value_int(env, hist, v.clone())?,
1613                            ))
1614                        })
1615                        .collect::<Result<LPooled<Vec<(Value, Value)>>>>()?;
1616                    Ok(Value::Map(Map::from_iter(m.drain(..))))
1617                }
1618                Value::Array(a) => {
1619                    let mut m = a
1620                        .iter()
1621                        .map(|a| match a {
1622                            Value::Array(a) if a.len() == 2 => Ok((
1623                                key.cast_value_int(env, hist, a[0].clone())?,
1624                                value.cast_value_int(env, hist, a[1].clone())?,
1625                            )),
1626                            _ => bail!("expected an array of pairs"),
1627                        })
1628                        .collect::<Result<LPooled<Vec<(Value, Value)>>>>()?;
1629                    Ok(Value::Map(Map::from_iter(m.drain(..))))
1630                }
1631                _ => bail!("can't cast {v} to {self}"),
1632            },
1633            Type::Tuple(ts) => match v {
1634                Value::Array(elts) => {
1635                    if elts.len() != ts.len() {
1636                        bail!("tuple size mismatch {self} with {}", Value::Array(elts))
1637                    }
1638                    let a = ts
1639                        .iter()
1640                        .zip(elts.iter())
1641                        .map(|(t, el)| t.cast_value_int(env, hist, el.clone()))
1642                        .collect::<Result<SmallVec<[Value; 8]>>>()?;
1643                    Ok(Value::Array(ValArray::from_iter_exact(a.into_iter())))
1644                }
1645                v => bail!("can't cast {v} to {self}"),
1646            },
1647            Type::Struct(ts) => match v {
1648                Value::Array(elts) => {
1649                    if elts.len() != ts.len() {
1650                        bail!("struct size mismatch {self} with {}", Value::Array(elts))
1651                    }
1652                    let is_pairs = elts.iter().all(|v| match v {
1653                        Value::Array(a) if a.len() == 2 => match &a[0] {
1654                            Value::String(_) => true,
1655                            _ => false,
1656                        },
1657                        _ => false,
1658                    });
1659                    if !is_pairs {
1660                        bail!("expected array of pairs, got {}", Value::Array(elts))
1661                    }
1662                    let mut elts_s: SmallVec<[&Value; 16]> = elts.iter().collect();
1663                    elts_s.sort_by_key(|v| match v {
1664                        Value::Array(a) => match &a[0] {
1665                            Value::String(s) => s,
1666                            _ => unreachable!(),
1667                        },
1668                        _ => unreachable!(),
1669                    });
1670                    let keys_ok = ts.iter().zip(elts_s.iter()).fold(
1671                        Ok(true),
1672                        |acc: Result<_>, ((fname, t), v)| {
1673                            let kok = acc?;
1674                            let (name, v) = match v {
1675                                Value::Array(a) => match (&a[0], &a[1]) {
1676                                    (Value::String(n), v) => (n, v),
1677                                    _ => unreachable!(),
1678                                },
1679                                _ => unreachable!(),
1680                            };
1681                            Ok(kok
1682                                && name == fname
1683                                && t.contains(env, &Type::Primitive(Typ::get(v).into()))?)
1684                        },
1685                    )?;
1686                    if keys_ok {
1687                        let elts = ts
1688                            .iter()
1689                            .zip(elts_s.iter())
1690                            .map(|((n, t), v)| match v {
1691                                Value::Array(a) => {
1692                                    let a = [
1693                                        Value::String(n.clone()),
1694                                        t.cast_value_int(env, hist, a[1].clone())?,
1695                                    ];
1696                                    Ok(Value::Array(ValArray::from_iter_exact(
1697                                        a.into_iter(),
1698                                    )))
1699                                }
1700                                _ => unreachable!(),
1701                            })
1702                            .collect::<Result<SmallVec<[Value; 8]>>>()?;
1703                        Ok(Value::Array(ValArray::from_iter_exact(elts.into_iter())))
1704                    } else {
1705                        drop(elts_s);
1706                        bail!("struct fields mismatch {self}, {}", Value::Array(elts))
1707                    }
1708                }
1709                v => bail!("can't cast {v} to {self}"),
1710            },
1711            Type::Variant(tag, ts) if ts.len() == 0 => match &v {
1712                Value::String(s) if s == tag => Ok(v),
1713                _ => bail!("variant tag mismatch expected {tag} got {v}"),
1714            },
1715            Type::Variant(tag, ts) => match &v {
1716                Value::Array(elts) => {
1717                    if ts.len() + 1 == elts.len() {
1718                        match &elts[0] {
1719                            Value::String(s) if s == tag => (),
1720                            v => bail!("variant tag mismatch expected {tag} got {v}"),
1721                        }
1722                        let a = iter::once(&Type::Primitive(Typ::String.into()))
1723                            .chain(ts.iter())
1724                            .zip(elts.iter())
1725                            .map(|(t, v)| t.cast_value_int(env, hist, v.clone()))
1726                            .collect::<Result<SmallVec<[Value; 8]>>>()?;
1727                        Ok(Value::Array(ValArray::from_iter_exact(a.into_iter())))
1728                    } else if ts.len() == elts.len() {
1729                        let mut a = ts
1730                            .iter()
1731                            .zip(elts.iter())
1732                            .map(|(t, v)| t.cast_value_int(env, hist, v.clone()))
1733                            .collect::<Result<SmallVec<[Value; 8]>>>()?;
1734                        a.insert(0, Value::String(tag.clone()));
1735                        Ok(Value::Array(ValArray::from_iter_exact(a.into_iter())))
1736                    } else {
1737                        bail!("variant length mismatch")
1738                    }
1739                }
1740                v => bail!("can't cast {v} to {self}"),
1741            },
1742            Type::Ref { .. } => self.lookup_ref(env)?.cast_value_int(env, hist, v),
1743            Type::Set(ts) => ts
1744                .iter()
1745                .find_map(|t| t.cast_value_int(env, hist, v.clone()).ok())
1746                .ok_or_else(|| anyhow!("can't cast {v} to {self}")),
1747            Type::TVar(tv) => match &*tv.read().typ.read() {
1748                Some(t) => t.cast_value_int(env, hist, v.clone()),
1749                None => Ok(v),
1750            },
1751        }
1752    }
1753
1754    pub fn cast_value<R: Rt, E: UserEvent>(&self, env: &Env<R, E>, v: Value) -> Value {
1755        match self.cast_value_int(env, &mut LPooled::take(), v) {
1756            Ok(v) => v,
1757            Err(e) => errf!(CAST_ERR_TAG, "{e:?}"),
1758        }
1759    }
1760
1761    fn is_a_int<R: Rt, E: UserEvent>(
1762        &self,
1763        env: &Env<R, E>,
1764        hist: &mut FxHashSet<(usize, usize)>,
1765        v: &Value,
1766    ) -> bool {
1767        match self {
1768            Type::Ref { .. } => match self.lookup_ref(env) {
1769                Err(_) => false,
1770                Ok(t) => {
1771                    let t_addr = (t as *const Type).addr();
1772                    let v_addr = (v as *const Value).addr();
1773                    !hist.contains(&(t_addr, v_addr)) && {
1774                        hist.insert((t_addr, v_addr));
1775                        t.is_a_int(env, hist, v)
1776                    }
1777                }
1778            },
1779            Type::Primitive(t) => t.contains(Typ::get(&v)),
1780            Type::Any => true,
1781            Type::Array(et) => match v {
1782                Value::Array(a) => a.iter().all(|v| et.is_a_int(env, hist, v)),
1783                _ => false,
1784            },
1785            Type::Map { key, value } => match v {
1786                Value::Map(m) => m.into_iter().all(|(k, v)| {
1787                    key.is_a_int(env, hist, k) && value.is_a_int(env, hist, v)
1788                }),
1789                _ => false,
1790            },
1791            Type::Error(e) => match v {
1792                Value::Error(v) => e.is_a_int(env, hist, v),
1793                _ => false,
1794            },
1795            Type::ByRef(_) => matches!(v, Value::U64(_) | Value::V64(_)),
1796            Type::Tuple(ts) => match v {
1797                Value::Array(elts) => {
1798                    elts.len() == ts.len()
1799                        && ts
1800                            .iter()
1801                            .zip(elts.iter())
1802                            .all(|(t, v)| t.is_a_int(env, hist, v))
1803                }
1804                _ => false,
1805            },
1806            Type::Struct(ts) => match v {
1807                Value::Array(elts) => {
1808                    elts.len() == ts.len()
1809                        && ts.iter().zip(elts.iter()).all(|((n, t), v)| match v {
1810                            Value::Array(a) if a.len() == 2 => match &a[..] {
1811                                [Value::String(key), v] => {
1812                                    n == key && t.is_a_int(env, hist, v)
1813                                }
1814                                _ => false,
1815                            },
1816                            _ => false,
1817                        })
1818                }
1819                _ => false,
1820            },
1821            Type::Variant(tag, ts) if ts.len() == 0 => match &v {
1822                Value::String(s) => s == tag,
1823                _ => false,
1824            },
1825            Type::Variant(tag, ts) => match &v {
1826                Value::Array(elts) => {
1827                    ts.len() + 1 == elts.len()
1828                        && match &elts[0] {
1829                            Value::String(s) => s == tag,
1830                            _ => false,
1831                        }
1832                        && ts
1833                            .iter()
1834                            .zip(elts[1..].iter())
1835                            .all(|(t, v)| t.is_a_int(env, hist, v))
1836                }
1837                _ => false,
1838            },
1839            Type::TVar(tv) => match &*tv.read().typ.read() {
1840                None => true,
1841                Some(t) => t.is_a_int(env, hist, v),
1842            },
1843            Type::Fn(_) => match v {
1844                Value::U64(_) => true,
1845                _ => false,
1846            },
1847            Type::Bottom => true,
1848            Type::Set(ts) => ts.iter().any(|t| t.is_a_int(env, hist, v)),
1849        }
1850    }
1851
1852    /// return true if v is structurally compatible with the type
1853    pub fn is_a<R: Rt, E: UserEvent>(&self, env: &Env<R, E>, v: &Value) -> bool {
1854        self.is_a_int(env, &mut LPooled::take(), v)
1855    }
1856
1857    pub fn is_bot(&self) -> bool {
1858        match self {
1859            Type::Bottom => true,
1860            Type::Any
1861            | Type::TVar(_)
1862            | Type::Primitive(_)
1863            | Type::Ref { .. }
1864            | Type::Fn(_)
1865            | Type::Error(_)
1866            | Type::Array(_)
1867            | Type::ByRef(_)
1868            | Type::Tuple(_)
1869            | Type::Struct(_)
1870            | Type::Variant(_, _)
1871            | Type::Set(_)
1872            | Type::Map { .. } => false,
1873        }
1874    }
1875
1876    pub fn with_deref<R, F: FnOnce(Option<&Self>) -> R>(&self, f: F) -> R {
1877        match self {
1878            Self::Bottom
1879            | Self::Any
1880            | Self::Primitive(_)
1881            | Self::Fn(_)
1882            | Self::Set(_)
1883            | Self::Error(_)
1884            | Self::Array(_)
1885            | Self::ByRef(_)
1886            | Self::Tuple(_)
1887            | Self::Struct(_)
1888            | Self::Variant(_, _)
1889            | Self::Ref { .. }
1890            | Self::Map { .. } => f(Some(self)),
1891            Self::TVar(tv) => match tv.read().typ.read().as_ref() {
1892                Some(t) => t.with_deref(f),
1893                None => f(None),
1894            },
1895        }
1896    }
1897
1898    pub(crate) fn flatten_set(set: impl IntoIterator<Item = Self>) -> Self {
1899        let init: Box<dyn Iterator<Item = Self>> = Box::new(set.into_iter());
1900        let mut iters: LPooled<Vec<Box<dyn Iterator<Item = Self>>>> =
1901            LPooled::from_iter([init]);
1902        let mut acc: LPooled<Vec<Self>> = LPooled::take();
1903        loop {
1904            match iters.last_mut() {
1905                None => break,
1906                Some(iter) => match iter.next() {
1907                    None => {
1908                        iters.pop();
1909                    }
1910                    Some(Type::Set(s)) => {
1911                        let v: SmallVec<[Self; 16]> =
1912                            s.iter().map(|t| t.clone()).collect();
1913                        iters.push(Box::new(v.into_iter()))
1914                    }
1915                    Some(Type::Any) => return Type::Any,
1916                    Some(t) => {
1917                        acc.push(t);
1918                        let mut i = 0;
1919                        let mut j = 0;
1920                        while i < acc.len() {
1921                            while j < acc.len() {
1922                                if j == i {
1923                                    j += 1;
1924                                    continue;
1925                                }
1926                                match acc[i].merge(&acc[j]) {
1927                                    None => j += 1,
1928                                    Some(t) => {
1929                                        acc[i] = t;
1930                                        acc.remove(j);
1931                                        i = 0;
1932                                        j = 0;
1933                                    }
1934                                }
1935                            }
1936                            i += 1;
1937                            j = 0;
1938                        }
1939                    }
1940                },
1941            }
1942        }
1943        acc.sort();
1944        match &**acc {
1945            [] => Type::Primitive(BitFlags::empty()),
1946            [t] => t.clone(),
1947            _ => Type::Set(Arc::from_iter(acc.drain(..))),
1948        }
1949    }
1950
1951    pub(crate) fn normalize(&self) -> Self {
1952        match self {
1953            Type::Bottom | Type::Any | Type::Primitive(_) => self.clone(),
1954            Type::Ref { scope, name, params } => {
1955                let params = Arc::from_iter(params.iter().map(|t| t.normalize()));
1956                Type::Ref { scope: scope.clone(), name: name.clone(), params }
1957            }
1958            Type::TVar(tv) => Type::TVar(tv.normalize()),
1959            Type::Set(s) => Self::flatten_set(s.iter().map(|t| t.normalize())),
1960            Type::Error(t) => Type::Error(Arc::new(t.normalize())),
1961            Type::Array(t) => Type::Array(Arc::new(t.normalize())),
1962            Type::Map { key, value } => {
1963                let key = Arc::new(key.normalize());
1964                let value = Arc::new(value.normalize());
1965                Type::Map { key, value }
1966            }
1967            Type::ByRef(t) => Type::ByRef(Arc::new(t.normalize())),
1968            Type::Tuple(t) => {
1969                Type::Tuple(Arc::from_iter(t.iter().map(|t| t.normalize())))
1970            }
1971            Type::Struct(t) => Type::Struct(Arc::from_iter(
1972                t.iter().map(|(n, t)| (n.clone(), t.normalize())),
1973            )),
1974            Type::Variant(tag, t) => Type::Variant(
1975                tag.clone(),
1976                Arc::from_iter(t.iter().map(|t| t.normalize())),
1977            ),
1978            Type::Fn(ft) => Type::Fn(Arc::new(ft.normalize())),
1979        }
1980    }
1981
1982    fn merge(&self, t: &Self) -> Option<Self> {
1983        macro_rules! flatten {
1984            ($t:expr) => {
1985                match $t {
1986                    Type::Set(et) => Self::flatten_set(et.iter().cloned()),
1987                    t => t.clone(),
1988                }
1989            };
1990        }
1991        match (self, t) {
1992            (
1993                Type::Ref { scope: s0, name: r0, params: a0 },
1994                Type::Ref { scope: s1, name: r1, params: a1 },
1995            ) => {
1996                if s0 == s1 && r0 == r1 && a0 == a1 {
1997                    Some(Type::Ref {
1998                        scope: s0.clone(),
1999                        name: r0.clone(),
2000                        params: a0.clone(),
2001                    })
2002                } else {
2003                    None
2004                }
2005            }
2006            (Type::Ref { .. }, _) | (_, Type::Ref { .. }) => None,
2007            (Type::Bottom, t) | (t, Type::Bottom) => Some(t.clone()),
2008            (Type::Any, _) | (_, Type::Any) => Some(Type::Any),
2009            (Type::Primitive(s0), Type::Primitive(s1)) => {
2010                Some(Type::Primitive(*s0 | *s1))
2011            }
2012            (Type::Primitive(p), t) | (t, Type::Primitive(p)) if p.is_empty() => {
2013                Some(t.clone())
2014            }
2015            (Type::Fn(f0), Type::Fn(f1)) => {
2016                if f0 == f1 {
2017                    Some(Type::Fn(f0.clone()))
2018                } else {
2019                    None
2020                }
2021            }
2022            (Type::Array(t0), Type::Array(t1)) => {
2023                if flatten!(&**t0) == flatten!(&**t1) {
2024                    Some(Type::Array(t0.clone()))
2025                } else {
2026                    None
2027                }
2028            }
2029            (Type::Primitive(p), Type::Array(_))
2030            | (Type::Array(_), Type::Primitive(p)) => {
2031                if p.contains(Typ::Array) {
2032                    Some(Type::Primitive(*p))
2033                } else {
2034                    None
2035                }
2036            }
2037            (Type::Map { key: k0, value: v0 }, Type::Map { key: k1, value: v1 }) => {
2038                if flatten!(&**k0) == flatten!(&**k1)
2039                    && flatten!(&**v0) == flatten!(&**v1)
2040                {
2041                    Some(Type::Map { key: k0.clone(), value: v0.clone() })
2042                } else {
2043                    None
2044                }
2045            }
2046            (Type::Error(t0), Type::Error(t1)) => {
2047                if flatten!(&**t0) == flatten!(&**t1) {
2048                    Some(Type::Error(t0.clone()))
2049                } else {
2050                    None
2051                }
2052            }
2053            (Type::ByRef(t0), Type::ByRef(t1)) => {
2054                t0.merge(t1).map(|t| Type::ByRef(Arc::new(t)))
2055            }
2056            (Type::Set(s0), Type::Set(s1)) => {
2057                Some(Self::flatten_set(s0.iter().cloned().chain(s1.iter().cloned())))
2058            }
2059            (Type::Set(s), Type::Primitive(p)) | (Type::Primitive(p), Type::Set(s))
2060                if p.is_empty() =>
2061            {
2062                Some(Type::Set(s.clone()))
2063            }
2064            (Type::Set(s), t) | (t, Type::Set(s)) => {
2065                Some(Self::flatten_set(s.iter().cloned().chain(iter::once(t.clone()))))
2066            }
2067            (Type::Tuple(t0), Type::Tuple(t1)) => {
2068                if t0.len() == t1.len() {
2069                    let t = t0
2070                        .iter()
2071                        .zip(t1.iter())
2072                        .map(|(t0, t1)| t0.merge(t1))
2073                        .collect::<Option<SmallVec<[Type; 8]>>>()?;
2074                    Some(Type::Tuple(Arc::from_iter(t)))
2075                } else {
2076                    None
2077                }
2078            }
2079            (Type::Variant(tag0, t0), Type::Variant(tag1, t1)) => {
2080                if tag0 == tag1 && t0.len() == t1.len() {
2081                    let t = t0
2082                        .iter()
2083                        .zip(t1.iter())
2084                        .map(|(t0, t1)| t0.merge(t1))
2085                        .collect::<Option<SmallVec<[Type; 8]>>>()?;
2086                    Some(Type::Variant(tag0.clone(), Arc::from_iter(t)))
2087                } else {
2088                    None
2089                }
2090            }
2091            (Type::Struct(t0), Type::Struct(t1)) => {
2092                if t0.len() == t1.len() {
2093                    let t = t0
2094                        .iter()
2095                        .zip(t1.iter())
2096                        .map(|((n0, t0), (n1, t1))| {
2097                            if n0 != n1 {
2098                                None
2099                            } else {
2100                                t0.merge(t1).map(|t| (n0.clone(), t))
2101                            }
2102                        })
2103                        .collect::<Option<SmallVec<[(ArcStr, Type); 8]>>>()?;
2104                    Some(Type::Struct(Arc::from_iter(t)))
2105                } else {
2106                    None
2107                }
2108            }
2109            (Type::TVar(tv0), Type::TVar(tv1)) if tv0.name == tv1.name && tv0 == tv1 => {
2110                Some(Type::TVar(tv0.clone()))
2111            }
2112            (Type::TVar(tv), t) => {
2113                tv.read().typ.read().as_ref().and_then(|tv| tv.merge(t))
2114            }
2115            (t, Type::TVar(tv)) => {
2116                tv.read().typ.read().as_ref().and_then(|tv| t.merge(tv))
2117            }
2118            (Type::ByRef(_), _)
2119            | (_, Type::ByRef(_))
2120            | (Type::Array(_), _)
2121            | (_, Type::Array(_))
2122            | (_, Type::Map { .. })
2123            | (Type::Map { .. }, _)
2124            | (Type::Tuple(_), _)
2125            | (_, Type::Tuple(_))
2126            | (Type::Struct(_), _)
2127            | (_, Type::Struct(_))
2128            | (Type::Variant(_, _), _)
2129            | (_, Type::Variant(_, _))
2130            | (_, Type::Fn(_))
2131            | (Type::Fn(_), _)
2132            | (Type::Error(_), _)
2133            | (_, Type::Error(_)) => None,
2134        }
2135    }
2136
2137    pub fn scope_refs(&self, scope: &ModPath) -> Type {
2138        match self {
2139            Type::Bottom => Type::Bottom,
2140            Type::Any => Type::Any,
2141            Type::Primitive(s) => Type::Primitive(*s),
2142            Type::Error(t0) => Type::Error(Arc::new(t0.scope_refs(scope))),
2143            Type::Array(t0) => Type::Array(Arc::new(t0.scope_refs(scope))),
2144            Type::Map { key, value } => {
2145                let key = Arc::new(key.scope_refs(scope));
2146                let value = Arc::new(value.scope_refs(scope));
2147                Type::Map { key, value }
2148            }
2149            Type::ByRef(t) => Type::ByRef(Arc::new(t.scope_refs(scope))),
2150            Type::Tuple(ts) => {
2151                let i = ts.iter().map(|t| t.scope_refs(scope));
2152                Type::Tuple(Arc::from_iter(i))
2153            }
2154            Type::Variant(tag, ts) => {
2155                let i = ts.iter().map(|t| t.scope_refs(scope));
2156                Type::Variant(tag.clone(), Arc::from_iter(i))
2157            }
2158            Type::Struct(ts) => {
2159                let i = ts.iter().map(|(n, t)| (n.clone(), t.scope_refs(scope)));
2160                Type::Struct(Arc::from_iter(i))
2161            }
2162            Type::TVar(tv) => match tv.read().typ.read().as_ref() {
2163                None => Type::TVar(TVar::empty_named(tv.name.clone())),
2164                Some(typ) => {
2165                    let typ = typ.scope_refs(scope);
2166                    Type::TVar(TVar::named(tv.name.clone(), typ))
2167                }
2168            },
2169            Type::Ref { scope: _, name, params } => {
2170                let params = Arc::from_iter(params.iter().map(|t| t.scope_refs(scope)));
2171                Type::Ref { scope: scope.clone(), name: name.clone(), params }
2172            }
2173            Type::Set(ts) => {
2174                Type::Set(Arc::from_iter(ts.iter().map(|t| t.scope_refs(scope))))
2175            }
2176            Type::Fn(f) => Type::Fn(Arc::new(f.scope_refs(scope))),
2177        }
2178    }
2179}
2180
2181impl fmt::Display for Type {
2182    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
2183        match self {
2184            Self::Bottom => write!(f, "_"),
2185            Self::Any => write!(f, "Any"),
2186            Self::Ref { scope: _, name, params } => {
2187                write!(f, "{name}")?;
2188                if !params.is_empty() {
2189                    write!(f, "<")?;
2190                    for (i, t) in params.iter().enumerate() {
2191                        write!(f, "{t}")?;
2192                        if i < params.len() - 1 {
2193                            write!(f, ", ")?;
2194                        }
2195                    }
2196                    write!(f, ">")?;
2197                }
2198                Ok(())
2199            }
2200            Self::TVar(tv) => write!(f, "{tv}"),
2201            Self::Fn(t) => write!(f, "{t}"),
2202            Self::Error(t) => write!(f, "Error<{t}>"),
2203            Self::Array(t) => write!(f, "Array<{t}>"),
2204            Self::Map { key, value } => write!(f, "Map<{key}, {value}>"),
2205            Self::ByRef(t) => write!(f, "&{t}"),
2206            Self::Tuple(ts) => {
2207                write!(f, "(")?;
2208                for (i, t) in ts.iter().enumerate() {
2209                    write!(f, "{t}")?;
2210                    if i < ts.len() - 1 {
2211                        write!(f, ", ")?;
2212                    }
2213                }
2214                write!(f, ")")
2215            }
2216            Self::Variant(tag, ts) if ts.len() == 0 => {
2217                write!(f, "`{tag}")
2218            }
2219            Self::Variant(tag, ts) => {
2220                write!(f, "`{tag}(")?;
2221                for (i, t) in ts.iter().enumerate() {
2222                    write!(f, "{t}")?;
2223                    if i < ts.len() - 1 {
2224                        write!(f, ", ")?
2225                    }
2226                }
2227                write!(f, ")")
2228            }
2229            Self::Struct(ts) => {
2230                write!(f, "{{")?;
2231                for (i, (n, t)) in ts.iter().enumerate() {
2232                    write!(f, "{n}: {t}")?;
2233                    if i < ts.len() - 1 {
2234                        write!(f, ", ")?
2235                    }
2236                }
2237                write!(f, "}}")
2238            }
2239            Self::Set(s) => {
2240                write!(f, "[")?;
2241                for (i, t) in s.iter().enumerate() {
2242                    write!(f, "{t}")?;
2243                    if i < s.len() - 1 {
2244                        write!(f, ", ")?;
2245                    }
2246                }
2247                write!(f, "]")
2248            }
2249            Self::Primitive(s) => {
2250                let replace = PRINT_FLAGS.get().contains(PrintFlag::ReplacePrims);
2251                if replace && *s == Typ::number() {
2252                    write!(f, "Number")
2253                } else if replace && *s == Typ::float() {
2254                    write!(f, "Float")
2255                } else if replace && *s == Typ::real() {
2256                    write!(f, "Real")
2257                } else if replace && *s == Typ::integer() {
2258                    write!(f, "Int")
2259                } else if replace && *s == Typ::unsigned_integer() {
2260                    write!(f, "Uint")
2261                } else if replace && *s == Typ::signed_integer() {
2262                    write!(f, "Sint")
2263                } else if s.len() == 0 {
2264                    write!(f, "[]")
2265                } else if s.len() == 1 {
2266                    write!(f, "{}", s.iter().next().unwrap())
2267                } else {
2268                    let mut s = *s;
2269                    macro_rules! builtin {
2270                        ($set:expr, $name:literal) => {
2271                            if replace && s.contains($set) {
2272                                s.remove($set);
2273                                write!(f, $name)?;
2274                                if !s.is_empty() {
2275                                    write!(f, ", ")?
2276                                }
2277                            }
2278                        };
2279                    }
2280                    write!(f, "[")?;
2281                    builtin!(Typ::number(), "Number");
2282                    builtin!(Typ::real(), "Real");
2283                    builtin!(Typ::float(), "Float");
2284                    builtin!(Typ::integer(), "Int");
2285                    builtin!(Typ::unsigned_integer(), "Uint");
2286                    builtin!(Typ::signed_integer(), "Sint");
2287                    for (i, t) in s.iter().enumerate() {
2288                        write!(f, "{t}")?;
2289                        if i < s.len() - 1 {
2290                            write!(f, ", ")?;
2291                        }
2292                    }
2293                    write!(f, "]")
2294                }
2295            }
2296        }
2297    }
2298}
2299
2300impl PrettyDisplay for Type {
2301    fn fmt_pretty_inner(&self, buf: &mut PrettyBuf) -> fmt::Result {
2302        match self {
2303            Self::Bottom => writeln!(buf, "_"),
2304            Self::Any => writeln!(buf, "Any"),
2305            Self::Ref { scope: _, name, params } => {
2306                if params.is_empty() {
2307                    writeln!(buf, "{name}")
2308                } else {
2309                    writeln!(buf, "{name}<")?;
2310                    buf.with_indent(2, |buf| {
2311                        for (i, t) in params.iter().enumerate() {
2312                            t.fmt_pretty(buf)?;
2313                            if i < params.len() - 1 {
2314                                buf.kill_newline();
2315                                writeln!(buf, ",")?;
2316                            }
2317                        }
2318                        Ok(())
2319                    })?;
2320                    writeln!(buf, ">")
2321                }
2322            }
2323            Self::TVar(tv) => writeln!(buf, "{tv}"),
2324            Self::Fn(t) => t.fmt_pretty(buf),
2325            Self::Error(t) => {
2326                writeln!(buf, "Error<")?;
2327                buf.with_indent(2, |buf| t.fmt_pretty(buf))?;
2328                writeln!(buf, ">")
2329            }
2330            Self::Array(t) => {
2331                writeln!(buf, "Array<")?;
2332                buf.with_indent(2, |buf| t.fmt_pretty(buf))?;
2333                writeln!(buf, ">")
2334            }
2335            Self::Map { key, value } => {
2336                writeln!(buf, "Map<")?;
2337                buf.with_indent(2, |buf| {
2338                    key.fmt_pretty(buf)?;
2339                    buf.kill_newline();
2340                    writeln!(buf, ",")?;
2341                    value.fmt_pretty(buf)
2342                })?;
2343                writeln!(buf, ">")
2344            }
2345            Self::ByRef(t) => {
2346                write!(buf, "&")?;
2347                t.fmt_pretty(buf)
2348            }
2349            Self::Tuple(ts) => {
2350                writeln!(buf, "(")?;
2351                buf.with_indent(2, |buf| {
2352                    for (i, t) in ts.iter().enumerate() {
2353                        t.fmt_pretty(buf)?;
2354                        if i < ts.len() - 1 {
2355                            buf.kill_newline();
2356                            writeln!(buf, ",")?;
2357                        }
2358                    }
2359                    Ok(())
2360                })?;
2361                writeln!(buf, ")")
2362            }
2363            Self::Variant(tag, ts) if ts.is_empty() => writeln!(buf, "`{tag}"),
2364            Self::Variant(tag, ts) => {
2365                writeln!(buf, "`{tag}(")?;
2366                buf.with_indent(2, |buf| {
2367                    for (i, t) in ts.iter().enumerate() {
2368                        t.fmt_pretty(buf)?;
2369                        if i < ts.len() - 1 {
2370                            buf.kill_newline();
2371                            writeln!(buf, ",")?;
2372                        }
2373                    }
2374                    Ok(())
2375                })?;
2376                writeln!(buf, ")")
2377            }
2378            Self::Struct(ts) => {
2379                writeln!(buf, "{{")?;
2380                buf.with_indent(2, |buf| {
2381                    for (i, (n, t)) in ts.iter().enumerate() {
2382                        write!(buf, "{n}: ")?;
2383                        buf.with_indent(2, |buf| t.fmt_pretty(buf))?;
2384                        if i < ts.len() - 1 {
2385                            buf.kill_newline();
2386                            writeln!(buf, ",")?;
2387                        }
2388                    }
2389                    Ok(())
2390                })?;
2391                writeln!(buf, "}}")
2392            }
2393            Self::Set(s) => {
2394                writeln!(buf, "[")?;
2395                buf.with_indent(2, |buf| {
2396                    for (i, t) in s.iter().enumerate() {
2397                        t.fmt_pretty(buf)?;
2398                        if i < s.len() - 1 {
2399                            buf.kill_newline();
2400                            writeln!(buf, ",")?;
2401                        }
2402                    }
2403                    Ok(())
2404                })?;
2405                writeln!(buf, "]")
2406            }
2407            Self::Primitive(_) => {
2408                // Primitives are simple enough to just use Display
2409                writeln!(buf, "{self}")
2410            }
2411        }
2412    }
2413}