1use crate::{
2 env::Env,
3 format_with_flags,
4 typ::{AbstractId, AndAc, Type},
5 PrintFlag,
6};
7use anyhow::{bail, Result};
8use enumflags2::BitFlags;
9use fxhash::{FxHashMap, FxHashSet};
10use netidx_value::Typ;
11use poolshark::local::LPooled;
12
13impl Type {
14 fn could_match_int(
15 &self,
16 env: &Env,
17 hist: &mut FxHashMap<(usize, usize), bool>,
18 t: &Self,
19 ) -> Result<bool> {
20 let fl = BitFlags::empty();
21 match (self, t) {
22 (
23 Self::Ref { scope: s0, name: n0, params: p0 },
24 Self::Ref { scope: s1, name: n1, params: p1 },
25 ) if s0 == s1 && n0 == n1 => Ok(p0.len() == p1.len()
26 && p0
27 .iter()
28 .zip(p1.iter())
29 .map(|(t0, t1)| t0.could_match_int(env, hist, t1))
30 .collect::<Result<AndAc>>()?
31 .0),
32 (t0 @ Self::Ref { .. }, t1) | (t0, t1 @ Self::Ref { .. }) => {
33 let t0 = t0.lookup_ref(env)?;
34 let t1 = t1.lookup_ref(env)?;
35 let t0_addr = (t0 as *const Type).addr();
36 let t1_addr = (t1 as *const Type).addr();
37 match hist.get(&(t0_addr, t1_addr)) {
38 Some(r) => Ok(*r),
39 None => {
40 hist.insert((t0_addr, t1_addr), true);
41 match t0.could_match_int(env, hist, t1) {
42 Ok(r) => {
43 hist.insert((t0_addr, t1_addr), r);
44 Ok(r)
45 }
46 Err(e) => {
47 hist.remove(&(t0_addr, t1_addr));
48 Err(e)
49 }
50 }
51 }
52 }
53 }
54 (t0, Self::Primitive(s)) => {
55 for t1 in s.iter() {
56 if t0.contains_int(fl, env, hist, &Type::Primitive(t1.into()))? {
57 return Ok(true);
58 }
59 }
60 Ok(false)
61 }
62 (Type::Primitive(p), Type::Error(_)) => Ok(p.contains(Typ::Error)),
63 (Type::Error(t0), Type::Error(t1)) => t0.could_match_int(env, hist, t1),
64 (Type::Array(t0), Type::Array(t1)) => t0.could_match_int(env, hist, t1),
65 (Type::Primitive(p), Type::Array(_)) => Ok(p.contains(Typ::Array)),
66 (Type::Map { key: k0, value: v0 }, Type::Map { key: k1, value: v1 }) => {
67 Ok(k0.could_match_int(env, hist, k1)?
68 && v0.could_match_int(env, hist, v1)?)
69 }
70 (Type::Primitive(p), Type::Map { .. }) => Ok(p.contains(Typ::Map)),
71 (Type::Tuple(ts0), Type::Tuple(ts1)) => Ok(ts0.len() == ts1.len()
72 && ts0
73 .iter()
74 .zip(ts1.iter())
75 .map(|(t0, t1)| t0.could_match_int(env, hist, t1))
76 .collect::<Result<AndAc>>()?
77 .0),
78 (Type::Struct(ts0), Type::Struct(ts1)) => Ok(ts0.len() == ts1.len()
79 && ts0
80 .iter()
81 .zip(ts1.iter())
82 .map(|((n0, t0), (n1, t1))| {
83 Ok(n0 == n1 && t0.could_match_int(env, hist, t1)?)
84 })
85 .collect::<Result<AndAc>>()?
86 .0),
87 (Type::Variant(n0, ts0), Type::Variant(n1, ts1)) => Ok(ts0.len()
88 == ts1.len()
89 && n0 == n1
90 && ts0
91 .iter()
92 .zip(ts1.iter())
93 .map(|(t0, t1)| t0.could_match_int(env, hist, t1))
94 .collect::<Result<AndAc>>()?
95 .0),
96 (Type::ByRef(t0), Type::ByRef(t1)) => t0.could_match_int(env, hist, t1),
97 (t0, Self::Set(ts)) => {
98 for t1 in ts.iter() {
99 if t0.could_match_int(env, hist, t1)? {
100 return Ok(true);
101 }
102 }
103 Ok(false)
104 }
105 (Type::Set(ts), t1) => {
106 for t0 in ts.iter() {
107 if t0.could_match_int(env, hist, t1)? {
108 return Ok(true);
109 }
110 }
111 Ok(false)
112 }
113 (Type::TVar(t0), t1) => match &*t0.read().typ.read() {
114 Some(t0) => t0.could_match_int(env, hist, t1),
115 None => Ok(true),
116 },
117 (t0, Type::TVar(t1)) => match &*t1.read().typ.read() {
118 Some(t1) => t0.could_match_int(env, hist, t1),
119 None => Ok(true),
120 },
121 (
122 Type::Abstract { id: id0, params: p0 },
123 Type::Abstract { id: id1, params: p1 },
124 ) => Ok(id0 == id1
125 && p0.len() == p1.len()
126 && p0
127 .iter()
128 .zip(p1.iter())
129 .map(|(t0, t1)| t0.could_match_int(env, hist, t1))
130 .collect::<Result<AndAc>>()?
131 .0),
132 (_, Type::Bottom) => Ok(true),
133 (Type::Bottom, _) => Ok(false),
134 (Type::Any, _) | (_, Type::Any) => Ok(true),
135 (Type::Abstract { .. }, _)
136 | (_, Type::Abstract { .. })
137 | (Type::Fn(_), _)
138 | (_, Type::Fn(_))
139 | (Type::Tuple(_), _)
140 | (_, Type::Tuple(_))
141 | (Type::Struct(_), _)
142 | (_, Type::Struct(_))
143 | (Type::Variant(_, _), _)
144 | (_, Type::Variant(_, _))
145 | (Type::ByRef(_), _)
146 | (_, Type::ByRef(_))
147 | (Type::Array(_), _)
148 | (_, Type::Array(_))
149 | (_, Type::Map { .. })
150 | (Type::Map { .. }, _) => Ok(false),
151 }
152 }
153
154 pub fn could_match(&self, env: &Env, t: &Self) -> Result<bool> {
155 self.could_match_int(env, &mut LPooled::take(), t)
156 }
157
158 pub fn sig_matches(
159 &self,
160 env: &Env,
161 impl_type: &Self,
162 adts: &FxHashMap<AbstractId, Type>,
163 ) -> Result<()> {
164 self.sig_matches_int(
165 env,
166 impl_type,
167 &mut LPooled::take(),
168 &mut LPooled::take(),
169 adts,
170 )
171 }
172
173 pub(super) fn sig_matches_int(
174 &self,
175 env: &Env,
176 impl_type: &Self,
177 tvar_map: &mut FxHashMap<usize, Type>,
178 hist: &mut FxHashSet<(usize, usize)>,
179 adts: &FxHashMap<AbstractId, Type>,
180 ) -> Result<()> {
181 if (self as *const Type) == (impl_type as *const Type) {
182 return Ok(());
183 }
184 match (self, impl_type) {
185 (Self::Bottom, Self::Bottom) => Ok(()),
186 (Self::Any, Self::Any) => Ok(()),
187 (Self::Primitive(p0), Self::Primitive(p1)) if p0 == p1 => Ok(()),
188 (
189 Self::Ref { scope: s0, name: n0, params: p0 },
190 Self::Ref { scope: s1, name: n1, params: p1 },
191 ) if s0 == s1 && n0 == n1 && p0.len() == p1.len() => {
192 for (t0, t1) in p0.iter().zip(p1.iter()) {
193 t0.sig_matches_int(env, t1, tvar_map, hist, adts)?;
194 }
195 Ok(())
196 }
197 (t0 @ Self::Ref { .. }, t1) | (t0, t1 @ Self::Ref { .. }) => {
198 let t0 = t0.lookup_ref(env)?;
199 let t1 = t1.lookup_ref(env)?;
200 let t0_addr = (t0 as *const Type).addr();
201 let t1_addr = (t1 as *const Type).addr();
202 if hist.contains(&(t0_addr, t1_addr)) {
203 Ok(())
204 } else {
205 hist.insert((t0_addr, t1_addr));
206 t0.sig_matches_int(env, t1, tvar_map, hist, adts)
207 }
208 }
209 (Self::Fn(f0), Self::Fn(f1)) => {
210 f0.sig_matches_int(env, f1, tvar_map, hist, adts)
211 }
212 (Self::Set(s0), Self::Set(s1)) if s0.len() == s1.len() => {
213 for (t0, t1) in s0.iter().zip(s1.iter()) {
214 t0.sig_matches_int(env, t1, tvar_map, hist, adts)?;
215 }
216 Ok(())
217 }
218 (Self::Error(e0), Self::Error(e1)) => {
219 e0.sig_matches_int(env, e1, tvar_map, hist, adts)
220 }
221 (Self::Array(a0), Self::Array(a1)) => {
222 a0.sig_matches_int(env, a1, tvar_map, hist, adts)
223 }
224 (Self::ByRef(b0), Self::ByRef(b1)) => {
225 b0.sig_matches_int(env, b1, tvar_map, hist, adts)
226 }
227 (Self::Tuple(t0), Self::Tuple(t1)) if t0.len() == t1.len() => {
228 for (t0, t1) in t0.iter().zip(t1.iter()) {
229 t0.sig_matches_int(env, t1, tvar_map, hist, adts)?;
230 }
231 Ok(())
232 }
233 (Self::Struct(s0), Self::Struct(s1)) if s0.len() == s1.len() => {
234 for ((n0, t0), (n1, t1)) in s0.iter().zip(s1.iter()) {
235 if n0 != n1 {
236 format_with_flags(PrintFlag::DerefTVars, || {
237 bail!("struct field name mismatch: {n0} vs {n1}")
238 })?
239 }
240 t0.sig_matches_int(env, t1, tvar_map, hist, adts)?;
241 }
242 Ok(())
243 }
244 (Self::Variant(tag0, t0), Self::Variant(tag1, t1))
245 if tag0 == tag1 && t0.len() == t1.len() =>
246 {
247 for (t0, t1) in t0.iter().zip(t1.iter()) {
248 t0.sig_matches_int(env, t1, tvar_map, hist, adts)?;
249 }
250 Ok(())
251 }
252 (Self::Map { key: k0, value: v0 }, Self::Map { key: k1, value: v1 }) => {
253 k0.sig_matches_int(env, k1, tvar_map, hist, adts)?;
254 v0.sig_matches_int(env, v1, tvar_map, hist, adts)
255 }
256 (Self::Abstract { .. }, Self::Abstract { .. }) => {
257 bail!("abstract types must have a concrete definition in the implementation")
258 }
259 (Self::Abstract { id, params: _ }, t0) => match adts.get(id) {
260 None => bail!("undefined abstract type"),
261 Some(t1) => {
262 if t0 != t1 {
263 format_with_flags(PrintFlag::DerefTVars, || {
264 bail!("abstract type mismatch {t0} != {t1}")
265 })?
266 }
267 Ok(())
268 }
269 },
270 (Self::TVar(sig_tv), Self::TVar(impl_tv)) if sig_tv != impl_tv => {
271 format_with_flags(PrintFlag::DerefTVars, || {
272 bail!("signature type variable {sig_tv} does not match implementation {impl_tv}")
273 })
274 }
275 (sig_type, Self::TVar(impl_tv)) => {
276 let impl_tv_addr = impl_tv.inner_addr();
277 match tvar_map.get(&impl_tv_addr) {
278 Some(prev_sig_type) => {
279 let matches = match (sig_type, prev_sig_type) {
280 (Type::TVar(tv0), Type::TVar(tv1)) => {
281 tv0.inner_addr() == tv1.inner_addr()
282 }
283 _ => sig_type == prev_sig_type,
284 };
285 if matches {
286 Ok(())
287 } else {
288 format_with_flags(PrintFlag::DerefTVars, || {
289 bail!(
290 "type variable usage mismatch: expected {prev_sig_type}, got {sig_type}"
291 )
292 })
293 }
294 }
295 None => {
296 tvar_map.insert(impl_tv_addr, sig_type.clone());
297 Ok(())
298 }
299 }
300 }
301 (Self::TVar(sig_tv), impl_type) => {
302 format_with_flags(PrintFlag::DerefTVars, || {
303 bail!("signature has type variable '{sig_tv} where implementation has {impl_type}")
304 })
305 }
306 (sig_type, impl_type) => format_with_flags(PrintFlag::DerefTVars, || {
307 bail!("type mismatch: signature has {sig_type}, implementation has {impl_type}")
308 }),
309 }
310 }
311}