entrenar/train/
shard_reader.rs1use crate::train::transformer_trainer::LMBatch;
11use std::fs::File;
12use std::io::{self, BufReader, Read};
13use std::path::{Path, PathBuf};
14
15pub struct ShardBatchIter {
26 shards: Vec<PathBuf>,
27 cursor_shard: usize,
28 cursor_reader: Option<BufReader<File>>,
29 batch_size: usize,
30 seq_plus_one: usize,
31 pad_id: u32,
32 eos_id: u32,
33 wrap_around: bool,
34 epochs_completed: u64,
35}
36
37impl ShardBatchIter {
38 pub fn new(
43 dataset_dir: &Path,
44 batch_size: usize,
45 seq_length: usize,
46 pad_id: u32,
47 eos_id: u32,
48 ) -> io::Result<Self> {
49 let mut shards: Vec<PathBuf> = std::fs::read_dir(dataset_dir)?
50 .filter_map(|e| e.ok())
51 .map(|e| e.path())
52 .filter(|p| p.extension().is_some_and(|ext| ext == "bin"))
53 .collect();
54 shards.sort();
55 if shards.is_empty() {
56 return Err(io::Error::new(
57 io::ErrorKind::NotFound,
58 format!("no .bin shards in {}", dataset_dir.display()),
59 ));
60 }
61 Ok(Self {
62 shards,
63 cursor_shard: 0,
64 cursor_reader: None,
65 batch_size,
66 seq_plus_one: seq_length + 1,
67 pad_id,
68 eos_id,
69 wrap_around: false,
70 epochs_completed: 0,
71 })
72 }
73
74 #[must_use]
84 pub fn with_wrap_around(mut self, wrap_around: bool) -> Self {
85 self.wrap_around = wrap_around;
86 self
87 }
88
89 #[must_use]
93 pub fn epochs_completed(&self) -> u64 {
94 self.epochs_completed
95 }
96
97 fn ensure_reader(&mut self) -> io::Result<bool> {
98 if self.cursor_reader.is_some() {
99 return Ok(true);
100 }
101 while self.cursor_shard < self.shards.len() {
102 match File::open(&self.shards[self.cursor_shard]) {
103 Ok(f) => {
104 self.cursor_reader = Some(BufReader::new(f));
105 return Ok(true);
106 }
107 Err(_) => {
108 self.cursor_shard += 1;
109 }
110 }
111 }
112 Ok(false)
113 }
114
115 fn read_one_sequence(&mut self) -> io::Result<Option<Vec<u32>>> {
116 let tokens_per_seq = self.seq_plus_one;
117 let mut buf = vec![0u8; tokens_per_seq * 4];
118 loop {
119 if !self.ensure_reader()? {
120 if self.wrap_around {
123 self.epochs_completed += 1;
124 self.cursor_shard = 0;
125 self.cursor_reader = None;
126 if !self.ensure_reader()? {
127 return Ok(None);
130 }
131 } else {
132 return Ok(None);
133 }
134 }
135 let reader = self.cursor_reader.as_mut().expect("reader set above");
136 match reader.read_exact(&mut buf) {
137 Ok(()) => {
138 let mut seq = Vec::with_capacity(tokens_per_seq);
139 for chunk in buf.chunks_exact(4) {
140 seq.push(u32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]));
141 }
142 return Ok(Some(seq));
143 }
144 Err(ref e) if e.kind() == io::ErrorKind::UnexpectedEof => {
145 self.cursor_reader = None;
146 self.cursor_shard += 1;
147 }
148 Err(e) => return Err(e),
149 }
150 }
151 }
152}
153
154impl Iterator for ShardBatchIter {
155 type Item = LMBatch;
156
157 fn next(&mut self) -> Option<LMBatch> {
158 let mut seqs: Vec<Vec<u32>> = Vec::with_capacity(self.batch_size);
159 for _ in 0..self.batch_size {
160 match self.read_one_sequence() {
161 Ok(Some(seq)) => seqs.push(seq),
162 Ok(None) => break,
163 Err(_) => break,
164 }
165 }
166 if seqs.is_empty() {
167 None
168 } else {
169 Some(LMBatch::from_sequences(&seqs, self.pad_id, self.eos_id))
170 }
171 }
172}
173
174#[cfg(test)]
175mod tests {
176 use super::*;
177 use tempfile::TempDir;
178
179 fn write_shard(dir: &Path, name: &str, tokens: &[u32]) {
180 let path = dir.join(name);
181 let mut bytes = Vec::with_capacity(tokens.len() * 4);
182 for t in tokens {
183 bytes.extend_from_slice(&t.to_le_bytes());
184 }
185 std::fs::write(&path, bytes).expect("shard write");
186 }
187
188 #[test]
196 fn wrap_around_continues_past_shard_exhaustion() {
197 let tmp = TempDir::new().expect("tempdir");
198 let tokens: Vec<u32> = (0u32..40).collect(); write_shard(tmp.path(), "shard-0.bin", &tokens);
200 let mut iter =
201 ShardBatchIter::new(tmp.path(), 2, 4, 0, 0).expect("iter").with_wrap_around(true);
202 let mut batches = Vec::new();
205 for _ in 0..12 {
206 batches.push(iter.next().expect("wrap-around must keep yielding"));
207 }
208 assert_eq!(batches.len(), 12, "12 batches across 3 simulated epochs");
209 assert!(
210 iter.epochs_completed() >= 2,
211 "epochs_completed = {} should reflect at least 2 wrap resets",
212 iter.epochs_completed()
213 );
214 }
215
216 #[test]
219 fn no_wrap_around_terminates_on_exhaustion() {
220 let tmp = TempDir::new().expect("tempdir");
221 let tokens: Vec<u32> = (0u32..40).collect();
222 write_shard(tmp.path(), "shard-0.bin", &tokens);
223 let iter = ShardBatchIter::new(tmp.path(), 2, 4, 0, 0).expect("iter");
224 let batches: Vec<_> = iter.collect();
225 assert_eq!(batches.len(), 4, "default: 8 seqs / batch=2 = 4 batches then None");
226 }
227
228 #[test]
229 fn single_shard_yields_expected_batch_count() {
230 let tmp = TempDir::new().expect("tempdir");
231 let tokens: Vec<u32> = (0u32..40).collect(); write_shard(tmp.path(), "shard-0.bin", &tokens);
233 let iter = ShardBatchIter::new(tmp.path(), 2, 4, 0, 0).expect("iter");
234 let batches: Vec<_> = iter.collect();
235 assert_eq!(batches.len(), 4, "8 seqs / batch_size=2 = 4 batches");
236 assert_eq!(batches[0].batch_size, 2);
237 assert_eq!(batches[0].seq_len, 4);
238 }
239
240 #[test]
241 fn empty_dir_errors() {
242 let tmp = TempDir::new().expect("tempdir");
243 let res = ShardBatchIter::new(tmp.path(), 2, 4, 0, 0);
244 assert!(res.is_err(), "empty dir must error");
245 }
246
247 #[test]
248 fn multi_shard_ordering_is_lexical() {
249 let tmp = TempDir::new().expect("tempdir");
250 write_shard(tmp.path(), "shard-0.bin", &(0u32..10).collect::<Vec<_>>());
251 write_shard(tmp.path(), "shard-1.bin", &(100u32..110).collect::<Vec<_>>());
252 let mut iter = ShardBatchIter::new(tmp.path(), 1, 4, 0, 0).expect("iter");
253 let first = iter.next().expect("first batch");
254 assert_eq!(first.get_input(0).expect("input0")[0], 0, "shard-0 first");
255 }
256}