Skip to main content

graphix_compiler/typ/
matches.rs

1use crate::{
2    env::Env,
3    format_with_flags,
4    typ::{AbstractId, AndAc, Type},
5    PrintFlag,
6};
7use anyhow::{bail, Result};
8use enumflags2::BitFlags;
9use fxhash::{FxHashMap, FxHashSet};
10use netidx_value::Typ;
11use poolshark::local::LPooled;
12
13impl Type {
14    fn could_match_int(
15        &self,
16        env: &Env,
17        hist: &mut FxHashMap<(usize, usize), bool>,
18        t: &Self,
19    ) -> Result<bool> {
20        let fl = BitFlags::empty();
21        match (self, t) {
22            (
23                Self::Ref { scope: s0, name: n0, params: p0 },
24                Self::Ref { scope: s1, name: n1, params: p1 },
25            ) if s0 == s1 && n0 == n1 => Ok(p0.len() == p1.len()
26                && p0
27                    .iter()
28                    .zip(p1.iter())
29                    .map(|(t0, t1)| t0.could_match_int(env, hist, t1))
30                    .collect::<Result<AndAc>>()?
31                    .0),
32            (t0 @ Self::Ref { .. }, t1) | (t0, t1 @ Self::Ref { .. }) => {
33                let t0 = t0.lookup_ref(env)?;
34                let t1 = t1.lookup_ref(env)?;
35                let t0_addr = (t0 as *const Type).addr();
36                let t1_addr = (t1 as *const Type).addr();
37                match hist.get(&(t0_addr, t1_addr)) {
38                    Some(r) => Ok(*r),
39                    None => {
40                        hist.insert((t0_addr, t1_addr), true);
41                        match t0.could_match_int(env, hist, t1) {
42                            Ok(r) => {
43                                hist.insert((t0_addr, t1_addr), r);
44                                Ok(r)
45                            }
46                            Err(e) => {
47                                hist.remove(&(t0_addr, t1_addr));
48                                Err(e)
49                            }
50                        }
51                    }
52                }
53            }
54            (t0, Self::Primitive(s)) => {
55                for t1 in s.iter() {
56                    if t0.contains_int(fl, env, hist, &Type::Primitive(t1.into()))? {
57                        return Ok(true);
58                    }
59                }
60                Ok(false)
61            }
62            (Type::Primitive(p), Type::Error(_)) => Ok(p.contains(Typ::Error)),
63            (Type::Error(t0), Type::Error(t1)) => t0.could_match_int(env, hist, t1),
64            (Type::Array(t0), Type::Array(t1)) => t0.could_match_int(env, hist, t1),
65            (Type::Primitive(p), Type::Array(_)) => Ok(p.contains(Typ::Array)),
66            (Type::Map { key: k0, value: v0 }, Type::Map { key: k1, value: v1 }) => {
67                Ok(k0.could_match_int(env, hist, k1)?
68                    && v0.could_match_int(env, hist, v1)?)
69            }
70            (Type::Primitive(p), Type::Map { .. }) => Ok(p.contains(Typ::Map)),
71            (Type::Tuple(ts0), Type::Tuple(ts1)) => Ok(ts0.len() == ts1.len()
72                && ts0
73                    .iter()
74                    .zip(ts1.iter())
75                    .map(|(t0, t1)| t0.could_match_int(env, hist, t1))
76                    .collect::<Result<AndAc>>()?
77                    .0),
78            (Type::Struct(ts0), Type::Struct(ts1)) => Ok(ts0.len() == ts1.len()
79                && ts0
80                    .iter()
81                    .zip(ts1.iter())
82                    .map(|((n0, t0), (n1, t1))| {
83                        Ok(n0 == n1 && t0.could_match_int(env, hist, t1)?)
84                    })
85                    .collect::<Result<AndAc>>()?
86                    .0),
87            (Type::Variant(n0, ts0), Type::Variant(n1, ts1)) => Ok(ts0.len()
88                == ts1.len()
89                && n0 == n1
90                && ts0
91                    .iter()
92                    .zip(ts1.iter())
93                    .map(|(t0, t1)| t0.could_match_int(env, hist, t1))
94                    .collect::<Result<AndAc>>()?
95                    .0),
96            (Type::ByRef(t0), Type::ByRef(t1)) => t0.could_match_int(env, hist, t1),
97            (t0, Self::Set(ts)) => {
98                for t1 in ts.iter() {
99                    if t0.could_match_int(env, hist, t1)? {
100                        return Ok(true);
101                    }
102                }
103                Ok(false)
104            }
105            (Type::Set(ts), t1) => {
106                for t0 in ts.iter() {
107                    if t0.could_match_int(env, hist, t1)? {
108                        return Ok(true);
109                    }
110                }
111                Ok(false)
112            }
113            (Type::TVar(t0), t1) => match &*t0.read().typ.read() {
114                Some(t0) => t0.could_match_int(env, hist, t1),
115                None => Ok(true),
116            },
117            (t0, Type::TVar(t1)) => match &*t1.read().typ.read() {
118                Some(t1) => t0.could_match_int(env, hist, t1),
119                None => Ok(true),
120            },
121            (
122                Type::Abstract { id: id0, params: p0 },
123                Type::Abstract { id: id1, params: p1 },
124            ) => Ok(id0 == id1
125                && p0.len() == p1.len()
126                && p0
127                    .iter()
128                    .zip(p1.iter())
129                    .map(|(t0, t1)| t0.could_match_int(env, hist, t1))
130                    .collect::<Result<AndAc>>()?
131                    .0),
132            (_, Type::Bottom) => Ok(true),
133            (Type::Bottom, _) => Ok(false),
134            (Type::Any, _) | (_, Type::Any) => Ok(true),
135            (Type::Abstract { .. }, _)
136            | (_, Type::Abstract { .. })
137            | (Type::Fn(_), _)
138            | (_, Type::Fn(_))
139            | (Type::Tuple(_), _)
140            | (_, Type::Tuple(_))
141            | (Type::Struct(_), _)
142            | (_, Type::Struct(_))
143            | (Type::Variant(_, _), _)
144            | (_, Type::Variant(_, _))
145            | (Type::ByRef(_), _)
146            | (_, Type::ByRef(_))
147            | (Type::Array(_), _)
148            | (_, Type::Array(_))
149            | (_, Type::Map { .. })
150            | (Type::Map { .. }, _) => Ok(false),
151        }
152    }
153
154    pub fn could_match(&self, env: &Env, t: &Self) -> Result<bool> {
155        self.could_match_int(env, &mut LPooled::take(), t)
156    }
157
158    pub fn sig_matches(
159        &self,
160        env: &Env,
161        impl_type: &Self,
162        adts: &FxHashMap<AbstractId, Type>,
163    ) -> Result<()> {
164        self.sig_matches_int(
165            env,
166            impl_type,
167            &mut LPooled::take(),
168            &mut LPooled::take(),
169            adts,
170        )
171    }
172
173    pub(super) fn sig_matches_int(
174        &self,
175        env: &Env,
176        impl_type: &Self,
177        tvar_map: &mut FxHashMap<usize, Type>,
178        hist: &mut FxHashSet<(usize, usize)>,
179        adts: &FxHashMap<AbstractId, Type>,
180    ) -> Result<()> {
181        if (self as *const Type) == (impl_type as *const Type) {
182            return Ok(());
183        }
184        match (self, impl_type) {
185            (Self::Bottom, Self::Bottom) => Ok(()),
186            (Self::Any, Self::Any) => Ok(()),
187            (Self::Primitive(p0), Self::Primitive(p1)) if p0 == p1 => Ok(()),
188            (
189                Self::Ref { scope: s0, name: n0, params: p0 },
190                Self::Ref { scope: s1, name: n1, params: p1 },
191            ) if s0 == s1 && n0 == n1 && p0.len() == p1.len() => {
192                for (t0, t1) in p0.iter().zip(p1.iter()) {
193                    t0.sig_matches_int(env, t1, tvar_map, hist, adts)?;
194                }
195                Ok(())
196            }
197            (t0 @ Self::Ref { .. }, t1) | (t0, t1 @ Self::Ref { .. }) => {
198                let t0 = t0.lookup_ref(env)?;
199                let t1 = t1.lookup_ref(env)?;
200                let t0_addr = (t0 as *const Type).addr();
201                let t1_addr = (t1 as *const Type).addr();
202                if hist.contains(&(t0_addr, t1_addr)) {
203                    Ok(())
204                } else {
205                    hist.insert((t0_addr, t1_addr));
206                    t0.sig_matches_int(env, t1, tvar_map, hist, adts)
207                }
208            }
209            (Self::Fn(f0), Self::Fn(f1)) => {
210                f0.sig_matches_int(env, f1, tvar_map, hist, adts)
211            }
212            (Self::Set(s0), Self::Set(s1)) if s0.len() == s1.len() => {
213                for (t0, t1) in s0.iter().zip(s1.iter()) {
214                    t0.sig_matches_int(env, t1, tvar_map, hist, adts)?;
215                }
216                Ok(())
217            }
218            (Self::Error(e0), Self::Error(e1)) => {
219                e0.sig_matches_int(env, e1, tvar_map, hist, adts)
220            }
221            (Self::Array(a0), Self::Array(a1)) => {
222                a0.sig_matches_int(env, a1, tvar_map, hist, adts)
223            }
224            (Self::ByRef(b0), Self::ByRef(b1)) => {
225                b0.sig_matches_int(env, b1, tvar_map, hist, adts)
226            }
227            (Self::Tuple(t0), Self::Tuple(t1)) if t0.len() == t1.len() => {
228                for (t0, t1) in t0.iter().zip(t1.iter()) {
229                    t0.sig_matches_int(env, t1, tvar_map, hist, adts)?;
230                }
231                Ok(())
232            }
233            (Self::Struct(s0), Self::Struct(s1)) if s0.len() == s1.len() => {
234                for ((n0, t0), (n1, t1)) in s0.iter().zip(s1.iter()) {
235                    if n0 != n1 {
236                        format_with_flags(PrintFlag::DerefTVars, || {
237                            bail!("struct field name mismatch: {n0} vs {n1}")
238                        })?
239                    }
240                    t0.sig_matches_int(env, t1, tvar_map, hist, adts)?;
241                }
242                Ok(())
243            }
244            (Self::Variant(tag0, t0), Self::Variant(tag1, t1))
245                if tag0 == tag1 && t0.len() == t1.len() =>
246            {
247                for (t0, t1) in t0.iter().zip(t1.iter()) {
248                    t0.sig_matches_int(env, t1, tvar_map, hist, adts)?;
249                }
250                Ok(())
251            }
252            (Self::Map { key: k0, value: v0 }, Self::Map { key: k1, value: v1 }) => {
253                k0.sig_matches_int(env, k1, tvar_map, hist, adts)?;
254                v0.sig_matches_int(env, v1, tvar_map, hist, adts)
255            }
256            (Self::Abstract { .. }, Self::Abstract { .. }) => {
257                bail!("abstract types must have a concrete definition in the implementation")
258            }
259            (Self::Abstract { id, params: _ }, t0) => match adts.get(id) {
260                None => bail!("undefined abstract type"),
261                Some(t1) => {
262                    if t0 != t1 {
263                        format_with_flags(PrintFlag::DerefTVars, || {
264                            bail!("abstract type mismatch {t0} != {t1}")
265                        })?
266                    }
267                    Ok(())
268                }
269            },
270            (Self::TVar(sig_tv), Self::TVar(impl_tv)) if sig_tv != impl_tv => {
271                format_with_flags(PrintFlag::DerefTVars, || {
272                    bail!("signature type variable {sig_tv} does not match implementation {impl_tv}")
273                })
274            }
275            (sig_type, Self::TVar(impl_tv)) => {
276                let impl_tv_addr = impl_tv.inner_addr();
277                match tvar_map.get(&impl_tv_addr) {
278                    Some(prev_sig_type) => {
279                        let matches = match (sig_type, prev_sig_type) {
280                            (Type::TVar(tv0), Type::TVar(tv1)) => {
281                                tv0.inner_addr() == tv1.inner_addr()
282                            }
283                            _ => sig_type == prev_sig_type,
284                        };
285                        if matches {
286                            Ok(())
287                        } else {
288                            format_with_flags(PrintFlag::DerefTVars, || {
289                                bail!(
290                                    "type variable usage mismatch: expected {prev_sig_type}, got {sig_type}"
291                                )
292                            })
293                        }
294                    }
295                    None => {
296                        tvar_map.insert(impl_tv_addr, sig_type.clone());
297                        Ok(())
298                    }
299                }
300            }
301            (Self::TVar(sig_tv), impl_type) => {
302                format_with_flags(PrintFlag::DerefTVars, || {
303                    bail!("signature has type variable '{sig_tv} where implementation has {impl_type}")
304                })
305            }
306            (sig_type, impl_type) => format_with_flags(PrintFlag::DerefTVars, || {
307                bail!("type mismatch: signature has {sig_type}, implementation has {impl_type}")
308            }),
309        }
310    }
311}