use std::collections::{BTreeMap, BTreeSet};
use crate::algebra::signature::Signature;
use crate::ast::*;
use crate::error::CompileError;
use crate::scopes::{Scopes, Symbol};
pub fn check_rule_matches(
rid: RuleDeclId,
ast: &Ast,
scopes: &Scopes,
signature: &Signature,
errors: &mut Vec<CompileError>,
) {
let body = ast.rule_decl(rid).body.clone();
let mut match_stmts: Vec<MatchStmtId> = Vec::new();
collect_from_block(ast, &body, &mut match_stmts);
for stmt in match_stmts {
check_match(stmt, ast, scopes, signature, errors);
}
}
fn collect_from_block(ast: &Ast, stmts: &[StmtId], out: &mut Vec<MatchStmtId>) {
for s in stmts {
match *ast.stmt(*s) {
Stmt::Match(mid) => {
out.push(mid);
let cases = ast.match_stmt(mid).cases.clone();
for case in cases {
let body = ast.match_case(case).body.clone();
collect_from_block(ast, &body, out);
}
}
Stmt::Branch(bid) => {
let blocks = ast.branch_stmt(bid).blocks.clone();
for block in blocks {
collect_from_block(ast, &block, out);
}
}
Stmt::If(_) | Stmt::Then(_) => {}
}
}
}
fn pattern_ctor(pattern: TermId, ast: &Ast, scopes: &Scopes) -> Option<CtorDeclId> {
let Term::App(aid) = *ast.term(pattern) else {
return None;
};
let func = ast.app_term(aid).func;
let FuncExpr::Ambient(id) = *ast.func_expr(func) else {
return None;
};
let scope = scopes.entry(id);
let name = &ast.ambient_func_expr(id).name;
match scopes.lookup(scope, name)? {
Symbol::Ctor(cid) => Some(cid),
_ => None,
}
}
fn ctor_enum(ctor: CtorDeclId, signature: &Signature) -> Option<EnumDeclId> {
let fid = signature.func_for_ctor_decl(ctor)?;
let codomain = signature.func(fid).codomain;
signature.enum_decl_for_type(codomain)
}
fn check_match(
stmt: MatchStmtId,
ast: &Ast,
scopes: &Scopes,
signature: &Signature,
errors: &mut Vec<CompileError>,
) {
let cases = ast.match_stmt(stmt).cases.clone();
let mut first_ctor_per_enum: BTreeMap<EnumDeclId, CtorDeclId> = BTreeMap::new();
let mut used_ctors: BTreeSet<CtorDeclId> = BTreeSet::new();
for case in &cases {
let pattern = ast.match_case(*case).pattern;
let Some(ctor) = pattern_ctor(pattern, ast, scopes) else {
continue;
};
used_ctors.insert(ctor);
let Some(enum_id) = ctor_enum(ctor, signature) else {
continue;
};
first_ctor_per_enum.entry(enum_id).or_insert(ctor);
}
if first_ctor_per_enum.len() >= 2 {
let mut iter = first_ctor_per_enum.values().copied();
let first = iter.next().unwrap();
let second = iter.next().unwrap();
errors.push(CompileError::MatchConflictingEnum {
match_stmt_location: ast.loc(stmt),
first_ctor_decl_location: ast.loc(first),
second_ctor_decl_location: ast.loc(second),
});
return;
}
let Some((&enum_id, _)) = first_ctor_per_enum.iter().next() else {
return;
};
for ctor in ast.enum_decl(enum_id).ctors.clone() {
if !used_ctors.contains(&ctor) {
errors.push(CompileError::MatchNotExhaustive {
match_location: ast.loc(stmt),
missing_ctor_decl_location: ast.loc(ctor),
});
}
}
}