Skip to main content

graphix_compiler/typ/
mod.rs

1use crate::{
2    env::{Env, TypeDef},
3    expr::ModPath,
4    format_with_flags, PrintFlag, PRINT_FLAGS,
5};
6use ahash::{AHashMap, AHashSet};
7use anyhow::{anyhow, bail, Result};
8use arcstr::ArcStr;
9use enumflags2::BitFlags;
10use netidx::{publisher::Typ, utils::Either};
11use nohash::IntMap;
12use poolshark::{local::LPooled, IsoPoolable};
13use smallvec::SmallVec;
14use std::{
15    cmp::{Eq, PartialEq},
16    fmt::Debug,
17    iter,
18    ops::{Deref, DerefMut},
19};
20use triomphe::Arc;
21
22mod cast;
23mod contains;
24mod fntyp;
25mod matches;
26mod normalize;
27mod print;
28mod setops;
29mod tval;
30mod tvar;
31
32pub use fntyp::{FnArgKind, FnArgType, FnType};
33pub use tval::TVal;
34pub use tvar::TVar;
35
36struct AndAc(bool);
37
38impl FromIterator<bool> for AndAc {
39    fn from_iter<T: IntoIterator<Item = bool>>(iter: T) -> Self {
40        AndAc(iter.into_iter().all(|b| b))
41    }
42}
43
44struct RefHist<H: IsoPoolable> {
45    inner: LPooled<H>,
46    ref_ids: LPooled<IntMap<usize, SmallVec<[(Arc<[Type]>, usize); 2]>>>,
47    next_id: usize,
48}
49
50impl<H: IsoPoolable> Deref for RefHist<H> {
51    type Target = H;
52    fn deref(&self) -> &H {
53        &*self.inner
54    }
55}
56
57impl<H: IsoPoolable> DerefMut for RefHist<H> {
58    fn deref_mut(&mut self) -> &mut H {
59        &mut *self.inner
60    }
61}
62
63impl<H: IsoPoolable> RefHist<H> {
64    fn new(inner: LPooled<H>) -> Self {
65        RefHist { inner, ref_ids: LPooled::take(), next_id: 0 }
66    }
67
68    /// Return a stable ID for a Ref type based on (typedef identity, params).
69    /// Returns None for non-Ref types — cycle detection is driven by the
70    /// Ref side, and None collapses all non-Ref types to the same key.
71    fn ref_id(&mut self, t: &Type, env: &Env) -> Option<usize> {
72        match t {
73            Type::Ref(TypeRef { scope, name, params, .. }) => {
74                match env.lookup_typedef(scope, name) {
75                    Some(def) => {
76                        let def_addr = (def as *const TypeDef).addr();
77                        let entries = self.ref_ids.entry(def_addr).or_default();
78                        for &(ref p, id) in entries.iter() {
79                            if **p == **params {
80                                return Some(id);
81                            }
82                        }
83                        let id = self.next_id;
84                        self.next_id += 1;
85                        entries.push((params.clone(), id));
86                        Some(id)
87                    }
88                    None => None,
89                }
90            }
91            _ => None,
92        }
93    }
94}
95
96atomic_id!(AbstractId);
97
98/// A reference to a named typedef, e.g. `Foo` or `Result<i64, string>`.
99/// `pos` and `ori` are IDE metadata recording where this reference
100/// was written in source — they're populated by the parser and
101/// ignored for type-system equality, ordering and hashing so they
102/// don't affect type identity.
103#[derive(Debug, Clone)]
104pub struct TypeRef {
105    pub scope: ModPath,
106    pub name: ModPath,
107    pub params: Arc<[Type]>,
108    pub pos: Option<crate::SourcePosition>,
109    pub ori: Option<Arc<crate::expr::Origin>>,
110}
111
112impl TypeRef {
113    /// Build a `TypeRef` with no source-position info — for synthetic
114    /// type references created during type inference, set operations,
115    /// stdlib type literals, etc.
116    pub fn synthetic(scope: ModPath, name: ModPath, params: Arc<[Type]>) -> Self {
117        Self { scope, name, params, pos: None, ori: None }
118    }
119}
120
121impl Default for TypeRef {
122    fn default() -> Self {
123        Self {
124            scope: ModPath::root(),
125            name: ModPath::root(),
126            params: Arc::from(Vec::<Type>::new()),
127            pos: None,
128            ori: None,
129        }
130    }
131}
132
133impl PartialEq for TypeRef {
134    fn eq(&self, other: &Self) -> bool {
135        self.scope == other.scope
136            && self.name == other.name
137            && self.params == other.params
138    }
139}
140
141impl Eq for TypeRef {}
142
143impl PartialOrd for TypeRef {
144    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
145        Some(self.cmp(other))
146    }
147}
148
149impl Ord for TypeRef {
150    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
151        self.scope
152            .cmp(&other.scope)
153            .then_with(|| self.name.cmp(&other.name))
154            .then_with(|| self.params.cmp(&other.params))
155    }
156}
157
158#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
159pub enum Type {
160    Bottom,
161    Any,
162    Primitive(BitFlags<Typ>),
163    Ref(TypeRef),
164    Fn(Arc<FnType>),
165    Set(Arc<[Type]>),
166    TVar(TVar),
167    Error(Arc<Type>),
168    Array(Arc<Type>),
169    ByRef(Arc<Type>),
170    Tuple(Arc<[Type]>),
171    Struct(Arc<[(ArcStr, Type)]>),
172    Variant(ArcStr, Arc<[Type]>),
173    Map { key: Arc<Type>, value: Arc<Type> },
174    Abstract { id: AbstractId, params: Arc<[Type]> },
175}
176
177impl Default for Type {
178    fn default() -> Self {
179        Self::Bottom
180    }
181}
182
183impl Type {
184    pub fn empty_tvar() -> Self {
185        Type::TVar(TVar::default())
186    }
187
188    fn iter_prims(&self) -> impl Iterator<Item = Self> {
189        match self {
190            Self::Primitive(p) => {
191                Either::Left(p.iter().map(|t| Type::Primitive(t.into())))
192            }
193            t => Either::Right(iter::once(t.clone())),
194        }
195    }
196
197    pub fn is_defined(&self) -> bool {
198        match self {
199            Self::Bottom
200            | Self::Any
201            | Self::Primitive(_)
202            | Self::Fn(_)
203            | Self::Set(_)
204            | Self::Error(_)
205            | Self::Array(_)
206            | Self::ByRef(_)
207            | Self::Tuple(_)
208            | Self::Struct(_)
209            | Self::Variant(_, _)
210            | Self::Ref(TypeRef { .. })
211            | Self::Map { .. }
212            | Self::Abstract { .. } => true,
213            Self::TVar(tv) => tv.read().typ.read().is_some(),
214        }
215    }
216
217    pub fn lookup_ref(&self, env: &Env) -> Result<Type> {
218        match self {
219            Self::Ref(TypeRef { scope, name, params, pos, ori }) => {
220                let resolved = env
221                    .find_visible(scope, name, |s, n| {
222                        env.typedefs.get(s).and_then(|m| m.get(n)).map(|d| {
223                            let canonical = ModPath(netidx::path::Path::from(
224                                arcstr::ArcStr::from(s),
225                            ));
226                            (
227                                canonical,
228                                d.pos,
229                                d.ori.clone(),
230                                d.params.clone(),
231                                d.typ.clone(),
232                            )
233                        })
234                    })
235                    .ok_or_else(|| anyhow!("undefined type {name} in {scope}"))?;
236                let (canonical_scope, def_pos, def_ori, def_params, def_typ) = resolved;
237                if def_params.len() != params.len() {
238                    bail!("{} expects {} type parameters", name, def_params.len());
239                }
240                if env.lsp_mode {
241                    if let (Some(pos), Some(ori)) = (pos, ori) {
242                        env.push_type_ref(crate::TypeRefSite {
243                            pos: *pos,
244                            ori: ori.clone(),
245                            name: name.clone(),
246                            canonical_scope,
247                            def_pos,
248                            def_ori,
249                        });
250                    }
251                }
252                let mut known: LPooled<AHashMap<ArcStr, Type>> = LPooled::take();
253                for ((tv, ct), arg) in def_params.iter().zip(params.iter()) {
254                    if let Some(ct) = ct {
255                        ct.check_contains(env, arg)?;
256                    }
257                    known.insert(tv.name.clone(), arg.clone());
258                }
259                Ok(def_typ.replace_tvars(&known))
260            }
261            t => Ok(t.clone()),
262        }
263    }
264
265    /// Walk this type tree and, for every `Type::Ref` carrying
266    /// parser-populated `pos`/`ori`, push a `TypeRefSite` to the
267    /// IDE side-channel. Used at typedef-registration time so
268    /// references inside typedef bodies (which the type system
269    /// never auto-derefs) still show up in find-references results.
270    /// Caller is responsible for gating on `env.lsp_mode`; this
271    /// method recurses unconditionally once entered.
272    pub fn record_ide_refs(&self, env: &Env, fallback_scope: &ModPath) {
273        match self {
274            Type::Ref(tr) => {
275                if let (Some(pos), Some(ori)) = (tr.pos, &tr.ori) {
276                    let resolved = env.find_visible(&tr.scope, &tr.name, |s, n| {
277                        env.typedefs.get(s).and_then(|m| m.get(n)).map(|d| {
278                            let canonical = ModPath(netidx::path::Path::from(
279                                arcstr::ArcStr::from(s),
280                            ));
281                            (canonical, d.pos, d.ori.clone())
282                        })
283                    });
284                    let (canonical_scope, def_pos, def_ori) = match resolved {
285                        Some((s, dp, do_)) => (s, dp, do_),
286                        None => (
287                            fallback_scope.clone(),
288                            crate::SourcePosition::default(),
289                            ori.clone(),
290                        ),
291                    };
292                    env.push_type_ref(crate::TypeRefSite {
293                        pos,
294                        ori: ori.clone(),
295                        name: tr.name.clone(),
296                        canonical_scope,
297                        def_pos,
298                        def_ori,
299                    });
300                }
301                for p in tr.params.iter() {
302                    p.record_ide_refs(env, fallback_scope);
303                }
304            }
305            Type::Set(ts) | Type::Tuple(ts) | Type::Variant(_, ts) => {
306                for t in ts.iter() {
307                    t.record_ide_refs(env, fallback_scope);
308                }
309            }
310            Type::Array(t) | Type::Error(t) | Type::ByRef(t) => {
311                t.record_ide_refs(env, fallback_scope)
312            }
313            Type::Map { key, value } => {
314                key.record_ide_refs(env, fallback_scope);
315                value.record_ide_refs(env, fallback_scope);
316            }
317            Type::Struct(fields) => {
318                for (_, t) in fields.iter() {
319                    t.record_ide_refs(env, fallback_scope);
320                }
321            }
322            Type::Fn(ft) => {
323                for arg in ft.args.iter() {
324                    arg.typ.record_ide_refs(env, fallback_scope);
325                }
326                ft.rtype.record_ide_refs(env, fallback_scope);
327                ft.throws.record_ide_refs(env, fallback_scope);
328            }
329            Type::Abstract { params, .. } => {
330                for p in params.iter() {
331                    p.record_ide_refs(env, fallback_scope);
332                }
333            }
334            Type::TVar(tv) => {
335                if let Some(t) = tv.read().typ.read().as_ref() {
336                    t.record_ide_refs(env, fallback_scope);
337                }
338            }
339            Type::Bottom | Type::Any | Type::Primitive(_) => (),
340        }
341    }
342
343    pub fn any() -> Self {
344        Self::Any
345    }
346
347    pub fn boolean() -> Self {
348        Self::Primitive(Typ::Bool.into())
349    }
350
351    pub fn number() -> Self {
352        Self::Primitive(Typ::number())
353    }
354
355    pub fn int() -> Self {
356        Self::Primitive(Typ::integer())
357    }
358
359    pub fn uint() -> Self {
360        Self::Primitive(Typ::unsigned_integer())
361    }
362
363    fn strip_error_int(
364        &self,
365        env: &Env,
366        hist: &mut RefHist<AHashSet<Option<usize>>>,
367    ) -> Option<Type> {
368        match self {
369            Type::Error(t) => match t.strip_error_int(env, hist) {
370                Some(t) => Some(t),
371                None => Some((**t).clone()),
372            },
373            Type::TVar(tv) => {
374                tv.read().typ.read().as_ref().and_then(|t| t.strip_error_int(env, hist))
375            }
376            Type::Primitive(p) => {
377                if *p == BitFlags::from(Typ::Error) {
378                    Some(Type::Any)
379                } else {
380                    None
381                }
382            }
383            Type::Ref(TypeRef { .. }) => {
384                let id = hist.ref_id(self, env);
385                let t = self.lookup_ref(env).ok()?;
386                if hist.insert(id) {
387                    t.strip_error_int(env, hist)
388                } else {
389                    None
390                }
391            }
392            Type::Set(s) => {
393                let r = Self::flatten_set(
394                    s.iter().filter_map(|t| t.strip_error_int(env, hist)),
395                );
396                match r {
397                    Type::Primitive(p) if p.is_empty() => None,
398                    t => Some(t),
399                }
400            }
401            Type::Array(_)
402            | Type::Map { .. }
403            | Type::ByRef(_)
404            | Type::Tuple(_)
405            | Type::Struct(_)
406            | Type::Variant(_, _)
407            | Type::Fn(_)
408            | Type::Any
409            | Type::Bottom
410            | Type::Abstract { .. } => None,
411        }
412    }
413
414    /// remove the outer error type and return the inner payload, fail if self
415    /// isn't an error or contains non error types
416    pub fn strip_error(&self, env: &Env) -> Option<Self> {
417        self.strip_error_int(
418            env,
419            &mut RefHist::<AHashSet<Option<usize>>>::new(LPooled::take()),
420        )
421    }
422
423    pub fn is_bot(&self) -> bool {
424        match self {
425            Type::Bottom => true,
426            Type::Any
427            | Type::Abstract { .. }
428            | Type::TVar(_)
429            | Type::Primitive(_)
430            | Type::Ref(TypeRef { .. })
431            | Type::Fn(_)
432            | Type::Error(_)
433            | Type::Array(_)
434            | Type::ByRef(_)
435            | Type::Tuple(_)
436            | Type::Struct(_)
437            | Type::Variant(_, _)
438            | Type::Set(_)
439            | Type::Map { .. } => false,
440        }
441    }
442
443    pub fn with_deref<R, F: FnOnce(Option<&Self>) -> R>(&self, f: F) -> R {
444        match self {
445            Self::Bottom
446            | Self::Abstract { .. }
447            | Self::Any
448            | Self::Primitive(_)
449            | Self::Fn(_)
450            | Self::Set(_)
451            | Self::Error(_)
452            | Self::Array(_)
453            | Self::ByRef(_)
454            | Self::Tuple(_)
455            | Self::Struct(_)
456            | Self::Variant(_, _)
457            | Self::Ref(TypeRef { .. })
458            | Self::Map { .. } => f(Some(self)),
459            Self::TVar(tv) => match tv.read().typ.read().as_ref() {
460                Some(t) => t.with_deref(f),
461                None => f(None),
462            },
463        }
464    }
465
466    pub fn scope_refs(&self, scope: &ModPath) -> Type {
467        match self {
468            Type::Bottom => Type::Bottom,
469            Type::Any => Type::Any,
470            Type::Primitive(s) => Type::Primitive(*s),
471            Type::Abstract { id, params } => Type::Abstract {
472                id: *id,
473                params: Arc::from_iter(params.iter().map(|t| t.scope_refs(scope))),
474            },
475            Type::Error(t0) => Type::Error(Arc::new(t0.scope_refs(scope))),
476            Type::Array(t0) => Type::Array(Arc::new(t0.scope_refs(scope))),
477            Type::Map { key, value } => {
478                let key = Arc::new(key.scope_refs(scope));
479                let value = Arc::new(value.scope_refs(scope));
480                Type::Map { key, value }
481            }
482            Type::ByRef(t) => Type::ByRef(Arc::new(t.scope_refs(scope))),
483            Type::Tuple(ts) => {
484                let i = ts.iter().map(|t| t.scope_refs(scope));
485                Type::Tuple(Arc::from_iter(i))
486            }
487            Type::Variant(tag, ts) => {
488                let i = ts.iter().map(|t| t.scope_refs(scope));
489                Type::Variant(tag.clone(), Arc::from_iter(i))
490            }
491            Type::Struct(ts) => {
492                let i = ts.iter().map(|(n, t)| (n.clone(), t.scope_refs(scope)));
493                Type::Struct(Arc::from_iter(i))
494            }
495            Type::TVar(tv) => match tv.read().typ.read().as_ref() {
496                None => Type::TVar(TVar::empty_named(tv.name.clone())),
497                Some(typ) => {
498                    let typ = typ.scope_refs(scope);
499                    Type::TVar(TVar::named(tv.name.clone(), typ))
500                }
501            },
502            Type::Ref(tr) => {
503                let params =
504                    Arc::from_iter(tr.params.iter().map(|t| t.scope_refs(scope)));
505                Type::Ref(TypeRef { scope: scope.clone(), params, ..tr.clone() })
506            }
507            Type::Set(ts) => {
508                Type::Set(Arc::from_iter(ts.iter().map(|t| t.scope_refs(scope))))
509            }
510            Type::Fn(f) => Type::Fn(Arc::new(f.scope_refs(scope))),
511        }
512    }
513}