Skip to main content

graphix_compiler/typ/
contains.rs

1use crate::{
2    env::Env,
3    format_with_flags,
4    typ::{tvar::would_cycle_inner, AndAc, RefHist, Type},
5    PrintFlag,
6};
7use anyhow::{bail, Result};
8use enumflags2::bitflags;
9use enumflags2::BitFlags;
10use fxhash::FxHashMap;
11use netidx::publisher::Typ;
12use poolshark::local::LPooled;
13use std::fmt::Debug;
14use triomphe::Arc;
15
16#[derive(Debug, Clone, Copy)]
17#[bitflags]
18#[repr(u8)]
19pub enum ContainsFlags {
20    AliasTVars,
21    InitTVars,
22}
23
24impl Type {
25    pub fn check_contains(&self, env: &Env, t: &Self) -> Result<()> {
26        if self.contains(env, t)? {
27            Ok(())
28        } else {
29            format_with_flags(PrintFlag::DerefTVars | PrintFlag::ReplacePrims, || {
30                bail!("type mismatch {self} does not contain {t}")
31            })
32        }
33    }
34
35    pub(super) fn contains_int(
36        &self,
37        flags: BitFlags<ContainsFlags>,
38        env: &Env,
39        hist: &mut RefHist<FxHashMap<(Option<usize>, Option<usize>), bool>>,
40        t: &Self,
41    ) -> Result<bool> {
42        if (self as *const Type) == (t as *const Type) {
43            return Ok(true);
44        }
45        match (self, t) {
46            (
47                Self::Ref { scope: s0, name: n0, params: p0 },
48                Self::Ref { scope: s1, name: n1, params: p1 },
49            ) if s0 == s1 && n0 == n1 => Ok(p0.len() == p1.len()
50                && p0
51                    .iter()
52                    .zip(p1.iter())
53                    .map(|(t0, t1)| t0.contains_int(flags, env, hist, t1))
54                    .collect::<Result<AndAc>>()?
55                    .0),
56            (t0 @ Self::Ref { .. }, t1) | (t0, t1 @ Self::Ref { .. }) => {
57                let t0_id = hist.ref_id(t0, env);
58                let t1_id = hist.ref_id(t1, env);
59                let t0 = t0.lookup_ref(env)?;
60                let t1 = t1.lookup_ref(env)?;
61                match hist.get(&(t0_id, t1_id)) {
62                    Some(r) => Ok(*r),
63                    None => {
64                        hist.insert((t0_id, t1_id), true);
65                        let r = t0.contains_int(flags, env, hist, &t1);
66                        hist.remove(&(t0_id, t1_id));
67                        r
68                    }
69                }
70            }
71            (Self::TVar(t0), Self::Bottom) => {
72                if let Some(_) = &*t0.read().typ.read() {
73                    return Ok(true);
74                }
75                if flags.contains(ContainsFlags::InitTVars) {
76                    *t0.read().typ.write() = Some(Self::Bottom);
77                }
78                Ok(true)
79            }
80            (Self::Bottom, Self::TVar(t0)) => {
81                if let Some(Type::Bottom) = &*t0.read().typ.read() {
82                    return Ok(true);
83                }
84                if flags.contains(ContainsFlags::InitTVars) {
85                    *t0.read().typ.write() = Some(Self::Bottom);
86                    return Ok(true);
87                }
88                Ok(false)
89            }
90            (Self::Bottom, Self::Bottom) => Ok(true),
91            (Self::Bottom, _) => Ok(false),
92            (_, Self::Bottom) => Ok(true),
93            (Self::TVar(t0), Self::Any) => {
94                if let Some(t0) = &*t0.read().typ.read() {
95                    return t0.contains_int(flags, env, hist, t);
96                }
97                if flags.contains(ContainsFlags::InitTVars) {
98                    *t0.read().typ.write() = Some(Self::Any);
99                }
100                Ok(true)
101            }
102            (Self::Any, _) => Ok(true),
103            (
104                Self::Abstract { id: id0, params: p0 },
105                Self::Abstract { id: id1, params: p1 },
106            ) => Ok(id0 == id1
107                && p0.len() == p1.len()
108                && p0
109                    .iter()
110                    .zip(p1.iter())
111                    .map(|(t0, t1)| t0.contains_int(flags, env, hist, t1))
112                    .collect::<Result<AndAc>>()?
113                    .0),
114            (Self::Primitive(p0), Self::Primitive(p1)) => Ok(p0.contains(*p1)),
115            (
116                Self::Primitive(p),
117                Self::Array(_) | Self::Tuple(_) | Self::Struct(_) | Self::Variant(_, _),
118            ) => Ok(p.contains(Typ::Array)),
119            (Self::Array(t0), Self::Array(t1)) => t0.contains_int(flags, env, hist, t1),
120            (Self::Array(t0), Self::Primitive(p)) if *p == BitFlags::from(Typ::Array) => {
121                t0.contains_int(flags, env, hist, &Type::Any)
122            }
123            (Self::Map { key: k0, value: v0 }, Self::Map { key: k1, value: v1 }) => {
124                Ok(k0.contains_int(flags, env, hist, k1)?
125                    && v0.contains_int(flags, env, hist, v1)?)
126            }
127            (Self::Primitive(p), Self::Map { .. }) => Ok(p.contains(Typ::Map)),
128            (Self::Map { key, value }, Self::Primitive(p))
129                if *p == BitFlags::from(Typ::Map) =>
130            {
131                Ok(key.contains_int(flags, env, hist, &Type::Any)?
132                    && value.contains_int(flags, env, hist, &Type::Any)?)
133            }
134            (Self::Primitive(p0), Self::Error(_)) => Ok(p0.contains(Typ::Error)),
135            (Self::Error(e), Self::Primitive(p)) if *p == BitFlags::from(Typ::Error) => {
136                e.contains_int(flags, env, hist, &Type::Any)
137            }
138            (Self::Error(e0), Self::Error(e1)) => e0.contains_int(flags, env, hist, e1),
139            (Self::Tuple(t0), Self::Tuple(t1))
140                if t0.as_ptr().addr() == t1.as_ptr().addr() =>
141            {
142                Ok(true)
143            }
144            (Self::Tuple(t0), Self::Tuple(t1)) => Ok(t0.len() == t1.len()
145                && t0
146                    .iter()
147                    .zip(t1.iter())
148                    .map(|(t0, t1)| t0.contains_int(flags, env, hist, t1))
149                    .collect::<Result<AndAc>>()?
150                    .0),
151            (Self::Struct(t0), Self::Struct(t1))
152                if t0.as_ptr().addr() == t1.as_ptr().addr() =>
153            {
154                Ok(true)
155            }
156            (Self::Struct(t0), Self::Struct(t1)) => {
157                Ok(t0.len() == t1.len() && {
158                    // struct types are always sorted by field name
159                    t0.iter()
160                        .zip(t1.iter())
161                        .map(|((n0, t0), (n1, t1))| {
162                            Ok(n0 == n1 && t0.contains_int(flags, env, hist, t1)?)
163                        })
164                        .collect::<Result<AndAc>>()?
165                        .0
166                })
167            }
168            (Self::Variant(tg0, t0), Self::Variant(tg1, t1))
169                if tg0.as_ptr() == tg1.as_ptr()
170                    && t0.as_ptr().addr() == t1.as_ptr().addr() =>
171            {
172                Ok(true)
173            }
174            (Self::Variant(tg0, t0), Self::Variant(tg1, t1)) => Ok(tg0 == tg1
175                && t0.len() == t1.len()
176                && t0
177                    .iter()
178                    .zip(t1.iter())
179                    .map(|(t0, t1)| t0.contains_int(flags, env, hist, t1))
180                    .collect::<Result<AndAc>>()?
181                    .0),
182            (Self::ByRef(t0), Self::ByRef(t1)) => t0.contains_int(flags, env, hist, t1),
183            (Self::TVar(t0), Self::TVar(t1))
184                if t0.addr() == t1.addr() || t0.read().id == t1.read().id =>
185            {
186                Ok(true)
187            }
188            (tt0 @ Self::TVar(t0), tt1 @ Self::TVar(t1)) => {
189                #[derive(Debug)]
190                enum Act {
191                    RightCopy,
192                    RightAlias,
193                    LeftAlias,
194                    LeftCopy,
195                }
196                if t0.would_cycle(tt1) || t1.would_cycle(tt0) {
197                    return Ok(true);
198                }
199                let act = {
200                    let t0 = t0.read();
201                    let t1 = t1.read();
202                    let addr0 = Arc::as_ptr(&t0.typ).addr();
203                    let addr1 = Arc::as_ptr(&t1.typ).addr();
204                    if addr0 == addr1 {
205                        return Ok(true);
206                    }
207                    if would_cycle_inner(addr0, tt1) || would_cycle_inner(addr1, tt0) {
208                        return Ok(true);
209                    }
210                    let t0i = t0.typ.read();
211                    let t1i = t1.typ.read();
212                    match (&*t0i, &*t1i) {
213                        (Some(t0), Some(t1)) => {
214                            return t0.contains_int(flags, env, hist, &*t1)
215                        }
216                        (None, None) => {
217                            if t0.frozen && t1.frozen {
218                                return Ok(true);
219                            }
220                            if t0.frozen {
221                                Act::RightAlias
222                            } else {
223                                Act::LeftAlias
224                            }
225                        }
226                        (Some(_), None) => Act::RightCopy,
227                        (None, Some(_)) => Act::LeftCopy,
228                    }
229                };
230                match act {
231                    Act::RightCopy if flags.contains(ContainsFlags::InitTVars) => {
232                        t1.copy(t0)
233                    }
234                    Act::RightAlias if flags.contains(ContainsFlags::AliasTVars) => {
235                        t1.alias(t0)
236                    }
237                    Act::LeftAlias if flags.contains(ContainsFlags::AliasTVars) => {
238                        t0.alias(t1)
239                    }
240                    Act::LeftCopy if flags.contains(ContainsFlags::InitTVars) => {
241                        t0.copy(t1)
242                    }
243                    Act::RightCopy | Act::RightAlias | Act::LeftAlias | Act::LeftCopy => {
244                        ()
245                    }
246                }
247                Ok(true)
248            }
249            (Self::TVar(t0), t1) if !t0.would_cycle(t1) => {
250                if let Some(t0) = &*t0.read().typ.read() {
251                    return t0.contains_int(flags, env, hist, t1);
252                }
253                if flags.contains(ContainsFlags::InitTVars) {
254                    *t0.read().typ.write() = Some(t1.clone());
255                }
256                Ok(true)
257            }
258            (t0, Self::TVar(t1)) if !t1.would_cycle(t0) => {
259                if let Some(t1) = &*t1.read().typ.read() {
260                    return t0.contains_int(flags, env, hist, t1);
261                }
262                if flags.contains(ContainsFlags::InitTVars) {
263                    *t1.read().typ.write() = Some(t0.clone());
264                }
265                Ok(true)
266            }
267            (Self::Set(s0), Self::Set(s1))
268                if s0.as_ptr().addr() == s1.as_ptr().addr() =>
269            {
270                Ok(true)
271            }
272            (t0, Self::Set(s)) => Ok(s
273                .iter()
274                .map(|t1| t0.contains_int(flags, env, hist, t1))
275                .collect::<Result<AndAc>>()?
276                .0),
277            (Self::Set(s), t) => Ok(s
278                .iter()
279                .fold(Ok::<_, anyhow::Error>(false), |acc, t0| {
280                    Ok(acc? || t0.contains_int(flags, env, hist, t)?)
281                })?
282                || t.iter_prims().fold(Ok::<_, anyhow::Error>(true), |acc, t1| {
283                    Ok(acc?
284                        && s.iter().fold(Ok::<_, anyhow::Error>(false), |acc, t0| {
285                            Ok(acc? || t0.contains_int(flags, env, hist, &t1)?)
286                        })?)
287                })?),
288            (Self::Fn(f0), Self::Fn(f1)) => {
289                Ok(f0.as_ptr() == f1.as_ptr() || f0.contains_int(flags, env, hist, f1)?)
290            }
291            (_, Self::Any)
292            | (Self::Abstract { .. }, _)
293            | (_, Self::Abstract { .. })
294            | (_, Self::TVar(_))
295            | (Self::TVar(_), _)
296            | (Self::Fn(_), _)
297            | (Self::ByRef(_), _)
298            | (_, Self::ByRef(_))
299            | (_, Self::Fn(_))
300            | (Self::Tuple(_), Self::Array(_))
301            | (Self::Tuple(_), Self::Primitive(_))
302            | (Self::Tuple(_), Self::Struct(_))
303            | (Self::Tuple(_), Self::Variant(_, _))
304            | (Self::Tuple(_), Self::Error(_))
305            | (Self::Tuple(_), Self::Map { .. })
306            | (Self::Array(_), Self::Primitive(_))
307            | (Self::Array(_), Self::Tuple(_))
308            | (Self::Array(_), Self::Struct(_))
309            | (Self::Array(_), Self::Variant(_, _))
310            | (Self::Array(_), Self::Error(_))
311            | (Self::Array(_), Self::Map { .. })
312            | (Self::Struct(_), Self::Array(_))
313            | (Self::Struct(_), Self::Primitive(_))
314            | (Self::Struct(_), Self::Tuple(_))
315            | (Self::Struct(_), Self::Variant(_, _))
316            | (Self::Struct(_), Self::Error(_))
317            | (Self::Struct(_), Self::Map { .. })
318            | (Self::Variant(_, _), Self::Array(_))
319            | (Self::Variant(_, _), Self::Struct(_))
320            | (Self::Variant(_, _), Self::Primitive(_))
321            | (Self::Variant(_, _), Self::Tuple(_))
322            | (Self::Variant(_, _), Self::Error(_))
323            | (Self::Variant(_, _), Self::Map { .. })
324            | (Self::Error(_), Self::Array(_))
325            | (Self::Error(_), Self::Primitive(_))
326            | (Self::Error(_), Self::Struct(_))
327            | (Self::Error(_), Self::Variant(_, _))
328            | (Self::Error(_), Self::Tuple(_))
329            | (Self::Error(_), Self::Map { .. })
330            | (Self::Map { .. }, Self::Array(_))
331            | (Self::Map { .. }, Self::Primitive(_))
332            | (Self::Map { .. }, Self::Struct(_))
333            | (Self::Map { .. }, Self::Variant(_, _))
334            | (Self::Map { .. }, Self::Tuple(_))
335            | (Self::Map { .. }, Self::Error(_)) => Ok(false),
336        }
337    }
338
339    pub fn contains(&self, env: &Env, t: &Self) -> Result<bool> {
340        self.contains_int(
341            ContainsFlags::AliasTVars | ContainsFlags::InitTVars,
342            env,
343            &mut RefHist::new(LPooled::take()),
344            t,
345        )
346    }
347
348    pub fn contains_with_flags(
349        &self,
350        flags: BitFlags<ContainsFlags>,
351        env: &Env,
352        t: &Self,
353    ) -> Result<bool> {
354        self.contains_int(flags, env, &mut RefHist::new(LPooled::take()), t)
355    }
356}