formality/
term.rs

1// Warning: needs polishments to avoid countless unecessary clones
2
3use std;
4use std::collections::*;
5
6#[derive(Clone, PartialEq, PartialOrd, Eq, Ord, Debug, Hash)]
7pub enum Term {
8    // Forall
9    All {
10        nam: Vec<u8>,
11        typ: Box<Term>,
12        bod: Box<Term>
13    },
14    // Lambda
15    Lam {
16        nam: Vec<u8>,
17        typ: Box<Term>,
18        bod: Box<Term>
19    },
20    // Variable
21    Var {
22        idx: i32
23    },
24    // Application
25    App {
26        fun: Box<Term>,
27        arg: Box<Term>
28    },
29    // Inductive Data Type
30    Idt {
31        nam: Vec<u8>,
32        arg: Vec<Box<Term>>,
33        par: Vec<(Vec<u8>, Box<Term>)>,
34        typ: Box<Term>,
35        ctr: Vec<(Vec<u8>, Box<Term>)>
36    },
37    // Instantiate
38    New {
39        idt: Box<Term>,
40        ctr: Vec<Vec<u8>>,
41        bod: Box<Term>
42    },
43    // Pattern-Matching
44    Cas {
45        val: Box<Term>,
46        cas: Vec<(Vec<u8>, Vars, Box<Term>)>,
47        ret: (Vars, Box<Term>)
48    },
49    // Reference
50    Ref {
51        nam: Vec<u8>
52    },
53    // Copy
54    Cpy {
55        nam: (Vec<u8>, Vec<u8>),
56        val: Box<Term>,
57        bod: Box<Term>
58    },
59    // Type of Types
60    Set
61}
62use self::Term::{*};
63
64pub type Vars = Vec<Vec<u8>>;
65pub type Defs = HashMap<Vec<u8>, Term>;
66
67#[derive(Clone, PartialEq, PartialOrd, Eq, Ord, Debug, Hash)]
68pub enum TypeError {
69    AppTypeMismatch {
70        expect: Term,
71        actual: Term,
72        argval: Term,
73        term: Term,
74        vars: Vars
75    },
76    AppNotAll {
77        funval: Term,
78        funtyp: Term,
79        term: Term,
80        vars: Vars
81    },
82    ForallNotAType {
83        typtyp: Term,
84        bodtyp: Term,
85        term: Term,
86        vars: Vars
87    },
88    Unbound {
89        name: Vec<u8>,
90        vars: Vars
91    },
92    NewTypeMismatch {
93        expect: Term,
94        actual: Term,
95        term: Term,
96        vars: Vars
97    },
98    MatchNotIDT {
99        actual: Term,
100        term: Term,
101        vars: Vars
102    },
103    WrongMatchIndexCount {
104        expect: usize,
105        actual: usize,
106        term: Term,
107        vars: Vars
108    },
109    WrongMatchReturnArity {
110        expect: usize,
111        actual: usize,
112        term: Term,
113        vars: Vars
114    },
115    WrongMatchCaseCount {
116        expect: usize,
117        actual: usize,
118        term: Term,
119        vars: Vars
120    },
121    WrongCaseName {
122        expect: Vec<u8>,
123        actual: Vec<u8>,
124        term: Term,
125        vars: Vars
126    },
127    WrongCaseArity {
128        expect: usize,
129        actual: usize,
130        name: Vec<u8>,
131        term: Term,
132        vars: Vars
133    },
134    WrongCaseType {
135        expect: Term,
136        actual: Term,
137        name: Vec<u8>,
138        term: Term,
139        vars: Vars
140    },
141}
142use self::TypeError::{*};
143
144// Adds an unique name to a Vars vector, properly renaming if shadowed.
145pub fn rename(nam : &Vec<u8>, vars : &Vars) -> Vec<u8> {
146    let mut new_nam = nam.clone();
147    if new_nam.len() > 0 {
148        for var_nam in vars.iter() {
149            if var_nam == nam {
150                new_nam.extend_from_slice(b"'");
151            }
152        }
153    }
154    new_nam
155}
156
157// Returns how many time the variable at depth `dpt` is used.
158pub fn uses(term : &Term, dpt : i32) -> u32 {
159    match term {
160        &App{ref fun, ref arg} => {
161            uses(fun, dpt) +
162            uses(arg, dpt)
163        },
164        &Lam{nam: _, ref typ, ref bod} => {
165            uses(typ, dpt) +
166            uses(bod, dpt + 1)
167        },
168        &All{nam: _, ref typ, ref bod} => {
169            uses(typ, dpt) +
170            uses(bod, dpt + 1)
171        },
172        &Var{idx} => {
173            if idx == dpt { 1 } else { 0 }
174        },
175        &Ref{nam: _} => 0,
176        &Idt{nam: _, ref arg, ref par, ref typ, ref ctr} => {
177            arg.iter().fold(0, |res, val| res + uses(&val, dpt)) +
178            par.iter().fold(0, |res, val| res + uses(&val.1, dpt)) +
179            uses(typ, dpt + par.len() as i32) +
180            ctr.iter().fold(0, |res, val| res + uses(&val.1, dpt + par.len() as i32 + 1))
181        },
182        &New{ref idt, ref ctr, ref bod} => {
183            uses(idt, dpt) +
184            uses(bod, dpt + ctr.len() as i32)
185        },
186        &Cas{ref val, ref ret, ref cas} => {
187            uses(val, dpt) +
188            uses(&ret.1, dpt + 1 + ret.0.len() as i32) +
189            cas.iter().fold(0, |res, val| res + uses(&val.2, dpt + 1 + val.1.len() as i32))
190        },
191        &Cpy{nam: _, ref val, ref bod} => {
192            uses(val, dpt) +
193            uses(bod, dpt + 2)
194        },
195        &Set => 0
196    }
197}
198
199// Increases the index of all free variables of a term, assuming `cut` enclosing lambdas, by `inc`.
200pub fn shift(term : &mut Term, inc : i32, cut : i32) {
201    match term {
202        &mut App{ref mut fun, ref mut arg} => {
203            shift(fun, inc, cut);
204            shift(arg, inc, cut);
205        },
206        &mut Lam{nam: _, ref mut typ, ref mut bod} => {
207            shift(typ, inc, cut);
208            shift(bod, inc, cut + 1);
209        },
210        &mut All{nam: _, ref mut typ, ref mut bod} => {
211            shift(typ, inc, cut);
212            shift(bod, inc, cut + 1);
213        },
214        &mut Var{ref mut idx} => {
215            *idx = if *idx < cut { *idx } else { *idx + inc };
216        },
217        &mut Ref{nam: _} => {},
218        &mut Idt{nam: _, ref mut arg, ref mut par, ref mut typ, ref mut ctr} => {
219            for arg_val in arg {
220                shift(arg_val, inc, cut);
221            }
222            for i in 0..par.len() { 
223                shift(&mut par[i].1, inc, cut);
224            }
225            shift(typ, inc, cut + par.len() as i32);
226            for (_,ctr_typ) in ctr {
227                shift(ctr_typ, inc, cut + par.len() as i32 + 1);
228            }
229        },
230        &mut New{ref mut idt, ref mut ctr, ref mut bod} => {
231            shift(idt, inc, cut);
232            shift(bod, inc, cut + ctr.len() as i32);
233        },
234        &mut Cas{ref mut val, ref mut ret, ref mut cas} => {
235            shift(val, inc, cut);
236            shift(&mut ret.1, inc, cut + 1 + ret.0.len() as i32);
237            for (_, cas_arg, cas_bod) in cas {
238                shift(cas_bod, inc, cut + 1 + cas_arg.len() as i32);
239            }
240        },
241        &mut Cpy{nam: _, ref mut val, ref mut bod} => {
242            shift(val, inc, cut);
243            shift(bod, inc, cut + 2);
244        },
245        &mut Set => {}
246    }
247}
248
249// Immutable shift.
250pub fn shifted(term : &Term, inc : i32, cut : i32) -> Term {
251    let mut term_copy = term.clone();
252    shift(&mut term_copy, inc, cut);
253    term_copy
254}
255
256// Substitutes the variable at given depth in term by value.
257pub fn subs(term : &mut Term, value : &Term, dpt : i32) {
258    let mut new_term : Option<Term> = None;
259    match term {
260        &mut App{ref mut fun, ref mut arg} => {
261            subs(fun, value, dpt);
262            subs(arg, value, dpt);
263        },
264        &mut Lam{nam: ref mut _nam, ref mut typ, ref mut bod} => {
265            subs(typ, value, dpt);
266            subs(bod, value, dpt + 1);
267        },
268        &mut All{nam: ref mut _nam, ref mut typ, ref mut bod} => {
269            subs(typ, value, dpt);
270            subs(bod, value, dpt + 1);
271        },
272        &mut Var{idx} => {
273            if dpt == idx {
274                let mut val = value.clone();
275                shift(&mut val, dpt as i32, 0);
276                new_term = Some(val);
277            } else if dpt < idx {
278                new_term = Some(Var{idx: idx - 1})
279            }
280        },
281        &mut Ref{nam: _} => {},
282        &mut Idt{nam: ref mut _nam, ref mut arg, ref mut par, ref mut typ, ref mut ctr} => {
283            for arg_val in arg {
284                subs(arg_val, value, dpt);
285            }
286            for i in 0..par.len() {
287                subs(&mut par[i].1, value, dpt);
288            }
289            subs(typ, value, dpt + par.len() as i32);
290            for (_,ctr_typ) in ctr {
291                subs(ctr_typ, value, dpt + par.len() as i32 + 1);
292            }
293        },
294        &mut New{ref mut idt, ref mut ctr, ref mut bod} => {
295            subs(idt, value, dpt);
296            subs(bod, value, dpt + ctr.len() as i32);
297        },
298        &mut Cas{ref mut val, ref mut ret, ref mut cas} => {
299            subs(val, value, dpt);
300            subs(&mut ret.1, value, dpt + 1 + ret.0.len() as i32);
301            for (_, cas_arg, cas_bod) in cas {
302                subs(cas_bod, value, dpt + 1 + cas_arg.len() as i32);
303            }
304        },
305        &mut Cpy{nam: _, ref mut val, ref mut bod} => {
306            subs(val, value, dpt);
307            subs(bod, value, dpt + 2);
308        },
309        _ => {}
310    };
311    // Because couldn't modify Var inside its own case
312    match new_term {
313        Some(new_term) => *term = new_term,
314        None => {}
315    };
316}
317
318// Extracts the function and a list of arguments from a curried f(x, y, z) expression.
319pub fn get_fun_args(term : &Term) -> (&Term, Vec<&Term>) {
320    let mut term : &Term = term;
321    let mut args : Vec<&Term> = Vec::new();
322    loop {
323        match term {
324            App{ref fun, ref arg} => {
325                args.push(arg);
326                term = fun;
327            },
328            _ => break
329        }
330    }
331    args.reverse();
332    (term, args)
333}
334
335// Extracts the names, types and body from a curried `(x : A) => (y : B) => c` expression.
336pub fn get_nams_typs_bod(term : &Term) -> (Vec<&Vec<u8>>, Vec<&Term>, &Term) {
337    let mut term : &Term = term;
338    let mut nams : Vec<&Vec<u8>> = Vec::new();
339    let mut typs : Vec<&Term> = Vec::new();
340    loop {
341        match term {
342            Lam{ref nam, ref typ, ref bod} => {
343                nams.push(nam);
344                typs.push(typ);
345                term = bod;
346            },
347            All{ref nam, ref typ, ref bod} => {
348                nams.push(nam);
349                typs.push(typ);
350                term = bod;
351            },
352            _ => break
353        }
354    }
355    (nams, typs, term)
356}
357
358// Reduces an expression if it is a redex, returns true if was.
359pub fn redex(term : &mut Term, defs : &Defs, deref : bool) -> bool {
360    let mut changed = false;
361    let tmp_term = std::mem::replace(term, Set);
362    let new_term : Term = match tmp_term {
363        App{mut fun, mut arg} => {
364            let tmp_fun : Term = *fun;
365            match tmp_fun {
366                Lam{nam: _, typ: _, mut bod} => {
367                    subs(&mut bod, &arg, 0);
368                    changed = true;
369                    *bod
370                },
371                t => {
372                    App{fun: Box::new(t), arg}
373                }
374            }
375        },
376        Cas{mut val, mut ret, mut cas} => {
377            let tmp_val = *val;
378            match tmp_val {
379                New{idt, ctr, bod} => {
380                    let (ctr_choice, args) = get_fun_args(&bod);
381                    match ctr_choice {
382                        Var{idx} => {
383                            changed = true;
384                            // Creates the folding function
385                            let mut new_ret = ret.clone();
386                            let mut new_cas = cas.clone();
387                            shift(&mut new_ret.1, 1, 1 + new_ret.0.len() as i32);
388                            for (_, ref mut new_cas_arg, ref mut new_cas_bod) in &mut new_cas {
389                                shift(new_cas_bod, 1, 1 + new_cas_arg.len() as i32);
390                            }
391                            let mut fold_fun = Lam{
392                                nam: b"X".to_vec(),
393                                typ: Box::new(Set),
394                                bod: Box::new(Cas{
395                                    val: Box::new(Var{idx: 0}),
396                                    ret: new_ret,
397                                    cas: new_cas
398                                })
399                            };
400                            // Finds matching constructor and substitutes
401                            let mut bod : Term = Set;
402                            for i in 0..cas.len() {
403                                let cas_nam = &cas[i].0;
404                                let cas_bod = &cas[i].2;
405                                if *cas_nam == ctr[cas.len() - *idx as usize - 1] {
406                                    bod = *cas_bod.clone();
407                                }
408                            }
409                            subs(&mut bod, &fold_fun, args.len() as i32);
410                            for i in 0..args.len() {
411                                let mut new_arg = args[i].clone();
412                                shift(&mut new_arg, ctr.len() as i32 * -1, 0);
413                                subs(&mut bod, &new_arg, (args.len() - i - 1) as i32);
414                            }
415
416                            bod
417                        },
418                        _ => {
419                            let idt = idt.clone();
420                            let ctr = ctr.clone();
421                            let bod = bod.clone();
422                            let val = Box::new(New{idt, ctr, bod});
423                            Cas{val, ret, cas}
424                        }
425                    }
426                },
427                _ => {
428                    Cas{val: Box::new(tmp_val.clone()), ret, cas}
429                }
430            }
431        },
432        Cpy{nam: _, mut val, mut bod} => {
433            let mut bod = bod.clone();
434            subs(&mut bod, &val, 1);
435            subs(&mut bod, &val, 0);
436            changed = true;
437            *bod
438        },
439        Ref{nam} => {
440            if deref {
441                match defs.get(&nam) {
442                    Some(val) => {
443                        changed = true;
444                        val.clone()
445                    },
446                    None => {
447                        panic!(format!("Unbound variable: {}.", String::from_utf8_lossy(&nam)))
448                    }
449                }
450            } else {
451                Ref{nam}
452            }
453        },
454        t => t
455    };
456    std::mem::replace(term, new_term);
457    changed
458}
459
460// Performs a global parallel reduction step.
461pub fn global_reduce_step(term : &mut Term, defs : &Defs, deref : bool) -> bool {
462    let changed_below = match term {
463        App{ref mut fun, ref mut arg} => {
464            let fun = global_reduce_step(fun, defs, deref);
465            let arg = global_reduce_step(arg, defs, deref);
466            fun || arg
467        },
468        Lam{nam: _, ref mut typ, ref mut bod} => {
469            let typ = global_reduce_step(typ, defs, deref);
470            let bod = global_reduce_step(bod, defs, deref);
471            typ || bod
472        },
473        All{nam: _, ref mut typ, ref mut bod} => {
474            let typ = global_reduce_step(typ, defs, deref);
475            let bod = global_reduce_step(bod, defs, deref);
476            typ || bod
477        },
478        Idt{nam: _, ref mut arg, ref mut par, ref mut typ, ref mut ctr} => {
479            let mut changed_arg = false;
480            for i in 0..arg.len() {
481                changed_arg = changed_arg || global_reduce_step(&mut arg[i], defs, deref);
482            }
483            let mut changed_par = false;
484            for i in 0..par.len() {
485                changed_par = changed_par || global_reduce_step(&mut par[i].1, defs, deref);
486            }
487            let mut changed_ctr = false;
488            for i in 0..ctr.len() {
489                changed_ctr = changed_ctr || global_reduce_step(&mut ctr[i].1, defs, deref);
490            }
491            let typ = global_reduce_step(typ, defs, deref);
492            changed_arg || changed_par || changed_ctr || typ
493        },
494        New{ref mut idt, ctr: _, ref mut bod} => {
495            let idt = global_reduce_step(idt, defs, deref);
496            let bod = global_reduce_step(bod, defs, deref);
497            idt || bod
498        },
499        Cas{ref mut val, ref mut ret, ref mut cas} => {
500            let mut changed_cas = false;
501            for i in 0..cas.len() {
502                let cas_ret = &mut cas[i].2;
503                changed_cas = changed_cas || global_reduce_step(cas_ret, defs, deref);
504            }
505            let val = global_reduce_step(val, defs, deref);
506            let ret = global_reduce_step(&mut ret.1, defs, deref);
507            changed_cas || val || ret
508        },
509        Cpy{nam: _, ref mut val, ref mut bod} => {
510            global_reduce_step(val, defs, deref) ||
511            global_reduce_step(bod, defs, deref)
512        },
513        _ => false
514    };
515    let changed_self = redex(term, defs, deref);
516    changed_below || changed_self
517}
518
519// Performs a global parallel weak reduction step.
520pub fn weak_global_reduce_step(term : &mut Term, defs : &Defs, deref : bool) -> bool {
521    let changed_below = match term {
522        App{ref mut fun, arg: _} => {
523            weak_global_reduce_step(fun, defs, deref)
524        },
525        Cas{ref mut val, ret: _, cas: _} => {
526            weak_global_reduce_step(val, defs, deref)
527        },
528        _ => false
529    };
530    let changed_self = redex(term, defs, deref);
531    changed_below || changed_self
532}
533
534// Reduces a term to weak head normal form.
535pub fn weak_reduce(term : &mut Term, defs : &Defs, deref : bool) -> bool {
536    let mut changed = false;
537    while weak_global_reduce_step(term, defs, deref) {
538        changed = true;
539    }
540    changed
541}
542
543// Immutable weak_reduce.
544pub fn weak_reduced(term : &Term, defs : &Defs, deref : bool) -> Term {
545    let mut term_copy = term.clone();
546    weak_reduce(&mut term_copy, defs, deref);
547    term_copy
548}
549
550// Reduces a term to normal form.
551pub fn reduce(term : &mut Term, defs : &Defs, deref : bool) -> bool {
552    let mut changed = false;
553    loop {
554        // Reduces as much as possible without ref expansions
555        while global_reduce_step(term, defs, false) {
556            changed = true;
557        }
558        // Reduces once with ref expansion
559        if deref && global_reduce_step(term, defs, true) {
560            changed = true;
561        // If nothing changed, halt
562        } else {
563            break;
564        }
565    }
566    changed
567}
568
569// Immutable reduce.
570pub fn reduced(term : &Term, defs : &Defs, deref : bool) -> Term {
571    let mut term_copy = term.clone();
572    reduce(&mut term_copy, defs, deref); 
573    term_copy
574}
575
576// Performs an equality test.
577pub fn equals(a : &Term, b : &Term) -> bool {
578    // Check if the heads are equal.
579    match (a, b) {
580        (&App{fun: ref a_fun, arg: ref a_arg},
581         &App{fun: ref b_fun, arg: ref b_arg}) => {
582            equals(a_fun, b_fun) &&
583            equals(a_arg, b_arg)
584        },
585        (&Lam{nam: _, typ: ref a_typ, bod: ref a_bod},
586         &Lam{nam: _, typ: ref b_typ, bod: ref b_bod}) => {
587            equals(a_typ, b_typ) &&
588            equals(a_bod, b_bod)
589        },
590        (&All{nam: _, typ: ref a_typ, bod: ref a_bod},
591         &All{nam: _, typ: ref b_typ, bod: ref b_bod}) => {
592            equals(a_typ, b_typ) &&
593            equals(a_bod, b_bod)
594        },
595        (&Var{idx: ref a_idx},
596         &Var{idx: ref b_idx}) => {
597            a_idx == b_idx
598        },
599        (&Ref{nam: ref a_nam},
600         &Ref{nam: ref b_nam}) => {
601            a_nam == b_nam
602         },
603        (&Idt{nam: _, arg: ref a_arg, par: ref a_par, typ: ref a_typ, ctr: ref a_ctr},
604         &Idt{nam: _, arg: ref b_arg, par: ref b_par, typ: ref b_typ, ctr: ref b_ctr}) => {
605            let mut eql_arg = true;
606            if a_arg.len() != b_arg.len() {
607                return false;
608            }
609            for i in 0..a_arg.len() {
610                let a_arg_val = a_arg[i].clone();
611                let b_arg_val = b_arg[i].clone();
612                eql_arg = eql_arg && equals(&a_arg_val, &b_arg_val);
613            }
614            let mut eql_par = true;
615            if a_par.len() != b_par.len() {
616                return false;
617            }
618            for i in 0..a_par.len() {
619                let (a_par_nam, a_par_typ) = a_par[i].clone();
620                let (b_par_nam, b_par_typ) = b_par[i].clone();
621                eql_par = eql_par && a_par_nam == b_par_nam && equals(&a_par_typ, &b_par_typ);
622            }
623            let mut eql_ctr = true;
624            if a_ctr.len() != b_ctr.len() {
625                return false;
626            }
627            for i in 0..a_ctr.len() {
628                let (a_ctr_nam, a_ctr_typ) = a_ctr[i].clone();
629                let (b_ctr_nam, b_ctr_typ) = b_ctr[i].clone();
630                eql_ctr = eql_ctr && a_ctr_nam == b_ctr_nam && equals(&a_ctr_typ, &b_ctr_typ);
631            }
632            eql_arg && eql_par && equals(a_typ, b_typ) && eql_ctr
633        },
634        (&New{idt: ref a_idt, ctr: _, bod: ref a_bod},
635         &New{idt: ref b_idt, ctr: _, bod: ref b_bod}) => {
636            equals(a_idt, b_idt) && equals(a_bod, b_bod)
637        },
638        (&Cas{val: ref a_val, ret: ref a_ret, cas: ref a_cas},
639         &Cas{val: ref b_val, ret: ref b_ret, cas: ref b_cas}) => {
640            let mut eql_cas = true;
641            for i in 0..a_cas.len() {
642                let (_, _, ref a_cas_bod) = a_cas[i];
643                let (_, _, ref b_cas_bod) = b_cas[i];
644                eql_cas = eql_cas && equals(&a_cas_bod, &b_cas_bod);
645            }
646            eql_cas &&
647            equals(a_val, b_val) &&
648            equals(&a_ret.1, &b_ret.1)
649        },
650        (&Cpy{nam: _, val: ref a_val, bod: ref a_bod},
651         &Cpy{nam: _, val: ref b_val, bod: ref b_bod}) => {
652            equals(a_val, b_val) &&
653            equals(a_bod, b_bod)
654        },
655        (Set, Set) => true,
656        _ => false
657    }
658}
659
660// Performs an equality test after normalization.
661pub fn equals_reduced(a : &Term, b : &Term, defs : &Defs) -> bool {
662    let mut a_nf = a.clone();
663    let mut b_nf = b.clone();
664    reduce(&mut a_nf, defs, true);
665    reduce(&mut b_nf, defs, true);
666    equals(&a_nf, &b_nf)
667}
668
669// A Context is a vector of type assignments.
670pub type Context<'a> = Vec<Term>;
671
672// Extends a context.
673pub fn extend_context<'a>(val : &Term, ctx : &'a mut Context<'a>) -> &'a mut Context<'a> {
674    for i in 0..ctx.len() {
675        shift(&mut ctx[i], 1, 0);
676    }
677    ctx.push(val.clone());
678    ctx
679}
680
681// Narrows a context.
682pub fn narrow_context<'a>(ctx : &'a mut Context<'a>) -> &'a mut Context<'a> {
683    ctx.pop();
684    for i in 0..ctx.len() {
685        shift(&mut ctx[i], -1, 0);
686    }
687    ctx
688}
689
690// Returns the type of an IDT and its constructors, with parameters substituted.
691pub fn apply_idt_args(idt : &Term) -> (Term, Vec<(Vec<u8>, Term)>) {
692    match idt {
693        Idt{nam:_, ref arg, par: _, ref typ, ref ctr} => {
694            let mut typ = *typ.clone();
695            for j in 0..arg.len() {
696                subs(&mut typ, &arg[j], (arg.len() - j - 1) as i32);
697            }
698            let mut ctr_typs = Vec::new();
699            for i in 0..ctr.len() {
700                let ctr_nam = ctr[i].0.clone();
701                let mut ctr_typ = *ctr[i].1.clone();
702                for j in 0..arg.len() {
703                    subs(&mut ctr_typ, &arg[j], (arg.len() + 1 - j - 1) as i32);
704                }
705                subs(&mut ctr_typ, &idt, 0);
706                ctr_typs.push((ctr_nam, ctr_typ));
707            }
708            (typ, ctr_typs)
709        },
710        _ => (Set, Vec::new())
711    }
712}
713
714// Infers the type.
715pub fn do_infer<'a>(term : &Term, vars : &mut Vars, defs : &Defs, ctx : &mut Context, checked : bool) -> Result<Term, TypeError> {
716    match term {
717        App{fun, arg} => {
718            let fun_t = weak_reduced(&do_infer(fun, vars, defs, ctx, checked)?, defs, true);
719            match fun_t {
720                All{nam: _f_nam, typ: f_typ, bod: f_bod} => {
721                    let mut arg_n = arg.clone();
722                    if !checked {
723                        let arg_t = do_infer(arg, vars, defs, ctx, checked)?;
724                        if !equals_reduced(&f_typ, &arg_t, defs) {
725                            return Err(AppTypeMismatch{
726                                expect: *f_typ.clone(), 
727                                actual: arg_t.clone(),
728                                argval: *arg.clone(),
729                                term: term.clone(),
730                                vars: vars.clone()
731                            });
732                        }
733                    }
734                    let mut new_bod = f_bod.clone();
735                    subs(&mut new_bod, &arg_n, 0);
736                    Ok(*new_bod)
737                },
738                _ => {
739                    Err(AppNotAll{
740                        funval: *fun.clone(),
741                        funtyp: fun_t.clone(),
742                        term: term.clone(),
743                        vars: vars.clone()
744                    })
745                }
746            }
747        },
748        Lam{nam, typ, bod} => {
749            let nam = rename(&nam, vars);
750            vars.push(nam.to_vec());
751            extend_context(&shifted(&typ,1,0), ctx);
752            let bod_t = Box::new(do_infer(bod, vars, defs, ctx, checked)?);
753            vars.pop();
754            narrow_context(ctx);
755            if !checked {
756                let nam = nam.clone();
757                let typ = typ.clone();
758                let bod = bod_t.clone();
759                do_infer(&All{nam,typ,bod}, vars, defs, ctx, checked)?;
760            }
761            Ok(All{nam: nam.clone(), typ: typ.clone(), bod: bod_t})
762        },
763        All{nam, typ, bod} => {
764            if !checked {
765                let nam = rename(&nam, vars);
766                let typ_t = Box::new(do_infer(typ, vars, defs, ctx, checked)?);
767                vars.push(nam.to_vec());
768                extend_context(&shifted(&typ,1,0), ctx);
769                let bod_t = Box::new(do_infer(bod, vars, defs, ctx, checked)?);
770                if !equals_reduced(&typ_t, &Set, defs) || !equals_reduced(&bod_t, &Set, defs) {
771                    return Err(ForallNotAType{
772                        typtyp: *typ_t.clone(),
773                        bodtyp: *bod_t.clone(),
774                        term: term.clone(),
775                        vars: vars.clone()
776                    });
777                }
778                vars.pop();
779                narrow_context(ctx);
780            }
781            Ok(Set)
782        },
783        Var{idx} => {
784            Ok(ctx[ctx.len() - (*idx as usize) - 1].clone())
785        },
786        Ref{nam} => {
787            match defs.get(nam) {
788                Some(val) => infer(val, &defs, true),
789                None => Err(Unbound{name: nam.clone(), vars: vars.clone()})
790            }
791        },
792        Idt{nam: _, arg, par: _, typ, ctr: _} => {
793            // TODO: IDT isn't checked
794            let mut typ = typ.clone();
795            for i in 0..arg.len() {
796                subs(&mut typ, &arg[i], (arg.len() - i - 1) as i32);
797            }
798            Ok(*typ)
799        },
800        New{idt, ctr: _, bod} => {
801            let idt = weak_reduced(&idt, defs, true);
802            let (_, idt_ctr) = apply_idt_args(&idt);
803            for i in 0..idt_ctr.len() {
804                vars.push(idt_ctr[i].0.clone());
805                extend_context(&shifted(&idt_ctr[i].1, (i + 1) as i32, 0), ctx); // TODO: (i+1) or ctr.len()?
806            }
807
808            let mut bod_typ = do_infer(bod, vars, defs, ctx, checked)?;
809            shift(&mut bod_typ, idt_ctr.len() as i32 * -1, 0);
810
811            for _ in 0..idt_ctr.len() {
812                vars.pop();
813                narrow_context(ctx);
814            }
815
816            // TODO: check if body has right type
817            Ok(bod_typ)
818        },
819        Cas{val, ret, cas} => {
820            // Gets datatype and applied indices
821            let val_typ = do_infer(val, vars, defs, ctx, checked)?;
822            let mut idt_fxs = get_fun_args(&val_typ);
823            let mut idt = weak_reduced(&idt_fxs.0, defs, true);
824            let mut idx = idt_fxs.1;
825
826            // Gets datatype type and constructors
827            let (typ, ctr) = apply_idt_args(&idt);
828            
829            // Builds the match return type
830            let mut ret_typ : Term = *ret.1.clone();
831            subs(&mut ret_typ, &val, idx.len() as i32);
832            for i in 0..idx.len() {
833                subs(&mut ret_typ, &idx[i], (idx.len() - i - 1) as i32);
834            }
835
836            // Creates the fold type
837            let mut fold_ret_typ : Term = *ret.1.clone();
838            for i in 0..idx.len() {
839                subs(&mut fold_ret_typ, &idx[i], (idx.len() - i - 1) as i32);
840            }
841            let mut fold_typ : Term = All{
842                nam: b"X".to_vec(),
843                typ: Box::new(val_typ.clone()),
844                bod: Box::new(fold_ret_typ.clone())
845            };
846
847            if !checked {
848                let (_, expect_idx_typ, _) = get_nams_typs_bod(&typ); 
849
850                // Checks if number of indices match
851                if idx.len() != expect_idx_typ.len() {
852                    return Err(WrongMatchIndexCount{
853                        expect: expect_idx_typ.len(),
854                        actual: idx.len(),
855                        term: term.clone(),
856                        vars: vars.clone()
857                    });
858                }
859
860                // Check if return type has expected arity
861                if ret.0.len() != idx.len() {
862                    return Err(WrongMatchReturnArity{
863                        expect: idx.len(),
864                        actual: ret.0.len(),
865                        term: term.clone(),
866                        vars: vars.clone()
867                    });
868                }
869
870                // Checks if number of cases matches number of constructors
871                if cas.len() != ctr.len() {
872                    return Err(WrongMatchCaseCount{
873                        expect: ctr.len(),
874                        actual: cas.len(),
875                        term: term.clone(),
876                        vars: vars.clone()
877                    });
878                } 
879
880                // For each case of the pattern-match
881                for i in 0..cas.len() {
882                    // Get its name, variables, body and type
883                    let cas_nam = &cas[i].0;
884                    let cas_arg = &cas[i].1;
885                    let cas_bod = &cas[i].2;
886                    let cas_typ = &ctr[i].1;
887
888                    // Checks if case name matches the constructor's name
889                    if cas_nam != &ctr[i].0 {
890                        return Err(WrongCaseName{
891                            expect: ctr[i].0.clone(),
892                            actual: cas_nam.clone(),
893                            term: term.clone(),
894                            vars: vars.clone()
895                        });
896                    }
897
898                    // Gets argument types and body type
899                    let mut cas_typ = cas_typ.clone();
900                    shift(&mut cas_typ, 1, 0); // because of fold
901                    let (_, cas_arg_typ, cas_bod_typ) = get_nams_typs_bod(&cas_typ);
902
903                    // Gets the datatype indices
904                    let (_, cas_idx) = get_fun_args(cas_bod_typ);
905
906                    // Checks if case field count matches constructor field count
907                    if cas_arg_typ.len() != cas_arg.len() {
908                        return Err(WrongCaseArity{
909                            expect: cas_arg_typ.len(),
910                            actual: cas_arg.len(),
911                            name: cas_nam.clone(),
912                            term: term.clone(),
913                            vars: vars.clone()
914                        });
915                    }
916
917                    // Initializes the witness
918                    let mut wit = Var{idx: (ctr.len() - i - 1) as i32};
919                    let mut idt = idt.clone();
920
921                    // Initializes the expected case return type
922                    let mut expect_cas_ret_typ = ret.1.clone();
923
924                    // Extends the context with the fold type
925                    extend_context(&shifted(&fold_typ,1,0), ctx);
926                    vars.push(b"fold".to_vec());
927                    shift(&mut expect_cas_ret_typ, 1, 1 + ret.0.len() as i32);
928                    shift(&mut idt, 1, 0);
929
930                    // For each field of this case
931                    for j in 0..cas_arg.len() {
932                        // Extends context with the field's type
933                        extend_context(&shifted(&cas_arg_typ[j],1,0), ctx);
934                        vars.push(cas_arg[j].clone());
935                        shift(&mut expect_cas_ret_typ, 1, 1 + ret.0.len() as i32);
936                        shift(&mut idt, 1, 0);
937
938                        // Appends field variable to the witness
939                        shift(&mut wit, 1, ctr.len() as i32);
940                        wit = App{
941                            fun: Box::new(wit),
942                            arg: Box::new(Var{idx: ctr.len() as i32})
943                        };
944                    }
945
946                    // Completes witness
947                    wit = New{
948                        idt: Box::new(idt),
949                        ctr: ctr.iter().map(|c| c.0.clone()).collect(),
950                        bod: Box::new(wit)
951                    };
952
953                    // Applies the witness to the expected case return type
954                    subs(&mut expect_cas_ret_typ, &wit, cas_idx.len() as i32);
955
956                    // Applies each index to the expected case return type
957                    for i in 0..cas_idx.len() {
958                        subs(&mut expect_cas_ret_typ, cas_idx[i], (cas_idx.len() - i - 1) as i32);
959                    }
960
961                    // Infers the actual case return type
962                    let actual_cas_ret_typ = do_infer(cas_bod, vars, defs, ctx, checked)?;
963
964                    // Checks if expected case return type matches actual case return type
965                    if !equals_reduced(&expect_cas_ret_typ, &actual_cas_ret_typ, defs) {
966                        return Err(WrongCaseType{
967                            expect: *expect_cas_ret_typ.clone(),
968                            actual: actual_cas_ret_typ.clone(),
969                            name: cas_nam.clone(),
970                            term: term.clone(),
971                            vars: vars.clone()
972                        });
973                    }
974
975                    // Cleans up fold var
976                    narrow_context(ctx);
977                    vars.pop();
978
979                    // Cleans up constructor vars
980                    for _ in 0..cas_arg.len() {
981                        narrow_context(ctx);
982                        vars.pop();
983                    }
984                }
985            }
986
987            Ok(ret_typ)
988        },
989        Cpy{nam, val, bod} => {
990            let nam_0 = rename(&nam.0, vars);
991            let nam_1 = rename(&nam.1, vars);
992            let val_typ = do_infer(val, vars, defs, ctx, checked)?;
993
994            extend_context(&shifted(&val_typ, 1, 0), ctx);
995            vars.push(nam_0.clone());
996
997            extend_context(&shifted(&val_typ, 2, 0), ctx);
998            vars.push(nam_1.clone());
999
1000            let mut bod_typ = do_infer(bod, vars, defs, ctx, checked)?;
1001            shift(&mut bod_typ, -2, 0);
1002
1003            narrow_context(ctx);
1004            vars.pop();
1005
1006            narrow_context(ctx);
1007            vars.pop();
1008
1009            Ok(bod_typ)
1010        },
1011        Set => {
1012            Ok(Set)
1013        },
1014    }
1015}
1016
1017// Convenience
1018pub fn infer(term : &Term, defs : &Defs, checked : bool) -> Result<Term, TypeError> {
1019    do_infer(term, &mut Vec::new(), defs, &mut Vec::new(), checked)
1020}