Skip to main content

graphix_compiler/typ/
mod.rs

1use crate::{
2    env::{Env, TypeDef},
3    expr::ModPath,
4    format_with_flags, PrintFlag, PRINT_FLAGS,
5};
6use anyhow::{anyhow, bail, Result};
7use arcstr::ArcStr;
8use enumflags2::BitFlags;
9use fxhash::{FxHashMap, FxHashSet};
10use netidx::{publisher::Typ, utils::Either};
11use poolshark::{local::LPooled, IsoPoolable};
12use smallvec::SmallVec;
13use std::{
14    cmp::{Eq, PartialEq},
15    fmt::Debug,
16    iter,
17    ops::{Deref, DerefMut},
18};
19use triomphe::Arc;
20
21mod cast;
22mod contains;
23mod fntyp;
24mod matches;
25mod normalize;
26mod print;
27mod setops;
28mod tval;
29mod tvar;
30
31pub use fntyp::{FnArgType, FnType};
32pub use tval::TVal;
33pub use tvar::TVar;
34
35struct AndAc(bool);
36
37impl FromIterator<bool> for AndAc {
38    fn from_iter<T: IntoIterator<Item = bool>>(iter: T) -> Self {
39        AndAc(iter.into_iter().all(|b| b))
40    }
41}
42
43struct RefHist<H: IsoPoolable> {
44    inner: LPooled<H>,
45    ref_ids: LPooled<FxHashMap<usize, SmallVec<[(Arc<[Type]>, usize); 2]>>>,
46    next_id: usize,
47}
48
49impl<H: IsoPoolable> Deref for RefHist<H> {
50    type Target = H;
51    fn deref(&self) -> &H {
52        &*self.inner
53    }
54}
55
56impl<H: IsoPoolable> DerefMut for RefHist<H> {
57    fn deref_mut(&mut self) -> &mut H {
58        &mut *self.inner
59    }
60}
61
62impl<H: IsoPoolable> RefHist<H> {
63    fn new(inner: LPooled<H>) -> Self {
64        RefHist { inner, ref_ids: LPooled::take(), next_id: 0 }
65    }
66
67    /// Return a stable ID for a Ref type based on (typedef identity, params).
68    /// Returns None for non-Ref types — cycle detection is driven by the
69    /// Ref side, and None collapses all non-Ref types to the same key.
70    fn ref_id(&mut self, t: &Type, env: &Env) -> Option<usize> {
71        match t {
72            Type::Ref { scope, name, params } => match env.lookup_typedef(scope, name) {
73                Some(def) => {
74                    let def_addr = (def as *const TypeDef).addr();
75                    let entries = self.ref_ids.entry(def_addr).or_default();
76                    for &(ref p, id) in entries.iter() {
77                        if **p == **params {
78                            return Some(id);
79                        }
80                    }
81                    let id = self.next_id;
82                    self.next_id += 1;
83                    entries.push((params.clone(), id));
84                    Some(id)
85                }
86                None => None,
87            },
88            _ => None,
89        }
90    }
91}
92
93atomic_id!(AbstractId);
94
95#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
96pub enum Type {
97    Bottom,
98    Any,
99    Primitive(BitFlags<Typ>),
100    Ref { scope: ModPath, name: ModPath, params: Arc<[Type]> },
101    Fn(Arc<FnType>),
102    Set(Arc<[Type]>),
103    TVar(TVar),
104    Error(Arc<Type>),
105    Array(Arc<Type>),
106    ByRef(Arc<Type>),
107    Tuple(Arc<[Type]>),
108    Struct(Arc<[(ArcStr, Type)]>),
109    Variant(ArcStr, Arc<[Type]>),
110    Map { key: Arc<Type>, value: Arc<Type> },
111    Abstract { id: AbstractId, params: Arc<[Type]> },
112}
113
114impl Default for Type {
115    fn default() -> Self {
116        Self::Bottom
117    }
118}
119
120impl Type {
121    pub fn empty_tvar() -> Self {
122        Type::TVar(TVar::default())
123    }
124
125    fn iter_prims(&self) -> impl Iterator<Item = Self> {
126        match self {
127            Self::Primitive(p) => {
128                Either::Left(p.iter().map(|t| Type::Primitive(t.into())))
129            }
130            t => Either::Right(iter::once(t.clone())),
131        }
132    }
133
134    pub fn is_defined(&self) -> bool {
135        match self {
136            Self::Bottom
137            | Self::Any
138            | Self::Primitive(_)
139            | Self::Fn(_)
140            | Self::Set(_)
141            | Self::Error(_)
142            | Self::Array(_)
143            | Self::ByRef(_)
144            | Self::Tuple(_)
145            | Self::Struct(_)
146            | Self::Variant(_, _)
147            | Self::Ref { .. }
148            | Self::Map { .. }
149            | Self::Abstract { .. } => true,
150            Self::TVar(tv) => tv.read().typ.read().is_some(),
151        }
152    }
153
154    pub fn lookup_ref(&self, env: &Env) -> Result<Type> {
155        match self {
156            Self::Ref { scope, name, params } => {
157                let def = env
158                    .lookup_typedef(scope, name)
159                    .ok_or_else(|| anyhow!("undefined type {name} in {scope}"))?;
160                if def.params.len() != params.len() {
161                    bail!("{} expects {} type parameters", name, def.params.len());
162                }
163                let mut known: LPooled<FxHashMap<ArcStr, Type>> = LPooled::take();
164                for ((tv, ct), arg) in def.params.iter().zip(params.iter()) {
165                    if let Some(ct) = ct {
166                        ct.check_contains(env, arg)?;
167                    }
168                    known.insert(tv.name.clone(), arg.clone());
169                }
170                Ok(def.typ.replace_tvars(&known))
171            }
172            t => Ok(t.clone()),
173        }
174    }
175
176    pub fn any() -> Self {
177        Self::Any
178    }
179
180    pub fn boolean() -> Self {
181        Self::Primitive(Typ::Bool.into())
182    }
183
184    pub fn number() -> Self {
185        Self::Primitive(Typ::number())
186    }
187
188    pub fn int() -> Self {
189        Self::Primitive(Typ::integer())
190    }
191
192    pub fn uint() -> Self {
193        Self::Primitive(Typ::unsigned_integer())
194    }
195
196    fn strip_error_int(
197        &self,
198        env: &Env,
199        hist: &mut RefHist<FxHashSet<Option<usize>>>,
200    ) -> Option<Type> {
201        match self {
202            Type::Error(t) => match t.strip_error_int(env, hist) {
203                Some(t) => Some(t),
204                None => Some((**t).clone()),
205            },
206            Type::TVar(tv) => {
207                tv.read().typ.read().as_ref().and_then(|t| t.strip_error_int(env, hist))
208            }
209            Type::Primitive(p) => {
210                if *p == BitFlags::from(Typ::Error) {
211                    Some(Type::Any)
212                } else {
213                    None
214                }
215            }
216            Type::Ref { .. } => {
217                let id = hist.ref_id(self, env);
218                let t = self.lookup_ref(env).ok()?;
219                if hist.insert(id) {
220                    t.strip_error_int(env, hist)
221                } else {
222                    None
223                }
224            }
225            Type::Set(s) => {
226                let r = Self::flatten_set(
227                    s.iter().filter_map(|t| t.strip_error_int(env, hist)),
228                );
229                match r {
230                    Type::Primitive(p) if p.is_empty() => None,
231                    t => Some(t),
232                }
233            }
234            Type::Array(_)
235            | Type::Map { .. }
236            | Type::ByRef(_)
237            | Type::Tuple(_)
238            | Type::Struct(_)
239            | Type::Variant(_, _)
240            | Type::Fn(_)
241            | Type::Any
242            | Type::Bottom
243            | Type::Abstract { .. } => None,
244        }
245    }
246
247    /// remove the outer error type and return the inner payload, fail if self
248    /// isn't an error or contains non error types
249    pub fn strip_error(&self, env: &Env) -> Option<Self> {
250        self.strip_error_int(
251            env,
252            &mut RefHist::<FxHashSet<Option<usize>>>::new(LPooled::take()),
253        )
254    }
255
256    pub fn is_bot(&self) -> bool {
257        match self {
258            Type::Bottom => true,
259            Type::Any
260            | Type::Abstract { .. }
261            | Type::TVar(_)
262            | Type::Primitive(_)
263            | Type::Ref { .. }
264            | Type::Fn(_)
265            | Type::Error(_)
266            | Type::Array(_)
267            | Type::ByRef(_)
268            | Type::Tuple(_)
269            | Type::Struct(_)
270            | Type::Variant(_, _)
271            | Type::Set(_)
272            | Type::Map { .. } => false,
273        }
274    }
275
276    pub fn with_deref<R, F: FnOnce(Option<&Self>) -> R>(&self, f: F) -> R {
277        match self {
278            Self::Bottom
279            | Self::Abstract { .. }
280            | Self::Any
281            | Self::Primitive(_)
282            | Self::Fn(_)
283            | Self::Set(_)
284            | Self::Error(_)
285            | Self::Array(_)
286            | Self::ByRef(_)
287            | Self::Tuple(_)
288            | Self::Struct(_)
289            | Self::Variant(_, _)
290            | Self::Ref { .. }
291            | Self::Map { .. } => f(Some(self)),
292            Self::TVar(tv) => match tv.read().typ.read().as_ref() {
293                Some(t) => t.with_deref(f),
294                None => f(None),
295            },
296        }
297    }
298
299    pub fn scope_refs(&self, scope: &ModPath) -> Type {
300        match self {
301            Type::Bottom => Type::Bottom,
302            Type::Any => Type::Any,
303            Type::Primitive(s) => Type::Primitive(*s),
304            Type::Abstract { id, params } => Type::Abstract {
305                id: *id,
306                params: Arc::from_iter(params.iter().map(|t| t.scope_refs(scope))),
307            },
308            Type::Error(t0) => Type::Error(Arc::new(t0.scope_refs(scope))),
309            Type::Array(t0) => Type::Array(Arc::new(t0.scope_refs(scope))),
310            Type::Map { key, value } => {
311                let key = Arc::new(key.scope_refs(scope));
312                let value = Arc::new(value.scope_refs(scope));
313                Type::Map { key, value }
314            }
315            Type::ByRef(t) => Type::ByRef(Arc::new(t.scope_refs(scope))),
316            Type::Tuple(ts) => {
317                let i = ts.iter().map(|t| t.scope_refs(scope));
318                Type::Tuple(Arc::from_iter(i))
319            }
320            Type::Variant(tag, ts) => {
321                let i = ts.iter().map(|t| t.scope_refs(scope));
322                Type::Variant(tag.clone(), Arc::from_iter(i))
323            }
324            Type::Struct(ts) => {
325                let i = ts.iter().map(|(n, t)| (n.clone(), t.scope_refs(scope)));
326                Type::Struct(Arc::from_iter(i))
327            }
328            Type::TVar(tv) => match tv.read().typ.read().as_ref() {
329                None => Type::TVar(TVar::empty_named(tv.name.clone())),
330                Some(typ) => {
331                    let typ = typ.scope_refs(scope);
332                    Type::TVar(TVar::named(tv.name.clone(), typ))
333                }
334            },
335            Type::Ref { scope: _, name, params } => {
336                let params = Arc::from_iter(params.iter().map(|t| t.scope_refs(scope)));
337                Type::Ref { scope: scope.clone(), name: name.clone(), params }
338            }
339            Type::Set(ts) => {
340                Type::Set(Arc::from_iter(ts.iter().map(|t| t.scope_refs(scope))))
341            }
342            Type::Fn(f) => Type::Fn(Arc::new(f.scope_refs(scope))),
343        }
344    }
345}