use std::collections::{HashMap, VecDeque};
use std::sync::Arc;
use std::thread;
use crossbeam_channel::{Receiver, Sender, unbounded};
use num_bigint::{BigInt, BigUint};
use num_traits::Zero;
use crate::binary::Bounds;
use crate::binary::{UBinary, UXBinary, XBinary};
use crate::concurrency::StopFlag;
use crate::error::ComputableError;
use crate::node::Node;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum XUsize {
Finite(usize),
Inf,
}
#[derive(Clone, Copy)]
pub enum RefineCommand {
Step { precision_bits: usize },
Stop,
}
pub struct RefinerHandle {
pub sender: Sender<RefineCommand>,
}
#[derive(Clone)]
pub struct NodeUpdate {
pub node_id: usize,
pub bounds: Bounds,
}
enum ExhaustionReason {
Converged,
StateUnchanged,
}
enum RefinerMessage {
Update(NodeUpdate),
Exhausted {
update: NodeUpdate,
reason: ExhaustionReason,
},
Error(ComputableError),
}
pub struct RefinementGraph {
pub root: Arc<Node>,
pub nodes: HashMap<usize, Arc<Node>>, pub parents: HashMap<usize, Vec<usize>>, pub refiners: Vec<Arc<Node>>,
}
impl RefinementGraph {
pub fn new(root: Arc<Node>) -> Result<Self, ComputableError> {
let mut nodes = HashMap::new();
let mut parents: HashMap<usize, Vec<usize>> = HashMap::new();
let mut refiners = Vec::new();
let mut stack = vec![Arc::clone(&root)];
while let Some(node) = stack.pop() {
if nodes.contains_key(&node.id) {
continue;
}
let node_id = node.id;
nodes.insert(node_id, Arc::clone(&node));
if node.is_refiner() {
refiners.push(Arc::clone(&node));
}
for child in node.children() {
parents.entry(child.id).or_default().push(node_id);
stack.push(child);
}
}
let graph = Self {
root,
nodes,
parents,
refiners,
};
Ok(graph)
}
pub fn refine_to<const MAX_REFINEMENT_ITERATIONS: usize>(
&self,
tolerance_exp: &XUsize,
) -> Result<Bounds, ComputableError> {
let mut outcome = None;
thread::scope(|scope| {
let stop_flag = Arc::new(StopFlag::new());
let mut refiners = Vec::new();
let (update_tx, update_rx) = unbounded();
for node in &self.refiners {
refiners.push(spawn_refiner(
scope,
Arc::clone(node),
Arc::clone(&stop_flag),
update_tx.clone(),
));
}
drop(update_tx);
let shutdown_refiners = |handles: Vec<RefinerHandle>, stop_signal: &Arc<StopFlag>| {
stop_signal.stop();
for refiner in &handles {
let _shutdown = refiner.sender.send(RefineCommand::Stop);
}
};
let result = (|| {
let num_refiners = refiners.len();
let mut active = vec![true; num_refiners];
let mut all_state_unchanged = true;
let precision_met =
|root: &Arc<Node>, tol: &XUsize| -> Result<bool, ComputableError> {
let bounds = root.get_bounds()?;
Ok(bounds_width_leq(&bounds, tol))
};
for _iteration in 0..MAX_REFINEMENT_ITERATIONS {
if precision_met(&self.root, tolerance_exp)? {
return self.root.get_bounds();
}
let active_count = active.iter().filter(|&&a| a).count();
if active_count == 0 {
return if all_state_unchanged {
Err(ComputableError::StateUnchanged)
} else {
Err(ComputableError::MaxRefinementIterations {
max: MAX_REFINEMENT_ITERATIONS,
})
};
}
let demand_budget = compute_demand_budget(tolerance_exp, active_count);
let precision_bits = match demand_budget {
XUsize::Finite(n) => n,
XUsize::Inf => usize::MAX,
};
let mut expected = 0_usize;
for (i, refiner) in refiners.iter().enumerate() {
if !active[i] {
continue;
}
let skip = self.refiners[i]
.cached_bounds()
.is_some_and(|b| bounds_width_leq(&b, &demand_budget));
if !skip {
refiner
.sender
.send(RefineCommand::Step { precision_bits })
.map_err(|_send_err| ComputableError::RefinementChannelClosed)?;
expected = expected.checked_add(1).unwrap_or_else(|| {
unreachable!("expected <= refiners.len(), cannot overflow usize")
});
}
}
if expected == 0 {
let max_width = (0..refiners.len())
.filter(|&i| active[i])
.map(|i| {
self.refiners[i]
.cached_bounds()
.map_or(UXBinary::Inf, |b| b.width().clone())
})
.max();
for (i, refiner) in refiners.iter().enumerate() {
if !active[i] {
continue;
}
let dominated = max_width.as_ref().is_some_and(|max_w| {
self.refiners[i]
.cached_bounds()
.is_some_and(|b| is_width_dominated(b.width(), max_w))
});
if !dominated {
refiner
.sender
.send(RefineCommand::Step { precision_bits })
.map_err(|_send_err| {
ComputableError::RefinementChannelClosed
})?;
expected = expected.checked_add(1).unwrap_or_else(|| {
unreachable!(
"expected <= refiners.len(), cannot overflow usize"
)
});
}
}
}
for _ in 0..expected {
let message = match update_rx.recv() {
Ok(msg) => msg,
Err(_) => return Err(ComputableError::RefinementChannelClosed),
};
match message {
RefinerMessage::Update(update) => {
self.apply_update(update)?;
if precision_met(&self.root, tolerance_exp)? {
return self.root.get_bounds();
}
}
RefinerMessage::Exhausted { update, reason } => {
let exhausted_node_id = update.node_id;
self.apply_update(update)?;
for (i, refiner_node) in self.refiners.iter().enumerate() {
if refiner_node.id == exhausted_node_id {
active[i] = false;
break;
}
}
if !matches!(reason, ExhaustionReason::StateUnchanged) {
all_state_unchanged = false;
}
if precision_met(&self.root, tolerance_exp)? {
return self.root.get_bounds();
}
}
RefinerMessage::Error(error) => {
return Err(error);
}
}
}
}
Err(ComputableError::MaxRefinementIterations {
max: MAX_REFINEMENT_ITERATIONS,
})
})();
shutdown_refiners(refiners, &stop_flag);
outcome = Some(result);
});
match outcome {
Some(result) => result,
None => Err(ComputableError::RefinementChannelClosed),
}
}
fn apply_update(&self, update: NodeUpdate) -> Result<(), ComputableError> {
let mut queue = VecDeque::new();
if let Some(node) = self.nodes.get(&update.node_id) {
node.set_bounds(update.bounds);
queue.push_back(node.id);
}
while let Some(changed_id) = queue.pop_front() {
let Some(parents) = self.parents.get(&changed_id) else {
continue;
};
for parent_id in parents {
let parent = self
.nodes
.get(parent_id)
.ok_or(ComputableError::RefinementChannelClosed)?;
let next_bounds = parent.compute_bounds()?;
if parent.cached_bounds().as_ref() != Some(&next_bounds) {
parent.set_bounds(next_bounds);
queue.push_back(*parent_id);
}
}
}
Ok(())
}
}
fn spawn_refiner<'scope, 'env>(
scope: &'scope thread::Scope<'scope, 'env>,
node: Arc<Node>,
stop: Arc<StopFlag>,
updates: Sender<RefinerMessage>,
) -> RefinerHandle {
let (command_tx, command_rx) = unbounded();
scope.spawn(move || {
refiner_loop(node, stop, command_rx, updates);
});
RefinerHandle { sender: command_tx }
}
fn refiner_loop(
node: Arc<Node>,
stop: Arc<StopFlag>,
commands: Receiver<RefineCommand>,
updates: Sender<RefinerMessage>,
) {
while !stop.is_stopped() {
match commands.recv() {
Ok(RefineCommand::Step { precision_bits }) => match node.refine_step(precision_bits) {
Ok(true) => {
let bounds = match node.compute_bounds() {
Ok(b) => b,
Err(e) => {
let _send = updates.send(RefinerMessage::Error(e));
break;
}
};
node.set_bounds(bounds.clone());
if updates
.send(RefinerMessage::Update(NodeUpdate {
node_id: node.id,
bounds,
}))
.is_err()
{
break;
}
}
Ok(false) => {
let bounds = match node.compute_bounds() {
Ok(b) => b,
Err(e) => {
let _send = updates.send(RefinerMessage::Error(e));
break;
}
};
node.set_bounds(bounds.clone());
let _send = updates.send(RefinerMessage::Exhausted {
update: NodeUpdate {
node_id: node.id,
bounds,
},
reason: ExhaustionReason::Converged,
});
break;
}
Err(ComputableError::StateUnchanged) => {
let bounds = node
.cached_bounds()
.unwrap_or_else(|| Bounds::new(XBinary::NegInf, XBinary::PosInf));
let _send = updates.send(RefinerMessage::Exhausted {
update: NodeUpdate {
node_id: node.id,
bounds,
},
reason: ExhaustionReason::StateUnchanged,
});
break;
}
Err(error) => {
let _send = updates.send(RefinerMessage::Error(error));
break;
}
},
Ok(RefineCommand::Stop) | Err(_) => break,
}
}
}
fn compute_demand_budget(tolerance_exp: &XUsize, num_active: usize) -> XUsize {
debug_assert!(
num_active >= 1,
"compute_demand_budget called with 0 active refiners"
);
match tolerance_exp {
XUsize::Inf => XUsize::Inf,
XUsize::Finite(exp) => {
let shift_u32 = usize::BITS
.checked_sub(num_active.leading_zeros())
.unwrap_or_else(|| unreachable!("leading_zeros() is always <= usize::BITS"));
#[allow(clippy::as_conversions)] let shift = shift_u32 as usize;
let tolerance = *exp;
let budget = crate::sane_arithmetic!(tolerance, shift; tolerance + shift);
XUsize::Finite(budget)
}
}
}
const SAFETY_VALVE_SKIP_SHIFT: i64 = 4;
fn is_width_dominated(width: &UXBinary, max_width: &UXBinary) -> bool {
match width {
UXBinary::Inf => false,
UXBinary::Finite(w) => {
let shifted = UBinary::new(
w.mantissa().clone(),
w.exponent() + BigInt::from(SAFETY_VALVE_SKIP_SHIFT),
);
UXBinary::Finite(shifted) <= *max_width
}
}
}
pub fn bounds_width_leq(bounds: &Bounds, tolerance_exp: &XUsize) -> bool {
match bounds.width() {
UXBinary::Inf => false,
UXBinary::Finite(width) => match tolerance_exp {
XUsize::Inf => width.mantissa().is_zero(),
XUsize::Finite(exp) => *width <= UBinary::new(BigUint::from(1u32), -BigInt::from(*exp)),
},
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::binary::XBinary;
use crate::computable::Computable;
use crate::error::ComputableError;
use crate::test_utils::{
bin, interval_midpoint_computable, interval_noop_computable, interval_refine,
midpoint_between, unwrap_finite, xbin,
};
use std::sync::{Arc, Barrier};
use std::thread;
use std::time::Duration;
type IntervalState = Bounds;
fn interval_bounds(state: &IntervalState) -> Bounds {
state.clone()
}
fn interval_refine_strict(state: IntervalState) -> Result<IntervalState, ComputableError> {
let midpoint = midpoint_between(state.small(), &state.large());
Ok(Bounds::new(
state.small().clone(),
XBinary::Finite(midpoint),
))
}
fn sqrt_computable(value_int: u64) -> Computable {
Computable::constant(bin(value_int as i64, 0))
.nth_root(std::num::NonZeroU32::new(2).expect("2 is non-zero"))
}
fn assert_width_nonnegative(bounds: &Bounds) {
assert!(*bounds.width() >= UXBinary::zero());
}
#[test]
fn refine_to_accepts_zero_epsilon_for_exact_values() {
let computable = interval_midpoint_computable(0, 2);
let tolerance_exp = XUsize::Inf;
let bounds = computable
.refine_to_default(tolerance_exp)
.expect("refine_to with epsilon=0 should succeed when bounds converge exactly");
assert_eq!(bounds.small(), &xbin(1, 0));
assert_eq!(bounds.large(), xbin(1, 0));
assert!(matches!(bounds.width(), UXBinary::Finite(w) if w.mantissa().is_zero()));
}
#[test]
fn refine_to_with_zero_epsilon_on_constant_succeeds_immediately() {
let computable = Computable::constant(bin(42, 0));
let tolerance_exp = XUsize::Inf;
let bounds = computable
.refine_to_default(tolerance_exp)
.expect("refine_to with epsilon=0 should succeed for constants");
assert_eq!(bounds.small(), &xbin(42, 0));
assert_eq!(bounds.large(), xbin(42, 0));
}
#[test]
fn refine_to_with_zero_epsilon_on_non_exact_value_returns_max_iterations() {
let one = Computable::constant(bin(1, 0));
let three = Computable::constant(bin(3, 0));
let one_third = one / three;
let tolerance_exp = XUsize::Inf;
let result = one_third.refine_to::<10>(tolerance_exp);
assert!(
matches!(
result,
Err(ComputableError::MaxRefinementIterations { max: 10 })
),
"expected MaxRefinementIterations error for non-exact value with epsilon=0, got {:?}",
result
);
}
#[test]
fn refine_to_returns_refined_state() {
let computable = interval_midpoint_computable(0, 2);
let tolerance_exp = XUsize::Finite(1);
let bounds = computable
.refine_to_default(tolerance_exp)
.expect("refine_to should succeed");
let expected = xbin(1, 0);
let upper = bounds.large();
assert!(bounds.small() <= &expected && &expected <= &upper);
assert!(bounds_width_leq(&bounds, &tolerance_exp));
let refined_bounds = computable.bounds().expect("bounds should succeed");
let refined_upper = refined_bounds.large();
assert!(refined_bounds.small() <= &expected && &expected <= &refined_upper);
}
#[test]
fn refine_to_rejects_unchanged_state() {
let computable = interval_noop_computable(0, 2);
let tolerance_exp = XUsize::Finite(2);
let result = computable.refine_to_default(tolerance_exp);
assert!(matches!(result, Err(ComputableError::StateUnchanged)));
}
#[test]
fn refine_to_enforces_max_iterations() {
let computable = Computable::new(
0usize,
|_| Ok(Bounds::new(XBinary::NegInf, XBinary::PosInf)),
|state| Ok(state + 1),
);
let tolerance_exp = XUsize::Finite(1);
let result = computable.refine_to::<5>(tolerance_exp);
assert!(matches!(
result,
Err(ComputableError::MaxRefinementIterations { max: 5 })
));
}
#[test]
fn refine_to_handles_non_meeting_bounds() {
let interval_state = Bounds::new(xbin(0, 0), xbin(4, 0));
let computable = Computable::new(
interval_state,
|inner_state| Ok(interval_bounds(inner_state)),
interval_refine_strict,
);
let tolerance_exp = XUsize::Finite(1);
let bounds = computable
.refine_to_default(tolerance_exp)
.expect("refine_to should succeed");
let upper = bounds.large();
assert!(bounds.small() < &upper);
assert!(bounds_width_leq(&bounds, &tolerance_exp));
assert_eq!(computable.bounds().expect("bounds should succeed"), bounds);
}
#[test]
fn refine_to_rejects_worsened_bounds() {
let interval_state = Bounds::new(xbin(0, 0), xbin(1, 0));
let computable = Computable::new(
interval_state,
|inner_state| Ok(interval_bounds(inner_state)),
|inner_state: IntervalState| {
let upper = inner_state.large();
let worse_upper = unwrap_finite(&upper).add(&bin(1, 0));
Ok(Bounds::new(
inner_state.small().clone(),
XBinary::Finite(worse_upper),
))
},
);
let tolerance_exp = XUsize::Finite(2);
let result = computable.refine_to_default(tolerance_exp);
assert!(matches!(result, Err(ComputableError::BoundsWorsened)));
}
#[test]
fn refine_shared_clone_updates_original() {
let original = sqrt_computable(2);
let cloned = original.clone();
let tolerance_exp = XUsize::Finite(12);
let _bounds = cloned
.refine_to_default(tolerance_exp)
.expect("refine_to should succeed");
let bounds = original.bounds().expect("bounds should succeed");
assert!(bounds_width_leq(&bounds, &tolerance_exp));
}
#[test]
fn refine_to_propagates_refiner_error() {
let computable = Computable::new(
0usize,
|_| Ok(Bounds::new(XBinary::NegInf, XBinary::PosInf)),
|_| Err(ComputableError::DomainError),
);
let tolerance_exp = XUsize::Finite(4);
let result = computable.refine_to::<2>(tolerance_exp);
assert!(matches!(result, Err(ComputableError::DomainError)));
}
#[test]
fn refine_to_max_iterations_multiple_refiners() {
let left = Computable::new(
0usize,
|_| Ok(Bounds::new(XBinary::NegInf, XBinary::PosInf)),
|state| Ok(state + 1),
);
let right = Computable::new(
0usize,
|_| Ok(Bounds::new(XBinary::NegInf, XBinary::PosInf)),
|state| Ok(state + 1),
);
let expr = left + right;
let tolerance_exp = XUsize::Finite(4);
let result = expr.refine_to::<2>(tolerance_exp);
assert!(matches!(
result,
Err(ComputableError::MaxRefinementIterations { max: 2 })
));
}
#[test]
fn refine_to_error_path_stops_refiners() {
let stable = interval_midpoint_computable(0, 2);
let faulty = Computable::new(
Bounds::new(xbin(0, 0), xbin(1, 0)),
|state| Ok(state.clone()),
|state| Ok(Bounds::new(state.small().clone(), xbin(2, 0))),
);
let expr = stable + faulty;
let tolerance_exp = XUsize::Finite(4);
let result = expr.refine_to::<3>(tolerance_exp);
assert!(matches!(result, Err(ComputableError::BoundsWorsened)));
}
#[test]
fn concurrent_bounds_reads_during_failed_refinement() {
let computable = Arc::new(Computable::new(
0usize,
|_| Ok(Bounds::new(XBinary::NegInf, XBinary::PosInf)),
|state| Ok(state + 1),
));
let tolerance_exp = XUsize::Finite(6);
let reader = Arc::clone(&computable);
let handle = thread::spawn(move || {
for _ in 0..8 {
let bounds = reader.bounds().expect("bounds should succeed");
assert_width_nonnegative(&bounds);
}
});
let result = computable.refine_to::<3>(tolerance_exp);
assert!(matches!(
result,
Err(ComputableError::MaxRefinementIterations { max: 3 })
));
handle.join().expect("reader thread should join");
}
#[test]
fn refinement_parallelizes_multiple_refiners() {
use std::time::Instant;
const SLEEP_MS: u64 = 10;
let sleep_duration = Duration::from_millis(SLEEP_MS);
let slow_refiner = || {
Computable::new(
0usize,
|_| Ok(Bounds::new(XBinary::NegInf, XBinary::PosInf)),
|state| {
thread::sleep(Duration::from_millis(SLEEP_MS));
Ok(state + 1)
},
)
};
let expr = slow_refiner() + slow_refiner() + slow_refiner() + slow_refiner();
let tolerance_exp = XUsize::Finite(6);
let start = Instant::now();
let result = expr.refine_to::<1>(tolerance_exp);
let elapsed = start.elapsed();
assert!(matches!(
result,
Err(ComputableError::MaxRefinementIterations { max: 1 })
));
assert!(
elapsed >= sleep_duration,
"refinement must not have actually run, elapsed {elapsed:?}"
);
assert!(
elapsed < 2 * sleep_duration,
"expected parallel refinement under {}ms, elapsed {elapsed:?}",
2 * SLEEP_MS
);
}
#[test]
fn concurrent_refine_to_shared_expression() {
let sqrt2 = sqrt_computable(2);
let base_expression =
(sqrt2.clone() + sqrt2.clone()) * (Computable::constant(bin(1, 0)) + sqrt2.clone());
let expression = Arc::new(base_expression);
let tolerance_exp = XUsize::Finite(10);
let barrier = Arc::new(Barrier::new(4));
let mut handles = Vec::new();
for _ in 0..3 {
let shared_expression = Arc::clone(&expression);
let shared_barrier = Arc::clone(&barrier);
handles.push(thread::spawn(move || {
shared_barrier.wait();
shared_expression.refine_to_default(tolerance_exp)
}));
}
barrier.wait();
let main_bounds = expression
.refine_to_default(tolerance_exp)
.expect("refine_to should succeed");
let main_upper = main_bounds.large();
assert!(bounds_width_leq(&main_bounds, &tolerance_exp));
for handle in handles {
let bounds = handle
.join()
.expect("thread should join")
.expect("refine_to should succeed");
let bounds_upper = bounds.large();
assert_width_nonnegative(&bounds);
assert!(bounds_width_leq(&bounds, &tolerance_exp));
assert!(bounds.small() <= &main_upper);
assert!(main_bounds.small() <= &bounds_upper);
}
}
#[test]
fn concurrent_refine_to_uses_single_refiner() {
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
let active_refines = Arc::new(AtomicUsize::new(0));
let saw_overlap = Arc::new(AtomicBool::new(false));
let shared_active = Arc::clone(&active_refines);
let shared_overlap = Arc::clone(&saw_overlap);
let computable = Computable::new(
Bounds::new(xbin(0, 0), xbin(4, 0)),
|state| Ok(state.clone()),
move |state: IntervalState| {
let prior = shared_active.fetch_add(1, Ordering::SeqCst);
if prior > 0 {
shared_overlap.store(true, Ordering::SeqCst);
}
thread::sleep(Duration::from_millis(10));
let next = interval_refine(state);
shared_active.fetch_sub(1, Ordering::SeqCst);
next
},
);
let shared = Arc::new(computable);
let tolerance_exp = XUsize::Finite(6);
let barrier = Arc::new(Barrier::new(3));
let mut handles = Vec::new();
for _ in 0..2 {
let shared_value = Arc::clone(&shared);
let shared_barrier = Arc::clone(&barrier);
handles.push(thread::spawn(move || {
shared_barrier.wait();
shared_value
.refine_to_default(tolerance_exp)
.expect("refine_to should succeed")
}));
}
barrier.wait();
let main_bounds = shared
.refine_to_default(tolerance_exp)
.expect("refine_to should succeed");
for handle in handles {
let bounds = handle.join().expect("thread should join");
assert_width_nonnegative(&bounds);
}
assert!(!saw_overlap.load(Ordering::SeqCst));
assert!(bounds_width_leq(&main_bounds, &tolerance_exp));
}
#[test]
fn concurrent_bounds_reads_during_refinement() {
let base_value = interval_midpoint_computable(0, 4);
let shared_value = Arc::new(base_value);
let tolerance_exp = XUsize::Finite(8);
let barrier = Arc::new(Barrier::new(2));
let reader = {
let reader_value = Arc::clone(&shared_value);
let reader_barrier = Arc::clone(&barrier);
thread::spawn(move || {
reader_barrier.wait();
for _ in 0..32 {
let bounds = reader_value.bounds().expect("bounds should succeed");
assert_width_nonnegative(&bounds);
}
})
};
barrier.wait();
let refined = shared_value
.refine_to_default(tolerance_exp)
.expect("refine_to should succeed");
reader.join().expect("reader should join");
assert_width_nonnegative(&refined);
}
#[test]
fn demand_skipping_unnecessarily_steps_already_precise_refiner() {
use std::time::Instant;
const SLOW_STEP_MS: u64 = 1000;
let x = Computable::new(
Bounds::new(xbin(0, 0), xbin(1024, 0)),
|state| Ok(state.clone()),
interval_refine_strict,
);
let y = Computable::new(
Bounds::new(xbin(0, 0), xbin(3, -3)),
|state| Ok(state.clone()),
move |state: Bounds| {
thread::sleep(Duration::from_millis(SLOW_STEP_MS));
interval_refine(state)
},
);
let sum = x + y;
let tolerance_exp = XUsize::Finite(1);
let start = Instant::now();
let bounds = sum
.refine_to_default(tolerance_exp)
.expect("refine_to should succeed");
let elapsed = start.elapsed();
assert!(
bounds_width_leq(&bounds, &tolerance_exp),
"bounds should meet target precision"
);
assert!(
elapsed >= Duration::from_millis(SLOW_STEP_MS),
"expected y to be stepped (demand budget flaw), but finished in {elapsed:?}"
);
}
}