use std::sync::Arc;
#[derive(Debug, Clone)]
pub struct OuterScoreSubsample {
pub mask: Arc<Vec<usize>>,
pub rows: Arc<Vec<WeightedOuterRow>>,
pub n_full: usize,
pub weight_scale: f64,
pub seed: u64,
}
impl OuterScoreSubsample {
pub fn from_uniform_inclusion_mask(mask: Vec<usize>, n_full: usize, seed: u64) -> Self {
let m = mask.len();
let w = if m == 0 {
1.0
} else {
n_full as f64 / m as f64
};
Self::with_uniform_weight(mask, n_full, seed, w)
}
pub fn with_uniform_weight(mask: Vec<usize>, n_full: usize, seed: u64, weight: f64) -> Self {
let rows: Vec<WeightedOuterRow> = mask
.iter()
.map(|&index| WeightedOuterRow {
index,
weight,
stratum: 0,
})
.collect();
let weight_scale = if rows.is_empty() { 1.0 } else { weight };
Self {
mask: Arc::new(mask),
rows: Arc::new(rows),
n_full,
weight_scale,
seed,
}
}
pub fn from_weighted_rows(mut rows: Vec<WeightedOuterRow>, n_full: usize, seed: u64) -> Self {
rows.sort_by_key(|r| r.index);
rows.dedup_by_key(|r| r.index);
let mask: Vec<usize> = rows.iter().map(|r| r.index).collect();
let weight_scale = if rows.is_empty() {
1.0
} else {
rows.iter().map(|r| r.weight).sum::<f64>() / rows.len() as f64
};
Self {
mask: Arc::new(mask),
rows: Arc::new(rows),
n_full,
weight_scale,
seed,
}
}
#[inline]
pub fn len(&self) -> usize {
self.mask.len()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.mask.is_empty()
}
pub fn has_variable_weights(&self) -> bool {
let mut iter = self.rows.iter();
let Some(first) = iter.next() else {
return false;
};
iter.any(|r| (r.weight - first.weight).abs() > 0.0)
}
}
#[derive(Debug, Clone, Copy)]
pub struct WeightedOuterRow {
pub index: usize,
pub weight: f64,
pub stratum: u32,
}
pub const ARROW_ROW_CHUNK: usize = 256;
#[inline]
pub fn arrow_row_chunk_count(n_rows: usize) -> usize {
if n_rows == 0 {
0
} else {
(n_rows - 1) / ARROW_ROW_CHUNK + 1
}
}
#[derive(Clone)]
pub enum RowSet {
All,
Subsample {
rows: Arc<Vec<WeightedOuterRow>>,
n_full: usize,
},
}
impl RowSet {
#[inline]
pub fn par_reduce_fold<T, I, F, R>(&self, n_total: usize, init: I, fold: F, reduce: R) -> T
where
T: Send,
I: Fn() -> T + Send + Sync,
F: Fn(T, usize, f64) -> T + Send + Sync,
R: Fn(T, T) -> T + Send + Sync,
{
use rayon::iter::{IntoParallelIterator, ParallelIterator};
use rayon::slice::ParallelSlice;
match self {
Self::All => {
let chunk_accumulators: Vec<T> = (0..arrow_row_chunk_count(n_total))
.into_par_iter()
.map(|chunk_idx| {
let start = chunk_idx * ARROW_ROW_CHUNK;
let end = (start + ARROW_ROW_CHUNK).min(n_total);
let mut acc = init();
for i in start..end {
acc = fold(acc, i, 1.0);
}
acc
})
.collect();
let mut total = init();
for acc in chunk_accumulators {
total = reduce(total, acc);
}
total
}
Self::Subsample { rows, .. } => {
let chunk_accumulators: Vec<T> = rows
.par_chunks(ARROW_ROW_CHUNK)
.map(|chunk| {
let mut acc = init();
for r in chunk {
acc = fold(acc, r.index, r.weight);
}
acc
})
.collect();
let mut total = init();
for acc in chunk_accumulators {
total = reduce(total, acc);
}
total
}
}
}
#[inline]
pub fn par_try_reduce_fold<T, E, I, F, R>(
&self,
n_total: usize,
init: I,
fold: F,
reduce: R,
) -> Result<T, E>
where
T: Send,
E: Send,
I: Fn() -> T + Send + Sync,
F: Fn(T, usize, f64) -> Result<T, E> + Send + Sync,
R: Fn(T, T) -> Result<T, E> + Send + Sync,
{
use rayon::iter::{IntoParallelIterator, ParallelIterator};
use rayon::slice::ParallelSlice;
match self {
Self::All => {
let chunk_accumulators: Vec<Result<T, E>> = (0..arrow_row_chunk_count(n_total))
.into_par_iter()
.map(|chunk_idx| {
let start = chunk_idx * ARROW_ROW_CHUNK;
let end = (start + ARROW_ROW_CHUNK).min(n_total);
let mut acc = init();
for i in start..end {
acc = fold(acc, i, 1.0)?;
}
Ok(acc)
})
.collect();
let mut total = init();
for acc in chunk_accumulators {
total = reduce(total, acc?)?;
}
Ok(total)
}
Self::Subsample { rows, .. } => {
let chunk_accumulators: Vec<Result<T, E>> = rows
.par_chunks(ARROW_ROW_CHUNK)
.map(|chunk| {
let mut acc = init();
for r in chunk {
acc = fold(acc, r.index, r.weight)?;
}
Ok(acc)
})
.collect();
let mut total = init();
for acc in chunk_accumulators {
total = reduce(total, acc?)?;
}
Ok(total)
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn chunk_count_zero_rows_is_zero() {
assert_eq!(arrow_row_chunk_count(0), 0);
}
#[test]
fn chunk_count_one_row_is_one() {
assert_eq!(arrow_row_chunk_count(1), 1);
}
#[test]
fn chunk_count_exact_multiple() {
assert_eq!(arrow_row_chunk_count(ARROW_ROW_CHUNK), 1);
assert_eq!(arrow_row_chunk_count(ARROW_ROW_CHUNK * 3), 3);
}
#[test]
fn chunk_count_just_over_boundary() {
assert_eq!(arrow_row_chunk_count(ARROW_ROW_CHUNK + 1), 2);
}
#[test]
fn uniform_mask_weight_scale_is_n_full_over_m() {
let mask = vec![0usize, 2, 4];
let s = OuterScoreSubsample::from_uniform_inclusion_mask(mask, 6, 0);
assert_eq!(s.len(), 3);
assert!((s.weight_scale - 2.0).abs() < 1e-14);
assert!(s.rows.iter().all(|r| (r.weight - 2.0).abs() < 1e-14));
}
#[test]
fn uniform_mask_empty_has_weight_scale_one() {
let s = OuterScoreSubsample::from_uniform_inclusion_mask(vec![], 10, 0);
assert_eq!(s.len(), 0);
assert!(s.is_empty());
assert_eq!(s.weight_scale, 1.0);
}
#[test]
fn weighted_rows_sorts_and_deduplicates() {
let rows = vec![
WeightedOuterRow { index: 3, weight: 2.0, stratum: 0 },
WeightedOuterRow { index: 1, weight: 1.0, stratum: 0 },
WeightedOuterRow { index: 3, weight: 2.0, stratum: 0 }, ];
let s = OuterScoreSubsample::from_weighted_rows(rows, 10, 42);
assert_eq!(s.len(), 2);
assert_eq!(s.mask[0], 1);
assert_eq!(s.mask[1], 3);
}
#[test]
fn weighted_rows_weight_scale_is_average_weight() {
let rows = vec![
WeightedOuterRow { index: 0, weight: 1.0, stratum: 0 },
WeightedOuterRow { index: 1, weight: 3.0, stratum: 0 },
];
let s = OuterScoreSubsample::from_weighted_rows(rows, 10, 0);
assert!((s.weight_scale - 2.0).abs() < 1e-14);
}
#[test]
fn has_variable_weights_false_for_uniform() {
let s = OuterScoreSubsample::with_uniform_weight(vec![0, 1, 2], 3, 0, 1.5);
assert!(!s.has_variable_weights());
}
#[test]
fn has_variable_weights_true_for_mixed() {
let rows = vec![
WeightedOuterRow { index: 0, weight: 1.0, stratum: 0 },
WeightedOuterRow { index: 1, weight: 2.0, stratum: 0 },
];
let s = OuterScoreSubsample::from_weighted_rows(rows, 5, 0);
assert!(s.has_variable_weights());
}
#[test]
fn row_set_all_sums_indices_zero_to_n() {
let rs = RowSet::All;
let sum: f64 = rs.par_reduce_fold(
5,
|| 0.0_f64,
|acc, i, w| acc + i as f64 * w,
|a, b| a + b,
);
assert!((sum - 10.0).abs() < 1e-14);
}
#[test]
fn row_set_subsample_applies_per_row_weight() {
let rows = Arc::new(vec![
WeightedOuterRow { index: 2, weight: 3.0, stratum: 0 },
WeightedOuterRow { index: 5, weight: 2.0, stratum: 0 },
]);
let rs = RowSet::Subsample { rows, n_full: 10 };
let sum: f64 = rs.par_reduce_fold(
10,
|| 0.0_f64,
|acc, i, w| acc + w * i as f64,
|a, b| a + b,
);
assert!((sum - 16.0).abs() < 1e-14);
}
}