Skip to main content

scirs2_datasets/sharding/
mod.rs

1//! Dataset sharding API for distributed and parallel data loading.
2//!
3//! This module provides utilities to split datasets into shards for distributed
4//! training, cross-validation, and parallel processing workflows.
5//!
6//! ## Key types
7//!
8//! - [`ShardingConfig`] — configuration driving how sharding is performed.
9//! - [`ShardStrategy`] — enumeration of available sharding strategies.
10//! - [`DataShard`] — a single shard containing a set of sample indices.
11//! - [`ShardedDataset`] — the complete collection of shards over a dataset.
12
13use crate::error::{DatasetsError, Result};
14use scirs2_core::ndarray::{Array1, Array2};
15
16// ─────────────────────────────────────────────────────────────────────────────
17// LCG helpers (avoids pulling in the `rand` crate)
18// ─────────────────────────────────────────────────────────────────────────────
19
20/// Minimal 64-bit LCG (Knuth parameters).
21struct Lcg64 {
22    state: u64,
23}
24
25impl Lcg64 {
26    fn new(seed: u64) -> Self {
27        Self {
28            state: seed.wrapping_add(1),
29        }
30    }
31
32    /// Advance the state and return the next pseudo-random `u64`.
33    fn next_u64(&mut self) -> u64 {
34        self.state = self
35            .state
36            .wrapping_mul(6_364_136_223_846_793_005)
37            .wrapping_add(1_442_695_040_888_963_407);
38        self.state
39    }
40
41    /// Return a pseudo-random value in `[0, n)`.
42    fn next_usize(&mut self, n: usize) -> usize {
43        if n == 0 {
44            return 0;
45        }
46        (self.next_u64() % n as u64) as usize
47    }
48
49    /// Return a pseudo-random `f64` in `[0, 1)`.
50    fn next_f64(&mut self) -> f64 {
51        (self.next_u64() >> 11) as f64 / (1u64 << 53) as f64
52    }
53}
54
55// ─────────────────────────────────────────────────────────────────────────────
56// Public types
57// ─────────────────────────────────────────────────────────────────────────────
58
59/// Strategy to use when sharding a dataset.
60#[non_exhaustive]
61#[derive(Debug, Clone, PartialEq, Default)]
62pub enum ShardStrategy {
63    /// Divide contiguous index ranges equally (optionally shuffled first).
64    #[default]
65    Index,
66    /// Hash-based assignment: sample `i` → shard `i % n_shards`.
67    Hash,
68    /// Stratified by a categorical label column — preserves class proportions.
69    Stratified {
70        /// Name of the label column (informational; caller must supply labels).
71        label_column: String,
72    },
73    /// Split by approximate shard size in bytes.
74    Size {
75        /// Target size (bytes) per shard.
76        shard_size_bytes: usize,
77    },
78}
79
80/// Configuration for dataset sharding.
81#[derive(Debug, Clone)]
82pub struct ShardingConfig {
83    /// Number of shards to produce.
84    pub n_shards: usize,
85    /// Strategy used to assign samples to shards.
86    pub strategy: ShardStrategy,
87    /// Whether to shuffle indices before partitioning.
88    pub shuffle: bool,
89    /// Seed for the LCG when `shuffle` is `true`.
90    pub seed: u64,
91}
92
93impl Default for ShardingConfig {
94    fn default() -> Self {
95        Self {
96            n_shards: 8,
97            strategy: ShardStrategy::default(),
98            shuffle: true,
99            seed: 42,
100        }
101    }
102}
103
104/// A single shard containing a slice of sample indices.
105#[derive(Debug, Clone)]
106pub struct DataShard {
107    /// Zero-based shard identifier.
108    pub shard_id: usize,
109    /// Total number of shards in the parent [`ShardedDataset`].
110    pub n_shards: usize,
111    /// Sample indices belonging to this shard.
112    pub indices: Vec<usize>,
113    /// Whether this shard is designated as a training shard.
114    pub is_train: bool,
115}
116
117impl DataShard {
118    /// Build a new [`DataShard`] for `shard_id` out of `total_shards` over `n_samples` samples.
119    ///
120    /// When `config.shuffle` is `true` and a seed is provided the global index permutation is
121    /// deterministically computed from that seed, ensuring every shard call with the same
122    /// `(total_shards, n_samples, config)` triple produces consistent, non-overlapping index
123    /// sets.
124    pub fn new(
125        shard_id: usize,
126        total_shards: usize,
127        n_samples: usize,
128        config: &ShardConfig,
129    ) -> Self {
130        let all_shards = shard_by_index(
131            n_samples,
132            total_shards,
133            config.shuffle,
134            config.seed.unwrap_or(0),
135        );
136        // If shard_id is out of range, return an empty shard.
137        match all_shards.into_iter().find(|s| s.shard_id == shard_id) {
138            Some(s) => Self {
139                shard_id: s.shard_id,
140                n_shards: s.n_shards,
141                indices: s.indices,
142                is_train: s.is_train,
143            },
144            None => Self {
145                shard_id,
146                n_shards: total_shards,
147                indices: Vec::new(),
148                is_train: true,
149            },
150        }
151    }
152
153    /// Apply this shard's indices to a 2-D feature matrix.
154    ///
155    /// Returns a new `Array2<T>` containing only the rows selected by this shard,
156    /// in the order given by [`Self::indices`].
157    ///
158    /// # Panics
159    ///
160    /// Does not panic; indices that exceed the data row count are silently skipped.
161    pub fn apply_2d<T: Clone + Default>(&self, data: &Array2<T>) -> Array2<T> {
162        let n_cols = data.ncols();
163        let valid_indices: Vec<usize> = self
164            .indices
165            .iter()
166            .copied()
167            .filter(|&i| i < data.nrows())
168            .collect();
169        let n_rows = valid_indices.len();
170        if n_rows == 0 || n_cols == 0 {
171            return Array2::default((0, n_cols));
172        }
173        let mut flat = Vec::with_capacity(n_rows * n_cols);
174        for &row_idx in &valid_indices {
175            flat.extend_from_slice(data.row(row_idx).as_slice().unwrap_or(&[]));
176        }
177        // If the row wasn't contiguous, fall back to element-wise copy
178        if flat.len() != n_rows * n_cols {
179            flat.clear();
180            for &row_idx in &valid_indices {
181                for col in 0..n_cols {
182                    flat.push(data[[row_idx, col]].clone());
183                }
184            }
185        }
186        Array2::from_shape_vec((n_rows, n_cols), flat)
187            .unwrap_or_else(|_| Array2::default((0, n_cols)))
188    }
189
190    /// Apply this shard's indices to a 1-D target array.
191    ///
192    /// Returns a new `Array1<T>` containing only the elements at positions
193    /// given by [`Self::indices`], in the same order.
194    ///
195    /// # Panics
196    ///
197    /// Does not panic; indices that exceed the data length are silently skipped.
198    pub fn apply_1d<T: Clone>(&self, data: &Array1<T>) -> Array1<T> {
199        let selected: Vec<T> = self
200            .indices
201            .iter()
202            .copied()
203            .filter(|&i| i < data.len())
204            .map(|i| data[i].clone())
205            .collect();
206        Array1::from_vec(selected)
207    }
208
209    /// Number of samples in this shard.
210    pub fn len(&self) -> usize {
211        self.indices.len()
212    }
213
214    /// Returns `true` if this shard contains no samples.
215    pub fn is_empty(&self) -> bool {
216        self.indices.is_empty()
217    }
218}
219
220/// Simple configuration for the [`DataShard::new`] constructor.
221///
222/// Mirrors the subset of [`ShardingConfig`] parameters relevant to splitting.
223#[derive(Debug, Clone)]
224pub struct ShardConfig {
225    /// Total number of shards to produce.
226    pub n_shards: usize,
227    /// Whether to shuffle indices before partitioning.
228    pub shuffle: bool,
229    /// Optional seed for the shuffling LCG.
230    pub seed: Option<u64>,
231}
232
233/// A sharded view over a dataset.
234///
235/// Contains all shards produced according to a [`ShardingConfig`].
236#[derive(Debug, Clone)]
237pub struct ShardedDataset {
238    /// All shards.
239    pub shards: Vec<DataShard>,
240    /// Total number of samples in the underlying dataset.
241    pub total_size: usize,
242    /// Configuration used to build this sharded dataset.
243    pub config: ShardingConfig,
244}
245
246// ─────────────────────────────────────────────────────────────────────────────
247// Sharding functions
248// ─────────────────────────────────────────────────────────────────────────────
249
250/// Perform a Fisher-Yates shuffle of `0..n` using a seeded LCG.
251///
252/// Calling this function with the same `seed` always produces the same ordering.
253pub fn consistent_shuffle(n: usize, seed: u64) -> Vec<usize> {
254    let mut indices: Vec<usize> = (0..n).collect();
255    let mut rng = Lcg64::new(seed);
256    // Fisher-Yates (Knuth) shuffle
257    for i in (1..n).rev() {
258        let j = rng.next_usize(i + 1);
259        indices.swap(i, j);
260    }
261    indices
262}
263
264/// Shard `n_samples` into `n_shards` equal-sized groups by index.
265///
266/// If `shuffle` is `true` the indices are first shuffled with the given `seed`.
267/// The resulting shards cover all indices exactly once.
268pub fn shard_by_index(
269    n_samples: usize,
270    n_shards: usize,
271    shuffle: bool,
272    seed: u64,
273) -> Vec<DataShard> {
274    if n_shards == 0 || n_samples == 0 {
275        return Vec::new();
276    }
277
278    let indices = if shuffle {
279        consistent_shuffle(n_samples, seed)
280    } else {
281        (0..n_samples).collect()
282    };
283
284    let base = n_samples / n_shards;
285    let remainder = n_samples % n_shards;
286
287    let mut shards = Vec::with_capacity(n_shards);
288    let mut offset = 0usize;
289
290    for shard_id in 0..n_shards {
291        let extra = if shard_id < remainder { 1 } else { 0 };
292        let size = base + extra;
293        let shard_indices = indices[offset..offset + size].to_vec();
294        shards.push(DataShard {
295            shard_id,
296            n_shards,
297            indices: shard_indices,
298            is_train: true,
299        });
300        offset += size;
301    }
302
303    shards
304}
305
306/// Shard using consistent hashing: sample `i` always lands in shard `i % n_shards`.
307pub fn shard_by_hash(n_samples: usize, n_shards: usize) -> Vec<DataShard> {
308    if n_shards == 0 || n_samples == 0 {
309        return Vec::new();
310    }
311
312    let mut buckets: Vec<Vec<usize>> = vec![Vec::new(); n_shards];
313    for i in 0..n_samples {
314        buckets[i % n_shards].push(i);
315    }
316
317    buckets
318        .into_iter()
319        .enumerate()
320        .map(|(shard_id, indices)| DataShard {
321            shard_id,
322            n_shards,
323            indices,
324            is_train: true,
325        })
326        .collect()
327}
328
329/// Stratified sharding: distributes each class proportionally across all shards.
330///
331/// `labels` must have length `n_samples` with integer class identifiers.
332/// The caller controls shuffling and seeding.
333pub fn shard_stratified(
334    labels: &[usize],
335    n_shards: usize,
336    shuffle: bool,
337    seed: u64,
338) -> Vec<DataShard> {
339    if n_shards == 0 || labels.is_empty() {
340        return Vec::new();
341    }
342
343    // Group indices by class.
344    let max_class = labels.iter().copied().max().unwrap_or(0);
345    let mut class_indices: Vec<Vec<usize>> = vec![Vec::new(); max_class + 1];
346    for (i, &label) in labels.iter().enumerate() {
347        class_indices[label].push(i);
348    }
349
350    // Optionally shuffle within each class using a per-class seed.
351    if shuffle {
352        for (cls, indices) in class_indices.iter_mut().enumerate() {
353            let class_seed = seed.wrapping_add(cls as u64 * 0x9e37_79b9_7f4a_7c15);
354            let shuffled = consistent_shuffle(indices.len(), class_seed);
355            let original = indices.clone();
356            for (new_pos, &old_pos) in shuffled.iter().enumerate() {
357                indices[new_pos] = original[old_pos];
358            }
359        }
360    }
361
362    // Build shard buckets.
363    let mut buckets: Vec<Vec<usize>> = vec![Vec::new(); n_shards];
364    for class_idx in class_indices {
365        // Round-robin assignment within this class.
366        for (pos, sample_idx) in class_idx.into_iter().enumerate() {
367            buckets[pos % n_shards].push(sample_idx);
368        }
369    }
370
371    buckets
372        .into_iter()
373        .enumerate()
374        .map(|(shard_id, indices)| DataShard {
375            shard_id,
376            n_shards,
377            indices,
378            is_train: true,
379        })
380        .collect()
381}
382
383// ─────────────────────────────────────────────────────────────────────────────
384// ShardedDataset impl
385// ─────────────────────────────────────────────────────────────────────────────
386
387impl ShardedDataset {
388    /// Build a [`ShardedDataset`] from a dataset of `n_samples` samples and the
389    /// given configuration.
390    ///
391    /// For [`ShardStrategy::Stratified`] use [`ShardedDataset::new_stratified`]
392    /// instead, because labels must be supplied externally.
393    pub fn new(n_samples: usize, config: ShardingConfig) -> Result<Self> {
394        if config.n_shards == 0 {
395            return Err(DatasetsError::InvalidFormat("n_shards must be >= 1".into()));
396        }
397        if n_samples == 0 {
398            return Err(DatasetsError::InvalidFormat(
399                "n_samples must be >= 1".into(),
400            ));
401        }
402
403        let shards = match &config.strategy {
404            ShardStrategy::Index => {
405                shard_by_index(n_samples, config.n_shards, config.shuffle, config.seed)
406            }
407            ShardStrategy::Hash => shard_by_hash(n_samples, config.n_shards),
408            ShardStrategy::Stratified { .. } => {
409                return Err(DatasetsError::InvalidFormat(
410                    "Use ShardedDataset::new_stratified for Stratified strategy".into(),
411                ));
412            }
413            ShardStrategy::Size { shard_size_bytes } => {
414                // Estimate: assume each sample is `n_samples`-byte rows (fallback to index).
415                // For a proper size-based split the caller must know the row size.
416                // Here we approximate by treating `shard_size_bytes / (n_samples / n_samples)`
417                // and fall back to uniform index sharding with the configured n_shards.
418                let _ = shard_size_bytes; // informational only at this level
419                shard_by_index(n_samples, config.n_shards, config.shuffle, config.seed)
420            }
421        };
422
423        Ok(Self {
424            shards,
425            total_size: n_samples,
426            config,
427        })
428    }
429
430    /// Build a [`ShardedDataset`] using stratified sharding with explicit `labels`.
431    pub fn new_stratified(labels: &[usize], config: ShardingConfig) -> Result<Self> {
432        if config.n_shards == 0 {
433            return Err(DatasetsError::InvalidFormat("n_shards must be >= 1".into()));
434        }
435        if labels.is_empty() {
436            return Err(DatasetsError::InvalidFormat(
437                "labels must not be empty".into(),
438            ));
439        }
440
441        let shards = shard_stratified(labels, config.n_shards, config.shuffle, config.seed);
442        let total_size = labels.len();
443
444        Ok(Self {
445            shards,
446            total_size,
447            config,
448        })
449    }
450
451    /// Look up a shard by its identifier.
452    pub fn get_shard(&self, shard_id: usize) -> Option<&DataShard> {
453        self.shards.get(shard_id)
454    }
455
456    /// Partition shard identifiers into a (train, validation) split.
457    ///
458    /// The last `ceil(n_shards * val_fraction)` shard IDs are used as the
459    /// validation set; the rest are used for training.
460    pub fn train_shards(&self, val_fraction: f64) -> (Vec<usize>, Vec<usize>) {
461        let n = self.shards.len();
462        if n == 0 {
463            return (Vec::new(), Vec::new());
464        }
465        let n_val = ((n as f64 * val_fraction).ceil() as usize).min(n);
466        let n_train = n - n_val;
467        let train_ids: Vec<usize> = (0..n_train).collect();
468        let val_ids: Vec<usize> = (n_train..n).collect();
469        (train_ids, val_ids)
470    }
471
472    /// Return an iterator over the sample indices of shard `shard_id`.
473    ///
474    /// Returns an empty iterator if `shard_id` is out of range.
475    pub fn shard_iter(&self, shard_id: usize) -> impl Iterator<Item = usize> + '_ {
476        let slice: &[usize] = match self.shards.get(shard_id) {
477            Some(shard) => &shard.indices,
478            None => &[],
479        };
480        slice.iter().copied()
481    }
482
483    /// Total number of shards.
484    pub fn n_shards(&self) -> usize {
485        self.shards.len()
486    }
487
488    /// Total number of samples across all shards.
489    pub fn total_samples(&self) -> usize {
490        self.shards.iter().map(|s| s.indices.len()).sum()
491    }
492}
493
494// ─────────────────────────────────────────────────────────────────────────────
495// Data-carrying shard types and functions
496// ─────────────────────────────────────────────────────────────────────────────
497
498/// A data-carrying shard containing feature vectors, labels, and indices.
499#[derive(Debug, Clone)]
500pub struct DatasetShard {
501    /// Zero-based shard identifier.
502    pub shard_id: usize,
503    /// Total number of shards.
504    pub total_shards: usize,
505    /// Sample indices from the original dataset.
506    pub indices: Vec<usize>,
507    /// Feature vectors for samples in this shard.
508    pub data: Vec<Vec<f64>>,
509    /// Labels for samples in this shard.
510    pub labels: Vec<usize>,
511}
512
513impl DatasetShard {
514    /// Number of samples in this shard.
515    pub fn len(&self) -> usize {
516        self.indices.len()
517    }
518
519    /// Returns `true` if this shard is empty.
520    pub fn is_empty(&self) -> bool {
521        self.indices.is_empty()
522    }
523
524    /// Return the subset of `data` corresponding to this shard's indices.
525    ///
526    /// Only indices that are within bounds of `data` are included; out-of-bound
527    /// indices are silently skipped.
528    pub fn apply_f64(&self, data: &[Vec<f64>]) -> Vec<Vec<f64>> {
529        self.indices
530            .iter()
531            .filter(|&&i| i < data.len())
532            .map(|&i| data[i].clone())
533            .collect()
534    }
535
536    /// Return the subset of `labels` corresponding to this shard's indices.
537    ///
538    /// Only indices that are within bounds of `labels` are included; out-of-bound
539    /// indices are silently skipped.
540    pub fn apply_labels(&self, labels: &[usize]) -> Vec<usize> {
541        self.indices
542            .iter()
543            .filter(|&&i| i < labels.len())
544            .map(|&i| labels[i])
545            .collect()
546    }
547}
548
549// ─────────────────────────────────────────────────────────────────────────────
550// ShardedLoader — consistent shuffled sharding for distributed training
551// ─────────────────────────────────────────────────────────────────────────────
552
553/// A loader that partitions a dataset into consistently shuffled shards for
554/// multi-process or multi-node distributed training.
555///
556/// Given the same `seed`, `total_samples`, and `n_shards`, every call to
557/// [`ShardedLoader::get_shard`] with the same arguments will return the same
558/// `DatasetShard`, making the assignment deterministic and reproducible across
559/// independent processes.
560///
561/// ## Example
562///
563/// ```rust
564/// use scirs2_datasets::ShardedLoader;
565///
566/// let loader = ShardedLoader::new(100, 4, 42);
567/// assert!(loader.verify_coverage());
568///
569/// let shard0 = loader.get_shard(0);
570/// let shard1 = loader.get_shard(1);
571/// // No overlap between shards.
572/// for &i in &shard0.indices {
573///     assert!(!shard1.indices.contains(&i));
574/// }
575/// ```
576#[derive(Debug, Clone)]
577pub struct ShardedLoader {
578    /// Total number of samples in the dataset.
579    pub total_samples: usize,
580    /// Number of shards to partition into.
581    pub n_shards: usize,
582    /// Seed used for the deterministic shuffle.
583    pub seed: u64,
584}
585
586impl ShardedLoader {
587    /// Create a new `ShardedLoader`.
588    ///
589    /// * `total_samples` — number of samples in the full dataset.
590    /// * `n_shards` — number of shards to divide the dataset into.
591    /// * `seed` — seed for the deterministic Fisher-Yates shuffle.
592    pub fn new(total_samples: usize, n_shards: usize, seed: u64) -> Self {
593        Self {
594            total_samples,
595            n_shards,
596            seed,
597        }
598    }
599
600    /// Compute the global shuffled permutation of all sample indices.
601    ///
602    /// Calling this function with the same `seed` always returns the same
603    /// permutation, regardless of which process calls it.
604    pub fn global_permutation(&self) -> Vec<usize> {
605        consistent_shuffle(self.total_samples, self.seed)
606    }
607
608    /// Return shard `shard_id` (0-indexed) from the consistently shuffled
609    /// partition of the dataset.
610    ///
611    /// If `shard_id >= self.n_shards` an empty shard is returned.
612    pub fn get_shard(&self, shard_id: usize) -> DatasetShard {
613        if self.n_shards == 0 || self.total_samples == 0 || shard_id >= self.n_shards {
614            return DatasetShard {
615                shard_id,
616                total_shards: self.n_shards,
617                indices: Vec::new(),
618                data: Vec::new(),
619                labels: Vec::new(),
620            };
621        }
622
623        let permuted = self.global_permutation();
624        let base = self.total_samples / self.n_shards;
625        let remainder = self.total_samples % self.n_shards;
626
627        // Determine start and end offsets for this shard.
628        let mut offset = 0usize;
629        for id in 0..shard_id {
630            let extra = if id < remainder { 1 } else { 0 };
631            offset += base + extra;
632        }
633        let extra = if shard_id < remainder { 1 } else { 0 };
634        let size = base + extra;
635
636        let indices = permuted[offset..offset + size].to_vec();
637
638        DatasetShard {
639            shard_id,
640            total_shards: self.n_shards,
641            indices,
642            data: Vec::new(),
643            labels: Vec::new(),
644        }
645    }
646
647    /// Verify that the union of all shard index sets covers every sample index
648    /// exactly once (no gaps, no duplicates).
649    ///
650    /// Returns `true` when coverage is complete and disjoint.
651    pub fn verify_coverage(&self) -> bool {
652        if self.n_shards == 0 || self.total_samples == 0 {
653            return self.total_samples == 0;
654        }
655
656        let mut seen = vec![false; self.total_samples];
657        for shard_id in 0..self.n_shards {
658            let shard = self.get_shard(shard_id);
659            for &idx in &shard.indices {
660                if idx >= self.total_samples || seen[idx] {
661                    return false;
662                }
663                seen[idx] = true;
664            }
665        }
666        seen.iter().all(|&v| v)
667    }
668}
669
670/// Split a dataset into `n_shards` equal parts, optionally shuffled.
671///
672/// Each shard receives a contiguous slice of the (possibly shuffled) index
673/// order, along with the corresponding data and label rows.
674///
675/// # Errors
676///
677/// Returns an error if `data.len() != labels.len()` or `n_shards == 0`.
678pub fn shard_dataset(
679    data: &[Vec<f64>],
680    labels: &[usize],
681    n_shards: usize,
682    seed: u64,
683) -> Result<Vec<DatasetShard>> {
684    let n = data.len();
685    if n != labels.len() {
686        return Err(DatasetsError::InvalidFormat(format!(
687            "data length ({}) != labels length ({})",
688            n,
689            labels.len()
690        )));
691    }
692    if n_shards == 0 {
693        return Err(DatasetsError::InvalidFormat("n_shards must be >= 1".into()));
694    }
695    if n == 0 {
696        return Ok(Vec::new());
697    }
698
699    let index_shards = shard_by_index(n, n_shards, true, seed);
700    Ok(build_dataset_shards(data, labels, &index_shards))
701}
702
703/// Split a dataset into `n_shards` shards that maintain per-shard label
704/// distribution matching the global distribution.
705///
706/// # Errors
707///
708/// Returns an error if `data.len() != labels.len()` or `n_shards == 0`.
709pub fn stratified_shard(
710    data: &[Vec<f64>],
711    labels: &[usize],
712    n_shards: usize,
713) -> Result<Vec<DatasetShard>> {
714    let n = data.len();
715    if n != labels.len() {
716        return Err(DatasetsError::InvalidFormat(format!(
717            "data length ({}) != labels length ({})",
718            n,
719            labels.len()
720        )));
721    }
722    if n_shards == 0 {
723        return Err(DatasetsError::InvalidFormat("n_shards must be >= 1".into()));
724    }
725    if n == 0 {
726        return Ok(Vec::new());
727    }
728
729    let index_shards = shard_stratified(labels, n_shards, false, 0);
730    Ok(build_dataset_shards(data, labels, &index_shards))
731}
732
733/// Split a dataset into `n_shards` shards with consistent random shuffling.
734///
735/// Uses a seeded shuffle so the same seed always produces the same assignment.
736///
737/// # Errors
738///
739/// Returns an error if `data.len() != labels.len()` or `n_shards == 0`.
740pub fn shuffled_shard(
741    data: &[Vec<f64>],
742    labels: &[usize],
743    n_shards: usize,
744    seed: u64,
745) -> Result<Vec<DatasetShard>> {
746    shard_dataset(data, labels, n_shards, seed)
747}
748
749/// Reconstruct a full dataset from a collection of shards.
750///
751/// Samples are reassembled in index order when possible; otherwise they
752/// appear in the order encountered across shards.
753pub fn merge_shards(shards: &[DatasetShard]) -> (Vec<Vec<f64>>, Vec<usize>) {
754    if shards.is_empty() {
755        return (Vec::new(), Vec::new());
756    }
757
758    // Collect all (index, data, label) triples.
759    let mut entries: Vec<(usize, &Vec<f64>, usize)> = Vec::new();
760    for shard in shards {
761        for (pos, &idx) in shard.indices.iter().enumerate() {
762            entries.push((idx, &shard.data[pos], shard.labels[pos]));
763        }
764    }
765
766    // Sort by original index for deterministic reconstruction.
767    entries.sort_by_key(|(idx, _, _)| *idx);
768
769    let data: Vec<Vec<f64>> = entries.iter().map(|(_, d, _)| (*d).clone()).collect();
770    let labels: Vec<usize> = entries.iter().map(|(_, _, l)| *l).collect();
771    (data, labels)
772}
773
774/// Internal helper: convert index-only shards into data-carrying DatasetShards.
775fn build_dataset_shards(
776    data: &[Vec<f64>],
777    labels: &[usize],
778    index_shards: &[DataShard],
779) -> Vec<DatasetShard> {
780    index_shards
781        .iter()
782        .map(|is| {
783            let shard_data: Vec<Vec<f64>> = is.indices.iter().map(|&i| data[i].clone()).collect();
784            let shard_labels: Vec<usize> = is.indices.iter().map(|&i| labels[i]).collect();
785            DatasetShard {
786                shard_id: is.shard_id,
787                total_shards: is.n_shards,
788                indices: is.indices.clone(),
789                data: shard_data,
790                labels: shard_labels,
791            }
792        })
793        .collect()
794}
795
796// ─────────────────────────────────────────────────────────────────────────────
797// Tests
798// ─────────────────────────────────────────────────────────────────────────────
799
800#[cfg(test)]
801mod tests {
802    use super::*;
803
804    #[test]
805    fn test_shard_by_index_no_shuffle() {
806        let shards = shard_by_index(100, 4, false, 0);
807        assert_eq!(shards.len(), 4);
808        for shard in &shards {
809            assert_eq!(shard.indices.len(), 25);
810        }
811        // All indices covered exactly once.
812        let mut seen = [false; 100];
813        for shard in &shards {
814            for &i in &shard.indices {
815                assert!(!seen[i], "index {i} appears twice");
816                seen[i] = true;
817            }
818        }
819        assert!(seen.iter().all(|&v| v));
820    }
821
822    #[test]
823    fn test_shard_by_index_shuffle() {
824        let shards = shard_by_index(100, 4, true, 42);
825        assert_eq!(shards.len(), 4);
826        let total: usize = shards.iter().map(|s| s.len()).sum();
827        assert_eq!(total, 100);
828    }
829
830    #[test]
831    fn test_consistent_shuffle_determinism() {
832        let a = consistent_shuffle(50, 12345);
833        let b = consistent_shuffle(50, 12345);
834        assert_eq!(a, b);
835        // Different seed → different order (with overwhelming probability).
836        let c = consistent_shuffle(50, 99999);
837        assert_ne!(a, c);
838    }
839
840    #[test]
841    fn test_consistent_shuffle_permutation() {
842        let n = 200;
843        let shuffled = consistent_shuffle(n, 7);
844        assert_eq!(shuffled.len(), n);
845        let mut sorted = shuffled.clone();
846        sorted.sort_unstable();
847        assert_eq!(sorted, (0..n).collect::<Vec<_>>());
848    }
849
850    #[test]
851    fn test_shard_by_hash() {
852        let shards = shard_by_hash(100, 4);
853        assert_eq!(shards.len(), 4);
854        // Shard 0 contains indices 0,4,8,...
855        assert!(shards[0].indices.iter().all(|&i| i % 4 == 0));
856        let total: usize = shards.iter().map(|s| s.len()).sum();
857        assert_eq!(total, 100);
858    }
859
860    #[test]
861    fn test_stratified_class_proportions() {
862        // 50 samples: 30 in class 0, 20 in class 1.
863        let mut labels = vec![0usize; 30];
864        labels.extend(vec![1usize; 20]);
865        let shards = shard_stratified(&labels, 5, false, 0);
866        assert_eq!(shards.len(), 5);
867        // Each shard should have 10 samples (6 from class-0, 4 from class-1).
868        for shard in &shards {
869            assert_eq!(shard.indices.len(), 10);
870        }
871    }
872
873    #[test]
874    fn test_sharded_dataset_new() {
875        let config = ShardingConfig {
876            n_shards: 4,
877            strategy: ShardStrategy::Index,
878            shuffle: false,
879            seed: 0,
880        };
881        let ds = ShardedDataset::new(100, config).expect("should succeed");
882        assert_eq!(ds.n_shards(), 4);
883        assert_eq!(ds.total_samples(), 100);
884    }
885
886    #[test]
887    fn test_train_shards_split() {
888        let config = ShardingConfig {
889            n_shards: 8,
890            strategy: ShardStrategy::Index,
891            shuffle: false,
892            seed: 0,
893        };
894        let ds = ShardedDataset::new(80, config).expect("should succeed");
895        let (train, val) = ds.train_shards(0.25);
896        assert_eq!(train.len() + val.len(), 8);
897        assert_eq!(val.len(), 2); // ceil(8 * 0.25) = 2
898    }
899
900    #[test]
901    fn test_shard_iter() {
902        let config = ShardingConfig {
903            n_shards: 4,
904            strategy: ShardStrategy::Index,
905            shuffle: false,
906            seed: 0,
907        };
908        let ds = ShardedDataset::new(40, config).expect("should succeed");
909        let collected: Vec<usize> = ds.shard_iter(0).collect();
910        assert_eq!(collected.len(), 10);
911        // shard 0 contains indices 0..10 (no shuffle).
912        assert_eq!(collected, (0..10).collect::<Vec<_>>());
913    }
914
915    #[test]
916    fn test_shard_iter_out_of_bounds() {
917        let config = ShardingConfig::default();
918        let ds = ShardedDataset::new(10, config).expect("should succeed");
919        let empty: Vec<usize> = ds.shard_iter(999).collect();
920        assert!(empty.is_empty());
921    }
922
923    #[test]
924    fn test_sharded_dataset_invalid_config() {
925        let bad_config = ShardingConfig {
926            n_shards: 0,
927            ..Default::default()
928        };
929        assert!(ShardedDataset::new(100, bad_config).is_err());
930    }
931
932    #[test]
933    fn test_shard_id_assignment() {
934        let shards = shard_by_index(100, 4, false, 0);
935        for (expected_id, shard) in shards.iter().enumerate() {
936            assert_eq!(shard.shard_id, expected_id);
937            assert_eq!(shard.n_shards, 4);
938        }
939    }
940
941    #[test]
942    fn test_stratified_new_stratified() {
943        let labels: Vec<usize> = (0..60).map(|i| i % 3).collect();
944        let config = ShardingConfig {
945            n_shards: 3,
946            strategy: ShardStrategy::Stratified {
947                label_column: "class".into(),
948            },
949            shuffle: false,
950            seed: 0,
951        };
952        let ds = ShardedDataset::new_stratified(&labels, config).expect("ok");
953        assert_eq!(ds.n_shards(), 3);
954        assert_eq!(ds.total_samples(), 60);
955    }
956
957    // ── Data-carrying shard tests ──────────────────────────────────────────
958
959    fn make_test_data(n: usize) -> (Vec<Vec<f64>>, Vec<usize>) {
960        let data: Vec<Vec<f64>> = (0..n).map(|i| vec![i as f64, (i * 2) as f64]).collect();
961        let labels: Vec<usize> = (0..n).map(|i| i % 3).collect();
962        (data, labels)
963    }
964
965    #[test]
966    fn test_shard_dataset_total_samples() {
967        let (data, labels) = make_test_data(100);
968        let shards = shard_dataset(&data, &labels, 4, 42).expect("ok");
969        assert_eq!(shards.len(), 4);
970        let total: usize = shards.iter().map(|s| s.len()).sum();
971        assert_eq!(total, 100);
972    }
973
974    #[test]
975    fn test_stratified_shard_label_proportions() {
976        // 60 class-0, 40 class-1
977        let mut labels = vec![0usize; 60];
978        labels.extend(vec![1usize; 40]);
979        let data: Vec<Vec<f64>> = (0..100).map(|i| vec![i as f64]).collect();
980        let shards = stratified_shard(&data, &labels, 5).expect("ok");
981        assert_eq!(shards.len(), 5);
982        for shard in &shards {
983            let c0 = shard.labels.iter().filter(|&&l| l == 0).count();
984            let c1 = shard.labels.iter().filter(|&&l| l == 1).count();
985            // Each shard should have 12 class-0 and 8 class-1
986            assert_eq!(c0, 12, "Expected 12 class-0 per shard, got {c0}");
987            assert_eq!(c1, 8, "Expected 8 class-1 per shard, got {c1}");
988        }
989    }
990
991    #[test]
992    fn test_merge_shards_recovers_data() {
993        let (data, labels) = make_test_data(50);
994        let shards = shard_dataset(&data, &labels, 5, 99).expect("ok");
995        let (merged_data, merged_labels) = merge_shards(&shards);
996        assert_eq!(merged_data.len(), 50);
997        assert_eq!(merged_labels.len(), 50);
998        // After merge (sorted by index), should match original.
999        for i in 0..50 {
1000            assert_eq!(merged_data[i], data[i], "Data mismatch at index {i}");
1001            assert_eq!(merged_labels[i], labels[i], "Label mismatch at index {i}");
1002        }
1003    }
1004
1005    #[test]
1006    fn test_shuffled_shard_determinism() {
1007        let (data, labels) = make_test_data(30);
1008        let s1 = shuffled_shard(&data, &labels, 3, 42).expect("ok");
1009        let s2 = shuffled_shard(&data, &labels, 3, 42).expect("ok");
1010        for (a, b) in s1.iter().zip(s2.iter()) {
1011            assert_eq!(a.indices, b.indices);
1012        }
1013    }
1014
1015    #[test]
1016    fn test_shard_dataset_error_on_mismatch() {
1017        let data = vec![vec![1.0]; 10];
1018        let labels = vec![0; 5];
1019        assert!(shard_dataset(&data, &labels, 2, 0).is_err());
1020    }
1021
1022    #[test]
1023    fn test_merge_empty_shards() {
1024        let (data, labels) = merge_shards(&[]);
1025        assert!(data.is_empty());
1026        assert!(labels.is_empty());
1027    }
1028
1029    // ── ShardedLoader tests ────────────────────────────────────────────────
1030
1031    /// verify_coverage returns true for 100 samples divided into 4 shards.
1032    #[test]
1033    fn test_sharded_loader_verify_coverage() {
1034        let loader = ShardedLoader::new(100, 4, 42);
1035        assert!(
1036            loader.verify_coverage(),
1037            "all 100 samples should be covered"
1038        );
1039    }
1040
1041    /// Shard sizes should differ by at most 1 (balanced sharding).
1042    #[test]
1043    fn test_sharded_loader_balanced_sizes() {
1044        let loader = ShardedLoader::new(101, 4, 7); // 101 not divisible by 4
1045        let sizes: Vec<usize> = (0..4).map(|id| loader.get_shard(id).len()).collect();
1046        let min_size = *sizes.iter().min().expect("non-empty");
1047        let max_size = *sizes.iter().max().expect("non-empty");
1048        assert!(
1049            max_size - min_size <= 1,
1050            "shard sizes differ by more than 1: {sizes:?}"
1051        );
1052        let total: usize = sizes.iter().sum();
1053        assert_eq!(total, 101, "total should equal n_samples");
1054    }
1055
1056    /// Shard 0 and shard 1 indices must be disjoint.
1057    #[test]
1058    fn test_sharded_loader_disjoint_shards() {
1059        let loader = ShardedLoader::new(100, 4, 99);
1060        let shard0 = loader.get_shard(0);
1061        let shard1 = loader.get_shard(1);
1062        for &i in &shard0.indices {
1063            assert!(
1064                !shard1.indices.contains(&i),
1065                "index {i} appears in both shard 0 and shard 1"
1066            );
1067        }
1068    }
1069
1070    /// Same seed must produce the same permutation on every call.
1071    #[test]
1072    fn test_sharded_loader_same_seed_same_permutation() {
1073        let loader = ShardedLoader::new(100, 4, 12345);
1074        let p1 = loader.global_permutation();
1075        let p2 = loader.global_permutation();
1076        assert_eq!(p1, p2, "same seed should give same permutation");
1077
1078        let loader2 = ShardedLoader::new(100, 4, 12345);
1079        let p3 = loader2.global_permutation();
1080        assert_eq!(p1, p3, "independent loader with same seed should match");
1081    }
1082
1083    /// apply_f64 returns the correct number of rows.
1084    #[test]
1085    fn test_dataset_shard_apply_f64() {
1086        let data: Vec<Vec<f64>> = (0..100).map(|i| vec![i as f64, (i * 2) as f64]).collect();
1087        let loader = ShardedLoader::new(100, 4, 42);
1088        let shard = loader.get_shard(0);
1089        let subset = shard.apply_f64(&data);
1090        assert_eq!(
1091            subset.len(),
1092            shard.len(),
1093            "apply_f64 should return exactly shard.len() rows"
1094        );
1095        // Each row in subset should have 2 features.
1096        for row in &subset {
1097            assert_eq!(row.len(), 2, "each row should have 2 features");
1098        }
1099    }
1100
1101    /// apply_labels returns the correct number of labels.
1102    #[test]
1103    fn test_dataset_shard_apply_labels() {
1104        let labels: Vec<usize> = (0..100).map(|i| i % 3).collect();
1105        let loader = ShardedLoader::new(100, 4, 42);
1106        let shard = loader.get_shard(2);
1107        let subset = shard.apply_labels(&labels);
1108        assert_eq!(
1109            subset.len(),
1110            shard.len(),
1111            "apply_labels should return exactly shard.len() labels"
1112        );
1113    }
1114
1115    /// Verify coverage works for edge case: 1 shard.
1116    #[test]
1117    fn test_sharded_loader_single_shard_coverage() {
1118        let loader = ShardedLoader::new(50, 1, 0);
1119        assert!(loader.verify_coverage());
1120        let shard = loader.get_shard(0);
1121        assert_eq!(shard.len(), 50);
1122    }
1123
1124    /// Out-of-range shard_id returns an empty shard.
1125    #[test]
1126    fn test_sharded_loader_out_of_range_shard() {
1127        let loader = ShardedLoader::new(100, 4, 42);
1128        let empty_shard = loader.get_shard(99);
1129        assert!(
1130            empty_shard.is_empty(),
1131            "out-of-range shard_id should give empty shard"
1132        );
1133    }
1134}