1use std::collections::HashSet;
2use std::collections::{HashMap, VecDeque, hash_map::DefaultHasher};
3use std::hash::{Hash, Hasher};
4use bimap::BiMap;
5use std::ops::RangeInclusive;
6use itertools::Itertools;
7use indexmap::{IndexMap, IndexSet};
8
9pub type Bound = (i64, i64);
10pub type ID = String;
11
12fn create_hash(data: &Vec<(String, i64)>, num: i64) -> u64 {
14    let mut hasher = DefaultHasher::new();
16    
17    for (s, i) in data {
19        s.hash(&mut hasher);
20        i.hash(&mut hasher);
21    }
22    
23    num.hash(&mut hasher);
25    
26    hasher.finish()
28}
29
30fn bound_fixed(b: Bound) -> bool {
31    b.0 == b.1
33}
34
35fn bound_bool(b: Bound) -> bool {
36    b.0 == 0 && b.1 == 1
38}
39
40fn bound_add(b1: Bound, b2: Bound) -> Bound {
41    return (b1.0 + b2.0, b1.1 + b2.1);
42}
43
44fn bound_multiply(k: i64, b: Bound) -> Bound {
45    if k < 0 {
46        return (k*b.1, k*b.0);
47    } else {
48        return (k*b.0, k*b.1);
49    }
50}
51
52fn bound_span(b: Bound) -> i64 {
53    return (b.1 - b.0).abs();
55}
56
57pub struct SparseIntegerMatrix {
58    pub rows: Vec<usize>,
59    pub cols: Vec<usize>,
60    pub vals: Vec<i64>,
61    pub shape: (usize, usize),
62}
63
64pub struct DenseIntegerMatrix {
65    pub data: Vec<Vec<i64>>,
66    pub shape: (usize, usize),
67}
68
69impl DenseIntegerMatrix {
70    pub fn new(rows: usize, cols: usize) -> DenseIntegerMatrix {
71        DenseIntegerMatrix {
72            data: vec![vec![0; cols]; rows],
73            shape: (rows, cols),
74        }
75    }
76
77    pub fn dot_product(&self, vector: &Vec<i64>) -> Vec<i64> {
78        let mut result = vec![0; self.shape.0];
79        for i in 0..self.shape.0 {
80            for j in 0..self.shape.1 {
81                result[i] += self.data[i][j] * vector[j];
82            }
83        }
84        result
85    }
86}
87
88pub struct DensePolyhedron {
89    pub A: DenseIntegerMatrix,
90    pub b: Vec<i64>,
91    pub columns: Vec<String>,
92    pub integer_columns: Vec<String>,
93}
94
95impl DensePolyhedron {
96    pub fn to_vector(&self, from_interpretation: &HashMap<String, i64>) -> Vec<i64> {
97        let mut vector: Vec<i64> = vec![0; self.columns.len()];
98        for (index, v) in from_interpretation
99            .iter()
100            .filter_map(|(k, v)| self.columns.iter().position(|col| col == k).map(|index| (index, v)))
101        {
102            vector[index] = *v;
103        }
104        vector
105    }
106
107    pub fn assume(&self, values: &HashMap<String, i64>) -> DensePolyhedron {
108        let mut new_A_data   = self.A.data.clone();     let mut new_b        = self.b.clone();          let mut new_columns  = self.columns.clone();    let mut new_int_cols = self.integer_columns.clone();  let mut to_remove: Vec<(usize, String, i64)> = values
117            .iter()
118            .filter_map(|(name, &val)| {
119                self.columns
121                    .iter()
122                    .position(|col| col == name)
123                    .map(|idx| (idx, name.clone(), val))
124            })
125            .collect();
126
127        to_remove.sort_by(|a, b| b.0.cmp(&a.0));
130
131        for (col_idx, col_name, fixed_val) in to_remove {
136            for row in 0..new_A_data.len() {
137                new_b[row] -= new_A_data[row][col_idx] * fixed_val;
139                new_A_data[row].remove(col_idx);
141            }
142            new_columns.remove(col_idx);
144            new_int_cols.retain(|c| c != &col_name);
146        }
147
148        let new_shape = (new_A_data.len(), new_columns.len());
150        let new_A = DenseIntegerMatrix {
151            data: new_A_data,
152            shape: new_shape,
153        };
154
155        DensePolyhedron {
157            A: new_A,
158            b: new_b,
159            columns: new_columns,
160            integer_columns: new_int_cols,
161        }
162    }
163
164    pub fn evaluate(&self, interpretation: &IndexMap<String, Bound>) -> Bound {
165        let mut lower_bounds = HashMap::new();
166        let mut upper_bounds = HashMap::new();
167        for (key, bound) in interpretation {
168            lower_bounds.insert(key.clone(), bound.0);
169            upper_bounds.insert(key.clone(), bound.1);
170        }
171
172        let lower_result = self.A.dot_product(&self.to_vector(&lower_bounds))
173            .iter()
174            .zip(&self.b)
175            .all(|(a, b)| a >= b);
176
177        let upper_result = self.A.dot_product(&self.to_vector(&upper_bounds))
178            .iter()
179            .zip(&self.b)
180            .all(|(a, b)| a >= b);
181
182        (lower_result as i64, upper_result as i64)
183    }
184}
185
186impl From<SparseIntegerMatrix> for DenseIntegerMatrix {
187    fn from(sparse: SparseIntegerMatrix) -> DenseIntegerMatrix {
188        let mut dense = DenseIntegerMatrix::new(sparse.shape.0, sparse.shape.1);
189        for ((&row, &col), &val) in sparse.rows.iter().zip(&sparse.cols).zip(&sparse.vals) {
190            dense.data[row][col] = val;
191        }
192        dense
193    }
194}
195
196impl From<DenseIntegerMatrix> for SparseIntegerMatrix {
197    fn from(dense: DenseIntegerMatrix) -> SparseIntegerMatrix {
198        let mut rows = Vec::new();
199        let mut cols = Vec::new();
200        let mut vals = Vec::new();
201        for (i, row) in dense.data.iter().enumerate() {
202            for (j, &val) in row.iter().enumerate() {
203                if val != 0 {
204                    rows.push(i);
205                    cols.push(j);
206                    vals.push(val);
207                }
208            }
209        }
210        SparseIntegerMatrix {
211            rows,
212            cols,
213            vals,
214            shape: dense.shape,
215        }
216    }
217}
218
219impl SparseIntegerMatrix {
220    pub fn new() -> SparseIntegerMatrix {
221        SparseIntegerMatrix {
222            rows: Vec::new(),
223            cols: Vec::new(),
224            vals: Vec::new(),
225            shape: (0, 0),
226        }
227    }
228}
229
230pub struct SparsePolyhedron {
231    pub A: SparseIntegerMatrix,
233    pub b: Vec<i64>,
234    pub columns: Vec<String>,
235    pub integer_columns: Vec<String>,
236}
237
238impl From<SparsePolyhedron> for DensePolyhedron {
239    fn from(sparse: SparsePolyhedron) -> DensePolyhedron {
240        let mut dense_matrix = DenseIntegerMatrix::new(sparse.A.shape.0, sparse.A.shape.1);
241        for ((&row, &col), &val) in sparse.A.rows.iter().zip(&sparse.A.cols).zip(&sparse.A.vals) {
242            dense_matrix.data[row][col] = val;
243        }
244        DensePolyhedron {
245            A: dense_matrix,
246            b: sparse.b,
247            columns: sparse.columns,
248            integer_columns: sparse.integer_columns,
249        }
250    }
251}
252
253impl From<DensePolyhedron> for SparsePolyhedron {
254    fn from(dense: DensePolyhedron) -> SparsePolyhedron {
255        let mut rows = Vec::new();
256        let mut cols = Vec::new();
257        let mut vals = Vec::new();
258        for (i, row) in dense.A.data.iter().enumerate() {
259            for (j, &val) in row.iter().enumerate() {
260                if val != 0 {
261                    rows.push(i);
262                    cols.push(j);
263                    vals.push(val);
264                }
265            }
266        }
267        SparsePolyhedron {
268            A: SparseIntegerMatrix {
269                rows,
270                cols,
271                vals,
272                shape: dense.A.shape,
273            },
274            b: dense.b,
275            columns: dense.columns,
276            integer_columns: dense.integer_columns,
277        }
278    }
279}
280
281pub type Coefficient = (String, i64);
282pub struct Constraint {
283    pub coefficients: Vec<Coefficient>,
284    pub bias: Bound,
285}
286
287impl Constraint {
288
289    pub fn dot(&self, values: &IndexMap<String, Bound>) -> Bound {
290        self.coefficients.iter().fold((0, 0), |acc, (key, coeff)| {
291            let bound = values.get(key).unwrap_or(&(0, 0));
292            let (min, max) = bound_multiply(*coeff, *bound);
293            (acc.0 + min, acc.1 + max)
294        })
295    }
296
297    pub fn evaluate(&self, values: &IndexMap<String, Bound>) -> Bound {
298        let bound = self.dot(values);
299        return (
300            (bound.0 + self.bias.0 >= 0) as i64,
301            (bound.1 + self.bias.1 >= 0) as i64
302        )
303    }
304
305    pub fn negate(&self) -> Constraint {
306        Constraint {
307            coefficients: self.coefficients.iter().map(|(key, val)| {
308                (key.clone(), -val)
309            }).collect(),
310            bias: (
311                -self.bias.0-1,
312                 -self.bias.1-1
313            ),
314        }
315    }
316}
317pub enum BoolExpression {
318    Composite(Constraint),
319    Primitive(Bound),
320}
321
322pub struct Pldag {
323    pub _amat: IndexMap<String, BoolExpression>,
325    pub _amap: BiMap<String, String>,
327}
328
329impl Pldag {
330
331    fn new() -> Pldag {
332        Pldag {
333            _amat: IndexMap::new(),
334            _amap: BiMap::new(),
335        }
336    }
337
338    pub fn transitive_dependencies(&self) -> HashMap<ID, HashSet<ID>> {
340        let mut memo: HashMap<String, HashSet<String>> = HashMap::new();
342        let mut result: HashMap<String, HashSet<String>> = HashMap::new();
343
344        for key in self._amat.keys() {
345            let deps = self._collect_deps(key, &mut memo);
347            result.insert(key.clone(), deps);
348        }
349
350        result
351    }
352
353    fn _collect_deps(&self, node: &ID, memo: &mut HashMap<ID, HashSet<ID>>) -> HashSet<ID> {
355        if let Some(cached) = memo.get(node) {
357            return cached.clone();
358        }
359
360        let mut deps = HashSet::new();
361
362        if let Some(expr) = self._amat.get(node) {
363            if let BoolExpression::Composite(constraint) = expr {
364                for (child_id, _) in &constraint.coefficients {
365                    deps.insert(child_id.clone());
367                    let sub = self._collect_deps(child_id, memo);
369                    deps.extend(sub);
370                }
371            }
372            }
374
375        memo.insert(node.clone(), deps.clone());
377        deps
378    }
379
380    fn primitive_combinations(&self) -> impl Iterator<Item = HashMap<ID, i64>> {
381        let primitives: Vec<(String, (i64, i64))> = self._amat
383            .iter()
384            .filter_map(|(key, expr)| {
385                if let BoolExpression::Primitive(bound) = expr {
386                    Some((key.clone(), *bound))
387                } else {
388                    None
389                }
390            })
391            .collect();
392
393        let keys: Vec<String> = primitives.iter().map(|(k, _)| k.clone()).collect();
395
396        let ranges: Vec<RangeInclusive<i64>> = primitives
398            .iter()
399            .map(|(_, (low, high))| *low..=*high)
400            .collect();
401
402        ranges
404            .into_iter()
405            .map(|r| r.collect::<Vec<_>>())
406            .multi_cartesian_product()
407            .map(move |values| {
409                keys.iter()
410                    .cloned()
411                    .zip(values.into_iter())
412                    .collect::<HashMap<String, i64>>()
413            })
414    }
415
416    pub fn get_id(&self, from_alias: &String) -> ID {
417        if let Some(id) = self._amap.get_by_left(from_alias) {
419            return id.clone();
420        }
421        return from_alias.clone();
423    }
424
425    pub fn get_alias(&self, from_id: &ID) -> String {
426        if let Some(alias) = self._amap.get_by_right(from_id) {
428            return alias.clone();
429        }
430        return from_id.clone();
432    }
433
434    pub fn check_combination(&self, interpretation: &IndexMap<ID, Bound>) -> IndexMap<ID, Bound> {
435
436        let mut result= interpretation.clone();
437
438        for (key, value) in self._amat.iter() {
440            if !result.contains_key(key) {
441                if let BoolExpression::Primitive(bound) = value {
442                    result.insert(key.clone(), *bound);
443                }
444            }
445        }
446
447        let mut S: VecDeque<String> = self._amat
451            .iter()
452            .filter(|(key, constraint)| {
453                match constraint {
454                    BoolExpression::Composite(composite) => {
455                    composite.coefficients.iter().all(|x| {
456                        match self._amat.get(&x.0) {
457                        Some(BoolExpression::Primitive(_)) => true,
458                        _ => false
459                        }
460                    }) && !result.contains_key(&key.to_string())
461                    },
462                    BoolExpression::Primitive(_) => false,
463                }
464            })
465            .map(|(key, _)| key.clone())
466            .collect();
467
468        let mut visited = HashSet::new();
472
473        while let Some(s) = S.pop_front() {
475
476            if visited.contains(&s) {
478                panic!("Cycle detected in the graph");
479            }
480
481            visited.insert(s.clone());
483            
484            match self._amat.get(&s) {
487                Some(BoolExpression::Composite(composite)) => {
488                    result.insert(
489                        s.clone(), 
490                        composite.evaluate(&result)
491                    );
492        
493                    let incoming = self._amat
496                        .iter()
497                        .filter(|(key, sub_constraint)| {
498                            !result.contains_key(&key.to_string()) && match sub_constraint {
499                                BoolExpression::Composite(sub_composite) => {
500                                    sub_composite.coefficients.iter().any(|x| x.0 == s)
501                                },
502                                _ => false
503                            }
504                        })
505                        .map(|(key, _)| key.clone())
506                        .collect::<Vec<String>>();
507        
508                    for incoming_id in incoming {
510                        if !S.contains(&incoming_id) {
511                            S.push_back(incoming_id);
512                        }
513                    }
514                },
515                _ => {}
516            }
517        }
518
519        return result;
520    }
521    
522    pub fn check_combination_default(&self) -> IndexMap<ID, Bound> {
523        let interpretation: IndexMap<String, Bound> = self._amap.iter().filter_map(|(key, value)| {
524            if let Some(bound) = self._amat.get(key) {
525                if let BoolExpression::Primitive(bound) = bound {
526                    Some((value.clone(), *bound))
527                } else {
528                    None
529                }
530            } else {
531                None
532            }
533        }).collect();
534        self.check_combination(&interpretation)
535    }
536    
537    pub fn score_combination_batch(&self, interpretation: &IndexMap<ID, Bound>, weight_sets: &Vec<&IndexMap<ID, f64>>) -> Vec<HashMap<ID, (f64, f64)>> {
538        let trans_deps = self.transitive_dependencies();
539        let mut result = Vec::new();
540        for weights in weight_sets {
541            let mut local_result = HashMap::new();
542    
543            for (variable, dependencies) in trans_deps.iter() {
544                let variable_bounds = interpretation.get(variable.as_str()).unwrap_or(&(0, 1));
545                if dependencies.len() > 0 {
546                    let dependency_weighted_bound = dependencies.iter()
547                        .filter_map(|dep| interpretation.get(dep).map(|bound| {
548                            let weight = weights.get(dep).unwrap_or(&0.0);
549                            (
550                                (bound.0 as f64) * weight,
551                                (bound.1 as f64) * weight,
552                            )
553                        }))
554                        .fold((0.0, 0.0), |acc, (low, high)| {
555                            (acc.0 + low, acc.1 + high)
556                        });
557                        local_result.insert(variable.clone(), (dependency_weighted_bound.0 * variable_bounds.0 as f64, dependency_weighted_bound.1 * variable_bounds.1 as f64));
558                } else {
559                    let weight = weights.get(variable.as_str()).unwrap_or(&0.0);
560                    local_result.insert(variable.clone(), (variable_bounds.0 as f64 * weight, variable_bounds.1 as f64 * weight));
561                }
562            }
563
564            result.push(local_result);
566        }
567        return result;
568    }
569
570    pub fn score_combination(&self, interpretation: &IndexMap<ID, Bound>, weights: &IndexMap<ID, f64>) -> HashMap<ID, (f64, f64)> {
571        return self.score_combination_batch(interpretation, &vec![weights]).get(0).unwrap().clone();
572    }
573    
574    pub fn check_and_score(&self, interpretation: &IndexMap<ID, Bound>, weights: &IndexMap<ID, f64>) -> HashMap<ID, (f64, f64)> {
575        self.score_combination(&self.check_combination(interpretation), weights)
576    }
577
578    pub fn check_and_score_default(&self, weights: &IndexMap<ID, f64>) -> HashMap<ID, (f64, f64)> {
579        self.score_combination(&self.check_combination_default(), weights)
580    }
581
582    pub fn to_sparse_polyhedron(&self, double_binding: bool, integer_constraints: bool, fixed_constraints: bool) -> SparsePolyhedron {
583
584        fn get_coef_bounds(composite: &Constraint, amat: &IndexMap<String, BoolExpression>) -> IndexMap<String, Bound> {
585            let mut coef_bounds: IndexMap<String, Bound> = IndexMap::new();
586            for (coef_key, _) in composite.coefficients.iter() {
587                let coef_exp = amat.get(&coef_key.to_string())
588                    .unwrap_or_else(|| panic!("Coefficient key '{}' not found in _amat", coef_key));
589                match coef_exp {
590                    BoolExpression::Primitive(bound) => {
591                        coef_bounds.insert(coef_key.to_string(), *bound);
592                    },
593                    _ => {coef_bounds.insert(coef_key.to_string(), (0,1));}
594                }
595            }
596            return coef_bounds;
597        }
598
599        let mut A_matrix = SparseIntegerMatrix::new();
601        let mut b_vector: Vec<i64> = Vec::new();
602
603        let primitives: HashMap<&String, Bound> = self._amat.iter()
605            .filter_map(|(key, value)| {
606                if let BoolExpression::Primitive(bound) = value {
607                    Some((key, *bound))
608                } else {
609                    None
610                }
611            })
612            .collect();
613
614        let composites: HashMap<&String, &Constraint> = self._amat.iter()
616            .filter_map(|(key, value)| {
617                if let BoolExpression::Composite(constraint) = value {
618                    Some((key, constraint))
619                } else {
620                    None
621                }
622            })
623            .collect();
624
625        let column_names_map: IndexMap<String, usize> = primitives.keys().chain(composites.keys()).enumerate().map(|(i, key)| (key.to_string(), i)).collect();
627
628        let mut row_i: usize = 0;
630
631        for (key, composite) in composites {
632
633            let ki = *column_names_map.get(key).unwrap();
635
636            let coef_bounds = get_coef_bounds(composite, &self._amat);
638
639            let ib_phi = composite.dot(&coef_bounds);
647
648            let d_pi = std::cmp::max(ib_phi.0.abs(), ib_phi.1.abs());
650            
651            A_matrix.rows.push(row_i);
653            A_matrix.cols.push(ki);
654            A_matrix.vals.push(-d_pi);
655
656            for (coef_key, coef_val) in composite.coefficients.iter() {
658                let ck_index: usize = *column_names_map.get(coef_key).unwrap();
659                A_matrix.rows.push(row_i);
660                A_matrix.cols.push(ck_index);
661                A_matrix.vals.push(*coef_val);
662            }
663
664            let b_phi = composite.bias.0 + d_pi;
666            b_vector.push(-1 * b_phi);
667
668            if double_binding {
669
670                let phi_prim = composite.negate();
679                let phi_prim_ib = phi_prim.dot(&coef_bounds);
680                let d_phi_prim = std::cmp::max(phi_prim_ib.0.abs(), phi_prim_ib.1.abs());
681                let pi_coef = d_phi_prim - phi_prim.bias.0;
682
683                A_matrix.rows.push(row_i + 1);
685                A_matrix.cols.push(ki);
686                A_matrix.vals.push(pi_coef);
687
688                for (phi_coef_key, phi_coef_val) in phi_prim.coefficients.iter() {
690                    let ck_index: usize = *column_names_map.get(phi_coef_key).unwrap();
691                    A_matrix.rows.push(row_i + 1);
692                    A_matrix.cols.push(ck_index);
693                    A_matrix.vals.push(*phi_coef_val);
694                }
695
696                b_vector.push(-1 * phi_prim.bias.0);
698
699                row_i += 1;
701            }
702
703            row_i += 1;
705        }
706
707        if fixed_constraints {
708            let mut fixed_bound_map: HashMap<i64, Vec<usize>> = HashMap::new();
711            for (key, bound) in primitives.iter().filter(|(_, bound)| bound_fixed(**bound)) {
712                fixed_bound_map.entry(bound.0).or_insert_with(Vec::new).push(*column_names_map.get(&key.to_string()).unwrap());
713            }
714    
715            for (v, primitive_ids) in fixed_bound_map.iter() {
716                let b = *v * primitive_ids.len() as i64;
717                for i in vec![-1, 1] {
718                    for primitive_id in primitive_ids {
719                        A_matrix.rows.push(row_i);
720                        A_matrix.cols.push(*primitive_id);
721                        A_matrix.vals.push(i);
722                    }
723                    b_vector.push(i * b);
724                    row_i += 1;
725                }
726            }
727        }
728
729        let mut integer_variables: Vec<String> = Vec::new();
731
732        for (p_key, p_bound) in primitives.iter().filter(|(_, bound)| bound.0 < 0 || bound.1 > 1) {
734            
735            integer_variables.push(p_key.to_string());
737            
738            if integer_constraints {
739                let pi = *column_names_map.get(&p_key.to_string()).unwrap();
741                
742                if p_bound.0 < 0 {
743                    A_matrix.rows.push(row_i);
744                    A_matrix.cols.push(pi);
745                    A_matrix.vals.push(-1);
746                    b_vector.push(-1 * p_bound.0);
747                    row_i += 1;
748                }
749    
750                if p_bound.1 > 1 {
751                    A_matrix.rows.push(row_i);
752                    A_matrix.cols.push(pi);
753                    A_matrix.vals.push(1);
754                    b_vector.push(p_bound.1);
755                    row_i += 1;
756                } 
757            }
758        }
759
760        A_matrix.shape = (row_i, column_names_map.len());
762
763        let polyhedron = SparsePolyhedron {
765            A: A_matrix,
766            b: b_vector,
767            columns: column_names_map.keys().cloned().collect(),
768            integer_columns: integer_variables,
769        };
770
771        return polyhedron;
772    }
773
774    pub fn to_sparse_polyhedron_default(&self) -> SparsePolyhedron {
775        self.to_sparse_polyhedron(true, true, true)
776    }
777
778    pub fn to_dense_polyhedron(&self, double_binding: bool, integer_constraints: bool, fixed_constraints: bool) -> DensePolyhedron {
779        DensePolyhedron::from(self.to_sparse_polyhedron(double_binding, integer_constraints, fixed_constraints))
780    }
781
782    pub fn to_dense_polyhedron_default(&self) -> DensePolyhedron {
783        self.to_dense_polyhedron(true, true, true)
784    }
785    
786    pub fn set_primitive(&mut self, id: ID, bound: Bound) {
787        self._amat.insert(id.clone(), BoolExpression::Primitive(bound));
789        
790        self._amap.insert(id.clone(), id.clone());
792    }
793
794    pub fn set_primitives(&mut self, ids: Vec<ID>, bound: Bound) {
795        let unique_ids: IndexSet<_> = ids.into_iter().collect();
796        for id in unique_ids {
797            self.set_primitive(id, bound);
798        }
799    }
800
801    pub fn set_gelineq(&mut self, coefficient_variables: Vec<Coefficient>, bias: i64, alias: Option<ID>) -> ID {
802        let mut unique_coefficients: IndexMap<ID, i64> = IndexMap::new();
804        for (key, value) in coefficient_variables {
805            *unique_coefficients.entry(key).or_insert(0) += value;
806        }
807        let coefficient_variables: Vec<Coefficient> = unique_coefficients.into_iter().collect();
808
809        let hash = create_hash(&coefficient_variables, bias);
811        
812        let id = hash.to_string();
814
815        self._amat.insert(id.clone(), BoolExpression::Composite(Constraint { coefficients: coefficient_variables, bias: (bias, bias) }));
817
818        if let Some(alias) = alias {
820            self._amap.insert(id.clone(), alias);
821        }
822
823        return id;
824    }
825
826    pub fn set_atleast(&mut self, references: Vec<ID>, value: i64, alias: Option<ID>) -> ID {
827        let unique_references: IndexSet<_> = references.into_iter().collect();
828        self.set_gelineq(unique_references.into_iter().map(|x| (x, 1)).collect(), -value, alias)
829    }
830
831    pub fn set_atmost(&mut self, references: Vec<ID>, value: i64, alias: Option<ID>) -> ID {
832        let unique_references: IndexSet<_> = references.into_iter().collect();
833        self.set_gelineq(unique_references.into_iter().map(|x| (x, -1)).collect(), value, alias)
834    }
835
836    pub fn set_equal(&mut self, references: Vec<ID>, value: i64, alias: Option<ID>) -> ID {
837        let unique_references: IndexSet<_> = references.into_iter().collect();
838        let ub = self.set_atleast(unique_references.clone().into_iter().collect(), value, None);
839        let lb = self.set_atmost(unique_references.into_iter().collect(), value, None);
840        self.set_and(vec![ub, lb], alias)
841    }
842
843    pub fn set_and(&mut self, references: Vec<ID>, alias: Option<ID>) -> ID {
844        let unique_references: IndexSet<_> = references.into_iter().collect();
845        let length = unique_references.len();
846        self.set_atleast(unique_references.into_iter().collect(), length as i64, alias)
847    }
848
849    pub fn set_or(&mut self, references: Vec<ID>, alias: Option<ID>) -> ID {
850        let unique_references: IndexSet<_> = references.into_iter().collect();
851        self.set_atleast(unique_references.into_iter().collect(), 1, alias)
852    }
853
854    pub fn set_nand(&mut self, references: Vec<ID>, alias: Option<ID>) -> ID {
855        let unique_references: IndexSet<_> = references.into_iter().collect();
856        let length = unique_references.len();
857        self.set_atmost(unique_references.into_iter().collect(), length as i64 - 1, alias)
858    }
859    
860    pub fn set_nor(&mut self, references: Vec<ID>, alias: Option<ID>) -> ID {
861        let unique_references: IndexSet<_> = references.into_iter().collect();
862        self.set_atmost(unique_references.into_iter().collect(), 0, alias)
863    }
864
865    pub fn set_not(&mut self, references: Vec<ID>, alias: Option<ID>) -> ID {
866        let unique_references: IndexSet<_> = references.into_iter().collect();
867        self.set_atmost(unique_references.into_iter().collect(), 0, alias)
868    }
869
870    pub fn set_xor(&mut self, references: Vec<ID>, alias: Option<ID>) -> ID {
871        let unique_references: IndexSet<_> = references.into_iter().collect();
872        let atleast = self.set_or(unique_references.clone().into_iter().collect(), None);
873        let atmost = self.set_atmost(unique_references.into_iter().collect(), 1, None);
874        self.set_and(vec![atleast, atmost], alias)
875    }
876
877    pub fn set_xnor(&mut self, references: Vec<ID>, alias: Option<ID>) -> ID {
878        let unique_references: IndexSet<_> = references.into_iter().collect();
879        let atleast = self.set_atleast(unique_references.clone().into_iter().collect(), 2, None);
880        let atmost = self.set_atmost(unique_references.into_iter().collect(), 0, None);
881        self.set_or(vec![atleast, atmost], alias)
882    }
883
884    pub fn set_imply(&mut self, condition: ID, consequence: ID, alias: Option<ID>) -> ID {
885        let not_condition = self.set_not(vec![condition], None);
886        self.set_or(vec![not_condition, consequence], alias)
887    }
888
889    pub fn set_equiv(&mut self, lhs: ID, rhs: ID, alias: Option<ID>) -> ID {
890        let imply_lr = self.set_imply(lhs.clone(), rhs.clone(), None);
891        let imply_rl = self.set_imply(rhs.clone(), lhs.clone(), None);
892        self.set_and(vec![imply_lr, imply_rl], alias)
893    }
894
895    pub fn set_alias(&mut self, id: ID, alias: String) {
896        self._amap.insert(id.clone(), alias.clone());
898    }
899}
900
901#[cfg(test)]
902mod tests {
903    use super::*;
904
905    fn evaluate_model_polyhedron(
912        model: &Pldag,
913        poly: &DensePolyhedron,
914        root: &String
915    ) {
916        for combo in model.primitive_combinations() {
917            let interp = combo.iter()
919                .map(|(k,&v)| (k.clone(), (v,v)))
920                .collect::<IndexMap<String,Bound>>();
921
922            let prop = model.check_combination(&interp);
924            let model_root_val = *prop.get(root).unwrap();
925
926            let mut assumption = HashMap::new();
928            assumption.insert(root.clone(), 1);
929            let shrunk = poly.assume(&assumption);
930
931            let poly_val = shrunk.evaluate(&prop);
933            assert_eq!(
934                poly_val,
935                model_root_val,
936                "Disagreement on {:?}: model={:?}, poly={:?}",
937                combo,
938                model_root_val,
939                poly_val
940            );
941        }
942    }
943
944    #[test]
945    fn test_propagate() {
946        let mut model = Pldag::new();
947        model.set_primitive("x".to_string(), (0, 1));
948        model.set_primitive("y".to_string(), (0, 1));
949        let root = model.set_and(vec!["x".to_string(), "y".to_string()], None);
950
951        let result = model.check_combination(&IndexMap::new());
952        assert_eq!(result.get("x").unwrap(), &(0, 1));
953        assert_eq!(result.get("y").unwrap(), &(0, 1));
954        assert_eq!(result.get(&root).unwrap(), &(0, 1));
955
956        let mut interpretation = IndexMap::new();
957        interpretation.insert("x".to_string(), (1, 1));
958        interpretation.insert("y".to_string(), (1, 1));
959        let result = model.check_combination(&interpretation);
960        assert_eq!(result.get(&root).unwrap(), &(1, 1));
961
962        let mut model = Pldag::new();
963        model.set_primitive("x".to_string(), (0, 1));
964        model.set_primitive("y".to_string(), (0, 1));
965        model.set_primitive("z".to_string(), (0, 1));
966        let root = model.set_xor(vec!["x".to_string(), "y".to_string(), "z".to_string()], None);
967        let result = model.check_combination(&IndexMap::new());
968        assert_eq!(result.get("x").unwrap(), &(0, 1));
969        assert_eq!(result.get("y").unwrap(), &(0, 1));
970        assert_eq!(result.get("z").unwrap(), &(0, 1));
971        assert_eq!(result.get(&root).unwrap(), &(0, 1));
972
973        let mut interpretation = IndexMap::new();
974        interpretation.insert("x".to_string(), (1, 1));
975        interpretation.insert("y".to_string(), (1, 1));
976        interpretation.insert("z".to_string(), (1, 1));
977        let result = model.check_combination(&interpretation);
978        assert_eq!(result.get(&root).unwrap(), &(0, 0));
979        
980        let mut interpretation = IndexMap::new();
981        interpretation.insert("x".to_string(), (0, 1));
982        interpretation.insert("y".to_string(), (1, 1));
983        interpretation.insert("z".to_string(), (1, 1));
984        let result = model.check_combination(&interpretation);
985        assert_eq!(result.get(&root).unwrap(), &(0, 0));
986        
987        let mut interpretation = IndexMap::new();
988        interpretation.insert("x".to_string(), (0, 0));
989        interpretation.insert("y".to_string(), (1, 1));
990        interpretation.insert("z".to_string(), (0, 0));
991        let result = model.check_combination(&interpretation);
992        assert_eq!(result.get(&root).unwrap(), &(1, 1));
993    }
994
995    #[test]
997    fn test_propagate_or_gate() {
998        let mut model = Pldag::new();
999        model.set_primitive("a".into(), (0, 1));
1000        model.set_primitive("b".into(), (0, 1));
1001        let or_root = model.set_or(vec!["a".into(), "b".into()], None);
1002
1003        let res = model.check_combination(&IndexMap::new());
1005        assert_eq!(res["a"], (0, 1));
1006        assert_eq!(res["b"], (0, 1));
1007        assert_eq!(res[&or_root], (0, 1));
1008
1009        let mut interp = IndexMap::new();
1011        interp.insert("a".into(), (1, 1));
1012        let res = model.check_combination(&interp);
1013        assert_eq!(res[&or_root], (1, 1));
1014
1015        let mut interp = IndexMap::new();
1017        interp.insert("a".into(), (0, 0));
1018        interp.insert("b".into(), (0, 0));
1019        let res = model.check_combination(&interp);
1020        assert_eq!(res[&or_root], (0, 0));
1021
1022        let mut interp = IndexMap::new();
1024        interp.insert("b".into(), (0, 0));
1025        let res = model.check_combination(&interp);
1026        assert_eq!(res[&or_root], (0, 1));
1027    }
1028
1029    #[test]
1031    fn test_propagate_not_gate() {
1032        let mut model = Pldag::new();
1033        model.set_primitive("p".into(), (0, 1));
1034        let not_root = model.set_not(vec!["p".into()], None);
1035
1036        let res = model.check_combination(&IndexMap::new());
1038        assert_eq!(res["p"], (0, 1));
1039        assert_eq!(res[¬_root], (0, 1));
1040
1041        let mut interp = IndexMap::new();
1043        interp.insert("p".into(), (0, 0));
1044        let res = model.check_combination(&interp);
1045        assert_eq!(res[¬_root], (1, 1));
1046
1047        let mut interp = IndexMap::new();
1049        interp.insert("p".into(), (1, 1));
1050        let res = model.check_combination(&interp);
1051        assert_eq!(res[¬_root], (0, 0));
1052    }
1053
1054    #[test]
1055    fn test_to_polyhedron_and() {
1056        let mut m = Pldag::new();
1057        m.set_primitive("x".into(), (0,1));
1058        m.set_primitive("y".into(), (0,1));
1059        let root = m.set_and(vec!["x".into(), "y".into()], None);
1060        let poly: DensePolyhedron = m.to_sparse_polyhedron_default().into();
1061        evaluate_model_polyhedron(&m, &poly, &root);
1062    }
1063
1064    #[test]
1065    fn test_to_polyhedron_or() {
1066        let mut m = Pldag::new();
1067        m.set_primitive("a".into(), (0,1));
1068        m.set_primitive("b".into(), (0,1));
1069        m.set_primitive("c".into(), (0,1));
1070        let root = m.set_or(vec!["a".into(), "b".into(), "c".into()], None);
1071        let poly: DensePolyhedron = m.to_sparse_polyhedron_default().into();
1072        evaluate_model_polyhedron(&m, &poly, &root);
1073    }
1074
1075    #[test]
1076    fn test_to_polyhedron_not() {
1077        let mut m = Pldag::new();
1078        m.set_primitive("p".into(), (0,1));
1079        let root = m.set_not(vec!["p".into()], None);
1080        let poly: DensePolyhedron = m.to_sparse_polyhedron_default().into();
1081        evaluate_model_polyhedron(&m, &poly, &root);
1082    }
1083
1084    #[test]
1085    fn test_to_polyhedron_xor() {
1086        let mut m = Pldag::new();
1087        m.set_primitive("x".into(), (0,1));
1088        m.set_primitive("y".into(), (0,1));
1089        m.set_primitive("z".into(), (0,1));
1090        let root = m.set_xor(vec!["x".into(), "y".into(), "z".into()], None);
1091        let poly: DensePolyhedron = m.to_sparse_polyhedron_default().into();
1092        evaluate_model_polyhedron(&m, &poly, &root);
1093    }
1094
1095    #[test]
1096    fn test_to_polyhedron_nested() {
1097        let mut m = Pldag::new();
1100        m.set_primitive("x".into(), (0,1));
1101        m.set_primitive("y".into(), (0,1));
1102        m.set_primitive("z".into(), (0,1));
1103
1104        let w = m.set_and(vec!["x".into(), "y".into()], None);
1105        let nz = m.set_not(vec!["z".into()], None);
1106        let v = m.set_or(vec![w.clone(), nz.clone()], None);
1107
1108        let poly: DensePolyhedron = m.to_sparse_polyhedron_default().into();
1109        evaluate_model_polyhedron(&m, &poly, &v);
1110    }
1111
1112    #[test]
1115    fn test_propagate_nested_composite() {
1116        let mut model = Pldag::new();
1117        model.set_primitive("x".into(), (0, 1));
1118        model.set_primitive("y".into(), (0, 1));
1119        model.set_primitive("z".into(), (0, 1));
1120
1121        let w = model.set_and(vec!["x".into(), "y".into()], None);
1122        let v = model.set_xor(vec![w.clone(), "z".into()], None);
1123
1124        let res = model.check_combination(&IndexMap::new());
1126        for var in &["x","y","z"] {
1127            assert_eq!(res[*var], (0,1), "{}", var);
1128        }
1129        assert_eq!(res[&w], (0,1));
1130        assert_eq!(res[&v], (0,1));
1131
1132        let mut interp = IndexMap::new();
1134        interp.insert("x".into(), (1,1));
1135        interp.insert("y".into(), (1,1));
1136        interp.insert("z".into(), (0,0));
1137        let res = model.check_combination(&interp);
1138        assert_eq!(res[&w], (1,1));
1139        assert_eq!(res[&v], (1,1));
1140
1141        let mut interp = IndexMap::new();
1143        interp.insert("x".into(), (0,0));
1144        interp.insert("y".into(), (1,1));
1145        interp.insert("z".into(), (1,1));
1146        let res = model.check_combination(&interp);
1147        assert_eq!(res[&w], (0,0));
1148        assert_eq!(res[&v], (1,1));
1149
1150        let mut interp = IndexMap::new();
1152        interp.insert("x".into(), (0,0));
1153        interp.insert("y".into(), (0,0));
1154        interp.insert("z".into(), (0,0));
1155        let res = model.check_combination(&interp);
1156        assert_eq!(res[&w], (0,0));
1157        assert_eq!(res[&v], (0,0));
1158    }
1159
1160    #[test]
1164    fn test_propagate_out_of_bounds_does_not_crash() {
1165        let mut model = Pldag::new();
1166        model.set_primitive("u".into(), (0, 1));
1167        let root = model.set_not(vec!["u".into()], None);
1168
1169        let mut interp = IndexMap::new();
1170        interp.insert("u".into(), (5,5));
1172        let res = model.check_combination(&interp);
1173
1174        assert_eq!(res["u"], (5,5));
1176        let _ = res[&root];
1180    }
1181
1182    #[test]
1183    fn test_to_polyhedron() {
1184
1185        fn evaluate_model_polyhedron(model: &Pldag, polyhedron: &DensePolyhedron, root: &String) {
1186            for combination in model.primitive_combinations() {
1187                let interpretation = combination
1188                    .iter()
1189                    .map(|(k, &v)| (k.clone(), (v, v)))
1190                    .collect::<IndexMap<String, Bound>>();
1191                let model_prop = model.check_combination(&interpretation);
1192                let model_eval = *model_prop.get(root).unwrap();
1193                let mut assumption = HashMap::new();
1194                assumption.insert(root.clone(), 1);
1195                let assumed_polyhedron = polyhedron.assume(&assumption);
1196                let assumed_poly_eval = assumed_polyhedron.evaluate(&model_prop);
1197                assert_eq!(assumed_poly_eval, model_eval);
1198            }
1199        }
1200
1201        let mut model: Pldag = Pldag::new();
1202        model.set_primitive("x".to_string(), (0, 1));
1203        model.set_primitive("y".to_string(), (0, 1));
1204        model.set_primitive("z".to_string(), (0, 1));
1205        let root = model.set_xor(vec!["x".to_string(), "y".to_string(), "z".to_string()], None);
1206        let polyhedron: DensePolyhedron = model.to_sparse_polyhedron_default().into();
1207        evaluate_model_polyhedron(&model, &polyhedron, &root);
1208
1209        let mut model = Pldag::new();
1210        model.set_primitive("x".to_string(), (0, 1));
1211        model.set_primitive("y".to_string(), (0, 1));
1212        let root = model.set_and(vec!["x".to_string(), "y".to_string()], None);
1213        let polyhedron = model.to_sparse_polyhedron_default().into();
1214        evaluate_model_polyhedron(&model, &polyhedron, &root);
1215
1216        let mut model: Pldag = Pldag::new();
1217        model.set_primitive("x".to_string(), (0, 1));
1218        model.set_primitive("y".to_string(), (0, 1));
1219        model.set_primitive("z".to_string(), (0, 1));
1220        let root = model.set_xor(vec!["x".to_string(), "y".to_string(), "z".to_string()], None);
1221        let polyhedron = model.to_sparse_polyhedron_default().into();
1222        evaluate_model_polyhedron(&model, &polyhedron, &root);
1223    }
1224
1225    #[test]
1227    fn test_to_polyhedron_single_operand_identity() {
1228        {
1230            let mut m = Pldag::new();
1231            m.set_primitive("x".into(), (0,1));
1232            let root = m.set_and(vec!["x".into()], None);
1233            let poly: DensePolyhedron = m.to_sparse_polyhedron_default().into();
1234            evaluate_model_polyhedron(&m, &poly, &root);
1235        }
1236        {
1238            let mut m = Pldag::new();
1239            m.set_primitive("y".into(), (0,1));
1240            let root = m.set_or(vec!["y".into()], None);
1241            let poly: DensePolyhedron = m.to_sparse_polyhedron_default().into();
1242            evaluate_model_polyhedron(&m, &poly, &root);
1243        }
1244        {
1246            let mut m = Pldag::new();
1247            m.set_primitive("z".into(), (0,1));
1248            let root = m.set_xor(vec!["z".into()], None);
1249            let poly: DensePolyhedron = m.to_sparse_polyhedron_default().into();
1250            evaluate_model_polyhedron(&m, &poly, &root);
1251        }
1252    }
1253
1254    #[test]
1256    fn test_to_polyhedron_duplicate_operands_and() {
1257        let mut m = Pldag::new();
1258        m.set_primitive("x".into(), (0,1));
1259        let root = m.set_and(vec!["x".into(), "x".into()], None);
1260        let poly: DensePolyhedron = m.to_sparse_polyhedron_default().into();
1261        evaluate_model_polyhedron(&m, &poly, &root);
1262    }
1263
1264    #[test]
1270    fn test_to_polyhedron_deeply_nested_chain() {
1271        let mut m = Pldag::new();
1272        for &v in &["a","b","c","d","e"] {
1274            m.set_primitive(v.into(), (0,1));
1275        }
1276        let a = "a".to_string();
1277        let b = "b".to_string();
1278        let c = "c".to_string();
1279        let d = "d".to_string();
1280
1281        let w1 = m.set_and(vec![a.clone(), b.clone()], None);
1282        let w2 = m.set_or(vec![w1.clone(), c.clone()], None);
1283        let w3 = m.set_xor(vec![w2.clone(), d.clone()], None);
1284        let root = m.set_not(vec![w3.clone()], None);
1285
1286        let poly: DensePolyhedron = m.to_sparse_polyhedron_default().into();
1287        evaluate_model_polyhedron(&m, &poly, &root);
1288    }
1289
1290    fn make_simple_dag() -> Pldag {
1292        let mut pldag = Pldag::new();
1293        pldag.set_primitive("b".into(), (0, 1));
1294        pldag.set_primitive("d".into(), (0, 1));
1295        pldag.set_primitive("e".into(), (0, 1));
1296        let c = pldag.set_or(vec!["d".into(), "e".into()], None);
1297        let a = pldag.set_or(vec!["b".into(), c.clone()], None);
1298        pldag.set_alias("a".into(), a.clone());
1299        pldag.set_alias("c".into(), c.clone());
1300        return pldag;
1301    }
1302
1303    #[test]
1304    fn test_simple_dag() {
1305        let pldag = make_simple_dag();
1306        let deps = pldag.transitive_dependencies();
1307
1308        let expect = |xs: &[&str]| {
1309            xs.iter().cloned().map(String::from).collect::<HashSet<_>>()
1310        };
1311
1312        let a = pldag.get_id(&"a".to_string());
1313        let c = pldag.get_id(&"c".to_string());
1314        assert_eq!(deps.get(&a), Some(&expect(&["b", &c.clone(), "d", "e"])));
1315        assert_eq!(deps.get("b"), Some(&expect(&[])));
1316        assert_eq!(deps.get(&c), Some(&expect(&["d", "e"])));
1317        assert_eq!(deps.get("d"), Some(&expect(&[])));
1318        assert_eq!(deps.get("e"), Some(&expect(&[])));
1319    }
1320
1321    #[test]
1322    fn test_chain_dag() {
1323        let mut pldag = Pldag::new();
1325        pldag.set_primitive("z".into(), (0, 0));;
1326        let y = pldag.set_or(vec!["z".into()], None);
1327        let x = pldag.set_or(vec![y.clone()], None);
1328        let deps = pldag.transitive_dependencies();
1329
1330        let expect = |xs: &[&str]| {
1331            xs.iter().cloned().map(String::from).collect::<HashSet<_>>()
1332        };
1333
1334        assert_eq!(deps.get(&x), Some(&expect(&[&y.to_string(), "z"])));
1335        assert_eq!(deps.get(&y), Some(&expect(&["z"])));
1336        assert_eq!(deps.get("z"), Some(&expect(&[])));
1337    }
1338
1339    #[test]
1340    fn test_all_primitives() {
1341        let mut amat = IndexMap::new();
1343        for &name in &["p", "q", "r"] {
1344            amat.insert(name.into(), BoolExpression::Primitive((1, 5)));
1345        }
1346        let mut pldag = Pldag::new();
1347        pldag._amat = amat;
1348        let deps = pldag.transitive_dependencies();
1349
1350        for &name in &["p", "q", "r"] {
1351            assert!(deps.get(name).unwrap().is_empty(), "{} should have no deps", name);
1352        }
1353    }
1354
1355    #[test]
1356    fn test_propagate_weighted() {
1357        let mut model = Pldag::new();
1358        model.set_primitive("x".to_string(), (0, 1));
1359        model.set_primitive("y".to_string(), (0, 1));
1360        let root = model.set_and(vec!["x".to_string(), "y".to_string()], None);
1361        let mut interpretation = IndexMap::new();
1362        interpretation.insert("x".to_string(), (1, 1));
1363        interpretation.insert("y".to_string(), (1, 1));
1364        let mut weights: IndexMap<String, f64> = IndexMap::new();
1365        weights.insert("x".to_string(), 2.0);
1366        weights.insert("y".to_string(), 3.0);
1367        let propagated = model.check_and_score(&interpretation, &weights);
1368        assert_eq!(propagated.get("x").unwrap(), &(2.0, 2.0));
1369        assert_eq!(propagated.get("y").unwrap(), &(3.0, 3.0));
1370        assert_eq!(propagated.get(&root).unwrap(), &(5.0, 5.0));
1371    }
1372
1373    #[test]
1374    fn test_readme_example() {
1375        let mut pldag: Pldag = Pldag::new();
1379
1380        pldag.set_primitive("x".to_string(), (0, 1));
1382        pldag.set_primitive("y".to_string(), (0, 1));
1383        pldag.set_primitive("z".to_string(), (0, 1));
1384
1385        let root = pldag.set_or(vec![
1387            "x".to_string(),
1388            "y".to_string(),
1389            "z".to_string(),
1390        ], None);
1391
1392        let mut inputs: IndexMap<String, Bound> = IndexMap::new();
1394        let validited = pldag.check_combination(&inputs);
1395        println!("Root valid? {}", *validited.get(&root).unwrap() == (1, 1)); inputs.insert("x".to_string(), (0,0));
1401        let revalidited = pldag.check_combination(&inputs);
1402        println!("Root valid? {}", *revalidited.get(&root).unwrap() == (1, 1)); inputs.insert("y".to_string(), (1,1));
1406        inputs.insert("z".to_string(), (1,1));
1407        let revalidited = pldag.check_combination(&inputs);
1408        println!("Root valid? {}", *revalidited.get(&root).unwrap() == (1, 1)); let mut weights: IndexMap<String, f64> = IndexMap::new();
1413        weights.insert("x".to_string(), 1.0);
1414        weights.insert("y".to_string(), 2.0);
1415        weights.insert("z".to_string(), 3.0);
1416        weights.insert(root.clone(), -1.0);
1418        let scores = pldag.check_and_score(&inputs, &weights);
1419        println!("Total score: {:?}", scores.get(&root).unwrap());
1420
1421        inputs.insert("x".to_string(), (0,1));
1423        let scores = pldag.check_and_score(&inputs, &weights);
1424        println!("Total score: {:?}", scores.get(&root).unwrap());
1427
1428        inputs.insert("x".to_string(), (0,0));
1430        let scores = pldag.check_and_score(&inputs, &weights);
1431        println!("Total score: {:?}", scores.get(&root).unwrap());
1432
1433        inputs.insert("y".to_string(), (0,0));
1435        inputs.insert("z".to_string(), (0,0));
1436        let scores = pldag.check_and_score(&inputs, &weights);
1437        println!("Total score: {:?}", scores.get(&root).unwrap());
1438    }
1439}