#[cfg(not(feature = "std"))]
use alloc::{collections::HashMap, vec, vec::Vec};
#[cfg(feature = "std")]
use std::collections::HashMap;
use super::core::{rng_utils, Sampler, SamplerIterator};
use scirs2_core::rand_prelude::SliceRandom;
use scirs2_core::random::Random;
use scirs2_core::RngExt;
#[derive(Debug, Clone)]
pub struct WeightedRandomSampler {
weights: Vec<f32>,
replacement: bool,
generator: Option<u64>,
alias_table: Option<AliasTable>,
}
impl WeightedRandomSampler {
pub fn new(weights: Vec<f32>, replacement: bool) -> Self {
assert!(!weights.is_empty(), "Weights vector cannot be empty");
assert!(
weights.iter().any(|&w| w > 0.0),
"At least one weight must be positive"
);
Self {
weights,
replacement,
generator: None,
alias_table: None,
}
}
pub fn with_generator(mut self, seed: u64) -> Self {
self.generator = Some(seed);
self
}
pub fn weights(&self) -> &[f32] {
&self.weights
}
pub fn uses_replacement(&self) -> bool {
self.replacement
}
pub fn generator_seed(&self) -> Option<u64> {
self.generator
}
fn build_alias_table(&mut self) {
if self.alias_table.is_none() {
self.alias_table = Some(AliasTable::new(&self.weights));
}
}
fn generate_indices(&mut self, count: usize) -> Vec<usize> {
self.build_alias_table();
let alias_table = self
.alias_table
.as_ref()
.expect("alias table should be built");
let mut rng = rng_utils::create_rng(self.generator);
let mut indices = Vec::with_capacity(count);
for _ in 0..count {
let idx = alias_table.sample(&mut rng);
indices.push(idx);
}
indices
}
}
impl Sampler for WeightedRandomSampler {
type Iter = SamplerIterator;
fn iter(&self) -> Self::Iter {
let count = if self.replacement {
self.weights.len() } else {
self.weights.len() };
let mut sampler = self.clone();
let indices = if self.replacement {
sampler.generate_indices(count)
} else {
let mut weighted_indices: Vec<(usize, f32)> = self
.weights
.iter()
.enumerate()
.map(|(i, &w)| (i, w))
.collect();
let mut rng = rng_utils::create_rng(self.generator);
for i in (1..weighted_indices.len()).rev() {
let total_weight: f32 = weighted_indices[..=i].iter().map(|(_, w)| w).sum();
let mut target_weight = rng.random::<f32>() * total_weight;
let mut selected_idx = 0;
for (j, (_, weight)) in weighted_indices[..=i].iter().enumerate() {
target_weight -= weight;
if target_weight <= 0.0 {
selected_idx = j;
break;
}
}
weighted_indices.swap(i, selected_idx);
}
weighted_indices.into_iter().map(|(idx, _)| idx).collect()
};
SamplerIterator::new(indices)
}
fn len(&self) -> usize {
self.weights.len()
}
}
#[derive(Debug, Clone)]
struct AliasTable {
prob: Vec<f32>,
alias: Vec<usize>,
}
impl AliasTable {
fn new(weights: &[f32]) -> Self {
let n = weights.len();
let sum: f32 = weights.iter().sum();
assert!(sum > 0.0, "Total weight must be positive");
let mut prob = vec![0.0; n];
let mut alias = vec![0; n];
let normalized: Vec<f32> = weights.iter().map(|&w| w * n as f32 / sum).collect();
let mut small = Vec::new();
let mut large = Vec::new();
for (i, &p) in normalized.iter().enumerate() {
if p < 1.0 {
small.push(i);
} else {
large.push(i);
}
}
prob.copy_from_slice(&normalized);
while let (Some(l), Some(g)) = (small.pop(), large.pop()) {
alias[l] = g;
prob[g] = prob[g] + prob[l] - 1.0;
if prob[g] < 1.0 {
small.push(g);
} else {
large.push(g);
}
}
while let Some(g) = large.pop() {
prob[g] = 1.0;
}
while let Some(l) = small.pop() {
prob[l] = 1.0;
}
Self { prob, alias }
}
fn sample(&self, rng: &mut Random<scirs2_core::rngs::StdRng>) -> usize {
let i = rng.gen_range(0..self.prob.len());
let coin_flip = rng.random::<f32>();
if coin_flip < self.prob[i] {
i
} else {
self.alias[i]
}
}
}
#[derive(Debug)]
pub struct GroupedSampler<F> {
groups: Vec<Vec<usize>>,
shuffle_groups: bool,
shuffle_within_groups: bool,
generator: Option<u64>,
_phantom: std::marker::PhantomData<F>,
}
impl<F> GroupedSampler<F>
where
F: Fn(usize) -> usize + Send,
{
pub fn new<D>(dataset: &D, group_fn: F) -> Self
where
D: crate::dataset::Dataset,
{
let mut groups: HashMap<usize, Vec<usize>> = HashMap::new();
for idx in 0..dataset.len() {
let group_key = group_fn(idx);
groups.entry(group_key).or_default().push(idx);
}
let mut group_list: Vec<(usize, Vec<usize>)> = groups.into_iter().collect();
group_list.sort_by_key(|(key, _)| *key);
let groups: Vec<Vec<usize>> = group_list.into_iter().map(|(_, indices)| indices).collect();
Self {
groups,
shuffle_groups: false,
shuffle_within_groups: false,
generator: None,
_phantom: std::marker::PhantomData,
}
}
pub fn with_shuffle_groups(mut self, shuffle: bool) -> Self {
self.shuffle_groups = shuffle;
self
}
pub fn with_shuffle_within_groups(mut self, shuffle: bool) -> Self {
self.shuffle_within_groups = shuffle;
self
}
pub fn with_generator(mut self, seed: u64) -> Self {
self.generator = Some(seed);
self
}
pub fn num_groups(&self) -> usize {
self.groups.len()
}
pub fn group_sizes(&self) -> Vec<usize> {
self.groups.iter().map(|group| group.len()).collect()
}
pub fn shuffles_groups(&self) -> bool {
self.shuffle_groups
}
pub fn shuffles_within_groups(&self) -> bool {
self.shuffle_within_groups
}
}
impl<F: Send> Sampler for GroupedSampler<F> {
type Iter = SamplerIterator;
fn iter(&self) -> Self::Iter {
let mut rng = rng_utils::create_rng(self.generator);
let mut groups = self.groups.clone();
if self.shuffle_within_groups {
for group in &mut groups {
group.shuffle(&mut rng);
}
}
if self.shuffle_groups {
groups.shuffle(&mut rng);
}
let indices: Vec<usize> = groups.into_iter().flatten().collect();
SamplerIterator::new(indices)
}
fn len(&self) -> usize {
self.groups.iter().map(|group| group.len()).sum()
}
}
#[derive(Debug, Clone)]
pub struct StratifiedSampler {
strata: HashMap<usize, Vec<usize>>,
proportional: bool,
min_samples_per_stratum: usize,
generator: Option<u64>,
}
impl StratifiedSampler {
pub fn new(class_labels: Vec<usize>) -> Self {
let mut strata: HashMap<usize, Vec<usize>> = HashMap::new();
for (idx, &class) in class_labels.iter().enumerate() {
strata.entry(class).or_default().push(idx);
}
Self {
strata,
proportional: true,
min_samples_per_stratum: 1,
generator: None,
}
}
pub fn from_strata(strata: HashMap<usize, Vec<usize>>) -> Self {
Self {
strata,
proportional: true,
min_samples_per_stratum: 1,
generator: None,
}
}
pub fn with_proportional(mut self, proportional: bool) -> Self {
self.proportional = proportional;
self
}
pub fn with_min_samples_per_stratum(mut self, min_samples: usize) -> Self {
self.min_samples_per_stratum = min_samples;
self
}
pub fn with_generator(mut self, seed: u64) -> Self {
self.generator = Some(seed);
self
}
pub fn num_strata(&self) -> usize {
self.strata.len()
}
pub fn stratum_sizes(&self) -> HashMap<usize, usize> {
self.strata.iter().map(|(&k, v)| (k, v.len())).collect()
}
pub fn uses_proportional(&self) -> bool {
self.proportional
}
fn calculate_stratum_samples(&self, total_samples: usize) -> HashMap<usize, usize> {
let total_stratum_size: usize = self.strata.values().map(|v| v.len()).sum();
let mut stratum_samples = HashMap::new();
if self.proportional {
for (&stratum_id, indices) in &self.strata {
let proportional_samples = (indices.len() * total_samples) / total_stratum_size;
let final_samples = proportional_samples.max(self.min_samples_per_stratum);
stratum_samples.insert(stratum_id, final_samples);
}
} else {
let samples_per_stratum = total_samples / self.strata.len();
for &stratum_id in self.strata.keys() {
stratum_samples.insert(
stratum_id,
samples_per_stratum.max(self.min_samples_per_stratum),
);
}
}
stratum_samples
}
}
impl Sampler for StratifiedSampler {
type Iter = SamplerIterator;
fn iter(&self) -> Self::Iter {
let total_samples: usize = self.strata.values().map(|v| v.len()).sum();
let stratum_samples = self.calculate_stratum_samples(total_samples);
let mut rng = rng_utils::create_rng(self.generator);
let mut all_indices = Vec::new();
for (&stratum_id, indices) in &self.strata {
let target_samples = stratum_samples[&stratum_id];
let mut stratum_indices = indices.clone();
stratum_indices.shuffle(&mut rng);
if target_samples <= indices.len() {
all_indices.extend(&stratum_indices[..target_samples]);
} else {
all_indices.extend(&stratum_indices);
for _ in indices.len()..target_samples {
let idx = rng.gen_range(0..indices.len());
all_indices.push(indices[idx]);
}
}
}
all_indices.shuffle(&mut rng);
SamplerIterator::new(all_indices)
}
fn len(&self) -> usize {
let total_samples: usize = self.strata.values().map(|v| v.len()).sum();
let stratum_samples = self.calculate_stratum_samples(total_samples);
stratum_samples.values().sum()
}
}
#[derive(Debug, Clone)]
pub struct ImportanceSampler {
importance_scores: Vec<f32>,
temperature: f32,
generator: Option<u64>,
adaptive: bool,
update_rate: f32,
}
impl ImportanceSampler {
pub fn new(importance_scores: Vec<f32>) -> Self {
assert!(
!importance_scores.is_empty(),
"Importance scores cannot be empty"
);
Self {
importance_scores,
temperature: 1.0,
generator: None,
adaptive: false,
update_rate: 0.1,
}
}
pub fn with_temperature(mut self, temperature: f32) -> Self {
assert!(temperature > 0.0, "Temperature must be positive");
self.temperature = temperature;
self
}
pub fn with_adaptive(mut self, adaptive: bool, update_rate: f32) -> Self {
assert!(
update_rate >= 0.0 && update_rate <= 1.0,
"Update rate must be in [0, 1]"
);
self.adaptive = adaptive;
self.update_rate = update_rate;
self
}
pub fn with_generator(mut self, seed: u64) -> Self {
self.generator = Some(seed);
self
}
pub fn importance_scores(&self) -> &[f32] {
&self.importance_scores
}
pub fn temperature(&self) -> f32 {
self.temperature
}
pub fn is_adaptive(&self) -> bool {
self.adaptive
}
pub fn update_importance_scores(&mut self, new_scores: Vec<f32>) {
if self.adaptive && new_scores.len() == self.importance_scores.len() {
for (old, &new) in self.importance_scores.iter_mut().zip(new_scores.iter()) {
*old = (1.0 - self.update_rate) * *old + self.update_rate * new;
}
}
}
fn compute_probabilities(&self) -> Vec<f32> {
let scaled_scores: Vec<f32> = self
.importance_scores
.iter()
.map(|&score| (score / self.temperature).exp())
.collect();
let total: f32 = scaled_scores.iter().sum();
if total > 0.0 {
scaled_scores.iter().map(|&score| score / total).collect()
} else {
vec![1.0 / self.importance_scores.len() as f32; self.importance_scores.len()]
}
}
}
impl Sampler for ImportanceSampler {
type Iter = SamplerIterator;
fn iter(&self) -> Self::Iter {
let probabilities = self.compute_probabilities();
let mut weighted_sampler = WeightedRandomSampler::new(probabilities, false);
if let Some(seed) = self.generator {
weighted_sampler = weighted_sampler.with_generator(seed);
}
weighted_sampler.iter()
}
fn len(&self) -> usize {
self.importance_scores.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
struct MockDataset {
size: usize,
}
impl crate::dataset::Dataset for MockDataset {
type Item = usize;
fn get(&self, index: usize) -> torsh_core::error::Result<Self::Item> {
if index < self.size {
Ok(index)
} else {
Err(torsh_core::error::TorshError::IndexOutOfBounds {
index,
size: self.size,
})
}
}
fn len(&self) -> usize {
self.size
}
}
#[test]
fn test_weighted_random_sampler() {
let weights = vec![0.1, 0.3, 0.6]; let sampler = WeightedRandomSampler::new(weights.clone(), true).with_generator(42);
assert_eq!(sampler.len(), 3);
assert_eq!(sampler.weights(), &weights);
assert!(sampler.uses_replacement());
assert_eq!(sampler.generator_seed(), Some(42));
let indices: Vec<usize> = sampler.iter().collect();
assert_eq!(indices.len(), 3);
assert!(indices.iter().all(|&i| i < 3));
}
#[test]
fn test_weighted_sampler_deterministic() {
let weights = vec![1.0, 2.0, 3.0];
let sampler1 = WeightedRandomSampler::new(weights.clone(), true).with_generator(123);
let sampler2 = WeightedRandomSampler::new(weights, true).with_generator(123);
let indices1: Vec<usize> = sampler1.iter().collect();
let indices2: Vec<usize> = sampler2.iter().collect();
assert_eq!(indices1, indices2);
}
#[test]
fn test_alias_table() {
let weights = vec![1.0, 2.0, 3.0];
let table = AliasTable::new(&weights);
assert_eq!(table.prob.len(), 3);
assert_eq!(table.alias.len(), 3);
let mut rng = rng_utils::create_rng(Some(42));
let mut counts = vec![0; 3];
for _ in 0..1000 {
let sample = table.sample(&mut rng);
assert!(sample < 3);
counts[sample] += 1;
}
assert!(counts[2] > counts[1]); assert!(counts[1] > counts[0]); }
#[test]
fn test_grouped_sampler() {
let dataset = MockDataset { size: 12 };
let group_fn = |idx: usize| idx % 3;
let sampler = GroupedSampler::new(&dataset, group_fn)
.with_shuffle_groups(false)
.with_shuffle_within_groups(false);
assert_eq!(sampler.len(), 12);
assert_eq!(sampler.num_groups(), 3);
assert_eq!(sampler.group_sizes(), vec![4, 4, 4]);
let indices: Vec<usize> = sampler.iter().collect();
assert_eq!(indices.len(), 12);
}
#[test]
fn test_grouped_sampler_with_shuffling() {
let dataset = MockDataset { size: 9 };
let group_fn = |idx: usize| idx % 3;
let sampler = GroupedSampler::new(&dataset, group_fn)
.with_shuffle_groups(true)
.with_shuffle_within_groups(true)
.with_generator(42);
let indices1: Vec<usize> = sampler.iter().collect();
let indices2: Vec<usize> = sampler.iter().collect();
assert_eq!(indices1, indices2);
assert_eq!(indices1.len(), 9);
let mut sorted_indices = indices1;
sorted_indices.sort();
assert_eq!(sorted_indices, (0..9).collect::<Vec<_>>());
}
#[test]
fn test_stratified_sampler() {
let class_labels = vec![0, 0, 1, 1, 1, 2]; let sampler = StratifiedSampler::new(class_labels)
.with_proportional(true)
.with_generator(42);
assert_eq!(sampler.num_strata(), 3);
assert!(sampler.uses_proportional());
let stratum_sizes = sampler.stratum_sizes();
assert_eq!(stratum_sizes[&0], 2);
assert_eq!(stratum_sizes[&1], 3);
assert_eq!(stratum_sizes[&2], 1);
let indices: Vec<usize> = sampler.iter().collect();
assert!(!indices.is_empty());
}
#[test]
fn test_stratified_sampler_balanced() {
let class_labels = vec![0, 0, 1, 1, 1, 2];
let sampler = StratifiedSampler::new(class_labels)
.with_proportional(false) .with_min_samples_per_stratum(2)
.with_generator(42);
assert!(!sampler.uses_proportional());
let indices: Vec<usize> = sampler.iter().collect();
assert!(!indices.is_empty());
}
#[test]
fn test_stratified_sampler_from_strata() {
let mut strata = HashMap::new();
strata.insert(0, vec![0, 1]);
strata.insert(1, vec![2, 3, 4]);
strata.insert(2, vec![5]);
let sampler = StratifiedSampler::from_strata(strata);
assert_eq!(sampler.num_strata(), 3);
let indices: Vec<usize> = sampler.iter().collect();
assert!(!indices.is_empty());
}
#[test]
fn test_importance_sampler() {
let scores = vec![0.1, 0.8, 0.3, 0.9, 0.2];
let sampler = ImportanceSampler::new(scores.clone())
.with_temperature(1.0)
.with_generator(42);
assert_eq!(sampler.len(), 5);
assert_eq!(sampler.importance_scores(), &scores);
assert_eq!(sampler.temperature(), 1.0);
assert!(!sampler.is_adaptive());
let indices: Vec<usize> = sampler.iter().collect();
assert_eq!(indices.len(), 5);
assert!(indices.iter().all(|&i| i < 5));
}
#[test]
fn test_importance_sampler_temperature() {
let scores = vec![0.1, 1.0, 0.1];
let low_temp_sampler = ImportanceSampler::new(scores.clone())
.with_temperature(0.1)
.with_generator(42);
let high_temp_sampler = ImportanceSampler::new(scores)
.with_temperature(10.0)
.with_generator(42);
let _low_indices: Vec<usize> = low_temp_sampler.iter().collect();
let _high_indices: Vec<usize> = high_temp_sampler.iter().collect();
}
#[test]
fn test_importance_sampler_adaptive() {
let scores = vec![0.1, 0.5, 0.3];
let mut sampler = ImportanceSampler::new(scores)
.with_adaptive(true, 0.2)
.with_generator(42);
assert!(sampler.is_adaptive());
let original_scores = sampler.importance_scores().to_vec();
let new_scores = vec![0.2, 0.8, 0.1];
sampler.update_importance_scores(new_scores);
let updated_scores = sampler.importance_scores().to_vec();
assert_ne!(original_scores, updated_scores);
for i in 0..3 {
assert!(updated_scores[i] != original_scores[i]);
}
}
#[test]
#[should_panic(expected = "Weights vector cannot be empty")]
fn test_weighted_sampler_empty_weights() {
WeightedRandomSampler::new(vec![], true);
}
#[test]
#[should_panic(expected = "At least one weight must be positive")]
fn test_weighted_sampler_zero_weights() {
WeightedRandomSampler::new(vec![0.0, 0.0, 0.0], true);
}
#[test]
#[should_panic(expected = "Temperature must be positive")]
fn test_importance_sampler_zero_temperature() {
let scores = vec![0.1, 0.2, 0.3];
ImportanceSampler::new(scores).with_temperature(0.0);
}
#[test]
#[should_panic(expected = "Importance scores cannot be empty")]
fn test_importance_sampler_empty_scores() {
ImportanceSampler::new(vec![]);
}
#[test]
fn test_importance_sampler_probabilities() {
let scores = vec![1.0, 2.0, 3.0];
let sampler = ImportanceSampler::new(scores).with_temperature(1.0);
let probabilities = sampler.compute_probabilities();
assert_eq!(probabilities.len(), 3);
let sum: f32 = probabilities.iter().sum();
assert!((sum - 1.0).abs() < 1e-6);
assert!(probabilities[2] > probabilities[1]);
assert!(probabilities[1] > probabilities[0]);
}
}