use crate::T;
pub(crate) fn phi(k: u32, p: u32) -> u32 {
let v = if k == 0 {
T } else {
T + 1 + ((k - 1) >> T)
};
v.min(64 - p)
}
pub(crate) fn omega(u: u32, p: u32) -> f64 {
let phi_u = phi(u, p);
let two_t = (1u32 << T) as f64;
(two_t * (1.0 + phi_u as f64 - T as f64) - u as f64) * pow2_neg(phi_u)
}
#[inline]
pub(crate) fn pow2_neg(x: u32) -> f64 {
f64::from_bits((1023u64.wrapping_sub(x as u64)) << 52)
}
pub(crate) fn h(r: u32, p: u32, d: u32) -> f64 {
let d_mask = (1u32 << d) - 1;
let u = r >> d;
let bitmap = r & d_mask;
let m = (1u64 << p) as f64;
let mut acc = omega(u, p);
if u >= 2 {
let k_lo = u.saturating_sub(d).max(1);
for k in k_lo..u {
let j = u - k; let l_j = (bitmap >> (d - j)) & 1;
acc += (1.0 - l_j as f64) * pow2_neg(phi(k, p));
}
}
acc / m
}
pub(crate) fn compute_alpha_beta<I>(registers: I, p: u32, d: u32) -> (f64, Vec<u32>)
where
I: Iterator<Item = u32>,
{
let d_mask = (1u32 << d) - 1;
let beta_len = (64 - p - T) as usize;
let mut beta = vec![0u32; beta_len];
let mut alpha = 0.0_f64;
for r in registers {
let u = r >> d;
let bitmap = r & d_mask;
alpha += omega(u, p);
if u >= 1 {
let j = phi(u, p);
beta[(j - T - 1) as usize] += 1;
if u >= 2 {
let k_lo = u.saturating_sub(d).max(1);
for k in k_lo..u {
let bit_pos = d - (u - k);
let bit_set = (bitmap >> bit_pos) & 1 == 1;
let phi_k = phi(k, p);
if bit_set {
beta[(phi_k - T - 1) as usize] += 1;
} else {
alpha += pow2_neg(phi_k);
}
}
}
}
}
(alpha, beta)
}
fn g(y: f64, beta: &[u32]) -> f64 {
let mut sum = 0.0;
for (idx, &b) in beta.iter().enumerate() {
if b == 0 {
continue;
}
let u = idx as u32 + T + 1;
let scale = pow2_neg(u);
let denom = (y * scale).exp_m1();
if !denom.is_finite() || denom == 0.0 {
continue;
}
sum += b as f64 * scale / denom;
}
sum
}
pub(crate) fn solve_ml(alpha: f64, beta: &[u32], p: u32) -> f64 {
if beta.iter().all(|&b| b == 0) {
return 0.0;
}
if alpha <= 0.0 {
return f64::INFINITY;
}
let mut lo: f64 = -200.0;
let mut hi: f64 = 200.0;
for _ in 0..200 {
let mid = 0.5 * (lo + hi);
let y = (mid * std::f64::consts::LN_2).exp();
let gv = g(y, beta);
if gv > alpha {
lo = mid;
} else {
hi = mid;
}
if hi - lo < 1e-13 {
break;
}
}
let mid = 0.5 * (lo + hi);
let y = (mid * std::f64::consts::LN_2).exp();
let m = (1u64 << p) as f64;
m * y
}
#[inline]
pub(crate) fn apply_insert(r: u32, k: u32, d: u32) -> u32 {
let d_mask = (1u32 << d) - 1;
let u = r >> d;
if k > u {
let delta = (k - u) as u64;
let bitmap = (r & d_mask) as u64;
let combined = (1u64 << d) | bitmap;
let new_low = if delta <= u64::from(d) + 1 {
combined >> delta
} else {
0
};
(k << d) | (new_low as u32 & d_mask)
} else if k < u {
let neg_delta = (u - k) as u64;
if neg_delta <= u64::from(d) {
let pos = d - neg_delta as u32;
r | (1u32 << pos)
} else {
r
}
} else {
r
}
}
pub(crate) fn merge_register(r: u32, r2: u32, d: u32) -> u32 {
let d_mask = (1u32 << d) - 1;
let u = r >> d;
let u2 = r2 >> d;
if u > u2 && u2 > 0 {
let bitmap2 = (r2 & d_mask) as u64;
let combined = (1u64 << d) | bitmap2;
let shift = u - u2;
let extra = if shift <= d + 1 { combined >> shift } else { 0 };
r | (extra as u32 & d_mask)
} else if u2 > u && u > 0 {
let bitmap = (r & d_mask) as u64;
let combined = (1u64 << d) | bitmap;
let shift = u2 - u;
let extra = if shift <= d + 1 { combined >> shift } else { 0 };
r2 | (extra as u32 & d_mask)
} else {
r | r2
}
}
#[inline]
pub(crate) fn hash_to_register_k(hash: u64, p: u32) -> (usize, u32) {
let p_plus_t = p + T;
let i = ((hash >> T) & ((1u64 << p) - 1)) as usize;
let a = hash | ((1u64 << p_plus_t) - 1);
let nlz_a = a.leading_zeros() as u64;
let t_mask = (1u64 << T) - 1;
let low_t = hash & t_mask;
let k = ((nlz_a << T) + low_t + 1) as u32;
debug_assert!(k >= 1);
debug_assert!(k as u64 <= ((65 - p as u64 - T as u64) << T));
(i, k)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{MAX_P, MIN_P};
#[test]
fn phi_matches_definition() {
let p = 8;
assert_eq!(phi(0, p), T);
assert_eq!(phi(1, p), T + 1);
for k in 1u32..1000 {
let expected = (T + 1 + (k - 1) / (1 << T)).min(64 - p);
assert_eq!(phi(k, p), expected);
}
}
#[test]
fn omega_zero_is_one() {
for p in MIN_P..=18 {
let w = omega(0, p);
assert!((w - 1.0).abs() < 1e-12, "ω(0) for p={p} = {w}");
}
}
#[test]
fn pow2_neg_matches_powi() {
for x in 0u32..64 {
let fast = pow2_neg(x);
let reference = 2.0_f64.powi(-(x as i32));
assert!((fast - reference).abs() < 1e-300 || (fast / reference - 1.0).abs() < 1e-15,);
}
let _ = MAX_P;
}
#[test]
fn h_of_zero_is_one_over_m() {
for p in MIN_P..=18 {
let m = (1u64 << p) as f64;
for d in [20u32, 24] {
let value = h(0, p, d);
assert!((value - 1.0 / m).abs() < 1e-15);
}
}
}
}