1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
//! Minimal tokenized-shard reader for MODEL-2 pretrain MVP (task #111).
//!
//! Reads a directory of `.bin` files containing little-endian u32 tokens,
//! chunks them into `seq_length + 1` sequences, and yields `LMBatch`es of
//! `batch_size` sequences. No licensing filter, no MinHash dedup, no PII
//! scrub — those belong to `apr-corpus-ingest run`.
//!
//! Contract: `contracts/dataset-thestack-python-v1.yaml` (shard format).
use crate::train::transformer_trainer::LMBatch;
use std::fs::File;
use std::io::{self, BufReader, Read};
use std::path::{Path, PathBuf};
/// Streaming iterator over `LMBatch`es produced from a directory of
/// `.bin` token shards (little-endian u32).
///
/// Default behaviour matches the historical contract: when the last
/// shard is exhausted, `next()` returns `None`. For training paths
/// where the corpus is finite but the run extends beyond a single
/// epoch, opt in to loop-on-exhaust via `with_wrap_around(true)`.
/// This is the standard PyTorch/HuggingFace behaviour; `apr pretrain`
/// uses it in the real-corpus drive paths so that step-count budgets
/// are honoured even on small datasets.
pub struct ShardBatchIter {
shards: Vec<PathBuf>,
cursor_shard: usize,
cursor_reader: Option<BufReader<File>>,
batch_size: usize,
seq_plus_one: usize,
pad_id: u32,
eos_id: u32,
wrap_around: bool,
epochs_completed: u64,
/// SPEC §82 P2-B: emit `eprintln!` when wrap-around fires.
/// Helps operators detect data starvation (corpus too small for step budget).
warn_on_wrap_around: bool,
}
impl ShardBatchIter {
/// Build an iterator that yields `LMBatch` with `batch_size` sequences
/// of length `seq_length + 1` (for causal shift).
///
/// Returns `Err` if `dataset_dir` is missing or contains no `.bin` shards.
pub fn new(
dataset_dir: &Path,
batch_size: usize,
seq_length: usize,
pad_id: u32,
eos_id: u32,
) -> io::Result<Self> {
let mut shards: Vec<PathBuf> = std::fs::read_dir(dataset_dir)?
.filter_map(|e| e.ok())
.map(|e| e.path())
.filter(|p| p.extension().is_some_and(|ext| ext == "bin"))
.collect();
shards.sort();
if shards.is_empty() {
return Err(io::Error::new(
io::ErrorKind::NotFound,
format!("no .bin shards in {}", dataset_dir.display()),
));
}
Ok(Self {
shards,
cursor_shard: 0,
cursor_reader: None,
batch_size,
seq_plus_one: seq_length + 1,
pad_id,
eos_id,
wrap_around: false,
epochs_completed: 0,
warn_on_wrap_around: false,
})
}
/// Enable corpus wrap-around: when the last shard is exhausted,
/// reset the cursor to shard 0 and continue.
///
/// This is the standard ML-training behaviour. Without it, an
/// 18M-token corpus exhausts in ~2 epochs of a 5K-step run with
/// batch=16 seq=512, and the upstream `StepFn` falls back to
/// returning placeholder loss `(1.0, 1.0)` — silently producing
/// garbage data that breaks convergence. See spec §22 (PR #1073)
/// for the corpus-bottleneck investigation.
#[must_use]
pub fn with_wrap_around(mut self, wrap_around: bool) -> Self {
self.wrap_around = wrap_around;
self
}
/// Number of times the iterator has cycled through the entire
/// shard set. Increments each time the last shard is exhausted
/// AND `wrap_around` was true (so a reset happened).
#[must_use]
pub fn epochs_completed(&self) -> u64 {
self.epochs_completed
}
/// SPEC §82 P2-B: when wrap-around fires, emit a stderr line so operators
/// can detect data starvation (corpus too small for the requested step
/// budget). Default off for backward compatibility with tests.
#[must_use]
pub fn with_warn_on_wrap_around(mut self, warn: bool) -> Self {
self.warn_on_wrap_around = warn;
self
}
fn ensure_reader(&mut self) -> io::Result<bool> {
if self.cursor_reader.is_some() {
return Ok(true);
}
while self.cursor_shard < self.shards.len() {
match File::open(&self.shards[self.cursor_shard]) {
Ok(f) => {
self.cursor_reader = Some(BufReader::new(f));
return Ok(true);
}
Err(_) => {
self.cursor_shard += 1;
}
}
}
Ok(false)
}
fn read_one_sequence(&mut self) -> io::Result<Option<Vec<u32>>> {
let tokens_per_seq = self.seq_plus_one;
let mut buf = vec![0u8; tokens_per_seq * 4];
loop {
if !self.ensure_reader()? {
// All shards exhausted. If wrap-around is on, reset cursor
// and start over; else return None as before.
if self.wrap_around {
self.epochs_completed += 1;
if self.warn_on_wrap_around {
eprintln!(
"[P2-B] corpus wrap-around #{}: dataset_dir of {} shards exhausted; \
cycling. If observed early in run, corpus is too small for the \
requested step budget — extend corpus per Chinchilla D ≈ 20·N or \
reduce --num-steps.",
self.epochs_completed,
self.shards.len(),
);
}
self.cursor_shard = 0;
self.cursor_reader = None;
if !self.ensure_reader()? {
// Still no readable shard after reset — give up
// to avoid infinite loop on a broken shard set.
return Ok(None);
}
} else {
return Ok(None);
}
}
let reader = self.cursor_reader.as_mut().expect("reader set above");
match reader.read_exact(&mut buf) {
Ok(()) => {
let mut seq = Vec::with_capacity(tokens_per_seq);
for chunk in buf.chunks_exact(4) {
seq.push(u32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]));
}
return Ok(Some(seq));
}
Err(ref e) if e.kind() == io::ErrorKind::UnexpectedEof => {
self.cursor_reader = None;
self.cursor_shard += 1;
}
Err(e) => return Err(e),
}
}
}
}
impl Iterator for ShardBatchIter {
type Item = LMBatch;
fn next(&mut self) -> Option<LMBatch> {
let mut seqs: Vec<Vec<u32>> = Vec::with_capacity(self.batch_size);
for _ in 0..self.batch_size {
match self.read_one_sequence() {
Ok(Some(seq)) => seqs.push(seq),
Ok(None) => break,
Err(_) => break,
}
}
if seqs.is_empty() {
None
} else {
Some(LMBatch::from_sequences(&seqs, self.pad_id, self.eos_id))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
fn write_shard(dir: &Path, name: &str, tokens: &[u32]) {
let path = dir.join(name);
let mut bytes = Vec::with_capacity(tokens.len() * 4);
for t in tokens {
bytes.extend_from_slice(&t.to_le_bytes());
}
std::fs::write(&path, bytes).expect("shard write");
}
/// Wrap-around regression guard: with `with_wrap_around(true)`,
/// the iterator MUST keep yielding batches past the natural shard
/// boundary. This is the SHIP-007-adjacent corpus-bottleneck fix —
/// without wrap-around an N-token corpus exhausts in 1 epoch and
/// the Cuda*StepFn falls back to placeholder `(1.0, 1.0)` losses,
/// silently producing garbage gradients (observed 2026-04-26 on a
/// 5K-step run that early-stopped at epoch 4 with train_loss=1.0).
#[test]
fn wrap_around_continues_past_shard_exhaustion() {
let tmp = TempDir::new().expect("tempdir");
let tokens: Vec<u32> = (0u32..40).collect(); // 8 sequences of len 5
write_shard(tmp.path(), "shard-0.bin", &tokens);
let mut iter =
ShardBatchIter::new(tmp.path(), 2, 4, 0, 0).expect("iter").with_wrap_around(true);
// Without wrap-around, we'd get 4 batches then None forever.
// With wrap-around, we should get 12 batches (3 epochs of 4).
let mut batches = Vec::new();
for _ in 0..12 {
batches.push(iter.next().expect("wrap-around must keep yielding"));
}
assert_eq!(batches.len(), 12, "12 batches across 3 simulated epochs");
assert!(
iter.epochs_completed() >= 2,
"epochs_completed = {} should reflect at least 2 wrap resets",
iter.epochs_completed()
);
}
/// SPEC §82 P2-B: --warn-on-wrap-around exposes data starvation by
/// emitting a stderr line whenever the corpus cycles. This test verifies
/// the wrap counter still advances and the iterator stays well-behaved;
/// stderr capture is brittle across test harnesses, so we don't assert
/// on the literal text — that's a behavioural integration concern.
#[test]
fn warn_on_wrap_around_does_not_break_iteration() {
let tmp = TempDir::new().expect("tempdir");
let tokens: Vec<u32> = (0u32..40).collect();
write_shard(tmp.path(), "shard-0.bin", &tokens);
let mut iter = ShardBatchIter::new(tmp.path(), 2, 4, 0, 0)
.expect("iter")
.with_wrap_around(true)
.with_warn_on_wrap_around(true);
let mut batches = Vec::new();
for _ in 0..8 {
batches.push(iter.next().expect("must keep yielding with wrap"));
}
assert_eq!(batches.len(), 8);
assert!(
iter.epochs_completed() >= 1,
"at least one wrap should have fired with 4-batches/epoch × 8 pulls",
);
}
/// SPEC §82 P2-B: with_warn_on_wrap_around defaults off, and turning it on
/// without wrap_around is a no-op (no warning, no wrap, exhaustion is final).
#[test]
fn warn_without_wrap_is_inert() {
let tmp = TempDir::new().expect("tempdir");
let tokens: Vec<u32> = (0u32..40).collect();
write_shard(tmp.path(), "shard-0.bin", &tokens);
let iter = ShardBatchIter::new(tmp.path(), 2, 4, 0, 0)
.expect("iter")
.with_warn_on_wrap_around(true);
let batches: Vec<_> = iter.collect();
assert_eq!(batches.len(), 4, "still terminates after one pass without wrap");
}
/// Default behaviour (wrap_around=false) preserved: returns None
/// after the corpus is exhausted, matching the historical contract.
#[test]
fn no_wrap_around_terminates_on_exhaustion() {
let tmp = TempDir::new().expect("tempdir");
let tokens: Vec<u32> = (0u32..40).collect();
write_shard(tmp.path(), "shard-0.bin", &tokens);
let iter = ShardBatchIter::new(tmp.path(), 2, 4, 0, 0).expect("iter");
let batches: Vec<_> = iter.collect();
assert_eq!(batches.len(), 4, "default: 8 seqs / batch=2 = 4 batches then None");
}
#[test]
fn single_shard_yields_expected_batch_count() {
let tmp = TempDir::new().expect("tempdir");
let tokens: Vec<u32> = (0u32..40).collect(); // 40 tokens = 8 × (seq=4+1)
write_shard(tmp.path(), "shard-0.bin", &tokens);
let iter = ShardBatchIter::new(tmp.path(), 2, 4, 0, 0).expect("iter");
let batches: Vec<_> = iter.collect();
assert_eq!(batches.len(), 4, "8 seqs / batch_size=2 = 4 batches");
assert_eq!(batches[0].batch_size, 2);
assert_eq!(batches[0].seq_len, 4);
}
#[test]
fn empty_dir_errors() {
let tmp = TempDir::new().expect("tempdir");
let res = ShardBatchIter::new(tmp.path(), 2, 4, 0, 0);
assert!(res.is_err(), "empty dir must error");
}
#[test]
fn multi_shard_ordering_is_lexical() {
let tmp = TempDir::new().expect("tempdir");
write_shard(tmp.path(), "shard-0.bin", &(0u32..10).collect::<Vec<_>>());
write_shard(tmp.path(), "shard-1.bin", &(100u32..110).collect::<Vec<_>>());
let mut iter = ShardBatchIter::new(tmp.path(), 1, 4, 0, 0).expect("iter");
let first = iter.next().expect("first batch");
assert_eq!(first.get_input(0).expect("input0")[0], 0, "shard-0 first");
}
}