Skip to main content

litex/fact/
chain_fact_order_closure.rs

1use crate::common::keywords::{EQUAL, GREATER, GREATER_EQUAL, LESS, LESS_EQUAL};
2use crate::fact::matchable_fact_with_atomic_fact_inside::ChainFact;
3use crate::prelude::*;
4use std::collections::{HashMap, HashSet};
5
6#[derive(Clone, Copy, PartialEq, Eq)]
7enum OrderEdge {
8    Eq,
9    Le,
10    Lt,
11    Ge,
12    Gt,
13}
14
15#[derive(Clone, Copy, PartialEq, Eq)]
16enum ChainPolarity {
17    Up,
18    Down,
19}
20
21struct UnionFind {
22    parent: Vec<usize>,
23}
24
25impl UnionFind {
26    fn new(n: usize) -> Self {
27        UnionFind {
28            parent: (0..n).collect(),
29        }
30    }
31
32    fn find(&mut self, x: usize) -> usize {
33        if self.parent[x] != x {
34            self.parent[x] = self.find(self.parent[x]);
35        }
36        self.parent[x]
37    }
38
39    fn union(&mut self, a: usize, b: usize) {
40        let ra = self.find(a);
41        let rb = self.find(b);
42        if ra != rb {
43            self.parent[rb] = ra;
44        }
45    }
46}
47
48fn order_edge_from_prop(p: &IdentifierOrIdentifierWithMod) -> Option<OrderEdge> {
49    match p.to_string().as_str() {
50        EQUAL => Some(OrderEdge::Eq),
51        LESS_EQUAL => Some(OrderEdge::Le),
52        LESS => Some(OrderEdge::Lt),
53        GREATER_EQUAL => Some(OrderEdge::Ge),
54        GREATER => Some(OrderEdge::Gt),
55        _ => None,
56    }
57}
58
59fn dedup_atomic_facts(mut facts: Vec<AtomicFact>) -> Vec<AtomicFact> {
60    let mut seen = HashSet::new();
61    facts.retain(|f| seen.insert(f.to_string()));
62    facts
63}
64
65impl ChainFact {
66    pub fn facts_with_order_transitive_closure(&self) -> Result<Vec<AtomicFact>, RuntimeErrorStruct> {
67        let base = self.facts()?;
68        let n = self.objs.len();
69        if n < 2 {
70            return Ok(base);
71        }
72
73        let mut edges: Vec<OrderEdge> = Vec::with_capacity(self.prop_names.len());
74        for p in &self.prop_names {
75            let Some(e) = order_edge_from_prop(p) else {
76                return Ok(base);
77            };
78            edges.push(e);
79        }
80
81        let mut has_up = false;
82        let mut has_down = false;
83        for e in &edges {
84            match e {
85                OrderEdge::Le | OrderEdge::Lt => has_up = true,
86                OrderEdge::Ge | OrderEdge::Gt => has_down = true,
87                OrderEdge::Eq => {}
88            }
89        }
90        if has_up && has_down {
91            return Ok(base);
92        }
93
94        let polarity = if has_up {
95            ChainPolarity::Up
96        } else if has_down {
97            ChainPolarity::Down
98        } else {
99            ChainPolarity::Up
100        };
101
102        let mut uf = UnionFind::new(n);
103        for (k, e) in edges.iter().enumerate() {
104            if *e == OrderEdge::Eq {
105                uf.union(k, k + 1);
106            }
107        }
108
109        let mut quotient: Vec<usize> = Vec::new();
110        for i in 0..n {
111            if i == 0 || uf.find(i) != uf.find(i - 1) {
112                quotient.push(i);
113            }
114        }
115
116        let mut between_strict: Vec<bool> = Vec::new();
117        for k in 0..edges.len() {
118            let ca = uf.find(k);
119            let cb = uf.find(k + 1);
120            if ca != cb {
121                let strict = match polarity {
122                    ChainPolarity::Up => matches!(edges[k], OrderEdge::Lt),
123                    ChainPolarity::Down => matches!(edges[k], OrderEdge::Gt),
124                };
125                between_strict.push(strict);
126            }
127        }
128
129        if between_strict.len() + 1 != quotient.len() {
130            return Ok(base);
131        }
132
133        let lf = self.line_file.clone();
134        let mut extra: Vec<AtomicFact> = Vec::new();
135
136        let mut members: HashMap<usize, Vec<usize>> = HashMap::new();
137        for i in 0..n {
138            members.entry(uf.find(i)).or_default().push(i);
139        }
140        for mut idxs in members.into_values() {
141            if idxs.len() < 2 {
142                continue;
143            }
144            idxs.sort_unstable();
145            let rep = idxs[0];
146            for &j in idxs.iter().skip(1) {
147                extra.push(AtomicFact::EqualFact(EqualFact::new(
148                    self.objs[rep].clone(),
149                    self.objs[j].clone(),
150                    lf.clone(),
151                )));
152            }
153        }
154
155        for qi in 0..quotient.len() {
156            for qj in qi + 1..quotient.len() {
157                let path_strict = between_strict[qi..qj].iter().any(|&s| s);
158                let left = self.objs[quotient[qi]].clone();
159                let right = self.objs[quotient[qj]].clone();
160                let f = match polarity {
161                    ChainPolarity::Up => {
162                        if path_strict {
163                            AtomicFact::LessFact(LessFact::new(left, right, lf.clone()))
164                        } else {
165                            AtomicFact::LessEqualFact(LessEqualFact::new(left, right, lf.clone()))
166                        }
167                    }
168                    ChainPolarity::Down => {
169                        if path_strict {
170                            AtomicFact::GreaterFact(GreaterFact::new(left, right, lf.clone()))
171                        } else {
172                            AtomicFact::GreaterEqualFact(GreaterEqualFact::new(
173                                left, right, lf.clone(),
174                            ))
175                        }
176                    }
177                };
178                extra.push(f);
179            }
180        }
181
182        let mut all = base;
183        all.extend(extra);
184        Ok(dedup_atomic_facts(all))
185    }
186}
187
188#[cfg(test)]
189mod tests {
190    use super::*;
191
192    fn id(name: &str) -> Obj {
193        Identifier::mk(name.to_string())
194    }
195
196    fn prop(s: &str) -> IdentifierOrIdentifierWithMod {
197        IdentifierOrIdentifierWithMod::Identifier(Identifier::new(s.to_string()))
198    }
199
200    #[test]
201    fn le_eq_lt_chain_adds_transitive_facts() {
202        let lf = default_line_file();
203        let chain = ChainFact::new(
204            vec![id("a"), id("b"), id("c"), id("d")],
205            vec![prop("<="), prop("="), prop("<")],
206            lf.clone(),
207        );
208        let facts = chain.facts_with_order_transitive_closure().unwrap();
209        let displayed: Vec<_> = facts.iter().map(|f| f.to_string()).collect();
210        assert!(displayed.iter().any(|s| s.contains("a") && s.contains("d")));
211        assert!(
212            facts.len() >= 5,
213            "expected transitive closure beyond three adjacent facts, got {:?}",
214            displayed
215        );
216    }
217
218    #[test]
219    fn mixed_up_down_chain_skips_closure_beyond_base() {
220        let lf = default_line_file();
221        let chain = ChainFact::new(
222            vec![id("a"), id("b"), id("c")],
223            vec![prop("<="), prop(">")],
224            lf,
225        );
226        let facts = chain.facts_with_order_transitive_closure().unwrap();
227        assert_eq!(facts.len(), 2);
228    }
229
230    #[test]
231    fn ge_eq_gt_chain_adds_transitive_greater_facts() {
232        let lf = default_line_file();
233        let chain = ChainFact::new(
234            vec![id("d"), id("c"), id("b"), id("a")],
235            vec![prop(">="), prop("="), prop(">")],
236            lf,
237        );
238        let facts = chain.facts_with_order_transitive_closure().unwrap();
239        let displayed: Vec<_> = facts.iter().map(|f| f.to_string()).collect();
240        assert!(displayed.iter().any(|s| s.contains("d") && s.contains("a")));
241        assert!(facts.len() >= 5);
242    }
243}