Skip to main content

graphix_compiler/typ/
matches.rs

1use crate::{
2    env::Env,
3    format_with_flags,
4    typ::{AbstractId, AndAc, RefHist, Type, TypeRef},
5    PrintFlag,
6};
7use ahash::{AHashMap, AHashSet};
8use anyhow::{bail, Result};
9use enumflags2::BitFlags;
10use netidx_value::Typ;
11use nohash::IntMap;
12use poolshark::local::LPooled;
13
14impl Type {
15    fn could_match_int(
16        &self,
17        env: &Env,
18        hist: &mut RefHist<AHashMap<(Option<usize>, Option<usize>), bool>>,
19        t: &Self,
20    ) -> Result<bool> {
21        let fl = BitFlags::empty();
22        match (self, t) {
23            (
24                Self::Ref(TypeRef { scope: s0, name: n0, params: p0, .. }),
25                Self::Ref(TypeRef { scope: s1, name: n1, params: p1, .. }),
26            ) if s0 == s1 && n0 == n1 => Ok(p0.len() == p1.len()
27                && p0
28                    .iter()
29                    .zip(p1.iter())
30                    .map(|(t0, t1)| t0.could_match_int(env, hist, t1))
31                    .collect::<Result<AndAc>>()?
32                    .0),
33            (t0 @ Self::Ref(TypeRef { .. }), t1)
34            | (t0, t1 @ Self::Ref(TypeRef { .. })) => {
35                let t0_id = hist.ref_id(t0, env);
36                let t1_id = hist.ref_id(t1, env);
37                let t0 = t0.lookup_ref(env)?;
38                let t1 = t1.lookup_ref(env)?;
39                match hist.get(&(t0_id, t1_id)) {
40                    Some(r) => Ok(*r),
41                    None => {
42                        hist.insert((t0_id, t1_id), true);
43                        let r = t0.could_match_int(env, hist, &t1);
44                        hist.remove(&(t0_id, t1_id));
45                        r
46                    }
47                }
48            }
49            (t0, Self::Primitive(s)) => {
50                for t1 in s.iter() {
51                    if t0.contains_int(fl, env, hist, &Type::Primitive(t1.into()))? {
52                        return Ok(true);
53                    }
54                }
55                Ok(false)
56            }
57            (Type::Primitive(p), Type::Error(_)) => Ok(p.contains(Typ::Error)),
58            (Type::Error(t0), Type::Error(t1)) => t0.could_match_int(env, hist, t1),
59            (Type::Array(t0), Type::Array(t1)) => t0.could_match_int(env, hist, t1),
60            (Type::Primitive(p), Type::Array(_)) => Ok(p.contains(Typ::Array)),
61            (Type::Map { key: k0, value: v0 }, Type::Map { key: k1, value: v1 }) => {
62                Ok(k0.could_match_int(env, hist, k1)?
63                    && v0.could_match_int(env, hist, v1)?)
64            }
65            (Type::Primitive(p), Type::Map { .. }) => Ok(p.contains(Typ::Map)),
66            (Type::Tuple(ts0), Type::Tuple(ts1)) => Ok(ts0.len() == ts1.len()
67                && ts0
68                    .iter()
69                    .zip(ts1.iter())
70                    .map(|(t0, t1)| t0.could_match_int(env, hist, t1))
71                    .collect::<Result<AndAc>>()?
72                    .0),
73            (Type::Struct(ts0), Type::Struct(ts1)) => Ok(ts0.len() == ts1.len()
74                && ts0
75                    .iter()
76                    .zip(ts1.iter())
77                    .map(|((n0, t0), (n1, t1))| {
78                        Ok(n0 == n1 && t0.could_match_int(env, hist, t1)?)
79                    })
80                    .collect::<Result<AndAc>>()?
81                    .0),
82            (Type::Variant(n0, ts0), Type::Variant(n1, ts1)) => Ok(ts0.len()
83                == ts1.len()
84                && n0 == n1
85                && ts0
86                    .iter()
87                    .zip(ts1.iter())
88                    .map(|(t0, t1)| t0.could_match_int(env, hist, t1))
89                    .collect::<Result<AndAc>>()?
90                    .0),
91            (Type::ByRef(t0), Type::ByRef(t1)) => t0.could_match_int(env, hist, t1),
92            (t0, Self::Set(ts)) => {
93                for t1 in ts.iter() {
94                    if t0.could_match_int(env, hist, t1)? {
95                        return Ok(true);
96                    }
97                }
98                Ok(false)
99            }
100            (Type::Set(ts), t1) => {
101                for t0 in ts.iter() {
102                    if t0.could_match_int(env, hist, t1)? {
103                        return Ok(true);
104                    }
105                }
106                Ok(false)
107            }
108            (Type::TVar(t0), t1) => match &*t0.read().typ.read() {
109                Some(t0) => t0.could_match_int(env, hist, t1),
110                None => Ok(true),
111            },
112            (t0, Type::TVar(t1)) => match &*t1.read().typ.read() {
113                Some(t1) => t0.could_match_int(env, hist, t1),
114                None => Ok(true),
115            },
116            (
117                Type::Abstract { id: id0, params: p0 },
118                Type::Abstract { id: id1, params: p1 },
119            ) => Ok(id0 == id1
120                && p0.len() == p1.len()
121                && p0
122                    .iter()
123                    .zip(p1.iter())
124                    .map(|(t0, t1)| t0.could_match_int(env, hist, t1))
125                    .collect::<Result<AndAc>>()?
126                    .0),
127            (_, Type::Bottom) => Ok(true),
128            (Type::Bottom, _) => Ok(false),
129            (Type::Any, _) | (_, Type::Any) => Ok(true),
130            (Type::Abstract { .. }, _)
131            | (_, Type::Abstract { .. })
132            | (Type::Fn(_), _)
133            | (_, Type::Fn(_))
134            | (Type::Tuple(_), _)
135            | (_, Type::Tuple(_))
136            | (Type::Struct(_), _)
137            | (_, Type::Struct(_))
138            | (Type::Variant(_, _), _)
139            | (_, Type::Variant(_, _))
140            | (Type::ByRef(_), _)
141            | (_, Type::ByRef(_))
142            | (Type::Array(_), _)
143            | (_, Type::Array(_))
144            | (_, Type::Map { .. })
145            | (Type::Map { .. }, _) => Ok(false),
146        }
147    }
148
149    pub fn could_match(&self, env: &Env, t: &Self) -> Result<bool> {
150        self.could_match_int(env, &mut RefHist::new(LPooled::take()), t)
151    }
152
153    pub fn sig_matches(
154        &self,
155        env: &Env,
156        impl_type: &Self,
157        adts: &IntMap<AbstractId, Type>,
158    ) -> Result<()> {
159        self.sig_matches_int(
160            env,
161            impl_type,
162            &mut LPooled::take(),
163            &mut RefHist::new(LPooled::take()),
164            adts,
165        )
166    }
167
168    pub(super) fn sig_matches_int(
169        &self,
170        env: &Env,
171        impl_type: &Self,
172        tvar_map: &mut IntMap<usize, Type>,
173        hist: &mut RefHist<AHashSet<(Option<usize>, Option<usize>)>>,
174        adts: &IntMap<AbstractId, Type>,
175    ) -> Result<()> {
176        if (self as *const Type) == (impl_type as *const Type) {
177            return Ok(());
178        }
179        match (self, impl_type) {
180            (Self::Bottom, Self::Bottom) => Ok(()),
181            (Self::Any, Self::Any) => Ok(()),
182            (Self::Primitive(p0), Self::Primitive(p1)) if p0 == p1 => Ok(()),
183            (
184                Self::Ref(TypeRef { scope: s0, name: n0, params: p0, .. }),
185                Self::Ref(TypeRef { scope: s1, name: n1, params: p1, .. }),
186            ) if s0 == s1 && n0 == n1 && p0.len() == p1.len() => {
187                for (t0, t1) in p0.iter().zip(p1.iter()) {
188                    t0.sig_matches_int(env, t1, tvar_map, hist, adts)?;
189                }
190                Ok(())
191            }
192            (t0 @ Self::Ref(TypeRef { .. }), t1)
193            | (t0, t1 @ Self::Ref(TypeRef { .. })) => {
194                let t0_id = hist.ref_id(t0, env);
195                let t1_id = hist.ref_id(t1, env);
196                let t0 = t0.lookup_ref(env)?;
197                let t1 = t1.lookup_ref(env)?;
198                if hist.contains(&(t0_id, t1_id)) {
199                    Ok(())
200                } else {
201                    hist.insert((t0_id, t1_id));
202                    let r = t0.sig_matches_int(env, &t1, tvar_map, hist, adts);
203                    hist.remove(&(t0_id, t1_id));
204                    r
205                }
206            }
207            (Self::Fn(f0), Self::Fn(f1)) => {
208                f0.sig_matches_int(env, f1, tvar_map, hist, adts)?;
209                f0.merge_lambda_ids(f1);
210                Ok(())
211            }
212            (Self::Set(s0), Self::Set(s1)) if s0.len() == s1.len() => {
213                for (t0, t1) in s0.iter().zip(s1.iter()) {
214                    t0.sig_matches_int(env, t1, tvar_map, hist, adts)?;
215                }
216                Ok(())
217            }
218            (Self::Error(e0), Self::Error(e1)) => {
219                e0.sig_matches_int(env, e1, tvar_map, hist, adts)
220            }
221            (Self::Array(a0), Self::Array(a1)) => {
222                a0.sig_matches_int(env, a1, tvar_map, hist, adts)
223            }
224            (Self::ByRef(b0), Self::ByRef(b1)) => {
225                b0.sig_matches_int(env, b1, tvar_map, hist, adts)
226            }
227            (Self::Tuple(t0), Self::Tuple(t1)) if t0.len() == t1.len() => {
228                for (t0, t1) in t0.iter().zip(t1.iter()) {
229                    t0.sig_matches_int(env, t1, tvar_map, hist, adts)?;
230                }
231                Ok(())
232            }
233            (Self::Struct(s0), Self::Struct(s1)) if s0.len() == s1.len() => {
234                for ((n0, t0), (n1, t1)) in s0.iter().zip(s1.iter()) {
235                    if n0 != n1 {
236                        format_with_flags(PrintFlag::DerefTVars, || {
237                            bail!("struct field name mismatch: {n0} vs {n1}")
238                        })?
239                    }
240                    t0.sig_matches_int(env, t1, tvar_map, hist, adts)?;
241                }
242                Ok(())
243            }
244            (Self::Variant(tag0, t0), Self::Variant(tag1, t1))
245                if tag0 == tag1 && t0.len() == t1.len() =>
246            {
247                for (t0, t1) in t0.iter().zip(t1.iter()) {
248                    t0.sig_matches_int(env, t1, tvar_map, hist, adts)?;
249                }
250                Ok(())
251            }
252            (Self::Map { key: k0, value: v0 }, Self::Map { key: k1, value: v1 }) => {
253                k0.sig_matches_int(env, k1, tvar_map, hist, adts)?;
254                v0.sig_matches_int(env, v1, tvar_map, hist, adts)
255            }
256            (Self::Abstract { .. }, Self::Abstract { .. }) => Ok(()),
257            (Self::Abstract { id, params: _ }, t0) => match adts.get(id) {
258                None => Ok(()), // it's in another module
259                Some(t1) => {
260                    if t0 != t1 {
261                        format_with_flags(PrintFlag::DerefTVars, || {
262                            bail!("abstract type mismatch {t0} != {t1}")
263                        })?
264                    }
265                    Ok(())
266                }
267            },
268            (Self::TVar(sig_tv), Self::TVar(impl_tv)) if sig_tv != impl_tv => {
269                format_with_flags(PrintFlag::DerefTVars, || {
270                    bail!("signature type variable {sig_tv} does not match implementation {impl_tv}")
271                })
272            }
273            (sig_type, Self::TVar(impl_tv)) => {
274                let impl_tv_addr = impl_tv.inner_addr();
275                match tvar_map.get(&impl_tv_addr) {
276                    Some(prev_sig_type) => {
277                        let matches = match (sig_type, prev_sig_type) {
278                            (Type::TVar(tv0), Type::TVar(tv1)) => {
279                                tv0.inner_addr() == tv1.inner_addr()
280                            }
281                            _ => sig_type == prev_sig_type,
282                        };
283                        if matches {
284                            Ok(())
285                        } else {
286                            format_with_flags(PrintFlag::DerefTVars, || {
287                                bail!(
288                                    "type variable usage mismatch: expected {prev_sig_type}, got {sig_type}"
289                                )
290                            })
291                        }
292                    }
293                    None => {
294                        tvar_map.insert(impl_tv_addr, sig_type.clone());
295                        Ok(())
296                    }
297                }
298            }
299            (Self::TVar(sig_tv), impl_type) => {
300                format_with_flags(PrintFlag::DerefTVars, || {
301                    bail!("signature has type variable '{sig_tv} where implementation has {impl_type}")
302                })
303            }
304            (sig_type, impl_type) => format_with_flags(PrintFlag::DerefTVars, || {
305                bail!("type mismatch: signature has {sig_type}, implementation has {impl_type}")
306            }),
307        }
308    }
309}