Skip to main content

graphix_compiler/typ/
matches.rs

1use crate::{
2    env::Env,
3    format_with_flags,
4    typ::{AbstractId, AndAc, RefHist, Type},
5    PrintFlag,
6};
7use anyhow::{bail, Result};
8use enumflags2::BitFlags;
9use fxhash::{FxHashMap, FxHashSet};
10use netidx_value::Typ;
11use poolshark::local::LPooled;
12
13impl Type {
14    fn could_match_int(
15        &self,
16        env: &Env,
17        hist: &mut RefHist<FxHashMap<(Option<usize>, Option<usize>), bool>>,
18        t: &Self,
19    ) -> Result<bool> {
20        let fl = BitFlags::empty();
21        match (self, t) {
22            (
23                Self::Ref { scope: s0, name: n0, params: p0 },
24                Self::Ref { scope: s1, name: n1, params: p1 },
25            ) if s0 == s1 && n0 == n1 => Ok(p0.len() == p1.len()
26                && p0
27                    .iter()
28                    .zip(p1.iter())
29                    .map(|(t0, t1)| t0.could_match_int(env, hist, t1))
30                    .collect::<Result<AndAc>>()?
31                    .0),
32            (t0 @ Self::Ref { .. }, t1) | (t0, t1 @ Self::Ref { .. }) => {
33                let t0_id = hist.ref_id(t0, env);
34                let t1_id = hist.ref_id(t1, env);
35                let t0 = t0.lookup_ref(env)?;
36                let t1 = t1.lookup_ref(env)?;
37                match hist.get(&(t0_id, t1_id)) {
38                    Some(r) => Ok(*r),
39                    None => {
40                        hist.insert((t0_id, t1_id), true);
41                        let r = t0.could_match_int(env, hist, &t1);
42                        hist.remove(&(t0_id, t1_id));
43                        r
44                    }
45                }
46            }
47            (t0, Self::Primitive(s)) => {
48                for t1 in s.iter() {
49                    if t0.contains_int(fl, env, hist, &Type::Primitive(t1.into()))? {
50                        return Ok(true);
51                    }
52                }
53                Ok(false)
54            }
55            (Type::Primitive(p), Type::Error(_)) => Ok(p.contains(Typ::Error)),
56            (Type::Error(t0), Type::Error(t1)) => t0.could_match_int(env, hist, t1),
57            (Type::Array(t0), Type::Array(t1)) => t0.could_match_int(env, hist, t1),
58            (Type::Primitive(p), Type::Array(_)) => Ok(p.contains(Typ::Array)),
59            (Type::Map { key: k0, value: v0 }, Type::Map { key: k1, value: v1 }) => {
60                Ok(k0.could_match_int(env, hist, k1)?
61                    && v0.could_match_int(env, hist, v1)?)
62            }
63            (Type::Primitive(p), Type::Map { .. }) => Ok(p.contains(Typ::Map)),
64            (Type::Tuple(ts0), Type::Tuple(ts1)) => Ok(ts0.len() == ts1.len()
65                && ts0
66                    .iter()
67                    .zip(ts1.iter())
68                    .map(|(t0, t1)| t0.could_match_int(env, hist, t1))
69                    .collect::<Result<AndAc>>()?
70                    .0),
71            (Type::Struct(ts0), Type::Struct(ts1)) => Ok(ts0.len() == ts1.len()
72                && ts0
73                    .iter()
74                    .zip(ts1.iter())
75                    .map(|((n0, t0), (n1, t1))| {
76                        Ok(n0 == n1 && t0.could_match_int(env, hist, t1)?)
77                    })
78                    .collect::<Result<AndAc>>()?
79                    .0),
80            (Type::Variant(n0, ts0), Type::Variant(n1, ts1)) => Ok(ts0.len()
81                == ts1.len()
82                && n0 == n1
83                && ts0
84                    .iter()
85                    .zip(ts1.iter())
86                    .map(|(t0, t1)| t0.could_match_int(env, hist, t1))
87                    .collect::<Result<AndAc>>()?
88                    .0),
89            (Type::ByRef(t0), Type::ByRef(t1)) => t0.could_match_int(env, hist, t1),
90            (t0, Self::Set(ts)) => {
91                for t1 in ts.iter() {
92                    if t0.could_match_int(env, hist, t1)? {
93                        return Ok(true);
94                    }
95                }
96                Ok(false)
97            }
98            (Type::Set(ts), t1) => {
99                for t0 in ts.iter() {
100                    if t0.could_match_int(env, hist, t1)? {
101                        return Ok(true);
102                    }
103                }
104                Ok(false)
105            }
106            (Type::TVar(t0), t1) => match &*t0.read().typ.read() {
107                Some(t0) => t0.could_match_int(env, hist, t1),
108                None => Ok(true),
109            },
110            (t0, Type::TVar(t1)) => match &*t1.read().typ.read() {
111                Some(t1) => t0.could_match_int(env, hist, t1),
112                None => Ok(true),
113            },
114            (
115                Type::Abstract { id: id0, params: p0 },
116                Type::Abstract { id: id1, params: p1 },
117            ) => Ok(id0 == id1
118                && p0.len() == p1.len()
119                && p0
120                    .iter()
121                    .zip(p1.iter())
122                    .map(|(t0, t1)| t0.could_match_int(env, hist, t1))
123                    .collect::<Result<AndAc>>()?
124                    .0),
125            (_, Type::Bottom) => Ok(true),
126            (Type::Bottom, _) => Ok(false),
127            (Type::Any, _) | (_, Type::Any) => Ok(true),
128            (Type::Abstract { .. }, _)
129            | (_, Type::Abstract { .. })
130            | (Type::Fn(_), _)
131            | (_, Type::Fn(_))
132            | (Type::Tuple(_), _)
133            | (_, Type::Tuple(_))
134            | (Type::Struct(_), _)
135            | (_, Type::Struct(_))
136            | (Type::Variant(_, _), _)
137            | (_, Type::Variant(_, _))
138            | (Type::ByRef(_), _)
139            | (_, Type::ByRef(_))
140            | (Type::Array(_), _)
141            | (_, Type::Array(_))
142            | (_, Type::Map { .. })
143            | (Type::Map { .. }, _) => Ok(false),
144        }
145    }
146
147    pub fn could_match(&self, env: &Env, t: &Self) -> Result<bool> {
148        self.could_match_int(env, &mut RefHist::new(LPooled::take()), t)
149    }
150
151    pub fn sig_matches(
152        &self,
153        env: &Env,
154        impl_type: &Self,
155        adts: &FxHashMap<AbstractId, Type>,
156    ) -> Result<()> {
157        self.sig_matches_int(
158            env,
159            impl_type,
160            &mut LPooled::take(),
161            &mut RefHist::new(LPooled::take()),
162            adts,
163        )
164    }
165
166    pub(super) fn sig_matches_int(
167        &self,
168        env: &Env,
169        impl_type: &Self,
170        tvar_map: &mut FxHashMap<usize, Type>,
171        hist: &mut RefHist<FxHashSet<(Option<usize>, Option<usize>)>>,
172        adts: &FxHashMap<AbstractId, Type>,
173    ) -> Result<()> {
174        if (self as *const Type) == (impl_type as *const Type) {
175            return Ok(());
176        }
177        match (self, impl_type) {
178            (Self::Bottom, Self::Bottom) => Ok(()),
179            (Self::Any, Self::Any) => Ok(()),
180            (Self::Primitive(p0), Self::Primitive(p1)) if p0 == p1 => Ok(()),
181            (
182                Self::Ref { scope: s0, name: n0, params: p0 },
183                Self::Ref { scope: s1, name: n1, params: p1 },
184            ) if s0 == s1 && n0 == n1 && p0.len() == p1.len() => {
185                for (t0, t1) in p0.iter().zip(p1.iter()) {
186                    t0.sig_matches_int(env, t1, tvar_map, hist, adts)?;
187                }
188                Ok(())
189            }
190            (t0 @ Self::Ref { .. }, t1) | (t0, t1 @ Self::Ref { .. }) => {
191                let t0_id = hist.ref_id(t0, env);
192                let t1_id = hist.ref_id(t1, env);
193                let t0 = t0.lookup_ref(env)?;
194                let t1 = t1.lookup_ref(env)?;
195                if hist.contains(&(t0_id, t1_id)) {
196                    Ok(())
197                } else {
198                    hist.insert((t0_id, t1_id));
199                    let r = t0.sig_matches_int(env, &t1, tvar_map, hist, adts);
200                    hist.remove(&(t0_id, t1_id));
201                    r
202                }
203            }
204            (Self::Fn(f0), Self::Fn(f1)) => {
205                f0.sig_matches_int(env, f1, tvar_map, hist, adts)
206            }
207            (Self::Set(s0), Self::Set(s1)) if s0.len() == s1.len() => {
208                for (t0, t1) in s0.iter().zip(s1.iter()) {
209                    t0.sig_matches_int(env, t1, tvar_map, hist, adts)?;
210                }
211                Ok(())
212            }
213            (Self::Error(e0), Self::Error(e1)) => {
214                e0.sig_matches_int(env, e1, tvar_map, hist, adts)
215            }
216            (Self::Array(a0), Self::Array(a1)) => {
217                a0.sig_matches_int(env, a1, tvar_map, hist, adts)
218            }
219            (Self::ByRef(b0), Self::ByRef(b1)) => {
220                b0.sig_matches_int(env, b1, tvar_map, hist, adts)
221            }
222            (Self::Tuple(t0), Self::Tuple(t1)) if t0.len() == t1.len() => {
223                for (t0, t1) in t0.iter().zip(t1.iter()) {
224                    t0.sig_matches_int(env, t1, tvar_map, hist, adts)?;
225                }
226                Ok(())
227            }
228            (Self::Struct(s0), Self::Struct(s1)) if s0.len() == s1.len() => {
229                for ((n0, t0), (n1, t1)) in s0.iter().zip(s1.iter()) {
230                    if n0 != n1 {
231                        format_with_flags(PrintFlag::DerefTVars, || {
232                            bail!("struct field name mismatch: {n0} vs {n1}")
233                        })?
234                    }
235                    t0.sig_matches_int(env, t1, tvar_map, hist, adts)?;
236                }
237                Ok(())
238            }
239            (Self::Variant(tag0, t0), Self::Variant(tag1, t1))
240                if tag0 == tag1 && t0.len() == t1.len() =>
241            {
242                for (t0, t1) in t0.iter().zip(t1.iter()) {
243                    t0.sig_matches_int(env, t1, tvar_map, hist, adts)?;
244                }
245                Ok(())
246            }
247            (Self::Map { key: k0, value: v0 }, Self::Map { key: k1, value: v1 }) => {
248                k0.sig_matches_int(env, k1, tvar_map, hist, adts)?;
249                v0.sig_matches_int(env, v1, tvar_map, hist, adts)
250            }
251            (Self::Abstract { .. }, Self::Abstract { .. }) => {
252                bail!("abstract types must have a concrete definition in the implementation")
253            }
254            (Self::Abstract { id, params: _ }, t0) => match adts.get(id) {
255                None => bail!("undefined abstract type"),
256                Some(t1) => {
257                    if t0 != t1 {
258                        format_with_flags(PrintFlag::DerefTVars, || {
259                            bail!("abstract type mismatch {t0} != {t1}")
260                        })?
261                    }
262                    Ok(())
263                }
264            },
265            (Self::TVar(sig_tv), Self::TVar(impl_tv)) if sig_tv != impl_tv => {
266                format_with_flags(PrintFlag::DerefTVars, || {
267                    bail!("signature type variable {sig_tv} does not match implementation {impl_tv}")
268                })
269            }
270            (sig_type, Self::TVar(impl_tv)) => {
271                let impl_tv_addr = impl_tv.inner_addr();
272                match tvar_map.get(&impl_tv_addr) {
273                    Some(prev_sig_type) => {
274                        let matches = match (sig_type, prev_sig_type) {
275                            (Type::TVar(tv0), Type::TVar(tv1)) => {
276                                tv0.inner_addr() == tv1.inner_addr()
277                            }
278                            _ => sig_type == prev_sig_type,
279                        };
280                        if matches {
281                            Ok(())
282                        } else {
283                            format_with_flags(PrintFlag::DerefTVars, || {
284                                bail!(
285                                    "type variable usage mismatch: expected {prev_sig_type}, got {sig_type}"
286                                )
287                            })
288                        }
289                    }
290                    None => {
291                        tvar_map.insert(impl_tv_addr, sig_type.clone());
292                        Ok(())
293                    }
294                }
295            }
296            (Self::TVar(sig_tv), impl_type) => {
297                format_with_flags(PrintFlag::DerefTVars, || {
298                    bail!("signature has type variable '{sig_tv} where implementation has {impl_type}")
299                })
300            }
301            (sig_type, impl_type) => format_with_flags(PrintFlag::DerefTVars, || {
302                bail!("type mismatch: signature has {sig_type}, implementation has {impl_type}")
303            }),
304        }
305    }
306}