oxinum_float/native/binary_splitting.rs
1//! Binary-splitting engine for hypergeometric-like series.
2//!
3//! This module implements the standard "binary splitting" divide-and-conquer
4//! algorithm for evaluating series of the form
5//!
6//! ```text
7//! S = Σ_{k=lo}^{hi-1} a(k) · P(lo) · P(lo+1) · … · P(k)
8//! / (Q(lo) · Q(lo+1) · … · Q(k)
9//! · B(lo) · B(lo+1) · … · B(k))
10//! ```
11//!
12//! where `P(k)`, `Q(k)`, `B(k)`, `a(k)` are integer-valued functions of `k`.
13//!
14//! # Algorithm
15//!
16//! Each recursive call returns a `BSSplit { p, q, b, t }` struct where
17//! `t / (q · b)` equals the partial sum over `[lo, hi)`. The combine step is:
18//!
19//! ```text
20//! p = p_L · p_R
21//! q = q_L · q_R
22//! b = b_L · b_R
23//! t = t_L · q_R · b_R + t_R · p_L
24//! ```
25//!
26//! This is `O(M(n) log n)` where `M(n)` is the cost of multiplying two n-digit
27//! integers (Karatsuba in this implementation).
28//!
29//! # Usage
30//!
31//! Implement [`BSSeries`] for your series, then call [`binary_split`]:
32//!
33//! ```no_run
34//! # use oxinum_float::native::binary_splitting::{BSSeries, BSSplit, binary_split};
35//! # use oxinum_int::native::BigInt;
36//! struct MySeries;
37//! impl BSSeries for MySeries {
38//! fn term(&self, k: u64) -> (BigInt, BigInt, BigInt, BigInt) {
39//! (BigInt::one(), BigInt::one(), BigInt::one(), BigInt::one())
40//! }
41//! }
42//! let split = binary_split(&MySeries, 0, 10);
43//! ```
44
45use oxinum_int::native::BigInt;
46
47// ---------------------------------------------------------------------------
48// Public data type
49// ---------------------------------------------------------------------------
50
51/// Result of binary-splitting over a range `[lo, hi)`.
52///
53/// The partial sum equals `t / (q · b)`.
54pub struct BSSplit {
55 /// Cumulative numerator factor `P(lo) · P(lo+1) · … · P(hi-1)`.
56 pub p: BigInt,
57 /// Cumulative denominator factor `Q(lo) · Q(lo+1) · … · Q(hi-1)`.
58 pub q: BigInt,
59 /// Cumulative denominator factor `B(lo) · B(lo+1) · … · B(hi-1)`.
60 pub b: BigInt,
61 /// Accumulated partial-sum numerator (over shared denominator `q · b`).
62 pub t: BigInt,
63}
64
65// ---------------------------------------------------------------------------
66// Series trait
67// ---------------------------------------------------------------------------
68
69/// Trait that defines the per-term factors of a binary-splittable series.
70///
71/// For term index `k`, implementors return `(p_k, q_k, b_k, a_k)`:
72///
73/// * `p_k` — numerator factor at position `k`.
74/// * `q_k` — denominator factor at position `k`.
75/// * `b_k` — auxiliary denominator factor at position `k` (often `1`).
76/// * `a_k` — coefficient / weight of the `k`-th term (can be negative).
77///
78/// The partial sum is then:
79/// ```text
80/// Σ a(k) · P(0..k) / (Q(0..k) · B(0..k))
81/// ```
82/// where `P(0..k) = p(0)·p(1)·…·p(k)`, etc.
83pub trait BSSeries {
84 /// Returns `(p_k, q_k, b_k, a_k)` for term index `k`.
85 fn term(&self, k: u64) -> (BigInt, BigInt, BigInt, BigInt);
86}
87
88// ---------------------------------------------------------------------------
89// Core engine
90// ---------------------------------------------------------------------------
91
92// ---------------------------------------------------------------------------
93// Combine helper (shared between serial and parallel implementations)
94// ---------------------------------------------------------------------------
95
96/// Combine two adjacent binary-split sub-results into one.
97///
98/// The algebraic identity is:
99/// ```text
100/// p = p_L · p_R
101/// q = q_L · q_R
102/// b = b_L · b_R
103/// t = t_L · q_R · b_R + t_R · p_L
104/// ```
105#[inline]
106fn combine(l: BSSplit, r: BSSplit) -> BSSplit {
107 let p = &l.p * &r.p;
108 let q = &l.q * &r.q;
109 let b = &l.b * &r.b;
110 let t = &l.t * &r.q * &r.b + &r.t * &l.p;
111 BSSplit { p, q, b, t }
112}
113
114/// Evaluate `Σ_{k=lo}^{hi-1}` using binary splitting.
115///
116/// `hi` must be strictly greater than `lo`.
117///
118/// # Panics
119///
120/// Panics if `hi <= lo`.
121#[cfg(not(feature = "parallel"))]
122pub fn binary_split<S: BSSeries>(series: &S, lo: u64, hi: u64) -> BSSplit {
123 assert!(hi > lo, "binary_split: hi ({hi}) must be > lo ({lo})");
124
125 if hi == lo + 1 {
126 // Base case: single term.
127 let (p, q, b, a) = series.term(lo);
128 let t = &a * &p;
129 return BSSplit { p, q, b, t };
130 }
131
132 let mid = lo + (hi - lo) / 2;
133 let l = binary_split(series, lo, mid);
134 let r = binary_split(series, mid, hi);
135 combine(l, r)
136}
137
138/// Minimum sub-problem size for parallel recursion.
139#[cfg(feature = "parallel")]
140const BS_PARALLEL_MIN: u64 = 64;
141
142/// Evaluate `Σ_{k=lo}^{hi-1}` using binary splitting with optional `rayon`
143/// parallelism (enabled by the `parallel` feature).
144///
145/// When `hi - lo >= BS_PARALLEL_MIN`, the two halves are computed in parallel
146/// via `rayon::join`. Smaller sub-problems fall back to sequential recursion.
147///
148/// The `Sync` bound is required so `series` can be shared across threads.
149///
150/// # Panics
151///
152/// Panics if `hi <= lo`.
153#[cfg(feature = "parallel")]
154pub fn binary_split<S: BSSeries + Sync>(series: &S, lo: u64, hi: u64) -> BSSplit {
155 assert!(hi > lo, "binary_split: hi ({hi}) must be > lo ({lo})");
156
157 if hi == lo + 1 {
158 let (p, q, b, a) = series.term(lo);
159 let t = &a * &p;
160 return BSSplit { p, q, b, t };
161 }
162
163 let mid = lo + (hi - lo) / 2;
164 let (l, r) = if hi - lo >= BS_PARALLEL_MIN {
165 rayon::join(
166 || binary_split(series, lo, mid),
167 || binary_split(series, mid, hi),
168 )
169 } else {
170 (binary_split(series, lo, mid), binary_split(series, mid, hi))
171 };
172 combine(l, r)
173}
174
175#[cfg(feature = "parallel")]
176use rayon;
177
178// ---------------------------------------------------------------------------
179// Unit tests for the combining rule
180// ---------------------------------------------------------------------------
181
182#[cfg(test)]
183mod tests {
184 use super::*;
185
186 /// Simplest possible series: Σ 1 for k in [0, N). Sum should be N.
187 struct ConstantSeries;
188 impl BSSeries for ConstantSeries {
189 fn term(&self, _k: u64) -> (BigInt, BigInt, BigInt, BigInt) {
190 (BigInt::one(), BigInt::one(), BigInt::one(), BigInt::one())
191 }
192 }
193
194 #[test]
195 fn constant_series_base() {
196 let r = binary_split(&ConstantSeries, 0, 1);
197 // t=1, q=1, b=1 => sum = 1/1 = 1
198 assert_eq!(r.t, BigInt::one());
199 assert_eq!(r.q, BigInt::one());
200 }
201
202 #[test]
203 fn constant_series_n() {
204 // Σ_{k=0}^{N-1} 1 = N. sum = t/(q*b). p_total = 1^N = 1, q = 1, b = 1.
205 // With a(k)=1 and p(k)=1, t after binary split should equal N.
206 for n in 2u64..=20 {
207 let r = binary_split(&ConstantSeries, 0, n);
208 let expected_t = BigInt::from(n as i64);
209 assert_eq!(r.t, expected_t, "N={n}");
210 }
211 }
212
213 /// Geometric series: Σ_{k=0}^{N-1} (1/2)^k.
214 /// p(k)=1, q(k)=2, b(k)=1, a(k)=1.
215 /// Result = t/(q*b). At N terms: sum ≈ 2·(1 - 1/2^N).
216 struct GeomHalf;
217 impl BSSeries for GeomHalf {
218 fn term(&self, _k: u64) -> (BigInt, BigInt, BigInt, BigInt) {
219 (
220 BigInt::one(),
221 BigInt::from(2i64),
222 BigInt::one(),
223 BigInt::one(),
224 )
225 }
226 }
227
228 #[test]
229 fn geometric_half_n4() {
230 // With p(k)=1, q(k)=2, b(k)=1, a(k)=1 for k in 0..4:
231 // The series is Σ_{k=0}^{3} (1/q_prefix) where q_prefix(k) = 2^(k+1).
232 // sum = 1/2 + 1/4 + 1/8 + 1/16 = 15/16.
233 //
234 // Binary splitting gives t/(q*b):
235 // Q = 2^4 = 16, B = 1, T = 15 → sum = 15/16.
236 let r = binary_split(&GeomHalf, 0, 4);
237 let q16 = BigInt::from(16i64);
238 let b1 = BigInt::one();
239 assert_eq!(r.q, q16, "q should be 2^4 = 16");
240 assert_eq!(r.b, b1, "b should be 1");
241 assert_eq!(r.t, BigInt::from(15i64), "t should be 15");
242 }
243}