1use crate::export::{global_to_json, local_to_json};
7use crate::import::{json_to_global, json_to_local, ImportError};
8use crate::runner::LeanRunner;
9use serde_json::Value;
10use telltale_types::{GlobalType, LocalTypeR};
11use thiserror::Error;
12
13#[derive(Debug, Clone)]
15pub enum ValidationResult {
16 Valid,
18 Invalid(String),
20 Error(String),
22}
23
24impl ValidationResult {
25 pub fn is_valid(&self) -> bool {
27 matches!(self, ValidationResult::Valid)
28 }
29
30 pub fn is_invalid(&self) -> bool {
32 matches!(self, ValidationResult::Invalid(_))
33 }
34}
35
36#[derive(Debug, Clone, Copy, PartialEq, Eq)]
38pub enum SubtypingDecision {
39 IsSubtype,
41 NotSubtype,
43}
44
45impl From<bool> for SubtypingDecision {
46 fn from(b: bool) -> Self {
47 if b {
48 SubtypingDecision::IsSubtype
49 } else {
50 SubtypingDecision::NotSubtype
51 }
52 }
53}
54
55#[derive(Debug, Error)]
57pub enum ValidateError {
58 #[error("Import error: {0}")]
59 Import(#[from] ImportError),
60
61 #[error("Structure mismatch: {0}")]
62 StructureMismatch(String),
63
64 #[error("Lean execution failed: {0}")]
65 LeanExecutionFailed(String),
66}
67
68pub struct Validator {
70 lean_path: Option<String>,
72}
73
74impl Default for Validator {
75 fn default() -> Self {
76 Self::new()
77 }
78}
79
80impl Validator {
81 pub fn new() -> Self {
83 Self { lean_path: None }
84 }
85
86 #[must_use]
88 pub fn with_lean_path(mut self, path: impl Into<String>) -> Self {
89 self.lean_path = Some(path.into());
90 self
91 }
92
93 pub fn validate_global_roundtrip(&self, g: &GlobalType) -> ValidationResult {
95 let json = global_to_json(g);
96 match json_to_global(&json) {
97 Ok(parsed) => {
98 if global_types_equal(g, &parsed) {
99 ValidationResult::Valid
100 } else {
101 ValidationResult::Invalid("Round-trip produced different structure".to_string())
102 }
103 }
104 Err(e) => ValidationResult::Error(format!("Parse error: {}", e)),
105 }
106 }
107
108 pub fn validate_local_roundtrip(&self, lt: &LocalTypeR) -> ValidationResult {
110 let json = local_to_json(lt);
111 match json_to_local(&json) {
112 Ok(parsed) => {
113 if local_types_equal(lt, &parsed) {
114 ValidationResult::Valid
115 } else {
116 ValidationResult::Invalid("Round-trip produced different structure".to_string())
117 }
118 }
119 Err(e) => ValidationResult::Error(format!("Parse error: {}", e)),
120 }
121 }
122
123 pub fn compare_projection(
125 &self,
126 rust_result: &LocalTypeR,
127 lean_json: &Value,
128 ) -> Result<ValidationResult, ValidateError> {
129 let lean_result = json_to_local(lean_json)?;
130
131 if local_types_equal(rust_result, &lean_result) {
132 Ok(ValidationResult::Valid)
133 } else {
134 Ok(ValidationResult::Invalid(format!(
135 "Projection mismatch:\n Rust: {:?}\n Lean: {:?}",
136 rust_result, lean_result
137 )))
138 }
139 }
140
141 pub fn compare_subtyping(
143 &self,
144 rust_result: SubtypingDecision,
145 lean_result: SubtypingDecision,
146 ) -> ValidationResult {
147 if rust_result == lean_result {
148 ValidationResult::Valid
149 } else {
150 ValidationResult::Invalid(format!(
151 "Subtyping mismatch: Rust={:?}, Lean={:?}",
152 rust_result, lean_result
153 ))
154 }
155 }
156
157 pub fn validate_projection_with_lean(
171 &self,
172 choreography_json: &Value,
173 program_json: &Value,
174 ) -> Result<ValidationResult, ValidateError> {
175 let runner = match &self.lean_path {
176 Some(path) => LeanRunner::with_binary_path(path)
177 .map_err(|e| ValidateError::LeanExecutionFailed(e.to_string()))?,
178 None => {
179 LeanRunner::new().map_err(|e| ValidateError::LeanExecutionFailed(e.to_string()))?
180 }
181 };
182
183 match runner.validate(choreography_json, program_json) {
184 Ok(result) => {
185 if result.success {
186 Ok(ValidationResult::Valid)
187 } else {
188 let msg = if result.message.is_empty() {
189 "projection mismatch".to_string()
190 } else {
191 result.message
192 };
193 Ok(ValidationResult::Invalid(msg))
194 }
195 }
196 Err(e) => Err(ValidateError::LeanExecutionFailed(e.to_string())),
197 }
198 }
199
200 #[must_use]
204 pub fn lean_available(&self) -> bool {
205 match &self.lean_path {
206 Some(path) => std::path::Path::new(path).exists(),
207 None => LeanRunner::is_available(),
208 }
209 }
210}
211
212fn global_types_equal(g1: &GlobalType, g2: &GlobalType) -> bool {
214 match (g1, g2) {
215 (GlobalType::End, GlobalType::End) => true,
216
217 (
218 GlobalType::Comm {
219 sender: s1,
220 receiver: r1,
221 branches: b1,
222 },
223 GlobalType::Comm {
224 sender: s2,
225 receiver: r2,
226 branches: b2,
227 },
228 ) => {
229 s1 == s2
230 && r1 == r2
231 && b1.len() == b2.len()
232 && b1
233 .iter()
234 .zip(b2.iter())
235 .all(|((l1, c1), (l2, c2))| labels_equal(l1, l2) && global_types_equal(c1, c2))
236 }
237
238 (GlobalType::Mu { var: v1, body: b1 }, GlobalType::Mu { var: v2, body: b2 }) => {
239 v1 == v2 && global_types_equal(b1, b2)
240 }
241
242 (GlobalType::Var(n1), GlobalType::Var(n2)) => n1 == n2,
243
244 _ => false,
245 }
246}
247
248fn local_types_equal(lt1: &LocalTypeR, lt2: &LocalTypeR) -> bool {
250 match (lt1, lt2) {
251 (LocalTypeR::End, LocalTypeR::End) => true,
252
253 (
254 LocalTypeR::Send {
255 partner: p1,
256 branches: b1,
257 },
258 LocalTypeR::Send {
259 partner: p2,
260 branches: b2,
261 },
262 ) => {
263 p1 == p2
264 && b1.len() == b2.len()
265 && b1
266 .iter()
267 .zip(b2.iter())
268 .all(|((l1, vt1, c1), (l2, vt2, c2))| {
269 labels_equal(l1, l2) && vt1 == vt2 && local_types_equal(c1, c2)
270 })
271 }
272
273 (
274 LocalTypeR::Recv {
275 partner: p1,
276 branches: b1,
277 },
278 LocalTypeR::Recv {
279 partner: p2,
280 branches: b2,
281 },
282 ) => {
283 p1 == p2
284 && b1.len() == b2.len()
285 && b1
286 .iter()
287 .zip(b2.iter())
288 .all(|((l1, vt1, c1), (l2, vt2, c2))| {
289 labels_equal(l1, l2) && vt1 == vt2 && local_types_equal(c1, c2)
290 })
291 }
292
293 (LocalTypeR::Mu { var: v1, body: b1 }, LocalTypeR::Mu { var: v2, body: b2 }) => {
294 v1 == v2 && local_types_equal(b1, b2)
295 }
296
297 (LocalTypeR::Var(n1), LocalTypeR::Var(n2)) => n1 == n2,
298
299 _ => false,
300 }
301}
302
303fn labels_equal(l1: &telltale_types::Label, l2: &telltale_types::Label) -> bool {
305 l1.name == l2.name && l1.sort == l2.sort
306}
307
308#[cfg(test)]
309mod tests {
310 use super::*;
311 use telltale_types::Label;
312 use telltale_types::{PayloadSort, ValType};
313
314 #[test]
315 fn test_global_roundtrip_valid() {
316 let validator = Validator::new();
317 let g = GlobalType::comm("A", "B", vec![(Label::new("msg"), GlobalType::End)]);
318
319 assert!(validator.validate_global_roundtrip(&g).is_valid());
320 }
321
322 #[test]
323 fn test_local_roundtrip_valid() {
324 let validator = Validator::new();
325 let lt = LocalTypeR::send("B", Label::new("hello"), LocalTypeR::End);
326
327 assert!(validator.validate_local_roundtrip(<).is_valid());
328 }
329
330 #[test]
331 fn test_recursive_roundtrip() {
332 let validator = Validator::new();
333 let g = GlobalType::mu(
334 "X",
335 GlobalType::comm("A", "B", vec![(Label::new("ping"), GlobalType::var("X"))]),
336 );
337
338 assert!(validator.validate_global_roundtrip(&g).is_valid());
339 }
340
341 #[test]
342 fn test_compare_subtyping_match() {
343 let validator = Validator::new();
344 let result =
345 validator.compare_subtyping(SubtypingDecision::IsSubtype, SubtypingDecision::IsSubtype);
346 assert!(result.is_valid());
347 }
348
349 #[test]
350 fn test_compare_subtyping_mismatch() {
351 let validator = Validator::new();
352 let result = validator
353 .compare_subtyping(SubtypingDecision::IsSubtype, SubtypingDecision::NotSubtype);
354 assert!(result.is_invalid());
355 }
356
357 #[test]
358 fn test_compare_projection() {
359 use serde_json::json;
360
361 let validator = Validator::new();
362 let rust_result = LocalTypeR::send("B", Label::new("msg"), LocalTypeR::End);
363 let lean_json = json!({
364 "kind": "send",
365 "partner": "B",
366 "branches": [{
367 "label": { "name": "msg", "sort": "unit" },
368 "continuation": { "kind": "end" }
369 }]
370 });
371
372 let result = validator
373 .compare_projection(&rust_result, &lean_json)
374 .unwrap();
375 assert!(result.is_valid());
376 }
377
378 #[test]
379 fn test_compare_projection_rejects_payload_annotation_mismatch() {
380 use serde_json::json;
381
382 let validator = Validator::new();
383 let rust_result = LocalTypeR::Send {
384 partner: "B".to_string(),
385 branches: vec![(
386 Label::with_sort("msg", PayloadSort::Nat),
387 Some(ValType::Nat),
388 LocalTypeR::End,
389 )],
390 };
391 let lean_json = json!({
392 "kind": "send",
393 "partner": "B",
394 "branches": [{
395 "label": { "name": "msg", "sort": "nat" },
396 "continuation": { "kind": "end" }
397 }]
398 });
399
400 let result = validator
401 .compare_projection(&rust_result, &lean_json)
402 .unwrap();
403 assert!(result.is_invalid());
404 }
405}