brainwires_datasets/
sampling.rs1use crate::dataset::{Dataset, InstructDataset, PreferenceDataset};
2use crate::types::{PreferencePair, TrainingExample};
3
4const PCG_MULTIPLIER: u64 = 6_364_136_223_846_793_005;
6const PCG_INCREMENT: u64 = 1_442_695_040_888_963_407;
8
9#[derive(Debug, Clone)]
11pub struct SplitConfig {
12 pub train_ratio: f32,
14 pub seed: u64,
16 pub shuffle: bool,
18}
19
20impl Default for SplitConfig {
21 fn default() -> Self {
22 Self {
23 train_ratio: 0.9,
24 seed: 42,
25 shuffle: true,
26 }
27 }
28}
29
30pub struct SplitResult {
32 pub train: InstructDataset,
34 pub eval: InstructDataset,
36}
37
38pub fn train_eval_split(examples: &[TrainingExample], config: &SplitConfig) -> SplitResult {
40 let mut dataset = InstructDataset::new(examples.to_vec());
41
42 if config.shuffle {
43 dataset.shuffle(config.seed);
44 }
45
46 let (train, eval) = dataset.split(config.train_ratio);
47
48 tracing::debug!("Split dataset: {} train, {} eval", train.len(), eval.len());
49
50 SplitResult {
51 train: InstructDataset::new(train),
52 eval: InstructDataset::new(eval),
53 }
54}
55
56pub fn curriculum_order(examples: &mut [TrainingExample]) {
58 examples.sort_by_key(|e| e.estimated_tokens());
59}
60
61pub fn anti_curriculum_order(examples: &mut [TrainingExample]) {
63 examples.sort_by_key(|b| std::cmp::Reverse(b.estimated_tokens()));
64}
65
66pub fn sample_n(examples: &[TrainingExample], n: usize, seed: u64) -> Vec<TrainingExample> {
68 if n >= examples.len() {
69 return examples.to_vec();
70 }
71
72 let mut indices: Vec<usize> = (0..examples.len()).collect();
74 let mut state = seed;
75 for i in 0..n {
76 state = state
77 .wrapping_mul(PCG_MULTIPLIER)
78 .wrapping_add(PCG_INCREMENT);
79 let j = i + ((state >> 33) as usize % (examples.len() - i));
80 indices.swap(i, j);
81 }
82
83 indices[..n].iter().map(|&i| examples[i].clone()).collect()
84}
85
86pub struct PreferenceSplitResult {
88 pub train: PreferenceDataset,
90 pub eval: PreferenceDataset,
92}
93
94pub fn preference_train_eval_split(
96 pairs: &[PreferencePair],
97 config: &SplitConfig,
98) -> PreferenceSplitResult {
99 let mut dataset = PreferenceDataset::new(pairs.to_vec());
100
101 if config.shuffle {
102 dataset.shuffle(config.seed);
103 }
104
105 let (train, eval) = dataset.split(config.train_ratio);
106
107 tracing::debug!(
108 "Split preference dataset: {} train, {} eval",
109 train.len(),
110 eval.len()
111 );
112
113 PreferenceSplitResult {
114 train: PreferenceDataset::new(train),
115 eval: PreferenceDataset::new(eval),
116 }
117}
118
119pub fn preference_curriculum_order(pairs: &mut [PreferencePair]) {
121 pairs.sort_by_key(|p| p.estimated_tokens());
122}
123
124pub fn preference_sample_n(pairs: &[PreferencePair], n: usize, seed: u64) -> Vec<PreferencePair> {
126 if n >= pairs.len() {
127 return pairs.to_vec();
128 }
129
130 let mut indices: Vec<usize> = (0..pairs.len()).collect();
131 let mut state = seed;
132 for i in 0..n {
133 state = state
134 .wrapping_mul(PCG_MULTIPLIER)
135 .wrapping_add(PCG_INCREMENT);
136 let j = i + ((state >> 33) as usize % (pairs.len() - i));
137 indices.swap(i, j);
138 }
139
140 indices[..n].iter().map(|&i| pairs[i].clone()).collect()
141}
142
143#[cfg(test)]
144mod tests {
145 use super::*;
146 use crate::types::TrainingMessage;
147
148 fn sample_examples(n: usize) -> Vec<TrainingExample> {
149 (0..n)
150 .map(|i| {
151 TrainingExample::with_id(
152 format!("ex-{i}"),
153 vec![
154 TrainingMessage::user(format!("Q{}: {}", i, "x".repeat(i * 10))),
155 TrainingMessage::assistant(format!("A{}", i)),
156 ],
157 )
158 })
159 .collect()
160 }
161
162 #[test]
163 fn test_train_eval_split() {
164 let examples = sample_examples(100);
165 let result = train_eval_split(&examples, &SplitConfig::default());
166 assert_eq!(result.train.len(), 90);
167 assert_eq!(result.eval.len(), 10);
168 }
169
170 #[test]
171 fn test_curriculum_order() {
172 let mut examples = sample_examples(10);
173 curriculum_order(&mut examples);
174 for i in 1..examples.len() {
175 assert!(examples[i].estimated_tokens() >= examples[i - 1].estimated_tokens());
176 }
177 }
178
179 #[test]
180 fn test_sample_n() {
181 let examples = sample_examples(100);
182 let sampled = sample_n(&examples, 10, 42);
183 assert_eq!(sampled.len(), 10);
184
185 let sampled2 = sample_n(&examples, 10, 42);
187 for (a, b) in sampled.iter().zip(sampled2.iter()) {
188 assert_eq!(a.id, b.id);
189 }
190 }
191
192 #[test]
193 fn test_sample_n_larger_than_dataset() {
194 let examples = sample_examples(5);
195 let sampled = sample_n(&examples, 100, 42);
196 assert_eq!(sampled.len(), 5);
197 }
198
199 #[test]
200 fn test_preference_train_eval_split() {
201 use crate::types::PreferencePair;
202 let pairs: Vec<PreferencePair> = (0..100)
203 .map(|i| {
204 PreferencePair::new(
205 vec![TrainingMessage::user(format!("Q{}", i))],
206 vec![TrainingMessage::assistant("Good")],
207 vec![TrainingMessage::assistant("Bad")],
208 )
209 })
210 .collect();
211 let result = preference_train_eval_split(&pairs, &SplitConfig::default());
212 assert_eq!(result.train.len(), 90);
213 assert_eq!(result.eval.len(), 10);
214 }
215
216 #[test]
217 fn test_preference_sample_n() {
218 use crate::types::PreferencePair;
219 let pairs: Vec<PreferencePair> = (0..50)
220 .map(|i| {
221 PreferencePair::new(
222 vec![TrainingMessage::user(format!("Q{}", i))],
223 vec![TrainingMessage::assistant("Good")],
224 vec![TrainingMessage::assistant("Bad")],
225 )
226 })
227 .collect();
228 let sampled = preference_sample_n(&pairs, 10, 42);
229 assert_eq!(sampled.len(), 10);
230 let sampled2 = preference_sample_n(&pairs, 10, 42);
231 for (a, b) in sampled.iter().zip(sampled2.iter()) {
232 assert_eq!(a.prompt[0].content, b.prompt[0].content);
233 }
234 }
235}