Skip to main content

graphix_compiler/typ/
setops.rs

1use crate::{env::Env, typ::{RefHist, Type}};
2use anyhow::Result;
3use enumflags2::BitFlags;
4use fxhash::FxHashMap;
5use netidx::publisher::Typ;
6use poolshark::local::LPooled;
7use std::iter;
8use triomphe::Arc;
9
10impl Type {
11    fn union_int(
12        &self,
13        env: &Env,
14        hist: &mut RefHist<FxHashMap<(Option<usize>, Option<usize>), Type>>,
15        t: &Self,
16    ) -> Result<Self> {
17        match (self, t) {
18            (
19                Type::Ref { scope: s0, name: n0, params: p0 },
20                Type::Ref { scope: s1, name: n1, params: p1 },
21            ) if n0 == n1 && s0 == s1 && p0.len() == p1.len() => {
22                let mut params = p0
23                    .iter()
24                    .zip(p1.iter())
25                    .map(|(p0, p1)| p0.union_int(env, hist, p1))
26                    .collect::<Result<LPooled<Vec<_>>>>()?;
27                let params = Arc::from_iter(params.drain(..));
28                Ok(Self::Ref { scope: s0.clone(), name: n0.clone(), params })
29            }
30            (tr @ Type::Ref { .. }, t) => {
31                let t0_id = hist.ref_id(tr, env);
32                let t_id = hist.ref_id(t, env);
33                let t0 = tr.lookup_ref(env)?;
34                match hist.get(&(t0_id, t_id)) {
35                    Some(t) => Ok(t.clone()),
36                    None => {
37                        hist.insert((t0_id, t_id), tr.clone());
38                        let r = t0.union_int(env, hist, t);
39                        hist.remove(&(t0_id, t_id));
40                        r
41                    }
42                }
43            }
44            (t, tr @ Type::Ref { .. }) => {
45                let t_id = hist.ref_id(t, env);
46                let t1_id = hist.ref_id(tr, env);
47                let t1 = tr.lookup_ref(env)?;
48                match hist.get(&(t_id, t1_id)) {
49                    Some(t) => Ok(t.clone()),
50                    None => {
51                        hist.insert((t_id, t1_id), tr.clone());
52                        let r = t.union_int(env, hist, &t1);
53                        hist.remove(&(t_id, t1_id));
54                        r
55                    }
56                }
57            }
58            (
59                Type::Abstract { id: id0, params: p0 },
60                Type::Abstract { id: id1, params: p1 },
61            ) if id0 == id1 && p0 == p1 => Ok(self.clone()),
62            (t0 @ Type::Abstract { .. }, t1) | (t0, t1 @ Type::Abstract { .. }) => {
63                Ok(Type::Set(Arc::from_iter([t0.clone(), t1.clone()])))
64            }
65            (Type::Bottom, t) | (t, Type::Bottom) => Ok(t.clone()),
66            (Type::Any, _) | (_, Type::Any) => Ok(Type::Any),
67            (Type::Primitive(p), t) | (t, Type::Primitive(p)) if p.is_empty() => {
68                Ok(t.clone())
69            }
70            (Type::Primitive(s0), Type::Primitive(s1)) => {
71                let mut s = *s0;
72                s.insert(*s1);
73                Ok(Type::Primitive(s))
74            }
75            (
76                Type::Primitive(p),
77                Type::Array(_) | Type::Struct(_) | Type::Tuple(_) | Type::Variant(_, _),
78            )
79            | (
80                Type::Array(_) | Type::Struct(_) | Type::Tuple(_) | Type::Variant(_, _),
81                Type::Primitive(p),
82            ) if p.contains(Typ::Array) => Ok(Type::Primitive(*p)),
83            (Type::Primitive(p), Type::Array(t))
84            | (Type::Array(t), Type::Primitive(p)) => Ok(Type::Set(Arc::from_iter([
85                Type::Primitive(*p),
86                Type::Array(t.clone()),
87            ]))),
88            (t @ Type::Array(t0), u @ Type::Array(t1)) => {
89                if t0 == t1 {
90                    Ok(Type::Array(t0.clone()))
91                } else {
92                    Ok(Type::Set(Arc::from_iter([t.clone(), u.clone()])))
93                }
94            }
95            (Type::Primitive(p), Type::Map { .. })
96            | (Type::Map { .. }, Type::Primitive(p))
97                if p.contains(Typ::Map) =>
98            {
99                Ok(Type::Primitive(*p))
100            }
101            (Type::Primitive(p), Type::Map { key, value })
102            | (Type::Map { key, value }, Type::Primitive(p)) => {
103                Ok(Type::Set(Arc::from_iter([
104                    Type::Primitive(*p),
105                    Type::Map { key: key.clone(), value: value.clone() },
106                ])))
107            }
108            (
109                t @ Type::Map { key: k0, value: v0 },
110                u @ Type::Map { key: k1, value: v1 },
111            ) => {
112                if k0 == k1 && v0 == v1 {
113                    Ok(Type::Map { key: k0.clone(), value: v0.clone() })
114                } else {
115                    Ok(Type::Set(Arc::from_iter([t.clone(), u.clone()])))
116                }
117            }
118            (t @ Type::Map { .. }, u) | (u, t @ Type::Map { .. }) => {
119                Ok(Type::Set(Arc::from_iter([t.clone(), u.clone()])))
120            }
121            (Type::Primitive(p), Type::Error(_))
122            | (Type::Error(_), Type::Primitive(p))
123                if p.contains(Typ::Error) =>
124            {
125                Ok(Type::Primitive(*p))
126            }
127            (Type::Error(e0), Type::Error(e1)) => {
128                Ok(Type::Error(Arc::new(e0.union_int(env, hist, e1)?)))
129            }
130            (e @ Type::Error(_), t) | (t, e @ Type::Error(_)) => {
131                Ok(Type::Set(Arc::from_iter([e.clone(), t.clone()])))
132            }
133            (t @ Type::ByRef(t0), u @ Type::ByRef(t1)) => {
134                if t0 == t1 {
135                    Ok(Type::ByRef(t0.clone()))
136                } else {
137                    Ok(Type::Set(Arc::from_iter([u.clone(), t.clone()])))
138                }
139            }
140            (Type::Set(s0), Type::Set(s1)) => Ok(Type::Set(Arc::from_iter(
141                s0.iter().cloned().chain(s1.iter().cloned()),
142            ))),
143            (Type::Set(s), t) | (t, Type::Set(s)) => Ok(Type::Set(Arc::from_iter(
144                s.iter().cloned().chain(iter::once(t.clone())),
145            ))),
146            (u @ Type::Struct(t0), t @ Type::Struct(t1)) => {
147                if t0.len() == t1.len() && t0 == t1 {
148                    Ok(u.clone())
149                } else {
150                    Ok(Type::Set(Arc::from_iter([u.clone(), t.clone()])))
151                }
152            }
153            (u @ Type::Struct(_), t) | (t, u @ Type::Struct(_)) => {
154                Ok(Type::Set(Arc::from_iter([u.clone(), t.clone()])))
155            }
156            (u @ Type::Tuple(t0), t @ Type::Tuple(t1)) => {
157                if t0 == t1 {
158                    Ok(u.clone())
159                } else {
160                    Ok(Type::Set(Arc::from_iter([u.clone(), t.clone()])))
161                }
162            }
163            (u @ Type::Tuple(_), t) | (t, u @ Type::Tuple(_)) => {
164                Ok(Type::Set(Arc::from_iter([u.clone(), t.clone()])))
165            }
166            (u @ Type::Variant(tg0, t0), t @ Type::Variant(tg1, t1)) => {
167                if tg0 == tg1 && t0.len() == t1.len() {
168                    let mut typs = t0
169                        .iter()
170                        .zip(t1.iter())
171                        .map(|(t0, t1)| t0.union_int(env, hist, t1))
172                        .collect::<Result<LPooled<Vec<_>>>>()?;
173                    Ok(Type::Variant(tg0.clone(), Arc::from_iter(typs.drain(..))))
174                } else {
175                    Ok(Type::Set(Arc::from_iter([u.clone(), t.clone()])))
176                }
177            }
178            (u @ Type::Variant(_, _), t) | (t, u @ Type::Variant(_, _)) => {
179                Ok(Type::Set(Arc::from_iter([u.clone(), t.clone()])))
180            }
181            (Type::Fn(f0), Type::Fn(f1)) => {
182                if f0 == f1 {
183                    Ok(Type::Fn(f0.clone()))
184                } else {
185                    Ok(Type::Set(Arc::from_iter([
186                        Type::Fn(f0.clone()),
187                        Type::Fn(f1.clone()),
188                    ])))
189                }
190            }
191            (f @ Type::Fn(_), t) | (t, f @ Type::Fn(_)) => {
192                Ok(Type::Set(Arc::from_iter([f.clone(), t.clone()])))
193            }
194            (t0 @ Type::TVar(_), t1 @ Type::TVar(_)) => {
195                if t0 == t1 {
196                    Ok(t0.clone())
197                } else {
198                    Ok(Type::Set(Arc::from_iter([t0.clone(), t1.clone()])))
199                }
200            }
201            (t0 @ Type::TVar(_), t1) | (t1, t0 @ Type::TVar(_)) => {
202                Ok(Type::Set(Arc::from_iter([t0.clone(), t1.clone()])))
203            }
204            (t @ Type::ByRef(_), u) | (u, t @ Type::ByRef(_)) => {
205                Ok(Type::Set(Arc::from_iter([t.clone(), u.clone()])))
206            }
207        }
208    }
209
210    pub fn union(&self, env: &Env, t: &Self) -> Result<Self> {
211        Ok(self.union_int(env, &mut RefHist::new(LPooled::take()), t)?.normalize())
212    }
213
214    fn diff_int(
215        &self,
216        env: &Env,
217        hist: &mut RefHist<FxHashMap<(Option<usize>, Option<usize>), Type>>,
218        t: &Self,
219    ) -> Result<Self> {
220        match (self, t) {
221            (
222                Type::Ref { scope: s0, name: n0, .. },
223                Type::Ref { scope: s1, name: n1, .. },
224            ) if s0 == s1 && n0 == n1 => Ok(Type::Primitive(BitFlags::empty())),
225            (t0 @ Type::Ref { .. }, t1) | (t0, t1 @ Type::Ref { .. }) => {
226                let t0_id = hist.ref_id(t0, env);
227                let t1_id = hist.ref_id(t1, env);
228                let t0 = t0.lookup_ref(env)?;
229                let t1 = t1.lookup_ref(env)?;
230                match hist.get(&(t0_id, t1_id)) {
231                    Some(r) => Ok(r.clone()),
232                    None => {
233                        let r = Type::Primitive(BitFlags::empty());
234                        hist.insert((t0_id, t1_id), r);
235                        let r = t0.diff_int(env, hist, &t1);
236                        hist.remove(&(t0_id, t1_id));
237                        r
238                    }
239                }
240            }
241            (Type::Set(s0), Type::Set(s1)) => {
242                let mut s: LPooled<Vec<Type>> = LPooled::take();
243                for i in 0..s0.len() {
244                    s.push(s0[i].clone());
245                    for j in 0..s1.len() {
246                        s[i] = s[i].diff_int(env, hist, &s1[j])?
247                    }
248                }
249                Ok(Self::flatten_set(s.drain(..)))
250            }
251            (Type::Set(s), t) => Ok(Self::flatten_set(
252                s.iter()
253                    .map(|s| s.diff_int(env, hist, t))
254                    .collect::<Result<LPooled<Vec<_>>>>()?
255                    .drain(..),
256            )),
257            (t, Type::Set(s)) => {
258                let mut t = t.clone();
259                for st in s.iter() {
260                    t = t.diff_int(env, hist, st)?;
261                }
262                Ok(t)
263            }
264            (Type::Tuple(t0), Type::Tuple(t1)) => {
265                if t0 == t1 {
266                    Ok(Type::Primitive(BitFlags::empty()))
267                } else {
268                    Ok(self.clone())
269                }
270            }
271            (Type::Struct(t0), Type::Struct(t1)) => {
272                if t0.len() == t1.len() && t0 == t1 {
273                    Ok(Type::Primitive(BitFlags::empty()))
274                } else {
275                    Ok(self.clone())
276                }
277            }
278            (Type::Variant(tg0, t0), Type::Variant(tg1, t1)) => {
279                if tg0 == tg1 && t0.len() == t1.len() && t0 == t1 {
280                    Ok(Type::Primitive(BitFlags::empty()))
281                } else {
282                    Ok(self.clone())
283                }
284            }
285            (Type::Map { key: k0, value: v0 }, Type::Map { key: k1, value: v1 }) => {
286                if k0 == k1 && v0 == v1 {
287                    Ok(Type::Primitive(BitFlags::empty()))
288                } else {
289                    Ok(self.clone())
290                }
291            }
292            (Type::Map { .. }, Type::Primitive(p)) => {
293                if p.contains(Typ::Map) {
294                    Ok(Type::Primitive(BitFlags::empty()))
295                } else {
296                    Ok(self.clone())
297                }
298            }
299            (Type::Primitive(p), Type::Map { key, value }) => {
300                if **key == Type::Any && **value == Type::Any {
301                    let mut p = *p;
302                    p.remove(Typ::Map);
303                    Ok(Type::Primitive(p))
304                } else {
305                    Ok(Type::Primitive(*p))
306                }
307            }
308            (Type::Fn(f0), Type::Fn(f1)) => {
309                if f0 == f1 {
310                    Ok(Type::Primitive(BitFlags::empty()))
311                } else {
312                    Ok(Type::Fn(f0.clone()))
313                }
314            }
315            (Type::TVar(tv0), Type::TVar(tv1)) => {
316                if tv0.read().typ.as_ptr() == tv1.read().typ.as_ptr() {
317                    return Ok(Type::Primitive(BitFlags::empty()));
318                }
319                Ok(match (&*tv0.read().typ.read(), &*tv1.read().typ.read()) {
320                    (None, _) | (_, None) => Type::TVar(tv0.clone()),
321                    (Some(t0), Some(t1)) => t0.diff_int(env, hist, t1)?,
322                })
323            }
324            (Type::TVar(tv), t) => Ok(match &*tv.read().typ.read() {
325                Some(tv) => tv.diff_int(env, hist, t)?,
326                None => self.clone(),
327            }),
328            (t, Type::TVar(tv)) => Ok(match &*tv.read().typ.read() {
329                Some(tv) => t.diff_int(env, hist, tv)?,
330                None => self.clone(),
331            }),
332            (Type::Array(t0), Type::Array(t1)) => {
333                if t0 == t1 {
334                    Ok(Type::Primitive(BitFlags::empty()))
335                } else {
336                    Ok(Type::Array(Arc::new(t0.diff_int(env, hist, t1)?)))
337                }
338            }
339            (Type::Primitive(p), Type::Array(t)) => {
340                if &**t == &Type::Any {
341                    let mut s = *p;
342                    s.remove(Typ::Array);
343                    Ok(Type::Primitive(s))
344                } else {
345                    Ok(Type::Primitive(*p))
346                }
347            }
348            (
349                Type::Array(_) | Type::Struct(_) | Type::Tuple(_) | Type::Variant(_, _),
350                Type::Primitive(p),
351            ) => {
352                if p.contains(Typ::Array) {
353                    Ok(Type::Primitive(BitFlags::empty()))
354                } else {
355                    Ok(self.clone())
356                }
357            }
358            (_, Type::Any) => Ok(Type::Primitive(BitFlags::empty())),
359            (Type::Any, _) => Ok(Type::Any),
360            (Type::Primitive(s0), Type::Primitive(s1)) => {
361                let mut s = *s0;
362                s.remove(*s1);
363                Ok(Type::Primitive(s))
364            }
365            (Type::Primitive(p), Type::Error(e)) => {
366                if &**e == &Type::Any {
367                    let mut s = *p;
368                    s.remove(Typ::Error);
369                    Ok(Type::Primitive(s))
370                } else {
371                    Ok(Type::Primitive(*p))
372                }
373            }
374            (Type::Error(_), Type::Primitive(p)) => {
375                if p.contains(Typ::Error) {
376                    Ok(Type::Primitive(BitFlags::empty()))
377                } else {
378                    Ok(self.clone())
379                }
380            }
381            (Type::Error(e0), Type::Error(e1)) => {
382                if e0 == e1 {
383                    Ok(Type::Primitive(BitFlags::empty()))
384                } else {
385                    Ok(Type::Error(Arc::new(e0.diff_int(env, hist, e1)?)))
386                }
387            }
388            (Type::ByRef(t0), Type::ByRef(t1)) => {
389                Ok(Type::ByRef(Arc::new(t0.diff_int(env, hist, t1)?)))
390            }
391            (
392                Type::Abstract { id: id0, params: p0 },
393                Type::Abstract { id: id1, params: p1 },
394            ) if id0 == id1 && p0 == p1 => Ok(Type::Primitive(BitFlags::empty())),
395            (Type::Abstract { .. }, _)
396            | (_, Type::Abstract { .. })
397            | (Type::Fn(_), _)
398            | (_, Type::Fn(_))
399            | (Type::Array(_), _)
400            | (_, Type::Array(_))
401            | (Type::Tuple(_), _)
402            | (_, Type::Tuple(_))
403            | (Type::Struct(_), _)
404            | (_, Type::Struct(_))
405            | (Type::Variant(_, _), _)
406            | (_, Type::Variant(_, _))
407            | (Type::ByRef(_), _)
408            | (_, Type::ByRef(_))
409            | (Type::Error(_), _)
410            | (_, Type::Error(_))
411            | (Type::Primitive(_), _)
412            | (_, Type::Primitive(_))
413            | (Type::Bottom, _)
414            | (Type::Map { .. }, _) => Ok(self.clone()),
415        }
416    }
417
418    pub fn diff(&self, env: &Env, t: &Self) -> Result<Self> {
419        Ok(self.diff_int(env, &mut RefHist::new(LPooled::take()), t)?.normalize())
420    }
421}