use std::sync::Arc;
#[derive(Debug, Clone)]
pub struct OuterScoreSubsample {
pub mask: Arc<Vec<usize>>,
pub rows: Arc<Vec<WeightedOuterRow>>,
pub n_full: usize,
pub weight_scale: f64,
pub seed: u64,
}
impl OuterScoreSubsample {
pub fn from_uniform_inclusion_mask(mask: Vec<usize>, n_full: usize, seed: u64) -> Self {
let m = mask.len();
let w = if m == 0 {
1.0
} else {
n_full as f64 / m as f64
};
Self::with_uniform_weight(mask, n_full, seed, w)
}
pub fn with_uniform_weight(mask: Vec<usize>, n_full: usize, seed: u64, weight: f64) -> Self {
let rows: Vec<WeightedOuterRow> = mask
.iter()
.map(|&index| WeightedOuterRow {
index,
weight,
stratum: 0,
})
.collect();
let weight_scale = if rows.is_empty() { 1.0 } else { weight };
Self {
mask: Arc::new(mask),
rows: Arc::new(rows),
n_full,
weight_scale,
seed,
}
}
pub fn from_weighted_rows(mut rows: Vec<WeightedOuterRow>, n_full: usize, seed: u64) -> Self {
rows.sort_by_key(|r| r.index);
rows.dedup_by_key(|r| r.index);
let mask: Vec<usize> = rows.iter().map(|r| r.index).collect();
let weight_scale = if rows.is_empty() {
1.0
} else {
rows.iter().map(|r| r.weight).sum::<f64>() / rows.len() as f64
};
Self {
mask: Arc::new(mask),
rows: Arc::new(rows),
n_full,
weight_scale,
seed,
}
}
#[inline]
pub fn len(&self) -> usize {
self.mask.len()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.mask.is_empty()
}
pub fn has_variable_weights(&self) -> bool {
let mut iter = self.rows.iter();
let Some(first) = iter.next() else {
return false;
};
iter.any(|r| (r.weight - first.weight).abs() > 0.0)
}
}
#[derive(Debug, Clone, Copy)]
pub struct WeightedOuterRow {
pub index: usize,
pub weight: f64,
pub stratum: u32,
}
pub const ARROW_ROW_CHUNK: usize = 256;
#[inline]
pub fn arrow_row_chunk_count(n_rows: usize) -> usize {
if n_rows == 0 {
0
} else {
(n_rows - 1) / ARROW_ROW_CHUNK + 1
}
}
#[derive(Clone)]
pub enum RowSet {
All,
Subsample {
rows: Arc<Vec<WeightedOuterRow>>,
n_full: usize,
},
}
impl RowSet {
#[inline]
pub fn par_reduce_fold<T, I, F, R>(&self, n_total: usize, init: I, fold: F, reduce: R) -> T
where
T: Send,
I: Fn() -> T + Send + Sync,
F: Fn(T, usize, f64) -> T + Send + Sync,
R: Fn(T, T) -> T + Send + Sync,
{
use rayon::iter::{IntoParallelIterator, ParallelIterator};
use rayon::slice::ParallelSlice;
match self {
Self::All => {
let chunk_accumulators: Vec<T> = (0..arrow_row_chunk_count(n_total))
.into_par_iter()
.map(|chunk_idx| {
let start = chunk_idx * ARROW_ROW_CHUNK;
let end = (start + ARROW_ROW_CHUNK).min(n_total);
let mut acc = init();
for i in start..end {
acc = fold(acc, i, 1.0);
}
acc
})
.collect();
let mut total = init();
for acc in chunk_accumulators {
total = reduce(total, acc);
}
total
}
Self::Subsample { rows, .. } => {
let chunk_accumulators: Vec<T> = rows
.par_chunks(ARROW_ROW_CHUNK)
.map(|chunk| {
let mut acc = init();
for r in chunk {
acc = fold(acc, r.index, r.weight);
}
acc
})
.collect();
let mut total = init();
for acc in chunk_accumulators {
total = reduce(total, acc);
}
total
}
}
}
#[inline]
pub fn par_try_reduce_fold<T, E, I, F, R>(
&self,
n_total: usize,
init: I,
fold: F,
reduce: R,
) -> Result<T, E>
where
T: Send,
E: Send,
I: Fn() -> T + Send + Sync,
F: Fn(T, usize, f64) -> Result<T, E> + Send + Sync,
R: Fn(T, T) -> Result<T, E> + Send + Sync,
{
use rayon::iter::{IntoParallelIterator, ParallelIterator};
use rayon::slice::ParallelSlice;
match self {
Self::All => {
let chunk_accumulators: Vec<Result<T, E>> = (0..arrow_row_chunk_count(n_total))
.into_par_iter()
.map(|chunk_idx| {
let start = chunk_idx * ARROW_ROW_CHUNK;
let end = (start + ARROW_ROW_CHUNK).min(n_total);
let mut acc = init();
for i in start..end {
acc = fold(acc, i, 1.0)?;
}
Ok(acc)
})
.collect();
let mut total = init();
for acc in chunk_accumulators {
total = reduce(total, acc?)?;
}
Ok(total)
}
Self::Subsample { rows, .. } => {
let chunk_accumulators: Vec<Result<T, E>> = rows
.par_chunks(ARROW_ROW_CHUNK)
.map(|chunk| {
let mut acc = init();
for r in chunk {
acc = fold(acc, r.index, r.weight)?;
}
Ok(acc)
})
.collect();
let mut total = init();
for acc in chunk_accumulators {
total = reduce(total, acc?)?;
}
Ok(total)
}
}
}
}