1use std::collections::HashMap;
31
32use crate::term::{Literal, Term};
33
34#[derive(Debug, Clone, PartialEq, Eq, Hash)]
50pub enum STerm {
51 Lit(i64),
53 Var(i64),
55 Name(String),
57 App(Box<STerm>, Box<STerm>),
59}
60
61pub type Substitution = HashMap<i64, STerm>;
67
68fn term_to_sterm(term: &Term) -> Option<STerm> {
74 if let Some(n) = extract_slit(term) {
76 return Some(STerm::Lit(n));
77 }
78
79 if let Some(i) = extract_svar(term) {
81 return Some(STerm::Var(i));
82 }
83
84 if let Some(s) = extract_sname(term) {
86 return Some(STerm::Name(s));
87 }
88
89 if let Some((f, a)) = extract_sapp(term) {
91 let sf = term_to_sterm(&f)?;
92 let sa = term_to_sterm(&a)?;
93 return Some(STerm::App(Box::new(sf), Box::new(sa)));
94 }
95
96 None
97}
98
99fn sterm_to_term(st: &STerm) -> Term {
101 match st {
102 STerm::Lit(n) => Term::App(
103 Box::new(Term::Global("SLit".to_string())),
104 Box::new(Term::Lit(Literal::Int(*n))),
105 ),
106 STerm::Var(i) => Term::App(
107 Box::new(Term::Global("SVar".to_string())),
108 Box::new(Term::Lit(Literal::Int(*i))),
109 ),
110 STerm::Name(s) => Term::App(
111 Box::new(Term::Global("SName".to_string())),
112 Box::new(Term::Lit(Literal::Text(s.clone()))),
113 ),
114 STerm::App(f, a) => Term::App(
115 Box::new(Term::App(
116 Box::new(Term::Global("SApp".to_string())),
117 Box::new(sterm_to_term(f)),
118 )),
119 Box::new(sterm_to_term(a)),
120 ),
121 }
122}
123
124fn simplify_sterm(term: &STerm, subst: &Substitution, fuel: usize) -> STerm {
131 if fuel == 0 {
132 return term.clone();
133 }
134
135 match term {
136 STerm::Var(i) => {
138 if let Some(replacement) = subst.get(i) {
139 simplify_sterm(replacement, subst, fuel - 1)
141 } else {
142 term.clone()
143 }
144 }
145
146 STerm::Lit(_) => term.clone(),
148 STerm::Name(_) => term.clone(),
149
150 STerm::App(f, a) => {
152 let sf = simplify_sterm(f, subst, fuel - 1);
153 let sa = simplify_sterm(a, subst, fuel - 1);
154
155 if let Some(result) = try_arithmetic(&sf, &sa) {
157 return simplify_sterm(&result, subst, fuel - 1);
158 }
159
160 STerm::App(Box::new(sf), Box::new(sa))
161 }
162 }
163}
164
165fn try_arithmetic(func: &STerm, arg: &STerm) -> Option<STerm> {
168 if let STerm::App(op_box, x_box) = func {
172 if let STerm::Name(op) = op_box.as_ref() {
173 if let (STerm::Lit(x), STerm::Lit(y)) = (x_box.as_ref(), arg) {
174 let result = match op.as_str() {
175 "add" => x.checked_add(*y)?,
176 "sub" => x.checked_sub(*y)?,
177 "mul" => x.checked_mul(*y)?,
178 "div" if *y != 0 => x.checked_div(*y)?,
179 "mod" if *y != 0 => x.checked_rem(*y)?,
180 _ => return None,
181 };
182 return Some(STerm::Lit(result));
183 }
184 }
185 }
186 None
187}
188
189fn decompose_goal(goal: &Term) -> (Substitution, Term) {
197 let mut subst = HashMap::new();
198 let mut current = goal.clone();
199
200 while let Some((hyp, rest)) = extract_implication(¤t) {
202 if let Some((lhs, rhs)) = extract_equality(&hyp) {
204 if let Some(st_lhs) = term_to_sterm(&lhs) {
206 if let STerm::Var(i) = st_lhs {
207 if let Some(st_rhs) = term_to_sterm(&rhs) {
209 subst.insert(i, st_rhs);
210 }
211 }
212 }
213 }
214 current = rest;
215 }
216
217 (subst, current)
218}
219
220pub fn check_goal(goal: &Term) -> bool {
236 let (subst, conclusion) = decompose_goal(goal);
237
238 let (lhs, rhs) = match extract_equality(&conclusion) {
240 Some(eq) => eq,
241 None => return false,
242 };
243
244 let st_lhs = match term_to_sterm(&lhs) {
246 Some(t) => t,
247 None => return false,
248 };
249
250 let st_rhs = match term_to_sterm(&rhs) {
251 Some(t) => t,
252 None => return false,
253 };
254
255 const FUEL: usize = 1000;
257 let simp_lhs = simplify_sterm(&st_lhs, &subst, FUEL);
258 let simp_rhs = simplify_sterm(&st_rhs, &subst, FUEL);
259
260 simp_lhs == simp_rhs
262}
263
264fn extract_slit(term: &Term) -> Option<i64> {
270 if let Term::App(ctor, arg) = term {
271 if let Term::Global(name) = ctor.as_ref() {
272 if name == "SLit" {
273 if let Term::Lit(Literal::Int(n)) = arg.as_ref() {
274 return Some(*n);
275 }
276 }
277 }
278 }
279 None
280}
281
282fn extract_svar(term: &Term) -> Option<i64> {
284 if let Term::App(ctor, arg) = term {
285 if let Term::Global(name) = ctor.as_ref() {
286 if name == "SVar" {
287 if let Term::Lit(Literal::Int(i)) = arg.as_ref() {
288 return Some(*i);
289 }
290 }
291 }
292 }
293 None
294}
295
296fn extract_sname(term: &Term) -> Option<String> {
298 if let Term::App(ctor, arg) = term {
299 if let Term::Global(name) = ctor.as_ref() {
300 if name == "SName" {
301 if let Term::Lit(Literal::Text(s)) = arg.as_ref() {
302 return Some(s.clone());
303 }
304 }
305 }
306 }
307 None
308}
309
310fn extract_sapp(term: &Term) -> Option<(Term, Term)> {
312 if let Term::App(outer, arg) = term {
313 if let Term::App(sapp, func) = outer.as_ref() {
314 if let Term::Global(ctor) = sapp.as_ref() {
315 if ctor == "SApp" {
316 return Some((func.as_ref().clone(), arg.as_ref().clone()));
317 }
318 }
319 }
320 }
321 None
322}
323
324fn extract_implication(term: &Term) -> Option<(Term, Term)> {
326 if let Some((op, hyp, concl)) = extract_binary_app(term) {
327 if op == "implies" {
328 return Some((hyp, concl));
329 }
330 }
331 None
332}
333
334fn extract_equality(term: &Term) -> Option<(Term, Term)> {
337 if let Some((op, lhs, rhs)) = extract_binary_app(term) {
339 if op == "Eq" || op == "eq" {
340 return Some((lhs, rhs));
341 }
342 }
343
344 if let Some((lhs, rhs)) = extract_ternary_eq(term) {
346 return Some((lhs, rhs));
347 }
348
349 None
350}
351
352fn extract_ternary_eq(term: &Term) -> Option<(Term, Term)> {
354 let (func, rhs) = extract_sapp(term)?;
356
357 let (func2, lhs) = extract_sapp(&func)?;
359
360 let (eq_name, _ty) = extract_sapp(&func2)?;
362
363 let name = extract_sname(&eq_name)?;
365 if name == "Eq" {
366 return Some((lhs, rhs));
367 }
368
369 None
370}
371
372fn extract_binary_app(term: &Term) -> Option<(String, Term, Term)> {
374 if let Term::App(outer, b) = term {
375 if let Term::App(sapp_outer, inner) = outer.as_ref() {
376 if let Term::Global(ctor) = sapp_outer.as_ref() {
377 if ctor == "SApp" {
378 if let Term::App(partial, a) = inner.as_ref() {
379 if let Term::App(sapp_inner, op_term) = partial.as_ref() {
380 if let Term::Global(ctor2) = sapp_inner.as_ref() {
381 if ctor2 == "SApp" {
382 if let Some(op) = extract_sname(op_term) {
383 return Some((
384 op,
385 a.as_ref().clone(),
386 b.as_ref().clone(),
387 ));
388 }
389 }
390 }
391 }
392 }
393 }
394 }
395 }
396 }
397 None
398}
399
400#[cfg(test)]
405mod tests {
406 use super::*;
407
408 fn make_sname(s: &str) -> Term {
410 Term::App(
411 Box::new(Term::Global("SName".to_string())),
412 Box::new(Term::Lit(Literal::Text(s.to_string()))),
413 )
414 }
415
416 fn make_svar(i: i64) -> Term {
418 Term::App(
419 Box::new(Term::Global("SVar".to_string())),
420 Box::new(Term::Lit(Literal::Int(i))),
421 )
422 }
423
424 fn make_slit(n: i64) -> Term {
426 Term::App(
427 Box::new(Term::Global("SLit".to_string())),
428 Box::new(Term::Lit(Literal::Int(n))),
429 )
430 }
431
432 fn make_sapp(f: Term, a: Term) -> Term {
434 Term::App(
435 Box::new(Term::App(
436 Box::new(Term::Global("SApp".to_string())),
437 Box::new(f),
438 )),
439 Box::new(a),
440 )
441 }
442
443 #[test]
444 fn test_term_to_sterm_lit() {
445 let term = make_slit(42);
446 let result = term_to_sterm(&term);
447 assert_eq!(result, Some(STerm::Lit(42)));
448 }
449
450 #[test]
451 fn test_term_to_sterm_var() {
452 let term = make_svar(0);
453 let result = term_to_sterm(&term);
454 assert_eq!(result, Some(STerm::Var(0)));
455 }
456
457 #[test]
458 fn test_term_to_sterm_name() {
459 let term = make_sname("add");
460 let result = term_to_sterm(&term);
461 assert_eq!(result, Some(STerm::Name("add".to_string())));
462 }
463
464 #[test]
465 fn test_term_to_sterm_app() {
466 let add_2 = make_sapp(make_sname("add"), make_slit(2));
468 let add_2_3 = make_sapp(add_2, make_slit(3));
469 let result = term_to_sterm(&add_2_3);
470
471 let expected = STerm::App(
472 Box::new(STerm::App(
473 Box::new(STerm::Name("add".to_string())),
474 Box::new(STerm::Lit(2)),
475 )),
476 Box::new(STerm::Lit(3)),
477 );
478 assert_eq!(result, Some(expected));
479 }
480
481 #[test]
482 fn test_arithmetic_add() {
483 let func = STerm::App(
485 Box::new(STerm::Name("add".to_string())),
486 Box::new(STerm::Lit(2)),
487 );
488 let arg = STerm::Lit(3);
489 let result = try_arithmetic(&func, &arg);
490 assert_eq!(result, Some(STerm::Lit(5)));
491 }
492
493 #[test]
494 fn test_arithmetic_mul() {
495 let func = STerm::App(
496 Box::new(STerm::Name("mul".to_string())),
497 Box::new(STerm::Lit(4)),
498 );
499 let arg = STerm::Lit(5);
500 let result = try_arithmetic(&func, &arg);
501 assert_eq!(result, Some(STerm::Lit(20)));
502 }
503
504 #[test]
505 fn test_arithmetic_sub() {
506 let func = STerm::App(
507 Box::new(STerm::Name("sub".to_string())),
508 Box::new(STerm::Lit(10)),
509 );
510 let arg = STerm::Lit(3);
511 let result = try_arithmetic(&func, &arg);
512 assert_eq!(result, Some(STerm::Lit(7)));
513 }
514
515 #[test]
516 fn test_simplify_constant_addition() {
517 let term = STerm::App(
519 Box::new(STerm::App(
520 Box::new(STerm::Name("add".to_string())),
521 Box::new(STerm::Lit(2)),
522 )),
523 Box::new(STerm::Lit(3)),
524 );
525 let result = simplify_sterm(&term, &HashMap::new(), 100);
526 assert_eq!(result, STerm::Lit(5));
527 }
528
529 #[test]
530 fn test_simplify_nested_arithmetic() {
531 let one_plus_one = STerm::App(
533 Box::new(STerm::App(
534 Box::new(STerm::Name("add".to_string())),
535 Box::new(STerm::Lit(1)),
536 )),
537 Box::new(STerm::Lit(1)),
538 );
539 let term = STerm::App(
540 Box::new(STerm::App(
541 Box::new(STerm::Name("mul".to_string())),
542 Box::new(one_plus_one),
543 )),
544 Box::new(STerm::Lit(3)),
545 );
546 let result = simplify_sterm(&term, &HashMap::new(), 100);
547 assert_eq!(result, STerm::Lit(6));
548 }
549
550 #[test]
551 fn test_simplify_with_substitution() {
552 let x_plus_1 = STerm::App(
554 Box::new(STerm::App(
555 Box::new(STerm::Name("add".to_string())),
556 Box::new(STerm::Var(0)),
557 )),
558 Box::new(STerm::Lit(1)),
559 );
560 let mut subst = HashMap::new();
561 subst.insert(0, STerm::Lit(0));
562
563 let result = simplify_sterm(&x_plus_1, &subst, 100);
564 assert_eq!(result, STerm::Lit(1));
565 }
566
567 #[test]
568 fn test_check_goal_reflexive() {
569 let x = make_svar(0);
571 let goal = make_sapp(make_sapp(make_sname("Eq"), x.clone()), x);
572 assert!(check_goal(&goal), "simp should prove x = x");
573 }
574
575 #[test]
576 fn test_check_goal_constant() {
577 let add_2_3 = make_sapp(make_sapp(make_sname("add"), make_slit(2)), make_slit(3));
579 let goal = make_sapp(make_sapp(make_sname("Eq"), add_2_3), make_slit(5));
580 assert!(check_goal(&goal), "simp should prove 2+3 = 5");
581 }
582
583 #[test]
584 fn test_check_goal_with_hypothesis() {
585 let x = make_svar(0);
587 let zero = make_slit(0);
588 let one = make_slit(1);
589
590 let x_plus_1 = make_sapp(make_sapp(make_sname("add"), x.clone()), one.clone());
591 let hyp = make_sapp(make_sapp(make_sname("Eq"), x), zero);
592 let concl = make_sapp(make_sapp(make_sname("Eq"), x_plus_1), one);
593 let goal = make_sapp(make_sapp(make_sname("implies"), hyp), concl);
594
595 assert!(check_goal(&goal), "simp should prove x=0 -> x+1=1");
596 }
597
598 #[test]
599 fn test_check_goal_false_equality() {
600 let goal = make_sapp(make_sapp(make_sname("Eq"), make_slit(2)), make_slit(3));
602 assert!(!check_goal(&goal), "simp should NOT prove 2 = 3");
603 }
604}