Skip to main content

litex/fact/
chain_fact_order_closure.rs

1use crate::prelude::*;
2use std::collections::{HashMap, HashSet};
3
4#[derive(Clone, Copy, PartialEq, Eq)]
5enum OrderEdge {
6    Eq,
7    Le,
8    Lt,
9    Ge,
10    Gt,
11}
12
13#[derive(Clone, Copy, PartialEq, Eq)]
14enum ChainPolarity {
15    Up,
16    Down,
17}
18
19struct UnionFind {
20    parent: Vec<usize>,
21}
22
23impl UnionFind {
24    fn new(n: usize) -> Self {
25        UnionFind {
26            parent: (0..n).collect(),
27        }
28    }
29
30    fn find(&mut self, x: usize) -> usize {
31        if self.parent[x] != x {
32            self.parent[x] = self.find(self.parent[x]);
33        }
34        self.parent[x]
35    }
36
37    fn union(&mut self, a: usize, b: usize) {
38        let ra = self.find(a);
39        let rb = self.find(b);
40        if ra != rb {
41            self.parent[rb] = ra;
42        }
43    }
44}
45
46fn order_edge_from_prop(p: &AtomicName) -> Option<OrderEdge> {
47    match p.to_string().as_str() {
48        EQUAL => Some(OrderEdge::Eq),
49        LESS_EQUAL => Some(OrderEdge::Le),
50        LESS => Some(OrderEdge::Lt),
51        GREATER_EQUAL => Some(OrderEdge::Ge),
52        GREATER => Some(OrderEdge::Gt),
53        _ => None,
54    }
55}
56
57fn subset_chain_prop(p: &AtomicName) -> Option<&'static str> {
58    match p.to_string().as_str() {
59        SUBSET => Some(SUBSET),
60        SUPERSET => Some(SUPERSET),
61        _ => None,
62    }
63}
64
65fn dedup_atomic_facts(mut facts: Vec<AtomicFact>) -> Vec<AtomicFact> {
66    let mut seen = HashSet::new();
67    facts.retain(|f| seen.insert(f.to_string()));
68    facts
69}
70
71impl ChainFact {
72    pub fn facts_with_order_transitive_closure(&self) -> Result<Vec<AtomicFact>, RuntimeError> {
73        let base = self.facts()?;
74        let n = self.objs.len();
75        if n < 2 {
76            return Ok(base);
77        }
78
79        if let Some(first_prop) = self.prop_names.first().and_then(subset_chain_prop) {
80            let all_same_subset_prop = self
81                .prop_names
82                .iter()
83                .all(|p| subset_chain_prop(p) == Some(first_prop));
84            if all_same_subset_prop {
85                let mut extra = Vec::new();
86                let lf = self.line_file.clone();
87                for i in 0..n {
88                    for j in i + 2..n {
89                        let fact: AtomicFact = if first_prop == SUBSET {
90                            SubsetFact::new(self.objs[i].clone(), self.objs[j].clone(), lf.clone())
91                                .into()
92                        } else {
93                            SupersetFact::new(
94                                self.objs[i].clone(),
95                                self.objs[j].clone(),
96                                lf.clone(),
97                            )
98                            .into()
99                        };
100                        extra.push(fact);
101                    }
102                }
103
104                let mut all = base;
105                all.extend(extra);
106                return Ok(dedup_atomic_facts(all));
107            }
108        }
109
110        let mut edges: Vec<OrderEdge> = Vec::with_capacity(self.prop_names.len());
111        for p in &self.prop_names {
112            let Some(e) = order_edge_from_prop(p) else {
113                return Ok(base);
114            };
115            edges.push(e);
116        }
117
118        let mut has_up = false;
119        let mut has_down = false;
120        for e in &edges {
121            match e {
122                OrderEdge::Le | OrderEdge::Lt => has_up = true,
123                OrderEdge::Ge | OrderEdge::Gt => has_down = true,
124                OrderEdge::Eq => {}
125            }
126        }
127        if has_up && has_down {
128            return Ok(base);
129        }
130
131        let polarity = if has_up {
132            ChainPolarity::Up
133        } else if has_down {
134            ChainPolarity::Down
135        } else {
136            ChainPolarity::Up
137        };
138
139        let mut uf = UnionFind::new(n);
140        for (k, e) in edges.iter().enumerate() {
141            if *e == OrderEdge::Eq {
142                uf.union(k, k + 1);
143            }
144        }
145
146        let mut quotient: Vec<usize> = Vec::new();
147        for i in 0..n {
148            if i == 0 || uf.find(i) != uf.find(i - 1) {
149                quotient.push(i);
150            }
151        }
152
153        let mut between_strict: Vec<bool> = Vec::new();
154        for k in 0..edges.len() {
155            let ca = uf.find(k);
156            let cb = uf.find(k + 1);
157            if ca != cb {
158                let strict = match polarity {
159                    ChainPolarity::Up => matches!(edges[k], OrderEdge::Lt),
160                    ChainPolarity::Down => matches!(edges[k], OrderEdge::Gt),
161                };
162                between_strict.push(strict);
163            }
164        }
165
166        if between_strict.len() + 1 != quotient.len() {
167            return Ok(base);
168        }
169
170        let lf = self.line_file.clone();
171        let mut extra: Vec<AtomicFact> = Vec::new();
172
173        let mut members: HashMap<usize, Vec<usize>> = HashMap::new();
174        for i in 0..n {
175            members.entry(uf.find(i)).or_default().push(i);
176        }
177        for mut indexes in members.into_values() {
178            if indexes.len() < 2 {
179                continue;
180            }
181            indexes.sort_unstable();
182            for ii in 0..indexes.len() {
183                for jj in ii + 1..indexes.len() {
184                    let i = indexes[ii];
185                    let j = indexes[jj];
186                    extra.push(
187                        EqualFact::new(self.objs[i].clone(), self.objs[j].clone(), lf.clone())
188                            .into(),
189                    );
190                }
191            }
192        }
193
194        for qi in 0..quotient.len() {
195            for qj in qi + 1..quotient.len() {
196                let path_strict = between_strict[qi..qj].iter().any(|&s| s);
197                let left = self.objs[quotient[qi]].clone();
198                let right = self.objs[quotient[qj]].clone();
199                let f = match polarity {
200                    ChainPolarity::Up => {
201                        if path_strict {
202                            LessFact::new(left, right, lf.clone()).into()
203                        } else {
204                            LessEqualFact::new(left, right, lf.clone()).into()
205                        }
206                    }
207                    ChainPolarity::Down => {
208                        if path_strict {
209                            GreaterFact::new(left, right, lf.clone()).into()
210                        } else {
211                            GreaterEqualFact::new(left, right, lf.clone()).into()
212                        }
213                    }
214                };
215                extra.push(f);
216            }
217        }
218
219        let mut all = base;
220        all.extend(extra);
221        Ok(dedup_atomic_facts(all))
222    }
223}