bend/fun/transform/
fix_match_terms.rs

1use crate::{
2  diagnostics::{Diagnostics, WarningType, ERR_INDENT_SIZE},
3  fun::{Adts, Constructors, CtrField, Ctx, MatchRule, Name, Num, Term},
4  maybe_grow,
5};
6use std::collections::HashMap;
7
8enum FixMatchErr {
9  AdtMismatch { expected: Name, found: Name, ctr: Name },
10  NonExhaustiveMatch { typ: Name, missing: Name },
11  IrrefutableMatch { var: Option<Name> },
12  UnreachableMatchArms { var: Option<Name> },
13  RedundantArm { ctr: Name },
14}
15
16impl Ctx<'_> {
17  /// Convert all match and switch expressions to a normalized form.
18  /// * For matches, resolve the constructors and create the name of the field variables.
19  /// * For switches, the resolution and name bind is already done during parsing.
20  /// * Check for redundant arms and non-exhaustive matches.
21  /// * Converts the initial bind to an alias on every arm, rebuilding the eliminated constructor
22  /// * Since the bind is not needed anywhere else, it's erased from the term.
23  ///
24  /// Example:
25  /// For the program
26  /// ```hvm
27  /// data MyList = (Cons h t) | Nil
28  /// match x {
29  ///   Cons: (A x.h x.t)
30  ///   Nil: switch %arg = (Foo y) { 0: B; 1: C; _ %arg-2: D }
31  /// }
32  /// ```
33  /// The following AST transformations will be made:
34  /// * The binds `x.h` and `x.t` will be generated and stored in the match term.
35  /// * If it was missing one of the match cases, we'd get an error.
36  /// * If it included one of the cases more than once (including wildcard patterns), we'd get a warning.
37  /// ```hvm
38  /// match * = x {
39  ///   Cons x.h x.t: use x = (Cons x.h x.t); (A x.h x.t)
40  ///   Nil: use x = Nil;
41  ///     switch * = (Foo y) {
42  ///       0: use %arg = 0; B;
43  ///       1: use %arg = 1; C;
44  ///       _: use %arg = (+ %arg-2 2); D
45  ///       }
46  /// }
47  /// ```
48  pub fn fix_match_terms(&mut self) -> Result<(), Diagnostics> {
49    for def in self.book.defs.values_mut() {
50      for rule in def.rules.iter_mut() {
51        let errs = rule.body.fix_match_terms(&self.book.ctrs, &self.book.adts);
52
53        for err in errs {
54          match err {
55            FixMatchErr::AdtMismatch { .. } | FixMatchErr::NonExhaustiveMatch { .. } => {
56              self.info.add_function_error(err, def.name.clone(), def.source.clone())
57            }
58            FixMatchErr::IrrefutableMatch { .. } => self.info.add_function_warning(
59              err,
60              WarningType::IrrefutableMatch,
61              def.name.clone(),
62              def.source.clone(),
63            ),
64            FixMatchErr::UnreachableMatchArms { .. } => self.info.add_function_warning(
65              err,
66              WarningType::UnreachableMatch,
67              def.name.clone(),
68              def.source.clone(),
69            ),
70            FixMatchErr::RedundantArm { .. } => self.info.add_function_warning(
71              err,
72              WarningType::RedundantMatch,
73              def.name.clone(),
74              def.source.clone(),
75            ),
76          }
77        }
78      }
79    }
80
81    self.info.fatal(())
82  }
83}
84
85impl Term {
86  fn fix_match_terms(&mut self, ctrs: &Constructors, adts: &Adts) -> Vec<FixMatchErr> {
87    maybe_grow(|| {
88      let mut errs = Vec::new();
89
90      for child in self.children_mut() {
91        let mut e = child.fix_match_terms(ctrs, adts);
92        errs.append(&mut e);
93      }
94
95      if matches!(self, Term::Mat { .. } | Term::Fold { .. }) {
96        self.fix_match(&mut errs, ctrs, adts);
97      }
98      match self {
99        Term::Def { def, nxt } => {
100          for rule in def.rules.iter_mut() {
101            errs.extend(rule.body.fix_match_terms(ctrs, adts));
102          }
103          errs.extend(nxt.fix_match_terms(ctrs, adts));
104        }
105        // Add a use term to each arm rebuilding the matched variable
106        Term::Mat { arg: _, bnd, with_bnd: _, with_arg: _, arms }
107        | Term::Fold { bnd, arg: _, with_bnd: _, with_arg: _, arms } => {
108          for (ctr, fields, body) in arms {
109            if let Some(ctr) = ctr {
110              *body = Term::Use {
111                nam: bnd.clone(),
112                val: Box::new(Term::call(
113                  Term::Ref { nam: ctr.clone() },
114                  fields.iter().flatten().cloned().map(|nam| Term::Var { nam }),
115                )),
116                nxt: Box::new(std::mem::take(body)),
117              };
118            }
119          }
120        }
121        Term::Swt { arg: _, bnd, with_bnd: _, with_arg: _, pred, arms } => {
122          let n_nums = arms.len() - 1;
123          for (i, arm) in arms.iter_mut().enumerate() {
124            let orig = if i == n_nums {
125              Term::add_num(Term::Var { nam: pred.clone().unwrap() }, Num::U24(i as u32))
126            } else {
127              Term::Num { val: Num::U24(i as u32) }
128            };
129            *arm = Term::Use { nam: bnd.clone(), val: Box::new(orig), nxt: Box::new(std::mem::take(arm)) };
130          }
131        }
132        _ => {}
133      }
134
135      // Remove the bound name
136      match self {
137        Term::Mat { bnd, .. } | Term::Swt { bnd, .. } | Term::Fold { bnd, .. } => *bnd = None,
138        _ => {}
139      }
140
141      errs
142    })
143  }
144
145  fn fix_match(&mut self, errs: &mut Vec<FixMatchErr>, ctrs: &Constructors, adts: &Adts) {
146    let (Term::Mat { bnd, arg, with_bnd, with_arg, arms }
147    | Term::Fold { bnd, arg, with_bnd, with_arg, arms }) = self
148    else {
149      unreachable!()
150    };
151    let bnd = bnd.clone().unwrap();
152
153    // Normalize arms, making one arm for each constructor of the matched adt.
154    if let Some(ctr_nam) = &arms[0].0 {
155      if let Some(adt_nam) = ctrs.get(ctr_nam) {
156        // First arm matches a constructor as expected, so we can normalize the arms.
157        let adt_ctrs = &adts[adt_nam].ctrs;
158
159        // Decide which constructor corresponds to which arm of the match.
160        let mut bodies = fixed_match_arms(&bnd, arms, adt_nam, adt_ctrs.keys(), ctrs, adts, errs);
161
162        // Build the match arms, with all constructors
163        let mut new_rules = vec![];
164        for (ctr_nam, ctr) in adt_ctrs.iter() {
165          let fields = ctr.fields.iter().map(|f| Some(match_field(&bnd, &f.nam))).collect::<Vec<_>>();
166          let body = if let Some(Some(body)) = bodies.remove(ctr_nam) {
167            body
168          } else {
169            errs.push(FixMatchErr::NonExhaustiveMatch { typ: adt_nam.clone(), missing: ctr_nam.clone() });
170            Term::Err
171          };
172          new_rules.push((Some(ctr_nam.clone()), fields, body));
173        }
174        *arms = new_rules;
175        return;
176      }
177    }
178
179    // First arm was not matching a constructor, irrefutable match, convert into a use term.
180    errs.push(FixMatchErr::IrrefutableMatch { var: arms[0].0.clone() });
181    let match_var = arms[0].0.take();
182    let arg = std::mem::take(arg);
183    let with_bnd = std::mem::take(with_bnd);
184    let with_arg = std::mem::take(with_arg);
185
186    // Replaces `self` by its irrefutable arm
187    *self = std::mem::take(&mut arms[0].2);
188
189    // `with` clause desugaring
190    // Performs the same as `Term::linearize_match_with`.
191    // Note that it only wraps the arm with function calls if `with_bnd` and `with_arg` aren't empty.
192    *self = Term::rfold_lams(std::mem::take(self), with_bnd.into_iter());
193    *self = Term::call(std::mem::take(self), with_arg);
194
195    if let Some(var) = match_var {
196      *self = Term::Use {
197        nam: Some(bnd.clone()),
198        val: arg,
199        nxt: Box::new(Term::Use {
200          nam: Some(var),
201          val: Box::new(Term::Var { nam: bnd }),
202          nxt: Box::new(std::mem::take(self)),
203        }),
204      }
205    }
206  }
207}
208
209/// Given the rules of a match term, return the bodies that match
210/// each of the constructors of the matched ADT.
211///
212/// If no rules match a certain constructor, return None in the map,
213/// indicating a non-exhaustive match.
214fn fixed_match_arms<'a>(
215  bnd: &Name,
216  rules: &mut Vec<MatchRule>,
217  adt_nam: &Name,
218  adt_ctrs: impl Iterator<Item = &'a Name>,
219  ctrs: &Constructors,
220  adts: &Adts,
221  errs: &mut Vec<FixMatchErr>,
222) -> HashMap<&'a Name, Option<Term>> {
223  let mut bodies = HashMap::<&Name, Option<Term>>::from_iter(adt_ctrs.map(|ctr| (ctr, None)));
224  for rule_idx in 0..rules.len() {
225    // If Ctr arm, use the body of this rule for this constructor.
226    if let Some(ctr_nam) = &rules[rule_idx].0 {
227      if let Some(found_adt) = ctrs.get(ctr_nam) {
228        if found_adt == adt_nam {
229          let body = bodies.get_mut(ctr_nam).unwrap();
230          if body.is_none() {
231            // Use this rule for this constructor
232            *body = Some(rules[rule_idx].2.clone());
233          } else {
234            errs.push(FixMatchErr::RedundantArm { ctr: ctr_nam.clone() });
235          }
236        } else {
237          errs.push(FixMatchErr::AdtMismatch {
238            expected: adt_nam.clone(),
239            found: found_adt.clone(),
240            ctr: ctr_nam.clone(),
241          })
242        }
243        continue;
244      }
245    }
246    // Otherwise, Var arm, use the body of this rule for all non-covered constructors.
247    for (ctr, body) in bodies.iter_mut() {
248      if body.is_none() {
249        let mut new_body = rules[rule_idx].2.clone();
250        if let Some(var) = &rules[rule_idx].0 {
251          new_body = Term::Use {
252            nam: Some(var.clone()),
253            val: Box::new(rebuild_ctr(bnd, ctr, &adts[adt_nam].ctrs[&**ctr].fields)),
254            nxt: Box::new(new_body),
255          };
256        }
257        *body = Some(new_body);
258      }
259    }
260    if rule_idx != rules.len() - 1 {
261      errs.push(FixMatchErr::UnreachableMatchArms { var: rules[rule_idx].0.clone() });
262      rules.truncate(rule_idx + 1);
263    }
264    break;
265  }
266
267  bodies
268}
269
270fn match_field(arg: &Name, field: &Name) -> Name {
271  Name::new(format!("{arg}.{field}"))
272}
273
274fn rebuild_ctr(arg: &Name, ctr: &Name, fields: &[CtrField]) -> Term {
275  let ctr = Term::Ref { nam: ctr.clone() };
276  let fields = fields.iter().map(|f| Term::Var { nam: match_field(arg, &f.nam) });
277  Term::call(ctr, fields)
278}
279
280impl std::fmt::Display for FixMatchErr {
281  fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
282    match self {
283      FixMatchErr::AdtMismatch { expected, found, ctr } => write!(
284        f,
285        "Type mismatch in 'match' expression: Expected a constructor of type '{expected}', found '{ctr}' of type '{found}'"
286      ),
287      FixMatchErr::NonExhaustiveMatch { typ, missing } => {
288        write!(f, "Non-exhaustive 'match' expression of type '{typ}'. Case '{missing}' not covered.")
289      }
290      FixMatchErr::IrrefutableMatch { var } => {
291        writeln!(
292          f,
293          "Irrefutable 'match' expression. All cases after variable pattern '{}' will be ignored.",
294          var.as_ref().unwrap_or(&Name::new("*")),
295        )?;
296        writeln!(
297          f,
298          "{:ERR_INDENT_SIZE$}Note that to use a 'match' expression, the matched constructors need to be defined in a 'data' definition.",
299          "",
300        )?;
301        write!(
302          f,
303          "{:ERR_INDENT_SIZE$}If this is not a mistake, consider using a 'let' expression instead.",
304          ""
305        )
306      }
307
308      FixMatchErr::UnreachableMatchArms { var } => write!(
309        f,
310        "Unreachable arms in 'match' expression. All cases after '{}' will be ignored.",
311        var.as_ref().unwrap_or(&Name::new("*"))
312      ),
313      FixMatchErr::RedundantArm { ctr } => {
314        write!(f, "Redundant arm in 'match' expression. Case '{ctr}' appears more than once.")
315      }
316    }
317  }
318}