use crate::error::{DatasetsError, Result};
use scirs2_core::ndarray::Array1;
use scirs2_core::random::prelude::*;
use scirs2_core::random::prelude::*;
use scirs2_core::random::rngs::StdRng;
use scirs2_core::random::seq::SliceRandom;
use scirs2_core::random::Uniform;
use std::collections::HashMap;
#[allow(dead_code)]
pub fn random_sample(
n_samples: usize,
sample_size: usize,
replace: bool,
random_seed: Option<u64>,
) -> Result<Vec<usize>> {
if n_samples == 0 {
return Err(DatasetsError::InvalidFormat(
"Number of _samples must be > 0".to_string(),
));
}
if sample_size == 0 {
return Err(DatasetsError::InvalidFormat(
"Sample _size must be > 0".to_string(),
));
}
if !replace && sample_size > n_samples {
return Err(DatasetsError::InvalidFormat(format!(
"Cannot sample {sample_size} items from {n_samples} without replacement"
)));
}
let mut rng = match random_seed {
Some(_seed) => StdRng::seed_from_u64(_seed),
None => {
let mut r = thread_rng();
StdRng::seed_from_u64(r.next_u64())
}
};
let mut indices = Vec::with_capacity(sample_size);
if replace {
for _ in 0..sample_size {
indices.push(rng.sample(Uniform::new(0, n_samples).expect("Operation failed")));
}
} else {
let mut available: Vec<usize> = (0..n_samples).collect();
available.shuffle(&mut rng);
indices.extend_from_slice(&available[0..sample_size]);
}
Ok(indices)
}
#[allow(dead_code)]
pub fn stratified_sample(
targets: &Array1<f64>,
sample_size: usize,
random_seed: Option<u64>,
) -> Result<Vec<usize>> {
if targets.is_empty() {
return Err(DatasetsError::InvalidFormat(
"Targets array cannot be empty".to_string(),
));
}
if sample_size == 0 {
return Err(DatasetsError::InvalidFormat(
"Sample _size must be > 0".to_string(),
));
}
if sample_size > targets.len() {
return Err(DatasetsError::InvalidFormat(format!(
"Cannot sample {} items from {} total samples",
sample_size,
targets.len()
)));
}
let mut class_indices: HashMap<i64, Vec<usize>> = HashMap::new();
for (i, &target) in targets.iter().enumerate() {
let class = target.round() as i64;
class_indices.entry(class).or_default().push(i);
}
let mut rng = match random_seed {
Some(_seed) => StdRng::seed_from_u64(_seed),
None => {
let mut r = thread_rng();
StdRng::seed_from_u64(r.next_u64())
}
};
let mut stratified_indices = Vec::new();
let n_classes = class_indices.len();
let base_samples_per_class = sample_size / n_classes;
let remainder = sample_size % n_classes;
let mut class_list: Vec<_> = class_indices.keys().cloned().collect();
class_list.sort();
for (i, &class) in class_list.iter().enumerate() {
let class_samples = class_indices.get(&class).expect("Operation failed");
let samples_for_this_class = if i < remainder {
base_samples_per_class + 1
} else {
base_samples_per_class
};
if samples_for_this_class > class_samples.len() {
return Err(DatasetsError::InvalidFormat(format!(
"Class {} has only {} samples but needs {} for stratified sampling",
class,
class_samples.len(),
samples_for_this_class
)));
}
let sampled_indices = random_sample(
class_samples.len(),
samples_for_this_class,
false,
Some(rng.next_u64()),
)?;
for &idx in &sampled_indices {
stratified_indices.push(class_samples[idx]);
}
}
stratified_indices.shuffle(&mut rng);
Ok(stratified_indices)
}
#[allow(dead_code)]
pub fn importance_sample(
weights: &Array1<f64>,
sample_size: usize,
replace: bool,
random_seed: Option<u64>,
) -> Result<Vec<usize>> {
if weights.is_empty() {
return Err(DatasetsError::InvalidFormat(
"Weights array cannot be empty".to_string(),
));
}
if sample_size == 0 {
return Err(DatasetsError::InvalidFormat(
"Sample _size must be > 0".to_string(),
));
}
if !replace && sample_size > weights.len() {
return Err(DatasetsError::InvalidFormat(format!(
"Cannot sample {} items from {} without replacement",
sample_size,
weights.len()
)));
}
for &weight in weights.iter() {
if weight < 0.0 {
return Err(DatasetsError::InvalidFormat(
"All weights must be non-negative".to_string(),
));
}
}
let weight_sum: f64 = weights.sum();
if weight_sum <= 0.0 {
return Err(DatasetsError::InvalidFormat(
"Sum of weights must be positive".to_string(),
));
}
let mut rng = match random_seed {
Some(_seed) => StdRng::seed_from_u64(_seed),
None => {
let mut r = thread_rng();
StdRng::seed_from_u64(r.next_u64())
}
};
let mut indices = Vec::with_capacity(sample_size);
let mut available_weights = weights.clone();
let mut available_indices: Vec<usize> = (0..weights.len()).collect();
for _ in 0..sample_size {
let current_sum = available_weights.sum();
if current_sum <= 0.0 {
break;
}
let random_value = rng.random_range(0.0..current_sum);
let mut cumulative_weight = 0.0;
let mut selected_idx = 0;
for (i, &weight) in available_weights.iter().enumerate() {
cumulative_weight += weight;
if random_value <= cumulative_weight {
selected_idx = i;
break;
}
}
let original_idx = available_indices[selected_idx];
indices.push(original_idx);
if !replace {
available_weights = Array1::from_iter(
available_weights
.iter()
.enumerate()
.filter(|(i_, _)| *i_ != selected_idx)
.map(|(_, &w)| w),
);
available_indices.remove(selected_idx);
}
}
Ok(indices)
}
#[allow(dead_code)]
pub fn bootstrap_sample(
n_samples: usize,
n_bootstrap_samples: usize,
random_seed: Option<u64>,
) -> Result<Vec<usize>> {
random_sample(n_samples, n_bootstrap_samples, true, random_seed)
}
#[allow(dead_code)]
pub fn multiple_bootstrap_samples(
n_samples: usize,
sample_size: usize,
n_bootstrap_rounds: usize,
random_seed: Option<u64>,
) -> Result<Vec<Vec<usize>>> {
if n_bootstrap_rounds == 0 {
return Err(DatasetsError::InvalidFormat(
"Number of bootstrap _rounds must be > 0".to_string(),
));
}
let mut rng = match random_seed {
Some(_seed) => StdRng::seed_from_u64(_seed),
None => {
let mut r = thread_rng();
StdRng::seed_from_u64(r.next_u64())
}
};
let mut bootstrap_samples = Vec::with_capacity(n_bootstrap_rounds);
for _ in 0..n_bootstrap_rounds {
let sample = random_sample(n_samples, sample_size, true, Some(rng.next_u64()))?;
bootstrap_samples.push(sample);
}
Ok(bootstrap_samples)
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::array;
use scirs2_core::random::Uniform;
use std::collections::HashSet;
#[test]
fn test_random_sample_without_replacement() {
let indices = random_sample(10, 5, false, Some(42)).expect("Operation failed");
assert_eq!(indices.len(), 5);
assert!(indices.iter().all(|&i| i < 10));
let unique_indices: HashSet<_> = indices.iter().cloned().collect();
assert_eq!(unique_indices.len(), 5);
}
#[test]
fn test_random_sample_with_replacement() {
let indices = random_sample(5, 10, true, Some(42)).expect("Operation failed");
assert_eq!(indices.len(), 10);
assert!(indices.iter().all(|&i| i < 5));
let unique_indices: HashSet<_> = indices.iter().cloned().collect();
assert!(unique_indices.len() <= 10);
}
#[test]
fn test_random_sample_invalid_params() {
assert!(random_sample(0, 5, false, None).is_err());
assert!(random_sample(10, 0, false, None).is_err());
assert!(random_sample(5, 10, false, None).is_err());
}
#[test]
fn test_stratified_sample() {
let targets = array![0.0, 0.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0]; let indices = stratified_sample(&targets, 6, Some(42)).expect("Operation failed");
assert_eq!(indices.len(), 6);
let mut class_counts = HashMap::new();
for &idx in &indices {
let class = targets[idx] as i32;
*class_counts.entry(class).or_insert(0) += 1;
}
assert!(class_counts.len() <= 3); }
#[test]
fn test_stratified_sample_insufficient_samples() {
let targets = array![0.0, 1.0]; assert!(stratified_sample(&targets, 4, Some(42)).is_err());
}
#[test]
fn test_importance_sample() {
let weights = array![0.1, 0.1, 0.1, 0.8, 0.9, 1.0]; let indices = importance_sample(&weights, 3, false, Some(42)).expect("Operation failed");
assert_eq!(indices.len(), 3);
assert!(indices.iter().all(|&i| i < 6));
let unique_indices: HashSet<_> = indices.iter().cloned().collect();
assert_eq!(unique_indices.len(), 3);
}
#[test]
fn test_importance_sample_negative_weights() {
let weights = array![0.5, -0.1, 0.3]; assert!(importance_sample(&weights, 2, false, None).is_err());
}
#[test]
fn test_importance_sample_zero_weights() {
let weights = array![0.0, 0.0, 0.0]; assert!(importance_sample(&weights, 2, false, None).is_err());
}
#[test]
fn test_bootstrap_sample() {
let indices = bootstrap_sample(20, 20, Some(42)).expect("Operation failed");
assert_eq!(indices.len(), 20);
assert!(indices.iter().all(|&i| i < 20));
let unique_indices: HashSet<_> = indices.iter().cloned().collect();
assert!(unique_indices.len() < 20); }
#[test]
fn test_multiple_bootstrap_samples() {
let samples = multiple_bootstrap_samples(10, 8, 5, Some(42)).expect("Operation failed");
assert_eq!(samples.len(), 5);
assert!(samples.iter().all(|sample| sample.len() == 8));
assert!(samples.iter().all(|sample| sample.iter().all(|&i| i < 10)));
assert_ne!(samples[0], samples[1]); }
#[test]
fn test_multiple_bootstrap_samples_invalid_params() {
assert!(multiple_bootstrap_samples(10, 10, 0, None).is_err());
}
}