lambdas/
types.rs

1use core::panic;
2use std::{collections::VecDeque};
3use crate::parse_type;
4use crate::expr::{Expr,Lambda};
5use crate::dsl::Domain;
6use egg::{Symbol,Id};
7
8
9
10#[derive(Debug, Clone, PartialEq, Eq)]
11pub enum UnifyErr {
12    Occurs,
13    ConcreteSubtree,
14    Production
15}
16pub type UnifyResult = Result<(), UnifyErr>;
17
18#[derive(Debug, Clone, PartialEq, Eq, Hash)]
19pub enum Type {
20    Var(usize), // type variable like t0 t1 etc
21    Term(Symbol, Vec<Type>), // symbol is the name like "int" or "list" or "->" and Vec<Type> is the args which is empty list for things like int etc
22    // Arrow(Box<Type>,Box<Type>)
23}
24
25
26/// int
27/// [Term("int",None)]
28/// 
29/// (list int)
30/// [Term("int",None),Term("list",0)]
31/// 
32/// (-> int int)
33/// [Term("int",None), Term("int",None), ArgCons(0,1), Term("->", Some(2))]
34
35
36#[derive(Debug, Clone, PartialEq, Eq, Hash)]
37pub enum TNode {
38    Var(usize), // type variable like t0 t1 etc
39    Term(Symbol, Option<RawTypeRef>), // symbol is the name like "int" or "list" or "->" and Option<usize> is the index of an ArgCons
40    ArgCons(RawTypeRef,Option<RawTypeRef>),
41}
42
43#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
44pub struct TypeRef {
45    pub raw: RawTypeRef,
46    pub shift: usize,
47}
48
49#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
50pub struct RawTypeRef(usize);
51
52
53impl RawTypeRef {
54    pub fn shift(&self, shift: usize) -> TypeRef {
55        TypeRef::new(*self,shift)
56    }
57
58    pub fn resolve<'a>(&self, typeset: &'a TypeSet) -> &'a TNode {
59        &typeset.nodes[self.0]
60    }
61
62    /// convenience method for converting to types. probably super slow but useful for debugging
63    #[inline(never)]
64    pub fn tp(&self, typeset: &TypeSet) -> Type {
65        match self.resolve(typeset) {
66            TNode::Var(i) => Type::Var(*i),
67            TNode::Term(p, _) => {
68                Type::Term(*p, self.iter_term_args(typeset).map(|arg| arg.tp(typeset)).collect())
69            },
70            TNode::ArgCons(_, _) => unreachable!()
71        }
72    }
73
74    pub fn show(&self, typeset: &TypeSet) -> String {
75        self.tp(typeset).to_string()
76    }
77
78    pub fn iter_term_args<'a>(&self, typeset: &'a TypeSet) -> TermArgIter<'a> {
79        if let TNode::Term(_,args) = self.resolve(typeset) {
80            TermArgIter { typeset, curr_idx: args}
81        } else {
82            panic!("cant iterate, not a term")
83        }
84    }
85
86    pub fn as_arrow(&self, typeset: &TypeSet) -> Option<(RawTypeRef, RawTypeRef)> {
87        if let TNode::Term(name,_) = self.resolve(typeset) {
88            if *name != *ARROW_SYM {
89                return None
90            }
91            let mut it = self.iter_term_args(typeset);
92            let left = it.next().unwrap();
93            let right = it.next().unwrap();
94            assert!(it.next().is_none(), "malformed arrow");
95            Some((left,right))
96        } else {
97             None
98        }
99    }
100
101    pub fn is_arrow(&self, typeset: &TypeSet) -> bool {
102        if let TNode::Term(name,_) = self.resolve(typeset) {
103            return *name == *ARROW_SYM
104        }
105        false
106    }
107
108    /// iterates over all nodes in the term of this type
109    pub fn iter_arrows<'a>(&self, typeset: &'a TypeSet) -> ArrowIterTypeRef<'a> {
110        ArrowIterTypeRef { curr: *self, typeset }
111    }
112
113    /// iterates over uncurried argument types of this arrow type
114    pub fn iter_args<'a>(&self, typeset: &'a TypeSet) -> impl Iterator<Item=RawTypeRef> + 'a {
115        self.iter_arrows(typeset).map(|(left,_right)| left)
116    }
117
118    /// arity of this arrow type (zero if not an arrow type)
119    pub fn arity(&self, typeset: &TypeSet) -> usize {
120        self.iter_args(typeset).count()
121    }
122
123    /// return type of this arrow types *after* uncurrying. For a non arrow type
124    /// this just returns the type itself.
125    pub fn return_type(&self, typeset: &TypeSet) -> RawTypeRef {
126        self.iter_arrows(typeset).last().map(|(_left,right)| right).unwrap_or(*self)
127    }
128
129    /// true if there are no type vars in this type
130    pub fn is_concrete(&self, typeset: &TypeSet) -> bool {
131        match self.resolve(typeset) {
132            TNode::Var(_) => false,
133            TNode::Term(_, _) => self.iter_term_args(typeset).all(|ty| ty.is_concrete(typeset)),
134            TNode::ArgCons(_,_) => panic!("is_concrete on an ArgCons")
135        }
136    }
137
138    pub fn max_var(&self, typeset: &TypeSet) -> Option<usize> {
139        match self.resolve(typeset) {
140            TNode::Var(i) => Some(*i),
141            TNode::Term(_, _) => self.iter_term_args(typeset).filter_map(|ty| ty.max_var(typeset)).max(),
142            TNode::ArgCons(_,_) => panic!("is_concrete on an ArgCons")
143        }
144    }
145
146    pub fn instantiate(&self, typeset: &mut TypeSet) -> TypeRef {
147        let shift_by = typeset.next_var;
148        if let Some(max_var) = self.max_var(typeset) {
149            // create a fresh type var for each new variable
150            for _ in 0..=max_var {
151                typeset.fresh_type_var();
152            }
153        }
154        TypeRef::new(*self, shift_by)
155    }
156
157}
158
159
160impl TypeRef {
161    fn new(raw: RawTypeRef, shift: usize) -> TypeRef {
162        TypeRef {raw, shift}
163    }
164
165    /// if `self` is a Var that is bound by our context, return whatever it is bound to 
166    pub fn canonicalize(&self, typeset: &TypeSet) -> TypeRef {
167        if let TNode::Var(i) = self.raw.resolve(typeset) {
168            if let Some(tp_ref) = typeset.get_var(*i + self.shift) {
169                // println!("looked up t{} -> {}", *i + self.shift, tp_ref.show(typeset));
170                return tp_ref.canonicalize(typeset) // recursively resolve the lookup result
171            }
172        }
173        *self
174    }
175
176    /// canonicalizes any toplevel variable away then resolves the resulting raw type ref. Note that
177    /// the TNode returned here will not be shifted
178    pub fn resolve(&self, typeset: &TypeSet) -> TNode {
179        let canonical = self.canonicalize(typeset);
180        let resolved = canonical.raw.resolve(typeset);
181        match resolved {
182            TNode::Var(i) => TNode::Var(i + canonical.shift), // importantly we add canonical.shift here not self.shift
183            _ => resolved.clone()
184        }
185    }
186
187    pub fn tp(&self, typeset: &TypeSet) -> Type {
188        self.raw.tp(typeset)
189    }
190
191    pub fn show(&self, typeset: &TypeSet) -> String {
192        format!("[shift {}] {}", self.shift, self.raw.tp(typeset))
193    }
194
195    pub fn iter_term_args<'a>(&'a self, typeset: &'a TypeSet) -> impl Iterator<Item=TypeRef> + 'a {
196        let canonical = self.canonicalize(typeset);
197        canonical.raw.iter_term_args(typeset).map(move |raw| raw.shift(canonical.shift))
198    }
199
200    pub fn as_arrow(&self, typeset: &TypeSet) -> Option<(TypeRef, TypeRef)> {
201        let canonical = self.canonicalize(typeset);
202        canonical.raw.as_arrow(typeset).map(|(r1,r2)| (r1.shift(canonical.shift),r2.shift(canonical.shift)))
203    }
204
205    pub fn is_arrow(&self, typeset: &TypeSet) -> bool {
206        if let TNode::Term(name,_) = self.resolve(typeset) {
207            return name == *ARROW_SYM
208        }
209        false
210    }
211
212    /// iterates over all nodes in the term of this type
213    pub fn iter_arrows<'a>(&'a self, typeset: &'a TypeSet) -> impl Iterator<Item=(TypeRef,TypeRef)> + 'a {
214        let canonical = self.canonicalize(typeset);
215        canonical.raw.iter_arrows(typeset).map(move |(r1,r2)| (r1.shift(canonical.shift),r2.shift(canonical.shift)))
216    }
217
218    /// iterates over uncurried argument types of this arrow type
219    pub fn iter_args<'a>(&'a self, typeset: &'a TypeSet) -> impl Iterator<Item=TypeRef> + 'a {
220        self.iter_arrows(typeset).map(|(left,_right)| left)
221    }
222
223    /// arity of this arrow type (zero if not an arrow type)
224    pub fn arity(&self, typeset: &TypeSet) -> usize {
225        self.iter_args(typeset).count()
226    }
227
228    /// return type of this arrow types *after* uncurrying. For a non arrow type
229    /// this just returns the type itself.
230    pub fn return_type(&self, typeset: &TypeSet) -> TypeRef {
231        self.iter_arrows(typeset).last().map(|(_left,right)| right).unwrap_or(*self)
232    }
233
234    /// true if there are no type vars in this type
235    pub fn is_concrete(&self, typeset: &TypeSet) -> bool {
236        match self.resolve(typeset) {
237            TNode::Var(_) => false,
238            TNode::Term(_, _) => self.iter_term_args(typeset).all(|ty| ty.is_concrete(typeset)),
239            TNode::ArgCons(_,_) => panic!("is_concrete on an ArgCons")
240        }
241    }
242
243    /// true if type var i occurs in this type (post-shifting of this type)
244    pub fn occurs(&self, i: usize, typeset: &TypeSet) -> bool {
245        // println!("occccc");
246        // todo!() // not sure if need to run substitution here
247        // println!("{:?}", self);
248        // println!("before canonicalizing: {}", self.show(typeset));
249        // println!("canonical: {}", self.canonicalize(typeset).show(typeset));
250        // println!("{:?}", self.resolve(typeset));
251
252        let resolved = self.resolve(typeset);
253
254        // println!("resolved: {:?}", resolved);
255
256        match resolved {
257            TNode::Var(j)  => i == j,
258            TNode::Term(_, _) => {
259                // println!("args: {:?}", self.iter_term_args(typeset).map(|arg|arg.show(typeset)).collect::<Vec<_>>());
260                self.iter_term_args(typeset).any(|ty| ty.occurs(i, typeset))
261            },
262            TNode::ArgCons(_, _) => panic!("occurs() on ArgCons")
263        }
264    }
265
266}
267
268#[derive(Debug, Clone, PartialEq, Eq)]
269pub struct TypeSet {
270    pub nodes: Vec<TNode>,
271    pub subst: Vec<(usize,TypeRef)>,
272    pub next_var: usize,
273}
274
275impl TypeSet {
276    pub fn add_tp(&mut self, tp: &Type) -> RawTypeRef {
277        match tp {
278            Type::Var(i) => {
279                self.nodes.push(TNode::Var(*i));
280                RawTypeRef(self.nodes.len() - 1)
281            }
282            Type::Term(p, args) => {
283                let mut arg_cons = None;
284                for arg in args.iter().rev() {
285                    let arg_hd = self.add_tp(arg);
286                    self.nodes.push(TNode::ArgCons(arg_hd, arg_cons));
287                    arg_cons = Some(RawTypeRef(self.nodes.len() - 1));
288                }
289                self.nodes.push(TNode::Term(*p, arg_cons));
290                RawTypeRef(self.nodes.len() - 1)
291            },
292        }
293    }
294    /// This is the usual way of creating a new Context. The context will be append-only
295    /// meaning you can roll it back to a point by truncating
296    pub fn empty() -> TypeSet {
297        TypeSet {
298            nodes: Default::default(),
299            subst: Default::default(),
300            next_var: 0,
301        }
302    }
303
304    pub fn save_state(&self) -> (usize,usize) {
305        (self.subst.len(), self.next_var)
306    }
307
308    pub fn load_state(&mut self, state: (usize,usize)) {
309        self.subst.truncate(state.0);
310        self.next_var = state.1;
311    }
312
313    fn fresh_type_var(&mut self) -> Type {
314        self.next_var += 1;
315        Type::Var(self.next_var-1)
316    }
317
318    // /// adds new fresh type vars as necessary such that variable Var exists
319    // #[inline(always)]
320    // fn fresh_type_vars(&mut self, var: usize) {
321    //     while var >= self.next_var {
322    //         self.fresh_type_var();
323    //     }
324    // }
325
326    /// a very quick non-allocating check that returns false if it's
327    /// obvious that these types won't unify. This works *even when a type hasnt
328    /// been instantiated() to have new type variables*. First this checks if t1 and t2 have the same constructors
329    /// and if theres an obvious mismatch there it gives up. Then it goes and looks up the types in the ctx
330    /// in case they were typevars, and then again checks if they have th same constructor. It uses apply_immut() to
331    /// avoid mutating the context for this lookup.
332    /// Note the apply_immut version of this was wrong bc thats only safe to do on the hole_tp side and apply_immut
333    /// is already done to the hole before then anyways
334    pub fn might_unify(&self, t1: &RawTypeRef, t2: &TypeRef) -> bool {
335        let node1 = t1.resolve(self);
336        let node2 = t2.resolve(self);
337        match (node1,node2) {
338            (TNode::Var(_), TNode::Var(_)) => true,
339            (TNode::Var(_), TNode::Term(_, _)) => true,
340            (TNode::Term(_, _), TNode::Var(_)) => true,
341            (TNode::Term(x, _), TNode::Term(y, _)) => {
342                *x == y && t1.arity(self) == t2.arity(self) && t1.iter_term_args(self).zip(t2.iter_term_args(self)).all(|(x,y)| self.might_unify(&x,&y))
343            },
344            _ => panic!("attempting to unify ArgCons or some other invalid constructor"),
345        }
346    }
347
348    /// Normal unification. Does not do the amortizing step of the unionfind (but may mutate
349    /// it still). See unify_cached() for amortized unionfind. Note that this is likely not slower
350    /// than unify_cached() in most cases.
351    pub fn unify(&mut self, t1: &TypeRef,  t2: &TypeRef) -> UnifyResult {
352        // println!("\tunify({},{})", t1.show(self), t2.show(self));
353        // println!("\t->({:?},{:?})", t1.resolve(self), t2.resolve(self));
354        // let t1: Type = t1.apply(self);
355        // let t2: Type = t2.apply(self);
356        // println!("\t  ...({},{}) {}", t1, t2, self);
357        // println!("about to resolve");
358
359        match ((t1.resolve(self),t1.canonicalize(self)), (t2.resolve(self),t2.canonicalize(self))) {
360            ((TNode::Var(i), _), (other, tref_other))
361          | ((other, tref_other), (TNode::Var(i),_)) =>
362          {
363                // println!("resolved");
364
365                if other == TNode::Var(i) { return Ok(()) } // unify(t0, t0) -> true
366                // println!("occurs");
367                if tref_other.occurs(i, self) { return Err(UnifyErr::Occurs) } // recursive type  e.g. unify(t0, (t0 -> int)) -> false
368                // println!("occurs done");
369
370                // *** Above is the "occurs" check, which prevents recursive definitions of types. Removing it would allow them.
371
372                assert!(self.get_var(i).is_none());
373                self.set_var(i, tref_other);
374                Ok(())
375            },
376            ((TNode::Term(x, _),tref_x), (TNode::Term(y, _),tref_y)) =>
377            {
378                // println!("resolved");
379                // simply recurse
380                if x != y || tref_x.arity(self) != tref_y.arity(self) {
381                    return Err(UnifyErr::Production)
382                }
383                // todo sad collect() here for borrow checker but might wanna find a way around
384                tref_x.iter_term_args(self).zip(tref_y.iter_term_args(self)).collect::<Vec<_>>().into_iter().try_for_each(|(x,y)| self.unify(&x,&y))
385            }
386            _ => unreachable!()
387        }
388    }
389
390    /// get what a variable is bound to (if anything).
391    #[inline(always)]
392    fn get_var(&self, var: usize) -> Option<&TypeRef> { // todo written in a silly way, rewrite
393        self.subst.iter().rfind(|(i,_)| *i == var).map(|(_,tp)| tp)
394    }
395    /// set what a variable is bound to
396    #[inline(always)]
397    fn set_var(&mut self, var: usize, ty: TypeRef) {
398        self.subst.push((var,ty));
399    }
400}
401
402
403
404
405lazy_static::lazy_static! {
406    static ref ARROW_SYM: egg::Symbol = Symbol::from(Type::ARROW);
407}
408
409impl Type {
410    pub const ARROW: &'static str = "->";
411
412    pub fn base(name: Symbol) -> Type {
413        Type::Term(name, vec![])
414    }
415
416    pub fn arrow(left: Type, right: Type) -> Type {
417        Type::Term(*ARROW_SYM, vec![left, right])
418    }
419
420    pub fn is_arrow(&self) -> bool {
421        match self {
422            Type::Var(_) => false,
423            Type::Term(name, _) => *name == *ARROW_SYM,
424        }
425    }
426
427    pub fn as_arrow(&self) -> Option<(&Type, &Type)> {
428        match self {
429            Type::Term(name,args) => {
430                if *name != *ARROW_SYM {
431                    return None
432                }
433                assert_eq!(args.len(),2);
434                Some((&args[0], &args[1]))
435            },
436            _ => None
437        }
438    }
439
440    /// iterates over all (left_type,right_type) pairs for the chain of arrows
441    /// starting here. Empty iterator if this is not an arrow.
442    // pub fn iter_nodes(&self) -> impl Iterator<Item=&Type> {
443    //     return NodeIter { curr: self }
444    // }
445
446    /// iterates over all nodes in the term of this type
447    pub fn iter_arrows(&self) -> ArrowIter {
448        ArrowIter { curr: self }
449    }
450
451    /// iterates over uncurried argument types of this arrow type
452    pub fn iter_args(&self) -> impl Iterator<Item=&Type> {
453        self.iter_arrows().map(|(left,_right)| left)
454    }
455
456    /// arity of this arrow type (zero if not an arrow type)
457    pub fn arity(&self) -> usize {
458        self.iter_args().count()
459    }
460
461    /// return type of this arrow types *after* uncurrying. For a non arrow type
462    /// this just returns the type itself.
463    pub fn return_type(&self) -> &Type {
464        self.iter_arrows().last().map(|(_left,right)| right).unwrap_or(self)
465    }
466
467    /// true if there are no type vars in this type
468    pub fn is_concrete(&self) -> bool {
469        match self {
470            Type::Var(_) => false,
471            Type::Term(_, args) => args.iter().all(|ty| ty.is_concrete())
472        }
473    }
474
475    /// true if type var i occurs in this type
476    pub fn occurs(&self, i: usize) -> bool {
477        match self {
478            Type::Var(j)  => i == *j,
479            Type::Term(_, args) => args.iter().any(|ty| ty.occurs(i))
480        }
481    }
482
483    pub fn apply_cached(&self, ctx: &mut Context) -> Type {
484        if self.is_concrete() {
485            return self.clone();
486        }
487        match self {
488            Type::Var(i) => {
489                // look up the type var in the ctx to see if its bound
490                if let Some(tp) = ctx.get(*i).cloned() {
491                    // in case it's bound to something that ALSO has variables, we want to track those down too
492                    let tp_applied = tp.apply(ctx);
493                    if tp != tp_applied {
494                        // and to save our work for the future, lets amortize it (union-find style) by saving what we
495                        // found things were bound to. Since bindings will never change this is okay.
496                        ctx.set(*i, tp_applied.clone())
497                    }
498                    tp_applied
499                } else {
500                    self.clone() // t0 is not bound by ctx so we leave it unbound
501                }
502            },
503            Type::Term(name, args) => Type::Term(*name, args.iter().map(|ty| ty.apply_cached(ctx)).collect())
504        }
505    }
506
507    /// same as apply_cached() but doesnt do the unionfind style caching of results, so there's no need to mutate the ctx
508    pub fn apply(&self, ctx: &Context) -> Type {
509        if self.is_concrete() {
510            return self.clone();
511        }
512        match self {
513            Type::Var(i) => {
514                // look up the type var in the ctx to see if its bound
515                if let Some(tp) = ctx.get(*i).cloned() {
516                    // in case it's bound to something that ALSO has variables, we want to track those down too
517                    tp.apply(ctx)
518                } else {
519                    self.clone() // t0 is not bound by ctx so we leave it unbound
520                }
521            },
522            Type::Term(name, args) => Type::Term(*name, args.iter().map(|ty| ty.apply(ctx)).collect())
523        }
524    }
525
526
527    /// shifts all variables in a type such that they are fresh variables in the context, returning a new type
528    pub fn instantiate(&self, ctx: &mut Context) -> Type {
529        if self.is_concrete() {
530            return self.clone()
531        }
532        fn instantiate_aux(ty: &Type, ctx: &mut Context, shift_by: usize) -> Type {
533            match ty {
534                Type::Var(i) => {
535                    let new = i + shift_by;
536                    ctx.fresh_type_vars(new);
537                    assert!(ctx.get(new).is_none());
538                    Type::Var(new)
539                },
540                Type::Term(name, args) => Type::Term(*name, args.iter().map(|t| instantiate_aux(t, ctx, shift_by)).collect()),
541            }
542        }
543        // shift by the highest var that already exists, so that theres no conflict
544        instantiate_aux(self, ctx, ctx.next_var)
545    }
546}
547
548pub struct ArrowIter<'a> {
549    curr: &'a Type
550}
551
552impl<'a> Iterator for ArrowIter<'a> {
553    type Item = (&'a Type, &'a Type);
554
555    fn next(&mut self) -> Option<Self::Item> {
556        if let Some((left,right)) = self.curr.as_arrow() {
557            self.curr = right;
558            Some((left,right))
559        } else {
560            None
561        }
562    }
563}
564
565pub struct ArrowIterTypeRef<'a> {
566    typeset: &'a TypeSet,
567    curr: RawTypeRef,
568}
569
570impl<'a> Iterator for ArrowIterTypeRef<'a> {
571    type Item = (RawTypeRef,RawTypeRef);
572
573    fn next(&mut self) -> Option<Self::Item> {
574        if let Some((left,right)) = self.curr.as_arrow(self.typeset) {
575            self.curr = right;
576            Some((left,right))
577        } else {
578            None
579        }
580    }
581}
582
583pub struct TermArgIter<'a> {
584    typeset: &'a TypeSet,
585    curr_idx: &'a Option<RawTypeRef>,
586}
587
588impl<'a> Iterator for TermArgIter<'a> {
589    type Item = RawTypeRef;
590
591    fn next(&mut self) -> Option<Self::Item> {
592        if let Some(curr_idx) = self.curr_idx {
593            if let TNode::ArgCons(arg,tl) = curr_idx.resolve(self.typeset) {
594                self.curr_idx = tl;
595                Some(*arg)
596            } else {
597                panic!("Cant iterate over something that's not a term")
598            }
599        } else {
600            None
601        }
602    }
603}
604
605
606
607impl std::str::FromStr for Type {
608    type Err = String;
609    fn from_str(s: &str) -> Result<Self, Self::Err> {
610        parse_type::parse(s)
611    }
612}
613
614impl std::fmt::Display for Type {
615    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
616        fn helper(ty: &Type, f: &mut std::fmt::Formatter<'_>, arrow_parens: bool) -> std::fmt::Result {
617            match ty {
618                Type::Var(i) => write!(f,"t{}", i),
619                Type::Term(name, args) => {
620                    if args.is_empty() {
621                        write!(f, "{}", name)
622                    } else if *name == *ARROW_SYM {
623                        assert_eq!(args.len(), 2);
624                        // write!(f, "({} {} {})", &args[0], name, &args[1])
625                        if arrow_parens {
626                            write!(f, "(")?;
627                        }
628                        helper(&args[0], f, true)?;
629                        write!(f, " {} ", Type::ARROW)?;
630                        helper(&args[1], f, false)?;
631                        if arrow_parens {
632                            write!(f, ")")?;
633                        }
634                        Ok(())
635                    } else {
636                        write!(f, "({}", name)?;
637                        for arg in args.iter() {
638                            write!(f, " ")?;
639                            helper(arg, f, true)?;
640                        }
641                        write!(f, ")")
642                    }
643                },
644            }
645        }
646        helper(self, f, true)
647    }
648}
649
650
651#[derive(Debug, Clone, PartialEq, Eq)]
652pub struct Context {
653    subst_unionfind: Vec<Option<Type>>, // todo also try ahashmap tho i just wanted to avoid the allocations
654    subst_append_only: Vec<(usize,Type)>,
655    next_var: usize,
656    append_only: bool,
657}
658
659impl Context {
660
661    /// This is the usual way of creating a new Context. The context will be append-only
662    /// meaning you can roll it back to a point by truncating
663    pub fn empty() -> Context {
664        Context {
665            subst_unionfind: Default::default(),
666            subst_append_only: Default::default(),
667            next_var: 0,
668            append_only: true,
669        }
670    }
671
672    /// instead of an append-only substitution, the context will instead use a unionfind. This is honestly
673    /// likely not noticably faster and doesnt allow rollbacks. It may even be slower.
674    pub fn empty_unionfind() -> Context {
675        Context {
676            subst_unionfind: Default::default(),
677            subst_append_only: Default::default(),
678            next_var: 0,
679            append_only: false,
680        }
681    }
682
683    pub fn save_state(&self) -> (usize,usize) {
684        assert!(self.append_only);
685        (self.subst_append_only.len(), self.next_var)
686    }
687
688    pub fn load_state(&mut self, state: (usize,usize)) {
689        assert!(self.append_only);
690        self.subst_append_only.truncate(state.0);
691        self.next_var = state.1;
692    }
693
694    fn fresh_type_var(&mut self) -> Type {
695        if !self.append_only {
696            self.subst_unionfind.push(None);
697        }
698        self.next_var += 1;
699        Type::Var(self.next_var-1)
700    }
701
702    /// adds new fresh type vars as necessary such that variable Var exists
703    #[inline(always)]
704    fn fresh_type_vars(&mut self, var: usize) {
705        while var >= self.next_var {
706            self.fresh_type_var();
707        }
708    }
709
710    /// a very quick non-allocating check that returns false if it's
711    /// obvious that these types won't unify. This works *even when a type hasnt
712    /// been instantiated() to have new type variables*. First this checks if t1 and t2 have the same constructors
713    /// and if theres an obvious mismatch there it gives up. Then it goes and looks up the types in the ctx
714    /// in case they were typevars, and then again checks if they have th same constructor. It uses apply_immut() to
715    /// avoid mutating the context for this lookup.
716    /// Note the apply_immut version of this was wrong bc thats only safe to do on the hole_tp side and apply_immut
717    /// is already done to the hole before then anyways
718    pub fn might_unify(&self, t1: &Type, t2: &Type) -> bool {
719        match (t1,t2) {
720            (Type::Var(_), Type::Var(_)) => true,
721            (Type::Var(_), Type::Term(_, _)) => true,
722            (Type::Term(_, _), Type::Var(_)) => true,
723            (Type::Term(x, xs), Type::Term(y, ys)) => {
724                x == y && xs.len() == ys.len() && xs.iter().zip(ys.iter()).all(|(x,y)| self.might_unify(x,y))
725            },
726        }
727    }
728
729    /// Normal unification. Does not do the amortizing step of the unionfind (but may mutate
730    /// it still). See unify_cached() for amortized unionfind. Note that this is likely not slower
731    /// than unify_cached() in most cases.
732    pub fn unify(&mut self, t1: &Type,  t2: &Type) -> UnifyResult {
733        // println!("\tunify({},{}) {}", t1, t2, self);
734        let t1: Type = t1.apply(self);
735        let t2: Type = t2.apply(self);
736        // println!("\t  ...({},{}) {}", t1, t2, self);
737        if t1.is_concrete() && t2.is_concrete() {
738            // if both types are concrete, simple equality works because we dont need to do any fancy variable binding
739            if t1 == t2 {
740                return Ok(())
741            } else {
742                return Err(UnifyErr::ConcreteSubtree)
743            }
744        }
745        match (t1, t2) {
746            (Type::Var(i), ty) | (ty, Type::Var(i)) => {
747                if ty == Type::Var(i) { return Ok(()) } // unify(t0, t0) -> true
748                if ty.occurs(i) { return Err(UnifyErr::Occurs) } // recursive type  e.g. unify(t0, (t0 -> int)) -> false
749                // *** Above is the "occurs" check, which prevents recursive definitions of types. Removing it would allow them.
750
751                assert!(self.get(i).is_none());
752                self.set(i, ty);
753                Ok(())
754            },
755            (Type::Term(x, xs), Type::Term(y, ys)) => {
756                // simply recurse
757                if x != y || xs.len() != ys.len() {
758                    return Err(UnifyErr::Production)
759                }
760                xs.iter().zip(ys.iter()).try_for_each(|(x,y)| self.unify(x,y))
761            }
762        }
763    }
764
765    /// [expert mode] like unify() but uses apply_cached() to do amortization step of
766    /// unionfind. Likely not worth using compared to unify().
767    pub fn unify_cached(&mut self, t1: &Type,  t2: &Type) -> UnifyResult {
768        // println!("unify({},{}) {}", t1, t2, self);
769        let t1: Type = t1.apply_cached(self);
770        let t2: Type = t2.apply_cached(self);
771        // println!("  ...({},{}) {}", t1, t2, self);
772        if t1.is_concrete() && t2.is_concrete() {
773            // if both types are concrete, simple equality works because we dont need to do any fancy variable binding
774            if t1 == t2 {
775                return Ok(())
776            } else {
777                return Err(UnifyErr::ConcreteSubtree)
778            }
779        }
780        match (t1, t2) {
781            (Type::Var(i), ty) | (ty, Type::Var(i)) => {
782                if ty == Type::Var(i) { return Ok(()) } // unify(t0, t0) -> true
783                if ty.occurs(i) { return Err(UnifyErr::Occurs) } // recursive type  e.g. unify(t0, (t0 -> int)) -> false
784                // *** Above is the "occurs" check, which prevents recursive definitions of types. Removing it would allow them.
785
786                assert!(self.subst_unionfind.get(i).is_none());
787                self.set(i, ty);
788                Ok(())
789            },
790            (Type::Term(x, xs), Type::Term(y, ys)) => {
791                // simply recurse
792                if x != y || xs.len() != ys.len() {
793                    return Err(UnifyErr::Production)
794                }
795                xs.iter().zip(ys.iter()).try_for_each(|(x,y)| self.unify(x,y))
796            }
797        }
798    }
799
800    /// get what a variable is bound to (if anything).
801    #[inline(always)]
802    fn get(&self, var: usize) -> Option<&Type> { // todo written in a silly way, rewrite
803        if self.append_only {
804            self.subst_append_only.iter().rfind(|(i,_)| *i == var).map(|(_,tp)| tp)
805        } else {
806            self.subst_unionfind[var].as_ref()
807        }
808    }
809    /// set what a variable is bound to
810    #[inline(always)]
811    fn set(&mut self, var: usize, ty: Type) {
812        if self.append_only {
813            self.subst_append_only.push((var,ty));
814        } else {
815            self.subst_unionfind[var] = Some(ty);
816        }
817    }
818
819}
820
821impl std::fmt::Display for Context {
822    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
823        write!(f,"{{")?;
824        let mut first: bool = true;
825        for (i, item) in self.subst_unionfind.iter().enumerate() {
826            if let Some(ty) = item {
827                if !first { write!(f, ", ")? } else { first = false }
828                write!(f, "{}:{}", i, ty)?
829            }
830        }
831        write!(f,"}}")
832    }
833}
834
835
836impl Expr {
837    pub fn infer<D: Domain>(&self, child: Option<Id>, ctx: &mut Context, env: &mut VecDeque<Type>) -> Result<Type,UnifyErr> {
838        // println!("infer({})", self.to_string_uncurried(child));
839        let child = child.unwrap_or_else(||self.root());
840        match &self.nodes[usize::from(child)] {
841            Lambda::App([f,x]) => {
842                let return_tp = ctx.fresh_type_var();
843                let x_tp = self.infer::<D>(Some(*x), ctx, env)?;
844                let f_tp = self.infer::<D>(Some(*f), ctx, env)?;
845                ctx.unify(&f_tp, &Type::arrow(x_tp, return_tp.clone()))?;
846                Ok(return_tp.apply(ctx))
847            },
848            Lambda::Lam([b]) => {
849                let var_tp = ctx.fresh_type_var();
850                // todo maybe optimize by making this a vecdeque for faster insert/remove at the zero index
851                env.push_front(var_tp.clone());
852                let body_tp = self.infer::<D>(Some(*b), ctx, env)?;
853                env.pop_front();
854                Ok(Type::arrow(var_tp, body_tp).apply(ctx))
855            },
856            Lambda::Var(i) => {
857                if (*i as usize) >= env.len() {
858                    panic!("unbound variable encountered during infer(): ${}", i)
859                }
860                Ok(env[*i as usize].apply(ctx))
861            },
862            Lambda::IVar(_i) => {
863                // interesting, I guess we can have this and it'd probably be easy to do
864                unimplemented!();
865            }
866            Lambda::Prim(p) => {
867                Ok(D::type_of_prim(*p).instantiate(ctx))
868            },
869            Lambda::Programs(_) => panic!("trying to infer() type of Programs() node"),
870        }
871    }
872}
873
874