aitia/
traversal.rs

1use std::collections::{HashMap, HashSet};
2use std::fmt::Debug;
3use std::hash::Hash;
4
5use crate::graph::{DepGraph, GraphNode};
6use crate::{dep::*, Fact};
7
8#[derive(Debug, derive_more::From)]
9pub struct TraversalError<'c, F: Fact> {
10    pub inner: TraversalInnerError<F>,
11    pub graph: DepGraph<'c, F>,
12}
13
14#[derive(Debug, derive_more::From)]
15pub enum TraversalInnerError<F: Fact> {
16    Dep(DepError<F>),
17    // TODO: eventually allow errors in checks
18    // Check(CheckError<F>),
19}
20
21#[derive(Debug, derive_more::From)]
22pub struct Traversal<'c, T: Fact> {
23    pub(crate) root_check_passed: bool,
24    pub(crate) graph: DepGraph<'c, T>,
25    pub(crate) terminals: HashSet<Dep<T>>,
26    pub(crate) ctx: &'c T::Context,
27}
28
29impl<T: Fact> Traversal<'_, T> {}
30
31pub type TraversalResult<'c, F> = Result<Traversal<'c, F>, TraversalError<'c, F>>;
32
33/// Different modes of traversing the graph
34#[derive(Debug, Clone, Copy)]
35pub enum TraversalMode {
36    /// The default mode, which terminates traversal along a branch whenever a true fact is encountered.
37    TraverseFails,
38    /// Traverses the entire graph, expecting the entire traversal to consist of true facts.
39    /// Useful for self-checking your model by running it against scenarios which are known to succeed.
40    TraversePasses,
41}
42
43impl Default for TraversalMode {
44    fn default() -> Self {
45        Self::TraverseFails
46    }
47}
48
49impl TraversalMode {
50    /// When traversing in this mode, when a Check comes back with this value, terminate that branch.
51    pub fn terminal_check_value(&self) -> bool {
52        match self {
53            TraversalMode::TraverseFails => true,
54            TraversalMode::TraversePasses => false,
55        }
56    }
57}
58
59#[derive(Clone, Debug, PartialEq, Eq)]
60pub enum TraversalStep<T: Fact> {
61    /// This node terminates the traversal due to its check() status.
62    Terminate,
63    /// The traversal should continue with the following nodes.
64    Continue(Vec<Dep<T>>),
65}
66
67impl<T: Fact> TraversalStep<T> {
68    pub fn is_pass(&self) -> bool {
69        matches!(self, TraversalStep::Terminate)
70    }
71}
72
73pub type TraversalMap<T> = HashMap<Dep<T>, Option<TraversalStep<T>>>;
74
75/// Traverse the causal graph implied by the specified Dep.
76///
77/// The Traversal is recorded as a sparse adjacency matrix.
78/// Each dep which is visited in the traversal gets added as a node in the graph,
79/// initially with no edges.
80/// For each dep with a failing "check", we recursively visit its dep(s).
81/// Any time we encounter a dep with a passing "check", we backtrack and add edges
82/// to add this path to the graph.
83/// If a path ends in a failing check, or if it forms a loop without encountering
84/// a passing check, we don't add that path to the graph.
85#[cfg_attr(feature = "instrument", tracing::instrument(skip(ctx)))]
86pub fn traverse<F: Fact>(fact: F, ctx: &F::Context) -> TraversalResult<F> {
87    let mut table = TraversalMap::default();
88
89    let root_check_passed = fact.check(ctx);
90    let mode = if root_check_passed {
91        TraversalMode::TraversePasses
92    } else {
93        TraversalMode::TraverseFails
94    };
95
96    let res = traverse_fact(&fact, ctx, &mut table, mode);
97    let dep = Dep::from(fact);
98
99    match res {
100        Ok(check) => {
101            table.insert(dep.clone(), Some(check.clone()));
102            let (graph, terminals) = produce_graph(&table, &dep, ctx);
103
104            Ok(Traversal {
105                root_check_passed,
106                graph,
107                terminals,
108                ctx,
109            })
110        }
111        Err(inner) => {
112            table.insert(
113                dep.clone(),
114                Some(TraversalStep::Continue(vec![dep.clone()])),
115            );
116            let (graph, _) = produce_graph(&table, &dep, ctx);
117
118            Err(TraversalError { graph, inner })
119        }
120    }
121}
122
123#[cfg_attr(feature = "instrument", tracing::instrument(skip(ctx, table)))]
124fn traverse_inner<F: Fact>(
125    dep: &Dep<F>,
126    ctx: &F::Context,
127    table: &mut TraversalMap<F>,
128    mode: TraversalMode,
129) -> Result<Option<TraversalStep<F>>, TraversalInnerError<F>> {
130    tracing::trace!("enter {:?}", dep);
131
132    match table.get(dep) {
133        None => {
134            tracing::trace!("marked visited");
135            // Mark this node as visited but undetermined in case the traversal leads to a loop
136            table.insert(dep.clone(), None);
137        }
138        Some(None) => {
139            tracing::trace!("loop encountered");
140            // We're currently processing a traversal that started from this dep.
141            // Not even sure if this is even valid, but in any case
142            // we certainly can't say anything about this traversal.
143            return Ok(None);
144        }
145        Some(Some(check)) => {
146            tracing::trace!("return cached: {:?}", check);
147            return Ok(Some(check.clone()));
148        }
149    }
150
151    #[allow(clippy::type_complexity)]
152    let mut recursive_checks =
153        |cs: &[Dep<F>]| -> Result<Vec<(Dep<F>, TraversalStep<F>)>, TraversalInnerError<F>> {
154            let mut checks = vec![];
155            for c in cs {
156                if let Some(check) = traverse_inner(c, ctx, table, mode)? {
157                    checks.push((c.clone(), check));
158                }
159            }
160            Ok(checks)
161        };
162
163    let check = match dep {
164        Dep::Fact(f) => {
165            let terminate = f.check(ctx) == mode.terminal_check_value();
166            if terminate {
167                tracing::trace!("fact terminate");
168                TraversalStep::Terminate
169            } else {
170                traverse_fact(f, ctx, table, mode)?
171            }
172        }
173        Dep::Any(_, cs) => {
174            let checks = recursive_checks(cs).map_err(|err| {
175                // Continue constructing the graph while we bubble up errors
176                tracing::error!("traversal ending due to error: {err:?}");
177                table.insert(dep.clone(), Some(TraversalStep::Continue(cs.clone())));
178                err
179            })?;
180            tracing::trace!("Any. checks: {:?}", checks);
181            if checks.is_empty() {
182                // All loops
183                tracing::debug!("All loops");
184                return Ok(None);
185            }
186            let num_checks = checks.len();
187            let fails: Vec<_> = checks
188                .into_iter()
189                .filter_map(|(dep, check)| (!check.is_pass()).then_some(dep))
190                .collect();
191            tracing::trace!("Any. fails: {:?}", fails);
192            if fails.len() < num_checks {
193                TraversalStep::Terminate
194            } else {
195                TraversalStep::Continue(fails)
196            }
197        }
198        Dep::Every(_, cs) => {
199            let checks = recursive_checks(cs).map_err(|err| {
200                // Continue constructing the graph while we bubble up errors
201                tracing::error!("traversal ending due to error: {err:?}");
202                table.insert(dep.clone(), Some(TraversalStep::Continue(cs.clone())));
203                err
204            })?;
205
206            tracing::trace!("Every. checks: {:?}", checks);
207            if checks.is_empty() {
208                // All loops
209                tracing::debug!("All loops");
210                return Ok(None);
211            }
212            let fails = checks.iter().filter(|(_, check)| !check.is_pass()).count();
213            let deps: Vec<_> = checks.into_iter().map(|(dep, _)| dep).collect();
214            tracing::trace!("Every. num fails: {}", fails);
215            if fails == 0 {
216                TraversalStep::Terminate
217            } else {
218                TraversalStep::Continue(deps)
219            }
220        }
221    };
222    table.insert(dep.clone(), Some(check.clone()));
223    tracing::trace!("exit. check: {:?}", check);
224    Ok(Some(check))
225}
226
227#[cfg_attr(feature = "instrument", tracing::instrument(skip(ctx, table)))]
228fn traverse_fact<F: Fact>(
229    fact: &F,
230    ctx: &F::Context,
231    table: &mut TraversalMap<F>,
232    mode: TraversalMode,
233) -> Result<TraversalStep<F>, TraversalInnerError<F>> {
234    if let Some(sub_dep) = fact.dep(ctx)? {
235        tracing::trace!("traversing fact");
236
237        let check = traverse_inner(&sub_dep, ctx, table, mode).map_err(|err| {
238            // Continue constructing the graph while we bubble up errors
239            table.insert(
240                Dep::from(fact.clone()),
241                Some(TraversalStep::Continue(vec![sub_dep.clone()])),
242            );
243            tracing::error!("traversal ending due to error: {err:?}");
244            err
245        })?;
246        tracing::trace!("traversal done, check: {:?}", check);
247        Ok(TraversalStep::Continue(vec![sub_dep]))
248    } else {
249        tracing::trace!("fact fail with no dep, terminating");
250        Ok(TraversalStep::Continue(vec![]))
251    }
252}
253
254/// Prune away any extraneous nodes or edges from a Traversal.
255/// After pruning, the graph contains either all true edges or all false edges,
256/// with paths terminating at a transition point.
257///
258/// Terminal facts are returned separately.
259#[allow(clippy::type_complexity)]
260fn prune_traversal<'a, 'b: 'a, T: Fact + Eq + Hash>(
261    table: &'a TraversalMap<T>,
262    start: &'b Dep<T>,
263) -> (HashMap<&'a Dep<T>, &'a [Dep<T>]>, Vec<&'a Dep<T>>) {
264    let mut sub = HashMap::<&Dep<T>, &[Dep<T>]>::new();
265    let mut terminals = vec![];
266    let mut to_add = vec![start];
267
268    while let Some(next) = to_add.pop() {
269        if let Some(step) = table.get(next) {
270            match step.as_ref() {
271                Some(TraversalStep::Continue(deps)) => {
272                    let old = sub.insert(next, deps.as_slice());
273                    if let Some(old) = old {
274                        assert_eq!(
275                            old, deps,
276                            "Looped back to same node, but with different children?"
277                        );
278                    } else {
279                        to_add.extend(deps.iter());
280                    }
281                }
282                Some(TraversalStep::Terminate) => {
283                    terminals.push(next);
284                }
285                None => {}
286            }
287        } else {
288            // Still include this as an orphan node.
289            // This should only ever apply to the starting node.
290            sub.insert(next, &[]);
291        }
292    }
293    (sub, terminals)
294}
295
296pub fn produce_graph<'a, 'b: 'a, 'c, T: Fact + Eq + Hash>(
297    table: &'a TraversalMap<T>,
298    start: &'b Dep<T>,
299    ctx: &'c T::Context,
300) -> (DepGraph<'c, T>, HashSet<Dep<T>>) {
301    let mut g = DepGraph::default();
302
303    let (sub, passes) = prune_traversal(table, start);
304
305    let rows: HashSet<_> = sub.into_iter().collect();
306    let mut nodemap = HashMap::new();
307    for (i, (k, _)) in rows.iter().enumerate() {
308        let id = g.add_node(GraphNode {
309            dep: (*k).to_owned(),
310            ctx,
311        });
312        nodemap.insert(k, id);
313        assert_eq!(id.index(), i);
314    }
315
316    for (k, v) in rows.iter() {
317        for c in v.iter() {
318            if let (Some(k), Some(c)) = (nodemap.get(k), nodemap.get(&c)) {
319                g.add_edge(*k, *c, ());
320            }
321        }
322    }
323
324    (g, passes.into_iter().cloned().collect())
325}