Skip to main content

brainwires_datasets/
sampling.rs

1use crate::dataset::{Dataset, InstructDataset, PreferenceDataset};
2use crate::types::{PreferencePair, TrainingExample};
3
4/// PCG random number generator multiplier constant.
5const PCG_MULTIPLIER: u64 = 6_364_136_223_846_793_005;
6/// PCG random number generator increment constant.
7const PCG_INCREMENT: u64 = 1_442_695_040_888_963_407;
8
9/// Split configuration for train/eval datasets.
10#[derive(Debug, Clone)]
11pub struct SplitConfig {
12    /// Fraction of data for training (0.0 - 1.0).
13    pub train_ratio: f32,
14    /// Random seed for reproducible splits.
15    pub seed: u64,
16    /// Whether to shuffle before splitting.
17    pub shuffle: bool,
18}
19
20impl Default for SplitConfig {
21    fn default() -> Self {
22        Self {
23            train_ratio: 0.9,
24            seed: 42,
25            shuffle: true,
26        }
27    }
28}
29
30/// Split result containing train and eval datasets.
31pub struct SplitResult {
32    /// The training split.
33    pub train: InstructDataset,
34    /// The evaluation split.
35    pub eval: InstructDataset,
36}
37
38/// Split a dataset into train/eval sets.
39pub fn train_eval_split(examples: &[TrainingExample], config: &SplitConfig) -> SplitResult {
40    let mut dataset = InstructDataset::new(examples.to_vec());
41
42    if config.shuffle {
43        dataset.shuffle(config.seed);
44    }
45
46    let (train, eval) = dataset.split(config.train_ratio);
47
48    tracing::debug!("Split dataset: {} train, {} eval", train.len(), eval.len());
49
50    SplitResult {
51        train: InstructDataset::new(train),
52        eval: InstructDataset::new(eval),
53    }
54}
55
56/// Sort examples by token count (ascending) for curriculum learning.
57pub fn curriculum_order(examples: &mut [TrainingExample]) {
58    examples.sort_by_key(|e| e.estimated_tokens());
59}
60
61/// Sort examples by token count (descending) for anti-curriculum.
62pub fn anti_curriculum_order(examples: &mut [TrainingExample]) {
63    examples.sort_by_key(|b| std::cmp::Reverse(b.estimated_tokens()));
64}
65
66/// Sample `n` examples uniformly (with seed for reproducibility).
67pub fn sample_n(examples: &[TrainingExample], n: usize, seed: u64) -> Vec<TrainingExample> {
68    if n >= examples.len() {
69        return examples.to_vec();
70    }
71
72    // Fisher-Yates partial shuffle
73    let mut indices: Vec<usize> = (0..examples.len()).collect();
74    let mut state = seed;
75    for i in 0..n {
76        state = state
77            .wrapping_mul(PCG_MULTIPLIER)
78            .wrapping_add(PCG_INCREMENT);
79        let j = i + ((state >> 33) as usize % (examples.len() - i));
80        indices.swap(i, j);
81    }
82
83    indices[..n].iter().map(|&i| examples[i].clone()).collect()
84}
85
86/// Split result for preference datasets.
87pub struct PreferenceSplitResult {
88    /// The training split.
89    pub train: PreferenceDataset,
90    /// The evaluation split.
91    pub eval: PreferenceDataset,
92}
93
94/// Split preference pairs into train/eval sets.
95pub fn preference_train_eval_split(
96    pairs: &[PreferencePair],
97    config: &SplitConfig,
98) -> PreferenceSplitResult {
99    let mut dataset = PreferenceDataset::new(pairs.to_vec());
100
101    if config.shuffle {
102        dataset.shuffle(config.seed);
103    }
104
105    let (train, eval) = dataset.split(config.train_ratio);
106
107    tracing::debug!(
108        "Split preference dataset: {} train, {} eval",
109        train.len(),
110        eval.len()
111    );
112
113    PreferenceSplitResult {
114        train: PreferenceDataset::new(train),
115        eval: PreferenceDataset::new(eval),
116    }
117}
118
119/// Sort preference pairs by total token count (ascending) for curriculum learning.
120pub fn preference_curriculum_order(pairs: &mut [PreferencePair]) {
121    pairs.sort_by_key(|p| p.estimated_tokens());
122}
123
124/// Sample `n` preference pairs uniformly (with seed for reproducibility).
125pub fn preference_sample_n(pairs: &[PreferencePair], n: usize, seed: u64) -> Vec<PreferencePair> {
126    if n >= pairs.len() {
127        return pairs.to_vec();
128    }
129
130    let mut indices: Vec<usize> = (0..pairs.len()).collect();
131    let mut state = seed;
132    for i in 0..n {
133        state = state
134            .wrapping_mul(PCG_MULTIPLIER)
135            .wrapping_add(PCG_INCREMENT);
136        let j = i + ((state >> 33) as usize % (pairs.len() - i));
137        indices.swap(i, j);
138    }
139
140    indices[..n].iter().map(|&i| pairs[i].clone()).collect()
141}
142
143#[cfg(test)]
144mod tests {
145    use super::*;
146    use crate::types::TrainingMessage;
147
148    fn sample_examples(n: usize) -> Vec<TrainingExample> {
149        (0..n)
150            .map(|i| {
151                TrainingExample::with_id(
152                    format!("ex-{i}"),
153                    vec![
154                        TrainingMessage::user(format!("Q{}: {}", i, "x".repeat(i * 10))),
155                        TrainingMessage::assistant(format!("A{}", i)),
156                    ],
157                )
158            })
159            .collect()
160    }
161
162    #[test]
163    fn test_train_eval_split() {
164        let examples = sample_examples(100);
165        let result = train_eval_split(&examples, &SplitConfig::default());
166        assert_eq!(result.train.len(), 90);
167        assert_eq!(result.eval.len(), 10);
168    }
169
170    #[test]
171    fn test_curriculum_order() {
172        let mut examples = sample_examples(10);
173        curriculum_order(&mut examples);
174        for i in 1..examples.len() {
175            assert!(examples[i].estimated_tokens() >= examples[i - 1].estimated_tokens());
176        }
177    }
178
179    #[test]
180    fn test_sample_n() {
181        let examples = sample_examples(100);
182        let sampled = sample_n(&examples, 10, 42);
183        assert_eq!(sampled.len(), 10);
184
185        // Deterministic
186        let sampled2 = sample_n(&examples, 10, 42);
187        for (a, b) in sampled.iter().zip(sampled2.iter()) {
188            assert_eq!(a.id, b.id);
189        }
190    }
191
192    #[test]
193    fn test_sample_n_larger_than_dataset() {
194        let examples = sample_examples(5);
195        let sampled = sample_n(&examples, 100, 42);
196        assert_eq!(sampled.len(), 5);
197    }
198
199    #[test]
200    fn test_preference_train_eval_split() {
201        use crate::types::PreferencePair;
202        let pairs: Vec<PreferencePair> = (0..100)
203            .map(|i| {
204                PreferencePair::new(
205                    vec![TrainingMessage::user(format!("Q{}", i))],
206                    vec![TrainingMessage::assistant("Good")],
207                    vec![TrainingMessage::assistant("Bad")],
208                )
209            })
210            .collect();
211        let result = preference_train_eval_split(&pairs, &SplitConfig::default());
212        assert_eq!(result.train.len(), 90);
213        assert_eq!(result.eval.len(), 10);
214    }
215
216    #[test]
217    fn test_preference_sample_n() {
218        use crate::types::PreferencePair;
219        let pairs: Vec<PreferencePair> = (0..50)
220            .map(|i| {
221                PreferencePair::new(
222                    vec![TrainingMessage::user(format!("Q{}", i))],
223                    vec![TrainingMessage::assistant("Good")],
224                    vec![TrainingMessage::assistant("Bad")],
225                )
226            })
227            .collect();
228        let sampled = preference_sample_n(&pairs, 10, 42);
229        assert_eq!(sampled.len(), 10);
230        let sampled2 = preference_sample_n(&pairs, 10, 42);
231        for (a, b) in sampled.iter().zip(sampled2.iter()) {
232            assert_eq!(a.prompt[0].content, b.prompt[0].content);
233        }
234    }
235}