1use std::collections::{HashMap, HashSet};
2use std::fmt::Debug;
3use std::hash::Hash;
4
5use crate::graph::{DepGraph, GraphNode};
6use crate::{dep::*, Fact};
7
8#[derive(Debug, derive_more::From)]
9pub struct TraversalError<'c, F: Fact> {
10 pub inner: TraversalInnerError<F>,
11 pub graph: DepGraph<'c, F>,
12}
13
14#[derive(Debug, derive_more::From)]
15pub enum TraversalInnerError<F: Fact> {
16 Dep(DepError<F>),
17 }
20
21#[derive(Debug, derive_more::From)]
22pub struct Traversal<'c, T: Fact> {
23 pub(crate) root_check_passed: bool,
24 pub(crate) graph: DepGraph<'c, T>,
25 pub(crate) terminals: HashSet<Dep<T>>,
26 pub(crate) ctx: &'c T::Context,
27}
28
29impl<T: Fact> Traversal<'_, T> {}
30
31pub type TraversalResult<'c, F> = Result<Traversal<'c, F>, TraversalError<'c, F>>;
32
33#[derive(Debug, Clone, Copy)]
35pub enum TraversalMode {
36 TraverseFails,
38 TraversePasses,
41}
42
43impl Default for TraversalMode {
44 fn default() -> Self {
45 Self::TraverseFails
46 }
47}
48
49impl TraversalMode {
50 pub fn terminal_check_value(&self) -> bool {
52 match self {
53 TraversalMode::TraverseFails => true,
54 TraversalMode::TraversePasses => false,
55 }
56 }
57}
58
59#[derive(Clone, Debug, PartialEq, Eq)]
60pub enum TraversalStep<T: Fact> {
61 Terminate,
63 Continue(Vec<Dep<T>>),
65}
66
67impl<T: Fact> TraversalStep<T> {
68 pub fn is_pass(&self) -> bool {
69 matches!(self, TraversalStep::Terminate)
70 }
71}
72
73pub type TraversalMap<T> = HashMap<Dep<T>, Option<TraversalStep<T>>>;
74
75#[cfg_attr(feature = "instrument", tracing::instrument(skip(ctx)))]
86pub fn traverse<F: Fact>(fact: F, ctx: &F::Context) -> TraversalResult<F> {
87 let mut table = TraversalMap::default();
88
89 let root_check_passed = fact.check(ctx);
90 let mode = if root_check_passed {
91 TraversalMode::TraversePasses
92 } else {
93 TraversalMode::TraverseFails
94 };
95
96 let res = traverse_fact(&fact, ctx, &mut table, mode);
97 let dep = Dep::from(fact);
98
99 match res {
100 Ok(check) => {
101 table.insert(dep.clone(), Some(check.clone()));
102 let (graph, terminals) = produce_graph(&table, &dep, ctx);
103
104 Ok(Traversal {
105 root_check_passed,
106 graph,
107 terminals,
108 ctx,
109 })
110 }
111 Err(inner) => {
112 table.insert(
113 dep.clone(),
114 Some(TraversalStep::Continue(vec![dep.clone()])),
115 );
116 let (graph, _) = produce_graph(&table, &dep, ctx);
117
118 Err(TraversalError { graph, inner })
119 }
120 }
121}
122
123#[cfg_attr(feature = "instrument", tracing::instrument(skip(ctx, table)))]
124fn traverse_inner<F: Fact>(
125 dep: &Dep<F>,
126 ctx: &F::Context,
127 table: &mut TraversalMap<F>,
128 mode: TraversalMode,
129) -> Result<Option<TraversalStep<F>>, TraversalInnerError<F>> {
130 tracing::trace!("enter {:?}", dep);
131
132 match table.get(dep) {
133 None => {
134 tracing::trace!("marked visited");
135 table.insert(dep.clone(), None);
137 }
138 Some(None) => {
139 tracing::trace!("loop encountered");
140 return Ok(None);
144 }
145 Some(Some(check)) => {
146 tracing::trace!("return cached: {:?}", check);
147 return Ok(Some(check.clone()));
148 }
149 }
150
151 #[allow(clippy::type_complexity)]
152 let mut recursive_checks =
153 |cs: &[Dep<F>]| -> Result<Vec<(Dep<F>, TraversalStep<F>)>, TraversalInnerError<F>> {
154 let mut checks = vec![];
155 for c in cs {
156 if let Some(check) = traverse_inner(c, ctx, table, mode)? {
157 checks.push((c.clone(), check));
158 }
159 }
160 Ok(checks)
161 };
162
163 let check = match dep {
164 Dep::Fact(f) => {
165 let terminate = f.check(ctx) == mode.terminal_check_value();
166 if terminate {
167 tracing::trace!("fact terminate");
168 TraversalStep::Terminate
169 } else {
170 traverse_fact(f, ctx, table, mode)?
171 }
172 }
173 Dep::Any(_, cs) => {
174 let checks = recursive_checks(cs).map_err(|err| {
175 tracing::error!("traversal ending due to error: {err:?}");
177 table.insert(dep.clone(), Some(TraversalStep::Continue(cs.clone())));
178 err
179 })?;
180 tracing::trace!("Any. checks: {:?}", checks);
181 if checks.is_empty() {
182 tracing::debug!("All loops");
184 return Ok(None);
185 }
186 let num_checks = checks.len();
187 let fails: Vec<_> = checks
188 .into_iter()
189 .filter_map(|(dep, check)| (!check.is_pass()).then_some(dep))
190 .collect();
191 tracing::trace!("Any. fails: {:?}", fails);
192 if fails.len() < num_checks {
193 TraversalStep::Terminate
194 } else {
195 TraversalStep::Continue(fails)
196 }
197 }
198 Dep::Every(_, cs) => {
199 let checks = recursive_checks(cs).map_err(|err| {
200 tracing::error!("traversal ending due to error: {err:?}");
202 table.insert(dep.clone(), Some(TraversalStep::Continue(cs.clone())));
203 err
204 })?;
205
206 tracing::trace!("Every. checks: {:?}", checks);
207 if checks.is_empty() {
208 tracing::debug!("All loops");
210 return Ok(None);
211 }
212 let fails = checks.iter().filter(|(_, check)| !check.is_pass()).count();
213 let deps: Vec<_> = checks.into_iter().map(|(dep, _)| dep).collect();
214 tracing::trace!("Every. num fails: {}", fails);
215 if fails == 0 {
216 TraversalStep::Terminate
217 } else {
218 TraversalStep::Continue(deps)
219 }
220 }
221 };
222 table.insert(dep.clone(), Some(check.clone()));
223 tracing::trace!("exit. check: {:?}", check);
224 Ok(Some(check))
225}
226
227#[cfg_attr(feature = "instrument", tracing::instrument(skip(ctx, table)))]
228fn traverse_fact<F: Fact>(
229 fact: &F,
230 ctx: &F::Context,
231 table: &mut TraversalMap<F>,
232 mode: TraversalMode,
233) -> Result<TraversalStep<F>, TraversalInnerError<F>> {
234 if let Some(sub_dep) = fact.dep(ctx)? {
235 tracing::trace!("traversing fact");
236
237 let check = traverse_inner(&sub_dep, ctx, table, mode).map_err(|err| {
238 table.insert(
240 Dep::from(fact.clone()),
241 Some(TraversalStep::Continue(vec![sub_dep.clone()])),
242 );
243 tracing::error!("traversal ending due to error: {err:?}");
244 err
245 })?;
246 tracing::trace!("traversal done, check: {:?}", check);
247 Ok(TraversalStep::Continue(vec![sub_dep]))
248 } else {
249 tracing::trace!("fact fail with no dep, terminating");
250 Ok(TraversalStep::Continue(vec![]))
251 }
252}
253
254#[allow(clippy::type_complexity)]
260fn prune_traversal<'a, 'b: 'a, T: Fact + Eq + Hash>(
261 table: &'a TraversalMap<T>,
262 start: &'b Dep<T>,
263) -> (HashMap<&'a Dep<T>, &'a [Dep<T>]>, Vec<&'a Dep<T>>) {
264 let mut sub = HashMap::<&Dep<T>, &[Dep<T>]>::new();
265 let mut terminals = vec![];
266 let mut to_add = vec![start];
267
268 while let Some(next) = to_add.pop() {
269 if let Some(step) = table.get(next) {
270 match step.as_ref() {
271 Some(TraversalStep::Continue(deps)) => {
272 let old = sub.insert(next, deps.as_slice());
273 if let Some(old) = old {
274 assert_eq!(
275 old, deps,
276 "Looped back to same node, but with different children?"
277 );
278 } else {
279 to_add.extend(deps.iter());
280 }
281 }
282 Some(TraversalStep::Terminate) => {
283 terminals.push(next);
284 }
285 None => {}
286 }
287 } else {
288 sub.insert(next, &[]);
291 }
292 }
293 (sub, terminals)
294}
295
296pub fn produce_graph<'a, 'b: 'a, 'c, T: Fact + Eq + Hash>(
297 table: &'a TraversalMap<T>,
298 start: &'b Dep<T>,
299 ctx: &'c T::Context,
300) -> (DepGraph<'c, T>, HashSet<Dep<T>>) {
301 let mut g = DepGraph::default();
302
303 let (sub, passes) = prune_traversal(table, start);
304
305 let rows: HashSet<_> = sub.into_iter().collect();
306 let mut nodemap = HashMap::new();
307 for (i, (k, _)) in rows.iter().enumerate() {
308 let id = g.add_node(GraphNode {
309 dep: (*k).to_owned(),
310 ctx,
311 });
312 nodemap.insert(k, id);
313 assert_eq!(id.index(), i);
314 }
315
316 for (k, v) in rows.iter() {
317 for c in v.iter() {
318 if let (Some(k), Some(c)) = (nodemap.get(k), nodemap.get(&c)) {
319 g.add_edge(*k, *c, ());
320 }
321 }
322 }
323
324 (g, passes.into_iter().cloned().collect())
325}