1use crate::{env::Env, typ::Type};
2use anyhow::Result;
3use enumflags2::BitFlags;
4use fxhash::FxHashMap;
5use netidx::publisher::Typ;
6use poolshark::local::LPooled;
7use std::iter;
8use triomphe::Arc;
9
10impl Type {
11 fn union_int(
12 &self,
13 env: &Env,
14 hist: &mut FxHashMap<(usize, usize), Type>,
15 t: &Self,
16 ) -> Result<Self> {
17 match (self, t) {
18 (
19 Type::Ref { scope: s0, name: n0, params: p0 },
20 Type::Ref { scope: s1, name: n1, params: p1 },
21 ) if n0 == n1 && s0 == s1 && p0.len() == p1.len() => {
22 let mut params = p0
23 .iter()
24 .zip(p1.iter())
25 .map(|(p0, p1)| p0.union_int(env, hist, p1))
26 .collect::<Result<LPooled<Vec<_>>>>()?;
27 let params = Arc::from_iter(params.drain(..));
28 Ok(Self::Ref { scope: s0.clone(), name: n0.clone(), params })
29 }
30 (tr @ Type::Ref { .. }, t) => {
31 let t0 = tr.lookup_ref(env)?;
32 let t0_addr = (t0 as *const Type).addr();
33 let t_addr = (t as *const Type).addr();
34 match hist.get(&(t0_addr, t_addr)) {
35 Some(t) => Ok(t.clone()),
36 None => {
37 hist.insert((t0_addr, t_addr), tr.clone());
38 let r = t0.union_int(env, hist, t)?;
39 hist.insert((t0_addr, t_addr), r.clone());
40 Ok(r)
41 }
42 }
43 }
44 (t, tr @ Type::Ref { .. }) => {
45 let t1 = tr.lookup_ref(env)?;
46 let t1_addr = (t1 as *const Type).addr();
47 let t_addr = (t as *const Type).addr();
48 match hist.get(&(t_addr, t1_addr)) {
49 Some(t) => Ok(t.clone()),
50 None => {
51 hist.insert((t_addr, t1_addr), tr.clone());
52 let r = t.union_int(env, hist, t1)?;
53 hist.insert((t_addr, t1_addr), r.clone());
54 Ok(r)
55 }
56 }
57 }
58 (
59 Type::Abstract { id: id0, params: p0 },
60 Type::Abstract { id: id1, params: p1 },
61 ) if id0 == id1 && p0 == p1 => Ok(self.clone()),
62 (t0 @ Type::Abstract { .. }, t1) | (t0, t1 @ Type::Abstract { .. }) => {
63 Ok(Type::Set(Arc::from_iter([t0.clone(), t1.clone()])))
64 }
65 (Type::Bottom, t) | (t, Type::Bottom) => Ok(t.clone()),
66 (Type::Any, _) | (_, Type::Any) => Ok(Type::Any),
67 (Type::Primitive(p), t) | (t, Type::Primitive(p)) if p.is_empty() => {
68 Ok(t.clone())
69 }
70 (Type::Primitive(s0), Type::Primitive(s1)) => {
71 let mut s = *s0;
72 s.insert(*s1);
73 Ok(Type::Primitive(s))
74 }
75 (
76 Type::Primitive(p),
77 Type::Array(_) | Type::Struct(_) | Type::Tuple(_) | Type::Variant(_, _),
78 )
79 | (
80 Type::Array(_) | Type::Struct(_) | Type::Tuple(_) | Type::Variant(_, _),
81 Type::Primitive(p),
82 ) if p.contains(Typ::Array) => Ok(Type::Primitive(*p)),
83 (Type::Primitive(p), Type::Array(t))
84 | (Type::Array(t), Type::Primitive(p)) => Ok(Type::Set(Arc::from_iter([
85 Type::Primitive(*p),
86 Type::Array(t.clone()),
87 ]))),
88 (t @ Type::Array(t0), u @ Type::Array(t1)) => {
89 if t0 == t1 {
90 Ok(Type::Array(t0.clone()))
91 } else {
92 Ok(Type::Set(Arc::from_iter([t.clone(), u.clone()])))
93 }
94 }
95 (Type::Primitive(p), Type::Map { .. })
96 | (Type::Map { .. }, Type::Primitive(p))
97 if p.contains(Typ::Map) =>
98 {
99 Ok(Type::Primitive(*p))
100 }
101 (Type::Primitive(p), Type::Map { key, value })
102 | (Type::Map { key, value }, Type::Primitive(p)) => {
103 Ok(Type::Set(Arc::from_iter([
104 Type::Primitive(*p),
105 Type::Map { key: key.clone(), value: value.clone() },
106 ])))
107 }
108 (
109 t @ Type::Map { key: k0, value: v0 },
110 u @ Type::Map { key: k1, value: v1 },
111 ) => {
112 if k0 == k1 && v0 == v1 {
113 Ok(Type::Map { key: k0.clone(), value: v0.clone() })
114 } else {
115 Ok(Type::Set(Arc::from_iter([t.clone(), u.clone()])))
116 }
117 }
118 (t @ Type::Map { .. }, u) | (u, t @ Type::Map { .. }) => {
119 Ok(Type::Set(Arc::from_iter([t.clone(), u.clone()])))
120 }
121 (Type::Primitive(p), Type::Error(_))
122 | (Type::Error(_), Type::Primitive(p))
123 if p.contains(Typ::Error) =>
124 {
125 Ok(Type::Primitive(*p))
126 }
127 (Type::Error(e0), Type::Error(e1)) => {
128 Ok(Type::Error(Arc::new(e0.union_int(env, hist, e1)?)))
129 }
130 (e @ Type::Error(_), t) | (t, e @ Type::Error(_)) => {
131 Ok(Type::Set(Arc::from_iter([e.clone(), t.clone()])))
132 }
133 (t @ Type::ByRef(t0), u @ Type::ByRef(t1)) => {
134 if t0 == t1 {
135 Ok(Type::ByRef(t0.clone()))
136 } else {
137 Ok(Type::Set(Arc::from_iter([u.clone(), t.clone()])))
138 }
139 }
140 (Type::Set(s0), Type::Set(s1)) => Ok(Type::Set(Arc::from_iter(
141 s0.iter().cloned().chain(s1.iter().cloned()),
142 ))),
143 (Type::Set(s), t) | (t, Type::Set(s)) => Ok(Type::Set(Arc::from_iter(
144 s.iter().cloned().chain(iter::once(t.clone())),
145 ))),
146 (u @ Type::Struct(t0), t @ Type::Struct(t1)) => {
147 if t0.len() == t1.len() && t0 == t1 {
148 Ok(u.clone())
149 } else {
150 Ok(Type::Set(Arc::from_iter([u.clone(), t.clone()])))
151 }
152 }
153 (u @ Type::Struct(_), t) | (t, u @ Type::Struct(_)) => {
154 Ok(Type::Set(Arc::from_iter([u.clone(), t.clone()])))
155 }
156 (u @ Type::Tuple(t0), t @ Type::Tuple(t1)) => {
157 if t0 == t1 {
158 Ok(u.clone())
159 } else {
160 Ok(Type::Set(Arc::from_iter([u.clone(), t.clone()])))
161 }
162 }
163 (u @ Type::Tuple(_), t) | (t, u @ Type::Tuple(_)) => {
164 Ok(Type::Set(Arc::from_iter([u.clone(), t.clone()])))
165 }
166 (u @ Type::Variant(tg0, t0), t @ Type::Variant(tg1, t1)) => {
167 if tg0 == tg1 && t0.len() == t1.len() {
168 let mut typs = t0
169 .iter()
170 .zip(t1.iter())
171 .map(|(t0, t1)| t0.union_int(env, hist, t1))
172 .collect::<Result<LPooled<Vec<_>>>>()?;
173 Ok(Type::Variant(tg0.clone(), Arc::from_iter(typs.drain(..))))
174 } else {
175 Ok(Type::Set(Arc::from_iter([u.clone(), t.clone()])))
176 }
177 }
178 (u @ Type::Variant(_, _), t) | (t, u @ Type::Variant(_, _)) => {
179 Ok(Type::Set(Arc::from_iter([u.clone(), t.clone()])))
180 }
181 (Type::Fn(f0), Type::Fn(f1)) => {
182 if f0 == f1 {
183 Ok(Type::Fn(f0.clone()))
184 } else {
185 Ok(Type::Set(Arc::from_iter([
186 Type::Fn(f0.clone()),
187 Type::Fn(f1.clone()),
188 ])))
189 }
190 }
191 (f @ Type::Fn(_), t) | (t, f @ Type::Fn(_)) => {
192 Ok(Type::Set(Arc::from_iter([f.clone(), t.clone()])))
193 }
194 (t0 @ Type::TVar(_), t1 @ Type::TVar(_)) => {
195 if t0 == t1 {
196 Ok(t0.clone())
197 } else {
198 Ok(Type::Set(Arc::from_iter([t0.clone(), t1.clone()])))
199 }
200 }
201 (t0 @ Type::TVar(_), t1) | (t1, t0 @ Type::TVar(_)) => {
202 Ok(Type::Set(Arc::from_iter([t0.clone(), t1.clone()])))
203 }
204 (t @ Type::ByRef(_), u) | (u, t @ Type::ByRef(_)) => {
205 Ok(Type::Set(Arc::from_iter([t.clone(), u.clone()])))
206 }
207 }
208 }
209
210 pub fn union(&self, env: &Env, t: &Self) -> Result<Self> {
211 Ok(self.union_int(env, &mut LPooled::take(), t)?.normalize())
212 }
213
214 fn diff_int(
215 &self,
216 env: &Env,
217 hist: &mut FxHashMap<(usize, usize), Type>,
218 t: &Self,
219 ) -> Result<Self> {
220 match (self, t) {
221 (
222 Type::Ref { scope: s0, name: n0, .. },
223 Type::Ref { scope: s1, name: n1, .. },
224 ) if s0 == s1 && n0 == n1 => Ok(Type::Primitive(BitFlags::empty())),
225 (t0 @ Type::Ref { .. }, t1) | (t0, t1 @ Type::Ref { .. }) => {
226 let t0 = t0.lookup_ref(env)?;
227 let t1 = t1.lookup_ref(env)?;
228 let t0_addr = (t0 as *const Type).addr();
229 let t1_addr = (t1 as *const Type).addr();
230 match hist.get(&(t0_addr, t1_addr)) {
231 Some(r) => Ok(r.clone()),
232 None => {
233 let r = Type::Primitive(BitFlags::empty());
234 hist.insert((t0_addr, t1_addr), r);
235 match t0.diff_int(env, hist, &t1) {
236 Ok(r) => {
237 hist.insert((t0_addr, t1_addr), r.clone());
238 Ok(r)
239 }
240 Err(e) => {
241 hist.remove(&(t0_addr, t1_addr));
242 Err(e)
243 }
244 }
245 }
246 }
247 }
248 (Type::Set(s0), Type::Set(s1)) => {
249 let mut s: LPooled<Vec<Type>> = LPooled::take();
250 for i in 0..s0.len() {
251 s.push(s0[i].clone());
252 for j in 0..s1.len() {
253 s[i] = s[i].diff_int(env, hist, &s1[j])?
254 }
255 }
256 Ok(Self::flatten_set(s.drain(..)))
257 }
258 (Type::Set(s), t) => Ok(Self::flatten_set(
259 s.iter()
260 .map(|s| s.diff_int(env, hist, t))
261 .collect::<Result<LPooled<Vec<_>>>>()?
262 .drain(..),
263 )),
264 (t, Type::Set(s)) => {
265 let mut t = t.clone();
266 for st in s.iter() {
267 t = t.diff_int(env, hist, st)?;
268 }
269 Ok(t)
270 }
271 (Type::Tuple(t0), Type::Tuple(t1)) => {
272 if t0 == t1 {
273 Ok(Type::Primitive(BitFlags::empty()))
274 } else {
275 Ok(self.clone())
276 }
277 }
278 (Type::Struct(t0), Type::Struct(t1)) => {
279 if t0.len() == t1.len() && t0 == t1 {
280 Ok(Type::Primitive(BitFlags::empty()))
281 } else {
282 Ok(self.clone())
283 }
284 }
285 (Type::Variant(tg0, t0), Type::Variant(tg1, t1)) => {
286 if tg0 == tg1 && t0.len() == t1.len() && t0 == t1 {
287 Ok(Type::Primitive(BitFlags::empty()))
288 } else {
289 Ok(self.clone())
290 }
291 }
292 (Type::Map { key: k0, value: v0 }, Type::Map { key: k1, value: v1 }) => {
293 if k0 == k1 && v0 == v1 {
294 Ok(Type::Primitive(BitFlags::empty()))
295 } else {
296 Ok(self.clone())
297 }
298 }
299 (Type::Map { .. }, Type::Primitive(p)) => {
300 if p.contains(Typ::Map) {
301 Ok(Type::Primitive(BitFlags::empty()))
302 } else {
303 Ok(self.clone())
304 }
305 }
306 (Type::Primitive(p), Type::Map { key, value }) => {
307 if **key == Type::Any && **value == Type::Any {
308 let mut p = *p;
309 p.remove(Typ::Map);
310 Ok(Type::Primitive(p))
311 } else {
312 Ok(Type::Primitive(*p))
313 }
314 }
315 (Type::Fn(f0), Type::Fn(f1)) => {
316 if f0 == f1 {
317 Ok(Type::Primitive(BitFlags::empty()))
318 } else {
319 Ok(Type::Fn(f0.clone()))
320 }
321 }
322 (Type::TVar(tv0), Type::TVar(tv1)) => {
323 if tv0.read().typ.as_ptr() == tv1.read().typ.as_ptr() {
324 return Ok(Type::Primitive(BitFlags::empty()));
325 }
326 Ok(match (&*tv0.read().typ.read(), &*tv1.read().typ.read()) {
327 (None, _) | (_, None) => Type::TVar(tv0.clone()),
328 (Some(t0), Some(t1)) => t0.diff_int(env, hist, t1)?,
329 })
330 }
331 (Type::TVar(tv), t) => Ok(match &*tv.read().typ.read() {
332 Some(tv) => tv.diff_int(env, hist, t)?,
333 None => self.clone(),
334 }),
335 (t, Type::TVar(tv)) => Ok(match &*tv.read().typ.read() {
336 Some(tv) => t.diff_int(env, hist, tv)?,
337 None => self.clone(),
338 }),
339 (Type::Array(t0), Type::Array(t1)) => {
340 if t0 == t1 {
341 Ok(Type::Primitive(BitFlags::empty()))
342 } else {
343 Ok(Type::Array(Arc::new(t0.diff_int(env, hist, t1)?)))
344 }
345 }
346 (Type::Primitive(p), Type::Array(t)) => {
347 if &**t == &Type::Any {
348 let mut s = *p;
349 s.remove(Typ::Array);
350 Ok(Type::Primitive(s))
351 } else {
352 Ok(Type::Primitive(*p))
353 }
354 }
355 (
356 Type::Array(_) | Type::Struct(_) | Type::Tuple(_) | Type::Variant(_, _),
357 Type::Primitive(p),
358 ) => {
359 if p.contains(Typ::Array) {
360 Ok(Type::Primitive(BitFlags::empty()))
361 } else {
362 Ok(self.clone())
363 }
364 }
365 (_, Type::Any) => Ok(Type::Primitive(BitFlags::empty())),
366 (Type::Any, _) => Ok(Type::Any),
367 (Type::Primitive(s0), Type::Primitive(s1)) => {
368 let mut s = *s0;
369 s.remove(*s1);
370 Ok(Type::Primitive(s))
371 }
372 (Type::Primitive(p), Type::Error(e)) => {
373 if &**e == &Type::Any {
374 let mut s = *p;
375 s.remove(Typ::Error);
376 Ok(Type::Primitive(s))
377 } else {
378 Ok(Type::Primitive(*p))
379 }
380 }
381 (Type::Error(_), Type::Primitive(p)) => {
382 if p.contains(Typ::Error) {
383 Ok(Type::Primitive(BitFlags::empty()))
384 } else {
385 Ok(self.clone())
386 }
387 }
388 (Type::Error(e0), Type::Error(e1)) => {
389 if e0 == e1 {
390 Ok(Type::Primitive(BitFlags::empty()))
391 } else {
392 Ok(Type::Error(Arc::new(e0.diff_int(env, hist, e1)?)))
393 }
394 }
395 (Type::ByRef(t0), Type::ByRef(t1)) => {
396 Ok(Type::ByRef(Arc::new(t0.diff_int(env, hist, t1)?)))
397 }
398 (
399 Type::Abstract { id: id0, params: p0 },
400 Type::Abstract { id: id1, params: p1 },
401 ) if id0 == id1 && p0 == p1 => Ok(Type::Primitive(BitFlags::empty())),
402 (Type::Abstract { .. }, _)
403 | (_, Type::Abstract { .. })
404 | (Type::Fn(_), _)
405 | (_, Type::Fn(_))
406 | (Type::Array(_), _)
407 | (_, Type::Array(_))
408 | (Type::Tuple(_), _)
409 | (_, Type::Tuple(_))
410 | (Type::Struct(_), _)
411 | (_, Type::Struct(_))
412 | (Type::Variant(_, _), _)
413 | (_, Type::Variant(_, _))
414 | (Type::ByRef(_), _)
415 | (_, Type::ByRef(_))
416 | (Type::Error(_), _)
417 | (_, Type::Error(_))
418 | (Type::Primitive(_), _)
419 | (_, Type::Primitive(_))
420 | (Type::Bottom, _)
421 | (Type::Map { .. }, _) => Ok(self.clone()),
422 }
423 }
424
425 pub fn diff(&self, env: &Env, t: &Self) -> Result<Self> {
426 Ok(self.diff_int(env, &mut LPooled::take(), t)?.normalize())
427 }
428}