Skip to main content

litex/fact/
chain_fact_order_closure.rs

1use crate::prelude::*;
2use std::collections::{HashMap, HashSet};
3
4#[derive(Clone, Copy, PartialEq, Eq)]
5enum OrderEdge {
6    Eq,
7    Le,
8    Lt,
9    Ge,
10    Gt,
11}
12
13#[derive(Clone, Copy, PartialEq, Eq)]
14enum ChainPolarity {
15    Up,
16    Down,
17}
18
19struct UnionFind {
20    parent: Vec<usize>,
21}
22
23impl UnionFind {
24    fn new(n: usize) -> Self {
25        UnionFind {
26            parent: (0..n).collect(),
27        }
28    }
29
30    fn find(&mut self, x: usize) -> usize {
31        if self.parent[x] != x {
32            self.parent[x] = self.find(self.parent[x]);
33        }
34        self.parent[x]
35    }
36
37    fn union(&mut self, a: usize, b: usize) {
38        let ra = self.find(a);
39        let rb = self.find(b);
40        if ra != rb {
41            self.parent[rb] = ra;
42        }
43    }
44}
45
46fn order_edge_from_prop(p: &AtomicName) -> Option<OrderEdge> {
47    match p.to_string().as_str() {
48        EQUAL => Some(OrderEdge::Eq),
49        LESS_EQUAL => Some(OrderEdge::Le),
50        LESS => Some(OrderEdge::Lt),
51        GREATER_EQUAL => Some(OrderEdge::Ge),
52        GREATER => Some(OrderEdge::Gt),
53        _ => None,
54    }
55}
56
57fn dedup_atomic_facts(mut facts: Vec<AtomicFact>) -> Vec<AtomicFact> {
58    let mut seen = HashSet::new();
59    facts.retain(|f| seen.insert(f.to_string()));
60    facts
61}
62
63impl ChainFact {
64    pub fn facts_with_order_transitive_closure(&self) -> Result<Vec<AtomicFact>, RuntimeError> {
65        let base = self.facts()?;
66        let n = self.objs.len();
67        if n < 2 {
68            return Ok(base);
69        }
70
71        let mut edges: Vec<OrderEdge> = Vec::with_capacity(self.prop_names.len());
72        for p in &self.prop_names {
73            let Some(e) = order_edge_from_prop(p) else {
74                return Ok(base);
75            };
76            edges.push(e);
77        }
78
79        let mut has_up = false;
80        let mut has_down = false;
81        for e in &edges {
82            match e {
83                OrderEdge::Le | OrderEdge::Lt => has_up = true,
84                OrderEdge::Ge | OrderEdge::Gt => has_down = true,
85                OrderEdge::Eq => {}
86            }
87        }
88        if has_up && has_down {
89            return Ok(base);
90        }
91
92        let polarity = if has_up {
93            ChainPolarity::Up
94        } else if has_down {
95            ChainPolarity::Down
96        } else {
97            ChainPolarity::Up
98        };
99
100        let mut uf = UnionFind::new(n);
101        for (k, e) in edges.iter().enumerate() {
102            if *e == OrderEdge::Eq {
103                uf.union(k, k + 1);
104            }
105        }
106
107        let mut quotient: Vec<usize> = Vec::new();
108        for i in 0..n {
109            if i == 0 || uf.find(i) != uf.find(i - 1) {
110                quotient.push(i);
111            }
112        }
113
114        let mut between_strict: Vec<bool> = Vec::new();
115        for k in 0..edges.len() {
116            let ca = uf.find(k);
117            let cb = uf.find(k + 1);
118            if ca != cb {
119                let strict = match polarity {
120                    ChainPolarity::Up => matches!(edges[k], OrderEdge::Lt),
121                    ChainPolarity::Down => matches!(edges[k], OrderEdge::Gt),
122                };
123                between_strict.push(strict);
124            }
125        }
126
127        if between_strict.len() + 1 != quotient.len() {
128            return Ok(base);
129        }
130
131        let lf = self.line_file.clone();
132        let mut extra: Vec<AtomicFact> = Vec::new();
133
134        let mut members: HashMap<usize, Vec<usize>> = HashMap::new();
135        for i in 0..n {
136            members.entry(uf.find(i)).or_default().push(i);
137        }
138        for mut indexes in members.into_values() {
139            if indexes.len() < 2 {
140                continue;
141            }
142            indexes.sort_unstable();
143            for ii in 0..indexes.len() {
144                for jj in ii + 1..indexes.len() {
145                    let i = indexes[ii];
146                    let j = indexes[jj];
147                    extra.push(
148                        EqualFact::new(self.objs[i].clone(), self.objs[j].clone(), lf.clone())
149                            .into(),
150                    );
151                }
152            }
153        }
154
155        for qi in 0..quotient.len() {
156            for qj in qi + 1..quotient.len() {
157                let path_strict = between_strict[qi..qj].iter().any(|&s| s);
158                let left = self.objs[quotient[qi]].clone();
159                let right = self.objs[quotient[qj]].clone();
160                let f = match polarity {
161                    ChainPolarity::Up => {
162                        if path_strict {
163                            LessFact::new(left, right, lf.clone()).into()
164                        } else {
165                            LessEqualFact::new(left, right, lf.clone()).into()
166                        }
167                    }
168                    ChainPolarity::Down => {
169                        if path_strict {
170                            GreaterFact::new(left, right, lf.clone()).into()
171                        } else {
172                            GreaterEqualFact::new(left, right, lf.clone()).into()
173                        }
174                    }
175                };
176                extra.push(f);
177            }
178        }
179
180        let mut all = base;
181        all.extend(extra);
182        Ok(dedup_atomic_facts(all))
183    }
184}