Skip to main content

entrenar/config/train/batches/
streaming.rs

1#![allow(dead_code)]
2//! Streaming Parquet data loader with file-level sharding for distributed training.
3//!
4//! # Architecture
5//!
6//! For DDP pretraining, each worker loads a disjoint subset of Parquet files.
7//! File-level sharding avoids duplicate samples across workers and is simpler
8//! than sequence-level sharding (no coordination needed).
9//!
10//! Worker N loads files: {f | f % world_size == rank}
11//!
12//! # Contract
13//!
14//! C-SHARD-001: Disjointness — no file is assigned to two workers.
15//! C-SHARD-001: Completeness — every file is assigned to exactly one worker.
16
17use std::collections::VecDeque;
18use std::path::{Path, PathBuf};
19
20/// Configuration for data sharding across distributed workers.
21#[derive(Debug, Clone)]
22pub struct ShardConfig {
23    /// This worker's global rank
24    pub rank: usize,
25    /// Total number of workers
26    pub world_size: usize,
27    /// Base random seed for epoch shuffling
28    pub seed: u64,
29}
30
31impl ShardConfig {
32    /// Create a single-worker (no sharding) config.
33    pub fn single() -> Self {
34        Self { rank: 0, world_size: 1, seed: 42 }
35    }
36}
37
38/// Streaming Parquet data loader with prefetch and file-level sharding.
39///
40/// Loads Parquet files lazily, keeping only a bounded buffer of batches
41/// in memory. Supports epoch-level reshuffling while maintaining shard
42/// assignment invariants.
43///
44/// # Example
45///
46/// ```ignore
47/// let loader = StreamingParquetLoader::new(
48///     &data_dir,
49///     ShardConfig { rank: 0, world_size: 2, seed: 42 },
50///     4,    // batch_size
51///     2048, // seq_len
52/// )?;
53/// ```
54#[derive(Debug)]
55pub struct StreamingParquetLoader {
56    /// All Parquet files discovered in the data directory
57    all_files: Vec<PathBuf>,
58    /// Files assigned to this worker (after sharding)
59    my_files: Vec<PathBuf>,
60    /// Shard configuration
61    shard_config: ShardConfig,
62    /// Batch size for LMBatch construction
63    batch_size: usize,
64    /// Sequence length
65    seq_len: usize,
66    /// Buffer of pre-loaded sequences (token ID vectors)
67    buffer: VecDeque<Vec<u32>>,
68    /// Index of next file to load from `my_files`
69    next_file_idx: usize,
70    /// Current epoch (for shuffling)
71    epoch: usize,
72}
73
74impl StreamingParquetLoader {
75    /// Create a new streaming loader.
76    ///
77    /// Discovers all `.parquet` files in `data_dir`, assigns a subset to
78    /// this worker based on `shard_config`, and prepares for iteration.
79    ///
80    /// # Errors
81    ///
82    /// Returns `Err` if:
83    /// - `data_dir` doesn't exist or is unreadable
84    /// - Fewer files than `world_size` (C-SHARD-001 violation)
85    pub fn new(
86        data_dir: &Path,
87        shard_config: ShardConfig,
88        batch_size: usize,
89        seq_len: usize,
90    ) -> Result<Self, String> {
91        let mut all_files = discover_parquet_files(data_dir)?;
92        all_files.sort(); // Deterministic ordering
93
94        if all_files.len() < shard_config.world_size {
95            return Err(format!(
96                "insufficient files for sharding: {} files < {} workers (C-SHARD-001)",
97                all_files.len(),
98                shard_config.world_size,
99            ));
100        }
101
102        let my_files = shard_files(&all_files, shard_config.rank, shard_config.world_size);
103
104        Ok(Self {
105            all_files,
106            my_files,
107            shard_config,
108            batch_size,
109            seq_len,
110            buffer: VecDeque::new(),
111            next_file_idx: 0,
112            epoch: 0,
113        })
114    }
115
116    /// Number of files assigned to this worker.
117    pub fn num_files(&self) -> usize {
118        self.my_files.len()
119    }
120
121    /// Total number of files across all workers.
122    pub fn total_files(&self) -> usize {
123        self.all_files.len()
124    }
125
126    /// Get the files assigned to this worker.
127    pub fn my_files(&self) -> &[PathBuf] {
128        &self.my_files
129    }
130
131    /// Reset for a new epoch, reshuffling file order.
132    pub fn reset_epoch(&mut self, epoch: usize) {
133        self.epoch = epoch;
134        self.next_file_idx = 0;
135        self.buffer.clear();
136        // Shuffle file order using epoch-specific seed
137        shuffle_files(&mut self.my_files, self.shard_config.seed, epoch);
138    }
139
140    /// Get batch size.
141    pub fn batch_size(&self) -> usize {
142        self.batch_size
143    }
144
145    /// Get sequence length.
146    pub fn seq_len(&self) -> usize {
147        self.seq_len
148    }
149
150    /// Load the next shard file and return its `LMBatch`es.
151    ///
152    /// Returns `Ok(None)` when all shards for this epoch are exhausted.
153    /// Each call loads one Parquet file, extracts pre-tokenized sequences,
154    /// creates `LMBatch`es, and drops the raw data. Peak memory = one shard.
155    #[cfg(all(not(target_arch = "wasm32"), feature = "parquet"))]
156    pub fn next_batches(
157        &mut self,
158    ) -> std::result::Result<Option<Vec<crate::train::LMBatch>>, String> {
159        use crate::train::LMBatch;
160
161        if self.next_file_idx >= self.my_files.len() {
162            return Ok(None);
163        }
164
165        let path = &self.my_files[self.next_file_idx];
166        self.next_file_idx += 1;
167
168        // Load and extract pre-tokenized sequences from this shard
169        let sequences = load_pretokenized_from_parquet(path)?;
170
171        if sequences.is_empty() {
172            return Ok(Some(Vec::new()));
173        }
174
175        // Create LMBatches from sequences (sequences dropped after this)
176        let pad_id = 0u32;
177        let eos_id = 2u32;
178        let num_batches = sequences.len().div_ceil(self.batch_size);
179        let mut batches = Vec::with_capacity(num_batches);
180        for chunk in sequences.chunks(self.batch_size) {
181            batches.push(LMBatch::from_sequences(chunk, pad_id, eos_id));
182        }
183
184        Ok(Some(batches))
185    }
186
187    /// Check if all files have been consumed for this epoch.
188    pub fn is_epoch_exhausted(&self) -> bool {
189        self.next_file_idx >= self.my_files.len() && self.buffer.is_empty()
190    }
191
192    /// Resume data loading from a specific file index (ALB-120).
193    ///
194    /// After checkpoint restore, call this to skip to the correct position.
195    /// Contract: C-DATARESUME-001.
196    pub fn resume_from(&mut self, file_idx: usize) {
197        self.next_file_idx = file_idx.min(self.my_files.len());
198        self.buffer.clear();
199    }
200
201    /// Current file index (for checkpointing).
202    pub fn current_file_idx(&self) -> usize {
203        self.next_file_idx
204    }
205
206    /// Current epoch (for checkpointing).
207    pub fn current_epoch(&self) -> usize {
208        self.epoch
209    }
210}
211
212/// Discover all `.parquet` files in a directory (non-recursive).
213fn discover_parquet_files(dir: &Path) -> Result<Vec<PathBuf>, String> {
214    if !dir.exists() {
215        return Err(format!("data directory does not exist: {}", dir.display()));
216    }
217
218    let mut files = Vec::new();
219    let entries = std::fs::read_dir(dir)
220        .map_err(|e| format!("failed to read directory {}: {e}", dir.display()))?;
221
222    for entry in entries {
223        let entry = entry.map_err(|e| format!("failed to read dir entry: {e}"))?;
224        let path = entry.path();
225        if path.extension().and_then(|e| e.to_str()) == Some("parquet") {
226            files.push(path);
227        }
228    }
229
230    if files.is_empty() {
231        return Err(format!("no .parquet files found in {}", dir.display()));
232    }
233
234    Ok(files)
235}
236
237/// Assign files to a worker using modular sharding.
238///
239/// Worker `rank` gets files at indices where `index % world_size == rank`.
240///
241/// # Contract (C-SHARD-001)
242///
243/// - Disjointness: `shard_files(_, r1, N) ∩ shard_files(_, r2, N) == ∅` for r1 ≠ r2
244/// - Completeness: `∪_{r=0}^{N-1} shard_files(_, r, N) == all_files`
245fn shard_files(all_files: &[PathBuf], rank: usize, world_size: usize) -> Vec<PathBuf> {
246    all_files
247        .iter()
248        .enumerate()
249        .filter(|(i, _)| i % world_size == rank)
250        .map(|(_, f)| f.clone())
251        .collect()
252}
253
254/// Shuffle files deterministically using a seed derived from base_seed + epoch.
255///
256/// Uses Fisher-Yates with a simple LCG PRNG for reproducibility.
257fn shuffle_files(files: &mut [PathBuf], base_seed: u64, epoch: usize) {
258    let mut rng_state = base_seed.wrapping_add(epoch as u64);
259    for i in (1..files.len()).rev() {
260        // LCG: state = state * 6364136223846793005 + 1442695040888963407
261        rng_state = rng_state
262            .wrapping_mul(6_364_136_223_846_793_005)
263            .wrapping_add(1_442_695_040_888_963_407);
264        let j = (rng_state >> 33) as usize % (i + 1);
265        files.swap(i, j);
266    }
267}
268
269/// Load pre-tokenized sequences from a single Parquet file.
270///
271/// Looks for `input_ids` or `token_ids` columns containing integer list arrays.
272/// Returns the sequences as `Vec<Vec<u32>>`. The `ArrowDataset` is dropped before
273/// returning, so only the extracted token IDs remain in memory.
274#[cfg(all(not(target_arch = "wasm32"), feature = "parquet"))]
275fn load_pretokenized_from_parquet(path: &Path) -> std::result::Result<Vec<Vec<u32>>, String> {
276    use alimentar::{ArrowDataset, Dataset};
277    use arrow::array::{Array, ListArray};
278
279    let dataset = ArrowDataset::from_parquet(path)
280        .map_err(|e| format!("Failed to load parquet {}: {e}", path.display()))?;
281
282    let schema = dataset.schema();
283    let column_names: Vec<&str> = schema.fields().iter().map(|f| f.name().as_str()).collect();
284
285    let token_col = column_names.iter().find(|&&n| n == "input_ids" || n == "token_ids").copied();
286
287    let token_col = match token_col {
288        Some(col) => col,
289        None => {
290            return Err(format!(
291                "No pre-tokenized column (input_ids/token_ids) in {}",
292                path.display()
293            ));
294        }
295    };
296
297    let col_idx = schema.index_of(token_col).map_err(|e| format!("Column index error: {e}"))?;
298
299    let mut sequences = Vec::with_capacity(dataset.len());
300
301    for batch in dataset.iter() {
302        let col = batch.column(col_idx);
303        if let Some(list_arr) = col.as_any().downcast_ref::<ListArray>() {
304            for i in 0..list_arr.len() {
305                if list_arr.is_null(i) {
306                    continue;
307                }
308                let values = list_arr.value(i);
309                let seq = extract_u32_values(&*values);
310                if !seq.is_empty() {
311                    sequences.push(seq);
312                }
313            }
314        }
315    }
316    // dataset dropped here — Arrow memory freed
317
318    Ok(sequences)
319}
320
321/// Extract u32 token IDs from an Arrow array (inner values of a ListArray).
322#[cfg(all(not(target_arch = "wasm32"), feature = "parquet"))]
323fn extract_u32_values(array: &dyn arrow::array::Array) -> Vec<u32> {
324    use arrow::array::{Int32Array, Int64Array, UInt32Array};
325
326    if let Some(arr) = array.as_any().downcast_ref::<UInt32Array>() {
327        arr.values().to_vec()
328    } else if let Some(arr) = array.as_any().downcast_ref::<Int32Array>() {
329        arr.values().iter().map(|&v| v as u32).collect()
330    } else if let Some(arr) = array.as_any().downcast_ref::<Int64Array>() {
331        arr.values().iter().map(|&v| v as u32).collect()
332    } else {
333        Vec::new()
334    }
335}
336
337#[cfg(test)]
338mod tests {
339    use super::*;
340    use std::fs;
341
342    fn create_temp_dir_with_files(n: usize) -> (tempfile::TempDir, Vec<PathBuf>) {
343        let dir = tempfile::tempdir().expect("create temp dir");
344        let mut files = Vec::new();
345        for i in 0..n {
346            let path = dir.path().join(format!("shard_{i:04}.parquet"));
347            fs::write(&path, format!("fake parquet {i}")).expect("write file");
348            files.push(path);
349        }
350        (dir, files)
351    }
352
353    #[test]
354    fn test_shard_files_disjointness() {
355        let files: Vec<PathBuf> = (0..10).map(|i| PathBuf::from(format!("f{i}.parquet"))).collect();
356        let s0 = shard_files(&files, 0, 3);
357        let s1 = shard_files(&files, 1, 3);
358        let s2 = shard_files(&files, 2, 3);
359
360        // Disjointness
361        for f in &s0 {
362            assert!(!s1.contains(f));
363            assert!(!s2.contains(f));
364        }
365        for f in &s1 {
366            assert!(!s2.contains(f));
367        }
368
369        // Completeness
370        assert_eq!(s0.len() + s1.len() + s2.len(), 10);
371    }
372
373    #[test]
374    fn test_shard_files_assignment() {
375        let files: Vec<PathBuf> = (0..10).map(|i| PathBuf::from(format!("f{i}.parquet"))).collect();
376        let s0 = shard_files(&files, 0, 3);
377        assert_eq!(s0.len(), 4); // 0,3,6,9
378        let s1 = shard_files(&files, 1, 3);
379        assert_eq!(s1.len(), 3); // 1,4,7
380        let s2 = shard_files(&files, 2, 3);
381        assert_eq!(s2.len(), 3); // 2,5,8
382    }
383
384    #[test]
385    fn test_shard_files_two_workers() {
386        let files: Vec<PathBuf> = (0..7).map(|i| PathBuf::from(format!("f{i}.parquet"))).collect();
387        let s0 = shard_files(&files, 0, 2);
388        let s1 = shard_files(&files, 1, 2);
389        assert_eq!(s0.len(), 4); // 0,2,4,6
390        assert_eq!(s1.len(), 3); // 1,3,5
391    }
392
393    #[test]
394    fn test_discover_parquet_files() {
395        let (dir, _) = create_temp_dir_with_files(5);
396        // Add a non-parquet file
397        fs::write(dir.path().join("readme.txt"), "not parquet").expect("write");
398        let found = discover_parquet_files(dir.path()).expect("discover");
399        assert_eq!(found.len(), 5);
400    }
401
402    #[test]
403    fn test_discover_parquet_files_empty_dir() {
404        let dir = tempfile::tempdir().expect("create temp dir");
405        let result = discover_parquet_files(dir.path());
406        assert!(result.is_err());
407        assert!(result.unwrap_err().contains("no .parquet files"));
408    }
409
410    #[test]
411    fn test_streaming_loader_insufficient_files() {
412        let (dir, _) = create_temp_dir_with_files(1);
413        let config = ShardConfig { rank: 0, world_size: 2, seed: 42 };
414        let result = StreamingParquetLoader::new(dir.path(), config, 4, 2048);
415        assert!(result.is_err());
416        assert!(result.unwrap_err().contains("insufficient files"));
417    }
418
419    #[test]
420    fn test_streaming_loader_basic() {
421        let (dir, _) = create_temp_dir_with_files(4);
422        let config = ShardConfig { rank: 0, world_size: 2, seed: 42 };
423        let loader =
424            StreamingParquetLoader::new(dir.path(), config, 4, 2048).expect("create loader");
425        assert_eq!(loader.num_files(), 2);
426        assert_eq!(loader.total_files(), 4);
427    }
428
429    #[test]
430    fn test_shuffle_files_deterministic() {
431        let mut a: Vec<PathBuf> = (0..10).map(|i| PathBuf::from(format!("f{i}"))).collect();
432        let mut b = a.clone();
433        shuffle_files(&mut a, 42, 0);
434        shuffle_files(&mut b, 42, 0);
435        assert_eq!(a, b, "same seed + epoch must produce same order");
436    }
437
438    #[test]
439    fn test_shuffle_files_different_epochs() {
440        let mut a: Vec<PathBuf> = (0..10).map(|i| PathBuf::from(format!("f{i}"))).collect();
441        let mut b = a.clone();
442        shuffle_files(&mut a, 42, 0);
443        shuffle_files(&mut b, 42, 1);
444        assert_ne!(a, b, "different epochs must produce different orders");
445    }
446
447    #[test]
448    fn test_reset_epoch() {
449        let (dir, _) = create_temp_dir_with_files(4);
450        let config = ShardConfig { rank: 0, world_size: 2, seed: 42 };
451        let mut loader = StreamingParquetLoader::new(dir.path(), config, 4, 2048).expect("create");
452        let files_epoch0 = loader.my_files().to_vec();
453        loader.reset_epoch(1);
454        let files_epoch1 = loader.my_files().to_vec();
455        // Same set of files (sorted), potentially different order
456        let mut s0 = files_epoch0.clone();
457        let mut s1 = files_epoch1.clone();
458        s0.sort();
459        s1.sort();
460        assert_eq!(s0, s1, "same files assigned across epochs");
461    }
462
463    #[test]
464    fn test_resume_from_skips_files() {
465        let (dir, _files) = create_temp_dir_with_files(5);
466        let mut loader =
467            StreamingParquetLoader::new(dir.path(), ShardConfig::single(), 4, 128).unwrap();
468        assert_eq!(loader.current_file_idx(), 0);
469        loader.resume_from(3);
470        assert_eq!(loader.current_file_idx(), 3);
471        loader.resume_from(100);
472        assert_eq!(loader.current_file_idx(), loader.num_files());
473        assert!(loader.is_epoch_exhausted());
474    }
475}