Skip to main content

graphix_compiler/typ/
setops.rs

1use crate::{
2    env::Env,
3    typ::{RefHist, Type, TypeRef},
4};
5use ahash::AHashMap;
6use anyhow::Result;
7use enumflags2::BitFlags;
8use netidx::publisher::Typ;
9use poolshark::local::LPooled;
10use std::iter;
11use triomphe::Arc;
12
13impl Type {
14    fn union_int(
15        &self,
16        env: &Env,
17        hist: &mut RefHist<AHashMap<(Option<usize>, Option<usize>), Type>>,
18        t: &Self,
19    ) -> Result<Self> {
20        match (self, t) {
21            (Type::Ref(t0), Type::Ref(t1))
22                if t0.name == t1.name
23                    && t0.scope == t1.scope
24                    && t0.params.len() == t1.params.len() =>
25            {
26                let mut params = t0
27                    .params
28                    .iter()
29                    .zip(t1.params.iter())
30                    .map(|(p0, p1)| p0.union_int(env, hist, p1))
31                    .collect::<Result<LPooled<Vec<_>>>>()?;
32                let params = Arc::from_iter(params.drain(..));
33                Ok(Self::Ref(TypeRef { params, ..t0.clone() }))
34            }
35            (tr @ Type::Ref(TypeRef { .. }), t) => {
36                let t0_id = hist.ref_id(tr, env);
37                let t_id = hist.ref_id(t, env);
38                let t0 = tr.lookup_ref(env)?;
39                match hist.get(&(t0_id, t_id)) {
40                    Some(t) => Ok(t.clone()),
41                    None => {
42                        hist.insert((t0_id, t_id), tr.clone());
43                        let r = t0.union_int(env, hist, t);
44                        hist.remove(&(t0_id, t_id));
45                        r
46                    }
47                }
48            }
49            (t, tr @ Type::Ref(TypeRef { .. })) => {
50                let t_id = hist.ref_id(t, env);
51                let t1_id = hist.ref_id(tr, env);
52                let t1 = tr.lookup_ref(env)?;
53                match hist.get(&(t_id, t1_id)) {
54                    Some(t) => Ok(t.clone()),
55                    None => {
56                        hist.insert((t_id, t1_id), tr.clone());
57                        let r = t.union_int(env, hist, &t1);
58                        hist.remove(&(t_id, t1_id));
59                        r
60                    }
61                }
62            }
63            (
64                Type::Abstract { id: id0, params: p0 },
65                Type::Abstract { id: id1, params: p1 },
66            ) if id0 == id1 && p0 == p1 => Ok(self.clone()),
67            (t0 @ Type::Abstract { .. }, t1) | (t0, t1 @ Type::Abstract { .. }) => {
68                Ok(Type::Set(Arc::from_iter([t0.clone(), t1.clone()])))
69            }
70            (Type::Bottom, t) | (t, Type::Bottom) => Ok(t.clone()),
71            (Type::Any, _) | (_, Type::Any) => Ok(Type::Any),
72            (Type::Primitive(p), t) | (t, Type::Primitive(p)) if p.is_empty() => {
73                Ok(t.clone())
74            }
75            (Type::Primitive(s0), Type::Primitive(s1)) => {
76                let mut s = *s0;
77                s.insert(*s1);
78                Ok(Type::Primitive(s))
79            }
80            (
81                Type::Primitive(p),
82                Type::Array(_) | Type::Struct(_) | Type::Tuple(_) | Type::Variant(_, _),
83            )
84            | (
85                Type::Array(_) | Type::Struct(_) | Type::Tuple(_) | Type::Variant(_, _),
86                Type::Primitive(p),
87            ) if p.contains(Typ::Array) => Ok(Type::Primitive(*p)),
88            (Type::Primitive(p), Type::Array(t))
89            | (Type::Array(t), Type::Primitive(p)) => Ok(Type::Set(Arc::from_iter([
90                Type::Primitive(*p),
91                Type::Array(t.clone()),
92            ]))),
93            (t @ Type::Array(t0), u @ Type::Array(t1)) => {
94                if t0 == t1 {
95                    Ok(Type::Array(t0.clone()))
96                } else {
97                    Ok(Type::Set(Arc::from_iter([t.clone(), u.clone()])))
98                }
99            }
100            (Type::Primitive(p), Type::Map { .. })
101            | (Type::Map { .. }, Type::Primitive(p))
102                if p.contains(Typ::Map) =>
103            {
104                Ok(Type::Primitive(*p))
105            }
106            (Type::Primitive(p), Type::Map { key, value })
107            | (Type::Map { key, value }, Type::Primitive(p)) => {
108                Ok(Type::Set(Arc::from_iter([
109                    Type::Primitive(*p),
110                    Type::Map { key: key.clone(), value: value.clone() },
111                ])))
112            }
113            (
114                t @ Type::Map { key: k0, value: v0 },
115                u @ Type::Map { key: k1, value: v1 },
116            ) => {
117                if k0 == k1 && v0 == v1 {
118                    Ok(Type::Map { key: k0.clone(), value: v0.clone() })
119                } else {
120                    Ok(Type::Set(Arc::from_iter([t.clone(), u.clone()])))
121                }
122            }
123            (t @ Type::Map { .. }, u) | (u, t @ Type::Map { .. }) => {
124                Ok(Type::Set(Arc::from_iter([t.clone(), u.clone()])))
125            }
126            (Type::Primitive(p), Type::Error(_))
127            | (Type::Error(_), Type::Primitive(p))
128                if p.contains(Typ::Error) =>
129            {
130                Ok(Type::Primitive(*p))
131            }
132            (Type::Error(e0), Type::Error(e1)) => {
133                Ok(Type::Error(Arc::new(e0.union_int(env, hist, e1)?)))
134            }
135            (e @ Type::Error(_), t) | (t, e @ Type::Error(_)) => {
136                Ok(Type::Set(Arc::from_iter([e.clone(), t.clone()])))
137            }
138            (t @ Type::ByRef(t0), u @ Type::ByRef(t1)) => {
139                if t0 == t1 {
140                    Ok(Type::ByRef(t0.clone()))
141                } else {
142                    Ok(Type::Set(Arc::from_iter([u.clone(), t.clone()])))
143                }
144            }
145            (Type::Set(s0), Type::Set(s1)) => Ok(Type::Set(Arc::from_iter(
146                s0.iter().cloned().chain(s1.iter().cloned()),
147            ))),
148            (Type::Set(s), t) | (t, Type::Set(s)) => Ok(Type::Set(Arc::from_iter(
149                s.iter().cloned().chain(iter::once(t.clone())),
150            ))),
151            (u @ Type::Struct(t0), t @ Type::Struct(t1)) => {
152                if t0.len() == t1.len() && t0 == t1 {
153                    Ok(u.clone())
154                } else {
155                    Ok(Type::Set(Arc::from_iter([u.clone(), t.clone()])))
156                }
157            }
158            (u @ Type::Struct(_), t) | (t, u @ Type::Struct(_)) => {
159                Ok(Type::Set(Arc::from_iter([u.clone(), t.clone()])))
160            }
161            (u @ Type::Tuple(t0), t @ Type::Tuple(t1)) => {
162                if t0 == t1 {
163                    Ok(u.clone())
164                } else {
165                    Ok(Type::Set(Arc::from_iter([u.clone(), t.clone()])))
166                }
167            }
168            (u @ Type::Tuple(_), t) | (t, u @ Type::Tuple(_)) => {
169                Ok(Type::Set(Arc::from_iter([u.clone(), t.clone()])))
170            }
171            (u @ Type::Variant(tg0, t0), t @ Type::Variant(tg1, t1)) => {
172                if tg0 == tg1 && t0.len() == t1.len() {
173                    let mut typs = t0
174                        .iter()
175                        .zip(t1.iter())
176                        .map(|(t0, t1)| t0.union_int(env, hist, t1))
177                        .collect::<Result<LPooled<Vec<_>>>>()?;
178                    Ok(Type::Variant(tg0.clone(), Arc::from_iter(typs.drain(..))))
179                } else {
180                    Ok(Type::Set(Arc::from_iter([u.clone(), t.clone()])))
181                }
182            }
183            (u @ Type::Variant(_, _), t) | (t, u @ Type::Variant(_, _)) => {
184                Ok(Type::Set(Arc::from_iter([u.clone(), t.clone()])))
185            }
186            (Type::Fn(f0), Type::Fn(f1)) => {
187                if f0 == f1 {
188                    Ok(Type::Fn(f0.clone()))
189                } else {
190                    Ok(Type::Set(Arc::from_iter([
191                        Type::Fn(f0.clone()),
192                        Type::Fn(f1.clone()),
193                    ])))
194                }
195            }
196            (f @ Type::Fn(_), t) | (t, f @ Type::Fn(_)) => {
197                Ok(Type::Set(Arc::from_iter([f.clone(), t.clone()])))
198            }
199            (t0 @ Type::TVar(_), t1 @ Type::TVar(_)) => {
200                if t0 == t1 {
201                    Ok(t0.clone())
202                } else {
203                    Ok(Type::Set(Arc::from_iter([t0.clone(), t1.clone()])))
204                }
205            }
206            (t0 @ Type::TVar(_), t1) | (t1, t0 @ Type::TVar(_)) => {
207                Ok(Type::Set(Arc::from_iter([t0.clone(), t1.clone()])))
208            }
209            (t @ Type::ByRef(_), u) | (u, t @ Type::ByRef(_)) => {
210                Ok(Type::Set(Arc::from_iter([t.clone(), u.clone()])))
211            }
212        }
213    }
214
215    pub fn union(&self, env: &Env, t: &Self) -> Result<Self> {
216        Ok(self.union_int(env, &mut RefHist::new(LPooled::take()), t)?.normalize())
217    }
218
219    fn diff_int(
220        &self,
221        env: &Env,
222        hist: &mut RefHist<AHashMap<(Option<usize>, Option<usize>), Type>>,
223        t: &Self,
224    ) -> Result<Self> {
225        match (self, t) {
226            (
227                Type::Ref(TypeRef { scope: s0, name: n0, .. }),
228                Type::Ref(TypeRef { scope: s1, name: n1, .. }),
229            ) if s0 == s1 && n0 == n1 => Ok(Type::Primitive(BitFlags::empty())),
230            (t0 @ Type::Ref(TypeRef { .. }), t1)
231            | (t0, t1 @ Type::Ref(TypeRef { .. })) => {
232                let t0_id = hist.ref_id(t0, env);
233                let t1_id = hist.ref_id(t1, env);
234                let t0 = t0.lookup_ref(env)?;
235                let t1 = t1.lookup_ref(env)?;
236                match hist.get(&(t0_id, t1_id)) {
237                    Some(r) => Ok(r.clone()),
238                    None => {
239                        let r = Type::Primitive(BitFlags::empty());
240                        hist.insert((t0_id, t1_id), r);
241                        let r = t0.diff_int(env, hist, &t1);
242                        hist.remove(&(t0_id, t1_id));
243                        r
244                    }
245                }
246            }
247            (Type::Set(s0), Type::Set(s1)) => {
248                let mut s: LPooled<Vec<Type>> = LPooled::take();
249                for i in 0..s0.len() {
250                    s.push(s0[i].clone());
251                    for j in 0..s1.len() {
252                        s[i] = s[i].diff_int(env, hist, &s1[j])?
253                    }
254                }
255                Ok(Self::flatten_set(s.drain(..)))
256            }
257            (Type::Set(s), t) => Ok(Self::flatten_set(
258                s.iter()
259                    .map(|s| s.diff_int(env, hist, t))
260                    .collect::<Result<LPooled<Vec<_>>>>()?
261                    .drain(..),
262            )),
263            (t, Type::Set(s)) => {
264                let mut t = t.clone();
265                for st in s.iter() {
266                    t = t.diff_int(env, hist, st)?;
267                }
268                Ok(t)
269            }
270            (Type::Tuple(t0), Type::Tuple(t1)) => {
271                if t0 == t1 {
272                    Ok(Type::Primitive(BitFlags::empty()))
273                } else {
274                    Ok(self.clone())
275                }
276            }
277            (Type::Struct(t0), Type::Struct(t1)) => {
278                if t0.len() == t1.len() && t0 == t1 {
279                    Ok(Type::Primitive(BitFlags::empty()))
280                } else {
281                    Ok(self.clone())
282                }
283            }
284            (Type::Variant(tg0, t0), Type::Variant(tg1, t1)) => {
285                if tg0 == tg1 && t0.len() == t1.len() && t0 == t1 {
286                    Ok(Type::Primitive(BitFlags::empty()))
287                } else {
288                    Ok(self.clone())
289                }
290            }
291            (Type::Map { key: k0, value: v0 }, Type::Map { key: k1, value: v1 }) => {
292                if k0 == k1 && v0 == v1 {
293                    Ok(Type::Primitive(BitFlags::empty()))
294                } else {
295                    Ok(self.clone())
296                }
297            }
298            (Type::Map { .. }, Type::Primitive(p)) => {
299                if p.contains(Typ::Map) {
300                    Ok(Type::Primitive(BitFlags::empty()))
301                } else {
302                    Ok(self.clone())
303                }
304            }
305            (Type::Primitive(p), Type::Map { key, value }) => {
306                if **key == Type::Any && **value == Type::Any {
307                    let mut p = *p;
308                    p.remove(Typ::Map);
309                    Ok(Type::Primitive(p))
310                } else {
311                    Ok(Type::Primitive(*p))
312                }
313            }
314            (Type::Fn(f0), Type::Fn(f1)) => {
315                if f0 == f1 {
316                    Ok(Type::Primitive(BitFlags::empty()))
317                } else {
318                    Ok(Type::Fn(f0.clone()))
319                }
320            }
321            (Type::TVar(tv0), t1 @ Type::TVar(tv1)) => {
322                if Arc::ptr_eq(&tv0.read().typ, &tv1.read().typ) {
323                    return Ok(Type::Primitive(BitFlags::empty()));
324                }
325                Ok(match (&*tv0.read().typ.read(), &*tv1.read().typ.read()) {
326                    (None, _) => Type::TVar(tv0.clone()),
327                    (Some(t0), None) => t0.diff_int(env, hist, t1)?,
328                    (Some(t0), Some(t1)) => t0.diff_int(env, hist, t1)?,
329                })
330            }
331            (Type::TVar(tv), t) => Ok(match &*tv.read().typ.read() {
332                Some(tv) => tv.diff_int(env, hist, t)?,
333                None => self.clone(),
334            }),
335            (t, Type::TVar(tv)) => Ok(match &*tv.read().typ.read() {
336                Some(tv) => t.diff_int(env, hist, tv)?,
337                None => self.clone(),
338            }),
339            (Type::Array(t0), Type::Array(t1)) => {
340                if t0 == t1 {
341                    Ok(Type::Primitive(BitFlags::empty()))
342                } else {
343                    Ok(Type::Array(Arc::new(t0.diff_int(env, hist, t1)?)))
344                }
345            }
346            (Type::Primitive(p), Type::Array(t)) => {
347                if &**t == &Type::Any {
348                    let mut s = *p;
349                    s.remove(Typ::Array);
350                    Ok(Type::Primitive(s))
351                } else {
352                    Ok(Type::Primitive(*p))
353                }
354            }
355            (
356                Type::Array(_) | Type::Struct(_) | Type::Tuple(_) | Type::Variant(_, _),
357                Type::Primitive(p),
358            ) => {
359                if p.contains(Typ::Array) {
360                    Ok(Type::Primitive(BitFlags::empty()))
361                } else {
362                    Ok(self.clone())
363                }
364            }
365            (_, Type::Any) => Ok(Type::Primitive(BitFlags::empty())),
366            (Type::Any, _) => Ok(Type::Any),
367            (Type::Primitive(s0), Type::Primitive(s1)) => {
368                let mut s = *s0;
369                s.remove(*s1);
370                Ok(Type::Primitive(s))
371            }
372            (Type::Primitive(p), Type::Error(e)) => {
373                if &**e == &Type::Any {
374                    let mut s = *p;
375                    s.remove(Typ::Error);
376                    Ok(Type::Primitive(s))
377                } else {
378                    Ok(Type::Primitive(*p))
379                }
380            }
381            (Type::Error(_), Type::Primitive(p)) => {
382                if p.contains(Typ::Error) {
383                    Ok(Type::Primitive(BitFlags::empty()))
384                } else {
385                    Ok(self.clone())
386                }
387            }
388            (Type::Error(e0), Type::Error(e1)) => {
389                if e0 == e1 {
390                    Ok(Type::Primitive(BitFlags::empty()))
391                } else {
392                    Ok(Type::Error(Arc::new(e0.diff_int(env, hist, e1)?)))
393                }
394            }
395            (Type::ByRef(t0), Type::ByRef(t1)) => {
396                Ok(Type::ByRef(Arc::new(t0.diff_int(env, hist, t1)?)))
397            }
398            (
399                Type::Abstract { id: id0, params: p0 },
400                Type::Abstract { id: id1, params: p1 },
401            ) if id0 == id1 && p0 == p1 => Ok(Type::Primitive(BitFlags::empty())),
402            (Type::Abstract { .. }, _)
403            | (_, Type::Abstract { .. })
404            | (Type::Fn(_), _)
405            | (_, Type::Fn(_))
406            | (Type::Array(_), _)
407            | (_, Type::Array(_))
408            | (Type::Tuple(_), _)
409            | (_, Type::Tuple(_))
410            | (Type::Struct(_), _)
411            | (_, Type::Struct(_))
412            | (Type::Variant(_, _), _)
413            | (_, Type::Variant(_, _))
414            | (Type::ByRef(_), _)
415            | (_, Type::ByRef(_))
416            | (Type::Error(_), _)
417            | (_, Type::Error(_))
418            | (Type::Primitive(_), _)
419            | (_, Type::Primitive(_))
420            | (Type::Bottom, _)
421            | (Type::Map { .. }, _) => Ok(self.clone()),
422        }
423    }
424
425    pub fn diff(&self, env: &Env, t: &Self) -> Result<Self> {
426        Ok(self.diff_int(env, &mut RefHist::new(LPooled::take()), t)?.normalize())
427    }
428}