1use crate::{
2 env::Env,
3 format_with_flags,
4 typ::{AbstractId, AndAc, RefHist, Type},
5 PrintFlag,
6};
7use anyhow::{bail, Result};
8use enumflags2::BitFlags;
9use fxhash::{FxHashMap, FxHashSet};
10use netidx_value::Typ;
11use poolshark::local::LPooled;
12
13impl Type {
14 fn could_match_int(
15 &self,
16 env: &Env,
17 hist: &mut RefHist<FxHashMap<(Option<usize>, Option<usize>), bool>>,
18 t: &Self,
19 ) -> Result<bool> {
20 let fl = BitFlags::empty();
21 match (self, t) {
22 (
23 Self::Ref { scope: s0, name: n0, params: p0 },
24 Self::Ref { scope: s1, name: n1, params: p1 },
25 ) if s0 == s1 && n0 == n1 => Ok(p0.len() == p1.len()
26 && p0
27 .iter()
28 .zip(p1.iter())
29 .map(|(t0, t1)| t0.could_match_int(env, hist, t1))
30 .collect::<Result<AndAc>>()?
31 .0),
32 (t0 @ Self::Ref { .. }, t1) | (t0, t1 @ Self::Ref { .. }) => {
33 let t0_id = hist.ref_id(t0, env);
34 let t1_id = hist.ref_id(t1, env);
35 let t0 = t0.lookup_ref(env)?;
36 let t1 = t1.lookup_ref(env)?;
37 match hist.get(&(t0_id, t1_id)) {
38 Some(r) => Ok(*r),
39 None => {
40 hist.insert((t0_id, t1_id), true);
41 let r = t0.could_match_int(env, hist, &t1);
42 hist.remove(&(t0_id, t1_id));
43 r
44 }
45 }
46 }
47 (t0, Self::Primitive(s)) => {
48 for t1 in s.iter() {
49 if t0.contains_int(fl, env, hist, &Type::Primitive(t1.into()))? {
50 return Ok(true);
51 }
52 }
53 Ok(false)
54 }
55 (Type::Primitive(p), Type::Error(_)) => Ok(p.contains(Typ::Error)),
56 (Type::Error(t0), Type::Error(t1)) => t0.could_match_int(env, hist, t1),
57 (Type::Array(t0), Type::Array(t1)) => t0.could_match_int(env, hist, t1),
58 (Type::Primitive(p), Type::Array(_)) => Ok(p.contains(Typ::Array)),
59 (Type::Map { key: k0, value: v0 }, Type::Map { key: k1, value: v1 }) => {
60 Ok(k0.could_match_int(env, hist, k1)?
61 && v0.could_match_int(env, hist, v1)?)
62 }
63 (Type::Primitive(p), Type::Map { .. }) => Ok(p.contains(Typ::Map)),
64 (Type::Tuple(ts0), Type::Tuple(ts1)) => Ok(ts0.len() == ts1.len()
65 && ts0
66 .iter()
67 .zip(ts1.iter())
68 .map(|(t0, t1)| t0.could_match_int(env, hist, t1))
69 .collect::<Result<AndAc>>()?
70 .0),
71 (Type::Struct(ts0), Type::Struct(ts1)) => Ok(ts0.len() == ts1.len()
72 && ts0
73 .iter()
74 .zip(ts1.iter())
75 .map(|((n0, t0), (n1, t1))| {
76 Ok(n0 == n1 && t0.could_match_int(env, hist, t1)?)
77 })
78 .collect::<Result<AndAc>>()?
79 .0),
80 (Type::Variant(n0, ts0), Type::Variant(n1, ts1)) => Ok(ts0.len()
81 == ts1.len()
82 && n0 == n1
83 && ts0
84 .iter()
85 .zip(ts1.iter())
86 .map(|(t0, t1)| t0.could_match_int(env, hist, t1))
87 .collect::<Result<AndAc>>()?
88 .0),
89 (Type::ByRef(t0), Type::ByRef(t1)) => t0.could_match_int(env, hist, t1),
90 (t0, Self::Set(ts)) => {
91 for t1 in ts.iter() {
92 if t0.could_match_int(env, hist, t1)? {
93 return Ok(true);
94 }
95 }
96 Ok(false)
97 }
98 (Type::Set(ts), t1) => {
99 for t0 in ts.iter() {
100 if t0.could_match_int(env, hist, t1)? {
101 return Ok(true);
102 }
103 }
104 Ok(false)
105 }
106 (Type::TVar(t0), t1) => match &*t0.read().typ.read() {
107 Some(t0) => t0.could_match_int(env, hist, t1),
108 None => Ok(true),
109 },
110 (t0, Type::TVar(t1)) => match &*t1.read().typ.read() {
111 Some(t1) => t0.could_match_int(env, hist, t1),
112 None => Ok(true),
113 },
114 (
115 Type::Abstract { id: id0, params: p0 },
116 Type::Abstract { id: id1, params: p1 },
117 ) => Ok(id0 == id1
118 && p0.len() == p1.len()
119 && p0
120 .iter()
121 .zip(p1.iter())
122 .map(|(t0, t1)| t0.could_match_int(env, hist, t1))
123 .collect::<Result<AndAc>>()?
124 .0),
125 (_, Type::Bottom) => Ok(true),
126 (Type::Bottom, _) => Ok(false),
127 (Type::Any, _) | (_, Type::Any) => Ok(true),
128 (Type::Abstract { .. }, _)
129 | (_, Type::Abstract { .. })
130 | (Type::Fn(_), _)
131 | (_, Type::Fn(_))
132 | (Type::Tuple(_), _)
133 | (_, Type::Tuple(_))
134 | (Type::Struct(_), _)
135 | (_, Type::Struct(_))
136 | (Type::Variant(_, _), _)
137 | (_, Type::Variant(_, _))
138 | (Type::ByRef(_), _)
139 | (_, Type::ByRef(_))
140 | (Type::Array(_), _)
141 | (_, Type::Array(_))
142 | (_, Type::Map { .. })
143 | (Type::Map { .. }, _) => Ok(false),
144 }
145 }
146
147 pub fn could_match(&self, env: &Env, t: &Self) -> Result<bool> {
148 self.could_match_int(env, &mut RefHist::new(LPooled::take()), t)
149 }
150
151 pub fn sig_matches(
152 &self,
153 env: &Env,
154 impl_type: &Self,
155 adts: &FxHashMap<AbstractId, Type>,
156 ) -> Result<()> {
157 self.sig_matches_int(
158 env,
159 impl_type,
160 &mut LPooled::take(),
161 &mut RefHist::new(LPooled::take()),
162 adts,
163 )
164 }
165
166 pub(super) fn sig_matches_int(
167 &self,
168 env: &Env,
169 impl_type: &Self,
170 tvar_map: &mut FxHashMap<usize, Type>,
171 hist: &mut RefHist<FxHashSet<(Option<usize>, Option<usize>)>>,
172 adts: &FxHashMap<AbstractId, Type>,
173 ) -> Result<()> {
174 if (self as *const Type) == (impl_type as *const Type) {
175 return Ok(());
176 }
177 match (self, impl_type) {
178 (Self::Bottom, Self::Bottom) => Ok(()),
179 (Self::Any, Self::Any) => Ok(()),
180 (Self::Primitive(p0), Self::Primitive(p1)) if p0 == p1 => Ok(()),
181 (
182 Self::Ref { scope: s0, name: n0, params: p0 },
183 Self::Ref { scope: s1, name: n1, params: p1 },
184 ) if s0 == s1 && n0 == n1 && p0.len() == p1.len() => {
185 for (t0, t1) in p0.iter().zip(p1.iter()) {
186 t0.sig_matches_int(env, t1, tvar_map, hist, adts)?;
187 }
188 Ok(())
189 }
190 (t0 @ Self::Ref { .. }, t1) | (t0, t1 @ Self::Ref { .. }) => {
191 let t0_id = hist.ref_id(t0, env);
192 let t1_id = hist.ref_id(t1, env);
193 let t0 = t0.lookup_ref(env)?;
194 let t1 = t1.lookup_ref(env)?;
195 if hist.contains(&(t0_id, t1_id)) {
196 Ok(())
197 } else {
198 hist.insert((t0_id, t1_id));
199 let r = t0.sig_matches_int(env, &t1, tvar_map, hist, adts);
200 hist.remove(&(t0_id, t1_id));
201 r
202 }
203 }
204 (Self::Fn(f0), Self::Fn(f1)) => {
205 f0.sig_matches_int(env, f1, tvar_map, hist, adts)
206 }
207 (Self::Set(s0), Self::Set(s1)) if s0.len() == s1.len() => {
208 for (t0, t1) in s0.iter().zip(s1.iter()) {
209 t0.sig_matches_int(env, t1, tvar_map, hist, adts)?;
210 }
211 Ok(())
212 }
213 (Self::Error(e0), Self::Error(e1)) => {
214 e0.sig_matches_int(env, e1, tvar_map, hist, adts)
215 }
216 (Self::Array(a0), Self::Array(a1)) => {
217 a0.sig_matches_int(env, a1, tvar_map, hist, adts)
218 }
219 (Self::ByRef(b0), Self::ByRef(b1)) => {
220 b0.sig_matches_int(env, b1, tvar_map, hist, adts)
221 }
222 (Self::Tuple(t0), Self::Tuple(t1)) if t0.len() == t1.len() => {
223 for (t0, t1) in t0.iter().zip(t1.iter()) {
224 t0.sig_matches_int(env, t1, tvar_map, hist, adts)?;
225 }
226 Ok(())
227 }
228 (Self::Struct(s0), Self::Struct(s1)) if s0.len() == s1.len() => {
229 for ((n0, t0), (n1, t1)) in s0.iter().zip(s1.iter()) {
230 if n0 != n1 {
231 format_with_flags(PrintFlag::DerefTVars, || {
232 bail!("struct field name mismatch: {n0} vs {n1}")
233 })?
234 }
235 t0.sig_matches_int(env, t1, tvar_map, hist, adts)?;
236 }
237 Ok(())
238 }
239 (Self::Variant(tag0, t0), Self::Variant(tag1, t1))
240 if tag0 == tag1 && t0.len() == t1.len() =>
241 {
242 for (t0, t1) in t0.iter().zip(t1.iter()) {
243 t0.sig_matches_int(env, t1, tvar_map, hist, adts)?;
244 }
245 Ok(())
246 }
247 (Self::Map { key: k0, value: v0 }, Self::Map { key: k1, value: v1 }) => {
248 k0.sig_matches_int(env, k1, tvar_map, hist, adts)?;
249 v0.sig_matches_int(env, v1, tvar_map, hist, adts)
250 }
251 (Self::Abstract { .. }, Self::Abstract { .. }) => {
252 bail!("abstract types must have a concrete definition in the implementation")
253 }
254 (Self::Abstract { id, params: _ }, t0) => match adts.get(id) {
255 None => bail!("undefined abstract type"),
256 Some(t1) => {
257 if t0 != t1 {
258 format_with_flags(PrintFlag::DerefTVars, || {
259 bail!("abstract type mismatch {t0} != {t1}")
260 })?
261 }
262 Ok(())
263 }
264 },
265 (Self::TVar(sig_tv), Self::TVar(impl_tv)) if sig_tv != impl_tv => {
266 format_with_flags(PrintFlag::DerefTVars, || {
267 bail!("signature type variable {sig_tv} does not match implementation {impl_tv}")
268 })
269 }
270 (sig_type, Self::TVar(impl_tv)) => {
271 let impl_tv_addr = impl_tv.inner_addr();
272 match tvar_map.get(&impl_tv_addr) {
273 Some(prev_sig_type) => {
274 let matches = match (sig_type, prev_sig_type) {
275 (Type::TVar(tv0), Type::TVar(tv1)) => {
276 tv0.inner_addr() == tv1.inner_addr()
277 }
278 _ => sig_type == prev_sig_type,
279 };
280 if matches {
281 Ok(())
282 } else {
283 format_with_flags(PrintFlag::DerefTVars, || {
284 bail!(
285 "type variable usage mismatch: expected {prev_sig_type}, got {sig_type}"
286 )
287 })
288 }
289 }
290 None => {
291 tvar_map.insert(impl_tv_addr, sig_type.clone());
292 Ok(())
293 }
294 }
295 }
296 (Self::TVar(sig_tv), impl_type) => {
297 format_with_flags(PrintFlag::DerefTVars, || {
298 bail!("signature has type variable '{sig_tv} where implementation has {impl_type}")
299 })
300 }
301 (sig_type, impl_type) => format_with_flags(PrintFlag::DerefTVars, || {
302 bail!("type mismatch: signature has {sig_type}, implementation has {impl_type}")
303 }),
304 }
305 }
306}