#[derive(Debug, Clone, Copy, PartialEq, Default, serde::Serialize, serde::Deserialize)]
#[non_exhaustive]
pub enum SGBTVariant {
#[default]
Standard,
Skip {
k: usize,
},
MultipleIterations {
multiplier: f64,
},
}
impl std::fmt::Display for SGBTVariant {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Standard => write!(f, "Standard"),
Self::Skip { k } => write!(f, "Skip(k={})", k),
Self::MultipleIterations { multiplier } => {
write!(f, "MultipleIterations(multiplier={})", multiplier)
}
}
}
}
impl SGBTVariant {
pub fn train_count(&self, hessian: f64, rng_state: &mut u64) -> usize {
match self {
SGBTVariant::Standard => 1,
SGBTVariant::Skip { k } => {
*rng_state ^= *rng_state << 13;
*rng_state ^= *rng_state >> 7;
*rng_state ^= *rng_state << 17;
if (*rng_state % (*k as u64)) == 0 {
0 } else {
1
}
}
SGBTVariant::MultipleIterations { multiplier } => {
let count = (hessian.abs() * multiplier).ceil() as usize;
count.max(1) }
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn standard_always_one() {
let variant = SGBTVariant::Standard;
let mut rng: u64 = 12345;
for &h in &[0.0, 0.5, 1.0, -2.0, 100.0, f64::MIN_POSITIVE] {
assert_eq!(variant.train_count(h, &mut rng), 1);
}
}
#[test]
fn skip_skips_roughly_one_in_k() {
let variant = SGBTVariant::Skip { k: 10 };
let mut rng: u64 = 0xDEAD_BEEF_CAFE_1234; let trials = 10_000;
let mut skip_count = 0usize;
for _ in 0..trials {
if variant.train_count(0.0, &mut rng) == 0 {
skip_count += 1;
}
}
let skip_rate = skip_count as f64 / trials as f64;
assert!(
skip_rate > 0.05 && skip_rate < 0.15,
"expected ~10% skip rate, got {:.2}% ({} / {})",
skip_rate * 100.0,
skip_count,
trials
);
}
#[test]
fn skip_k_one_always_skips() {
let variant = SGBTVariant::Skip { k: 1 };
let mut rng: u64 = 42;
for _ in 0..100 {
assert_eq!(variant.train_count(0.0, &mut rng), 0);
}
}
#[test]
fn mi_hessian_half_multiplier_ten() {
let variant = SGBTVariant::MultipleIterations { multiplier: 10.0 };
let mut rng: u64 = 1;
let count = variant.train_count(0.5, &mut rng);
assert_eq!(count, 5, "ceil(0.5 * 10) = 5");
}
#[test]
fn mi_negative_hessian_uses_abs() {
let variant = SGBTVariant::MultipleIterations { multiplier: 10.0 };
let mut rng: u64 = 1;
let count = variant.train_count(-0.5, &mut rng);
assert_eq!(count, 5, "ceil(|-0.5| * 10) = 5");
}
#[test]
fn mi_always_at_least_one() {
let variant = SGBTVariant::MultipleIterations { multiplier: 1.0 };
let mut rng: u64 = 1;
assert_eq!(variant.train_count(0.0, &mut rng), 1);
assert_eq!(variant.train_count(1e-20, &mut rng), 1);
}
#[test]
fn mi_large_hessian() {
let variant = SGBTVariant::MultipleIterations { multiplier: 2.0 };
let mut rng: u64 = 1;
assert_eq!(variant.train_count(3.7, &mut rng), 8);
}
#[test]
fn default_is_standard() {
let variant = SGBTVariant::default();
assert_eq!(variant, SGBTVariant::Standard);
}
#[test]
fn partial_eq() {
assert_eq!(SGBTVariant::Standard, SGBTVariant::Standard);
assert_ne!(SGBTVariant::Standard, SGBTVariant::Skip { k: 10 });
assert_eq!(SGBTVariant::Skip { k: 5 }, SGBTVariant::Skip { k: 5 });
assert_ne!(SGBTVariant::Skip { k: 5 }, SGBTVariant::Skip { k: 10 });
assert_eq!(
SGBTVariant::MultipleIterations { multiplier: 2.0 },
SGBTVariant::MultipleIterations { multiplier: 2.0 }
);
}
#[test]
fn clone_preserves_data() {
let original = SGBTVariant::MultipleIterations { multiplier: 7.5 };
let cloned = original;
assert_eq!(original, cloned);
}
#[test]
fn skip_is_non_degenerate() {
let variant = SGBTVariant::Skip { k: 2 };
let mut rng: u64 = 0xABCD_1234_5678_EF90;
let mut saw_zero = false;
let mut saw_one = false;
for _ in 0..100 {
match variant.train_count(0.0, &mut rng) {
0 => saw_zero = true,
1 => saw_one = true,
_ => panic!("Skip should only return 0 or 1"),
}
if saw_zero && saw_one {
break;
}
}
assert!(saw_zero && saw_one, "k=2 should produce both 0 and 1");
}
}