use crate::hazmat::float::ln;
use core::f64;
use crypto_bigint::{Concat, NonZero, Uint, cpubits};
pub fn estimate_primecount<const LIMBS: usize, const RHS_LIMBS: usize>(x: &Uint<LIMBS>) -> Uint<LIMBS>
where
Uint<LIMBS>: Concat<LIMBS, Output = Uint<RHS_LIMBS>>,
{
let scale_bits = {
cpubits! {
32 => { (LIMBS as u32 * 16).min(64) }
64 => { (LIMBS as u32 * 32).min(64) }
}
};
let total_scale_bits = 2 * scale_bits;
let denom_scale_factor = (1u128 << scale_bits) as f64;
let ln_x = ln(x);
if !ln_x.is_finite() || ln_x <= 1.0 {
return Uint::ZERO;
}
let ln_x_2 = ln_x * ln_x;
let ln_x_3 = ln_x_2 * ln_x;
let ln_x_4 = ln_x_3 * ln_x;
let f64_to_scaled_uint = |value: f64| -> NonZero<Uint<RHS_LIMBS>> {
let scaled = value * denom_scale_factor;
let scaled = libm::round(scaled).max(1.0) as u128;
let denom = Uint::<RHS_LIMBS>::from_u128(scaled);
NonZero::new(denom).expect("max(1.0) ensures value is at least 1")
};
let d1 = f64_to_scaled_uint(ln_x);
let d2 = f64_to_scaled_uint(ln_x_2);
let d3 = f64_to_scaled_uint(ln_x_3);
let d4 = f64_to_scaled_uint(ln_x_4);
let x_wide: Uint<RHS_LIMBS> = x.concat(&Uint::ZERO);
let x_wide_scaled = x_wide.wrapping_shl_vartime(total_scale_bits);
let term1_scaled = x_wide_scaled.wrapping_div(&d1);
let term2_scaled = x_wide_scaled.wrapping_div(&d2);
let term3_scaled = x_wide_scaled.wrapping_shl_vartime(1).wrapping_div(&d3);
let six = Uint::<LIMBS>::from(6u64);
let x_times_6 = x_wide_scaled.saturating_mul(&six);
let term4_scaled = x_times_6.wrapping_div(&d4);
let sum_scaled = term1_scaled
.wrapping_add(&term2_scaled)
.wrapping_add(&term3_scaled)
.wrapping_add(&term4_scaled);
let li_x = sum_scaled >> scale_bits;
li_x.resize_checked()
.expect("de-scaling should leave the high half zero")
}
#[cfg(test)]
mod tests {
use super::*;
use alloc::vec;
use alloc::vec::Vec;
use crypto_bigint::{U64, U128, U256, U1024};
#[test]
fn pi_x_2_500() {
let x = Uint::ONE << 500;
let sage_est = U1024::from_be_hex(concat![
"0000000000000000000000000000000000000000000000000000000000000000",
"0000000000000000000000000000000000000000000000000000000000000000",
"00000bda54dd744907290defac4f74bec507fdafd96e123c49bea56826f73702",
"a469b67453a13a6abc40e81b760a0a5fd95870dbb8bbe99973c246c49561e101"
]);
let estimate = estimate_primecount(&x);
assert_bit_difference(estimate, sage_est, 29);
}
#[test]
fn pi_x_max() {
let x = U1024::MAX;
let estimate = estimate_primecount(&x);
let sage_est = Uint::from_be_hex(concat![
"005c7682fe13533e630c22e716b35439b3dc61f1d4898d78a36dd9c9afc0745a",
"06d3a0deb93b77423f6d11c107283fcfdb8ae17de22b5197972f37cb480a2737",
"fe8d0f15202bb43bc1863b05f6d3849f865b95242eaec9789dcf3b40e92504d9",
"8258f80b394ebec1c63d1186f9552689076f709c2fd8497b5f78d82cea2c2137"
]);
assert_bit_difference(estimate, sage_est, 33);
}
#[test]
fn pi_x_random() {
let x = Uint::from_be_hex(concat![
"62A211E0907141403FD3EB60A82EAB701524710BDB024EB68DFF309389258B63",
"2EB9975D29F028F5137AC9DE870EB622D2D45A0D3A9C5801E8A3109BED220F82",
"890E108F1778E5523E3E89CCD5DEDB667E6C17E940E9D4C3F58575C86CB76403",
"017AD59D33AC084D2E58D81F8BB87A61B44677037A7DBDE04814256570DCBD7A"
]);
let sage_est = U1024::from_be_hex(concat![
"0023ac3184a0c4c8e9025e0ae9b44d7980cee1baacf69032bb898677841fac0e",
"516fa6bc8c1d1d3bb282622aa62c49f2d8e622d2f9aa80af3140c8c225136301",
"7c99621943c90ab55a6dd69a678110233254a1a3c50ceb1cdb516e7220a7514a",
"17b20114c7bef6f316e94cf7c9181187d70e751bda2e18695fa71e8015b8cf1c"
]);
let estimate = estimate_primecount(&x);
assert_bit_difference(estimate, sage_est, 34);
}
#[test]
fn pi_x_u128() {
let x = U128::from_u128(1000000000000000000000000);
let sage_est = U128::from_be_hex("00000000000003e76557786d0933dca8");
let estimate = estimate_primecount(&x);
assert_bit_difference(estimate, sage_est, 18);
}
#[test]
fn pi_x_u64() {
let x = U64::from_u64(10000000);
let sage_est = U64::from_be_hex("00000000000A2556");
let estimate = estimate_primecount(&x);
assert_bit_difference(estimate, sage_est, 9);
}
fn assert_bit_difference<const LIMBS: usize>(candidate: Uint<LIMBS>, reference: Uint<LIMBS>, min_bit_diff: u32) {
let delta = if reference > candidate {
reference - candidate
} else {
candidate - reference
};
assert!(
reference.bits_vartime() - delta.bits_vartime() >= min_bit_diff,
"Estimate not close enough: delta has {} bits, reference has {} bits. Difference should be >= {}\nEstimate: {candidate},\nReference: {reference}",
delta.bits_vartime(),
reference.bits_vartime(),
min_bit_diff
);
}
#[test]
fn pi_x_estimates_for_known_values() {
let pi_xs: Vec<(u128, u32)> = vec![
(1229, 4),
(9592, 5),
(78498, 6),
(664579, 7),
(5761455, 8),
(50847534, 9),
(455052511, 10),
(4118054813, 11),
(37607912018, 12),
(346065536839, 13),
(3204941750802, 14),
(29844570422669, 15),
(279238341033925, 16),
(2623557157654233, 17),
(24739954287740860, 18),
(234057667276344607, 19),
(2220819602560918840, 20),
(21127269486018731928, 21),
(201467286689315906290, 22),
(1925320391606803968923, 23),
(18435599767349200867866, 24),
(176846309399143769411680, 25),
(1699246750872437141327603, 26),
(16352460426841680446427399, 27),
(157589269275973410412739598, 28),
(1520698109714272166094258063, 29),
];
for (pi_x, exponent) in pi_xs.iter() {
let pi_x_wide = U256::from_u128(*pi_x);
let n = U256::from_u128(10u128.pow(*exponent));
let estimate = estimate_primecount(&n);
let delta = if pi_x_wide > estimate {
pi_x_wide - estimate
} else {
estimate - pi_x_wide
};
let delta = uint_to_u128(&delta);
let estimate_128 = uint_to_u128(&estimate);
let error = (delta as f64 / *pi_x as f64) * 100.0;
assert!(
error < 2.2,
"10^{exponent}:\t{pi_x} - {estimate_128} = {delta}, err: {error:.2}"
);
}
}
fn uint_to_u128<const LIMBS: usize>(x: &Uint<LIMBS>) -> u128 {
let limbs = x.as_limbs();
cpubits! {
32 => {
(limbs[3].0 as u128) << 96
| (limbs[2].0 as u128) << 64
| (limbs[1].0 as u128) << 32
| limbs[0].0 as u128
}
64 => {
((limbs[1].0 as u128) << 64) | limbs[0].0 as u128
}
}
}
}