Skip to main content

datasynth_eval/ml/
labels.rs

1//! Label quality analysis.
2//!
3//! Analyzes label distributions and quality for supervised ML tasks.
4
5use crate::error::EvalResult;
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8
9/// Results of label analysis.
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct LabelAnalysis {
12    /// Total samples.
13    pub total_samples: usize,
14    /// Samples with labels.
15    pub labeled_samples: usize,
16    /// Label coverage (labeled / total).
17    pub label_coverage: f64,
18    /// Anomaly rate (for binary anomaly detection).
19    pub anomaly_rate: f64,
20    /// Class distribution.
21    pub class_distribution: Vec<LabelDistribution>,
22    /// Imbalance ratio (max class / min class).
23    pub imbalance_ratio: f64,
24    /// Anomaly type breakdown.
25    pub anomaly_types: HashMap<String, usize>,
26    /// Label quality score (0.0-1.0).
27    pub quality_score: f64,
28    /// Issues with labels.
29    pub issues: Vec<String>,
30}
31
32/// Distribution for a single label class.
33#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct LabelDistribution {
35    /// Class name/value.
36    pub class_name: String,
37    /// Count of samples.
38    pub count: usize,
39    /// Percentage of total.
40    pub percentage: f64,
41}
42
43/// Input for label analysis.
44#[derive(Debug, Clone, Default)]
45pub struct LabelData {
46    /// Binary labels (true = anomaly/positive).
47    pub binary_labels: Vec<Option<bool>>,
48    /// Multi-class labels.
49    pub multiclass_labels: Vec<Option<String>>,
50    /// Anomaly type labels (for anomalies).
51    pub anomaly_types: Vec<Option<String>>,
52}
53
54/// Analyzer for label quality.
55pub struct LabelAnalyzer {
56    /// Minimum acceptable anomaly rate.
57    min_anomaly_rate: f64,
58    /// Maximum acceptable anomaly rate.
59    max_anomaly_rate: f64,
60    /// Maximum acceptable imbalance ratio.
61    max_imbalance_ratio: f64,
62}
63
64impl LabelAnalyzer {
65    /// Create a new analyzer.
66    pub fn new() -> Self {
67        Self {
68            min_anomaly_rate: 0.01,
69            max_anomaly_rate: 0.20,
70            max_imbalance_ratio: 100.0,
71        }
72    }
73
74    /// Analyze label quality.
75    pub fn analyze(&self, data: &LabelData) -> EvalResult<LabelAnalysis> {
76        let mut issues = Vec::new();
77
78        // Determine total samples from the larger of binary or multiclass
79        let total_samples = data.binary_labels.len().max(data.multiclass_labels.len());
80
81        // Analyze binary labels
82        let (anomaly_rate, labeled_binary) = if !data.binary_labels.is_empty() {
83            let present: Vec<bool> = data.binary_labels.iter().filter_map(|v| *v).collect();
84            let anomalies = present.iter().filter(|v| **v).count();
85            let rate = if !present.is_empty() {
86                anomalies as f64 / present.len() as f64
87            } else {
88                0.0
89            };
90            (rate, present.len())
91        } else {
92            (0.0, 0)
93        };
94
95        // Analyze multiclass labels
96        let (class_distribution, labeled_multi) = if !data.multiclass_labels.is_empty() {
97            let present: Vec<&String> = data
98                .multiclass_labels
99                .iter()
100                .filter_map(|v| v.as_ref())
101                .collect();
102
103            let mut counts: HashMap<&str, usize> = HashMap::new();
104            for label in &present {
105                *counts.entry(label.as_str()).or_insert(0) += 1;
106            }
107
108            let total = present.len();
109            let distribution: Vec<LabelDistribution> = counts
110                .iter()
111                .map(|(name, count)| LabelDistribution {
112                    class_name: name.to_string(),
113                    count: *count,
114                    percentage: if total > 0 {
115                        *count as f64 / total as f64
116                    } else {
117                        0.0
118                    },
119                })
120                .collect();
121
122            (distribution, present.len())
123        } else {
124            (Vec::new(), 0)
125        };
126
127        // Calculate label coverage
128        let labeled_samples = labeled_binary.max(labeled_multi);
129        let label_coverage = if total_samples > 0 {
130            labeled_samples as f64 / total_samples as f64
131        } else {
132            1.0
133        };
134
135        // Calculate imbalance ratio
136        let imbalance_ratio = if !class_distribution.is_empty() {
137            let max_count = class_distribution
138                .iter()
139                .map(|d| d.count)
140                .max()
141                .unwrap_or(1);
142            let min_count = class_distribution
143                .iter()
144                .map(|d| d.count)
145                .filter(|c| *c > 0)
146                .min()
147                .unwrap_or(1);
148            max_count as f64 / min_count as f64
149        } else if labeled_binary > 0 {
150            let anomalies = (anomaly_rate * labeled_binary as f64) as usize;
151            let normals = labeled_binary - anomalies;
152            if anomalies > 0 && normals > 0 {
153                (anomalies.max(normals) as f64) / (anomalies.min(normals) as f64)
154            } else {
155                f64::INFINITY
156            }
157        } else {
158            1.0
159        };
160
161        // Analyze anomaly types
162        let mut anomaly_types: HashMap<String, usize> = HashMap::new();
163        for atype in data.anomaly_types.iter().filter_map(|v| v.as_ref()) {
164            *anomaly_types.entry(atype.clone()).or_insert(0) += 1;
165        }
166
167        // Check for issues
168        if label_coverage < 0.99 {
169            issues.push(format!(
170                "Low label coverage: {:.2}%",
171                label_coverage * 100.0
172            ));
173        }
174
175        if anomaly_rate < self.min_anomaly_rate && labeled_binary > 0 {
176            issues.push(format!(
177                "Anomaly rate too low: {:.2}% (min: {:.2}%)",
178                anomaly_rate * 100.0,
179                self.min_anomaly_rate * 100.0
180            ));
181        }
182
183        if anomaly_rate > self.max_anomaly_rate {
184            issues.push(format!(
185                "Anomaly rate too high: {:.2}% (max: {:.2}%)",
186                anomaly_rate * 100.0,
187                self.max_anomaly_rate * 100.0
188            ));
189        }
190
191        if imbalance_ratio > self.max_imbalance_ratio {
192            issues.push(format!("High class imbalance: {:.1}:1", imbalance_ratio));
193        }
194
195        // Calculate quality score
196        let mut quality_factors = Vec::new();
197
198        // Coverage factor
199        quality_factors.push(label_coverage);
200
201        // Anomaly rate factor (penalize if outside ideal range)
202        if labeled_binary > 0 {
203            let rate_score =
204                if anomaly_rate >= self.min_anomaly_rate && anomaly_rate <= self.max_anomaly_rate {
205                    1.0
206                } else if anomaly_rate < self.min_anomaly_rate {
207                    anomaly_rate / self.min_anomaly_rate
208                } else {
209                    self.max_anomaly_rate / anomaly_rate
210                };
211            quality_factors.push(rate_score.min(1.0));
212        }
213
214        // Imbalance factor
215        if imbalance_ratio > 1.0 {
216            let balance_score = (1.0 / imbalance_ratio.sqrt()).min(1.0);
217            quality_factors.push(balance_score);
218        }
219
220        let quality_score = if quality_factors.is_empty() {
221            1.0
222        } else {
223            quality_factors.iter().sum::<f64>() / quality_factors.len() as f64
224        };
225
226        Ok(LabelAnalysis {
227            total_samples,
228            labeled_samples,
229            label_coverage,
230            anomaly_rate,
231            class_distribution,
232            imbalance_ratio,
233            anomaly_types,
234            quality_score,
235            issues,
236        })
237    }
238}
239
240impl Default for LabelAnalyzer {
241    fn default() -> Self {
242        Self::new()
243    }
244}
245
246#[cfg(test)]
247mod tests {
248    use super::*;
249
250    #[test]
251    fn test_balanced_labels() {
252        let data = LabelData {
253            binary_labels: vec![
254                Some(false),
255                Some(false),
256                Some(false),
257                Some(false),
258                Some(false),
259                Some(false),
260                Some(false),
261                Some(false),
262                Some(true),
263                Some(true),
264            ],
265            multiclass_labels: vec![],
266            anomaly_types: vec![],
267        };
268
269        let analyzer = LabelAnalyzer::new();
270        let result = analyzer.analyze(&data).unwrap();
271
272        assert_eq!(result.total_samples, 10);
273        assert_eq!(result.label_coverage, 1.0);
274        assert!((result.anomaly_rate - 0.2).abs() < 0.01);
275    }
276
277    #[test]
278    fn test_multiclass_labels() {
279        let data = LabelData {
280            binary_labels: vec![],
281            multiclass_labels: vec![
282                Some("A".to_string()),
283                Some("A".to_string()),
284                Some("B".to_string()),
285                Some("C".to_string()),
286            ],
287            anomaly_types: vec![],
288        };
289
290        let analyzer = LabelAnalyzer::new();
291        let result = analyzer.analyze(&data).unwrap();
292
293        assert_eq!(result.class_distribution.len(), 3);
294        assert!(result.imbalance_ratio >= 1.0);
295    }
296
297    #[test]
298    fn test_missing_labels() {
299        let data = LabelData {
300            binary_labels: vec![Some(true), None, Some(false), None],
301            multiclass_labels: vec![],
302            anomaly_types: vec![],
303        };
304
305        let analyzer = LabelAnalyzer::new();
306        let result = analyzer.analyze(&data).unwrap();
307
308        assert_eq!(result.labeled_samples, 2);
309        assert!(result.label_coverage < 1.0);
310    }
311}