Skip to main content

graphix_compiler/typ/
setops.rs

1use crate::{
2    env::Env,
3    typ::{RefHist, Type},
4};
5use anyhow::Result;
6use enumflags2::BitFlags;
7use fxhash::FxHashMap;
8use netidx::publisher::Typ;
9use poolshark::local::LPooled;
10use std::iter;
11use triomphe::Arc;
12
13impl Type {
14    fn union_int(
15        &self,
16        env: &Env,
17        hist: &mut RefHist<FxHashMap<(Option<usize>, Option<usize>), Type>>,
18        t: &Self,
19    ) -> Result<Self> {
20        match (self, t) {
21            (
22                Type::Ref { scope: s0, name: n0, params: p0 },
23                Type::Ref { scope: s1, name: n1, params: p1 },
24            ) if n0 == n1 && s0 == s1 && p0.len() == p1.len() => {
25                let mut params = p0
26                    .iter()
27                    .zip(p1.iter())
28                    .map(|(p0, p1)| p0.union_int(env, hist, p1))
29                    .collect::<Result<LPooled<Vec<_>>>>()?;
30                let params = Arc::from_iter(params.drain(..));
31                Ok(Self::Ref { scope: s0.clone(), name: n0.clone(), params })
32            }
33            (tr @ Type::Ref { .. }, t) => {
34                let t0_id = hist.ref_id(tr, env);
35                let t_id = hist.ref_id(t, env);
36                let t0 = tr.lookup_ref(env)?;
37                match hist.get(&(t0_id, t_id)) {
38                    Some(t) => Ok(t.clone()),
39                    None => {
40                        hist.insert((t0_id, t_id), tr.clone());
41                        let r = t0.union_int(env, hist, t);
42                        hist.remove(&(t0_id, t_id));
43                        r
44                    }
45                }
46            }
47            (t, tr @ Type::Ref { .. }) => {
48                let t_id = hist.ref_id(t, env);
49                let t1_id = hist.ref_id(tr, env);
50                let t1 = tr.lookup_ref(env)?;
51                match hist.get(&(t_id, t1_id)) {
52                    Some(t) => Ok(t.clone()),
53                    None => {
54                        hist.insert((t_id, t1_id), tr.clone());
55                        let r = t.union_int(env, hist, &t1);
56                        hist.remove(&(t_id, t1_id));
57                        r
58                    }
59                }
60            }
61            (
62                Type::Abstract { id: id0, params: p0 },
63                Type::Abstract { id: id1, params: p1 },
64            ) if id0 == id1 && p0 == p1 => Ok(self.clone()),
65            (t0 @ Type::Abstract { .. }, t1) | (t0, t1 @ Type::Abstract { .. }) => {
66                Ok(Type::Set(Arc::from_iter([t0.clone(), t1.clone()])))
67            }
68            (Type::Bottom, t) | (t, Type::Bottom) => Ok(t.clone()),
69            (Type::Any, _) | (_, Type::Any) => Ok(Type::Any),
70            (Type::Primitive(p), t) | (t, Type::Primitive(p)) if p.is_empty() => {
71                Ok(t.clone())
72            }
73            (Type::Primitive(s0), Type::Primitive(s1)) => {
74                let mut s = *s0;
75                s.insert(*s1);
76                Ok(Type::Primitive(s))
77            }
78            (
79                Type::Primitive(p),
80                Type::Array(_) | Type::Struct(_) | Type::Tuple(_) | Type::Variant(_, _),
81            )
82            | (
83                Type::Array(_) | Type::Struct(_) | Type::Tuple(_) | Type::Variant(_, _),
84                Type::Primitive(p),
85            ) if p.contains(Typ::Array) => Ok(Type::Primitive(*p)),
86            (Type::Primitive(p), Type::Array(t))
87            | (Type::Array(t), Type::Primitive(p)) => Ok(Type::Set(Arc::from_iter([
88                Type::Primitive(*p),
89                Type::Array(t.clone()),
90            ]))),
91            (t @ Type::Array(t0), u @ Type::Array(t1)) => {
92                if t0 == t1 {
93                    Ok(Type::Array(t0.clone()))
94                } else {
95                    Ok(Type::Set(Arc::from_iter([t.clone(), u.clone()])))
96                }
97            }
98            (Type::Primitive(p), Type::Map { .. })
99            | (Type::Map { .. }, Type::Primitive(p))
100                if p.contains(Typ::Map) =>
101            {
102                Ok(Type::Primitive(*p))
103            }
104            (Type::Primitive(p), Type::Map { key, value })
105            | (Type::Map { key, value }, Type::Primitive(p)) => {
106                Ok(Type::Set(Arc::from_iter([
107                    Type::Primitive(*p),
108                    Type::Map { key: key.clone(), value: value.clone() },
109                ])))
110            }
111            (
112                t @ Type::Map { key: k0, value: v0 },
113                u @ Type::Map { key: k1, value: v1 },
114            ) => {
115                if k0 == k1 && v0 == v1 {
116                    Ok(Type::Map { key: k0.clone(), value: v0.clone() })
117                } else {
118                    Ok(Type::Set(Arc::from_iter([t.clone(), u.clone()])))
119                }
120            }
121            (t @ Type::Map { .. }, u) | (u, t @ Type::Map { .. }) => {
122                Ok(Type::Set(Arc::from_iter([t.clone(), u.clone()])))
123            }
124            (Type::Primitive(p), Type::Error(_))
125            | (Type::Error(_), Type::Primitive(p))
126                if p.contains(Typ::Error) =>
127            {
128                Ok(Type::Primitive(*p))
129            }
130            (Type::Error(e0), Type::Error(e1)) => {
131                Ok(Type::Error(Arc::new(e0.union_int(env, hist, e1)?)))
132            }
133            (e @ Type::Error(_), t) | (t, e @ Type::Error(_)) => {
134                Ok(Type::Set(Arc::from_iter([e.clone(), t.clone()])))
135            }
136            (t @ Type::ByRef(t0), u @ Type::ByRef(t1)) => {
137                if t0 == t1 {
138                    Ok(Type::ByRef(t0.clone()))
139                } else {
140                    Ok(Type::Set(Arc::from_iter([u.clone(), t.clone()])))
141                }
142            }
143            (Type::Set(s0), Type::Set(s1)) => Ok(Type::Set(Arc::from_iter(
144                s0.iter().cloned().chain(s1.iter().cloned()),
145            ))),
146            (Type::Set(s), t) | (t, Type::Set(s)) => Ok(Type::Set(Arc::from_iter(
147                s.iter().cloned().chain(iter::once(t.clone())),
148            ))),
149            (u @ Type::Struct(t0), t @ Type::Struct(t1)) => {
150                if t0.len() == t1.len() && t0 == t1 {
151                    Ok(u.clone())
152                } else {
153                    Ok(Type::Set(Arc::from_iter([u.clone(), t.clone()])))
154                }
155            }
156            (u @ Type::Struct(_), t) | (t, u @ Type::Struct(_)) => {
157                Ok(Type::Set(Arc::from_iter([u.clone(), t.clone()])))
158            }
159            (u @ Type::Tuple(t0), t @ Type::Tuple(t1)) => {
160                if t0 == t1 {
161                    Ok(u.clone())
162                } else {
163                    Ok(Type::Set(Arc::from_iter([u.clone(), t.clone()])))
164                }
165            }
166            (u @ Type::Tuple(_), t) | (t, u @ Type::Tuple(_)) => {
167                Ok(Type::Set(Arc::from_iter([u.clone(), t.clone()])))
168            }
169            (u @ Type::Variant(tg0, t0), t @ Type::Variant(tg1, t1)) => {
170                if tg0 == tg1 && t0.len() == t1.len() {
171                    let mut typs = t0
172                        .iter()
173                        .zip(t1.iter())
174                        .map(|(t0, t1)| t0.union_int(env, hist, t1))
175                        .collect::<Result<LPooled<Vec<_>>>>()?;
176                    Ok(Type::Variant(tg0.clone(), Arc::from_iter(typs.drain(..))))
177                } else {
178                    Ok(Type::Set(Arc::from_iter([u.clone(), t.clone()])))
179                }
180            }
181            (u @ Type::Variant(_, _), t) | (t, u @ Type::Variant(_, _)) => {
182                Ok(Type::Set(Arc::from_iter([u.clone(), t.clone()])))
183            }
184            (Type::Fn(f0), Type::Fn(f1)) => {
185                if f0 == f1 {
186                    Ok(Type::Fn(f0.clone()))
187                } else {
188                    Ok(Type::Set(Arc::from_iter([
189                        Type::Fn(f0.clone()),
190                        Type::Fn(f1.clone()),
191                    ])))
192                }
193            }
194            (f @ Type::Fn(_), t) | (t, f @ Type::Fn(_)) => {
195                Ok(Type::Set(Arc::from_iter([f.clone(), t.clone()])))
196            }
197            (t0 @ Type::TVar(_), t1 @ Type::TVar(_)) => {
198                if t0 == t1 {
199                    Ok(t0.clone())
200                } else {
201                    Ok(Type::Set(Arc::from_iter([t0.clone(), t1.clone()])))
202                }
203            }
204            (t0 @ Type::TVar(_), t1) | (t1, t0 @ Type::TVar(_)) => {
205                Ok(Type::Set(Arc::from_iter([t0.clone(), t1.clone()])))
206            }
207            (t @ Type::ByRef(_), u) | (u, t @ Type::ByRef(_)) => {
208                Ok(Type::Set(Arc::from_iter([t.clone(), u.clone()])))
209            }
210        }
211    }
212
213    pub fn union(&self, env: &Env, t: &Self) -> Result<Self> {
214        Ok(self.union_int(env, &mut RefHist::new(LPooled::take()), t)?.normalize())
215    }
216
217    fn diff_int(
218        &self,
219        env: &Env,
220        hist: &mut RefHist<FxHashMap<(Option<usize>, Option<usize>), Type>>,
221        t: &Self,
222    ) -> Result<Self> {
223        match (self, t) {
224            (
225                Type::Ref { scope: s0, name: n0, .. },
226                Type::Ref { scope: s1, name: n1, .. },
227            ) if s0 == s1 && n0 == n1 => Ok(Type::Primitive(BitFlags::empty())),
228            (t0 @ Type::Ref { .. }, t1) | (t0, t1 @ Type::Ref { .. }) => {
229                let t0_id = hist.ref_id(t0, env);
230                let t1_id = hist.ref_id(t1, env);
231                let t0 = t0.lookup_ref(env)?;
232                let t1 = t1.lookup_ref(env)?;
233                match hist.get(&(t0_id, t1_id)) {
234                    Some(r) => Ok(r.clone()),
235                    None => {
236                        let r = Type::Primitive(BitFlags::empty());
237                        hist.insert((t0_id, t1_id), r);
238                        let r = t0.diff_int(env, hist, &t1);
239                        hist.remove(&(t0_id, t1_id));
240                        r
241                    }
242                }
243            }
244            (Type::Set(s0), Type::Set(s1)) => {
245                let mut s: LPooled<Vec<Type>> = LPooled::take();
246                for i in 0..s0.len() {
247                    s.push(s0[i].clone());
248                    for j in 0..s1.len() {
249                        s[i] = s[i].diff_int(env, hist, &s1[j])?
250                    }
251                }
252                Ok(Self::flatten_set(s.drain(..)))
253            }
254            (Type::Set(s), t) => Ok(Self::flatten_set(
255                s.iter()
256                    .map(|s| s.diff_int(env, hist, t))
257                    .collect::<Result<LPooled<Vec<_>>>>()?
258                    .drain(..),
259            )),
260            (t, Type::Set(s)) => {
261                let mut t = t.clone();
262                for st in s.iter() {
263                    t = t.diff_int(env, hist, st)?;
264                }
265                Ok(t)
266            }
267            (Type::Tuple(t0), Type::Tuple(t1)) => {
268                if t0 == t1 {
269                    Ok(Type::Primitive(BitFlags::empty()))
270                } else {
271                    Ok(self.clone())
272                }
273            }
274            (Type::Struct(t0), Type::Struct(t1)) => {
275                if t0.len() == t1.len() && t0 == t1 {
276                    Ok(Type::Primitive(BitFlags::empty()))
277                } else {
278                    Ok(self.clone())
279                }
280            }
281            (Type::Variant(tg0, t0), Type::Variant(tg1, t1)) => {
282                if tg0 == tg1 && t0.len() == t1.len() && t0 == t1 {
283                    Ok(Type::Primitive(BitFlags::empty()))
284                } else {
285                    Ok(self.clone())
286                }
287            }
288            (Type::Map { key: k0, value: v0 }, Type::Map { key: k1, value: v1 }) => {
289                if k0 == k1 && v0 == v1 {
290                    Ok(Type::Primitive(BitFlags::empty()))
291                } else {
292                    Ok(self.clone())
293                }
294            }
295            (Type::Map { .. }, Type::Primitive(p)) => {
296                if p.contains(Typ::Map) {
297                    Ok(Type::Primitive(BitFlags::empty()))
298                } else {
299                    Ok(self.clone())
300                }
301            }
302            (Type::Primitive(p), Type::Map { key, value }) => {
303                if **key == Type::Any && **value == Type::Any {
304                    let mut p = *p;
305                    p.remove(Typ::Map);
306                    Ok(Type::Primitive(p))
307                } else {
308                    Ok(Type::Primitive(*p))
309                }
310            }
311            (Type::Fn(f0), Type::Fn(f1)) => {
312                if f0 == f1 {
313                    Ok(Type::Primitive(BitFlags::empty()))
314                } else {
315                    Ok(Type::Fn(f0.clone()))
316                }
317            }
318            (Type::TVar(tv0), t1 @ Type::TVar(tv1)) => {
319                if Arc::ptr_eq(&tv0.read().typ, &tv1.read().typ) {
320                    return Ok(Type::Primitive(BitFlags::empty()));
321                }
322                Ok(match (&*tv0.read().typ.read(), &*tv1.read().typ.read()) {
323                    (None, _) => Type::TVar(tv0.clone()),
324                    (Some(t0), None) => t0.diff_int(env, hist, t1)?,
325                    (Some(t0), Some(t1)) => t0.diff_int(env, hist, t1)?,
326                })
327            }
328            (Type::TVar(tv), t) => Ok(match &*tv.read().typ.read() {
329                Some(tv) => tv.diff_int(env, hist, t)?,
330                None => self.clone(),
331            }),
332            (t, Type::TVar(tv)) => Ok(match &*tv.read().typ.read() {
333                Some(tv) => t.diff_int(env, hist, tv)?,
334                None => self.clone(),
335            }),
336            (Type::Array(t0), Type::Array(t1)) => {
337                if t0 == t1 {
338                    Ok(Type::Primitive(BitFlags::empty()))
339                } else {
340                    Ok(Type::Array(Arc::new(t0.diff_int(env, hist, t1)?)))
341                }
342            }
343            (Type::Primitive(p), Type::Array(t)) => {
344                if &**t == &Type::Any {
345                    let mut s = *p;
346                    s.remove(Typ::Array);
347                    Ok(Type::Primitive(s))
348                } else {
349                    Ok(Type::Primitive(*p))
350                }
351            }
352            (
353                Type::Array(_) | Type::Struct(_) | Type::Tuple(_) | Type::Variant(_, _),
354                Type::Primitive(p),
355            ) => {
356                if p.contains(Typ::Array) {
357                    Ok(Type::Primitive(BitFlags::empty()))
358                } else {
359                    Ok(self.clone())
360                }
361            }
362            (_, Type::Any) => Ok(Type::Primitive(BitFlags::empty())),
363            (Type::Any, _) => Ok(Type::Any),
364            (Type::Primitive(s0), Type::Primitive(s1)) => {
365                let mut s = *s0;
366                s.remove(*s1);
367                Ok(Type::Primitive(s))
368            }
369            (Type::Primitive(p), Type::Error(e)) => {
370                if &**e == &Type::Any {
371                    let mut s = *p;
372                    s.remove(Typ::Error);
373                    Ok(Type::Primitive(s))
374                } else {
375                    Ok(Type::Primitive(*p))
376                }
377            }
378            (Type::Error(_), Type::Primitive(p)) => {
379                if p.contains(Typ::Error) {
380                    Ok(Type::Primitive(BitFlags::empty()))
381                } else {
382                    Ok(self.clone())
383                }
384            }
385            (Type::Error(e0), Type::Error(e1)) => {
386                if e0 == e1 {
387                    Ok(Type::Primitive(BitFlags::empty()))
388                } else {
389                    Ok(Type::Error(Arc::new(e0.diff_int(env, hist, e1)?)))
390                }
391            }
392            (Type::ByRef(t0), Type::ByRef(t1)) => {
393                Ok(Type::ByRef(Arc::new(t0.diff_int(env, hist, t1)?)))
394            }
395            (
396                Type::Abstract { id: id0, params: p0 },
397                Type::Abstract { id: id1, params: p1 },
398            ) if id0 == id1 && p0 == p1 => Ok(Type::Primitive(BitFlags::empty())),
399            (Type::Abstract { .. }, _)
400            | (_, Type::Abstract { .. })
401            | (Type::Fn(_), _)
402            | (_, Type::Fn(_))
403            | (Type::Array(_), _)
404            | (_, Type::Array(_))
405            | (Type::Tuple(_), _)
406            | (_, Type::Tuple(_))
407            | (Type::Struct(_), _)
408            | (_, Type::Struct(_))
409            | (Type::Variant(_, _), _)
410            | (_, Type::Variant(_, _))
411            | (Type::ByRef(_), _)
412            | (_, Type::ByRef(_))
413            | (Type::Error(_), _)
414            | (_, Type::Error(_))
415            | (Type::Primitive(_), _)
416            | (_, Type::Primitive(_))
417            | (Type::Bottom, _)
418            | (Type::Map { .. }, _) => Ok(self.clone()),
419        }
420    }
421
422    pub fn diff(&self, env: &Env, t: &Self) -> Result<Self> {
423        Ok(self.diff_int(env, &mut RefHist::new(LPooled::take()), t)?.normalize())
424    }
425}