litex/fact/
chain_fact_order_closure.rs1use crate::common::keywords::{EQUAL, GREATER, GREATER_EQUAL, LESS, LESS_EQUAL};
2use crate::fact::matchable_fact_with_atomic_fact_inside::ChainFact;
3use crate::prelude::*;
4use std::collections::{HashMap, HashSet};
5
6#[derive(Clone, Copy, PartialEq, Eq)]
7enum OrderEdge {
8 Eq,
9 Le,
10 Lt,
11 Ge,
12 Gt,
13}
14
15#[derive(Clone, Copy, PartialEq, Eq)]
16enum ChainPolarity {
17 Up,
18 Down,
19}
20
21struct UnionFind {
22 parent: Vec<usize>,
23}
24
25impl UnionFind {
26 fn new(n: usize) -> Self {
27 UnionFind {
28 parent: (0..n).collect(),
29 }
30 }
31
32 fn find(&mut self, x: usize) -> usize {
33 if self.parent[x] != x {
34 self.parent[x] = self.find(self.parent[x]);
35 }
36 self.parent[x]
37 }
38
39 fn union(&mut self, a: usize, b: usize) {
40 let ra = self.find(a);
41 let rb = self.find(b);
42 if ra != rb {
43 self.parent[rb] = ra;
44 }
45 }
46}
47
48fn order_edge_from_prop(p: &IdentifierOrIdentifierWithMod) -> Option<OrderEdge> {
49 match p.to_string().as_str() {
50 EQUAL => Some(OrderEdge::Eq),
51 LESS_EQUAL => Some(OrderEdge::Le),
52 LESS => Some(OrderEdge::Lt),
53 GREATER_EQUAL => Some(OrderEdge::Ge),
54 GREATER => Some(OrderEdge::Gt),
55 _ => None,
56 }
57}
58
59fn dedup_atomic_facts(mut facts: Vec<AtomicFact>) -> Vec<AtomicFact> {
60 let mut seen = HashSet::new();
61 facts.retain(|f| seen.insert(f.to_string()));
62 facts
63}
64
65impl ChainFact {
66 pub fn facts_with_order_transitive_closure(&self) -> Result<Vec<AtomicFact>, RuntimeErrorStruct> {
67 let base = self.facts()?;
68 let n = self.objs.len();
69 if n < 2 {
70 return Ok(base);
71 }
72
73 let mut edges: Vec<OrderEdge> = Vec::with_capacity(self.prop_names.len());
74 for p in &self.prop_names {
75 let Some(e) = order_edge_from_prop(p) else {
76 return Ok(base);
77 };
78 edges.push(e);
79 }
80
81 let mut has_up = false;
82 let mut has_down = false;
83 for e in &edges {
84 match e {
85 OrderEdge::Le | OrderEdge::Lt => has_up = true,
86 OrderEdge::Ge | OrderEdge::Gt => has_down = true,
87 OrderEdge::Eq => {}
88 }
89 }
90 if has_up && has_down {
91 return Ok(base);
92 }
93
94 let polarity = if has_up {
95 ChainPolarity::Up
96 } else if has_down {
97 ChainPolarity::Down
98 } else {
99 ChainPolarity::Up
100 };
101
102 let mut uf = UnionFind::new(n);
103 for (k, e) in edges.iter().enumerate() {
104 if *e == OrderEdge::Eq {
105 uf.union(k, k + 1);
106 }
107 }
108
109 let mut quotient: Vec<usize> = Vec::new();
110 for i in 0..n {
111 if i == 0 || uf.find(i) != uf.find(i - 1) {
112 quotient.push(i);
113 }
114 }
115
116 let mut between_strict: Vec<bool> = Vec::new();
117 for k in 0..edges.len() {
118 let ca = uf.find(k);
119 let cb = uf.find(k + 1);
120 if ca != cb {
121 let strict = match polarity {
122 ChainPolarity::Up => matches!(edges[k], OrderEdge::Lt),
123 ChainPolarity::Down => matches!(edges[k], OrderEdge::Gt),
124 };
125 between_strict.push(strict);
126 }
127 }
128
129 if between_strict.len() + 1 != quotient.len() {
130 return Ok(base);
131 }
132
133 let lf = self.line_file.clone();
134 let mut extra: Vec<AtomicFact> = Vec::new();
135
136 let mut members: HashMap<usize, Vec<usize>> = HashMap::new();
137 for i in 0..n {
138 members.entry(uf.find(i)).or_default().push(i);
139 }
140 for mut idxs in members.into_values() {
141 if idxs.len() < 2 {
142 continue;
143 }
144 idxs.sort_unstable();
145 let rep = idxs[0];
146 for &j in idxs.iter().skip(1) {
147 extra.push(AtomicFact::EqualFact(EqualFact::new(
148 self.objs[rep].clone(),
149 self.objs[j].clone(),
150 lf.clone(),
151 )));
152 }
153 }
154
155 for qi in 0..quotient.len() {
156 for qj in qi + 1..quotient.len() {
157 let path_strict = between_strict[qi..qj].iter().any(|&s| s);
158 let left = self.objs[quotient[qi]].clone();
159 let right = self.objs[quotient[qj]].clone();
160 let f = match polarity {
161 ChainPolarity::Up => {
162 if path_strict {
163 AtomicFact::LessFact(LessFact::new(left, right, lf.clone()))
164 } else {
165 AtomicFact::LessEqualFact(LessEqualFact::new(left, right, lf.clone()))
166 }
167 }
168 ChainPolarity::Down => {
169 if path_strict {
170 AtomicFact::GreaterFact(GreaterFact::new(left, right, lf.clone()))
171 } else {
172 AtomicFact::GreaterEqualFact(GreaterEqualFact::new(
173 left, right, lf.clone(),
174 ))
175 }
176 }
177 };
178 extra.push(f);
179 }
180 }
181
182 let mut all = base;
183 all.extend(extra);
184 Ok(dedup_atomic_facts(all))
185 }
186}
187
188#[cfg(test)]
189mod tests {
190 use super::*;
191
192 fn id(name: &str) -> Obj {
193 Identifier::mk(name.to_string())
194 }
195
196 fn prop(s: &str) -> IdentifierOrIdentifierWithMod {
197 IdentifierOrIdentifierWithMod::Identifier(Identifier::new(s.to_string()))
198 }
199
200 #[test]
201 fn le_eq_lt_chain_adds_transitive_facts() {
202 let lf = default_line_file();
203 let chain = ChainFact::new(
204 vec![id("a"), id("b"), id("c"), id("d")],
205 vec![prop("<="), prop("="), prop("<")],
206 lf.clone(),
207 );
208 let facts = chain.facts_with_order_transitive_closure().unwrap();
209 let displayed: Vec<_> = facts.iter().map(|f| f.to_string()).collect();
210 assert!(displayed.iter().any(|s| s.contains("a") && s.contains("d")));
211 assert!(
212 facts.len() >= 5,
213 "expected transitive closure beyond three adjacent facts, got {:?}",
214 displayed
215 );
216 }
217
218 #[test]
219 fn mixed_up_down_chain_skips_closure_beyond_base() {
220 let lf = default_line_file();
221 let chain = ChainFact::new(
222 vec![id("a"), id("b"), id("c")],
223 vec![prop("<="), prop(">")],
224 lf,
225 );
226 let facts = chain.facts_with_order_transitive_closure().unwrap();
227 assert_eq!(facts.len(), 2);
228 }
229
230 #[test]
231 fn ge_eq_gt_chain_adds_transitive_greater_facts() {
232 let lf = default_line_file();
233 let chain = ChainFact::new(
234 vec![id("d"), id("c"), id("b"), id("a")],
235 vec![prop(">="), prop("="), prop(">")],
236 lf,
237 );
238 let facts = chain.facts_with_order_transitive_closure().unwrap();
239 let displayed: Vec<_> = facts.iter().map(|f| f.to_string()).collect();
240 assert!(displayed.iter().any(|s| s.contains("d") && s.contains("a")));
241 assert!(facts.len() >= 5);
242 }
243}