computable 0.1.0

Computable real numbers with guaranteed correctness via interval refinement
Documentation
//! The main Computable type representing computable real numbers.
//!
//! A `Computable` is a real number that can be refined to arbitrary precision.
//! It is backed by a computation graph where leaf nodes contain user-defined
//! state and refinement logic, and interior nodes represent arithmetic operations.

use std::num::NonZeroU32;
use std::sync::Arc;

use num_bigint::BigInt;
use num_traits::One;

use crate::binary::Bounds;
use crate::binary::{Binary, XBinary};
use crate::error::ComputableError;
use crate::node::{BaseNode, Node, TypedBaseNode};
use crate::ops::{AddOp, BaseOp, InvOp, MulOp, NegOp, NthRootOp, PiOp, PowOp, SinOp};
use crate::refinement::{RefinementGraph, XUsize, bounds_width_leq};

use parking_lot::RwLock;

#[cfg(debug_assertions)]
pub const DEFAULT_INV_MAX_REFINES: usize = 64;
#[cfg(not(debug_assertions))]
pub const DEFAULT_INV_MAX_REFINES: usize = 4096;

#[cfg(debug_assertions)]
pub const DEFAULT_MAX_REFINEMENT_ITERATIONS: usize = 64;
#[cfg(not(debug_assertions))]
pub const DEFAULT_MAX_REFINEMENT_ITERATIONS: usize = 4096;

/// A computable number backed by a shared node graph.
#[derive(Clone)]
pub struct Computable {
    node: Arc<Node>,
}

impl Computable {
    /// Creates a new computable from user-defined state and refinement logic.
    ///
    /// # Arguments
    /// * `state` - Initial state for this computable
    /// * `bounds` - Function to compute bounds from the current state
    /// * `refine` - Function to refine the state to a more precise version
    pub fn new<X, B, F>(state: X, bounds: B, refine: F) -> Self
    where
        X: Eq + Clone + Send + Sync + 'static,
        B: Fn(&X) -> Result<Bounds, ComputableError> + Send + Sync + 'static,
        F: Fn(X) -> Result<X, ComputableError> + Send + Sync + 'static,
    {
        let base_node_struct = TypedBaseNode::new(state, bounds, refine);
        let base_node: Arc<dyn BaseNode> = Arc::new(base_node_struct);
        let node = Node::new(Arc::new(BaseOp { base: base_node }));
        Self { node }
    }

    /// Creates a Computable from a pre-built Node.
    pub(crate) fn from_node(node: Arc<Node>) -> Self {
        Self { node }
    }

    /// Returns the current bounds for this computable.
    pub fn bounds(&self) -> Result<Bounds, ComputableError> {
        self.node.get_bounds()
    }

    /// Refines this computable until the bounds width is at most 2^(-tolerance_exp).
    ///
    /// # Arguments
    /// * `tolerance_exp` - Tolerance exponent. `Finite(n)` requests width ≤ 2^(-n).
    ///   `Inf` requests exact bounds (width = 0).
    ///
    /// # Type Parameters
    /// * `MAX_REFINEMENT_ITERATIONS` - Maximum number of refinement iterations
    ///
    /// # Warning
    /// Using `XUsize::Inf` (epsilon = 0) will only succeed for values that can be
    /// represented exactly in binary (e.g., integers, dyadic rationals like 1/2 or 3/4).
    /// For values that cannot be exactly represented (e.g., 1/3, sqrt(2), pi),
    /// refinement will never achieve zero width and will return
    /// [`ComputableError::MaxRefinementIterations`] after exhausting the iteration limit.
    pub fn refine_to<const MAX_REFINEMENT_ITERATIONS: usize>(
        &self,
        tolerance_exp: XUsize,
    ) -> Result<Bounds, ComputableError> {
        loop {
            let bounds = self.node.get_bounds()?;
            if bounds_width_leq(&bounds, &tolerance_exp) {
                return Ok(bounds);
            }

            let mut state_guard = self.node.refinement.state.lock();
            if !state_guard.active {
                state_guard.active = true;
                drop(state_guard);

                let graph = RefinementGraph::new(Arc::clone(&self.node))?;
                let result = graph.refine_to::<MAX_REFINEMENT_ITERATIONS>(&tolerance_exp);

                let mut completion_guard = self.node.refinement.state.lock();
                completion_guard.active = false;
                self.node.refinement.condvar.notify_all();
                return result;
            }

            let observed_epoch = state_guard.epoch;
            self.node
                .refinement
                .condvar
                .wait_while(&mut state_guard, |guard| {
                    guard.active && guard.epoch == observed_epoch
                });
        }
    }

