use crate::ArraySamplerStrategy;
#[derive(Copy, Clone, Debug, Default)]
pub enum ArraySamplerKind {
#[default]
Default,
Head,
Tail,
}
impl From<ArraySamplerStrategy> for ArraySamplerKind {
fn from(strategy: ArraySamplerStrategy) -> Self {
match strategy {
ArraySamplerStrategy::Default => ArraySamplerKind::Default,
ArraySamplerStrategy::Head => ArraySamplerKind::Head,
ArraySamplerStrategy::Tail => ArraySamplerKind::Tail,
}
}
}
const RANDOM_ACCEPT_SEED: u64 = 0x9e37_79b9_7f4a_7c15;
const RANDOM_ACCEPT_THRESHOLD: u32 = 0x8000_0000; const KEEP_FIRST_COUNT: usize = 3;
const GREEDY_PORTION_DIVISOR: usize = 2;
fn mix64(mut x: u64) -> u64 {
x ^= x >> 30;
x = x.wrapping_mul(0xbf58_476d_1ce4_e5b9);
x ^= x >> 27;
x = x.wrapping_mul(0x94d0_49bb_1331_11eb);
x ^ (x >> 31)
}
fn accept_index(i: u64) -> bool {
let h = mix64(i ^ RANDOM_ACCEPT_SEED);
((h >> 32) as u32) < RANDOM_ACCEPT_THRESHOLD
}
#[allow(
clippy::cognitive_complexity,
reason = "Single function mirrors JSON streaming sampler phases"
)]
pub fn choose_indices_default(total: usize, cap: usize) -> Vec<usize> {
if cap == 0 || total == 0 {
return Vec::new();
}
if cap >= total {
return (0..total).collect();
}
let mut out = Vec::with_capacity(cap.min(4096));
let keep_first = KEEP_FIRST_COUNT.min(cap).min(total);
for i in 0..keep_first {
out.push(i);
}
if out.len() >= cap || out.len() >= total {
out.truncate(cap.min(total));
return out;
}
let mut idx = keep_first;
let greedy_remaining =
(cap.saturating_sub(keep_first)) / GREEDY_PORTION_DIVISOR;
let mut g = 0usize;
while out.len() < cap && g < greedy_remaining && idx < total {
out.push(idx);
idx += 1;
g += 1;
}
if out.len() >= cap || idx >= total {
return out;
}
while out.len() < cap && idx < total {
if accept_index(idx as u64) {
out.push(idx);
}
idx += 1;
}
out
}
pub fn choose_indices_head(total: usize, cap: usize) -> Vec<usize> {
let kept = total.min(cap);
(0..kept).collect()
}
pub fn choose_indices_tail(total: usize, cap: usize) -> Vec<usize> {
if cap == 0 || total == 0 {
return Vec::new();
}
let kept = total.min(cap);
let start = total.saturating_sub(kept);
(start..total).collect()
}
pub fn choose_indices(
kind: ArraySamplerKind,
total: usize,
cap: usize,
) -> Vec<usize> {
match kind {
ArraySamplerKind::Default => choose_indices_default(total, cap),
ArraySamplerKind::Head => choose_indices_head(total, cap),
ArraySamplerKind::Tail => choose_indices_tail(total, cap),
}
}
#[allow(
clippy::cognitive_complexity,
reason = "Linear collect-and-merge logic reads clearest as a single function"
)]
pub fn merge_required(
sampled: Vec<usize>,
total: usize,
must_include: &impl Fn(usize) -> bool,
) -> Vec<usize> {
let mut seen = vec![false; total];
for &i in &sampled {
seen[i] = true;
}
let mut extra: Vec<usize> = Vec::new();
for (i, &already) in seen.iter().enumerate() {
if !already && must_include(i) {
extra.push(i);
}
}
if extra.is_empty() {
return sampled;
}
let mut result = Vec::with_capacity(sampled.len() + extra.len());
let (mut si, mut ei) = (0, 0);
while si < sampled.len() && ei < extra.len() {
if sampled[si] <= extra[ei] {
result.push(sampled[si]);
si += 1;
} else {
result.push(extra[ei]);
ei += 1;
}
}
result.extend_from_slice(&sampled[si..]);
result.extend_from_slice(&extra[ei..]);
result
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_sampler_returns_all_when_cap_not_binding() {
let total = 10usize;
let cap = total + 5;
let indices = choose_indices_default(total, cap);
assert_eq!(indices, (0..total).collect::<Vec<_>>());
}
#[test]
fn default_sampler_respects_cap_when_smaller() {
let total = 10usize;
let cap = 3usize;
let indices = choose_indices_default(total, cap);
assert!(indices.len() <= cap);
}
#[test]
fn merge_required_adds_missing_indices() {
let total = 20usize;
let cap = 3usize;
let sampled = choose_indices_default(total, cap);
let indices = merge_required(sampled, total, &|i| i == 15);
assert!(
indices.contains(&15),
"must_include index should be present: {indices:?}"
);
assert!(indices.contains(&0), "head items should be present");
}
#[test]
fn merge_required_preserves_sorted_order() {
let total = 100usize;
let cap = 5usize;
let sampled = choose_indices_default(total, cap);
let indices = merge_required(sampled, total, &|i| i == 50 || i == 90);
for w in indices.windows(2) {
assert!(w[0] < w[1], "indices should be sorted: {indices:?}");
}
assert!(indices.contains(&50));
assert!(indices.contains(&90));
}
#[test]
fn merge_required_with_zero_cap() {
let total = 10usize;
let sampled = choose_indices_default(total, 0);
let indices = merge_required(sampled, total, &|i| i == 3 || i == 7);
assert_eq!(indices, vec![3, 7]);
}
#[test]
fn merge_required_no_duplicates_when_already_sampled() {
let total = 10usize;
let cap = 10usize;
let sampled = choose_indices_default(total, cap);
let indices = merge_required(sampled, total, &|i| i == 0);
assert_eq!(indices, (0..total).collect::<Vec<_>>());
}
#[test]
fn head_sampler_merge_includes_required_beyond_cap() {
let total = 20usize;
let cap = 3usize;
let sampled = choose_indices_head(total, cap);
let indices = merge_required(sampled, total, &|i| i == 17);
assert_eq!(&indices[..3], &[0, 1, 2]);
assert!(
indices.contains(&17),
"must_include index should be present: {indices:?}"
);
for w in indices.windows(2) {
assert!(w[0] < w[1], "indices should be sorted: {indices:?}");
}
}
#[test]
fn head_sampler_merge_no_duplicates_when_already_sampled() {
let total = 10usize;
let cap = 5usize;
let sampled = choose_indices_head(total, cap);
let indices = merge_required(sampled, total, &|i| i == 2);
assert_eq!(indices, (0..5).collect::<Vec<_>>());
}
#[test]
fn tail_sampler_merge_includes_required_beyond_cap() {
let total = 20usize;
let cap = 3usize;
let sampled = choose_indices_tail(total, cap);
let indices = merge_required(sampled, total, &|i| i == 2);
assert!(indices.contains(&2), "must_include index should be present");
assert!(indices.contains(&17));
assert_eq!(indices, vec![2, 17, 18, 19]);
}
#[test]
fn tail_sampler_merge_no_duplicates_when_already_sampled() {
let total = 10usize;
let cap = 5usize;
let sampled = choose_indices_tail(total, cap);
let indices = merge_required(sampled, total, &|i| i == 7);
assert_eq!(indices, (5..10).collect::<Vec<_>>());
}
#[test]
fn tail_sampler_merge_with_zero_cap_returns_only_required() {
let total = 10usize;
let sampled = choose_indices_tail(total, 0);
let indices = merge_required(sampled, total, &|i| i == 4 || i == 8);
assert_eq!(indices, vec![4, 8]);
}
#[test]
fn head_sampler_merge_with_zero_cap_returns_only_required() {
let total = 10usize;
let sampled = choose_indices_head(total, 0);
let indices = merge_required(sampled, total, &|i| i == 4 || i == 8);
assert_eq!(indices, vec![4, 8]);
}
}