use crate::dataframe::DataFrame;
use crate::error::{Error, Result};
use rand::prelude::*;
use std::collections::HashMap;
pub(crate) fn sample_impl(df: &DataFrame, fraction: f64, replace: bool) -> Result<DataFrame> {
if fraction <= 0.0 {
return Err(Error::InvalidValue(
"Sample rate must be a positive value".into(),
));
}
let n_rows = df.row_count();
if n_rows == 0 {
return Ok(DataFrame::new());
}
let sample_size = (n_rows as f64 * fraction).ceil() as usize;
if !replace && sample_size > n_rows {
return Err(Error::InvalidOperation(
"For sampling without replacement, sample size must not exceed original data size"
.into(),
));
}
let mut rng = rand::rng();
let indices = if replace {
(0..sample_size)
.map(|_| rng.random_range(0..n_rows))
.collect::<Vec<_>>()
} else {
let mut idx: Vec<usize> = (0..n_rows).collect();
idx.shuffle(&mut rng);
idx[0..sample_size].to_vec()
};
let mut result = DataFrame::new();
for col_name in df.column_names() {
if let Ok(col) = df.get_column::<String>(&col_name) {
let sampled_values: Vec<String> = indices
.iter()
.filter_map(|&idx| col.values().get(idx).cloned())
.collect();
if !sampled_values.is_empty() {
let series =
crate::series::Series::new(sampled_values, Some(col_name.clone())).unwrap();
result.add_column(col_name.to_string(), series).unwrap();
}
}
}
Ok(result)
}
pub(crate) fn bootstrap_impl(data: &[f64], n_samples: usize) -> Result<Vec<Vec<f64>>> {
if data.is_empty() {
return Err(Error::EmptyData("Bootstrap requires data".into()));
}
if n_samples == 0 {
return Err(Error::InvalidValue(
"Number of samples must be positive".into(),
));
}
let n = data.len();
let mut rng = rand::rng();
let mut result = Vec::with_capacity(n_samples);
for _ in 0..n_samples {
let sample: Vec<f64> = (0..n).map(|_| data[rng.random_range(0..n)]).collect();
result.push(sample);
}
Ok(result)
}
pub fn stratified_sample_impl(
df: &DataFrame,
strata_column: &str,
fraction: f64,
replace: bool,
) -> Result<DataFrame> {
if !df.contains_column(strata_column) {
return Err(Error::ColumnNotFound(strata_column.to_string()));
}
if fraction <= 0.0 {
return Err(Error::InvalidValue(
"Sample rate must be a positive value".into(),
));
}
let strata_col = match df.get_column::<String>(strata_column) {
Ok(col) => col,
Err(_) => return Err(Error::ColumnNotFound(strata_column.to_string())),
};
let mut strata_values = Vec::new();
for value in strata_col.values() {
if !strata_values.contains(value) {
strata_values.push(value.clone());
}
}
let mut strata_indices: HashMap<String, Vec<usize>> = HashMap::new();
for (i, value) in strata_col.values().iter().enumerate() {
strata_indices
.entry(value.clone())
.or_insert_with(Vec::new)
.push(i);
}
let mut all_sample_indices = Vec::new();
for (_, indices) in strata_indices.iter() {
let sample_size = (indices.len() as f64 * fraction).ceil() as usize;
if sample_size == 0 {
continue;
}
let mut rng = rand::rng();
if replace {
for _ in 0..sample_size {
let idx = indices[rng.random_range(0..indices.len())];
all_sample_indices.push(idx);
}
} else {
if sample_size > indices.len() {
return Err(Error::InvalidOperation(
"For sampling without replacement, sample size must not exceed stratum size"
.into(),
));
}
let mut sampled_indices = indices.clone();
sampled_indices.shuffle(&mut rng);
all_sample_indices.extend_from_slice(&sampled_indices[0..sample_size]);
}
}
all_sample_indices.sort();
let mut result = DataFrame::new();
for col_name in df.column_names() {
if let Ok(col) = df.get_column::<String>(&col_name) {
let sampled_values: Vec<String> = all_sample_indices
.iter()
.filter_map(|&idx| col.values().get(idx).cloned())
.collect();
if !sampled_values.is_empty() {
let series =
crate::series::Series::new(sampled_values, Some(col_name.clone())).unwrap();
result.add_column(col_name.to_string(), series).unwrap();
}
}
}
Ok(result)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dataframe::DataFrame;
use crate::series::Series;
#[test]
fn test_simple_sample() {
let mut df = DataFrame::new();
let data = Series::new(
vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
Some("data".to_string()),
)
.unwrap();
df.add_column("data".to_string(), data).unwrap();
let sample = sample_impl(&df, 0.5, false).unwrap();
assert!(true);
let sample = sample_impl(&df, 0.3, true).unwrap();
assert!(true);
let sample = sample_impl(&df, 2.0, true).unwrap();
assert!(true);
let result = sample_impl(&df, 2.0, false);
assert!(result.is_err());
}
#[test]
fn test_bootstrap() {
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let bootstrap_samples = bootstrap_impl(&data, 10).unwrap();
assert_eq!(bootstrap_samples.len(), 10);
for sample in &bootstrap_samples {
assert_eq!(sample.len(), data.len());
}
for sample in &bootstrap_samples {
for value in sample {
assert!(data.contains(value));
}
}
}
}