Skip to main content

probabilistic_bisector/
lib.rs

1//! # Probabilistic Bisector
2//!
3//! `probabilistic_bisector` provides a probabilistic bisection algorithm for
4//! locating roots of scalar objective functions observed in noise.
5//!
6//! Classical bisection assumes that evaluating an objective function gives an
7//! exact sign. In many numerical, simulation, and experimental settings this is
8//! not true: repeated evaluations at the same point may produce different
9//! values, and the sign of the objective may only be inferable statistically.
10//!
11//! This crate is designed for that setting.
12//!
13//! ## Motivation
14//!
15//! Suppose we want to find a root `x*` of a scalar objective function
16//!
17//! `text
18//! f(x*) = 0
19//! `
20//!
21//! but each call to the objective produces a noisy observation. A single
22//! evaluation may not reliably tell us whether `f(x)` is positive or negative.
23//!
24//! Instead of treating each sign observation as exact, this crate maintains a
25//! posterior distribution over the possible root location. Each observation
26//! updates that distribution, and the algorithm returns a confidence interval
27//! for the root.
28//!
29//! ## Theoretical basis
30//!
31//! The implementation follows the probabilistic bisection framework described
32//! by Waeber.
33//!
34//! The algorithm maintains a posterior distribution over the root location on a
35//! fixed internal domain. At each step:
36//!
37//! 1. A query point is selected from the current posterior.
38//! 2. The objective is evaluated repeatedly at that point.
39//! 3. A curved-boundary sign test determines the sign of the objective to the
40//!    requested confidence level.
41//! 4. The posterior mass is updated according to whether the root is expected
42//!    to lie to the left or right of the query point.
43//! 5. A sequential confidence interval is updated by intersecting the previous
44//!    interval with the current Waeber-style confidence region.
45//!
46//! Internally, posterior mass is stored over a partition of the search domain.
47//! The posterior is represented in log-space for numerical stability.
48//!
49//! ## Coordinate scaling
50//!
51//! The posterior is represented on a numerically convenient internal domain,
52//! usually `[0, 1]`. User objective functions are evaluated on their original
53//! raw domain.
54//!
55//! A `Scaler` maps between these coordinate systems. For strictly positive
56//! domains spanning several orders of magnitude, logarithmic scaling may be used
57//! so that posterior resolution is distributed multiplicatively rather than
58//! linearly.
59//!
60//! ## Implementing an objective
61//!
62//! Problems are represented by implementing [`RootOracle`].
63//!
64//! The oracle provides noisy evaluations of the objective. Implementations may
65//! be deterministic or stochastic.
66//!
67//! ```rust
68//! use confi::ConfidenceLevel;
69//! use probabilistic_bisector::{run, BisectorConfig, RootOracle};
70//!
71//! struct Linear {
72//!     root: f64,
73//! }
74//!
75//! impl RootOracle<f64> for Linear {
76//!     fn evaluate(&mut self, x: f64) -> f64 {
77//!         x - self.root
78//!     }
79//! }
80//!
81//!
82//! fn main() -> Result<(), Box<dyn std::error::Error>> {
83//!     let config = BisectorConfig {
84//!         max_observations: 10000,      // Maximum observations in the loop
85//!         max_knots: 1000,              // Maximum knots in the posterior distribution
86//!         max_sign_evaluations: 1000,   // Maximum function calls in evaluating the objective sign
87//!         rel_tol: 1e-5,                // Target relative tolerance
88//!         tolerance_window: 10,         // Number of evaluations relative tolerance should be stable
89//!     };
90//!
91//!     let result = run(
92//!         0.0..10.0,                    // Search range
93//!         ConfidenceLevel::new(0.8)?,   // Required confidence level
94//!         Linear { root: 2.5 },         // Problem
95//!         config,
96//!     )?;
97//!
98//!     println!("root interval: {:?}", result.interval);
99//!     println!("termination: {:?}", result.termination);
100//!     Ok(())
101//! }
102//! ```
103//!
104//! ## Stochastic objectives
105//!
106//! A `RootOracle` may use internal randomness. Since [`RootOracle::evaluate`]
107//! takes `&mut self`, objectives can store their own random number generator,
108//! allowing reproducible seeded tests.
109//!
110//! `rust
111//! use probabilistic_bisector::RootOracle;
112//!
113//! struct NoisyLinear {
114//!     root: f64,
115//!     index: usize,
116//! }
117//!
118//! impl RootOracle<f64> for NoisyLinear {
119//!     fn evaluate(&mut self, x: f64) -> f64 {
120//!         let noise = if self.index % 2 == 0 { 0.001 } else { -0.001 };
121//!         self.index += 1;
122//!         x - self.root + noise
123//!     }
124//! }
125//! `
126//!
127//! ## Termination
128//!
129//! The solver may terminate because:
130//!
131//! - the requested tolerance was reached,
132//! - the maximum iteration budget was reached,
133//! - or the objective sign became indeterminate at all useful query points.
134//!
135//! Sign indeterminacy is not necessarily an error. It usually means that the
136//! algorithm has reached the noise floor of the objective: additional samples at
137//! nearby points do not determine the sign reliably within the configured
138//! evaluation budget -> Given the noise in the objective function the requested
139//! tolerance might be unachievable.
140//!
141//! ## Output
142//!
143//! Successful runs return a result containing:
144//!
145//! - a confidence interval for the root,
146//! - execution summary information,
147//! - and the reason the solver terminated.
148//!
149//! Exceptional failures are reserved for invalid inputs, invalid oracle values,
150//! or internal invariant violations.
151
152mod distribution;
153mod error;
154mod evolution;
155pub(crate) mod intervals;
156mod result;
157mod root;
158mod scaling;
159mod solver;
160pub(crate) mod support;
161
162use evolution::InferenceState;
163
164pub use confi::ConfidenceLevel;
165pub use root::{ObjectiveSign, RootError, RootOracle, RootSide};
166
167use distribution::{PosteriorDistribution, PosteriorError};
168pub use error::PBError;
169pub(crate) use evolution::BisectionError;
170
171use intervals::{Interval, IntervalError, MeetError, SequentialInterval};
172use scaling::{Scaler, ScalerError};
173use support::SupportSet;
174
175use solver::RootFinder;
176
177pub use result::{ProbabilisticBisectionError, ProbabilisticBisectionResult, SolverDiagnostics};
178
179use num_traits::{Float, FromPrimitive};
180use std::ops::Range;
181use trellis_runner::{
182    GenerateBuilderFallible, MaxIterationPolicy, TargetValuePolicy, TrellisFloat,
183};
184
185pub struct BisectorConfig<T> {
186    pub max_observations: usize,
187    pub max_knots: usize,
188    pub max_sign_evaluations: usize,
189    pub rel_tol: T,
190    pub tolerance_window: usize,
191}
192
193#[allow(clippy::result_large_err)]
194pub fn run<T, P>(
195    domain: Range<T>,
196    confidence_level: ConfidenceLevel<T>,
197    problem: P,
198    config: BisectorConfig<T>,
199) -> Result<ProbabilisticBisectionResult<T>, ProbabilisticBisectionError<T>>
200where
201    T: TrellisFloat
202        + Float
203        + FromPrimitive
204        + std::ops::AddAssign
205        + std::iter::Sum
206        + Send
207        + Sync
208        + 'static,
209    P: RootOracle<T>,
210{
211    let root_finder = RootFinder::new(
212        domain.clone(),
213        confidence_level,
214        config.max_sign_evaluations,
215    )?;
216    let state = InferenceState::new(root_finder.scaled_domain().clone(), config.max_knots)?;
217
218    let _tolerance_window = 10;
219    let engine = <RootFinder<T> as GenerateBuilderFallible>::build_for(root_finder, problem)
220        .and_policy(MaxIterationPolicy::new(config.max_observations))
221        .and_policy(TargetValuePolicy::new(
222            T::zero(),
223            config.rel_tol,
224            config.tolerance_window,
225        ))
226        .with_initial_state(state)
227        .finalise();
228
229    let result = engine.run();
230
231    match result {
232        Ok(output) => {
233            let result = ProbabilisticBisectionResult {
234                interval: output.result.current,
235                termination: output.termination,
236                summary: output.summary,
237            };
238            Ok(result)
239        }
240        Err(trellis_runner::EngineFailure::Procedure { error, state }) => {
241            Err(ProbabilisticBisectionError::Running {
242                error,
243                summary: state.run_summary(),
244                state: state.user,
245            })
246        }
247    }
248}