use crate::{RdRand, random_provider};
pub enum SubsetMode<'a> {
StratifiedCorrect,
FastRandom,
Exclude(&'a [usize]),
RangeList(&'a [(usize, usize)]),
}
pub fn create_individual_indexes(index: usize, max_index: usize, num_indices: usize) -> Vec<usize> {
let mut scratch = vec![0; num_indices];
subset(
max_index,
num_indices,
&mut scratch,
SubsetMode::StratifiedCorrect,
);
let mut i = 0;
while i < scratch.len() && scratch[i] < index {
i += 1;
}
if i < scratch.len() {
scratch[i] = index;
}
scratch.sort_unstable();
scratch
}
pub fn individual_indexes(index: usize, max_index: usize, num_indices: usize, buff: &mut [usize]) {
subset(max_index, num_indices, buff, SubsetMode::StratifiedCorrect);
let mut i = 0;
while i < buff.len() && buff[i] < index {
i += 1;
}
if i < buff.len() {
buff[i] = index;
}
buff.sort_unstable();
}
pub fn subset(max_index: usize, num_indicies: usize, buffer: &mut [usize], mode: SubsetMode) {
if max_index < num_indicies {
panic!("n smaller than k: {} < {}.", max_index, num_indicies);
}
random_provider::with_rng(|rand| match mode {
SubsetMode::StratifiedCorrect => {
next(max_index, buffer, rand);
}
SubsetMode::FastRandom => {
for i in 0..num_indicies {
buffer[i] = rand.range(0..max_index);
}
}
SubsetMode::Exclude(exclude) => {
for i in 0..num_indicies {
loop {
let index = rand.range(0..max_index);
if !exclude.contains(&index) {
buffer[i] = index;
break;
}
}
}
}
SubsetMode::RangeList(range_list) => {
for i in 0..num_indicies {
let (start, end) = range_list[i % range_list.len()];
buffer[i] = rand.range(start..end);
}
}
})
}
fn next(max_index: usize, sub_set: &mut [usize], rand: &mut RdRand<'_>) {
let k = sub_set.len();
if k == max_index {
for i in 0..k {
sub_set[i] = i;
}
return;
}
build_subset(max_index, sub_set, rand);
if k > max_index - k {
invert(max_index, sub_set);
}
}
fn build_subset(max_index: usize, sub: &mut [usize], rand: &mut RdRand<'_>) {
let k = sub.len();
check_subset(max_index, k);
for i in 0..k {
sub[i] = i * max_index / k;
}
for _ in 0..k {
let mut ix;
let mut l;
loop {
ix = rand.range(1..max_index);
l = (ix * k - 1) / max_index;
if sub[l] < ix {
break;
}
}
sub[l] += 1;
}
let mut ip = 0;
let mut is_ = k;
for i in 0..k {
let m = sub[i];
sub[i] = 0;
if m != i * max_index / k {
ip += 1;
sub[ip - 1] = m;
}
}
let ihi = ip;
for i in 1..=ihi {
ip = ihi + 1 - i;
let l = 1 + (sub[ip - 1] * k - 1) / max_index;
let ids = sub[ip - 1] - (l - 1) * max_index / k;
sub[ip - 1] = 0;
sub[is_ - 1] = l;
is_ -= ids;
}
for ll in 1..=k {
let l = k + 1 - ll;
if sub[l - 1] != 0 {
let ir = l;
let m0 = 1 + (sub[l - 1] - 1) * max_index / k;
let m = sub[l - 1] * max_index / k - m0 + 1;
let ix = rand.range(m0..m0 + m - 1);
let mut i = l + 1;
while i <= ir && ix >= sub[i - 1] {
sub[i - 2] = sub[i - 1];
i += 1;
}
sub[i - 2] = ix;
}
}
}
fn invert(n: usize, a: &mut [usize]) {
let k = a.len();
let mut v = n - 1;
let j = n - k - 1;
let mut ac = vec![0; k];
ac.copy_from_slice(a);
for i in (0..k).rev() {
while let Some(_) = index_of(&ac, j, v) {
v -= 1;
}
a[i] = v;
v -= 1;
}
}
fn index_of(a: &[usize], start: usize, value: usize) -> Option<usize> {
for i in (0..=start).rev() {
if a[i] == value {
return Some(i);
}
}
None
}
fn check_subset(n: usize, k: usize) {
if n < k {
panic!("n smaller than k: {} < {}.", n, k);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_fast_random_subset() {
let n = 50;
let k = 20;
let mut result = vec![0; k];
subset(n, k, &mut result, SubsetMode::FastRandom);
assert_eq!(result.len(), k);
assert!(result.iter().all(|&x| x < n));
}
#[test]
fn test_exclude_subset() {
let n = 10;
let k = 5;
let blacklist = vec![2, 3, 4];
let mut result = vec![0; k];
subset(n, k, &mut result, SubsetMode::Exclude(&blacklist));
assert_eq!(result.len(), k);
assert!(result.iter().all(|&x| !blacklist.contains(&x)));
}
#[test]
fn test_range_list_subset() {
let ranges = vec![(0, 5), (10, 15)];
let mut result = vec![0; 6];
subset(20, 6, &mut result, SubsetMode::RangeList(&ranges));
assert_eq!(result.len(), 6);
assert!(
result
.iter()
.all(|&x| (0..5).contains(&x) || (10..15).contains(&x))
);
}
#[test]
fn test_individual_indexes_includes_index() {
let result = create_individual_indexes(7, 20, 5);
assert_eq!(result.len(), 5);
assert!(result.contains(&7));
let mut sorted = result.clone();
sorted.sort_unstable();
assert_eq!(sorted, result); }
}