brainwires_datasets/
dataset.rs1use crate::error::{DatasetError, DatasetResult};
2use crate::types::{PreferencePair, TrainingExample};
3
4pub trait Dataset: Send + Sync {
6 type Item: Clone;
8
9 fn len(&self) -> usize;
11 fn is_empty(&self) -> bool {
13 self.len() == 0
14 }
15 fn get(&self, index: usize) -> Option<&Self::Item>;
17 fn iter(&self) -> Box<dyn Iterator<Item = &Self::Item> + '_>;
19 fn shuffle(&mut self, seed: u64);
21 fn split(&self, ratio: f32) -> (Vec<Self::Item>, Vec<Self::Item>);
23}
24
25#[derive(Debug, Clone)]
27pub struct InstructDataset {
28 examples: Vec<TrainingExample>,
29}
30
31impl InstructDataset {
32 pub fn new(examples: Vec<TrainingExample>) -> Self {
34 Self { examples }
35 }
36
37 pub fn from_examples(examples: impl IntoIterator<Item = TrainingExample>) -> Self {
39 Self {
40 examples: examples.into_iter().collect(),
41 }
42 }
43
44 pub fn push(&mut self, example: TrainingExample) {
46 self.examples.push(example);
47 }
48
49 pub fn extend(&mut self, examples: impl IntoIterator<Item = TrainingExample>) {
51 self.examples.extend(examples);
52 }
53
54 pub fn remove(&mut self, index: usize) -> DatasetResult<TrainingExample> {
56 if index >= self.examples.len() {
57 return Err(DatasetError::IndexOutOfBounds {
58 index,
59 len: self.examples.len(),
60 });
61 }
62 Ok(self.examples.remove(index))
63 }
64
65 pub fn total_estimated_tokens(&self) -> usize {
67 self.examples.iter().map(|e| e.estimated_tokens()).sum()
68 }
69
70 pub fn filter<F>(&self, predicate: F) -> Self
72 where
73 F: Fn(&TrainingExample) -> bool,
74 {
75 Self {
76 examples: self
77 .examples
78 .iter()
79 .filter(|e| predicate(e))
80 .cloned()
81 .collect(),
82 }
83 }
84
85 pub fn as_slice(&self) -> &[TrainingExample] {
87 &self.examples
88 }
89
90 pub fn into_inner(self) -> Vec<TrainingExample> {
92 self.examples
93 }
94}
95
96impl Dataset for InstructDataset {
97 type Item = TrainingExample;
98
99 fn len(&self) -> usize {
100 self.examples.len()
101 }
102
103 fn get(&self, index: usize) -> Option<&TrainingExample> {
104 self.examples.get(index)
105 }
106
107 fn iter(&self) -> Box<dyn Iterator<Item = &TrainingExample> + '_> {
108 Box::new(self.examples.iter())
109 }
110
111 fn shuffle(&mut self, seed: u64) {
112 let len = self.examples.len();
114 if len <= 1 {
115 return;
116 }
117 let mut state = seed;
118 for i in (1..len).rev() {
119 state = state
121 .wrapping_mul(6364136223846793005)
122 .wrapping_add(1442695040888963407);
123 let j = (state >> 33) as usize % (i + 1);
124 self.examples.swap(i, j);
125 }
126 }
127
128 fn split(&self, ratio: f32) -> (Vec<TrainingExample>, Vec<TrainingExample>) {
129 let ratio = ratio.clamp(0.0, 1.0);
130 let split_idx = (self.examples.len() as f32 * ratio) as usize;
131 let train = self.examples[..split_idx].to_vec();
132 let eval = self.examples[split_idx..].to_vec();
133 (train, eval)
134 }
135}
136
137#[derive(Debug, Clone)]
139pub struct PreferenceDataset {
140 pairs: Vec<PreferencePair>,
141}
142
143impl PreferenceDataset {
144 pub fn new(pairs: Vec<PreferencePair>) -> Self {
146 Self { pairs }
147 }
148
149 pub fn push(&mut self, pair: PreferencePair) {
151 self.pairs.push(pair);
152 }
153
154 pub fn from_pairs(pairs: impl IntoIterator<Item = PreferencePair>) -> Self {
156 Self {
157 pairs: pairs.into_iter().collect(),
158 }
159 }
160
161 pub fn extend(&mut self, pairs: impl IntoIterator<Item = PreferencePair>) {
163 self.pairs.extend(pairs);
164 }
165
166 pub fn remove(&mut self, index: usize) -> DatasetResult<PreferencePair> {
168 if index >= self.pairs.len() {
169 return Err(DatasetError::IndexOutOfBounds {
170 index,
171 len: self.pairs.len(),
172 });
173 }
174 Ok(self.pairs.remove(index))
175 }
176
177 pub fn filter<F>(&self, predicate: F) -> Self
179 where
180 F: Fn(&PreferencePair) -> bool,
181 {
182 Self {
183 pairs: self
184 .pairs
185 .iter()
186 .filter(|p| predicate(p))
187 .cloned()
188 .collect(),
189 }
190 }
191
192 pub fn total_estimated_tokens(&self) -> usize {
194 self.pairs.iter().map(|p| p.estimated_tokens()).sum()
195 }
196
197 pub fn as_slice(&self) -> &[PreferencePair] {
199 &self.pairs
200 }
201
202 pub fn into_inner(self) -> Vec<PreferencePair> {
204 self.pairs
205 }
206}
207
208impl Dataset for PreferenceDataset {
209 type Item = PreferencePair;
210
211 fn len(&self) -> usize {
212 self.pairs.len()
213 }
214
215 fn get(&self, index: usize) -> Option<&PreferencePair> {
216 self.pairs.get(index)
217 }
218
219 fn iter(&self) -> Box<dyn Iterator<Item = &PreferencePair> + '_> {
220 Box::new(self.pairs.iter())
221 }
222
223 fn shuffle(&mut self, seed: u64) {
224 let len = self.pairs.len();
225 if len <= 1 {
226 return;
227 }
228 let mut state = seed;
229 for i in (1..len).rev() {
230 state = state
231 .wrapping_mul(6364136223846793005)
232 .wrapping_add(1442695040888963407);
233 let j = (state >> 33) as usize % (i + 1);
234 self.pairs.swap(i, j);
235 }
236 }
237
238 fn split(&self, ratio: f32) -> (Vec<PreferencePair>, Vec<PreferencePair>) {
239 let ratio = ratio.clamp(0.0, 1.0);
240 let split_idx = (self.pairs.len() as f32 * ratio) as usize;
241 let train = self.pairs[..split_idx].to_vec();
242 let eval = self.pairs[split_idx..].to_vec();
243 (train, eval)
244 }
245}
246
247#[cfg(test)]
248mod tests {
249 use super::*;
250 use crate::types::TrainingMessage;
251
252 fn sample_examples(n: usize) -> Vec<TrainingExample> {
253 (0..n)
254 .map(|i| {
255 TrainingExample::with_id(
256 format!("ex-{i}"),
257 vec![
258 TrainingMessage::user(format!("Question {i}")),
259 TrainingMessage::assistant(format!("Answer {i}")),
260 ],
261 )
262 })
263 .collect()
264 }
265
266 #[test]
267 fn test_instruct_dataset_basics() {
268 let ds = InstructDataset::new(sample_examples(10));
269 assert_eq!(ds.len(), 10);
270 assert!(!ds.is_empty());
271 assert!(ds.get(0).is_some());
272 assert!(ds.get(10).is_none());
273 }
274
275 #[test]
276 fn test_instruct_dataset_split() {
277 let ds = InstructDataset::new(sample_examples(10));
278 let (train, eval) = ds.split(0.8);
279 assert_eq!(train.len(), 8);
280 assert_eq!(eval.len(), 2);
281 }
282
283 #[test]
284 fn test_instruct_dataset_shuffle_deterministic() {
285 let mut ds1 = InstructDataset::new(sample_examples(20));
286 let mut ds2 = InstructDataset::new(sample_examples(20));
287 ds1.shuffle(42);
288 ds2.shuffle(42);
289 for (a, b) in ds1.iter().zip(ds2.iter()) {
290 assert_eq!(a.id, b.id);
291 }
292 }
293
294 #[test]
295 fn test_instruct_dataset_filter() {
296 let ds = InstructDataset::new(sample_examples(10));
297 let filtered = ds.filter(|e| e.id.ends_with('5') || e.id.ends_with('7'));
298 assert_eq!(filtered.len(), 2);
299 }
300
301 #[test]
302 fn test_preference_dataset() {
303 let pairs = vec![PreferencePair::new(
304 vec![TrainingMessage::user("Q")],
305 vec![TrainingMessage::assistant("Good")],
306 vec![TrainingMessage::assistant("Bad")],
307 )];
308 let ds = PreferenceDataset::new(pairs);
309 assert_eq!(ds.len(), 1);
310 assert!(ds.get(0).is_some());
311 }
312
313 #[test]
314 fn test_preference_dataset_from_pairs() {
315 let pairs = vec![
316 PreferencePair::new(
317 vec![TrainingMessage::user("Q1")],
318 vec![TrainingMessage::assistant("Good1")],
319 vec![TrainingMessage::assistant("Bad1")],
320 ),
321 PreferencePair::new(
322 vec![TrainingMessage::user("Q2")],
323 vec![TrainingMessage::assistant("Good2")],
324 vec![TrainingMessage::assistant("Bad2")],
325 ),
326 ];
327 let ds = PreferenceDataset::from_pairs(pairs);
328 assert_eq!(ds.len(), 2);
329 }
330
331 #[test]
332 fn test_preference_dataset_extend() {
333 let mut ds = PreferenceDataset::new(vec![]);
334 ds.extend(vec![PreferencePair::new(
335 vec![TrainingMessage::user("Q")],
336 vec![TrainingMessage::assistant("Good")],
337 vec![TrainingMessage::assistant("Bad")],
338 )]);
339 assert_eq!(ds.len(), 1);
340 }
341
342 #[test]
343 fn test_preference_dataset_remove() {
344 let mut ds = PreferenceDataset::new(vec![PreferencePair::new(
345 vec![TrainingMessage::user("Q")],
346 vec![TrainingMessage::assistant("Good")],
347 vec![TrainingMessage::assistant("Bad")],
348 )]);
349 let removed = ds.remove(0).unwrap();
350 assert_eq!(removed.prompt.len(), 1);
351 assert!(ds.is_empty());
352 assert!(ds.remove(0).is_err());
353 }
354
355 #[test]
356 fn test_preference_dataset_filter() {
357 let ds = PreferenceDataset::new(vec![
358 PreferencePair::new(
359 vec![TrainingMessage::user("short")],
360 vec![TrainingMessage::assistant("a")],
361 vec![TrainingMessage::assistant("b")],
362 ),
363 PreferencePair::new(
364 vec![TrainingMessage::user("this is a longer prompt message")],
365 vec![TrainingMessage::assistant("good")],
366 vec![TrainingMessage::assistant("bad")],
367 ),
368 ]);
369 let filtered = ds.filter(|p| p.estimated_tokens() > 5);
370 assert_eq!(filtered.len(), 1);
371 }
372}