use crate::error::{Result, Error};
use crate::dataframe::DataFrame;
use rand::Rng;
use rand::seq::SliceRandom;
use crate::utils::rand_compat::{thread_rng, GenRangeCompat};
use std::collections::HashMap;
pub(crate) fn sample_impl(
df: &DataFrame,
fraction: f64,
replace: bool
) -> Result<DataFrame> {
if fraction <= 0.0 || fraction > 1.0 {
return Err(Error::InvalidInput("Sampling fraction must be between 0 and 1".into()));
}
let n_rows = df.nrows();
if n_rows == 0 {
return Err(Error::EmptyData("Cannot sample from empty DataFrame".into()));
}
let sample_size = (n_rows as f64 * fraction).ceil() as usize;
let mut indices = Vec::with_capacity(sample_size);
let mut rng = thread_rng();
if replace {
for _ in 0..sample_size {
indices.push(rng.gen_range(0..n_rows));
}
} else {
if sample_size > n_rows {
return Err(Error::InvalidInput(
format!("Sample size ({}) cannot be larger than DataFrame size ({}) when sampling without replacement",
sample_size, n_rows)
));
}
let mut all_indices: Vec<usize> = (0..n_rows).collect();
all_indices.shuffle(&mut rng);
indices = all_indices.into_iter().take(sample_size).collect();
}
let mut sample_df = DataFrame::new();
for col_name in df.columns() {
let col = df.get_column(&col_name)?;
let sampled_col = col.sample(&indices)?;
sample_df.add_column(col_name.clone(), sampled_col)?;
}
Ok(sample_df)
}
pub(crate) fn bootstrap_impl(
data: &[f64],
n_samples: usize
) -> Result<Vec<Vec<f64>>> {
if data.is_empty() {
return Err(Error::EmptyData("Cannot bootstrap from empty data".into()));
}
if n_samples == 0 {
return Err(Error::InvalidInput("Number of bootstrap samples must be greater than 0".into()));
}
let n = data.len();
let mut rng = thread_rng();
let mut bootstrap_samples = Vec::with_capacity(n_samples);
for _ in 0..n_samples {
let mut sample = Vec::with_capacity(n);
for _ in 0..n {
let idx = rng.gen_range(0..n);
sample.push(data[idx]);
}
bootstrap_samples.push(sample);
}
Ok(bootstrap_samples)
}
pub fn stratified_sample(
df: &DataFrame,
strata_column: &str,
fractions: &HashMap<String, f64>,
replace: bool
) -> Result<DataFrame> {
if !df.has_column(strata_column) {
return Err(Error::InvalidColumn(format!("Column '{}' does not exist", strata_column)));
}
let strata_series = df.get_column(strata_column)?;
let strata_values = strata_series.as_str()?;
for (stratum, &fraction) in fractions {
if fraction <= 0.0 || fraction > 1.0 {
return Err(Error::InvalidInput(
format!("Sampling fraction for stratum '{}' must be between 0 and 1", stratum)
));
}
}
let mut strata_indices: HashMap<String, Vec<usize>> = HashMap::new();
for (i, stratum) in strata_values.iter().enumerate() {
strata_indices.entry(stratum.to_string())
.or_insert_with(Vec::new)
.push(i);
}
let mut all_sampled_indices = Vec::new();
let mut rng = thread_rng();
for (stratum, indices) in strata_indices {
if !fractions.contains_key(&stratum) {
continue;
}
let fraction = fractions[&stratum];
let n_rows = indices.len();
let sample_size = (n_rows as f64 * fraction).ceil() as usize;
if n_rows == 0 {
continue;
}
if !replace && sample_size > n_rows {
return Err(Error::InvalidInput(
format!("Sample size ({}) cannot be larger than stratum size ({}) when sampling without replacement",
sample_size, n_rows)
));
}
if replace {
for _ in 0..sample_size {
let idx = indices[rng.gen_range(0..n_rows)];
all_sampled_indices.push(idx);
}
} else {
let mut stratum_indices = indices.clone();
stratum_indices.shuffle(&mut rng);
for i in 0..sample_size {
all_sampled_indices.push(stratum_indices[i]);
}
}
}
if all_sampled_indices.is_empty() {
return Err(Error::InvalidInput("No rows were sampled. Check if strata names match.".into()));
}
let mut sample_df = DataFrame::new();
for col_name in df.columns() {
let col = df.get_column(&col_name)?;
let sampled_col = col.sample(&all_sampled_indices)?;
sample_df.add_column(col_name.clone(), sampled_col)?;
}
Ok(sample_df)
}
pub fn bootstrap_confidence_interval<F>(
data: &[f64],
n_samples: usize,
statistic_fn: &F,
confidence_level: f64
) -> Result<(f64, f64)>
where
F: Fn(&[f64]) -> f64
{
if data.is_empty() {
return Err(Error::EmptyData("Bootstrap requires data".into()));
}
if n_samples < 100 {
return Err(Error::InvalidInput("Recommended to use at least 100 bootstrap samples".into()));
}
if confidence_level <= 0.0 || confidence_level >= 1.0 {
return Err(Error::InvalidInput("Confidence level must be between 0 and 1".into()));
}
let bootstrap_samples = bootstrap_impl(data, n_samples)?;
let mut bootstrap_statistics = Vec::with_capacity(n_samples);
for sample in bootstrap_samples {
let stat = statistic_fn(&sample);
bootstrap_statistics.push(stat);
}
bootstrap_statistics.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let alpha = 1.0 - confidence_level;
let lower_index = (alpha / 2.0 * n_samples as f64).ceil() as usize;
let upper_index = (n_samples as f64 - alpha / 2.0 * n_samples as f64).floor() as usize - 1;
let lower_bound = bootstrap_statistics.get(lower_index)
.ok_or_else(|| Error::ComputationError("Failed to compute lower bound".into()))?;
let upper_bound = bootstrap_statistics.get(upper_index)
.ok_or_else(|| Error::ComputationError("Failed to compute upper bound".into()))?;
Ok((*lower_bound, *upper_bound))
}
pub fn systematic_sample<T: Clone>(
data: &[T],
k: usize,
offset: usize
) -> Result<Vec<T>> {
if data.is_empty() {
return Err(Error::EmptyData("Cannot sample from empty data".into()));
}
if k == 0 {
return Err(Error::InvalidInput("Sampling interval must be greater than 0".into()));
}
if offset >= k {
return Err(Error::InvalidInput(
format!("Offset must be between 0 and k-1 (k={})", k)
));
}
let mut sample = Vec::new();
let mut i = offset;
while i < data.len() {
sample.push(data[i].clone());
i += k;
}
if sample.is_empty() {
return Err(Error::InvalidInput("No samples were selected".into()));
}
Ok(sample)
}
pub fn weighted_sample<T: Clone>(
data: &[T],
weights: &[f64],
size: usize,
replace: bool
) -> Result<Vec<T>> {
if data.is_empty() {
return Err(Error::EmptyData("Cannot sample from empty data".into()));
}
if weights.len() != data.len() {
return Err(Error::DimensionMismatch(
format!("Weights length ({}) must match data length ({})", weights.len(), data.len())
));
}
for &w in weights {
if w < 0.0 {
return Err(Error::InvalidInput("Weights must be non-negative".into()));
}
}
if weights.iter().all(|&w| w == 0.0) {
return Err(Error::InvalidInput("At least one weight must be positive".into()));
}
let sum_weights: f64 = weights.iter().sum();
let mut cum_weights = Vec::with_capacity(weights.len());
let mut cumulative = 0.0;
for &w in weights {
cumulative += w / sum_weights;
cum_weights.push(cumulative);
}
if let Some(last) = cum_weights.last_mut() {
*last = 1.0;
}
let mut rng = thread_rng();
let mut sample = Vec::with_capacity(size);
let mut used_indices = std::collections::HashSet::new();
for _ in 0..size {
if !replace && used_indices.len() == data.len() {
break;
}
let r = rng.gen::<f64>();
let mut selected_idx = 0;
while selected_idx < cum_weights.len() - 1 && r > cum_weights[selected_idx] {
selected_idx += 1;
}
if !replace {
if used_indices.contains(&selected_idx) {
let mut found = false;
for i in (selected_idx + 1)..data.len() {
if !used_indices.contains(&i) {
selected_idx = i;
found = true;
break;
}
}
if !found {
for i in 0..selected_idx {
if !used_indices.contains(&i) {
selected_idx = i;
found = true;
break;
}
}
}
if !found {
break;
}
}
used_indices.insert(selected_idx);
}
sample.push(data[selected_idx].clone());
}
if !replace && size > data.len() {
}
if sample.is_empty() {
return Err(Error::InvalidInput("No samples were selected".into()));
}
Ok(sample)
}
pub fn bootstrap_standard_error<F>(
data: &[f64],
n_samples: usize,
statistic_fn: &F
) -> Result<f64>
where
F: Fn(&[f64]) -> f64
{
if data.is_empty() {
return Err(Error::EmptyData("Bootstrap requires data".into()));
}
if n_samples < 100 {
return Err(Error::InvalidInput("Recommended to use at least 100 bootstrap samples".into()));
}
let bootstrap_samples = bootstrap_impl(data, n_samples)?;
let mut bootstrap_statistics = Vec::with_capacity(n_samples);
for sample in bootstrap_samples {
let stat = statistic_fn(&sample);
bootstrap_statistics.push(stat);
}
let mean: f64 = bootstrap_statistics.iter().sum::<f64>() / n_samples as f64;
let variance: f64 = bootstrap_statistics.iter()
.map(|&x| (x - mean).powi(2))
.sum::<f64>() / n_samples as f64;
let std_error = variance.sqrt();
Ok(std_error)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dataframe::DataFrame;
use crate::series::Series;
#[test]
fn test_sample_impl() {
let mut df = DataFrame::new();
let values = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
df.add_column("values".to_string(),
Series::new(values, Some("values".to_string())).expect("operation should succeed")).expect("operation should succeed");
let sample_with_replacement = sample_impl(&df, 0.5, true).expect("operation should succeed");
assert_eq!(sample_with_replacement.nrows(), 5);
let sample_without_replacement = sample_impl(&df, 0.5, false).expect("operation should succeed");
assert_eq!(sample_without_replacement.nrows(), 5);
let result_invalid = sample_impl(&df, 1.5, false);
assert!(result_invalid.is_err());
}
#[test]
fn test_bootstrap_impl() {
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let bootstrap_samples = bootstrap_impl(&data, 10).expect("operation should succeed");
assert_eq!(bootstrap_samples.len(), 10);
for sample in bootstrap_samples {
assert_eq!(sample.len(), data.len());
for &value in &sample {
assert!(data.contains(&value));
}
}
let result_empty = bootstrap_impl(&[], 10);
assert!(result_empty.is_err());
let result_zero = bootstrap_impl(&data, 0);
assert!(result_zero.is_err());
}
#[test]
fn test_stratified_sample() {
let mut df = DataFrame::new();
let strata = vec!["A", "A", "A", "A", "B", "B", "B", "B", "C", "C"]
.into_iter().map(|s| s.to_string()).collect::<Vec<_>>();
let values = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
df.add_column("strata".to_string(),
Series::new(strata, Some("strata".to_string())).expect("operation should succeed")).expect("operation should succeed");
df.add_column("values".to_string(),
Series::new(values, Some("values".to_string())).expect("operation should succeed")).expect("operation should succeed");
let mut fractions = HashMap::new();
fractions.insert("A".to_string(), 0.5); fractions.insert("B".to_string(), 0.75); fractions.insert("C".to_string(), 1.0);
let stratified_sample = stratified_sample(&df, "strata", &fractions, false).expect("operation should succeed");
let sampled_strata = stratified_sample.get_column("strata").expect("operation should succeed");
let strata_values = sampled_strata.as_str()
.ok_or_else(|| Error::TypeMismatch("expected string column for strata".into()))?;
let mut counts = HashMap::new();
for stratum in strata_values {
*counts.entry(stratum.to_string()).or_insert(0) += 1;
}
assert_eq!(*counts.get("A").unwrap_or(&0), 2);
assert_eq!(*counts.get("B").unwrap_or(&0), 3);
assert_eq!(*counts.get("C").unwrap_or(&0), 2);
}
#[test]
fn test_systematic_sample() {
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
let sys_sample = systematic_sample(&data, 3, 0).expect("operation should succeed");
assert_eq!(sys_sample, vec![1.0, 4.0, 7.0, 10.0]);
let sys_sample_offset = systematic_sample(&data, 3, 1).expect("operation should succeed");
assert_eq!(sys_sample_offset, vec![2.0, 5.0, 8.0]);
let result_invalid_k = systematic_sample(&data, 0, 0);
assert!(result_invalid_k.is_err());
let result_invalid_offset = systematic_sample(&data, 3, 3);
assert!(result_invalid_offset.is_err());
}
#[test]
fn test_weighted_sample() {
let data = vec![1, 2, 3, 4, 5];
let equal_weights = vec![1.0, 1.0, 1.0, 1.0, 1.0];
let sample_equal = weighted_sample(&data, &equal_weights, 10, true).expect("operation should succeed");
assert_eq!(sample_equal.len(), 10);
let biased_weights = vec![1.0, 0.0, 0.0, 0.0, 0.0];
let sample_biased = weighted_sample(&data, &biased_weights, 5, true).expect("operation should succeed");
assert_eq!(sample_biased, vec![1, 1, 1, 1, 1]);
let sample_no_replace = weighted_sample(&data, &equal_weights, 5, false).expect("operation should succeed");
assert_eq!(sample_no_replace.len(), 5);
let mut unique_items = std::collections::HashSet::new();
for item in &sample_no_replace {
unique_items.insert(item);
}
assert_eq!(unique_items.len(), 5);
let invalid_weights = vec![1.0, 1.0]; let result_invalid = weighted_sample(&data, &invalid_weights, 3, true);
assert!(result_invalid.is_err());
let negative_weights = vec![-1.0, 1.0, 1.0, 1.0, 1.0]; let result_negative = weighted_sample(&data, &negative_weights, 3, true);
assert!(result_negative.is_err());
}
#[test]
fn test_bootstrap_confidence_interval() {
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
let mean_fn = |x: &[f64]| x.iter().sum::<f64>() / x.len() as f64;
let (lower, upper) = bootstrap_confidence_interval(&data, 1000, &mean_fn, 0.95).expect("operation should succeed");
assert!(lower <= 5.5 && upper >= 5.5);
assert!(lower < upper);
let result_empty = bootstrap_confidence_interval(&[], 1000, &mean_fn, 0.95);
assert!(result_empty.is_err());
let result_invalid_cl = bootstrap_confidence_interval(&data, 1000, &mean_fn, 1.1);
assert!(result_invalid_cl.is_err());
}
#[test]
fn test_bootstrap_standard_error() {
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
let mean_fn = |x: &[f64]| x.iter().sum::<f64>() / x.len() as f64;
let std_error = bootstrap_standard_error(&data, 1000, &mean_fn).expect("operation should succeed");
assert!(std_error > 0.0);
assert!((std_error - 0.91).abs() < 0.2);
let result_empty = bootstrap_standard_error(&[], 1000, &mean_fn);
assert!(result_empty.is_err());
}
}