1use crate::{
2 env::Env,
3 format_with_flags,
4 typ::{AbstractId, AndAc, RefHist, Type, TypeRef},
5 PrintFlag,
6};
7use ahash::{AHashMap, AHashSet};
8use anyhow::{bail, Result};
9use enumflags2::BitFlags;
10use netidx_value::Typ;
11use nohash::IntMap;
12use poolshark::local::LPooled;
13
14impl Type {
15 fn could_match_int(
16 &self,
17 env: &Env,
18 hist: &mut RefHist<AHashMap<(Option<usize>, Option<usize>), bool>>,
19 t: &Self,
20 ) -> Result<bool> {
21 let fl = BitFlags::empty();
22 match (self, t) {
23 (
24 Self::Ref(TypeRef { scope: s0, name: n0, params: p0, .. }),
25 Self::Ref(TypeRef { scope: s1, name: n1, params: p1, .. }),
26 ) if s0 == s1 && n0 == n1 => Ok(p0.len() == p1.len()
27 && p0
28 .iter()
29 .zip(p1.iter())
30 .map(|(t0, t1)| t0.could_match_int(env, hist, t1))
31 .collect::<Result<AndAc>>()?
32 .0),
33 (t0 @ Self::Ref(TypeRef { .. }), t1)
34 | (t0, t1 @ Self::Ref(TypeRef { .. })) => {
35 let t0_id = hist.ref_id(t0, env);
36 let t1_id = hist.ref_id(t1, env);
37 let t0 = t0.lookup_ref(env)?;
38 let t1 = t1.lookup_ref(env)?;
39 match hist.get(&(t0_id, t1_id)) {
40 Some(r) => Ok(*r),
41 None => {
42 hist.insert((t0_id, t1_id), true);
43 let r = t0.could_match_int(env, hist, &t1);
44 hist.remove(&(t0_id, t1_id));
45 r
46 }
47 }
48 }
49 (t0, Self::Primitive(s)) => {
50 for t1 in s.iter() {
51 if t0.contains_int(fl, env, hist, &Type::Primitive(t1.into()))? {
52 return Ok(true);
53 }
54 }
55 Ok(false)
56 }
57 (Type::Primitive(p), Type::Error(_)) => Ok(p.contains(Typ::Error)),
58 (Type::Error(t0), Type::Error(t1)) => t0.could_match_int(env, hist, t1),
59 (Type::Array(t0), Type::Array(t1)) => t0.could_match_int(env, hist, t1),
60 (Type::Primitive(p), Type::Array(_)) => Ok(p.contains(Typ::Array)),
61 (Type::Map { key: k0, value: v0 }, Type::Map { key: k1, value: v1 }) => {
62 Ok(k0.could_match_int(env, hist, k1)?
63 && v0.could_match_int(env, hist, v1)?)
64 }
65 (Type::Primitive(p), Type::Map { .. }) => Ok(p.contains(Typ::Map)),
66 (Type::Tuple(ts0), Type::Tuple(ts1)) => Ok(ts0.len() == ts1.len()
67 && ts0
68 .iter()
69 .zip(ts1.iter())
70 .map(|(t0, t1)| t0.could_match_int(env, hist, t1))
71 .collect::<Result<AndAc>>()?
72 .0),
73 (Type::Struct(ts0), Type::Struct(ts1)) => Ok(ts0.len() == ts1.len()
74 && ts0
75 .iter()
76 .zip(ts1.iter())
77 .map(|((n0, t0), (n1, t1))| {
78 Ok(n0 == n1 && t0.could_match_int(env, hist, t1)?)
79 })
80 .collect::<Result<AndAc>>()?
81 .0),
82 (Type::Variant(n0, ts0), Type::Variant(n1, ts1)) => Ok(ts0.len()
83 == ts1.len()
84 && n0 == n1
85 && ts0
86 .iter()
87 .zip(ts1.iter())
88 .map(|(t0, t1)| t0.could_match_int(env, hist, t1))
89 .collect::<Result<AndAc>>()?
90 .0),
91 (Type::ByRef(t0), Type::ByRef(t1)) => t0.could_match_int(env, hist, t1),
92 (t0, Self::Set(ts)) => {
93 for t1 in ts.iter() {
94 if t0.could_match_int(env, hist, t1)? {
95 return Ok(true);
96 }
97 }
98 Ok(false)
99 }
100 (Type::Set(ts), t1) => {
101 for t0 in ts.iter() {
102 if t0.could_match_int(env, hist, t1)? {
103 return Ok(true);
104 }
105 }
106 Ok(false)
107 }
108 (Type::TVar(t0), t1) => match &*t0.read().typ.read() {
109 Some(t0) => t0.could_match_int(env, hist, t1),
110 None => Ok(true),
111 },
112 (t0, Type::TVar(t1)) => match &*t1.read().typ.read() {
113 Some(t1) => t0.could_match_int(env, hist, t1),
114 None => Ok(true),
115 },
116 (
117 Type::Abstract { id: id0, params: p0 },
118 Type::Abstract { id: id1, params: p1 },
119 ) => Ok(id0 == id1
120 && p0.len() == p1.len()
121 && p0
122 .iter()
123 .zip(p1.iter())
124 .map(|(t0, t1)| t0.could_match_int(env, hist, t1))
125 .collect::<Result<AndAc>>()?
126 .0),
127 (_, Type::Bottom) => Ok(true),
128 (Type::Bottom, _) => Ok(false),
129 (Type::Any, _) | (_, Type::Any) => Ok(true),
130 (Type::Abstract { .. }, _)
131 | (_, Type::Abstract { .. })
132 | (Type::Fn(_), _)
133 | (_, Type::Fn(_))
134 | (Type::Tuple(_), _)
135 | (_, Type::Tuple(_))
136 | (Type::Struct(_), _)
137 | (_, Type::Struct(_))
138 | (Type::Variant(_, _), _)
139 | (_, Type::Variant(_, _))
140 | (Type::ByRef(_), _)
141 | (_, Type::ByRef(_))
142 | (Type::Array(_), _)
143 | (_, Type::Array(_))
144 | (_, Type::Map { .. })
145 | (Type::Map { .. }, _) => Ok(false),
146 }
147 }
148
149 pub fn could_match(&self, env: &Env, t: &Self) -> Result<bool> {
150 self.could_match_int(env, &mut RefHist::new(LPooled::take()), t)
151 }
152
153 pub fn sig_matches(
154 &self,
155 env: &Env,
156 impl_type: &Self,
157 adts: &IntMap<AbstractId, Type>,
158 ) -> Result<()> {
159 self.sig_matches_int(
160 env,
161 impl_type,
162 &mut LPooled::take(),
163 &mut RefHist::new(LPooled::take()),
164 adts,
165 )
166 }
167
168 pub(super) fn sig_matches_int(
169 &self,
170 env: &Env,
171 impl_type: &Self,
172 tvar_map: &mut IntMap<usize, Type>,
173 hist: &mut RefHist<AHashSet<(Option<usize>, Option<usize>)>>,
174 adts: &IntMap<AbstractId, Type>,
175 ) -> Result<()> {
176 if (self as *const Type) == (impl_type as *const Type) {
177 return Ok(());
178 }
179 match (self, impl_type) {
180 (Self::Bottom, Self::Bottom) => Ok(()),
181 (Self::Any, Self::Any) => Ok(()),
182 (Self::Primitive(p0), Self::Primitive(p1)) if p0 == p1 => Ok(()),
183 (
184 Self::Ref(TypeRef { scope: s0, name: n0, params: p0, .. }),
185 Self::Ref(TypeRef { scope: s1, name: n1, params: p1, .. }),
186 ) if s0 == s1 && n0 == n1 && p0.len() == p1.len() => {
187 for (t0, t1) in p0.iter().zip(p1.iter()) {
188 t0.sig_matches_int(env, t1, tvar_map, hist, adts)?;
189 }
190 Ok(())
191 }
192 (t0 @ Self::Ref(TypeRef { .. }), t1)
193 | (t0, t1 @ Self::Ref(TypeRef { .. })) => {
194 let t0_id = hist.ref_id(t0, env);
195 let t1_id = hist.ref_id(t1, env);
196 let t0 = t0.lookup_ref(env)?;
197 let t1 = t1.lookup_ref(env)?;
198 if hist.contains(&(t0_id, t1_id)) {
199 Ok(())
200 } else {
201 hist.insert((t0_id, t1_id));
202 let r = t0.sig_matches_int(env, &t1, tvar_map, hist, adts);
203 hist.remove(&(t0_id, t1_id));
204 r
205 }
206 }
207 (Self::Fn(f0), Self::Fn(f1)) => {
208 f0.sig_matches_int(env, f1, tvar_map, hist, adts)?;
209 f0.merge_lambda_ids(f1);
210 Ok(())
211 }
212 (Self::Set(s0), Self::Set(s1)) if s0.len() == s1.len() => {
213 for (t0, t1) in s0.iter().zip(s1.iter()) {
214 t0.sig_matches_int(env, t1, tvar_map, hist, adts)?;
215 }
216 Ok(())
217 }
218 (Self::Error(e0), Self::Error(e1)) => {
219 e0.sig_matches_int(env, e1, tvar_map, hist, adts)
220 }
221 (Self::Array(a0), Self::Array(a1)) => {
222 a0.sig_matches_int(env, a1, tvar_map, hist, adts)
223 }
224 (Self::ByRef(b0), Self::ByRef(b1)) => {
225 b0.sig_matches_int(env, b1, tvar_map, hist, adts)
226 }
227 (Self::Tuple(t0), Self::Tuple(t1)) if t0.len() == t1.len() => {
228 for (t0, t1) in t0.iter().zip(t1.iter()) {
229 t0.sig_matches_int(env, t1, tvar_map, hist, adts)?;
230 }
231 Ok(())
232 }
233 (Self::Struct(s0), Self::Struct(s1)) if s0.len() == s1.len() => {
234 for ((n0, t0), (n1, t1)) in s0.iter().zip(s1.iter()) {
235 if n0 != n1 {
236 format_with_flags(PrintFlag::DerefTVars, || {
237 bail!("struct field name mismatch: {n0} vs {n1}")
238 })?
239 }
240 t0.sig_matches_int(env, t1, tvar_map, hist, adts)?;
241 }
242 Ok(())
243 }
244 (Self::Variant(tag0, t0), Self::Variant(tag1, t1))
245 if tag0 == tag1 && t0.len() == t1.len() =>
246 {
247 for (t0, t1) in t0.iter().zip(t1.iter()) {
248 t0.sig_matches_int(env, t1, tvar_map, hist, adts)?;
249 }
250 Ok(())
251 }
252 (Self::Map { key: k0, value: v0 }, Self::Map { key: k1, value: v1 }) => {
253 k0.sig_matches_int(env, k1, tvar_map, hist, adts)?;
254 v0.sig_matches_int(env, v1, tvar_map, hist, adts)
255 }
256 (Self::Abstract { .. }, Self::Abstract { .. }) => Ok(()),
257 (Self::Abstract { id, params: _ }, t0) => match adts.get(id) {
258 None => Ok(()), Some(t1) => {
260 if t0 != t1 {
261 format_with_flags(PrintFlag::DerefTVars, || {
262 bail!("abstract type mismatch {t0} != {t1}")
263 })?
264 }
265 Ok(())
266 }
267 },
268 (Self::TVar(sig_tv), Self::TVar(impl_tv)) if sig_tv != impl_tv => {
269 format_with_flags(PrintFlag::DerefTVars, || {
270 bail!("signature type variable {sig_tv} does not match implementation {impl_tv}")
271 })
272 }
273 (sig_type, Self::TVar(impl_tv)) => {
274 let impl_tv_addr = impl_tv.inner_addr();
275 match tvar_map.get(&impl_tv_addr) {
276 Some(prev_sig_type) => {
277 let matches = match (sig_type, prev_sig_type) {
278 (Type::TVar(tv0), Type::TVar(tv1)) => {
279 tv0.inner_addr() == tv1.inner_addr()
280 }
281 _ => sig_type == prev_sig_type,
282 };
283 if matches {
284 Ok(())
285 } else {
286 format_with_flags(PrintFlag::DerefTVars, || {
287 bail!(
288 "type variable usage mismatch: expected {prev_sig_type}, got {sig_type}"
289 )
290 })
291 }
292 }
293 None => {
294 tvar_map.insert(impl_tv_addr, sig_type.clone());
295 Ok(())
296 }
297 }
298 }
299 (Self::TVar(sig_tv), impl_type) => {
300 format_with_flags(PrintFlag::DerefTVars, || {
301 bail!("signature has type variable '{sig_tv} where implementation has {impl_type}")
302 })
303 }
304 (sig_type, impl_type) => format_with_flags(PrintFlag::DerefTVars, || {
305 bail!("type mismatch: signature has {sig_type}, implementation has {impl_type}")
306 }),
307 }
308 }
309}