Skip to main content

oxiz_sat/
gate.rs

1//! Gate detection and optimization
2//!
3//! This module implements gate detection and extraction from CNF formulas.
4//! Many real-world SAT instances contain encoded logical gates (AND, OR, XOR, ITE).
5//! Detecting these gates enables specialized reasoning and optimizations.
6//!
7//! References:
8//! - "Gate Extraction for CNF Formulas" (Eén et al.)
9//! - "Effective Preprocessing with Hyper-Resolution and Equality Reduction" (Bacchus & Winter)
10//! - "Recognition of Nested Gates in CNF Formulas" (Manthey)
11
12use crate::clause::ClauseDatabase;
13use crate::literal::Lit;
14#[allow(unused_imports)]
15use crate::prelude::*;
16
17/// Statistics for gate detection
18#[derive(Debug, Clone, Default)]
19pub struct GateStats {
20    /// Number of AND gates detected
21    pub and_gates: usize,
22    /// Number of OR gates detected
23    pub or_gates: usize,
24    /// Number of XOR gates detected
25    pub xor_gates: usize,
26    /// Number of ITE (if-then-else) gates detected
27    pub ite_gates: usize,
28    /// Number of MUX gates detected
29    pub mux_gates: usize,
30    /// Number of equivalent gates merged
31    pub gates_merged: usize,
32    /// Number of clauses removed via gate substitution
33    pub clauses_removed: usize,
34}
35
36impl GateStats {
37    /// Display statistics
38    pub fn display(&self) {
39        println!("Gate Detection Statistics:");
40        println!("  AND gates: {}", self.and_gates);
41        println!("  OR gates: {}", self.or_gates);
42        println!("  XOR gates: {}", self.xor_gates);
43        println!("  ITE gates: {}", self.ite_gates);
44        println!("  MUX gates: {}", self.mux_gates);
45        println!("  Gates merged: {}", self.gates_merged);
46        println!("  Clauses removed: {}", self.clauses_removed);
47    }
48}
49
50/// Type of logical gate
51#[derive(Debug, Clone, PartialEq, Eq, Hash)]
52pub enum GateType {
53    /// AND gate: output = a ∧ b
54    And {
55        /// Output literal of the AND gate
56        output: Lit,
57        /// Input literals
58        inputs: Vec<Lit>,
59    },
60    /// OR gate: output = a ∨ b
61    Or {
62        /// Output literal of the OR gate
63        output: Lit,
64        /// Input literals
65        inputs: Vec<Lit>,
66    },
67    /// XOR gate: output = a ⊕ b
68    Xor {
69        /// Output literal of the XOR gate
70        output: Lit,
71        /// Input literals
72        inputs: Vec<Lit>,
73    },
74    /// ITE gate: output = if c then t else e
75    Ite {
76        /// Output literal of the ITE gate
77        output: Lit,
78        /// Condition literal
79        condition: Lit,
80        /// Then-value literal
81        then_val: Lit,
82        /// Else-value literal
83        else_val: Lit,
84    },
85    /// MUX gate: generalized ITE with multiple conditions
86    Mux {
87        /// Output literal of the MUX gate
88        output: Lit,
89        /// Select literals
90        select: Vec<Lit>,
91        /// Input literals
92        inputs: Vec<Lit>,
93    },
94}
95
96impl GateType {
97    /// Get the output literal of this gate
98    #[must_use]
99    pub fn output(&self) -> Lit {
100        match self {
101            GateType::And { output, .. }
102            | GateType::Or { output, .. }
103            | GateType::Xor { output, .. }
104            | GateType::Ite { output, .. }
105            | GateType::Mux { output, .. } => *output,
106        }
107    }
108
109    /// Get all input literals of this gate
110    #[must_use]
111    pub fn inputs(&self) -> Vec<Lit> {
112        match self {
113            GateType::And { inputs, .. }
114            | GateType::Or { inputs, .. }
115            | GateType::Xor { inputs, .. } => inputs.clone(),
116            GateType::Ite {
117                condition,
118                then_val,
119                else_val,
120                ..
121            } => vec![*condition, *then_val, *else_val],
122            GateType::Mux { select, inputs, .. } => {
123                let mut all = select.clone();
124                all.extend(inputs);
125                all
126            }
127        }
128    }
129}
130
131/// Gate detector and optimizer
132#[derive(Debug)]
133pub struct GateDetector {
134    /// Detected gates indexed by output literal
135    gates: HashMap<Lit, GateType>,
136    /// Clauses that define each literal
137    definitions: HashMap<Lit, Vec<Vec<Lit>>>,
138    /// Statistics
139    stats: GateStats,
140}
141
142impl Default for GateDetector {
143    fn default() -> Self {
144        Self::new()
145    }
146}
147
148impl GateDetector {
149    /// Create a new gate detector
150    #[must_use]
151    pub fn new() -> Self {
152        Self {
153            gates: HashMap::new(),
154            definitions: HashMap::new(),
155            stats: GateStats::default(),
156        }
157    }
158
159    /// Build clause definitions for each literal
160    pub fn build_definitions(&mut self, clauses: &ClauseDatabase) {
161        self.definitions.clear();
162
163        for cid in clauses.iter_ids() {
164            if let Some(clause) = clauses.get(cid) {
165                let lits: Vec<_> = clause.lits.to_vec();
166
167                // For each literal in the clause, record this clause as a definition
168                for &lit in &lits {
169                    self.definitions.entry(lit).or_default().push(lits.clone());
170                }
171            }
172        }
173    }
174
175    /// Detect AND gates
176    ///
177    /// Pattern: output = a ∧ b ∧ ...
178    /// CNF: (~a ∨ ~b ∨ output) ∧ (a ∨ ~output) ∧ (b ∨ ~output) ...
179    pub fn detect_and_gates(&mut self) {
180        let outputs: Vec<_> = self.definitions.keys().copied().collect();
181
182        for output in outputs {
183            if self.gates.contains_key(&output) {
184                continue; // Already detected a gate for this literal
185            }
186
187            if let Some(gate_inputs) = self.try_extract_and_gate(output) {
188                self.gates.insert(
189                    output,
190                    GateType::And {
191                        output,
192                        inputs: gate_inputs,
193                    },
194                );
195                self.stats.and_gates += 1;
196            }
197        }
198    }
199
200    /// Try to extract an AND gate with the given output
201    fn try_extract_and_gate(&self, output: Lit) -> Option<Vec<Lit>> {
202        // For AND gate: out = a ∧ b
203        // CNF: (~a ∨ ~b ∨ out) ∧ (a ∨ ~out) ∧ (b ∨ ~out)
204
205        // Look for binary clauses (input ∨ ~output) in definitions[~output]
206        let neg_output_defs = self.definitions.get(&!output)?;
207        let mut inputs = Vec::new();
208
209        for clause in neg_output_defs {
210            if clause.len() == 2 && clause.contains(&!output) {
211                // Binary clause (input ∨ ~output)
212                let input = clause.iter().find(|&&lit| lit != !output)?;
213                inputs.push(*input);
214            }
215        }
216
217        if inputs.is_empty() {
218            return None;
219        }
220
221        // Check for the main AND clause: (~a ∨ ~b ∨ ... ∨ output) in definitions[output]
222        let output_defs = self.definitions.get(&output)?;
223
224        for clause in output_defs {
225            if clause.contains(&output) && clause.len() == inputs.len() + 1 {
226                // Check if all other literals are negations of inputs
227                let other_lits: Vec<_> = clause
228                    .iter()
229                    .filter(|&&lit| lit != output)
230                    .copied()
231                    .collect();
232
233                let expected_lits: HashSet<_> = inputs.iter().map(|&lit| !lit).collect();
234                let actual_lits: HashSet<_> = other_lits.iter().copied().collect();
235
236                if expected_lits == actual_lits {
237                    return Some(inputs);
238                }
239            }
240        }
241
242        None
243    }
244
245    /// Detect OR gates
246    ///
247    /// Pattern: output = a ∨ b ∨ ...
248    /// CNF: (a ∨ b ∨ ~output) ∧ (~a ∨ output) ∧ (~b ∨ output) ...
249    pub fn detect_or_gates(&mut self) {
250        let outputs: Vec<_> = self.definitions.keys().copied().collect();
251
252        for output in outputs {
253            if self.gates.contains_key(&output) {
254                continue;
255            }
256
257            if let Some(gate_inputs) = self.try_extract_or_gate(output) {
258                self.gates.insert(
259                    output,
260                    GateType::Or {
261                        output,
262                        inputs: gate_inputs,
263                    },
264                );
265                self.stats.or_gates += 1;
266            }
267        }
268    }
269
270    /// Try to extract an OR gate with the given output
271    fn try_extract_or_gate(&self, output: Lit) -> Option<Vec<Lit>> {
272        // For OR gate: out = a ∨ b
273        // CNF: (a ∨ b ∨ ~out) ∧ (~a ∨ out) ∧ (~b ∨ out)
274
275        // Look for binary clauses (~input ∨ output) in definitions[output]
276        let output_defs = self.definitions.get(&output)?;
277        let mut inputs = Vec::new();
278
279        for clause in output_defs {
280            if clause.len() == 2 && clause.contains(&output) {
281                // Binary clause (~input ∨ output)
282                let neg_input = clause.iter().find(|&&lit| lit != output)?;
283                inputs.push(!*neg_input);
284            }
285        }
286
287        if inputs.is_empty() {
288            return None;
289        }
290
291        // Check for the main OR clause: (a ∨ b ∨ ... ∨ ~output) in definitions[~output]
292        let neg_output_defs = self.definitions.get(&!output)?;
293
294        for clause in neg_output_defs {
295            if clause.contains(&!output) && clause.len() == inputs.len() + 1 {
296                // Check if all other literals match inputs
297                let other_lits: Vec<_> = clause
298                    .iter()
299                    .filter(|&&lit| lit != !output)
300                    .copied()
301                    .collect();
302
303                let expected_lits: HashSet<_> = inputs.iter().copied().collect();
304                let actual_lits: HashSet<_> = other_lits.iter().copied().collect();
305
306                if expected_lits == actual_lits {
307                    return Some(inputs);
308                }
309            }
310        }
311
312        None
313    }
314
315    /// Detect ITE (if-then-else) gates
316    ///
317    /// Pattern: output = if c then t else e
318    /// CNF: (~c ∨ ~t ∨ output) ∧ (c ∨ ~e ∨ output) ∧ (~c ∨ t ∨ ~output) ∧ (c ∨ e ∨ ~output)
319    pub fn detect_ite_gates(&mut self) {
320        let outputs: Vec<_> = self.definitions.keys().copied().collect();
321
322        for output in outputs {
323            if self.gates.contains_key(&output) {
324                continue;
325            }
326
327            if let Some((condition, then_val, else_val)) = self.try_extract_ite_gate(output) {
328                self.gates.insert(
329                    output,
330                    GateType::Ite {
331                        output,
332                        condition,
333                        then_val,
334                        else_val,
335                    },
336                );
337                self.stats.ite_gates += 1;
338            }
339        }
340    }
341
342    /// Try to extract an ITE gate with the given output
343    fn try_extract_ite_gate(&self, output: Lit) -> Option<(Lit, Lit, Lit)> {
344        let output_defs = self.definitions.get(&output)?;
345        let neg_output_defs = self.definitions.get(&!output)?;
346
347        // Look for the characteristic clauses of an ITE
348        // This is a simplified detection - a full implementation would be more thorough
349        for clause1 in output_defs {
350            if clause1.len() == 3 {
351                for clause2 in neg_output_defs {
352                    if clause2.len() == 3 {
353                        // Try to match ITE pattern
354                        // This is a placeholder - full pattern matching would be complex
355                        // For now, we'll just detect simple cases
356                    }
357                }
358            }
359        }
360
361        None // Simplified - would need full pattern matching
362    }
363
364    /// Detect all gate types
365    pub fn detect_all(&mut self, clauses: &ClauseDatabase) {
366        self.build_definitions(clauses);
367        self.detect_and_gates();
368        self.detect_or_gates();
369        self.detect_ite_gates();
370    }
371
372    /// Get all detected gates
373    #[must_use]
374    pub fn gates(&self) -> &HashMap<Lit, GateType> {
375        &self.gates
376    }
377
378    /// Get gate for a specific literal
379    #[must_use]
380    pub fn get_gate(&self, lit: Lit) -> Option<&GateType> {
381        self.gates.get(&lit)
382    }
383
384    /// Check if a literal is a gate output
385    #[must_use]
386    pub fn is_gate_output(&self, lit: Lit) -> bool {
387        self.gates.contains_key(&lit)
388    }
389
390    /// Get statistics
391    #[must_use]
392    pub fn stats(&self) -> &GateStats {
393        &self.stats
394    }
395
396    /// Reset statistics
397    pub fn reset_stats(&mut self) {
398        self.stats = GateStats::default();
399    }
400
401    /// Clear all detected gates
402    pub fn clear(&mut self) {
403        self.gates.clear();
404        self.definitions.clear();
405    }
406}
407
408#[cfg(test)]
409mod tests {
410    use super::*;
411    use crate::literal::Var;
412
413    #[test]
414    fn test_gate_detector_creation() {
415        let detector = GateDetector::new();
416        assert_eq!(detector.stats().and_gates, 0);
417    }
418
419    #[test]
420    fn test_build_definitions() {
421        let mut detector = GateDetector::new();
422        let mut db = ClauseDatabase::new();
423
424        let a = Lit::pos(Var::new(0));
425        let b = Lit::pos(Var::new(1));
426
427        db.add_original(vec![a, b]);
428        detector.build_definitions(&db);
429
430        assert!(detector.definitions.contains_key(&a));
431        assert!(detector.definitions.contains_key(&b));
432    }
433
434    #[test]
435    fn test_detect_and_gate() {
436        let mut detector = GateDetector::new();
437        let mut db = ClauseDatabase::new();
438
439        let a = Lit::pos(Var::new(0));
440        let b = Lit::pos(Var::new(1));
441        let out = Lit::pos(Var::new(2));
442
443        // Encode: out = a ∧ b
444        // CNF: (~a ∨ ~b ∨ out) ∧ (a ∨ ~out) ∧ (b ∨ ~out)
445        db.add_original(vec![!a, !b, out]);
446        db.add_original(vec![a, !out]);
447        db.add_original(vec![b, !out]);
448
449        detector.detect_all(&db);
450
451        assert!(detector.is_gate_output(out));
452        if let Some(GateType::And { inputs, .. }) = detector.get_gate(out) {
453            assert_eq!(inputs.len(), 2);
454            assert!(inputs.contains(&a));
455            assert!(inputs.contains(&b));
456        } else {
457            panic!("Expected AND gate");
458        }
459    }
460
461    #[test]
462    fn test_detect_or_gate() {
463        let mut detector = GateDetector::new();
464        let mut db = ClauseDatabase::new();
465
466        let a = Lit::pos(Var::new(0));
467        let b = Lit::pos(Var::new(1));
468        let out = Lit::pos(Var::new(2));
469
470        // Encode: out = a ∨ b
471        // CNF: (a ∨ b ∨ ~out) ∧ (~a ∨ out) ∧ (~b ∨ out)
472        db.add_original(vec![a, b, !out]);
473        db.add_original(vec![!a, out]);
474        db.add_original(vec![!b, out]);
475
476        detector.detect_all(&db);
477
478        assert!(detector.is_gate_output(out));
479        if let Some(GateType::Or { inputs, .. }) = detector.get_gate(out) {
480            assert_eq!(inputs.len(), 2);
481            assert!(inputs.contains(&a));
482            assert!(inputs.contains(&b));
483        } else {
484            panic!("Expected OR gate");
485        }
486    }
487
488    #[test]
489    fn test_gate_type_output() {
490        let a = Lit::pos(Var::new(0));
491        let b = Lit::pos(Var::new(1));
492        let out = Lit::pos(Var::new(2));
493
494        let gate = GateType::And {
495            output: out,
496            inputs: vec![a, b],
497        };
498
499        assert_eq!(gate.output(), out);
500        let inputs = gate.inputs();
501        assert_eq!(inputs.len(), 2);
502    }
503
504    #[test]
505    fn test_clear() {
506        let mut detector = GateDetector::new();
507        let mut db = ClauseDatabase::new();
508
509        db.add_original(vec![Lit::pos(Var::new(0)), Lit::pos(Var::new(1))]);
510        detector.build_definitions(&db);
511
512        assert!(!detector.definitions.is_empty());
513
514        detector.clear();
515        assert!(detector.definitions.is_empty());
516        assert!(detector.gates.is_empty());
517    }
518}