graphix_compiler/typ/
fntyp.rs

1use super::AndAc;
2use crate::{
3    env::Env,
4    expr::ModPath,
5    typ::{ContainsFlags, TVar, Type},
6    Rt, UserEvent,
7};
8use anyhow::{bail, Result};
9use arcstr::ArcStr;
10use enumflags2::BitFlags;
11use fxhash::FxHashMap;
12use parking_lot::RwLock;
13use poolshark::local::LPooled;
14use smallvec::{smallvec, SmallVec};
15use std::{
16    cmp::{Eq, Ordering, PartialEq},
17    fmt::{self, Debug},
18};
19use triomphe::Arc;
20
21#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
22pub struct FnArgType {
23    pub label: Option<(ArcStr, bool)>,
24    pub typ: Type,
25}
26
27#[derive(Debug, Clone)]
28pub struct FnType {
29    pub args: Arc<[FnArgType]>,
30    pub vargs: Option<Type>,
31    pub rtype: Type,
32    pub constraints: Arc<RwLock<LPooled<Vec<(TVar, Type)>>>>,
33    pub throws: Type,
34}
35
36impl PartialEq for FnType {
37    fn eq(&self, other: &Self) -> bool {
38        let Self {
39            args: args0,
40            vargs: vargs0,
41            rtype: rtype0,
42            constraints: constraints0,
43            throws: th0,
44        } = self;
45        let Self {
46            args: args1,
47            vargs: vargs1,
48            rtype: rtype1,
49            constraints: constraints1,
50            throws: th1,
51        } = other;
52        args0 == args1
53            && vargs0 == vargs1
54            && rtype0 == rtype1
55            && &*constraints0.read() == &*constraints1.read()
56            && th0 == th1
57    }
58}
59
60impl Eq for FnType {}
61
62impl PartialOrd for FnType {
63    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
64        use std::cmp::Ordering;
65        let Self {
66            args: args0,
67            vargs: vargs0,
68            rtype: rtype0,
69            constraints: constraints0,
70            throws: th0,
71        } = self;
72        let Self {
73            args: args1,
74            vargs: vargs1,
75            rtype: rtype1,
76            constraints: constraints1,
77            throws: th1,
78        } = other;
79        match args0.partial_cmp(&args1) {
80            Some(Ordering::Equal) => match vargs0.partial_cmp(vargs1) {
81                Some(Ordering::Equal) => match rtype0.partial_cmp(rtype1) {
82                    Some(Ordering::Equal) => {
83                        match constraints0.read().partial_cmp(&*constraints1.read()) {
84                            Some(Ordering::Equal) => th0.partial_cmp(th1),
85                            r => r,
86                        }
87                    }
88                    r => r,
89                },
90                r => r,
91            },
92            r => r,
93        }
94    }
95}
96
97impl Ord for FnType {
98    fn cmp(&self, other: &Self) -> Ordering {
99        self.partial_cmp(other).unwrap()
100    }
101}
102
103impl Default for FnType {
104    fn default() -> Self {
105        Self {
106            args: Arc::from_iter([]),
107            vargs: None,
108            rtype: Default::default(),
109            constraints: Arc::new(RwLock::new(LPooled::take())),
110            throws: Default::default(),
111        }
112    }
113}
114
115impl FnType {
116    pub(super) fn normalize(&self) -> Self {
117        let Self { args, vargs, rtype, constraints, throws } = self;
118        let args = Arc::from_iter(
119            args.iter()
120                .map(|a| FnArgType { label: a.label.clone(), typ: a.typ.normalize() }),
121        );
122        let vargs = vargs.as_ref().map(|t| t.normalize());
123        let rtype = rtype.normalize();
124        let constraints = Arc::new(RwLock::new(
125            constraints
126                .read()
127                .iter()
128                .map(|(tv, t)| (tv.clone(), t.normalize()))
129                .collect(),
130        ));
131        let throws = throws.normalize();
132        FnType { args, vargs, rtype, constraints, throws }
133    }
134
135    pub fn unbind_tvars(&self) {
136        let FnType { args, vargs, rtype, constraints, throws } = self;
137        for arg in args.iter() {
138            arg.typ.unbind_tvars()
139        }
140        if let Some(t) = vargs {
141            t.unbind_tvars()
142        }
143        rtype.unbind_tvars();
144        for (tv, tc) in constraints.read().iter() {
145            tv.unbind();
146            tc.unbind_tvars()
147        }
148        throws.unbind_tvars();
149    }
150
151    pub fn constrain_known(&self) {
152        let mut known = LPooled::take();
153        self.collect_tvars(&mut known);
154        let mut constraints = self.constraints.write();
155        for (name, tv) in known.drain() {
156            if let Some(t) = tv.read().typ.read().as_ref() {
157                if !constraints.iter().any(|(tv, _)| tv.name == name) {
158                    t.bind_as(&Type::Any);
159                    constraints.push((tv.clone(), t.normalize()));
160                }
161            }
162        }
163    }
164
165    pub fn reset_tvars(&self) -> Self {
166        let FnType { args, vargs, rtype, constraints, throws } = self;
167        let args = Arc::from_iter(
168            args.iter()
169                .map(|a| FnArgType { label: a.label.clone(), typ: a.typ.reset_tvars() }),
170        );
171        let vargs = vargs.as_ref().map(|t| t.reset_tvars());
172        let rtype = rtype.reset_tvars();
173        let constraints = Arc::new(RwLock::new(
174            constraints
175                .read()
176                .iter()
177                .map(|(tv, tc)| (TVar::empty_named(tv.name.clone()), tc.reset_tvars()))
178                .collect(),
179        ));
180        let throws = throws.reset_tvars();
181        FnType { args, vargs, rtype, constraints, throws }
182    }
183
184    pub fn replace_tvars(&self, known: &FxHashMap<ArcStr, Type>) -> Self {
185        let FnType { args, vargs, rtype, constraints, throws } = self;
186        let args = Arc::from_iter(args.iter().map(|a| FnArgType {
187            label: a.label.clone(),
188            typ: a.typ.replace_tvars(known),
189        }));
190        let vargs = vargs.as_ref().map(|t| t.replace_tvars(known));
191        let rtype = rtype.replace_tvars(known);
192        let constraints = constraints.clone();
193        let throws = throws.replace_tvars(known);
194        FnType { args, vargs, rtype, constraints, throws }
195    }
196
197    /// replace automatically constrained type variables with their
198    /// constraint type. This is only useful for making nicer display
199    /// types in IDEs and shells.
200    pub fn replace_auto_constrained(&self) -> Self {
201        let mut known: LPooled<FxHashMap<ArcStr, Type>> = LPooled::take();
202        let Self { args, vargs, rtype, constraints, throws } = self;
203        let constraints: LPooled<Vec<(TVar, Type)>> = constraints
204            .read()
205            .iter()
206            .filter_map(|(tv, ct)| {
207                if tv.name.starts_with("_") {
208                    known.insert(tv.name.clone(), ct.clone());
209                    None
210                } else {
211                    Some((tv.clone(), ct.clone()))
212                }
213            })
214            .collect();
215        let constraints = Arc::new(RwLock::new(constraints));
216        let args = Arc::from_iter(args.iter().map(|FnArgType { label, typ }| {
217            FnArgType { label: label.clone(), typ: typ.replace_tvars(&known) }
218        }));
219        let vargs = vargs.as_ref().map(|t| t.replace_tvars(&known));
220        let rtype = rtype.replace_tvars(&known);
221        let throws = throws.replace_tvars(&known);
222        Self { args, vargs, rtype, constraints, throws }
223    }
224
225    pub fn has_unbound(&self) -> bool {
226        let FnType { args, vargs, rtype, constraints, throws } = self;
227        args.iter().any(|a| a.typ.has_unbound())
228            || vargs.as_ref().map(|t| t.has_unbound()).unwrap_or(false)
229            || rtype.has_unbound()
230            || constraints
231                .read()
232                .iter()
233                .any(|(tv, tc)| tv.read().typ.read().is_none() || tc.has_unbound())
234            || throws.has_unbound()
235    }
236
237    pub fn bind_as(&self, t: &Type) {
238        let FnType { args, vargs, rtype, constraints, throws } = self;
239        for a in args.iter() {
240            a.typ.bind_as(t)
241        }
242        if let Some(va) = vargs.as_ref() {
243            va.bind_as(t)
244        }
245        rtype.bind_as(t);
246        for (tv, tc) in constraints.read().iter() {
247            let tv = tv.read();
248            let mut tv = tv.typ.write();
249            if tv.is_none() {
250                *tv = Some(t.clone())
251            }
252            tc.bind_as(t)
253        }
254        throws.bind_as(t);
255    }
256
257    pub fn alias_tvars(&self, known: &mut FxHashMap<ArcStr, TVar>) {
258        let FnType { args, vargs, rtype, constraints, throws } = self;
259        for arg in args.iter() {
260            arg.typ.alias_tvars(known)
261        }
262        if let Some(vargs) = vargs {
263            vargs.alias_tvars(known)
264        }
265        rtype.alias_tvars(known);
266        for (tv, tc) in constraints.read().iter() {
267            Type::TVar(tv.clone()).alias_tvars(known);
268            tc.alias_tvars(known);
269        }
270        throws.alias_tvars(known);
271    }
272
273    pub fn collect_tvars(&self, known: &mut FxHashMap<ArcStr, TVar>) {
274        let FnType { args, vargs, rtype, constraints, throws } = self;
275        for arg in args.iter() {
276            arg.typ.collect_tvars(known)
277        }
278        if let Some(vargs) = vargs {
279            vargs.collect_tvars(known)
280        }
281        rtype.collect_tvars(known);
282        for (tv, tc) in constraints.read().iter() {
283            Type::TVar(tv.clone()).collect_tvars(known);
284            tc.collect_tvars(known);
285        }
286        throws.collect_tvars(known);
287    }
288
289    pub fn contains<R: Rt, E: UserEvent>(
290        &self,
291        env: &Env<R, E>,
292        t: &Self,
293    ) -> Result<bool> {
294        self.contains_int(
295            ContainsFlags::AliasTVars | ContainsFlags::InitTVars,
296            env,
297            &mut LPooled::take(),
298            t,
299        )
300    }
301
302    pub(super) fn contains_int<R: Rt, E: UserEvent>(
303        &self,
304        flags: BitFlags<ContainsFlags>,
305        env: &Env<R, E>,
306        hist: &mut FxHashMap<(usize, usize), bool>,
307        t: &Self,
308    ) -> Result<bool> {
309        let mut sul = 0;
310        let mut tul = 0;
311        for (i, a) in self.args.iter().enumerate() {
312            sul = i;
313            match &a.label {
314                None => {
315                    break;
316                }
317                Some((l, _)) => match t
318                    .args
319                    .iter()
320                    .find(|a| a.label.as_ref().map(|a| &a.0) == Some(l))
321                {
322                    None => return Ok(false),
323                    Some(o) => {
324                        if !o.typ.contains_int(flags, env, hist, &a.typ)? {
325                            return Ok(false);
326                        }
327                    }
328                },
329            }
330        }
331        for (i, a) in t.args.iter().enumerate() {
332            tul = i;
333            match &a.label {
334                None => {
335                    break;
336                }
337                Some((l, opt)) => match self
338                    .args
339                    .iter()
340                    .find(|a| a.label.as_ref().map(|a| &a.0) == Some(l))
341                {
342                    Some(_) => (),
343                    None => {
344                        if !opt {
345                            return Ok(false);
346                        }
347                    }
348                },
349            }
350        }
351        let slen = self.args.len() - sul;
352        let tlen = t.args.len() - tul;
353        Ok(slen == tlen
354            && t.args[tul..]
355                .iter()
356                .zip(self.args[sul..].iter())
357                .map(|(t, s)| t.typ.contains_int(flags, env, hist, &s.typ))
358                .collect::<Result<AndAc>>()?
359                .0
360            && match (&t.vargs, &self.vargs) {
361                (Some(tv), Some(sv)) => tv.contains_int(flags, env, hist, sv)?,
362                (None, None) => true,
363                (_, _) => false,
364            }
365            && self.rtype.contains_int(flags, env, hist, &t.rtype)?
366            && self
367                .constraints
368                .read()
369                .iter()
370                .map(|(tv, tc)| {
371                    tc.contains_int(flags, env, hist, &Type::TVar(tv.clone()))
372                })
373                .collect::<Result<AndAc>>()?
374                .0
375            && t.constraints
376                .read()
377                .iter()
378                .map(|(tv, tc)| {
379                    tc.contains_int(flags, env, hist, &Type::TVar(tv.clone()))
380                })
381                .collect::<Result<AndAc>>()?
382                .0
383            && self.throws.contains_int(flags, env, hist, &t.throws)?)
384    }
385
386    pub fn check_contains<R: Rt, E: UserEvent>(
387        &self,
388        env: &Env<R, E>,
389        other: &Self,
390    ) -> Result<()> {
391        if !self.contains(env, other)? {
392            bail!("Fn type mismatch {self} does not contain {other}")
393        }
394        Ok(())
395    }
396
397    /// Return true if function signatures match. This is contains,
398    /// but does not allow labeled argument subtyping.
399    pub fn sigmatch<R: Rt, E: UserEvent>(
400        &self,
401        env: &Env<R, E>,
402        other: &Self,
403    ) -> Result<bool> {
404        let Self {
405            args: args0,
406            vargs: vargs0,
407            rtype: rtype0,
408            constraints: constraints0,
409            throws: tr0,
410        } = self;
411        let Self {
412            args: args1,
413            vargs: vargs1,
414            rtype: rtype1,
415            constraints: constraints1,
416            throws: tr1,
417        } = other;
418        Ok(args0.len() == args1.len()
419            && args0
420                .iter()
421                .zip(args1.iter())
422                .map(
423                    |(a0, a1)| Ok(a0.label == a1.label && a0.typ.contains(env, &a1.typ)?),
424                )
425                .collect::<Result<AndAc>>()?
426                .0
427            && match (vargs0, vargs1) {
428                (None, None) => true,
429                (None, _) | (_, None) => false,
430                (Some(t0), Some(t1)) => t0.contains(env, t1)?,
431            }
432            && rtype0.contains(env, rtype1)?
433            && constraints0
434                .read()
435                .iter()
436                .map(|(tv, tc)| tc.contains(env, &Type::TVar(tv.clone())))
437                .collect::<Result<AndAc>>()?
438                .0
439            && constraints1
440                .read()
441                .iter()
442                .map(|(tv, tc)| tc.contains(env, &Type::TVar(tv.clone())))
443                .collect::<Result<AndAc>>()?
444                .0
445            && tr0.contains(env, tr1)?)
446    }
447
448    pub fn check_sigmatch<R: Rt, E: UserEvent>(
449        &self,
450        env: &Env<R, E>,
451        other: &Self,
452    ) -> Result<()> {
453        if !self.sigmatch(env, other)? {
454            bail!("Fn signatures do not match {self} does not match {other}")
455        }
456        Ok(())
457    }
458
459    pub fn map_argpos(
460        &self,
461        other: &Self,
462    ) -> LPooled<FxHashMap<ArcStr, (Option<usize>, Option<usize>)>> {
463        let mut tbl: LPooled<FxHashMap<ArcStr, (Option<usize>, Option<usize>)>> =
464            LPooled::take();
465        for (i, a) in self.args.iter().enumerate() {
466            match &a.label {
467                None => break,
468                Some((n, _)) => tbl.entry(n.clone()).or_default().0 = Some(i),
469            }
470        }
471        for (i, a) in other.args.iter().enumerate() {
472            match &a.label {
473                None => break,
474                Some((n, _)) => tbl.entry(n.clone()).or_default().1 = Some(i),
475            }
476        }
477        tbl
478    }
479
480    pub fn scope_refs(&self, scope: &ModPath) -> Self {
481        let vargs = self.vargs.as_ref().map(|t| t.scope_refs(scope));
482        let rtype = self.rtype.scope_refs(scope);
483        let args =
484            Arc::from_iter(self.args.iter().map(|a| FnArgType {
485                label: a.label.clone(),
486                typ: a.typ.scope_refs(scope),
487            }));
488        let mut cres: SmallVec<[(TVar, Type); 4]> = smallvec![];
489        for (tv, tc) in self.constraints.read().iter() {
490            let tv = tv.scope_refs(scope);
491            let tc = tc.scope_refs(scope);
492            cres.push((tv, tc));
493        }
494        let throws = self.throws.scope_refs(scope);
495        FnType {
496            args,
497            rtype,
498            constraints: Arc::new(RwLock::new(cres.into_iter().collect())),
499            vargs,
500            throws,
501        }
502    }
503}
504
505impl fmt::Display for FnType {
506    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
507        let constraints = self.constraints.read();
508        if constraints.len() == 0 {
509            write!(f, "fn(")?;
510        } else {
511            write!(f, "fn<")?;
512            for (i, (tv, t)) in constraints.iter().enumerate() {
513                write!(f, "{tv}: {t}")?;
514                if i < constraints.len() - 1 {
515                    write!(f, ", ")?;
516                }
517            }
518            write!(f, ">(")?;
519        }
520        for (i, a) in self.args.iter().enumerate() {
521            match &a.label {
522                Some((l, true)) => write!(f, "?#{l}: ")?,
523                Some((l, false)) => write!(f, "#{l}: ")?,
524                None => (),
525            }
526            write!(f, "{}", a.typ)?;
527            if i < self.args.len() - 1 || self.vargs.is_some() {
528                write!(f, ", ")?;
529            }
530        }
531        if let Some(vargs) = &self.vargs {
532            write!(f, "@args: {}", vargs)?;
533        }
534        match &self.rtype {
535            Type::Fn(ft) => write!(f, ") -> ({ft})")?,
536            Type::ByRef(t) => match &**t {
537                Type::Fn(ft) => write!(f, ") -> &({ft})")?,
538                t => write!(f, ") -> &{t}")?,
539            },
540            t => write!(f, ") -> {t}")?,
541        }
542        match &self.throws {
543            Type::Bottom => Ok(()),
544            Type::TVar(tv) if *tv.read().typ.read() == Some(Type::Bottom) => Ok(()),
545            t => write!(f, " throws {t}"),
546        }
547    }
548}