Skip to main content

graphix_compiler/typ/
mod.rs

1use crate::{env::{Env, TypeDef}, expr::ModPath, format_with_flags, PrintFlag, PRINT_FLAGS};
2use anyhow::{anyhow, bail, Result};
3use arcstr::ArcStr;
4use enumflags2::BitFlags;
5use fxhash::{FxHashMap, FxHashSet};
6use netidx::{publisher::Typ, utils::Either};
7use poolshark::{local::LPooled, IsoPoolable};
8use smallvec::SmallVec;
9use std::{
10    cmp::{Eq, PartialEq},
11    fmt::Debug,
12    iter,
13    ops::{Deref, DerefMut},
14};
15use triomphe::Arc;
16
17mod cast;
18mod contains;
19mod fntyp;
20mod matches;
21mod normalize;
22mod print;
23mod setops;
24mod tval;
25mod tvar;
26
27pub use fntyp::{FnArgType, FnType};
28pub use tval::TVal;
29pub use tvar::TVar;
30
31struct AndAc(bool);
32
33impl FromIterator<bool> for AndAc {
34    fn from_iter<T: IntoIterator<Item = bool>>(iter: T) -> Self {
35        AndAc(iter.into_iter().all(|b| b))
36    }
37}
38
39struct RefHist<H: IsoPoolable> {
40    inner: LPooled<H>,
41    ref_ids: LPooled<FxHashMap<usize, SmallVec<[(Arc<[Type]>, usize); 2]>>>,
42    next_id: usize,
43}
44
45impl<H: IsoPoolable> Deref for RefHist<H> {
46    type Target = H;
47    fn deref(&self) -> &H {
48        &*self.inner
49    }
50}
51
52impl<H: IsoPoolable> DerefMut for RefHist<H> {
53    fn deref_mut(&mut self) -> &mut H {
54        &mut *self.inner
55    }
56}
57
58impl<H: IsoPoolable> RefHist<H> {
59    fn new(inner: LPooled<H>) -> Self {
60        RefHist { inner, ref_ids: LPooled::take(), next_id: 0 }
61    }
62
63    /// Return a stable ID for a Ref type based on (typedef identity, params).
64    /// Returns None for non-Ref types — cycle detection is driven by the
65    /// Ref side, and None collapses all non-Ref types to the same key.
66    fn ref_id(&mut self, t: &Type, env: &Env) -> Option<usize> {
67        match t {
68            Type::Ref { scope, name, params } => match env.lookup_typedef(scope, name) {
69                Some(def) => {
70                    let def_addr = (def as *const TypeDef).addr();
71                    let entries = self.ref_ids.entry(def_addr).or_default();
72                    for &(ref p, id) in entries.iter() {
73                        if **p == **params {
74                            return Some(id);
75                        }
76                    }
77                    let id = self.next_id;
78                    self.next_id += 1;
79                    entries.push((params.clone(), id));
80                    Some(id)
81                }
82                None => None,
83            },
84            _ => None,
85        }
86    }
87}
88
89atomic_id!(AbstractId);
90
91#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
92pub enum Type {
93    Bottom,
94    Any,
95    Primitive(BitFlags<Typ>),
96    Ref { scope: ModPath, name: ModPath, params: Arc<[Type]> },
97    Fn(Arc<FnType>),
98    Set(Arc<[Type]>),
99    TVar(TVar),
100    Error(Arc<Type>),
101    Array(Arc<Type>),
102    ByRef(Arc<Type>),
103    Tuple(Arc<[Type]>),
104    Struct(Arc<[(ArcStr, Type)]>),
105    Variant(ArcStr, Arc<[Type]>),
106    Map { key: Arc<Type>, value: Arc<Type> },
107    Abstract { id: AbstractId, params: Arc<[Type]> },
108}
109
110impl Default for Type {
111    fn default() -> Self {
112        Self::Bottom
113    }
114}
115
116impl Type {
117    pub fn empty_tvar() -> Self {
118        Type::TVar(TVar::default())
119    }
120
121    fn iter_prims(&self) -> impl Iterator<Item = Self> {
122        match self {
123            Self::Primitive(p) => {
124                Either::Left(p.iter().map(|t| Type::Primitive(t.into())))
125            }
126            t => Either::Right(iter::once(t.clone())),
127        }
128    }
129
130    pub fn is_defined(&self) -> bool {
131        match self {
132            Self::Bottom
133            | Self::Any
134            | Self::Primitive(_)
135            | Self::Fn(_)
136            | Self::Set(_)
137            | Self::Error(_)
138            | Self::Array(_)
139            | Self::ByRef(_)
140            | Self::Tuple(_)
141            | Self::Struct(_)
142            | Self::Variant(_, _)
143            | Self::Ref { .. }
144            | Self::Map { .. }
145            | Self::Abstract { .. } => true,
146            Self::TVar(tv) => tv.read().typ.read().is_some(),
147        }
148    }
149
150    pub fn lookup_ref(&self, env: &Env) -> Result<Type> {
151        match self {
152            Self::Ref { scope, name, params } => {
153                let def = env
154                    .lookup_typedef(scope, name)
155                    .ok_or_else(|| anyhow!("undefined type {name} in {scope}"))?;
156                if def.params.len() != params.len() {
157                    bail!("{} expects {} type parameters", name, def.params.len());
158                }
159                let mut known: LPooled<FxHashMap<ArcStr, Type>> = LPooled::take();
160                for ((tv, ct), arg) in def.params.iter().zip(params.iter()) {
161                    if let Some(ct) = ct {
162                        ct.check_contains(env, arg)?;
163                    }
164                    known.insert(tv.name.clone(), arg.clone());
165                }
166                Ok(def.typ.replace_tvars(&known))
167            }
168            t => Ok(t.clone()),
169        }
170    }
171
172    pub fn any() -> Self {
173        Self::Any
174    }
175
176    pub fn boolean() -> Self {
177        Self::Primitive(Typ::Bool.into())
178    }
179
180    pub fn number() -> Self {
181        Self::Primitive(Typ::number())
182    }
183
184    pub fn int() -> Self {
185        Self::Primitive(Typ::integer())
186    }
187
188    pub fn uint() -> Self {
189        Self::Primitive(Typ::unsigned_integer())
190    }
191
192    fn strip_error_int(
193        &self,
194        env: &Env,
195        hist: &mut RefHist<FxHashSet<Option<usize>>>,
196    ) -> Option<Type> {
197        match self {
198            Type::Error(t) => match t.strip_error_int(env, hist) {
199                Some(t) => Some(t),
200                None => Some((**t).clone()),
201            },
202            Type::TVar(tv) => {
203                tv.read().typ.read().as_ref().and_then(|t| t.strip_error_int(env, hist))
204            }
205            Type::Primitive(p) => {
206                if *p == BitFlags::from(Typ::Error) {
207                    Some(Type::Any)
208                } else {
209                    None
210                }
211            }
212            Type::Ref { .. } => {
213                let id = hist.ref_id(self, env);
214                let t = self.lookup_ref(env).ok()?;
215                if hist.insert(id) {
216                    t.strip_error_int(env, hist)
217                } else {
218                    None
219                }
220            }
221            Type::Set(s) => {
222                let r = Self::flatten_set(
223                    s.iter().filter_map(|t| t.strip_error_int(env, hist)),
224                );
225                match r {
226                    Type::Primitive(p) if p.is_empty() => None,
227                    t => Some(t),
228                }
229            }
230            Type::Array(_)
231            | Type::Map { .. }
232            | Type::ByRef(_)
233            | Type::Tuple(_)
234            | Type::Struct(_)
235            | Type::Variant(_, _)
236            | Type::Fn(_)
237            | Type::Any
238            | Type::Bottom
239            | Type::Abstract { .. } => None,
240        }
241    }
242
243    /// remove the outer error type and return the inner payload, fail if self
244    /// isn't an error or contains non error types
245    pub fn strip_error(&self, env: &Env) -> Option<Self> {
246        self.strip_error_int(env, &mut RefHist::<FxHashSet<Option<usize>>>::new(LPooled::take()))
247    }
248
249    pub fn is_bot(&self) -> bool {
250        match self {
251            Type::Bottom => true,
252            Type::Any
253            | Type::Abstract { .. }
254            | Type::TVar(_)
255            | Type::Primitive(_)
256            | Type::Ref { .. }
257            | Type::Fn(_)
258            | Type::Error(_)
259            | Type::Array(_)
260            | Type::ByRef(_)
261            | Type::Tuple(_)
262            | Type::Struct(_)
263            | Type::Variant(_, _)
264            | Type::Set(_)
265            | Type::Map { .. } => false,
266        }
267    }
268
269    pub fn with_deref<R, F: FnOnce(Option<&Self>) -> R>(&self, f: F) -> R {
270        match self {
271            Self::Bottom
272            | Self::Abstract { .. }
273            | Self::Any
274            | Self::Primitive(_)
275            | Self::Fn(_)
276            | Self::Set(_)
277            | Self::Error(_)
278            | Self::Array(_)
279            | Self::ByRef(_)
280            | Self::Tuple(_)
281            | Self::Struct(_)
282            | Self::Variant(_, _)
283            | Self::Ref { .. }
284            | Self::Map { .. } => f(Some(self)),
285            Self::TVar(tv) => match tv.read().typ.read().as_ref() {
286                Some(t) => t.with_deref(f),
287                None => f(None),
288            },
289        }
290    }
291
292    pub fn scope_refs(&self, scope: &ModPath) -> Type {
293        match self {
294            Type::Bottom => Type::Bottom,
295            Type::Any => Type::Any,
296            Type::Primitive(s) => Type::Primitive(*s),
297            Type::Abstract { id, params } => Type::Abstract {
298                id: *id,
299                params: Arc::from_iter(params.iter().map(|t| t.scope_refs(scope))),
300            },
301            Type::Error(t0) => Type::Error(Arc::new(t0.scope_refs(scope))),
302            Type::Array(t0) => Type::Array(Arc::new(t0.scope_refs(scope))),
303            Type::Map { key, value } => {
304                let key = Arc::new(key.scope_refs(scope));
305                let value = Arc::new(value.scope_refs(scope));
306                Type::Map { key, value }
307            }
308            Type::ByRef(t) => Type::ByRef(Arc::new(t.scope_refs(scope))),
309            Type::Tuple(ts) => {
310                let i = ts.iter().map(|t| t.scope_refs(scope));
311                Type::Tuple(Arc::from_iter(i))
312            }
313            Type::Variant(tag, ts) => {
314                let i = ts.iter().map(|t| t.scope_refs(scope));
315                Type::Variant(tag.clone(), Arc::from_iter(i))
316            }
317            Type::Struct(ts) => {
318                let i = ts.iter().map(|(n, t)| (n.clone(), t.scope_refs(scope)));
319                Type::Struct(Arc::from_iter(i))
320            }
321            Type::TVar(tv) => match tv.read().typ.read().as_ref() {
322                None => Type::TVar(TVar::empty_named(tv.name.clone())),
323                Some(typ) => {
324                    let typ = typ.scope_refs(scope);
325                    Type::TVar(TVar::named(tv.name.clone(), typ))
326                }
327            },
328            Type::Ref { scope: _, name, params } => {
329                let params = Arc::from_iter(params.iter().map(|t| t.scope_refs(scope)));
330                Type::Ref { scope: scope.clone(), name: name.clone(), params }
331            }
332            Type::Set(ts) => {
333                Type::Set(Arc::from_iter(ts.iter().map(|t| t.scope_refs(scope))))
334            }
335            Type::Fn(f) => Type::Fn(Arc::new(f.scope_refs(scope))),
336        }
337    }
338}