    /// Refines this computable using the default maximum iterations.
    pub fn refine_to_default(&self, tolerance_exp: XUsize) -> Result<Bounds, ComputableError> {
        self.refine_to::<DEFAULT_MAX_REFINEMENT_ITERATIONS>(tolerance_exp)
    }

    /// Returns the multiplicative inverse of this computable.
    pub fn inv(self) -> Self {
        let node = Node::new(Arc::new(InvOp {
            inner: Arc::clone(&self.node),
            newton_state: RwLock::new(None),
        }));
        Self { node }
    }

    /// Computes the sine of this computable number.
    ///
    /// Uses Taylor series with provably correct error bounds.
    /// The implementation uses directed rounding throughout: the lower bound computation
    /// rounds toward negative infinity and the upper bound rounds toward positive infinity.
    /// The error bound |x|^(2n+1)/(2n+1)! is also computed conservatively (rounded up)
    /// to ensure the true value is always contained within the returned bounds.
    pub fn sin(self) -> Self {
        let pi_node = Node::new(Arc::new(PiOp {
            num_terms: RwLock::new(crate::ops::pi::INITIAL_PI_TERMS),
        }));
        let node = Node::new(Arc::new(SinOp {
            inner: Arc::clone(&self.node),
            pi_node,
            num_terms: RwLock::new(BigInt::one()),
        }));
        Self { node }
    }

    /// Computes the n-th root of this computable number.
    ///
    /// Uses binary search (bisection) for guaranteed convergence with provably
    /// correct bounds. For each refinement step, the interval is halved.
    ///
    /// # Arguments
    /// * `degree` - The root degree (n in x^(1/n)). Must be >= 1, enforced by the type system.
    ///
    /// # Constraints
    /// - For even degrees (2, 4, 6, ...): requires non-negative input
    /// - For odd degrees (3, 5, 7, ...): supports all real inputs
    ///
    /// # Examples
    /// - `nth_root(NonZeroU32::new(2).unwrap())` computes the square root
    /// - `nth_root(NonZeroU32::new(3).unwrap())` computes the cube root
    /// - `nth_root(NonZeroU32::new(4).unwrap())` computes the fourth root
    pub fn nth_root(self, degree: NonZeroU32) -> Self {
        let node = Node::new(Arc::new(NthRootOp {
            inner: Arc::clone(&self.node),
            degree,
            bisection_state: RwLock::new(None),
        }));
        Self { node }
    }

    /// Raises this computable number to an integer power.
    ///
    /// Computes x^n for non-negative integer exponents. This is more efficient than
    /// repeated multiplication because it computes bounds directly using the
    /// monotonicity properties of power functions.
    ///
    /// # Arguments
    /// * `exponent` - The power to raise to (n in x^n).
    ///
    /// # Bounds Computation
    /// - For n=0: returns constant 1 (including 0^0 = 1 by convention)
    /// - For odd n: x^n is monotonically increasing, so bounds are [lower^n, upper^n]
    /// - For even n: x^n has a minimum at 0
    ///   - If interval is non-negative: [lower^n, upper^n]
    ///   - If interval is non-positive: [upper^n, lower^n]
    ///   - If interval spans zero: [0, max(|lower|^n, |upper|^n)]
    ///
    /// # Examples
    /// - `pow(0)` returns constant 1
    /// - `pow(2)` computes the square
    /// - `pow(3)` computes the cube
    pub fn pow(self, exponent: u32) -> Self {
        match std::num::NonZeroU32::new(exponent) {
            None => {
                // x^0 = 1 for all x, including 0^0 = 1 by convention
                // Check for infinite bounds - infinity^0 is an indeterminate form.
                if let Ok(bounds) = self.node.get_bounds() {
                    let has_infinite = matches!(bounds.small(), XBinary::NegInf | XBinary::PosInf)
                        || matches!(&bounds.large(), XBinary::NegInf | XBinary::PosInf);
                    if has_infinite {
                        crate::detected_computable_with_infinite_value!(
                            "input has infinite bounds for x^0 (infinity^0 is an indeterminate form)"
                        );
                    }
                }
                Computable::constant(Binary::new(
                    num_bigint::BigInt::from(1),
                    num_bigint::BigInt::from(0),
                ))
            }
            Some(nonzero_exp) => {
                let node = Node::new(Arc::new(PowOp {
                    inner: Arc::clone(&self.node),
                    exponent: nonzero_exp,
                }));
                Self { node }
            }
        }
    }

