Skip to main content

brainwires_datasets/
dataset.rs

1use crate::error::{DatasetError, DatasetResult};
2use crate::types::{PreferencePair, TrainingExample};
3
4/// Core dataset abstraction.
5pub trait Dataset: Send + Sync {
6    /// The item type stored in this dataset.
7    type Item: Clone;
8
9    /// Return the number of items in the dataset.
10    fn len(&self) -> usize;
11    /// Return true if the dataset is empty.
12    fn is_empty(&self) -> bool {
13        self.len() == 0
14    }
15    /// Get an item by index.
16    fn get(&self, index: usize) -> Option<&Self::Item>;
17    /// Return an iterator over all items.
18    fn iter(&self) -> Box<dyn Iterator<Item = &Self::Item> + '_>;
19    /// Shuffle the dataset in place using the given seed.
20    fn shuffle(&mut self, seed: u64);
21    /// Split the dataset by ratio into two vectors.
22    fn split(&self, ratio: f32) -> (Vec<Self::Item>, Vec<Self::Item>);
23}
24
25/// Instruction-tuning dataset (multi-turn conversations).
26#[derive(Debug, Clone)]
27pub struct InstructDataset {
28    examples: Vec<TrainingExample>,
29}
30
31impl InstructDataset {
32    /// Create a new instruct dataset from a vector of examples.
33    pub fn new(examples: Vec<TrainingExample>) -> Self {
34        Self { examples }
35    }
36
37    /// Create a new instruct dataset from an iterator of examples.
38    pub fn from_examples(examples: impl IntoIterator<Item = TrainingExample>) -> Self {
39        Self {
40            examples: examples.into_iter().collect(),
41        }
42    }
43
44    /// Append a single example to the dataset.
45    pub fn push(&mut self, example: TrainingExample) {
46        self.examples.push(example);
47    }
48
49    /// Extend the dataset with an iterator of examples.
50    pub fn extend(&mut self, examples: impl IntoIterator<Item = TrainingExample>) {
51        self.examples.extend(examples);
52    }
53
54    /// Remove and return the example at the given index.
55    pub fn remove(&mut self, index: usize) -> DatasetResult<TrainingExample> {
56        if index >= self.examples.len() {
57            return Err(DatasetError::IndexOutOfBounds {
58                index,
59                len: self.examples.len(),
60            });
61        }
62        Ok(self.examples.remove(index))
63    }
64
65    /// Total estimated tokens across all examples.
66    pub fn total_estimated_tokens(&self) -> usize {
67        self.examples.iter().map(|e| e.estimated_tokens()).sum()
68    }
69
70    /// Filter examples by a predicate.
71    pub fn filter<F>(&self, predicate: F) -> Self
72    where
73        F: Fn(&TrainingExample) -> bool,
74    {
75        Self {
76            examples: self
77                .examples
78                .iter()
79                .filter(|e| predicate(e))
80                .cloned()
81                .collect(),
82        }
83    }
84
85    /// Get all examples as a slice.
86    pub fn as_slice(&self) -> &[TrainingExample] {
87        &self.examples
88    }
89
90    /// Consume self and return the underlying Vec.
91    pub fn into_inner(self) -> Vec<TrainingExample> {
92        self.examples
93    }
94}
95
96impl Dataset for InstructDataset {
97    type Item = TrainingExample;
98
99    fn len(&self) -> usize {
100        self.examples.len()
101    }
102
103    fn get(&self, index: usize) -> Option<&TrainingExample> {
104        self.examples.get(index)
105    }
106
107    fn iter(&self) -> Box<dyn Iterator<Item = &TrainingExample> + '_> {
108        Box::new(self.examples.iter())
109    }
110
111    fn shuffle(&mut self, seed: u64) {
112        // Simple Fisher-Yates shuffle with deterministic seed
113        let len = self.examples.len();
114        if len <= 1 {
115            return;
116        }
117        let mut state = seed;
118        for i in (1..len).rev() {
119            // Simple LCG for deterministic shuffle
120            state = state
121                .wrapping_mul(6364136223846793005)
122                .wrapping_add(1442695040888963407);
123            let j = (state >> 33) as usize % (i + 1);
124            self.examples.swap(i, j);
125        }
126    }
127
128    fn split(&self, ratio: f32) -> (Vec<TrainingExample>, Vec<TrainingExample>) {
129        let ratio = ratio.clamp(0.0, 1.0);
130        let split_idx = (self.examples.len() as f32 * ratio) as usize;
131        let train = self.examples[..split_idx].to_vec();
132        let eval = self.examples[split_idx..].to_vec();
133        (train, eval)
134    }
135}
136
137/// Preference dataset for DPO/ORPO training.
138#[derive(Debug, Clone)]
139pub struct PreferenceDataset {
140    pairs: Vec<PreferencePair>,
141}
142
143impl PreferenceDataset {
144    /// Create a new preference dataset from a vector of pairs.
145    pub fn new(pairs: Vec<PreferencePair>) -> Self {
146        Self { pairs }
147    }
148
149    /// Append a single preference pair to the dataset.
150    pub fn push(&mut self, pair: PreferencePair) {
151        self.pairs.push(pair);
152    }
153
154    /// Create a new preference dataset from an iterator of pairs.
155    pub fn from_pairs(pairs: impl IntoIterator<Item = PreferencePair>) -> Self {
156        Self {
157            pairs: pairs.into_iter().collect(),
158        }
159    }
160
161    /// Extend the dataset with an iterator of pairs.
162    pub fn extend(&mut self, pairs: impl IntoIterator<Item = PreferencePair>) {
163        self.pairs.extend(pairs);
164    }
165
166    /// Remove and return the pair at the given index.
167    pub fn remove(&mut self, index: usize) -> DatasetResult<PreferencePair> {
168        if index >= self.pairs.len() {
169            return Err(DatasetError::IndexOutOfBounds {
170                index,
171                len: self.pairs.len(),
172            });
173        }
174        Ok(self.pairs.remove(index))
175    }
176
177    /// Filter pairs by a predicate.
178    pub fn filter<F>(&self, predicate: F) -> Self
179    where
180        F: Fn(&PreferencePair) -> bool,
181    {
182        Self {
183            pairs: self
184                .pairs
185                .iter()
186                .filter(|p| predicate(p))
187                .cloned()
188                .collect(),
189        }
190    }
191
192    /// Total estimated tokens across all preference pairs.
193    pub fn total_estimated_tokens(&self) -> usize {
194        self.pairs.iter().map(|p| p.estimated_tokens()).sum()
195    }
196
197    /// Get all pairs as a slice.
198    pub fn as_slice(&self) -> &[PreferencePair] {
199        &self.pairs
200    }
201
202    /// Consume self and return the underlying vector of pairs.
203    pub fn into_inner(self) -> Vec<PreferencePair> {
204        self.pairs
205    }
206}
207
208impl Dataset for PreferenceDataset {
209    type Item = PreferencePair;
210
211    fn len(&self) -> usize {
212        self.pairs.len()
213    }
214
215    fn get(&self, index: usize) -> Option<&PreferencePair> {
216        self.pairs.get(index)
217    }
218
219    fn iter(&self) -> Box<dyn Iterator<Item = &PreferencePair> + '_> {
220        Box::new(self.pairs.iter())
221    }
222
223    fn shuffle(&mut self, seed: u64) {
224        let len = self.pairs.len();
225        if len <= 1 {
226            return;
227        }
228        let mut state = seed;
229        for i in (1..len).rev() {
230            state = state
231                .wrapping_mul(6364136223846793005)
232                .wrapping_add(1442695040888963407);
233            let j = (state >> 33) as usize % (i + 1);
234            self.pairs.swap(i, j);
235        }
236    }
237
238    fn split(&self, ratio: f32) -> (Vec<PreferencePair>, Vec<PreferencePair>) {
239        let ratio = ratio.clamp(0.0, 1.0);
240        let split_idx = (self.pairs.len() as f32 * ratio) as usize;
241        let train = self.pairs[..split_idx].to_vec();
242        let eval = self.pairs[split_idx..].to_vec();
243        (train, eval)
244    }
245}
246
247#[cfg(test)]
248mod tests {
249    use super::*;
250    use crate::types::TrainingMessage;
251
252    fn sample_examples(n: usize) -> Vec<TrainingExample> {
253        (0..n)
254            .map(|i| {
255                TrainingExample::with_id(
256                    format!("ex-{i}"),
257                    vec![
258                        TrainingMessage::user(format!("Question {i}")),
259                        TrainingMessage::assistant(format!("Answer {i}")),
260                    ],
261                )
262            })
263            .collect()
264    }
265
266    #[test]
267    fn test_instruct_dataset_basics() {
268        let ds = InstructDataset::new(sample_examples(10));
269        assert_eq!(ds.len(), 10);
270        assert!(!ds.is_empty());
271        assert!(ds.get(0).is_some());
272        assert!(ds.get(10).is_none());
273    }
274
275    #[test]
276    fn test_instruct_dataset_split() {
277        let ds = InstructDataset::new(sample_examples(10));
278        let (train, eval) = ds.split(0.8);
279        assert_eq!(train.len(), 8);
280        assert_eq!(eval.len(), 2);
281    }
282
283    #[test]
284    fn test_instruct_dataset_shuffle_deterministic() {
285        let mut ds1 = InstructDataset::new(sample_examples(20));
286        let mut ds2 = InstructDataset::new(sample_examples(20));
287        ds1.shuffle(42);
288        ds2.shuffle(42);
289        for (a, b) in ds1.iter().zip(ds2.iter()) {
290            assert_eq!(a.id, b.id);
291        }
292    }
293
294    #[test]
295    fn test_instruct_dataset_filter() {
296        let ds = InstructDataset::new(sample_examples(10));
297        let filtered = ds.filter(|e| e.id.ends_with('5') || e.id.ends_with('7'));
298        assert_eq!(filtered.len(), 2);
299    }
300
301    #[test]
302    fn test_preference_dataset() {
303        let pairs = vec![PreferencePair::new(
304            vec![TrainingMessage::user("Q")],
305            vec![TrainingMessage::assistant("Good")],
306            vec![TrainingMessage::assistant("Bad")],
307        )];
308        let ds = PreferenceDataset::new(pairs);
309        assert_eq!(ds.len(), 1);
310        assert!(ds.get(0).is_some());
311    }
312
313    #[test]
314    fn test_preference_dataset_from_pairs() {
315        let pairs = vec![
316            PreferencePair::new(
317                vec![TrainingMessage::user("Q1")],
318                vec![TrainingMessage::assistant("Good1")],
319                vec![TrainingMessage::assistant("Bad1")],
320            ),
321            PreferencePair::new(
322                vec![TrainingMessage::user("Q2")],
323                vec![TrainingMessage::assistant("Good2")],
324                vec![TrainingMessage::assistant("Bad2")],
325            ),
326        ];
327        let ds = PreferenceDataset::from_pairs(pairs);
328        assert_eq!(ds.len(), 2);
329    }
330
331    #[test]
332    fn test_preference_dataset_extend() {
333        let mut ds = PreferenceDataset::new(vec![]);
334        ds.extend(vec![PreferencePair::new(
335            vec![TrainingMessage::user("Q")],
336            vec![TrainingMessage::assistant("Good")],
337            vec![TrainingMessage::assistant("Bad")],
338        )]);
339        assert_eq!(ds.len(), 1);
340    }
341
342    #[test]
343    fn test_preference_dataset_remove() {
344        let mut ds = PreferenceDataset::new(vec![PreferencePair::new(
345            vec![TrainingMessage::user("Q")],
346            vec![TrainingMessage::assistant("Good")],
347            vec![TrainingMessage::assistant("Bad")],
348        )]);
349        let removed = ds.remove(0).unwrap();
350        assert_eq!(removed.prompt.len(), 1);
351        assert!(ds.is_empty());
352        assert!(ds.remove(0).is_err());
353    }
354
355    #[test]
356    fn test_preference_dataset_filter() {
357        let ds = PreferenceDataset::new(vec![
358            PreferencePair::new(
359                vec![TrainingMessage::user("short")],
360                vec![TrainingMessage::assistant("a")],
361                vec![TrainingMessage::assistant("b")],
362            ),
363            PreferencePair::new(
364                vec![TrainingMessage::user("this is a longer prompt message")],
365                vec![TrainingMessage::assistant("good")],
366                vec![TrainingMessage::assistant("bad")],
367            ),
368        ]);
369        let filtered = ds.filter(|p| p.estimated_tokens() > 5);
370        assert_eq!(filtered.len(), 1);
371    }
372}