1use crate::{
2 env::Env,
3 format_with_flags,
4 typ::{tvar::would_cycle_inner, AndAc, Type},
5 PrintFlag,
6};
7use anyhow::{bail, Result};
8use enumflags2::bitflags;
9use enumflags2::BitFlags;
10use fxhash::FxHashMap;
11use netidx::publisher::Typ;
12use poolshark::local::LPooled;
13use std::fmt::Debug;
14use triomphe::Arc;
15
16#[derive(Debug, Clone, Copy)]
17#[bitflags]
18#[repr(u8)]
19pub enum ContainsFlags {
20 AliasTVars,
21 InitTVars,
22}
23
24impl Type {
25 pub fn check_contains(&self, env: &Env, t: &Self) -> Result<()> {
26 if self.contains(env, t)? {
27 Ok(())
28 } else {
29 format_with_flags(PrintFlag::DerefTVars | PrintFlag::ReplacePrims, || {
30 bail!("type mismatch {self} does not contain {t}")
31 })
32 }
33 }
34
35 pub(super) fn contains_int(
36 &self,
37 flags: BitFlags<ContainsFlags>,
38 env: &Env,
39 hist: &mut FxHashMap<(usize, usize), bool>,
40 t: &Self,
41 ) -> Result<bool> {
42 if (self as *const Type) == (t as *const Type) {
43 return Ok(true);
44 }
45 match (self, t) {
46 (
47 Self::Ref { scope: s0, name: n0, params: p0 },
48 Self::Ref { scope: s1, name: n1, params: p1 },
49 ) if s0 == s1 && n0 == n1 => Ok(p0.len() == p1.len()
50 && p0
51 .iter()
52 .zip(p1.iter())
53 .map(|(t0, t1)| t0.contains_int(flags, env, hist, t1))
54 .collect::<Result<AndAc>>()?
55 .0),
56 (t0 @ Self::Ref { .. }, t1) | (t0, t1 @ Self::Ref { .. }) => {
57 let t0 = t0.lookup_ref(env)?;
58 let t1 = t1.lookup_ref(env)?;
59 let t0_addr = (t0 as *const Type).addr();
60 let t1_addr = (t1 as *const Type).addr();
61 match hist.get(&(t0_addr, t1_addr)) {
62 Some(r) => Ok(*r),
63 None => {
64 hist.insert((t0_addr, t1_addr), true);
65 match t0.contains_int(flags, env, hist, t1) {
66 Ok(r) => {
67 hist.insert((t0_addr, t1_addr), r);
68 Ok(r)
69 }
70 Err(e) => {
71 hist.remove(&(t0_addr, t1_addr));
72 Err(e)
73 }
74 }
75 }
76 }
77 }
78 (Self::TVar(t0), Self::Bottom) => {
79 if let Some(_) = &*t0.read().typ.read() {
80 return Ok(true);
81 }
82 if flags.contains(ContainsFlags::InitTVars) {
83 *t0.read().typ.write() = Some(Self::Bottom);
84 }
85 Ok(true)
86 }
87 (Self::Bottom, Self::TVar(t0)) => {
88 if let Some(Type::Bottom) = &*t0.read().typ.read() {
89 return Ok(true);
90 }
91 if flags.contains(ContainsFlags::InitTVars) {
92 *t0.read().typ.write() = Some(Self::Bottom);
93 return Ok(true);
94 }
95 Ok(false)
96 }
97 (Self::Bottom, Self::Bottom) => Ok(true),
98 (Self::Bottom, _) => Ok(false),
99 (_, Self::Bottom) => Ok(true),
100 (Self::TVar(t0), Self::Any) => {
101 if let Some(t0) = &*t0.read().typ.read() {
102 return t0.contains_int(flags, env, hist, t);
103 }
104 if flags.contains(ContainsFlags::InitTVars) {
105 *t0.read().typ.write() = Some(Self::Any);
106 }
107 Ok(true)
108 }
109 (Self::Any, _) => Ok(true),
110 (
111 Self::Abstract { id: id0, params: p0 },
112 Self::Abstract { id: id1, params: p1 },
113 ) => Ok(id0 == id1
114 && p0.len() == p1.len()
115 && p0
116 .iter()
117 .zip(p1.iter())
118 .map(|(t0, t1)| t0.contains_int(flags, env, hist, t1))
119 .collect::<Result<AndAc>>()?
120 .0),
121 (Self::Primitive(p0), Self::Primitive(p1)) => Ok(p0.contains(*p1)),
122 (
123 Self::Primitive(p),
124 Self::Array(_) | Self::Tuple(_) | Self::Struct(_) | Self::Variant(_, _),
125 ) => Ok(p.contains(Typ::Array)),
126 (Self::Array(t0), Self::Array(t1)) => t0.contains_int(flags, env, hist, t1),
127 (Self::Array(t0), Self::Primitive(p)) if *p == BitFlags::from(Typ::Array) => {
128 t0.contains_int(flags, env, hist, &Type::Any)
129 }
130 (Self::Map { key: k0, value: v0 }, Self::Map { key: k1, value: v1 }) => {
131 Ok(k0.contains_int(flags, env, hist, k1)?
132 && v0.contains_int(flags, env, hist, v1)?)
133 }
134 (Self::Primitive(p), Self::Map { .. }) => Ok(p.contains(Typ::Map)),
135 (Self::Map { key, value }, Self::Primitive(p))
136 if *p == BitFlags::from(Typ::Map) =>
137 {
138 Ok(key.contains_int(flags, env, hist, &Type::Any)?
139 && value.contains_int(flags, env, hist, &Type::Any)?)
140 }
141 (Self::Primitive(p0), Self::Error(_)) => Ok(p0.contains(Typ::Error)),
142 (Self::Error(e), Self::Primitive(p)) if *p == BitFlags::from(Typ::Error) => {
143 e.contains_int(flags, env, hist, &Type::Any)
144 }
145 (Self::Error(e0), Self::Error(e1)) => e0.contains_int(flags, env, hist, e1),
146 (Self::Tuple(t0), Self::Tuple(t1))
147 if t0.as_ptr().addr() == t1.as_ptr().addr() =>
148 {
149 Ok(true)
150 }
151 (Self::Tuple(t0), Self::Tuple(t1)) => Ok(t0.len() == t1.len()
152 && t0
153 .iter()
154 .zip(t1.iter())
155 .map(|(t0, t1)| t0.contains_int(flags, env, hist, t1))
156 .collect::<Result<AndAc>>()?
157 .0),
158 (Self::Struct(t0), Self::Struct(t1))
159 if t0.as_ptr().addr() == t1.as_ptr().addr() =>
160 {
161 Ok(true)
162 }
163 (Self::Struct(t0), Self::Struct(t1)) => {
164 Ok(t0.len() == t1.len() && {
165 t0.iter()
167 .zip(t1.iter())
168 .map(|((n0, t0), (n1, t1))| {
169 Ok(n0 == n1 && t0.contains_int(flags, env, hist, t1)?)
170 })
171 .collect::<Result<AndAc>>()?
172 .0
173 })
174 }
175 (Self::Variant(tg0, t0), Self::Variant(tg1, t1))
176 if tg0.as_ptr() == tg1.as_ptr()
177 && t0.as_ptr().addr() == t1.as_ptr().addr() =>
178 {
179 Ok(true)
180 }
181 (Self::Variant(tg0, t0), Self::Variant(tg1, t1)) => Ok(tg0 == tg1
182 && t0.len() == t1.len()
183 && t0
184 .iter()
185 .zip(t1.iter())
186 .map(|(t0, t1)| t0.contains_int(flags, env, hist, t1))
187 .collect::<Result<AndAc>>()?
188 .0),
189 (Self::ByRef(t0), Self::ByRef(t1)) => t0.contains_int(flags, env, hist, t1),
190 (Self::TVar(t0), Self::TVar(t1))
191 if t0.addr() == t1.addr() || t0.read().id == t1.read().id =>
192 {
193 Ok(true)
194 }
195 (tt0 @ Self::TVar(t0), tt1 @ Self::TVar(t1)) => {
196 #[derive(Debug)]
197 enum Act {
198 RightCopy,
199 RightAlias,
200 LeftAlias,
201 LeftCopy,
202 }
203 if t0.would_cycle(tt1) || t1.would_cycle(tt0) {
204 return Ok(true);
205 }
206 let act = {
207 let t0 = t0.read();
208 let t1 = t1.read();
209 let addr0 = Arc::as_ptr(&t0.typ).addr();
210 let addr1 = Arc::as_ptr(&t1.typ).addr();
211 if addr0 == addr1 {
212 return Ok(true);
213 }
214 if would_cycle_inner(addr0, tt1) || would_cycle_inner(addr1, tt0) {
215 return Ok(true);
216 }
217 let t0i = t0.typ.read();
218 let t1i = t1.typ.read();
219 match (&*t0i, &*t1i) {
220 (Some(t0), Some(t1)) => {
221 return t0.contains_int(flags, env, hist, &*t1)
222 }
223 (None, None) => {
224 if t0.frozen && t1.frozen {
225 return Ok(true);
226 }
227 if t0.frozen {
228 Act::RightAlias
229 } else {
230 Act::LeftAlias
231 }
232 }
233 (Some(_), None) => Act::RightCopy,
234 (None, Some(_)) => Act::LeftCopy,
235 }
236 };
237 match act {
238 Act::RightCopy if flags.contains(ContainsFlags::InitTVars) => {
239 t1.copy(t0)
240 }
241 Act::RightAlias if flags.contains(ContainsFlags::AliasTVars) => {
242 t1.alias(t0)
243 }
244 Act::LeftAlias if flags.contains(ContainsFlags::AliasTVars) => {
245 t0.alias(t1)
246 }
247 Act::LeftCopy if flags.contains(ContainsFlags::InitTVars) => {
248 t0.copy(t1)
249 }
250 Act::RightCopy | Act::RightAlias | Act::LeftAlias | Act::LeftCopy => {
251 ()
252 }
253 }
254 Ok(true)
255 }
256 (Self::TVar(t0), t1) if !t0.would_cycle(t1) => {
257 if let Some(t0) = &*t0.read().typ.read() {
258 return t0.contains_int(flags, env, hist, t1);
259 }
260 if flags.contains(ContainsFlags::InitTVars) {
261 *t0.read().typ.write() = Some(t1.clone());
262 }
263 Ok(true)
264 }
265 (t0, Self::TVar(t1)) if !t1.would_cycle(t0) => {
266 if let Some(t1) = &*t1.read().typ.read() {
267 return t0.contains_int(flags, env, hist, t1);
268 }
269 if flags.contains(ContainsFlags::InitTVars) {
270 *t1.read().typ.write() = Some(t0.clone());
271 }
272 Ok(true)
273 }
274 (Self::Set(s0), Self::Set(s1))
275 if s0.as_ptr().addr() == s1.as_ptr().addr() =>
276 {
277 Ok(true)
278 }
279 (t0, Self::Set(s)) => Ok(s
280 .iter()
281 .map(|t1| t0.contains_int(flags, env, hist, t1))
282 .collect::<Result<AndAc>>()?
283 .0),
284 (Self::Set(s), t) => Ok(s
285 .iter()
286 .fold(Ok::<_, anyhow::Error>(false), |acc, t0| {
287 Ok(acc? || t0.contains_int(flags, env, hist, t)?)
288 })?
289 || t.iter_prims().fold(Ok::<_, anyhow::Error>(true), |acc, t1| {
290 Ok(acc?
291 && s.iter().fold(Ok::<_, anyhow::Error>(false), |acc, t0| {
292 Ok(acc? || t0.contains_int(flags, env, hist, &t1)?)
293 })?)
294 })?),
295 (Self::Fn(f0), Self::Fn(f1)) => {
296 Ok(f0.as_ptr() == f1.as_ptr() || f0.contains_int(flags, env, hist, f1)?)
297 }
298 (_, Self::Any)
299 | (Self::Abstract { .. }, _)
300 | (_, Self::Abstract { .. })
301 | (_, Self::TVar(_))
302 | (Self::TVar(_), _)
303 | (Self::Fn(_), _)
304 | (Self::ByRef(_), _)
305 | (_, Self::ByRef(_))
306 | (_, Self::Fn(_))
307 | (Self::Tuple(_), Self::Array(_))
308 | (Self::Tuple(_), Self::Primitive(_))
309 | (Self::Tuple(_), Self::Struct(_))
310 | (Self::Tuple(_), Self::Variant(_, _))
311 | (Self::Tuple(_), Self::Error(_))
312 | (Self::Tuple(_), Self::Map { .. })
313 | (Self::Array(_), Self::Primitive(_))
314 | (Self::Array(_), Self::Tuple(_))
315 | (Self::Array(_), Self::Struct(_))
316 | (Self::Array(_), Self::Variant(_, _))
317 | (Self::Array(_), Self::Error(_))
318 | (Self::Array(_), Self::Map { .. })
319 | (Self::Struct(_), Self::Array(_))
320 | (Self::Struct(_), Self::Primitive(_))
321 | (Self::Struct(_), Self::Tuple(_))
322 | (Self::Struct(_), Self::Variant(_, _))
323 | (Self::Struct(_), Self::Error(_))
324 | (Self::Struct(_), Self::Map { .. })
325 | (Self::Variant(_, _), Self::Array(_))
326 | (Self::Variant(_, _), Self::Struct(_))
327 | (Self::Variant(_, _), Self::Primitive(_))
328 | (Self::Variant(_, _), Self::Tuple(_))
329 | (Self::Variant(_, _), Self::Error(_))
330 | (Self::Variant(_, _), Self::Map { .. })
331 | (Self::Error(_), Self::Array(_))
332 | (Self::Error(_), Self::Primitive(_))
333 | (Self::Error(_), Self::Struct(_))
334 | (Self::Error(_), Self::Variant(_, _))
335 | (Self::Error(_), Self::Tuple(_))
336 | (Self::Error(_), Self::Map { .. })
337 | (Self::Map { .. }, Self::Array(_))
338 | (Self::Map { .. }, Self::Primitive(_))
339 | (Self::Map { .. }, Self::Struct(_))
340 | (Self::Map { .. }, Self::Variant(_, _))
341 | (Self::Map { .. }, Self::Tuple(_))
342 | (Self::Map { .. }, Self::Error(_)) => Ok(false),
343 }
344 }
345
346 pub fn contains(&self, env: &Env, t: &Self) -> Result<bool> {
347 self.contains_int(
348 ContainsFlags::AliasTVars | ContainsFlags::InitTVars,
349 env,
350 &mut LPooled::take(),
351 t,
352 )
353 }
354
355 pub fn contains_with_flags(
356 &self,
357 flags: BitFlags<ContainsFlags>,
358 env: &Env,
359 t: &Self,
360 ) -> Result<bool> {
361 self.contains_int(flags, env, &mut LPooled::take(), t)
362 }
363}