1use crate::{env::Env, typ::{RefHist, Type}};
2use anyhow::Result;
3use enumflags2::BitFlags;
4use fxhash::FxHashMap;
5use netidx::publisher::Typ;
6use poolshark::local::LPooled;
7use std::iter;
8use triomphe::Arc;
9
10impl Type {
11 fn union_int(
12 &self,
13 env: &Env,
14 hist: &mut RefHist<FxHashMap<(Option<usize>, Option<usize>), Type>>,
15 t: &Self,
16 ) -> Result<Self> {
17 match (self, t) {
18 (
19 Type::Ref { scope: s0, name: n0, params: p0 },
20 Type::Ref { scope: s1, name: n1, params: p1 },
21 ) if n0 == n1 && s0 == s1 && p0.len() == p1.len() => {
22 let mut params = p0
23 .iter()
24 .zip(p1.iter())
25 .map(|(p0, p1)| p0.union_int(env, hist, p1))
26 .collect::<Result<LPooled<Vec<_>>>>()?;
27 let params = Arc::from_iter(params.drain(..));
28 Ok(Self::Ref { scope: s0.clone(), name: n0.clone(), params })
29 }
30 (tr @ Type::Ref { .. }, t) => {
31 let t0_id = hist.ref_id(tr, env);
32 let t_id = hist.ref_id(t, env);
33 let t0 = tr.lookup_ref(env)?;
34 match hist.get(&(t0_id, t_id)) {
35 Some(t) => Ok(t.clone()),
36 None => {
37 hist.insert((t0_id, t_id), tr.clone());
38 let r = t0.union_int(env, hist, t);
39 hist.remove(&(t0_id, t_id));
40 r
41 }
42 }
43 }
44 (t, tr @ Type::Ref { .. }) => {
45 let t_id = hist.ref_id(t, env);
46 let t1_id = hist.ref_id(tr, env);
47 let t1 = tr.lookup_ref(env)?;
48 match hist.get(&(t_id, t1_id)) {
49 Some(t) => Ok(t.clone()),
50 None => {
51 hist.insert((t_id, t1_id), tr.clone());
52 let r = t.union_int(env, hist, &t1);
53 hist.remove(&(t_id, t1_id));
54 r
55 }
56 }
57 }
58 (
59 Type::Abstract { id: id0, params: p0 },
60 Type::Abstract { id: id1, params: p1 },
61 ) if id0 == id1 && p0 == p1 => Ok(self.clone()),
62 (t0 @ Type::Abstract { .. }, t1) | (t0, t1 @ Type::Abstract { .. }) => {
63 Ok(Type::Set(Arc::from_iter([t0.clone(), t1.clone()])))
64 }
65 (Type::Bottom, t) | (t, Type::Bottom) => Ok(t.clone()),
66 (Type::Any, _) | (_, Type::Any) => Ok(Type::Any),
67 (Type::Primitive(p), t) | (t, Type::Primitive(p)) if p.is_empty() => {
68 Ok(t.clone())
69 }
70 (Type::Primitive(s0), Type::Primitive(s1)) => {
71 let mut s = *s0;
72 s.insert(*s1);
73 Ok(Type::Primitive(s))
74 }
75 (
76 Type::Primitive(p),
77 Type::Array(_) | Type::Struct(_) | Type::Tuple(_) | Type::Variant(_, _),
78 )
79 | (
80 Type::Array(_) | Type::Struct(_) | Type::Tuple(_) | Type::Variant(_, _),
81 Type::Primitive(p),
82 ) if p.contains(Typ::Array) => Ok(Type::Primitive(*p)),
83 (Type::Primitive(p), Type::Array(t))
84 | (Type::Array(t), Type::Primitive(p)) => Ok(Type::Set(Arc::from_iter([
85 Type::Primitive(*p),
86 Type::Array(t.clone()),
87 ]))),
88 (t @ Type::Array(t0), u @ Type::Array(t1)) => {
89 if t0 == t1 {
90 Ok(Type::Array(t0.clone()))
91 } else {
92 Ok(Type::Set(Arc::from_iter([t.clone(), u.clone()])))
93 }
94 }
95 (Type::Primitive(p), Type::Map { .. })
96 | (Type::Map { .. }, Type::Primitive(p))
97 if p.contains(Typ::Map) =>
98 {
99 Ok(Type::Primitive(*p))
100 }
101 (Type::Primitive(p), Type::Map { key, value })
102 | (Type::Map { key, value }, Type::Primitive(p)) => {
103 Ok(Type::Set(Arc::from_iter([
104 Type::Primitive(*p),
105 Type::Map { key: key.clone(), value: value.clone() },
106 ])))
107 }
108 (
109 t @ Type::Map { key: k0, value: v0 },
110 u @ Type::Map { key: k1, value: v1 },
111 ) => {
112 if k0 == k1 && v0 == v1 {
113 Ok(Type::Map { key: k0.clone(), value: v0.clone() })
114 } else {
115 Ok(Type::Set(Arc::from_iter([t.clone(), u.clone()])))
116 }
117 }
118 (t @ Type::Map { .. }, u) | (u, t @ Type::Map { .. }) => {
119 Ok(Type::Set(Arc::from_iter([t.clone(), u.clone()])))
120 }
121 (Type::Primitive(p), Type::Error(_))
122 | (Type::Error(_), Type::Primitive(p))
123 if p.contains(Typ::Error) =>
124 {
125 Ok(Type::Primitive(*p))
126 }
127 (Type::Error(e0), Type::Error(e1)) => {
128 Ok(Type::Error(Arc::new(e0.union_int(env, hist, e1)?)))
129 }
130 (e @ Type::Error(_), t) | (t, e @ Type::Error(_)) => {
131 Ok(Type::Set(Arc::from_iter([e.clone(), t.clone()])))
132 }
133 (t @ Type::ByRef(t0), u @ Type::ByRef(t1)) => {
134 if t0 == t1 {
135 Ok(Type::ByRef(t0.clone()))
136 } else {
137 Ok(Type::Set(Arc::from_iter([u.clone(), t.clone()])))
138 }
139 }
140 (Type::Set(s0), Type::Set(s1)) => Ok(Type::Set(Arc::from_iter(
141 s0.iter().cloned().chain(s1.iter().cloned()),
142 ))),
143 (Type::Set(s), t) | (t, Type::Set(s)) => Ok(Type::Set(Arc::from_iter(
144 s.iter().cloned().chain(iter::once(t.clone())),
145 ))),
146 (u @ Type::Struct(t0), t @ Type::Struct(t1)) => {
147 if t0.len() == t1.len() && t0 == t1 {
148 Ok(u.clone())
149 } else {
150 Ok(Type::Set(Arc::from_iter([u.clone(), t.clone()])))
151 }
152 }
153 (u @ Type::Struct(_), t) | (t, u @ Type::Struct(_)) => {
154 Ok(Type::Set(Arc::from_iter([u.clone(), t.clone()])))
155 }
156 (u @ Type::Tuple(t0), t @ Type::Tuple(t1)) => {
157 if t0 == t1 {
158 Ok(u.clone())
159 } else {
160 Ok(Type::Set(Arc::from_iter([u.clone(), t.clone()])))
161 }
162 }
163 (u @ Type::Tuple(_), t) | (t, u @ Type::Tuple(_)) => {
164 Ok(Type::Set(Arc::from_iter([u.clone(), t.clone()])))
165 }
166 (u @ Type::Variant(tg0, t0), t @ Type::Variant(tg1, t1)) => {
167 if tg0 == tg1 && t0.len() == t1.len() {
168 let mut typs = t0
169 .iter()
170 .zip(t1.iter())
171 .map(|(t0, t1)| t0.union_int(env, hist, t1))
172 .collect::<Result<LPooled<Vec<_>>>>()?;
173 Ok(Type::Variant(tg0.clone(), Arc::from_iter(typs.drain(..))))
174 } else {
175 Ok(Type::Set(Arc::from_iter([u.clone(), t.clone()])))
176 }
177 }
178 (u @ Type::Variant(_, _), t) | (t, u @ Type::Variant(_, _)) => {
179 Ok(Type::Set(Arc::from_iter([u.clone(), t.clone()])))
180 }
181 (Type::Fn(f0), Type::Fn(f1)) => {
182 if f0 == f1 {
183 Ok(Type::Fn(f0.clone()))
184 } else {
185 Ok(Type::Set(Arc::from_iter([
186 Type::Fn(f0.clone()),
187 Type::Fn(f1.clone()),
188 ])))
189 }
190 }
191 (f @ Type::Fn(_), t) | (t, f @ Type::Fn(_)) => {
192 Ok(Type::Set(Arc::from_iter([f.clone(), t.clone()])))
193 }
194 (t0 @ Type::TVar(_), t1 @ Type::TVar(_)) => {
195 if t0 == t1 {
196 Ok(t0.clone())
197 } else {
198 Ok(Type::Set(Arc::from_iter([t0.clone(), t1.clone()])))
199 }
200 }
201 (t0 @ Type::TVar(_), t1) | (t1, t0 @ Type::TVar(_)) => {
202 Ok(Type::Set(Arc::from_iter([t0.clone(), t1.clone()])))
203 }
204 (t @ Type::ByRef(_), u) | (u, t @ Type::ByRef(_)) => {
205 Ok(Type::Set(Arc::from_iter([t.clone(), u.clone()])))
206 }
207 }
208 }
209
210 pub fn union(&self, env: &Env, t: &Self) -> Result<Self> {
211 Ok(self.union_int(env, &mut RefHist::new(LPooled::take()), t)?.normalize())
212 }
213
214 fn diff_int(
215 &self,
216 env: &Env,
217 hist: &mut RefHist<FxHashMap<(Option<usize>, Option<usize>), Type>>,
218 t: &Self,
219 ) -> Result<Self> {
220 match (self, t) {
221 (
222 Type::Ref { scope: s0, name: n0, .. },
223 Type::Ref { scope: s1, name: n1, .. },
224 ) if s0 == s1 && n0 == n1 => Ok(Type::Primitive(BitFlags::empty())),
225 (t0 @ Type::Ref { .. }, t1) | (t0, t1 @ Type::Ref { .. }) => {
226 let t0_id = hist.ref_id(t0, env);
227 let t1_id = hist.ref_id(t1, env);
228 let t0 = t0.lookup_ref(env)?;
229 let t1 = t1.lookup_ref(env)?;
230 match hist.get(&(t0_id, t1_id)) {
231 Some(r) => Ok(r.clone()),
232 None => {
233 let r = Type::Primitive(BitFlags::empty());
234 hist.insert((t0_id, t1_id), r);
235 let r = t0.diff_int(env, hist, &t1);
236 hist.remove(&(t0_id, t1_id));
237 r
238 }
239 }
240 }
241 (Type::Set(s0), Type::Set(s1)) => {
242 let mut s: LPooled<Vec<Type>> = LPooled::take();
243 for i in 0..s0.len() {
244 s.push(s0[i].clone());
245 for j in 0..s1.len() {
246 s[i] = s[i].diff_int(env, hist, &s1[j])?
247 }
248 }
249 Ok(Self::flatten_set(s.drain(..)))
250 }
251 (Type::Set(s), t) => Ok(Self::flatten_set(
252 s.iter()
253 .map(|s| s.diff_int(env, hist, t))
254 .collect::<Result<LPooled<Vec<_>>>>()?
255 .drain(..),
256 )),
257 (t, Type::Set(s)) => {
258 let mut t = t.clone();
259 for st in s.iter() {
260 t = t.diff_int(env, hist, st)?;
261 }
262 Ok(t)
263 }
264 (Type::Tuple(t0), Type::Tuple(t1)) => {
265 if t0 == t1 {
266 Ok(Type::Primitive(BitFlags::empty()))
267 } else {
268 Ok(self.clone())
269 }
270 }
271 (Type::Struct(t0), Type::Struct(t1)) => {
272 if t0.len() == t1.len() && t0 == t1 {
273 Ok(Type::Primitive(BitFlags::empty()))
274 } else {
275 Ok(self.clone())
276 }
277 }
278 (Type::Variant(tg0, t0), Type::Variant(tg1, t1)) => {
279 if tg0 == tg1 && t0.len() == t1.len() && t0 == t1 {
280 Ok(Type::Primitive(BitFlags::empty()))
281 } else {
282 Ok(self.clone())
283 }
284 }
285 (Type::Map { key: k0, value: v0 }, Type::Map { key: k1, value: v1 }) => {
286 if k0 == k1 && v0 == v1 {
287 Ok(Type::Primitive(BitFlags::empty()))
288 } else {
289 Ok(self.clone())
290 }
291 }
292 (Type::Map { .. }, Type::Primitive(p)) => {
293 if p.contains(Typ::Map) {
294 Ok(Type::Primitive(BitFlags::empty()))
295 } else {
296 Ok(self.clone())
297 }
298 }
299 (Type::Primitive(p), Type::Map { key, value }) => {
300 if **key == Type::Any && **value == Type::Any {
301 let mut p = *p;
302 p.remove(Typ::Map);
303 Ok(Type::Primitive(p))
304 } else {
305 Ok(Type::Primitive(*p))
306 }
307 }
308 (Type::Fn(f0), Type::Fn(f1)) => {
309 if f0 == f1 {
310 Ok(Type::Primitive(BitFlags::empty()))
311 } else {
312 Ok(Type::Fn(f0.clone()))
313 }
314 }
315 (Type::TVar(tv0), Type::TVar(tv1)) => {
316 if tv0.read().typ.as_ptr() == tv1.read().typ.as_ptr() {
317 return Ok(Type::Primitive(BitFlags::empty()));
318 }
319 Ok(match (&*tv0.read().typ.read(), &*tv1.read().typ.read()) {
320 (None, _) | (_, None) => Type::TVar(tv0.clone()),
321 (Some(t0), Some(t1)) => t0.diff_int(env, hist, t1)?,
322 })
323 }
324 (Type::TVar(tv), t) => Ok(match &*tv.read().typ.read() {
325 Some(tv) => tv.diff_int(env, hist, t)?,
326 None => self.clone(),
327 }),
328 (t, Type::TVar(tv)) => Ok(match &*tv.read().typ.read() {
329 Some(tv) => t.diff_int(env, hist, tv)?,
330 None => self.clone(),
331 }),
332 (Type::Array(t0), Type::Array(t1)) => {
333 if t0 == t1 {
334 Ok(Type::Primitive(BitFlags::empty()))
335 } else {
336 Ok(Type::Array(Arc::new(t0.diff_int(env, hist, t1)?)))
337 }
338 }
339 (Type::Primitive(p), Type::Array(t)) => {
340 if &**t == &Type::Any {
341 let mut s = *p;
342 s.remove(Typ::Array);
343 Ok(Type::Primitive(s))
344 } else {
345 Ok(Type::Primitive(*p))
346 }
347 }
348 (
349 Type::Array(_) | Type::Struct(_) | Type::Tuple(_) | Type::Variant(_, _),
350 Type::Primitive(p),
351 ) => {
352 if p.contains(Typ::Array) {
353 Ok(Type::Primitive(BitFlags::empty()))
354 } else {
355 Ok(self.clone())
356 }
357 }
358 (_, Type::Any) => Ok(Type::Primitive(BitFlags::empty())),
359 (Type::Any, _) => Ok(Type::Any),
360 (Type::Primitive(s0), Type::Primitive(s1)) => {
361 let mut s = *s0;
362 s.remove(*s1);
363 Ok(Type::Primitive(s))
364 }
365 (Type::Primitive(p), Type::Error(e)) => {
366 if &**e == &Type::Any {
367 let mut s = *p;
368 s.remove(Typ::Error);
369 Ok(Type::Primitive(s))
370 } else {
371 Ok(Type::Primitive(*p))
372 }
373 }
374 (Type::Error(_), Type::Primitive(p)) => {
375 if p.contains(Typ::Error) {
376 Ok(Type::Primitive(BitFlags::empty()))
377 } else {
378 Ok(self.clone())
379 }
380 }
381 (Type::Error(e0), Type::Error(e1)) => {
382 if e0 == e1 {
383 Ok(Type::Primitive(BitFlags::empty()))
384 } else {
385 Ok(Type::Error(Arc::new(e0.diff_int(env, hist, e1)?)))
386 }
387 }
388 (Type::ByRef(t0), Type::ByRef(t1)) => {
389 Ok(Type::ByRef(Arc::new(t0.diff_int(env, hist, t1)?)))
390 }
391 (
392 Type::Abstract { id: id0, params: p0 },
393 Type::Abstract { id: id1, params: p1 },
394 ) if id0 == id1 && p0 == p1 => Ok(Type::Primitive(BitFlags::empty())),
395 (Type::Abstract { .. }, _)
396 | (_, Type::Abstract { .. })
397 | (Type::Fn(_), _)
398 | (_, Type::Fn(_))
399 | (Type::Array(_), _)
400 | (_, Type::Array(_))
401 | (Type::Tuple(_), _)
402 | (_, Type::Tuple(_))
403 | (Type::Struct(_), _)
404 | (_, Type::Struct(_))
405 | (Type::Variant(_, _), _)
406 | (_, Type::Variant(_, _))
407 | (Type::ByRef(_), _)
408 | (_, Type::ByRef(_))
409 | (Type::Error(_), _)
410 | (_, Type::Error(_))
411 | (Type::Primitive(_), _)
412 | (_, Type::Primitive(_))
413 | (Type::Bottom, _)
414 | (Type::Map { .. }, _) => Ok(self.clone()),
415 }
416 }
417
418 pub fn diff(&self, env: &Env, t: &Self) -> Result<Self> {
419 Ok(self.diff_int(env, &mut RefHist::new(LPooled::take()), t)?.normalize())
420 }
421}