use std::sync::Arc;
use panproto_expr::{Env, EvalConfig, Expr, Literal, eval};
use panproto_gat::{CoercionClass, DirectedEquation, Theory, ValueKind};
use rustc_hash::FxHashMap;
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
#[serde(tag = "kind")]
#[non_exhaustive]
pub enum CoercionLawViolation {
Backward {
input: Literal,
forward_result: Literal,
round_tripped: Literal,
},
Forward {
input: Literal,
inverse_result: Literal,
round_tripped: Literal,
},
NonDeterministic {
input: Literal,
first: Literal,
second: Literal,
},
MissingInverse {
class: CoercionClass,
},
ForwardEvalError {
input: Literal,
error: String,
},
InverseEvalError {
input: Literal,
error: String,
},
UnknownClass {
debug_repr: String,
},
}
#[must_use]
pub fn check_coercion_laws(
forward: &Expr,
inverse: Option<&Expr>,
class: CoercionClass,
samples: &[Literal],
var_name: &str,
) -> Vec<CoercionLawViolation> {
let var: Arc<str> = Arc::from(var_name);
let config = EvalConfig::default();
let mut violations = Vec::new();
match class {
CoercionClass::Iso => {
let Some(inv) = inverse else {
violations.push(CoercionLawViolation::MissingInverse { class });
return violations;
};
for sample in samples {
check_backward(forward, inv, sample, &var, &config, &mut violations);
check_forward(forward, inv, sample, &var, &config, &mut violations);
}
}
CoercionClass::Retraction => {
let Some(inv) = inverse else {
violations.push(CoercionLawViolation::MissingInverse { class });
return violations;
};
for sample in samples {
check_backward(forward, inv, sample, &var, &config, &mut violations);
}
}
CoercionClass::Projection => {
for sample in samples {
check_deterministic(forward, sample, &var, &config, &mut violations);
}
}
CoercionClass::Opaque => {
}
other => {
violations.push(CoercionLawViolation::UnknownClass {
debug_repr: format!("{other:?}"),
});
}
}
violations
}
#[must_use]
pub fn check_directed_equation_coercion_law(
deq: &DirectedEquation,
samples: &[Literal],
var_name: &str,
) -> Vec<CoercionLawViolation> {
check_coercion_laws(
&deq.impl_term,
deq.inverse.as_ref(),
deq.coercion_class,
samples,
var_name,
)
}
#[must_use]
pub fn default_samples_for_string_value() -> Vec<Literal> {
vec![
Literal::Str(String::new()),
Literal::Str("name".to_owned()),
Literal::Str("Alice".to_owned()),
Literal::Str("ALICE".to_owned()),
Literal::Str("hello world".to_owned()),
Literal::Str("schön".to_owned()),
]
}
fn check_backward(
forward: &Expr,
inverse: &Expr,
sample: &Literal,
var: &Arc<str>,
config: &EvalConfig,
violations: &mut Vec<CoercionLawViolation>,
) {
let env = Env::new().extend(Arc::clone(var), sample.clone());
let forward_result = match eval(forward, &env, config) {
Ok(v) => v,
Err(e) => {
violations.push(CoercionLawViolation::ForwardEvalError {
input: sample.clone(),
error: e.to_string(),
});
return;
}
};
let inverse_env = Env::new().extend(Arc::clone(var), forward_result.clone());
match eval(inverse, &inverse_env, config) {
Ok(round_tripped) => {
if round_tripped != *sample {
violations.push(CoercionLawViolation::Backward {
input: sample.clone(),
forward_result,
round_tripped,
});
}
}
Err(e) => {
violations.push(CoercionLawViolation::InverseEvalError {
input: sample.clone(),
error: e.to_string(),
});
}
}
}
fn check_forward(
forward: &Expr,
inverse: &Expr,
sample: &Literal,
var: &Arc<str>,
config: &EvalConfig,
violations: &mut Vec<CoercionLawViolation>,
) {
let env = Env::new().extend(Arc::clone(var), sample.clone());
let inverse_result = match eval(inverse, &env, config) {
Ok(v) => v,
Err(e) => {
violations.push(CoercionLawViolation::InverseEvalError {
input: sample.clone(),
error: e.to_string(),
});
return;
}
};
let forward_env = Env::new().extend(Arc::clone(var), inverse_result.clone());
match eval(forward, &forward_env, config) {
Ok(round_tripped) => {
if round_tripped != *sample {
violations.push(CoercionLawViolation::Forward {
input: sample.clone(),
inverse_result,
round_tripped,
});
}
}
Err(e) => {
violations.push(CoercionLawViolation::ForwardEvalError {
input: sample.clone(),
error: e.to_string(),
});
}
}
}
fn check_deterministic(
forward: &Expr,
sample: &Literal,
var: &Arc<str>,
config: &EvalConfig,
violations: &mut Vec<CoercionLawViolation>,
) {
let env = Env::new().extend(Arc::clone(var), sample.clone());
let first = match eval(forward, &env, config) {
Ok(v) => v,
Err(e) => {
violations.push(CoercionLawViolation::ForwardEvalError {
input: sample.clone(),
error: e.to_string(),
});
return;
}
};
let second = match eval(forward, &env, config) {
Ok(v) => v,
Err(e) => {
violations.push(CoercionLawViolation::ForwardEvalError {
input: sample.clone(),
error: e.to_string(),
});
return;
}
};
if first != second {
violations.push(CoercionLawViolation::NonDeterministic {
input: sample.clone(),
first,
second,
});
}
}
#[derive(Debug, Clone)]
pub struct CoercionSampleRegistry {
samples: FxHashMap<ValueKind, Vec<Literal>>,
}
impl CoercionSampleRegistry {
#[must_use]
pub fn new() -> Self {
Self {
samples: FxHashMap::default(),
}
}
#[must_use]
pub fn with_defaults() -> Self {
let mut reg = Self::new();
reg.register(
ValueKind::Bool,
vec![Literal::Bool(false), Literal::Bool(true)],
);
reg.register(
ValueKind::Int,
vec![
Literal::Int(0),
Literal::Int(1),
Literal::Int(-1),
Literal::Int(42),
Literal::Int(i64::MAX),
Literal::Int(i64::MIN),
],
);
reg.register(
ValueKind::Float,
vec![
Literal::Float(0.0),
Literal::Float(1.0),
Literal::Float(-1.0),
Literal::Float(3.5),
Literal::Float(-2.25),
],
);
reg.register(ValueKind::Str, default_samples_for_string_value());
reg.register(
ValueKind::Bytes,
vec![
Literal::Bytes(Vec::new()),
Literal::Bytes(b"abc".to_vec()),
Literal::Bytes(vec![0, 255, 7]),
],
);
reg.register(ValueKind::Null, vec![Literal::Null]);
reg.register(
ValueKind::Token,
vec![
Literal::Str("token".to_owned()),
Literal::Str("id_42".to_owned()),
Literal::Str("ns:name".to_owned()),
],
);
let mut union: Vec<Literal> = Vec::new();
for kind in ValueKind::all() {
if matches!(kind, ValueKind::Any) {
continue;
}
if let Some(vs) = reg.samples.get(kind) {
union.extend(vs.iter().cloned());
}
}
reg.register(ValueKind::Any, union);
reg
}
pub fn register(&mut self, kind: ValueKind, samples: Vec<Literal>) {
self.samples.insert(kind, samples);
}
#[must_use]
pub fn samples_for(&self, kind: ValueKind) -> &[Literal] {
self.samples.get(&kind).map_or(&[], Vec::as_slice)
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.samples.values().all(Vec::is_empty)
}
}
impl Default for CoercionSampleRegistry {
fn default() -> Self {
Self::new()
}
}
#[must_use]
pub fn check_directed_equation_with_registry(
deq: &DirectedEquation,
registry: &CoercionSampleRegistry,
var_name: &str,
) -> Vec<CoercionLawViolation> {
let primary = deq
.source_kind
.map_or(&[] as &[Literal], |k| registry.samples_for(k));
let samples: &[Literal] = if primary.is_empty() {
registry.samples_for(ValueKind::Any)
} else {
primary
};
if samples.is_empty() {
return Vec::new();
}
check_directed_equation_coercion_law(deq, samples, var_name)
}
#[derive(Debug, Clone, Default)]
pub struct TheoryCoercionReport {
pub per_equation: Vec<(Arc<str>, Vec<CoercionLawViolation>)>,
}
impl TheoryCoercionReport {
#[must_use]
pub fn is_clean(&self) -> bool {
self.per_equation.iter().all(|(_, vs)| vs.is_empty())
}
#[must_use]
pub fn violation_count(&self) -> usize {
self.per_equation.iter().map(|(_, vs)| vs.len()).sum()
}
}
#[must_use]
pub fn check_theory(theory: &Theory, registry: &CoercionSampleRegistry) -> TheoryCoercionReport {
check_theory_with_var(theory, registry, "x")
}
#[must_use]
pub fn check_theory_with_var(
theory: &Theory,
registry: &CoercionSampleRegistry,
var_name: &str,
) -> TheoryCoercionReport {
let mut per_equation = Vec::with_capacity(theory.directed_eqs.len());
for deq in &theory.directed_eqs {
let violations = check_directed_equation_with_registry(deq, registry, var_name);
per_equation.push((Arc::clone(&deq.name), violations));
}
TheoryCoercionReport { per_equation }
}
pub trait CoercionLawValidation {
fn validate_coercion_law(
&self,
registry: &CoercionSampleRegistry,
var_name: &str,
) -> Result<(), Vec<CoercionLawViolation>>;
}
impl CoercionLawValidation for DirectedEquation {
fn validate_coercion_law(
&self,
registry: &CoercionSampleRegistry,
var_name: &str,
) -> Result<(), Vec<CoercionLawViolation>> {
let violations = check_directed_equation_with_registry(self, registry, var_name);
if violations.is_empty() {
Ok(())
} else {
Err(violations)
}
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
use panproto_expr::{BuiltinOp, Expr};
fn upper_expr(var: &str) -> Expr {
Expr::Builtin(BuiltinOp::Upper, vec![Expr::Var(Arc::from(var))])
}
fn identity_expr(var: &str) -> Expr {
Expr::Var(Arc::from(var))
}
#[test]
fn iso_with_honest_identity_passes() {
let violations = check_coercion_laws(
&identity_expr("x"),
Some(&identity_expr("x")),
CoercionClass::Iso,
&default_samples_for_string_value(),
"x",
);
assert!(
violations.is_empty(),
"honest identity iso must have no violations, got {violations:?}"
);
}
#[test]
fn iso_with_lying_identity_inverse_is_flagged() {
let forward = upper_expr("x");
let inverse = identity_expr("x");
let violations = check_coercion_laws(
&forward,
Some(&inverse),
CoercionClass::Iso,
&[Literal::Str("Alice".to_owned())],
"x",
);
assert!(
!violations.is_empty(),
"lying iso declaration must be flagged",
);
let has_backward = violations
.iter()
.any(|v| matches!(v, CoercionLawViolation::Backward { .. }));
let has_forward = violations
.iter()
.any(|v| matches!(v, CoercionLawViolation::Forward { .. }));
assert!(
has_backward,
"expected Backward violation in {violations:?}"
);
assert!(has_forward, "expected Forward violation in {violations:?}");
}
#[test]
fn retraction_checks_only_backward_direction() {
let forward = upper_expr("x");
let inverse = Expr::Builtin(BuiltinOp::Lower, vec![Expr::Var(Arc::from("x"))]);
let violations = check_coercion_laws(
&forward,
Some(&inverse),
CoercionClass::Retraction,
&[Literal::Str("Alice".to_owned())],
"x",
);
assert!(
violations
.iter()
.any(|v| matches!(v, CoercionLawViolation::Backward { .. })),
"retraction backward violation expected, got {violations:?}"
);
}
#[test]
fn projection_checks_determinism() {
let violations = check_coercion_laws(
&upper_expr("x"),
None,
CoercionClass::Projection,
&default_samples_for_string_value(),
"x",
);
assert!(
violations.is_empty(),
"deterministic projection must pass, got {violations:?}"
);
}
#[test]
fn opaque_declares_no_law_so_always_passes() {
let violations = check_coercion_laws(
&upper_expr("x"),
None,
CoercionClass::Opaque,
&default_samples_for_string_value(),
"x",
);
assert!(violations.is_empty(), "opaque has no laws to violate");
}
#[test]
fn iso_without_inverse_reports_missing_inverse() {
let violations = check_coercion_laws(
&upper_expr("x"),
None,
CoercionClass::Iso,
&default_samples_for_string_value(),
"x",
);
assert_eq!(violations.len(), 1);
assert!(matches!(
violations[0],
CoercionLawViolation::MissingInverse {
class: CoercionClass::Iso,
}
));
}
#[test]
fn retraction_without_inverse_reports_missing_inverse() {
let violations = check_coercion_laws(
&upper_expr("x"),
None,
CoercionClass::Retraction,
&default_samples_for_string_value(),
"x",
);
assert_eq!(violations.len(), 1);
assert!(matches!(
violations[0],
CoercionLawViolation::MissingInverse {
class: CoercionClass::Retraction,
}
));
}
#[test]
fn check_directed_equation_matches_explicit_call() {
let forward = upper_expr("x");
let deq = DirectedEquation {
name: Arc::from("upper_iso_lying"),
lhs: panproto_gat::Term::var("x"),
rhs: panproto_gat::Term::app("upper", vec![panproto_gat::Term::var("x")]),
impl_term: forward.clone(),
inverse: Some(identity_expr("x")),
source_kind: Some(panproto_gat::ValueKind::Str),
target_kind: Some(panproto_gat::ValueKind::Str),
coercion_class: CoercionClass::Iso,
};
let samples = vec![Literal::Str("Alice".to_owned())];
let direct = check_coercion_laws(
&forward,
Some(&identity_expr("x")),
CoercionClass::Iso,
&samples,
"x",
);
let via_deq = check_directed_equation_coercion_law(&deq, &samples, "x");
assert_eq!(direct.len(), via_deq.len());
}
fn honest_iso_deq(name: &str) -> DirectedEquation {
DirectedEquation {
name: Arc::from(name),
lhs: panproto_gat::Term::var("x"),
rhs: panproto_gat::Term::var("x"),
impl_term: identity_expr("x"),
inverse: Some(identity_expr("x")),
source_kind: Some(ValueKind::Str),
target_kind: Some(ValueKind::Str),
coercion_class: CoercionClass::Iso,
}
}
fn lying_iso_deq(name: &str) -> DirectedEquation {
DirectedEquation {
name: Arc::from(name),
lhs: panproto_gat::Term::var("x"),
rhs: panproto_gat::Term::app("upper", vec![panproto_gat::Term::var("x")]),
impl_term: upper_expr("x"),
inverse: Some(identity_expr("x")),
source_kind: Some(ValueKind::Str),
target_kind: Some(ValueKind::Str),
coercion_class: CoercionClass::Iso,
}
}
#[test]
fn registry_defaults_any_union_is_deterministic() {
let first = CoercionSampleRegistry::with_defaults();
let first_any: Vec<Literal> = first.samples_for(ValueKind::Any).to_vec();
for _ in 0..8 {
let next = CoercionSampleRegistry::with_defaults();
let next_any: Vec<Literal> = next.samples_for(ValueKind::Any).to_vec();
assert_eq!(
first_any, next_any,
"Any-union must be stable across with_defaults invocations",
);
}
assert!(matches!(first_any.first(), Some(Literal::Bool(false))));
}
#[test]
fn registry_defaults_cover_every_primitive_kind() {
let reg = CoercionSampleRegistry::with_defaults();
for kind in ValueKind::all() {
assert!(
!reg.samples_for(*kind).is_empty(),
"kind {kind:?} must have default samples"
);
}
assert!(!reg.is_empty());
}
#[test]
fn registry_check_passes_on_honest_iso() {
let reg = CoercionSampleRegistry::with_defaults();
let deq = honest_iso_deq("honest");
let violations = check_directed_equation_with_registry(&deq, ®, "x");
assert!(
violations.is_empty(),
"honest iso must pass: {violations:?}"
);
}
#[test]
fn registry_check_flags_lying_iso() {
let reg = CoercionSampleRegistry::with_defaults();
let deq = lying_iso_deq("lying");
let violations = check_directed_equation_with_registry(&deq, ®, "x");
assert!(!violations.is_empty(), "lying iso must be flagged");
}
#[test]
fn registry_check_falls_back_to_any_when_kind_missing() {
let mut reg = CoercionSampleRegistry::new();
reg.register(ValueKind::Any, vec![Literal::Str("Alice".to_owned())]);
let mut deq = lying_iso_deq("lying_untyped");
deq.source_kind = None;
let violations = check_directed_equation_with_registry(&deq, ®, "x");
assert!(!violations.is_empty());
}
#[test]
fn check_theory_reports_mixed_honesty() {
let theory = Theory::full(
"TestCoercionTheory",
Vec::new(),
Vec::new(),
Vec::new(),
Vec::new(),
vec![honest_iso_deq("honest"), lying_iso_deq("lying")],
Vec::new(),
);
let reg = CoercionSampleRegistry::with_defaults();
let report = check_theory(&theory, ®);
assert!(!report.is_clean());
assert_eq!(report.per_equation.len(), 2);
let (honest_name, honest_violations) = &report.per_equation[0];
assert_eq!(honest_name.as_ref(), "honest");
assert!(honest_violations.is_empty());
let (lying_name, lying_violations) = &report.per_equation[1];
assert_eq!(lying_name.as_ref(), "lying");
assert!(!lying_violations.is_empty());
assert!(report.violation_count() >= 1);
}
#[test]
fn coercion_law_validation_trait_succeeds_on_honest() {
let reg = CoercionSampleRegistry::with_defaults();
let deq = honest_iso_deq("honest");
assert!(deq.validate_coercion_law(®, "x").is_ok());
}
#[test]
fn coercion_law_validation_trait_flags_lying() {
let reg = CoercionSampleRegistry::with_defaults();
let deq = lying_iso_deq("lying");
let violations = check_directed_equation_with_registry(&deq, ®, "x");
assert!(!violations.is_empty());
let Err(err) = deq.validate_coercion_law(®, "x") else {
panic!("lying iso must yield Err in every build config");
};
assert!(!err.is_empty(), "Err payload must carry the violations");
}
#[test]
fn exhaustive_check_coercion_class() {
let forward = identity_expr("x");
let inverse = Some(identity_expr("x"));
let samples = [Literal::Str("probe".to_owned())];
for class in CoercionClass::all() {
let violations = check_coercion_laws(&forward, inverse.as_ref(), *class, &samples, "x");
for v in &violations {
assert!(
!matches!(v, CoercionLawViolation::UnknownClass { .. }),
"check_coercion_laws must handle {class:?} explicitly; \
got UnknownClass in {violations:?}",
);
}
}
}
#[test]
fn eval_error_on_wrong_type_is_reported() {
let violations = check_coercion_laws(
&upper_expr("x"),
None,
CoercionClass::Projection,
&[Literal::Int(42)],
"x",
);
assert!(
violations
.iter()
.any(|v| matches!(v, CoercionLawViolation::ForwardEvalError { .. })),
"expected ForwardEvalError, got {violations:?}"
);
}
}