oxinum_float/native/binary_splitting.rs
1//! Binary-splitting engine for hypergeometric-like series.
2//!
3//! This module implements the standard "binary splitting" divide-and-conquer
4//! algorithm for evaluating series of the form
5//!
6//! ```text
7//! S = Σ_{k=lo}^{hi-1} a(k) · P(lo) · P(lo+1) · … · P(k)
8//! / (Q(lo) · Q(lo+1) · … · Q(k)
9//! · B(lo) · B(lo+1) · … · B(k))
10//! ```
11//!
12//! where `P(k)`, `Q(k)`, `B(k)`, `a(k)` are integer-valued functions of `k`.
13//!
14//! # Algorithm
15//!
16//! Each recursive call returns a `BSSplit { p, q, b, t }` struct where
17//! `t / (q · b)` equals the partial sum over `[lo, hi)`. The combine step is:
18//!
19//! ```text
20//! p = p_L · p_R
21//! q = q_L · q_R
22//! b = b_L · b_R
23//! t = t_L · q_R · b_R + t_R · p_L
24//! ```
25//!
26//! This is `O(M(n) log n)` where `M(n)` is the cost of multiplying two n-digit
27//! integers (Karatsuba in this implementation).
28//!
29//! # Usage
30//!
31//! Implement [`BSSeries`] for your series, then call [`binary_split`]:
32//!
33//! ```no_run
34//! # use oxinum_float::native::binary_splitting::{BSSeries, BSSplit, binary_split};
35//! # use oxinum_int::native::BigInt;
36//! struct MySeries;
37//! impl BSSeries for MySeries {
38//! fn term(&self, k: u64) -> (BigInt, BigInt, BigInt, BigInt) {
39//! (BigInt::one(), BigInt::one(), BigInt::one(), BigInt::one())
40//! }
41//! }
42//! let split = binary_split(&MySeries, 0, 10);
43//! ```
44
45use oxinum_int::native::BigInt;
46
47// ---------------------------------------------------------------------------
48// Public data type
49// ---------------------------------------------------------------------------
50
51/// Result of binary-splitting over a range `[lo, hi)`.
52///
53/// The partial sum equals `t / (q · b)`.
54pub struct BSSplit {
55 /// Cumulative numerator factor `P(lo) · P(lo+1) · … · P(hi-1)`.
56 pub p: BigInt,
57 /// Cumulative denominator factor `Q(lo) · Q(lo+1) · … · Q(hi-1)`.
58 pub q: BigInt,
59 /// Cumulative denominator factor `B(lo) · B(lo+1) · … · B(hi-1)`.
60 pub b: BigInt,
61 /// Accumulated partial-sum numerator (over shared denominator `q · b`).
62 pub t: BigInt,
63}
64
65// ---------------------------------------------------------------------------
66// Series trait
67// ---------------------------------------------------------------------------
68
69/// Trait that defines the per-term factors of a binary-splittable series.
70///
71/// For term index `k`, implementors return `(p_k, q_k, b_k, a_k)`:
72///
73/// * `p_k` — numerator factor at position `k`.
74/// * `q_k` — denominator factor at position `k`.
75/// * `b_k` — auxiliary denominator factor at position `k` (often `1`).
76/// * `a_k` — coefficient / weight of the `k`-th term (can be negative).
77///
78/// The partial sum is then:
79/// ```text
80/// Σ a(k) · P(0..k) / (Q(0..k) · B(0..k))
81/// ```
82/// where `P(0..k) = p(0)·p(1)·…·p(k)`, etc.
83pub trait BSSeries {
84 /// Returns `(p_k, q_k, b_k, a_k)` for term index `k`.
85 fn term(&self, k: u64) -> (BigInt, BigInt, BigInt, BigInt);
86}
87
88// ---------------------------------------------------------------------------
89// Core engine
90// ---------------------------------------------------------------------------
91
92/// Evaluate `Σ_{k=lo}^{hi-1}` using binary splitting.
93///
94/// `hi` must be strictly greater than `lo`.
95///
96/// # Panics
97///
98/// Panics if `hi <= lo`.
99pub fn binary_split<S: BSSeries>(series: &S, lo: u64, hi: u64) -> BSSplit {
100 assert!(hi > lo, "binary_split: hi ({hi}) must be > lo ({lo})");
101
102 if hi == lo + 1 {
103 // Base case: single term.
104 let (p, q, b, a) = series.term(lo);
105 let t = &a * &p;
106 return BSSplit { p, q, b, t };
107 }
108
109 let mid = lo + (hi - lo) / 2;
110 let l = binary_split(series, lo, mid);
111 let r = binary_split(series, mid, hi);
112
113 // Combine:
114 // p = p_L · p_R
115 // q = q_L · q_R
116 // b = b_L · b_R
117 // t = t_L · q_R · b_R + t_R · p_L
118 let p = &l.p * &r.p;
119 let q = &l.q * &r.q;
120 let b = &l.b * &r.b;
121 let t = &l.t * &r.q * &r.b + &r.t * &l.p;
122
123 BSSplit { p, q, b, t }
124}
125
126// ---------------------------------------------------------------------------
127// Unit tests for the combining rule
128// ---------------------------------------------------------------------------
129
130#[cfg(test)]
131mod tests {
132 use super::*;
133
134 /// Simplest possible series: Σ 1 for k in [0, N). Sum should be N.
135 struct ConstantSeries;
136 impl BSSeries for ConstantSeries {
137 fn term(&self, _k: u64) -> (BigInt, BigInt, BigInt, BigInt) {
138 (BigInt::one(), BigInt::one(), BigInt::one(), BigInt::one())
139 }
140 }
141
142 #[test]
143 fn constant_series_base() {
144 let r = binary_split(&ConstantSeries, 0, 1);
145 // t=1, q=1, b=1 => sum = 1/1 = 1
146 assert_eq!(r.t, BigInt::one());
147 assert_eq!(r.q, BigInt::one());
148 }
149
150 #[test]
151 fn constant_series_n() {
152 // Σ_{k=0}^{N-1} 1 = N. sum = t/(q*b). p_total = 1^N = 1, q = 1, b = 1.
153 // With a(k)=1 and p(k)=1, t after binary split should equal N.
154 for n in 2u64..=20 {
155 let r = binary_split(&ConstantSeries, 0, n);
156 let expected_t = BigInt::from(n as i64);
157 assert_eq!(r.t, expected_t, "N={n}");
158 }
159 }
160
161 /// Geometric series: Σ_{k=0}^{N-1} (1/2)^k.
162 /// p(k)=1, q(k)=2, b(k)=1, a(k)=1.
163 /// Result = t/(q*b). At N terms: sum ≈ 2·(1 - 1/2^N).
164 struct GeomHalf;
165 impl BSSeries for GeomHalf {
166 fn term(&self, _k: u64) -> (BigInt, BigInt, BigInt, BigInt) {
167 (
168 BigInt::one(),
169 BigInt::from(2i64),
170 BigInt::one(),
171 BigInt::one(),
172 )
173 }
174 }
175
176 #[test]
177 fn geometric_half_n4() {
178 // With p(k)=1, q(k)=2, b(k)=1, a(k)=1 for k in 0..4:
179 // The series is Σ_{k=0}^{3} (1/q_prefix) where q_prefix(k) = 2^(k+1).
180 // sum = 1/2 + 1/4 + 1/8 + 1/16 = 15/16.
181 //
182 // Binary splitting gives t/(q*b):
183 // Q = 2^4 = 16, B = 1, T = 15 → sum = 15/16.
184 let r = binary_split(&GeomHalf, 0, 4);
185 let q16 = BigInt::from(16i64);
186 let b1 = BigInt::one();
187 assert_eq!(r.q, q16, "q should be 2^4 = 16");
188 assert_eq!(r.b, b1, "b should be 1");
189 assert_eq!(r.t, BigInt::from(15i64), "t should be 15");
190 }
191}