use std::collections::HashMap;
use crate::ast::{Expr, Literal, Spanned, VerifyGiven, VerifyGivenDomain, VerifyLaw};
use crate::ast_rewrite::rewrite_idents_scoped;
use crate::types::checker::hostile_values::boundary_exprs;
use crate::types::parse_type_str_strict;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ExpansionMode {
Declared,
Hostile,
}
#[derive(Debug, Clone)]
pub struct ExpandedCase {
pub lhs: Spanned<Expr>,
pub rhs: Spanned<Expr>,
pub guard: Option<Spanned<Expr>>,
pub bindings: Vec<(String, Spanned<Expr>)>,
pub from_hostile: bool,
}
pub fn expand_law_cases(law: &VerifyLaw, mode: ExpansionMode) -> Vec<ExpandedCase> {
let declared_combos = collect_combinations(law, ExpansionMode::Declared);
match mode {
ExpansionMode::Declared => declared_combos
.into_iter()
.map(|combo| materialize_case(law, &combo, false))
.collect(),
ExpansionMode::Hostile => {
let mut out: Vec<ExpandedCase> = declared_combos
.iter()
.map(|combo| materialize_case(law, combo, false))
.collect();
let declared_set: std::collections::HashSet<Vec<String>> = declared_combos
.iter()
.map(|c| combination_signature(c))
.collect();
for combo in collect_combinations(law, ExpansionMode::Hostile) {
if declared_set.contains(&combination_signature(&combo)) {
continue;
}
out.push(materialize_case(law, &combo, true));
}
out
}
}
}
fn collect_combinations(law: &VerifyLaw, mode: ExpansionMode) -> Vec<Vec<(String, Spanned<Expr>)>> {
let per_given: Vec<Vec<Spanned<Expr>>> = law
.givens
.iter()
.map(|g| values_for_given(g, mode))
.collect();
cartesian(&law.givens, &per_given)
}
fn values_for_given(g: &VerifyGiven, mode: ExpansionMode) -> Vec<Spanned<Expr>> {
let mut values = declared_values(&g.domain);
if mode == ExpansionMode::Hostile
&& let Ok(ty) = parse_type_str_strict(&g.type_name)
{
let existing: std::collections::HashSet<String> = values.iter().map(render_expr).collect();
for candidate in boundary_exprs(&ty) {
if !existing.contains(&render_expr(&candidate)) {
values.push(candidate);
}
}
}
values
}
fn declared_values(domain: &VerifyGivenDomain) -> Vec<Spanned<Expr>> {
match domain {
VerifyGivenDomain::Explicit(values) => values.clone(),
VerifyGivenDomain::IntRange { start, end } => (*start..=*end)
.map(|n| Spanned::bare(Expr::Literal(Literal::Int(n))))
.collect(),
}
}
fn cartesian(
givens: &[VerifyGiven],
per_given: &[Vec<Spanned<Expr>>],
) -> Vec<Vec<(String, Spanned<Expr>)>> {
let mut out: Vec<Vec<(String, Spanned<Expr>)>> = vec![Vec::new()];
for (g, choices) in givens.iter().zip(per_given) {
let mut next = Vec::with_capacity(out.len() * choices.len().max(1));
for partial in &out {
for choice in choices {
let mut extended = partial.clone();
extended.push((g.name.clone(), choice.clone()));
next.push(extended);
}
}
out = next;
}
out
}
fn materialize_case(
law: &VerifyLaw,
bindings: &[(String, Spanned<Expr>)],
from_hostile: bool,
) -> ExpandedCase {
let map: HashMap<&str, &Spanned<Expr>> =
bindings.iter().map(|(n, v)| (n.as_str(), v)).collect();
ExpandedCase {
lhs: substitute(&law.lhs, &map),
rhs: substitute(&law.rhs, &map),
guard: law.when.as_ref().map(|w| substitute(w, &map)),
bindings: bindings.to_vec(),
from_hostile,
}
}
fn substitute(expr: &Spanned<Expr>, bindings: &HashMap<&str, &Spanned<Expr>>) -> Spanned<Expr> {
rewrite_idents_scoped(expr, |name| bindings.get(name).map(|v| (*v).clone()))
}
fn combination_signature(combo: &[(String, Spanned<Expr>)]) -> Vec<String> {
combo
.iter()
.map(|(n, v)| format!("{}={}", n, render_expr(v)))
.collect()
}
fn render_expr(expr: &Spanned<Expr>) -> String {
match &expr.node {
Expr::Literal(Literal::Int(i)) => i.to_string(),
Expr::Literal(Literal::Float(f)) => format!("{:?}", f),
Expr::Literal(Literal::Str(s)) => format!("{:?}", s),
Expr::Literal(Literal::Bool(b)) => b.to_string(),
Expr::Literal(Literal::Unit) => "Unit".to_string(),
other => format!("{:?}", other),
}
}
#[cfg(test)]
mod tests {
use super::*;
fn lit_int(n: i64) -> Spanned<Expr> {
Spanned::bare(Expr::Literal(Literal::Int(n)))
}
fn law_one_int(name: &str, ty: &str, declared: Vec<i64>) -> VerifyLaw {
VerifyLaw {
name: "test".to_string(),
givens: vec![VerifyGiven {
name: name.to_string(),
type_name: ty.to_string(),
domain: VerifyGivenDomain::Explicit(declared.into_iter().map(lit_int).collect()),
}],
when: None,
lhs: Spanned::bare(Expr::Ident(name.to_string())),
rhs: Spanned::bare(Expr::Ident(name.to_string())),
sample_guards: vec![],
}
}
#[test]
fn declared_mode_returns_only_user_values() {
let law = law_one_int("x", "Int", vec![1, 2]);
let cases = expand_law_cases(&law, ExpansionMode::Declared);
assert_eq!(cases.len(), 2);
assert!(cases.iter().all(|c| !c.from_hostile));
}
#[test]
fn hostile_mode_appends_boundary_set() {
let law = law_one_int("x", "Int", vec![5]);
let declared = expand_law_cases(&law, ExpansionMode::Declared);
let hostile = expand_law_cases(&law, ExpansionMode::Hostile);
assert_eq!(declared.len(), 1);
assert_eq!(hostile.len(), 1 + 5);
assert!(!hostile[0].from_hostile);
assert!(hostile[1..].iter().all(|c| c.from_hostile));
}
#[test]
fn hostile_dedupes_against_declared() {
let law = law_one_int("x", "Int", vec![0]);
let hostile = expand_law_cases(&law, ExpansionMode::Hostile);
let zero_count = hostile
.iter()
.filter(|c| matches!(&c.lhs.node, Expr::Literal(Literal::Int(0))))
.count();
assert_eq!(zero_count, 1, "zero should appear exactly once");
}
#[test]
fn cartesian_two_givens() {
let law = VerifyLaw {
name: "two".to_string(),
givens: vec![
VerifyGiven {
name: "x".to_string(),
type_name: "Bool".to_string(),
domain: VerifyGivenDomain::Explicit(vec![Spanned::bare(Expr::Literal(
Literal::Bool(true),
))]),
},
VerifyGiven {
name: "y".to_string(),
type_name: "Int".to_string(),
domain: VerifyGivenDomain::Explicit(vec![lit_int(1)]),
},
],
when: None,
lhs: Spanned::bare(Expr::Ident("x".to_string())),
rhs: Spanned::bare(Expr::Ident("y".to_string())),
sample_guards: vec![],
};
let declared = expand_law_cases(&law, ExpansionMode::Declared);
assert_eq!(declared.len(), 1);
let hostile = expand_law_cases(&law, ExpansionMode::Hostile);
assert_eq!(hostile.iter().filter(|c| !c.from_hostile).count(), 1);
assert_eq!(hostile.len(), 10);
}
#[test]
fn unknown_type_falls_back_to_declared_only() {
let law = law_one_int("x", "Shape", vec![1]);
let hostile = expand_law_cases(&law, ExpansionMode::Hostile);
assert_eq!(hostile.len(), 1);
assert!(!hostile[0].from_hostile);
}
}