    /// Creates a constant computable with exact bounds.
    pub fn constant(value: Binary) -> Self {
        fn bounds(value: &Binary) -> Result<Bounds, ComputableError> {
            Ok(Bounds::new(
                XBinary::Finite(value.clone()),
                XBinary::Finite(value.clone()),
            ))
        }

        fn refine(value: Binary) -> Result<Binary, ComputableError> {
            Ok(value)
        }

        Computable::new(value, bounds, refine)
    }
}

impl From<Binary> for Computable {
    fn from(value: Binary) -> Self {
        Computable::constant(value)
    }
}

impl std::ops::Neg for Computable {
    type Output = Self;

    fn neg(self) -> Self::Output {
        let node = Node::new(Arc::new(NegOp {
            inner: Arc::clone(&self.node),
        }));
        Self { node }
    }
}

impl std::ops::Add for Computable {
    type Output = Self;

    fn add(self, rhs: Self) -> Self::Output {
        let node = Node::new(Arc::new(AddOp {
            left: Arc::clone(&self.node),
            right: Arc::clone(&rhs.node),
        }));
        Self { node }
    }
}

impl std::ops::Sub for Computable {
    type Output = Self;

    fn sub(self, rhs: Self) -> Self::Output {
        self + (-rhs)
    }
}

impl std::ops::Mul for Computable {
    type Output = Self;

    fn mul(self, rhs: Self) -> Self::Output {
        let node = Node::new(Arc::new(MulOp {
            left: Arc::clone(&self.node),
            right: Arc::clone(&rhs.node),
        }));
        Self { node }
    }
}

#[allow(clippy::suspicious_arithmetic_impl)]
impl std::ops::Div for Computable {
    type Output = Self;

    fn div(self, rhs: Self) -> Self::Output {
        self * rhs.inv()
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::refinement::XUsize;
    use crate::test_utils::{bin, epsilon_as_binary, unwrap_finite};

    fn sqrt_computable(value_int: u64) -> Computable {
        Computable::constant(bin(value_int as i64, 0))
            .nth_root(NonZeroU32::new(2).expect("2 is non-zero"))
    }

    #[test]
    fn from_binary_matches_constant_bounds() {
        let value = bin(3, 0);
        let computable: Computable = value.clone().into();

        let bounds = computable.bounds().expect("bounds should succeed");
        assert_eq!(
            bounds,
            Bounds::new(XBinary::Finite(value.clone()), XBinary::Finite(value))
        );
    }

    #[test]
    fn integration_sqrt2_expression() {
        let one = Computable::constant(bin(1, 0));
        let sqrt2 = sqrt_computable(2);
        let expr = (sqrt2.clone() + one.clone()) * (sqrt2.clone() - one) + sqrt2.inv();

        let tolerance_exp = XUsize::Finite(12);
        let bounds = expr
            .refine_to_default(tolerance_exp)
            .expect("refine_to should succeed");

        let lower = unwrap_finite(bounds.small());
        let upper = bounds.large();
        let upper = unwrap_finite(&upper);
        let expected = 1.0_f64 + 2.0_f64.sqrt().recip();
        let expected_binary =
            XBinary::from_f64(expected).expect("expected value should convert to extended binary");
        let expected_value = unwrap_finite(&expected_binary);
        let eps_binary = epsilon_as_binary(12);

        let lower_plus = lower.add(&eps_binary);
        let upper_minus = upper.sub(&eps_binary);

        assert!(lower <= expected_value && expected_value <= upper);
        assert!(upper_minus <= expected_value && expected_value <= lower_plus);
    }

    #[test]
    fn shared_operand_in_expression() {
        let shared = sqrt_computable(2);
        let expr = shared.clone() + shared * Computable::constant(bin(1, 0));

        let tolerance_exp = XUsize::Finite(12);
        let bounds = expr
            .refine_to_default(tolerance_exp)
            .expect("refine_to should succeed");

        let lower = unwrap_finite(bounds.small());
        let upper = bounds.large();
        let upper = unwrap_finite(&upper);
        let expected = 2.0_f64 * 2.0_f64.sqrt();
        let expected_binary =
            XBinary::from_f64(expected).expect("expected value should convert to extended binary");
        let expected_value = unwrap_finite(&expected_binary);
        let eps_binary = epsilon_as_binary(12);

        let lower_plus = lower.add(&eps_binary);
        let upper_minus = upper.sub(&eps_binary);

        assert!(lower <= expected_value && expected_value <= upper);
        assert!(upper_minus <= expected_value && expected_value <= lower_plus);
    }
}