Skip to main content

entrenar/train/
shard_reader.rs

1//! Minimal tokenized-shard reader for MODEL-2 pretrain MVP (task #111).
2//!
3//! Reads a directory of `.bin` files containing little-endian u32 tokens,
4//! chunks them into `seq_length + 1` sequences, and yields `LMBatch`es of
5//! `batch_size` sequences. No licensing filter, no MinHash dedup, no PII
6//! scrub — those belong to `apr-corpus-ingest run`.
7//!
8//! Contract: `contracts/dataset-thestack-python-v1.yaml` (shard format).
9
10use crate::train::transformer_trainer::LMBatch;
11use std::fs::File;
12use std::io::{self, BufReader, Read};
13use std::path::{Path, PathBuf};
14
15/// Streaming iterator over `LMBatch`es produced from a directory of
16/// `.bin` token shards (little-endian u32).
17///
18/// Default behaviour matches the historical contract: when the last
19/// shard is exhausted, `next()` returns `None`. For training paths
20/// where the corpus is finite but the run extends beyond a single
21/// epoch, opt in to loop-on-exhaust via `with_wrap_around(true)`.
22/// This is the standard PyTorch/HuggingFace behaviour; `apr pretrain`
23/// uses it in the real-corpus drive paths so that step-count budgets
24/// are honoured even on small datasets.
25pub struct ShardBatchIter {
26    shards: Vec<PathBuf>,
27    cursor_shard: usize,
28    cursor_reader: Option<BufReader<File>>,
29    batch_size: usize,
30    seq_plus_one: usize,
31    pad_id: u32,
32    eos_id: u32,
33    wrap_around: bool,
34    epochs_completed: u64,
35}
36
37impl ShardBatchIter {
38    /// Build an iterator that yields `LMBatch` with `batch_size` sequences
39    /// of length `seq_length + 1` (for causal shift).
40    ///
41    /// Returns `Err` if `dataset_dir` is missing or contains no `.bin` shards.
42    pub fn new(
43        dataset_dir: &Path,
44        batch_size: usize,
45        seq_length: usize,
46        pad_id: u32,
47        eos_id: u32,
48    ) -> io::Result<Self> {
49        let mut shards: Vec<PathBuf> = std::fs::read_dir(dataset_dir)?
50            .filter_map(|e| e.ok())
51            .map(|e| e.path())
52            .filter(|p| p.extension().is_some_and(|ext| ext == "bin"))
53            .collect();
54        shards.sort();
55        if shards.is_empty() {
56            return Err(io::Error::new(
57                io::ErrorKind::NotFound,
58                format!("no .bin shards in {}", dataset_dir.display()),
59            ));
60        }
61        Ok(Self {
62            shards,
63            cursor_shard: 0,
64            cursor_reader: None,
65            batch_size,
66            seq_plus_one: seq_length + 1,
67            pad_id,
68            eos_id,
69            wrap_around: false,
70            epochs_completed: 0,
71        })
72    }
73
74    /// Enable corpus wrap-around: when the last shard is exhausted,
75    /// reset the cursor to shard 0 and continue.
76    ///
77    /// This is the standard ML-training behaviour. Without it, an
78    /// 18M-token corpus exhausts in ~2 epochs of a 5K-step run with
79    /// batch=16 seq=512, and the upstream `StepFn` falls back to
80    /// returning placeholder loss `(1.0, 1.0)` — silently producing
81    /// garbage data that breaks convergence. See spec §22 (PR #1073)
82    /// for the corpus-bottleneck investigation.
83    #[must_use]
84    pub fn with_wrap_around(mut self, wrap_around: bool) -> Self {
85        self.wrap_around = wrap_around;
86        self
87    }
88
89    /// Number of times the iterator has cycled through the entire
90    /// shard set. Increments each time the last shard is exhausted
91    /// AND `wrap_around` was true (so a reset happened).
92    #[must_use]
93    pub fn epochs_completed(&self) -> u64 {
94        self.epochs_completed
95    }
96
97    fn ensure_reader(&mut self) -> io::Result<bool> {
98        if self.cursor_reader.is_some() {
99            return Ok(true);
100        }
101        while self.cursor_shard < self.shards.len() {
102            match File::open(&self.shards[self.cursor_shard]) {
103                Ok(f) => {
104                    self.cursor_reader = Some(BufReader::new(f));
105                    return Ok(true);
106                }
107                Err(_) => {
108                    self.cursor_shard += 1;
109                }
110            }
111        }
112        Ok(false)
113    }
114
115    fn read_one_sequence(&mut self) -> io::Result<Option<Vec<u32>>> {
116        let tokens_per_seq = self.seq_plus_one;
117        let mut buf = vec![0u8; tokens_per_seq * 4];
118        loop {
119            if !self.ensure_reader()? {
120                // All shards exhausted. If wrap-around is on, reset cursor
121                // and start over; else return None as before.
122                if self.wrap_around {
123                    self.epochs_completed += 1;
124                    self.cursor_shard = 0;
125                    self.cursor_reader = None;
126                    if !self.ensure_reader()? {
127                        // Still no readable shard after reset — give up
128                        // to avoid infinite loop on a broken shard set.
129                        return Ok(None);
130                    }
131                } else {
132                    return Ok(None);
133                }
134            }
135            let reader = self.cursor_reader.as_mut().expect("reader set above");
136            match reader.read_exact(&mut buf) {
137                Ok(()) => {
138                    let mut seq = Vec::with_capacity(tokens_per_seq);
139                    for chunk in buf.chunks_exact(4) {
140                        seq.push(u32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]));
141                    }
142                    return Ok(Some(seq));
143                }
144                Err(ref e) if e.kind() == io::ErrorKind::UnexpectedEof => {
145                    self.cursor_reader = None;
146                    self.cursor_shard += 1;
147                }
148                Err(e) => return Err(e),
149            }
150        }
151    }
152}
153
154impl Iterator for ShardBatchIter {
155    type Item = LMBatch;
156
157    fn next(&mut self) -> Option<LMBatch> {
158        let mut seqs: Vec<Vec<u32>> = Vec::with_capacity(self.batch_size);
159        for _ in 0..self.batch_size {
160            match self.read_one_sequence() {
161                Ok(Some(seq)) => seqs.push(seq),
162                Ok(None) => break,
163                Err(_) => break,
164            }
165        }
166        if seqs.is_empty() {
167            None
168        } else {
169            Some(LMBatch::from_sequences(&seqs, self.pad_id, self.eos_id))
170        }
171    }
172}
173
174#[cfg(test)]
175mod tests {
176    use super::*;
177    use tempfile::TempDir;
178
179    fn write_shard(dir: &Path, name: &str, tokens: &[u32]) {
180        let path = dir.join(name);
181        let mut bytes = Vec::with_capacity(tokens.len() * 4);
182        for t in tokens {
183            bytes.extend_from_slice(&t.to_le_bytes());
184        }
185        std::fs::write(&path, bytes).expect("shard write");
186    }
187
188    /// Wrap-around regression guard: with `with_wrap_around(true)`,
189    /// the iterator MUST keep yielding batches past the natural shard
190    /// boundary. This is the SHIP-007-adjacent corpus-bottleneck fix —
191    /// without wrap-around an N-token corpus exhausts in 1 epoch and
192    /// the Cuda*StepFn falls back to placeholder `(1.0, 1.0)` losses,
193    /// silently producing garbage gradients (observed 2026-04-26 on a
194    /// 5K-step run that early-stopped at epoch 4 with train_loss=1.0).
195    #[test]
196    fn wrap_around_continues_past_shard_exhaustion() {
197        let tmp = TempDir::new().expect("tempdir");
198        let tokens: Vec<u32> = (0u32..40).collect(); // 8 sequences of len 5
199        write_shard(tmp.path(), "shard-0.bin", &tokens);
200        let mut iter =
201            ShardBatchIter::new(tmp.path(), 2, 4, 0, 0).expect("iter").with_wrap_around(true);
202        // Without wrap-around, we'd get 4 batches then None forever.
203        // With wrap-around, we should get 12 batches (3 epochs of 4).
204        let mut batches = Vec::new();
205        for _ in 0..12 {
206            batches.push(iter.next().expect("wrap-around must keep yielding"));
207        }
208        assert_eq!(batches.len(), 12, "12 batches across 3 simulated epochs");
209        assert!(
210            iter.epochs_completed() >= 2,
211            "epochs_completed = {} should reflect at least 2 wrap resets",
212            iter.epochs_completed()
213        );
214    }
215
216    /// Default behaviour (wrap_around=false) preserved: returns None
217    /// after the corpus is exhausted, matching the historical contract.
218    #[test]
219    fn no_wrap_around_terminates_on_exhaustion() {
220        let tmp = TempDir::new().expect("tempdir");
221        let tokens: Vec<u32> = (0u32..40).collect();
222        write_shard(tmp.path(), "shard-0.bin", &tokens);
223        let iter = ShardBatchIter::new(tmp.path(), 2, 4, 0, 0).expect("iter");
224        let batches: Vec<_> = iter.collect();
225        assert_eq!(batches.len(), 4, "default: 8 seqs / batch=2 = 4 batches then None");
226    }
227
228    #[test]
229    fn single_shard_yields_expected_batch_count() {
230        let tmp = TempDir::new().expect("tempdir");
231        let tokens: Vec<u32> = (0u32..40).collect(); // 40 tokens = 8 × (seq=4+1)
232        write_shard(tmp.path(), "shard-0.bin", &tokens);
233        let iter = ShardBatchIter::new(tmp.path(), 2, 4, 0, 0).expect("iter");
234        let batches: Vec<_> = iter.collect();
235        assert_eq!(batches.len(), 4, "8 seqs / batch_size=2 = 4 batches");
236        assert_eq!(batches[0].batch_size, 2);
237        assert_eq!(batches[0].seq_len, 4);
238    }
239
240    #[test]
241    fn empty_dir_errors() {
242        let tmp = TempDir::new().expect("tempdir");
243        let res = ShardBatchIter::new(tmp.path(), 2, 4, 0, 0);
244        assert!(res.is_err(), "empty dir must error");
245    }
246
247    #[test]
248    fn multi_shard_ordering_is_lexical() {
249        let tmp = TempDir::new().expect("tempdir");
250        write_shard(tmp.path(), "shard-0.bin", &(0u32..10).collect::<Vec<_>>());
251        write_shard(tmp.path(), "shard-1.bin", &(100u32..110).collect::<Vec<_>>());
252        let mut iter = ShardBatchIter::new(tmp.path(), 1, 4, 0, 0).expect("iter");
253        let first = iter.next().expect("first batch");
254        assert_eq!(first.get_input(0).expect("input0")[0], 0, "shard-0 first");
255    }
256}