litex/fact/
chain_fact_order_closure.rs1use crate::prelude::*;
2use std::collections::{HashMap, HashSet};
3
4#[derive(Clone, Copy, PartialEq, Eq)]
5enum OrderEdge {
6 Eq,
7 Le,
8 Lt,
9 Ge,
10 Gt,
11}
12
13#[derive(Clone, Copy, PartialEq, Eq)]
14enum ChainPolarity {
15 Up,
16 Down,
17}
18
19struct UnionFind {
20 parent: Vec<usize>,
21}
22
23impl UnionFind {
24 fn new(n: usize) -> Self {
25 UnionFind {
26 parent: (0..n).collect(),
27 }
28 }
29
30 fn find(&mut self, x: usize) -> usize {
31 if self.parent[x] != x {
32 self.parent[x] = self.find(self.parent[x]);
33 }
34 self.parent[x]
35 }
36
37 fn union(&mut self, a: usize, b: usize) {
38 let ra = self.find(a);
39 let rb = self.find(b);
40 if ra != rb {
41 self.parent[rb] = ra;
42 }
43 }
44}
45
46fn order_edge_from_prop(p: &AtomicName) -> Option<OrderEdge> {
47 match p.to_string().as_str() {
48 EQUAL => Some(OrderEdge::Eq),
49 LESS_EQUAL => Some(OrderEdge::Le),
50 LESS => Some(OrderEdge::Lt),
51 GREATER_EQUAL => Some(OrderEdge::Ge),
52 GREATER => Some(OrderEdge::Gt),
53 _ => None,
54 }
55}
56
57fn dedup_atomic_facts(mut facts: Vec<AtomicFact>) -> Vec<AtomicFact> {
58 let mut seen = HashSet::new();
59 facts.retain(|f| seen.insert(f.to_string()));
60 facts
61}
62
63impl ChainFact {
64 pub fn facts_with_order_transitive_closure(&self) -> Result<Vec<AtomicFact>, RuntimeError> {
65 let base = self.facts()?;
66 let n = self.objs.len();
67 if n < 2 {
68 return Ok(base);
69 }
70
71 let mut edges: Vec<OrderEdge> = Vec::with_capacity(self.prop_names.len());
72 for p in &self.prop_names {
73 let Some(e) = order_edge_from_prop(p) else {
74 return Ok(base);
75 };
76 edges.push(e);
77 }
78
79 let mut has_up = false;
80 let mut has_down = false;
81 for e in &edges {
82 match e {
83 OrderEdge::Le | OrderEdge::Lt => has_up = true,
84 OrderEdge::Ge | OrderEdge::Gt => has_down = true,
85 OrderEdge::Eq => {}
86 }
87 }
88 if has_up && has_down {
89 return Ok(base);
90 }
91
92 let polarity = if has_up {
93 ChainPolarity::Up
94 } else if has_down {
95 ChainPolarity::Down
96 } else {
97 ChainPolarity::Up
98 };
99
100 let mut uf = UnionFind::new(n);
101 for (k, e) in edges.iter().enumerate() {
102 if *e == OrderEdge::Eq {
103 uf.union(k, k + 1);
104 }
105 }
106
107 let mut quotient: Vec<usize> = Vec::new();
108 for i in 0..n {
109 if i == 0 || uf.find(i) != uf.find(i - 1) {
110 quotient.push(i);
111 }
112 }
113
114 let mut between_strict: Vec<bool> = Vec::new();
115 for k in 0..edges.len() {
116 let ca = uf.find(k);
117 let cb = uf.find(k + 1);
118 if ca != cb {
119 let strict = match polarity {
120 ChainPolarity::Up => matches!(edges[k], OrderEdge::Lt),
121 ChainPolarity::Down => matches!(edges[k], OrderEdge::Gt),
122 };
123 between_strict.push(strict);
124 }
125 }
126
127 if between_strict.len() + 1 != quotient.len() {
128 return Ok(base);
129 }
130
131 let lf = self.line_file.clone();
132 let mut extra: Vec<AtomicFact> = Vec::new();
133
134 let mut members: HashMap<usize, Vec<usize>> = HashMap::new();
135 for i in 0..n {
136 members.entry(uf.find(i)).or_default().push(i);
137 }
138 for mut indexes in members.into_values() {
139 if indexes.len() < 2 {
140 continue;
141 }
142 indexes.sort_unstable();
143 for ii in 0..indexes.len() {
144 for jj in ii + 1..indexes.len() {
145 let i = indexes[ii];
146 let j = indexes[jj];
147 extra.push(
148 EqualFact::new(self.objs[i].clone(), self.objs[j].clone(), lf.clone())
149 .into(),
150 );
151 }
152 }
153 }
154
155 for qi in 0..quotient.len() {
156 for qj in qi + 1..quotient.len() {
157 let path_strict = between_strict[qi..qj].iter().any(|&s| s);
158 let left = self.objs[quotient[qi]].clone();
159 let right = self.objs[quotient[qj]].clone();
160 let f = match polarity {
161 ChainPolarity::Up => {
162 if path_strict {
163 LessFact::new(left, right, lf.clone()).into()
164 } else {
165 LessEqualFact::new(left, right, lf.clone()).into()
166 }
167 }
168 ChainPolarity::Down => {
169 if path_strict {
170 GreaterFact::new(left, right, lf.clone()).into()
171 } else {
172 GreaterEqualFact::new(left, right, lf.clone()).into()
173 }
174 }
175 };
176 extra.push(f);
177 }
178 }
179
180 let mut all = base;
181 all.extend(extra);
182 Ok(dedup_atomic_facts(all))
183 }
184}