netidx_bscript/typ/
mod.rs

1use crate::{env::Env, expr::ModPath, Ctx, UserEvent};
2use anyhow::{anyhow, bail, Result};
3use arcstr::ArcStr;
4use enumflags2::{bitflags, BitFlags};
5use fxhash::{FxHashMap, FxHashSet};
6use netidx::{
7    publisher::{Typ, Value},
8    utils::Either,
9};
10use netidx_netproto::valarray::ValArray;
11use parking_lot::RwLock;
12use smallvec::{smallvec, SmallVec};
13use std::{
14    cell::{Cell, RefCell},
15    cmp::{Eq, PartialEq},
16    collections::{hash_map::Entry, HashMap, HashSet},
17    fmt::{self, Debug},
18    iter,
19};
20use triomphe::Arc;
21
22mod fntyp;
23mod tval;
24mod tvar;
25
26pub use fntyp::{FnArgType, FnType};
27pub use tval::TVal;
28use tvar::would_cycle_inner;
29pub use tvar::TVar;
30
31struct AndAc(bool);
32
33impl FromIterator<bool> for AndAc {
34    fn from_iter<T: IntoIterator<Item = bool>>(iter: T) -> Self {
35        AndAc(iter.into_iter().all(|b| b))
36    }
37}
38
39#[derive(Debug, Clone, Copy)]
40#[bitflags]
41#[repr(u64)]
42pub enum PrintFlag {
43    /// Dereference type variables and print both the tvar name and
44    /// the bound type or "unbound".
45    DerefTVars,
46    /// Replace common primitives with shorter type names as defined
47    /// in core. e.g. Any, instead of the set of every primitive type.
48    ReplacePrims,
49}
50
51thread_local! {
52    static PRINT_FLAGS: Cell<BitFlags<PrintFlag>> = Cell::new(PrintFlag::ReplacePrims.into());
53}
54
55/// For the duration of the closure F change the way type variables
56/// are formatted (on this thread only) according to the specified
57/// flags.
58pub fn format_with_flags<R, F: FnOnce() -> R>(flags: BitFlags<PrintFlag>, f: F) -> R {
59    let prev = PRINT_FLAGS.replace(flags);
60    let res = f();
61    PRINT_FLAGS.set(prev);
62    res
63}
64
65#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
66pub enum Type {
67    Bottom,
68    Primitive(BitFlags<Typ>),
69    Ref { scope: ModPath, name: ModPath, params: Arc<[Type]> },
70    Fn(Arc<FnType>),
71    Set(Arc<[Type]>),
72    TVar(TVar),
73    Array(Arc<Type>),
74    ByRef(Arc<Type>),
75    Tuple(Arc<[Type]>),
76    Struct(Arc<[(ArcStr, Type)]>),
77    Variant(ArcStr, Arc<[Type]>),
78}
79
80impl Default for Type {
81    fn default() -> Self {
82        Self::Bottom
83    }
84}
85
86impl Type {
87    pub fn empty_tvar() -> Self {
88        Type::TVar(TVar::default())
89    }
90
91    fn iter_prims(&self) -> impl Iterator<Item = Self> {
92        match self {
93            Self::Primitive(p) => {
94                Either::Left(p.iter().map(|t| Type::Primitive(t.into())))
95            }
96            t => Either::Right(iter::once(t.clone())),
97        }
98    }
99
100    pub fn is_defined(&self) -> bool {
101        match self {
102            Self::Bottom
103            | Self::Primitive(_)
104            | Self::Fn(_)
105            | Self::Set(_)
106            | Self::Array(_)
107            | Self::ByRef(_)
108            | Self::Tuple(_)
109            | Self::Struct(_)
110            | Self::Variant(_, _) => true,
111            Self::TVar(tv) => tv.read().typ.read().is_some(),
112            Self::Ref { .. } => true,
113        }
114    }
115
116    pub fn lookup_ref<'a, C: Ctx, E: UserEvent>(
117        &'a self,
118        env: &'a Env<C, E>,
119    ) -> Result<&'a Type> {
120        match self {
121            Self::Ref { scope, name, params } => {
122                let def = env
123                    .lookup_typedef(scope, name)
124                    .ok_or_else(|| anyhow!("undefined type {scope}::{name}"))?;
125                if def.params.len() != params.len() {
126                    bail!("{} expects {} type parameters", name, def.params.len());
127                }
128                def.typ.unbind_tvars();
129                for ((tv, ct), arg) in def.params.iter().zip(params.iter()) {
130                    if let Some(ct) = ct {
131                        ct.check_contains(env, arg)?;
132                    }
133                    if !tv.would_cycle(arg) {
134                        *tv.read().typ.write() = Some(arg.clone());
135                    }
136                }
137                Ok(&def.typ)
138            }
139            t => Ok(t),
140        }
141    }
142
143    pub fn check_contains<C: Ctx, E: UserEvent>(
144        &self,
145        env: &Env<C, E>,
146        t: &Self,
147    ) -> Result<()> {
148        if self.contains(env, t)? {
149            Ok(())
150        } else {
151            format_with_flags(PrintFlag::DerefTVars | PrintFlag::ReplacePrims, || {
152                bail!("type mismatch {self} does not contain {t}")
153            })
154        }
155    }
156
157    fn contains_int<C: Ctx, E: UserEvent>(
158        &self,
159        env: &Env<C, E>,
160        hist: &mut FxHashMap<(usize, usize), bool>,
161        t: &Self,
162    ) -> Result<bool> {
163        match (self, t) {
164            (
165                Self::Ref { scope: s0, name: n0, .. },
166                Self::Ref { scope: s1, name: n1, .. },
167            ) if s0 == s1 && n0 == n1 => Ok(true),
168            (t0 @ Self::Ref { .. }, t1) | (t0, t1 @ Self::Ref { .. }) => {
169                let t0 = t0.lookup_ref(env)?;
170                let t1 = t1.lookup_ref(env)?;
171                let t0_addr = (t0 as *const Type).addr();
172                let t1_addr = (t1 as *const Type).addr();
173                match hist.get(&(t0_addr, t1_addr)) {
174                    Some(r) => Ok(*r),
175                    None => {
176                        hist.insert((t0_addr, t1_addr), true);
177                        match t0.contains_int(env, hist, t1) {
178                            Ok(r) => {
179                                hist.insert((t0_addr, t1_addr), r);
180                                Ok(r)
181                            }
182                            Err(e) => {
183                                hist.remove(&(t0_addr, t1_addr));
184                                Err(e)
185                            }
186                        }
187                    }
188                }
189            }
190            (Self::TVar(t0), Self::Bottom) => {
191                if let Some(_) = &*t0.read().typ.read() {
192                    return Ok(true);
193                }
194                *t0.read().typ.write() = Some(Self::Bottom);
195                Ok(true)
196            }
197            (Self::Bottom, _) | (_, Self::Bottom) => Ok(true),
198            (Self::Primitive(p0), Self::Primitive(p1)) => Ok(p0.contains(*p1)),
199            (
200                Self::Primitive(p),
201                Self::Array(_) | Self::Tuple(_) | Self::Struct(_) | Self::Variant(_, _),
202            ) => Ok(p.contains(Typ::Array)),
203            (Self::Array(t0), Self::Array(t1)) => t0.contains_int(env, hist, t1),
204            (Self::Array(t0), Self::Primitive(p)) if *p == BitFlags::from(Typ::Array) => {
205                t0.contains_int(env, hist, &Type::Primitive(BitFlags::all()))
206            }
207            (Self::Tuple(t0), Self::Tuple(t1)) => Ok(t0.len() == t1.len()
208                && t0
209                    .iter()
210                    .zip(t1.iter())
211                    .map(|(t0, t1)| t0.contains_int(env, hist, t1))
212                    .collect::<Result<AndAc>>()?
213                    .0),
214            (Self::Struct(t0), Self::Struct(t1)) => {
215                Ok(t0.len() == t1.len() && {
216                    // struct types are always sorted by field name
217                    t0.iter()
218                        .zip(t1.iter())
219                        .map(|((n0, t0), (n1, t1))| {
220                            Ok(n0 == n1 && t0.contains_int(env, hist, t1)?)
221                        })
222                        .collect::<Result<AndAc>>()?
223                        .0
224                })
225            }
226            (Self::Variant(tg0, t0), Self::Variant(tg1, t1)) => Ok(tg0 == tg1
227                && t0.len() == t1.len()
228                && t0
229                    .iter()
230                    .zip(t1.iter())
231                    .map(|(t0, t1)| t0.contains_int(env, hist, t1))
232                    .collect::<Result<AndAc>>()?
233                    .0),
234            (Self::ByRef(t0), Self::ByRef(t1)) => t0.contains_int(env, hist, t1),
235            (Self::Tuple(_), Self::Array(_))
236            | (Self::Tuple(_), Self::Primitive(_))
237            | (Self::Tuple(_), Self::Struct(_))
238            | (Self::Tuple(_), Self::Variant(_, _))
239            | (Self::Array(_), Self::Primitive(_))
240            | (Self::Array(_), Self::Tuple(_))
241            | (Self::Array(_), Self::Struct(_))
242            | (Self::Array(_), Self::Variant(_, _))
243            | (Self::Struct(_), Self::Primitive(_))
244            | (Self::Struct(_), Self::Array(_))
245            | (Self::Struct(_), Self::Tuple(_))
246            | (Self::Struct(_), Self::Variant(_, _))
247            | (Self::Variant(_, _), Self::Array(_))
248            | (Self::Variant(_, _), Self::Struct(_))
249            | (Self::Variant(_, _), Self::Primitive(_))
250            | (Self::Variant(_, _), Self::Tuple(_)) => Ok(false),
251            (Self::TVar(t0), tt1 @ Self::TVar(t1)) => {
252                #[derive(Debug)]
253                enum Act {
254                    RightCopy,
255                    LeftAlias,
256                    LeftCopy,
257                }
258                let act = {
259                    let t0 = t0.read();
260                    let t1 = t1.read();
261                    let addr = Arc::as_ptr(&t0.typ).addr();
262                    if addr == Arc::as_ptr(&t1.typ).addr() {
263                        return Ok(true);
264                    }
265                    let t0i = t0.typ.read();
266                    let t1i = t1.typ.read();
267                    match (&*t0i, &*t1i) {
268                        (Some(t0), Some(t1)) => return t0.contains_int(env, hist, &*t1),
269                        (None, None) => {
270                            if would_cycle_inner(addr, tt1) {
271                                return Ok(true);
272                            }
273                            Act::LeftAlias
274                        }
275                        (Some(_), None) => {
276                            if would_cycle_inner(addr, tt1) {
277                                return Ok(true);
278                            }
279                            Act::RightCopy
280                        }
281                        (None, Some(_)) => {
282                            if would_cycle_inner(addr, tt1) {
283                                return Ok(true);
284                            }
285                            Act::LeftCopy
286                        }
287                    }
288                };
289                match act {
290                    Act::RightCopy => t1.copy(t0),
291                    Act::LeftAlias => t0.alias(t1),
292                    Act::LeftCopy => t0.copy(t1),
293                }
294                Ok(true)
295            }
296            (Self::TVar(t0), t1) if !t0.would_cycle(t1) => {
297                if let Some(t0) = &*t0.read().typ.read() {
298                    return t0.contains_int(env, hist, t1);
299                }
300                *t0.read().typ.write() = Some(t1.clone());
301                Ok(true)
302            }
303            (t0, Self::TVar(t1)) if !t1.would_cycle(t0) => {
304                if let Some(t1) = &*t1.read().typ.read() {
305                    return t0.contains_int(env, hist, t1);
306                }
307                *t1.read().typ.write() = Some(t0.clone());
308                Ok(true)
309            }
310            (t0, Self::Set(s)) => Ok(s
311                .iter()
312                .map(|t1| t0.contains_int(env, hist, t1))
313                .collect::<Result<AndAc>>()?
314                .0),
315            (Self::Set(s), t) => Ok(s
316                .iter()
317                .fold(Ok::<_, anyhow::Error>(false), |acc, t0| {
318                    Ok(acc? || t0.contains_int(env, hist, t)?)
319                })?
320                || t.iter_prims().fold(Ok::<_, anyhow::Error>(true), |acc, t1| {
321                    Ok(acc?
322                        && s.iter().fold(Ok::<_, anyhow::Error>(false), |acc, t0| {
323                            Ok(acc? || t0.contains_int(env, hist, &t1)?)
324                        })?)
325                })?),
326            (Self::Fn(f0), Self::Fn(f1)) => {
327                Ok(f0.as_ptr() == f1.as_ptr() || f0.contains_int(env, hist, f1)?)
328            }
329            (_, Self::TVar(_))
330            | (Self::TVar(_), _)
331            | (Self::Fn(_), _)
332            | (Self::ByRef(_), _)
333            | (_, Self::ByRef(_))
334            | (_, Self::Fn(_)) => Ok(false),
335        }
336    }
337
338    pub fn contains<C: Ctx, E: UserEvent>(
339        &self,
340        env: &Env<C, E>,
341        t: &Self,
342    ) -> Result<bool> {
343        thread_local! {
344            static HIST: RefCell<FxHashMap<(usize, usize), bool>> = RefCell::new(HashMap::default());
345        }
346        HIST.with_borrow_mut(|hist| {
347            hist.clear();
348            self.contains_int(env, hist, t)
349        })
350    }
351
352    fn union_int(&self, t: &Self) -> Self {
353        match (self, t) {
354            (Type::Bottom, t) | (t, Type::Bottom) => t.clone(),
355            (Type::Primitive(p), t) | (t, Type::Primitive(p)) if p.is_empty() => {
356                t.clone()
357            }
358            (Type::Primitive(s0), Type::Primitive(s1)) => {
359                let mut s = *s0;
360                s.insert(*s1);
361                Type::Primitive(s)
362            }
363            (
364                Type::Primitive(p),
365                Type::Array(_) | Type::Struct(_) | Type::Tuple(_) | Type::Variant(_, _),
366            )
367            | (
368                Type::Array(_) | Type::Struct(_) | Type::Tuple(_) | Type::Variant(_, _),
369                Type::Primitive(p),
370            ) if p.contains(Typ::Array) => Type::Primitive(*p),
371            (Type::Primitive(p), Type::Array(t))
372            | (Type::Array(t), Type::Primitive(p)) => {
373                Type::Set(Arc::from_iter([Type::Primitive(*p), Type::Array(t.clone())]))
374            }
375            (t @ Type::Array(t0), u @ Type::Array(t1)) => {
376                if t0 == t1 {
377                    Type::Array(t0.clone())
378                } else {
379                    Type::Set(Arc::from_iter([u.clone(), t.clone()]))
380                }
381            }
382            (t @ Type::ByRef(t0), u @ Type::ByRef(t1)) => {
383                if t0 == t1 {
384                    Type::ByRef(t0.clone())
385                } else {
386                    Type::Set(Arc::from_iter([u.clone(), t.clone()]))
387                }
388            }
389            (Type::Set(s0), Type::Set(s1)) => {
390                Type::Set(Arc::from_iter(s0.iter().cloned().chain(s1.iter().cloned())))
391            }
392            (Type::Set(s), t) | (t, Type::Set(s)) => {
393                Type::Set(Arc::from_iter(s.iter().cloned().chain(iter::once(t.clone()))))
394            }
395            (u @ Type::Struct(t0), t @ Type::Struct(t1)) => {
396                if t0.len() == t1.len() && t0 == t1 {
397                    u.clone()
398                } else {
399                    Type::Set(Arc::from_iter([u.clone(), t.clone()]))
400                }
401            }
402            (u @ Type::Struct(_), t) | (t, u @ Type::Struct(_)) => {
403                Type::Set(Arc::from_iter([u.clone(), t.clone()]))
404            }
405            (u @ Type::Tuple(t0), t @ Type::Tuple(t1)) => {
406                if t0 == t1 {
407                    u.clone()
408                } else {
409                    Type::Set(Arc::from_iter([u.clone(), t.clone()]))
410                }
411            }
412            (u @ Type::Tuple(_), t) | (t, u @ Type::Tuple(_)) => {
413                Type::Set(Arc::from_iter([u.clone(), t.clone()]))
414            }
415            (u @ Type::Variant(tg0, t0), t @ Type::Variant(tg1, t1)) => {
416                if tg0 == tg1 && t0.len() == t1.len() {
417                    let typs = t0.iter().zip(t1.iter()).map(|(t0, t1)| t0.union_int(t1));
418                    Type::Variant(tg0.clone(), Arc::from_iter(typs))
419                } else {
420                    Type::Set(Arc::from_iter([u.clone(), t.clone()]))
421                }
422            }
423            (u @ Type::Variant(_, _), t) | (t, u @ Type::Variant(_, _)) => {
424                Type::Set(Arc::from_iter([u.clone(), t.clone()]))
425            }
426            (Type::Fn(f0), Type::Fn(f1)) => {
427                if f0 == f1 {
428                    Type::Fn(f0.clone())
429                } else {
430                    Type::Set(Arc::from_iter([
431                        Type::Fn(f0.clone()),
432                        Type::Fn(f1.clone()),
433                    ]))
434                }
435            }
436            (f @ Type::Fn(_), t) | (t, f @ Type::Fn(_)) => {
437                Type::Set(Arc::from_iter([f.clone(), t.clone()]))
438            }
439            (t0 @ Type::TVar(_), t1 @ Type::TVar(_)) => {
440                if t0 == t1 {
441                    t0.clone()
442                } else {
443                    Type::Set(Arc::from_iter([t0.clone(), t1.clone()]))
444                }
445            }
446            (t0 @ Type::TVar(_), t1) | (t1, t0 @ Type::TVar(_)) => {
447                Type::Set(Arc::from_iter([t0.clone(), t1.clone()]))
448            }
449            (t @ Type::ByRef(_), u) | (u, t @ Type::ByRef(_)) => {
450                Type::Set(Arc::from_iter([t.clone(), u.clone()]))
451            }
452            (tr @ Type::Ref { .. }, t) | (t, tr @ Type::Ref { .. }) => {
453                Type::Set(Arc::from_iter([tr.clone(), t.clone()]))
454            }
455        }
456    }
457
458    pub fn union(&self, t: &Self) -> Self {
459        self.union_int(t).normalize()
460    }
461
462    fn diff_int<C: Ctx, E: UserEvent>(
463        &self,
464        env: &Env<C, E>,
465        hist: &mut FxHashMap<(usize, usize), Type>,
466        t: &Self,
467    ) -> Result<Self> {
468        match (self, t) {
469            (
470                Type::Ref { scope: s0, name: n0, .. },
471                Type::Ref { scope: s1, name: n1, .. },
472            ) if s0 == s1 && n0 == n1 => Ok(Type::Primitive(BitFlags::empty())),
473            (t0 @ Type::Ref { .. }, t1) | (t0, t1 @ Type::Ref { .. }) => {
474                let t0 = t0.lookup_ref(env)?;
475                let t1 = t1.lookup_ref(env)?;
476                let t0_addr = (t0 as *const Type).addr();
477                let t1_addr = (t1 as *const Type).addr();
478                match hist.get(&(t0_addr, t1_addr)) {
479                    Some(r) => Ok(r.clone()),
480                    None => {
481                        let r = Type::Primitive(BitFlags::empty());
482                        hist.insert((t0_addr, t1_addr), r);
483                        match t0.diff_int(env, hist, &t1) {
484                            Ok(r) => {
485                                hist.insert((t0_addr, t1_addr), r.clone());
486                                Ok(r)
487                            }
488                            Err(e) => {
489                                hist.remove(&(t0_addr, t1_addr));
490                                Err(e)
491                            }
492                        }
493                    }
494                }
495            }
496            (Type::Bottom, t) | (t, Type::Bottom) => Ok(t.clone()),
497            (Type::Primitive(s0), Type::Primitive(s1)) => {
498                let mut s = *s0;
499                s.remove(*s1);
500                Ok(Type::Primitive(s))
501            }
502            (
503                Type::Primitive(p),
504                Type::Array(_) | Type::Struct(_) | Type::Tuple(_) | Type::Variant(_, _),
505            ) => {
506                // CR estokes: is this correct? It's a bit odd.
507                let mut s = *p;
508                s.remove(Typ::Array);
509                Ok(Type::Primitive(s))
510            }
511            (
512                Type::Array(_) | Type::Struct(_) | Type::Tuple(_) | Type::Variant(_, _),
513                Type::Primitive(p),
514            ) => {
515                if p.contains(Typ::Array) {
516                    Ok(Type::Primitive(BitFlags::empty()))
517                } else {
518                    Ok(self.clone())
519                }
520            }
521            (Type::Array(t0), Type::Array(t1)) => {
522                Ok(Type::Array(Arc::new(t0.diff_int(env, hist, t1)?)))
523            }
524            (Type::ByRef(t0), Type::ByRef(t1)) => {
525                Ok(Type::ByRef(Arc::new(t0.diff_int(env, hist, t1)?)))
526            }
527            (Type::Set(s0), Type::Set(s1)) => {
528                let mut s: SmallVec<[Type; 4]> = smallvec![];
529                for i in 0..s0.len() {
530                    s.push(s0[i].clone());
531                    for j in 0..s1.len() {
532                        s[i] = s[i].diff_int(env, hist, &s1[j])?
533                    }
534                }
535                Ok(Self::flatten_set(s.into_iter()))
536            }
537            (Type::Set(s), t) => Ok(Self::flatten_set(
538                s.iter()
539                    .map(|s| s.diff_int(env, hist, t))
540                    .collect::<Result<SmallVec<[_; 8]>>>()?,
541            )),
542            (t, Type::Set(s)) => {
543                let mut t = t.clone();
544                for st in s.iter() {
545                    t = t.diff_int(env, hist, st)?;
546                }
547                Ok(t)
548            }
549            (Type::Tuple(t0), Type::Tuple(t1)) => {
550                if t0 == t1 {
551                    Ok(Type::Primitive(BitFlags::empty()))
552                } else {
553                    Ok(self.clone())
554                }
555            }
556            (Type::Tuple(_), _) | (_, Type::Tuple(_)) => Ok(self.clone()),
557            (Type::Struct(t0), Type::Struct(t1)) => {
558                if t0.len() == t1.len() && t0 == t1 {
559                    Ok(Type::Primitive(BitFlags::empty()))
560                } else {
561                    Ok(self.clone())
562                }
563            }
564            (Type::Struct(_), _) | (_, Type::Struct(_)) => Ok(self.clone()),
565            (Type::ByRef(_), _) | (_, Type::ByRef(_)) => Ok(self.clone()),
566            (Type::Variant(tg0, t0), Type::Variant(tg1, t1)) => {
567                if tg0 == tg1 && t0.len() == t1.len() && t0 == t1 {
568                    Ok(Type::Primitive(BitFlags::empty()))
569                } else {
570                    Ok(self.clone())
571                }
572            }
573            (Type::Variant(_, _), _) | (_, Type::Variant(_, _)) => Ok(self.clone()),
574            (Type::Fn(f0), Type::Fn(f1)) => {
575                if f0 == f1 {
576                    Ok(Type::Primitive(BitFlags::empty()))
577                } else {
578                    Ok(Type::Fn(f0.clone()))
579                }
580            }
581            (f @ Type::Fn(_), _) => Ok(f.clone()),
582            (t, Type::Fn(_)) => Ok(t.clone()),
583            (Type::TVar(tv0), Type::TVar(tv1)) => {
584                if tv0.read().typ.as_ptr() == tv1.read().typ.as_ptr() {
585                    return Ok(Type::Primitive(BitFlags::empty()));
586                }
587                Ok(match (&*tv0.read().typ.read(), &*tv1.read().typ.read()) {
588                    (None, _) | (_, None) => Type::TVar(tv0.clone()),
589                    (Some(t0), Some(t1)) => t0.diff_int(env, hist, t1)?,
590                })
591            }
592            (Type::TVar(tv), t1) => match &*tv.read().typ.read() {
593                None => Ok(Type::TVar(tv.clone())),
594                Some(t0) => t0.diff_int(env, hist, t1),
595            },
596            (t0, Type::TVar(tv)) => match &*tv.read().typ.read() {
597                None => Ok(t0.clone()),
598                Some(t1) => t0.diff_int(env, hist, t1),
599            },
600        }
601    }
602
603    pub fn diff<C: Ctx, E: UserEvent>(&self, env: &Env<C, E>, t: &Self) -> Result<Self> {
604        thread_local! {
605            static HIST: RefCell<FxHashMap<(usize, usize), Type>> = RefCell::new(HashMap::default());
606        }
607        HIST.with_borrow_mut(|hist| {
608            hist.clear();
609            Ok(self.diff_int(env, hist, t)?.normalize())
610        })
611    }
612
613    pub fn any() -> Self {
614        Self::Primitive(Typ::any())
615    }
616
617    pub fn boolean() -> Self {
618        Self::Primitive(Typ::Bool.into())
619    }
620
621    pub fn number() -> Self {
622        Self::Primitive(Typ::number())
623    }
624
625    pub fn int() -> Self {
626        Self::Primitive(Typ::integer())
627    }
628
629    pub fn uint() -> Self {
630        Self::Primitive(Typ::unsigned_integer())
631    }
632
633    /// alias type variables with the same name to each other
634    pub fn alias_tvars(&self, known: &mut FxHashMap<ArcStr, TVar>) {
635        match self {
636            Type::Bottom | Type::Primitive(_) => (),
637            Type::Ref { params, .. } => {
638                for t in params.iter() {
639                    t.alias_tvars(known);
640                }
641            }
642            Type::Array(t) => t.alias_tvars(known),
643            Type::ByRef(t) => t.alias_tvars(known),
644            Type::Tuple(ts) => {
645                for t in ts.iter() {
646                    t.alias_tvars(known)
647                }
648            }
649            Type::Struct(ts) => {
650                for (_, t) in ts.iter() {
651                    t.alias_tvars(known)
652                }
653            }
654            Type::Variant(_, ts) => {
655                for t in ts.iter() {
656                    t.alias_tvars(known)
657                }
658            }
659            Type::TVar(tv) => match known.entry(tv.name.clone()) {
660                Entry::Occupied(e) => tv.alias(e.get()),
661                Entry::Vacant(e) => {
662                    e.insert(tv.clone());
663                    ()
664                }
665            },
666            Type::Fn(ft) => ft.alias_tvars(known),
667            Type::Set(s) => {
668                for typ in s.iter() {
669                    typ.alias_tvars(known)
670                }
671            }
672        }
673    }
674
675    pub fn collect_tvars(&self, known: &mut FxHashMap<ArcStr, TVar>) {
676        match self {
677            Type::Bottom | Type::Primitive(_) => (),
678            Type::Ref { params, .. } => {
679                for t in params.iter() {
680                    t.collect_tvars(known);
681                }
682            }
683            Type::Array(t) => t.collect_tvars(known),
684            Type::ByRef(t) => t.collect_tvars(known),
685            Type::Tuple(ts) => {
686                for t in ts.iter() {
687                    t.collect_tvars(known)
688                }
689            }
690            Type::Struct(ts) => {
691                for (_, t) in ts.iter() {
692                    t.collect_tvars(known)
693                }
694            }
695            Type::Variant(_, ts) => {
696                for t in ts.iter() {
697                    t.collect_tvars(known)
698                }
699            }
700            Type::TVar(tv) => match known.entry(tv.name.clone()) {
701                Entry::Occupied(_) => (),
702                Entry::Vacant(e) => {
703                    e.insert(tv.clone());
704                    ()
705                }
706            },
707            Type::Fn(ft) => ft.collect_tvars(known),
708            Type::Set(s) => {
709                for typ in s.iter() {
710                    typ.collect_tvars(known)
711                }
712            }
713        }
714    }
715
716    pub fn check_tvars_declared(&self, declared: &FxHashSet<ArcStr>) -> Result<()> {
717        match self {
718            Type::Bottom | Type::Primitive(_) => Ok(()),
719            Type::Ref { params, .. } => {
720                params.iter().try_for_each(|t| t.check_tvars_declared(declared))
721            }
722            Type::Array(t) => t.check_tvars_declared(declared),
723            Type::ByRef(t) => t.check_tvars_declared(declared),
724            Type::Tuple(ts) => {
725                ts.iter().try_for_each(|t| t.check_tvars_declared(declared))
726            }
727            Type::Struct(ts) => {
728                ts.iter().try_for_each(|(_, t)| t.check_tvars_declared(declared))
729            }
730            Type::Variant(_, ts) => {
731                ts.iter().try_for_each(|t| t.check_tvars_declared(declared))
732            }
733            Type::TVar(tv) => {
734                if !declared.contains(&tv.name) {
735                    bail!("undeclared type variable '{}'", tv.name)
736                } else {
737                    Ok(())
738                }
739            }
740            Type::Set(s) => s.iter().try_for_each(|t| t.check_tvars_declared(declared)),
741            Type::Fn(_) => Ok(()),
742        }
743    }
744
745    pub fn has_unbound(&self) -> bool {
746        match self {
747            Type::Bottom | Type::Primitive(_) => false,
748            Type::Ref { .. } => false,
749            Type::Array(t0) => t0.has_unbound(),
750            Type::ByRef(t0) => t0.has_unbound(),
751            Type::Tuple(ts) => ts.iter().any(|t| t.has_unbound()),
752            Type::Struct(ts) => ts.iter().any(|(_, t)| t.has_unbound()),
753            Type::Variant(_, ts) => ts.iter().any(|t| t.has_unbound()),
754            Type::TVar(tv) => tv.read().typ.read().is_some(),
755            Type::Set(s) => s.iter().any(|t| t.has_unbound()),
756            Type::Fn(ft) => ft.has_unbound(),
757        }
758    }
759
760    /// bind all unbound type variables to the specified type
761    pub fn bind_as(&self, t: &Self) {
762        match self {
763            Type::Bottom | Type::Primitive(_) => (),
764            Type::Ref { .. } => (),
765            Type::Array(t0) => t0.bind_as(t),
766            Type::ByRef(t0) => t0.bind_as(t),
767            Type::Tuple(ts) => {
768                for elt in ts.iter() {
769                    elt.bind_as(t)
770                }
771            }
772            Type::Struct(ts) => {
773                for (_, elt) in ts.iter() {
774                    elt.bind_as(t)
775                }
776            }
777            Type::Variant(_, ts) => {
778                for elt in ts.iter() {
779                    elt.bind_as(t)
780                }
781            }
782            Type::TVar(tv) => {
783                let tv = tv.read();
784                let mut tv = tv.typ.write();
785                if tv.is_none() {
786                    *tv = Some(t.clone());
787                }
788            }
789            Type::Set(s) => {
790                for elt in s.iter() {
791                    elt.bind_as(t)
792                }
793            }
794            Type::Fn(ft) => ft.bind_as(t),
795        }
796    }
797
798    /// return a copy of self with all type variables unbound and
799    /// unaliased. self will not be modified
800    pub fn reset_tvars(&self) -> Type {
801        match self {
802            Type::Bottom => Type::Bottom,
803            Type::Primitive(p) => Type::Primitive(*p),
804            Type::Ref { scope, name, params } => Type::Ref {
805                scope: scope.clone(),
806                name: name.clone(),
807                params: Arc::from_iter(params.iter().map(|t| t.reset_tvars())),
808            },
809            Type::Array(t0) => Type::Array(Arc::new(t0.reset_tvars())),
810            Type::ByRef(t0) => Type::ByRef(Arc::new(t0.reset_tvars())),
811            Type::Tuple(ts) => {
812                Type::Tuple(Arc::from_iter(ts.iter().map(|t| t.reset_tvars())))
813            }
814            Type::Struct(ts) => Type::Struct(Arc::from_iter(
815                ts.iter().map(|(n, t)| (n.clone(), t.reset_tvars())),
816            )),
817            Type::Variant(tag, ts) => Type::Variant(
818                tag.clone(),
819                Arc::from_iter(ts.iter().map(|t| t.reset_tvars())),
820            ),
821            Type::TVar(tv) => Type::TVar(TVar::empty_named(tv.name.clone())),
822            Type::Set(s) => Type::Set(Arc::from_iter(s.iter().map(|t| t.reset_tvars()))),
823            Type::Fn(fntyp) => Type::Fn(Arc::new(fntyp.reset_tvars())),
824        }
825    }
826
827    /// return a copy of self with every TVar named in known replaced
828    /// with the corresponding type
829    pub fn replace_tvars(&self, known: &FxHashMap<ArcStr, Self>) -> Type {
830        match self {
831            Type::TVar(tv) => match known.get(&tv.name) {
832                Some(t) => t.clone(),
833                None => Type::TVar(tv.clone()),
834            },
835            Type::Bottom => Type::Bottom,
836            Type::Primitive(p) => Type::Primitive(*p),
837            Type::Ref { scope, name, params } => Type::Ref {
838                scope: scope.clone(),
839                name: name.clone(),
840                params: Arc::from_iter(params.iter().map(|t| t.replace_tvars(known))),
841            },
842            Type::Array(t0) => Type::Array(Arc::new(t0.replace_tvars(known))),
843            Type::ByRef(t0) => Type::ByRef(Arc::new(t0.replace_tvars(known))),
844            Type::Tuple(ts) => {
845                Type::Tuple(Arc::from_iter(ts.iter().map(|t| t.replace_tvars(known))))
846            }
847            Type::Struct(ts) => Type::Struct(Arc::from_iter(
848                ts.iter().map(|(n, t)| (n.clone(), t.replace_tvars(known))),
849            )),
850            Type::Variant(tag, ts) => Type::Variant(
851                tag.clone(),
852                Arc::from_iter(ts.iter().map(|t| t.replace_tvars(known))),
853            ),
854            Type::Set(s) => {
855                Type::Set(Arc::from_iter(s.iter().map(|t| t.replace_tvars(known))))
856            }
857            Type::Fn(fntyp) => Type::Fn(Arc::new(fntyp.replace_tvars(known))),
858        }
859    }
860
861    /// Unbind any bound tvars, but do not unalias them.
862    fn unbind_tvars(&self) {
863        match self {
864            Type::Bottom | Type::Primitive(_) | Type::Ref { .. } => (),
865            Type::Array(t0) => t0.unbind_tvars(),
866            Type::ByRef(t0) => t0.unbind_tvars(),
867            Type::Tuple(ts) | Type::Variant(_, ts) | Type::Set(ts) => {
868                for t in ts.iter() {
869                    t.unbind_tvars()
870                }
871            }
872            Type::Struct(ts) => {
873                for (_, t) in ts.iter() {
874                    t.unbind_tvars()
875                }
876            }
877            Type::TVar(tv) => tv.unbind(),
878            Type::Fn(fntyp) => fntyp.unbind_tvars(),
879        }
880    }
881
882    fn first_prim_int<C: Ctx, E: UserEvent>(
883        &self,
884        env: &Env<C, E>,
885        hist: &mut FxHashSet<usize>,
886    ) -> Option<Typ> {
887        match self {
888            Type::Primitive(p) => p.iter().next(),
889            Type::Bottom => None,
890            Type::Fn(_) => None,
891            Type::Set(s) => s.iter().find_map(|t| t.first_prim_int(env, hist)),
892            Type::TVar(tv) => {
893                tv.read().typ.read().as_ref().and_then(|t| t.first_prim_int(env, hist))
894            }
895            // array, tuple, and struct casting are handled directly
896            Type::Array(_)
897            | Type::Tuple(_)
898            | Type::Struct(_)
899            | Type::Variant(_, _)
900            | Type::ByRef(_) => None,
901            Type::Ref { .. } => {
902                let t = self.lookup_ref(env).ok()?;
903                let t_addr = (t as *const Type).addr();
904                if hist.contains(&t_addr) {
905                    None
906                } else {
907                    hist.insert(t_addr);
908                    t.first_prim_int(env, hist)
909                }
910            }
911        }
912    }
913
914    fn first_prim<C: Ctx, E: UserEvent>(&self, env: &Env<C, E>) -> Option<Typ> {
915        thread_local! {
916            static HIST: RefCell<FxHashSet<usize>> = RefCell::new(HashSet::default());
917        }
918        HIST.with_borrow_mut(|hist| {
919            hist.clear();
920            self.first_prim_int(env, hist)
921        })
922    }
923
924    fn check_cast_int<C: Ctx, E: UserEvent>(
925        &self,
926        env: &Env<C, E>,
927        hist: &mut FxHashSet<usize>,
928    ) -> Result<()> {
929        match self {
930            Type::Primitive(_) => Ok(()),
931            Type::Fn(_) => bail!("can't cast a value to a function"),
932            Type::Bottom => bail!("can't cast a value to bottom"),
933            Type::Set(s) => Ok(for t in s.iter() {
934                t.check_cast_int(env, hist)?
935            }),
936            Type::TVar(tv) => match &*tv.read().typ.read() {
937                Some(t) => t.check_cast_int(env, hist),
938                None => bail!("can't cast a value to a free type variable"),
939            },
940            Type::Array(et) => et.check_cast_int(env, hist),
941            Type::ByRef(_) => bail!("can't cast a reference"),
942            Type::Tuple(ts) => Ok(for t in ts.iter() {
943                t.check_cast_int(env, hist)?
944            }),
945            Type::Struct(ts) => Ok(for (_, t) in ts.iter() {
946                t.check_cast_int(env, hist)?
947            }),
948            Type::Variant(_, ts) => Ok(for t in ts.iter() {
949                t.check_cast_int(env, hist)?
950            }),
951            Type::Ref { .. } => {
952                let t = self.lookup_ref(env)?;
953                let t_addr = (t as *const Type).addr();
954                if hist.contains(&t_addr) {
955                    Ok(())
956                } else {
957                    hist.insert(t_addr);
958                    t.check_cast_int(env, hist)
959                }
960            }
961        }
962    }
963
964    pub fn check_cast<C: Ctx, E: UserEvent>(&self, env: &Env<C, E>) -> Result<()> {
965        thread_local! {
966            static HIST: RefCell<FxHashSet<usize>> = RefCell::new(FxHashSet::default());
967        }
968        HIST.with_borrow_mut(|hist| {
969            hist.clear();
970            self.check_cast_int(env, hist)
971        })
972    }
973
974    fn check_array<C: Ctx, E: UserEvent>(
975        &self,
976        env: &Env<C, E>,
977        a: &ValArray,
978    ) -> Result<bool> {
979        Ok(a.iter()
980            .map(|elt| match elt {
981                Value::Array(elts) => match self {
982                    Type::Array(et) => et.check_array(env, elts),
983                    _ => Ok(false),
984                },
985                v => self.contains(env, &Type::Primitive(Typ::get(v).into())),
986            })
987            .collect::<Result<AndAc>>()?
988            .0)
989    }
990
991    fn cast_value_int<C: Ctx, E: UserEvent>(
992        &self,
993        env: &Env<C, E>,
994        v: Value,
995    ) -> Result<Value> {
996        if self.contains(env, &Type::Primitive(Typ::get(&v).into()))? {
997            return Ok(v);
998        }
999        match self {
1000            Type::Array(et) => match v {
1001                Value::Array(elts) => {
1002                    if et.check_array(env, &elts)? {
1003                        return Ok(Value::Array(elts));
1004                    }
1005                    let va = elts
1006                        .iter()
1007                        .map(|el| et.cast_value_int(env, el.clone()))
1008                        .collect::<Result<SmallVec<[Value; 8]>>>()?;
1009                    Ok(Value::Array(ValArray::from_iter_exact(va.into_iter())))
1010                }
1011                v => Ok(Value::Array([et.cast_value_int(env, v)?].into())),
1012            },
1013            Type::Tuple(ts) => match v {
1014                Value::Array(elts) => {
1015                    if elts.len() != ts.len() {
1016                        bail!("tuple size mismatch {self} with {}", Value::Array(elts))
1017                    }
1018                    let ok = ts
1019                        .iter()
1020                        .zip(elts.iter())
1021                        .map(|(t, v)| {
1022                            t.contains(env, &Type::Primitive(Typ::get(v).into()))
1023                        })
1024                        .collect::<Result<AndAc>>()?
1025                        .0;
1026                    if ok {
1027                        return Ok(Value::Array(elts));
1028                    }
1029                    let a = ts
1030                        .iter()
1031                        .zip(elts.iter())
1032                        .map(|(t, el)| t.cast_value_int(env, el.clone()))
1033                        .collect::<Result<SmallVec<[Value; 8]>>>()?;
1034                    Ok(Value::Array(ValArray::from_iter_exact(a.into_iter())))
1035                }
1036                v => bail!("can't cast {v} to {self}"),
1037            },
1038            Type::Struct(ts) => match v {
1039                Value::Array(elts) => {
1040                    if elts.len() != ts.len() {
1041                        bail!("struct size mismatch {self} with {}", Value::Array(elts))
1042                    }
1043                    let is_pairs = elts.iter().all(|v| match v {
1044                        Value::Array(a) if a.len() == 2 => match &a[0] {
1045                            Value::String(_) => true,
1046                            _ => false,
1047                        },
1048                        _ => false,
1049                    });
1050                    if !is_pairs {
1051                        bail!("expected array of pairs, got {}", Value::Array(elts))
1052                    }
1053                    let mut elts_s: SmallVec<[&Value; 16]> = elts.iter().collect();
1054                    elts_s.sort_by_key(|v| match v {
1055                        Value::Array(a) => match &a[0] {
1056                            Value::String(s) => s,
1057                            _ => unreachable!(),
1058                        },
1059                        _ => unreachable!(),
1060                    });
1061                    let (keys_ok, ok) = ts.iter().zip(elts_s.iter()).fold(
1062                        Ok((true, true)),
1063                        |acc: Result<_>, ((fname, t), v)| {
1064                            let (kok, ok) = acc?;
1065                            let (name, v) = match v {
1066                                Value::Array(a) => match (&a[0], &a[1]) {
1067                                    (Value::String(n), v) => (n, v),
1068                                    _ => unreachable!(),
1069                                },
1070                                _ => unreachable!(),
1071                            };
1072                            Ok((
1073                                kok && name == fname,
1074                                ok && kok
1075                                    && t.contains(
1076                                        env,
1077                                        &Type::Primitive(Typ::get(v).into()),
1078                                    )?,
1079                            ))
1080                        },
1081                    )?;
1082                    if ok {
1083                        drop(elts_s);
1084                        return Ok(Value::Array(elts));
1085                    } else if keys_ok {
1086                        let elts = ts
1087                            .iter()
1088                            .zip(elts_s.iter())
1089                            .map(|((n, t), v)| match v {
1090                                Value::Array(a) => {
1091                                    let a = [
1092                                        Value::String(n.clone()),
1093                                        t.cast_value_int(env, a[1].clone())?,
1094                                    ];
1095                                    Ok(Value::Array(ValArray::from_iter_exact(
1096                                        a.into_iter(),
1097                                    )))
1098                                }
1099                                _ => unreachable!(),
1100                            })
1101                            .collect::<Result<SmallVec<[Value; 8]>>>()?;
1102                        Ok(Value::Array(ValArray::from_iter_exact(elts.into_iter())))
1103                    } else {
1104                        drop(elts_s);
1105                        bail!("struct fields mismatch {self}, {}", Value::Array(elts))
1106                    }
1107                }
1108                v => bail!("can't cast {v} to {self}"),
1109            },
1110            Type::Variant(tag, ts) if ts.len() == 0 => match &v {
1111                Value::String(s) if s == tag => Ok(v),
1112                _ => bail!("variant tag mismatch expected {tag} got {v}"),
1113            },
1114            Type::Variant(tag, ts) => match &v {
1115                Value::Array(elts) => {
1116                    if ts.len() + 1 == elts.len() {
1117                        match &elts[0] {
1118                            Value::String(s) if s == tag => (),
1119                            v => bail!("variant tag mismatch expected {tag} got {v}"),
1120                        }
1121                        let ok = iter::once(&Type::Primitive(Typ::String.into()))
1122                            .chain(ts.iter())
1123                            .zip(elts.iter())
1124                            .fold(Ok(true), |ok: Result<_>, (t, v)| {
1125                                Ok(ok?
1126                                    && t.contains(
1127                                        env,
1128                                        &Type::Primitive(Typ::get(v).into()),
1129                                    )?)
1130                            })?;
1131                        if ok {
1132                            Ok(v)
1133                        } else {
1134                            let a = iter::once(&Type::Primitive(Typ::String.into()))
1135                                .chain(ts.iter())
1136                                .zip(elts.iter())
1137                                .map(|(t, v)| t.cast_value_int(env, v.clone()))
1138                                .collect::<Result<SmallVec<[Value; 8]>>>()?;
1139                            Ok(Value::Array(ValArray::from_iter_exact(a.into_iter())))
1140                        }
1141                    } else if ts.len() == elts.len() {
1142                        let mut a = ts
1143                            .iter()
1144                            .zip(elts.iter())
1145                            .map(|(t, v)| t.cast_value_int(env, v.clone()))
1146                            .collect::<Result<SmallVec<[Value; 8]>>>()?;
1147                        a.insert(0, Value::String(tag.clone()));
1148                        Ok(Value::Array(ValArray::from_iter_exact(a.into_iter())))
1149                    } else {
1150                        bail!("variant length mismatch")
1151                    }
1152                }
1153                v => bail!("can't cast {v} to {self}"),
1154            },
1155            Type::Ref { .. } => self.lookup_ref(env)?.cast_value_int(env, v),
1156            t => match t.first_prim(env) {
1157                None => bail!("empty or non primitive cast"),
1158                Some(t) => Ok(v
1159                    .clone()
1160                    .cast(t)
1161                    .ok_or_else(|| anyhow!("can't cast {v} to {t}"))?),
1162            },
1163        }
1164    }
1165
1166    pub fn cast_value<C: Ctx, E: UserEvent>(&self, env: &Env<C, E>, v: Value) -> Value {
1167        match self.cast_value_int(env, v) {
1168            Ok(v) => v,
1169            Err(e) => Value::Error(e.to_string().into()),
1170        }
1171    }
1172
1173    fn is_a_int<C: Ctx, E: UserEvent>(
1174        &self,
1175        env: &Env<C, E>,
1176        hist: &mut FxHashSet<usize>,
1177        v: &Value,
1178    ) -> bool {
1179        match self {
1180            Type::Ref { .. } => match self.lookup_ref(env) {
1181                Err(_) => false,
1182                Ok(t) => {
1183                    let t_addr = (t as *const Type).addr();
1184                    !hist.contains(&t_addr) && {
1185                        hist.insert(t_addr);
1186                        t.is_a_int(env, hist, v)
1187                    }
1188                }
1189            },
1190            Type::Primitive(t) => t.contains(Typ::get(&v)),
1191            Type::Array(et) => match v {
1192                Value::Array(a) => a.iter().all(|v| et.is_a_int(env, hist, v)),
1193                _ => false,
1194            },
1195            Type::ByRef(_) => matches!(v, Value::U64(_) | Value::V64(_)),
1196            Type::Tuple(ts) => match v {
1197                Value::Array(elts) => {
1198                    elts.len() == ts.len()
1199                        && ts
1200                            .iter()
1201                            .zip(elts.iter())
1202                            .all(|(t, v)| t.is_a_int(env, hist, v))
1203                }
1204                _ => false,
1205            },
1206            Type::Struct(ts) => match v {
1207                Value::Array(elts) => {
1208                    elts.len() == ts.len()
1209                        && ts.iter().zip(elts.iter()).all(|((n, t), v)| match v {
1210                            Value::Array(a) if a.len() == 2 => match &a[..] {
1211                                [Value::String(key), v] => {
1212                                    n == key && t.is_a_int(env, hist, v)
1213                                }
1214                                _ => false,
1215                            },
1216                            _ => false,
1217                        })
1218                }
1219                _ => false,
1220            },
1221            Type::Variant(tag, ts) if ts.len() == 0 => match &v {
1222                Value::String(s) => s == tag,
1223                _ => false,
1224            },
1225            Type::Variant(tag, ts) => match &v {
1226                Value::Array(elts) => {
1227                    ts.len() + 1 == elts.len()
1228                        && match &elts[0] {
1229                            Value::String(s) => s == tag,
1230                            _ => false,
1231                        }
1232                        && ts
1233                            .iter()
1234                            .zip(elts[1..].iter())
1235                            .all(|(t, v)| t.is_a_int(env, hist, v))
1236                }
1237                _ => false,
1238            },
1239            Type::TVar(tv) => match &*tv.read().typ.read() {
1240                None => true,
1241                Some(t) => t.is_a_int(env, hist, v),
1242            },
1243            Type::Fn(_) => match v {
1244                Value::U64(_) => true,
1245                _ => false,
1246            },
1247            Type::Bottom => true,
1248            Type::Set(ts) => ts.iter().any(|t| t.is_a_int(env, hist, v)),
1249        }
1250    }
1251
1252    /// return true if v is structurally compatible with the type
1253    pub fn is_a<C: Ctx, E: UserEvent>(&self, env: &Env<C, E>, v: &Value) -> bool {
1254        thread_local! {
1255            static HIST: RefCell<FxHashSet<usize>> = RefCell::new(HashSet::default());
1256        }
1257        HIST.with_borrow_mut(|hist| {
1258            hist.clear();
1259            self.is_a_int(env, hist, v)
1260        })
1261    }
1262
1263    pub fn is_bot(&self) -> bool {
1264        match self {
1265            Type::Bottom => true,
1266            Type::TVar(_)
1267            | Type::Primitive(_)
1268            | Type::Ref { .. }
1269            | Type::Fn(_)
1270            | Type::Array(_)
1271            | Type::ByRef(_)
1272            | Type::Tuple(_)
1273            | Type::Struct(_)
1274            | Type::Variant(_, _)
1275            | Type::Set(_) => false,
1276        }
1277    }
1278
1279    pub fn with_deref<R, F: FnOnce(Option<&Self>) -> R>(&self, f: F) -> R {
1280        match self {
1281            Self::Bottom
1282            | Self::Primitive(_)
1283            | Self::Fn(_)
1284            | Self::Set(_)
1285            | Self::Array(_)
1286            | Self::ByRef(_)
1287            | Self::Tuple(_)
1288            | Self::Struct(_)
1289            | Self::Variant(_, _)
1290            | Self::Ref { .. } => f(Some(self)),
1291            Self::TVar(tv) => f(tv.read().typ.read().as_ref()),
1292        }
1293    }
1294
1295    pub(crate) fn flatten_set(set: impl IntoIterator<Item = Self>) -> Self {
1296        let init: Box<dyn Iterator<Item = Self>> = Box::new(set.into_iter());
1297        let mut iters: SmallVec<[Box<dyn Iterator<Item = Self>>; 16]> = smallvec![init];
1298        let mut acc: SmallVec<[Self; 16]> = smallvec![];
1299        loop {
1300            match iters.last_mut() {
1301                None => break,
1302                Some(iter) => match iter.next() {
1303                    None => {
1304                        iters.pop();
1305                    }
1306                    Some(Type::Set(s)) => {
1307                        let v: SmallVec<[Self; 16]> =
1308                            s.iter().map(|t| t.clone()).collect();
1309                        iters.push(Box::new(v.into_iter()))
1310                    }
1311                    Some(t) => {
1312                        let mut merged = false;
1313                        for i in 0..acc.len() {
1314                            if let Some(t) = t.merge(&acc[i]) {
1315                                acc[i] = t;
1316                                merged = true;
1317                                break;
1318                            }
1319                        }
1320                        if !merged {
1321                            acc.push(t);
1322                        }
1323                    }
1324                },
1325            }
1326        }
1327        acc.sort();
1328        match &*acc {
1329            [] => Type::Primitive(BitFlags::empty()),
1330            [t] => t.clone(),
1331            _ => Type::Set(Arc::from_iter(acc)),
1332        }
1333    }
1334
1335    pub(crate) fn normalize(&self) -> Self {
1336        match self {
1337            Type::Bottom | Type::Primitive(_) => self.clone(),
1338            Type::Ref { scope, name, params } => {
1339                let params = Arc::from_iter(params.iter().map(|t| t.normalize()));
1340                Type::Ref { scope: scope.clone(), name: name.clone(), params }
1341            }
1342            Type::TVar(tv) => Type::TVar(tv.normalize()),
1343            Type::Set(s) => Self::flatten_set(s.iter().map(|t| t.normalize())),
1344            Type::Array(t) => Type::Array(Arc::new(t.normalize())),
1345            Type::ByRef(t) => Type::ByRef(Arc::new(t.normalize())),
1346            Type::Tuple(t) => {
1347                Type::Tuple(Arc::from_iter(t.iter().map(|t| t.normalize())))
1348            }
1349            Type::Struct(t) => Type::Struct(Arc::from_iter(
1350                t.iter().map(|(n, t)| (n.clone(), t.normalize())),
1351            )),
1352            Type::Variant(tag, t) => Type::Variant(
1353                tag.clone(),
1354                Arc::from_iter(t.iter().map(|t| t.normalize())),
1355            ),
1356            Type::Fn(ft) => Type::Fn(Arc::new(ft.normalize())),
1357        }
1358    }
1359
1360    fn merge(&self, t: &Self) -> Option<Self> {
1361        match (self, t) {
1362            (
1363                Type::Ref { scope: s0, name: r0, params: a0 },
1364                Type::Ref { scope: s1, name: r1, params: a1 },
1365            ) => {
1366                if s0 == s1 && r0 == r1 && a0 == a1 {
1367                    Some(Type::Ref {
1368                        scope: s0.clone(),
1369                        name: r0.clone(),
1370                        params: a0.clone(),
1371                    })
1372                } else {
1373                    None
1374                }
1375            }
1376            (Type::Ref { .. }, _) | (_, Type::Ref { .. }) => None,
1377            (Type::Bottom, t) | (t, Type::Bottom) => Some(t.clone()),
1378            (Type::Primitive(p), t) | (t, Type::Primitive(p)) if p.is_empty() => {
1379                Some(t.clone())
1380            }
1381            (Type::Primitive(s0), Type::Primitive(s1)) => {
1382                let mut s = *s0;
1383                s.insert(*s1);
1384                Some(Type::Primitive(s))
1385            }
1386            (Type::Fn(f0), Type::Fn(f1)) => {
1387                if f0 == f1 {
1388                    Some(Type::Fn(f0.clone()))
1389                } else {
1390                    None
1391                }
1392            }
1393            (Type::Array(t0), Type::Array(t1)) => {
1394                let t0f = match &**t0 {
1395                    Type::Set(et) => Self::flatten_set(et.iter().cloned()),
1396                    t => t.clone(),
1397                };
1398                let t1f = match &**t1 {
1399                    Type::Set(et) => Self::flatten_set(et.iter().cloned()),
1400                    t => t.clone(),
1401                };
1402                if t0f == t1f {
1403                    Some(Type::Array(t0.clone()))
1404                } else {
1405                    None
1406                }
1407            }
1408            (Type::ByRef(t0), Type::ByRef(t1)) => {
1409                t0.merge(t1).map(|t| Type::ByRef(Arc::new(t)))
1410            }
1411            (Type::ByRef(_), _) | (_, Type::ByRef(_)) => None,
1412            (Type::Array(_), _) | (_, Type::Array(_)) => None,
1413            (Type::Set(s0), Type::Set(s1)) => {
1414                Some(Self::flatten_set(s0.iter().cloned().chain(s1.iter().cloned())))
1415            }
1416            (Type::Set(s), Type::Primitive(p)) | (Type::Primitive(p), Type::Set(s))
1417                if p.is_empty() =>
1418            {
1419                Some(Type::Set(s.clone()))
1420            }
1421            (Type::Set(s), t) | (t, Type::Set(s)) => {
1422                Some(Self::flatten_set(s.iter().cloned().chain(iter::once(t.clone()))))
1423            }
1424            (Type::Tuple(t0), Type::Tuple(t1)) => {
1425                if t0.len() == t1.len() {
1426                    let t = t0
1427                        .iter()
1428                        .zip(t1.iter())
1429                        .map(|(t0, t1)| t0.merge(t1))
1430                        .collect::<Option<SmallVec<[Type; 8]>>>()?;
1431                    Some(Type::Tuple(Arc::from_iter(t)))
1432                } else {
1433                    None
1434                }
1435            }
1436            (Type::Variant(tag0, t0), Type::Variant(tag1, t1)) => {
1437                if tag0 == tag1 && t0.len() == t1.len() {
1438                    let t = t0
1439                        .iter()
1440                        .zip(t1.iter())
1441                        .map(|(t0, t1)| t0.merge(t1))
1442                        .collect::<Option<SmallVec<[Type; 8]>>>()?;
1443                    Some(Type::Variant(tag0.clone(), Arc::from_iter(t)))
1444                } else {
1445                    None
1446                }
1447            }
1448            (Type::Struct(t0), Type::Struct(t1)) => {
1449                if t0.len() == t1.len() {
1450                    let t = t0
1451                        .iter()
1452                        .zip(t1.iter())
1453                        .map(|((n0, t0), (n1, t1))| {
1454                            if n0 != n1 {
1455                                None
1456                            } else {
1457                                t0.merge(t1).map(|t| (n0.clone(), t))
1458                            }
1459                        })
1460                        .collect::<Option<SmallVec<[(ArcStr, Type); 8]>>>()?;
1461                    Some(Type::Struct(Arc::from_iter(t)))
1462                } else {
1463                    None
1464                }
1465            }
1466            (Type::TVar(tv0), Type::TVar(tv1)) if tv0.name == tv1.name && tv0 == tv1 => {
1467                Some(Type::TVar(tv0.clone()))
1468            }
1469            (Type::TVar(tv), t) => {
1470                tv.read().typ.read().as_ref().and_then(|tv| tv.merge(t))
1471            }
1472            (t, Type::TVar(tv)) => {
1473                tv.read().typ.read().as_ref().and_then(|tv| t.merge(tv))
1474            }
1475            (Type::Tuple(_), _)
1476            | (_, Type::Tuple(_))
1477            | (Type::Struct(_), _)
1478            | (_, Type::Struct(_))
1479            | (Type::Variant(_, _), _)
1480            | (_, Type::Variant(_, _))
1481            | (_, Type::Fn(_))
1482            | (Type::Fn(_), _) => None,
1483        }
1484    }
1485
1486    pub fn scope_refs(&self, scope: &ModPath) -> Type {
1487        match self {
1488            Type::Bottom => Type::Bottom,
1489            Type::Primitive(s) => Type::Primitive(*s),
1490            Type::Array(t0) => Type::Array(Arc::new(t0.scope_refs(scope))),
1491            Type::ByRef(t) => Type::ByRef(Arc::new(t.scope_refs(scope))),
1492            Type::Tuple(ts) => {
1493                let i = ts.iter().map(|t| t.scope_refs(scope));
1494                Type::Tuple(Arc::from_iter(i))
1495            }
1496            Type::Variant(tag, ts) => {
1497                let i = ts.iter().map(|t| t.scope_refs(scope));
1498                Type::Variant(tag.clone(), Arc::from_iter(i))
1499            }
1500            Type::Struct(ts) => {
1501                let i = ts.iter().map(|(n, t)| (n.clone(), t.scope_refs(scope)));
1502                Type::Struct(Arc::from_iter(i))
1503            }
1504            Type::TVar(tv) => match tv.read().typ.read().as_ref() {
1505                None => Type::TVar(TVar::empty_named(tv.name.clone())),
1506                Some(typ) => {
1507                    let typ = typ.scope_refs(scope);
1508                    Type::TVar(TVar::named(tv.name.clone(), typ))
1509                }
1510            },
1511            Type::Ref { scope: _, name, params } => {
1512                let params = Arc::from_iter(params.iter().map(|t| t.scope_refs(scope)));
1513                Type::Ref { scope: scope.clone(), name: name.clone(), params }
1514            }
1515            Type::Set(ts) => {
1516                Type::Set(Arc::from_iter(ts.iter().map(|t| t.scope_refs(scope))))
1517            }
1518            Type::Fn(f) => {
1519                let vargs = f.vargs.as_ref().map(|t| t.scope_refs(scope));
1520                let rtype = f.rtype.scope_refs(scope);
1521                let args = Arc::from_iter(f.args.iter().map(|a| FnArgType {
1522                    label: a.label.clone(),
1523                    typ: a.typ.scope_refs(scope),
1524                }));
1525                let mut cres: SmallVec<[(TVar, Type); 4]> = smallvec![];
1526                for (tv, tc) in f.constraints.read().iter() {
1527                    let tv = tv.scope_refs(scope);
1528                    let tc = tc.scope_refs(scope);
1529                    cres.push((tv, tc));
1530                }
1531                Type::Fn(Arc::new(FnType {
1532                    args,
1533                    rtype,
1534                    constraints: Arc::new(RwLock::new(cres.into_iter().collect())),
1535                    vargs,
1536                }))
1537            }
1538        }
1539    }
1540}
1541
1542impl fmt::Display for Type {
1543    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1544        match self {
1545            Self::Bottom => write!(f, "_"),
1546            Self::Ref { scope: _, name, params } => {
1547                write!(f, "{name}")?;
1548                if !params.is_empty() {
1549                    write!(f, "<")?;
1550                    for (i, t) in params.iter().enumerate() {
1551                        write!(f, "{t}")?;
1552                        if i < params.len() - 1 {
1553                            write!(f, ", ")?;
1554                        }
1555                    }
1556                    write!(f, ">")?;
1557                }
1558                Ok(())
1559            }
1560            Self::TVar(tv) => write!(f, "{tv}"),
1561            Self::Fn(t) => write!(f, "{t}"),
1562            Self::Array(t) => write!(f, "Array<{t}>"),
1563            Self::ByRef(t) => write!(f, "&{t}"),
1564            Self::Tuple(ts) => {
1565                write!(f, "(")?;
1566                for (i, t) in ts.iter().enumerate() {
1567                    write!(f, "{t}")?;
1568                    if i < ts.len() - 1 {
1569                        write!(f, ", ")?;
1570                    }
1571                }
1572                write!(f, ")")
1573            }
1574            Self::Variant(tag, ts) if ts.len() == 0 => {
1575                write!(f, "`{tag}")
1576            }
1577            Self::Variant(tag, ts) => {
1578                write!(f, "`{tag}(")?;
1579                for (i, t) in ts.iter().enumerate() {
1580                    write!(f, "{t}")?;
1581                    if i < ts.len() - 1 {
1582                        write!(f, ", ")?
1583                    }
1584                }
1585                write!(f, ")")
1586            }
1587            Self::Struct(ts) => {
1588                write!(f, "{{")?;
1589                for (i, (n, t)) in ts.iter().enumerate() {
1590                    write!(f, "{n}: {t}")?;
1591                    if i < ts.len() - 1 {
1592                        write!(f, ", ")?
1593                    }
1594                }
1595                write!(f, "}}")
1596            }
1597            Self::Set(s) => {
1598                write!(f, "[")?;
1599                for (i, t) in s.iter().enumerate() {
1600                    write!(f, "{t}")?;
1601                    if i < s.len() - 1 {
1602                        write!(f, ", ")?;
1603                    }
1604                }
1605                write!(f, "]")
1606            }
1607            Self::Primitive(s) => {
1608                let replace = PRINT_FLAGS.get().contains(PrintFlag::ReplacePrims);
1609                if replace && *s == Typ::any() {
1610                    write!(f, "Any")
1611                } else if replace && *s == Typ::number() {
1612                    write!(f, "Number")
1613                } else if replace && *s == Typ::float() {
1614                    write!(f, "Float")
1615                } else if replace && *s == Typ::real() {
1616                    write!(f, "Real")
1617                } else if replace && *s == Typ::integer() {
1618                    write!(f, "Int")
1619                } else if replace && *s == Typ::unsigned_integer() {
1620                    write!(f, "Uint")
1621                } else if replace && *s == Typ::signed_integer() {
1622                    write!(f, "Sint")
1623                } else if s.len() == 0 {
1624                    write!(f, "[]")
1625                } else if s.len() == 1 {
1626                    write!(f, "{}", s.iter().next().unwrap())
1627                } else {
1628                    let mut s = *s;
1629                    macro_rules! builtin {
1630                        ($set:expr, $name:literal) => {
1631                            if replace && s.contains($set) {
1632                                s.remove($set);
1633                                write!(f, $name)?;
1634                                if !s.is_empty() {
1635                                    write!(f, ", ")?
1636                                }
1637                            }
1638                        };
1639                    }
1640                    write!(f, "[")?;
1641                    builtin!(Typ::number(), "Number");
1642                    builtin!(Typ::real(), "Real");
1643                    builtin!(Typ::float(), "Float");
1644                    builtin!(Typ::integer(), "Int");
1645                    builtin!(Typ::unsigned_integer(), "Uint");
1646                    builtin!(Typ::signed_integer(), "Sint");
1647                    for (i, t) in s.iter().enumerate() {
1648                        write!(f, "{t}")?;
1649                        if i < s.len() - 1 {
1650                            write!(f, ", ")?;
1651                        }
1652                    }
1653                    write!(f, "]")
1654                }
1655            }
1656        }
1657    }
1658}