1use crate::error::EvalResult;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10
11#[derive(Debug, Clone)]
13pub struct EmbeddingInput {
14 pub feature_matrix: Vec<Vec<f64>>,
16 pub labels: Vec<String>,
18}
19
20#[derive(Debug, Clone)]
22pub struct EmbeddingReadinessThresholds {
23 pub min_embedding_readiness: f64,
25}
26
27impl Default for EmbeddingReadinessThresholds {
28 fn default() -> Self {
29 Self {
30 min_embedding_readiness: 0.50,
31 }
32 }
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
37pub struct EmbeddingReadinessAnalysis {
38 pub embedding_readiness_score: f64,
40 pub effective_dimensionality: usize,
42 pub total_dimensions: usize,
44 pub contrastive_learning_viable: bool,
46 pub min_class_count: usize,
48 pub passes: bool,
50 pub issues: Vec<String>,
52}
53
54pub struct EmbeddingReadinessAnalyzer {
56 thresholds: EmbeddingReadinessThresholds,
57}
58
59impl EmbeddingReadinessAnalyzer {
60 pub fn new() -> Self {
62 Self {
63 thresholds: EmbeddingReadinessThresholds::default(),
64 }
65 }
66
67 pub fn with_thresholds(thresholds: EmbeddingReadinessThresholds) -> Self {
69 Self { thresholds }
70 }
71
72 pub fn analyze(&self, input: &EmbeddingInput) -> EvalResult<EmbeddingReadinessAnalysis> {
74 let mut issues = Vec::new();
75
76 if input.feature_matrix.is_empty() {
77 return Ok(EmbeddingReadinessAnalysis {
78 embedding_readiness_score: 0.0,
79 effective_dimensionality: 0,
80 total_dimensions: 0,
81 contrastive_learning_viable: false,
82 min_class_count: 0,
83 passes: true,
84 issues: vec!["No samples provided".to_string()],
85 });
86 }
87
88 let total_dimensions = input.feature_matrix.first().map(|r| r.len()).unwrap_or(0);
89
90 if total_dimensions == 0 {
91 return Ok(EmbeddingReadinessAnalysis {
92 embedding_readiness_score: 0.0,
93 effective_dimensionality: 0,
94 total_dimensions: 0,
95 contrastive_learning_viable: false,
96 min_class_count: 0,
97 passes: false,
98 issues: vec!["Zero-dimensional features".to_string()],
99 });
100 }
101
102 let effective_dimensionality =
104 self.compute_effective_dimensionality(&input.feature_matrix, total_dimensions);
105
106 let mut class_counts: HashMap<&str, usize> = HashMap::new();
108 for label in &input.labels {
109 *class_counts.entry(label.as_str()).or_insert(0) += 1;
110 }
111
112 let min_class_count = class_counts.values().copied().min().unwrap_or(0);
113 let num_classes = class_counts.len();
114 let contrastive_learning_viable = min_class_count >= 2 && num_classes >= 2;
115
116 if !contrastive_learning_viable {
117 issues.push(format!(
118 "Contrastive learning not viable: {} classes, min count = {}",
119 num_classes, min_class_count
120 ));
121 }
122
123 let dim_ratio = if total_dimensions > 0 {
125 effective_dimensionality as f64 / total_dimensions as f64
126 } else {
127 0.0
128 };
129 let dim_score = (1.0 - dim_ratio).clamp(0.0, 1.0);
131 let contrastive_score = if contrastive_learning_viable {
132 1.0
133 } else {
134 0.0
135 };
136 let class_balance_score = if num_classes >= 2 && min_class_count > 0 {
137 let max_count = class_counts.values().copied().max().unwrap_or(1);
138 (min_class_count as f64 / max_count as f64).clamp(0.0, 1.0)
139 } else {
140 0.0
141 };
142
143 let embedding_readiness_score =
144 (dim_score * 0.4 + contrastive_score * 0.3 + class_balance_score * 0.3).clamp(0.0, 1.0);
145
146 if embedding_readiness_score < self.thresholds.min_embedding_readiness {
147 issues.push(format!(
148 "Embedding readiness score {:.4} < {:.4} (threshold)",
149 embedding_readiness_score, self.thresholds.min_embedding_readiness
150 ));
151 }
152
153 let passes = issues.is_empty();
154
155 Ok(EmbeddingReadinessAnalysis {
156 embedding_readiness_score,
157 effective_dimensionality,
158 total_dimensions,
159 contrastive_learning_viable,
160 min_class_count,
161 passes,
162 issues,
163 })
164 }
165
166 fn compute_effective_dimensionality(
171 &self,
172 feature_matrix: &[Vec<f64>],
173 total_dims: usize,
174 ) -> usize {
175 let n = feature_matrix.len();
176 if n < 2 || total_dims == 0 {
177 return total_dims;
178 }
179
180 let mut means = vec![0.0; total_dims];
182 for row in feature_matrix {
183 for (j, &val) in row.iter().enumerate().take(total_dims) {
184 means[j] += val;
185 }
186 }
187 for m in &mut means {
188 *m /= n as f64;
189 }
190
191 let dim = total_dims.min(50); let mut cov = vec![vec![0.0; dim]; dim];
194
195 for row in feature_matrix {
196 for i in 0..dim {
197 let di = if i < row.len() {
198 row[i] - means[i]
199 } else {
200 0.0
201 };
202 for j in i..dim {
203 let dj = if j < row.len() {
204 row[j] - means[j]
205 } else {
206 0.0
207 };
208 cov[i][j] += di * dj;
209 }
210 }
211 }
212
213 #[allow(clippy::needless_range_loop)]
215 for i in 0..dim {
216 for j in i..dim {
217 cov[i][j] /= (n - 1) as f64;
218 cov[j][i] = cov[i][j];
219 }
220 }
221
222 let max_eigenvalues = dim;
224 let mut eigenvalues = Vec::new();
225 let mut work_cov = cov.clone();
226
227 for _ in 0..max_eigenvalues {
228 let (eigenvalue, eigenvector) = self.power_iteration(&work_cov, dim);
229 if eigenvalue.abs() < 1e-12 {
230 break;
231 }
232 eigenvalues.push(eigenvalue);
233
234 for i in 0..dim {
236 for j in 0..dim {
237 work_cov[i][j] -= eigenvalue * eigenvector[i] * eigenvector[j];
238 }
239 }
240 }
241
242 if eigenvalues.is_empty() {
243 return total_dims;
244 }
245
246 let total_variance: f64 = eigenvalues.iter().filter(|&&v| v > 0.0).sum();
248 if total_variance < 1e-12 {
249 return total_dims;
250 }
251
252 let target = 0.95 * total_variance;
253 let mut cumulative = 0.0;
254 let mut effective = 0;
255
256 for &ev in &eigenvalues {
257 if ev <= 0.0 {
258 continue;
259 }
260 cumulative += ev;
261 effective += 1;
262 if cumulative >= target {
263 break;
264 }
265 }
266
267 effective.max(1)
268 }
269
270 fn power_iteration(&self, matrix: &[Vec<f64>], dim: usize) -> (f64, Vec<f64>) {
272 let max_iter = 100;
273 let tolerance = 1e-10;
274
275 let mut v = vec![1.0 / (dim as f64).sqrt(); dim];
277 let mut eigenvalue = 0.0;
278
279 for _ in 0..max_iter {
280 let mut w = vec![0.0; dim];
282 for i in 0..dim {
283 for j in 0..dim {
284 w[i] += matrix[i][j] * v[j];
285 }
286 }
287
288 let new_eigenvalue: f64 = v.iter().zip(w.iter()).map(|(vi, wi)| vi * wi).sum();
290
291 let norm: f64 = w.iter().map(|x| x * x).sum::<f64>().sqrt();
293 if norm < 1e-15 {
294 break;
295 }
296 for x in &mut w {
297 *x /= norm;
298 }
299
300 if (new_eigenvalue - eigenvalue).abs() < tolerance {
302 eigenvalue = new_eigenvalue;
303 v = w;
304 break;
305 }
306
307 eigenvalue = new_eigenvalue;
308 v = w;
309 }
310
311 (eigenvalue, v)
312 }
313}
314
315impl Default for EmbeddingReadinessAnalyzer {
316 fn default() -> Self {
317 Self::new()
318 }
319}
320
321#[cfg(test)]
322#[allow(clippy::unwrap_used)]
323mod tests {
324 use super::*;
325
326 #[test]
327 fn test_valid_embedding_input() {
328 let input = EmbeddingInput {
329 feature_matrix: vec![
330 vec![1.0, 0.0, 0.0, 0.5],
331 vec![0.9, 0.1, 0.0, 0.6],
332 vec![0.0, 1.0, 0.0, 0.2],
333 vec![0.1, 0.9, 0.1, 0.3],
334 vec![0.0, 0.0, 1.0, 0.8],
335 vec![0.0, 0.1, 0.9, 0.7],
336 ],
337 labels: vec![
338 "A".into(),
339 "A".into(),
340 "B".into(),
341 "B".into(),
342 "C".into(),
343 "C".into(),
344 ],
345 };
346
347 let analyzer = EmbeddingReadinessAnalyzer::new();
348 let result = analyzer.analyze(&input).unwrap();
349
350 assert_eq!(result.total_dimensions, 4);
351 assert!(result.effective_dimensionality > 0);
352 assert!(result.effective_dimensionality <= 4);
353 assert!(result.contrastive_learning_viable);
354 assert_eq!(result.min_class_count, 2);
355 assert!(result.embedding_readiness_score > 0.0);
356 }
357
358 #[test]
359 fn test_invalid_single_class() {
360 let input = EmbeddingInput {
361 feature_matrix: vec![vec![1.0, 2.0], vec![3.0, 4.0], vec![5.0, 6.0]],
362 labels: vec!["A".into(), "A".into(), "A".into()],
363 };
364
365 let analyzer = EmbeddingReadinessAnalyzer::new();
366 let result = analyzer.analyze(&input).unwrap();
367
368 assert!(!result.contrastive_learning_viable);
369 assert!(!result.passes);
370 }
371
372 #[test]
373 fn test_empty_input() {
374 let input = EmbeddingInput {
375 feature_matrix: Vec::new(),
376 labels: Vec::new(),
377 };
378
379 let analyzer = EmbeddingReadinessAnalyzer::new();
380 let result = analyzer.analyze(&input).unwrap();
381
382 assert_eq!(result.total_dimensions, 0);
383 assert_eq!(result.effective_dimensionality, 0);
384 assert!(!result.contrastive_learning_viable);
385 }
386}