Skip to main content

oxinum_float/native/
binary_splitting.rs

1//! Binary-splitting engine for hypergeometric-like series.
2//!
3//! This module implements the standard "binary splitting" divide-and-conquer
4//! algorithm for evaluating series of the form
5//!
6//! ```text
7//! S = Σ_{k=lo}^{hi-1} a(k) · P(lo) · P(lo+1) · … · P(k)
8//!                          / (Q(lo) · Q(lo+1) · … · Q(k)
9//!                             · B(lo) · B(lo+1) · … · B(k))
10//! ```
11//!
12//! where `P(k)`, `Q(k)`, `B(k)`, `a(k)` are integer-valued functions of `k`.
13//!
14//! # Algorithm
15//!
16//! Each recursive call returns a `BSSplit { p, q, b, t }` struct where
17//! `t / (q · b)` equals the partial sum over `[lo, hi)`. The combine step is:
18//!
19//! ```text
20//! p  = p_L · p_R
21//! q  = q_L · q_R
22//! b  = b_L · b_R
23//! t  = t_L · q_R · b_R + t_R · p_L
24//! ```
25//!
26//! This is `O(M(n) log n)` where `M(n)` is the cost of multiplying two n-digit
27//! integers (Karatsuba in this implementation).
28//!
29//! # Usage
30//!
31//! Implement [`BSSeries`] for your series, then call [`binary_split`]:
32//!
33//! ```no_run
34//! # use oxinum_float::native::binary_splitting::{BSSeries, BSSplit, binary_split};
35//! # use oxinum_int::native::BigInt;
36//! struct MySeries;
37//! impl BSSeries for MySeries {
38//!     fn term(&self, k: u64) -> (BigInt, BigInt, BigInt, BigInt) {
39//!         (BigInt::one(), BigInt::one(), BigInt::one(), BigInt::one())
40//!     }
41//! }
42//! let split = binary_split(&MySeries, 0, 10);
43//! ```
44
45use oxinum_int::native::BigInt;
46
47// ---------------------------------------------------------------------------
48// Public data type
49// ---------------------------------------------------------------------------
50
51/// Result of binary-splitting over a range `[lo, hi)`.
52///
53/// The partial sum equals `t / (q · b)`.
54pub struct BSSplit {
55    /// Cumulative numerator factor `P(lo) · P(lo+1) · … · P(hi-1)`.
56    pub p: BigInt,
57    /// Cumulative denominator factor `Q(lo) · Q(lo+1) · … · Q(hi-1)`.
58    pub q: BigInt,
59    /// Cumulative denominator factor `B(lo) · B(lo+1) · … · B(hi-1)`.
60    pub b: BigInt,
61    /// Accumulated partial-sum numerator (over shared denominator `q · b`).
62    pub t: BigInt,
63}
64
65// ---------------------------------------------------------------------------
66// Series trait
67// ---------------------------------------------------------------------------
68
69/// Trait that defines the per-term factors of a binary-splittable series.
70///
71/// For term index `k`, implementors return `(p_k, q_k, b_k, a_k)`:
72///
73/// * `p_k` — numerator factor at position `k`.
74/// * `q_k` — denominator factor at position `k`.
75/// * `b_k` — auxiliary denominator factor at position `k` (often `1`).
76/// * `a_k` — coefficient / weight of the `k`-th term (can be negative).
77///
78/// The partial sum is then:
79/// ```text
80/// Σ a(k) · P(0..k) / (Q(0..k) · B(0..k))
81/// ```
82/// where `P(0..k) = p(0)·p(1)·…·p(k)`, etc.
83pub trait BSSeries {
84    /// Returns `(p_k, q_k, b_k, a_k)` for term index `k`.
85    fn term(&self, k: u64) -> (BigInt, BigInt, BigInt, BigInt);
86}
87
88// ---------------------------------------------------------------------------
89// Core engine
90// ---------------------------------------------------------------------------
91
92// ---------------------------------------------------------------------------
93// Combine helper (shared between serial and parallel implementations)
94// ---------------------------------------------------------------------------
95
96/// Combine two adjacent binary-split sub-results into one.
97///
98/// The algebraic identity is:
99/// ```text
100/// p  = p_L · p_R
101/// q  = q_L · q_R
102/// b  = b_L · b_R
103/// t  = t_L · q_R · b_R  +  t_R · p_L
104/// ```
105#[inline]
106fn combine(l: BSSplit, r: BSSplit) -> BSSplit {
107    let p = &l.p * &r.p;
108    let q = &l.q * &r.q;
109    let b = &l.b * &r.b;
110    let t = &l.t * &r.q * &r.b + &r.t * &l.p;
111    BSSplit { p, q, b, t }
112}
113
114/// Evaluate `Σ_{k=lo}^{hi-1}` using binary splitting.
115///
116/// `hi` must be strictly greater than `lo`.
117///
118/// # Panics
119///
120/// Panics if `hi <= lo`.
121#[cfg(not(feature = "parallel"))]
122pub fn binary_split<S: BSSeries>(series: &S, lo: u64, hi: u64) -> BSSplit {
123    assert!(hi > lo, "binary_split: hi ({hi}) must be > lo ({lo})");
124
125    if hi == lo + 1 {
126        // Base case: single term.
127        let (p, q, b, a) = series.term(lo);
128        let t = &a * &p;
129        return BSSplit { p, q, b, t };
130    }
131
132    let mid = lo + (hi - lo) / 2;
133    let l = binary_split(series, lo, mid);
134    let r = binary_split(series, mid, hi);
135    combine(l, r)
136}
137
138/// Minimum sub-problem size for parallel recursion.
139#[cfg(feature = "parallel")]
140const BS_PARALLEL_MIN: u64 = 64;
141
142/// Evaluate `Σ_{k=lo}^{hi-1}` using binary splitting with optional `rayon`
143/// parallelism (enabled by the `parallel` feature).
144///
145/// When `hi - lo >= BS_PARALLEL_MIN`, the two halves are computed in parallel
146/// via `rayon::join`.  Smaller sub-problems fall back to sequential recursion.
147///
148/// The `Sync` bound is required so `series` can be shared across threads.
149///
150/// # Panics
151///
152/// Panics if `hi <= lo`.
153#[cfg(feature = "parallel")]
154pub fn binary_split<S: BSSeries + Sync>(series: &S, lo: u64, hi: u64) -> BSSplit {
155    assert!(hi > lo, "binary_split: hi ({hi}) must be > lo ({lo})");
156
157    if hi == lo + 1 {
158        let (p, q, b, a) = series.term(lo);
159        let t = &a * &p;
160        return BSSplit { p, q, b, t };
161    }
162
163    let mid = lo + (hi - lo) / 2;
164    let (l, r) = if hi - lo >= BS_PARALLEL_MIN {
165        rayon::join(
166            || binary_split(series, lo, mid),
167            || binary_split(series, mid, hi),
168        )
169    } else {
170        (binary_split(series, lo, mid), binary_split(series, mid, hi))
171    };
172    combine(l, r)
173}
174
175#[cfg(feature = "parallel")]
176use rayon;
177
178// ---------------------------------------------------------------------------
179// Unit tests for the combining rule
180// ---------------------------------------------------------------------------
181
182#[cfg(test)]
183mod tests {
184    use super::*;
185
186    /// Simplest possible series: Σ 1 for k in [0, N).  Sum should be N.
187    struct ConstantSeries;
188    impl BSSeries for ConstantSeries {
189        fn term(&self, _k: u64) -> (BigInt, BigInt, BigInt, BigInt) {
190            (BigInt::one(), BigInt::one(), BigInt::one(), BigInt::one())
191        }
192    }
193
194    #[test]
195    fn constant_series_base() {
196        let r = binary_split(&ConstantSeries, 0, 1);
197        // t=1, q=1, b=1  =>  sum = 1/1 = 1
198        assert_eq!(r.t, BigInt::one());
199        assert_eq!(r.q, BigInt::one());
200    }
201
202    #[test]
203    fn constant_series_n() {
204        // Σ_{k=0}^{N-1} 1 = N.  sum = t/(q*b).  p_total = 1^N = 1, q = 1, b = 1.
205        // With a(k)=1 and p(k)=1, t after binary split should equal N.
206        for n in 2u64..=20 {
207            let r = binary_split(&ConstantSeries, 0, n);
208            let expected_t = BigInt::from(n as i64);
209            assert_eq!(r.t, expected_t, "N={n}");
210        }
211    }
212
213    /// Geometric series: Σ_{k=0}^{N-1} (1/2)^k.
214    /// p(k)=1, q(k)=2, b(k)=1, a(k)=1.
215    /// Result = t/(q*b).  At N terms: sum ≈ 2·(1 - 1/2^N).
216    struct GeomHalf;
217    impl BSSeries for GeomHalf {
218        fn term(&self, _k: u64) -> (BigInt, BigInt, BigInt, BigInt) {
219            (
220                BigInt::one(),
221                BigInt::from(2i64),
222                BigInt::one(),
223                BigInt::one(),
224            )
225        }
226    }
227
228    #[test]
229    fn geometric_half_n4() {
230        // With p(k)=1, q(k)=2, b(k)=1, a(k)=1 for k in 0..4:
231        //   The series is Σ_{k=0}^{3} (1/q_prefix)  where q_prefix(k) = 2^(k+1).
232        //   sum = 1/2 + 1/4 + 1/8 + 1/16 = 15/16.
233        //
234        // Binary splitting gives t/(q*b):
235        //   Q = 2^4 = 16, B = 1, T = 15  →  sum = 15/16.
236        let r = binary_split(&GeomHalf, 0, 4);
237        let q16 = BigInt::from(16i64);
238        let b1 = BigInt::one();
239        assert_eq!(r.q, q16, "q should be 2^4 = 16");
240        assert_eq!(r.b, b1, "b should be 1");
241        assert_eq!(r.t, BigInt::from(15i64), "t should be 15");
242    }
243}