scm_bisect/
search.rs

1//! A search algorithm for directed acyclic graphs to find the nodes which
2//! "flip" from passing to failing a predicate.
3
4use std::collections::{HashSet, VecDeque};
5use std::fmt::Debug;
6use std::hash::Hash;
7
8use indexmap::IndexMap;
9use tracing::{debug, instrument};
10
11/// The set of nodes compromising a directed acyclic graph to be searched.
12pub trait Graph: Debug {
13    /// The type of nodes in the graph. This should be cheap to clone.
14    type Node: Clone + Debug + Hash + Eq;
15
16    /// An error type.
17    type Error: std::error::Error;
18
19    /// Return whether or not `node` is an ancestor of `descendant`. A node `X``
20    /// is said to be an "ancestor" of node `Y` if one of the following is true:
21    ///
22    /// - `X == Y`
23    /// - `X` is an immediate parent of `Y`.
24    /// - `X` is an ancestor of an immediate parent of `Y` (defined
25    ///   recursively).
26    fn is_ancestor(
27        &self,
28        ancestor: Self::Node,
29        descendant: Self::Node,
30    ) -> Result<bool, Self::Error>;
31
32    /// Filter `nodes` to only include nodes that are not ancestors of any other
33    /// node in `nodes`. This is not strictly necessary, but it improves
34    /// performance as some operations are linear in the size of the success
35    /// bounds, and it can make the intermediate results more sensible.
36    ///
37    /// This operation is called `heads` in e.g. Mercurial.
38    #[instrument]
39    fn simplify_success_bounds(
40        &self,
41        nodes: HashSet<Self::Node>,
42    ) -> Result<HashSet<Self::Node>, Self::Error> {
43        Ok(nodes)
44    }
45
46    /// Filter `nodes` to only include nodes that are not descendants of any
47    /// other node in `nodes`. This is not strictly necessary, but it improves
48    /// performance as some operations are linear in the size of the failure
49    /// bounds, and it can make the intermediate results more sensible.
50    ///
51    /// This operation is called `roots` in e.g. Mercurial.
52    #[instrument]
53    fn simplify_failure_bounds(
54        &self,
55        nodes: HashSet<Self::Node>,
56    ) -> Result<HashSet<Self::Node>, Self::Error> {
57        Ok(nodes)
58    }
59}
60
61/// The possible statuses of a node in the search.
62#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
63pub enum Status {
64    /// The node has not been tested yet. This is the starting state for each node in a search.
65    Untested,
66
67    /// The node has been tested and satisfies some caller-defined predicate.
68    /// For the rest of the search, it's assumed that all ancestor nodes of this
69    /// node also satisfy the predicate.
70    Success,
71
72    /// The node has been tested and does not satisfy some caller-defined
73    /// predicate. For the rest of the search, it's assumed that all descendant
74    /// nodes of this node also do not satisfy the predicate.
75    Failure,
76
77    /// The node has been tested, but it is not known whether it satisfies some caller-defined
78    /// predicate. It will be skipped in future searches.
79    Indeterminate,
80}
81
82/// The upper and lower bounds of the search.
83#[derive(Debug, Eq, PartialEq)]
84pub struct Bounds<Node: Debug + Eq + Hash> {
85    /// The upper bounds of the search. The ancestors of this set have (or are
86    /// assumed to have) `Status::Success`.
87    pub success: HashSet<Node>,
88
89    /// The lower bounds of the search. The ancestors of this set have (or are
90    /// assumed to have) `Status::Failure`.
91    pub failure: HashSet<Node>,
92}
93
94impl<Node: Debug + Eq + Hash> Default for Bounds<Node> {
95    fn default() -> Self {
96        Bounds {
97            success: Default::default(),
98            failure: Default::default(),
99        }
100    }
101}
102
103/// A search strategy to select the next node to search in the graph.
104pub trait Strategy<G: Graph>: Debug {
105    /// An error type.
106    type Error: std::error::Error;
107
108    /// Return a "midpoint" for the search. Such a midpoint lies between the
109    /// success bounds and failure bounds, for some meaning of "lie between",
110    /// which depends on the strategy details.
111    ///
112    /// If `None` is returned, then the search exits.
113    ///
114    /// For example, linear search would return a node immediately "after"
115    /// the node(s) in `success_bounds`, while binary search would return the
116    /// node in the middle of `success_bounds` and `failure_bounds`.`
117    ///
118    /// NOTE: This must not return a value that has already been included in the
119    /// success or failure bounds, since then you would search it again in a
120    /// loop indefinitely. In that case, you must return `None` instead.
121    fn midpoint(
122        &self,
123        graph: &G,
124        bounds: &Bounds<G::Node>,
125        statuses: &IndexMap<G::Node, Status>,
126    ) -> Result<Option<G::Node>, Self::Error>;
127}
128
129/// The results of the search so far. The search is complete if `next_to_search` is empty.
130pub struct LazySolution<'a, TNode: Debug + Eq + Hash + 'a, TError> {
131    /// The bounds of the search so far.
132    pub bounds: Bounds<TNode>,
133
134    /// The next nodes to search in a suggested order. Normally, you would only
135    /// consume the first node in this iterator and then call `Search::notify`
136    /// with the result. However, if you want to parallelize or speculate on
137    /// further nodes, you can consume more nodes from this iterator.
138    ///
139    /// This will be empty when the bounds are as tight as possible, i.e. the
140    /// search is complete.
141    pub next_to_search: Box<dyn Iterator<Item = Result<TNode, TError>> + 'a>,
142}
143
144impl<'a, TNode: Debug + Eq + Hash + 'a, TError> LazySolution<'a, TNode, TError> {
145    /// Convenience function to call `EagerSolution::from` on this `LazySolution`.
146    pub fn into_eager(self) -> Result<EagerSolution<TNode>, TError> {
147        let LazySolution {
148            bounds,
149            next_to_search,
150        } = self;
151        Ok(EagerSolution {
152            bounds,
153            next_to_search: next_to_search.collect::<Result<Vec<_>, TError>>()?,
154        })
155    }
156}
157
158/// A `LazySolution` with a `Vec<Node>` for `next_to_search`. This is primarily
159/// for debugging.
160#[derive(Debug, Eq, PartialEq)]
161pub struct EagerSolution<Node: Debug + Hash + Eq> {
162    pub(crate) bounds: Bounds<Node>,
163    pub(crate) next_to_search: Vec<Node>,
164}
165
166#[allow(missing_docs)]
167#[derive(Debug, thiserror::Error)]
168pub enum SearchError<TNode, TGraphError, TStrategyError> {
169    #[error("node {node:?} has already been classified as a {status:?} node, but was returned as a new midpoint to search; this would loop indefinitely")]
170    AlreadySearchedMidpoint { node: TNode, status: Status },
171
172    #[error(transparent)]
173    Graph(TGraphError),
174
175    #[error(transparent)]
176    Strategy(TStrategyError),
177}
178
179/// The error type for the search.
180#[allow(missing_docs)]
181#[derive(Debug, thiserror::Error)]
182pub enum NotifyError<TNode, TGraphError> {
183    #[error("inconsistent state transition: {ancestor_node:?} ({ancestor_status:?}) was marked as an ancestor of {descendant_node:?} ({descendant_status:?}")]
184    InconsistentStateTransition {
185        ancestor_node: TNode,
186        ancestor_status: Status,
187        descendant_node: TNode,
188        descendant_status: Status,
189    },
190
191    #[error("illegal state transition for {node:?}: {from:?} -> {to:?}")]
192    IllegalStateTransition {
193        node: TNode,
194        from: Status,
195        to: Status,
196    },
197
198    #[error(transparent)]
199    Graph(TGraphError),
200}
201
202/// The search algorithm.
203#[derive(Clone, Debug)]
204pub struct Search<G: Graph> {
205    graph: G,
206    nodes: IndexMap<G::Node, Status>,
207}
208
209impl<G: Graph> Search<G> {
210    /// Construct a new search. The provided `graph` represents the universe of
211    /// all nodes, and `nodes` represents a subset of that universe to search
212    /// in. Only elements from `nodes` will be returned by `success_bounds` and
213    /// `failure_bounds`.
214    ///
215    /// For example, `graph` might correspond to the entire source control
216    /// directed acyclic graph, and `nodes` might correspond to a recent range
217    /// of commits where the first one is passing and the last one is failing.
218    pub fn new(graph: G, search_nodes: impl IntoIterator<Item = G::Node>) -> Self {
219        let nodes = search_nodes
220            .into_iter()
221            .map(|node| (node, Status::Untested))
222            .collect();
223        Self { graph, nodes }
224    }
225
226    /// Get the currently known bounds on the success nodes.
227    ///
228    /// FIXME: O(n) complexity.
229    #[instrument]
230    pub fn success_bounds(&self) -> Result<HashSet<G::Node>, G::Error> {
231        let success_nodes = self
232            .nodes
233            .iter()
234            .filter_map(|(node, status)| match status {
235                Status::Success => Some(node.clone()),
236                Status::Untested | Status::Failure | Status::Indeterminate => None,
237            })
238            .collect::<HashSet<_>>();
239        let success_bounds = self.graph.simplify_success_bounds(success_nodes)?;
240        Ok(success_bounds)
241    }
242
243    /// Get the currently known bounds on the failure nodes.
244    ///
245    /// FIXME: O(n) complexity.
246    #[instrument]
247    pub fn failure_bounds(&self) -> Result<HashSet<G::Node>, G::Error> {
248        let failure_nodes = self
249            .nodes
250            .iter()
251            .filter_map(|(node, status)| match status {
252                Status::Failure => Some(node.clone()),
253                Status::Untested | Status::Success | Status::Indeterminate => None,
254            })
255            .collect::<HashSet<_>>();
256        let failure_bounds = self.graph.simplify_failure_bounds(failure_nodes)?;
257        Ok(failure_bounds)
258    }
259
260    /// Summarize the current search progress and suggest the next node(s) to
261    /// search. The caller is responsible for calling `notify` with the result.
262    #[instrument]
263    #[allow(clippy::type_complexity)]
264    pub fn search<'a, S: Strategy<G>>(
265        &'a self,
266        strategy: &'a S,
267    ) -> Result<
268        LazySolution<G::Node, SearchError<G::Node, G::Error, S::Error>>,
269        SearchError<G::Node, G::Error, S::Error>,
270    > {
271        let success_bounds = self.success_bounds().map_err(SearchError::Graph)?;
272        let failure_bounds = self.failure_bounds().map_err(SearchError::Graph)?;
273
274        #[derive(Debug)]
275        struct State<G: Graph> {
276            bounds: Bounds<G::Node>,
277            statuses: IndexMap<G::Node, Status>,
278        }
279
280        struct Iter<'a, G: Graph, S: Strategy<G>> {
281            graph: &'a G,
282            strategy: &'a S,
283            seen: HashSet<G::Node>,
284            states: VecDeque<State<G>>,
285        }
286
287        impl<'a, G: Graph, S: Strategy<G>> Iterator for Iter<'a, G, S> {
288            type Item = Result<G::Node, SearchError<G::Node, G::Error, S::Error>>;
289
290            fn next(&mut self) -> Option<Self::Item> {
291                while let Some(state) = self.states.pop_front() {
292                    debug!(?state, "Popped speculation state");
293                    let State { bounds, statuses } = state;
294
295                    let node = match self.strategy.midpoint(self.graph, &bounds, &statuses) {
296                        Ok(Some(node)) => node,
297                        Ok(None) => continue,
298                        Err(err) => return Some(Err(SearchError::Strategy(err))),
299                    };
300
301                    let Bounds { success, failure } = bounds;
302                    for success_node in success.iter() {
303                        match self.graph.is_ancestor(node.clone(), success_node.clone()) {
304                            Ok(true) => {
305                                return Some(Err(SearchError::AlreadySearchedMidpoint {
306                                    node,
307                                    status: Status::Success,
308                                }));
309                            }
310                            Ok(false) => (),
311                            Err(err) => return Some(Err(SearchError::Graph(err))),
312                        }
313                    }
314                    for failure_node in failure.iter() {
315                        match self.graph.is_ancestor(failure_node.clone(), node.clone()) {
316                            Ok(true) => {
317                                return Some(Err(SearchError::AlreadySearchedMidpoint {
318                                    node,
319                                    status: Status::Failure,
320                                }));
321                            }
322                            Ok(false) => (),
323                            Err(err) => return Some(Err(SearchError::Graph(err))),
324                        }
325                    }
326
327                    // Speculate failure:
328                    self.states.push_back(State {
329                        bounds: Bounds {
330                            success: success.clone(),
331                            failure: {
332                                let mut failure_bounds = failure.clone();
333                                failure_bounds.insert(node.clone());
334                                match self.graph.simplify_failure_bounds(failure_bounds) {
335                                    Ok(bounds) => bounds,
336                                    Err(err) => return Some(Err(SearchError::Graph(err))),
337                                }
338                            },
339                        },
340                        statuses: {
341                            let mut statuses = statuses.clone();
342                            statuses.insert(node.clone(), Status::Failure);
343                            statuses
344                        },
345                    });
346
347                    // Speculate success:
348                    self.states.push_back(State {
349                        bounds: Bounds {
350                            success: {
351                                let mut success_bounds = success.clone();
352                                success_bounds.insert(node.clone());
353                                match self.graph.simplify_success_bounds(success_bounds) {
354                                    Ok(bounds) => bounds,
355                                    Err(err) => return Some(Err(SearchError::Graph(err))),
356                                }
357                            },
358                            failure: failure.clone(),
359                        },
360                        statuses: {
361                            let mut statuses = statuses.clone();
362                            statuses.insert(node.clone(), Status::Success);
363                            statuses
364                        },
365                    });
366
367                    if self.seen.insert(node.clone()) {
368                        return Some(Ok(node));
369                    }
370                }
371                None
372            }
373        }
374
375        let initial_state = State {
376            bounds: Bounds {
377                success: success_bounds.clone(),
378                failure: failure_bounds.clone(),
379            },
380            statuses: self.nodes.clone(),
381        };
382        let iter = Iter {
383            graph: &self.graph,
384            strategy,
385            seen: Default::default(),
386            states: [initial_state].into_iter().collect(),
387        };
388
389        Ok(LazySolution {
390            bounds: Bounds {
391                success: success_bounds,
392                failure: failure_bounds,
393            },
394            next_to_search: Box::new(iter),
395        })
396    }
397
398    /// Update the search state with the result of a search.
399    #[instrument]
400    pub fn notify(
401        &mut self,
402        node: G::Node,
403        status: Status,
404    ) -> Result<(), NotifyError<G::Node, G::Error>> {
405        match self.nodes.get(&node) {
406            Some(existing_status @ (Status::Success | Status::Failure))
407                if existing_status != &status =>
408            {
409                return Err(NotifyError::IllegalStateTransition {
410                    node,
411                    from: *existing_status,
412                    to: status,
413                })
414            }
415            _ => {}
416        }
417
418        match status {
419            Status::Untested | Status::Indeterminate => {}
420
421            Status::Success => {
422                for failure_node in self.failure_bounds().map_err(NotifyError::Graph)? {
423                    if self
424                        .graph
425                        .is_ancestor(failure_node.clone(), node.clone())
426                        .map_err(NotifyError::Graph)?
427                    {
428                        return Err(NotifyError::InconsistentStateTransition {
429                            ancestor_node: failure_node,
430                            ancestor_status: Status::Failure,
431                            descendant_node: node,
432                            descendant_status: Status::Success,
433                        });
434                    }
435                }
436            }
437
438            Status::Failure => {
439                for success_node in self.success_bounds().map_err(NotifyError::Graph)? {
440                    if self
441                        .graph
442                        .is_ancestor(node.clone(), success_node.clone())
443                        .map_err(NotifyError::Graph)?
444                    {
445                        return Err(NotifyError::InconsistentStateTransition {
446                            ancestor_node: node,
447                            ancestor_status: Status::Failure,
448                            descendant_node: success_node,
449                            descendant_status: Status::Success,
450                        });
451                    }
452                }
453            }
454        }
455
456        self.nodes.insert(node, status);
457        Ok(())
458    }
459}