Skip to main content

logicaffeine_kernel/
simp.rs

1//! Simplifier Tactic
2//!
3//! Normalizes goals by applying rewrite rules from the context and performing
4//! arithmetic evaluation. This is a general-purpose simplification tactic.
5//!
6//! # Algorithm
7//!
8//! The simp tactic works in four steps:
9//! 1. **Extract hypotheses**: Peel implications and extract `x = t` as rewrites
10//! 2. **Simplify LHS**: Apply substitutions and arithmetic bottom-up
11//! 3. **Simplify RHS**: Apply substitutions and arithmetic bottom-up
12//! 4. **Compare**: Goal succeeds if simplified LHS equals simplified RHS
13//!
14//! # Supported Simplifications
15//!
16//! - **Reflexive equalities**: `Eq a a` succeeds immediately
17//! - **Constant folding**: `Eq (add 2 3) 5` simplifies to `Eq 5 5`
18//! - **Hypothesis substitution**: Given `x = 0`, `x + 1` becomes `0 + 1` then `1`
19//!
20//! # Arithmetic Operations
21//!
22//! Supports `add`, `sub`, `mul`, `div`, `mod` on integer literals.
23//! Non-literal arithmetic is left unevaluated.
24//!
25//! # Fuel Limit
26//!
27//! Simplification uses a fuel counter to prevent infinite loops from
28//! cyclic rewrites. The default fuel is 1000 simplification steps.
29
30use std::collections::HashMap;
31
32use crate::term::{Literal, Term};
33
34// =============================================================================
35// SIMPLIFIED SYNTAX REPRESENTATION
36// =============================================================================
37
38/// Simplified representation of Syntax terms for rewriting.
39///
40/// This is a flattened version of the kernel's `Term` type, specialized
41/// for the Syntax deep embedding. Unlike `Term`, `STerm` is designed
42/// for efficient pattern matching and rewriting during simplification.
43///
44/// The correspondence is:
45/// - `Term::App(Global("SLit"), Lit(n))` -> `STerm::Lit(n)`
46/// - `Term::App(Global("SVar"), Lit(i))` -> `STerm::Var(i)`
47/// - `Term::App(Global("SName"), Lit(s))` -> `STerm::Name(s)`
48/// - `Term::App(App(Global("SApp"), f), a)` -> `STerm::App(f, a)`
49#[derive(Debug, Clone, PartialEq, Eq, Hash)]
50pub enum STerm {
51    /// Integer literal from `SLit n`.
52    Lit(i64),
53    /// De Bruijn variable from `SVar i`.
54    Var(i64),
55    /// Named constant or function symbol from `SName s`.
56    Name(String),
57    /// Function application from `SApp f a`.
58    App(Box<STerm>, Box<STerm>),
59}
60
61/// Substitution mapping variable indices to replacement terms.
62///
63/// Extracted from hypothesis equalities like `x = t`, where `x` is a
64/// De Bruijn variable. Applied during simplification to replace variables
65/// with their known values.
66pub type Substitution = HashMap<i64, STerm>;
67
68// =============================================================================
69// TERM CONVERSION
70// =============================================================================
71
72/// Convert a kernel Term (representing Syntax) to our simplified STerm
73fn term_to_sterm(term: &Term) -> Option<STerm> {
74    // SLit n
75    if let Some(n) = extract_slit(term) {
76        return Some(STerm::Lit(n));
77    }
78
79    // SVar i
80    if let Some(i) = extract_svar(term) {
81        return Some(STerm::Var(i));
82    }
83
84    // SName s
85    if let Some(s) = extract_sname(term) {
86        return Some(STerm::Name(s));
87    }
88
89    // SApp f a
90    if let Some((f, a)) = extract_sapp(term) {
91        let sf = term_to_sterm(&f)?;
92        let sa = term_to_sterm(&a)?;
93        return Some(STerm::App(Box::new(sf), Box::new(sa)));
94    }
95
96    None
97}
98
99/// Convert STerm back to kernel Term (Syntax encoding)
100fn sterm_to_term(st: &STerm) -> Term {
101    match st {
102        STerm::Lit(n) => Term::App(
103            Box::new(Term::Global("SLit".to_string())),
104            Box::new(Term::Lit(Literal::Int(*n))),
105        ),
106        STerm::Var(i) => Term::App(
107            Box::new(Term::Global("SVar".to_string())),
108            Box::new(Term::Lit(Literal::Int(*i))),
109        ),
110        STerm::Name(s) => Term::App(
111            Box::new(Term::Global("SName".to_string())),
112            Box::new(Term::Lit(Literal::Text(s.clone()))),
113        ),
114        STerm::App(f, a) => Term::App(
115            Box::new(Term::App(
116                Box::new(Term::Global("SApp".to_string())),
117                Box::new(sterm_to_term(f)),
118            )),
119            Box::new(sterm_to_term(a)),
120        ),
121    }
122}
123
124// =============================================================================
125// SIMPLIFICATION ENGINE
126// =============================================================================
127
128/// Simplify an STerm using the given substitution (from hypotheses)
129/// and arithmetic evaluation.
130fn simplify_sterm(term: &STerm, subst: &Substitution, fuel: usize) -> STerm {
131    if fuel == 0 {
132        return term.clone();
133    }
134
135    match term {
136        // Variables: apply substitution if bound
137        STerm::Var(i) => {
138            if let Some(replacement) = subst.get(i) {
139                // Re-simplify the replacement (may enable more rewrites)
140                simplify_sterm(replacement, subst, fuel - 1)
141            } else {
142                term.clone()
143            }
144        }
145
146        // Literals and names are already simplified
147        STerm::Lit(_) => term.clone(),
148        STerm::Name(_) => term.clone(),
149
150        // Applications: simplify children first, then try arithmetic
151        STerm::App(f, a) => {
152            let sf = simplify_sterm(f, subst, fuel - 1);
153            let sa = simplify_sterm(a, subst, fuel - 1);
154
155            // Try arithmetic simplification on the simplified application
156            if let Some(result) = try_arithmetic(&sf, &sa) {
157                return simplify_sterm(&result, subst, fuel - 1);
158            }
159
160            STerm::App(Box::new(sf), Box::new(sa))
161        }
162    }
163}
164
165/// Try to evaluate arithmetic operations on literals.
166/// Handles: add, sub, mul, div, mod
167fn try_arithmetic(func: &STerm, arg: &STerm) -> Option<STerm> {
168    // Pattern: (add x) y, (sub x) y, (mul x) y, etc.
169    // func = App(Name("add"), x)
170    // arg = y
171    if let STerm::App(op_box, x_box) = func {
172        if let STerm::Name(op) = op_box.as_ref() {
173            if let (STerm::Lit(x), STerm::Lit(y)) = (x_box.as_ref(), arg) {
174                let result = match op.as_str() {
175                    "add" => x.checked_add(*y)?,
176                    "sub" => x.checked_sub(*y)?,
177                    "mul" => x.checked_mul(*y)?,
178                    "div" if *y != 0 => x.checked_div(*y)?,
179                    "mod" if *y != 0 => x.checked_rem(*y)?,
180                    _ => return None,
181                };
182                return Some(STerm::Lit(result));
183            }
184        }
185    }
186    None
187}
188
189// =============================================================================
190// GOAL DECOMPOSITION
191// =============================================================================
192
193/// Extract hypotheses and conclusion from a goal.
194/// Handles nested implications: h1 -> h2 -> ... -> conclusion
195/// Returns (substitution from hypotheses, conclusion)
196fn decompose_goal(goal: &Term) -> (Substitution, Term) {
197    let mut subst = HashMap::new();
198    let mut current = goal.clone();
199
200    // Peel off nested implications
201    while let Some((hyp, rest)) = extract_implication(&current) {
202        // Extract equality from hypothesis
203        if let Some((lhs, rhs)) = extract_equality(&hyp) {
204            // Convert LHS to check if it's a variable
205            if let Some(st_lhs) = term_to_sterm(&lhs) {
206                if let STerm::Var(i) = st_lhs {
207                    // Variable on LHS: add substitution i → rhs
208                    if let Some(st_rhs) = term_to_sterm(&rhs) {
209                        subst.insert(i, st_rhs);
210                    }
211                }
212            }
213        }
214        current = rest;
215    }
216
217    (subst, current)
218}
219
220/// Check if a goal is provable by simplification.
221///
222/// This is the main entry point for the simp tactic. It extracts
223/// hypothesis equalities as rewrites, simplifies both sides of the
224/// conclusion equality, and checks for syntactic equality.
225///
226/// # Supported Goals
227///
228/// - Bare equalities: `Eq a b` where `a` simplifies to `b`
229/// - Implications: `(Eq x 0) -> (Eq (add x 1) 1)` using hypothesis as rewrite
230/// - Nested implications with multiple hypothesis rewrites
231///
232/// # Returns
233///
234/// `true` if the simplified LHS equals the simplified RHS, `false` otherwise.
235pub fn check_goal(goal: &Term) -> bool {
236    let (subst, conclusion) = decompose_goal(goal);
237
238    // Conclusion must be an equality
239    let (lhs, rhs) = match extract_equality(&conclusion) {
240        Some(eq) => eq,
241        None => return false,
242    };
243
244    // Convert to STerm
245    let st_lhs = match term_to_sterm(&lhs) {
246        Some(t) => t,
247        None => return false,
248    };
249
250    let st_rhs = match term_to_sterm(&rhs) {
251        Some(t) => t,
252        None => return false,
253    };
254
255    // Simplify both sides
256    const FUEL: usize = 1000;
257    let simp_lhs = simplify_sterm(&st_lhs, &subst, FUEL);
258    let simp_rhs = simplify_sterm(&st_rhs, &subst, FUEL);
259
260    // Check if they're equal
261    simp_lhs == simp_rhs
262}
263
264// =============================================================================
265// HELPER EXTRACTORS (same pattern as cc.rs)
266// =============================================================================
267
268/// Extract integer from SLit n
269fn extract_slit(term: &Term) -> Option<i64> {
270    if let Term::App(ctor, arg) = term {
271        if let Term::Global(name) = ctor.as_ref() {
272            if name == "SLit" {
273                if let Term::Lit(Literal::Int(n)) = arg.as_ref() {
274                    return Some(*n);
275                }
276            }
277        }
278    }
279    None
280}
281
282/// Extract variable index from SVar i
283fn extract_svar(term: &Term) -> Option<i64> {
284    if let Term::App(ctor, arg) = term {
285        if let Term::Global(name) = ctor.as_ref() {
286            if name == "SVar" {
287                if let Term::Lit(Literal::Int(i)) = arg.as_ref() {
288                    return Some(*i);
289                }
290            }
291        }
292    }
293    None
294}
295
296/// Extract name from SName "x"
297fn extract_sname(term: &Term) -> Option<String> {
298    if let Term::App(ctor, arg) = term {
299        if let Term::Global(name) = ctor.as_ref() {
300            if name == "SName" {
301                if let Term::Lit(Literal::Text(s)) = arg.as_ref() {
302                    return Some(s.clone());
303                }
304            }
305        }
306    }
307    None
308}
309
310/// Extract unary application: SApp f a
311fn extract_sapp(term: &Term) -> Option<(Term, Term)> {
312    if let Term::App(outer, arg) = term {
313        if let Term::App(sapp, func) = outer.as_ref() {
314            if let Term::Global(ctor) = sapp.as_ref() {
315                if ctor == "SApp" {
316                    return Some((func.as_ref().clone(), arg.as_ref().clone()));
317                }
318            }
319        }
320    }
321    None
322}
323
324/// Extract implication: SApp (SApp (SName "implies") hyp) concl
325fn extract_implication(term: &Term) -> Option<(Term, Term)> {
326    if let Some((op, hyp, concl)) = extract_binary_app(term) {
327        if op == "implies" {
328            return Some((hyp, concl));
329        }
330    }
331    None
332}
333
334/// Extract equality: SApp (SApp (SName "Eq") lhs) rhs
335/// Also handles: SApp (SApp (SApp (SName "Eq") ty) lhs) rhs
336fn extract_equality(term: &Term) -> Option<(Term, Term)> {
337    // Try binary Eq first (no type annotation)
338    if let Some((op, lhs, rhs)) = extract_binary_app(term) {
339        if op == "Eq" || op == "eq" {
340            return Some((lhs, rhs));
341        }
342    }
343
344    // Try ternary Eq (with type annotation): (Eq T) lhs rhs
345    if let Some((lhs, rhs)) = extract_ternary_eq(term) {
346        return Some((lhs, rhs));
347    }
348
349    None
350}
351
352/// Extract ternary equality: SApp (SApp (SApp (SName "Eq") ty) lhs) rhs
353fn extract_ternary_eq(term: &Term) -> Option<(Term, Term)> {
354    // term = SApp func rhs, where func = SApp (SApp (SName "Eq") ty) lhs
355    let (func, rhs) = extract_sapp(term)?;
356
357    // func = SApp func2 lhs, where func2 = SApp (SName "Eq") ty
358    let (func2, lhs) = extract_sapp(&func)?;
359
360    // func2 = SApp eq_name ty, where eq_name = SName "Eq"
361    let (eq_name, _ty) = extract_sapp(&func2)?;
362
363    // Check that eq_name is SName "Eq"
364    let name = extract_sname(&eq_name)?;
365    if name == "Eq" {
366        return Some((lhs, rhs));
367    }
368
369    None
370}
371
372/// Extract binary application: SApp (SApp (SName "op") a) b
373fn extract_binary_app(term: &Term) -> Option<(String, Term, Term)> {
374    if let Term::App(outer, b) = term {
375        if let Term::App(sapp_outer, inner) = outer.as_ref() {
376            if let Term::Global(ctor) = sapp_outer.as_ref() {
377                if ctor == "SApp" {
378                    if let Term::App(partial, a) = inner.as_ref() {
379                        if let Term::App(sapp_inner, op_term) = partial.as_ref() {
380                            if let Term::Global(ctor2) = sapp_inner.as_ref() {
381                                if ctor2 == "SApp" {
382                                    if let Some(op) = extract_sname(op_term) {
383                                        return Some((
384                                            op,
385                                            a.as_ref().clone(),
386                                            b.as_ref().clone(),
387                                        ));
388                                    }
389                                }
390                            }
391                        }
392                    }
393                }
394            }
395        }
396    }
397    None
398}
399
400// =============================================================================
401// UNIT TESTS
402// =============================================================================
403
404#[cfg(test)]
405mod tests {
406    use super::*;
407
408    /// Helper to build SName "s"
409    fn make_sname(s: &str) -> Term {
410        Term::App(
411            Box::new(Term::Global("SName".to_string())),
412            Box::new(Term::Lit(Literal::Text(s.to_string()))),
413        )
414    }
415
416    /// Helper to build SVar i
417    fn make_svar(i: i64) -> Term {
418        Term::App(
419            Box::new(Term::Global("SVar".to_string())),
420            Box::new(Term::Lit(Literal::Int(i))),
421        )
422    }
423
424    /// Helper to build SLit n
425    fn make_slit(n: i64) -> Term {
426        Term::App(
427            Box::new(Term::Global("SLit".to_string())),
428            Box::new(Term::Lit(Literal::Int(n))),
429        )
430    }
431
432    /// Helper to build SApp f a
433    fn make_sapp(f: Term, a: Term) -> Term {
434        Term::App(
435            Box::new(Term::App(
436                Box::new(Term::Global("SApp".to_string())),
437                Box::new(f),
438            )),
439            Box::new(a),
440        )
441    }
442
443    #[test]
444    fn test_term_to_sterm_lit() {
445        let term = make_slit(42);
446        let result = term_to_sterm(&term);
447        assert_eq!(result, Some(STerm::Lit(42)));
448    }
449
450    #[test]
451    fn test_term_to_sterm_var() {
452        let term = make_svar(0);
453        let result = term_to_sterm(&term);
454        assert_eq!(result, Some(STerm::Var(0)));
455    }
456
457    #[test]
458    fn test_term_to_sterm_name() {
459        let term = make_sname("add");
460        let result = term_to_sterm(&term);
461        assert_eq!(result, Some(STerm::Name("add".to_string())));
462    }
463
464    #[test]
465    fn test_term_to_sterm_app() {
466        // (add 2 3) = SApp (SApp (SName "add") (SLit 2)) (SLit 3)
467        let add_2 = make_sapp(make_sname("add"), make_slit(2));
468        let add_2_3 = make_sapp(add_2, make_slit(3));
469        let result = term_to_sterm(&add_2_3);
470
471        let expected = STerm::App(
472            Box::new(STerm::App(
473                Box::new(STerm::Name("add".to_string())),
474                Box::new(STerm::Lit(2)),
475            )),
476            Box::new(STerm::Lit(3)),
477        );
478        assert_eq!(result, Some(expected));
479    }
480
481    #[test]
482    fn test_arithmetic_add() {
483        // (add 2) applied to 3
484        let func = STerm::App(
485            Box::new(STerm::Name("add".to_string())),
486            Box::new(STerm::Lit(2)),
487        );
488        let arg = STerm::Lit(3);
489        let result = try_arithmetic(&func, &arg);
490        assert_eq!(result, Some(STerm::Lit(5)));
491    }
492
493    #[test]
494    fn test_arithmetic_mul() {
495        let func = STerm::App(
496            Box::new(STerm::Name("mul".to_string())),
497            Box::new(STerm::Lit(4)),
498        );
499        let arg = STerm::Lit(5);
500        let result = try_arithmetic(&func, &arg);
501        assert_eq!(result, Some(STerm::Lit(20)));
502    }
503
504    #[test]
505    fn test_arithmetic_sub() {
506        let func = STerm::App(
507            Box::new(STerm::Name("sub".to_string())),
508            Box::new(STerm::Lit(10)),
509        );
510        let arg = STerm::Lit(3);
511        let result = try_arithmetic(&func, &arg);
512        assert_eq!(result, Some(STerm::Lit(7)));
513    }
514
515    #[test]
516    fn test_simplify_constant_addition() {
517        // 2 + 3 should simplify to 5
518        let term = STerm::App(
519            Box::new(STerm::App(
520                Box::new(STerm::Name("add".to_string())),
521                Box::new(STerm::Lit(2)),
522            )),
523            Box::new(STerm::Lit(3)),
524        );
525        let result = simplify_sterm(&term, &HashMap::new(), 100);
526        assert_eq!(result, STerm::Lit(5));
527    }
528
529    #[test]
530    fn test_simplify_nested_arithmetic() {
531        // (1 + 1) * 3 = 6
532        let one_plus_one = STerm::App(
533            Box::new(STerm::App(
534                Box::new(STerm::Name("add".to_string())),
535                Box::new(STerm::Lit(1)),
536            )),
537            Box::new(STerm::Lit(1)),
538        );
539        let term = STerm::App(
540            Box::new(STerm::App(
541                Box::new(STerm::Name("mul".to_string())),
542                Box::new(one_plus_one),
543            )),
544            Box::new(STerm::Lit(3)),
545        );
546        let result = simplify_sterm(&term, &HashMap::new(), 100);
547        assert_eq!(result, STerm::Lit(6));
548    }
549
550    #[test]
551    fn test_simplify_with_substitution() {
552        // x + 1 with x = 0 should give 1
553        let x_plus_1 = STerm::App(
554            Box::new(STerm::App(
555                Box::new(STerm::Name("add".to_string())),
556                Box::new(STerm::Var(0)),
557            )),
558            Box::new(STerm::Lit(1)),
559        );
560        let mut subst = HashMap::new();
561        subst.insert(0, STerm::Lit(0));
562
563        let result = simplify_sterm(&x_plus_1, &subst, 100);
564        assert_eq!(result, STerm::Lit(1));
565    }
566
567    #[test]
568    fn test_check_goal_reflexive() {
569        // (Eq x x) should be provable
570        let x = make_svar(0);
571        let goal = make_sapp(make_sapp(make_sname("Eq"), x.clone()), x);
572        assert!(check_goal(&goal), "simp should prove x = x");
573    }
574
575    #[test]
576    fn test_check_goal_constant() {
577        // (Eq (add 2 3) 5) should be provable
578        let add_2_3 = make_sapp(make_sapp(make_sname("add"), make_slit(2)), make_slit(3));
579        let goal = make_sapp(make_sapp(make_sname("Eq"), add_2_3), make_slit(5));
580        assert!(check_goal(&goal), "simp should prove 2+3 = 5");
581    }
582
583    #[test]
584    fn test_check_goal_with_hypothesis() {
585        // (implies (Eq x 0) (Eq (add x 1) 1)) should be provable
586        let x = make_svar(0);
587        let zero = make_slit(0);
588        let one = make_slit(1);
589
590        let x_plus_1 = make_sapp(make_sapp(make_sname("add"), x.clone()), one.clone());
591        let hyp = make_sapp(make_sapp(make_sname("Eq"), x), zero);
592        let concl = make_sapp(make_sapp(make_sname("Eq"), x_plus_1), one);
593        let goal = make_sapp(make_sapp(make_sname("implies"), hyp), concl);
594
595        assert!(check_goal(&goal), "simp should prove x=0 -> x+1=1");
596    }
597
598    #[test]
599    fn test_check_goal_false_equality() {
600        // (Eq 2 3) should NOT be provable
601        let goal = make_sapp(make_sapp(make_sname("Eq"), make_slit(2)), make_slit(3));
602        assert!(!check_goal(&goal), "simp should NOT prove 2 = 3");
603    }
604}