Skip to main content

graphix_compiler/typ/
setops.rs

1use crate::{env::Env, typ::Type};
2use anyhow::Result;
3use enumflags2::BitFlags;
4use fxhash::FxHashMap;
5use netidx::publisher::Typ;
6use poolshark::local::LPooled;
7use std::iter;
8use triomphe::Arc;
9
10impl Type {
11    fn union_int(
12        &self,
13        env: &Env,
14        hist: &mut FxHashMap<(usize, usize), Type>,
15        t: &Self,
16    ) -> Result<Self> {
17        match (self, t) {
18            (
19                Type::Ref { scope: s0, name: n0, params: p0 },
20                Type::Ref { scope: s1, name: n1, params: p1 },
21            ) if n0 == n1 && s0 == s1 && p0.len() == p1.len() => {
22                let mut params = p0
23                    .iter()
24                    .zip(p1.iter())
25                    .map(|(p0, p1)| p0.union_int(env, hist, p1))
26                    .collect::<Result<LPooled<Vec<_>>>>()?;
27                let params = Arc::from_iter(params.drain(..));
28                Ok(Self::Ref { scope: s0.clone(), name: n0.clone(), params })
29            }
30            (tr @ Type::Ref { .. }, t) => {
31                let t0 = tr.lookup_ref(env)?;
32                let t0_addr = (t0 as *const Type).addr();
33                let t_addr = (t as *const Type).addr();
34                match hist.get(&(t0_addr, t_addr)) {
35                    Some(t) => Ok(t.clone()),
36                    None => {
37                        hist.insert((t0_addr, t_addr), tr.clone());
38                        let r = t0.union_int(env, hist, t)?;
39                        hist.insert((t0_addr, t_addr), r.clone());
40                        Ok(r)
41                    }
42                }
43            }
44            (t, tr @ Type::Ref { .. }) => {
45                let t1 = tr.lookup_ref(env)?;
46                let t1_addr = (t1 as *const Type).addr();
47                let t_addr = (t as *const Type).addr();
48                match hist.get(&(t_addr, t1_addr)) {
49                    Some(t) => Ok(t.clone()),
50                    None => {
51                        hist.insert((t_addr, t1_addr), tr.clone());
52                        let r = t.union_int(env, hist, t1)?;
53                        hist.insert((t_addr, t1_addr), r.clone());
54                        Ok(r)
55                    }
56                }
57            }
58            (
59                Type::Abstract { id: id0, params: p0 },
60                Type::Abstract { id: id1, params: p1 },
61            ) if id0 == id1 && p0 == p1 => Ok(self.clone()),
62            (t0 @ Type::Abstract { .. }, t1) | (t0, t1 @ Type::Abstract { .. }) => {
63                Ok(Type::Set(Arc::from_iter([t0.clone(), t1.clone()])))
64            }
65            (Type::Bottom, t) | (t, Type::Bottom) => Ok(t.clone()),
66            (Type::Any, _) | (_, Type::Any) => Ok(Type::Any),
67            (Type::Primitive(p), t) | (t, Type::Primitive(p)) if p.is_empty() => {
68                Ok(t.clone())
69            }
70            (Type::Primitive(s0), Type::Primitive(s1)) => {
71                let mut s = *s0;
72                s.insert(*s1);
73                Ok(Type::Primitive(s))
74            }
75            (
76                Type::Primitive(p),
77                Type::Array(_) | Type::Struct(_) | Type::Tuple(_) | Type::Variant(_, _),
78            )
79            | (
80                Type::Array(_) | Type::Struct(_) | Type::Tuple(_) | Type::Variant(_, _),
81                Type::Primitive(p),
82            ) if p.contains(Typ::Array) => Ok(Type::Primitive(*p)),
83            (Type::Primitive(p), Type::Array(t))
84            | (Type::Array(t), Type::Primitive(p)) => Ok(Type::Set(Arc::from_iter([
85                Type::Primitive(*p),
86                Type::Array(t.clone()),
87            ]))),
88            (t @ Type::Array(t0), u @ Type::Array(t1)) => {
89                if t0 == t1 {
90                    Ok(Type::Array(t0.clone()))
91                } else {
92                    Ok(Type::Set(Arc::from_iter([t.clone(), u.clone()])))
93                }
94            }
95            (Type::Primitive(p), Type::Map { .. })
96            | (Type::Map { .. }, Type::Primitive(p))
97                if p.contains(Typ::Map) =>
98            {
99                Ok(Type::Primitive(*p))
100            }
101            (Type::Primitive(p), Type::Map { key, value })
102            | (Type::Map { key, value }, Type::Primitive(p)) => {
103                Ok(Type::Set(Arc::from_iter([
104                    Type::Primitive(*p),
105                    Type::Map { key: key.clone(), value: value.clone() },
106                ])))
107            }
108            (
109                t @ Type::Map { key: k0, value: v0 },
110                u @ Type::Map { key: k1, value: v1 },
111            ) => {
112                if k0 == k1 && v0 == v1 {
113                    Ok(Type::Map { key: k0.clone(), value: v0.clone() })
114                } else {
115                    Ok(Type::Set(Arc::from_iter([t.clone(), u.clone()])))
116                }
117            }
118            (t @ Type::Map { .. }, u) | (u, t @ Type::Map { .. }) => {
119                Ok(Type::Set(Arc::from_iter([t.clone(), u.clone()])))
120            }
121            (Type::Primitive(p), Type::Error(_))
122            | (Type::Error(_), Type::Primitive(p))
123                if p.contains(Typ::Error) =>
124            {
125                Ok(Type::Primitive(*p))
126            }
127            (Type::Error(e0), Type::Error(e1)) => {
128                Ok(Type::Error(Arc::new(e0.union_int(env, hist, e1)?)))
129            }
130            (e @ Type::Error(_), t) | (t, e @ Type::Error(_)) => {
131                Ok(Type::Set(Arc::from_iter([e.clone(), t.clone()])))
132            }
133            (t @ Type::ByRef(t0), u @ Type::ByRef(t1)) => {
134                if t0 == t1 {
135                    Ok(Type::ByRef(t0.clone()))
136                } else {
137                    Ok(Type::Set(Arc::from_iter([u.clone(), t.clone()])))
138                }
139            }
140            (Type::Set(s0), Type::Set(s1)) => Ok(Type::Set(Arc::from_iter(
141                s0.iter().cloned().chain(s1.iter().cloned()),
142            ))),
143            (Type::Set(s), t) | (t, Type::Set(s)) => Ok(Type::Set(Arc::from_iter(
144                s.iter().cloned().chain(iter::once(t.clone())),
145            ))),
146            (u @ Type::Struct(t0), t @ Type::Struct(t1)) => {
147                if t0.len() == t1.len() && t0 == t1 {
148                    Ok(u.clone())
149                } else {
150                    Ok(Type::Set(Arc::from_iter([u.clone(), t.clone()])))
151                }
152            }
153            (u @ Type::Struct(_), t) | (t, u @ Type::Struct(_)) => {
154                Ok(Type::Set(Arc::from_iter([u.clone(), t.clone()])))
155            }
156            (u @ Type::Tuple(t0), t @ Type::Tuple(t1)) => {
157                if t0 == t1 {
158                    Ok(u.clone())
159                } else {
160                    Ok(Type::Set(Arc::from_iter([u.clone(), t.clone()])))
161                }
162            }
163            (u @ Type::Tuple(_), t) | (t, u @ Type::Tuple(_)) => {
164                Ok(Type::Set(Arc::from_iter([u.clone(), t.clone()])))
165            }
166            (u @ Type::Variant(tg0, t0), t @ Type::Variant(tg1, t1)) => {
167                if tg0 == tg1 && t0.len() == t1.len() {
168                    let mut typs = t0
169                        .iter()
170                        .zip(t1.iter())
171                        .map(|(t0, t1)| t0.union_int(env, hist, t1))
172                        .collect::<Result<LPooled<Vec<_>>>>()?;
173                    Ok(Type::Variant(tg0.clone(), Arc::from_iter(typs.drain(..))))
174                } else {
175                    Ok(Type::Set(Arc::from_iter([u.clone(), t.clone()])))
176                }
177            }
178            (u @ Type::Variant(_, _), t) | (t, u @ Type::Variant(_, _)) => {
179                Ok(Type::Set(Arc::from_iter([u.clone(), t.clone()])))
180            }
181            (Type::Fn(f0), Type::Fn(f1)) => {
182                if f0 == f1 {
183                    Ok(Type::Fn(f0.clone()))
184                } else {
185                    Ok(Type::Set(Arc::from_iter([
186                        Type::Fn(f0.clone()),
187                        Type::Fn(f1.clone()),
188                    ])))
189                }
190            }
191            (f @ Type::Fn(_), t) | (t, f @ Type::Fn(_)) => {
192                Ok(Type::Set(Arc::from_iter([f.clone(), t.clone()])))
193            }
194            (t0 @ Type::TVar(_), t1 @ Type::TVar(_)) => {
195                if t0 == t1 {
196                    Ok(t0.clone())
197                } else {
198                    Ok(Type::Set(Arc::from_iter([t0.clone(), t1.clone()])))
199                }
200            }
201            (t0 @ Type::TVar(_), t1) | (t1, t0 @ Type::TVar(_)) => {
202                Ok(Type::Set(Arc::from_iter([t0.clone(), t1.clone()])))
203            }
204            (t @ Type::ByRef(_), u) | (u, t @ Type::ByRef(_)) => {
205                Ok(Type::Set(Arc::from_iter([t.clone(), u.clone()])))
206            }
207        }
208    }
209
210    pub fn union(&self, env: &Env, t: &Self) -> Result<Self> {
211        Ok(self.union_int(env, &mut LPooled::take(), t)?.normalize())
212    }
213
214    fn diff_int(
215        &self,
216        env: &Env,
217        hist: &mut FxHashMap<(usize, usize), Type>,
218        t: &Self,
219    ) -> Result<Self> {
220        match (self, t) {
221            (
222                Type::Ref { scope: s0, name: n0, .. },
223                Type::Ref { scope: s1, name: n1, .. },
224            ) if s0 == s1 && n0 == n1 => Ok(Type::Primitive(BitFlags::empty())),
225            (t0 @ Type::Ref { .. }, t1) | (t0, t1 @ Type::Ref { .. }) => {
226                let t0 = t0.lookup_ref(env)?;
227                let t1 = t1.lookup_ref(env)?;
228                let t0_addr = (t0 as *const Type).addr();
229                let t1_addr = (t1 as *const Type).addr();
230                match hist.get(&(t0_addr, t1_addr)) {
231                    Some(r) => Ok(r.clone()),
232                    None => {
233                        let r = Type::Primitive(BitFlags::empty());
234                        hist.insert((t0_addr, t1_addr), r);
235                        match t0.diff_int(env, hist, &t1) {
236                            Ok(r) => {
237                                hist.insert((t0_addr, t1_addr), r.clone());
238                                Ok(r)
239                            }
240                            Err(e) => {
241                                hist.remove(&(t0_addr, t1_addr));
242                                Err(e)
243                            }
244                        }
245                    }
246                }
247            }
248            (Type::Set(s0), Type::Set(s1)) => {
249                let mut s: LPooled<Vec<Type>> = LPooled::take();
250                for i in 0..s0.len() {
251                    s.push(s0[i].clone());
252                    for j in 0..s1.len() {
253                        s[i] = s[i].diff_int(env, hist, &s1[j])?
254                    }
255                }
256                Ok(Self::flatten_set(s.drain(..)))
257            }
258            (Type::Set(s), t) => Ok(Self::flatten_set(
259                s.iter()
260                    .map(|s| s.diff_int(env, hist, t))
261                    .collect::<Result<LPooled<Vec<_>>>>()?
262                    .drain(..),
263            )),
264            (t, Type::Set(s)) => {
265                let mut t = t.clone();
266                for st in s.iter() {
267                    t = t.diff_int(env, hist, st)?;
268                }
269                Ok(t)
270            }
271            (Type::Tuple(t0), Type::Tuple(t1)) => {
272                if t0 == t1 {
273                    Ok(Type::Primitive(BitFlags::empty()))
274                } else {
275                    Ok(self.clone())
276                }
277            }
278            (Type::Struct(t0), Type::Struct(t1)) => {
279                if t0.len() == t1.len() && t0 == t1 {
280                    Ok(Type::Primitive(BitFlags::empty()))
281                } else {
282                    Ok(self.clone())
283                }
284            }
285            (Type::Variant(tg0, t0), Type::Variant(tg1, t1)) => {
286                if tg0 == tg1 && t0.len() == t1.len() && t0 == t1 {
287                    Ok(Type::Primitive(BitFlags::empty()))
288                } else {
289                    Ok(self.clone())
290                }
291            }
292            (Type::Map { key: k0, value: v0 }, Type::Map { key: k1, value: v1 }) => {
293                if k0 == k1 && v0 == v1 {
294                    Ok(Type::Primitive(BitFlags::empty()))
295                } else {
296                    Ok(self.clone())
297                }
298            }
299            (Type::Map { .. }, Type::Primitive(p)) => {
300                if p.contains(Typ::Map) {
301                    Ok(Type::Primitive(BitFlags::empty()))
302                } else {
303                    Ok(self.clone())
304                }
305            }
306            (Type::Primitive(p), Type::Map { key, value }) => {
307                if **key == Type::Any && **value == Type::Any {
308                    let mut p = *p;
309                    p.remove(Typ::Map);
310                    Ok(Type::Primitive(p))
311                } else {
312                    Ok(Type::Primitive(*p))
313                }
314            }
315            (Type::Fn(f0), Type::Fn(f1)) => {
316                if f0 == f1 {
317                    Ok(Type::Primitive(BitFlags::empty()))
318                } else {
319                    Ok(Type::Fn(f0.clone()))
320                }
321            }
322            (Type::TVar(tv0), Type::TVar(tv1)) => {
323                if tv0.read().typ.as_ptr() == tv1.read().typ.as_ptr() {
324                    return Ok(Type::Primitive(BitFlags::empty()));
325                }
326                Ok(match (&*tv0.read().typ.read(), &*tv1.read().typ.read()) {
327                    (None, _) | (_, None) => Type::TVar(tv0.clone()),
328                    (Some(t0), Some(t1)) => t0.diff_int(env, hist, t1)?,
329                })
330            }
331            (Type::TVar(tv), t) => Ok(match &*tv.read().typ.read() {
332                Some(tv) => tv.diff_int(env, hist, t)?,
333                None => self.clone(),
334            }),
335            (t, Type::TVar(tv)) => Ok(match &*tv.read().typ.read() {
336                Some(tv) => t.diff_int(env, hist, tv)?,
337                None => self.clone(),
338            }),
339            (Type::Array(t0), Type::Array(t1)) => {
340                if t0 == t1 {
341                    Ok(Type::Primitive(BitFlags::empty()))
342                } else {
343                    Ok(Type::Array(Arc::new(t0.diff_int(env, hist, t1)?)))
344                }
345            }
346            (Type::Primitive(p), Type::Array(t)) => {
347                if &**t == &Type::Any {
348                    let mut s = *p;
349                    s.remove(Typ::Array);
350                    Ok(Type::Primitive(s))
351                } else {
352                    Ok(Type::Primitive(*p))
353                }
354            }
355            (
356                Type::Array(_) | Type::Struct(_) | Type::Tuple(_) | Type::Variant(_, _),
357                Type::Primitive(p),
358            ) => {
359                if p.contains(Typ::Array) {
360                    Ok(Type::Primitive(BitFlags::empty()))
361                } else {
362                    Ok(self.clone())
363                }
364            }
365            (_, Type::Any) => Ok(Type::Primitive(BitFlags::empty())),
366            (Type::Any, _) => Ok(Type::Any),
367            (Type::Primitive(s0), Type::Primitive(s1)) => {
368                let mut s = *s0;
369                s.remove(*s1);
370                Ok(Type::Primitive(s))
371            }
372            (Type::Primitive(p), Type::Error(e)) => {
373                if &**e == &Type::Any {
374                    let mut s = *p;
375                    s.remove(Typ::Error);
376                    Ok(Type::Primitive(s))
377                } else {
378                    Ok(Type::Primitive(*p))
379                }
380            }
381            (Type::Error(_), Type::Primitive(p)) => {
382                if p.contains(Typ::Error) {
383                    Ok(Type::Primitive(BitFlags::empty()))
384                } else {
385                    Ok(self.clone())
386                }
387            }
388            (Type::Error(e0), Type::Error(e1)) => {
389                if e0 == e1 {
390                    Ok(Type::Primitive(BitFlags::empty()))
391                } else {
392                    Ok(Type::Error(Arc::new(e0.diff_int(env, hist, e1)?)))
393                }
394            }
395            (Type::ByRef(t0), Type::ByRef(t1)) => {
396                Ok(Type::ByRef(Arc::new(t0.diff_int(env, hist, t1)?)))
397            }
398            (
399                Type::Abstract { id: id0, params: p0 },
400                Type::Abstract { id: id1, params: p1 },
401            ) if id0 == id1 && p0 == p1 => Ok(Type::Primitive(BitFlags::empty())),
402            (Type::Abstract { .. }, _)
403            | (_, Type::Abstract { .. })
404            | (Type::Fn(_), _)
405            | (_, Type::Fn(_))
406            | (Type::Array(_), _)
407            | (_, Type::Array(_))
408            | (Type::Tuple(_), _)
409            | (_, Type::Tuple(_))
410            | (Type::Struct(_), _)
411            | (_, Type::Struct(_))
412            | (Type::Variant(_, _), _)
413            | (_, Type::Variant(_, _))
414            | (Type::ByRef(_), _)
415            | (_, Type::ByRef(_))
416            | (Type::Error(_), _)
417            | (_, Type::Error(_))
418            | (Type::Primitive(_), _)
419            | (_, Type::Primitive(_))
420            | (Type::Bottom, _)
421            | (Type::Map { .. }, _) => Ok(self.clone()),
422        }
423    }
424
425    pub fn diff(&self, env: &Env, t: &Self) -> Result<Self> {
426        Ok(self.diff_int(env, &mut LPooled::take(), t)?.normalize())
427    }
428}