Skip to main content

brainwires_datasets/quality/
stats.rs

1use crate::types::{PreferencePair, TrainingExample, TrainingRole};
2
3/// Statistics about a training dataset.
4#[derive(Debug, Clone)]
5pub struct DatasetStats {
6    /// Total number of training examples.
7    pub total_examples: usize,
8    /// Total number of messages across all examples.
9    pub total_messages: usize,
10    /// Total estimated tokens across all examples.
11    pub total_estimated_tokens: usize,
12    /// Average messages per example.
13    pub avg_messages_per_example: f64,
14    /// Average estimated tokens per example.
15    pub avg_tokens_per_example: f64,
16    /// Minimum tokens in any single example.
17    pub min_tokens: usize,
18    /// Maximum tokens in any single example.
19    pub max_tokens: usize,
20    /// Number of examples that include a system message.
21    pub examples_with_system: usize,
22    /// Message counts per role.
23    pub role_counts: RoleCounts,
24    /// Token count distribution histogram.
25    pub token_histogram: Vec<HistogramBucket>,
26}
27
28/// Message counts broken down by role.
29#[derive(Debug, Clone, Default)]
30pub struct RoleCounts {
31    /// Number of system messages.
32    pub system: usize,
33    /// Number of user messages.
34    pub user: usize,
35    /// Number of assistant messages.
36    pub assistant: usize,
37    /// Number of tool messages.
38    pub tool: usize,
39}
40
41/// A single bucket in the token count histogram.
42#[derive(Debug, Clone)]
43pub struct HistogramBucket {
44    /// Inclusive lower bound of the bucket range.
45    pub range_start: usize,
46    /// Exclusive upper bound of the bucket range.
47    pub range_end: usize,
48    /// Number of examples falling in this range.
49    pub count: usize,
50}
51
52/// Compute statistics for a set of training examples.
53pub fn compute_stats(examples: &[TrainingExample]) -> DatasetStats {
54    if examples.is_empty() {
55        return DatasetStats {
56            total_examples: 0,
57            total_messages: 0,
58            total_estimated_tokens: 0,
59            avg_messages_per_example: 0.0,
60            avg_tokens_per_example: 0.0,
61            min_tokens: 0,
62            max_tokens: 0,
63            examples_with_system: 0,
64            role_counts: RoleCounts::default(),
65            token_histogram: Vec::new(),
66        };
67    }
68
69    let mut total_messages = 0;
70    let mut total_tokens = 0;
71    let mut min_tokens = usize::MAX;
72    let mut max_tokens = 0;
73    let mut examples_with_system = 0;
74    let mut role_counts = RoleCounts::default();
75    let mut token_counts: Vec<usize> = Vec::with_capacity(examples.len());
76
77    for example in examples {
78        let tokens = example.estimated_tokens();
79        token_counts.push(tokens);
80        total_messages += example.messages.len();
81        total_tokens += tokens;
82        min_tokens = min_tokens.min(tokens);
83        max_tokens = max_tokens.max(tokens);
84
85        if example.has_system_message() {
86            examples_with_system += 1;
87        }
88
89        for msg in &example.messages {
90            match msg.role {
91                TrainingRole::System => role_counts.system += 1,
92                TrainingRole::User => role_counts.user += 1,
93                TrainingRole::Assistant => role_counts.assistant += 1,
94                TrainingRole::Tool => role_counts.tool += 1,
95            }
96        }
97    }
98
99    let n = examples.len();
100    let histogram = build_histogram(&token_counts);
101
102    DatasetStats {
103        total_examples: n,
104        total_messages,
105        total_estimated_tokens: total_tokens,
106        avg_messages_per_example: total_messages as f64 / n as f64,
107        avg_tokens_per_example: total_tokens as f64 / n as f64,
108        min_tokens,
109        max_tokens,
110        examples_with_system,
111        role_counts,
112        token_histogram: histogram,
113    }
114}
115
116/// Statistics about a preference training dataset.
117#[derive(Debug, Clone)]
118pub struct PreferenceStats {
119    /// Total number of preference pairs.
120    pub total_pairs: usize,
121    /// Total estimated tokens across all pairs.
122    pub total_estimated_tokens: usize,
123    /// Average tokens in prompt messages.
124    pub avg_prompt_tokens: f64,
125    /// Average tokens in chosen messages.
126    pub avg_chosen_tokens: f64,
127    /// Average tokens in rejected messages.
128    pub avg_rejected_tokens: f64,
129    /// Minimum tokens in any single pair.
130    pub min_tokens: usize,
131    /// Maximum tokens in any single pair.
132    pub max_tokens: usize,
133    /// Average ratio of chosen to rejected length.
134    pub chosen_rejected_length_ratio: f64,
135    /// Token count distribution histogram.
136    pub token_histogram: Vec<HistogramBucket>,
137}
138
139/// Compute statistics for a set of preference pairs.
140pub fn compute_preference_stats(pairs: &[PreferencePair]) -> PreferenceStats {
141    if pairs.is_empty() {
142        return PreferenceStats {
143            total_pairs: 0,
144            total_estimated_tokens: 0,
145            avg_prompt_tokens: 0.0,
146            avg_chosen_tokens: 0.0,
147            avg_rejected_tokens: 0.0,
148            min_tokens: 0,
149            max_tokens: 0,
150            chosen_rejected_length_ratio: 0.0,
151            token_histogram: Vec::new(),
152        };
153    }
154
155    let mut total_tokens = 0;
156    let mut total_prompt_tokens = 0;
157    let mut total_chosen_tokens = 0;
158    let mut total_rejected_tokens = 0;
159    let mut min_tokens = usize::MAX;
160    let mut max_tokens = 0;
161    let mut ratio_sum = 0.0;
162    let mut token_counts: Vec<usize> = Vec::with_capacity(pairs.len());
163
164    for pair in pairs {
165        let prompt_t: usize = pair.prompt.iter().map(|m| m.estimated_tokens()).sum();
166        let chosen_t: usize = pair.chosen.iter().map(|m| m.estimated_tokens()).sum();
167        let rejected_t: usize = pair.rejected.iter().map(|m| m.estimated_tokens()).sum();
168        let pair_tokens = prompt_t + chosen_t + rejected_t;
169
170        token_counts.push(pair_tokens);
171        total_tokens += pair_tokens;
172        total_prompt_tokens += prompt_t;
173        total_chosen_tokens += chosen_t;
174        total_rejected_tokens += rejected_t;
175        min_tokens = min_tokens.min(pair_tokens);
176        max_tokens = max_tokens.max(pair_tokens);
177
178        let chosen_len = chosen_t.max(1) as f64;
179        let rejected_len = rejected_t.max(1) as f64;
180        ratio_sum += chosen_len / rejected_len;
181    }
182
183    let n = pairs.len() as f64;
184    let histogram = build_histogram(&token_counts);
185
186    PreferenceStats {
187        total_pairs: pairs.len(),
188        total_estimated_tokens: total_tokens,
189        avg_prompt_tokens: total_prompt_tokens as f64 / n,
190        avg_chosen_tokens: total_chosen_tokens as f64 / n,
191        avg_rejected_tokens: total_rejected_tokens as f64 / n,
192        min_tokens,
193        max_tokens,
194        chosen_rejected_length_ratio: ratio_sum / n,
195        token_histogram: histogram,
196    }
197}
198
199fn build_histogram(token_counts: &[usize]) -> Vec<HistogramBucket> {
200    if token_counts.is_empty() {
201        return Vec::new();
202    }
203
204    let max = *token_counts.iter().max().unwrap_or(&0);
205    if max == 0 {
206        return vec![HistogramBucket {
207            range_start: 0,
208            range_end: 1,
209            count: token_counts.len(),
210        }];
211    }
212
213    // Use power-of-2 bucket boundaries: 0-128, 128-256, 256-512, etc.
214    let mut boundaries = vec![0usize];
215    let mut b = 128;
216    while b <= max {
217        boundaries.push(b);
218        b *= 2;
219    }
220    boundaries.push(b);
221
222    let mut buckets: Vec<HistogramBucket> = boundaries
223        .windows(2)
224        .map(|w| HistogramBucket {
225            range_start: w[0],
226            range_end: w[1],
227            count: 0,
228        })
229        .collect();
230
231    for &count in token_counts {
232        for bucket in &mut buckets {
233            if count >= bucket.range_start && count < bucket.range_end {
234                bucket.count += 1;
235                break;
236            }
237        }
238    }
239
240    // Remove empty trailing buckets
241    while buckets.last().is_some_and(|b| b.count == 0) {
242        buckets.pop();
243    }
244
245    buckets
246}
247
248#[cfg(test)]
249mod tests {
250    use super::*;
251    use crate::types::TrainingMessage;
252
253    fn sample_examples() -> Vec<TrainingExample> {
254        vec![
255            TrainingExample::with_id(
256                "1",
257                vec![
258                    TrainingMessage::system("Be helpful"),
259                    TrainingMessage::user("Hello"),
260                    TrainingMessage::assistant("Hi there! How can I help?"),
261                ],
262            ),
263            TrainingExample::with_id(
264                "2",
265                vec![
266                    TrainingMessage::user("What is 2+2?"),
267                    TrainingMessage::assistant("4"),
268                ],
269            ),
270            TrainingExample::with_id(
271                "3",
272                vec![
273                    TrainingMessage::system("Expert mode"),
274                    TrainingMessage::user("Explain quantum computing"),
275                    TrainingMessage::assistant(
276                        "Quantum computing leverages quantum mechanical phenomena...",
277                    ),
278                ],
279            ),
280        ]
281    }
282
283    #[test]
284    fn test_compute_stats() {
285        let stats = compute_stats(&sample_examples());
286        assert_eq!(stats.total_examples, 3);
287        assert_eq!(stats.total_messages, 8);
288        assert_eq!(stats.examples_with_system, 2);
289        assert_eq!(stats.role_counts.system, 2);
290        assert_eq!(stats.role_counts.user, 3);
291        assert_eq!(stats.role_counts.assistant, 3);
292        assert!(stats.avg_messages_per_example > 2.0);
293        assert!(stats.total_estimated_tokens > 0);
294    }
295
296    #[test]
297    fn test_empty_stats() {
298        let stats = compute_stats(&[]);
299        assert_eq!(stats.total_examples, 0);
300        assert_eq!(stats.avg_tokens_per_example, 0.0);
301    }
302
303    #[test]
304    fn test_compute_preference_stats() {
305        use crate::types::PreferencePair;
306        let pairs = vec![
307            PreferencePair::new(
308                vec![TrainingMessage::user("Question one here")],
309                vec![TrainingMessage::assistant("A good answer")],
310                vec![TrainingMessage::assistant("Bad")],
311            ),
312            PreferencePair::new(
313                vec![TrainingMessage::user("Another question")],
314                vec![TrainingMessage::assistant("Another good answer")],
315                vec![TrainingMessage::assistant("Another bad answer")],
316            ),
317        ];
318        let stats = compute_preference_stats(&pairs);
319        assert_eq!(stats.total_pairs, 2);
320        assert!(stats.total_estimated_tokens > 0);
321        assert!(stats.avg_prompt_tokens > 0.0);
322        assert!(stats.chosen_rejected_length_ratio > 0.0);
323    }
324
325    #[test]
326    fn test_empty_preference_stats() {
327        let stats = compute_preference_stats(&[]);
328        assert_eq!(stats.total_pairs, 0);
329    }
330}