1use crate::{
2 env::Env,
3 format_with_flags,
4 typ::{tvar::would_cycle_inner, AndAc, RefHist, Type, TypeRef},
5 PrintFlag,
6};
7use ahash::AHashMap;
8use anyhow::{bail, Result};
9use enumflags2::bitflags;
10use enumflags2::BitFlags;
11use netidx::publisher::Typ;
12use poolshark::local::LPooled;
13use std::fmt::Debug;
14use triomphe::Arc;
15
16#[derive(Debug, Clone, Copy)]
17#[bitflags]
18#[repr(u8)]
19pub enum ContainsFlags {
20 AliasTVars,
21 InitTVars,
22}
23
24impl Type {
25 pub fn check_contains(&self, env: &Env, t: &Self) -> Result<()> {
26 if self.contains(env, t)? {
27 Ok(())
28 } else {
29 format_with_flags(PrintFlag::DerefTVars | PrintFlag::ReplacePrims, || {
30 bail!("type mismatch {self} does not contain {t}")
31 })
32 }
33 }
34
35 pub(super) fn contains_int(
36 &self,
37 flags: BitFlags<ContainsFlags>,
38 env: &Env,
39 hist: &mut RefHist<AHashMap<(Option<usize>, Option<usize>), bool>>,
40 t: &Self,
41 ) -> Result<bool> {
42 if (self as *const Type) == (t as *const Type) {
43 return Ok(true);
44 }
45 match (self, t) {
46 (
47 Self::Ref(TypeRef { scope: s0, name: n0, params: p0, .. }),
48 Self::Ref(TypeRef { scope: s1, name: n1, params: p1, .. }),
49 ) if s0 == s1 && n0 == n1 => Ok(p0.len() == p1.len()
50 && p0
51 .iter()
52 .zip(p1.iter())
53 .map(|(t0, t1)| t0.contains_int(flags, env, hist, t1))
54 .collect::<Result<AndAc>>()?
55 .0),
56 (t0 @ Self::Ref(TypeRef { .. }), t1)
57 | (t0, t1 @ Self::Ref(TypeRef { .. })) => {
58 let t0_id = hist.ref_id(t0, env);
59 let t1_id = hist.ref_id(t1, env);
60 let t0 = t0.lookup_ref(env)?;
61 let t1 = t1.lookup_ref(env)?;
62 match hist.get(&(t0_id, t1_id)) {
63 Some(r) => Ok(*r),
64 None => {
65 hist.insert((t0_id, t1_id), true);
66 let r = t0.contains_int(flags, env, hist, &t1);
67 hist.remove(&(t0_id, t1_id));
68 r
69 }
70 }
71 }
72 (Self::TVar(t0), Self::Bottom) => {
73 if let Some(_) = &*t0.read().typ.read() {
74 return Ok(true);
75 }
76 if flags.contains(ContainsFlags::InitTVars) {
77 *t0.read().typ.write() = Some(Self::Bottom);
78 }
79 Ok(true)
80 }
81 (Self::Bottom, Self::TVar(t0)) => {
82 if let Some(Type::Bottom) = &*t0.read().typ.read() {
83 return Ok(true);
84 }
85 if flags.contains(ContainsFlags::InitTVars) {
86 *t0.read().typ.write() = Some(Self::Bottom);
87 }
88 Ok(true)
89 }
90 (Self::Bottom, Self::Bottom) => Ok(true),
91 (Self::Bottom, _) => Ok(false),
92 (_, Self::Bottom) => Ok(true),
93 (Self::TVar(t0), Self::Any) => {
94 if let Some(t0) = &*t0.read().typ.read() {
95 return t0.contains_int(flags, env, hist, t);
96 }
97 if flags.contains(ContainsFlags::InitTVars) {
98 *t0.read().typ.write() = Some(Self::Any);
99 }
100 Ok(true)
101 }
102 (Self::Any, _) => Ok(true),
103 (
104 Self::Abstract { id: id0, params: p0 },
105 Self::Abstract { id: id1, params: p1 },
106 ) => Ok(id0 == id1
107 && p0.len() == p1.len()
108 && p0
109 .iter()
110 .zip(p1.iter())
111 .map(|(t0, t1)| t0.contains_int(flags, env, hist, t1))
112 .collect::<Result<AndAc>>()?
113 .0),
114 (Self::Primitive(p0), Self::Primitive(p1)) => Ok(p0.contains(*p1)),
115 (
116 Self::Primitive(p),
117 Self::Array(_) | Self::Tuple(_) | Self::Struct(_) | Self::Variant(_, _),
118 ) => Ok(p.contains(Typ::Array)),
119 (Self::Array(t0), Self::Array(t1)) => t0.contains_int(flags, env, hist, t1),
120 (Self::Array(t0), Self::Primitive(p)) if *p == BitFlags::from(Typ::Array) => {
121 t0.contains_int(flags, env, hist, &Type::Any)
122 }
123 (Self::Map { key: k0, value: v0 }, Self::Map { key: k1, value: v1 }) => {
124 Ok(k0.contains_int(flags, env, hist, k1)?
125 && v0.contains_int(flags, env, hist, v1)?)
126 }
127 (Self::Primitive(p), Self::Map { .. }) => Ok(p.contains(Typ::Map)),
128 (Self::Map { key, value }, Self::Primitive(p))
129 if *p == BitFlags::from(Typ::Map) =>
130 {
131 Ok(key.contains_int(flags, env, hist, &Type::Any)?
132 && value.contains_int(flags, env, hist, &Type::Any)?)
133 }
134 (Self::Primitive(p0), Self::Error(_)) => Ok(p0.contains(Typ::Error)),
135 (Self::Error(e), Self::Primitive(p)) if *p == BitFlags::from(Typ::Error) => {
136 e.contains_int(flags, env, hist, &Type::Any)
137 }
138 (Self::Error(e0), Self::Error(e1)) => e0.contains_int(flags, env, hist, e1),
139 (Self::Tuple(t0), Self::Tuple(t1)) if Arc::ptr_eq(t0, t1) => Ok(true),
140 (Self::Tuple(t0), Self::Tuple(t1)) => Ok(t0.len() == t1.len()
141 && t0
142 .iter()
143 .zip(t1.iter())
144 .map(|(t0, t1)| t0.contains_int(flags, env, hist, t1))
145 .collect::<Result<AndAc>>()?
146 .0),
147 (Self::Struct(t0), Self::Struct(t1)) if Arc::ptr_eq(t0, t1) => Ok(true),
148 (Self::Struct(t0), Self::Struct(t1)) => {
149 Ok(t0.len() == t1.len() && {
150 t0.iter()
152 .zip(t1.iter())
153 .map(|((n0, t0), (n1, t1))| {
154 Ok(n0 == n1 && t0.contains_int(flags, env, hist, t1)?)
155 })
156 .collect::<Result<AndAc>>()?
157 .0
158 })
159 }
160 (Self::Variant(tg0, t0), Self::Variant(tg1, t1))
161 if tg0.as_ptr() == tg1.as_ptr() && Arc::ptr_eq(t0, t1) =>
162 {
163 Ok(true)
164 }
165 (Self::Variant(tg0, t0), Self::Variant(tg1, t1)) => Ok(tg0 == tg1
166 && t0.len() == t1.len()
167 && t0
168 .iter()
169 .zip(t1.iter())
170 .map(|(t0, t1)| t0.contains_int(flags, env, hist, t1))
171 .collect::<Result<AndAc>>()?
172 .0),
173 (Self::ByRef(t0), Self::ByRef(t1)) => t0.contains_int(flags, env, hist, t1),
174 (Self::TVar(t0), Self::TVar(t1))
175 if t0.addr() == t1.addr() || t0.read().id == t1.read().id =>
176 {
177 Ok(true)
178 }
179 (tt0 @ Self::TVar(t0), tt1 @ Self::TVar(t1)) => {
180 #[derive(Debug)]
181 enum Act {
182 RightCopy,
183 RightAlias,
184 LeftAlias,
185 LeftCopy,
186 }
187 if t0.would_cycle(tt1) || t1.would_cycle(tt0) {
188 return Ok(true);
189 }
190 let act = {
191 let t0 = t0.read();
192 let t1 = t1.read();
193 let addr0 = Arc::as_ptr(&t0.typ).addr();
194 let addr1 = Arc::as_ptr(&t1.typ).addr();
195 if addr0 == addr1 {
196 return Ok(true);
197 }
198 if would_cycle_inner(addr0, tt1) || would_cycle_inner(addr1, tt0) {
199 return Ok(true);
200 }
201 let t0i = t0.typ.read();
202 let t1i = t1.typ.read();
203 match (&*t0i, &*t1i) {
204 (Some(t0), Some(t1)) => {
205 return t0.contains_int(flags, env, hist, &*t1)
206 }
207 (None, None) => {
208 if t0.frozen && t1.frozen {
209 return Ok(true);
210 }
211 if t0.frozen {
212 Act::RightAlias
213 } else {
214 Act::LeftAlias
215 }
216 }
217 (Some(_), None) => Act::RightCopy,
218 (None, Some(_)) => Act::LeftCopy,
219 }
220 };
221 match act {
222 Act::RightCopy if flags.contains(ContainsFlags::InitTVars) => {
223 t1.copy(t0)
224 }
225 Act::RightAlias if flags.contains(ContainsFlags::AliasTVars) => {
226 t1.alias(t0)
227 }
228 Act::LeftAlias if flags.contains(ContainsFlags::AliasTVars) => {
229 t0.alias(t1)
230 }
231 Act::LeftCopy if flags.contains(ContainsFlags::InitTVars) => {
232 t0.copy(t1)
233 }
234 Act::RightCopy | Act::RightAlias | Act::LeftAlias | Act::LeftCopy => {
235 ()
236 }
237 }
238 Ok(true)
239 }
240 (Self::TVar(t0), t1) if !t0.would_cycle(t1) => {
241 if let Some(t0) = &*t0.read().typ.read() {
242 return t0.contains_int(flags, env, hist, t1);
243 }
244 if flags.contains(ContainsFlags::InitTVars) {
245 *t0.read().typ.write() = Some(t1.clone());
246 }
247 Ok(true)
248 }
249 (t0, Self::TVar(t1)) if !t1.would_cycle(t0) => {
250 if let Some(t1) = &*t1.read().typ.read() {
251 return t0.contains_int(flags, env, hist, t1);
252 }
253 if flags.contains(ContainsFlags::InitTVars) {
254 *t1.read().typ.write() = Some(t0.clone());
255 }
256 Ok(true)
257 }
258 (Self::Set(s0), Self::Set(s1)) if Arc::ptr_eq(s0, s1) => Ok(true),
259 (t0 @ Self::Set(_), t1 @ Self::Set(_)) if t0 == t1 => {
260 if flags.contains(ContainsFlags::InitTVars) {
261 let mut known = LPooled::take();
262 t0.alias_tvars(&mut known);
263 t1.alias_tvars(&mut known);
264 }
265 Ok(true)
266 }
267 (t0, Self::Set(s)) => Ok(s
268 .iter()
269 .map(|t1| t0.contains_int(flags, env, hist, t1))
270 .collect::<Result<AndAc>>()?
271 .0),
272 (Self::Set(s), t) => {
273 let probe = BitFlags::empty();
274 let whole_ok =
275 s.iter().fold(Ok::<_, anyhow::Error>(false), |acc, t0| {
276 Ok(acc? || t0.contains_int(probe, env, hist, t)?)
277 })?;
278 let prims_ok =
279 t.iter_prims().fold(Ok::<_, anyhow::Error>(true), |acc, t1| {
280 Ok(acc?
281 && s.iter().fold(
282 Ok::<_, anyhow::Error>(false),
283 |acc, t0| {
284 Ok(acc? || t0.contains_int(probe, env, hist, &t1)?)
285 },
286 )?)
287 })?;
288 match (whole_ok, prims_ok) {
289 (false, false) => Ok(false),
290 (_, true) => Ok(t.iter_prims().fold(
292 Ok::<_, anyhow::Error>(true),
293 |acc, t1| {
294 Ok(acc?
295 && s.iter().fold(
296 Ok::<_, anyhow::Error>(false),
297 |acc, t0| {
298 Ok(acc?
299 || t0.contains_int(flags, env, hist, &t1)?)
300 },
301 )?)
302 },
303 )?),
304 (true, false) => {
305 Ok(s.iter().fold(Ok::<_, anyhow::Error>(false), |acc, t0| {
306 Ok(acc? || t0.contains_int(flags, env, hist, t)?)
307 })?)
308 }
309 }
310 }
311 (Self::Fn(f0), Self::Fn(f1)) => {
312 let same = Arc::ptr_eq(f0, f1);
313 let r = same || f0.contains_int(flags, env, hist, f1)?;
314 if r && !same && flags.contains(ContainsFlags::InitTVars) {
315 f0.merge_lambda_ids(f1);
316 }
317 Ok(r)
318 }
319 (_, Self::Any)
320 | (Self::Abstract { .. }, _)
321 | (_, Self::Abstract { .. })
322 | (_, Self::TVar(_))
323 | (Self::TVar(_), _)
324 | (Self::Fn(_), _)
325 | (Self::ByRef(_), _)
326 | (_, Self::ByRef(_))
327 | (_, Self::Fn(_))
328 | (Self::Tuple(_), Self::Array(_))
329 | (Self::Tuple(_), Self::Primitive(_))
330 | (Self::Tuple(_), Self::Struct(_))
331 | (Self::Tuple(_), Self::Variant(_, _))
332 | (Self::Tuple(_), Self::Error(_))
333 | (Self::Tuple(_), Self::Map { .. })
334 | (Self::Array(_), Self::Primitive(_))
335 | (Self::Array(_), Self::Tuple(_))
336 | (Self::Array(_), Self::Struct(_))
337 | (Self::Array(_), Self::Variant(_, _))
338 | (Self::Array(_), Self::Error(_))
339 | (Self::Array(_), Self::Map { .. })
340 | (Self::Struct(_), Self::Array(_))
341 | (Self::Struct(_), Self::Primitive(_))
342 | (Self::Struct(_), Self::Tuple(_))
343 | (Self::Struct(_), Self::Variant(_, _))
344 | (Self::Struct(_), Self::Error(_))
345 | (Self::Struct(_), Self::Map { .. })
346 | (Self::Variant(_, _), Self::Array(_))
347 | (Self::Variant(_, _), Self::Struct(_))
348 | (Self::Variant(_, _), Self::Primitive(_))
349 | (Self::Variant(_, _), Self::Tuple(_))
350 | (Self::Variant(_, _), Self::Error(_))
351 | (Self::Variant(_, _), Self::Map { .. })
352 | (Self::Error(_), Self::Array(_))
353 | (Self::Error(_), Self::Primitive(_))
354 | (Self::Error(_), Self::Struct(_))
355 | (Self::Error(_), Self::Variant(_, _))
356 | (Self::Error(_), Self::Tuple(_))
357 | (Self::Error(_), Self::Map { .. })
358 | (Self::Map { .. }, Self::Array(_))
359 | (Self::Map { .. }, Self::Primitive(_))
360 | (Self::Map { .. }, Self::Struct(_))
361 | (Self::Map { .. }, Self::Variant(_, _))
362 | (Self::Map { .. }, Self::Tuple(_))
363 | (Self::Map { .. }, Self::Error(_)) => Ok(false),
364 }
365 }
366
367 pub fn contains(&self, env: &Env, t: &Self) -> Result<bool> {
368 self.contains_int(
369 ContainsFlags::AliasTVars | ContainsFlags::InitTVars,
370 env,
371 &mut RefHist::new(LPooled::take()),
372 t,
373 )
374 }
375
376 pub fn contains_with_flags(
377 &self,
378 flags: BitFlags<ContainsFlags>,
379 env: &Env,
380 t: &Self,
381 ) -> Result<bool> {
382 self.contains_int(flags, env, &mut RefHist::new(LPooled::take()), t)
383 }
384}