Skip to main content

oxiz_theories/set/
operations.rs

1//! Set Operations Implementation
2//!
3//! Implements symbolic set operations:
4//! - Union (∪)
5//! - Intersection (∩)
6//! - Difference (\)
7//! - Complement (¬)
8//! - Cartesian product (×)
9//! - Symmetric difference (△)
10
11#![allow(missing_docs)]
12#![allow(dead_code)]
13
14use super::{SetConflict, SetVarId};
15use rustc_hash::{FxHashMap, FxHashSet};
16use smallvec::SmallVec;
17
18/// Binary set operation kind
19#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
20pub enum SetBinOp {
21    /// Union: S1 ∪ S2
22    Union,
23    /// Intersection: S1 ∩ S2
24    Intersection,
25    /// Difference: S1 \ S2
26    Difference,
27    /// Symmetric difference: S1 △ S2
28    SymmetricDiff,
29}
30
31/// Set operation builder for complex expressions
32#[derive(Debug, Clone)]
33pub struct SetOpBuilder {
34    /// Operations to apply
35    ops: Vec<SetOp>,
36    /// Intermediate results
37    intermediates: Vec<SetVarId>,
38}
39
40impl SetOpBuilder {
41    /// Create a new builder
42    pub fn new() -> Self {
43        Self {
44            ops: Vec::new(),
45            intermediates: Vec::new(),
46        }
47    }
48
49    /// Add a union operation
50    pub fn union(mut self, lhs: SetVarId, rhs: SetVarId, result: SetVarId) -> Self {
51        self.ops.push(SetOp::Binary {
52            op: SetBinOp::Union,
53            lhs,
54            rhs,
55            result,
56        });
57        self
58    }
59
60    /// Add an intersection operation
61    pub fn intersection(mut self, lhs: SetVarId, rhs: SetVarId, result: SetVarId) -> Self {
62        self.ops.push(SetOp::Binary {
63            op: SetBinOp::Intersection,
64            lhs,
65            rhs,
66            result,
67        });
68        self
69    }
70
71    /// Add a difference operation
72    pub fn difference(mut self, lhs: SetVarId, rhs: SetVarId, result: SetVarId) -> Self {
73        self.ops.push(SetOp::Binary {
74            op: SetBinOp::Difference,
75            lhs,
76            rhs,
77            result,
78        });
79        self
80    }
81
82    /// Add a complement operation
83    pub fn complement(mut self, set: SetVarId, result: SetVarId) -> Self {
84        self.ops.push(SetOp::Complement { set, result });
85        self
86    }
87
88    /// Build the operation sequence
89    pub fn build(self) -> Vec<SetOp> {
90        self.ops
91    }
92}
93
94impl Default for SetOpBuilder {
95    fn default() -> Self {
96        Self::new()
97    }
98}
99
100/// Set operation
101#[derive(Debug, Clone)]
102pub enum SetOp {
103    /// Binary operation: result = lhs op rhs
104    Binary {
105        op: SetBinOp,
106        lhs: SetVarId,
107        rhs: SetVarId,
108        result: SetVarId,
109    },
110    /// Complement: result = ¬set
111    Complement { set: SetVarId, result: SetVarId },
112}
113
114/// Set union implementation
115#[derive(Debug, Clone)]
116pub struct SetUnion {
117    /// Left operand
118    pub lhs: SetVarId,
119    /// Right operand
120    pub rhs: SetVarId,
121    /// Result variable
122    pub result: SetVarId,
123}
124
125impl SetUnion {
126    /// Create a new union operation
127    pub fn new(lhs: SetVarId, rhs: SetVarId, result: SetVarId) -> Self {
128        Self { lhs, rhs, result }
129    }
130
131    /// Propagate union constraints
132    ///
133    /// For result = lhs ∪ rhs:
134    /// - x ∈ result ⟺ x ∈ lhs ∨ x ∈ rhs
135    /// - x ∉ result ⟹ x ∉ lhs ∧ x ∉ rhs
136    /// - x ∈ lhs ⟹ x ∈ result
137    /// - x ∈ rhs ⟹ x ∈ result
138    pub fn propagate(
139        &self,
140        lhs_members: &FxHashSet<u32>,
141        lhs_non_members: &FxHashSet<u32>,
142        rhs_members: &FxHashSet<u32>,
143        rhs_non_members: &FxHashSet<u32>,
144    ) -> (FxHashSet<u32>, FxHashSet<u32>) {
145        let mut result_must = FxHashSet::default();
146        let mut result_must_not = FxHashSet::default();
147
148        // x ∈ lhs ⟹ x ∈ result
149        for &elem in lhs_members {
150            result_must.insert(elem);
151        }
152
153        // x ∈ rhs ⟹ x ∈ result
154        for &elem in rhs_members {
155            result_must.insert(elem);
156        }
157
158        // x ∉ lhs ∧ x ∉ rhs ⟹ x ∉ result
159        for &elem in lhs_non_members {
160            if rhs_non_members.contains(&elem) {
161                result_must_not.insert(elem);
162            }
163        }
164
165        (result_must, result_must_not)
166    }
167
168    /// Backward propagation from result to operands
169    pub fn propagate_backward(
170        &self,
171        _result_members: &FxHashSet<u32>,
172        result_non_members: &FxHashSet<u32>,
173    ) -> (
174        FxHashSet<u32>,
175        FxHashSet<u32>,
176        FxHashSet<u32>,
177        FxHashSet<u32>,
178    ) {
179        let mut lhs_must_not = FxHashSet::default();
180        let mut rhs_must_not = FxHashSet::default();
181        let lhs_must = FxHashSet::default();
182        let rhs_must = FxHashSet::default();
183
184        // x ∉ result ⟹ x ∉ lhs ∧ x ∉ rhs
185        for &elem in result_non_members {
186            lhs_must_not.insert(elem);
187            rhs_must_not.insert(elem);
188        }
189
190        (lhs_must, lhs_must_not, rhs_must, rhs_must_not)
191    }
192
193    /// Compute cardinality bounds for union
194    ///
195    /// max(|lhs|, |rhs|) ≤ |result| ≤ |lhs| + |rhs|
196    pub fn cardinality_bounds(
197        &self,
198        lhs_card: (i64, Option<i64>),
199        rhs_card: (i64, Option<i64>),
200    ) -> (i64, Option<i64>) {
201        let lower = lhs_card.0.max(rhs_card.0);
202        let upper = match (lhs_card.1, rhs_card.1) {
203            (Some(l), Some(r)) => Some(l + r),
204            _ => None,
205        };
206        (lower, upper)
207    }
208}
209
210/// Set intersection implementation
211#[derive(Debug, Clone)]
212pub struct SetIntersection {
213    /// Left operand
214    pub lhs: SetVarId,
215    /// Right operand
216    pub rhs: SetVarId,
217    /// Result variable
218    pub result: SetVarId,
219}
220
221impl SetIntersection {
222    /// Create a new intersection operation
223    pub fn new(lhs: SetVarId, rhs: SetVarId, result: SetVarId) -> Self {
224        Self { lhs, rhs, result }
225    }
226
227    /// Propagate intersection constraints
228    ///
229    /// For result = lhs ∩ rhs:
230    /// - x ∈ result ⟺ x ∈ lhs ∧ x ∈ rhs
231    /// - x ∉ result ⟹ x ∉ lhs ∨ x ∉ rhs
232    /// - x ∈ result ⟹ x ∈ lhs ∧ x ∈ rhs
233    /// - x ∉ lhs ⟹ x ∉ result
234    /// - x ∉ rhs ⟹ x ∉ result
235    pub fn propagate(
236        &self,
237        lhs_members: &FxHashSet<u32>,
238        lhs_non_members: &FxHashSet<u32>,
239        rhs_members: &FxHashSet<u32>,
240        rhs_non_members: &FxHashSet<u32>,
241    ) -> (FxHashSet<u32>, FxHashSet<u32>) {
242        let mut result_must = FxHashSet::default();
243        let mut result_must_not = FxHashSet::default();
244
245        // x ∈ lhs ∧ x ∈ rhs ⟹ x ∈ result
246        for &elem in lhs_members {
247            if rhs_members.contains(&elem) {
248                result_must.insert(elem);
249            }
250        }
251
252        // x ∉ lhs ⟹ x ∉ result
253        for &elem in lhs_non_members {
254            result_must_not.insert(elem);
255        }
256
257        // x ∉ rhs ⟹ x ∉ result
258        for &elem in rhs_non_members {
259            result_must_not.insert(elem);
260        }
261
262        (result_must, result_must_not)
263    }
264
265    /// Backward propagation from result to operands
266    pub fn propagate_backward(
267        &self,
268        result_members: &FxHashSet<u32>,
269        _result_non_members: &FxHashSet<u32>,
270    ) -> (
271        FxHashSet<u32>,
272        FxHashSet<u32>,
273        FxHashSet<u32>,
274        FxHashSet<u32>,
275    ) {
276        let mut lhs_must = FxHashSet::default();
277        let mut rhs_must = FxHashSet::default();
278        let lhs_must_not = FxHashSet::default();
279        let rhs_must_not = FxHashSet::default();
280
281        // x ∈ result ⟹ x ∈ lhs ∧ x ∈ rhs
282        for &elem in result_members {
283            lhs_must.insert(elem);
284            rhs_must.insert(elem);
285        }
286
287        (lhs_must, lhs_must_not, rhs_must, rhs_must_not)
288    }
289
290    /// Compute cardinality bounds for intersection
291    ///
292    /// 0 ≤ |result| ≤ min(|lhs|, |rhs|)
293    pub fn cardinality_bounds(
294        &self,
295        lhs_card: (i64, Option<i64>),
296        rhs_card: (i64, Option<i64>),
297    ) -> (i64, Option<i64>) {
298        let lower = 0;
299        let upper = match (lhs_card.1, rhs_card.1) {
300            (Some(l), Some(r)) => Some(l.min(r)),
301            (Some(l), None) => Some(l),
302            (None, Some(r)) => Some(r),
303            _ => None,
304        };
305        (lower, upper)
306    }
307}
308
309/// Set difference implementation
310#[derive(Debug, Clone)]
311pub struct SetDifference {
312    /// Left operand
313    pub lhs: SetVarId,
314    /// Right operand
315    pub rhs: SetVarId,
316    /// Result variable
317    pub result: SetVarId,
318}
319
320impl SetDifference {
321    /// Create a new difference operation
322    pub fn new(lhs: SetVarId, rhs: SetVarId, result: SetVarId) -> Self {
323        Self { lhs, rhs, result }
324    }
325
326    /// Propagate difference constraints
327    ///
328    /// For result = lhs \ rhs:
329    /// - x ∈ result ⟺ x ∈ lhs ∧ x ∉ rhs
330    /// - x ∈ result ⟹ x ∈ lhs
331    /// - x ∈ result ⟹ x ∉ rhs
332    /// - x ∉ lhs ⟹ x ∉ result
333    /// - x ∈ rhs ⟹ x ∉ result
334    pub fn propagate(
335        &self,
336        lhs_members: &FxHashSet<u32>,
337        lhs_non_members: &FxHashSet<u32>,
338        rhs_members: &FxHashSet<u32>,
339        rhs_non_members: &FxHashSet<u32>,
340    ) -> (FxHashSet<u32>, FxHashSet<u32>) {
341        let mut result_must = FxHashSet::default();
342        let mut result_must_not = FxHashSet::default();
343
344        // x ∈ lhs ∧ x ∉ rhs ⟹ x ∈ result
345        for &elem in lhs_members {
346            if rhs_non_members.contains(&elem) {
347                result_must.insert(elem);
348            }
349        }
350
351        // x ∉ lhs ⟹ x ∉ result
352        for &elem in lhs_non_members {
353            result_must_not.insert(elem);
354        }
355
356        // x ∈ rhs ⟹ x ∉ result
357        for &elem in rhs_members {
358            result_must_not.insert(elem);
359        }
360
361        (result_must, result_must_not)
362    }
363
364    /// Backward propagation from result to operands
365    pub fn propagate_backward(
366        &self,
367        result_members: &FxHashSet<u32>,
368        _result_non_members: &FxHashSet<u32>,
369    ) -> (
370        FxHashSet<u32>,
371        FxHashSet<u32>,
372        FxHashSet<u32>,
373        FxHashSet<u32>,
374    ) {
375        let mut lhs_must = FxHashSet::default();
376        let mut rhs_must_not = FxHashSet::default();
377        let lhs_must_not = FxHashSet::default();
378        let rhs_must = FxHashSet::default();
379
380        // x ∈ result ⟹ x ∈ lhs
381        for &elem in result_members {
382            lhs_must.insert(elem);
383            rhs_must_not.insert(elem);
384        }
385
386        // x ∉ result ∧ x ∈ lhs ⟹ x ∈ rhs
387        // (This is weaker, we don't propagate it here)
388
389        (lhs_must, lhs_must_not, rhs_must, rhs_must_not)
390    }
391
392    /// Compute cardinality bounds for difference
393    ///
394    /// 0 ≤ |result| ≤ |lhs|
395    pub fn cardinality_bounds(
396        &self,
397        lhs_card: (i64, Option<i64>),
398        _rhs_card: (i64, Option<i64>),
399    ) -> (i64, Option<i64>) {
400        let lower = 0;
401        let upper = lhs_card.1;
402        (lower, upper)
403    }
404}
405
406/// Set complement implementation
407#[derive(Debug, Clone)]
408pub struct SetComplement {
409    /// Set to complement
410    pub set: SetVarId,
411    /// Result variable
412    pub result: SetVarId,
413    /// Universe (if known)
414    pub universe: Option<FxHashSet<u32>>,
415}
416
417impl SetComplement {
418    /// Create a new complement operation
419    pub fn new(set: SetVarId, result: SetVarId, universe: Option<FxHashSet<u32>>) -> Self {
420        Self {
421            set,
422            result,
423            universe,
424        }
425    }
426
427    /// Propagate complement constraints
428    ///
429    /// For result = ¬set:
430    /// - x ∈ result ⟺ x ∉ set
431    /// - x ∉ result ⟺ x ∈ set
432    pub fn propagate(
433        &self,
434        set_members: &FxHashSet<u32>,
435        set_non_members: &FxHashSet<u32>,
436    ) -> (FxHashSet<u32>, FxHashSet<u32>) {
437        let mut result_must = FxHashSet::default();
438        let mut result_must_not = FxHashSet::default();
439
440        // x ∉ set ⟹ x ∈ result
441        for &elem in set_non_members {
442            result_must.insert(elem);
443        }
444
445        // x ∈ set ⟹ x ∉ result
446        for &elem in set_members {
447            result_must_not.insert(elem);
448        }
449
450        // If universe is known, we can be more precise
451        if let Some(univ) = &self.universe {
452            for &elem in univ {
453                if !set_members.contains(&elem) && !set_non_members.contains(&elem) {
454                    // Element is unknown in set, so unknown in result
455                } else if set_members.contains(&elem) {
456                    result_must_not.insert(elem);
457                } else {
458                    result_must.insert(elem);
459                }
460            }
461        }
462
463        (result_must, result_must_not)
464    }
465
466    /// Backward propagation from result to set
467    pub fn propagate_backward(
468        &self,
469        result_members: &FxHashSet<u32>,
470        result_non_members: &FxHashSet<u32>,
471    ) -> (FxHashSet<u32>, FxHashSet<u32>) {
472        let mut set_must = FxHashSet::default();
473        let mut set_must_not = FxHashSet::default();
474
475        // x ∈ result ⟹ x ∉ set
476        for &elem in result_members {
477            set_must_not.insert(elem);
478        }
479
480        // x ∉ result ⟹ x ∈ set
481        for &elem in result_non_members {
482            set_must.insert(elem);
483        }
484
485        (set_must, set_must_not)
486    }
487
488    /// Compute cardinality bounds for complement
489    ///
490    /// If universe size is known: |result| = |universe| - |set|
491    pub fn cardinality_bounds(
492        &self,
493        set_card: (i64, Option<i64>),
494        universe_card: Option<i64>,
495    ) -> (i64, Option<i64>) {
496        if let Some(univ_size) = universe_card {
497            match set_card.1 {
498                Some(set_upper) => {
499                    let lower = (univ_size - set_upper).max(0);
500                    let upper = Some(univ_size - set_card.0);
501                    (lower, upper)
502                }
503                None => (0, Some(univ_size - set_card.0)),
504            }
505        } else {
506            // Universe is infinite or unknown
507            (0, None)
508        }
509    }
510}
511
512/// Set operation result
513pub type SetOpResult<T> = std::result::Result<T, SetConflict>;
514
515/// Set operation statistics
516#[derive(Debug, Clone, Default)]
517pub struct SetOpStats {
518    /// Number of union operations
519    pub num_unions: usize,
520    /// Number of intersection operations
521    pub num_intersections: usize,
522    /// Number of difference operations
523    pub num_differences: usize,
524    /// Number of complement operations
525    pub num_complements: usize,
526    /// Number of propagations
527    pub num_propagations: usize,
528}
529
530/// Set operation manager
531#[derive(Debug)]
532pub struct SetOpManager {
533    /// Union operations
534    unions: Vec<SetUnion>,
535    /// Intersection operations
536    intersections: Vec<SetIntersection>,
537    /// Difference operations
538    differences: Vec<SetDifference>,
539    /// Complement operations
540    complements: Vec<SetComplement>,
541    /// Statistics
542    stats: SetOpStats,
543}
544
545impl SetOpManager {
546    /// Create a new operation manager
547    pub fn new() -> Self {
548        Self {
549            unions: Vec::new(),
550            intersections: Vec::new(),
551            differences: Vec::new(),
552            complements: Vec::new(),
553            stats: SetOpStats::default(),
554        }
555    }
556
557    /// Add a union operation
558    pub fn add_union(&mut self, lhs: SetVarId, rhs: SetVarId, result: SetVarId) {
559        self.unions.push(SetUnion::new(lhs, rhs, result));
560        self.stats.num_unions += 1;
561    }
562
563    /// Add an intersection operation
564    pub fn add_intersection(&mut self, lhs: SetVarId, rhs: SetVarId, result: SetVarId) {
565        self.intersections
566            .push(SetIntersection::new(lhs, rhs, result));
567        self.stats.num_intersections += 1;
568    }
569
570    /// Add a difference operation
571    pub fn add_difference(&mut self, lhs: SetVarId, rhs: SetVarId, result: SetVarId) {
572        self.differences.push(SetDifference::new(lhs, rhs, result));
573        self.stats.num_differences += 1;
574    }
575
576    /// Add a complement operation
577    pub fn add_complement(
578        &mut self,
579        set: SetVarId,
580        result: SetVarId,
581        universe: Option<FxHashSet<u32>>,
582    ) {
583        self.complements
584            .push(SetComplement::new(set, result, universe));
585        self.stats.num_complements += 1;
586    }
587
588    /// Get statistics
589    pub fn stats(&self) -> &SetOpStats {
590        &self.stats
591    }
592
593    /// Reset the manager
594    pub fn reset(&mut self) {
595        self.unions.clear();
596        self.intersections.clear();
597        self.differences.clear();
598        self.complements.clear();
599        self.stats = SetOpStats::default();
600    }
601}
602
603impl Default for SetOpManager {
604    fn default() -> Self {
605        Self::new()
606    }
607}
608
609/// Symbolic set expression evaluator
610#[derive(Debug)]
611pub struct SetExprEvaluator {
612    /// Cache of evaluated subexpressions
613    cache: FxHashMap<SetVarId, (FxHashSet<u32>, FxHashSet<u32>)>,
614}
615
616impl SetExprEvaluator {
617    /// Create a new evaluator
618    pub fn new() -> Self {
619        Self {
620            cache: FxHashMap::default(),
621        }
622    }
623
624    /// Evaluate a union expression
625    pub fn eval_union(
626        &mut self,
627        lhs: SetVarId,
628        rhs: SetVarId,
629        lhs_val: (FxHashSet<u32>, FxHashSet<u32>),
630        rhs_val: (FxHashSet<u32>, FxHashSet<u32>),
631    ) -> (FxHashSet<u32>, FxHashSet<u32>) {
632        let union_op = SetUnion::new(lhs, rhs, SetVarId(0));
633        union_op.propagate(&lhs_val.0, &lhs_val.1, &rhs_val.0, &rhs_val.1)
634    }
635
636    /// Evaluate an intersection expression
637    pub fn eval_intersection(
638        &mut self,
639        lhs: SetVarId,
640        rhs: SetVarId,
641        lhs_val: (FxHashSet<u32>, FxHashSet<u32>),
642        rhs_val: (FxHashSet<u32>, FxHashSet<u32>),
643    ) -> (FxHashSet<u32>, FxHashSet<u32>) {
644        let inter_op = SetIntersection::new(lhs, rhs, SetVarId(0));
645        inter_op.propagate(&lhs_val.0, &lhs_val.1, &rhs_val.0, &rhs_val.1)
646    }
647
648    /// Evaluate a difference expression
649    pub fn eval_difference(
650        &mut self,
651        lhs: SetVarId,
652        rhs: SetVarId,
653        lhs_val: (FxHashSet<u32>, FxHashSet<u32>),
654        rhs_val: (FxHashSet<u32>, FxHashSet<u32>),
655    ) -> (FxHashSet<u32>, FxHashSet<u32>) {
656        let diff_op = SetDifference::new(lhs, rhs, SetVarId(0));
657        diff_op.propagate(&lhs_val.0, &lhs_val.1, &rhs_val.0, &rhs_val.1)
658    }
659
660    /// Evaluate a complement expression
661    pub fn eval_complement(
662        &mut self,
663        set: SetVarId,
664        set_val: (FxHashSet<u32>, FxHashSet<u32>),
665        universe: Option<FxHashSet<u32>>,
666    ) -> (FxHashSet<u32>, FxHashSet<u32>) {
667        let comp_op = SetComplement::new(set, SetVarId(0), universe);
668        comp_op.propagate(&set_val.0, &set_val.1)
669    }
670
671    /// Clear the cache
672    pub fn clear_cache(&mut self) {
673        self.cache.clear();
674    }
675}
676
677impl Default for SetExprEvaluator {
678    fn default() -> Self {
679        Self::new()
680    }
681}
682
683/// N-ary set operations for optimization
684#[derive(Debug)]
685pub struct NarySetOp {
686    /// Operation kind
687    pub op: SetBinOp,
688    /// Operands
689    pub operands: SmallVec<[SetVarId; 8]>,
690    /// Result
691    pub result: SetVarId,
692}
693
694impl NarySetOp {
695    /// Create a new n-ary operation
696    pub fn new(op: SetBinOp, operands: SmallVec<[SetVarId; 8]>, result: SetVarId) -> Self {
697        Self {
698            op,
699            operands,
700            result,
701        }
702    }
703
704    /// Flatten nested operations of the same kind
705    #[allow(dead_code)]
706    pub fn flatten(_ops: Vec<SetOp>) -> Vec<NarySetOp> {
707        // TODO: Implement flattening logic
708        Vec::new()
709    }
710
711    /// Propagate n-ary union
712    pub fn propagate_nary_union(
713        &self,
714        operand_vals: &[(FxHashSet<u32>, FxHashSet<u32>)],
715    ) -> (FxHashSet<u32>, FxHashSet<u32>) {
716        let mut result_must = FxHashSet::default();
717        let mut result_must_not = FxHashSet::default();
718
719        // Collect all must-members from all operands
720        for (must, _) in operand_vals {
721            for &elem in must {
722                result_must.insert(elem);
723            }
724        }
725
726        // Element is must-not if it's must-not in all operands
727        if !operand_vals.is_empty() {
728            let first_must_not = &operand_vals[0].1;
729            for &elem in first_must_not {
730                if operand_vals
731                    .iter()
732                    .all(|(_, must_not)| must_not.contains(&elem))
733                {
734                    result_must_not.insert(elem);
735                }
736            }
737        }
738
739        (result_must, result_must_not)
740    }
741
742    /// Propagate n-ary intersection
743    pub fn propagate_nary_intersection(
744        &self,
745        operand_vals: &[(FxHashSet<u32>, FxHashSet<u32>)],
746    ) -> (FxHashSet<u32>, FxHashSet<u32>) {
747        let mut result_must = FxHashSet::default();
748        let mut result_must_not = FxHashSet::default();
749
750        // Element is must if it's must in all operands
751        if !operand_vals.is_empty() {
752            let first_must = &operand_vals[0].0;
753            for &elem in first_must {
754                if operand_vals.iter().all(|(must, _)| must.contains(&elem)) {
755                    result_must.insert(elem);
756                }
757            }
758        }
759
760        // Collect all must-not-members from any operand
761        for (_, must_not) in operand_vals {
762            for &elem in must_not {
763                result_must_not.insert(elem);
764            }
765        }
766
767        (result_must, result_must_not)
768    }
769}
770
771/// Cartesian product operation
772#[derive(Debug, Clone)]
773pub struct CartesianProduct {
774    /// First set
775    #[allow(dead_code)]
776    pub lhs: SetVarId,
777    /// Second set
778    #[allow(dead_code)]
779    pub rhs: SetVarId,
780    /// Result set (of pairs)
781    #[allow(dead_code)]
782    pub result: SetVarId,
783}
784
785impl CartesianProduct {
786    /// Create a new cartesian product
787    pub fn new(lhs: SetVarId, rhs: SetVarId, result: SetVarId) -> Self {
788        Self { lhs, rhs, result }
789    }
790
791    /// Compute cardinality of cartesian product
792    ///
793    /// |lhs × rhs| = |lhs| * |rhs|
794    pub fn cardinality_bounds(
795        &self,
796        lhs_card: (i64, Option<i64>),
797        rhs_card: (i64, Option<i64>),
798    ) -> (i64, Option<i64>) {
799        let lower = lhs_card.0 * rhs_card.0;
800        let upper = match (lhs_card.1, rhs_card.1) {
801            (Some(l), Some(r)) => Some(l * r),
802            _ => None,
803        };
804        (lower, upper)
805    }
806
807    /// Check if a pair is in the cartesian product
808    #[allow(dead_code)]
809    pub fn contains_pair(
810        &self,
811        pair: (u32, u32),
812        lhs_members: &FxHashSet<u32>,
813        rhs_members: &FxHashSet<u32>,
814    ) -> Option<bool> {
815        let lhs_contains = lhs_members.contains(&pair.0);
816        let rhs_contains = rhs_members.contains(&pair.1);
817
818        if lhs_contains && rhs_contains {
819            Some(true)
820        } else if !lhs_contains || !rhs_contains {
821            Some(false)
822        } else {
823            None
824        }
825    }
826}
827
828/// Symmetric difference operation
829#[derive(Debug, Clone)]
830pub struct SymmetricDifference {
831    /// Left operand
832    #[allow(dead_code)]
833    pub lhs: SetVarId,
834    /// Right operand
835    #[allow(dead_code)]
836    pub rhs: SetVarId,
837    /// Result variable
838    #[allow(dead_code)]
839    pub result: SetVarId,
840}
841
842impl SymmetricDifference {
843    /// Create a new symmetric difference operation
844    pub fn new(lhs: SetVarId, rhs: SetVarId, result: SetVarId) -> Self {
845        Self { lhs, rhs, result }
846    }
847
848    /// Propagate symmetric difference constraints
849    ///
850    /// For result = lhs △ rhs = (lhs \ rhs) ∪ (rhs \ lhs):
851    /// - x ∈ result ⟺ (x ∈ lhs ∧ x ∉ rhs) ∨ (x ∈ rhs ∧ x ∉ lhs)
852    /// - x ∈ result ⟺ x ∈ lhs ⊕ x ∈ rhs
853    pub fn propagate(
854        &self,
855        lhs_members: &FxHashSet<u32>,
856        lhs_non_members: &FxHashSet<u32>,
857        rhs_members: &FxHashSet<u32>,
858        rhs_non_members: &FxHashSet<u32>,
859    ) -> (FxHashSet<u32>, FxHashSet<u32>) {
860        let mut result_must = FxHashSet::default();
861        let mut result_must_not = FxHashSet::default();
862
863        // x ∈ lhs ∧ x ∉ rhs ⟹ x ∈ result
864        for &elem in lhs_members {
865            if rhs_non_members.contains(&elem) {
866                result_must.insert(elem);
867            }
868        }
869
870        // x ∈ rhs ∧ x ∉ lhs ⟹ x ∈ result
871        for &elem in rhs_members {
872            if lhs_non_members.contains(&elem) {
873                result_must.insert(elem);
874            }
875        }
876
877        // x ∈ lhs ∧ x ∈ rhs ⟹ x ∉ result
878        for &elem in lhs_members {
879            if rhs_members.contains(&elem) {
880                result_must_not.insert(elem);
881            }
882        }
883
884        // x ∉ lhs ∧ x ∉ rhs ⟹ x ∉ result
885        for &elem in lhs_non_members {
886            if rhs_non_members.contains(&elem) {
887                result_must_not.insert(elem);
888            }
889        }
890
891        (result_must, result_must_not)
892    }
893
894    /// Compute cardinality bounds for symmetric difference
895    ///
896    /// |result| = |lhs| + |rhs| - 2|lhs ∩ rhs|
897    #[allow(dead_code)]
898    pub fn cardinality_bounds(
899        &self,
900        lhs_card: (i64, Option<i64>),
901        rhs_card: (i64, Option<i64>),
902        intersection_card: (i64, Option<i64>),
903    ) -> (i64, Option<i64>) {
904        let lower = (lhs_card.0 + rhs_card.0 - 2 * intersection_card.1.unwrap_or(0)).max(0);
905        let upper = match (lhs_card.1, rhs_card.1) {
906            (Some(l), Some(r)) => Some(l + r - 2 * intersection_card.0),
907            _ => None,
908        };
909        (lower, upper)
910    }
911}
912
913#[cfg(test)]
914mod tests {
915    use super::*;
916
917    #[test]
918    fn test_union_propagation() {
919        let lhs = SetVarId(0);
920        let rhs = SetVarId(1);
921        let result = SetVarId(2);
922        let union = SetUnion::new(lhs, rhs, result);
923
924        let mut lhs_members = FxHashSet::default();
925        lhs_members.insert(1);
926        lhs_members.insert(2);
927
928        let mut rhs_members = FxHashSet::default();
929        rhs_members.insert(3);
930
931        let lhs_non = FxHashSet::default();
932        let rhs_non = FxHashSet::default();
933
934        let (result_must, _) = union.propagate(&lhs_members, &lhs_non, &rhs_members, &rhs_non);
935
936        assert!(result_must.contains(&1));
937        assert!(result_must.contains(&2));
938        assert!(result_must.contains(&3));
939        assert_eq!(result_must.len(), 3);
940    }
941
942    #[test]
943    fn test_intersection_propagation() {
944        let lhs = SetVarId(0);
945        let rhs = SetVarId(1);
946        let result = SetVarId(2);
947        let intersection = SetIntersection::new(lhs, rhs, result);
948
949        let mut lhs_members = FxHashSet::default();
950        lhs_members.insert(1);
951        lhs_members.insert(2);
952        lhs_members.insert(3);
953
954        let mut rhs_members = FxHashSet::default();
955        rhs_members.insert(2);
956        rhs_members.insert(3);
957        rhs_members.insert(4);
958
959        let lhs_non = FxHashSet::default();
960        let rhs_non = FxHashSet::default();
961
962        let (result_must, _) =
963            intersection.propagate(&lhs_members, &lhs_non, &rhs_members, &rhs_non);
964
965        assert!(!result_must.contains(&1));
966        assert!(result_must.contains(&2));
967        assert!(result_must.contains(&3));
968        assert!(!result_must.contains(&4));
969        assert_eq!(result_must.len(), 2);
970    }
971
972    #[test]
973    fn test_difference_propagation() {
974        let lhs = SetVarId(0);
975        let rhs = SetVarId(1);
976        let result = SetVarId(2);
977        let difference = SetDifference::new(lhs, rhs, result);
978
979        let mut lhs_members = FxHashSet::default();
980        lhs_members.insert(1);
981        lhs_members.insert(2);
982        lhs_members.insert(3);
983
984        let mut rhs_members = FxHashSet::default();
985        rhs_members.insert(2);
986        rhs_members.insert(4);
987
988        let mut rhs_non = FxHashSet::default();
989        rhs_non.insert(1);
990        rhs_non.insert(3);
991
992        let lhs_non = FxHashSet::default();
993
994        let (result_must, result_must_not) =
995            difference.propagate(&lhs_members, &lhs_non, &rhs_members, &rhs_non);
996
997        // 1 ∈ lhs and 1 ∉ rhs, so 1 ∈ result
998        assert!(result_must.contains(&1));
999        // 3 ∈ lhs and 3 ∉ rhs, so 3 ∈ result
1000        assert!(result_must.contains(&3));
1001        // 2 ∈ rhs, so 2 ∉ result
1002        assert!(result_must_not.contains(&2));
1003    }
1004
1005    #[test]
1006    fn test_complement_propagation() {
1007        let set = SetVarId(0);
1008        let result = SetVarId(1);
1009
1010        let mut universe = FxHashSet::default();
1011        for i in 1..=5 {
1012            universe.insert(i);
1013        }
1014
1015        let complement = SetComplement::new(set, result, Some(universe));
1016
1017        let mut set_members = FxHashSet::default();
1018        set_members.insert(1);
1019        set_members.insert(2);
1020
1021        let mut set_non = FxHashSet::default();
1022        set_non.insert(4);
1023        set_non.insert(5);
1024
1025        let (result_must, result_must_not) = complement.propagate(&set_members, &set_non);
1026
1027        // 4, 5 ∉ set, so they ∈ result
1028        assert!(result_must.contains(&4));
1029        assert!(result_must.contains(&5));
1030        // 1, 2 ∈ set, so they ∉ result
1031        assert!(result_must_not.contains(&1));
1032        assert!(result_must_not.contains(&2));
1033    }
1034
1035    #[test]
1036    fn test_symmetric_difference() {
1037        let lhs = SetVarId(0);
1038        let rhs = SetVarId(1);
1039        let result = SetVarId(2);
1040        let symdiff = SymmetricDifference::new(lhs, rhs, result);
1041
1042        let mut lhs_members = FxHashSet::default();
1043        lhs_members.insert(1);
1044        lhs_members.insert(2);
1045        lhs_members.insert(3);
1046
1047        let mut rhs_members = FxHashSet::default();
1048        rhs_members.insert(2);
1049        rhs_members.insert(3);
1050        rhs_members.insert(4);
1051
1052        let mut lhs_non = FxHashSet::default();
1053        lhs_non.insert(4);
1054        lhs_non.insert(5);
1055
1056        let mut rhs_non = FxHashSet::default();
1057        rhs_non.insert(1);
1058        rhs_non.insert(5);
1059
1060        let (result_must, result_must_not) =
1061            symdiff.propagate(&lhs_members, &lhs_non, &rhs_members, &rhs_non);
1062
1063        // 1 ∈ lhs, 1 ∉ rhs => 1 ∈ result
1064        assert!(result_must.contains(&1));
1065        // 4 ∈ rhs, 4 ∉ lhs => 4 ∈ result
1066        assert!(result_must.contains(&4));
1067        // 2 ∈ lhs, 2 ∈ rhs => 2 ∉ result
1068        assert!(result_must_not.contains(&2));
1069        // 3 ∈ lhs, 3 ∈ rhs => 3 ∉ result
1070        assert!(result_must_not.contains(&3));
1071        // 5 ∉ lhs, 5 ∉ rhs => 5 ∉ result
1072        assert!(result_must_not.contains(&5));
1073    }
1074
1075    #[test]
1076    fn test_union_cardinality_bounds() {
1077        let union = SetUnion::new(SetVarId(0), SetVarId(1), SetVarId(2));
1078
1079        let lhs_card = (2, Some(5));
1080        let rhs_card = (3, Some(4));
1081
1082        let (lower, upper) = union.cardinality_bounds(lhs_card, rhs_card);
1083
1084        assert_eq!(lower, 3); // max(2, 3)
1085        assert_eq!(upper, Some(9)); // 5 + 4
1086    }
1087
1088    #[test]
1089    fn test_intersection_cardinality_bounds() {
1090        let intersection = SetIntersection::new(SetVarId(0), SetVarId(1), SetVarId(2));
1091
1092        let lhs_card = (2, Some(5));
1093        let rhs_card = (3, Some(4));
1094
1095        let (lower, upper) = intersection.cardinality_bounds(lhs_card, rhs_card);
1096
1097        assert_eq!(lower, 0);
1098        assert_eq!(upper, Some(4)); // min(5, 4)
1099    }
1100
1101    #[test]
1102    fn test_cartesian_product_cardinality() {
1103        let product = CartesianProduct::new(SetVarId(0), SetVarId(1), SetVarId(2));
1104
1105        let lhs_card = (2, Some(3));
1106        let rhs_card = (4, Some(5));
1107
1108        let (lower, upper) = product.cardinality_bounds(lhs_card, rhs_card);
1109
1110        assert_eq!(lower, 8); // 2 * 4
1111        assert_eq!(upper, Some(15)); // 3 * 5
1112    }
1113
1114    #[test]
1115    fn test_set_op_builder() {
1116        let builder = SetOpBuilder::new()
1117            .union(SetVarId(0), SetVarId(1), SetVarId(2))
1118            .intersection(SetVarId(2), SetVarId(3), SetVarId(4));
1119
1120        let ops = builder.build();
1121        assert_eq!(ops.len(), 2);
1122    }
1123
1124    #[test]
1125    fn test_nary_union_propagation() {
1126        let op = NarySetOp::new(
1127            SetBinOp::Union,
1128            SmallVec::from_vec(vec![SetVarId(0), SetVarId(1), SetVarId(2)]),
1129            SetVarId(3),
1130        );
1131
1132        let mut operand_vals = Vec::new();
1133
1134        let mut set1 = (FxHashSet::default(), FxHashSet::default());
1135        set1.0.insert(1);
1136        operand_vals.push(set1);
1137
1138        let mut set2 = (FxHashSet::default(), FxHashSet::default());
1139        set2.0.insert(2);
1140        operand_vals.push(set2);
1141
1142        let mut set3 = (FxHashSet::default(), FxHashSet::default());
1143        set3.0.insert(3);
1144        operand_vals.push(set3);
1145
1146        let (result_must, _) = op.propagate_nary_union(&operand_vals);
1147
1148        assert!(result_must.contains(&1));
1149        assert!(result_must.contains(&2));
1150        assert!(result_must.contains(&3));
1151        assert_eq!(result_must.len(), 3);
1152    }
1153
1154    #[test]
1155    fn test_nary_intersection_propagation() {
1156        let op = NarySetOp::new(
1157            SetBinOp::Intersection,
1158            SmallVec::from_vec(vec![SetVarId(0), SetVarId(1), SetVarId(2)]),
1159            SetVarId(3),
1160        );
1161
1162        let mut operand_vals = Vec::new();
1163
1164        let mut set1 = (FxHashSet::default(), FxHashSet::default());
1165        set1.0.insert(1);
1166        set1.0.insert(2);
1167        operand_vals.push(set1);
1168
1169        let mut set2 = (FxHashSet::default(), FxHashSet::default());
1170        set2.0.insert(2);
1171        set2.0.insert(3);
1172        operand_vals.push(set2);
1173
1174        let mut set3 = (FxHashSet::default(), FxHashSet::default());
1175        set3.0.insert(2);
1176        set3.0.insert(4);
1177        operand_vals.push(set3);
1178
1179        let (result_must, _) = op.propagate_nary_intersection(&operand_vals);
1180
1181        assert!(!result_must.contains(&1));
1182        assert!(result_must.contains(&2)); // 2 is in all three sets
1183        assert!(!result_must.contains(&3));
1184        assert!(!result_must.contains(&4));
1185        assert_eq!(result_must.len(), 1);
1186    }
1187}