1use crate::{
2 diagnostics::{Diagnostics, WarningType},
3 fun::{builtins, Adts, Constructors, Ctx, Definition, FanKind, Name, Num, Pattern, Rule, Tag, Term},
4 maybe_grow,
5};
6use itertools::Itertools;
7use std::collections::{BTreeSet, HashSet};
8
9pub enum DesugarMatchDefErr {
10 AdtNotExhaustive { adt: Name, ctr: Name },
11 NumMissingDefault,
12 TypeMismatch { expected: Type, found: Type, pat: Pattern },
13 RepeatedBind { bind: Name },
14 UnreachableRule { idx: usize, nam: Name, pats: Vec<Pattern> },
15}
16
17impl Ctx<'_> {
18 pub fn desugar_match_defs(&mut self) -> Result<(), Diagnostics> {
20 for (def_name, def) in self.book.defs.iter_mut() {
21 let errs = def.desugar_match_def(&self.book.ctrs, &self.book.adts);
22 for err in errs {
23 match err {
24 DesugarMatchDefErr::AdtNotExhaustive { .. }
25 | DesugarMatchDefErr::NumMissingDefault
26 | DesugarMatchDefErr::TypeMismatch { .. } => {
27 self.info.add_function_error(err, def_name.clone(), def.source.clone())
28 }
29 DesugarMatchDefErr::RepeatedBind { .. } => self.info.add_function_warning(
30 err,
31 WarningType::RepeatedBind,
32 def_name.clone(),
33 def.source.clone(),
34 ),
35 DesugarMatchDefErr::UnreachableRule { .. } => self.info.add_function_warning(
36 err,
37 WarningType::UnreachableMatch,
38 def_name.clone(),
39 def.source.clone(),
40 ),
41 }
42 }
43 }
44
45 self.info.fatal(())
46 }
47}
48
49impl Definition {
50 pub fn desugar_match_def(&mut self, ctrs: &Constructors, adts: &Adts) -> Vec<DesugarMatchDefErr> {
51 let mut errs = vec![];
52 for rule in self.rules.iter_mut() {
53 desugar_inner_match_defs(&mut rule.body, ctrs, adts, &mut errs);
54 }
55 let repeated_bind_errs = fix_repeated_binds(&mut self.rules);
56 errs.extend(repeated_bind_errs);
57
58 let args = (0..self.arity()).map(|i| Name::new(format!("%arg{i}"))).collect::<Vec<_>>();
59 let rules = std::mem::take(&mut self.rules);
60 let idx = (0..rules.len()).collect::<Vec<_>>();
61 let mut used = BTreeSet::new();
62 match simplify_rule_match(args.clone(), rules.clone(), idx.clone(), vec![], &mut used, ctrs, adts) {
63 Ok(body) => {
64 let body = Term::rfold_lams(body, args.into_iter().map(Some));
65 self.rules = vec![Rule { pats: vec![], body }];
66 for i in idx {
67 if !used.contains(&i) {
68 let e = DesugarMatchDefErr::UnreachableRule {
69 idx: i,
70 nam: self.name.clone(),
71 pats: rules[i].pats.clone(),
72 };
73 errs.push(e);
74 }
75 }
76 }
77 Err(e) => errs.push(e),
78 }
79 errs
80 }
81}
82
83fn desugar_inner_match_defs(
84 term: &mut Term,
85 ctrs: &Constructors,
86 adts: &Adts,
87 errs: &mut Vec<DesugarMatchDefErr>,
88) {
89 maybe_grow(|| match term {
90 Term::Def { def, nxt } => {
91 errs.extend(def.desugar_match_def(ctrs, adts));
92 desugar_inner_match_defs(nxt, ctrs, adts, errs);
93 }
94 _ => {
95 for child in term.children_mut() {
96 desugar_inner_match_defs(child, ctrs, adts, errs);
97 }
98 }
99 })
100}
101
102fn fix_repeated_binds(rules: &mut [Rule]) -> Vec<DesugarMatchDefErr> {
113 let mut errs = vec![];
114 for rule in rules {
115 let mut binds = HashSet::new();
116 rule.pats.iter_mut().flat_map(|p| p.binds_mut()).rev().for_each(|nam| {
117 if binds.contains(nam) {
118 if let Some(nam) = nam {
120 errs.push(DesugarMatchDefErr::RepeatedBind { bind: nam.clone() });
121 }
122 *nam = None;
123 } else {
125 binds.insert(&*nam);
126 }
127 });
128 }
129 errs
130}
131
132fn simplify_rule_match(
152 args: Vec<Name>,
153 rules: Vec<Rule>,
154 idx: Vec<usize>,
155 with: Vec<Name>,
156 used: &mut BTreeSet<usize>,
157 ctrs: &Constructors,
158 adts: &Adts,
159) -> Result<Term, DesugarMatchDefErr> {
160 if args.is_empty() {
161 used.insert(idx[0]);
162 Ok(rules.into_iter().next().unwrap().body)
163 } else if rules[0].pats.iter().all(|p| p.is_wildcard()) {
164 Ok(irrefutable_fst_row_rule(args, rules.into_iter().next().unwrap(), idx[0], used))
165 } else {
166 let typ = Type::infer_from_def_arg(&rules, 0, ctrs)?;
167 match typ {
168 Type::Any => var_rule(args, rules, idx, with, used, ctrs, adts),
169 Type::Fan(fan, tag, tup_len) => fan_rule(args, rules, idx, with, used, fan, tag, tup_len, ctrs, adts),
170 Type::Num => num_rule(args, rules, idx, with, used, ctrs, adts),
171 Type::Adt(adt_name) => switch_rule(args, rules, idx, with, adt_name, used, ctrs, adts),
172 }
173 }
174}
175
176fn irrefutable_fst_row_rule(args: Vec<Name>, rule: Rule, idx: usize, used: &mut BTreeSet<usize>) -> Term {
180 let mut term = rule.body;
181 for (arg, pat) in args.into_iter().zip(rule.pats.into_iter()) {
182 match pat {
183 Pattern::Var(None) => {}
184 Pattern::Var(Some(var)) => {
185 term = Term::Use { nam: Some(var), val: Box::new(Term::Var { nam: arg }), nxt: Box::new(term) };
186 }
187 Pattern::Chn(var) => {
188 term = Term::Let {
189 pat: Box::new(Pattern::Chn(var)),
190 val: Box::new(Term::Var { nam: arg }),
191 nxt: Box::new(term),
192 };
193 }
194 _ => unreachable!(),
195 }
196 }
197 used.insert(idx);
198 term
199}
200
201fn var_rule(
206 mut args: Vec<Name>,
207 rules: Vec<Rule>,
208 idx: Vec<usize>,
209 mut with: Vec<Name>,
210 used: &mut BTreeSet<usize>,
211 ctrs: &Constructors,
212 adts: &Adts,
213) -> Result<Term, DesugarMatchDefErr> {
214 let arg = args[0].clone();
215 let new_args = args.split_off(1);
216
217 let mut new_rules = vec![];
218 for mut rule in rules {
219 let new_pats = rule.pats.split_off(1);
220 let pat = rule.pats.pop().unwrap();
221
222 if let Pattern::Var(Some(nam)) = &pat {
223 rule.body = Term::Use {
224 nam: Some(nam.clone()),
225 val: Box::new(Term::Var { nam: arg.clone() }),
226 nxt: Box::new(std::mem::take(&mut rule.body)),
227 };
228 }
229
230 let new_rule = Rule { pats: new_pats, body: rule.body };
231 new_rules.push(new_rule);
232 }
233
234 with.push(arg);
235
236 simplify_rule_match(new_args, new_rules, idx, with, used, ctrs, adts)
237}
238
239#[allow(clippy::too_many_arguments)]
255fn fan_rule(
256 mut args: Vec<Name>,
257 rules: Vec<Rule>,
258 idx: Vec<usize>,
259 with: Vec<Name>,
260 used: &mut BTreeSet<usize>,
261 fan: FanKind,
262 tag: Tag,
263 len: usize,
264 ctrs: &Constructors,
265 adts: &Adts,
266) -> Result<Term, DesugarMatchDefErr> {
267 let arg = args[0].clone();
268 let old_args = args.split_off(1);
269 let new_args = (0..len).map(|i| Name::new(format!("{arg}.{i}")));
270
271 let mut new_rules = vec![];
272 for mut rule in rules {
273 let pat = rule.pats[0].clone();
274 let old_pats = rule.pats.split_off(1);
275
276 let mut new_pats = match pat {
278 Pattern::Fan(.., sub_pats) => sub_pats,
279 Pattern::Var(var) => {
280 if let Some(var) = var {
281 let tup =
283 Term::Fan { fan, tag: tag.clone(), els: new_args.clone().map(|nam| Term::Var { nam }).collect() };
284 rule.body =
285 Term::Use { nam: Some(var), val: Box::new(tup), nxt: Box::new(std::mem::take(&mut rule.body)) };
286 }
287 new_args.clone().map(|nam| Pattern::Var(Some(nam))).collect()
288 }
289 _ => unreachable!(),
290 };
291 new_pats.extend(old_pats);
292
293 let new_rule = Rule { pats: new_pats, body: rule.body };
294 new_rules.push(new_rule);
295 }
296
297 let bnd = new_args.clone().map(|x| Pattern::Var(Some(x))).collect();
298 let args = new_args.chain(old_args).collect();
299 let nxt = simplify_rule_match(args, new_rules, idx, with, used, ctrs, adts)?;
300 let term = Term::Let {
301 pat: Box::new(Pattern::Fan(fan, tag.clone(), bnd)),
302 val: Box::new(Term::Var { nam: arg }),
303 nxt: Box::new(nxt),
304 };
305
306 Ok(term)
307}
308
309fn num_rule(
310 mut args: Vec<Name>,
311 rules: Vec<Rule>,
312 idx: Vec<usize>,
313 with: Vec<Name>,
314 used: &mut BTreeSet<usize>,
315 ctrs: &Constructors,
316 adts: &Adts,
317) -> Result<Term, DesugarMatchDefErr> {
318 if !rules.iter().any(|r| r.pats[0].is_wildcard()) {
320 return Err(DesugarMatchDefErr::NumMissingDefault);
321 }
322
323 let arg = args[0].clone();
324 let args = args.split_off(1);
325
326 let pred_var = Name::new(format!("{arg}-1"));
327
328 let nums = rules
331 .iter()
332 .filter_map(|r| if let Pattern::Num(n) = r.pats[0] { Some(n) } else { None })
333 .collect::<BTreeSet<_>>()
334 .into_iter()
335 .collect::<Vec<_>>();
336
337 let mut num_bodies = vec![];
339 for num in nums.iter() {
340 let mut new_rules = vec![];
341 let mut new_idx = vec![];
342 for (rule, &idx) in rules.iter().zip(&idx) {
343 match &rule.pats[0] {
344 Pattern::Num(n) if n == num => {
345 let body = rule.body.clone();
346 let rule = Rule { pats: rule.pats[1..].to_vec(), body };
347 new_rules.push(rule);
348 new_idx.push(idx);
349 }
350 Pattern::Var(var) => {
351 let mut body = rule.body.clone();
352 if let Some(var) = var {
353 body = Term::Use {
354 nam: Some(var.clone()),
355 val: Box::new(Term::Num { val: Num::U24(*num) }),
356 nxt: Box::new(std::mem::take(&mut body)),
357 };
358 }
359 let rule = Rule { pats: rule.pats[1..].to_vec(), body };
360 new_rules.push(rule);
361 new_idx.push(idx);
362 }
363 _ => (),
364 }
365 }
366 let body = simplify_rule_match(args.clone(), new_rules, new_idx, with.clone(), used, ctrs, adts)?;
367 num_bodies.push(body);
368 }
369
370 let mut new_rules = vec![];
372 let mut new_idx = vec![];
373 for (rule, &idx) in rules.into_iter().zip(&idx) {
374 if let Pattern::Var(var) = &rule.pats[0] {
375 let mut body = rule.body.clone();
376 if let Some(var) = var {
377 let last_num = *nums.last().unwrap();
378 let cur_num = 1 + last_num;
379 let var_recovered = Term::add_num(Term::Var { nam: pred_var.clone() }, Num::U24(cur_num));
380 body = Term::Use { nam: Some(var.clone()), val: Box::new(var_recovered), nxt: Box::new(body) };
381 fast_pred_access(&mut body, cur_num, var, &pred_var);
382 }
383 let rule = Rule { pats: rule.pats[1..].to_vec(), body };
384 new_rules.push(rule);
385 new_idx.push(idx);
386 }
387 }
388 let mut default_with = with.clone();
389 default_with.push(pred_var.clone());
390 let default_body = simplify_rule_match(args.clone(), new_rules, new_idx, default_with, used, ctrs, adts)?;
391
392 let with = with.into_iter().chain(args).collect::<Vec<_>>();
394 let with_bnd = with.iter().cloned().map(Some).collect::<Vec<_>>();
395 let with_arg = with.iter().cloned().map(|nam| Term::Var { nam }).collect::<Vec<_>>();
396
397 let term = num_bodies.into_iter().enumerate().rfold(default_body, |term, (i, body)| {
398 let val = if i > 0 {
399 Term::sub_num(Term::Var { nam: pred_var.clone() }, Num::U24(nums[i] - 1 - nums[i - 1]))
402 } else {
403 Term::sub_num(Term::Var { nam: arg.clone() }, Num::U24(nums[i]))
405 };
406
407 Term::Swt {
408 arg: Box::new(val),
409 bnd: Some(arg.clone()),
410 with_bnd: with_bnd.clone(),
411 with_arg: with_arg.clone(),
412 pred: Some(pred_var.clone()),
413 arms: vec![body, term],
414 }
415 });
416
417 Ok(term)
418}
419
420fn fast_pred_access(body: &mut Term, cur_num: u32, var: &Name, pred_var: &Name) {
423 maybe_grow(|| {
424 if let Term::Oper { opr: crate::fun::Op::SUB, fst, snd } = body {
425 if let Term::Num { val: crate::fun::Num::U24(val) } = &**snd {
426 if let Term::Var { nam } = &**fst {
427 if nam == var && *val == cur_num {
428 *body = Term::Var { nam: pred_var.clone() };
429 }
430 }
431 }
432 }
433 for child in body.children_mut() {
434 fast_pred_access(child, cur_num, var, pred_var)
435 }
436 })
437}
438
439#[allow(clippy::too_many_arguments)]
490fn switch_rule(
491 mut args: Vec<Name>,
492 rules: Vec<Rule>,
493 idx: Vec<usize>,
494 with: Vec<Name>,
495 adt_name: Name,
496 used: &mut BTreeSet<usize>,
497 ctrs: &Constructors,
498 adts: &Adts,
499) -> Result<Term, DesugarMatchDefErr> {
500 let arg = args[0].clone();
501 let old_args = args.split_off(1);
502
503 let mut new_arms = vec![];
504 for (ctr_nam, ctr) in &adts[&adt_name].ctrs {
505 let new_args = ctr.fields.iter().map(|f| Name::new(format!("{}.{}", arg, f.nam)));
506 let args = new_args.clone().chain(old_args.clone()).collect();
507
508 let mut new_rules = vec![];
509 let mut new_idx = vec![];
510 for (rule, &idx) in rules.iter().zip(&idx) {
511 let old_pats = rule.pats[1..].to_vec();
512 match &rule.pats[0] {
513 Pattern::Ctr(found_ctr, new_pats) if ctr_nam == found_ctr => {
518 let pats = new_pats.iter().cloned().chain(old_pats).collect();
519 let body = rule.body.clone();
520 let rule = Rule { pats, body };
521 new_rules.push(rule);
522 new_idx.push(idx);
523 }
524 Pattern::Var(var) => {
530 let new_pats = new_args.clone().map(|n| Pattern::Var(Some(n)));
531 let pats = new_pats.chain(old_pats.clone()).collect();
532 let mut body = rule.body.clone();
533 let reconstructed_var =
534 Term::call(Term::Ref { nam: ctr_nam.clone() }, new_args.clone().map(|nam| Term::Var { nam }));
535 if let Some(var) = var {
536 body =
537 Term::Use { nam: Some(var.clone()), val: Box::new(reconstructed_var), nxt: Box::new(body) };
538 }
539 let rule = Rule { pats, body };
540 new_rules.push(rule);
541 new_idx.push(idx);
542 }
543 _ => (),
544 }
545 }
546
547 if new_rules.is_empty() {
548 return Err(DesugarMatchDefErr::AdtNotExhaustive { adt: adt_name, ctr: ctr_nam.clone() });
549 }
550
551 let body = simplify_rule_match(args, new_rules, new_idx, with.clone(), used, ctrs, adts)?;
552 new_arms.push((Some(ctr_nam.clone()), new_args.map(Some).collect(), body));
553 }
554
555 let with = with.into_iter().chain(old_args).collect::<Vec<_>>();
557 let with_bnd = with.iter().cloned().map(Some).collect::<Vec<_>>();
558 let with_arg = with.iter().cloned().map(|nam| Term::Var { nam }).collect::<Vec<_>>();
559
560 let term = Term::Mat {
561 arg: Box::new(Term::Var { nam: arg.clone() }),
562 bnd: Some(arg.clone()),
563 with_bnd,
564 with_arg,
565 arms: new_arms,
566 };
567 Ok(term)
568}
569
570#[derive(Debug, Clone, PartialEq, Eq)]
572pub enum Type {
573 Any,
575 Fan(FanKind, Tag, usize),
577 Num,
579 Adt(Name),
581}
582
583impl Type {
584 fn infer_from_def_arg(
586 rules: &[Rule],
587 arg_idx: usize,
588 ctrs: &Constructors,
589 ) -> Result<Type, DesugarMatchDefErr> {
590 let pats = rules.iter().map(|r| &r.pats[arg_idx]);
591 let mut arg_type = Type::Any;
592 for pat in pats {
593 arg_type = match (arg_type, pat.to_type(ctrs)) {
594 (Type::Any, found) => found,
595 (expected, Type::Any) => expected,
596
597 (expected, found) if expected == found => expected,
598
599 (expected, found) => {
600 return Err(DesugarMatchDefErr::TypeMismatch { expected, found, pat: pat.clone() });
601 }
602 };
603 }
604 Ok(arg_type)
605 }
606}
607
608impl Pattern {
609 fn to_type(&self, ctrs: &Constructors) -> Type {
610 match self {
611 Pattern::Var(_) | Pattern::Chn(_) => Type::Any,
612 Pattern::Ctr(ctr_nam, _) => {
613 let adt_nam = ctrs.get(ctr_nam).unwrap_or_else(|| panic!("Unknown constructor '{ctr_nam}'"));
614 Type::Adt(adt_nam.clone())
615 }
616 Pattern::Fan(is_tup, tag, args) => Type::Fan(*is_tup, tag.clone(), args.len()),
617 Pattern::Num(_) => Type::Num,
618 Pattern::Lst(..) => Type::Adt(Name::new(builtins::LIST)),
619 Pattern::Str(..) => Type::Adt(Name::new(builtins::STRING)),
620 }
621 }
622}
623
624impl std::fmt::Display for Type {
625 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
626 match self {
627 Type::Any => write!(f, "any"),
628 Type::Fan(FanKind::Tup, tag, n) => write!(f, "{}{n}-tuple", tag.display_padded()),
629 Type::Fan(FanKind::Dup, tag, n) => write!(f, "{}{n}-dup", tag.display_padded()),
630 Type::Num => write!(f, "number"),
631 Type::Adt(nam) => write!(f, "{nam}"),
632 }
633 }
634}
635
636impl std::fmt::Display for DesugarMatchDefErr {
637 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
638 match self {
639 DesugarMatchDefErr::AdtNotExhaustive { adt, ctr } => {
640 write!(f, "Non-exhaustive pattern matching rule. Constructor '{ctr}' of type '{adt}' not covered")
641 }
642 DesugarMatchDefErr::TypeMismatch { expected, found, pat } => {
643 write!(
644 f,
645 "Type mismatch in pattern matching rule. Expected a constructor of type '{}', found '{}' with type '{}'.",
646 expected, pat, found
647 )
648 }
649 DesugarMatchDefErr::NumMissingDefault => {
650 write!(f, "Non-exhaustive pattern matching rule. Default case of number type not covered.")
651 }
652 DesugarMatchDefErr::RepeatedBind { bind } => {
653 write!(f, "Repeated bind in pattern matching rule: '{bind}'.")
654 }
655 DesugarMatchDefErr::UnreachableRule { idx, nam, pats } => {
656 write!(
657 f,
658 "Unreachable pattern matching rule '({}{})' (rule index {idx}).",
659 nam,
660 pats.iter().map(|p| format!(" {p}")).join("")
661 )
662 }
663 }
664 }
665}