entrenar/config/train/batches/
streaming.rs1#![allow(dead_code)]
2use std::collections::VecDeque;
18use std::path::{Path, PathBuf};
19
20#[derive(Debug, Clone)]
22pub struct ShardConfig {
23 pub rank: usize,
25 pub world_size: usize,
27 pub seed: u64,
29}
30
31impl ShardConfig {
32 pub fn single() -> Self {
34 Self { rank: 0, world_size: 1, seed: 42 }
35 }
36}
37
38#[derive(Debug)]
55pub struct StreamingParquetLoader {
56 all_files: Vec<PathBuf>,
58 my_files: Vec<PathBuf>,
60 shard_config: ShardConfig,
62 batch_size: usize,
64 seq_len: usize,
66 buffer: VecDeque<Vec<u32>>,
68 next_file_idx: usize,
70 epoch: usize,
72}
73
74impl StreamingParquetLoader {
75 pub fn new(
86 data_dir: &Path,
87 shard_config: ShardConfig,
88 batch_size: usize,
89 seq_len: usize,
90 ) -> Result<Self, String> {
91 let mut all_files = discover_parquet_files(data_dir)?;
92 all_files.sort(); if all_files.len() < shard_config.world_size {
95 return Err(format!(
96 "insufficient files for sharding: {} files < {} workers (C-SHARD-001)",
97 all_files.len(),
98 shard_config.world_size,
99 ));
100 }
101
102 let my_files = shard_files(&all_files, shard_config.rank, shard_config.world_size);
103
104 Ok(Self {
105 all_files,
106 my_files,
107 shard_config,
108 batch_size,
109 seq_len,
110 buffer: VecDeque::new(),
111 next_file_idx: 0,
112 epoch: 0,
113 })
114 }
115
116 pub fn num_files(&self) -> usize {
118 self.my_files.len()
119 }
120
121 pub fn total_files(&self) -> usize {
123 self.all_files.len()
124 }
125
126 pub fn my_files(&self) -> &[PathBuf] {
128 &self.my_files
129 }
130
131 pub fn reset_epoch(&mut self, epoch: usize) {
133 self.epoch = epoch;
134 self.next_file_idx = 0;
135 self.buffer.clear();
136 shuffle_files(&mut self.my_files, self.shard_config.seed, epoch);
138 }
139
140 pub fn batch_size(&self) -> usize {
142 self.batch_size
143 }
144
145 pub fn seq_len(&self) -> usize {
147 self.seq_len
148 }
149
150 #[cfg(all(not(target_arch = "wasm32"), feature = "parquet"))]
156 pub fn next_batches(
157 &mut self,
158 ) -> std::result::Result<Option<Vec<crate::train::LMBatch>>, String> {
159 use crate::train::LMBatch;
160
161 if self.next_file_idx >= self.my_files.len() {
162 return Ok(None);
163 }
164
165 let path = &self.my_files[self.next_file_idx];
166 self.next_file_idx += 1;
167
168 let sequences = load_pretokenized_from_parquet(path)?;
170
171 if sequences.is_empty() {
172 return Ok(Some(Vec::new()));
173 }
174
175 let pad_id = 0u32;
177 let eos_id = 2u32;
178 let num_batches = sequences.len().div_ceil(self.batch_size);
179 let mut batches = Vec::with_capacity(num_batches);
180 for chunk in sequences.chunks(self.batch_size) {
181 batches.push(LMBatch::from_sequences(chunk, pad_id, eos_id));
182 }
183
184 Ok(Some(batches))
185 }
186
187 pub fn is_epoch_exhausted(&self) -> bool {
189 self.next_file_idx >= self.my_files.len() && self.buffer.is_empty()
190 }
191
192 pub fn resume_from(&mut self, file_idx: usize) {
197 self.next_file_idx = file_idx.min(self.my_files.len());
198 self.buffer.clear();
199 }
200
201 pub fn current_file_idx(&self) -> usize {
203 self.next_file_idx
204 }
205
206 pub fn current_epoch(&self) -> usize {
208 self.epoch
209 }
210}
211
212fn discover_parquet_files(dir: &Path) -> Result<Vec<PathBuf>, String> {
214 if !dir.exists() {
215 return Err(format!("data directory does not exist: {}", dir.display()));
216 }
217
218 let mut files = Vec::new();
219 let entries = std::fs::read_dir(dir)
220 .map_err(|e| format!("failed to read directory {}: {e}", dir.display()))?;
221
222 for entry in entries {
223 let entry = entry.map_err(|e| format!("failed to read dir entry: {e}"))?;
224 let path = entry.path();
225 if path.extension().and_then(|e| e.to_str()) == Some("parquet") {
226 files.push(path);
227 }
228 }
229
230 if files.is_empty() {
231 return Err(format!("no .parquet files found in {}", dir.display()));
232 }
233
234 Ok(files)
235}
236
237fn shard_files(all_files: &[PathBuf], rank: usize, world_size: usize) -> Vec<PathBuf> {
246 all_files
247 .iter()
248 .enumerate()
249 .filter(|(i, _)| i % world_size == rank)
250 .map(|(_, f)| f.clone())
251 .collect()
252}
253
254fn shuffle_files(files: &mut [PathBuf], base_seed: u64, epoch: usize) {
258 let mut rng_state = base_seed.wrapping_add(epoch as u64);
259 for i in (1..files.len()).rev() {
260 rng_state = rng_state
262 .wrapping_mul(6_364_136_223_846_793_005)
263 .wrapping_add(1_442_695_040_888_963_407);
264 let j = (rng_state >> 33) as usize % (i + 1);
265 files.swap(i, j);
266 }
267}
268
269#[cfg(all(not(target_arch = "wasm32"), feature = "parquet"))]
275fn load_pretokenized_from_parquet(path: &Path) -> std::result::Result<Vec<Vec<u32>>, String> {
276 use alimentar::{ArrowDataset, Dataset};
277 use arrow::array::{Array, ListArray};
278
279 let dataset = ArrowDataset::from_parquet(path)
280 .map_err(|e| format!("Failed to load parquet {}: {e}", path.display()))?;
281
282 let schema = dataset.schema();
283 let column_names: Vec<&str> = schema.fields().iter().map(|f| f.name().as_str()).collect();
284
285 let token_col = column_names.iter().find(|&&n| n == "input_ids" || n == "token_ids").copied();
286
287 let token_col = match token_col {
288 Some(col) => col,
289 None => {
290 return Err(format!(
291 "No pre-tokenized column (input_ids/token_ids) in {}",
292 path.display()
293 ));
294 }
295 };
296
297 let col_idx = schema.index_of(token_col).map_err(|e| format!("Column index error: {e}"))?;
298
299 let mut sequences = Vec::with_capacity(dataset.len());
300
301 for batch in dataset.iter() {
302 let col = batch.column(col_idx);
303 if let Some(list_arr) = col.as_any().downcast_ref::<ListArray>() {
304 for i in 0..list_arr.len() {
305 if list_arr.is_null(i) {
306 continue;
307 }
308 let values = list_arr.value(i);
309 let seq = extract_u32_values(&*values);
310 if !seq.is_empty() {
311 sequences.push(seq);
312 }
313 }
314 }
315 }
316 Ok(sequences)
319}
320
321#[cfg(all(not(target_arch = "wasm32"), feature = "parquet"))]
323fn extract_u32_values(array: &dyn arrow::array::Array) -> Vec<u32> {
324 use arrow::array::{Int32Array, Int64Array, UInt32Array};
325
326 if let Some(arr) = array.as_any().downcast_ref::<UInt32Array>() {
327 arr.values().to_vec()
328 } else if let Some(arr) = array.as_any().downcast_ref::<Int32Array>() {
329 arr.values().iter().map(|&v| v as u32).collect()
330 } else if let Some(arr) = array.as_any().downcast_ref::<Int64Array>() {
331 arr.values().iter().map(|&v| v as u32).collect()
332 } else {
333 Vec::new()
334 }
335}
336
337#[cfg(test)]
338mod tests {
339 use super::*;
340 use std::fs;
341
342 fn create_temp_dir_with_files(n: usize) -> (tempfile::TempDir, Vec<PathBuf>) {
343 let dir = tempfile::tempdir().expect("create temp dir");
344 let mut files = Vec::new();
345 for i in 0..n {
346 let path = dir.path().join(format!("shard_{i:04}.parquet"));
347 fs::write(&path, format!("fake parquet {i}")).expect("write file");
348 files.push(path);
349 }
350 (dir, files)
351 }
352
353 #[test]
354 fn test_shard_files_disjointness() {
355 let files: Vec<PathBuf> = (0..10).map(|i| PathBuf::from(format!("f{i}.parquet"))).collect();
356 let s0 = shard_files(&files, 0, 3);
357 let s1 = shard_files(&files, 1, 3);
358 let s2 = shard_files(&files, 2, 3);
359
360 for f in &s0 {
362 assert!(!s1.contains(f));
363 assert!(!s2.contains(f));
364 }
365 for f in &s1 {
366 assert!(!s2.contains(f));
367 }
368
369 assert_eq!(s0.len() + s1.len() + s2.len(), 10);
371 }
372
373 #[test]
374 fn test_shard_files_assignment() {
375 let files: Vec<PathBuf> = (0..10).map(|i| PathBuf::from(format!("f{i}.parquet"))).collect();
376 let s0 = shard_files(&files, 0, 3);
377 assert_eq!(s0.len(), 4); let s1 = shard_files(&files, 1, 3);
379 assert_eq!(s1.len(), 3); let s2 = shard_files(&files, 2, 3);
381 assert_eq!(s2.len(), 3); }
383
384 #[test]
385 fn test_shard_files_two_workers() {
386 let files: Vec<PathBuf> = (0..7).map(|i| PathBuf::from(format!("f{i}.parquet"))).collect();
387 let s0 = shard_files(&files, 0, 2);
388 let s1 = shard_files(&files, 1, 2);
389 assert_eq!(s0.len(), 4); assert_eq!(s1.len(), 3); }
392
393 #[test]
394 fn test_discover_parquet_files() {
395 let (dir, _) = create_temp_dir_with_files(5);
396 fs::write(dir.path().join("readme.txt"), "not parquet").expect("write");
398 let found = discover_parquet_files(dir.path()).expect("discover");
399 assert_eq!(found.len(), 5);
400 }
401
402 #[test]
403 fn test_discover_parquet_files_empty_dir() {
404 let dir = tempfile::tempdir().expect("create temp dir");
405 let result = discover_parquet_files(dir.path());
406 assert!(result.is_err());
407 assert!(result.unwrap_err().contains("no .parquet files"));
408 }
409
410 #[test]
411 fn test_streaming_loader_insufficient_files() {
412 let (dir, _) = create_temp_dir_with_files(1);
413 let config = ShardConfig { rank: 0, world_size: 2, seed: 42 };
414 let result = StreamingParquetLoader::new(dir.path(), config, 4, 2048);
415 assert!(result.is_err());
416 assert!(result.unwrap_err().contains("insufficient files"));
417 }
418
419 #[test]
420 fn test_streaming_loader_basic() {
421 let (dir, _) = create_temp_dir_with_files(4);
422 let config = ShardConfig { rank: 0, world_size: 2, seed: 42 };
423 let loader =
424 StreamingParquetLoader::new(dir.path(), config, 4, 2048).expect("create loader");
425 assert_eq!(loader.num_files(), 2);
426 assert_eq!(loader.total_files(), 4);
427 }
428
429 #[test]
430 fn test_shuffle_files_deterministic() {
431 let mut a: Vec<PathBuf> = (0..10).map(|i| PathBuf::from(format!("f{i}"))).collect();
432 let mut b = a.clone();
433 shuffle_files(&mut a, 42, 0);
434 shuffle_files(&mut b, 42, 0);
435 assert_eq!(a, b, "same seed + epoch must produce same order");
436 }
437
438 #[test]
439 fn test_shuffle_files_different_epochs() {
440 let mut a: Vec<PathBuf> = (0..10).map(|i| PathBuf::from(format!("f{i}"))).collect();
441 let mut b = a.clone();
442 shuffle_files(&mut a, 42, 0);
443 shuffle_files(&mut b, 42, 1);
444 assert_ne!(a, b, "different epochs must produce different orders");
445 }
446
447 #[test]
448 fn test_reset_epoch() {
449 let (dir, _) = create_temp_dir_with_files(4);
450 let config = ShardConfig { rank: 0, world_size: 2, seed: 42 };
451 let mut loader = StreamingParquetLoader::new(dir.path(), config, 4, 2048).expect("create");
452 let files_epoch0 = loader.my_files().to_vec();
453 loader.reset_epoch(1);
454 let files_epoch1 = loader.my_files().to_vec();
455 let mut s0 = files_epoch0.clone();
457 let mut s1 = files_epoch1.clone();
458 s0.sort();
459 s1.sort();
460 assert_eq!(s0, s1, "same files assigned across epochs");
461 }
462
463 #[test]
464 fn test_resume_from_skips_files() {
465 let (dir, _files) = create_temp_dir_with_files(5);
466 let mut loader =
467 StreamingParquetLoader::new(dir.path(), ShardConfig::single(), 4, 128).unwrap();
468 assert_eq!(loader.current_file_idx(), 0);
469 loader.resume_from(3);
470 assert_eq!(loader.current_file_idx(), 3);
471 loader.resume_from(100);
472 assert_eq!(loader.current_file_idx(), loader.num_files());
473 assert!(loader.is_epoch_exhausted());
474 }
475}