1use crate::{
2 env::Env,
3 expr::ModPath,
4 typ::{TVar, Type},
5 Ctx, UserEvent,
6};
7use anyhow::{bail, Result};
8use arcstr::ArcStr;
9use fxhash::FxHashMap;
10use parking_lot::RwLock;
11use std::{
12 cell::RefCell,
13 cmp::{Eq, Ordering, PartialEq},
14 collections::HashMap,
15 fmt::{self, Debug},
16};
17use triomphe::Arc;
18
19use super::AndAc;
20
21#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
22pub struct FnArgType {
23 pub label: Option<(ArcStr, bool)>,
24 pub typ: Type,
25}
26
27#[derive(Debug, Clone)]
28pub struct FnType {
29 pub args: Arc<[FnArgType]>,
30 pub vargs: Option<Type>,
31 pub rtype: Type,
32 pub constraints: Arc<RwLock<Vec<(TVar, Type)>>>,
33}
34
35impl PartialEq for FnType {
36 fn eq(&self, other: &Self) -> bool {
37 let Self { args: args0, vargs: vargs0, rtype: rtype0, constraints: constraints0 } =
38 self;
39 let Self { args: args1, vargs: vargs1, rtype: rtype1, constraints: constraints1 } =
40 other;
41 args0 == args1
42 && vargs0 == vargs1
43 && rtype0 == rtype1
44 && &*constraints0.read() == &*constraints1.read()
45 }
46}
47
48impl Eq for FnType {}
49
50impl PartialOrd for FnType {
51 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
52 use std::cmp::Ordering;
53 let Self { args: args0, vargs: vargs0, rtype: rtype0, constraints: constraints0 } =
54 self;
55 let Self { args: args1, vargs: vargs1, rtype: rtype1, constraints: constraints1 } =
56 other;
57 match args0.partial_cmp(&args1) {
58 Some(Ordering::Equal) => match vargs0.partial_cmp(vargs1) {
59 Some(Ordering::Equal) => match rtype0.partial_cmp(rtype1) {
60 Some(Ordering::Equal) => {
61 constraints0.read().partial_cmp(&*constraints1.read())
62 }
63 r => r,
64 },
65 r => r,
66 },
67 r => r,
68 }
69 }
70}
71
72impl Ord for FnType {
73 fn cmp(&self, other: &Self) -> Ordering {
74 self.partial_cmp(other).unwrap()
75 }
76}
77
78impl Default for FnType {
79 fn default() -> Self {
80 Self {
81 args: Arc::from_iter([]),
82 vargs: None,
83 rtype: Default::default(),
84 constraints: Arc::new(RwLock::new(vec![])),
85 }
86 }
87}
88
89impl FnType {
90 pub fn scope_refs(&self, scope: &ModPath) -> FnType {
91 let typ = Type::Fn(Arc::new(self.clone()));
92 match typ.scope_refs(scope) {
93 Type::Fn(f) => (*f).clone(),
94 _ => unreachable!(),
95 }
96 }
97
98 pub(super) fn normalize(&self) -> Self {
99 let Self { args, vargs, rtype, constraints } = self;
100 let args = Arc::from_iter(
101 args.iter()
102 .map(|a| FnArgType { label: a.label.clone(), typ: a.typ.normalize() }),
103 );
104 let vargs = vargs.as_ref().map(|t| t.normalize());
105 let rtype = rtype.normalize();
106 let constraints = Arc::new(RwLock::new(
107 constraints
108 .read()
109 .iter()
110 .map(|(tv, t)| (tv.clone(), t.normalize()))
111 .collect(),
112 ));
113 FnType { args, vargs, rtype, constraints }
114 }
115
116 pub fn unbind_tvars(&self) {
117 let FnType { args, vargs, rtype, constraints } = self;
118 for arg in args.iter() {
119 arg.typ.unbind_tvars()
120 }
121 if let Some(t) = vargs {
122 t.unbind_tvars()
123 }
124 rtype.unbind_tvars();
125 for (tv, tc) in constraints.read().iter() {
126 tv.unbind();
127 tc.unbind_tvars()
128 }
129 }
130
131 pub fn constrain_known(&self) {
132 thread_local! {
133 static KNOWN: RefCell<FxHashMap<ArcStr, TVar>> = RefCell::new(HashMap::default());
134 }
135 KNOWN.with_borrow_mut(|known| {
136 known.clear();
137 self.collect_tvars(known);
138 let mut constraints = self.constraints.write();
139 for (name, tv) in known.drain() {
140 if let Some(t) = tv.read().typ.read().as_ref() {
141 if !constraints.iter().any(|(tv, _)| tv.name == name) {
142 t.bind_as(&Type::Any);
143 constraints.push((tv.clone(), t.normalize()));
144 }
145 }
146 }
147 });
148 }
149
150 pub fn reset_tvars(&self) -> Self {
151 let FnType { args, vargs, rtype, constraints } = self;
152 let args = Arc::from_iter(
153 args.iter()
154 .map(|a| FnArgType { label: a.label.clone(), typ: a.typ.reset_tvars() }),
155 );
156 let vargs = vargs.as_ref().map(|t| t.reset_tvars());
157 let rtype = rtype.reset_tvars();
158 let constraints = Arc::new(RwLock::new(
159 constraints
160 .read()
161 .iter()
162 .map(|(tv, tc)| (TVar::empty_named(tv.name.clone()), tc.reset_tvars()))
163 .collect(),
164 ));
165 FnType { args, vargs, rtype, constraints }
166 }
167
168 pub fn replace_tvars(&self, known: &FxHashMap<ArcStr, Type>) -> Self {
169 let FnType { args, vargs, rtype, constraints } = self;
170 let args = Arc::from_iter(args.iter().map(|a| FnArgType {
171 label: a.label.clone(),
172 typ: a.typ.replace_tvars(known),
173 }));
174 let vargs = vargs.as_ref().map(|t| t.replace_tvars(known));
175 let rtype = rtype.replace_tvars(known);
176 let constraints = constraints.clone();
177 FnType { args, vargs, rtype, constraints }
178 }
179
180 pub fn replace_auto_constrained(&self) -> Self {
184 thread_local! {
185 static KNOWN: RefCell<FxHashMap<ArcStr, Type>> = RefCell::new(HashMap::default());
186 }
187 KNOWN.with_borrow_mut(|known| {
188 known.clear();
189 let Self { args, vargs, rtype, constraints } = self;
190 let constraints: Vec<(TVar, Type)> = constraints
191 .read()
192 .iter()
193 .filter_map(|(tv, ct)| {
194 if tv.name.starts_with("_") {
195 known.insert(tv.name.clone(), ct.clone());
196 None
197 } else {
198 Some((tv.clone(), ct.clone()))
199 }
200 })
201 .collect();
202 let constraints = Arc::new(RwLock::new(constraints));
203 let args = Arc::from_iter(args.iter().map(|FnArgType { label, typ }| {
204 FnArgType { label: label.clone(), typ: typ.replace_tvars(&known) }
205 }));
206 let vargs = vargs.as_ref().map(|t| t.replace_tvars(&known));
207 let rtype = rtype.replace_tvars(&known);
208 Self { args, vargs, rtype, constraints }
209 })
210 }
211
212 pub fn has_unbound(&self) -> bool {
213 let FnType { args, vargs, rtype, constraints } = self;
214 args.iter().any(|a| a.typ.has_unbound())
215 || vargs.as_ref().map(|t| t.has_unbound()).unwrap_or(false)
216 || rtype.has_unbound()
217 || constraints
218 .read()
219 .iter()
220 .any(|(tv, tc)| tv.read().typ.read().is_none() || tc.has_unbound())
221 }
222
223 pub fn bind_as(&self, t: &Type) {
224 let FnType { args, vargs, rtype, constraints } = self;
225 for a in args.iter() {
226 a.typ.bind_as(t)
227 }
228 if let Some(va) = vargs.as_ref() {
229 va.bind_as(t)
230 }
231 rtype.bind_as(t);
232 for (tv, tc) in constraints.read().iter() {
233 let tv = tv.read();
234 let mut tv = tv.typ.write();
235 if tv.is_none() {
236 *tv = Some(t.clone())
237 }
238 tc.bind_as(t)
239 }
240 }
241
242 pub fn alias_tvars(&self, known: &mut FxHashMap<ArcStr, TVar>) {
243 let FnType { args, vargs, rtype, constraints } = self;
244 for arg in args.iter() {
245 arg.typ.alias_tvars(known)
246 }
247 if let Some(vargs) = vargs {
248 vargs.alias_tvars(known)
249 }
250 rtype.alias_tvars(known);
251 for (tv, tc) in constraints.read().iter() {
252 Type::TVar(tv.clone()).alias_tvars(known);
253 tc.alias_tvars(known);
254 }
255 }
256
257 pub fn collect_tvars(&self, known: &mut FxHashMap<ArcStr, TVar>) {
258 let FnType { args, vargs, rtype, constraints } = self;
259 for arg in args.iter() {
260 arg.typ.collect_tvars(known)
261 }
262 if let Some(vargs) = vargs {
263 vargs.collect_tvars(known)
264 }
265 rtype.collect_tvars(known);
266 for (tv, tc) in constraints.read().iter() {
267 Type::TVar(tv.clone()).collect_tvars(known);
268 tc.collect_tvars(known);
269 }
270 }
271
272 pub fn contains<C: Ctx, E: UserEvent>(
273 &self,
274 env: &Env<C, E>,
275 t: &Self,
276 ) -> Result<bool> {
277 thread_local! {
278 static HIST: RefCell<FxHashMap<(usize, usize), bool>> = RefCell::new(HashMap::default());
279 }
280 HIST.with_borrow_mut(|hist| self.contains_int(env, hist, t))
281 }
282
283 pub(super) fn contains_int<C: Ctx, E: UserEvent>(
284 &self,
285 env: &Env<C, E>,
286 hist: &mut FxHashMap<(usize, usize), bool>,
287 t: &Self,
288 ) -> Result<bool> {
289 let mut sul = 0;
290 let mut tul = 0;
291 for (i, a) in self.args.iter().enumerate() {
292 sul = i;
293 match &a.label {
294 None => {
295 break;
296 }
297 Some((l, _)) => match t
298 .args
299 .iter()
300 .find(|a| a.label.as_ref().map(|a| &a.0) == Some(l))
301 {
302 None => return Ok(false),
303 Some(o) => {
304 if !o.typ.contains_int(env, hist, &a.typ)? {
305 return Ok(false);
306 }
307 }
308 },
309 }
310 }
311 for (i, a) in t.args.iter().enumerate() {
312 tul = i;
313 match &a.label {
314 None => {
315 break;
316 }
317 Some((l, opt)) => match self
318 .args
319 .iter()
320 .find(|a| a.label.as_ref().map(|a| &a.0) == Some(l))
321 {
322 Some(_) => (),
323 None => {
324 if !opt {
325 return Ok(false);
326 }
327 }
328 },
329 }
330 }
331 let slen = self.args.len() - sul;
332 let tlen = t.args.len() - tul;
333 Ok(slen == tlen
334 && self
335 .constraints
336 .read()
337 .iter()
338 .map(|(tv, tc)| tc.contains_int(env, hist, &Type::TVar(tv.clone())))
339 .collect::<Result<AndAc>>()?
340 .0
341 && t.constraints
342 .read()
343 .iter()
344 .map(|(tv, tc)| tc.contains_int(env, hist, &Type::TVar(tv.clone())))
345 .collect::<Result<AndAc>>()?
346 .0
347 && t.args[tul..]
348 .iter()
349 .zip(self.args[sul..].iter())
350 .map(|(t, s)| t.typ.contains_int(env, hist, &s.typ))
351 .collect::<Result<AndAc>>()?
352 .0
353 && match (&t.vargs, &self.vargs) {
354 (Some(tv), Some(sv)) => tv.contains_int(env, hist, sv)?,
355 (None, None) => true,
356 (_, _) => false,
357 }
358 && self.rtype.contains_int(env, hist, &t.rtype)?)
359 }
360
361 pub fn check_contains<C: Ctx, E: UserEvent>(
362 &self,
363 env: &Env<C, E>,
364 other: &Self,
365 ) -> Result<()> {
366 if !self.contains(env, other)? {
367 bail!("Fn type mismatch {self} does not contain {other}")
368 }
369 Ok(())
370 }
371
372 pub fn sigmatch<C: Ctx, E: UserEvent>(
375 &self,
376 env: &Env<C, E>,
377 other: &Self,
378 ) -> Result<bool> {
379 let Self { args: args0, vargs: vargs0, rtype: rtype0, constraints: constraints0 } =
380 self;
381 let Self { args: args1, vargs: vargs1, rtype: rtype1, constraints: constraints1 } =
382 other;
383 Ok(args0.len() == args1.len()
384 && args0
385 .iter()
386 .zip(args1.iter())
387 .map(
388 |(a0, a1)| Ok(a0.label == a1.label && a0.typ.contains(env, &a1.typ)?),
389 )
390 .collect::<Result<AndAc>>()?
391 .0
392 && match (vargs0, vargs1) {
393 (None, None) => true,
394 (None, _) | (_, None) => false,
395 (Some(t0), Some(t1)) => t0.contains(env, t1)?,
396 }
397 && rtype0.contains(env, rtype1)?
398 && constraints0
399 .read()
400 .iter()
401 .map(|(tv, tc)| tc.contains(env, &Type::TVar(tv.clone())))
402 .collect::<Result<AndAc>>()?
403 .0
404 && constraints1
405 .read()
406 .iter()
407 .map(|(tv, tc)| tc.contains(env, &Type::TVar(tv.clone())))
408 .collect::<Result<AndAc>>()?
409 .0)
410 }
411
412 pub fn check_sigmatch<C: Ctx, E: UserEvent>(
413 &self,
414 env: &Env<C, E>,
415 other: &Self,
416 ) -> Result<()> {
417 if !self.sigmatch(env, other)? {
418 bail!("Fn signatures do not match {self} does not match {other}")
419 }
420 Ok(())
421 }
422
423 pub fn map_argpos(
424 &self,
425 other: &Self,
426 ) -> FxHashMap<ArcStr, (Option<usize>, Option<usize>)> {
427 let mut tbl: FxHashMap<ArcStr, (Option<usize>, Option<usize>)> =
428 FxHashMap::default();
429 for (i, a) in self.args.iter().enumerate() {
430 match &a.label {
431 None => break,
432 Some((n, _)) => tbl.entry(n.clone()).or_default().0 = Some(i),
433 }
434 }
435 for (i, a) in other.args.iter().enumerate() {
436 match &a.label {
437 None => break,
438 Some((n, _)) => tbl.entry(n.clone()).or_default().1 = Some(i),
439 }
440 }
441 tbl
442 }
443}
444
445impl fmt::Display for FnType {
446 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
447 let constraints = self.constraints.read();
448 if constraints.len() == 0 {
449 write!(f, "fn(")?;
450 } else {
451 write!(f, "fn<")?;
452 for (i, (tv, t)) in constraints.iter().enumerate() {
453 write!(f, "{tv}: {t}")?;
454 if i < constraints.len() - 1 {
455 write!(f, ", ")?;
456 }
457 }
458 write!(f, ">(")?;
459 }
460 for (i, a) in self.args.iter().enumerate() {
461 match &a.label {
462 Some((l, true)) => write!(f, "?#{l}: ")?,
463 Some((l, false)) => write!(f, "#{l}: ")?,
464 None => (),
465 }
466 write!(f, "{}", a.typ)?;
467 if i < self.args.len() - 1 || self.vargs.is_some() {
468 write!(f, ", ")?;
469 }
470 }
471 if let Some(vargs) = &self.vargs {
472 write!(f, "@args: {}", vargs)?;
473 }
474 write!(f, ") -> {}", self.rtype)
475 }
476}