1use crate::{
2 env::Env,
3 format_with_flags,
4 typ::{tvar::would_cycle_inner, AndAc, RefHist, Type},
5 PrintFlag,
6};
7use anyhow::{bail, Result};
8use enumflags2::bitflags;
9use enumflags2::BitFlags;
10use fxhash::FxHashMap;
11use netidx::publisher::Typ;
12use poolshark::local::LPooled;
13use std::fmt::Debug;
14use triomphe::Arc;
15
16#[derive(Debug, Clone, Copy)]
17#[bitflags]
18#[repr(u8)]
19pub enum ContainsFlags {
20 AliasTVars,
21 InitTVars,
22}
23
24impl Type {
25 pub fn check_contains(&self, env: &Env, t: &Self) -> Result<()> {
26 if self.contains(env, t)? {
27 Ok(())
28 } else {
29 format_with_flags(PrintFlag::DerefTVars | PrintFlag::ReplacePrims, || {
30 bail!("type mismatch {self} does not contain {t}")
31 })
32 }
33 }
34
35 pub(super) fn contains_int(
36 &self,
37 flags: BitFlags<ContainsFlags>,
38 env: &Env,
39 hist: &mut RefHist<FxHashMap<(Option<usize>, Option<usize>), bool>>,
40 t: &Self,
41 ) -> Result<bool> {
42 if (self as *const Type) == (t as *const Type) {
43 return Ok(true);
44 }
45 match (self, t) {
46 (
47 Self::Ref { scope: s0, name: n0, params: p0 },
48 Self::Ref { scope: s1, name: n1, params: p1 },
49 ) if s0 == s1 && n0 == n1 => Ok(p0.len() == p1.len()
50 && p0
51 .iter()
52 .zip(p1.iter())
53 .map(|(t0, t1)| t0.contains_int(flags, env, hist, t1))
54 .collect::<Result<AndAc>>()?
55 .0),
56 (t0 @ Self::Ref { .. }, t1) | (t0, t1 @ Self::Ref { .. }) => {
57 let t0_id = hist.ref_id(t0, env);
58 let t1_id = hist.ref_id(t1, env);
59 let t0 = t0.lookup_ref(env)?;
60 let t1 = t1.lookup_ref(env)?;
61 match hist.get(&(t0_id, t1_id)) {
62 Some(r) => Ok(*r),
63 None => {
64 hist.insert((t0_id, t1_id), true);
65 let r = t0.contains_int(flags, env, hist, &t1);
66 hist.remove(&(t0_id, t1_id));
67 r
68 }
69 }
70 }
71 (Self::TVar(t0), Self::Bottom) => {
72 if let Some(_) = &*t0.read().typ.read() {
73 return Ok(true);
74 }
75 if flags.contains(ContainsFlags::InitTVars) {
76 *t0.read().typ.write() = Some(Self::Bottom);
77 }
78 Ok(true)
79 }
80 (Self::Bottom, Self::TVar(t0)) => {
81 if let Some(Type::Bottom) = &*t0.read().typ.read() {
82 return Ok(true);
83 }
84 if flags.contains(ContainsFlags::InitTVars) {
85 *t0.read().typ.write() = Some(Self::Bottom);
86 return Ok(true);
87 }
88 Ok(false)
89 }
90 (Self::Bottom, Self::Bottom) => Ok(true),
91 (Self::Bottom, _) => Ok(false),
92 (_, Self::Bottom) => Ok(true),
93 (Self::TVar(t0), Self::Any) => {
94 if let Some(t0) = &*t0.read().typ.read() {
95 return t0.contains_int(flags, env, hist, t);
96 }
97 if flags.contains(ContainsFlags::InitTVars) {
98 *t0.read().typ.write() = Some(Self::Any);
99 }
100 Ok(true)
101 }
102 (Self::Any, _) => Ok(true),
103 (
104 Self::Abstract { id: id0, params: p0 },
105 Self::Abstract { id: id1, params: p1 },
106 ) => Ok(id0 == id1
107 && p0.len() == p1.len()
108 && p0
109 .iter()
110 .zip(p1.iter())
111 .map(|(t0, t1)| t0.contains_int(flags, env, hist, t1))
112 .collect::<Result<AndAc>>()?
113 .0),
114 (Self::Primitive(p0), Self::Primitive(p1)) => Ok(p0.contains(*p1)),
115 (
116 Self::Primitive(p),
117 Self::Array(_) | Self::Tuple(_) | Self::Struct(_) | Self::Variant(_, _),
118 ) => Ok(p.contains(Typ::Array)),
119 (Self::Array(t0), Self::Array(t1)) => t0.contains_int(flags, env, hist, t1),
120 (Self::Array(t0), Self::Primitive(p)) if *p == BitFlags::from(Typ::Array) => {
121 t0.contains_int(flags, env, hist, &Type::Any)
122 }
123 (Self::Map { key: k0, value: v0 }, Self::Map { key: k1, value: v1 }) => {
124 Ok(k0.contains_int(flags, env, hist, k1)?
125 && v0.contains_int(flags, env, hist, v1)?)
126 }
127 (Self::Primitive(p), Self::Map { .. }) => Ok(p.contains(Typ::Map)),
128 (Self::Map { key, value }, Self::Primitive(p))
129 if *p == BitFlags::from(Typ::Map) =>
130 {
131 Ok(key.contains_int(flags, env, hist, &Type::Any)?
132 && value.contains_int(flags, env, hist, &Type::Any)?)
133 }
134 (Self::Primitive(p0), Self::Error(_)) => Ok(p0.contains(Typ::Error)),
135 (Self::Error(e), Self::Primitive(p)) if *p == BitFlags::from(Typ::Error) => {
136 e.contains_int(flags, env, hist, &Type::Any)
137 }
138 (Self::Error(e0), Self::Error(e1)) => e0.contains_int(flags, env, hist, e1),
139 (Self::Tuple(t0), Self::Tuple(t1))
140 if t0.as_ptr().addr() == t1.as_ptr().addr() =>
141 {
142 Ok(true)
143 }
144 (Self::Tuple(t0), Self::Tuple(t1)) => Ok(t0.len() == t1.len()
145 && t0
146 .iter()
147 .zip(t1.iter())
148 .map(|(t0, t1)| t0.contains_int(flags, env, hist, t1))
149 .collect::<Result<AndAc>>()?
150 .0),
151 (Self::Struct(t0), Self::Struct(t1))
152 if t0.as_ptr().addr() == t1.as_ptr().addr() =>
153 {
154 Ok(true)
155 }
156 (Self::Struct(t0), Self::Struct(t1)) => {
157 Ok(t0.len() == t1.len() && {
158 t0.iter()
160 .zip(t1.iter())
161 .map(|((n0, t0), (n1, t1))| {
162 Ok(n0 == n1 && t0.contains_int(flags, env, hist, t1)?)
163 })
164 .collect::<Result<AndAc>>()?
165 .0
166 })
167 }
168 (Self::Variant(tg0, t0), Self::Variant(tg1, t1))
169 if tg0.as_ptr() == tg1.as_ptr()
170 && t0.as_ptr().addr() == t1.as_ptr().addr() =>
171 {
172 Ok(true)
173 }
174 (Self::Variant(tg0, t0), Self::Variant(tg1, t1)) => Ok(tg0 == tg1
175 && t0.len() == t1.len()
176 && t0
177 .iter()
178 .zip(t1.iter())
179 .map(|(t0, t1)| t0.contains_int(flags, env, hist, t1))
180 .collect::<Result<AndAc>>()?
181 .0),
182 (Self::ByRef(t0), Self::ByRef(t1)) => t0.contains_int(flags, env, hist, t1),
183 (Self::TVar(t0), Self::TVar(t1))
184 if t0.addr() == t1.addr() || t0.read().id == t1.read().id =>
185 {
186 Ok(true)
187 }
188 (tt0 @ Self::TVar(t0), tt1 @ Self::TVar(t1)) => {
189 #[derive(Debug)]
190 enum Act {
191 RightCopy,
192 RightAlias,
193 LeftAlias,
194 LeftCopy,
195 }
196 if t0.would_cycle(tt1) || t1.would_cycle(tt0) {
197 return Ok(true);
198 }
199 let act = {
200 let t0 = t0.read();
201 let t1 = t1.read();
202 let addr0 = Arc::as_ptr(&t0.typ).addr();
203 let addr1 = Arc::as_ptr(&t1.typ).addr();
204 if addr0 == addr1 {
205 return Ok(true);
206 }
207 if would_cycle_inner(addr0, tt1) || would_cycle_inner(addr1, tt0) {
208 return Ok(true);
209 }
210 let t0i = t0.typ.read();
211 let t1i = t1.typ.read();
212 match (&*t0i, &*t1i) {
213 (Some(t0), Some(t1)) => {
214 return t0.contains_int(flags, env, hist, &*t1)
215 }
216 (None, None) => {
217 if t0.frozen && t1.frozen {
218 return Ok(true);
219 }
220 if t0.frozen {
221 Act::RightAlias
222 } else {
223 Act::LeftAlias
224 }
225 }
226 (Some(_), None) => Act::RightCopy,
227 (None, Some(_)) => Act::LeftCopy,
228 }
229 };
230 match act {
231 Act::RightCopy if flags.contains(ContainsFlags::InitTVars) => {
232 t1.copy(t0)
233 }
234 Act::RightAlias if flags.contains(ContainsFlags::AliasTVars) => {
235 t1.alias(t0)
236 }
237 Act::LeftAlias if flags.contains(ContainsFlags::AliasTVars) => {
238 t0.alias(t1)
239 }
240 Act::LeftCopy if flags.contains(ContainsFlags::InitTVars) => {
241 t0.copy(t1)
242 }
243 Act::RightCopy | Act::RightAlias | Act::LeftAlias | Act::LeftCopy => {
244 ()
245 }
246 }
247 Ok(true)
248 }
249 (Self::TVar(t0), t1) if !t0.would_cycle(t1) => {
250 if let Some(t0) = &*t0.read().typ.read() {
251 return t0.contains_int(flags, env, hist, t1);
252 }
253 if flags.contains(ContainsFlags::InitTVars) {
254 *t0.read().typ.write() = Some(t1.clone());
255 }
256 Ok(true)
257 }
258 (t0, Self::TVar(t1)) if !t1.would_cycle(t0) => {
259 if let Some(t1) = &*t1.read().typ.read() {
260 return t0.contains_int(flags, env, hist, t1);
261 }
262 if flags.contains(ContainsFlags::InitTVars) {
263 *t1.read().typ.write() = Some(t0.clone());
264 }
265 Ok(true)
266 }
267 (Self::Set(s0), Self::Set(s1))
268 if s0.as_ptr().addr() == s1.as_ptr().addr() =>
269 {
270 Ok(true)
271 }
272 (t0, Self::Set(s)) => Ok(s
273 .iter()
274 .map(|t1| t0.contains_int(flags, env, hist, t1))
275 .collect::<Result<AndAc>>()?
276 .0),
277 (Self::Set(s), t) => Ok(s
278 .iter()
279 .fold(Ok::<_, anyhow::Error>(false), |acc, t0| {
280 Ok(acc? || t0.contains_int(flags, env, hist, t)?)
281 })?
282 || t.iter_prims().fold(Ok::<_, anyhow::Error>(true), |acc, t1| {
283 Ok(acc?
284 && s.iter().fold(Ok::<_, anyhow::Error>(false), |acc, t0| {
285 Ok(acc? || t0.contains_int(flags, env, hist, &t1)?)
286 })?)
287 })?),
288 (Self::Fn(f0), Self::Fn(f1)) => {
289 Ok(f0.as_ptr() == f1.as_ptr() || f0.contains_int(flags, env, hist, f1)?)
290 }
291 (_, Self::Any)
292 | (Self::Abstract { .. }, _)
293 | (_, Self::Abstract { .. })
294 | (_, Self::TVar(_))
295 | (Self::TVar(_), _)
296 | (Self::Fn(_), _)
297 | (Self::ByRef(_), _)
298 | (_, Self::ByRef(_))
299 | (_, Self::Fn(_))
300 | (Self::Tuple(_), Self::Array(_))
301 | (Self::Tuple(_), Self::Primitive(_))
302 | (Self::Tuple(_), Self::Struct(_))
303 | (Self::Tuple(_), Self::Variant(_, _))
304 | (Self::Tuple(_), Self::Error(_))
305 | (Self::Tuple(_), Self::Map { .. })
306 | (Self::Array(_), Self::Primitive(_))
307 | (Self::Array(_), Self::Tuple(_))
308 | (Self::Array(_), Self::Struct(_))
309 | (Self::Array(_), Self::Variant(_, _))
310 | (Self::Array(_), Self::Error(_))
311 | (Self::Array(_), Self::Map { .. })
312 | (Self::Struct(_), Self::Array(_))
313 | (Self::Struct(_), Self::Primitive(_))
314 | (Self::Struct(_), Self::Tuple(_))
315 | (Self::Struct(_), Self::Variant(_, _))
316 | (Self::Struct(_), Self::Error(_))
317 | (Self::Struct(_), Self::Map { .. })
318 | (Self::Variant(_, _), Self::Array(_))
319 | (Self::Variant(_, _), Self::Struct(_))
320 | (Self::Variant(_, _), Self::Primitive(_))
321 | (Self::Variant(_, _), Self::Tuple(_))
322 | (Self::Variant(_, _), Self::Error(_))
323 | (Self::Variant(_, _), Self::Map { .. })
324 | (Self::Error(_), Self::Array(_))
325 | (Self::Error(_), Self::Primitive(_))
326 | (Self::Error(_), Self::Struct(_))
327 | (Self::Error(_), Self::Variant(_, _))
328 | (Self::Error(_), Self::Tuple(_))
329 | (Self::Error(_), Self::Map { .. })
330 | (Self::Map { .. }, Self::Array(_))
331 | (Self::Map { .. }, Self::Primitive(_))
332 | (Self::Map { .. }, Self::Struct(_))
333 | (Self::Map { .. }, Self::Variant(_, _))
334 | (Self::Map { .. }, Self::Tuple(_))
335 | (Self::Map { .. }, Self::Error(_)) => Ok(false),
336 }
337 }
338
339 pub fn contains(&self, env: &Env, t: &Self) -> Result<bool> {
340 self.contains_int(
341 ContainsFlags::AliasTVars | ContainsFlags::InitTVars,
342 env,
343 &mut RefHist::new(LPooled::take()),
344 t,
345 )
346 }
347
348 pub fn contains_with_flags(
349 &self,
350 flags: BitFlags<ContainsFlags>,
351 env: &Env,
352 t: &Self,
353 ) -> Result<bool> {
354 self.contains_int(flags, env, &mut RefHist::new(LPooled::take()), t)
355 }
356}