Skip to main content

oxinum_float/native/
binary_splitting.rs

1//! Binary-splitting engine for hypergeometric-like series.
2//!
3//! This module implements the standard "binary splitting" divide-and-conquer
4//! algorithm for evaluating series of the form
5//!
6//! ```text
7//! S = Σ_{k=lo}^{hi-1} a(k) · P(lo) · P(lo+1) · … · P(k)
8//!                          / (Q(lo) · Q(lo+1) · … · Q(k)
9//!                             · B(lo) · B(lo+1) · … · B(k))
10//! ```
11//!
12//! where `P(k)`, `Q(k)`, `B(k)`, `a(k)` are integer-valued functions of `k`.
13//!
14//! # Algorithm
15//!
16//! Each recursive call returns a `BSSplit { p, q, b, t }` struct where
17//! `t / (q · b)` equals the partial sum over `[lo, hi)`. The combine step is:
18//!
19//! ```text
20//! p  = p_L · p_R
21//! q  = q_L · q_R
22//! b  = b_L · b_R
23//! t  = t_L · q_R · b_R + t_R · p_L
24//! ```
25//!
26//! This is `O(M(n) log n)` where `M(n)` is the cost of multiplying two n-digit
27//! integers (Karatsuba in this implementation).
28//!
29//! # Usage
30//!
31//! Implement [`BSSeries`] for your series, then call [`binary_split`]:
32//!
33//! ```no_run
34//! # use oxinum_float::native::binary_splitting::{BSSeries, BSSplit, binary_split};
35//! # use oxinum_int::native::BigInt;
36//! struct MySeries;
37//! impl BSSeries for MySeries {
38//!     fn term(&self, k: u64) -> (BigInt, BigInt, BigInt, BigInt) {
39//!         (BigInt::one(), BigInt::one(), BigInt::one(), BigInt::one())
40//!     }
41//! }
42//! let split = binary_split(&MySeries, 0, 10);
43//! ```
44
45use oxinum_int::native::BigInt;
46
47// ---------------------------------------------------------------------------
48// Public data type
49// ---------------------------------------------------------------------------
50
51/// Result of binary-splitting over a range `[lo, hi)`.
52///
53/// The partial sum equals `t / (q · b)`.
54pub struct BSSplit {
55    /// Cumulative numerator factor `P(lo) · P(lo+1) · … · P(hi-1)`.
56    pub p: BigInt,
57    /// Cumulative denominator factor `Q(lo) · Q(lo+1) · … · Q(hi-1)`.
58    pub q: BigInt,
59    /// Cumulative denominator factor `B(lo) · B(lo+1) · … · B(hi-1)`.
60    pub b: BigInt,
61    /// Accumulated partial-sum numerator (over shared denominator `q · b`).
62    pub t: BigInt,
63}
64
65// ---------------------------------------------------------------------------
66// Series trait
67// ---------------------------------------------------------------------------
68
69/// Trait that defines the per-term factors of a binary-splittable series.
70///
71/// For term index `k`, implementors return `(p_k, q_k, b_k, a_k)`:
72///
73/// * `p_k` — numerator factor at position `k`.
74/// * `q_k` — denominator factor at position `k`.
75/// * `b_k` — auxiliary denominator factor at position `k` (often `1`).
76/// * `a_k` — coefficient / weight of the `k`-th term (can be negative).
77///
78/// The partial sum is then:
79/// ```text
80/// Σ a(k) · P(0..k) / (Q(0..k) · B(0..k))
81/// ```
82/// where `P(0..k) = p(0)·p(1)·…·p(k)`, etc.
83pub trait BSSeries {
84    /// Returns `(p_k, q_k, b_k, a_k)` for term index `k`.
85    fn term(&self, k: u64) -> (BigInt, BigInt, BigInt, BigInt);
86}
87
88// ---------------------------------------------------------------------------
89// Core engine
90// ---------------------------------------------------------------------------
91
92/// Evaluate `Σ_{k=lo}^{hi-1}` using binary splitting.
93///
94/// `hi` must be strictly greater than `lo`.
95///
96/// # Panics
97///
98/// Panics if `hi <= lo`.
99pub fn binary_split<S: BSSeries>(series: &S, lo: u64, hi: u64) -> BSSplit {
100    assert!(hi > lo, "binary_split: hi ({hi}) must be > lo ({lo})");
101
102    if hi == lo + 1 {
103        // Base case: single term.
104        let (p, q, b, a) = series.term(lo);
105        let t = &a * &p;
106        return BSSplit { p, q, b, t };
107    }
108
109    let mid = lo + (hi - lo) / 2;
110    let l = binary_split(series, lo, mid);
111    let r = binary_split(series, mid, hi);
112
113    // Combine:
114    //   p = p_L · p_R
115    //   q = q_L · q_R
116    //   b = b_L · b_R
117    //   t = t_L · q_R · b_R  +  t_R · p_L
118    let p = &l.p * &r.p;
119    let q = &l.q * &r.q;
120    let b = &l.b * &r.b;
121    let t = &l.t * &r.q * &r.b + &r.t * &l.p;
122
123    BSSplit { p, q, b, t }
124}
125
126// ---------------------------------------------------------------------------
127// Unit tests for the combining rule
128// ---------------------------------------------------------------------------
129
130#[cfg(test)]
131mod tests {
132    use super::*;
133
134    /// Simplest possible series: Σ 1 for k in [0, N).  Sum should be N.
135    struct ConstantSeries;
136    impl BSSeries for ConstantSeries {
137        fn term(&self, _k: u64) -> (BigInt, BigInt, BigInt, BigInt) {
138            (BigInt::one(), BigInt::one(), BigInt::one(), BigInt::one())
139        }
140    }
141
142    #[test]
143    fn constant_series_base() {
144        let r = binary_split(&ConstantSeries, 0, 1);
145        // t=1, q=1, b=1  =>  sum = 1/1 = 1
146        assert_eq!(r.t, BigInt::one());
147        assert_eq!(r.q, BigInt::one());
148    }
149
150    #[test]
151    fn constant_series_n() {
152        // Σ_{k=0}^{N-1} 1 = N.  sum = t/(q*b).  p_total = 1^N = 1, q = 1, b = 1.
153        // With a(k)=1 and p(k)=1, t after binary split should equal N.
154        for n in 2u64..=20 {
155            let r = binary_split(&ConstantSeries, 0, n);
156            let expected_t = BigInt::from(n as i64);
157            assert_eq!(r.t, expected_t, "N={n}");
158        }
159    }
160
161    /// Geometric series: Σ_{k=0}^{N-1} (1/2)^k.
162    /// p(k)=1, q(k)=2, b(k)=1, a(k)=1.
163    /// Result = t/(q*b).  At N terms: sum ≈ 2·(1 - 1/2^N).
164    struct GeomHalf;
165    impl BSSeries for GeomHalf {
166        fn term(&self, _k: u64) -> (BigInt, BigInt, BigInt, BigInt) {
167            (
168                BigInt::one(),
169                BigInt::from(2i64),
170                BigInt::one(),
171                BigInt::one(),
172            )
173        }
174    }
175
176    #[test]
177    fn geometric_half_n4() {
178        // With p(k)=1, q(k)=2, b(k)=1, a(k)=1 for k in 0..4:
179        //   The series is Σ_{k=0}^{3} (1/q_prefix)  where q_prefix(k) = 2^(k+1).
180        //   sum = 1/2 + 1/4 + 1/8 + 1/16 = 15/16.
181        //
182        // Binary splitting gives t/(q*b):
183        //   Q = 2^4 = 16, B = 1, T = 15  →  sum = 15/16.
184        let r = binary_split(&GeomHalf, 0, 4);
185        let q16 = BigInt::from(16i64);
186        let b1 = BigInt::one();
187        assert_eq!(r.q, q16, "q should be 2^4 = 16");
188        assert_eq!(r.b, b1, "b should be 1");
189        assert_eq!(r.t, BigInt::from(15i64), "t should be 15");
190    }
191}