1use anyhow::Result;
2use candle_core::{Device, Tensor};
3use rand::Rng;
4use rand::seq::SliceRandom;
5use serde::Deserialize;
6use std::io::{self, BufRead, BufReader, Read};
7use std::path::Path;
8
9use crate::io as file_io;
10use crate::tokenizer::Tokenizer;
11
12#[derive(Deserialize)]
13struct JsonlRecord {
14 text: String,
15}
16
17pub struct Dataset {
18 tokens: Vec<u32>,
19 seq_len: usize,
20}
21
22impl Dataset {
23 pub fn new(tokens: Vec<u32>, seq_len: usize) -> Self {
24 Self { tokens, seq_len }
25 }
26
27 fn from_reader<R: Read>(reader: R, tokenizer: &Tokenizer) -> Result<Vec<u32>> {
28 let reader = BufReader::new(reader);
29 let mut all_tokens = Vec::new();
30
31 for line in reader.lines() {
32 let line = line?;
33 if line.is_empty() {
34 continue;
35 }
36 let record: JsonlRecord = serde_json::from_str(&line)?;
37 if !record.text.is_empty() {
38 let tokens = tokenizer.encode(&record.text, false)?;
39 all_tokens.extend(tokens);
40 all_tokens.push(tokenizer.eos_token_id());
41 }
42 }
43
44 Ok(all_tokens)
45 }
46
47 pub fn from_file<P: AsRef<Path>>(
50 path: P,
51 tokenizer: &Tokenizer,
52 seq_len: usize,
53 ) -> Result<Self> {
54 let reader = file_io::open_file(path)?;
55 let tokens = Self::from_reader(reader, tokenizer)?;
56 Ok(Self::new(tokens, seq_len))
57 }
58
59 pub fn from_stdin(tokenizer: &Tokenizer, seq_len: usize) -> Result<Self> {
61 let stdin = io::stdin().lock();
62 let tokens = Self::from_reader(stdin, tokenizer)?;
63 Ok(Self::new(tokens, seq_len))
64 }
65
66 pub fn from_files<P: AsRef<Path>>(
69 paths: &[P],
70 tokenizer: &Tokenizer,
71 seq_len: usize,
72 ) -> Result<Self> {
73 let mut all_tokens = Vec::new();
74
75 for path in paths {
76 let reader = file_io::open_file(path)?;
77
78 for line in reader.lines() {
79 let line = line?;
80 if line.is_empty() {
81 continue;
82 }
83 let record: JsonlRecord = serde_json::from_str(&line)?;
84 if !record.text.is_empty() {
85 let tokens = tokenizer.encode(&record.text, false)?;
86 all_tokens.extend(tokens);
87 all_tokens.push(tokenizer.eos_token_id());
88 }
89 }
90 }
91
92 Ok(Self::new(all_tokens, seq_len))
93 }
94
95 pub fn len(&self) -> usize {
96 if self.tokens.len() <= self.seq_len {
97 0
98 } else {
99 self.tokens.len() - self.seq_len
100 }
101 }
102
103 pub fn is_empty(&self) -> bool {
104 self.len() == 0
105 }
106
107 pub fn get_batch(&self, indices: &[usize], device: &Device) -> Result<(Tensor, Tensor)> {
108 let batch_size = indices.len();
109 let mut input_data = Vec::with_capacity(batch_size * self.seq_len);
110 let mut target_data = Vec::with_capacity(batch_size * self.seq_len);
111
112 for &idx in indices {
113 let start = idx;
114 let end = start + self.seq_len;
115
116 for i in start..end {
117 input_data.push(self.tokens[i]);
118 target_data.push(self.tokens[i + 1]);
119 }
120 }
121
122 let input = Tensor::new(input_data, device)?
123 .reshape((batch_size, self.seq_len))?
124 .to_dtype(candle_core::DType::U32)?;
125 let target = Tensor::new(target_data, device)?
126 .reshape((batch_size, self.seq_len))?
127 .to_dtype(candle_core::DType::U32)?;
128
129 Ok((input, target))
130 }
131
132 pub fn tokens(&self) -> &[u32] {
133 &self.tokens
134 }
135}
136
137pub struct DataLoader {
138 dataset: Dataset,
139 batch_size: usize,
140 shuffle: bool,
141 indices: Vec<usize>,
142 current_pos: usize,
143 rank: usize,
144 world_size: usize,
145 batches_yielded: usize,
146 max_batches: usize,
147 shuffle_seed: u64,
149}
150
151impl DataLoader {
152 pub fn new(dataset: Dataset, batch_size: usize, shuffle: bool) -> Self {
153 Self::new_distributed(dataset, batch_size, shuffle, 0, 1)
154 }
155
156 pub fn new_distributed(
159 dataset: Dataset,
160 batch_size: usize,
161 shuffle: bool,
162 rank: usize,
163 world_size: usize,
164 ) -> Self {
165 let len = dataset.len();
166 let indices: Vec<usize> = (0..len).collect();
167 let total_batches = len / batch_size;
169 let max_batches = total_batches / world_size;
170 Self {
171 dataset,
172 batch_size,
173 shuffle,
174 indices,
175 current_pos: 0,
176 rank,
177 world_size,
178 batches_yielded: 0,
179 max_batches,
180 shuffle_seed: 42,
181 }
182 }
183
184 pub fn reset(&mut self) {
185 self.current_pos = 0;
186 self.batches_yielded = 0;
187 if self.shuffle {
188 use rand::SeedableRng;
190 let mut rng = rand::rngs::StdRng::seed_from_u64(self.shuffle_seed);
191 self.indices.shuffle(&mut rng);
192 }
193 }
194
195 pub fn reset_with_seed(&mut self, seed: u64) {
197 self.shuffle_seed = seed;
198 self.reset();
199 }
200
201 pub fn position(&self) -> usize {
203 self.current_pos
204 }
205
206 pub fn set_position(&mut self, pos: usize) {
208 self.current_pos = pos;
209 self.batches_yielded = pos / self.batch_size / self.world_size;
210 }
211
212 pub fn num_batches(&self) -> usize {
213 (self.dataset.len() / self.batch_size) / self.world_size
215 }
216
217 pub fn next_batch(&mut self, device: &Device) -> Result<Option<(Tensor, Tensor)>> {
218 if self.batches_yielded >= self.max_batches {
220 return Ok(None);
221 }
222
223 loop {
225 if self.current_pos + self.batch_size > self.indices.len() {
226 return Ok(None);
227 }
228
229 let batch_num = self.current_pos / self.batch_size;
230 let batch_indices: Vec<usize> =
231 self.indices[self.current_pos..self.current_pos + self.batch_size].to_vec();
232 self.current_pos += self.batch_size;
233
234 if batch_num % self.world_size == self.rank {
236 self.batches_yielded += 1;
237 let (input, target) = self.dataset.get_batch(&batch_indices, device)?;
238 return Ok(Some((input, target)));
239 }
240 }
242 }
243
244 pub fn iter<'a>(&'a mut self, device: &'a Device) -> DataLoaderIterator<'a> {
245 self.reset();
246 DataLoaderIterator {
247 loader: self,
248 device,
249 }
250 }
251}
252
253pub struct DataLoaderIterator<'a> {
254 loader: &'a mut DataLoader,
255 device: &'a Device,
256}
257
258impl<'a> Iterator for DataLoaderIterator<'a> {
259 type Item = Result<(Tensor, Tensor)>;
260
261 fn next(&mut self) -> Option<Self::Item> {
262 match self.loader.next_batch(self.device) {
263 Ok(Some(batch)) => Some(Ok(batch)),
264 Ok(None) => None,
265 Err(e) => Some(Err(e)),
266 }
267 }
268}
269
270pub fn generate_random_batch(
271 batch_size: usize,
272 seq_len: usize,
273 vocab_size: usize,
274 device: &Device,
275) -> Result<(Tensor, Tensor)> {
276 let mut rng = rand::rng();
277 let input_data: Vec<u32> = (0..batch_size * seq_len)
278 .map(|_| rng.random_range(0..vocab_size as u32))
279 .collect();
280 let target_data: Vec<u32> = (0..batch_size * seq_len)
281 .map(|_| rng.random_range(0..vocab_size as u32))
282 .collect();
283
284 let input = Tensor::new(input_data, device)?
285 .reshape((batch_size, seq_len))?
286 .to_dtype(candle_core::DType::U32)?;
287 let target = Tensor::new(target_data, device)?
288 .reshape((batch_size, seq_len))?
289 .to_dtype(candle_core::DType::U32)?;
290
291 Ok((input, target))
292}