use crate::export::{global_to_json, local_to_json};
use crate::import::{json_to_global, json_to_local, ImportError};
use crate::runner::LeanRunner;
use serde_json::Value;
use telltale_types::{GlobalType, LocalTypeR};
use thiserror::Error;
#[derive(Debug, Clone)]
pub enum ValidationResult {
Valid,
Invalid(String),
Error(String),
}
impl ValidationResult {
pub fn is_valid(&self) -> bool {
matches!(self, ValidationResult::Valid)
}
pub fn is_invalid(&self) -> bool {
matches!(self, ValidationResult::Invalid(_))
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SubtypingDecision {
IsSubtype,
NotSubtype,
}
impl From<bool> for SubtypingDecision {
fn from(b: bool) -> Self {
if b {
SubtypingDecision::IsSubtype
} else {
SubtypingDecision::NotSubtype
}
}
}
#[derive(Debug, Error)]
pub enum ValidateError {
#[error("Import error: {0}")]
Import(#[from] ImportError),
#[error("Structure mismatch: {0}")]
StructureMismatch(String),
#[error("Lean execution failed: {0}")]
LeanExecutionFailed(String),
}
pub struct Validator {
lean_path: Option<String>,
}
impl Default for Validator {
fn default() -> Self {
Self::new()
}
}
impl Validator {
pub fn new() -> Self {
Self { lean_path: None }
}
#[must_use]
pub fn with_lean_path(mut self, path: impl Into<String>) -> Self {
self.lean_path = Some(path.into());
self
}
pub fn validate_global_roundtrip(&self, g: &GlobalType) -> ValidationResult {
let json = global_to_json(g);
match json_to_global(&json) {
Ok(parsed) => {
if global_types_equal(g, &parsed) {
ValidationResult::Valid
} else {
ValidationResult::Invalid("Round-trip produced different structure".to_string())
}
}
Err(e) => ValidationResult::Error(format!("Parse error: {}", e)),
}
}
pub fn validate_local_roundtrip(&self, lt: &LocalTypeR) -> ValidationResult {
let json = local_to_json(lt);
match json_to_local(&json) {
Ok(parsed) => {
if local_types_equal(lt, &parsed) {
ValidationResult::Valid
} else {
ValidationResult::Invalid("Round-trip produced different structure".to_string())
}
}
Err(e) => ValidationResult::Error(format!("Parse error: {}", e)),
}
}
pub fn compare_projection(
&self,
rust_result: &LocalTypeR,
lean_json: &Value,
) -> Result<ValidationResult, ValidateError> {
let lean_result = json_to_local(lean_json)?;
if local_types_equal(rust_result, &lean_result) {
Ok(ValidationResult::Valid)
} else {
Ok(ValidationResult::Invalid(format!(
"Projection mismatch:\n Rust: {:?}\n Lean: {:?}",
rust_result, lean_result
)))
}
}
pub fn compare_subtyping(
&self,
rust_result: SubtypingDecision,
lean_result: SubtypingDecision,
) -> ValidationResult {
if rust_result == lean_result {
ValidationResult::Valid
} else {
ValidationResult::Invalid(format!(
"Subtyping mismatch: Rust={:?}, Lean={:?}",
rust_result, lean_result
))
}
}
pub fn validate_projection_with_lean(
&self,
choreography_json: &Value,
program_json: &Value,
) -> Result<ValidationResult, ValidateError> {
let runner = match &self.lean_path {
Some(path) => LeanRunner::with_binary_path(path)
.map_err(|e| ValidateError::LeanExecutionFailed(e.to_string()))?,
None => {
LeanRunner::new().map_err(|e| ValidateError::LeanExecutionFailed(e.to_string()))?
}
};
match runner.validate(choreography_json, program_json) {
Ok(result) => {
if result.success {
Ok(ValidationResult::Valid)
} else {
let msg = if result.message.is_empty() {
"projection mismatch".to_string()
} else {
result.message
};
Ok(ValidationResult::Invalid(msg))
}
}
Err(e) => Err(ValidateError::LeanExecutionFailed(e.to_string())),
}
}
#[must_use]
pub fn lean_available(&self) -> bool {
match &self.lean_path {
Some(path) => std::path::Path::new(path).exists(),
None => LeanRunner::is_available(),
}
}
}
fn global_types_equal(g1: &GlobalType, g2: &GlobalType) -> bool {
match (g1, g2) {
(GlobalType::End, GlobalType::End) => true,
(
GlobalType::Comm {
sender: s1,
receiver: r1,
branches: b1,
},
GlobalType::Comm {
sender: s2,
receiver: r2,
branches: b2,
},
) => {
s1 == s2
&& r1 == r2
&& b1.len() == b2.len()
&& b1
.iter()
.zip(b2.iter())
.all(|((l1, c1), (l2, c2))| labels_equal(l1, l2) && global_types_equal(c1, c2))
}
(GlobalType::Mu { var: v1, body: b1 }, GlobalType::Mu { var: v2, body: b2 }) => {
v1 == v2 && global_types_equal(b1, b2)
}
(GlobalType::Var(n1), GlobalType::Var(n2)) => n1 == n2,
_ => false,
}
}
fn local_types_equal(lt1: &LocalTypeR, lt2: &LocalTypeR) -> bool {
match (lt1, lt2) {
(LocalTypeR::End, LocalTypeR::End) => true,
(
LocalTypeR::Send {
partner: p1,
branches: b1,
},
LocalTypeR::Send {
partner: p2,
branches: b2,
},
) => {
p1 == p2
&& b1.len() == b2.len()
&& b1
.iter()
.zip(b2.iter())
.all(|((l1, vt1, c1), (l2, vt2, c2))| {
labels_equal(l1, l2) && vt1 == vt2 && local_types_equal(c1, c2)
})
}
(
LocalTypeR::Recv {
partner: p1,
branches: b1,
},
LocalTypeR::Recv {
partner: p2,
branches: b2,
},
) => {
p1 == p2
&& b1.len() == b2.len()
&& b1
.iter()
.zip(b2.iter())
.all(|((l1, vt1, c1), (l2, vt2, c2))| {
labels_equal(l1, l2) && vt1 == vt2 && local_types_equal(c1, c2)
})
}
(LocalTypeR::Mu { var: v1, body: b1 }, LocalTypeR::Mu { var: v2, body: b2 }) => {
v1 == v2 && local_types_equal(b1, b2)
}
(LocalTypeR::Var(n1), LocalTypeR::Var(n2)) => n1 == n2,
_ => false,
}
}
fn labels_equal(l1: &telltale_types::Label, l2: &telltale_types::Label) -> bool {
l1.name == l2.name && l1.sort == l2.sort
}
#[cfg(test)]
mod tests {
use super::*;
use telltale_types::Label;
use telltale_types::{PayloadSort, ValType};
#[test]
fn test_global_roundtrip_valid() {
let validator = Validator::new();
let g = GlobalType::comm("A", "B", vec![(Label::new("msg"), GlobalType::End)]);
assert!(validator.validate_global_roundtrip(&g).is_valid());
}
#[test]
fn test_local_roundtrip_valid() {
let validator = Validator::new();
let lt = LocalTypeR::send("B", Label::new("hello"), LocalTypeR::End);
assert!(validator.validate_local_roundtrip(<).is_valid());
}
#[test]
fn test_recursive_roundtrip() {
let validator = Validator::new();
let g = GlobalType::mu(
"X",
GlobalType::comm("A", "B", vec![(Label::new("ping"), GlobalType::var("X"))]),
);
assert!(validator.validate_global_roundtrip(&g).is_valid());
}
#[test]
fn test_compare_subtyping_match() {
let validator = Validator::new();
let result =
validator.compare_subtyping(SubtypingDecision::IsSubtype, SubtypingDecision::IsSubtype);
assert!(result.is_valid());
}
#[test]
fn test_compare_subtyping_mismatch() {
let validator = Validator::new();
let result = validator
.compare_subtyping(SubtypingDecision::IsSubtype, SubtypingDecision::NotSubtype);
assert!(result.is_invalid());
}
#[test]
fn test_compare_projection() {
use serde_json::json;
let validator = Validator::new();
let rust_result = LocalTypeR::send("B", Label::new("msg"), LocalTypeR::End);
let lean_json = json!({
"kind": "send",
"partner": "B",
"branches": [{
"label": { "name": "msg", "sort": "unit" },
"continuation": { "kind": "end" }
}]
});
let result = validator
.compare_projection(&rust_result, &lean_json)
.unwrap();
assert!(result.is_valid());
}
#[test]
fn test_compare_projection_rejects_payload_annotation_mismatch() {
use serde_json::json;
let validator = Validator::new();
let rust_result = LocalTypeR::Send {
partner: "B".to_string(),
branches: vec![(
Label::with_sort("msg", PayloadSort::Nat),
Some(ValType::Nat),
LocalTypeR::End,
)],
};
let lean_json = json!({
"kind": "send",
"partner": "B",
"branches": [{
"label": { "name": "msg", "sort": "nat" },
"continuation": { "kind": "end" }
}]
});
let result = validator
.compare_projection(&rust_result, &lean_json)
.unwrap();
assert!(result.is_invalid());
}
}