1use super::AndAc;
2use crate::{
3 env::Env,
4 expr::{
5 print::{PrettyBuf, PrettyDisplay},
6 ModPath,
7 },
8 typ::{ContainsFlags, TVar, Type},
9 Rt, UserEvent,
10};
11use anyhow::{bail, Context, Result};
12use arcstr::ArcStr;
13use enumflags2::BitFlags;
14use fxhash::{FxHashMap, FxHashSet};
15use parking_lot::RwLock;
16use poolshark::local::LPooled;
17use smallvec::{smallvec, SmallVec};
18use std::{
19 cmp::{Eq, Ordering, PartialEq},
20 fmt::{self, Debug, Write},
21};
22use triomphe::Arc;
23
24#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
25pub struct FnArgType {
26 pub label: Option<(ArcStr, bool)>,
27 pub typ: Type,
28}
29
30#[derive(Debug, Clone)]
31pub struct FnType {
32 pub args: Arc<[FnArgType]>,
33 pub vargs: Option<Type>,
34 pub rtype: Type,
35 pub constraints: Arc<RwLock<LPooled<Vec<(TVar, Type)>>>>,
36 pub throws: Type,
37 pub explicit_throws: bool,
38}
39
40impl PartialEq for FnType {
41 fn eq(&self, other: &Self) -> bool {
42 let Self {
43 args: args0,
44 vargs: vargs0,
45 rtype: rtype0,
46 constraints: constraints0,
47 throws: th0,
48 explicit_throws: _,
49 } = self;
50 let Self {
51 args: args1,
52 vargs: vargs1,
53 rtype: rtype1,
54 constraints: constraints1,
55 throws: th1,
56 explicit_throws: _,
57 } = other;
58 args0 == args1
59 && vargs0 == vargs1
60 && rtype0 == rtype1
61 && &*constraints0.read() == &*constraints1.read()
62 && th0 == th1
63 }
64}
65
66impl Eq for FnType {}
67
68impl PartialOrd for FnType {
69 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
70 use std::cmp::Ordering;
71 let Self {
72 args: args0,
73 vargs: vargs0,
74 rtype: rtype0,
75 constraints: constraints0,
76 throws: th0,
77 explicit_throws: _,
78 } = self;
79 let Self {
80 args: args1,
81 vargs: vargs1,
82 rtype: rtype1,
83 constraints: constraints1,
84 throws: th1,
85 explicit_throws: _,
86 } = other;
87 match args0.partial_cmp(&args1) {
88 Some(Ordering::Equal) => match vargs0.partial_cmp(vargs1) {
89 Some(Ordering::Equal) => match rtype0.partial_cmp(rtype1) {
90 Some(Ordering::Equal) => {
91 match constraints0.read().partial_cmp(&*constraints1.read()) {
92 Some(Ordering::Equal) => th0.partial_cmp(th1),
93 r => r,
94 }
95 }
96 r => r,
97 },
98 r => r,
99 },
100 r => r,
101 }
102 }
103}
104
105impl Ord for FnType {
106 fn cmp(&self, other: &Self) -> Ordering {
107 self.partial_cmp(other).unwrap()
108 }
109}
110
111impl Default for FnType {
112 fn default() -> Self {
113 Self {
114 args: Arc::from_iter([]),
115 vargs: None,
116 rtype: Default::default(),
117 constraints: Arc::new(RwLock::new(LPooled::take())),
118 throws: Default::default(),
119 explicit_throws: false,
120 }
121 }
122}
123
124impl FnType {
125 pub(super) fn normalize(&self) -> Self {
126 let Self { args, vargs, rtype, constraints, throws, explicit_throws } = self;
127 let args = Arc::from_iter(
128 args.iter()
129 .map(|a| FnArgType { label: a.label.clone(), typ: a.typ.normalize() }),
130 );
131 let vargs = vargs.as_ref().map(|t| t.normalize());
132 let rtype = rtype.normalize();
133 let constraints = Arc::new(RwLock::new(
134 constraints
135 .read()
136 .iter()
137 .map(|(tv, t)| (tv.clone(), t.normalize()))
138 .collect(),
139 ));
140 let throws = throws.normalize();
141 let explicit_throws = *explicit_throws;
142 FnType { args, vargs, rtype, constraints, throws, explicit_throws }
143 }
144
145 pub fn unbind_tvars(&self) {
146 let FnType { args, vargs, rtype, constraints, throws, explicit_throws: _ } = self;
147 for arg in args.iter() {
148 arg.typ.unbind_tvars()
149 }
150 if let Some(t) = vargs {
151 t.unbind_tvars()
152 }
153 rtype.unbind_tvars();
154 for (tv, _) in constraints.read().iter() {
155 tv.unbind();
156 }
157 throws.unbind_tvars();
158 }
159
160 pub fn constrain_known(&self) {
161 let mut known = LPooled::take();
162 self.collect_tvars(&mut known);
163 let mut constraints = self.constraints.write();
164 for (name, tv) in known.drain() {
165 if let Some(t) = tv.read().typ.read().as_ref()
166 && t != &Type::Bottom
167 && t != &Type::Any
168 {
169 if !constraints.iter().any(|(tv, _)| tv.name == name) {
170 t.bind_as(&Type::Any);
171 constraints.push((tv.clone(), t.normalize()));
172 }
173 }
174 }
175 }
176
177 pub fn reset_tvars(&self) -> Self {
178 let FnType { args, vargs, rtype, constraints, throws, explicit_throws } = self;
179 let args = Arc::from_iter(
180 args.iter()
181 .map(|a| FnArgType { label: a.label.clone(), typ: a.typ.reset_tvars() }),
182 );
183 let vargs = vargs.as_ref().map(|t| t.reset_tvars());
184 let rtype = rtype.reset_tvars();
185 let constraints = Arc::new(RwLock::new(
186 constraints
187 .read()
188 .iter()
189 .map(|(tv, tc)| (TVar::empty_named(tv.name.clone()), tc.reset_tvars()))
190 .collect(),
191 ));
192 let throws = throws.reset_tvars();
193 let explicit_throws = *explicit_throws;
194 FnType { args, vargs, rtype, constraints, throws, explicit_throws }
195 }
196
197 pub fn replace_tvars(&self, known: &FxHashMap<ArcStr, Type>) -> Self {
198 let FnType { args, vargs, rtype, constraints, throws, explicit_throws } = self;
199 let args = Arc::from_iter(args.iter().map(|a| FnArgType {
200 label: a.label.clone(),
201 typ: a.typ.replace_tvars(known),
202 }));
203 let vargs = vargs.as_ref().map(|t| t.replace_tvars(known));
204 let rtype = rtype.replace_tvars(known);
205 let constraints = constraints.clone();
206 let throws = throws.replace_tvars(known);
207 let explicit_throws = *explicit_throws;
208 FnType { args, vargs, rtype, constraints, throws, explicit_throws }
209 }
210
211 pub fn replace_auto_constrained(&self) -> Self {
215 let mut known: LPooled<FxHashMap<ArcStr, Type>> = LPooled::take();
216 let Self { args, vargs, rtype, constraints, throws, explicit_throws } = self;
217 let constraints: LPooled<Vec<(TVar, Type)>> = constraints
218 .read()
219 .iter()
220 .filter_map(|(tv, ct)| {
221 if tv.name.starts_with("_") {
222 known.insert(tv.name.clone(), ct.clone());
223 None
224 } else {
225 Some((tv.clone(), ct.clone()))
226 }
227 })
228 .collect();
229 let constraints = Arc::new(RwLock::new(constraints));
230 let args = Arc::from_iter(args.iter().map(|FnArgType { label, typ }| {
231 FnArgType { label: label.clone(), typ: typ.replace_tvars(&known) }
232 }));
233 let vargs = vargs.as_ref().map(|t| t.replace_tvars(&known));
234 let rtype = rtype.replace_tvars(&known);
235 let throws = throws.replace_tvars(&known);
236 let explicit_throws = *explicit_throws;
237 Self { args, vargs, rtype, constraints, throws, explicit_throws }
238 }
239
240 pub fn has_unbound(&self) -> bool {
241 let FnType { args, vargs, rtype, constraints, throws, explicit_throws: _ } = self;
242 args.iter().any(|a| a.typ.has_unbound())
243 || vargs.as_ref().map(|t| t.has_unbound()).unwrap_or(false)
244 || rtype.has_unbound()
245 || constraints
246 .read()
247 .iter()
248 .any(|(tv, tc)| tv.read().typ.read().is_none() || tc.has_unbound())
249 || throws.has_unbound()
250 }
251
252 pub fn bind_as(&self, t: &Type) {
253 let FnType { args, vargs, rtype, constraints, throws, explicit_throws: _ } = self;
254 for a in args.iter() {
255 a.typ.bind_as(t)
256 }
257 if let Some(va) = vargs.as_ref() {
258 va.bind_as(t)
259 }
260 rtype.bind_as(t);
261 for (tv, tc) in constraints.read().iter() {
262 let tv = tv.read();
263 let mut tv = tv.typ.write();
264 if tv.is_none() {
265 *tv = Some(t.clone())
266 }
267 tc.bind_as(t)
268 }
269 throws.bind_as(t);
270 }
271
272 pub fn alias_tvars(&self, known: &mut FxHashMap<ArcStr, TVar>) {
273 let FnType { args, vargs, rtype, constraints, throws, explicit_throws: _ } = self;
274 for arg in args.iter() {
275 arg.typ.alias_tvars(known)
276 }
277 if let Some(vargs) = vargs {
278 vargs.alias_tvars(known)
279 }
280 rtype.alias_tvars(known);
281 for (tv, tc) in constraints.read().iter() {
282 Type::TVar(tv.clone()).alias_tvars(known);
283 tc.alias_tvars(known);
284 }
285 throws.alias_tvars(known);
286 }
287
288 pub fn collect_tvars(&self, known: &mut FxHashMap<ArcStr, TVar>) {
289 let FnType { args, vargs, rtype, constraints, throws, explicit_throws: _ } = self;
290 for arg in args.iter() {
291 arg.typ.collect_tvars(known)
292 }
293 if let Some(vargs) = vargs {
294 vargs.collect_tvars(known)
295 }
296 rtype.collect_tvars(known);
297 for (tv, tc) in constraints.read().iter() {
298 Type::TVar(tv.clone()).collect_tvars(known);
299 tc.collect_tvars(known);
300 }
301 throws.collect_tvars(known);
302 }
303
304 pub fn contains<R: Rt, E: UserEvent>(
305 &self,
306 env: &Env<R, E>,
307 t: &Self,
308 ) -> Result<bool> {
309 self.contains_int(
310 ContainsFlags::AliasTVars | ContainsFlags::InitTVars,
311 env,
312 &mut LPooled::take(),
313 t,
314 )
315 }
316
317 pub(super) fn contains_int<R: Rt, E: UserEvent>(
318 &self,
319 flags: BitFlags<ContainsFlags>,
320 env: &Env<R, E>,
321 hist: &mut FxHashMap<(usize, usize), bool>,
322 t: &Self,
323 ) -> Result<bool> {
324 let mut sul = 0;
325 let mut tul = 0;
326 for (i, a) in self.args.iter().enumerate() {
327 sul = i;
328 match &a.label {
329 None => {
330 break;
331 }
332 Some((l, _)) => match t
333 .args
334 .iter()
335 .find(|a| a.label.as_ref().map(|a| &a.0) == Some(l))
336 {
337 None => return Ok(false),
338 Some(o) => {
339 if !o.typ.contains_int(flags, env, hist, &a.typ)? {
340 return Ok(false);
341 }
342 }
343 },
344 }
345 }
346 for (i, a) in t.args.iter().enumerate() {
347 tul = i;
348 match &a.label {
349 None => {
350 break;
351 }
352 Some((l, opt)) => match self
353 .args
354 .iter()
355 .find(|a| a.label.as_ref().map(|a| &a.0) == Some(l))
356 {
357 Some(_) => (),
358 None => {
359 if !opt {
360 return Ok(false);
361 }
362 }
363 },
364 }
365 }
366 let slen = self.args.len() - sul;
367 let tlen = t.args.len() - tul;
368 Ok(slen == tlen
369 && t.args[tul..]
370 .iter()
371 .zip(self.args[sul..].iter())
372 .map(|(t, s)| t.typ.contains_int(flags, env, hist, &s.typ))
373 .collect::<Result<AndAc>>()?
374 .0
375 && match (&t.vargs, &self.vargs) {
376 (Some(tv), Some(sv)) => tv.contains_int(flags, env, hist, sv)?,
377 (None, None) => true,
378 (_, _) => false,
379 }
380 && self.rtype.contains_int(flags, env, hist, &t.rtype)?
381 && self
382 .constraints
383 .read()
384 .iter()
385 .map(|(tv, tc)| {
386 tc.contains_int(flags, env, hist, &Type::TVar(tv.clone()))
387 })
388 .collect::<Result<AndAc>>()?
389 .0
390 && t.constraints
391 .read()
392 .iter()
393 .map(|(tv, tc)| {
394 tc.contains_int(flags, env, hist, &Type::TVar(tv.clone()))
395 })
396 .collect::<Result<AndAc>>()?
397 .0
398 && self.throws.contains_int(flags, env, hist, &t.throws)?)
399 }
400
401 pub fn check_contains<R: Rt, E: UserEvent>(
402 &self,
403 env: &Env<R, E>,
404 other: &Self,
405 ) -> Result<()> {
406 if !self.contains(env, other)? {
407 bail!("Fn type mismatch {self} does not contain {other}")
408 }
409 Ok(())
410 }
411
412 pub fn sig_contains<R: Rt, E: UserEvent>(
415 &self,
416 env: &Env<R, E>,
417 other: &Self,
418 ) -> Result<bool> {
419 let Self {
420 args: args0,
421 vargs: vargs0,
422 rtype: rtype0,
423 constraints: constraints0,
424 throws: tr0,
425 explicit_throws: _,
426 } = self;
427 let Self {
428 args: args1,
429 vargs: vargs1,
430 rtype: rtype1,
431 constraints: constraints1,
432 throws: tr1,
433 explicit_throws: _,
434 } = other;
435 Ok(args0.len() == args1.len()
436 && args0
437 .iter()
438 .zip(args1.iter())
439 .map(
440 |(a0, a1)| Ok(a0.label == a1.label && a0.typ.contains(env, &a1.typ)?),
441 )
442 .collect::<Result<AndAc>>()?
443 .0
444 && match (vargs0, vargs1) {
445 (None, None) => true,
446 (None, _) | (_, None) => false,
447 (Some(t0), Some(t1)) => t0.contains(env, t1)?,
448 }
449 && rtype0.contains(env, rtype1)?
450 && constraints0
451 .read()
452 .iter()
453 .map(|(tv, tc)| tc.contains(env, &Type::TVar(tv.clone())))
454 .collect::<Result<AndAc>>()?
455 .0
456 && constraints1
457 .read()
458 .iter()
459 .map(|(tv, tc)| tc.contains(env, &Type::TVar(tv.clone())))
460 .collect::<Result<AndAc>>()?
461 .0
462 && tr0.contains(env, tr1)?)
463 }
464
465 pub fn check_sig_contains<R: Rt, E: UserEvent>(
466 &self,
467 env: &Env<R, E>,
468 other: &Self,
469 ) -> Result<()> {
470 if !self.sig_contains(env, other)? {
471 bail!("Fn signature {self} does not contain {other}")
472 }
473 Ok(())
474 }
475
476 pub fn sig_matches<R: Rt, E: UserEvent>(
477 &self,
478 env: &Env<R, E>,
479 impl_fn: &Self,
480 ) -> Result<()> {
481 self.sig_matches_int(env, impl_fn, &mut LPooled::take(), &mut LPooled::take())
482 }
483
484 pub(super) fn sig_matches_int<R: Rt, E: UserEvent>(
485 &self,
486 env: &Env<R, E>,
487 impl_fn: &Self,
488 tvar_map: &mut FxHashMap<usize, Type>,
489 hist: &mut FxHashSet<(usize, usize)>,
490 ) -> Result<()> {
491 let Self {
492 args: sig_args,
493 vargs: sig_vargs,
494 rtype: sig_rtype,
495 constraints: sig_constraints,
496 throws: sig_throws,
497 explicit_throws: _,
498 } = self;
499 let Self {
500 args: impl_args,
501 vargs: impl_vargs,
502 rtype: impl_rtype,
503 constraints: impl_constraints,
504 throws: impl_throws,
505 explicit_throws: _,
506 } = impl_fn;
507 if sig_args.len() != impl_args.len() {
508 bail!(
509 "argument count mismatch: signature has {}, implementation has {}",
510 sig_args.len(),
511 impl_args.len()
512 );
513 }
514 for (i, (sig_arg, impl_arg)) in sig_args.iter().zip(impl_args.iter()).enumerate()
515 {
516 if sig_arg.label != impl_arg.label {
517 bail!(
518 "argument {} label mismatch: signature has {:?}, implementation has {:?}",
519 i,
520 sig_arg.label,
521 impl_arg.label
522 );
523 }
524 sig_arg
525 .typ
526 .sig_matches_int(env, &impl_arg.typ, tvar_map, hist)
527 .with_context(|| format!("in argument {i}"))?;
528 }
529 match (sig_vargs, impl_vargs) {
530 (None, None) => (),
531 (Some(sig_va), Some(impl_va)) => {
532 sig_va
533 .sig_matches_int(env, impl_va, tvar_map, hist)
534 .context("in variadic argument")?;
535 }
536 (None, Some(_)) => {
537 bail!("signature has no variadic args but implementation does")
538 }
539 (Some(_), None) => {
540 bail!("signature has variadic args but implementation does not")
541 }
542 }
543 sig_rtype
544 .sig_matches_int(env, impl_rtype, tvar_map, hist)
545 .context("in return type")?;
546 sig_throws
547 .sig_matches_int(env, impl_throws, tvar_map, hist)
548 .context("in throws clause")?;
549 let sig_cons = sig_constraints.read();
550 let impl_cons = impl_constraints.read();
551 for (sig_tv, sig_tc) in sig_cons.iter() {
552 if !impl_cons
553 .iter()
554 .any(|(impl_tv, impl_tc)| sig_tv == impl_tv && sig_tc == impl_tc)
555 {
556 bail!("missing constraint {sig_tv}: {sig_tc} in implementation")
557 }
558 }
559 for (impl_tv, impl_tc) in impl_cons.iter() {
560 match tvar_map.get(&impl_tv.inner_addr()).cloned() {
561 None | Some(Type::TVar(_)) => (),
562 Some(sig_type) => {
563 sig_type.sig_matches_int(env, impl_tc, tvar_map, hist).with_context(|| {
564 format!(
565 "signature has concrete type {sig_type}, implementation constraint is {impl_tc}"
566 )
567 })?;
568 }
569 }
570 }
571 Ok(())
572 }
573
574 pub fn map_argpos(
575 &self,
576 other: &Self,
577 ) -> LPooled<FxHashMap<ArcStr, (Option<usize>, Option<usize>)>> {
578 let mut tbl: LPooled<FxHashMap<ArcStr, (Option<usize>, Option<usize>)>> =
579 LPooled::take();
580 for (i, a) in self.args.iter().enumerate() {
581 match &a.label {
582 None => break,
583 Some((n, _)) => tbl.entry(n.clone()).or_default().0 = Some(i),
584 }
585 }
586 for (i, a) in other.args.iter().enumerate() {
587 match &a.label {
588 None => break,
589 Some((n, _)) => tbl.entry(n.clone()).or_default().1 = Some(i),
590 }
591 }
592 tbl
593 }
594
595 pub fn scope_refs(&self, scope: &ModPath) -> Self {
596 let vargs = self.vargs.as_ref().map(|t| t.scope_refs(scope));
597 let rtype = self.rtype.scope_refs(scope);
598 let args =
599 Arc::from_iter(self.args.iter().map(|a| FnArgType {
600 label: a.label.clone(),
601 typ: a.typ.scope_refs(scope),
602 }));
603 let mut cres: SmallVec<[(TVar, Type); 4]> = smallvec![];
604 for (tv, tc) in self.constraints.read().iter() {
605 let tv = tv.scope_refs(scope);
606 let tc = tc.scope_refs(scope);
607 cres.push((tv, tc));
608 }
609 let throws = self.throws.scope_refs(scope);
610 FnType {
611 args,
612 rtype,
613 constraints: Arc::new(RwLock::new(cres.into_iter().collect())),
614 vargs,
615 throws,
616 explicit_throws: self.explicit_throws,
617 }
618 }
619}
620
621impl fmt::Display for FnType {
622 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
623 let constraints = self.constraints.read();
624 if constraints.len() == 0 {
625 write!(f, "fn(")?;
626 } else {
627 write!(f, "fn<")?;
628 for (i, (tv, t)) in constraints.iter().enumerate() {
629 write!(f, "{tv}: {t}")?;
630 if i < constraints.len() - 1 {
631 write!(f, ", ")?;
632 }
633 }
634 write!(f, ">(")?;
635 }
636 for (i, a) in self.args.iter().enumerate() {
637 match &a.label {
638 Some((l, true)) => write!(f, "?#{l}: ")?,
639 Some((l, false)) => write!(f, "#{l}: ")?,
640 None => (),
641 }
642 write!(f, "{}", a.typ)?;
643 if i < self.args.len() - 1 || self.vargs.is_some() {
644 write!(f, ", ")?;
645 }
646 }
647 if let Some(vargs) = &self.vargs {
648 write!(f, "@args: {}", vargs)?;
649 }
650 match &self.rtype {
651 Type::Fn(ft) => write!(f, ") -> ({ft})")?,
652 Type::ByRef(t) => match &**t {
653 Type::Fn(ft) => write!(f, ") -> &({ft})")?,
654 t => write!(f, ") -> &{t}")?,
655 },
656 t => write!(f, ") -> {t}")?,
657 }
658 match &self.throws {
659 Type::Bottom => Ok(()),
660 Type::TVar(tv) if *tv.read().typ.read() == Some(Type::Bottom) => Ok(()),
661 t => write!(f, " throws {t}"),
662 }
663 }
664}
665
666impl PrettyDisplay for FnType {
667 fn fmt_pretty_inner(&self, buf: &mut PrettyBuf) -> fmt::Result {
668 let constraints = self.constraints.read();
669 if constraints.is_empty() {
670 writeln!(buf, "fn(")?;
671 } else {
672 writeln!(buf, "fn<")?;
673 buf.with_indent(2, |buf| {
674 for (i, (tv, t)) in constraints.iter().enumerate() {
675 write!(buf, "{tv}: ")?;
676 buf.with_indent(2, |buf| t.fmt_pretty(buf))?;
677 if i < constraints.len() - 1 {
678 buf.kill_newline();
679 writeln!(buf, ",")?;
680 }
681 }
682 Ok(())
683 })?;
684 writeln!(buf, ">(")?;
685 }
686 buf.with_indent(2, |buf| {
687 for (i, a) in self.args.iter().enumerate() {
688 match &a.label {
689 Some((l, true)) => write!(buf, "?#{l}: ")?,
690 Some((l, false)) => write!(buf, "#{l}: ")?,
691 None => (),
692 }
693 buf.with_indent(2, |buf| a.typ.fmt_pretty(buf))?;
694 if i < self.args.len() - 1 || self.vargs.is_some() {
695 buf.kill_newline();
696 writeln!(buf, ",")?;
697 }
698 }
699 if let Some(vargs) = &self.vargs {
700 write!(buf, "@args: ")?;
701 buf.with_indent(2, |buf| vargs.fmt_pretty(buf))?;
702 }
703 Ok(())
704 })?;
705 match &self.rtype {
706 Type::Fn(ft) => {
707 write!(buf, ") -> (")?;
708 ft.fmt_pretty(buf)?;
709 buf.kill_newline();
710 writeln!(buf, ")")?;
711 }
712 Type::ByRef(t) => match &**t {
713 Type::Fn(ft) => {
714 write!(buf, ") -> &(")?;
715 ft.fmt_pretty(buf)?;
716 buf.kill_newline();
717 writeln!(buf, ")")?;
718 }
719 t => {
720 write!(buf, ") -> &")?;
721 t.fmt_pretty(buf)?;
722 }
723 },
724 t => {
725 write!(buf, ") -> ")?;
726 t.fmt_pretty(buf)?;
727 }
728 }
729 match &self.throws {
730 Type::Bottom if !self.explicit_throws => Ok(()),
731 Type::TVar(tv)
732 if *tv.read().typ.read() == Some(Type::Bottom)
733 && !self.explicit_throws =>
734 {
735 Ok(())
736 }
737 t => {
738 buf.kill_newline();
739 write!(buf, " throws ")?;
740 t.fmt_pretty(buf)
741 }
742 }
743 }
744}