use num_bigint::BigInt;
use num_rational::{BigRational, Ratio};
use num_traits::{FromPrimitive, Signed, Zero};
use super::numtraits::{IntRing, ZSigned};
type Q = BigRational;
fn q_int_i128(n: i128) -> Q {
Ratio::from_integer(BigInt::from(n))
}
fn q_int_usize(n: usize) -> Q {
Ratio::from_integer(BigInt::from(n))
}
#[derive(Clone, Debug)]
struct PolyQ {
coeffs: Vec<Q>,
}
impl PolyQ {
fn zero() -> Self {
Self { coeffs: vec![] }
}
fn is_zero(&self) -> bool {
self.coeffs.is_empty()
}
fn deg(&self) -> i32 {
self.coeffs.len() as i32 - 1
}
fn lc(&self) -> &Q {
debug_assert!(!self.coeffs.is_empty());
self.coeffs.last().unwrap()
}
fn normalize(&mut self) {
while matches!(self.coeffs.last(), Some(c) if c.is_zero()) {
self.coeffs.pop();
}
}
fn from_int_coeffs(c: &[i64]) -> Self {
let mut p = Self {
coeffs: c.iter().map(|&x| q_int_i128(x as i128)).collect(),
};
p.normalize();
p
}
fn eval_at_int(&self, x: i128) -> Q {
let xq = q_int_i128(x);
let mut acc = Q::zero();
for c in self.coeffs.iter().rev() {
acc = &acc * &xq + c;
}
acc
}
fn neg_in_place(&mut self) {
for c in &mut self.coeffs {
*c = -std::mem::replace(c, Q::zero());
}
}
}
fn mul(a: &PolyQ, b: &PolyQ) -> PolyQ {
if a.is_zero() || b.is_zero() {
return PolyQ::zero();
}
let new_len = a.coeffs.len() + b.coeffs.len() - 1;
let mut out: Vec<Q> = (0..new_len).map(|_| Q::zero()).collect();
for (i, ai) in a.coeffs.iter().enumerate() {
for (j, bj) in b.coeffs.iter().enumerate() {
out[i + j] = &out[i + j] + ai * bj;
}
}
let mut p = PolyQ { coeffs: out };
p.normalize();
p
}
fn derivative(p: &PolyQ) -> PolyQ {
if p.coeffs.len() <= 1 {
return PolyQ::zero();
}
let coeffs: Vec<Q> = p
.coeffs
.iter()
.enumerate()
.skip(1)
.map(|(i, c)| c * q_int_usize(i))
.collect();
let mut out = PolyQ { coeffs };
out.normalize();
out
}
fn poly_rem_in_place(rem: &mut PolyQ, divisor: &PolyQ) {
debug_assert!(!divisor.is_zero(), "poly_rem_in_place: zero divisor");
let d_deg = divisor.deg();
let d_lc = divisor.lc().clone();
while !rem.is_zero() && rem.deg() >= d_deg {
let r_deg = rem.deg();
let r_lc = rem.lc().clone();
let term_coef = &r_lc / &d_lc;
let term_pow = (r_deg - d_deg) as usize;
for j in 0..divisor.coeffs.len() {
let idx = j + term_pow;
rem.coeffs[idx] = &rem.coeffs[idx] - &term_coef * &divisor.coeffs[j];
}
rem.normalize();
}
}
#[inline]
fn variation_count(seq: &[Q]) -> usize {
let mut prev: i8 = 0;
let mut count = 0usize;
for v in seq {
let s = if v.is_zero() {
0
} else if v.is_positive() {
1
} else {
-1
};
if s == 0 {
continue;
}
if prev != 0 && prev != s {
count += 1;
}
prev = s;
}
count
}
fn sturm_sign_at_root(p_coeffs: &[i64], f_coeffs: &[i64], lo: i64, hi: i64) -> i8 {
let p = PolyQ::from_int_coeffs(p_coeffs);
let f = PolyQ::from_int_coeffs(f_coeffs);
if f.is_zero() {
return 0;
}
let p_prime = derivative(&p);
let mut s1 = mul(&p_prime, &f);
poly_rem_in_place(&mut s1, &p);
let mut chain: Vec<PolyQ> = vec![p, s1];
while !chain.last().unwrap().is_zero() {
let prev = chain[chain.len() - 2].clone();
let curr = chain[chain.len() - 1].clone();
if curr.is_zero() {
break;
}
let mut r = prev;
poly_rem_in_place(&mut r, &curr);
r.neg_in_place();
chain.push(r);
}
if chain.last().is_some_and(|p| p.is_zero()) {
chain.pop();
}
let vals_lo: Vec<Q> = chain.iter().map(|s| s.eval_at_int(lo as i128)).collect();
let vals_hi: Vec<Q> = chain.iter().map(|s| s.eval_at_int(hi as i128)).collect();
let v_lo = variation_count(&vals_lo) as i32;
let v_hi = variation_count(&vals_hi) as i32;
(v_lo - v_hi) as i8
}
#[cfg(feature = "debug")]
mod profile {
use std::sync::atomic::{AtomicU64, Ordering};
static TOTAL: AtomicU64 = AtomicU64::new(0);
static BIG: AtomicU64 = AtomicU64::new(0);
static MAX_C: AtomicU64 = AtomicU64::new(0);
pub fn record(c_max: i128) {
let cm = c_max as u64;
TOTAL.fetch_add(1, Ordering::Relaxed);
if c_max >= 1024 {
BIG.fetch_add(1, Ordering::Relaxed);
}
let mut prev = MAX_C.load(Ordering::Relaxed);
while cm > prev {
match MAX_C.compare_exchange_weak(prev, cm, Ordering::Relaxed, Ordering::Relaxed) {
Ok(_) => break,
Err(p) => prev = p,
}
}
let t = TOTAL.load(Ordering::Relaxed);
if t.is_power_of_two() && t >= 1024 {
let bb = BIG.load(Ordering::Relaxed);
eprintln!(
"[sign_profile] total={t} big={bb} ({:.1}%) max_c={}",
bb as f64 / t as f64 * 100.0,
MAX_C.load(Ordering::Relaxed)
);
}
}
}
pub fn signum_sum_sqrt_expr_2<T: IntRing + ZSigned>(a: T, m: T, b: T, n: T) -> T {
let sgn_a = a.signum();
let sgn_b = b.signum();
if sgn_a == sgn_b {
return sgn_a;
}
if a.is_zero() {
return sgn_b;
}
if b.is_zero() {
return sgn_a;
}
(sgn_a * a * a * m + sgn_b * b * b * n).signum()
}
#[allow(clippy::too_many_arguments)]
pub fn signum_sum_sqrt_expr_4<T: IntRing + ZSigned + FromPrimitive>(
a: T,
k: T,
b: T,
m: T,
c: T,
n: T,
d: T,
l: T,
) -> T {
let sgn_ad_terms = signum_sum_sqrt_expr_2(a, k, d, l);
let sgn_bc_terms = signum_sum_sqrt_expr_2(b, m, c, n);
if sgn_bc_terms == sgn_ad_terms {
return sgn_ad_terms;
}
if sgn_ad_terms.is_zero() {
return sgn_bc_terms;
}
if sgn_bc_terms.is_zero() {
return sgn_ad_terms;
}
if !(k.is_one() && l == m * n) {
panic!("Unhandled general case!");
}
let four = T::from_i8(4).unwrap();
let mn = l;
let lhs = (b * b * m) + (c * c * n) - (d * d * mn) - (a * a);
let sq_lhs = lhs.signum() * lhs * lhs;
let ad_m_bc = (a * d) - (b * c);
let sq_rhs = four * mn * ad_m_bc.signum() * ad_m_bc * ad_m_bc;
sgn_bc_terms.signum() * (sq_lhs - sq_rhs).signum()
}
pub fn signum_sum_sqrt_expr_4_zz16<T: IntRing + ZSigned + FromPrimitive>(
a: T,
b: T,
c: T,
d: T,
) -> T {
let sp = signum_sum_sqrt_expr_2(a, T::one(), b, T::from_i8(2).unwrap());
let sq = signum_sum_sqrt_expr_2(c, T::one(), d, T::from_i8(2).unwrap());
if sp == sq {
return sp;
}
if sq.is_zero() {
return sp;
}
if sp.is_zero() {
return sq;
}
let int2 = T::from_i8(2).unwrap();
let int4 = T::from_i8(4).unwrap();
let aa = a * a;
let bb = b * b;
let cc = c * c;
let dd = d * d;
let cd = c * d;
let ab = a * b;
let alpha = aa + int2 * bb - int2 * cc - int4 * dd - int4 * cd;
let beta = int2 * ab - cc - int2 * dd - int4 * cd;
let spq = signum_sum_sqrt_expr_2(alpha, T::one(), beta, int2);
-sq * spq
}
pub fn signum_sum_sqrt_expr_4_pentagonal<T: IntRing + ZSigned + FromPrimitive>(
a: T,
b: T,
c: T,
d: T,
) -> T {
let sp = signum_sum_sqrt_expr_2(a, T::one(), b, T::from_i8(5).unwrap());
let sq = signum_sum_sqrt_expr_2(c, T::one(), d, T::from_i8(5).unwrap());
if sp == sq {
return sp;
}
if sq.is_zero() {
return sp;
}
let int2 = T::from_i8(2).unwrap();
let int5 = T::from_i8(5).unwrap();
let int10 = T::from_i8(10).unwrap();
let int20 = T::from_i8(20).unwrap();
let int50 = T::from_i8(50).unwrap();
let aa = a * a;
let bb = b * b;
let cc = c * c;
let dd = d * d;
let alpha = aa + int5 * bb - int10 * cc + int20 * c * d - int50 * dd;
let beta = int2 * a * b + int2 * cc - int20 * c * d + int10 * dd;
let spq = signum_sum_sqrt_expr_2(alpha, T::one(), beta, int5);
-sq * spq
}
#[allow(clippy::too_many_arguments)]
pub fn signum_sum_sqrt_expr_8_zz32<T: IntRing + ZSigned + FromPrimitive>(
a0: T,
a1: T,
a2: T,
a3: T,
a4: T,
a5: T,
a6: T,
a7: T,
) -> T {
let int2 = T::from_i8(2).unwrap();
let int4 = T::from_i8(4).unwrap();
let sp = signum_sum_sqrt_expr_4_zz16(a0, a1, a2, a4);
let sq = signum_sum_sqrt_expr_4_zz16(a3, a5, a6, a7);
if sp == sq {
return sp;
}
if sq.is_zero() {
return sp;
}
if sp.is_zero() {
return sq;
}
let p0 = a0 * a0 + int2 * a1 * a1 + int2 * a2 * a2 + int4 * a4 * a4 + int4 * a2 * a4;
let p1 = a2 * a2 + int2 * a4 * a4 + int2 * a0 * a1 + int4 * a2 * a4;
let p2 = int2 * a0 * a2 + int4 * a1 * a4;
let p3 = int2 * a0 * a4 + int2 * a1 * a2;
let q0 = a3 * a3 + int2 * a5 * a5 + int2 * a6 * a6 + int4 * a7 * a7 + int4 * a6 * a7;
let q1 = a6 * a6 + int2 * a7 * a7 + int2 * a3 * a5 + int4 * a6 * a7;
let q2 = int2 * a3 * a6 + int4 * a5 * a7;
let q3 = int2 * a3 * a7 + int2 * a5 * a6;
let qy_0 = int2 * q0 + int2 * q2 + int2 * q3;
let qy_1 = int2 * q1 + q2 + int2 * q3;
let qy_2 = int2 * q2 + q0;
let qy_3 = int2 * q3 + q1;
let alpha = p0 - qy_0;
let beta = p1 - qy_1;
let gamma = p2 - qy_2;
let delta = p3 - qy_3;
let spq = signum_sum_sqrt_expr_4_zz16(alpha, beta, gamma, delta);
-sq * spq
}
#[allow(clippy::too_many_arguments)]
pub fn signum_sum_sqrt_expr_8_zz60<T: IntRing + ZSigned + FromPrimitive>(
a: T,
b: T,
c: T,
d: T,
e: T,
f: T,
g: T,
h: T,
) -> T {
let int1 = T::one();
let int2 = T::from_i8(2).unwrap();
let int3 = T::from_i8(3).unwrap();
let int5 = T::from_i8(5).unwrap();
let int6 = T::from_i8(6).unwrap();
let int10 = T::from_i8(10).unwrap();
let int15 = T::from_i8(15).unwrap();
let sp = signum_sum_sqrt_expr_4(a, int1, b, int3, c, int5, e, int15);
let sq = signum_sum_sqrt_expr_4(d, int1, f, int3, g, int5, h, int15);
if sp == sq {
return sp;
}
if sq.is_zero() {
return sp;
}
if sp.is_zero() {
return sq;
}
let p2_0 = a * a + int3 * b * b + int5 * c * c + int15 * e * e;
let p2_1 = int2 * a * b + int10 * c * e;
let p2_2 = int2 * a * c + int6 * b * e;
let p2_3 = int2 * a * e + int2 * b * c;
let q2_0 = d * d + int3 * f * f + int5 * g * g + int15 * h * h;
let q2_1 = int2 * d * f + int10 * g * h;
let q2_2 = int2 * d * g + int6 * f * h;
let q2_3 = int2 * d * h + int2 * f * g;
let qy_0 = int10 * q2_0 - int10 * q2_2;
let qy_1 = int10 * q2_1 - int10 * q2_3;
let qy_2 = int10 * q2_2 - int2 * q2_0;
let qy_3 = int10 * q2_3 - int2 * q2_1;
let alpha = p2_0 - qy_0;
let beta = p2_1 - qy_1;
let gamma = p2_2 - qy_2;
let delta = p2_3 - qy_3;
let spq = signum_sum_sqrt_expr_4(alpha, int1, beta, int3, gamma, int5, delta, int15);
-sq * spq
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sum_root_expr_sign_2() {
assert_eq!(signum_sum_sqrt_expr_2(0, 2, 0, 3), 0);
assert_eq!(signum_sum_sqrt_expr_2(1, 2, 0, 3), 1);
assert_eq!(signum_sum_sqrt_expr_2(0, 2, -1, 3), -1);
assert_eq!(signum_sum_sqrt_expr_2(2, 2, -1, 3), 1);
assert_eq!(signum_sum_sqrt_expr_2(-5, 2, 4, 3), -1);
assert_eq!(signum_sum_sqrt_expr_2(-5, 2, 5, 3), 1);
}
#[test]
fn test_sum_root_expr_sign_4() {
let sign_zz24 = |a, b, c, d| signum_sum_sqrt_expr_4(a, 1, b, 2, c, 3, d, 6);
assert_eq!(sign_zz24(0, 0, 0, 0), 0);
assert_eq!(sign_zz24(1, 1, 1, 1), 1);
assert_eq!(sign_zz24(-1, -1, -1, -1), -1);
assert_eq!(sign_zz24(1, 0, 0, 0), 1);
assert_eq!(sign_zz24(0, -1, 0, 0), -1);
assert_eq!(sign_zz24(0, 0, 1, 0), 1);
assert_eq!(sign_zz24(0, 0, 0, -1), -1);
assert_eq!(sign_zz24(5, 7, 11, -13), 1);
assert_eq!(sign_zz24(5, 7, 11, -14), -1);
assert_eq!(sign_zz24(17, -11, 9, -7), -1);
assert_eq!(sign_zz24(18, -11, 9, -7), 1);
assert_eq!(sign_zz24(18, -11, 8, -7), -1);
assert_eq!(sign_zz24(18, -11, 8, -6), 1);
{
let (a, b, c, d) = (130, 92, 75, 53);
assert_eq!(sign_zz24(-a, -b, c, d), -1);
assert_eq!(sign_zz24(-a, b, -c, d), 1);
assert_eq!(sign_zz24(-a, b, c, -d), 1);
assert_eq!(sign_zz24(a, -b, -c, d), -1);
assert_eq!(sign_zz24(a, -b, c, -d), -1);
assert_eq!(sign_zz24(a, b, -c, -d), 1);
}
{
let (a, b, c, d) = (485, 343, 280, 198);
assert_eq!(sign_zz24(-a, -b, c, d), -1);
assert_eq!(sign_zz24(-a, b, -c, d), 1);
assert_eq!(sign_zz24(-a, b, c, -d), 1);
assert_eq!(sign_zz24(a, -b, -c, d), -1);
assert_eq!(sign_zz24(a, -b, c, -d), -1);
assert_eq!(sign_zz24(a, b, -c, -d), 1);
}
}
#[test]
fn test_sum_root_expr_sign_4_pentagonal() {
let sign_zz10 = signum_sum_sqrt_expr_4_pentagonal::<i64>;
assert_eq!(sign_zz10(0, 0, 0, 0), 0);
assert_eq!(sign_zz10(1, 0, 0, 0), 1);
assert_eq!(sign_zz10(-1, 0, 0, 0), -1);
assert_eq!(sign_zz10(0, 1, 0, 0), 1);
assert_eq!(sign_zz10(0, -1, 0, 0), -1);
assert_eq!(sign_zz10(0, 0, 1, 0), 1);
assert_eq!(sign_zz10(0, 0, -1, 0), -1);
assert_eq!(sign_zz10(0, 0, 0, 1), 1);
assert_eq!(sign_zz10(0, 0, 0, -1), -1);
assert_eq!(sign_zz10(1, 1, 0, 0), 1);
assert_eq!(sign_zz10(-1, -1, 0, 0), -1);
assert_eq!(sign_zz10(0, 0, 1, 1), 1);
assert_eq!(sign_zz10(0, 0, -1, -1), -1);
assert_eq!(sign_zz10(-3, 0, 1, 0), -1);
assert_eq!(sign_zz10(3, 0, -1, 0), 1);
assert_eq!(sign_zz10(1, 0, 0, -2), -1);
assert_eq!(sign_zz10(-1, 0, 0, 2), 1);
}
#[test]
fn test_sum_root_expr_sign_4_zz16() {
let sign_zz16 = signum_sum_sqrt_expr_4_zz16::<i64>;
assert_eq!(sign_zz16(0, 0, 0, 0), 0);
assert_eq!(sign_zz16(1, 0, 0, 0), 1);
assert_eq!(sign_zz16(-1, 0, 0, 0), -1);
assert_eq!(sign_zz16(0, 1, 0, 0), 1);
assert_eq!(sign_zz16(0, -1, 0, 0), -1);
assert_eq!(sign_zz16(0, 0, 1, 0), 1);
assert_eq!(sign_zz16(0, 0, -1, 0), -1);
assert_eq!(sign_zz16(0, 0, 0, 1), 1);
assert_eq!(sign_zz16(0, 0, 0, -1), -1);
assert_eq!(sign_zz16(1, 1, 0, 0), 1);
assert_eq!(sign_zz16(-1, -1, 0, 0), -1);
assert_eq!(sign_zz16(0, 0, 1, 1), 1);
assert_eq!(sign_zz16(0, 0, -1, -1), -1);
assert_eq!(sign_zz16(1, -1, 0, 0), -1);
assert_eq!(sign_zz16(2, -1, 0, 0), 1);
assert_eq!(sign_zz16(1, 0, -1, 0), -1);
assert_eq!(sign_zz16(2, 0, -1, 0), 1);
assert_eq!(sign_zz16(0, 2, 0, -1), 1);
assert_eq!(sign_zz16(0, 1, 0, -1), -1);
assert_eq!(sign_zz16(0, 0, 1, 1), 1);
assert_eq!(sign_zz16(1, 1, 0, 0), 1);
assert_eq!(sign_zz16(1, 1, 1, 1), 1);
assert_eq!(sign_zz16(-1, -1, -1, -1), -1);
}
#[test]
fn test_sum_root_expr_sign_8_zz60() {
let s = signum_sum_sqrt_expr_8_zz60::<i64>;
assert_eq!(s(0, 0, 0, 0, 0, 0, 0, 0), 0);
assert_eq!(s(1, 0, 0, 0, 0, 0, 0, 0), 1);
assert_eq!(s(-1, 0, 0, 0, 0, 0, 0, 0), -1);
assert_eq!(s(0, 1, 0, 0, 0, 0, 0, 0), 1);
assert_eq!(s(0, -1, 0, 0, 0, 0, 0, 0), -1);
assert_eq!(s(0, 0, 1, 0, 0, 0, 0, 0), 1);
assert_eq!(s(0, 0, -1, 0, 0, 0, 0, 0), -1);
assert_eq!(s(0, 0, 0, 1, 0, 0, 0, 0), 1);
assert_eq!(s(0, 0, 0, -1, 0, 0, 0, 0), -1);
assert_eq!(s(0, 0, 0, 0, 1, 0, 0, 0), 1);
assert_eq!(s(0, 0, 0, 0, 0, 1, 0, 0), 1);
assert_eq!(s(0, 0, 0, 0, 0, 0, 1, 0), 1);
assert_eq!(s(0, 0, 0, 0, 0, 0, 0, 1), 1);
assert_eq!(s(1, 1, 1, 1, 1, 1, 1, 1), 1);
assert_eq!(s(-1, -1, -1, -1, -1, -1, -1, -1), -1);
assert_eq!(s(1, -1, 0, 0, 0, 0, 0, 0), -1);
assert_eq!(s(2, -1, 0, 0, 0, 0, 0, 0), 1);
assert_eq!(s(0, -1, -1, 0, 1, 0, 0, 0), -1);
assert_eq!(s(0, 0, -1, 0, 1, 0, 0, 0), 1);
assert_eq!(s(0, 0, 0, 1, 0, 0, 0, 0), 1);
assert_eq!(s(0, 0, 0, -1, 0, 0, 0, 0), -1);
assert_eq!(s(-1, 0, 0, 1, 0, 0, 0, 0), 1);
assert_eq!(s(-3, 0, 0, 1, 0, 0, 0, 0), -1);
assert_eq!(s(3, 0, 0, -1, 0, 0, 0, 0), 1);
}
#[test]
fn test_sum_root_expr_sign_8_zz32() {
let s = signum_sum_sqrt_expr_8_zz32::<i64>;
assert_eq!(s(0, 0, 0, 0, 0, 0, 0, 0), 0);
for i in 0..8 {
let mut v = [0i64; 8];
v[i] = 1;
assert_eq!(
s(v[0], v[1], v[2], v[3], v[4], v[5], v[6], v[7]),
1,
"axis {i} positive should be +1"
);
v[i] = -1;
assert_eq!(
s(v[0], v[1], v[2], v[3], v[4], v[5], v[6], v[7]),
-1,
"axis {i} negative should be -1"
);
}
assert_eq!(s(1, 1, 1, 1, 1, 1, 1, 1), 1);
assert_eq!(s(-1, -1, -1, -1, -1, -1, -1, -1), -1);
assert_eq!(s(1, -1, 0, 0, 0, 0, 0, 0), -1);
assert_eq!(s(2, -1, 0, 0, 0, 0, 0, 0), 1);
assert_eq!(s(0, 0, 0, -1, 1, 0, 0, 0), 1);
assert_eq!(s(0, 0, 0, 1, -1, 0, 0, 0), -1);
assert_eq!(s(-1, 0, 0, 1, 0, 0, 0, 0), 1);
assert_eq!(s(-2, 0, 0, 1, 0, 0, 0, 0), -1);
assert_eq!(s(-3, 0, 0, 1, 0, 0, 0, 0), -1);
assert_eq!(s(3, 0, 0, -1, 0, 0, 0, 0), 1);
}
fn sign_f64(x: f64) -> i64 {
if x == 0.0 {
0
} else if x > 0.0 {
1
} else {
-1
}
}
#[test]
fn stress_signum_sum_sqrt_expr_2_matches_f64() {
for a in -10..=10 {
for b in -10..=10 {
for (m, n) in [(2f64, 3f64), (2f64, 5f64), (3f64, 7f64)] {
let got = signum_sum_sqrt_expr_2(a, m as i64, b, n as i64);
let exp = sign_f64((a as f64) * m.sqrt() + (b as f64) * n.sqrt());
assert_eq!(got, exp, "a={a} b={b} m={m} n={n}");
}
}
}
}
#[test]
fn stress_signum_sum_sqrt_expr_4_matches_f64() {
for a in -6..=6 {
for b in -6..=6 {
for c in -6..=6 {
for d in -6..=6 {
let got = signum_sum_sqrt_expr_4(a, 1, b, 2, c, 3, d, 6);
let val = (a as f64)
+ (b as f64) * 2f64.sqrt()
+ (c as f64) * 3f64.sqrt()
+ (d as f64) * 6f64.sqrt();
let exp = sign_f64(val);
assert_eq!(got, exp, "a={a} b={b} c={c} d={d}");
}
}
}
}
}
#[test]
fn stress_signum_sum_sqrt_expr_4_pentagonal_matches_f64() {
let sqrt5: f64 = 2.236_067_977_499_79;
let y: f64 = 10.0 - 2.0 * sqrt5;
for a in -6..=6 {
for b in -6..=6 {
for c in -6..=6 {
for d in -6..=6 {
let got = signum_sum_sqrt_expr_4_pentagonal(a, b, c, d);
let val = (a as f64)
+ (b as f64) * sqrt5
+ (c as f64) * y.sqrt()
+ (d as f64) * (5.0 * y).sqrt();
let exp = sign_f64(val);
assert_eq!(got, exp, "a={a} b={b} c={c} d={d}");
}
}
}
}
}
}
pub fn sign_at_cubic_root_in_interval(
coeffs: [i64; 3],
minpoly: [i64; 4],
lo: (i64, i64),
hi: (i64, i64),
) -> i8 {
if coeffs == [0, 0, 0] {
return 0;
}
debug_assert!(minpoly[3] == 1, "minpoly must be monic");
debug_assert!(lo.1 > 0 && hi.1 > 0, "denominators must be positive");
let a = coeffs[0] as i128;
let b = coeffs[1] as i128;
let d = coeffs[2] as i128;
let m0 = minpoly[0] as i128;
let m1 = minpoly[1] as i128;
let m2 = minpoly[2] as i128;
let m3 = minpoly[3] as i128;
let f_abs_num = |n: i128, dn: i128| -> i128 {
a.checked_mul(dn)
.and_then(|x| x.checked_mul(dn))
.and_then(|aa| {
b.checked_mul(n)
.and_then(|x| x.checked_mul(dn))
.map(|bb| aa + bb)
})
.and_then(|ab| {
d.checked_mul(n)
.and_then(|x| x.checked_mul(n))
.map(|dd| ab + dd)
})
.expect("sign_at_cubic_root: i128 overflow in f_abs_num")
};
let p_sign = |n: i128, dn: i128| -> i8 {
let v = (m0 * dn * dn * dn)
.checked_add(m1 * n * dn * dn)
.and_then(|x| x.checked_add(m2 * n * n * dn))
.and_then(|x| x.checked_add(m3 * n * n * n))
.expect("sign_at_cubic_root: i128 overflow in p_sign");
v.signum() as i8
};
let mut lo_n = lo.0 as i128;
let mut lo_d = lo.1 as i128;
let mut hi_n = hi.0 as i128;
let mut hi_d = hi.1 as i128;
let p_sign_lo = p_sign(lo_n, lo_d);
debug_assert!(
p_sign_lo != 0,
"sign_at_cubic_root: lo endpoint is a rational root of minpoly"
);
debug_assert!(
p_sign(hi_n, hi_d) == -p_sign_lo,
"sign_at_cubic_root: minpoly doesn't change sign across the isolating interval"
);
const MAX_BISECTIONS: usize = 38;
let l_bound = b.abs() + 4 * d.abs();
let c_max = a.abs().max(b.abs()).max(d.abs());
#[cfg(feature = "debug")]
profile::record(c_max);
if c_max >= 4096 {
debug_assert!(
lo.1 == 1 && hi.1 == 1,
"Sturm fallback expects integer-endpoint isolating intervals"
);
return sturm_sign_at_root(&minpoly, &coeffs, lo.0, hi.0);
}
const P: u32 = 50;
let fast = match minpoly {
[1, -2, -1, 1] => Some((
2284227452366899633255690010624i128,
4116035643581264728217479023376i128,
)),
[-1, -3, 0, 1] => Some((
2382403829538589549223439499264i128,
4477454596699003705844035366321i128,
)),
_ => None,
};
let fast_sign = fast.and_then(|(c_two_p, c_sq)| {
let v = (1i128 << (2 * P))
.checked_mul(a)
.and_then(|t| b.checked_mul(c_two_p).and_then(|x| t.checked_add(x)))
.and_then(|t| d.checked_mul(c_sq).and_then(|x| t.checked_add(x)))?;
let thresh = l_bound.checked_mul(1i128 << P)?;
(v.abs() > thresh).then_some(v.signum() as i8)
});
if let Some(s) = fast_sign {
return s;
}
for _ in 0..MAX_BISECTIONS {
let v_lo = f_abs_num(lo_n, lo_d);
let v_hi = f_abs_num(hi_n, hi_d);
let s_lo = v_lo.signum() as i8;
let s_hi = v_hi.signum() as i8;
if s_lo != 0 && s_lo == s_hi {
let width_num = hi_n
.checked_mul(lo_d)
.and_then(|x| x.checked_sub(lo_n.checked_mul(hi_d).unwrap_or(0)))
.expect("sign_at_cubic_root: width num overflow");
debug_assert!(width_num > 0);
let lhs_lo = v_lo.unsigned_abs() as i128 * hi_d;
let rhs_lo = (l_bound)
.checked_mul(width_num)
.and_then(|x| x.checked_mul(lo_d))
.expect("sign_at_cubic_root: width-check lo overflow");
let lhs_hi = v_hi.unsigned_abs() as i128 * lo_d;
let rhs_hi = (l_bound)
.checked_mul(width_num)
.and_then(|x| x.checked_mul(hi_d))
.expect("sign_at_cubic_root: width-check hi overflow");
if lhs_lo > rhs_lo && lhs_hi > rhs_hi {
return s_lo;
}
}
let m_num_raw = lo_n
.checked_mul(hi_d)
.and_then(|x| x.checked_add(hi_n.checked_mul(lo_d).unwrap_or(i128::MAX)))
.expect("sign_at_cubic_root: midpoint num overflow");
let m_den_raw = 2i128
.checked_mul(lo_d)
.and_then(|x| x.checked_mul(hi_d))
.expect("sign_at_cubic_root: midpoint den overflow");
let g = {
let mut x = m_num_raw.unsigned_abs();
let mut y = m_den_raw.unsigned_abs();
while y != 0 {
let t = x % y;
x = y;
y = t;
}
x as i128
};
let m_num = m_num_raw / g;
let m_den = m_den_raw / g;
let s_p_mid = p_sign(m_num, m_den);
if s_p_mid == 0 {
debug_assert!(
false,
"sign_at_cubic_root: midpoint is a rational root of minpoly"
);
return 0;
}
if s_p_mid == p_sign_lo {
lo_n = m_num;
lo_d = m_den;
} else {
hi_n = m_num;
hi_d = m_den;
}
}
debug_assert!(
lo.1 == 1 && hi.1 == 1,
"Sturm fallback expects integer-endpoint isolating intervals"
);
sturm_sign_at_root(&minpoly, &coeffs, lo.0, hi.0)
}
#[cfg(test)]
mod cubic_root_tests {
use super::sign_at_cubic_root_in_interval;
const ZZ14_MINPOLY: [i64; 4] = [1, -2, -1, 1];
const ZZ14_ISO_LO: (i64, i64) = (1, 1);
const ZZ14_ISO_HI: (i64, i64) = (2, 1);
const ZZ18_MINPOLY: [i64; 4] = [-1, -3, 0, 1];
const ZZ18_ISO_LO: (i64, i64) = (1, 1);
const ZZ18_ISO_HI: (i64, i64) = (2, 1);
fn s14(a: i64, b: i64, d: i64) -> i8 {
sign_at_cubic_root_in_interval([a, b, d], ZZ14_MINPOLY, ZZ14_ISO_LO, ZZ14_ISO_HI)
}
fn s18(a: i64, b: i64, d: i64) -> i8 {
sign_at_cubic_root_in_interval([a, b, d], ZZ18_MINPOLY, ZZ18_ISO_LO, ZZ18_ISO_HI)
}
#[test]
fn cubic_sign_exhaustive_small_grid_matches_sturm() {
for (minpoly, lo, hi) in [
(ZZ14_MINPOLY, ZZ14_ISO_LO, ZZ14_ISO_HI),
(ZZ18_MINPOLY, ZZ18_ISO_LO, ZZ18_ISO_HI),
] {
for a in -6..=6 {
for b in -6..=6 {
for d in -6..=6 {
let got = sign_at_cubic_root_in_interval([a, b, d], minpoly, lo, hi);
let want = super::sturm_sign_at_root(&minpoly, &[a, b, d], lo.0, hi.0);
assert_eq!(
got, want,
"minpoly {minpoly:?} coeffs [{a},{b},{d}]: fast/bisect={got} sturm={want}"
);
}
}
}
}
}
#[test]
fn zero_input_is_zero() {
assert_eq!(s14(0, 0, 0), 0);
assert_eq!(s18(0, 0, 0), 0);
}
#[test]
fn single_component_vectors() {
assert_eq!(s14(5, 0, 0), 1);
assert_eq!(s14(-3, 0, 0), -1);
assert_eq!(s14(0, 7, 0), 1);
assert_eq!(s14(0, -2, 0), -1);
assert_eq!(s18(0, 7, 0), 1);
assert_eq!(s14(0, 0, 4), 1);
assert_eq!(s14(0, 0, -9), -1);
assert_eq!(s18(0, 0, 4), 1);
}
#[test]
fn cubic_root_adversarial_near_zero() {
assert_eq!(s14(-17, -23, 18), 1);
assert_eq!(s14(17, 23, -18), -1); assert_eq!(s14(-18, 19, -5), 1);
assert_eq!(s14(-22, 5, 4), -1);
assert_eq!(s14(-5, 28, -14), -1);
assert_eq!(s18(-7, 15, -6), -1);
assert_eq!(s18(7, -15, 6), 1);
assert_eq!(s18(-29, -9, 13), 1);
assert_eq!(s18(-6, -25, 15), -1);
assert_eq!(s14(127, -178, 59), -1);
}
#[test]
fn sympy_oracle_zz14() {
const CASES: &[(i64, i64, i64, i8)] = &[
(3, -15, 7, -1),
(20, -16, -15, -1),
(-11, -10, -13, -1),
(14, 15, 16, 1),
(16, -18, -15, -1),
(-10, -1, -2, -1),
(6, -19, 6, -1),
(-2, -4, -11, -1),
(-5, 6, -14, -1),
(-5, 6, -20, -1),
];
for &(a, b, d, expected) in CASES {
let got = s14(a, b, d);
assert_eq!(
got, expected,
"ZZ14 sign({a} + {b}*c + {d}*c^2) = got {got}, expected {expected}"
);
}
}
#[test]
fn sympy_oracle_zz18() {
const CASES: &[(i64, i64, i64, i8)] = &[
(7, 13, -8, 1),
(-3, 19, -7, 1),
(6, 4, 15, 1),
(-18, 18, 18, 1),
(-17, 2, -16, -1),
(16, 18, 13, 1),
(-4, -2, 6, 1),
(-8, 6, -17, -1),
(14, -4, 18, 1),
(-17, 19, 15, 1),
];
for &(a, b, d, expected) in CASES {
let got = s18(a, b, d);
assert_eq!(
got, expected,
"ZZ18 sign({a} + {b}*c + {d}*c^2) = got {got}, expected {expected}"
);
}
}
}
pub fn sign_at_s_times_x_minus_k(
x: [i64; 3],
k: i64,
minpoly: [i64; 4],
iso_lo: (i64, i64),
iso_hi: (i64, i64),
) -> i8 {
let sx = sign_at_cubic_root_in_interval(x, minpoly, iso_lo, iso_hi);
if sx == 0 {
return -(k.signum() as i8);
}
if k == 0 {
return sx;
}
let sk = k.signum() as i8;
if sx != sk {
return sx;
}
let c3: [i128; 3] = [
-(minpoly[0] as i128),
-(minpoly[1] as i128),
-(minpoly[2] as i128),
];
let x128: [i128; 3] = [x[0] as i128, x[1] as i128, x[2] as i128];
let x_squared = poly_mul_deg2_mod_cubic(x128, x128, c3);
let four_minus_c2: [i128; 3] = [4, 0, -1];
let scaled = poly_mul_deg2_mod_cubic(four_minus_c2, x_squared, c3);
let k128 = k as i128;
let k_sq = k128
.checked_mul(k128)
.expect("sign_at_s_times_x_minus_k: K^2 overflow");
let result: [i128; 3] = [
scaled[0]
.checked_sub(k_sq)
.expect("sign_at_s_times_x_minus_k: const subtraction overflow"),
scaled[1],
scaled[2],
];
let result_i64: [i64; 3] = [
result[0]
.try_into()
.expect("sign_at_s_times_x_minus_k: result[0] exceeds i64"),
result[1]
.try_into()
.expect("sign_at_s_times_x_minus_k: result[1] exceeds i64"),
result[2]
.try_into()
.expect("sign_at_s_times_x_minus_k: result[2] exceeds i64"),
];
let sign_diff = sign_at_cubic_root_in_interval(result_i64, minpoly, iso_lo, iso_hi);
(sx as i16 * sign_diff as i16) as i8
}
fn poly_mul_deg2_mod_cubic(a: [i128; 3], b: [i128; 3], c3: [i128; 3]) -> [i128; 3] {
let c4: [i128; 3] = [c3[2] * c3[0], c3[0] + c3[2] * c3[1], c3[1] + c3[2] * c3[2]];
let coef_c3 = a[1] * b[2] + a[2] * b[1];
let coef_c4 = a[2] * b[2];
[
a[0] * b[0] + coef_c3 * c3[0] + coef_c4 * c4[0],
a[0] * b[1] + a[1] * b[0] + coef_c3 * c3[1] + coef_c4 * c4[1],
a[0] * b[2] + a[1] * b[1] + a[2] * b[0] + coef_c3 * c3[2] + coef_c4 * c4[2],
]
}
#[cfg(test)]
mod s_times_x_minus_k_tests {
use super::sign_at_s_times_x_minus_k;
const ZZ14_MINPOLY: [i64; 4] = [1, -2, -1, 1];
const ZZ14_ISO_LO: (i64, i64) = (1, 1);
const ZZ14_ISO_HI: (i64, i64) = (2, 1);
const ZZ18_MINPOLY: [i64; 4] = [-1, -3, 0, 1];
const ZZ18_ISO_LO: (i64, i64) = (1, 1);
const ZZ18_ISO_HI: (i64, i64) = (2, 1);
fn s14(n0: i64, n1: i64, n2: i64, k: i64) -> i8 {
sign_at_s_times_x_minus_k([n0, n1, n2], k, ZZ14_MINPOLY, ZZ14_ISO_LO, ZZ14_ISO_HI)
}
fn s18(n0: i64, n1: i64, n2: i64, k: i64) -> i8 {
sign_at_s_times_x_minus_k([n0, n1, n2], k, ZZ18_MINPOLY, ZZ18_ISO_LO, ZZ18_ISO_HI)
}
#[test]
fn zero_x_and_zero_k_special_cases() {
assert_eq!(s14(0, 0, 0, 0), 0);
assert_eq!(s14(0, 0, 0, 5), -1);
assert_eq!(s14(0, 0, 0, -5), 1);
assert_eq!(s14(1, 0, 0, 0), 1);
assert_eq!(s14(-1, 0, 0, 0), -1);
}
#[test]
fn sympy_oracle_zz14_s_times_x_minus_k() {
const CASES: &[(i64, i64, i64, i64, i8)] = &[
(-5, 2, 9, -11, 1),
(5, -4, 4, 8, 1),
(-1, 9, -10, -2, -1),
(-2, -7, 6, -4, 1),
(8, 8, -6, -2, 1),
(-1, -8, -6, -8, -1),
(1, 6, -10, 7, -1),
(-6, 6, -2, -14, 1),
(8, 3, -6, 9, -1),
(-3, -5, -9, 3, -1),
];
for &(n0, n1, n2, k, expected) in CASES {
let got = s14(n0, n1, n2, k);
assert_eq!(
got, expected,
"ZZ14 sign(s*({n0} + {n1}*c + {n2}*c^2) - {k}) = got {got}, expected {expected}"
);
}
}
#[test]
fn sympy_oracle_zz18_s_times_x_minus_k() {
const CASES: &[(i64, i64, i64, i64, i8)] = &[
(10, 5, 0, -3, 1),
(3, 3, 0, -11, 1),
(-7, 0, 9, 5, 1),
(7, 1, 3, -3, 1),
(-3, 9, -5, -11, 1),
(0, 6, 0, -12, 1),
(5, -2, 8, -7, 1),
(-1, -2, -9, -14, -1),
(-5, -8, -6, -11, -1),
(1, -8, 4, -8, 1),
];
for &(n0, n1, n2, k, expected) in CASES {
let got = s18(n0, n1, n2, k);
assert_eq!(
got, expected,
"ZZ18 sign(s*({n0} + {n1}*c + {n2}*c^2) - {k}) = got {got}, expected {expected}"
);
}
}
#[test]
fn magnitude_comparison_close_cases() {
assert_eq!(s14(5, 0, 0, 4), 1);
assert_eq!(s14(5, 0, 0, 5), -1);
assert_eq!(s14(10, 0, 0, 8), 1);
assert_eq!(s14(10, 0, 0, 9), -1);
assert_eq!(s14(-5, 0, 0, -4), -1);
assert_eq!(s14(-5, 0, 0, -5), 1);
}
}
#[cfg(test)]
mod fuzz_tests {
use super::*;
struct Xorshift64(u64);
impl Xorshift64 {
fn new(seed: u64) -> Self {
Self(seed)
}
fn next(&mut self) -> u64 {
let mut x = self.0;
x ^= x << 13;
x ^= x >> 7;
x ^= x << 17;
self.0 = x;
x
}
fn next_i64_in(&mut self, lo: i64, hi: i64) -> i64 {
let range = (hi - lo + 1) as u64;
lo + (self.next() % range) as i64
}
}
fn f64_value_and_safety(a: i64, b: i64, d: i64, c_f64: f64) -> (f64, f64) {
let af = a as f64;
let bf = b as f64;
let df = d as f64;
let cc = c_f64 * c_f64;
let value = af + bf * c_f64 + df * cc;
let max_term = af.abs().max((bf * c_f64).abs()).max((df * cc).abs());
let safety = 1e-12 * (1.0 + max_term);
(value, safety)
}
#[allow(clippy::too_many_arguments)]
fn fuzz_cubic_root<F>(
seed: u64,
iterations: u64,
coeff_lo: i64,
coeff_hi: i64,
minpoly: [i64; 4],
iso_lo: (i64, i64),
iso_hi: (i64, i64),
c_f64: F,
label: &str,
) where
F: Fn() -> f64,
{
let c = c_f64();
let mut rng = Xorshift64::new(seed);
let mut checked = 0u64;
let mut skipped = 0u64;
for _ in 0..iterations {
let a = rng.next_i64_in(coeff_lo, coeff_hi);
let b = rng.next_i64_in(coeff_lo, coeff_hi);
let d = rng.next_i64_in(coeff_lo, coeff_hi);
let exact = sign_at_cubic_root_in_interval([a, b, d], minpoly, iso_lo, iso_hi);
let (value, safety) = f64_value_and_safety(a, b, d, c);
if a == 0 && b == 0 && d == 0 {
assert_eq!(exact, 0, "{label}: zero input gave nonzero sign");
checked += 1;
continue;
}
if value.abs() > safety {
let f64_sign = if value > 0.0 { 1i8 } else { -1 };
assert_eq!(
exact, f64_sign,
"{label}: f64 says {f64_sign} for a + b*c + d*c^2 = {value} \
(a={a}, b={b}, d={d}), exact helper says {exact}",
);
checked += 1;
} else {
skipped += 1;
}
}
assert!(
checked > iterations / 10,
"{label}: too few cases verified ({checked}/{iterations}); safety margin too wide?",
);
eprintln!("{label}: {checked} cases verified, {skipped} skipped (close-to-zero)",);
}
const ZZ14_MINPOLY: [i64; 4] = [1, -2, -1, 1];
const ZZ18_MINPOLY: [i64; 4] = [-1, -3, 0, 1];
const ISO: ((i64, i64), (i64, i64)) = ((1, 1), (2, 1));
#[test]
fn fuzz_cubic_root_zz14() {
fuzz_cubic_root(
0xDEADBEEFCAFEBABE,
5000,
-1_000,
1_000,
ZZ14_MINPOLY,
ISO.0,
ISO.1,
|| 2.0 * (std::f64::consts::PI / 7.0).cos(),
"ZZ14",
);
}
#[test]
fn fuzz_cubic_root_zz18() {
fuzz_cubic_root(
0xC0FFEEDEADBEEF77,
5000,
-1_000,
1_000,
ZZ18_MINPOLY,
ISO.0,
ISO.1,
|| 2.0 * (std::f64::consts::PI / 9.0).cos(),
"ZZ18",
);
}
#[allow(clippy::too_many_arguments)]
fn fuzz_s_minus_k<F, G>(
seed: u64,
iterations: u64,
coeff_lo: i64,
coeff_hi: i64,
minpoly: [i64; 4],
iso_lo: (i64, i64),
iso_hi: (i64, i64),
c_f64: F,
s_f64: G,
label: &str,
) where
F: Fn() -> f64,
G: Fn() -> f64,
{
let c = c_f64();
let s = s_f64();
let mut rng = Xorshift64::new(seed);
let mut checked = 0u64;
let mut skipped = 0u64;
for _ in 0..iterations {
let a = rng.next_i64_in(coeff_lo, coeff_hi);
let b = rng.next_i64_in(coeff_lo, coeff_hi);
let d = rng.next_i64_in(coeff_lo, coeff_hi);
let k = rng.next_i64_in(coeff_lo, coeff_hi);
let exact = sign_at_s_times_x_minus_k([a, b, d], k, minpoly, iso_lo, iso_hi);
let x_value = a as f64 + b as f64 * c + d as f64 * c * c;
let value = s * x_value - k as f64;
let max_op = ((a as f64).abs() + (b as f64 * c).abs() + (d as f64 * c * c).abs()) * s
+ (k as f64).abs();
let safety = 1e-12 * (1.0 + max_op);
if value.abs() > safety {
let f64_sign = if value > 0.0 { 1i8 } else { -1 };
assert_eq!(
exact, f64_sign,
"{label}: f64 says {f64_sign} for s*X - K = {value} \
(a={a}, b={b}, d={d}, K={k}), exact says {exact}",
);
checked += 1;
} else {
skipped += 1;
}
}
assert!(
checked > iterations / 10,
"{label}: too few cases verified ({checked}/{iterations})",
);
eprintln!("{label}: {checked} cases verified, {skipped} skipped");
}
#[test]
fn fuzz_s_times_x_minus_k_zz14() {
fuzz_s_minus_k(
0xABCDEF0123456789,
5000,
-1_000,
1_000,
ZZ14_MINPOLY,
ISO.0,
ISO.1,
|| 2.0 * (std::f64::consts::PI / 7.0).cos(),
|| 2.0 * (std::f64::consts::PI / 7.0).sin(),
"ZZ14",
);
}
#[test]
fn fuzz_s_times_x_minus_k_zz18() {
fuzz_s_minus_k(
0x9876543210ABCDEF,
5000,
-1_000,
1_000,
ZZ18_MINPOLY,
ISO.0,
ISO.1,
|| 2.0 * (std::f64::consts::PI / 9.0).cos(),
|| 2.0 * (std::f64::consts::PI / 9.0).sin(),
"ZZ18",
);
}
#[allow(clippy::too_many_arguments)]
fn sturm_matches_bisection(
seed: u64,
iterations: u64,
coeff_lo: i64,
coeff_hi: i64,
minpoly: [i64; 4],
iso_lo: i64,
iso_hi: i64,
label: &str,
) {
let mut rng = Xorshift64::new(seed);
for _ in 0..iterations {
let a = rng.next_i64_in(coeff_lo, coeff_hi);
let b = rng.next_i64_in(coeff_lo, coeff_hi);
let d = rng.next_i64_in(coeff_lo, coeff_hi);
let bis = sign_at_cubic_root_in_interval([a, b, d], minpoly, (iso_lo, 1), (iso_hi, 1));
let sturm = sturm_sign_at_root(&minpoly, &[a, b, d], iso_lo, iso_hi);
assert_eq!(
bis, sturm,
"{label}: bisection={bis} Sturm={sturm} for f(x) = {a} + {b}*x + {d}*x^2"
);
}
}
#[test]
fn sturm_matches_bisection_zz14() {
sturm_matches_bisection(
0x11223344AABBCCDD,
5000,
-1_000,
1_000,
ZZ14_MINPOLY,
1,
2,
"ZZ14",
);
}
#[test]
fn sturm_matches_bisection_zz18() {
sturm_matches_bisection(
0x55667788EEFF0011,
5000,
-1_000,
1_000,
ZZ18_MINPOLY,
1,
2,
"ZZ18",
);
}
#[test]
fn sturm_matches_bisection_large_coeffs() {
sturm_matches_bisection(
0xFEEDFACEDEADBEEF,
1000,
-5_000,
5_000,
ZZ14_MINPOLY,
1,
2,
"ZZ14 large",
);
sturm_matches_bisection(
0xBADCAFEBABE00011,
1000,
-5_000,
5_000,
ZZ18_MINPOLY,
1,
2,
"ZZ18 large",
);
}
}