1use crate::train::transformer_trainer::LMBatch;
11use std::fs::File;
12use std::io::{self, BufReader, Read};
13use std::path::{Path, PathBuf};
14
15pub struct ShardBatchIter {
26 shards: Vec<PathBuf>,
27 cursor_shard: usize,
28 cursor_reader: Option<BufReader<File>>,
29 batch_size: usize,
30 seq_plus_one: usize,
31 pad_id: u32,
32 eos_id: u32,
33 wrap_around: bool,
34 epochs_completed: u64,
35 warn_on_wrap_around: bool,
38}
39
40impl ShardBatchIter {
41 pub fn new(
46 dataset_dir: &Path,
47 batch_size: usize,
48 seq_length: usize,
49 pad_id: u32,
50 eos_id: u32,
51 ) -> io::Result<Self> {
52 let mut shards: Vec<PathBuf> = std::fs::read_dir(dataset_dir)?
53 .filter_map(|e| e.ok())
54 .map(|e| e.path())
55 .filter(|p| p.extension().is_some_and(|ext| ext == "bin"))
56 .collect();
57 shards.sort();
58 if shards.is_empty() {
59 return Err(io::Error::new(
60 io::ErrorKind::NotFound,
61 format!("no .bin shards in {}", dataset_dir.display()),
62 ));
63 }
64 Ok(Self {
65 shards,
66 cursor_shard: 0,
67 cursor_reader: None,
68 batch_size,
69 seq_plus_one: seq_length + 1,
70 pad_id,
71 eos_id,
72 wrap_around: false,
73 epochs_completed: 0,
74 warn_on_wrap_around: false,
75 })
76 }
77
78 #[must_use]
88 pub fn with_wrap_around(mut self, wrap_around: bool) -> Self {
89 self.wrap_around = wrap_around;
90 self
91 }
92
93 #[must_use]
97 pub fn epochs_completed(&self) -> u64 {
98 self.epochs_completed
99 }
100
101 #[must_use]
105 pub fn with_warn_on_wrap_around(mut self, warn: bool) -> Self {
106 self.warn_on_wrap_around = warn;
107 self
108 }
109
110 fn ensure_reader(&mut self) -> io::Result<bool> {
111 if self.cursor_reader.is_some() {
112 return Ok(true);
113 }
114 while self.cursor_shard < self.shards.len() {
115 match File::open(&self.shards[self.cursor_shard]) {
116 Ok(f) => {
117 self.cursor_reader = Some(BufReader::new(f));
118 return Ok(true);
119 }
120 Err(_) => {
121 self.cursor_shard += 1;
122 }
123 }
124 }
125 Ok(false)
126 }
127
128 fn read_one_sequence(&mut self) -> io::Result<Option<Vec<u32>>> {
129 let tokens_per_seq = self.seq_plus_one;
130 let mut buf = vec![0u8; tokens_per_seq * 4];
131 loop {
132 if !self.ensure_reader()? {
133 if self.wrap_around {
136 self.epochs_completed += 1;
137 if self.warn_on_wrap_around {
138 eprintln!(
139 "[P2-B] corpus wrap-around #{}: dataset_dir of {} shards exhausted; \
140 cycling. If observed early in run, corpus is too small for the \
141 requested step budget — extend corpus per Chinchilla D ≈ 20·N or \
142 reduce --num-steps.",
143 self.epochs_completed,
144 self.shards.len(),
145 );
146 }
147 self.cursor_shard = 0;
148 self.cursor_reader = None;
149 if !self.ensure_reader()? {
150 return Ok(None);
153 }
154 } else {
155 return Ok(None);
156 }
157 }
158 let reader = self.cursor_reader.as_mut().expect("reader set above");
159 match reader.read_exact(&mut buf) {
160 Ok(()) => {
161 let mut seq = Vec::with_capacity(tokens_per_seq);
162 for chunk in buf.chunks_exact(4) {
163 seq.push(u32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]));
164 }
165 return Ok(Some(seq));
166 }
167 Err(ref e) if e.kind() == io::ErrorKind::UnexpectedEof => {
168 self.cursor_reader = None;
169 self.cursor_shard += 1;
170 }
171 Err(e) => return Err(e),
172 }
173 }
174 }
175}
176
177impl Iterator for ShardBatchIter {
178 type Item = LMBatch;
179
180 fn next(&mut self) -> Option<LMBatch> {
181 let mut seqs: Vec<Vec<u32>> = Vec::with_capacity(self.batch_size);
182 for _ in 0..self.batch_size {
183 match self.read_one_sequence() {
184 Ok(Some(seq)) => seqs.push(seq),
185 Ok(None) => break,
186 Err(_) => break,
187 }
188 }
189 if seqs.is_empty() {
190 None
191 } else {
192 Some(LMBatch::from_sequences(&seqs, self.pad_id, self.eos_id))
193 }
194 }
195}
196
197#[cfg(test)]
198mod tests {
199 use super::*;
200 use tempfile::TempDir;
201
202 fn write_shard(dir: &Path, name: &str, tokens: &[u32]) {
203 let path = dir.join(name);
204 let mut bytes = Vec::with_capacity(tokens.len() * 4);
205 for t in tokens {
206 bytes.extend_from_slice(&t.to_le_bytes());
207 }
208 std::fs::write(&path, bytes).expect("shard write");
209 }
210
211 #[test]
219 fn wrap_around_continues_past_shard_exhaustion() {
220 let tmp = TempDir::new().expect("tempdir");
221 let tokens: Vec<u32> = (0u32..40).collect(); write_shard(tmp.path(), "shard-0.bin", &tokens);
223 let mut iter =
224 ShardBatchIter::new(tmp.path(), 2, 4, 0, 0).expect("iter").with_wrap_around(true);
225 let mut batches = Vec::new();
228 for _ in 0..12 {
229 batches.push(iter.next().expect("wrap-around must keep yielding"));
230 }
231 assert_eq!(batches.len(), 12, "12 batches across 3 simulated epochs");
232 assert!(
233 iter.epochs_completed() >= 2,
234 "epochs_completed = {} should reflect at least 2 wrap resets",
235 iter.epochs_completed()
236 );
237 }
238
239 #[test]
245 fn warn_on_wrap_around_does_not_break_iteration() {
246 let tmp = TempDir::new().expect("tempdir");
247 let tokens: Vec<u32> = (0u32..40).collect();
248 write_shard(tmp.path(), "shard-0.bin", &tokens);
249 let mut iter = ShardBatchIter::new(tmp.path(), 2, 4, 0, 0)
250 .expect("iter")
251 .with_wrap_around(true)
252 .with_warn_on_wrap_around(true);
253 let mut batches = Vec::new();
254 for _ in 0..8 {
255 batches.push(iter.next().expect("must keep yielding with wrap"));
256 }
257 assert_eq!(batches.len(), 8);
258 assert!(
259 iter.epochs_completed() >= 1,
260 "at least one wrap should have fired with 4-batches/epoch × 8 pulls",
261 );
262 }
263
264 #[test]
267 fn warn_without_wrap_is_inert() {
268 let tmp = TempDir::new().expect("tempdir");
269 let tokens: Vec<u32> = (0u32..40).collect();
270 write_shard(tmp.path(), "shard-0.bin", &tokens);
271 let iter = ShardBatchIter::new(tmp.path(), 2, 4, 0, 0)
272 .expect("iter")
273 .with_warn_on_wrap_around(true);
274 let batches: Vec<_> = iter.collect();
275 assert_eq!(batches.len(), 4, "still terminates after one pass without wrap");
276 }
277
278 #[test]
281 fn no_wrap_around_terminates_on_exhaustion() {
282 let tmp = TempDir::new().expect("tempdir");
283 let tokens: Vec<u32> = (0u32..40).collect();
284 write_shard(tmp.path(), "shard-0.bin", &tokens);
285 let iter = ShardBatchIter::new(tmp.path(), 2, 4, 0, 0).expect("iter");
286 let batches: Vec<_> = iter.collect();
287 assert_eq!(batches.len(), 4, "default: 8 seqs / batch=2 = 4 batches then None");
288 }
289
290 #[test]
291 fn single_shard_yields_expected_batch_count() {
292 let tmp = TempDir::new().expect("tempdir");
293 let tokens: Vec<u32> = (0u32..40).collect(); write_shard(tmp.path(), "shard-0.bin", &tokens);
295 let iter = ShardBatchIter::new(tmp.path(), 2, 4, 0, 0).expect("iter");
296 let batches: Vec<_> = iter.collect();
297 assert_eq!(batches.len(), 4, "8 seqs / batch_size=2 = 4 batches");
298 assert_eq!(batches[0].batch_size, 2);
299 assert_eq!(batches[0].seq_len, 4);
300 }
301
302 #[test]
303 fn empty_dir_errors() {
304 let tmp = TempDir::new().expect("tempdir");
305 let res = ShardBatchIter::new(tmp.path(), 2, 4, 0, 0);
306 assert!(res.is_err(), "empty dir must error");
307 }
308
309 #[test]
310 fn multi_shard_ordering_is_lexical() {
311 let tmp = TempDir::new().expect("tempdir");
312 write_shard(tmp.path(), "shard-0.bin", &(0u32..10).collect::<Vec<_>>());
313 write_shard(tmp.path(), "shard-1.bin", &(100u32..110).collect::<Vec<_>>());
314 let mut iter = ShardBatchIter::new(tmp.path(), 1, 4, 0, 0).expect("iter");
315 let first = iter.next().expect("first batch");
316 assert_eq!(first.get_input(0).expect("input0")[0], 0, "shard-0 first");
317 }
318}