Skip to main content

graphix_compiler/typ/
mod.rs

1use crate::{
2    env::{Env, TypeDef},
3    expr::ModPath,
4    format_with_flags, PrintFlag, PRINT_FLAGS,
5};
6use anyhow::{anyhow, bail, Result};
7use arcstr::ArcStr;
8use enumflags2::BitFlags;
9use fxhash::{FxHashMap, FxHashSet};
10use netidx::{publisher::Typ, utils::Either};
11use poolshark::{local::LPooled, IsoPoolable};
12use smallvec::SmallVec;
13use std::{
14    cmp::{Eq, PartialEq},
15    fmt::Debug,
16    iter,
17    ops::{Deref, DerefMut},
18};
19use triomphe::Arc;
20
21mod cast;
22mod contains;
23mod fntyp;
24mod matches;
25mod normalize;
26mod print;
27mod setops;
28mod tval;
29mod tvar;
30
31pub use fntyp::{FnArgKind, FnArgType, FnType};
32pub use tval::TVal;
33pub use tvar::TVar;
34
35struct AndAc(bool);
36
37impl FromIterator<bool> for AndAc {
38    fn from_iter<T: IntoIterator<Item = bool>>(iter: T) -> Self {
39        AndAc(iter.into_iter().all(|b| b))
40    }
41}
42
43struct RefHist<H: IsoPoolable> {
44    inner: LPooled<H>,
45    ref_ids: LPooled<FxHashMap<usize, SmallVec<[(Arc<[Type]>, usize); 2]>>>,
46    next_id: usize,
47}
48
49impl<H: IsoPoolable> Deref for RefHist<H> {
50    type Target = H;
51    fn deref(&self) -> &H {
52        &*self.inner
53    }
54}
55
56impl<H: IsoPoolable> DerefMut for RefHist<H> {
57    fn deref_mut(&mut self) -> &mut H {
58        &mut *self.inner
59    }
60}
61
62impl<H: IsoPoolable> RefHist<H> {
63    fn new(inner: LPooled<H>) -> Self {
64        RefHist { inner, ref_ids: LPooled::take(), next_id: 0 }
65    }
66
67    /// Return a stable ID for a Ref type based on (typedef identity, params).
68    /// Returns None for non-Ref types — cycle detection is driven by the
69    /// Ref side, and None collapses all non-Ref types to the same key.
70    fn ref_id(&mut self, t: &Type, env: &Env) -> Option<usize> {
71        match t {
72            Type::Ref(TypeRef { scope, name, params, .. }) => {
73                match env.lookup_typedef(scope, name) {
74                    Some(def) => {
75                        let def_addr = (def as *const TypeDef).addr();
76                        let entries = self.ref_ids.entry(def_addr).or_default();
77                        for &(ref p, id) in entries.iter() {
78                            if **p == **params {
79                                return Some(id);
80                            }
81                        }
82                        let id = self.next_id;
83                        self.next_id += 1;
84                        entries.push((params.clone(), id));
85                        Some(id)
86                    }
87                    None => None,
88                }
89            }
90            _ => None,
91        }
92    }
93}
94
95atomic_id!(AbstractId);
96
97/// A reference to a named typedef, e.g. `Foo` or `Result<i64, string>`.
98/// `pos` and `ori` are IDE metadata recording where this reference
99/// was written in source — they're populated by the parser and
100/// ignored for type-system equality, ordering and hashing so they
101/// don't affect type identity.
102#[derive(Debug, Clone)]
103pub struct TypeRef {
104    pub scope: ModPath,
105    pub name: ModPath,
106    pub params: Arc<[Type]>,
107    pub pos: Option<crate::SourcePosition>,
108    pub ori: Option<Arc<crate::expr::Origin>>,
109}
110
111impl TypeRef {
112    /// Build a `TypeRef` with no source-position info — for synthetic
113    /// type references created during type inference, set operations,
114    /// stdlib type literals, etc.
115    pub fn synthetic(scope: ModPath, name: ModPath, params: Arc<[Type]>) -> Self {
116        Self { scope, name, params, pos: None, ori: None }
117    }
118}
119
120impl Default for TypeRef {
121    fn default() -> Self {
122        Self {
123            scope: ModPath::root(),
124            name: ModPath::root(),
125            params: Arc::from(Vec::<Type>::new()),
126            pos: None,
127            ori: None,
128        }
129    }
130}
131
132impl PartialEq for TypeRef {
133    fn eq(&self, other: &Self) -> bool {
134        self.scope == other.scope
135            && self.name == other.name
136            && self.params == other.params
137    }
138}
139
140impl Eq for TypeRef {}
141
142impl PartialOrd for TypeRef {
143    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
144        Some(self.cmp(other))
145    }
146}
147
148impl Ord for TypeRef {
149    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
150        self.scope
151            .cmp(&other.scope)
152            .then_with(|| self.name.cmp(&other.name))
153            .then_with(|| self.params.cmp(&other.params))
154    }
155}
156
157#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
158pub enum Type {
159    Bottom,
160    Any,
161    Primitive(BitFlags<Typ>),
162    Ref(TypeRef),
163    Fn(Arc<FnType>),
164    Set(Arc<[Type]>),
165    TVar(TVar),
166    Error(Arc<Type>),
167    Array(Arc<Type>),
168    ByRef(Arc<Type>),
169    Tuple(Arc<[Type]>),
170    Struct(Arc<[(ArcStr, Type)]>),
171    Variant(ArcStr, Arc<[Type]>),
172    Map { key: Arc<Type>, value: Arc<Type> },
173    Abstract { id: AbstractId, params: Arc<[Type]> },
174}
175
176impl Default for Type {
177    fn default() -> Self {
178        Self::Bottom
179    }
180}
181
182impl Type {
183    pub fn empty_tvar() -> Self {
184        Type::TVar(TVar::default())
185    }
186
187    fn iter_prims(&self) -> impl Iterator<Item = Self> {
188        match self {
189            Self::Primitive(p) => {
190                Either::Left(p.iter().map(|t| Type::Primitive(t.into())))
191            }
192            t => Either::Right(iter::once(t.clone())),
193        }
194    }
195
196    pub fn is_defined(&self) -> bool {
197        match self {
198            Self::Bottom
199            | Self::Any
200            | Self::Primitive(_)
201            | Self::Fn(_)
202            | Self::Set(_)
203            | Self::Error(_)
204            | Self::Array(_)
205            | Self::ByRef(_)
206            | Self::Tuple(_)
207            | Self::Struct(_)
208            | Self::Variant(_, _)
209            | Self::Ref(TypeRef { .. })
210            | Self::Map { .. }
211            | Self::Abstract { .. } => true,
212            Self::TVar(tv) => tv.read().typ.read().is_some(),
213        }
214    }
215
216    pub fn lookup_ref(&self, env: &Env) -> Result<Type> {
217        match self {
218            Self::Ref(TypeRef { scope, name, params, pos, ori }) => {
219                let resolved = env
220                    .find_visible(scope, name, |s, n| {
221                        env.typedefs.get(s).and_then(|m| m.get(n)).map(|d| {
222                            let canonical = ModPath(netidx::path::Path::from(
223                                arcstr::ArcStr::from(s),
224                            ));
225                            (
226                                canonical,
227                                d.pos,
228                                d.ori.clone(),
229                                d.params.clone(),
230                                d.typ.clone(),
231                            )
232                        })
233                    })
234                    .ok_or_else(|| anyhow!("undefined type {name} in {scope}"))?;
235                let (canonical_scope, def_pos, def_ori, def_params, def_typ) = resolved;
236                if def_params.len() != params.len() {
237                    bail!("{} expects {} type parameters", name, def_params.len());
238                }
239                if env.lsp_mode {
240                    if let (Some(pos), Some(ori)) = (pos, ori) {
241                        env.push_type_ref(crate::TypeRefSite {
242                            pos: *pos,
243                            ori: ori.clone(),
244                            name: name.clone(),
245                            canonical_scope,
246                            def_pos,
247                            def_ori,
248                        });
249                    }
250                }
251                let mut known: LPooled<FxHashMap<ArcStr, Type>> = LPooled::take();
252                for ((tv, ct), arg) in def_params.iter().zip(params.iter()) {
253                    if let Some(ct) = ct {
254                        ct.check_contains(env, arg)?;
255                    }
256                    known.insert(tv.name.clone(), arg.clone());
257                }
258                Ok(def_typ.replace_tvars(&known))
259            }
260            t => Ok(t.clone()),
261        }
262    }
263
264    /// Walk this type tree and, for every `Type::Ref` carrying
265    /// parser-populated `pos`/`ori`, push a `TypeRefSite` to the
266    /// IDE side-channel. Used at typedef-registration time so
267    /// references inside typedef bodies (which the type system
268    /// never auto-derefs) still show up in find-references results.
269    /// Caller is responsible for gating on `env.lsp_mode`; this
270    /// method recurses unconditionally once entered.
271    pub fn record_ide_refs(&self, env: &Env, fallback_scope: &ModPath) {
272        match self {
273            Type::Ref(tr) => {
274                if let (Some(pos), Some(ori)) = (tr.pos, &tr.ori) {
275                    let resolved = env.find_visible(&tr.scope, &tr.name, |s, n| {
276                        env.typedefs.get(s).and_then(|m| m.get(n)).map(|d| {
277                            let canonical = ModPath(netidx::path::Path::from(
278                                arcstr::ArcStr::from(s),
279                            ));
280                            (canonical, d.pos, d.ori.clone())
281                        })
282                    });
283                    let (canonical_scope, def_pos, def_ori) = match resolved {
284                        Some((s, dp, do_)) => (s, dp, do_),
285                        None => (
286                            fallback_scope.clone(),
287                            crate::SourcePosition::default(),
288                            ori.clone(),
289                        ),
290                    };
291                    env.push_type_ref(crate::TypeRefSite {
292                        pos,
293                        ori: ori.clone(),
294                        name: tr.name.clone(),
295                        canonical_scope,
296                        def_pos,
297                        def_ori,
298                    });
299                }
300                for p in tr.params.iter() {
301                    p.record_ide_refs(env, fallback_scope);
302                }
303            }
304            Type::Set(ts) | Type::Tuple(ts) | Type::Variant(_, ts) => {
305                for t in ts.iter() {
306                    t.record_ide_refs(env, fallback_scope);
307                }
308            }
309            Type::Array(t) | Type::Error(t) | Type::ByRef(t) => {
310                t.record_ide_refs(env, fallback_scope)
311            }
312            Type::Map { key, value } => {
313                key.record_ide_refs(env, fallback_scope);
314                value.record_ide_refs(env, fallback_scope);
315            }
316            Type::Struct(fields) => {
317                for (_, t) in fields.iter() {
318                    t.record_ide_refs(env, fallback_scope);
319                }
320            }
321            Type::Fn(ft) => {
322                for arg in ft.args.iter() {
323                    arg.typ.record_ide_refs(env, fallback_scope);
324                }
325                ft.rtype.record_ide_refs(env, fallback_scope);
326                ft.throws.record_ide_refs(env, fallback_scope);
327            }
328            Type::Abstract { params, .. } => {
329                for p in params.iter() {
330                    p.record_ide_refs(env, fallback_scope);
331                }
332            }
333            Type::TVar(tv) => {
334                if let Some(t) = tv.read().typ.read().as_ref() {
335                    t.record_ide_refs(env, fallback_scope);
336                }
337            }
338            Type::Bottom | Type::Any | Type::Primitive(_) => (),
339        }
340    }
341
342    pub fn any() -> Self {
343        Self::Any
344    }
345
346    pub fn boolean() -> Self {
347        Self::Primitive(Typ::Bool.into())
348    }
349
350    pub fn number() -> Self {
351        Self::Primitive(Typ::number())
352    }
353
354    pub fn int() -> Self {
355        Self::Primitive(Typ::integer())
356    }
357
358    pub fn uint() -> Self {
359        Self::Primitive(Typ::unsigned_integer())
360    }
361
362    fn strip_error_int(
363        &self,
364        env: &Env,
365        hist: &mut RefHist<FxHashSet<Option<usize>>>,
366    ) -> Option<Type> {
367        match self {
368            Type::Error(t) => match t.strip_error_int(env, hist) {
369                Some(t) => Some(t),
370                None => Some((**t).clone()),
371            },
372            Type::TVar(tv) => {
373                tv.read().typ.read().as_ref().and_then(|t| t.strip_error_int(env, hist))
374            }
375            Type::Primitive(p) => {
376                if *p == BitFlags::from(Typ::Error) {
377                    Some(Type::Any)
378                } else {
379                    None
380                }
381            }
382            Type::Ref(TypeRef { .. }) => {
383                let id = hist.ref_id(self, env);
384                let t = self.lookup_ref(env).ok()?;
385                if hist.insert(id) {
386                    t.strip_error_int(env, hist)
387                } else {
388                    None
389                }
390            }
391            Type::Set(s) => {
392                let r = Self::flatten_set(
393                    s.iter().filter_map(|t| t.strip_error_int(env, hist)),
394                );
395                match r {
396                    Type::Primitive(p) if p.is_empty() => None,
397                    t => Some(t),
398                }
399            }
400            Type::Array(_)
401            | Type::Map { .. }
402            | Type::ByRef(_)
403            | Type::Tuple(_)
404            | Type::Struct(_)
405            | Type::Variant(_, _)
406            | Type::Fn(_)
407            | Type::Any
408            | Type::Bottom
409            | Type::Abstract { .. } => None,
410        }
411    }
412
413    /// remove the outer error type and return the inner payload, fail if self
414    /// isn't an error or contains non error types
415    pub fn strip_error(&self, env: &Env) -> Option<Self> {
416        self.strip_error_int(
417            env,
418            &mut RefHist::<FxHashSet<Option<usize>>>::new(LPooled::take()),
419        )
420    }
421
422    pub fn is_bot(&self) -> bool {
423        match self {
424            Type::Bottom => true,
425            Type::Any
426            | Type::Abstract { .. }
427            | Type::TVar(_)
428            | Type::Primitive(_)
429            | Type::Ref(TypeRef { .. })
430            | Type::Fn(_)
431            | Type::Error(_)
432            | Type::Array(_)
433            | Type::ByRef(_)
434            | Type::Tuple(_)
435            | Type::Struct(_)
436            | Type::Variant(_, _)
437            | Type::Set(_)
438            | Type::Map { .. } => false,
439        }
440    }
441
442    pub fn with_deref<R, F: FnOnce(Option<&Self>) -> R>(&self, f: F) -> R {
443        match self {
444            Self::Bottom
445            | Self::Abstract { .. }
446            | Self::Any
447            | Self::Primitive(_)
448            | Self::Fn(_)
449            | Self::Set(_)
450            | Self::Error(_)
451            | Self::Array(_)
452            | Self::ByRef(_)
453            | Self::Tuple(_)
454            | Self::Struct(_)
455            | Self::Variant(_, _)
456            | Self::Ref(TypeRef { .. })
457            | Self::Map { .. } => f(Some(self)),
458            Self::TVar(tv) => match tv.read().typ.read().as_ref() {
459                Some(t) => t.with_deref(f),
460                None => f(None),
461            },
462        }
463    }
464
465    pub fn scope_refs(&self, scope: &ModPath) -> Type {
466        match self {
467            Type::Bottom => Type::Bottom,
468            Type::Any => Type::Any,
469            Type::Primitive(s) => Type::Primitive(*s),
470            Type::Abstract { id, params } => Type::Abstract {
471                id: *id,
472                params: Arc::from_iter(params.iter().map(|t| t.scope_refs(scope))),
473            },
474            Type::Error(t0) => Type::Error(Arc::new(t0.scope_refs(scope))),
475            Type::Array(t0) => Type::Array(Arc::new(t0.scope_refs(scope))),
476            Type::Map { key, value } => {
477                let key = Arc::new(key.scope_refs(scope));
478                let value = Arc::new(value.scope_refs(scope));
479                Type::Map { key, value }
480            }
481            Type::ByRef(t) => Type::ByRef(Arc::new(t.scope_refs(scope))),
482            Type::Tuple(ts) => {
483                let i = ts.iter().map(|t| t.scope_refs(scope));
484                Type::Tuple(Arc::from_iter(i))
485            }
486            Type::Variant(tag, ts) => {
487                let i = ts.iter().map(|t| t.scope_refs(scope));
488                Type::Variant(tag.clone(), Arc::from_iter(i))
489            }
490            Type::Struct(ts) => {
491                let i = ts.iter().map(|(n, t)| (n.clone(), t.scope_refs(scope)));
492                Type::Struct(Arc::from_iter(i))
493            }
494            Type::TVar(tv) => match tv.read().typ.read().as_ref() {
495                None => Type::TVar(TVar::empty_named(tv.name.clone())),
496                Some(typ) => {
497                    let typ = typ.scope_refs(scope);
498                    Type::TVar(TVar::named(tv.name.clone(), typ))
499                }
500            },
501            Type::Ref(tr) => {
502                let params =
503                    Arc::from_iter(tr.params.iter().map(|t| t.scope_refs(scope)));
504                Type::Ref(TypeRef { scope: scope.clone(), params, ..tr.clone() })
505            }
506            Type::Set(ts) => {
507                Type::Set(Arc::from_iter(ts.iter().map(|t| t.scope_refs(scope))))
508            }
509            Type::Fn(f) => Type::Fn(Arc::new(f.scope_refs(scope))),
510        }
511    }
512}