use std::sync::Arc;
use crate::outer_subsample::OuterScoreSubsample;
#[derive(Clone, Debug)]
pub struct RowSubsampleMask {
pub id: u64,
pub mask: Option<Arc<OuterScoreSubsample>>,
}
impl RowSubsampleMask {
pub fn full_data(n: usize) -> Self {
Self {
id: hash_full(n),
mask: None,
}
}
pub fn subsample(mask: Arc<OuterScoreSubsample>) -> Self {
let id = hash_subsample(&mask);
Self {
id,
mask: Some(mask),
}
}
pub fn indices_and_weights(&self, n: usize) -> (Vec<usize>, Vec<f64>) {
match self.mask.as_ref() {
Some(m) => {
assert_eq!(
m.n_full, n,
"RowSubsampleMask n_full ({}) must match caller n ({})",
m.n_full, n
);
let indices: Vec<usize> = m.mask.as_ref().clone();
let mut weights = vec![1.0_f64; n];
for r in m.rows.iter() {
if r.index < n {
weights[r.index] = r.weight;
}
}
(indices, weights)
}
None => ((0..n).collect(), vec![1.0_f64; n]),
}
}
}
fn splitmix64(x: u64) -> u64 {
gam_linalg::utils::splitmix64_hash(x)
}
const FULL_DATA_ROW_SUBSAMPLE_SENTINEL: u64 = 0xA5A5_5A5A_DEAD_BEEF;
fn hash_full(n: usize) -> u64 {
let mut h = splitmix64(FULL_DATA_ROW_SUBSAMPLE_SENTINEL ^ (n as u64));
if h == 0 {
h = 0x1234_5678_9ABC_DEF0;
}
h
}
fn hash_subsample(mask: &Arc<OuterScoreSubsample>) -> u64 {
let ptr = Arc::as_ptr(mask) as u64;
let mut h = splitmix64(ptr);
h ^= splitmix64(mask.n_full as u64);
h ^= splitmix64(mask.len() as u64);
h ^= splitmix64(mask.seed);
h ^= splitmix64((mask.weight_scale.to_bits()) ^ 0xC0FF_EE00_0000_0000);
if h == 0 {
h = 0xDEAD_BEEF_FEED_FACE;
}
h
}
#[cfg(test)]
mod tests {
use super::*;
use crate::outer_subsample::OuterScoreSubsample;
#[test]
fn full_data_id_is_stable_per_n() {
let a = RowSubsampleMask::full_data(100);
let b = RowSubsampleMask::full_data(100);
let c = RowSubsampleMask::full_data(101);
assert_eq!(a.id, b.id);
assert_ne!(a.id, c.id);
assert!(a.mask.is_none());
}
#[test]
fn subsample_id_matches_for_same_arc() {
let s = Arc::new(OuterScoreSubsample::from_uniform_inclusion_mask(
vec![1, 3, 5],
10,
42,
));
let a = RowSubsampleMask::subsample(Arc::clone(&s));
let b = RowSubsampleMask::subsample(Arc::clone(&s));
assert_eq!(a.id, b.id);
}
#[test]
fn subsample_id_differs_for_different_arcs() {
let s1 = Arc::new(OuterScoreSubsample::from_uniform_inclusion_mask(
vec![1, 3, 5],
10,
42,
));
let s2 = Arc::new(OuterScoreSubsample::from_uniform_inclusion_mask(
vec![1, 3, 5],
10,
42,
));
let a = RowSubsampleMask::subsample(s1);
let b = RowSubsampleMask::subsample(s2);
assert_ne!(a.id, b.id);
}
#[test]
fn indices_and_weights_full_data() {
let rm = RowSubsampleMask::full_data(4);
let (idx, w) = rm.indices_and_weights(4);
assert_eq!(idx, vec![0, 1, 2, 3]);
assert_eq!(w, vec![1.0, 1.0, 1.0, 1.0]);
}
#[test]
fn indices_and_weights_subsample() {
let s = Arc::new(OuterScoreSubsample::from_uniform_inclusion_mask(
vec![0, 2],
4,
7,
));
let rm = RowSubsampleMask::subsample(s);
let (idx, w) = rm.indices_and_weights(4);
assert_eq!(idx, vec![0, 2]);
assert_eq!(w.len(), 4);
assert!(w[0] > 0.0 && w[2] > 0.0);
}
}