Skip to main content

litex/fact/
chain_fact_order_closure.rs

1use crate::prelude::*;
2use std::collections::{HashMap, HashSet};
3
4#[derive(Clone, Copy, PartialEq, Eq)]
5enum OrderEdge {
6    Eq,
7    Le,
8    Lt,
9    Ge,
10    Gt,
11}
12
13#[derive(Clone, Copy, PartialEq, Eq)]
14enum ChainPolarity {
15    Up,
16    Down,
17}
18
19struct UnionFind {
20    parent: Vec<usize>,
21}
22
23impl UnionFind {
24    fn new(n: usize) -> Self {
25        UnionFind {
26            parent: (0..n).collect(),
27        }
28    }
29
30    fn find(&mut self, x: usize) -> usize {
31        if self.parent[x] != x {
32            self.parent[x] = self.find(self.parent[x]);
33        }
34        self.parent[x]
35    }
36
37    fn union(&mut self, a: usize, b: usize) {
38        let ra = self.find(a);
39        let rb = self.find(b);
40        if ra != rb {
41            self.parent[rb] = ra;
42        }
43    }
44}
45
46fn order_edge_from_prop(p: &IdentifierOrIdentifierWithMod) -> Option<OrderEdge> {
47    match p.to_string().as_str() {
48        EQUAL => Some(OrderEdge::Eq),
49        LESS_EQUAL => Some(OrderEdge::Le),
50        LESS => Some(OrderEdge::Lt),
51        GREATER_EQUAL => Some(OrderEdge::Ge),
52        GREATER => Some(OrderEdge::Gt),
53        _ => None,
54    }
55}
56
57fn dedup_atomic_facts(mut facts: Vec<AtomicFact>) -> Vec<AtomicFact> {
58    let mut seen = HashSet::new();
59    facts.retain(|f| seen.insert(f.to_string()));
60    facts
61}
62
63impl ChainFact {
64    pub fn facts_with_order_transitive_closure(&self) -> Result<Vec<AtomicFact>, RuntimeError> {
65        let base = self.facts()?;
66        let n = self.objs.len();
67        if n < 2 {
68            return Ok(base);
69        }
70
71        let mut edges: Vec<OrderEdge> = Vec::with_capacity(self.prop_names.len());
72        for p in &self.prop_names {
73            let Some(e) = order_edge_from_prop(p) else {
74                return Ok(base);
75            };
76            edges.push(e);
77        }
78
79        let mut has_up = false;
80        let mut has_down = false;
81        for e in &edges {
82            match e {
83                OrderEdge::Le | OrderEdge::Lt => has_up = true,
84                OrderEdge::Ge | OrderEdge::Gt => has_down = true,
85                OrderEdge::Eq => {}
86            }
87        }
88        if has_up && has_down {
89            return Ok(base);
90        }
91
92        let polarity = if has_up {
93            ChainPolarity::Up
94        } else if has_down {
95            ChainPolarity::Down
96        } else {
97            ChainPolarity::Up
98        };
99
100        let mut uf = UnionFind::new(n);
101        for (k, e) in edges.iter().enumerate() {
102            if *e == OrderEdge::Eq {
103                uf.union(k, k + 1);
104            }
105        }
106
107        let mut quotient: Vec<usize> = Vec::new();
108        for i in 0..n {
109            if i == 0 || uf.find(i) != uf.find(i - 1) {
110                quotient.push(i);
111            }
112        }
113
114        let mut between_strict: Vec<bool> = Vec::new();
115        for k in 0..edges.len() {
116            let ca = uf.find(k);
117            let cb = uf.find(k + 1);
118            if ca != cb {
119                let strict = match polarity {
120                    ChainPolarity::Up => matches!(edges[k], OrderEdge::Lt),
121                    ChainPolarity::Down => matches!(edges[k], OrderEdge::Gt),
122                };
123                between_strict.push(strict);
124            }
125        }
126
127        if between_strict.len() + 1 != quotient.len() {
128            return Ok(base);
129        }
130
131        let lf = self.line_file.clone();
132        let mut extra: Vec<AtomicFact> = Vec::new();
133
134        let mut members: HashMap<usize, Vec<usize>> = HashMap::new();
135        for i in 0..n {
136            members.entry(uf.find(i)).or_default().push(i);
137        }
138        for mut idxs in members.into_values() {
139            if idxs.len() < 2 {
140                continue;
141            }
142            idxs.sort_unstable();
143            for ii in 0..idxs.len() {
144                for jj in ii + 1..idxs.len() {
145                    let i = idxs[ii];
146                    let j = idxs[jj];
147                    extra.push(
148                        EqualFact::new(self.objs[i].clone(), self.objs[j].clone(), lf.clone())
149                            .into(),
150                    );
151                }
152            }
153        }
154
155        for qi in 0..quotient.len() {
156            for qj in qi + 1..quotient.len() {
157                let path_strict = between_strict[qi..qj].iter().any(|&s| s);
158                let left = self.objs[quotient[qi]].clone();
159                let right = self.objs[quotient[qj]].clone();
160                let f = match polarity {
161                    ChainPolarity::Up => {
162                        if path_strict {
163                            LessFact::new(left, right, lf.clone()).into()
164                        } else {
165                            LessEqualFact::new(left, right, lf.clone()).into()
166                        }
167                    }
168                    ChainPolarity::Down => {
169                        if path_strict {
170                            GreaterFact::new(left, right, lf.clone()).into()
171                        } else {
172                            GreaterEqualFact::new(left, right, lf.clone()).into()
173                        }
174                    }
175                };
176                extra.push(f);
177            }
178        }
179
180        let mut all = base;
181        all.extend(extra);
182        Ok(dedup_atomic_facts(all))
183    }
184}
185
186#[cfg(test)]
187mod tests {
188    use super::*;
189
190    fn id(name: &str) -> Obj {
191        Identifier::mk(name.to_string())
192    }
193
194    fn prop(s: &str) -> IdentifierOrIdentifierWithMod {
195        IdentifierOrIdentifierWithMod::Identifier(Identifier::new(s.to_string()))
196    }
197
198    #[test]
199    fn eq_chain_adds_all_pairs_between_equal_nodes() {
200        let lf = default_line_file();
201        let chain = ChainFact::new(
202            vec![id("x"), id("y"), id("z")],
203            vec![prop("="), prop("=")],
204            lf.clone(),
205        );
206        let facts = chain.facts_with_order_transitive_closure().unwrap();
207        let displayed: Vec<_> = facts.iter().map(|f| f.to_string()).collect();
208        assert!(
209            displayed.iter().any(|s| s.contains("y") && s.contains("z") && s.contains("=")),
210            "expected y = z from equality clique, got {:?}",
211            displayed
212        );
213    }
214
215    #[test]
216    fn le_eq_lt_chain_adds_transitive_facts() {
217        let lf = default_line_file();
218        let chain = ChainFact::new(
219            vec![id("a"), id("b"), id("c"), id("d")],
220            vec![prop("<="), prop("="), prop("<")],
221            lf.clone(),
222        );
223        let facts = chain.facts_with_order_transitive_closure().unwrap();
224        let displayed: Vec<_> = facts.iter().map(|f| f.to_string()).collect();
225        assert!(displayed.iter().any(|s| s.contains("a") && s.contains("d")));
226        assert!(
227            facts.len() >= 5,
228            "expected transitive closure beyond three adjacent facts, got {:?}",
229            displayed
230        );
231    }
232
233    #[test]
234    fn mixed_up_down_chain_skips_closure_beyond_base() {
235        let lf = default_line_file();
236        let chain = ChainFact::new(
237            vec![id("a"), id("b"), id("c")],
238            vec![prop("<="), prop(">")],
239            lf,
240        );
241        let facts = chain.facts_with_order_transitive_closure().unwrap();
242        assert_eq!(facts.len(), 2);
243    }
244
245    #[test]
246    fn ge_eq_gt_chain_adds_transitive_greater_facts() {
247        let lf = default_line_file();
248        let chain = ChainFact::new(
249            vec![id("d"), id("c"), id("b"), id("a")],
250            vec![prop(">="), prop("="), prop(">")],
251            lf,
252        );
253        let facts = chain.facts_with_order_transitive_closure().unwrap();
254        let displayed: Vec<_> = facts.iter().map(|f| f.to_string()).collect();
255        assert!(displayed.iter().any(|s| s.contains("d") && s.contains("a")));
256        assert!(facts.len() >= 5);
257    }
258}