Skip to main content

scm_bisect/
search.rs

1//! A search algorithm for directed acyclic graphs to find the nodes which
2//! "flip" from passing to failing a predicate.
3
4use std::collections::{HashSet, VecDeque};
5use std::fmt::Debug;
6use std::hash::Hash;
7
8use indexmap::IndexMap;
9use tracing::{debug, instrument};
10
11/// The set of nodes compromising a directed acyclic graph to be searched.
12pub trait Graph: Debug {
13    /// The type of nodes in the graph. This should be cheap to clone.
14    type Node: Clone + Debug + Hash + Eq;
15
16    /// An error type.
17    type Error: std::error::Error;
18
19    /// Return whether or not `node` is an ancestor of `descendant`. A node `X``
20    /// is said to be an "ancestor" of node `Y` if one of the following is true:
21    ///
22    /// - `X == Y`
23    /// - `X` is an immediate parent of `Y`.
24    /// - `X` is an ancestor of an immediate parent of `Y` (defined
25    ///   recursively).
26    fn is_ancestor(
27        &self,
28        ancestor: Self::Node,
29        descendant: Self::Node,
30    ) -> Result<bool, Self::Error>;
31
32    /// Filter `nodes` to only include nodes that are not ancestors of any other
33    /// node in `nodes`. This is not strictly necessary, but it improves
34    /// performance as some operations are linear in the size of the success
35    /// bounds, and it can make the intermediate results more sensible.
36    ///
37    /// This operation is called `heads` in e.g. Mercurial.
38    #[instrument]
39    fn simplify_success_bounds(
40        &self,
41        nodes: HashSet<Self::Node>,
42    ) -> Result<HashSet<Self::Node>, Self::Error> {
43        Ok(nodes)
44    }
45
46    /// Filter `nodes` to only include nodes that are not descendants of any
47    /// other node in `nodes`. This is not strictly necessary, but it improves
48    /// performance as some operations are linear in the size of the failure
49    /// bounds, and it can make the intermediate results more sensible.
50    ///
51    /// This operation is called `roots` in e.g. Mercurial.
52    #[instrument]
53    fn simplify_failure_bounds(
54        &self,
55        nodes: HashSet<Self::Node>,
56    ) -> Result<HashSet<Self::Node>, Self::Error> {
57        Ok(nodes)
58    }
59}
60
61/// The possible statuses of a node in the search.
62#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
63pub enum Status {
64    /// The node has not been tested yet. This is the starting state for each node in a search.
65    Untested,
66
67    /// The node has been tested and satisfies some caller-defined predicate.
68    /// For the rest of the search, it's assumed that all ancestor nodes of this
69    /// node also satisfy the predicate.
70    Success,
71
72    /// The node has been tested and does not satisfy some caller-defined
73    /// predicate. For the rest of the search, it's assumed that all descendant
74    /// nodes of this node also do not satisfy the predicate.
75    Failure,
76
77    /// The node has been tested, but it is not known whether it satisfies some caller-defined
78    /// predicate. It will be skipped in future searches.
79    Indeterminate,
80}
81
82/// The upper and lower bounds of the search.
83#[derive(Debug, Eq, PartialEq)]
84pub struct Bounds<Node: Debug + Eq + Hash> {
85    /// The upper bounds of the search. The ancestors of this set have (or are
86    /// assumed to have) `Status::Success`.
87    pub success: HashSet<Node>,
88
89    /// The lower bounds of the search. The ancestors of this set have (or are
90    /// assumed to have) `Status::Failure`.
91    pub failure: HashSet<Node>,
92}
93
94impl<Node: Debug + Eq + Hash> Default for Bounds<Node> {
95    fn default() -> Self {
96        Bounds {
97            success: Default::default(),
98            failure: Default::default(),
99        }
100    }
101}
102
103/// A search strategy to select the next node to search in the graph.
104pub trait Strategy<G: Graph>: Debug {
105    /// An error type.
106    type Error: std::error::Error;
107
108    /// Return a "midpoint" for the search. Such a midpoint lies between the
109    /// success bounds and failure bounds, for some meaning of "lie between",
110    /// which depends on the strategy details.
111    ///
112    /// If `None` is returned, then the search exits.
113    ///
114    /// For example, linear search would return a node immediately "after"
115    /// the node(s) in `success_bounds`, while binary search would return the
116    /// node in the middle of `success_bounds` and `failure_bounds`.`
117    ///
118    /// NOTE: This must not return a value that has already been included in the
119    /// success or failure bounds, since then you would search it again in a
120    /// loop indefinitely. In that case, you must return `None` instead.
121    fn midpoint(
122        &self,
123        graph: &G,
124        bounds: &Bounds<G::Node>,
125        statuses: &IndexMap<G::Node, Status>,
126    ) -> Result<Option<G::Node>, Self::Error>;
127}
128
129/// The results of the search so far. The search is complete if `next_to_search` is empty.
130pub struct LazySolution<'a, TNode: Debug + Eq + Hash + 'a, TError> {
131    /// The bounds of the search so far.
132    pub bounds: Bounds<TNode>,
133
134    /// The next nodes to search in a suggested order. Normally, you would only
135    /// consume the first node in this iterator and then call `Search::notify`
136    /// with the result. However, if you want to parallelize or speculate on
137    /// further nodes, you can consume more nodes from this iterator.
138    ///
139    /// This will be empty when the bounds are as tight as possible, i.e. the
140    /// search is complete.
141    pub next_to_search: Box<dyn Iterator<Item = Result<TNode, TError>> + 'a>,
142}
143
144impl<'a, TNode: Debug + Eq + Hash + 'a, TError> LazySolution<'a, TNode, TError> {
145    /// Convenience function to call `EagerSolution::from` on this `LazySolution`.
146    pub fn into_eager(self) -> Result<EagerSolution<TNode>, TError> {
147        let LazySolution {
148            bounds,
149            next_to_search,
150        } = self;
151        Ok(EagerSolution {
152            bounds,
153            next_to_search: next_to_search.collect::<Result<Vec<_>, TError>>()?,
154        })
155    }
156}
157
158/// A `LazySolution` with a `Vec<Node>` for `next_to_search`. This is primarily
159/// for debugging.
160#[derive(Debug, Eq, PartialEq)]
161pub struct EagerSolution<Node: Debug + Hash + Eq> {
162    pub(crate) bounds: Bounds<Node>,
163    pub(crate) next_to_search: Vec<Node>,
164}
165
166#[allow(missing_docs)]
167#[derive(Debug, thiserror::Error)]
168pub enum SearchError<TNode, TGraphError, TStrategyError> {
169    #[error(
170        "node {node:?} has already been classified as a {status:?} node, but was returned as a new midpoint to search; this would loop indefinitely"
171    )]
172    AlreadySearchedMidpoint { node: TNode, status: Status },
173
174    #[error(transparent)]
175    Graph(TGraphError),
176
177    #[error(transparent)]
178    Strategy(TStrategyError),
179}
180
181/// The error type for the search.
182#[allow(missing_docs)]
183#[derive(Debug, thiserror::Error)]
184pub enum NotifyError<TNode, TGraphError> {
185    #[error(
186        "inconsistent state transition: {ancestor_node:?} ({ancestor_status:?}) was marked as an ancestor of {descendant_node:?} ({descendant_status:?}"
187    )]
188    InconsistentStateTransition {
189        ancestor_node: TNode,
190        ancestor_status: Status,
191        descendant_node: TNode,
192        descendant_status: Status,
193    },
194
195    #[error("illegal state transition for {node:?}: {from:?} -> {to:?}")]
196    IllegalStateTransition {
197        node: TNode,
198        from: Status,
199        to: Status,
200    },
201
202    #[error(transparent)]
203    Graph(TGraphError),
204}
205
206/// The search algorithm.
207#[derive(Clone, Debug)]
208pub struct Search<G: Graph> {
209    graph: G,
210    nodes: IndexMap<G::Node, Status>,
211}
212
213impl<G: Graph> Search<G> {
214    /// Construct a new search. The provided `graph` represents the universe of
215    /// all nodes, and `nodes` represents a subset of that universe to search
216    /// in. Only elements from `nodes` will be returned by `success_bounds` and
217    /// `failure_bounds`.
218    ///
219    /// For example, `graph` might correspond to the entire source control
220    /// directed acyclic graph, and `nodes` might correspond to a recent range
221    /// of commits where the first one is passing and the last one is failing.
222    pub fn new(graph: G, search_nodes: impl IntoIterator<Item = G::Node>) -> Self {
223        let nodes = search_nodes
224            .into_iter()
225            .map(|node| (node, Status::Untested))
226            .collect();
227        Self { graph, nodes }
228    }
229
230    /// Get the currently known bounds on the success nodes.
231    ///
232    /// FIXME: O(n) complexity.
233    #[instrument]
234    pub fn success_bounds(&self) -> Result<HashSet<G::Node>, G::Error> {
235        let success_nodes = self
236            .nodes
237            .iter()
238            .filter_map(|(node, status)| match status {
239                Status::Success => Some(node.clone()),
240                Status::Untested | Status::Failure | Status::Indeterminate => None,
241            })
242            .collect::<HashSet<_>>();
243        let success_bounds = self.graph.simplify_success_bounds(success_nodes)?;
244        Ok(success_bounds)
245    }
246
247    /// Get the currently known bounds on the failure nodes.
248    ///
249    /// FIXME: O(n) complexity.
250    #[instrument]
251    pub fn failure_bounds(&self) -> Result<HashSet<G::Node>, G::Error> {
252        let failure_nodes = self
253            .nodes
254            .iter()
255            .filter_map(|(node, status)| match status {
256                Status::Failure => Some(node.clone()),
257                Status::Untested | Status::Success | Status::Indeterminate => None,
258            })
259            .collect::<HashSet<_>>();
260        let failure_bounds = self.graph.simplify_failure_bounds(failure_nodes)?;
261        Ok(failure_bounds)
262    }
263
264    /// Summarize the current search progress and suggest the next node(s) to
265    /// search. The caller is responsible for calling `notify` with the result.
266    #[instrument]
267    #[allow(clippy::type_complexity)]
268    pub fn search<'a, S: Strategy<G>>(
269        &'a self,
270        strategy: &'a S,
271    ) -> Result<
272        LazySolution<'a, G::Node, SearchError<G::Node, G::Error, S::Error>>,
273        SearchError<G::Node, G::Error, S::Error>,
274    > {
275        let success_bounds = self.success_bounds().map_err(SearchError::Graph)?;
276        let failure_bounds = self.failure_bounds().map_err(SearchError::Graph)?;
277
278        #[derive(Debug)]
279        struct State<G: Graph> {
280            bounds: Bounds<G::Node>,
281            statuses: IndexMap<G::Node, Status>,
282        }
283
284        struct Iter<'a, G: Graph, S: Strategy<G>> {
285            graph: &'a G,
286            strategy: &'a S,
287            seen: HashSet<G::Node>,
288            states: VecDeque<State<G>>,
289        }
290
291        impl<G: Graph, S: Strategy<G>> Iterator for Iter<'_, G, S> {
292            type Item = Result<G::Node, SearchError<G::Node, G::Error, S::Error>>;
293
294            fn next(&mut self) -> Option<Self::Item> {
295                while let Some(state) = self.states.pop_front() {
296                    debug!(?state, "Popped speculation state");
297                    let State { bounds, statuses } = state;
298
299                    let node = match self.strategy.midpoint(self.graph, &bounds, &statuses) {
300                        Ok(Some(node)) => node,
301                        Ok(None) => continue,
302                        Err(err) => return Some(Err(SearchError::Strategy(err))),
303                    };
304
305                    let Bounds { success, failure } = bounds;
306                    for success_node in success.iter() {
307                        match self.graph.is_ancestor(node.clone(), success_node.clone()) {
308                            Ok(true) => {
309                                return Some(Err(SearchError::AlreadySearchedMidpoint {
310                                    node,
311                                    status: Status::Success,
312                                }));
313                            }
314                            Ok(false) => (),
315                            Err(err) => return Some(Err(SearchError::Graph(err))),
316                        }
317                    }
318                    for failure_node in failure.iter() {
319                        match self.graph.is_ancestor(failure_node.clone(), node.clone()) {
320                            Ok(true) => {
321                                return Some(Err(SearchError::AlreadySearchedMidpoint {
322                                    node,
323                                    status: Status::Failure,
324                                }));
325                            }
326                            Ok(false) => (),
327                            Err(err) => return Some(Err(SearchError::Graph(err))),
328                        }
329                    }
330
331                    // Speculate failure:
332                    self.states.push_back(State {
333                        bounds: Bounds {
334                            success: success.clone(),
335                            failure: {
336                                let mut failure_bounds = failure.clone();
337                                failure_bounds.insert(node.clone());
338                                match self.graph.simplify_failure_bounds(failure_bounds) {
339                                    Ok(bounds) => bounds,
340                                    Err(err) => return Some(Err(SearchError::Graph(err))),
341                                }
342                            },
343                        },
344                        statuses: {
345                            let mut statuses = statuses.clone();
346                            statuses.insert(node.clone(), Status::Failure);
347                            statuses
348                        },
349                    });
350
351                    // Speculate success:
352                    self.states.push_back(State {
353                        bounds: Bounds {
354                            success: {
355                                let mut success_bounds = success.clone();
356                                success_bounds.insert(node.clone());
357                                match self.graph.simplify_success_bounds(success_bounds) {
358                                    Ok(bounds) => bounds,
359                                    Err(err) => return Some(Err(SearchError::Graph(err))),
360                                }
361                            },
362                            failure: failure.clone(),
363                        },
364                        statuses: {
365                            let mut statuses = statuses.clone();
366                            statuses.insert(node.clone(), Status::Success);
367                            statuses
368                        },
369                    });
370
371                    if self.seen.insert(node.clone()) {
372                        return Some(Ok(node));
373                    }
374                }
375                None
376            }
377        }
378
379        let initial_state = State {
380            bounds: Bounds {
381                success: success_bounds.clone(),
382                failure: failure_bounds.clone(),
383            },
384            statuses: self.nodes.clone(),
385        };
386        let iter = Iter {
387            graph: &self.graph,
388            strategy,
389            seen: Default::default(),
390            states: [initial_state].into_iter().collect(),
391        };
392
393        Ok(LazySolution {
394            bounds: Bounds {
395                success: success_bounds,
396                failure: failure_bounds,
397            },
398            next_to_search: Box::new(iter),
399        })
400    }
401
402    /// Update the search state with the result of a search.
403    #[instrument]
404    pub fn notify(
405        &mut self,
406        node: G::Node,
407        status: Status,
408    ) -> Result<(), NotifyError<G::Node, G::Error>> {
409        match self.nodes.get(&node) {
410            Some(existing_status @ (Status::Success | Status::Failure))
411                if existing_status != &status =>
412            {
413                return Err(NotifyError::IllegalStateTransition {
414                    node,
415                    from: *existing_status,
416                    to: status,
417                });
418            }
419            _ => {}
420        }
421
422        match status {
423            Status::Untested | Status::Indeterminate => {}
424
425            Status::Success => {
426                for failure_node in self.failure_bounds().map_err(NotifyError::Graph)? {
427                    if self
428                        .graph
429                        .is_ancestor(failure_node.clone(), node.clone())
430                        .map_err(NotifyError::Graph)?
431                    {
432                        return Err(NotifyError::InconsistentStateTransition {
433                            ancestor_node: failure_node,
434                            ancestor_status: Status::Failure,
435                            descendant_node: node,
436                            descendant_status: Status::Success,
437                        });
438                    }
439                }
440            }
441
442            Status::Failure => {
443                for success_node in self.success_bounds().map_err(NotifyError::Graph)? {
444                    if self
445                        .graph
446                        .is_ancestor(node.clone(), success_node.clone())
447                        .map_err(NotifyError::Graph)?
448                    {
449                        return Err(NotifyError::InconsistentStateTransition {
450                            ancestor_node: node,
451                            ancestor_status: Status::Failure,
452                            descendant_node: success_node,
453                            descendant_status: Status::Success,
454                        });
455                    }
456                }
457            }
458        }
459
460        self.nodes.insert(node, status);
461        Ok(())
462    }
463}