graphix_compiler/typ/
fntyp.rs

1use crate::{
2    env::Env,
3    expr::ModPath,
4    typ::{TVar, Type},
5    Ctx, UserEvent,
6};
7use anyhow::{bail, Result};
8use arcstr::ArcStr;
9use fxhash::FxHashMap;
10use parking_lot::RwLock;
11use std::{
12    cell::RefCell,
13    cmp::{Eq, Ordering, PartialEq},
14    collections::HashMap,
15    fmt::{self, Debug},
16};
17use triomphe::Arc;
18
19use super::AndAc;
20
21#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
22pub struct FnArgType {
23    pub label: Option<(ArcStr, bool)>,
24    pub typ: Type,
25}
26
27#[derive(Debug, Clone)]
28pub struct FnType {
29    pub args: Arc<[FnArgType]>,
30    pub vargs: Option<Type>,
31    pub rtype: Type,
32    pub constraints: Arc<RwLock<Vec<(TVar, Type)>>>,
33}
34
35impl PartialEq for FnType {
36    fn eq(&self, other: &Self) -> bool {
37        let Self { args: args0, vargs: vargs0, rtype: rtype0, constraints: constraints0 } =
38            self;
39        let Self { args: args1, vargs: vargs1, rtype: rtype1, constraints: constraints1 } =
40            other;
41        args0 == args1
42            && vargs0 == vargs1
43            && rtype0 == rtype1
44            && &*constraints0.read() == &*constraints1.read()
45    }
46}
47
48impl Eq for FnType {}
49
50impl PartialOrd for FnType {
51    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
52        use std::cmp::Ordering;
53        let Self { args: args0, vargs: vargs0, rtype: rtype0, constraints: constraints0 } =
54            self;
55        let Self { args: args1, vargs: vargs1, rtype: rtype1, constraints: constraints1 } =
56            other;
57        match args0.partial_cmp(&args1) {
58            Some(Ordering::Equal) => match vargs0.partial_cmp(vargs1) {
59                Some(Ordering::Equal) => match rtype0.partial_cmp(rtype1) {
60                    Some(Ordering::Equal) => {
61                        constraints0.read().partial_cmp(&*constraints1.read())
62                    }
63                    r => r,
64                },
65                r => r,
66            },
67            r => r,
68        }
69    }
70}
71
72impl Ord for FnType {
73    fn cmp(&self, other: &Self) -> Ordering {
74        self.partial_cmp(other).unwrap()
75    }
76}
77
78impl Default for FnType {
79    fn default() -> Self {
80        Self {
81            args: Arc::from_iter([]),
82            vargs: None,
83            rtype: Default::default(),
84            constraints: Arc::new(RwLock::new(vec![])),
85        }
86    }
87}
88
89impl FnType {
90    pub fn scope_refs(&self, scope: &ModPath) -> FnType {
91        let typ = Type::Fn(Arc::new(self.clone()));
92        match typ.scope_refs(scope) {
93            Type::Fn(f) => (*f).clone(),
94            _ => unreachable!(),
95        }
96    }
97
98    pub(super) fn normalize(&self) -> Self {
99        let Self { args, vargs, rtype, constraints } = self;
100        let args = Arc::from_iter(
101            args.iter()
102                .map(|a| FnArgType { label: a.label.clone(), typ: a.typ.normalize() }),
103        );
104        let vargs = vargs.as_ref().map(|t| t.normalize());
105        let rtype = rtype.normalize();
106        let constraints = Arc::new(RwLock::new(
107            constraints
108                .read()
109                .iter()
110                .map(|(tv, t)| (tv.clone(), t.normalize()))
111                .collect(),
112        ));
113        FnType { args, vargs, rtype, constraints }
114    }
115
116    pub fn unbind_tvars(&self) {
117        let FnType { args, vargs, rtype, constraints } = self;
118        for arg in args.iter() {
119            arg.typ.unbind_tvars()
120        }
121        if let Some(t) = vargs {
122            t.unbind_tvars()
123        }
124        rtype.unbind_tvars();
125        for (tv, tc) in constraints.read().iter() {
126            tv.unbind();
127            tc.unbind_tvars()
128        }
129    }
130
131    pub fn constrain_known(&self) {
132        thread_local! {
133            static KNOWN: RefCell<FxHashMap<ArcStr, TVar>> = RefCell::new(HashMap::default());
134        }
135        KNOWN.with_borrow_mut(|known| {
136            known.clear();
137            self.collect_tvars(known);
138            let mut constraints = self.constraints.write();
139            for (name, tv) in known.drain() {
140                if let Some(t) = tv.read().typ.read().as_ref() {
141                    if !constraints.iter().any(|(tv, _)| tv.name == name) {
142                        t.bind_as(&Type::Any);
143                        constraints.push((tv.clone(), t.normalize()));
144                    }
145                }
146            }
147        });
148    }
149
150    pub fn reset_tvars(&self) -> Self {
151        let FnType { args, vargs, rtype, constraints } = self;
152        let args = Arc::from_iter(
153            args.iter()
154                .map(|a| FnArgType { label: a.label.clone(), typ: a.typ.reset_tvars() }),
155        );
156        let vargs = vargs.as_ref().map(|t| t.reset_tvars());
157        let rtype = rtype.reset_tvars();
158        let constraints = Arc::new(RwLock::new(
159            constraints
160                .read()
161                .iter()
162                .map(|(tv, tc)| (TVar::empty_named(tv.name.clone()), tc.reset_tvars()))
163                .collect(),
164        ));
165        FnType { args, vargs, rtype, constraints }
166    }
167
168    pub fn replace_tvars(&self, known: &FxHashMap<ArcStr, Type>) -> Self {
169        let FnType { args, vargs, rtype, constraints } = self;
170        let args = Arc::from_iter(args.iter().map(|a| FnArgType {
171            label: a.label.clone(),
172            typ: a.typ.replace_tvars(known),
173        }));
174        let vargs = vargs.as_ref().map(|t| t.replace_tvars(known));
175        let rtype = rtype.replace_tvars(known);
176        let constraints = constraints.clone();
177        FnType { args, vargs, rtype, constraints }
178    }
179
180    /// replace automatically constrained type variables with their
181    /// constraint type. This is only useful for making nicer display
182    /// types in IDEs and shells.
183    pub fn replace_auto_constrained(&self) -> Self {
184        thread_local! {
185            static KNOWN: RefCell<FxHashMap<ArcStr, Type>> = RefCell::new(HashMap::default());
186        }
187        KNOWN.with_borrow_mut(|known| {
188            known.clear();
189            let Self { args, vargs, rtype, constraints } = self;
190            let constraints: Vec<(TVar, Type)> = constraints
191                .read()
192                .iter()
193                .filter_map(|(tv, ct)| {
194                    if tv.name.starts_with("_") {
195                        known.insert(tv.name.clone(), ct.clone());
196                        None
197                    } else {
198                        Some((tv.clone(), ct.clone()))
199                    }
200                })
201                .collect();
202            let constraints = Arc::new(RwLock::new(constraints));
203            let args = Arc::from_iter(args.iter().map(|FnArgType { label, typ }| {
204                FnArgType { label: label.clone(), typ: typ.replace_tvars(&known) }
205            }));
206            let vargs = vargs.as_ref().map(|t| t.replace_tvars(&known));
207            let rtype = rtype.replace_tvars(&known);
208            Self { args, vargs, rtype, constraints }
209        })
210    }
211
212    pub fn has_unbound(&self) -> bool {
213        let FnType { args, vargs, rtype, constraints } = self;
214        args.iter().any(|a| a.typ.has_unbound())
215            || vargs.as_ref().map(|t| t.has_unbound()).unwrap_or(false)
216            || rtype.has_unbound()
217            || constraints
218                .read()
219                .iter()
220                .any(|(tv, tc)| tv.read().typ.read().is_none() || tc.has_unbound())
221    }
222
223    pub fn bind_as(&self, t: &Type) {
224        let FnType { args, vargs, rtype, constraints } = self;
225        for a in args.iter() {
226            a.typ.bind_as(t)
227        }
228        if let Some(va) = vargs.as_ref() {
229            va.bind_as(t)
230        }
231        rtype.bind_as(t);
232        for (tv, tc) in constraints.read().iter() {
233            let tv = tv.read();
234            let mut tv = tv.typ.write();
235            if tv.is_none() {
236                *tv = Some(t.clone())
237            }
238            tc.bind_as(t)
239        }
240    }
241
242    pub fn alias_tvars(&self, known: &mut FxHashMap<ArcStr, TVar>) {
243        let FnType { args, vargs, rtype, constraints } = self;
244        for arg in args.iter() {
245            arg.typ.alias_tvars(known)
246        }
247        if let Some(vargs) = vargs {
248            vargs.alias_tvars(known)
249        }
250        rtype.alias_tvars(known);
251        for (tv, tc) in constraints.read().iter() {
252            Type::TVar(tv.clone()).alias_tvars(known);
253            tc.alias_tvars(known);
254        }
255    }
256
257    pub fn collect_tvars(&self, known: &mut FxHashMap<ArcStr, TVar>) {
258        let FnType { args, vargs, rtype, constraints } = self;
259        for arg in args.iter() {
260            arg.typ.collect_tvars(known)
261        }
262        if let Some(vargs) = vargs {
263            vargs.collect_tvars(known)
264        }
265        rtype.collect_tvars(known);
266        for (tv, tc) in constraints.read().iter() {
267            Type::TVar(tv.clone()).collect_tvars(known);
268            tc.collect_tvars(known);
269        }
270    }
271
272    pub fn contains<C: Ctx, E: UserEvent>(
273        &self,
274        env: &Env<C, E>,
275        t: &Self,
276    ) -> Result<bool> {
277        thread_local! {
278            static HIST: RefCell<FxHashMap<(usize, usize), bool>> = RefCell::new(HashMap::default());
279        }
280        HIST.with_borrow_mut(|hist| self.contains_int(env, hist, t))
281    }
282
283    pub(super) fn contains_int<C: Ctx, E: UserEvent>(
284        &self,
285        env: &Env<C, E>,
286        hist: &mut FxHashMap<(usize, usize), bool>,
287        t: &Self,
288    ) -> Result<bool> {
289        let mut sul = 0;
290        let mut tul = 0;
291        for (i, a) in self.args.iter().enumerate() {
292            sul = i;
293            match &a.label {
294                None => {
295                    break;
296                }
297                Some((l, _)) => match t
298                    .args
299                    .iter()
300                    .find(|a| a.label.as_ref().map(|a| &a.0) == Some(l))
301                {
302                    None => return Ok(false),
303                    Some(o) => {
304                        if !o.typ.contains_int(env, hist, &a.typ)? {
305                            return Ok(false);
306                        }
307                    }
308                },
309            }
310        }
311        for (i, a) in t.args.iter().enumerate() {
312            tul = i;
313            match &a.label {
314                None => {
315                    break;
316                }
317                Some((l, opt)) => match self
318                    .args
319                    .iter()
320                    .find(|a| a.label.as_ref().map(|a| &a.0) == Some(l))
321                {
322                    Some(_) => (),
323                    None => {
324                        if !opt {
325                            return Ok(false);
326                        }
327                    }
328                },
329            }
330        }
331        let slen = self.args.len() - sul;
332        let tlen = t.args.len() - tul;
333        Ok(slen == tlen
334            && self
335                .constraints
336                .read()
337                .iter()
338                .map(|(tv, tc)| tc.contains_int(env, hist, &Type::TVar(tv.clone())))
339                .collect::<Result<AndAc>>()?
340                .0
341            && t.constraints
342                .read()
343                .iter()
344                .map(|(tv, tc)| tc.contains_int(env, hist, &Type::TVar(tv.clone())))
345                .collect::<Result<AndAc>>()?
346                .0
347            && t.args[tul..]
348                .iter()
349                .zip(self.args[sul..].iter())
350                .map(|(t, s)| t.typ.contains_int(env, hist, &s.typ))
351                .collect::<Result<AndAc>>()?
352                .0
353            && match (&t.vargs, &self.vargs) {
354                (Some(tv), Some(sv)) => tv.contains_int(env, hist, sv)?,
355                (None, None) => true,
356                (_, _) => false,
357            }
358            && self.rtype.contains_int(env, hist, &t.rtype)?)
359    }
360
361    pub fn check_contains<C: Ctx, E: UserEvent>(
362        &self,
363        env: &Env<C, E>,
364        other: &Self,
365    ) -> Result<()> {
366        if !self.contains(env, other)? {
367            bail!("Fn type mismatch {self} does not contain {other}")
368        }
369        Ok(())
370    }
371
372    /// Return true if function signatures match. This is contains,
373    /// but does not allow labeled argument subtyping.
374    pub fn sigmatch<C: Ctx, E: UserEvent>(
375        &self,
376        env: &Env<C, E>,
377        other: &Self,
378    ) -> Result<bool> {
379        let Self { args: args0, vargs: vargs0, rtype: rtype0, constraints: constraints0 } =
380            self;
381        let Self { args: args1, vargs: vargs1, rtype: rtype1, constraints: constraints1 } =
382            other;
383        Ok(args0.len() == args1.len()
384            && args0
385                .iter()
386                .zip(args1.iter())
387                .map(
388                    |(a0, a1)| Ok(a0.label == a1.label && a0.typ.contains(env, &a1.typ)?),
389                )
390                .collect::<Result<AndAc>>()?
391                .0
392            && match (vargs0, vargs1) {
393                (None, None) => true,
394                (None, _) | (_, None) => false,
395                (Some(t0), Some(t1)) => t0.contains(env, t1)?,
396            }
397            && rtype0.contains(env, rtype1)?
398            && constraints0
399                .read()
400                .iter()
401                .map(|(tv, tc)| tc.contains(env, &Type::TVar(tv.clone())))
402                .collect::<Result<AndAc>>()?
403                .0
404            && constraints1
405                .read()
406                .iter()
407                .map(|(tv, tc)| tc.contains(env, &Type::TVar(tv.clone())))
408                .collect::<Result<AndAc>>()?
409                .0)
410    }
411
412    pub fn check_sigmatch<C: Ctx, E: UserEvent>(
413        &self,
414        env: &Env<C, E>,
415        other: &Self,
416    ) -> Result<()> {
417        if !self.sigmatch(env, other)? {
418            bail!("Fn signatures do not match {self} does not match {other}")
419        }
420        Ok(())
421    }
422
423    pub fn map_argpos(
424        &self,
425        other: &Self,
426    ) -> FxHashMap<ArcStr, (Option<usize>, Option<usize>)> {
427        let mut tbl: FxHashMap<ArcStr, (Option<usize>, Option<usize>)> =
428            FxHashMap::default();
429        for (i, a) in self.args.iter().enumerate() {
430            match &a.label {
431                None => break,
432                Some((n, _)) => tbl.entry(n.clone()).or_default().0 = Some(i),
433            }
434        }
435        for (i, a) in other.args.iter().enumerate() {
436            match &a.label {
437                None => break,
438                Some((n, _)) => tbl.entry(n.clone()).or_default().1 = Some(i),
439            }
440        }
441        tbl
442    }
443}
444
445impl fmt::Display for FnType {
446    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
447        let constraints = self.constraints.read();
448        if constraints.len() == 0 {
449            write!(f, "fn(")?;
450        } else {
451            write!(f, "fn<")?;
452            for (i, (tv, t)) in constraints.iter().enumerate() {
453                write!(f, "{tv}: {t}")?;
454                if i < constraints.len() - 1 {
455                    write!(f, ", ")?;
456                }
457            }
458            write!(f, ">(")?;
459        }
460        for (i, a) in self.args.iter().enumerate() {
461            match &a.label {
462                Some((l, true)) => write!(f, "?#{l}: ")?,
463                Some((l, false)) => write!(f, "#{l}: ")?,
464                None => (),
465            }
466            write!(f, "{}", a.typ)?;
467            if i < self.args.len() - 1 || self.vargs.is_some() {
468                write!(f, ", ")?;
469            }
470        }
471        if let Some(vargs) = &self.vargs {
472            write!(f, "@args: {}", vargs)?;
473        }
474        write!(f, ") -> {}", self.rtype)
475    }
476}