use std::sync::Arc;
use crate::custom_family::BlockwiseFitOptions;
use crate::families::marginal_slope_shared::OuterScoreSubsample;
#[derive(Clone, Debug)]
pub struct RowMeasure {
pub id: u64,
pub mask: Option<Arc<OuterScoreSubsample>>,
}
impl RowMeasure {
pub fn full_data(n: usize) -> Self {
Self {
id: hash_full(n),
mask: None,
}
}
pub fn subsample(mask: Arc<OuterScoreSubsample>) -> Self {
let id = hash_subsample(&mask);
Self {
id,
mask: Some(mask),
}
}
pub fn from_options(options: &BlockwiseFitOptions, n: usize) -> Self {
match options.outer_score_subsample.as_ref() {
Some(mask) => Self::subsample(Arc::clone(mask)),
None => Self::full_data(n),
}
}
pub fn indices_and_weights(&self, n: usize) -> (Vec<usize>, Vec<f64>) {
match self.mask.as_ref() {
Some(m) => {
assert_eq!(
m.n_full, n,
"RowMeasure n_full ({}) must match caller n ({})",
m.n_full, n
);
let indices: Vec<usize> = m.mask.as_ref().clone();
let mut weights = vec![1.0_f64; n];
for r in m.rows.iter() {
if r.index < n {
weights[r.index] = r.weight;
}
}
(indices, weights)
}
None => ((0..n).collect(), vec![1.0_f64; n]),
}
}
}
fn splitmix64(mut x: u64) -> u64 {
x = x.wrapping_add(0x9E37_79B9_7F4A_7C15);
let mut z = x;
z = (z ^ (z >> 30)).wrapping_mul(0xBF58_476D_1CE4_E5B9);
z = (z ^ (z >> 27)).wrapping_mul(0x94D0_49BB_1331_11EB);
z ^ (z >> 31)
}
const FULL_DATA_ROW_MEASURE_SENTINEL: u64 = 0xA5A5_5A5A_DEAD_BEEF;
fn hash_full(n: usize) -> u64 {
let mut h = splitmix64(FULL_DATA_ROW_MEASURE_SENTINEL ^ (n as u64));
if h == 0 {
h = 0x1234_5678_9ABC_DEF0;
}
h
}
fn hash_subsample(mask: &Arc<OuterScoreSubsample>) -> u64 {
let ptr = Arc::as_ptr(mask) as u64;
let mut h = splitmix64(ptr);
h ^= splitmix64(mask.n_full as u64);
h ^= splitmix64(mask.len() as u64);
h ^= splitmix64(mask.seed);
h ^= splitmix64((mask.weight_scale.to_bits()) ^ 0xC0FF_EE00_0000_0000);
if h == 0 {
h = 0xDEAD_BEEF_FEED_FACE;
}
h
}
#[cfg(test)]
mod tests {
use super::*;
use crate::families::marginal_slope_shared::OuterScoreSubsample;
#[test]
fn full_data_id_is_stable_per_n() {
let a = RowMeasure::full_data(100);
let b = RowMeasure::full_data(100);
let c = RowMeasure::full_data(101);
assert_eq!(a.id, b.id);
assert_ne!(a.id, c.id);
assert!(a.mask.is_none());
}
#[test]
fn subsample_id_matches_for_same_arc() {
let s = Arc::new(OuterScoreSubsample::new(vec![1, 3, 5], 10, 42));
let a = RowMeasure::subsample(Arc::clone(&s));
let b = RowMeasure::subsample(Arc::clone(&s));
assert_eq!(a.id, b.id);
}
#[test]
fn subsample_id_differs_for_different_arcs() {
let s1 = Arc::new(OuterScoreSubsample::new(vec![1, 3, 5], 10, 42));
let s2 = Arc::new(OuterScoreSubsample::new(vec![1, 3, 5], 10, 42));
let a = RowMeasure::subsample(s1);
let b = RowMeasure::subsample(s2);
assert_ne!(a.id, b.id);
}
#[test]
fn indices_and_weights_full_data() {
let rm = RowMeasure::full_data(4);
let (idx, w) = rm.indices_and_weights(4);
assert_eq!(idx, vec![0, 1, 2, 3]);
assert_eq!(w, vec![1.0, 1.0, 1.0, 1.0]);
}
#[test]
fn indices_and_weights_subsample() {
let s = Arc::new(OuterScoreSubsample::new(vec![0, 2], 4, 7));
let rm = RowMeasure::subsample(s);
let (idx, w) = rm.indices_and_weights(4);
assert_eq!(idx, vec![0, 2]);
assert_eq!(w.len(), 4);
assert!(w[0] > 0.0 && w[2] > 0.0);
}
}