Skip to main content

entrenar/train/
shard_reader.rs

1//! Minimal tokenized-shard reader for MODEL-2 pretrain MVP (task #111).
2//!
3//! Reads a directory of `.bin` files containing little-endian u32 tokens,
4//! chunks them into `seq_length + 1` sequences, and yields `LMBatch`es of
5//! `batch_size` sequences. No licensing filter, no MinHash dedup, no PII
6//! scrub — those belong to `apr-corpus-ingest run`.
7//!
8//! Contract: `contracts/dataset-thestack-python-v1.yaml` (shard format).
9
10use crate::train::transformer_trainer::LMBatch;
11use std::fs::File;
12use std::io::{self, BufReader, Read};
13use std::path::{Path, PathBuf};
14
15/// Streaming iterator over `LMBatch`es produced from a directory of
16/// `.bin` token shards (little-endian u32).
17///
18/// Default behaviour matches the historical contract: when the last
19/// shard is exhausted, `next()` returns `None`. For training paths
20/// where the corpus is finite but the run extends beyond a single
21/// epoch, opt in to loop-on-exhaust via `with_wrap_around(true)`.
22/// This is the standard PyTorch/HuggingFace behaviour; `apr pretrain`
23/// uses it in the real-corpus drive paths so that step-count budgets
24/// are honoured even on small datasets.
25pub struct ShardBatchIter {
26    shards: Vec<PathBuf>,
27    cursor_shard: usize,
28    cursor_reader: Option<BufReader<File>>,
29    batch_size: usize,
30    seq_plus_one: usize,
31    pad_id: u32,
32    eos_id: u32,
33    wrap_around: bool,
34    epochs_completed: u64,
35    /// SPEC §82 P2-B: emit `eprintln!` when wrap-around fires.
36    /// Helps operators detect data starvation (corpus too small for step budget).
37    warn_on_wrap_around: bool,
38}
39
40impl ShardBatchIter {
41    /// Build an iterator that yields `LMBatch` with `batch_size` sequences
42    /// of length `seq_length + 1` (for causal shift).
43    ///
44    /// Returns `Err` if `dataset_dir` is missing or contains no `.bin` shards.
45    pub fn new(
46        dataset_dir: &Path,
47        batch_size: usize,
48        seq_length: usize,
49        pad_id: u32,
50        eos_id: u32,
51    ) -> io::Result<Self> {
52        let mut shards: Vec<PathBuf> = std::fs::read_dir(dataset_dir)?
53            .filter_map(|e| e.ok())
54            .map(|e| e.path())
55            .filter(|p| p.extension().is_some_and(|ext| ext == "bin"))
56            .collect();
57        shards.sort();
58        if shards.is_empty() {
59            return Err(io::Error::new(
60                io::ErrorKind::NotFound,
61                format!("no .bin shards in {}", dataset_dir.display()),
62            ));
63        }
64        Ok(Self {
65            shards,
66            cursor_shard: 0,
67            cursor_reader: None,
68            batch_size,
69            seq_plus_one: seq_length + 1,
70            pad_id,
71            eos_id,
72            wrap_around: false,
73            epochs_completed: 0,
74            warn_on_wrap_around: false,
75        })
76    }
77
78    /// Enable corpus wrap-around: when the last shard is exhausted,
79    /// reset the cursor to shard 0 and continue.
80    ///
81    /// This is the standard ML-training behaviour. Without it, an
82    /// 18M-token corpus exhausts in ~2 epochs of a 5K-step run with
83    /// batch=16 seq=512, and the upstream `StepFn` falls back to
84    /// returning placeholder loss `(1.0, 1.0)` — silently producing
85    /// garbage data that breaks convergence. See spec §22 (PR #1073)
86    /// for the corpus-bottleneck investigation.
87    #[must_use]
88    pub fn with_wrap_around(mut self, wrap_around: bool) -> Self {
89        self.wrap_around = wrap_around;
90        self
91    }
92
93    /// Number of times the iterator has cycled through the entire
94    /// shard set. Increments each time the last shard is exhausted
95    /// AND `wrap_around` was true (so a reset happened).
96    #[must_use]
97    pub fn epochs_completed(&self) -> u64 {
98        self.epochs_completed
99    }
100
101    /// SPEC §82 P2-B: when wrap-around fires, emit a stderr line so operators
102    /// can detect data starvation (corpus too small for the requested step
103    /// budget). Default off for backward compatibility with tests.
104    #[must_use]
105    pub fn with_warn_on_wrap_around(mut self, warn: bool) -> Self {
106        self.warn_on_wrap_around = warn;
107        self
108    }
109
110    fn ensure_reader(&mut self) -> io::Result<bool> {
111        if self.cursor_reader.is_some() {
112            return Ok(true);
113        }
114        while self.cursor_shard < self.shards.len() {
115            match File::open(&self.shards[self.cursor_shard]) {
116                Ok(f) => {
117                    self.cursor_reader = Some(BufReader::new(f));
118                    return Ok(true);
119                }
120                Err(_) => {
121                    self.cursor_shard += 1;
122                }
123            }
124        }
125        Ok(false)
126    }
127
128    fn read_one_sequence(&mut self) -> io::Result<Option<Vec<u32>>> {
129        let tokens_per_seq = self.seq_plus_one;
130        let mut buf = vec![0u8; tokens_per_seq * 4];
131        loop {
132            if !self.ensure_reader()? {
133                // All shards exhausted. If wrap-around is on, reset cursor
134                // and start over; else return None as before.
135                if self.wrap_around {
136                    self.epochs_completed += 1;
137                    if self.warn_on_wrap_around {
138                        eprintln!(
139                            "[P2-B] corpus wrap-around #{}: dataset_dir of {} shards exhausted; \
140                             cycling. If observed early in run, corpus is too small for the \
141                             requested step budget — extend corpus per Chinchilla D ≈ 20·N or \
142                             reduce --num-steps.",
143                            self.epochs_completed,
144                            self.shards.len(),
145                        );
146                    }
147                    self.cursor_shard = 0;
148                    self.cursor_reader = None;
149                    if !self.ensure_reader()? {
150                        // Still no readable shard after reset — give up
151                        // to avoid infinite loop on a broken shard set.
152                        return Ok(None);
153                    }
154                } else {
155                    return Ok(None);
156                }
157            }
158            let reader = self.cursor_reader.as_mut().expect("reader set above");
159            match reader.read_exact(&mut buf) {
160                Ok(()) => {
161                    let mut seq = Vec::with_capacity(tokens_per_seq);
162                    for chunk in buf.chunks_exact(4) {
163                        seq.push(u32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]));
164                    }
165                    return Ok(Some(seq));
166                }
167                Err(ref e) if e.kind() == io::ErrorKind::UnexpectedEof => {
168                    self.cursor_reader = None;
169                    self.cursor_shard += 1;
170                }
171                Err(e) => return Err(e),
172            }
173        }
174    }
175}
176
177impl Iterator for ShardBatchIter {
178    type Item = LMBatch;
179
180    fn next(&mut self) -> Option<LMBatch> {
181        let mut seqs: Vec<Vec<u32>> = Vec::with_capacity(self.batch_size);
182        for _ in 0..self.batch_size {
183            match self.read_one_sequence() {
184                Ok(Some(seq)) => seqs.push(seq),
185                Ok(None) => break,
186                Err(_) => break,
187            }
188        }
189        if seqs.is_empty() {
190            None
191        } else {
192            Some(LMBatch::from_sequences(&seqs, self.pad_id, self.eos_id))
193        }
194    }
195}
196
197#[cfg(test)]
198mod tests {
199    use super::*;
200    use tempfile::TempDir;
201
202    fn write_shard(dir: &Path, name: &str, tokens: &[u32]) {
203        let path = dir.join(name);
204        let mut bytes = Vec::with_capacity(tokens.len() * 4);
205        for t in tokens {
206            bytes.extend_from_slice(&t.to_le_bytes());
207        }
208        std::fs::write(&path, bytes).expect("shard write");
209    }
210
211    /// Wrap-around regression guard: with `with_wrap_around(true)`,
212    /// the iterator MUST keep yielding batches past the natural shard
213    /// boundary. This is the SHIP-007-adjacent corpus-bottleneck fix —
214    /// without wrap-around an N-token corpus exhausts in 1 epoch and
215    /// the Cuda*StepFn falls back to placeholder `(1.0, 1.0)` losses,
216    /// silently producing garbage gradients (observed 2026-04-26 on a
217    /// 5K-step run that early-stopped at epoch 4 with train_loss=1.0).
218    #[test]
219    fn wrap_around_continues_past_shard_exhaustion() {
220        let tmp = TempDir::new().expect("tempdir");
221        let tokens: Vec<u32> = (0u32..40).collect(); // 8 sequences of len 5
222        write_shard(tmp.path(), "shard-0.bin", &tokens);
223        let mut iter =
224            ShardBatchIter::new(tmp.path(), 2, 4, 0, 0).expect("iter").with_wrap_around(true);
225        // Without wrap-around, we'd get 4 batches then None forever.
226        // With wrap-around, we should get 12 batches (3 epochs of 4).
227        let mut batches = Vec::new();
228        for _ in 0..12 {
229            batches.push(iter.next().expect("wrap-around must keep yielding"));
230        }
231        assert_eq!(batches.len(), 12, "12 batches across 3 simulated epochs");
232        assert!(
233            iter.epochs_completed() >= 2,
234            "epochs_completed = {} should reflect at least 2 wrap resets",
235            iter.epochs_completed()
236        );
237    }
238
239    /// SPEC §82 P2-B: --warn-on-wrap-around exposes data starvation by
240    /// emitting a stderr line whenever the corpus cycles. This test verifies
241    /// the wrap counter still advances and the iterator stays well-behaved;
242    /// stderr capture is brittle across test harnesses, so we don't assert
243    /// on the literal text — that's a behavioural integration concern.
244    #[test]
245    fn warn_on_wrap_around_does_not_break_iteration() {
246        let tmp = TempDir::new().expect("tempdir");
247        let tokens: Vec<u32> = (0u32..40).collect();
248        write_shard(tmp.path(), "shard-0.bin", &tokens);
249        let mut iter = ShardBatchIter::new(tmp.path(), 2, 4, 0, 0)
250            .expect("iter")
251            .with_wrap_around(true)
252            .with_warn_on_wrap_around(true);
253        let mut batches = Vec::new();
254        for _ in 0..8 {
255            batches.push(iter.next().expect("must keep yielding with wrap"));
256        }
257        assert_eq!(batches.len(), 8);
258        assert!(
259            iter.epochs_completed() >= 1,
260            "at least one wrap should have fired with 4-batches/epoch × 8 pulls",
261        );
262    }
263
264    /// SPEC §82 P2-B: with_warn_on_wrap_around defaults off, and turning it on
265    /// without wrap_around is a no-op (no warning, no wrap, exhaustion is final).
266    #[test]
267    fn warn_without_wrap_is_inert() {
268        let tmp = TempDir::new().expect("tempdir");
269        let tokens: Vec<u32> = (0u32..40).collect();
270        write_shard(tmp.path(), "shard-0.bin", &tokens);
271        let iter = ShardBatchIter::new(tmp.path(), 2, 4, 0, 0)
272            .expect("iter")
273            .with_warn_on_wrap_around(true);
274        let batches: Vec<_> = iter.collect();
275        assert_eq!(batches.len(), 4, "still terminates after one pass without wrap");
276    }
277
278    /// Default behaviour (wrap_around=false) preserved: returns None
279    /// after the corpus is exhausted, matching the historical contract.
280    #[test]
281    fn no_wrap_around_terminates_on_exhaustion() {
282        let tmp = TempDir::new().expect("tempdir");
283        let tokens: Vec<u32> = (0u32..40).collect();
284        write_shard(tmp.path(), "shard-0.bin", &tokens);
285        let iter = ShardBatchIter::new(tmp.path(), 2, 4, 0, 0).expect("iter");
286        let batches: Vec<_> = iter.collect();
287        assert_eq!(batches.len(), 4, "default: 8 seqs / batch=2 = 4 batches then None");
288    }
289
290    #[test]
291    fn single_shard_yields_expected_batch_count() {
292        let tmp = TempDir::new().expect("tempdir");
293        let tokens: Vec<u32> = (0u32..40).collect(); // 40 tokens = 8 × (seq=4+1)
294        write_shard(tmp.path(), "shard-0.bin", &tokens);
295        let iter = ShardBatchIter::new(tmp.path(), 2, 4, 0, 0).expect("iter");
296        let batches: Vec<_> = iter.collect();
297        assert_eq!(batches.len(), 4, "8 seqs / batch_size=2 = 4 batches");
298        assert_eq!(batches[0].batch_size, 2);
299        assert_eq!(batches[0].seq_len, 4);
300    }
301
302    #[test]
303    fn empty_dir_errors() {
304        let tmp = TempDir::new().expect("tempdir");
305        let res = ShardBatchIter::new(tmp.path(), 2, 4, 0, 0);
306        assert!(res.is_err(), "empty dir must error");
307    }
308
309    #[test]
310    fn multi_shard_ordering_is_lexical() {
311        let tmp = TempDir::new().expect("tempdir");
312        write_shard(tmp.path(), "shard-0.bin", &(0u32..10).collect::<Vec<_>>());
313        write_shard(tmp.path(), "shard-1.bin", &(100u32..110).collect::<Vec<_>>());
314        let mut iter = ShardBatchIter::new(tmp.path(), 1, 4, 0, 0).expect("iter");
315        let first = iter.next().expect("first batch");
316        assert_eq!(first.get_input(0).expect("input0")[0], 0, "shard-0 first");
317    }
318}