1use crate::{
2 diagnostics::{Diagnostics, WarningType, ERR_INDENT_SIZE},
3 fun::{Adts, Constructors, CtrField, Ctx, MatchRule, Name, Num, Term},
4 maybe_grow,
5};
6use std::collections::HashMap;
7
8enum FixMatchErr {
9 AdtMismatch { expected: Name, found: Name, ctr: Name },
10 NonExhaustiveMatch { typ: Name, missing: Name },
11 IrrefutableMatch { var: Option<Name> },
12 UnreachableMatchArms { var: Option<Name> },
13 RedundantArm { ctr: Name },
14}
15
16impl Ctx<'_> {
17 pub fn fix_match_terms(&mut self) -> Result<(), Diagnostics> {
49 for def in self.book.defs.values_mut() {
50 for rule in def.rules.iter_mut() {
51 let errs = rule.body.fix_match_terms(&self.book.ctrs, &self.book.adts);
52
53 for err in errs {
54 match err {
55 FixMatchErr::AdtMismatch { .. } | FixMatchErr::NonExhaustiveMatch { .. } => {
56 self.info.add_function_error(err, def.name.clone(), def.source.clone())
57 }
58 FixMatchErr::IrrefutableMatch { .. } => self.info.add_function_warning(
59 err,
60 WarningType::IrrefutableMatch,
61 def.name.clone(),
62 def.source.clone(),
63 ),
64 FixMatchErr::UnreachableMatchArms { .. } => self.info.add_function_warning(
65 err,
66 WarningType::UnreachableMatch,
67 def.name.clone(),
68 def.source.clone(),
69 ),
70 FixMatchErr::RedundantArm { .. } => self.info.add_function_warning(
71 err,
72 WarningType::RedundantMatch,
73 def.name.clone(),
74 def.source.clone(),
75 ),
76 }
77 }
78 }
79 }
80
81 self.info.fatal(())
82 }
83}
84
85impl Term {
86 fn fix_match_terms(&mut self, ctrs: &Constructors, adts: &Adts) -> Vec<FixMatchErr> {
87 maybe_grow(|| {
88 let mut errs = Vec::new();
89
90 for child in self.children_mut() {
91 let mut e = child.fix_match_terms(ctrs, adts);
92 errs.append(&mut e);
93 }
94
95 if matches!(self, Term::Mat { .. } | Term::Fold { .. }) {
96 self.fix_match(&mut errs, ctrs, adts);
97 }
98 match self {
99 Term::Def { def, nxt } => {
100 for rule in def.rules.iter_mut() {
101 errs.extend(rule.body.fix_match_terms(ctrs, adts));
102 }
103 errs.extend(nxt.fix_match_terms(ctrs, adts));
104 }
105 Term::Mat { arg: _, bnd, with_bnd: _, with_arg: _, arms }
107 | Term::Fold { bnd, arg: _, with_bnd: _, with_arg: _, arms } => {
108 for (ctr, fields, body) in arms {
109 if let Some(ctr) = ctr {
110 *body = Term::Use {
111 nam: bnd.clone(),
112 val: Box::new(Term::call(
113 Term::Ref { nam: ctr.clone() },
114 fields.iter().flatten().cloned().map(|nam| Term::Var { nam }),
115 )),
116 nxt: Box::new(std::mem::take(body)),
117 };
118 }
119 }
120 }
121 Term::Swt { arg: _, bnd, with_bnd: _, with_arg: _, pred, arms } => {
122 let n_nums = arms.len() - 1;
123 for (i, arm) in arms.iter_mut().enumerate() {
124 let orig = if i == n_nums {
125 Term::add_num(Term::Var { nam: pred.clone().unwrap() }, Num::U24(i as u32))
126 } else {
127 Term::Num { val: Num::U24(i as u32) }
128 };
129 *arm = Term::Use { nam: bnd.clone(), val: Box::new(orig), nxt: Box::new(std::mem::take(arm)) };
130 }
131 }
132 _ => {}
133 }
134
135 match self {
137 Term::Mat { bnd, .. } | Term::Swt { bnd, .. } | Term::Fold { bnd, .. } => *bnd = None,
138 _ => {}
139 }
140
141 errs
142 })
143 }
144
145 fn fix_match(&mut self, errs: &mut Vec<FixMatchErr>, ctrs: &Constructors, adts: &Adts) {
146 let (Term::Mat { bnd, arg, with_bnd, with_arg, arms }
147 | Term::Fold { bnd, arg, with_bnd, with_arg, arms }) = self
148 else {
149 unreachable!()
150 };
151 let bnd = bnd.clone().unwrap();
152
153 if let Some(ctr_nam) = &arms[0].0 {
155 if let Some(adt_nam) = ctrs.get(ctr_nam) {
156 let adt_ctrs = &adts[adt_nam].ctrs;
158
159 let mut bodies = fixed_match_arms(&bnd, arms, adt_nam, adt_ctrs.keys(), ctrs, adts, errs);
161
162 let mut new_rules = vec![];
164 for (ctr_nam, ctr) in adt_ctrs.iter() {
165 let fields = ctr.fields.iter().map(|f| Some(match_field(&bnd, &f.nam))).collect::<Vec<_>>();
166 let body = if let Some(Some(body)) = bodies.remove(ctr_nam) {
167 body
168 } else {
169 errs.push(FixMatchErr::NonExhaustiveMatch { typ: adt_nam.clone(), missing: ctr_nam.clone() });
170 Term::Err
171 };
172 new_rules.push((Some(ctr_nam.clone()), fields, body));
173 }
174 *arms = new_rules;
175 return;
176 }
177 }
178
179 errs.push(FixMatchErr::IrrefutableMatch { var: arms[0].0.clone() });
181 let match_var = arms[0].0.take();
182 let arg = std::mem::take(arg);
183 let with_bnd = std::mem::take(with_bnd);
184 let with_arg = std::mem::take(with_arg);
185
186 *self = std::mem::take(&mut arms[0].2);
188
189 *self = Term::rfold_lams(std::mem::take(self), with_bnd.into_iter());
193 *self = Term::call(std::mem::take(self), with_arg);
194
195 if let Some(var) = match_var {
196 *self = Term::Use {
197 nam: Some(bnd.clone()),
198 val: arg,
199 nxt: Box::new(Term::Use {
200 nam: Some(var),
201 val: Box::new(Term::Var { nam: bnd }),
202 nxt: Box::new(std::mem::take(self)),
203 }),
204 }
205 }
206 }
207}
208
209fn fixed_match_arms<'a>(
215 bnd: &Name,
216 rules: &mut Vec<MatchRule>,
217 adt_nam: &Name,
218 adt_ctrs: impl Iterator<Item = &'a Name>,
219 ctrs: &Constructors,
220 adts: &Adts,
221 errs: &mut Vec<FixMatchErr>,
222) -> HashMap<&'a Name, Option<Term>> {
223 let mut bodies = HashMap::<&Name, Option<Term>>::from_iter(adt_ctrs.map(|ctr| (ctr, None)));
224 for rule_idx in 0..rules.len() {
225 if let Some(ctr_nam) = &rules[rule_idx].0 {
227 if let Some(found_adt) = ctrs.get(ctr_nam) {
228 if found_adt == adt_nam {
229 let body = bodies.get_mut(ctr_nam).unwrap();
230 if body.is_none() {
231 *body = Some(rules[rule_idx].2.clone());
233 } else {
234 errs.push(FixMatchErr::RedundantArm { ctr: ctr_nam.clone() });
235 }
236 } else {
237 errs.push(FixMatchErr::AdtMismatch {
238 expected: adt_nam.clone(),
239 found: found_adt.clone(),
240 ctr: ctr_nam.clone(),
241 })
242 }
243 continue;
244 }
245 }
246 for (ctr, body) in bodies.iter_mut() {
248 if body.is_none() {
249 let mut new_body = rules[rule_idx].2.clone();
250 if let Some(var) = &rules[rule_idx].0 {
251 new_body = Term::Use {
252 nam: Some(var.clone()),
253 val: Box::new(rebuild_ctr(bnd, ctr, &adts[adt_nam].ctrs[&**ctr].fields)),
254 nxt: Box::new(new_body),
255 };
256 }
257 *body = Some(new_body);
258 }
259 }
260 if rule_idx != rules.len() - 1 {
261 errs.push(FixMatchErr::UnreachableMatchArms { var: rules[rule_idx].0.clone() });
262 rules.truncate(rule_idx + 1);
263 }
264 break;
265 }
266
267 bodies
268}
269
270fn match_field(arg: &Name, field: &Name) -> Name {
271 Name::new(format!("{arg}.{field}"))
272}
273
274fn rebuild_ctr(arg: &Name, ctr: &Name, fields: &[CtrField]) -> Term {
275 let ctr = Term::Ref { nam: ctr.clone() };
276 let fields = fields.iter().map(|f| Term::Var { nam: match_field(arg, &f.nam) });
277 Term::call(ctr, fields)
278}
279
280impl std::fmt::Display for FixMatchErr {
281 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
282 match self {
283 FixMatchErr::AdtMismatch { expected, found, ctr } => write!(
284 f,
285 "Type mismatch in 'match' expression: Expected a constructor of type '{expected}', found '{ctr}' of type '{found}'"
286 ),
287 FixMatchErr::NonExhaustiveMatch { typ, missing } => {
288 write!(f, "Non-exhaustive 'match' expression of type '{typ}'. Case '{missing}' not covered.")
289 }
290 FixMatchErr::IrrefutableMatch { var } => {
291 writeln!(
292 f,
293 "Irrefutable 'match' expression. All cases after variable pattern '{}' will be ignored.",
294 var.as_ref().unwrap_or(&Name::new("*")),
295 )?;
296 writeln!(
297 f,
298 "{:ERR_INDENT_SIZE$}Note that to use a 'match' expression, the matched constructors need to be defined in a 'data' definition.",
299 "",
300 )?;
301 write!(
302 f,
303 "{:ERR_INDENT_SIZE$}If this is not a mistake, consider using a 'let' expression instead.",
304 ""
305 )
306 }
307
308 FixMatchErr::UnreachableMatchArms { var } => write!(
309 f,
310 "Unreachable arms in 'match' expression. All cases after '{}' will be ignored.",
311 var.as_ref().unwrap_or(&Name::new("*"))
312 ),
313 FixMatchErr::RedundantArm { ctr } => {
314 write!(f, "Redundant arm in 'match' expression. Case '{ctr}' appears more than once.")
315 }
316 }
317 }
318}