Skip to main content

oxilean_std/term_rewriting/
types.rs

1//! Auto-generated module
2//!
3//! 🤖 Generated with [SplitRS](https://github.com/cool-japan/splitrs)
4
5use std::collections::{HashMap, HashSet, VecDeque};
6
7use super::functions::*;
8
9/// A rewriting logic theory (equational + rewrite rules).
10#[allow(dead_code)]
11#[derive(Debug, Clone)]
12pub struct RewritingLogicTheory {
13    /// Equational axioms (identity up to).
14    pub equations: Vec<(String, String)>,
15    /// Rewrite rules (labeled, one-directional).
16    pub rw_rules: Vec<(String, String, String)>,
17    /// Sort hierarchy.
18    pub sorts: Vec<String>,
19}
20#[allow(dead_code)]
21impl RewritingLogicTheory {
22    /// Creates an empty theory.
23    pub fn new() -> Self {
24        RewritingLogicTheory {
25            equations: Vec::new(),
26            rw_rules: Vec::new(),
27            sorts: Vec::new(),
28        }
29    }
30    /// Adds a sort.
31    pub fn add_sort(&mut self, sort: &str) {
32        if !self.sorts.contains(&sort.to_string()) {
33            self.sorts.push(sort.to_string());
34        }
35    }
36    /// Adds an equation.
37    pub fn add_equation(&mut self, lhs: &str, rhs: &str) {
38        self.equations.push((lhs.to_string(), rhs.to_string()));
39    }
40    /// Adds a rewrite rule.
41    pub fn add_rw_rule(&mut self, label: &str, lhs: &str, rhs: &str) {
42        self.rw_rules
43            .push((label.to_string(), lhs.to_string(), rhs.to_string()));
44    }
45    /// Returns the signature size.
46    pub fn signature_size(&self) -> usize {
47        self.equations.len() + self.rw_rules.len()
48    }
49    /// Checks if the theory is a pure equational theory (no rw rules).
50    pub fn is_equational(&self) -> bool {
51        self.rw_rules.is_empty()
52    }
53    /// Generates the entailment description.
54    pub fn entailment_description(&self) -> String {
55        format!(
56            "Rewriting logic theory with {} sorts, {} equations, {} rules",
57            self.sorts.len(),
58            self.equations.len(),
59            self.rw_rules.len()
60        )
61    }
62}
63/// A rewrite rule `lhs → rhs`.
64///
65/// Variables in `lhs` range over `Term::Var(i)`.  The rule is valid when
66/// every variable in `rhs` also occurs in `lhs`.
67#[derive(Debug, Clone)]
68pub struct Rule {
69    pub lhs: Term,
70    pub rhs: Term,
71}
72impl Rule {
73    /// Creates a new rewrite rule.
74    pub fn new(lhs: Term, rhs: Term) -> Self {
75        Rule { lhs, rhs }
76    }
77    /// Returns `true` if every variable in `rhs` occurs in `lhs`.
78    pub fn is_valid(&self) -> bool {
79        self.rhs.vars().is_subset(&self.lhs.vars())
80    }
81    /// Returns `true` if `lhs` is linear (each variable occurs at most once).
82    pub fn is_left_linear(&self) -> bool {
83        let mut seen = HashSet::new();
84        fn check(t: &Term, seen: &mut HashSet<u32>) -> bool {
85            match t {
86                Term::Var(i) => seen.insert(*i),
87                Term::Fun(_, args) => args.iter().all(|a| check(a, seen)),
88            }
89        }
90        check(&self.lhs, &mut seen)
91    }
92    /// Returns a renamed copy of this rule with variables shifted by `offset`.
93    pub fn rename(&self, offset: u32) -> Rule {
94        fn shift(t: &Term, off: u32) -> Term {
95            match t {
96                Term::Var(i) => Term::Var(i + off),
97                Term::Fun(f, args) => {
98                    Term::Fun(f.clone(), args.iter().map(|a| shift(a, off)).collect())
99                }
100            }
101        }
102        Rule {
103            lhs: shift(&self.lhs, offset),
104            rhs: shift(&self.rhs, offset),
105        }
106    }
107}
108/// An E-graph for equality saturation.
109#[allow(dead_code)]
110#[derive(Debug, Clone)]
111pub struct EGraph {
112    /// E-classes by ID.
113    pub classes: Vec<EClass>,
114    /// Union-find parent array.
115    pub parent: Vec<usize>,
116    /// Number of e-nodes total.
117    pub total_nodes: usize,
118}
119#[allow(dead_code)]
120impl EGraph {
121    /// Creates an empty E-graph.
122    pub fn new() -> Self {
123        EGraph {
124            classes: Vec::new(),
125            parent: Vec::new(),
126            total_nodes: 0,
127        }
128    }
129    /// Adds a new e-class with a single node.
130    pub fn add_node(&mut self, node: &str) -> usize {
131        let id = self.classes.len();
132        self.classes.push(EClass {
133            id,
134            nodes: vec![node.to_string()],
135            size: 1,
136        });
137        self.parent.push(id);
138        self.total_nodes += 1;
139        id
140    }
141    /// Finds the canonical class ID (path-compressed union-find).
142    pub fn find(&mut self, id: usize) -> usize {
143        if self.parent[id] == id {
144            return id;
145        }
146        let root = self.find(self.parent[id]);
147        self.parent[id] = root;
148        root
149    }
150    /// Merges two e-classes.
151    pub fn union(&mut self, id1: usize, id2: usize) {
152        let r1 = self.find(id1);
153        let r2 = self.find(id2);
154        if r1 == r2 {
155            return;
156        }
157        self.parent[r2] = r1;
158        let nodes2 = self.classes[r2].nodes.clone();
159        let size2 = self.classes[r2].size;
160        self.classes[r1].nodes.extend(nodes2);
161        self.classes[r1].size += size2;
162    }
163    /// Checks if two nodes are in the same e-class.
164    pub fn are_equal(&mut self, id1: usize, id2: usize) -> bool {
165        self.find(id1) == self.find(id2)
166    }
167    /// Returns the number of e-classes.
168    pub fn num_classes(&self) -> usize {
169        self.classes.len()
170    }
171}
172/// A first-order term over a signature.
173#[derive(Debug, Clone, PartialEq, Eq, Hash)]
174pub enum Term {
175    /// A variable `x_i`.
176    Var(u32),
177    /// A function application `f(t1, ..., tn)`.
178    Fun(String, Vec<Term>),
179}
180impl Term {
181    /// Returns the set of variable indices occurring in this term.
182    pub fn vars(&self) -> HashSet<u32> {
183        match self {
184            Term::Var(i) => {
185                let mut s = HashSet::new();
186                s.insert(*i);
187                s
188            }
189            Term::Fun(_, args) => args.iter().flat_map(|a| a.vars()).collect(),
190        }
191    }
192    /// Returns `true` if this term is ground (contains no variables).
193    pub fn is_ground(&self) -> bool {
194        self.vars().is_empty()
195    }
196    /// Returns the depth of this term.
197    pub fn depth(&self) -> usize {
198        match self {
199            Term::Var(_) => 0,
200            Term::Fun(_, args) => 1 + args.iter().map(|a| a.depth()).max().unwrap_or(0),
201        }
202    }
203    /// Apply a substitution to this term.
204    pub fn apply(&self, subst: &Substitution) -> Term {
205        match self {
206            Term::Var(i) => subst.map.get(i).cloned().unwrap_or(Term::Var(*i)),
207            Term::Fun(f, args) => {
208                Term::Fun(f.clone(), args.iter().map(|a| a.apply(subst)).collect())
209            }
210        }
211    }
212    /// Returns `true` if the term `other` occurs in this term.
213    pub fn contains(&self, other: &Term) -> bool {
214        if self == other {
215            return true;
216        }
217        match self {
218            Term::Var(_) => false,
219            Term::Fun(_, args) => args.iter().any(|a| a.contains(other)),
220        }
221    }
222    /// Returns the subterm at the given position (empty = root).
223    pub fn subterm_at(&self, pos: &[usize]) -> Option<&Term> {
224        if pos.is_empty() {
225            return Some(self);
226        }
227        match self {
228            Term::Fun(_, args) => {
229                let idx = pos[0];
230                args.get(idx)?.subterm_at(&pos[1..])
231            }
232            Term::Var(_) => None,
233        }
234    }
235    /// Replace the subterm at `pos` with `replacement`, returning the new term.
236    pub fn replace_at(&self, pos: &[usize], replacement: Term) -> Term {
237        if pos.is_empty() {
238            return replacement;
239        }
240        match self {
241            Term::Fun(f, args) => {
242                let idx = pos[0];
243                let mut new_args = args.clone();
244                if idx < new_args.len() {
245                    new_args[idx] = new_args[idx].replace_at(&pos[1..], replacement);
246                }
247                Term::Fun(f.clone(), new_args)
248            }
249            Term::Var(_) => self.clone(),
250        }
251    }
252}
253/// A narrowing system: computes all possible narrowing steps of a term w.r.t. a TRS.
254///
255/// Narrowing generalises rewriting to terms with variables: a term `t` narrows
256/// to `t'` if there is a substitution σ and a rule `l → r` such that
257/// `t[σ]_p = l[σ]` for some non-variable position `p`, and `t' = t[r]_p[σ]`.
258#[derive(Debug, Clone)]
259pub struct NarrowingSystem {
260    /// The underlying TRS used for narrowing steps.
261    pub trs: Trs,
262    /// Variable counter for generating fresh variable names.
263    pub var_counter: u32,
264}
265impl NarrowingSystem {
266    /// Create a new narrowing system wrapping the given TRS.
267    pub fn new(trs: Trs) -> Self {
268        NarrowingSystem {
269            trs,
270            var_counter: 10000,
271        }
272    }
273    /// Fresh variable index (offset to avoid clashing with term variables).
274    fn fresh_var(&mut self) -> u32 {
275        let v = self.var_counter;
276        self.var_counter += 1;
277        v
278    }
279    /// Collect all non-variable positions in term `t`.
280    fn non_var_positions(t: &Term) -> Vec<Vec<usize>> {
281        match t {
282            Term::Var(_) => vec![],
283            Term::Fun(_, args) => {
284                let mut out = vec![vec![]];
285                for (i, a) in args.iter().enumerate() {
286                    for mut p in Self::non_var_positions(a) {
287                        let mut full = vec![i];
288                        full.append(&mut p);
289                        out.push(full);
290                    }
291                }
292                out
293            }
294        }
295    }
296    /// Compute one level of narrowing steps from term `t`.
297    ///
298    /// Returns a list of `(substitution, narrowed_term)` pairs.
299    pub fn narrow_step(&mut self, t: &Term) -> Vec<(Substitution, Term)> {
300        let mut results = Vec::new();
301        let positions = Self::non_var_positions(t);
302        for pos in &positions {
303            if let Some(sub) = t.subterm_at(pos) {
304                for rule in self.trs.rules.clone() {
305                    let offset = self.fresh_var();
306                    let renamed = rule.rename(offset);
307                    if let Some(sigma) = unify(sub, &renamed.lhs) {
308                        let new_term = t.replace_at(pos, renamed.rhs.apply(&sigma)).apply(&sigma);
309                        results.push((sigma, new_term));
310                    }
311                }
312            }
313        }
314        results
315    }
316    /// Basic narrowing: perform up to `depth` levels of narrowing from `t`.
317    ///
318    /// Returns all reachable (substitution, term) pairs.
319    pub fn basic_narrow(&mut self, t: &Term, depth: usize) -> Vec<(Substitution, Term)> {
320        if depth == 0 {
321            return vec![(Substitution::new(), t.clone())];
322        }
323        let steps = self.narrow_step(t);
324        if steps.is_empty() {
325            return vec![(Substitution::new(), t.clone())];
326        }
327        let mut all = Vec::new();
328        for (sigma, t2) in steps {
329            let deeper = self.basic_narrow(&t2, depth - 1);
330            for (sigma2, t3) in deeper {
331                let combined = sigma2.compose(&sigma);
332                all.push((combined, t3));
333            }
334        }
335        all
336    }
337    /// Unification via narrowing: tries to unify `s` and `t` by narrowing `s` toward `t`.
338    pub fn narrowing_unify(&mut self, s: &Term, t: &Term, depth: usize) -> Option<Substitution> {
339        let narrowings = self.basic_narrow(s, depth);
340        for (sigma, s_narrowed) in narrowings {
341            if let Some(sigma2) = unify(&s_narrowed, t) {
342                return Some(sigma2.compose(&sigma));
343            }
344        }
345        None
346    }
347}
348/// Knuth-Bendix completion state.
349pub struct KBState {
350    /// Current set of rules.
351    pub rules: Vec<Rule>,
352    /// Pending equations to orient/simplify.
353    pub equations: VecDeque<(Term, Term)>,
354    /// Maximum completion steps.
355    pub max_steps: usize,
356}
357impl KBState {
358    /// Creates a new KB state from an initial set of equations.
359    pub fn new(equations: Vec<(Term, Term)>, max_steps: usize) -> Self {
360        KBState {
361            rules: Vec::new(),
362            equations: VecDeque::from(equations),
363            max_steps,
364        }
365    }
366    /// Runs the Knuth-Bendix completion algorithm.
367    ///
368    /// Returns `Ok(trs)` if completion succeeds, `Err(msg)` otherwise.
369    pub fn complete(&mut self, order: TermOrdering) -> Result<Trs, String> {
370        let mut steps = 0;
371        while let Some((s, t)) = self.equations.pop_front() {
372            if steps >= self.max_steps {
373                return Err("KB completion: exceeded max steps".into());
374            }
375            steps += 1;
376            let trs = Trs {
377                rules: self.rules.clone(),
378            };
379            let s = trs.normalize_innermost(&s, 200);
380            let t = trs.normalize_innermost(&t, 200);
381            if s == t {
382                continue;
383            }
384            let (lhs, rhs) = match order(&s, &t) {
385                std::cmp::Ordering::Greater => (s, t),
386                std::cmp::Ordering::Less => (t, s),
387                std::cmp::Ordering::Equal => {
388                    return Err(format!("KB completion: cannot orient {} = {}", s, t));
389                }
390            };
391            let new_rule = Rule::new(lhs.clone(), rhs.clone());
392            let mut new_rules: Vec<Rule> = Vec::new();
393            let mut deferred: Vec<(Term, Term)> = Vec::new();
394            for r in &self.rules {
395                let nr = Trs {
396                    rules: vec![new_rule.clone()],
397                };
398                let lhs2 = nr.normalize_innermost(&r.lhs, 200);
399                let rhs2 = nr.normalize_innermost(&r.rhs, 200);
400                if lhs2 == r.lhs && rhs2 == r.rhs {
401                    new_rules.push(r.clone());
402                } else {
403                    deferred.push((lhs2, rhs2));
404                }
405            }
406            self.rules = new_rules;
407            self.rules.push(new_rule.clone());
408            let all_rules = self.rules.clone();
409            for (i, r) in all_rules.iter().enumerate() {
410                for pair in critical_pairs(&new_rule, r, (i * 2000 + 1) as u32) {
411                    self.equations.push_back(pair);
412                }
413                for pair in critical_pairs(r, &new_rule, (i * 2000 + 3001) as u32) {
414                    self.equations.push_back(pair);
415                }
416            }
417            for eq in deferred {
418                self.equations.push_back(eq);
419            }
420        }
421        Ok(Trs {
422            rules: self.rules.clone(),
423        })
424    }
425}
426/// A node in the dependency pair graph.
427#[derive(Debug, Clone, PartialEq, Eq, Hash)]
428pub struct DependencyPairNode {
429    /// The left-hand side of the dependency pair (root symbol of a subterm call).
430    pub lhs_root: String,
431    /// The right-hand side root symbol (the called function).
432    pub rhs_root: String,
433}
434/// A Term Rewriting System: a list of rewrite rules.
435#[derive(Debug, Clone, Default)]
436pub struct Trs {
437    pub rules: Vec<Rule>,
438}
439impl Trs {
440    /// Creates an empty TRS.
441    pub fn new() -> Self {
442        Trs { rules: Vec::new() }
443    }
444    /// Adds a rule to the TRS.
445    pub fn add_rule(&mut self, rule: Rule) {
446        self.rules.push(rule);
447    }
448    /// Returns `true` if all rules are left-linear.
449    pub fn is_left_linear(&self) -> bool {
450        self.rules.iter().all(|r| r.is_left_linear())
451    }
452    /// One-step innermost reduction: reduces the leftmost-innermost redex.
453    pub fn reduce_innermost(&self, term: &Term) -> Option<Term> {
454        if let Term::Fun(f, args) = term {
455            for (i, arg) in args.iter().enumerate() {
456                if let Some(reduced) = self.reduce_innermost(arg) {
457                    let mut new_args = args.clone();
458                    new_args[i] = reduced;
459                    return Some(Term::Fun(f.clone(), new_args));
460                }
461            }
462        }
463        for rule in &self.rules {
464            let renamed = rule.rename(1000);
465            if let Some(subst) = unify(&renamed.lhs, term) {
466                return Some(renamed.rhs.apply(&subst));
467            }
468        }
469        None
470    }
471    /// One-step outermost reduction: reduces the leftmost-outermost redex.
472    pub fn reduce_outermost(&self, term: &Term) -> Option<Term> {
473        for rule in &self.rules {
474            let renamed = rule.rename(1000);
475            if let Some(subst) = unify(&renamed.lhs, term) {
476                return Some(renamed.rhs.apply(&subst));
477            }
478        }
479        if let Term::Fun(f, args) = term {
480            for (i, arg) in args.iter().enumerate() {
481                if let Some(reduced) = self.reduce_outermost(arg) {
482                    let mut new_args = args.clone();
483                    new_args[i] = reduced;
484                    return Some(Term::Fun(f.clone(), new_args));
485                }
486            }
487        }
488        None
489    }
490    /// Fully reduces a term to normal form under innermost strategy (up to `limit` steps).
491    pub fn normalize_innermost(&self, term: &Term, limit: usize) -> Term {
492        let mut current = term.clone();
493        for _ in 0..limit {
494            match self.reduce_innermost(&current) {
495                Some(next) => current = next,
496                None => break,
497            }
498        }
499        current
500    }
501    /// Fully reduces a term to normal form under outermost strategy (up to `limit` steps).
502    pub fn normalize_outermost(&self, term: &Term, limit: usize) -> Term {
503        let mut current = term.clone();
504        for _ in 0..limit {
505            match self.reduce_outermost(&current) {
506                Some(next) => current = next,
507                None => break,
508            }
509        }
510        current
511    }
512    /// Checks whether `t` is a normal form (no rule applies at any position).
513    pub fn is_normal_form(&self, t: &Term) -> bool {
514        self.reduce_outermost(t).is_none()
515    }
516}
517/// Available reduction strategies for a TRS.
518#[derive(Debug, Clone, Copy, PartialEq, Eq)]
519pub enum Strategy {
520    /// Innermost (leftmost-innermost): reduce innermost redexes first.
521    Innermost,
522    /// Outermost (leftmost-outermost): reduce outermost redexes first.
523    Outermost,
524    /// Parallel: reduce all outermost redexes simultaneously.
525    Parallel,
526    /// Lazy: outermost with sharing (call-by-need approximation).
527    Lazy,
528}
529/// A polynomial interpretation for proving termination of a TRS.
530///
531/// Each function symbol `f` of arity `n` is mapped to a polynomial
532/// p_f(x_1, …, x_n) with natural-number coefficients.  A TRS is terminating
533/// if for every rule `l → r`, the polynomial interpretation satisfies
534/// Pol(l) > Pol(r) (as natural numbers for all assignments ≥ 0).
535#[derive(Debug, Clone)]
536pub struct PolynomialInterpretation {
537    /// Maps function symbol names to their polynomial coefficients.
538    ///
539    /// For a unary symbol the vector `[c0, c1]` represents c0 + c1 * x1.
540    /// For a binary symbol `[c0, c1, c2]` represents c0 + c1*x1 + c2*x2.
541    /// Constants (arity 0) are represented as `[c0]`.
542    pub interpretations: HashMap<String, Vec<i64>>,
543}
544impl PolynomialInterpretation {
545    /// Create an empty polynomial interpretation.
546    pub fn new() -> Self {
547        PolynomialInterpretation {
548            interpretations: HashMap::new(),
549        }
550    }
551    /// Register an interpretation for symbol `f`: coefficients `[c0, c1, …]`.
552    pub fn set(&mut self, symbol: impl Into<String>, coefficients: Vec<i64>) {
553        self.interpretations.insert(symbol.into(), coefficients);
554    }
555    /// Evaluate the polynomial interpretation of term `t` given variable assignment.
556    ///
557    /// Variable `Var(i)` maps to `assignment[i]`.
558    pub fn eval(&self, t: &Term, assignment: &[i64]) -> i64 {
559        match t {
560            Term::Var(i) => assignment.get(*i as usize).copied().unwrap_or(0),
561            Term::Fun(f, args) => {
562                let arg_vals: Vec<i64> = args.iter().map(|a| self.eval(a, assignment)).collect();
563                if let Some(coeffs) = self.interpretations.get(f) {
564                    let mut result = coeffs.first().copied().unwrap_or(0);
565                    for (k, &c) in coeffs.iter().skip(1).enumerate() {
566                        result += c * arg_vals.get(k).copied().unwrap_or(0);
567                    }
568                    result
569                } else {
570                    0
571                }
572            }
573        }
574    }
575    /// Check whether rule `lhs → rhs` is oriented by this interpretation for
576    /// all variable assignments in `[0, max_val]^n`.
577    pub fn orients_rule(&self, rule: &Rule, max_val: i64) -> bool {
578        let vars: HashSet<u32> = rule.lhs.vars().union(&rule.rhs.vars()).copied().collect();
579        let n = vars.iter().copied().max().map(|m| m + 1).unwrap_or(0) as usize;
580        let mut assignment = vec![0i64; n];
581        loop {
582            let lv = self.eval(&rule.lhs, &assignment);
583            let rv = self.eval(&rule.rhs, &assignment);
584            if lv <= rv {
585                return false;
586            }
587            let mut carry = true;
588            for a in assignment.iter_mut() {
589                if carry {
590                    *a += 1;
591                    if *a > max_val {
592                        *a = 0;
593                    } else {
594                        carry = false;
595                    }
596                }
597            }
598            if carry {
599                break;
600            }
601        }
602        true
603    }
604    /// Check whether this interpretation proves termination of the given TRS
605    /// (all rules are strictly oriented for assignments in [0, max_val]^n).
606    pub fn proves_termination(&self, trs: &Trs, max_val: i64) -> bool {
607        trs.rules.iter().all(|r| self.orients_rule(r, max_val))
608    }
609}
610/// Knuth-Bendix completion data.
611#[allow(dead_code)]
612#[derive(Debug, Clone)]
613pub struct KnuthBendixData {
614    /// Current set of rules.
615    pub rules: Vec<(String, String)>,
616    /// Critical pairs found.
617    pub critical_pairs: Vec<(String, String)>,
618    /// Whether the system is confluent.
619    pub is_confluent: bool,
620    /// Termination order description.
621    pub order: String,
622}
623#[allow(dead_code)]
624impl KnuthBendixData {
625    /// Creates Knuth-Bendix data.
626    pub fn new(order: &str) -> Self {
627        KnuthBendixData {
628            rules: Vec::new(),
629            critical_pairs: Vec::new(),
630            is_confluent: false,
631            order: order.to_string(),
632        }
633    }
634    /// Adds a rule (already oriented).
635    pub fn add_oriented_rule(&mut self, lhs: &str, rhs: &str) {
636        self.rules.push((lhs.to_string(), rhs.to_string()));
637    }
638    /// Registers a critical pair.
639    pub fn add_critical_pair(&mut self, left: &str, right: &str) {
640        self.critical_pairs
641            .push((left.to_string(), right.to_string()));
642    }
643    /// Marks as confluent (after resolving all critical pairs).
644    pub fn mark_confluent(&mut self) {
645        self.is_confluent = true;
646    }
647    /// Returns the convergent TRS description.
648    pub fn description(&self) -> String {
649        format!(
650            "KB({} rules, {} crit pairs, confluent={})",
651            self.rules.len(),
652            self.critical_pairs.len(),
653            self.is_confluent
654        )
655    }
656    /// Counts joinable critical pairs.
657    pub fn num_rules(&self) -> usize {
658        self.rules.len()
659    }
660}
661/// Represents a reduction strategy for a term rewriting system.
662#[allow(dead_code)]
663#[derive(Debug, Clone, PartialEq)]
664pub enum ReductionStrategy {
665    /// Leftmost-outermost (normal order).
666    LeftmostOutermost,
667    /// Leftmost-innermost (applicative order).
668    LeftmostInnermost,
669    /// Rightmost-outermost.
670    RightmostOutermost,
671    /// Parallel outermost.
672    ParallelOutermost,
673    /// Needed reduction (lazy evaluation).
674    Needed,
675}
676#[allow(dead_code)]
677impl ReductionStrategy {
678    /// Returns the name of the strategy.
679    pub fn name(&self) -> &str {
680        match self {
681            ReductionStrategy::LeftmostOutermost => "Leftmost-Outermost (Normal Order)",
682            ReductionStrategy::LeftmostInnermost => "Leftmost-Innermost (Applicative Order)",
683            ReductionStrategy::RightmostOutermost => "Rightmost-Outermost",
684            ReductionStrategy::ParallelOutermost => "Parallel Outermost",
685            ReductionStrategy::Needed => "Needed Reduction (Lazy)",
686        }
687    }
688    /// Checks if this strategy is complete (finds normal form if it exists).
689    pub fn is_complete(&self) -> bool {
690        matches!(
691            self,
692            ReductionStrategy::LeftmostOutermost | ReductionStrategy::Needed
693        )
694    }
695    /// Checks if this strategy is normalizing for orthogonal TRS.
696    pub fn normalizing_for_orthogonal(&self) -> bool {
697        matches!(
698            self,
699            ReductionStrategy::LeftmostOutermost
700                | ReductionStrategy::Needed
701                | ReductionStrategy::ParallelOutermost
702        )
703    }
704    /// Returns the corresponding lambda calculus evaluation order.
705    pub fn lambda_calculus_analog(&self) -> &str {
706        match self {
707            ReductionStrategy::LeftmostOutermost => "Normal order reduction",
708            ReductionStrategy::LeftmostInnermost => "Call-by-value",
709            ReductionStrategy::Needed => "Call-by-need (lazy)",
710            _ => "No direct analog",
711        }
712    }
713}
714/// A substitution: partial map from variable indices to terms.
715#[derive(Debug, Clone, Default)]
716pub struct Substitution {
717    pub map: HashMap<u32, Term>,
718}
719impl Substitution {
720    /// Creates an empty substitution.
721    pub fn new() -> Self {
722        Substitution {
723            map: HashMap::new(),
724        }
725    }
726    /// Binds variable `v` to `t`.
727    pub fn bind(&mut self, v: u32, t: Term) {
728        self.map.insert(v, t);
729    }
730    /// Compose `self` after `other`: `(self ∘ other)(x) = self(other(x))`.
731    pub fn compose(&self, other: &Substitution) -> Substitution {
732        let mut result = Substitution::new();
733        for (&v, t) in &other.map {
734            result.bind(v, t.apply(self));
735        }
736        for (&v, t) in &self.map {
737            if !other.map.contains_key(&v) {
738                result.bind(v, t.clone());
739            }
740        }
741        result
742    }
743}
744/// Represents an E-class in an E-graph.
745#[allow(dead_code)]
746#[derive(Debug, Clone)]
747pub struct EClass {
748    /// Canonical ID of this e-class.
749    pub id: usize,
750    /// E-nodes (terms) in this e-class.
751    pub nodes: Vec<String>,
752    /// Size of the e-class.
753    pub size: usize,
754}
755/// A (bottom-up) tree automaton for recognizing regular tree languages.
756///
757/// States are `usize` indices.  Transitions are of the form
758/// `f(q_1, …, q_n) → q` meaning: if children of `f` are in states q_1, …, q_n
759/// then `f(…)` can be assigned state `q`.
760#[derive(Debug, Clone)]
761pub struct TreeAutomaton {
762    /// Number of states.
763    pub num_states: usize,
764    /// Set of accepting (final) states.
765    pub final_states: HashSet<usize>,
766    /// Transitions: maps `(symbol, child_states)` to a set of target states.
767    pub transitions: HashMap<(String, Vec<usize>), HashSet<usize>>,
768}
769impl TreeAutomaton {
770    /// Create an automaton with `num_states` states.
771    pub fn new(num_states: usize) -> Self {
772        TreeAutomaton {
773            num_states,
774            final_states: HashSet::new(),
775            transitions: HashMap::new(),
776        }
777    }
778    /// Mark state `q` as a final (accepting) state.
779    pub fn add_final(&mut self, q: usize) {
780        self.final_states.insert(q);
781    }
782    /// Add a transition: when `f` is applied to children in states `child_states`,
783    /// allow reaching state `target`.
784    pub fn add_transition(
785        &mut self,
786        symbol: impl Into<String>,
787        child_states: Vec<usize>,
788        target: usize,
789    ) {
790        self.transitions
791            .entry((symbol.into(), child_states))
792            .or_default()
793            .insert(target);
794    }
795    /// Run the automaton bottom-up on term `t`.
796    ///
797    /// Returns the set of states reachable at the root.
798    pub fn run(&self, t: &Term) -> HashSet<usize> {
799        match t {
800            Term::Var(_) => (0..self.num_states).collect(),
801            Term::Fun(f, args) => {
802                let arg_states: Vec<HashSet<usize>> = args.iter().map(|a| self.run(a)).collect();
803                let mut result = HashSet::new();
804                if args.is_empty() {
805                    if let Some(targets) = self.transitions.get(&(f.clone(), vec![])) {
806                        result.extend(targets);
807                    }
808                } else {
809                    let combinations = Self::cartesian(&arg_states);
810                    for combo in combinations {
811                        if let Some(targets) = self.transitions.get(&(f.clone(), combo)) {
812                            result.extend(targets);
813                        }
814                    }
815                }
816                result
817            }
818        }
819    }
820    /// Cartesian product of sets for computing state combinations.
821    fn cartesian(sets: &[HashSet<usize>]) -> Vec<Vec<usize>> {
822        if sets.is_empty() {
823            return vec![vec![]];
824        }
825        let mut result = vec![vec![]];
826        for set in sets {
827            let mut new_result = Vec::new();
828            for combo in &result {
829                let mut sorted: Vec<usize> = set.iter().copied().collect();
830                sorted.sort_unstable();
831                for &state in &sorted {
832                    let mut new_combo = combo.clone();
833                    new_combo.push(state);
834                    new_result.push(new_combo);
835                }
836            }
837            result = new_result;
838        }
839        result
840    }
841    /// Check whether term `t` is accepted (root reachable in a final state).
842    pub fn accepts(&self, t: &Term) -> bool {
843        let states = self.run(t);
844        states.iter().any(|s| self.final_states.contains(s))
845    }
846    /// Returns `true` if the language is empty (no ground term is accepted).
847    ///
848    /// This is a simple fixpoint check: compute the set of states reachable
849    /// by any ground term, then check if any final state is reachable.
850    pub fn is_empty(&self) -> bool {
851        let mut reachable: HashSet<usize> = HashSet::new();
852        let mut changed = true;
853        while changed {
854            changed = false;
855            for ((_, child_req), targets) in &self.transitions {
856                if child_req.iter().all(|s| reachable.contains(s)) {
857                    for &t in targets {
858                        if reachable.insert(t) {
859                            changed = true;
860                        }
861                    }
862                }
863            }
864        }
865        !reachable.iter().any(|s| self.final_states.contains(s))
866    }
867}
868/// A string rewriting system rule: `lhs → rhs` over a string alphabet.
869#[derive(Debug, Clone)]
870pub struct SrsRule {
871    pub lhs: String,
872    pub rhs: String,
873}
874/// A string rewriting system (monoid presentation).
875#[allow(dead_code)]
876#[derive(Debug, Clone)]
877pub struct StringRewritingSystem {
878    /// Alphabet.
879    pub alphabet: Vec<char>,
880    /// Rules: (lhs, rhs) as strings over the alphabet.
881    pub rules: Vec<(String, String)>,
882}
883#[allow(dead_code)]
884impl StringRewritingSystem {
885    /// Creates a new SRS.
886    pub fn new(alphabet: Vec<char>) -> Self {
887        StringRewritingSystem {
888            alphabet,
889            rules: Vec::new(),
890        }
891    }
892    /// Adds a rule.
893    pub fn add_rule(&mut self, lhs: &str, rhs: &str) {
894        self.rules.push((lhs.to_string(), rhs.to_string()));
895    }
896    /// Applies one step of rewriting to a string (leftmost first match).
897    pub fn rewrite_step(&self, s: &str) -> Option<String> {
898        for (lhs, rhs) in &self.rules {
899            if let Some(pos) = s.find(lhs.as_str()) {
900                let result = format!("{}{}{}", &s[..pos], rhs, &s[pos + lhs.len()..]);
901                return Some(result);
902            }
903        }
904        None
905    }
906    /// Applies rewriting until normal form (limit iterations).
907    pub fn normalize(&self, s: &str, max_steps: usize) -> String {
908        let mut current = s.to_string();
909        for _ in 0..max_steps {
910            match self.rewrite_step(&current) {
911                Some(next) => current = next,
912                None => break,
913            }
914        }
915        current
916    }
917    /// Checks if two strings are equal modulo rewriting (up to max_steps).
918    pub fn are_equal_modulo(&self, s1: &str, s2: &str, max_steps: usize) -> bool {
919        let n1 = self.normalize(s1, max_steps);
920        let n2 = self.normalize(s2, max_steps);
921        n1 == n2
922    }
923    /// Returns the number of rules.
924    pub fn num_rules(&self) -> usize {
925        self.rules.len()
926    }
927}
928/// An equational theory presented by equations.
929#[derive(Debug, Clone, Default)]
930pub struct EquationalTheory {
931    /// Equational axioms `(lhs, rhs)`.
932    pub axioms: Vec<(Term, Term)>,
933}
934impl EquationalTheory {
935    /// Creates an empty equational theory.
936    pub fn new() -> Self {
937        EquationalTheory { axioms: Vec::new() }
938    }
939    /// Adds an axiom `lhs = rhs` to the theory.
940    pub fn add_axiom(&mut self, lhs: Term, rhs: Term) {
941        self.axioms.push((lhs, rhs));
942    }
943    /// Naive E-unification by closure: attempts to unify `s` and `t` modulo
944    /// the equational theory by rewriting.  Returns a substitution if found.
945    pub fn e_unify(&self, s: &Term, t: &Term, depth_limit: usize) -> Option<Substitution> {
946        let mut trs = Trs::new();
947        for (lhs, rhs) in &self.axioms {
948            trs.add_rule(Rule::new(lhs.clone(), rhs.clone()));
949            trs.add_rule(Rule::new(rhs.clone(), lhs.clone()));
950        }
951        let s_nf = trs.normalize_innermost(s, depth_limit);
952        let t_nf = trs.normalize_innermost(t, depth_limit);
953        unify(&s_nf, &t_nf)
954    }
955}
956/// A String Rewriting System.
957#[derive(Debug, Clone)]
958pub struct Srs {
959    pub rules: Vec<SrsRule>,
960}
961impl Srs {
962    /// Creates an empty SRS.
963    pub fn new() -> Self {
964        Srs { rules: Vec::new() }
965    }
966    /// Adds a rule `lhs → rhs`.
967    pub fn add_rule(&mut self, lhs: impl Into<String>, rhs: impl Into<String>) {
968        self.rules.push(SrsRule {
969            lhs: lhs.into(),
970            rhs: rhs.into(),
971        });
972    }
973    /// One-step reduction: applies the first applicable rule at any position.
974    pub fn step(&self, s: &str) -> Option<String> {
975        for rule in &self.rules {
976            if let Some(pos) = s.find(&rule.lhs) {
977                let result = format!("{}{}{}", &s[..pos], rule.rhs, &s[pos + rule.lhs.len()..]);
978                return Some(result);
979            }
980        }
981        None
982    }
983    /// Fully reduce a string to its normal form (up to `limit` steps).
984    pub fn normalize(&self, s: &str, limit: usize) -> String {
985        let mut current = s.to_owned();
986        for _ in 0..limit {
987            match self.step(&current) {
988                Some(next) => current = next,
989                None => break,
990            }
991        }
992        current
993    }
994    /// Checks whether two strings are equivalent under the congruence generated
995    /// by this SRS (compares normal forms).
996    pub fn word_equivalent(&self, s: &str, t: &str, limit: usize) -> bool {
997        self.normalize(s, limit) == self.normalize(t, limit)
998    }
999}
1000/// A dependency pair graph for a TRS — used to prove termination.
1001///
1002/// In the dependency pair method a "dependency pair" is derived from each rule
1003/// `f(l) → C[g(r)]` where `g` is a defined symbol.  Termination is equivalent
1004/// to the non-existence of infinite chains in the dependency pair graph.
1005#[derive(Debug, Clone, Default)]
1006pub struct DependencyPairGraph {
1007    /// The dependency pairs (nodes of the graph).
1008    pub pairs: Vec<DependencyPairNode>,
1009    /// Edges: `edges[i]` is the set of indices j such that pair i may precede pair j.
1010    pub edges: Vec<HashSet<usize>>,
1011}
1012impl DependencyPairGraph {
1013    /// Create an empty dependency pair graph.
1014    pub fn new() -> Self {
1015        DependencyPairGraph {
1016            pairs: Vec::new(),
1017            edges: Vec::new(),
1018        }
1019    }
1020    /// Add a dependency pair (lhs_root, rhs_root) and return its index.
1021    pub fn add_pair(&mut self, lhs_root: impl Into<String>, rhs_root: impl Into<String>) -> usize {
1022        let idx = self.pairs.len();
1023        self.pairs.push(DependencyPairNode {
1024            lhs_root: lhs_root.into(),
1025            rhs_root: rhs_root.into(),
1026        });
1027        self.edges.push(HashSet::new());
1028        idx
1029    }
1030    /// Add an edge from pair `from` to pair `to`.
1031    pub fn add_edge(&mut self, from: usize, to: usize) {
1032        if from < self.edges.len() {
1033            self.edges[from].insert(to);
1034        }
1035    }
1036    /// Find all strongly connected components (SCCs) via a simple iterative DFS.
1037    ///
1038    /// Returns a list of SCCs, each represented as a list of pair indices.
1039    pub fn sccs(&self) -> Vec<Vec<usize>> {
1040        let n = self.pairs.len();
1041        let mut visited = vec![false; n];
1042        let mut finish_order: Vec<usize> = Vec::new();
1043        for start in 0..n {
1044            if !visited[start] {
1045                let mut stack: Vec<(usize, bool)> = vec![(start, false)];
1046                while let Some((node, done)) = stack.pop() {
1047                    if done {
1048                        finish_order.push(node);
1049                        continue;
1050                    }
1051                    if visited[node] {
1052                        continue;
1053                    }
1054                    visited[node] = true;
1055                    stack.push((node, true));
1056                    for &next in &self.edges[node] {
1057                        if !visited[next] {
1058                            stack.push((next, false));
1059                        }
1060                    }
1061                }
1062            }
1063        }
1064        let mut rev_edges: Vec<HashSet<usize>> = vec![HashSet::new(); n];
1065        for i in 0..n {
1066            for &j in &self.edges[i] {
1067                rev_edges[j].insert(i);
1068            }
1069        }
1070        let mut component_id = vec![usize::MAX; n];
1071        let mut sccs: Vec<Vec<usize>> = Vec::new();
1072        let mut visited2 = vec![false; n];
1073        for &start in finish_order.iter().rev() {
1074            if visited2[start] {
1075                continue;
1076            }
1077            let scc_id = sccs.len();
1078            let mut component: Vec<usize> = Vec::new();
1079            let mut stack: Vec<usize> = vec![start];
1080            while let Some(node) = stack.pop() {
1081                if visited2[node] {
1082                    continue;
1083                }
1084                visited2[node] = true;
1085                component_id[node] = scc_id;
1086                component.push(node);
1087                for &prev in &rev_edges[node] {
1088                    if !visited2[prev] {
1089                        stack.push(prev);
1090                    }
1091                }
1092            }
1093            sccs.push(component);
1094        }
1095        let _ = component_id;
1096        sccs
1097    }
1098    /// Returns `true` if all SCCs are trivial (size ≤ 1 and no self-loops).
1099    ///
1100    /// A TRS is terminating iff its dependency pair graph has no infinite chains,
1101    /// which holds when all non-trivial SCCs can be removed by a reduction pair.
1102    pub fn all_sccs_trivial(&self) -> bool {
1103        for scc in self.sccs() {
1104            if scc.len() > 1 {
1105                return false;
1106            }
1107            if let Some(&node) = scc.first() {
1108                if self.edges[node].contains(&node) {
1109                    return false;
1110                }
1111            }
1112        }
1113        true
1114    }
1115    /// Derive dependency pairs from a Trs by inspecting rule root symbols.
1116    pub fn from_trs(trs: &Trs) -> Self {
1117        let mut graph = DependencyPairGraph::new();
1118        let defined: HashSet<String> = trs
1119            .rules
1120            .iter()
1121            .filter_map(|r| {
1122                if let Term::Fun(f, _) = &r.lhs {
1123                    Some(f.clone())
1124                } else {
1125                    None
1126                }
1127            })
1128            .collect();
1129        fn collect_calls(t: &Term, defined: &HashSet<String>, calls: &mut Vec<String>) {
1130            if let Term::Fun(f, args) = t {
1131                if defined.contains(f.as_str()) {
1132                    calls.push(f.clone());
1133                }
1134                for a in args {
1135                    collect_calls(a, defined, calls);
1136                }
1137            }
1138        }
1139        for rule in &trs.rules {
1140            if let Term::Fun(lhs_f, _) = &rule.lhs {
1141                if defined.contains(lhs_f.as_str()) {
1142                    let mut calls = Vec::new();
1143                    collect_calls(&rule.rhs, &defined, &mut calls);
1144                    for rhs_f in calls {
1145                        graph.add_pair(lhs_f.clone(), rhs_f);
1146                    }
1147                }
1148            }
1149        }
1150        let n = graph.pairs.len();
1151        for i in 0..n {
1152            for j in 0..n {
1153                if graph.pairs[i].rhs_root == graph.pairs[j].lhs_root {
1154                    graph.add_edge(i, j);
1155                }
1156            }
1157        }
1158        graph
1159    }
1160}