1use crate::{
2 env::Env,
3 typ::{RefHist, Type},
4};
5use anyhow::Result;
6use enumflags2::BitFlags;
7use fxhash::FxHashMap;
8use netidx::publisher::Typ;
9use poolshark::local::LPooled;
10use std::iter;
11use triomphe::Arc;
12
13impl Type {
14 fn union_int(
15 &self,
16 env: &Env,
17 hist: &mut RefHist<FxHashMap<(Option<usize>, Option<usize>), Type>>,
18 t: &Self,
19 ) -> Result<Self> {
20 match (self, t) {
21 (
22 Type::Ref { scope: s0, name: n0, params: p0 },
23 Type::Ref { scope: s1, name: n1, params: p1 },
24 ) if n0 == n1 && s0 == s1 && p0.len() == p1.len() => {
25 let mut params = p0
26 .iter()
27 .zip(p1.iter())
28 .map(|(p0, p1)| p0.union_int(env, hist, p1))
29 .collect::<Result<LPooled<Vec<_>>>>()?;
30 let params = Arc::from_iter(params.drain(..));
31 Ok(Self::Ref { scope: s0.clone(), name: n0.clone(), params })
32 }
33 (tr @ Type::Ref { .. }, t) => {
34 let t0_id = hist.ref_id(tr, env);
35 let t_id = hist.ref_id(t, env);
36 let t0 = tr.lookup_ref(env)?;
37 match hist.get(&(t0_id, t_id)) {
38 Some(t) => Ok(t.clone()),
39 None => {
40 hist.insert((t0_id, t_id), tr.clone());
41 let r = t0.union_int(env, hist, t);
42 hist.remove(&(t0_id, t_id));
43 r
44 }
45 }
46 }
47 (t, tr @ Type::Ref { .. }) => {
48 let t_id = hist.ref_id(t, env);
49 let t1_id = hist.ref_id(tr, env);
50 let t1 = tr.lookup_ref(env)?;
51 match hist.get(&(t_id, t1_id)) {
52 Some(t) => Ok(t.clone()),
53 None => {
54 hist.insert((t_id, t1_id), tr.clone());
55 let r = t.union_int(env, hist, &t1);
56 hist.remove(&(t_id, t1_id));
57 r
58 }
59 }
60 }
61 (
62 Type::Abstract { id: id0, params: p0 },
63 Type::Abstract { id: id1, params: p1 },
64 ) if id0 == id1 && p0 == p1 => Ok(self.clone()),
65 (t0 @ Type::Abstract { .. }, t1) | (t0, t1 @ Type::Abstract { .. }) => {
66 Ok(Type::Set(Arc::from_iter([t0.clone(), t1.clone()])))
67 }
68 (Type::Bottom, t) | (t, Type::Bottom) => Ok(t.clone()),
69 (Type::Any, _) | (_, Type::Any) => Ok(Type::Any),
70 (Type::Primitive(p), t) | (t, Type::Primitive(p)) if p.is_empty() => {
71 Ok(t.clone())
72 }
73 (Type::Primitive(s0), Type::Primitive(s1)) => {
74 let mut s = *s0;
75 s.insert(*s1);
76 Ok(Type::Primitive(s))
77 }
78 (
79 Type::Primitive(p),
80 Type::Array(_) | Type::Struct(_) | Type::Tuple(_) | Type::Variant(_, _),
81 )
82 | (
83 Type::Array(_) | Type::Struct(_) | Type::Tuple(_) | Type::Variant(_, _),
84 Type::Primitive(p),
85 ) if p.contains(Typ::Array) => Ok(Type::Primitive(*p)),
86 (Type::Primitive(p), Type::Array(t))
87 | (Type::Array(t), Type::Primitive(p)) => Ok(Type::Set(Arc::from_iter([
88 Type::Primitive(*p),
89 Type::Array(t.clone()),
90 ]))),
91 (t @ Type::Array(t0), u @ Type::Array(t1)) => {
92 if t0 == t1 {
93 Ok(Type::Array(t0.clone()))
94 } else {
95 Ok(Type::Set(Arc::from_iter([t.clone(), u.clone()])))
96 }
97 }
98 (Type::Primitive(p), Type::Map { .. })
99 | (Type::Map { .. }, Type::Primitive(p))
100 if p.contains(Typ::Map) =>
101 {
102 Ok(Type::Primitive(*p))
103 }
104 (Type::Primitive(p), Type::Map { key, value })
105 | (Type::Map { key, value }, Type::Primitive(p)) => {
106 Ok(Type::Set(Arc::from_iter([
107 Type::Primitive(*p),
108 Type::Map { key: key.clone(), value: value.clone() },
109 ])))
110 }
111 (
112 t @ Type::Map { key: k0, value: v0 },
113 u @ Type::Map { key: k1, value: v1 },
114 ) => {
115 if k0 == k1 && v0 == v1 {
116 Ok(Type::Map { key: k0.clone(), value: v0.clone() })
117 } else {
118 Ok(Type::Set(Arc::from_iter([t.clone(), u.clone()])))
119 }
120 }
121 (t @ Type::Map { .. }, u) | (u, t @ Type::Map { .. }) => {
122 Ok(Type::Set(Arc::from_iter([t.clone(), u.clone()])))
123 }
124 (Type::Primitive(p), Type::Error(_))
125 | (Type::Error(_), Type::Primitive(p))
126 if p.contains(Typ::Error) =>
127 {
128 Ok(Type::Primitive(*p))
129 }
130 (Type::Error(e0), Type::Error(e1)) => {
131 Ok(Type::Error(Arc::new(e0.union_int(env, hist, e1)?)))
132 }
133 (e @ Type::Error(_), t) | (t, e @ Type::Error(_)) => {
134 Ok(Type::Set(Arc::from_iter([e.clone(), t.clone()])))
135 }
136 (t @ Type::ByRef(t0), u @ Type::ByRef(t1)) => {
137 if t0 == t1 {
138 Ok(Type::ByRef(t0.clone()))
139 } else {
140 Ok(Type::Set(Arc::from_iter([u.clone(), t.clone()])))
141 }
142 }
143 (Type::Set(s0), Type::Set(s1)) => Ok(Type::Set(Arc::from_iter(
144 s0.iter().cloned().chain(s1.iter().cloned()),
145 ))),
146 (Type::Set(s), t) | (t, Type::Set(s)) => Ok(Type::Set(Arc::from_iter(
147 s.iter().cloned().chain(iter::once(t.clone())),
148 ))),
149 (u @ Type::Struct(t0), t @ Type::Struct(t1)) => {
150 if t0.len() == t1.len() && t0 == t1 {
151 Ok(u.clone())
152 } else {
153 Ok(Type::Set(Arc::from_iter([u.clone(), t.clone()])))
154 }
155 }
156 (u @ Type::Struct(_), t) | (t, u @ Type::Struct(_)) => {
157 Ok(Type::Set(Arc::from_iter([u.clone(), t.clone()])))
158 }
159 (u @ Type::Tuple(t0), t @ Type::Tuple(t1)) => {
160 if t0 == t1 {
161 Ok(u.clone())
162 } else {
163 Ok(Type::Set(Arc::from_iter([u.clone(), t.clone()])))
164 }
165 }
166 (u @ Type::Tuple(_), t) | (t, u @ Type::Tuple(_)) => {
167 Ok(Type::Set(Arc::from_iter([u.clone(), t.clone()])))
168 }
169 (u @ Type::Variant(tg0, t0), t @ Type::Variant(tg1, t1)) => {
170 if tg0 == tg1 && t0.len() == t1.len() {
171 let mut typs = t0
172 .iter()
173 .zip(t1.iter())
174 .map(|(t0, t1)| t0.union_int(env, hist, t1))
175 .collect::<Result<LPooled<Vec<_>>>>()?;
176 Ok(Type::Variant(tg0.clone(), Arc::from_iter(typs.drain(..))))
177 } else {
178 Ok(Type::Set(Arc::from_iter([u.clone(), t.clone()])))
179 }
180 }
181 (u @ Type::Variant(_, _), t) | (t, u @ Type::Variant(_, _)) => {
182 Ok(Type::Set(Arc::from_iter([u.clone(), t.clone()])))
183 }
184 (Type::Fn(f0), Type::Fn(f1)) => {
185 if f0 == f1 {
186 Ok(Type::Fn(f0.clone()))
187 } else {
188 Ok(Type::Set(Arc::from_iter([
189 Type::Fn(f0.clone()),
190 Type::Fn(f1.clone()),
191 ])))
192 }
193 }
194 (f @ Type::Fn(_), t) | (t, f @ Type::Fn(_)) => {
195 Ok(Type::Set(Arc::from_iter([f.clone(), t.clone()])))
196 }
197 (t0 @ Type::TVar(_), t1 @ Type::TVar(_)) => {
198 if t0 == t1 {
199 Ok(t0.clone())
200 } else {
201 Ok(Type::Set(Arc::from_iter([t0.clone(), t1.clone()])))
202 }
203 }
204 (t0 @ Type::TVar(_), t1) | (t1, t0 @ Type::TVar(_)) => {
205 Ok(Type::Set(Arc::from_iter([t0.clone(), t1.clone()])))
206 }
207 (t @ Type::ByRef(_), u) | (u, t @ Type::ByRef(_)) => {
208 Ok(Type::Set(Arc::from_iter([t.clone(), u.clone()])))
209 }
210 }
211 }
212
213 pub fn union(&self, env: &Env, t: &Self) -> Result<Self> {
214 Ok(self.union_int(env, &mut RefHist::new(LPooled::take()), t)?.normalize())
215 }
216
217 fn diff_int(
218 &self,
219 env: &Env,
220 hist: &mut RefHist<FxHashMap<(Option<usize>, Option<usize>), Type>>,
221 t: &Self,
222 ) -> Result<Self> {
223 match (self, t) {
224 (
225 Type::Ref { scope: s0, name: n0, .. },
226 Type::Ref { scope: s1, name: n1, .. },
227 ) if s0 == s1 && n0 == n1 => Ok(Type::Primitive(BitFlags::empty())),
228 (t0 @ Type::Ref { .. }, t1) | (t0, t1 @ Type::Ref { .. }) => {
229 let t0_id = hist.ref_id(t0, env);
230 let t1_id = hist.ref_id(t1, env);
231 let t0 = t0.lookup_ref(env)?;
232 let t1 = t1.lookup_ref(env)?;
233 match hist.get(&(t0_id, t1_id)) {
234 Some(r) => Ok(r.clone()),
235 None => {
236 let r = Type::Primitive(BitFlags::empty());
237 hist.insert((t0_id, t1_id), r);
238 let r = t0.diff_int(env, hist, &t1);
239 hist.remove(&(t0_id, t1_id));
240 r
241 }
242 }
243 }
244 (Type::Set(s0), Type::Set(s1)) => {
245 let mut s: LPooled<Vec<Type>> = LPooled::take();
246 for i in 0..s0.len() {
247 s.push(s0[i].clone());
248 for j in 0..s1.len() {
249 s[i] = s[i].diff_int(env, hist, &s1[j])?
250 }
251 }
252 Ok(Self::flatten_set(s.drain(..)))
253 }
254 (Type::Set(s), t) => Ok(Self::flatten_set(
255 s.iter()
256 .map(|s| s.diff_int(env, hist, t))
257 .collect::<Result<LPooled<Vec<_>>>>()?
258 .drain(..),
259 )),
260 (t, Type::Set(s)) => {
261 let mut t = t.clone();
262 for st in s.iter() {
263 t = t.diff_int(env, hist, st)?;
264 }
265 Ok(t)
266 }
267 (Type::Tuple(t0), Type::Tuple(t1)) => {
268 if t0 == t1 {
269 Ok(Type::Primitive(BitFlags::empty()))
270 } else {
271 Ok(self.clone())
272 }
273 }
274 (Type::Struct(t0), Type::Struct(t1)) => {
275 if t0.len() == t1.len() && t0 == t1 {
276 Ok(Type::Primitive(BitFlags::empty()))
277 } else {
278 Ok(self.clone())
279 }
280 }
281 (Type::Variant(tg0, t0), Type::Variant(tg1, t1)) => {
282 if tg0 == tg1 && t0.len() == t1.len() && t0 == t1 {
283 Ok(Type::Primitive(BitFlags::empty()))
284 } else {
285 Ok(self.clone())
286 }
287 }
288 (Type::Map { key: k0, value: v0 }, Type::Map { key: k1, value: v1 }) => {
289 if k0 == k1 && v0 == v1 {
290 Ok(Type::Primitive(BitFlags::empty()))
291 } else {
292 Ok(self.clone())
293 }
294 }
295 (Type::Map { .. }, Type::Primitive(p)) => {
296 if p.contains(Typ::Map) {
297 Ok(Type::Primitive(BitFlags::empty()))
298 } else {
299 Ok(self.clone())
300 }
301 }
302 (Type::Primitive(p), Type::Map { key, value }) => {
303 if **key == Type::Any && **value == Type::Any {
304 let mut p = *p;
305 p.remove(Typ::Map);
306 Ok(Type::Primitive(p))
307 } else {
308 Ok(Type::Primitive(*p))
309 }
310 }
311 (Type::Fn(f0), Type::Fn(f1)) => {
312 if f0 == f1 {
313 Ok(Type::Primitive(BitFlags::empty()))
314 } else {
315 Ok(Type::Fn(f0.clone()))
316 }
317 }
318 (Type::TVar(tv0), t1 @ Type::TVar(tv1)) => {
319 if Arc::ptr_eq(&tv0.read().typ, &tv1.read().typ) {
320 return Ok(Type::Primitive(BitFlags::empty()));
321 }
322 Ok(match (&*tv0.read().typ.read(), &*tv1.read().typ.read()) {
323 (None, _) => Type::TVar(tv0.clone()),
324 (Some(t0), None) => t0.diff_int(env, hist, t1)?,
325 (Some(t0), Some(t1)) => t0.diff_int(env, hist, t1)?,
326 })
327 }
328 (Type::TVar(tv), t) => Ok(match &*tv.read().typ.read() {
329 Some(tv) => tv.diff_int(env, hist, t)?,
330 None => self.clone(),
331 }),
332 (t, Type::TVar(tv)) => Ok(match &*tv.read().typ.read() {
333 Some(tv) => t.diff_int(env, hist, tv)?,
334 None => self.clone(),
335 }),
336 (Type::Array(t0), Type::Array(t1)) => {
337 if t0 == t1 {
338 Ok(Type::Primitive(BitFlags::empty()))
339 } else {
340 Ok(Type::Array(Arc::new(t0.diff_int(env, hist, t1)?)))
341 }
342 }
343 (Type::Primitive(p), Type::Array(t)) => {
344 if &**t == &Type::Any {
345 let mut s = *p;
346 s.remove(Typ::Array);
347 Ok(Type::Primitive(s))
348 } else {
349 Ok(Type::Primitive(*p))
350 }
351 }
352 (
353 Type::Array(_) | Type::Struct(_) | Type::Tuple(_) | Type::Variant(_, _),
354 Type::Primitive(p),
355 ) => {
356 if p.contains(Typ::Array) {
357 Ok(Type::Primitive(BitFlags::empty()))
358 } else {
359 Ok(self.clone())
360 }
361 }
362 (_, Type::Any) => Ok(Type::Primitive(BitFlags::empty())),
363 (Type::Any, _) => Ok(Type::Any),
364 (Type::Primitive(s0), Type::Primitive(s1)) => {
365 let mut s = *s0;
366 s.remove(*s1);
367 Ok(Type::Primitive(s))
368 }
369 (Type::Primitive(p), Type::Error(e)) => {
370 if &**e == &Type::Any {
371 let mut s = *p;
372 s.remove(Typ::Error);
373 Ok(Type::Primitive(s))
374 } else {
375 Ok(Type::Primitive(*p))
376 }
377 }
378 (Type::Error(_), Type::Primitive(p)) => {
379 if p.contains(Typ::Error) {
380 Ok(Type::Primitive(BitFlags::empty()))
381 } else {
382 Ok(self.clone())
383 }
384 }
385 (Type::Error(e0), Type::Error(e1)) => {
386 if e0 == e1 {
387 Ok(Type::Primitive(BitFlags::empty()))
388 } else {
389 Ok(Type::Error(Arc::new(e0.diff_int(env, hist, e1)?)))
390 }
391 }
392 (Type::ByRef(t0), Type::ByRef(t1)) => {
393 Ok(Type::ByRef(Arc::new(t0.diff_int(env, hist, t1)?)))
394 }
395 (
396 Type::Abstract { id: id0, params: p0 },
397 Type::Abstract { id: id1, params: p1 },
398 ) if id0 == id1 && p0 == p1 => Ok(Type::Primitive(BitFlags::empty())),
399 (Type::Abstract { .. }, _)
400 | (_, Type::Abstract { .. })
401 | (Type::Fn(_), _)
402 | (_, Type::Fn(_))
403 | (Type::Array(_), _)
404 | (_, Type::Array(_))
405 | (Type::Tuple(_), _)
406 | (_, Type::Tuple(_))
407 | (Type::Struct(_), _)
408 | (_, Type::Struct(_))
409 | (Type::Variant(_, _), _)
410 | (_, Type::Variant(_, _))
411 | (Type::ByRef(_), _)
412 | (_, Type::ByRef(_))
413 | (Type::Error(_), _)
414 | (_, Type::Error(_))
415 | (Type::Primitive(_), _)
416 | (_, Type::Primitive(_))
417 | (Type::Bottom, _)
418 | (Type::Map { .. }, _) => Ok(self.clone()),
419 }
420 }
421
422 pub fn diff(&self, env: &Env, t: &Self) -> Result<Self> {
423 Ok(self.diff_int(env, &mut RefHist::new(LPooled::take()), t)?.normalize())
424 }
425}