1use crate::error::EvalResult;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10
11#[derive(Debug, Clone)]
13pub struct EmbeddingInput {
14 pub feature_matrix: Vec<Vec<f64>>,
16 pub labels: Vec<String>,
18}
19
20#[derive(Debug, Clone)]
22pub struct EmbeddingReadinessThresholds {
23 pub min_embedding_readiness: f64,
25}
26
27impl Default for EmbeddingReadinessThresholds {
28 fn default() -> Self {
29 Self {
30 min_embedding_readiness: 0.50,
31 }
32 }
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
37pub struct EmbeddingReadinessAnalysis {
38 pub embedding_readiness_score: f64,
40 pub effective_dimensionality: usize,
42 pub total_dimensions: usize,
44 pub contrastive_learning_viable: bool,
46 pub min_class_count: usize,
48 pub passes: bool,
50 pub issues: Vec<String>,
52}
53
54pub struct EmbeddingReadinessAnalyzer {
56 thresholds: EmbeddingReadinessThresholds,
57}
58
59impl EmbeddingReadinessAnalyzer {
60 pub fn new() -> Self {
62 Self {
63 thresholds: EmbeddingReadinessThresholds::default(),
64 }
65 }
66
67 pub fn with_thresholds(thresholds: EmbeddingReadinessThresholds) -> Self {
69 Self { thresholds }
70 }
71
72 pub fn analyze(&self, input: &EmbeddingInput) -> EvalResult<EmbeddingReadinessAnalysis> {
74 let mut issues = Vec::new();
75
76 if input.feature_matrix.is_empty() {
77 return Ok(EmbeddingReadinessAnalysis {
78 embedding_readiness_score: 0.0,
79 effective_dimensionality: 0,
80 total_dimensions: 0,
81 contrastive_learning_viable: false,
82 min_class_count: 0,
83 passes: true,
84 issues: vec!["No samples provided".to_string()],
85 });
86 }
87
88 let total_dimensions = input
89 .feature_matrix
90 .first()
91 .map(std::vec::Vec::len)
92 .unwrap_or(0);
93
94 if total_dimensions == 0 {
95 return Ok(EmbeddingReadinessAnalysis {
96 embedding_readiness_score: 0.0,
97 effective_dimensionality: 0,
98 total_dimensions: 0,
99 contrastive_learning_viable: false,
100 min_class_count: 0,
101 passes: false,
102 issues: vec!["Zero-dimensional features".to_string()],
103 });
104 }
105
106 let effective_dimensionality =
108 self.compute_effective_dimensionality(&input.feature_matrix, total_dimensions);
109
110 let mut class_counts: HashMap<&str, usize> = HashMap::new();
112 for label in &input.labels {
113 *class_counts.entry(label.as_str()).or_insert(0) += 1;
114 }
115
116 let min_class_count = class_counts.values().copied().min().unwrap_or(0);
117 let num_classes = class_counts.len();
118 let contrastive_learning_viable = min_class_count >= 2 && num_classes >= 2;
119
120 if !contrastive_learning_viable {
121 issues.push(format!(
122 "Contrastive learning not viable: {num_classes} classes, min count = {min_class_count}"
123 ));
124 }
125
126 let dim_ratio = if total_dimensions > 0 {
128 effective_dimensionality as f64 / total_dimensions as f64
129 } else {
130 0.0
131 };
132 let dim_score = (1.0 - dim_ratio).clamp(0.0, 1.0);
134 let contrastive_score = if contrastive_learning_viable {
135 1.0
136 } else {
137 0.0
138 };
139 let class_balance_score = if num_classes >= 2 && min_class_count > 0 {
140 let max_count = class_counts.values().copied().max().unwrap_or(1);
141 (min_class_count as f64 / max_count as f64).clamp(0.0, 1.0)
142 } else {
143 0.0
144 };
145
146 let embedding_readiness_score =
147 (dim_score * 0.4 + contrastive_score * 0.3 + class_balance_score * 0.3).clamp(0.0, 1.0);
148
149 if embedding_readiness_score < self.thresholds.min_embedding_readiness {
150 issues.push(format!(
151 "Embedding readiness score {:.4} < {:.4} (threshold)",
152 embedding_readiness_score, self.thresholds.min_embedding_readiness
153 ));
154 }
155
156 let passes = issues.is_empty();
157
158 Ok(EmbeddingReadinessAnalysis {
159 embedding_readiness_score,
160 effective_dimensionality,
161 total_dimensions,
162 contrastive_learning_viable,
163 min_class_count,
164 passes,
165 issues,
166 })
167 }
168
169 fn compute_effective_dimensionality(
174 &self,
175 feature_matrix: &[Vec<f64>],
176 total_dims: usize,
177 ) -> usize {
178 let n = feature_matrix.len();
179 if n < 2 || total_dims == 0 {
180 return total_dims;
181 }
182
183 let mut means = vec![0.0; total_dims];
185 for row in feature_matrix {
186 for (j, &val) in row.iter().enumerate().take(total_dims) {
187 means[j] += val;
188 }
189 }
190 for m in &mut means {
191 *m /= n as f64;
192 }
193
194 let dim = total_dims.min(50); let mut cov = vec![vec![0.0; dim]; dim];
197
198 for row in feature_matrix {
199 for i in 0..dim {
200 let di = if i < row.len() {
201 row[i] - means[i]
202 } else {
203 0.0
204 };
205 for j in i..dim {
206 let dj = if j < row.len() {
207 row[j] - means[j]
208 } else {
209 0.0
210 };
211 cov[i][j] += di * dj;
212 }
213 }
214 }
215
216 #[allow(clippy::needless_range_loop)]
218 for i in 0..dim {
219 for j in i..dim {
220 cov[i][j] /= (n - 1) as f64;
221 cov[j][i] = cov[i][j];
222 }
223 }
224
225 let max_eigenvalues = dim;
227 let mut eigenvalues = Vec::new();
228 let mut work_cov = cov.clone();
229
230 for _ in 0..max_eigenvalues {
231 let (eigenvalue, eigenvector) = self.power_iteration(&work_cov, dim);
232 if eigenvalue.abs() < 1e-12 {
233 break;
234 }
235 eigenvalues.push(eigenvalue);
236
237 for i in 0..dim {
239 for j in 0..dim {
240 work_cov[i][j] -= eigenvalue * eigenvector[i] * eigenvector[j];
241 }
242 }
243 }
244
245 if eigenvalues.is_empty() {
246 return total_dims;
247 }
248
249 let total_variance: f64 = eigenvalues.iter().filter(|&&v| v > 0.0).sum();
251 if total_variance < 1e-12 {
252 return total_dims;
253 }
254
255 let target = 0.95 * total_variance;
256 let mut cumulative = 0.0;
257 let mut effective = 0;
258
259 for &ev in &eigenvalues {
260 if ev <= 0.0 {
261 continue;
262 }
263 cumulative += ev;
264 effective += 1;
265 if cumulative >= target {
266 break;
267 }
268 }
269
270 effective.max(1)
271 }
272
273 fn power_iteration(&self, matrix: &[Vec<f64>], dim: usize) -> (f64, Vec<f64>) {
275 let max_iter = 100;
276 let tolerance = 1e-10;
277
278 let mut v = vec![1.0 / (dim as f64).sqrt(); dim];
280 let mut eigenvalue = 0.0;
281
282 for _ in 0..max_iter {
283 let mut w = vec![0.0; dim];
285 for i in 0..dim {
286 for j in 0..dim {
287 w[i] += matrix[i][j] * v[j];
288 }
289 }
290
291 let new_eigenvalue: f64 = v.iter().zip(w.iter()).map(|(vi, wi)| vi * wi).sum();
293
294 let norm: f64 = w.iter().map(|x| x * x).sum::<f64>().sqrt();
296 if norm < 1e-15 {
297 break;
298 }
299 for x in &mut w {
300 *x /= norm;
301 }
302
303 if (new_eigenvalue - eigenvalue).abs() < tolerance {
305 eigenvalue = new_eigenvalue;
306 v = w;
307 break;
308 }
309
310 eigenvalue = new_eigenvalue;
311 v = w;
312 }
313
314 (eigenvalue, v)
315 }
316}
317
318impl Default for EmbeddingReadinessAnalyzer {
319 fn default() -> Self {
320 Self::new()
321 }
322}
323
324#[cfg(test)]
325#[allow(clippy::unwrap_used)]
326mod tests {
327 use super::*;
328
329 #[test]
330 fn test_valid_embedding_input() {
331 let input = EmbeddingInput {
332 feature_matrix: vec![
333 vec![1.0, 0.0, 0.0, 0.5],
334 vec![0.9, 0.1, 0.0, 0.6],
335 vec![0.0, 1.0, 0.0, 0.2],
336 vec![0.1, 0.9, 0.1, 0.3],
337 vec![0.0, 0.0, 1.0, 0.8],
338 vec![0.0, 0.1, 0.9, 0.7],
339 ],
340 labels: vec![
341 "A".into(),
342 "A".into(),
343 "B".into(),
344 "B".into(),
345 "C".into(),
346 "C".into(),
347 ],
348 };
349
350 let analyzer = EmbeddingReadinessAnalyzer::new();
351 let result = analyzer.analyze(&input).unwrap();
352
353 assert_eq!(result.total_dimensions, 4);
354 assert!(result.effective_dimensionality > 0);
355 assert!(result.effective_dimensionality <= 4);
356 assert!(result.contrastive_learning_viable);
357 assert_eq!(result.min_class_count, 2);
358 assert!(result.embedding_readiness_score > 0.0);
359 }
360
361 #[test]
362 fn test_invalid_single_class() {
363 let input = EmbeddingInput {
364 feature_matrix: vec![vec![1.0, 2.0], vec![3.0, 4.0], vec![5.0, 6.0]],
365 labels: vec!["A".into(), "A".into(), "A".into()],
366 };
367
368 let analyzer = EmbeddingReadinessAnalyzer::new();
369 let result = analyzer.analyze(&input).unwrap();
370
371 assert!(!result.contrastive_learning_viable);
372 assert!(!result.passes);
373 }
374
375 #[test]
376 fn test_empty_input() {
377 let input = EmbeddingInput {
378 feature_matrix: Vec::new(),
379 labels: Vec::new(),
380 };
381
382 let analyzer = EmbeddingReadinessAnalyzer::new();
383 let result = analyzer.analyze(&input).unwrap();
384
385 assert_eq!(result.total_dimensions, 0);
386 assert_eq!(result.effective_dimensionality, 0);
387 assert!(!result.contrastive_learning_viable);
388 }
389}