use std::sync::Arc;
use sim_kernel::{
Cx, Diagnostic, Error, Expr, MatchScore, Result, Shape, ShapeBindings, ShapeMatch, Symbol,
Value,
};
use crate::{AlgebraicDataType, VariantConstructor};
#[derive(Clone)]
pub struct MatchArm {
label: Symbol,
shape: Arc<dyn Shape>,
covered_variant: Option<Symbol>,
}
impl MatchArm {
pub fn new(label: Symbol, shape: Arc<dyn Shape>) -> Self {
Self {
label,
shape,
covered_variant: None,
}
}
pub fn for_constructor(constructor: &VariantConstructor) -> Self {
Self {
label: constructor.variant().clone(),
shape: constructor.shape(),
covered_variant: Some(constructor.variant().clone()),
}
}
pub fn with_covered_variant(mut self, variant: Symbol) -> Self {
self.covered_variant = Some(variant);
self
}
pub fn label(&self) -> &Symbol {
&self.label
}
pub fn shape(&self) -> &Arc<dyn Shape> {
&self.shape
}
pub fn covered_variant(&self) -> Option<&Symbol> {
self.covered_variant.as_ref()
}
}
#[derive(Clone, Debug)]
pub struct PatternMatch {
arm_index: usize,
label: Symbol,
captures: ShapeBindings,
score: MatchScore,
}
impl PatternMatch {
pub fn arm_index(&self) -> usize {
self.arm_index
}
pub fn label(&self) -> &Symbol {
&self.label
}
pub fn captures(&self) -> &ShapeBindings {
&self.captures
}
pub fn score(&self) -> MatchScore {
self.score
}
}
pub fn match_value(cx: &mut Cx, value: Value, arms: &[MatchArm]) -> Result<PatternMatch> {
let mut diagnostics = Vec::new();
for (index, arm) in arms.iter().enumerate() {
let matched = arm.shape().check_value(cx, value.clone())?;
if matched.accepted {
return Ok(PatternMatch {
arm_index: index,
label: arm.label().clone(),
captures: matched.captures,
score: matched.score,
});
}
diagnostics.extend(matched.diagnostics);
}
Err(Error::Eval(format!(
"no pattern arm matched: {}",
diagnostic_summary(&diagnostics)
)))
}
pub fn destructure_value(cx: &mut Cx, value: Value, shape: &dyn Shape) -> Result<ShapeMatch> {
shape.check_value(cx, value)
}
pub fn destructure_expr(cx: &mut Cx, expr: &Expr, shape: &dyn Shape) -> Result<ShapeMatch> {
shape.check_expr(cx, expr)
}
pub fn exhaustiveness_diagnostics(adt: &AlgebraicDataType, arms: &[MatchArm]) -> Vec<Diagnostic> {
let covered = arms
.iter()
.filter_map(MatchArm::covered_variant)
.collect::<std::collections::BTreeSet<_>>();
let missing = adt
.variants()
.filter(|variant| !covered.contains(variant.symbol()))
.map(|variant| variant.symbol().to_string())
.collect::<Vec<_>>();
if missing.is_empty() {
return Vec::new();
}
let mut diagnostic = Diagnostic::error(format!(
"non-exhaustive match for {}: missing {}",
adt.symbol(),
missing.join(", ")
));
diagnostic.code = Some(Symbol::qualified("pattern", "non-exhaustive"));
vec![diagnostic]
}
fn diagnostic_summary(diagnostics: &[Diagnostic]) -> String {
diagnostics
.first()
.map(|diagnostic| diagnostic.message.clone())
.unwrap_or_else(|| "all pattern shapes rejected the value".to_owned())
}