1use crate::error::EvalResult;
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct LabelAnalysis {
12 pub total_samples: usize,
14 pub labeled_samples: usize,
16 pub label_coverage: f64,
18 pub anomaly_rate: f64,
20 pub class_distribution: Vec<LabelDistribution>,
22 pub imbalance_ratio: f64,
24 pub anomaly_types: HashMap<String, usize>,
26 pub quality_score: f64,
28 pub issues: Vec<String>,
30}
31
32#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct LabelDistribution {
35 pub class_name: String,
37 pub count: usize,
39 pub percentage: f64,
41}
42
43#[derive(Debug, Clone, Default)]
45pub struct LabelData {
46 pub binary_labels: Vec<Option<bool>>,
48 pub multiclass_labels: Vec<Option<String>>,
50 pub anomaly_types: Vec<Option<String>>,
52}
53
54pub struct LabelAnalyzer {
56 min_anomaly_rate: f64,
58 max_anomaly_rate: f64,
60 max_imbalance_ratio: f64,
62}
63
64impl LabelAnalyzer {
65 pub fn new() -> Self {
67 Self {
68 min_anomaly_rate: 0.01,
69 max_anomaly_rate: 0.20,
70 max_imbalance_ratio: 100.0,
71 }
72 }
73
74 pub fn analyze(&self, data: &LabelData) -> EvalResult<LabelAnalysis> {
76 let mut issues = Vec::new();
77
78 let total_samples = data.binary_labels.len().max(data.multiclass_labels.len());
80
81 let (anomaly_rate, labeled_binary) = if !data.binary_labels.is_empty() {
83 let present: Vec<bool> = data.binary_labels.iter().filter_map(|v| *v).collect();
84 let anomalies = present.iter().filter(|v| **v).count();
85 let rate = if !present.is_empty() {
86 anomalies as f64 / present.len() as f64
87 } else {
88 0.0
89 };
90 (rate, present.len())
91 } else {
92 (0.0, 0)
93 };
94
95 let (class_distribution, labeled_multi) = if !data.multiclass_labels.is_empty() {
97 let present: Vec<&String> = data
98 .multiclass_labels
99 .iter()
100 .filter_map(|v| v.as_ref())
101 .collect();
102
103 let mut counts: HashMap<&str, usize> = HashMap::new();
104 for label in &present {
105 *counts.entry(label.as_str()).or_insert(0) += 1;
106 }
107
108 let total = present.len();
109 let distribution: Vec<LabelDistribution> = counts
110 .iter()
111 .map(|(name, count)| LabelDistribution {
112 class_name: name.to_string(),
113 count: *count,
114 percentage: if total > 0 {
115 *count as f64 / total as f64
116 } else {
117 0.0
118 },
119 })
120 .collect();
121
122 (distribution, present.len())
123 } else {
124 (Vec::new(), 0)
125 };
126
127 let labeled_samples = labeled_binary.max(labeled_multi);
129 let label_coverage = if total_samples > 0 {
130 labeled_samples as f64 / total_samples as f64
131 } else {
132 1.0
133 };
134
135 let imbalance_ratio = if !class_distribution.is_empty() {
137 let max_count = class_distribution
138 .iter()
139 .map(|d| d.count)
140 .max()
141 .unwrap_or(1);
142 let min_count = class_distribution
143 .iter()
144 .map(|d| d.count)
145 .filter(|c| *c > 0)
146 .min()
147 .unwrap_or(1);
148 max_count as f64 / min_count as f64
149 } else if labeled_binary > 0 {
150 let anomalies = (anomaly_rate * labeled_binary as f64) as usize;
151 let normals = labeled_binary - anomalies;
152 if anomalies > 0 && normals > 0 {
153 (anomalies.max(normals) as f64) / (anomalies.min(normals) as f64)
154 } else {
155 f64::INFINITY
156 }
157 } else {
158 1.0
159 };
160
161 let mut anomaly_types: HashMap<String, usize> = HashMap::new();
163 for atype in data.anomaly_types.iter().filter_map(|v| v.as_ref()) {
164 *anomaly_types.entry(atype.clone()).or_insert(0) += 1;
165 }
166
167 if label_coverage < 0.99 {
169 issues.push(format!(
170 "Low label coverage: {:.2}%",
171 label_coverage * 100.0
172 ));
173 }
174
175 if anomaly_rate < self.min_anomaly_rate && labeled_binary > 0 {
176 issues.push(format!(
177 "Anomaly rate too low: {:.2}% (min: {:.2}%)",
178 anomaly_rate * 100.0,
179 self.min_anomaly_rate * 100.0
180 ));
181 }
182
183 if anomaly_rate > self.max_anomaly_rate {
184 issues.push(format!(
185 "Anomaly rate too high: {:.2}% (max: {:.2}%)",
186 anomaly_rate * 100.0,
187 self.max_anomaly_rate * 100.0
188 ));
189 }
190
191 if imbalance_ratio > self.max_imbalance_ratio {
192 issues.push(format!("High class imbalance: {:.1}:1", imbalance_ratio));
193 }
194
195 let mut quality_factors = Vec::new();
197
198 quality_factors.push(label_coverage);
200
201 if labeled_binary > 0 {
203 let rate_score =
204 if anomaly_rate >= self.min_anomaly_rate && anomaly_rate <= self.max_anomaly_rate {
205 1.0
206 } else if anomaly_rate < self.min_anomaly_rate {
207 anomaly_rate / self.min_anomaly_rate
208 } else {
209 self.max_anomaly_rate / anomaly_rate
210 };
211 quality_factors.push(rate_score.min(1.0));
212 }
213
214 if imbalance_ratio > 1.0 {
216 let balance_score = (1.0 / imbalance_ratio.sqrt()).min(1.0);
217 quality_factors.push(balance_score);
218 }
219
220 let quality_score = if quality_factors.is_empty() {
221 1.0
222 } else {
223 quality_factors.iter().sum::<f64>() / quality_factors.len() as f64
224 };
225
226 Ok(LabelAnalysis {
227 total_samples,
228 labeled_samples,
229 label_coverage,
230 anomaly_rate,
231 class_distribution,
232 imbalance_ratio,
233 anomaly_types,
234 quality_score,
235 issues,
236 })
237 }
238}
239
240impl Default for LabelAnalyzer {
241 fn default() -> Self {
242 Self::new()
243 }
244}
245
246#[cfg(test)]
247mod tests {
248 use super::*;
249
250 #[test]
251 fn test_balanced_labels() {
252 let data = LabelData {
253 binary_labels: vec![
254 Some(false),
255 Some(false),
256 Some(false),
257 Some(false),
258 Some(false),
259 Some(false),
260 Some(false),
261 Some(false),
262 Some(true),
263 Some(true),
264 ],
265 multiclass_labels: vec![],
266 anomaly_types: vec![],
267 };
268
269 let analyzer = LabelAnalyzer::new();
270 let result = analyzer.analyze(&data).unwrap();
271
272 assert_eq!(result.total_samples, 10);
273 assert_eq!(result.label_coverage, 1.0);
274 assert!((result.anomaly_rate - 0.2).abs() < 0.01);
275 }
276
277 #[test]
278 fn test_multiclass_labels() {
279 let data = LabelData {
280 binary_labels: vec![],
281 multiclass_labels: vec![
282 Some("A".to_string()),
283 Some("A".to_string()),
284 Some("B".to_string()),
285 Some("C".to_string()),
286 ],
287 anomaly_types: vec![],
288 };
289
290 let analyzer = LabelAnalyzer::new();
291 let result = analyzer.analyze(&data).unwrap();
292
293 assert_eq!(result.class_distribution.len(), 3);
294 assert!(result.imbalance_ratio >= 1.0);
295 }
296
297 #[test]
298 fn test_missing_labels() {
299 let data = LabelData {
300 binary_labels: vec![Some(true), None, Some(false), None],
301 multiclass_labels: vec![],
302 anomaly_types: vec![],
303 };
304
305 let analyzer = LabelAnalyzer::new();
306 let result = analyzer.analyze(&data).unwrap();
307
308 assert_eq!(result.labeled_samples, 2);
309 assert!(result.label_coverage < 1.0);
310 }
311}