#![allow(dead_code)]
use std::collections::VecDeque;
use std::path::{Path, PathBuf};
#[derive(Debug, Clone)]
pub struct ShardConfig {
pub rank: usize,
pub world_size: usize,
pub seed: u64,
}
impl ShardConfig {
pub fn single() -> Self {
Self { rank: 0, world_size: 1, seed: 42 }
}
}
#[derive(Debug)]
pub struct StreamingParquetLoader {
all_files: Vec<PathBuf>,
my_files: Vec<PathBuf>,
shard_config: ShardConfig,
batch_size: usize,
seq_len: usize,
buffer: VecDeque<Vec<u32>>,
next_file_idx: usize,
epoch: usize,
}
impl StreamingParquetLoader {
pub fn new(
data_dir: &Path,
shard_config: ShardConfig,
batch_size: usize,
seq_len: usize,
) -> Result<Self, String> {
let mut all_files = discover_parquet_files(data_dir)?;
all_files.sort();
if all_files.len() < shard_config.world_size {
return Err(format!(
"insufficient files for sharding: {} files < {} workers (C-SHARD-001)",
all_files.len(),
shard_config.world_size,
));
}
let my_files = shard_files(&all_files, shard_config.rank, shard_config.world_size);
Ok(Self {
all_files,
my_files,
shard_config,
batch_size,
seq_len,
buffer: VecDeque::new(),
next_file_idx: 0,
epoch: 0,
})
}
pub fn num_files(&self) -> usize {
self.my_files.len()
}
pub fn total_files(&self) -> usize {
self.all_files.len()
}
pub fn my_files(&self) -> &[PathBuf] {
&self.my_files
}
pub fn reset_epoch(&mut self, epoch: usize) {
self.epoch = epoch;
self.next_file_idx = 0;
self.buffer.clear();
shuffle_files(&mut self.my_files, self.shard_config.seed, epoch);
}
pub fn batch_size(&self) -> usize {
self.batch_size
}
pub fn seq_len(&self) -> usize {
self.seq_len
}
#[cfg(all(not(target_arch = "wasm32"), feature = "parquet"))]
pub fn next_batches(
&mut self,
) -> std::result::Result<Option<Vec<crate::train::LMBatch>>, String> {
use crate::train::LMBatch;
if self.next_file_idx >= self.my_files.len() {
return Ok(None);
}
let path = &self.my_files[self.next_file_idx];
self.next_file_idx += 1;
let sequences = load_pretokenized_from_parquet(path)?;
if sequences.is_empty() {
return Ok(Some(Vec::new()));
}
let pad_id = 0u32;
let eos_id = 2u32;
let num_batches = sequences.len().div_ceil(self.batch_size);
let mut batches = Vec::with_capacity(num_batches);
for chunk in sequences.chunks(self.batch_size) {
batches.push(LMBatch::from_sequences(chunk, pad_id, eos_id));
}
Ok(Some(batches))
}
pub fn is_epoch_exhausted(&self) -> bool {
self.next_file_idx >= self.my_files.len() && self.buffer.is_empty()
}
pub fn resume_from(&mut self, file_idx: usize) {
self.next_file_idx = file_idx.min(self.my_files.len());
self.buffer.clear();
}
pub fn current_file_idx(&self) -> usize {
self.next_file_idx
}
pub fn current_epoch(&self) -> usize {
self.epoch
}
}
fn discover_parquet_files(dir: &Path) -> Result<Vec<PathBuf>, String> {
if !dir.exists() {
return Err(format!("data directory does not exist: {}", dir.display()));
}
let mut files = Vec::new();
let entries = std::fs::read_dir(dir)
.map_err(|e| format!("failed to read directory {}: {e}", dir.display()))?;
for entry in entries {
let entry = entry.map_err(|e| format!("failed to read dir entry: {e}"))?;
let path = entry.path();
if path.extension().and_then(|e| e.to_str()) == Some("parquet") {
files.push(path);
}
}
if files.is_empty() {
return Err(format!("no .parquet files found in {}", dir.display()));
}
Ok(files)
}
fn shard_files(all_files: &[PathBuf], rank: usize, world_size: usize) -> Vec<PathBuf> {
all_files
.iter()
.enumerate()
.filter(|(i, _)| i % world_size == rank)
.map(|(_, f)| f.clone())
.collect()
}
fn shuffle_files(files: &mut [PathBuf], base_seed: u64, epoch: usize) {
let mut rng_state = base_seed.wrapping_add(epoch as u64);
for i in (1..files.len()).rev() {
rng_state = rng_state
.wrapping_mul(6_364_136_223_846_793_005)
.wrapping_add(1_442_695_040_888_963_407);
let j = (rng_state >> 33) as usize % (i + 1);
files.swap(i, j);
}
}
#[cfg(all(not(target_arch = "wasm32"), feature = "parquet"))]
fn load_pretokenized_from_parquet(path: &Path) -> std::result::Result<Vec<Vec<u32>>, String> {
use alimentar::{ArrowDataset, Dataset};
use arrow::array::{Array, ListArray};
let dataset = ArrowDataset::from_parquet(path)
.map_err(|e| format!("Failed to load parquet {}: {e}", path.display()))?;
let schema = dataset.schema();
let column_names: Vec<&str> = schema.fields().iter().map(|f| f.name().as_str()).collect();
let token_col = column_names.iter().find(|&&n| n == "input_ids" || n == "token_ids").copied();
let token_col = match token_col {
Some(col) => col,
None => {
return Err(format!(
"No pre-tokenized column (input_ids/token_ids) in {}",
path.display()
));
}
};
let col_idx = schema.index_of(token_col).map_err(|e| format!("Column index error: {e}"))?;
let mut sequences = Vec::with_capacity(dataset.len());
for batch in dataset.iter() {
let col = batch.column(col_idx);
if let Some(list_arr) = col.as_any().downcast_ref::<ListArray>() {
for i in 0..list_arr.len() {
if list_arr.is_null(i) {
continue;
}
let values = list_arr.value(i);
let seq = extract_u32_values(&*values);
if !seq.is_empty() {
sequences.push(seq);
}
}
}
}
Ok(sequences)
}
#[cfg(all(not(target_arch = "wasm32"), feature = "parquet"))]
fn extract_u32_values(array: &dyn arrow::array::Array) -> Vec<u32> {
use arrow::array::{Int32Array, Int64Array, UInt32Array};
if let Some(arr) = array.as_any().downcast_ref::<UInt32Array>() {
arr.values().to_vec()
} else if let Some(arr) = array.as_any().downcast_ref::<Int32Array>() {
arr.values().iter().map(|&v| v as u32).collect()
} else if let Some(arr) = array.as_any().downcast_ref::<Int64Array>() {
arr.values().iter().map(|&v| v as u32).collect()
} else {
Vec::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs;
fn create_temp_dir_with_files(n: usize) -> (tempfile::TempDir, Vec<PathBuf>) {
let dir = tempfile::tempdir().expect("create temp dir");
let mut files = Vec::new();
for i in 0..n {
let path = dir.path().join(format!("shard_{i:04}.parquet"));
fs::write(&path, format!("fake parquet {i}")).expect("write file");
files.push(path);
}
(dir, files)
}
#[test]
fn test_shard_files_disjointness() {
let files: Vec<PathBuf> = (0..10).map(|i| PathBuf::from(format!("f{i}.parquet"))).collect();
let s0 = shard_files(&files, 0, 3);
let s1 = shard_files(&files, 1, 3);
let s2 = shard_files(&files, 2, 3);
for f in &s0 {
assert!(!s1.contains(f));
assert!(!s2.contains(f));
}
for f in &s1 {
assert!(!s2.contains(f));
}
assert_eq!(s0.len() + s1.len() + s2.len(), 10);
}
#[test]
fn test_shard_files_assignment() {
let files: Vec<PathBuf> = (0..10).map(|i| PathBuf::from(format!("f{i}.parquet"))).collect();
let s0 = shard_files(&files, 0, 3);
assert_eq!(s0.len(), 4); let s1 = shard_files(&files, 1, 3);
assert_eq!(s1.len(), 3); let s2 = shard_files(&files, 2, 3);
assert_eq!(s2.len(), 3); }
#[test]
fn test_shard_files_two_workers() {
let files: Vec<PathBuf> = (0..7).map(|i| PathBuf::from(format!("f{i}.parquet"))).collect();
let s0 = shard_files(&files, 0, 2);
let s1 = shard_files(&files, 1, 2);
assert_eq!(s0.len(), 4); assert_eq!(s1.len(), 3); }
#[test]
fn test_discover_parquet_files() {
let (dir, _) = create_temp_dir_with_files(5);
fs::write(dir.path().join("readme.txt"), "not parquet").expect("write");
let found = discover_parquet_files(dir.path()).expect("discover");
assert_eq!(found.len(), 5);
}
#[test]
fn test_discover_parquet_files_empty_dir() {
let dir = tempfile::tempdir().expect("create temp dir");
let result = discover_parquet_files(dir.path());
assert!(result.is_err());
assert!(result.unwrap_err().contains("no .parquet files"));
}
#[test]
fn test_streaming_loader_insufficient_files() {
let (dir, _) = create_temp_dir_with_files(1);
let config = ShardConfig { rank: 0, world_size: 2, seed: 42 };
let result = StreamingParquetLoader::new(dir.path(), config, 4, 2048);
assert!(result.is_err());
assert!(result.unwrap_err().contains("insufficient files"));
}
#[test]
fn test_streaming_loader_basic() {
let (dir, _) = create_temp_dir_with_files(4);
let config = ShardConfig { rank: 0, world_size: 2, seed: 42 };
let loader =
StreamingParquetLoader::new(dir.path(), config, 4, 2048).expect("create loader");
assert_eq!(loader.num_files(), 2);
assert_eq!(loader.total_files(), 4);
}
#[test]
fn test_shuffle_files_deterministic() {
let mut a: Vec<PathBuf> = (0..10).map(|i| PathBuf::from(format!("f{i}"))).collect();
let mut b = a.clone();
shuffle_files(&mut a, 42, 0);
shuffle_files(&mut b, 42, 0);
assert_eq!(a, b, "same seed + epoch must produce same order");
}
#[test]
fn test_shuffle_files_different_epochs() {
let mut a: Vec<PathBuf> = (0..10).map(|i| PathBuf::from(format!("f{i}"))).collect();
let mut b = a.clone();
shuffle_files(&mut a, 42, 0);
shuffle_files(&mut b, 42, 1);
assert_ne!(a, b, "different epochs must produce different orders");
}
#[test]
fn test_reset_epoch() {
let (dir, _) = create_temp_dir_with_files(4);
let config = ShardConfig { rank: 0, world_size: 2, seed: 42 };
let mut loader = StreamingParquetLoader::new(dir.path(), config, 4, 2048).expect("create");
let files_epoch0 = loader.my_files().to_vec();
loader.reset_epoch(1);
let files_epoch1 = loader.my_files().to_vec();
let mut s0 = files_epoch0.clone();
let mut s1 = files_epoch1.clone();
s0.sort();
s1.sort();
assert_eq!(s0, s1, "same files assigned across epochs");
}
#[test]
fn test_resume_from_skips_files() {
let (dir, _files) = create_temp_dir_with_files(5);
let mut loader =
StreamingParquetLoader::new(dir.path(), ShardConfig::single(), 4, 128).unwrap();
assert_eq!(loader.current_file_idx(), 0);
loader.resume_from(3);
assert_eq!(loader.current_file_idx(), 3);
loader.resume_from(100);
assert_eq!(loader.current_file_idx(), loader.num_files());
assert!(loader.is_epoch_exhausted());
}
}