Skip to main content

graphix_compiler/typ/
contains.rs

1use crate::{
2    env::Env,
3    format_with_flags,
4    typ::{tvar::would_cycle_inner, AndAc, RefHist, Type, TypeRef},
5    PrintFlag,
6};
7use ahash::AHashMap;
8use anyhow::{bail, Result};
9use enumflags2::bitflags;
10use enumflags2::BitFlags;
11use netidx::publisher::Typ;
12use poolshark::local::LPooled;
13use std::fmt::Debug;
14use triomphe::Arc;
15
16#[derive(Debug, Clone, Copy)]
17#[bitflags]
18#[repr(u8)]
19pub enum ContainsFlags {
20    AliasTVars,
21    InitTVars,
22}
23
24impl Type {
25    pub fn check_contains(&self, env: &Env, t: &Self) -> Result<()> {
26        if self.contains(env, t)? {
27            Ok(())
28        } else {
29            format_with_flags(PrintFlag::DerefTVars | PrintFlag::ReplacePrims, || {
30                bail!("type mismatch {self} does not contain {t}")
31            })
32        }
33    }
34
35    pub(super) fn contains_int(
36        &self,
37        flags: BitFlags<ContainsFlags>,
38        env: &Env,
39        hist: &mut RefHist<AHashMap<(Option<usize>, Option<usize>), bool>>,
40        t: &Self,
41    ) -> Result<bool> {
42        if (self as *const Type) == (t as *const Type) {
43            return Ok(true);
44        }
45        match (self, t) {
46            (
47                Self::Ref(TypeRef { scope: s0, name: n0, params: p0, .. }),
48                Self::Ref(TypeRef { scope: s1, name: n1, params: p1, .. }),
49            ) if s0 == s1 && n0 == n1 => Ok(p0.len() == p1.len()
50                && p0
51                    .iter()
52                    .zip(p1.iter())
53                    .map(|(t0, t1)| t0.contains_int(flags, env, hist, t1))
54                    .collect::<Result<AndAc>>()?
55                    .0),
56            (t0 @ Self::Ref(TypeRef { .. }), t1)
57            | (t0, t1 @ Self::Ref(TypeRef { .. })) => {
58                let t0_id = hist.ref_id(t0, env);
59                let t1_id = hist.ref_id(t1, env);
60                let t0 = t0.lookup_ref(env)?;
61                let t1 = t1.lookup_ref(env)?;
62                match hist.get(&(t0_id, t1_id)) {
63                    Some(r) => Ok(*r),
64                    None => {
65                        hist.insert((t0_id, t1_id), true);
66                        let r = t0.contains_int(flags, env, hist, &t1);
67                        hist.remove(&(t0_id, t1_id));
68                        r
69                    }
70                }
71            }
72            (Self::TVar(t0), Self::Bottom) => {
73                if let Some(_) = &*t0.read().typ.read() {
74                    return Ok(true);
75                }
76                if flags.contains(ContainsFlags::InitTVars) {
77                    *t0.read().typ.write() = Some(Self::Bottom);
78                }
79                Ok(true)
80            }
81            (Self::Bottom, Self::TVar(t0)) => {
82                if let Some(Type::Bottom) = &*t0.read().typ.read() {
83                    return Ok(true);
84                }
85                if flags.contains(ContainsFlags::InitTVars) {
86                    *t0.read().typ.write() = Some(Self::Bottom);
87                }
88                Ok(true)
89            }
90            (Self::Bottom, Self::Bottom) => Ok(true),
91            (Self::Bottom, _) => Ok(false),
92            (_, Self::Bottom) => Ok(true),
93            (Self::TVar(t0), Self::Any) => {
94                if let Some(t0) = &*t0.read().typ.read() {
95                    return t0.contains_int(flags, env, hist, t);
96                }
97                if flags.contains(ContainsFlags::InitTVars) {
98                    *t0.read().typ.write() = Some(Self::Any);
99                }
100                Ok(true)
101            }
102            (Self::Any, _) => Ok(true),
103            (
104                Self::Abstract { id: id0, params: p0 },
105                Self::Abstract { id: id1, params: p1 },
106            ) => Ok(id0 == id1
107                && p0.len() == p1.len()
108                && p0
109                    .iter()
110                    .zip(p1.iter())
111                    .map(|(t0, t1)| t0.contains_int(flags, env, hist, t1))
112                    .collect::<Result<AndAc>>()?
113                    .0),
114            (Self::Primitive(p0), Self::Primitive(p1)) => Ok(p0.contains(*p1)),
115            (
116                Self::Primitive(p),
117                Self::Array(_) | Self::Tuple(_) | Self::Struct(_) | Self::Variant(_, _),
118            ) => Ok(p.contains(Typ::Array)),
119            (Self::Array(t0), Self::Array(t1)) => t0.contains_int(flags, env, hist, t1),
120            (Self::Array(t0), Self::Primitive(p)) if *p == BitFlags::from(Typ::Array) => {
121                t0.contains_int(flags, env, hist, &Type::Any)
122            }
123            (Self::Map { key: k0, value: v0 }, Self::Map { key: k1, value: v1 }) => {
124                Ok(k0.contains_int(flags, env, hist, k1)?
125                    && v0.contains_int(flags, env, hist, v1)?)
126            }
127            (Self::Primitive(p), Self::Map { .. }) => Ok(p.contains(Typ::Map)),
128            (Self::Map { key, value }, Self::Primitive(p))
129                if *p == BitFlags::from(Typ::Map) =>
130            {
131                Ok(key.contains_int(flags, env, hist, &Type::Any)?
132                    && value.contains_int(flags, env, hist, &Type::Any)?)
133            }
134            (Self::Primitive(p0), Self::Error(_)) => Ok(p0.contains(Typ::Error)),
135            (Self::Error(e), Self::Primitive(p)) if *p == BitFlags::from(Typ::Error) => {
136                e.contains_int(flags, env, hist, &Type::Any)
137            }
138            (Self::Error(e0), Self::Error(e1)) => e0.contains_int(flags, env, hist, e1),
139            (Self::Tuple(t0), Self::Tuple(t1)) if Arc::ptr_eq(t0, t1) => Ok(true),
140            (Self::Tuple(t0), Self::Tuple(t1)) => Ok(t0.len() == t1.len()
141                && t0
142                    .iter()
143                    .zip(t1.iter())
144                    .map(|(t0, t1)| t0.contains_int(flags, env, hist, t1))
145                    .collect::<Result<AndAc>>()?
146                    .0),
147            (Self::Struct(t0), Self::Struct(t1)) if Arc::ptr_eq(t0, t1) => Ok(true),
148            (Self::Struct(t0), Self::Struct(t1)) => {
149                Ok(t0.len() == t1.len() && {
150                    // struct types are always sorted by field name
151                    t0.iter()
152                        .zip(t1.iter())
153                        .map(|((n0, t0), (n1, t1))| {
154                            Ok(n0 == n1 && t0.contains_int(flags, env, hist, t1)?)
155                        })
156                        .collect::<Result<AndAc>>()?
157                        .0
158                })
159            }
160            (Self::Variant(tg0, t0), Self::Variant(tg1, t1))
161                if tg0.as_ptr() == tg1.as_ptr() && Arc::ptr_eq(t0, t1) =>
162            {
163                Ok(true)
164            }
165            (Self::Variant(tg0, t0), Self::Variant(tg1, t1)) => Ok(tg0 == tg1
166                && t0.len() == t1.len()
167                && t0
168                    .iter()
169                    .zip(t1.iter())
170                    .map(|(t0, t1)| t0.contains_int(flags, env, hist, t1))
171                    .collect::<Result<AndAc>>()?
172                    .0),
173            (Self::ByRef(t0), Self::ByRef(t1)) => t0.contains_int(flags, env, hist, t1),
174            (Self::TVar(t0), Self::TVar(t1))
175                if t0.addr() == t1.addr() || t0.read().id == t1.read().id =>
176            {
177                Ok(true)
178            }
179            (tt0 @ Self::TVar(t0), tt1 @ Self::TVar(t1)) => {
180                #[derive(Debug)]
181                enum Act {
182                    RightCopy,
183                    RightAlias,
184                    LeftAlias,
185                    LeftCopy,
186                }
187                if t0.would_cycle(tt1) || t1.would_cycle(tt0) {
188                    return Ok(true);
189                }
190                let act = {
191                    let t0 = t0.read();
192                    let t1 = t1.read();
193                    let addr0 = Arc::as_ptr(&t0.typ).addr();
194                    let addr1 = Arc::as_ptr(&t1.typ).addr();
195                    if addr0 == addr1 {
196                        return Ok(true);
197                    }
198                    if would_cycle_inner(addr0, tt1) || would_cycle_inner(addr1, tt0) {
199                        return Ok(true);
200                    }
201                    let t0i = t0.typ.read();
202                    let t1i = t1.typ.read();
203                    match (&*t0i, &*t1i) {
204                        (Some(t0), Some(t1)) => {
205                            return t0.contains_int(flags, env, hist, &*t1)
206                        }
207                        (None, None) => {
208                            if t0.frozen && t1.frozen {
209                                return Ok(true);
210                            }
211                            if t0.frozen {
212                                Act::RightAlias
213                            } else {
214                                Act::LeftAlias
215                            }
216                        }
217                        (Some(_), None) => Act::RightCopy,
218                        (None, Some(_)) => Act::LeftCopy,
219                    }
220                };
221                match act {
222                    Act::RightCopy if flags.contains(ContainsFlags::InitTVars) => {
223                        t1.copy(t0)
224                    }
225                    Act::RightAlias if flags.contains(ContainsFlags::AliasTVars) => {
226                        t1.alias(t0)
227                    }
228                    Act::LeftAlias if flags.contains(ContainsFlags::AliasTVars) => {
229                        t0.alias(t1)
230                    }
231                    Act::LeftCopy if flags.contains(ContainsFlags::InitTVars) => {
232                        t0.copy(t1)
233                    }
234                    Act::RightCopy | Act::RightAlias | Act::LeftAlias | Act::LeftCopy => {
235                        ()
236                    }
237                }
238                Ok(true)
239            }
240            (Self::TVar(t0), t1) if !t0.would_cycle(t1) => {
241                if let Some(t0) = &*t0.read().typ.read() {
242                    return t0.contains_int(flags, env, hist, t1);
243                }
244                if flags.contains(ContainsFlags::InitTVars) {
245                    *t0.read().typ.write() = Some(t1.clone());
246                }
247                Ok(true)
248            }
249            (t0, Self::TVar(t1)) if !t1.would_cycle(t0) => {
250                if let Some(t1) = &*t1.read().typ.read() {
251                    return t0.contains_int(flags, env, hist, t1);
252                }
253                if flags.contains(ContainsFlags::InitTVars) {
254                    *t1.read().typ.write() = Some(t0.clone());
255                }
256                Ok(true)
257            }
258            (Self::Set(s0), Self::Set(s1)) if Arc::ptr_eq(s0, s1) => Ok(true),
259            (t0 @ Self::Set(_), t1 @ Self::Set(_)) if t0 == t1 => {
260                if flags.contains(ContainsFlags::InitTVars) {
261                    let mut known = LPooled::take();
262                    t0.alias_tvars(&mut known);
263                    t1.alias_tvars(&mut known);
264                }
265                Ok(true)
266            }
267            (t0, Self::Set(s)) => Ok(s
268                .iter()
269                .map(|t1| t0.contains_int(flags, env, hist, t1))
270                .collect::<Result<AndAc>>()?
271                .0),
272            (Self::Set(s), t) => {
273                let probe = BitFlags::empty();
274                let whole_ok =
275                    s.iter().fold(Ok::<_, anyhow::Error>(false), |acc, t0| {
276                        Ok(acc? || t0.contains_int(probe, env, hist, t)?)
277                    })?;
278                let prims_ok =
279                    t.iter_prims().fold(Ok::<_, anyhow::Error>(true), |acc, t1| {
280                        Ok(acc?
281                            && s.iter().fold(
282                                Ok::<_, anyhow::Error>(false),
283                                |acc, t0| {
284                                    Ok(acc? || t0.contains_int(probe, env, hist, &t1)?)
285                                },
286                            )?)
287                    })?;
288                match (whole_ok, prims_ok) {
289                    (false, false) => Ok(false),
290                    // prefer prims when valid — narrowest TVar bindings
291                    (_, true) => Ok(t.iter_prims().fold(
292                        Ok::<_, anyhow::Error>(true),
293                        |acc, t1| {
294                            Ok(acc?
295                                && s.iter().fold(
296                                    Ok::<_, anyhow::Error>(false),
297                                    |acc, t0| {
298                                        Ok(acc?
299                                            || t0.contains_int(flags, env, hist, &t1)?)
300                                    },
301                                )?)
302                        },
303                    )?),
304                    (true, false) => {
305                        Ok(s.iter().fold(Ok::<_, anyhow::Error>(false), |acc, t0| {
306                            Ok(acc? || t0.contains_int(flags, env, hist, t)?)
307                        })?)
308                    }
309                }
310            }
311            (Self::Fn(f0), Self::Fn(f1)) => {
312                let same = Arc::ptr_eq(f0, f1);
313                let r = same || f0.contains_int(flags, env, hist, f1)?;
314                if r && !same && flags.contains(ContainsFlags::InitTVars) {
315                    f0.merge_lambda_ids(f1);
316                }
317                Ok(r)
318            }
319            (_, Self::Any)
320            | (Self::Abstract { .. }, _)
321            | (_, Self::Abstract { .. })
322            | (_, Self::TVar(_))
323            | (Self::TVar(_), _)
324            | (Self::Fn(_), _)
325            | (Self::ByRef(_), _)
326            | (_, Self::ByRef(_))
327            | (_, Self::Fn(_))
328            | (Self::Tuple(_), Self::Array(_))
329            | (Self::Tuple(_), Self::Primitive(_))
330            | (Self::Tuple(_), Self::Struct(_))
331            | (Self::Tuple(_), Self::Variant(_, _))
332            | (Self::Tuple(_), Self::Error(_))
333            | (Self::Tuple(_), Self::Map { .. })
334            | (Self::Array(_), Self::Primitive(_))
335            | (Self::Array(_), Self::Tuple(_))
336            | (Self::Array(_), Self::Struct(_))
337            | (Self::Array(_), Self::Variant(_, _))
338            | (Self::Array(_), Self::Error(_))
339            | (Self::Array(_), Self::Map { .. })
340            | (Self::Struct(_), Self::Array(_))
341            | (Self::Struct(_), Self::Primitive(_))
342            | (Self::Struct(_), Self::Tuple(_))
343            | (Self::Struct(_), Self::Variant(_, _))
344            | (Self::Struct(_), Self::Error(_))
345            | (Self::Struct(_), Self::Map { .. })
346            | (Self::Variant(_, _), Self::Array(_))
347            | (Self::Variant(_, _), Self::Struct(_))
348            | (Self::Variant(_, _), Self::Primitive(_))
349            | (Self::Variant(_, _), Self::Tuple(_))
350            | (Self::Variant(_, _), Self::Error(_))
351            | (Self::Variant(_, _), Self::Map { .. })
352            | (Self::Error(_), Self::Array(_))
353            | (Self::Error(_), Self::Primitive(_))
354            | (Self::Error(_), Self::Struct(_))
355            | (Self::Error(_), Self::Variant(_, _))
356            | (Self::Error(_), Self::Tuple(_))
357            | (Self::Error(_), Self::Map { .. })
358            | (Self::Map { .. }, Self::Array(_))
359            | (Self::Map { .. }, Self::Primitive(_))
360            | (Self::Map { .. }, Self::Struct(_))
361            | (Self::Map { .. }, Self::Variant(_, _))
362            | (Self::Map { .. }, Self::Tuple(_))
363            | (Self::Map { .. }, Self::Error(_)) => Ok(false),
364        }
365    }
366
367    pub fn contains(&self, env: &Env, t: &Self) -> Result<bool> {
368        self.contains_int(
369            ContainsFlags::AliasTVars | ContainsFlags::InitTVars,
370            env,
371            &mut RefHist::new(LPooled::take()),
372            t,
373        )
374    }
375
376    pub fn contains_with_flags(
377        &self,
378        flags: BitFlags<ContainsFlags>,
379        env: &Env,
380        t: &Self,
381    ) -> Result<bool> {
382        self.contains_int(flags, env, &mut RefHist::new(LPooled::take()), t)
383    }
384}