Skip to main content

hermes_llm/
data.rs

1use anyhow::Result;
2use candle_core::{Device, Tensor};
3use rand::Rng;
4use rand::seq::SliceRandom;
5use serde::Deserialize;
6use std::io::{self, BufRead, BufReader, Read};
7use std::path::Path;
8
9use crate::io as file_io;
10use crate::tokenizer::Tokenizer;
11
12#[derive(Deserialize)]
13struct JsonlRecord {
14    text: String,
15}
16
17pub struct Dataset {
18    tokens: Vec<u32>,
19    seq_len: usize,
20}
21
22impl Dataset {
23    pub fn new(tokens: Vec<u32>, seq_len: usize) -> Self {
24        Self { tokens, seq_len }
25    }
26
27    fn from_reader<R: Read>(reader: R, tokenizer: &Tokenizer) -> Result<Vec<u32>> {
28        let reader = BufReader::new(reader);
29        let mut all_tokens = Vec::new();
30
31        for line in reader.lines() {
32            let line = line?;
33            if line.is_empty() {
34                continue;
35            }
36            let record: JsonlRecord = serde_json::from_str(&line)?;
37            if !record.text.is_empty() {
38                let tokens = tokenizer.encode(&record.text, false)?;
39                all_tokens.extend(tokens);
40                all_tokens.push(tokenizer.eos_token_id());
41            }
42        }
43
44        Ok(all_tokens)
45    }
46
47    /// Load dataset from a JSONL file where each line has a "text" field.
48    /// Supports .gz and .zst/.zstd compressed files.
49    pub fn from_file<P: AsRef<Path>>(
50        path: P,
51        tokenizer: &Tokenizer,
52        seq_len: usize,
53    ) -> Result<Self> {
54        let reader = file_io::open_file(path)?;
55        let tokens = Self::from_reader(reader, tokenizer)?;
56        Ok(Self::new(tokens, seq_len))
57    }
58
59    /// Load dataset from stdin (JSONL format).
60    pub fn from_stdin(tokenizer: &Tokenizer, seq_len: usize) -> Result<Self> {
61        let stdin = io::stdin().lock();
62        let tokens = Self::from_reader(stdin, tokenizer)?;
63        Ok(Self::new(tokens, seq_len))
64    }
65
66    /// Load dataset from multiple JSONL files.
67    /// Supports .gz and .zst/.zstd compressed files.
68    pub fn from_files<P: AsRef<Path>>(
69        paths: &[P],
70        tokenizer: &Tokenizer,
71        seq_len: usize,
72    ) -> Result<Self> {
73        let mut all_tokens = Vec::new();
74
75        for path in paths {
76            let reader = file_io::open_file(path)?;
77
78            for line in reader.lines() {
79                let line = line?;
80                if line.is_empty() {
81                    continue;
82                }
83                let record: JsonlRecord = serde_json::from_str(&line)?;
84                if !record.text.is_empty() {
85                    let tokens = tokenizer.encode(&record.text, false)?;
86                    all_tokens.extend(tokens);
87                    all_tokens.push(tokenizer.eos_token_id());
88                }
89            }
90        }
91
92        Ok(Self::new(all_tokens, seq_len))
93    }
94
95    pub fn len(&self) -> usize {
96        if self.tokens.len() <= self.seq_len {
97            0
98        } else {
99            self.tokens.len() - self.seq_len
100        }
101    }
102
103    pub fn is_empty(&self) -> bool {
104        self.len() == 0
105    }
106
107    pub fn get_batch(&self, indices: &[usize], device: &Device) -> Result<(Tensor, Tensor)> {
108        let batch_size = indices.len();
109        let mut input_data = Vec::with_capacity(batch_size * self.seq_len);
110        let mut target_data = Vec::with_capacity(batch_size * self.seq_len);
111
112        for &idx in indices {
113            let start = idx;
114            let end = start + self.seq_len;
115
116            for i in start..end {
117                input_data.push(self.tokens[i]);
118                target_data.push(self.tokens[i + 1]);
119            }
120        }
121
122        let input = Tensor::new(input_data, device)?
123            .reshape((batch_size, self.seq_len))?
124            .to_dtype(candle_core::DType::U32)?;
125        let target = Tensor::new(target_data, device)?
126            .reshape((batch_size, self.seq_len))?
127            .to_dtype(candle_core::DType::U32)?;
128
129        Ok((input, target))
130    }
131
132    pub fn tokens(&self) -> &[u32] {
133        &self.tokens
134    }
135}
136
137pub struct DataLoader {
138    dataset: Dataset,
139    batch_size: usize,
140    shuffle: bool,
141    indices: Vec<usize>,
142    current_pos: usize,
143    rank: usize,
144    world_size: usize,
145    batches_yielded: usize,
146    max_batches: usize,
147    /// Seed used for shuffling (for reproducibility on resume)
148    shuffle_seed: u64,
149}
150
151impl DataLoader {
152    pub fn new(dataset: Dataset, batch_size: usize, shuffle: bool) -> Self {
153        Self::new_distributed(dataset, batch_size, shuffle, 0, 1)
154    }
155
156    /// Create a distributed data loader that shards data across ranks
157    /// Each rank processes 1/world_size of the data
158    pub fn new_distributed(
159        dataset: Dataset,
160        batch_size: usize,
161        shuffle: bool,
162        rank: usize,
163        world_size: usize,
164    ) -> Self {
165        let len = dataset.len();
166        let indices: Vec<usize> = (0..len).collect();
167        // Ensure all ranks process exactly the same number of batches
168        let total_batches = len / batch_size;
169        let max_batches = total_batches / world_size;
170        Self {
171            dataset,
172            batch_size,
173            shuffle,
174            indices,
175            current_pos: 0,
176            rank,
177            world_size,
178            batches_yielded: 0,
179            max_batches,
180            shuffle_seed: 42,
181        }
182    }
183
184    pub fn reset(&mut self) {
185        self.current_pos = 0;
186        self.batches_yielded = 0;
187        if self.shuffle {
188            // Use deterministic seed for reproducible shuffle across all ranks
189            use rand::SeedableRng;
190            let mut rng = rand::rngs::StdRng::seed_from_u64(self.shuffle_seed);
191            self.indices.shuffle(&mut rng);
192        }
193    }
194
195    /// Reset with a specific seed (for reproducibility across epochs)
196    pub fn reset_with_seed(&mut self, seed: u64) {
197        self.shuffle_seed = seed;
198        self.reset();
199    }
200
201    /// Get current position for checkpointing
202    pub fn position(&self) -> usize {
203        self.current_pos
204    }
205
206    /// Set position (for resuming from checkpoint)
207    pub fn set_position(&mut self, pos: usize) {
208        self.current_pos = pos;
209        self.batches_yielded = pos / self.batch_size / self.world_size;
210    }
211
212    pub fn num_batches(&self) -> usize {
213        // Each rank processes 1/world_size of batches
214        (self.dataset.len() / self.batch_size) / self.world_size
215    }
216
217    pub fn next_batch(&mut self, device: &Device) -> Result<Option<(Tensor, Tensor)>> {
218        // Stop if we've yielded max_batches (ensures all ranks process same count)
219        if self.batches_yielded >= self.max_batches {
220            return Ok(None);
221        }
222
223        // In distributed mode, each rank processes every world_size-th batch
224        loop {
225            if self.current_pos + self.batch_size > self.indices.len() {
226                return Ok(None);
227            }
228
229            let batch_num = self.current_pos / self.batch_size;
230            let batch_indices: Vec<usize> =
231                self.indices[self.current_pos..self.current_pos + self.batch_size].to_vec();
232            self.current_pos += self.batch_size;
233
234            // Only process batches assigned to this rank
235            if batch_num % self.world_size == self.rank {
236                self.batches_yielded += 1;
237                let (input, target) = self.dataset.get_batch(&batch_indices, device)?;
238                return Ok(Some((input, target)));
239            }
240            // Skip batches assigned to other ranks
241        }
242    }
243
244    pub fn iter<'a>(&'a mut self, device: &'a Device) -> DataLoaderIterator<'a> {
245        self.reset();
246        DataLoaderIterator {
247            loader: self,
248            device,
249        }
250    }
251}
252
253pub struct DataLoaderIterator<'a> {
254    loader: &'a mut DataLoader,
255    device: &'a Device,
256}
257
258impl<'a> Iterator for DataLoaderIterator<'a> {
259    type Item = Result<(Tensor, Tensor)>;
260
261    fn next(&mut self) -> Option<Self::Item> {
262        match self.loader.next_batch(self.device) {
263            Ok(Some(batch)) => Some(Ok(batch)),
264            Ok(None) => None,
265            Err(e) => Some(Err(e)),
266        }
267    }
268}
269
270pub fn generate_random_batch(
271    batch_size: usize,
272    seq_len: usize,
273    vocab_size: usize,
274    device: &Device,
275) -> Result<(Tensor, Tensor)> {
276    let mut rng = rand::rng();
277    let input_data: Vec<u32> = (0..batch_size * seq_len)
278        .map(|_| rng.random_range(0..vocab_size as u32))
279        .collect();
280    let target_data: Vec<u32> = (0..batch_size * seq_len)
281        .map(|_| rng.random_range(0..vocab_size as u32))
282        .collect();
283
284    let input = Tensor::new(input_data, device)?
285        .reshape((batch_size, seq_len))?
286        .to_dtype(candle_core::DType::U32)?;
287    let target = Tensor::new(target_data, device)?
288        .reshape((batch_size, seq_len))?
289        .to_dtype(candle_core::DType::U32)?;
290
291    Ok((input, target))
292}