rssn/symbolic/
logic.rs

1//! # Symbolic Logic Module
2//!
3//! This module provides functions for symbolic manipulation of logical expressions.
4//! It includes capabilities for simplifying logical formulas, converting them to
5//! normal forms (CNF, DNF), and a basic SAT solver for quantifier-free predicate logic.
6use crate::symbolic::core::Expr;
7use crate::symbolic::simplify_dag::simplify;
8use std::collections::{BTreeSet, HashMap, HashSet};
9use std::sync::Arc;
10/// Checks if a variable occurs freely in an expression.
11pub(crate) fn free_vars(expr: &Expr, free: &mut BTreeSet<String>, bound: &mut BTreeSet<String>) {
12    match expr {
13        Expr::Dag(node) => {
14            free_vars(&node.to_expr().expect("Free Vars"), free, bound);
15        }
16        Expr::Variable(s) => {
17            if !bound.contains(s) {
18                free.insert(s.clone());
19            }
20        }
21        Expr::Add(a, b)
22        | Expr::Sub(a, b)
23        | Expr::Mul(a, b)
24        | Expr::Div(a, b)
25        | Expr::Power(a, b)
26        | Expr::Eq(a, b)
27        | Expr::Lt(a, b)
28        | Expr::Gt(a, b)
29        | Expr::Le(a, b)
30        | Expr::Ge(a, b)
31        | Expr::Xor(a, b)
32        | Expr::Implies(a, b)
33        | Expr::Equivalent(a, b) => {
34            free_vars(a, free, bound);
35            free_vars(b, free, bound);
36        }
37        Expr::Neg(a) | Expr::Not(a) => {
38            free_vars(a, free, bound);
39        }
40        Expr::And(v) | Expr::Or(v) => {
41            for sub_expr in v {
42                free_vars(sub_expr, free, bound);
43            }
44        }
45        Expr::ForAll(var, body) | Expr::Exists(var, body) => {
46            bound.insert(var.clone());
47            free_vars(body, free, bound);
48            bound.remove(var);
49        }
50        Expr::Predicate { args, .. } => {
51            for arg in args {
52                free_vars(arg, free, bound);
53            }
54        }
55        _ => {}
56    }
57}
58/// Helper to check if an expression contains a specific free variable.
59pub(crate) fn has_free_var(expr: &Expr, var: &str) -> bool {
60    let mut free = BTreeSet::new();
61    let mut bound = BTreeSet::new();
62    free_vars(expr, &mut free, &mut bound);
63    free.contains(var)
64}
65/// Simplifies a logical expression by applying a set of transformation rules.
66///
67/// This function recursively traverses the expression tree and applies rules such as:
68/// - **Double Negation**: `Not(Not(P))` -> `P`
69/// - **De Morgan's Laws**: `Not(ForAll(x, P(x)))` -> `Exists(x, Not(P(x)))`
70/// - **Constant Folding**: `A And False` -> `False`, `A Or True` -> `True`
71/// - **Identity and Idempotence**: `A And True` -> `A`, `A Or A` -> `A`
72/// - **Contradiction/Tautology**: `A And Not(A)` -> `False`, `A Or Not(A)` -> `True`
73/// - **Quantifier Reduction**: Removes redundant quantifiers where the variable is not free in the body.
74/// - **Quantifier Pushing**: Moves quantifiers inwards to narrow their scope, e.g.,
75///   `ForAll(x, P(x) And Q(y))` -> `(ForAll(x, P(x))) And Q(y)`.
76///
77/// # Arguments
78/// * `expr` - The logical expression to simplify.
79///
80/// # Returns
81/// A new, simplified logical expression.
82pub fn simplify_logic(expr: &Expr) -> Expr {
83    match expr {
84        Expr::Dag(node) => simplify_logic(&node.to_expr().expect("Simplify Logic")),
85        Expr::Not(inner) => match simplify_logic(inner) {
86            Expr::Boolean(b) => Expr::Boolean(!b),
87            Expr::Not(sub) => (*sub).clone(),
88            Expr::ForAll(var, body) => {
89                Expr::Exists(var, Arc::new(simplify_logic(&Expr::new_not(body))))
90            }
91            Expr::Exists(var, body) => {
92                Expr::ForAll(var, Arc::new(simplify_logic(&Expr::new_not(body))))
93            }
94            simplified_inner => Expr::new_not(simplified_inner),
95        },
96        Expr::And(v) => {
97            let mut new_terms = Vec::new();
98            for term in v {
99                let simplified = simplify_logic(term);
100                match simplified {
101                    Expr::Boolean(false) => return Expr::Boolean(false),
102                    Expr::Boolean(true) => continue,
103                    Expr::And(mut sub_terms) => new_terms.append(&mut sub_terms),
104                    _ => new_terms.push(simplified),
105                }
106            }
107            let mut unique_terms = BTreeSet::new();
108            for term in new_terms {
109                unique_terms.insert(term);
110            }
111            for term in &unique_terms {
112                if unique_terms.contains(&Expr::new_not(term.clone())) {
113                    return Expr::Boolean(false);
114                }
115            }
116            if unique_terms.is_empty() {
117                Expr::Boolean(true)
118            } else if unique_terms.len() == 1 {
119                unique_terms
120                    .into_iter()
121                    .next()
122                    .expect("Unique Term Parsing Failed")
123            } else {
124                Expr::And(unique_terms.into_iter().collect())
125            }
126        }
127        Expr::Or(v) => {
128            let mut new_terms = Vec::new();
129            for term in v {
130                let simplified = simplify_logic(term);
131                match simplified {
132                    Expr::Boolean(true) => return Expr::Boolean(true),
133                    Expr::Boolean(false) => continue,
134                    Expr::Or(mut sub_terms) => new_terms.append(&mut sub_terms),
135                    _ => new_terms.push(simplified),
136                }
137            }
138            let mut unique_terms = BTreeSet::new();
139            for term in new_terms {
140                unique_terms.insert(term);
141            }
142            for term in &unique_terms {
143                if unique_terms.contains(&Expr::new_not(term.clone())) {
144                    return Expr::Boolean(true);
145                }
146            }
147            if unique_terms.is_empty() {
148                Expr::Boolean(false)
149            } else if unique_terms.len() == 1 {
150                unique_terms
151                    .into_iter()
152                    .next()
153                    .expect("Unique Term Parsing Failed")
154            } else {
155                Expr::Or(unique_terms.into_iter().collect())
156            }
157        }
158        Expr::Implies(a, b) => simplify_logic(&Expr::Or(vec![
159            Expr::Not(Arc::new(a.as_ref().clone())),
160            b.as_ref().clone(),
161        ])),
162        Expr::Equivalent(a, b) => simplify_logic(&Expr::And(vec![
163            Expr::Implies(a.clone(), b.clone()),
164            Expr::Implies(b.clone(), a.clone()),
165        ])),
166        Expr::Xor(a, b) => simplify_logic(&Expr::And(vec![
167            Expr::Or(vec![a.as_ref().clone(), b.as_ref().clone()]),
168            Expr::Not(Arc::new(Expr::And(vec![
169                a.as_ref().clone(),
170                b.as_ref().clone(),
171            ]))),
172        ])),
173        Expr::ForAll(var, body) => {
174            let simplified_body = simplify_logic(body);
175            if !has_free_var(&simplified_body, var) {
176                return simplified_body;
177            }
178            if let Expr::And(terms) = &simplified_body {
179                let mut with_var = vec![];
180                let mut without_var = vec![];
181                for term in terms {
182                    if has_free_var(term, var) {
183                        with_var.push(term.clone());
184                    } else {
185                        without_var.push(term.clone());
186                    }
187                }
188                if !without_var.is_empty() {
189                    let forall_part = if with_var.is_empty() {
190                        Expr::Boolean(true)
191                    } else {
192                        Expr::ForAll(var.clone(), Arc::new(Expr::And(with_var)))
193                    };
194                    without_var.push(simplify_logic(&forall_part));
195                    return simplify_logic(&Expr::And(without_var));
196                }
197            }
198            Expr::ForAll(var.clone(), Arc::new(simplified_body))
199        }
200        Expr::Exists(var, body) => {
201            let simplified_body = simplify_logic(body);
202            if !has_free_var(&simplified_body, var) {
203                return simplified_body;
204            }
205            if let Expr::Or(terms) = &simplified_body {
206                let mut with_var = vec![];
207                let mut without_var = vec![];
208                for term in terms {
209                    if has_free_var(term, var) {
210                        with_var.push(term.clone());
211                    } else {
212                        without_var.push(term.clone());
213                    }
214                }
215                if !without_var.is_empty() {
216                    let exists_part = if with_var.is_empty() {
217                        Expr::Boolean(false)
218                    } else {
219                        Expr::Exists(var.clone(), Arc::new(Expr::Or(with_var)))
220                    };
221                    without_var.push(simplify_logic(&exists_part));
222                    return simplify_logic(&Expr::Or(without_var));
223                }
224            }
225            Expr::Exists(var.clone(), Arc::new(simplified_body))
226        }
227        Expr::Predicate { name, args } => Expr::Predicate {
228            name: name.clone(),
229            args: args
230                .iter()
231                .map(|expr: &Expr| simplify(&expr.clone()))
232                .collect(),
233        },
234        _ => expr.clone(),
235    }
236}
237pub(crate) fn to_basic_logic_ops(expr: &Expr) -> Expr {
238    match expr {
239        Expr::Dag(node) => to_basic_logic_ops(&node.to_expr().expect("To Basic Logic Ops")),
240        Expr::Implies(a, b) => Expr::Or(vec![
241            Expr::Not(Arc::new(to_basic_logic_ops(a))),
242            to_basic_logic_ops(b),
243        ]),
244        Expr::Equivalent(a, b) => Expr::And(vec![
245            Expr::Or(vec![
246                Expr::Not(Arc::new(to_basic_logic_ops(a))),
247                to_basic_logic_ops(b),
248            ]),
249            Expr::Or(vec![
250                Expr::Not(Arc::new(to_basic_logic_ops(b))),
251                to_basic_logic_ops(a),
252            ]),
253        ]),
254        Expr::Xor(a, b) => Expr::And(vec![
255            Expr::Or(vec![to_basic_logic_ops(a), to_basic_logic_ops(b)]),
256            Expr::Not(Arc::new(Expr::And(vec![
257                to_basic_logic_ops(a),
258                to_basic_logic_ops(b),
259            ]))),
260        ]),
261        Expr::And(v) => Expr::And(v.iter().map(to_basic_logic_ops).collect()),
262        Expr::Or(v) => Expr::Or(v.iter().map(to_basic_logic_ops).collect()),
263        Expr::Not(a) => Expr::new_not(to_basic_logic_ops(a)),
264        _ => expr.clone(),
265    }
266}
267pub(crate) fn move_not_inwards(expr: &Expr) -> Expr {
268    match expr {
269        Expr::Dag(node) => move_not_inwards(&node.to_expr().expect("Move not Inwards")),
270        Expr::Not(a) => match &**a {
271            Expr::And(v) => Expr::Or(
272                v.iter()
273                    .map(|e| move_not_inwards(&Expr::new_not(e.clone())))
274                    .collect(),
275            ),
276            Expr::Or(v) => Expr::And(
277                v.iter()
278                    .map(|e| move_not_inwards(&Expr::new_not(e.clone())))
279                    .collect(),
280            ),
281            Expr::Not(b) => move_not_inwards(b),
282            Expr::ForAll(var, body) => Expr::Exists(
283                var.clone(),
284                Arc::new(move_not_inwards(&Expr::new_not(body.clone()))),
285            ),
286            Expr::Exists(var, body) => Expr::ForAll(
287                var.clone(),
288                Arc::new(move_not_inwards(&Expr::new_not(body.clone()))),
289            ),
290            _ => expr.clone(),
291        },
292        Expr::And(v) => Expr::And(v.iter().map(move_not_inwards).collect()),
293        Expr::Or(v) => Expr::Or(v.iter().map(move_not_inwards).collect()),
294        _ => expr.clone(),
295    }
296}
297pub(crate) fn distribute_or_over_and(expr: &Expr) -> Expr {
298    match expr {
299        Expr::Dag(node) => distribute_or_over_and(&node.to_expr().expect("Distribute or Over")),
300        Expr::Or(v) => {
301            let v_dist: Vec<Expr> = v.iter().map(distribute_or_over_and).collect();
302            if let Some(pos) = v_dist.iter().position(|e| matches!(e, Expr::And(_))) {
303                let and_clause = v_dist[pos].clone();
304                let other_terms: Vec<Expr> = v_dist
305                    .iter()
306                    .enumerate()
307                    .filter(|(i, _)| *i != pos)
308                    .map(|(_, e)| e.clone())
309                    .collect();
310                if let Expr::And(and_terms) = and_clause {
311                    let new_clauses = and_terms
312                        .iter()
313                        .map(|term| {
314                            let mut new_or_list = other_terms.clone();
315                            new_or_list.push(term.clone());
316                            distribute_or_over_and(&Expr::Or(new_or_list))
317                        })
318                        .collect();
319                    return Expr::And(new_clauses);
320                }
321            }
322            Expr::Or(v_dist)
323        }
324        Expr::And(v) => Expr::And(v.iter().map(distribute_or_over_and).collect()),
325        _ => expr.clone(),
326    }
327}
328/// Converts a logical expression into Conjunctive Normal Form (CNF).
329///
330/// CNF is a standardized representation of a logical formula which is a conjunction
331/// of one or more clauses, where each clause is a disjunction of literals.
332/// The conversion process involves three main steps:
333/// 1.  Eliminating complex logical operators like `Implies`, `Equivalent`, and `Xor`.
334/// 2.  Moving all `Not` operators inwards using De Morgan's laws.
335/// 3.  Distributing `Or` over `And` to achieve the final CNF structure.
336///
337/// # Arguments
338/// * `expr` - The logical expression to convert.
339///
340/// # Returns
341/// An equivalent expression in Conjunctive Normal Form.
342pub fn to_cnf(expr: &Expr) -> Expr {
343    let simplified = simplify_logic(expr);
344    let basic_ops = to_basic_logic_ops(&simplified);
345    let not_inwards = move_not_inwards(&basic_ops);
346    let distributed = distribute_or_over_and(&not_inwards);
347    simplify_logic(&distributed)
348}
349/// Converts a logical expression into Disjunctive Normal Form (DNF).
350///
351/// DNF is a standardized representation of a logical formula which is a disjunction
352/// of one or more clauses, where each clause is a conjunction of literals.
353/// This implementation cleverly achieves the conversion by using the `to_cnf` function:
354/// 1.  The input expression `expr` is negated: `Not(expr)`.
355/// 2.  The negated expression is converted to CNF: `cnf(Not(expr))`.
356/// 3.  The resulting CNF is negated again, and De Morgan's laws are applied implicitly
357///     by `simplify_logic`, resulting in the DNF of the original expression.
358///
359/// # Arguments
360/// * `expr` - The logical expression to convert.
361///
362/// # Returns
363/// An equivalent expression in Disjunctive Normal Form.
364pub fn to_dnf(expr: &Expr) -> Expr {
365    let not_expr = simplify_logic(&Expr::new_not(expr.clone()));
366    let cnf_of_not = to_cnf(&not_expr);
367    simplify_logic(&Expr::new_not(cnf_of_not))
368}
369/// Determines if a quantifier-free logical formula is satisfiable using the DPLL algorithm.
370///
371/// This function first checks if the expression contains any quantifiers (`ForAll`, `Exists`).
372/// If it does, the problem is generally undecidable, and the function returns `None`.
373///
374/// For quantifier-free formulas, it proceeds by:
375/// 1.  Converting the expression to Conjunctive Normal Form (CNF).
376/// 2.  Applying the recursive DPLL (Davis-Putnam-Logemann-Loveland) algorithm to the CNF clauses.
377///
378/// The DPLL algorithm attempts to find a satisfying assignment for the propositional variables
379/// (in this case, predicate instances like `P(x)`) by using unit propagation, pure literal
380/// elimination (implicitly), and recursive branching on variable assignments.
381///
382/// # Arguments
383/// * `expr` - A logical expression, which should be quantifier-free for a definitive result.
384///
385/// # Returns
386/// * `Some(true)` if the formula is satisfiable.
387/// * `Some(false)` if the formula is unsatisfiable.
388/// * `None` if the formula contains quantifiers, as this solver does not handle them.
389pub fn is_satisfiable(expr: &Expr) -> Option<bool> {
390    if contains_quantifier(expr) {
391        return None;
392    }
393    let cnf = to_cnf(expr);
394    if let Expr::Boolean(b) = cnf {
395        return Some(b);
396    }
397    let mut clauses = extract_clauses(&cnf);
398    let mut assignments = HashMap::new();
399    Some(dpll(&mut clauses, &mut assignments))
400}
401/// A literal is an atomic proposition (e.g., P(x)) or its negation.
402#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
403pub enum Literal {
404    Positive(Expr),
405    Negative(Expr),
406}
407pub(crate) const fn get_atom(literal: &Literal) -> &Expr {
408    match literal {
409        Literal::Positive(atom) => atom,
410        Literal::Negative(atom) => atom,
411    }
412}
413pub(crate) fn extract_clauses(cnf_expr: &Expr) -> Vec<HashSet<Literal>> {
414    let mut clauses = Vec::new();
415    if let Expr::And(conjuncts) = cnf_expr {
416        for clause_expr in conjuncts {
417            clauses.push(extract_literals_from_clause(clause_expr));
418        }
419    } else {
420        clauses.push(extract_literals_from_clause(cnf_expr));
421    }
422    clauses
423}
424pub(crate) fn extract_literals_from_clause(clause_expr: &Expr) -> HashSet<Literal> {
425    let mut literals = HashSet::new();
426    if let Expr::Or(disjuncts) = clause_expr {
427        for literal_expr in disjuncts {
428            if let Expr::Not(atom) = literal_expr {
429                literals.insert(Literal::Negative(atom.as_ref().clone()));
430            } else {
431                literals.insert(Literal::Positive(literal_expr.clone()));
432            }
433        }
434    } else if let Expr::Not(atom) = clause_expr {
435        literals.insert(Literal::Negative(atom.as_ref().clone()));
436    } else {
437        literals.insert(Literal::Positive(clause_expr.clone()));
438    }
439    literals
440}
441pub(crate) fn dpll(
442    clauses: &mut Vec<HashSet<Literal>>,
443    assignments: &mut HashMap<Expr, bool>,
444) -> bool {
445    while let Some(unit_literal) = find_unit_clause(clauses) {
446        let (atom, value) = match unit_literal {
447            Literal::Positive(a) => (a, true),
448            Literal::Negative(a) => (a, false),
449        };
450        assignments.insert(atom.clone(), value);
451        simplify_clauses(clauses, &atom, value);
452        if clauses.is_empty() {
453            return true;
454        }
455        if clauses.iter().any(HashSet::is_empty) {
456            return false;
457        }
458    }
459    if clauses.is_empty() {
460        return true;
461    }
462    let atom_to_branch = match get_unassigned_atom(clauses, assignments) {
463        Some(v) => v,
464        _none => return true,
465    };
466    let mut clauses_true = clauses.clone();
467    let mut assignments_true = assignments.clone();
468    assignments_true.insert(atom_to_branch.clone(), true);
469    simplify_clauses(&mut clauses_true, &atom_to_branch, true);
470    if dpll(&mut clauses_true, &mut assignments_true) {
471        return true;
472    }
473    let mut clauses_false = clauses.clone();
474    let mut assignments_false = assignments.clone();
475    assignments_false.insert(atom_to_branch.clone(), false);
476    simplify_clauses(&mut clauses_false, &atom_to_branch, false);
477    if dpll(&mut clauses_false, &mut assignments_false) {
478        return true;
479    }
480    false
481}
482pub(crate) fn find_unit_clause(clauses: &[HashSet<Literal>]) -> Option<Literal> {
483    clauses
484        .iter()
485        .find(|c| c.len() == 1)
486        .and_then(|c| c.iter().next().cloned())
487}
488pub(crate) fn simplify_clauses(clauses: &mut Vec<HashSet<Literal>>, atom: &Expr, value: bool) {
489    clauses.retain(|clause| {
490        !clause.iter().any(|lit| match lit {
491            Literal::Positive(a) => a == atom && value,
492            Literal::Negative(a) => a == atom && !value,
493        })
494    });
495    let opposite_literal = if value {
496        Literal::Negative(atom.clone())
497    } else {
498        Literal::Positive(atom.clone())
499    };
500    for clause in clauses {
501        clause.remove(&opposite_literal);
502    }
503}
504pub(crate) fn get_unassigned_atom(
505    clauses: &[HashSet<Literal>],
506    assignments: &HashMap<Expr, bool>,
507) -> Option<Expr> {
508    for clause in clauses {
509        for literal in clause {
510            let atom = get_atom(literal);
511            if !assignments.contains_key(atom) {
512                return Some(atom.clone());
513            }
514        }
515    }
516    None
517}
518pub(crate) fn contains_quantifier(expr: &Expr) -> bool {
519    match expr {
520        Expr::Dag(node) => contains_quantifier(&node.to_expr().expect("Contains Quantifier")),
521        Expr::ForAll(_, _) | Expr::Exists(_, _) => true,
522        Expr::Add(a, b)
523        | Expr::Sub(a, b)
524        | Expr::Mul(a, b)
525        | Expr::Div(a, b)
526        | Expr::Power(a, b)
527        | Expr::Eq(a, b)
528        | Expr::Lt(a, b)
529        | Expr::Gt(a, b)
530        | Expr::Le(a, b)
531        | Expr::Ge(a, b)
532        | Expr::Xor(a, b)
533        | Expr::Implies(a, b)
534        | Expr::Equivalent(a, b) => contains_quantifier(a) || contains_quantifier(b),
535        Expr::Neg(a) | Expr::Not(a) => contains_quantifier(a),
536        Expr::And(v) | Expr::Or(v) => v.iter().any(contains_quantifier),
537        Expr::Predicate { args, .. } => args.iter().any(contains_quantifier),
538        _ => false,
539    }
540}