Skip to main content

graphix_compiler/typ/
setops.rs

1use crate::{
2    env::Env,
3    typ::{RefHist, Type, TypeRef},
4};
5use anyhow::Result;
6use enumflags2::BitFlags;
7use fxhash::FxHashMap;
8use netidx::publisher::Typ;
9use poolshark::local::LPooled;
10use std::iter;
11use triomphe::Arc;
12
13impl Type {
14    fn union_int(
15        &self,
16        env: &Env,
17        hist: &mut RefHist<FxHashMap<(Option<usize>, Option<usize>), Type>>,
18        t: &Self,
19    ) -> Result<Self> {
20        match (self, t) {
21            (Type::Ref(t0), Type::Ref(t1))
22                if t0.name == t1.name
23                    && t0.scope == t1.scope
24                    && t0.params.len() == t1.params.len() =>
25            {
26                let mut params = t0
27                    .params
28                    .iter()
29                    .zip(t1.params.iter())
30                    .map(|(p0, p1)| p0.union_int(env, hist, p1))
31                    .collect::<Result<LPooled<Vec<_>>>>()?;
32                let params = Arc::from_iter(params.drain(..));
33                Ok(Self::Ref(TypeRef { params, ..t0.clone() }))
34            }
35            (tr @ Type::Ref (TypeRef { .. }), t) => {
36                let t0_id = hist.ref_id(tr, env);
37                let t_id = hist.ref_id(t, env);
38                let t0 = tr.lookup_ref(env)?;
39                match hist.get(&(t0_id, t_id)) {
40                    Some(t) => Ok(t.clone()),
41                    None => {
42                        hist.insert((t0_id, t_id), tr.clone());
43                        let r = t0.union_int(env, hist, t);
44                        hist.remove(&(t0_id, t_id));
45                        r
46                    }
47                }
48            }
49            (t, tr @ Type::Ref (TypeRef { .. })) => {
50                let t_id = hist.ref_id(t, env);
51                let t1_id = hist.ref_id(tr, env);
52                let t1 = tr.lookup_ref(env)?;
53                match hist.get(&(t_id, t1_id)) {
54                    Some(t) => Ok(t.clone()),
55                    None => {
56                        hist.insert((t_id, t1_id), tr.clone());
57                        let r = t.union_int(env, hist, &t1);
58                        hist.remove(&(t_id, t1_id));
59                        r
60                    }
61                }
62            }
63            (
64                Type::Abstract { id: id0, params: p0 },
65                Type::Abstract { id: id1, params: p1 },
66            ) if id0 == id1 && p0 == p1 => Ok(self.clone()),
67            (t0 @ Type::Abstract { .. }, t1) | (t0, t1 @ Type::Abstract { .. }) => {
68                Ok(Type::Set(Arc::from_iter([t0.clone(), t1.clone()])))
69            }
70            (Type::Bottom, t) | (t, Type::Bottom) => Ok(t.clone()),
71            (Type::Any, _) | (_, Type::Any) => Ok(Type::Any),
72            (Type::Primitive(p), t) | (t, Type::Primitive(p)) if p.is_empty() => {
73                Ok(t.clone())
74            }
75            (Type::Primitive(s0), Type::Primitive(s1)) => {
76                let mut s = *s0;
77                s.insert(*s1);
78                Ok(Type::Primitive(s))
79            }
80            (
81                Type::Primitive(p),
82                Type::Array(_) | Type::Struct(_) | Type::Tuple(_) | Type::Variant(_, _),
83            )
84            | (
85                Type::Array(_) | Type::Struct(_) | Type::Tuple(_) | Type::Variant(_, _),
86                Type::Primitive(p),
87            ) if p.contains(Typ::Array) => Ok(Type::Primitive(*p)),
88            (Type::Primitive(p), Type::Array(t))
89            | (Type::Array(t), Type::Primitive(p)) => Ok(Type::Set(Arc::from_iter([
90                Type::Primitive(*p),
91                Type::Array(t.clone()),
92            ]))),
93            (t @ Type::Array(t0), u @ Type::Array(t1)) => {
94                if t0 == t1 {
95                    Ok(Type::Array(t0.clone()))
96                } else {
97                    Ok(Type::Set(Arc::from_iter([t.clone(), u.clone()])))
98                }
99            }
100            (Type::Primitive(p), Type::Map { .. })
101            | (Type::Map { .. }, Type::Primitive(p))
102                if p.contains(Typ::Map) =>
103            {
104                Ok(Type::Primitive(*p))
105            }
106            (Type::Primitive(p), Type::Map { key, value })
107            | (Type::Map { key, value }, Type::Primitive(p)) => {
108                Ok(Type::Set(Arc::from_iter([
109                    Type::Primitive(*p),
110                    Type::Map { key: key.clone(), value: value.clone() },
111                ])))
112            }
113            (
114                t @ Type::Map { key: k0, value: v0 },
115                u @ Type::Map { key: k1, value: v1 },
116            ) => {
117                if k0 == k1 && v0 == v1 {
118                    Ok(Type::Map { key: k0.clone(), value: v0.clone() })
119                } else {
120                    Ok(Type::Set(Arc::from_iter([t.clone(), u.clone()])))
121                }
122            }
123            (t @ Type::Map { .. }, u) | (u, t @ Type::Map { .. }) => {
124                Ok(Type::Set(Arc::from_iter([t.clone(), u.clone()])))
125            }
126            (Type::Primitive(p), Type::Error(_))
127            | (Type::Error(_), Type::Primitive(p))
128                if p.contains(Typ::Error) =>
129            {
130                Ok(Type::Primitive(*p))
131            }
132            (Type::Error(e0), Type::Error(e1)) => {
133                Ok(Type::Error(Arc::new(e0.union_int(env, hist, e1)?)))
134            }
135            (e @ Type::Error(_), t) | (t, e @ Type::Error(_)) => {
136                Ok(Type::Set(Arc::from_iter([e.clone(), t.clone()])))
137            }
138            (t @ Type::ByRef(t0), u @ Type::ByRef(t1)) => {
139                if t0 == t1 {
140                    Ok(Type::ByRef(t0.clone()))
141                } else {
142                    Ok(Type::Set(Arc::from_iter([u.clone(), t.clone()])))
143                }
144            }
145            (Type::Set(s0), Type::Set(s1)) => Ok(Type::Set(Arc::from_iter(
146                s0.iter().cloned().chain(s1.iter().cloned()),
147            ))),
148            (Type::Set(s), t) | (t, Type::Set(s)) => Ok(Type::Set(Arc::from_iter(
149                s.iter().cloned().chain(iter::once(t.clone())),
150            ))),
151            (u @ Type::Struct(t0), t @ Type::Struct(t1)) => {
152                if t0.len() == t1.len() && t0 == t1 {
153                    Ok(u.clone())
154                } else {
155                    Ok(Type::Set(Arc::from_iter([u.clone(), t.clone()])))
156                }
157            }
158            (u @ Type::Struct(_), t) | (t, u @ Type::Struct(_)) => {
159                Ok(Type::Set(Arc::from_iter([u.clone(), t.clone()])))
160            }
161            (u @ Type::Tuple(t0), t @ Type::Tuple(t1)) => {
162                if t0 == t1 {
163                    Ok(u.clone())
164                } else {
165                    Ok(Type::Set(Arc::from_iter([u.clone(), t.clone()])))
166                }
167            }
168            (u @ Type::Tuple(_), t) | (t, u @ Type::Tuple(_)) => {
169                Ok(Type::Set(Arc::from_iter([u.clone(), t.clone()])))
170            }
171            (u @ Type::Variant(tg0, t0), t @ Type::Variant(tg1, t1)) => {
172                if tg0 == tg1 && t0.len() == t1.len() {
173                    let mut typs = t0
174                        .iter()
175                        .zip(t1.iter())
176                        .map(|(t0, t1)| t0.union_int(env, hist, t1))
177                        .collect::<Result<LPooled<Vec<_>>>>()?;
178                    Ok(Type::Variant(tg0.clone(), Arc::from_iter(typs.drain(..))))
179                } else {
180                    Ok(Type::Set(Arc::from_iter([u.clone(), t.clone()])))
181                }
182            }
183            (u @ Type::Variant(_, _), t) | (t, u @ Type::Variant(_, _)) => {
184                Ok(Type::Set(Arc::from_iter([u.clone(), t.clone()])))
185            }
186            (Type::Fn(f0), Type::Fn(f1)) => {
187                if f0 == f1 {
188                    Ok(Type::Fn(f0.clone()))
189                } else {
190                    Ok(Type::Set(Arc::from_iter([
191                        Type::Fn(f0.clone()),
192                        Type::Fn(f1.clone()),
193                    ])))
194                }
195            }
196            (f @ Type::Fn(_), t) | (t, f @ Type::Fn(_)) => {
197                Ok(Type::Set(Arc::from_iter([f.clone(), t.clone()])))
198            }
199            (t0 @ Type::TVar(_), t1 @ Type::TVar(_)) => {
200                if t0 == t1 {
201                    Ok(t0.clone())
202                } else {
203                    Ok(Type::Set(Arc::from_iter([t0.clone(), t1.clone()])))
204                }
205            }
206            (t0 @ Type::TVar(_), t1) | (t1, t0 @ Type::TVar(_)) => {
207                Ok(Type::Set(Arc::from_iter([t0.clone(), t1.clone()])))
208            }
209            (t @ Type::ByRef(_), u) | (u, t @ Type::ByRef(_)) => {
210                Ok(Type::Set(Arc::from_iter([t.clone(), u.clone()])))
211            }
212        }
213    }
214
215    pub fn union(&self, env: &Env, t: &Self) -> Result<Self> {
216        Ok(self.union_int(env, &mut RefHist::new(LPooled::take()), t)?.normalize())
217    }
218
219    fn diff_int(
220        &self,
221        env: &Env,
222        hist: &mut RefHist<FxHashMap<(Option<usize>, Option<usize>), Type>>,
223        t: &Self,
224    ) -> Result<Self> {
225        match (self, t) {
226            (
227                Type::Ref (TypeRef { scope: s0, name: n0, .. }),
228                Type::Ref (TypeRef { scope: s1, name: n1, .. }),
229            ) if s0 == s1 && n0 == n1 => Ok(Type::Primitive(BitFlags::empty())),
230            (t0 @ Type::Ref (TypeRef { .. }), t1) | (t0, t1 @ Type::Ref (TypeRef { .. })) => {
231                let t0_id = hist.ref_id(t0, env);
232                let t1_id = hist.ref_id(t1, env);
233                let t0 = t0.lookup_ref(env)?;
234                let t1 = t1.lookup_ref(env)?;
235                match hist.get(&(t0_id, t1_id)) {
236                    Some(r) => Ok(r.clone()),
237                    None => {
238                        let r = Type::Primitive(BitFlags::empty());
239                        hist.insert((t0_id, t1_id), r);
240                        let r = t0.diff_int(env, hist, &t1);
241                        hist.remove(&(t0_id, t1_id));
242                        r
243                    }
244                }
245            }
246            (Type::Set(s0), Type::Set(s1)) => {
247                let mut s: LPooled<Vec<Type>> = LPooled::take();
248                for i in 0..s0.len() {
249                    s.push(s0[i].clone());
250                    for j in 0..s1.len() {
251                        s[i] = s[i].diff_int(env, hist, &s1[j])?
252                    }
253                }
254                Ok(Self::flatten_set(s.drain(..)))
255            }
256            (Type::Set(s), t) => Ok(Self::flatten_set(
257                s.iter()
258                    .map(|s| s.diff_int(env, hist, t))
259                    .collect::<Result<LPooled<Vec<_>>>>()?
260                    .drain(..),
261            )),
262            (t, Type::Set(s)) => {
263                let mut t = t.clone();
264                for st in s.iter() {
265                    t = t.diff_int(env, hist, st)?;
266                }
267                Ok(t)
268            }
269            (Type::Tuple(t0), Type::Tuple(t1)) => {
270                if t0 == t1 {
271                    Ok(Type::Primitive(BitFlags::empty()))
272                } else {
273                    Ok(self.clone())
274                }
275            }
276            (Type::Struct(t0), Type::Struct(t1)) => {
277                if t0.len() == t1.len() && t0 == t1 {
278                    Ok(Type::Primitive(BitFlags::empty()))
279                } else {
280                    Ok(self.clone())
281                }
282            }
283            (Type::Variant(tg0, t0), Type::Variant(tg1, t1)) => {
284                if tg0 == tg1 && t0.len() == t1.len() && t0 == t1 {
285                    Ok(Type::Primitive(BitFlags::empty()))
286                } else {
287                    Ok(self.clone())
288                }
289            }
290            (Type::Map { key: k0, value: v0 }, Type::Map { key: k1, value: v1 }) => {
291                if k0 == k1 && v0 == v1 {
292                    Ok(Type::Primitive(BitFlags::empty()))
293                } else {
294                    Ok(self.clone())
295                }
296            }
297            (Type::Map { .. }, Type::Primitive(p)) => {
298                if p.contains(Typ::Map) {
299                    Ok(Type::Primitive(BitFlags::empty()))
300                } else {
301                    Ok(self.clone())
302                }
303            }
304            (Type::Primitive(p), Type::Map { key, value }) => {
305                if **key == Type::Any && **value == Type::Any {
306                    let mut p = *p;
307                    p.remove(Typ::Map);
308                    Ok(Type::Primitive(p))
309                } else {
310                    Ok(Type::Primitive(*p))
311                }
312            }
313            (Type::Fn(f0), Type::Fn(f1)) => {
314                if f0 == f1 {
315                    Ok(Type::Primitive(BitFlags::empty()))
316                } else {
317                    Ok(Type::Fn(f0.clone()))
318                }
319            }
320            (Type::TVar(tv0), t1 @ Type::TVar(tv1)) => {
321                if Arc::ptr_eq(&tv0.read().typ, &tv1.read().typ) {
322                    return Ok(Type::Primitive(BitFlags::empty()));
323                }
324                Ok(match (&*tv0.read().typ.read(), &*tv1.read().typ.read()) {
325                    (None, _) => Type::TVar(tv0.clone()),
326                    (Some(t0), None) => t0.diff_int(env, hist, t1)?,
327                    (Some(t0), Some(t1)) => t0.diff_int(env, hist, t1)?,
328                })
329            }
330            (Type::TVar(tv), t) => Ok(match &*tv.read().typ.read() {
331                Some(tv) => tv.diff_int(env, hist, t)?,
332                None => self.clone(),
333            }),
334            (t, Type::TVar(tv)) => Ok(match &*tv.read().typ.read() {
335                Some(tv) => t.diff_int(env, hist, tv)?,
336                None => self.clone(),
337            }),
338            (Type::Array(t0), Type::Array(t1)) => {
339                if t0 == t1 {
340                    Ok(Type::Primitive(BitFlags::empty()))
341                } else {
342                    Ok(Type::Array(Arc::new(t0.diff_int(env, hist, t1)?)))
343                }
344            }
345            (Type::Primitive(p), Type::Array(t)) => {
346                if &**t == &Type::Any {
347                    let mut s = *p;
348                    s.remove(Typ::Array);
349                    Ok(Type::Primitive(s))
350                } else {
351                    Ok(Type::Primitive(*p))
352                }
353            }
354            (
355                Type::Array(_) | Type::Struct(_) | Type::Tuple(_) | Type::Variant(_, _),
356                Type::Primitive(p),
357            ) => {
358                if p.contains(Typ::Array) {
359                    Ok(Type::Primitive(BitFlags::empty()))
360                } else {
361                    Ok(self.clone())
362                }
363            }
364            (_, Type::Any) => Ok(Type::Primitive(BitFlags::empty())),
365            (Type::Any, _) => Ok(Type::Any),
366            (Type::Primitive(s0), Type::Primitive(s1)) => {
367                let mut s = *s0;
368                s.remove(*s1);
369                Ok(Type::Primitive(s))
370            }
371            (Type::Primitive(p), Type::Error(e)) => {
372                if &**e == &Type::Any {
373                    let mut s = *p;
374                    s.remove(Typ::Error);
375                    Ok(Type::Primitive(s))
376                } else {
377                    Ok(Type::Primitive(*p))
378                }
379            }
380            (Type::Error(_), Type::Primitive(p)) => {
381                if p.contains(Typ::Error) {
382                    Ok(Type::Primitive(BitFlags::empty()))
383                } else {
384                    Ok(self.clone())
385                }
386            }
387            (Type::Error(e0), Type::Error(e1)) => {
388                if e0 == e1 {
389                    Ok(Type::Primitive(BitFlags::empty()))
390                } else {
391                    Ok(Type::Error(Arc::new(e0.diff_int(env, hist, e1)?)))
392                }
393            }
394            (Type::ByRef(t0), Type::ByRef(t1)) => {
395                Ok(Type::ByRef(Arc::new(t0.diff_int(env, hist, t1)?)))
396            }
397            (
398                Type::Abstract { id: id0, params: p0 },
399                Type::Abstract { id: id1, params: p1 },
400            ) if id0 == id1 && p0 == p1 => Ok(Type::Primitive(BitFlags::empty())),
401            (Type::Abstract { .. }, _)
402            | (_, Type::Abstract { .. })
403            | (Type::Fn(_), _)
404            | (_, Type::Fn(_))
405            | (Type::Array(_), _)
406            | (_, Type::Array(_))
407            | (Type::Tuple(_), _)
408            | (_, Type::Tuple(_))
409            | (Type::Struct(_), _)
410            | (_, Type::Struct(_))
411            | (Type::Variant(_, _), _)
412            | (_, Type::Variant(_, _))
413            | (Type::ByRef(_), _)
414            | (_, Type::ByRef(_))
415            | (Type::Error(_), _)
416            | (_, Type::Error(_))
417            | (Type::Primitive(_), _)
418            | (_, Type::Primitive(_))
419            | (Type::Bottom, _)
420            | (Type::Map { .. }, _) => Ok(self.clone()),
421        }
422    }
423
424    pub fn diff(&self, env: &Env, t: &Self) -> Result<Self> {
425        Ok(self.diff_int(env, &mut RefHist::new(LPooled::take()), t)?.normalize())
426    }
427}