Skip to main content

graphix_compiler/typ/
contains.rs

1use crate::{
2    env::Env,
3    format_with_flags,
4    typ::{tvar::would_cycle_inner, AndAc, Type},
5    PrintFlag,
6};
7use anyhow::{bail, Result};
8use enumflags2::bitflags;
9use enumflags2::BitFlags;
10use fxhash::FxHashMap;
11use netidx::publisher::Typ;
12use poolshark::local::LPooled;
13use std::fmt::Debug;
14use triomphe::Arc;
15
16#[derive(Debug, Clone, Copy)]
17#[bitflags]
18#[repr(u8)]
19pub enum ContainsFlags {
20    AliasTVars,
21    InitTVars,
22}
23
24impl Type {
25    pub fn check_contains(&self, env: &Env, t: &Self) -> Result<()> {
26        if self.contains(env, t)? {
27            Ok(())
28        } else {
29            format_with_flags(PrintFlag::DerefTVars | PrintFlag::ReplacePrims, || {
30                bail!("type mismatch {self} does not contain {t}")
31            })
32        }
33    }
34
35    pub(super) fn contains_int(
36        &self,
37        flags: BitFlags<ContainsFlags>,
38        env: &Env,
39        hist: &mut FxHashMap<(usize, usize), bool>,
40        t: &Self,
41    ) -> Result<bool> {
42        if (self as *const Type) == (t as *const Type) {
43            return Ok(true);
44        }
45        match (self, t) {
46            (
47                Self::Ref { scope: s0, name: n0, params: p0 },
48                Self::Ref { scope: s1, name: n1, params: p1 },
49            ) if s0 == s1 && n0 == n1 => Ok(p0.len() == p1.len()
50                && p0
51                    .iter()
52                    .zip(p1.iter())
53                    .map(|(t0, t1)| t0.contains_int(flags, env, hist, t1))
54                    .collect::<Result<AndAc>>()?
55                    .0),
56            (t0 @ Self::Ref { .. }, t1) | (t0, t1 @ Self::Ref { .. }) => {
57                let t0 = t0.lookup_ref(env)?;
58                let t1 = t1.lookup_ref(env)?;
59                let t0_addr = (t0 as *const Type).addr();
60                let t1_addr = (t1 as *const Type).addr();
61                match hist.get(&(t0_addr, t1_addr)) {
62                    Some(r) => Ok(*r),
63                    None => {
64                        hist.insert((t0_addr, t1_addr), true);
65                        match t0.contains_int(flags, env, hist, t1) {
66                            Ok(r) => {
67                                hist.insert((t0_addr, t1_addr), r);
68                                Ok(r)
69                            }
70                            Err(e) => {
71                                hist.remove(&(t0_addr, t1_addr));
72                                Err(e)
73                            }
74                        }
75                    }
76                }
77            }
78            (Self::TVar(t0), Self::Bottom) => {
79                if let Some(_) = &*t0.read().typ.read() {
80                    return Ok(true);
81                }
82                if flags.contains(ContainsFlags::InitTVars) {
83                    *t0.read().typ.write() = Some(Self::Bottom);
84                }
85                Ok(true)
86            }
87            (Self::Bottom, Self::TVar(t0)) => {
88                if let Some(Type::Bottom) = &*t0.read().typ.read() {
89                    return Ok(true);
90                }
91                if flags.contains(ContainsFlags::InitTVars) {
92                    *t0.read().typ.write() = Some(Self::Bottom);
93                    return Ok(true);
94                }
95                Ok(false)
96            }
97            (Self::Bottom, Self::Bottom) => Ok(true),
98            (Self::Bottom, _) => Ok(false),
99            (_, Self::Bottom) => Ok(true),
100            (Self::TVar(t0), Self::Any) => {
101                if let Some(t0) = &*t0.read().typ.read() {
102                    return t0.contains_int(flags, env, hist, t);
103                }
104                if flags.contains(ContainsFlags::InitTVars) {
105                    *t0.read().typ.write() = Some(Self::Any);
106                }
107                Ok(true)
108            }
109            (Self::Any, _) => Ok(true),
110            (
111                Self::Abstract { id: id0, params: p0 },
112                Self::Abstract { id: id1, params: p1 },
113            ) => Ok(id0 == id1
114                && p0.len() == p1.len()
115                && p0
116                    .iter()
117                    .zip(p1.iter())
118                    .map(|(t0, t1)| t0.contains_int(flags, env, hist, t1))
119                    .collect::<Result<AndAc>>()?
120                    .0),
121            (Self::Primitive(p0), Self::Primitive(p1)) => Ok(p0.contains(*p1)),
122            (
123                Self::Primitive(p),
124                Self::Array(_) | Self::Tuple(_) | Self::Struct(_) | Self::Variant(_, _),
125            ) => Ok(p.contains(Typ::Array)),
126            (Self::Array(t0), Self::Array(t1)) => t0.contains_int(flags, env, hist, t1),
127            (Self::Array(t0), Self::Primitive(p)) if *p == BitFlags::from(Typ::Array) => {
128                t0.contains_int(flags, env, hist, &Type::Any)
129            }
130            (Self::Map { key: k0, value: v0 }, Self::Map { key: k1, value: v1 }) => {
131                Ok(k0.contains_int(flags, env, hist, k1)?
132                    && v0.contains_int(flags, env, hist, v1)?)
133            }
134            (Self::Primitive(p), Self::Map { .. }) => Ok(p.contains(Typ::Map)),
135            (Self::Map { key, value }, Self::Primitive(p))
136                if *p == BitFlags::from(Typ::Map) =>
137            {
138                Ok(key.contains_int(flags, env, hist, &Type::Any)?
139                    && value.contains_int(flags, env, hist, &Type::Any)?)
140            }
141            (Self::Primitive(p0), Self::Error(_)) => Ok(p0.contains(Typ::Error)),
142            (Self::Error(e), Self::Primitive(p)) if *p == BitFlags::from(Typ::Error) => {
143                e.contains_int(flags, env, hist, &Type::Any)
144            }
145            (Self::Error(e0), Self::Error(e1)) => e0.contains_int(flags, env, hist, e1),
146            (Self::Tuple(t0), Self::Tuple(t1))
147                if t0.as_ptr().addr() == t1.as_ptr().addr() =>
148            {
149                Ok(true)
150            }
151            (Self::Tuple(t0), Self::Tuple(t1)) => Ok(t0.len() == t1.len()
152                && t0
153                    .iter()
154                    .zip(t1.iter())
155                    .map(|(t0, t1)| t0.contains_int(flags, env, hist, t1))
156                    .collect::<Result<AndAc>>()?
157                    .0),
158            (Self::Struct(t0), Self::Struct(t1))
159                if t0.as_ptr().addr() == t1.as_ptr().addr() =>
160            {
161                Ok(true)
162            }
163            (Self::Struct(t0), Self::Struct(t1)) => {
164                Ok(t0.len() == t1.len() && {
165                    // struct types are always sorted by field name
166                    t0.iter()
167                        .zip(t1.iter())
168                        .map(|((n0, t0), (n1, t1))| {
169                            Ok(n0 == n1 && t0.contains_int(flags, env, hist, t1)?)
170                        })
171                        .collect::<Result<AndAc>>()?
172                        .0
173                })
174            }
175            (Self::Variant(tg0, t0), Self::Variant(tg1, t1))
176                if tg0.as_ptr() == tg1.as_ptr()
177                    && t0.as_ptr().addr() == t1.as_ptr().addr() =>
178            {
179                Ok(true)
180            }
181            (Self::Variant(tg0, t0), Self::Variant(tg1, t1)) => Ok(tg0 == tg1
182                && t0.len() == t1.len()
183                && t0
184                    .iter()
185                    .zip(t1.iter())
186                    .map(|(t0, t1)| t0.contains_int(flags, env, hist, t1))
187                    .collect::<Result<AndAc>>()?
188                    .0),
189            (Self::ByRef(t0), Self::ByRef(t1)) => t0.contains_int(flags, env, hist, t1),
190            (Self::TVar(t0), Self::TVar(t1))
191                if t0.addr() == t1.addr() || t0.read().id == t1.read().id =>
192            {
193                Ok(true)
194            }
195            (tt0 @ Self::TVar(t0), tt1 @ Self::TVar(t1)) => {
196                #[derive(Debug)]
197                enum Act {
198                    RightCopy,
199                    RightAlias,
200                    LeftAlias,
201                    LeftCopy,
202                }
203                if t0.would_cycle(tt1) || t1.would_cycle(tt0) {
204                    return Ok(true);
205                }
206                let act = {
207                    let t0 = t0.read();
208                    let t1 = t1.read();
209                    let addr0 = Arc::as_ptr(&t0.typ).addr();
210                    let addr1 = Arc::as_ptr(&t1.typ).addr();
211                    if addr0 == addr1 {
212                        return Ok(true);
213                    }
214                    if would_cycle_inner(addr0, tt1) || would_cycle_inner(addr1, tt0) {
215                        return Ok(true);
216                    }
217                    let t0i = t0.typ.read();
218                    let t1i = t1.typ.read();
219                    match (&*t0i, &*t1i) {
220                        (Some(t0), Some(t1)) => {
221                            return t0.contains_int(flags, env, hist, &*t1)
222                        }
223                        (None, None) => {
224                            if t0.frozen && t1.frozen {
225                                return Ok(true);
226                            }
227                            if t0.frozen {
228                                Act::RightAlias
229                            } else {
230                                Act::LeftAlias
231                            }
232                        }
233                        (Some(_), None) => Act::RightCopy,
234                        (None, Some(_)) => Act::LeftCopy,
235                    }
236                };
237                match act {
238                    Act::RightCopy if flags.contains(ContainsFlags::InitTVars) => {
239                        t1.copy(t0)
240                    }
241                    Act::RightAlias if flags.contains(ContainsFlags::AliasTVars) => {
242                        t1.alias(t0)
243                    }
244                    Act::LeftAlias if flags.contains(ContainsFlags::AliasTVars) => {
245                        t0.alias(t1)
246                    }
247                    Act::LeftCopy if flags.contains(ContainsFlags::InitTVars) => {
248                        t0.copy(t1)
249                    }
250                    Act::RightCopy | Act::RightAlias | Act::LeftAlias | Act::LeftCopy => {
251                        ()
252                    }
253                }
254                Ok(true)
255            }
256            (Self::TVar(t0), t1) if !t0.would_cycle(t1) => {
257                if let Some(t0) = &*t0.read().typ.read() {
258                    return t0.contains_int(flags, env, hist, t1);
259                }
260                if flags.contains(ContainsFlags::InitTVars) {
261                    *t0.read().typ.write() = Some(t1.clone());
262                }
263                Ok(true)
264            }
265            (t0, Self::TVar(t1)) if !t1.would_cycle(t0) => {
266                if let Some(t1) = &*t1.read().typ.read() {
267                    return t0.contains_int(flags, env, hist, t1);
268                }
269                if flags.contains(ContainsFlags::InitTVars) {
270                    *t1.read().typ.write() = Some(t0.clone());
271                }
272                Ok(true)
273            }
274            (Self::Set(s0), Self::Set(s1))
275                if s0.as_ptr().addr() == s1.as_ptr().addr() =>
276            {
277                Ok(true)
278            }
279            (t0, Self::Set(s)) => Ok(s
280                .iter()
281                .map(|t1| t0.contains_int(flags, env, hist, t1))
282                .collect::<Result<AndAc>>()?
283                .0),
284            (Self::Set(s), t) => Ok(s
285                .iter()
286                .fold(Ok::<_, anyhow::Error>(false), |acc, t0| {
287                    Ok(acc? || t0.contains_int(flags, env, hist, t)?)
288                })?
289                || t.iter_prims().fold(Ok::<_, anyhow::Error>(true), |acc, t1| {
290                    Ok(acc?
291                        && s.iter().fold(Ok::<_, anyhow::Error>(false), |acc, t0| {
292                            Ok(acc? || t0.contains_int(flags, env, hist, &t1)?)
293                        })?)
294                })?),
295            (Self::Fn(f0), Self::Fn(f1)) => {
296                Ok(f0.as_ptr() == f1.as_ptr() || f0.contains_int(flags, env, hist, f1)?)
297            }
298            (_, Self::Any)
299            | (Self::Abstract { .. }, _)
300            | (_, Self::Abstract { .. })
301            | (_, Self::TVar(_))
302            | (Self::TVar(_), _)
303            | (Self::Fn(_), _)
304            | (Self::ByRef(_), _)
305            | (_, Self::ByRef(_))
306            | (_, Self::Fn(_))
307            | (Self::Tuple(_), Self::Array(_))
308            | (Self::Tuple(_), Self::Primitive(_))
309            | (Self::Tuple(_), Self::Struct(_))
310            | (Self::Tuple(_), Self::Variant(_, _))
311            | (Self::Tuple(_), Self::Error(_))
312            | (Self::Tuple(_), Self::Map { .. })
313            | (Self::Array(_), Self::Primitive(_))
314            | (Self::Array(_), Self::Tuple(_))
315            | (Self::Array(_), Self::Struct(_))
316            | (Self::Array(_), Self::Variant(_, _))
317            | (Self::Array(_), Self::Error(_))
318            | (Self::Array(_), Self::Map { .. })
319            | (Self::Struct(_), Self::Array(_))
320            | (Self::Struct(_), Self::Primitive(_))
321            | (Self::Struct(_), Self::Tuple(_))
322            | (Self::Struct(_), Self::Variant(_, _))
323            | (Self::Struct(_), Self::Error(_))
324            | (Self::Struct(_), Self::Map { .. })
325            | (Self::Variant(_, _), Self::Array(_))
326            | (Self::Variant(_, _), Self::Struct(_))
327            | (Self::Variant(_, _), Self::Primitive(_))
328            | (Self::Variant(_, _), Self::Tuple(_))
329            | (Self::Variant(_, _), Self::Error(_))
330            | (Self::Variant(_, _), Self::Map { .. })
331            | (Self::Error(_), Self::Array(_))
332            | (Self::Error(_), Self::Primitive(_))
333            | (Self::Error(_), Self::Struct(_))
334            | (Self::Error(_), Self::Variant(_, _))
335            | (Self::Error(_), Self::Tuple(_))
336            | (Self::Error(_), Self::Map { .. })
337            | (Self::Map { .. }, Self::Array(_))
338            | (Self::Map { .. }, Self::Primitive(_))
339            | (Self::Map { .. }, Self::Struct(_))
340            | (Self::Map { .. }, Self::Variant(_, _))
341            | (Self::Map { .. }, Self::Tuple(_))
342            | (Self::Map { .. }, Self::Error(_)) => Ok(false),
343        }
344    }
345
346    pub fn contains(&self, env: &Env, t: &Self) -> Result<bool> {
347        self.contains_int(
348            ContainsFlags::AliasTVars | ContainsFlags::InitTVars,
349            env,
350            &mut LPooled::take(),
351            t,
352        )
353    }
354
355    pub fn contains_with_flags(
356        &self,
357        flags: BitFlags<ContainsFlags>,
358        env: &Env,
359        t: &Self,
360    ) -> Result<bool> {
361        self.contains_int(flags, env, &mut LPooled::take(), t)
362    }
363}