1use super::AndAc;
2use crate::{
3 env::Env,
4 expr::ModPath,
5 typ::{ContainsFlags, TVar, Type},
6 Rt, UserEvent,
7};
8use anyhow::{bail, Result};
9use arcstr::ArcStr;
10use enumflags2::BitFlags;
11use fxhash::FxHashMap;
12use parking_lot::RwLock;
13use poolshark::local::LPooled;
14use smallvec::{smallvec, SmallVec};
15use std::{
16 cmp::{Eq, Ordering, PartialEq},
17 fmt::{self, Debug},
18};
19use triomphe::Arc;
20
21#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
22pub struct FnArgType {
23 pub label: Option<(ArcStr, bool)>,
24 pub typ: Type,
25}
26
27#[derive(Debug, Clone)]
28pub struct FnType {
29 pub args: Arc<[FnArgType]>,
30 pub vargs: Option<Type>,
31 pub rtype: Type,
32 pub constraints: Arc<RwLock<LPooled<Vec<(TVar, Type)>>>>,
33 pub throws: Type,
34}
35
36impl PartialEq for FnType {
37 fn eq(&self, other: &Self) -> bool {
38 let Self {
39 args: args0,
40 vargs: vargs0,
41 rtype: rtype0,
42 constraints: constraints0,
43 throws: th0,
44 } = self;
45 let Self {
46 args: args1,
47 vargs: vargs1,
48 rtype: rtype1,
49 constraints: constraints1,
50 throws: th1,
51 } = other;
52 args0 == args1
53 && vargs0 == vargs1
54 && rtype0 == rtype1
55 && &*constraints0.read() == &*constraints1.read()
56 && th0 == th1
57 }
58}
59
60impl Eq for FnType {}
61
62impl PartialOrd for FnType {
63 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
64 use std::cmp::Ordering;
65 let Self {
66 args: args0,
67 vargs: vargs0,
68 rtype: rtype0,
69 constraints: constraints0,
70 throws: th0,
71 } = self;
72 let Self {
73 args: args1,
74 vargs: vargs1,
75 rtype: rtype1,
76 constraints: constraints1,
77 throws: th1,
78 } = other;
79 match args0.partial_cmp(&args1) {
80 Some(Ordering::Equal) => match vargs0.partial_cmp(vargs1) {
81 Some(Ordering::Equal) => match rtype0.partial_cmp(rtype1) {
82 Some(Ordering::Equal) => {
83 match constraints0.read().partial_cmp(&*constraints1.read()) {
84 Some(Ordering::Equal) => th0.partial_cmp(th1),
85 r => r,
86 }
87 }
88 r => r,
89 },
90 r => r,
91 },
92 r => r,
93 }
94 }
95}
96
97impl Ord for FnType {
98 fn cmp(&self, other: &Self) -> Ordering {
99 self.partial_cmp(other).unwrap()
100 }
101}
102
103impl Default for FnType {
104 fn default() -> Self {
105 Self {
106 args: Arc::from_iter([]),
107 vargs: None,
108 rtype: Default::default(),
109 constraints: Arc::new(RwLock::new(LPooled::take())),
110 throws: Default::default(),
111 }
112 }
113}
114
115impl FnType {
116 pub(super) fn normalize(&self) -> Self {
117 let Self { args, vargs, rtype, constraints, throws } = self;
118 let args = Arc::from_iter(
119 args.iter()
120 .map(|a| FnArgType { label: a.label.clone(), typ: a.typ.normalize() }),
121 );
122 let vargs = vargs.as_ref().map(|t| t.normalize());
123 let rtype = rtype.normalize();
124 let constraints = Arc::new(RwLock::new(
125 constraints
126 .read()
127 .iter()
128 .map(|(tv, t)| (tv.clone(), t.normalize()))
129 .collect(),
130 ));
131 let throws = throws.normalize();
132 FnType { args, vargs, rtype, constraints, throws }
133 }
134
135 pub fn unbind_tvars(&self) {
136 let FnType { args, vargs, rtype, constraints, throws } = self;
137 for arg in args.iter() {
138 arg.typ.unbind_tvars()
139 }
140 if let Some(t) = vargs {
141 t.unbind_tvars()
142 }
143 rtype.unbind_tvars();
144 for (tv, tc) in constraints.read().iter() {
145 tv.unbind();
146 tc.unbind_tvars()
147 }
148 throws.unbind_tvars();
149 }
150
151 pub fn constrain_known(&self) {
152 let mut known = LPooled::take();
153 self.collect_tvars(&mut known);
154 let mut constraints = self.constraints.write();
155 for (name, tv) in known.drain() {
156 if let Some(t) = tv.read().typ.read().as_ref() {
157 if !constraints.iter().any(|(tv, _)| tv.name == name) {
158 t.bind_as(&Type::Any);
159 constraints.push((tv.clone(), t.normalize()));
160 }
161 }
162 }
163 }
164
165 pub fn reset_tvars(&self) -> Self {
166 let FnType { args, vargs, rtype, constraints, throws } = self;
167 let args = Arc::from_iter(
168 args.iter()
169 .map(|a| FnArgType { label: a.label.clone(), typ: a.typ.reset_tvars() }),
170 );
171 let vargs = vargs.as_ref().map(|t| t.reset_tvars());
172 let rtype = rtype.reset_tvars();
173 let constraints = Arc::new(RwLock::new(
174 constraints
175 .read()
176 .iter()
177 .map(|(tv, tc)| (TVar::empty_named(tv.name.clone()), tc.reset_tvars()))
178 .collect(),
179 ));
180 let throws = throws.reset_tvars();
181 FnType { args, vargs, rtype, constraints, throws }
182 }
183
184 pub fn replace_tvars(&self, known: &FxHashMap<ArcStr, Type>) -> Self {
185 let FnType { args, vargs, rtype, constraints, throws } = self;
186 let args = Arc::from_iter(args.iter().map(|a| FnArgType {
187 label: a.label.clone(),
188 typ: a.typ.replace_tvars(known),
189 }));
190 let vargs = vargs.as_ref().map(|t| t.replace_tvars(known));
191 let rtype = rtype.replace_tvars(known);
192 let constraints = constraints.clone();
193 let throws = throws.replace_tvars(known);
194 FnType { args, vargs, rtype, constraints, throws }
195 }
196
197 pub fn replace_auto_constrained(&self) -> Self {
201 let mut known: LPooled<FxHashMap<ArcStr, Type>> = LPooled::take();
202 let Self { args, vargs, rtype, constraints, throws } = self;
203 let constraints: LPooled<Vec<(TVar, Type)>> = constraints
204 .read()
205 .iter()
206 .filter_map(|(tv, ct)| {
207 if tv.name.starts_with("_") {
208 known.insert(tv.name.clone(), ct.clone());
209 None
210 } else {
211 Some((tv.clone(), ct.clone()))
212 }
213 })
214 .collect();
215 let constraints = Arc::new(RwLock::new(constraints));
216 let args = Arc::from_iter(args.iter().map(|FnArgType { label, typ }| {
217 FnArgType { label: label.clone(), typ: typ.replace_tvars(&known) }
218 }));
219 let vargs = vargs.as_ref().map(|t| t.replace_tvars(&known));
220 let rtype = rtype.replace_tvars(&known);
221 let throws = throws.replace_tvars(&known);
222 Self { args, vargs, rtype, constraints, throws }
223 }
224
225 pub fn has_unbound(&self) -> bool {
226 let FnType { args, vargs, rtype, constraints, throws } = self;
227 args.iter().any(|a| a.typ.has_unbound())
228 || vargs.as_ref().map(|t| t.has_unbound()).unwrap_or(false)
229 || rtype.has_unbound()
230 || constraints
231 .read()
232 .iter()
233 .any(|(tv, tc)| tv.read().typ.read().is_none() || tc.has_unbound())
234 || throws.has_unbound()
235 }
236
237 pub fn bind_as(&self, t: &Type) {
238 let FnType { args, vargs, rtype, constraints, throws } = self;
239 for a in args.iter() {
240 a.typ.bind_as(t)
241 }
242 if let Some(va) = vargs.as_ref() {
243 va.bind_as(t)
244 }
245 rtype.bind_as(t);
246 for (tv, tc) in constraints.read().iter() {
247 let tv = tv.read();
248 let mut tv = tv.typ.write();
249 if tv.is_none() {
250 *tv = Some(t.clone())
251 }
252 tc.bind_as(t)
253 }
254 throws.bind_as(t);
255 }
256
257 pub fn alias_tvars(&self, known: &mut FxHashMap<ArcStr, TVar>) {
258 let FnType { args, vargs, rtype, constraints, throws } = self;
259 for arg in args.iter() {
260 arg.typ.alias_tvars(known)
261 }
262 if let Some(vargs) = vargs {
263 vargs.alias_tvars(known)
264 }
265 rtype.alias_tvars(known);
266 for (tv, tc) in constraints.read().iter() {
267 Type::TVar(tv.clone()).alias_tvars(known);
268 tc.alias_tvars(known);
269 }
270 throws.alias_tvars(known);
271 }
272
273 pub fn collect_tvars(&self, known: &mut FxHashMap<ArcStr, TVar>) {
274 let FnType { args, vargs, rtype, constraints, throws } = self;
275 for arg in args.iter() {
276 arg.typ.collect_tvars(known)
277 }
278 if let Some(vargs) = vargs {
279 vargs.collect_tvars(known)
280 }
281 rtype.collect_tvars(known);
282 for (tv, tc) in constraints.read().iter() {
283 Type::TVar(tv.clone()).collect_tvars(known);
284 tc.collect_tvars(known);
285 }
286 throws.collect_tvars(known);
287 }
288
289 pub fn contains<R: Rt, E: UserEvent>(
290 &self,
291 env: &Env<R, E>,
292 t: &Self,
293 ) -> Result<bool> {
294 self.contains_int(
295 ContainsFlags::AliasTVars | ContainsFlags::InitTVars,
296 env,
297 &mut LPooled::take(),
298 t,
299 )
300 }
301
302 pub(super) fn contains_int<R: Rt, E: UserEvent>(
303 &self,
304 flags: BitFlags<ContainsFlags>,
305 env: &Env<R, E>,
306 hist: &mut FxHashMap<(usize, usize), bool>,
307 t: &Self,
308 ) -> Result<bool> {
309 let mut sul = 0;
310 let mut tul = 0;
311 for (i, a) in self.args.iter().enumerate() {
312 sul = i;
313 match &a.label {
314 None => {
315 break;
316 }
317 Some((l, _)) => match t
318 .args
319 .iter()
320 .find(|a| a.label.as_ref().map(|a| &a.0) == Some(l))
321 {
322 None => return Ok(false),
323 Some(o) => {
324 if !o.typ.contains_int(flags, env, hist, &a.typ)? {
325 return Ok(false);
326 }
327 }
328 },
329 }
330 }
331 for (i, a) in t.args.iter().enumerate() {
332 tul = i;
333 match &a.label {
334 None => {
335 break;
336 }
337 Some((l, opt)) => match self
338 .args
339 .iter()
340 .find(|a| a.label.as_ref().map(|a| &a.0) == Some(l))
341 {
342 Some(_) => (),
343 None => {
344 if !opt {
345 return Ok(false);
346 }
347 }
348 },
349 }
350 }
351 let slen = self.args.len() - sul;
352 let tlen = t.args.len() - tul;
353 Ok(slen == tlen
354 && t.args[tul..]
355 .iter()
356 .zip(self.args[sul..].iter())
357 .map(|(t, s)| t.typ.contains_int(flags, env, hist, &s.typ))
358 .collect::<Result<AndAc>>()?
359 .0
360 && match (&t.vargs, &self.vargs) {
361 (Some(tv), Some(sv)) => tv.contains_int(flags, env, hist, sv)?,
362 (None, None) => true,
363 (_, _) => false,
364 }
365 && self.rtype.contains_int(flags, env, hist, &t.rtype)?
366 && self
367 .constraints
368 .read()
369 .iter()
370 .map(|(tv, tc)| {
371 tc.contains_int(flags, env, hist, &Type::TVar(tv.clone()))
372 })
373 .collect::<Result<AndAc>>()?
374 .0
375 && t.constraints
376 .read()
377 .iter()
378 .map(|(tv, tc)| {
379 tc.contains_int(flags, env, hist, &Type::TVar(tv.clone()))
380 })
381 .collect::<Result<AndAc>>()?
382 .0
383 && self.throws.contains_int(flags, env, hist, &t.throws)?)
384 }
385
386 pub fn check_contains<R: Rt, E: UserEvent>(
387 &self,
388 env: &Env<R, E>,
389 other: &Self,
390 ) -> Result<()> {
391 if !self.contains(env, other)? {
392 bail!("Fn type mismatch {self} does not contain {other}")
393 }
394 Ok(())
395 }
396
397 pub fn sigmatch<R: Rt, E: UserEvent>(
400 &self,
401 env: &Env<R, E>,
402 other: &Self,
403 ) -> Result<bool> {
404 let Self {
405 args: args0,
406 vargs: vargs0,
407 rtype: rtype0,
408 constraints: constraints0,
409 throws: tr0,
410 } = self;
411 let Self {
412 args: args1,
413 vargs: vargs1,
414 rtype: rtype1,
415 constraints: constraints1,
416 throws: tr1,
417 } = other;
418 Ok(args0.len() == args1.len()
419 && args0
420 .iter()
421 .zip(args1.iter())
422 .map(
423 |(a0, a1)| Ok(a0.label == a1.label && a0.typ.contains(env, &a1.typ)?),
424 )
425 .collect::<Result<AndAc>>()?
426 .0
427 && match (vargs0, vargs1) {
428 (None, None) => true,
429 (None, _) | (_, None) => false,
430 (Some(t0), Some(t1)) => t0.contains(env, t1)?,
431 }
432 && rtype0.contains(env, rtype1)?
433 && constraints0
434 .read()
435 .iter()
436 .map(|(tv, tc)| tc.contains(env, &Type::TVar(tv.clone())))
437 .collect::<Result<AndAc>>()?
438 .0
439 && constraints1
440 .read()
441 .iter()
442 .map(|(tv, tc)| tc.contains(env, &Type::TVar(tv.clone())))
443 .collect::<Result<AndAc>>()?
444 .0
445 && tr0.contains(env, tr1)?)
446 }
447
448 pub fn check_sigmatch<R: Rt, E: UserEvent>(
449 &self,
450 env: &Env<R, E>,
451 other: &Self,
452 ) -> Result<()> {
453 if !self.sigmatch(env, other)? {
454 bail!("Fn signatures do not match {self} does not match {other}")
455 }
456 Ok(())
457 }
458
459 pub fn map_argpos(
460 &self,
461 other: &Self,
462 ) -> LPooled<FxHashMap<ArcStr, (Option<usize>, Option<usize>)>> {
463 let mut tbl: LPooled<FxHashMap<ArcStr, (Option<usize>, Option<usize>)>> =
464 LPooled::take();
465 for (i, a) in self.args.iter().enumerate() {
466 match &a.label {
467 None => break,
468 Some((n, _)) => tbl.entry(n.clone()).or_default().0 = Some(i),
469 }
470 }
471 for (i, a) in other.args.iter().enumerate() {
472 match &a.label {
473 None => break,
474 Some((n, _)) => tbl.entry(n.clone()).or_default().1 = Some(i),
475 }
476 }
477 tbl
478 }
479
480 pub fn scope_refs(&self, scope: &ModPath) -> Self {
481 let vargs = self.vargs.as_ref().map(|t| t.scope_refs(scope));
482 let rtype = self.rtype.scope_refs(scope);
483 let args =
484 Arc::from_iter(self.args.iter().map(|a| FnArgType {
485 label: a.label.clone(),
486 typ: a.typ.scope_refs(scope),
487 }));
488 let mut cres: SmallVec<[(TVar, Type); 4]> = smallvec![];
489 for (tv, tc) in self.constraints.read().iter() {
490 let tv = tv.scope_refs(scope);
491 let tc = tc.scope_refs(scope);
492 cres.push((tv, tc));
493 }
494 let throws = self.throws.scope_refs(scope);
495 FnType {
496 args,
497 rtype,
498 constraints: Arc::new(RwLock::new(cres.into_iter().collect())),
499 vargs,
500 throws,
501 }
502 }
503}
504
505impl fmt::Display for FnType {
506 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
507 let constraints = self.constraints.read();
508 if constraints.len() == 0 {
509 write!(f, "fn(")?;
510 } else {
511 write!(f, "fn<")?;
512 for (i, (tv, t)) in constraints.iter().enumerate() {
513 write!(f, "{tv}: {t}")?;
514 if i < constraints.len() - 1 {
515 write!(f, ", ")?;
516 }
517 }
518 write!(f, ">(")?;
519 }
520 for (i, a) in self.args.iter().enumerate() {
521 match &a.label {
522 Some((l, true)) => write!(f, "?#{l}: ")?,
523 Some((l, false)) => write!(f, "#{l}: ")?,
524 None => (),
525 }
526 write!(f, "{}", a.typ)?;
527 if i < self.args.len() - 1 || self.vargs.is_some() {
528 write!(f, ", ")?;
529 }
530 }
531 if let Some(vargs) = &self.vargs {
532 write!(f, "@args: {}", vargs)?;
533 }
534 match &self.rtype {
535 Type::Fn(ft) => write!(f, ") -> ({ft})")?,
536 Type::ByRef(t) => match &**t {
537 Type::Fn(ft) => write!(f, ") -> &({ft})")?,
538 t => write!(f, ") -> &{t}")?,
539 },
540 t => write!(f, ") -> {t}")?,
541 }
542 match &self.throws {
543 Type::Bottom => Ok(()),
544 Type::TVar(tv) if *tv.read().typ.read() == Some(Type::Bottom) => Ok(()),
545 t => write!(f, " throws {t}"),
546 }
547 }
548}