use crate::error::{DatasetsError, Result};
use scirs2_core::ndarray::{Array1, Array2};
struct Lcg64 {
state: u64,
}
impl Lcg64 {
fn new(seed: u64) -> Self {
Self {
state: seed.wrapping_add(1),
}
}
fn next_u64(&mut self) -> u64 {
self.state = self
.state
.wrapping_mul(6_364_136_223_846_793_005)
.wrapping_add(1_442_695_040_888_963_407);
self.state
}
fn next_usize(&mut self, n: usize) -> usize {
if n == 0 {
return 0;
}
(self.next_u64() % n as u64) as usize
}
fn next_f64(&mut self) -> f64 {
(self.next_u64() >> 11) as f64 / (1u64 << 53) as f64
}
}
#[non_exhaustive]
#[derive(Debug, Clone, PartialEq, Default)]
pub enum ShardStrategy {
#[default]
Index,
Hash,
Stratified {
label_column: String,
},
Size {
shard_size_bytes: usize,
},
}
#[derive(Debug, Clone)]
pub struct ShardingConfig {
pub n_shards: usize,
pub strategy: ShardStrategy,
pub shuffle: bool,
pub seed: u64,
}
impl Default for ShardingConfig {
fn default() -> Self {
Self {
n_shards: 8,
strategy: ShardStrategy::default(),
shuffle: true,
seed: 42,
}
}
}
#[derive(Debug, Clone)]
pub struct DataShard {
pub shard_id: usize,
pub n_shards: usize,
pub indices: Vec<usize>,
pub is_train: bool,
}
impl DataShard {
pub fn new(
shard_id: usize,
total_shards: usize,
n_samples: usize,
config: &ShardConfig,
) -> Self {
let all_shards = shard_by_index(
n_samples,
total_shards,
config.shuffle,
config.seed.unwrap_or(0),
);
match all_shards.into_iter().find(|s| s.shard_id == shard_id) {
Some(s) => Self {
shard_id: s.shard_id,
n_shards: s.n_shards,
indices: s.indices,
is_train: s.is_train,
},
None => Self {
shard_id,
n_shards: total_shards,
indices: Vec::new(),
is_train: true,
},
}
}
pub fn apply_2d<T: Clone + Default>(&self, data: &Array2<T>) -> Array2<T> {
let n_cols = data.ncols();
let valid_indices: Vec<usize> = self
.indices
.iter()
.copied()
.filter(|&i| i < data.nrows())
.collect();
let n_rows = valid_indices.len();
if n_rows == 0 || n_cols == 0 {
return Array2::default((0, n_cols));
}
let mut flat = Vec::with_capacity(n_rows * n_cols);
for &row_idx in &valid_indices {
flat.extend_from_slice(data.row(row_idx).as_slice().unwrap_or(&[]));
}
if flat.len() != n_rows * n_cols {
flat.clear();
for &row_idx in &valid_indices {
for col in 0..n_cols {
flat.push(data[[row_idx, col]].clone());
}
}
}
Array2::from_shape_vec((n_rows, n_cols), flat)
.unwrap_or_else(|_| Array2::default((0, n_cols)))
}
pub fn apply_1d<T: Clone>(&self, data: &Array1<T>) -> Array1<T> {
let selected: Vec<T> = self
.indices
.iter()
.copied()
.filter(|&i| i < data.len())
.map(|i| data[i].clone())
.collect();
Array1::from_vec(selected)
}
pub fn len(&self) -> usize {
self.indices.len()
}
pub fn is_empty(&self) -> bool {
self.indices.is_empty()
}
}
#[derive(Debug, Clone)]
pub struct ShardConfig {
pub n_shards: usize,
pub shuffle: bool,
pub seed: Option<u64>,
}
#[derive(Debug, Clone)]
pub struct ShardedDataset {
pub shards: Vec<DataShard>,
pub total_size: usize,
pub config: ShardingConfig,
}
pub fn consistent_shuffle(n: usize, seed: u64) -> Vec<usize> {
let mut indices: Vec<usize> = (0..n).collect();
let mut rng = Lcg64::new(seed);
for i in (1..n).rev() {
let j = rng.next_usize(i + 1);
indices.swap(i, j);
}
indices
}
pub fn shard_by_index(
n_samples: usize,
n_shards: usize,
shuffle: bool,
seed: u64,
) -> Vec<DataShard> {
if n_shards == 0 || n_samples == 0 {
return Vec::new();
}
let indices = if shuffle {
consistent_shuffle(n_samples, seed)
} else {
(0..n_samples).collect()
};
let base = n_samples / n_shards;
let remainder = n_samples % n_shards;
let mut shards = Vec::with_capacity(n_shards);
let mut offset = 0usize;
for shard_id in 0..n_shards {
let extra = if shard_id < remainder { 1 } else { 0 };
let size = base + extra;
let shard_indices = indices[offset..offset + size].to_vec();
shards.push(DataShard {
shard_id,
n_shards,
indices: shard_indices,
is_train: true,
});
offset += size;
}
shards
}
pub fn shard_by_hash(n_samples: usize, n_shards: usize) -> Vec<DataShard> {
if n_shards == 0 || n_samples == 0 {
return Vec::new();
}
let mut buckets: Vec<Vec<usize>> = vec![Vec::new(); n_shards];
for i in 0..n_samples {
buckets[i % n_shards].push(i);
}
buckets
.into_iter()
.enumerate()
.map(|(shard_id, indices)| DataShard {
shard_id,
n_shards,
indices,
is_train: true,
})
.collect()
}
pub fn shard_stratified(
labels: &[usize],
n_shards: usize,
shuffle: bool,
seed: u64,
) -> Vec<DataShard> {
if n_shards == 0 || labels.is_empty() {
return Vec::new();
}
let max_class = labels.iter().copied().max().unwrap_or(0);
let mut class_indices: Vec<Vec<usize>> = vec![Vec::new(); max_class + 1];
for (i, &label) in labels.iter().enumerate() {
class_indices[label].push(i);
}
if shuffle {
for (cls, indices) in class_indices.iter_mut().enumerate() {
let class_seed = seed.wrapping_add(cls as u64 * 0x9e37_79b9_7f4a_7c15);
let shuffled = consistent_shuffle(indices.len(), class_seed);
let original = indices.clone();
for (new_pos, &old_pos) in shuffled.iter().enumerate() {
indices[new_pos] = original[old_pos];
}
}
}
let mut buckets: Vec<Vec<usize>> = vec![Vec::new(); n_shards];
for class_idx in class_indices {
for (pos, sample_idx) in class_idx.into_iter().enumerate() {
buckets[pos % n_shards].push(sample_idx);
}
}
buckets
.into_iter()
.enumerate()
.map(|(shard_id, indices)| DataShard {
shard_id,
n_shards,
indices,
is_train: true,
})
.collect()
}
impl ShardedDataset {
pub fn new(n_samples: usize, config: ShardingConfig) -> Result<Self> {
if config.n_shards == 0 {
return Err(DatasetsError::InvalidFormat("n_shards must be >= 1".into()));
}
if n_samples == 0 {
return Err(DatasetsError::InvalidFormat(
"n_samples must be >= 1".into(),
));
}
let shards = match &config.strategy {
ShardStrategy::Index => {
shard_by_index(n_samples, config.n_shards, config.shuffle, config.seed)
}
ShardStrategy::Hash => shard_by_hash(n_samples, config.n_shards),
ShardStrategy::Stratified { .. } => {
return Err(DatasetsError::InvalidFormat(
"Use ShardedDataset::new_stratified for Stratified strategy".into(),
));
}
ShardStrategy::Size { shard_size_bytes } => {
let _ = shard_size_bytes; shard_by_index(n_samples, config.n_shards, config.shuffle, config.seed)
}
};
Ok(Self {
shards,
total_size: n_samples,
config,
})
}
pub fn new_stratified(labels: &[usize], config: ShardingConfig) -> Result<Self> {
if config.n_shards == 0 {
return Err(DatasetsError::InvalidFormat("n_shards must be >= 1".into()));
}
if labels.is_empty() {
return Err(DatasetsError::InvalidFormat(
"labels must not be empty".into(),
));
}
let shards = shard_stratified(labels, config.n_shards, config.shuffle, config.seed);
let total_size = labels.len();
Ok(Self {
shards,
total_size,
config,
})
}
pub fn get_shard(&self, shard_id: usize) -> Option<&DataShard> {
self.shards.get(shard_id)
}
pub fn train_shards(&self, val_fraction: f64) -> (Vec<usize>, Vec<usize>) {
let n = self.shards.len();
if n == 0 {
return (Vec::new(), Vec::new());
}
let n_val = ((n as f64 * val_fraction).ceil() as usize).min(n);
let n_train = n - n_val;
let train_ids: Vec<usize> = (0..n_train).collect();
let val_ids: Vec<usize> = (n_train..n).collect();
(train_ids, val_ids)
}
pub fn shard_iter(&self, shard_id: usize) -> impl Iterator<Item = usize> + '_ {
let slice: &[usize] = match self.shards.get(shard_id) {
Some(shard) => &shard.indices,
None => &[],
};
slice.iter().copied()
}
pub fn n_shards(&self) -> usize {
self.shards.len()
}
pub fn total_samples(&self) -> usize {
self.shards.iter().map(|s| s.indices.len()).sum()
}
}
#[derive(Debug, Clone)]
pub struct DatasetShard {
pub shard_id: usize,
pub total_shards: usize,
pub indices: Vec<usize>,
pub data: Vec<Vec<f64>>,
pub labels: Vec<usize>,
}
impl DatasetShard {
pub fn len(&self) -> usize {
self.indices.len()
}
pub fn is_empty(&self) -> bool {
self.indices.is_empty()
}
pub fn apply_f64(&self, data: &[Vec<f64>]) -> Vec<Vec<f64>> {
self.indices
.iter()
.filter(|&&i| i < data.len())
.map(|&i| data[i].clone())
.collect()
}
pub fn apply_labels(&self, labels: &[usize]) -> Vec<usize> {
self.indices
.iter()
.filter(|&&i| i < labels.len())
.map(|&i| labels[i])
.collect()
}
}
#[derive(Debug, Clone)]
pub struct ShardedLoader {
pub total_samples: usize,
pub n_shards: usize,
pub seed: u64,
}
impl ShardedLoader {
pub fn new(total_samples: usize, n_shards: usize, seed: u64) -> Self {
Self {
total_samples,
n_shards,
seed,
}
}
pub fn global_permutation(&self) -> Vec<usize> {
consistent_shuffle(self.total_samples, self.seed)
}
pub fn get_shard(&self, shard_id: usize) -> DatasetShard {
if self.n_shards == 0 || self.total_samples == 0 || shard_id >= self.n_shards {
return DatasetShard {
shard_id,
total_shards: self.n_shards,
indices: Vec::new(),
data: Vec::new(),
labels: Vec::new(),
};
}
let permuted = self.global_permutation();
let base = self.total_samples / self.n_shards;
let remainder = self.total_samples % self.n_shards;
let mut offset = 0usize;
for id in 0..shard_id {
let extra = if id < remainder { 1 } else { 0 };
offset += base + extra;
}
let extra = if shard_id < remainder { 1 } else { 0 };
let size = base + extra;
let indices = permuted[offset..offset + size].to_vec();
DatasetShard {
shard_id,
total_shards: self.n_shards,
indices,
data: Vec::new(),
labels: Vec::new(),
}
}
pub fn verify_coverage(&self) -> bool {
if self.n_shards == 0 || self.total_samples == 0 {
return self.total_samples == 0;
}
let mut seen = vec![false; self.total_samples];
for shard_id in 0..self.n_shards {
let shard = self.get_shard(shard_id);
for &idx in &shard.indices {
if idx >= self.total_samples || seen[idx] {
return false;
}
seen[idx] = true;
}
}
seen.iter().all(|&v| v)
}
}
pub fn shard_dataset(
data: &[Vec<f64>],
labels: &[usize],
n_shards: usize,
seed: u64,
) -> Result<Vec<DatasetShard>> {
let n = data.len();
if n != labels.len() {
return Err(DatasetsError::InvalidFormat(format!(
"data length ({}) != labels length ({})",
n,
labels.len()
)));
}
if n_shards == 0 {
return Err(DatasetsError::InvalidFormat("n_shards must be >= 1".into()));
}
if n == 0 {
return Ok(Vec::new());
}
let index_shards = shard_by_index(n, n_shards, true, seed);
Ok(build_dataset_shards(data, labels, &index_shards))
}
pub fn stratified_shard(
data: &[Vec<f64>],
labels: &[usize],
n_shards: usize,
) -> Result<Vec<DatasetShard>> {
let n = data.len();
if n != labels.len() {
return Err(DatasetsError::InvalidFormat(format!(
"data length ({}) != labels length ({})",
n,
labels.len()
)));
}
if n_shards == 0 {
return Err(DatasetsError::InvalidFormat("n_shards must be >= 1".into()));
}
if n == 0 {
return Ok(Vec::new());
}
let index_shards = shard_stratified(labels, n_shards, false, 0);
Ok(build_dataset_shards(data, labels, &index_shards))
}
pub fn shuffled_shard(
data: &[Vec<f64>],
labels: &[usize],
n_shards: usize,
seed: u64,
) -> Result<Vec<DatasetShard>> {
shard_dataset(data, labels, n_shards, seed)
}
pub fn merge_shards(shards: &[DatasetShard]) -> (Vec<Vec<f64>>, Vec<usize>) {
if shards.is_empty() {
return (Vec::new(), Vec::new());
}
let mut entries: Vec<(usize, &Vec<f64>, usize)> = Vec::new();
for shard in shards {
for (pos, &idx) in shard.indices.iter().enumerate() {
entries.push((idx, &shard.data[pos], shard.labels[pos]));
}
}
entries.sort_by_key(|(idx, _, _)| *idx);
let data: Vec<Vec<f64>> = entries.iter().map(|(_, d, _)| (*d).clone()).collect();
let labels: Vec<usize> = entries.iter().map(|(_, _, l)| *l).collect();
(data, labels)
}
fn build_dataset_shards(
data: &[Vec<f64>],
labels: &[usize],
index_shards: &[DataShard],
) -> Vec<DatasetShard> {
index_shards
.iter()
.map(|is| {
let shard_data: Vec<Vec<f64>> = is.indices.iter().map(|&i| data[i].clone()).collect();
let shard_labels: Vec<usize> = is.indices.iter().map(|&i| labels[i]).collect();
DatasetShard {
shard_id: is.shard_id,
total_shards: is.n_shards,
indices: is.indices.clone(),
data: shard_data,
labels: shard_labels,
}
})
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_shard_by_index_no_shuffle() {
let shards = shard_by_index(100, 4, false, 0);
assert_eq!(shards.len(), 4);
for shard in &shards {
assert_eq!(shard.indices.len(), 25);
}
let mut seen = [false; 100];
for shard in &shards {
for &i in &shard.indices {
assert!(!seen[i], "index {i} appears twice");
seen[i] = true;
}
}
assert!(seen.iter().all(|&v| v));
}
#[test]
fn test_shard_by_index_shuffle() {
let shards = shard_by_index(100, 4, true, 42);
assert_eq!(shards.len(), 4);
let total: usize = shards.iter().map(|s| s.len()).sum();
assert_eq!(total, 100);
}
#[test]
fn test_consistent_shuffle_determinism() {
let a = consistent_shuffle(50, 12345);
let b = consistent_shuffle(50, 12345);
assert_eq!(a, b);
let c = consistent_shuffle(50, 99999);
assert_ne!(a, c);
}
#[test]
fn test_consistent_shuffle_permutation() {
let n = 200;
let shuffled = consistent_shuffle(n, 7);
assert_eq!(shuffled.len(), n);
let mut sorted = shuffled.clone();
sorted.sort_unstable();
assert_eq!(sorted, (0..n).collect::<Vec<_>>());
}
#[test]
fn test_shard_by_hash() {
let shards = shard_by_hash(100, 4);
assert_eq!(shards.len(), 4);
assert!(shards[0].indices.iter().all(|&i| i % 4 == 0));
let total: usize = shards.iter().map(|s| s.len()).sum();
assert_eq!(total, 100);
}
#[test]
fn test_stratified_class_proportions() {
let mut labels = vec![0usize; 30];
labels.extend(vec![1usize; 20]);
let shards = shard_stratified(&labels, 5, false, 0);
assert_eq!(shards.len(), 5);
for shard in &shards {
assert_eq!(shard.indices.len(), 10);
}
}
#[test]
fn test_sharded_dataset_new() {
let config = ShardingConfig {
n_shards: 4,
strategy: ShardStrategy::Index,
shuffle: false,
seed: 0,
};
let ds = ShardedDataset::new(100, config).expect("should succeed");
assert_eq!(ds.n_shards(), 4);
assert_eq!(ds.total_samples(), 100);
}
#[test]
fn test_train_shards_split() {
let config = ShardingConfig {
n_shards: 8,
strategy: ShardStrategy::Index,
shuffle: false,
seed: 0,
};
let ds = ShardedDataset::new(80, config).expect("should succeed");
let (train, val) = ds.train_shards(0.25);
assert_eq!(train.len() + val.len(), 8);
assert_eq!(val.len(), 2); }
#[test]
fn test_shard_iter() {
let config = ShardingConfig {
n_shards: 4,
strategy: ShardStrategy::Index,
shuffle: false,
seed: 0,
};
let ds = ShardedDataset::new(40, config).expect("should succeed");
let collected: Vec<usize> = ds.shard_iter(0).collect();
assert_eq!(collected.len(), 10);
assert_eq!(collected, (0..10).collect::<Vec<_>>());
}
#[test]
fn test_shard_iter_out_of_bounds() {
let config = ShardingConfig::default();
let ds = ShardedDataset::new(10, config).expect("should succeed");
let empty: Vec<usize> = ds.shard_iter(999).collect();
assert!(empty.is_empty());
}
#[test]
fn test_sharded_dataset_invalid_config() {
let bad_config = ShardingConfig {
n_shards: 0,
..Default::default()
};
assert!(ShardedDataset::new(100, bad_config).is_err());
}
#[test]
fn test_shard_id_assignment() {
let shards = shard_by_index(100, 4, false, 0);
for (expected_id, shard) in shards.iter().enumerate() {
assert_eq!(shard.shard_id, expected_id);
assert_eq!(shard.n_shards, 4);
}
}
#[test]
fn test_stratified_new_stratified() {
let labels: Vec<usize> = (0..60).map(|i| i % 3).collect();
let config = ShardingConfig {
n_shards: 3,
strategy: ShardStrategy::Stratified {
label_column: "class".into(),
},
shuffle: false,
seed: 0,
};
let ds = ShardedDataset::new_stratified(&labels, config).expect("ok");
assert_eq!(ds.n_shards(), 3);
assert_eq!(ds.total_samples(), 60);
}
fn make_test_data(n: usize) -> (Vec<Vec<f64>>, Vec<usize>) {
let data: Vec<Vec<f64>> = (0..n).map(|i| vec![i as f64, (i * 2) as f64]).collect();
let labels: Vec<usize> = (0..n).map(|i| i % 3).collect();
(data, labels)
}
#[test]
fn test_shard_dataset_total_samples() {
let (data, labels) = make_test_data(100);
let shards = shard_dataset(&data, &labels, 4, 42).expect("ok");
assert_eq!(shards.len(), 4);
let total: usize = shards.iter().map(|s| s.len()).sum();
assert_eq!(total, 100);
}
#[test]
fn test_stratified_shard_label_proportions() {
let mut labels = vec![0usize; 60];
labels.extend(vec![1usize; 40]);
let data: Vec<Vec<f64>> = (0..100).map(|i| vec![i as f64]).collect();
let shards = stratified_shard(&data, &labels, 5).expect("ok");
assert_eq!(shards.len(), 5);
for shard in &shards {
let c0 = shard.labels.iter().filter(|&&l| l == 0).count();
let c1 = shard.labels.iter().filter(|&&l| l == 1).count();
assert_eq!(c0, 12, "Expected 12 class-0 per shard, got {c0}");
assert_eq!(c1, 8, "Expected 8 class-1 per shard, got {c1}");
}
}
#[test]
fn test_merge_shards_recovers_data() {
let (data, labels) = make_test_data(50);
let shards = shard_dataset(&data, &labels, 5, 99).expect("ok");
let (merged_data, merged_labels) = merge_shards(&shards);
assert_eq!(merged_data.len(), 50);
assert_eq!(merged_labels.len(), 50);
for i in 0..50 {
assert_eq!(merged_data[i], data[i], "Data mismatch at index {i}");
assert_eq!(merged_labels[i], labels[i], "Label mismatch at index {i}");
}
}
#[test]
fn test_shuffled_shard_determinism() {
let (data, labels) = make_test_data(30);
let s1 = shuffled_shard(&data, &labels, 3, 42).expect("ok");
let s2 = shuffled_shard(&data, &labels, 3, 42).expect("ok");
for (a, b) in s1.iter().zip(s2.iter()) {
assert_eq!(a.indices, b.indices);
}
}
#[test]
fn test_shard_dataset_error_on_mismatch() {
let data = vec![vec![1.0]; 10];
let labels = vec![0; 5];
assert!(shard_dataset(&data, &labels, 2, 0).is_err());
}
#[test]
fn test_merge_empty_shards() {
let (data, labels) = merge_shards(&[]);
assert!(data.is_empty());
assert!(labels.is_empty());
}
#[test]
fn test_sharded_loader_verify_coverage() {
let loader = ShardedLoader::new(100, 4, 42);
assert!(
loader.verify_coverage(),
"all 100 samples should be covered"
);
}
#[test]
fn test_sharded_loader_balanced_sizes() {
let loader = ShardedLoader::new(101, 4, 7); let sizes: Vec<usize> = (0..4).map(|id| loader.get_shard(id).len()).collect();
let min_size = *sizes.iter().min().expect("non-empty");
let max_size = *sizes.iter().max().expect("non-empty");
assert!(
max_size - min_size <= 1,
"shard sizes differ by more than 1: {sizes:?}"
);
let total: usize = sizes.iter().sum();
assert_eq!(total, 101, "total should equal n_samples");
}
#[test]
fn test_sharded_loader_disjoint_shards() {
let loader = ShardedLoader::new(100, 4, 99);
let shard0 = loader.get_shard(0);
let shard1 = loader.get_shard(1);
for &i in &shard0.indices {
assert!(
!shard1.indices.contains(&i),
"index {i} appears in both shard 0 and shard 1"
);
}
}
#[test]
fn test_sharded_loader_same_seed_same_permutation() {
let loader = ShardedLoader::new(100, 4, 12345);
let p1 = loader.global_permutation();
let p2 = loader.global_permutation();
assert_eq!(p1, p2, "same seed should give same permutation");
let loader2 = ShardedLoader::new(100, 4, 12345);
let p3 = loader2.global_permutation();
assert_eq!(p1, p3, "independent loader with same seed should match");
}
#[test]
fn test_dataset_shard_apply_f64() {
let data: Vec<Vec<f64>> = (0..100).map(|i| vec![i as f64, (i * 2) as f64]).collect();
let loader = ShardedLoader::new(100, 4, 42);
let shard = loader.get_shard(0);
let subset = shard.apply_f64(&data);
assert_eq!(
subset.len(),
shard.len(),
"apply_f64 should return exactly shard.len() rows"
);
for row in &subset {
assert_eq!(row.len(), 2, "each row should have 2 features");
}
}
#[test]
fn test_dataset_shard_apply_labels() {
let labels: Vec<usize> = (0..100).map(|i| i % 3).collect();
let loader = ShardedLoader::new(100, 4, 42);
let shard = loader.get_shard(2);
let subset = shard.apply_labels(&labels);
assert_eq!(
subset.len(),
shard.len(),
"apply_labels should return exactly shard.len() labels"
);
}
#[test]
fn test_sharded_loader_single_shard_coverage() {
let loader = ShardedLoader::new(50, 1, 0);
assert!(loader.verify_coverage());
let shard = loader.get_shard(0);
assert_eq!(shard.len(), 50);
}
#[test]
fn test_sharded_loader_out_of_range_shard() {
let loader = ShardedLoader::new(100, 4, 42);
let empty_shard = loader.get_shard(99);
assert!(
empty_shard.is_empty(),
"out-of-range shard_id should give empty shard"
);
}
}