use crate::linalg::utils::splitmix64_hash;
const HASH_SPACE: f64 = u64::MAX as f64 + 1.0;
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct RhoStepPlan {
pub step: usize,
pub fraction: f64,
pub importance_weight: f64,
pub is_full_pass: bool,
}
impl RhoStepPlan {
#[inline]
pub fn includes(&self, row_id: u64) -> Option<f64> {
if self.is_full_pass {
return Some(1.0);
}
if row_in_fraction(row_id, self.fraction) {
Some(self.importance_weight)
} else {
None
}
}
}
#[inline]
pub fn row_in_fraction(row_id: u64, fraction: f64) -> bool {
if fraction >= 1.0 {
return true;
}
if fraction <= 0.0 {
return false;
}
let threshold = (fraction * HASH_SPACE) as u64;
splitmix64_hash(row_id) < threshold
}
#[derive(Debug, Clone)]
pub struct RhoCascadeSchedule {
steps: Vec<RhoStepPlan>,
total_rows: u64,
}
const MIN_FRACTION: f64 = 1.0 / 64.0;
const FULL_PASS_TAIL_STEPS: usize = 2;
impl RhoCascadeSchedule {
pub fn new(total_rows: u64, n_steps: usize) -> Self {
let n_steps = n_steps.max(1);
let min_rows = MIN_SUBSAMPLE_ROWS.min(total_rows.max(1));
let floor_fraction = if total_rows == 0 {
1.0
} else {
(min_rows as f64 / total_rows as f64)
.max(MIN_FRACTION)
.min(1.0)
};
let full_from = n_steps.saturating_sub(FULL_PASS_TAIL_STEPS);
let mut steps = Vec::with_capacity(n_steps);
for step in 0..n_steps {
let (fraction, is_full_pass) = if step >= full_from {
(1.0, true)
} else {
let subsample_steps = full_from.max(1);
let t = step as f64 / subsample_steps as f64;
let log_floor = floor_fraction.ln();
let frac = (log_floor * (1.0 - t)).exp();
let frac = frac.clamp(floor_fraction, 1.0);
(frac, frac >= 1.0)
};
let importance_weight = if fraction > 0.0 { 1.0 / fraction } else { 1.0 };
steps.push(RhoStepPlan {
step,
fraction,
importance_weight,
is_full_pass,
});
}
Self { steps, total_rows }
}
pub fn steps(&self) -> &[RhoStepPlan] {
&self.steps
}
pub fn step_plan(&self, step: usize) -> RhoStepPlan {
if let Some(plan) = self.steps.get(step) {
*plan
} else {
RhoStepPlan {
step,
fraction: 1.0,
importance_weight: 1.0,
is_full_pass: true,
}
}
}
pub fn total_rows(&self) -> u64 {
self.total_rows
}
pub fn expected_rows(&self, step: usize) -> u64 {
let plan = self.step_plan(step);
if plan.is_full_pass {
self.total_rows
} else {
(plan.fraction * self.total_rows as f64).round() as u64
}
}
}
const MIN_SUBSAMPLE_ROWS: u64 = 4096;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn full_fraction_includes_every_row() {
for id in 0..1000u64 {
assert!(row_in_fraction(id, 1.0));
assert!(row_in_fraction(id, 2.0));
}
}
#[test]
fn zero_fraction_excludes_every_row() {
for id in 0..1000u64 {
assert!(!row_in_fraction(id, 0.0));
assert!(!row_in_fraction(id, -0.5));
}
}
#[test]
fn subsample_is_deterministic() {
for id in 0..10_000u64 {
let a = row_in_fraction(id, 0.25);
let b = row_in_fraction(id, 0.25);
assert_eq!(a, b);
}
}
#[test]
fn subsample_fraction_is_approximately_realized() {
let n = 200_000u64;
let frac = 0.1;
let included = (0..n).filter(|&id| row_in_fraction(id, frac)).count();
let realized = included as f64 / n as f64;
assert!(
(realized - frac).abs() < 0.01,
"realized fraction {realized} too far from {frac}"
);
}
#[test]
fn importance_weight_unbiases_subsample() {
let n = 100_000u64;
let frac = 0.2;
let weight = 1.0 / frac;
let full_sum = n as f64; let sub_sum: f64 = (0..n)
.filter(|&id| row_in_fraction(id, frac))
.map(|_| weight)
.sum();
let rel_err = (sub_sum - full_sum).abs() / full_sum;
assert!(
rel_err < 0.02,
"weighted subsample {sub_sum} vs full {full_sum}"
);
}
#[test]
fn schedule_ends_in_full_passes() {
let sched = RhoCascadeSchedule::new(10_000_000, 8);
let steps = sched.steps();
let last = steps.last().expect("nonempty");
assert!(last.is_full_pass);
assert_eq!(last.importance_weight, 1.0);
assert!(steps[steps.len() - 1].is_full_pass);
assert!(steps[steps.len() - 2].is_full_pass);
}
#[test]
fn schedule_fraction_is_monotone_nondecreasing() {
let sched = RhoCascadeSchedule::new(10_000_000, 8);
let fracs: Vec<f64> = sched.steps().iter().map(|s| s.fraction).collect();
for w in fracs.windows(2) {
assert!(w[1] >= w[0] - 1e-12, "fractions not monotone: {fracs:?}");
}
assert!(sched.steps()[0].fraction < 1.0);
assert!(sched.steps()[0].importance_weight > 1.0);
}
#[test]
fn step_plan_includes_consistent_with_fraction() {
let sched = RhoCascadeSchedule::new(10_000_000, 8);
let plan = sched.step_plan(0);
for id in 0..1000u64 {
match plan.includes(id) {
Some(w) => {
assert!((w - plan.importance_weight).abs() < 1e-12);
assert!(row_in_fraction(id, plan.fraction) || plan.is_full_pass);
}
None => assert!(!row_in_fraction(id, plan.fraction)),
}
}
}
#[test]
fn tiny_corpus_is_all_full_passes() {
let sched = RhoCascadeSchedule::new(100, 5);
for s in sched.steps() {
assert!(s.is_full_pass);
assert_eq!(s.fraction, 1.0);
}
}
}