litex/fact/
chain_fact_order_closure.rs1use crate::prelude::*;
2use std::collections::{HashMap, HashSet};
3
4#[derive(Clone, Copy, PartialEq, Eq)]
5enum OrderEdge {
6 Eq,
7 Le,
8 Lt,
9 Ge,
10 Gt,
11}
12
13#[derive(Clone, Copy, PartialEq, Eq)]
14enum ChainPolarity {
15 Up,
16 Down,
17}
18
19struct UnionFind {
20 parent: Vec<usize>,
21}
22
23impl UnionFind {
24 fn new(n: usize) -> Self {
25 UnionFind {
26 parent: (0..n).collect(),
27 }
28 }
29
30 fn find(&mut self, x: usize) -> usize {
31 if self.parent[x] != x {
32 self.parent[x] = self.find(self.parent[x]);
33 }
34 self.parent[x]
35 }
36
37 fn union(&mut self, a: usize, b: usize) {
38 let ra = self.find(a);
39 let rb = self.find(b);
40 if ra != rb {
41 self.parent[rb] = ra;
42 }
43 }
44}
45
46fn order_edge_from_prop(p: &IdentifierOrIdentifierWithMod) -> Option<OrderEdge> {
47 match p.to_string().as_str() {
48 EQUAL => Some(OrderEdge::Eq),
49 LESS_EQUAL => Some(OrderEdge::Le),
50 LESS => Some(OrderEdge::Lt),
51 GREATER_EQUAL => Some(OrderEdge::Ge),
52 GREATER => Some(OrderEdge::Gt),
53 _ => None,
54 }
55}
56
57fn dedup_atomic_facts(mut facts: Vec<AtomicFact>) -> Vec<AtomicFact> {
58 let mut seen = HashSet::new();
59 facts.retain(|f| seen.insert(f.to_string()));
60 facts
61}
62
63impl ChainFact {
64 pub fn facts_with_order_transitive_closure(&self) -> Result<Vec<AtomicFact>, RuntimeError> {
65 let base = self.facts()?;
66 let n = self.objs.len();
67 if n < 2 {
68 return Ok(base);
69 }
70
71 let mut edges: Vec<OrderEdge> = Vec::with_capacity(self.prop_names.len());
72 for p in &self.prop_names {
73 let Some(e) = order_edge_from_prop(p) else {
74 return Ok(base);
75 };
76 edges.push(e);
77 }
78
79 let mut has_up = false;
80 let mut has_down = false;
81 for e in &edges {
82 match e {
83 OrderEdge::Le | OrderEdge::Lt => has_up = true,
84 OrderEdge::Ge | OrderEdge::Gt => has_down = true,
85 OrderEdge::Eq => {}
86 }
87 }
88 if has_up && has_down {
89 return Ok(base);
90 }
91
92 let polarity = if has_up {
93 ChainPolarity::Up
94 } else if has_down {
95 ChainPolarity::Down
96 } else {
97 ChainPolarity::Up
98 };
99
100 let mut uf = UnionFind::new(n);
101 for (k, e) in edges.iter().enumerate() {
102 if *e == OrderEdge::Eq {
103 uf.union(k, k + 1);
104 }
105 }
106
107 let mut quotient: Vec<usize> = Vec::new();
108 for i in 0..n {
109 if i == 0 || uf.find(i) != uf.find(i - 1) {
110 quotient.push(i);
111 }
112 }
113
114 let mut between_strict: Vec<bool> = Vec::new();
115 for k in 0..edges.len() {
116 let ca = uf.find(k);
117 let cb = uf.find(k + 1);
118 if ca != cb {
119 let strict = match polarity {
120 ChainPolarity::Up => matches!(edges[k], OrderEdge::Lt),
121 ChainPolarity::Down => matches!(edges[k], OrderEdge::Gt),
122 };
123 between_strict.push(strict);
124 }
125 }
126
127 if between_strict.len() + 1 != quotient.len() {
128 return Ok(base);
129 }
130
131 let lf = self.line_file.clone();
132 let mut extra: Vec<AtomicFact> = Vec::new();
133
134 let mut members: HashMap<usize, Vec<usize>> = HashMap::new();
135 for i in 0..n {
136 members.entry(uf.find(i)).or_default().push(i);
137 }
138 for mut idxs in members.into_values() {
139 if idxs.len() < 2 {
140 continue;
141 }
142 idxs.sort_unstable();
143 for ii in 0..idxs.len() {
144 for jj in ii + 1..idxs.len() {
145 let i = idxs[ii];
146 let j = idxs[jj];
147 extra.push(
148 EqualFact::new(self.objs[i].clone(), self.objs[j].clone(), lf.clone())
149 .into(),
150 );
151 }
152 }
153 }
154
155 for qi in 0..quotient.len() {
156 for qj in qi + 1..quotient.len() {
157 let path_strict = between_strict[qi..qj].iter().any(|&s| s);
158 let left = self.objs[quotient[qi]].clone();
159 let right = self.objs[quotient[qj]].clone();
160 let f = match polarity {
161 ChainPolarity::Up => {
162 if path_strict {
163 LessFact::new(left, right, lf.clone()).into()
164 } else {
165 LessEqualFact::new(left, right, lf.clone()).into()
166 }
167 }
168 ChainPolarity::Down => {
169 if path_strict {
170 GreaterFact::new(left, right, lf.clone()).into()
171 } else {
172 GreaterEqualFact::new(left, right, lf.clone()).into()
173 }
174 }
175 };
176 extra.push(f);
177 }
178 }
179
180 let mut all = base;
181 all.extend(extra);
182 Ok(dedup_atomic_facts(all))
183 }
184}
185
186#[cfg(test)]
187mod tests {
188 use super::*;
189
190 fn id(name: &str) -> Obj {
191 Identifier::mk(name.to_string())
192 }
193
194 fn prop(s: &str) -> IdentifierOrIdentifierWithMod {
195 IdentifierOrIdentifierWithMod::Identifier(Identifier::new(s.to_string()))
196 }
197
198 #[test]
199 fn eq_chain_adds_all_pairs_between_equal_nodes() {
200 let lf = default_line_file();
201 let chain = ChainFact::new(
202 vec![id("x"), id("y"), id("z")],
203 vec![prop("="), prop("=")],
204 lf.clone(),
205 );
206 let facts = chain.facts_with_order_transitive_closure().unwrap();
207 let displayed: Vec<_> = facts.iter().map(|f| f.to_string()).collect();
208 assert!(
209 displayed.iter().any(|s| s.contains("y") && s.contains("z") && s.contains("=")),
210 "expected y = z from equality clique, got {:?}",
211 displayed
212 );
213 }
214
215 #[test]
216 fn le_eq_lt_chain_adds_transitive_facts() {
217 let lf = default_line_file();
218 let chain = ChainFact::new(
219 vec![id("a"), id("b"), id("c"), id("d")],
220 vec![prop("<="), prop("="), prop("<")],
221 lf.clone(),
222 );
223 let facts = chain.facts_with_order_transitive_closure().unwrap();
224 let displayed: Vec<_> = facts.iter().map(|f| f.to_string()).collect();
225 assert!(displayed.iter().any(|s| s.contains("a") && s.contains("d")));
226 assert!(
227 facts.len() >= 5,
228 "expected transitive closure beyond three adjacent facts, got {:?}",
229 displayed
230 );
231 }
232
233 #[test]
234 fn mixed_up_down_chain_skips_closure_beyond_base() {
235 let lf = default_line_file();
236 let chain = ChainFact::new(
237 vec![id("a"), id("b"), id("c")],
238 vec![prop("<="), prop(">")],
239 lf,
240 );
241 let facts = chain.facts_with_order_transitive_closure().unwrap();
242 assert_eq!(facts.len(), 2);
243 }
244
245 #[test]
246 fn ge_eq_gt_chain_adds_transitive_greater_facts() {
247 let lf = default_line_file();
248 let chain = ChainFact::new(
249 vec![id("d"), id("c"), id("b"), id("a")],
250 vec![prop(">="), prop("="), prop(">")],
251 lf,
252 );
253 let facts = chain.facts_with_order_transitive_closure().unwrap();
254 let displayed: Vec<_> = facts.iter().map(|f| f.to_string()).collect();
255 assert!(displayed.iter().any(|s| s.contains("d") && s.contains("a")));
256 assert!(facts.len() >= 5);
257 }
258}