brainwires_datasets/quality/
stats.rs1use crate::types::{PreferencePair, TrainingExample, TrainingRole};
2
3#[derive(Debug, Clone)]
5pub struct DatasetStats {
6 pub total_examples: usize,
8 pub total_messages: usize,
10 pub total_estimated_tokens: usize,
12 pub avg_messages_per_example: f64,
14 pub avg_tokens_per_example: f64,
16 pub min_tokens: usize,
18 pub max_tokens: usize,
20 pub examples_with_system: usize,
22 pub role_counts: RoleCounts,
24 pub token_histogram: Vec<HistogramBucket>,
26}
27
28#[derive(Debug, Clone, Default)]
30pub struct RoleCounts {
31 pub system: usize,
33 pub user: usize,
35 pub assistant: usize,
37 pub tool: usize,
39}
40
41#[derive(Debug, Clone)]
43pub struct HistogramBucket {
44 pub range_start: usize,
46 pub range_end: usize,
48 pub count: usize,
50}
51
52pub fn compute_stats(examples: &[TrainingExample]) -> DatasetStats {
54 if examples.is_empty() {
55 return DatasetStats {
56 total_examples: 0,
57 total_messages: 0,
58 total_estimated_tokens: 0,
59 avg_messages_per_example: 0.0,
60 avg_tokens_per_example: 0.0,
61 min_tokens: 0,
62 max_tokens: 0,
63 examples_with_system: 0,
64 role_counts: RoleCounts::default(),
65 token_histogram: Vec::new(),
66 };
67 }
68
69 let mut total_messages = 0;
70 let mut total_tokens = 0;
71 let mut min_tokens = usize::MAX;
72 let mut max_tokens = 0;
73 let mut examples_with_system = 0;
74 let mut role_counts = RoleCounts::default();
75 let mut token_counts: Vec<usize> = Vec::with_capacity(examples.len());
76
77 for example in examples {
78 let tokens = example.estimated_tokens();
79 token_counts.push(tokens);
80 total_messages += example.messages.len();
81 total_tokens += tokens;
82 min_tokens = min_tokens.min(tokens);
83 max_tokens = max_tokens.max(tokens);
84
85 if example.has_system_message() {
86 examples_with_system += 1;
87 }
88
89 for msg in &example.messages {
90 match msg.role {
91 TrainingRole::System => role_counts.system += 1,
92 TrainingRole::User => role_counts.user += 1,
93 TrainingRole::Assistant => role_counts.assistant += 1,
94 TrainingRole::Tool => role_counts.tool += 1,
95 }
96 }
97 }
98
99 let n = examples.len();
100 let histogram = build_histogram(&token_counts);
101
102 DatasetStats {
103 total_examples: n,
104 total_messages,
105 total_estimated_tokens: total_tokens,
106 avg_messages_per_example: total_messages as f64 / n as f64,
107 avg_tokens_per_example: total_tokens as f64 / n as f64,
108 min_tokens,
109 max_tokens,
110 examples_with_system,
111 role_counts,
112 token_histogram: histogram,
113 }
114}
115
116#[derive(Debug, Clone)]
118pub struct PreferenceStats {
119 pub total_pairs: usize,
121 pub total_estimated_tokens: usize,
123 pub avg_prompt_tokens: f64,
125 pub avg_chosen_tokens: f64,
127 pub avg_rejected_tokens: f64,
129 pub min_tokens: usize,
131 pub max_tokens: usize,
133 pub chosen_rejected_length_ratio: f64,
135 pub token_histogram: Vec<HistogramBucket>,
137}
138
139pub fn compute_preference_stats(pairs: &[PreferencePair]) -> PreferenceStats {
141 if pairs.is_empty() {
142 return PreferenceStats {
143 total_pairs: 0,
144 total_estimated_tokens: 0,
145 avg_prompt_tokens: 0.0,
146 avg_chosen_tokens: 0.0,
147 avg_rejected_tokens: 0.0,
148 min_tokens: 0,
149 max_tokens: 0,
150 chosen_rejected_length_ratio: 0.0,
151 token_histogram: Vec::new(),
152 };
153 }
154
155 let mut total_tokens = 0;
156 let mut total_prompt_tokens = 0;
157 let mut total_chosen_tokens = 0;
158 let mut total_rejected_tokens = 0;
159 let mut min_tokens = usize::MAX;
160 let mut max_tokens = 0;
161 let mut ratio_sum = 0.0;
162 let mut token_counts: Vec<usize> = Vec::with_capacity(pairs.len());
163
164 for pair in pairs {
165 let prompt_t: usize = pair.prompt.iter().map(|m| m.estimated_tokens()).sum();
166 let chosen_t: usize = pair.chosen.iter().map(|m| m.estimated_tokens()).sum();
167 let rejected_t: usize = pair.rejected.iter().map(|m| m.estimated_tokens()).sum();
168 let pair_tokens = prompt_t + chosen_t + rejected_t;
169
170 token_counts.push(pair_tokens);
171 total_tokens += pair_tokens;
172 total_prompt_tokens += prompt_t;
173 total_chosen_tokens += chosen_t;
174 total_rejected_tokens += rejected_t;
175 min_tokens = min_tokens.min(pair_tokens);
176 max_tokens = max_tokens.max(pair_tokens);
177
178 let chosen_len = chosen_t.max(1) as f64;
179 let rejected_len = rejected_t.max(1) as f64;
180 ratio_sum += chosen_len / rejected_len;
181 }
182
183 let n = pairs.len() as f64;
184 let histogram = build_histogram(&token_counts);
185
186 PreferenceStats {
187 total_pairs: pairs.len(),
188 total_estimated_tokens: total_tokens,
189 avg_prompt_tokens: total_prompt_tokens as f64 / n,
190 avg_chosen_tokens: total_chosen_tokens as f64 / n,
191 avg_rejected_tokens: total_rejected_tokens as f64 / n,
192 min_tokens,
193 max_tokens,
194 chosen_rejected_length_ratio: ratio_sum / n,
195 token_histogram: histogram,
196 }
197}
198
199fn build_histogram(token_counts: &[usize]) -> Vec<HistogramBucket> {
200 if token_counts.is_empty() {
201 return Vec::new();
202 }
203
204 let max = *token_counts.iter().max().unwrap_or(&0);
205 if max == 0 {
206 return vec![HistogramBucket {
207 range_start: 0,
208 range_end: 1,
209 count: token_counts.len(),
210 }];
211 }
212
213 let mut boundaries = vec![0usize];
215 let mut b = 128;
216 while b <= max {
217 boundaries.push(b);
218 b *= 2;
219 }
220 boundaries.push(b);
221
222 let mut buckets: Vec<HistogramBucket> = boundaries
223 .windows(2)
224 .map(|w| HistogramBucket {
225 range_start: w[0],
226 range_end: w[1],
227 count: 0,
228 })
229 .collect();
230
231 for &count in token_counts {
232 for bucket in &mut buckets {
233 if count >= bucket.range_start && count < bucket.range_end {
234 bucket.count += 1;
235 break;
236 }
237 }
238 }
239
240 while buckets.last().is_some_and(|b| b.count == 0) {
242 buckets.pop();
243 }
244
245 buckets
246}
247
248#[cfg(test)]
249mod tests {
250 use super::*;
251 use crate::types::TrainingMessage;
252
253 fn sample_examples() -> Vec<TrainingExample> {
254 vec![
255 TrainingExample::with_id(
256 "1",
257 vec![
258 TrainingMessage::system("Be helpful"),
259 TrainingMessage::user("Hello"),
260 TrainingMessage::assistant("Hi there! How can I help?"),
261 ],
262 ),
263 TrainingExample::with_id(
264 "2",
265 vec![
266 TrainingMessage::user("What is 2+2?"),
267 TrainingMessage::assistant("4"),
268 ],
269 ),
270 TrainingExample::with_id(
271 "3",
272 vec![
273 TrainingMessage::system("Expert mode"),
274 TrainingMessage::user("Explain quantum computing"),
275 TrainingMessage::assistant(
276 "Quantum computing leverages quantum mechanical phenomena...",
277 ),
278 ],
279 ),
280 ]
281 }
282
283 #[test]
284 fn test_compute_stats() {
285 let stats = compute_stats(&sample_examples());
286 assert_eq!(stats.total_examples, 3);
287 assert_eq!(stats.total_messages, 8);
288 assert_eq!(stats.examples_with_system, 2);
289 assert_eq!(stats.role_counts.system, 2);
290 assert_eq!(stats.role_counts.user, 3);
291 assert_eq!(stats.role_counts.assistant, 3);
292 assert!(stats.avg_messages_per_example > 2.0);
293 assert!(stats.total_estimated_tokens > 0);
294 }
295
296 #[test]
297 fn test_empty_stats() {
298 let stats = compute_stats(&[]);
299 assert_eq!(stats.total_examples, 0);
300 assert_eq!(stats.avg_tokens_per_example, 0.0);
301 }
302
303 #[test]
304 fn test_compute_preference_stats() {
305 use crate::types::PreferencePair;
306 let pairs = vec![
307 PreferencePair::new(
308 vec![TrainingMessage::user("Question one here")],
309 vec![TrainingMessage::assistant("A good answer")],
310 vec![TrainingMessage::assistant("Bad")],
311 ),
312 PreferencePair::new(
313 vec![TrainingMessage::user("Another question")],
314 vec![TrainingMessage::assistant("Another good answer")],
315 vec![TrainingMessage::assistant("Another bad answer")],
316 ),
317 ];
318 let stats = compute_preference_stats(&pairs);
319 assert_eq!(stats.total_pairs, 2);
320 assert!(stats.total_estimated_tokens > 0);
321 assert!(stats.avg_prompt_tokens > 0.0);
322 assert!(stats.chosen_rejected_length_ratio > 0.0);
323 }
324
325 #[test]
326 fn test_empty_preference_stats() {
327 let stats = compute_preference_stats(&[]);
328 assert_eq!(stats.total_pairs, 0);
329 }
330}