use oxinum_int::native::BigInt;
pub struct BSSplit {
pub p: BigInt,
pub q: BigInt,
pub b: BigInt,
pub t: BigInt,
}
pub trait BSSeries {
fn term(&self, k: u64) -> (BigInt, BigInt, BigInt, BigInt);
}
#[inline]
fn combine(l: BSSplit, r: BSSplit) -> BSSplit {
let p = &l.p * &r.p;
let q = &l.q * &r.q;
let b = &l.b * &r.b;
let t = &l.t * &r.q * &r.b + &r.t * &l.p;
BSSplit { p, q, b, t }
}
#[cfg(not(feature = "parallel"))]
pub fn binary_split<S: BSSeries>(series: &S, lo: u64, hi: u64) -> BSSplit {
assert!(hi > lo, "binary_split: hi ({hi}) must be > lo ({lo})");
if hi == lo + 1 {
let (p, q, b, a) = series.term(lo);
let t = &a * &p;
return BSSplit { p, q, b, t };
}
let mid = lo + (hi - lo) / 2;
let l = binary_split(series, lo, mid);
let r = binary_split(series, mid, hi);
combine(l, r)
}
#[cfg(feature = "parallel")]
const BS_PARALLEL_MIN: u64 = 64;
#[cfg(feature = "parallel")]
pub fn binary_split<S: BSSeries + Sync>(series: &S, lo: u64, hi: u64) -> BSSplit {
assert!(hi > lo, "binary_split: hi ({hi}) must be > lo ({lo})");
if hi == lo + 1 {
let (p, q, b, a) = series.term(lo);
let t = &a * &p;
return BSSplit { p, q, b, t };
}
let mid = lo + (hi - lo) / 2;
let (l, r) = if hi - lo >= BS_PARALLEL_MIN {
rayon::join(
|| binary_split(series, lo, mid),
|| binary_split(series, mid, hi),
)
} else {
(binary_split(series, lo, mid), binary_split(series, mid, hi))
};
combine(l, r)
}
#[cfg(feature = "parallel")]
use rayon;
#[cfg(test)]
mod tests {
use super::*;
struct ConstantSeries;
impl BSSeries for ConstantSeries {
fn term(&self, _k: u64) -> (BigInt, BigInt, BigInt, BigInt) {
(BigInt::one(), BigInt::one(), BigInt::one(), BigInt::one())
}
}
#[test]
fn constant_series_base() {
let r = binary_split(&ConstantSeries, 0, 1);
assert_eq!(r.t, BigInt::one());
assert_eq!(r.q, BigInt::one());
}
#[test]
fn constant_series_n() {
for n in 2u64..=20 {
let r = binary_split(&ConstantSeries, 0, n);
let expected_t = BigInt::from(n as i64);
assert_eq!(r.t, expected_t, "N={n}");
}
}
struct GeomHalf;
impl BSSeries for GeomHalf {
fn term(&self, _k: u64) -> (BigInt, BigInt, BigInt, BigInt) {
(
BigInt::one(),
BigInt::from(2i64),
BigInt::one(),
BigInt::one(),
)
}
}
#[test]
fn geometric_half_n4() {
let r = binary_split(&GeomHalf, 0, 4);
let q16 = BigInt::from(16i64);
let b1 = BigInt::one();
assert_eq!(r.q, q16, "q should be 2^4 = 16");
assert_eq!(r.b, b1, "b should be 1");
assert_eq!(r.t, BigInt::from(15i64), "t should be 15");
}
}