Skip to main content

datasynth_eval/ml/
embedding_readiness.rs

1//! Embedding readiness evaluation.
2//!
3//! Validates prerequisites for representation learning by checking effective
4//! dimensionality (via eigendecomposition), contrastive learning viability
5//! (minimum class counts), and feature overlap between classes.
6
7use crate::error::EvalResult;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10
11/// Input for embedding readiness analysis.
12#[derive(Debug, Clone)]
13pub struct EmbeddingInput {
14    /// Feature matrix: rows are samples, columns are features.
15    pub feature_matrix: Vec<Vec<f64>>,
16    /// Class labels for each sample.
17    pub labels: Vec<String>,
18}
19
20/// Thresholds for embedding readiness analysis.
21#[derive(Debug, Clone)]
22pub struct EmbeddingReadinessThresholds {
23    /// Minimum embedding readiness score.
24    pub min_embedding_readiness: f64,
25}
26
27impl Default for EmbeddingReadinessThresholds {
28    fn default() -> Self {
29        Self {
30            min_embedding_readiness: 0.50,
31        }
32    }
33}
34
35/// Results of embedding readiness analysis.
36#[derive(Debug, Clone, Serialize, Deserialize)]
37pub struct EmbeddingReadinessAnalysis {
38    /// Overall embedding readiness score (0.0-1.0).
39    pub embedding_readiness_score: f64,
40    /// Number of dimensions needed to capture 95% of variance.
41    pub effective_dimensionality: usize,
42    /// Total number of feature dimensions.
43    pub total_dimensions: usize,
44    /// Whether contrastive learning is viable (min class count >= 2).
45    pub contrastive_learning_viable: bool,
46    /// Minimum number of samples in any class.
47    pub min_class_count: usize,
48    /// Whether the analysis passes all thresholds.
49    pub passes: bool,
50    /// Issues found during analysis.
51    pub issues: Vec<String>,
52}
53
54/// Analyzer for embedding readiness.
55pub struct EmbeddingReadinessAnalyzer {
56    thresholds: EmbeddingReadinessThresholds,
57}
58
59impl EmbeddingReadinessAnalyzer {
60    /// Create a new analyzer with default thresholds.
61    pub fn new() -> Self {
62        Self {
63            thresholds: EmbeddingReadinessThresholds::default(),
64        }
65    }
66
67    /// Create an analyzer with custom thresholds.
68    pub fn with_thresholds(thresholds: EmbeddingReadinessThresholds) -> Self {
69        Self { thresholds }
70    }
71
72    /// Analyze embedding readiness.
73    pub fn analyze(&self, input: &EmbeddingInput) -> EvalResult<EmbeddingReadinessAnalysis> {
74        let mut issues = Vec::new();
75
76        if input.feature_matrix.is_empty() {
77            return Ok(EmbeddingReadinessAnalysis {
78                embedding_readiness_score: 0.0,
79                effective_dimensionality: 0,
80                total_dimensions: 0,
81                contrastive_learning_viable: false,
82                min_class_count: 0,
83                passes: true,
84                issues: vec!["No samples provided".to_string()],
85            });
86        }
87
88        let total_dimensions = input.feature_matrix.first().map(|r| r.len()).unwrap_or(0);
89
90        if total_dimensions == 0 {
91            return Ok(EmbeddingReadinessAnalysis {
92                embedding_readiness_score: 0.0,
93                effective_dimensionality: 0,
94                total_dimensions: 0,
95                contrastive_learning_viable: false,
96                min_class_count: 0,
97                passes: false,
98                issues: vec!["Zero-dimensional features".to_string()],
99            });
100        }
101
102        // Compute effective dimensionality
103        let effective_dimensionality =
104            self.compute_effective_dimensionality(&input.feature_matrix, total_dimensions);
105
106        // Check contrastive learning viability
107        let mut class_counts: HashMap<&str, usize> = HashMap::new();
108        for label in &input.labels {
109            *class_counts.entry(label.as_str()).or_insert(0) += 1;
110        }
111
112        let min_class_count = class_counts.values().copied().min().unwrap_or(0);
113        let num_classes = class_counts.len();
114        let contrastive_learning_viable = min_class_count >= 2 && num_classes >= 2;
115
116        if !contrastive_learning_viable {
117            issues.push(format!(
118                "Contrastive learning not viable: {} classes, min count = {}",
119                num_classes, min_class_count
120            ));
121        }
122
123        // Composite readiness score
124        let dim_ratio = if total_dimensions > 0 {
125            effective_dimensionality as f64 / total_dimensions as f64
126        } else {
127            0.0
128        };
129        // Lower effective dimensionality ratio = better (more compressible)
130        let dim_score = (1.0 - dim_ratio).clamp(0.0, 1.0);
131        let contrastive_score = if contrastive_learning_viable {
132            1.0
133        } else {
134            0.0
135        };
136        let class_balance_score = if num_classes >= 2 && min_class_count > 0 {
137            let max_count = class_counts.values().copied().max().unwrap_or(1);
138            (min_class_count as f64 / max_count as f64).clamp(0.0, 1.0)
139        } else {
140            0.0
141        };
142
143        let embedding_readiness_score =
144            (dim_score * 0.4 + contrastive_score * 0.3 + class_balance_score * 0.3).clamp(0.0, 1.0);
145
146        if embedding_readiness_score < self.thresholds.min_embedding_readiness {
147            issues.push(format!(
148                "Embedding readiness score {:.4} < {:.4} (threshold)",
149                embedding_readiness_score, self.thresholds.min_embedding_readiness
150            ));
151        }
152
153        let passes = issues.is_empty();
154
155        Ok(EmbeddingReadinessAnalysis {
156            embedding_readiness_score,
157            effective_dimensionality,
158            total_dimensions,
159            contrastive_learning_viable,
160            min_class_count,
161            passes,
162            issues,
163        })
164    }
165
166    /// Compute effective dimensionality using power iteration on the covariance matrix.
167    ///
168    /// Finds the top eigenvalues and counts how many are needed to reach 95%
169    /// of total variance.
170    fn compute_effective_dimensionality(
171        &self,
172        feature_matrix: &[Vec<f64>],
173        total_dims: usize,
174    ) -> usize {
175        let n = feature_matrix.len();
176        if n < 2 || total_dims == 0 {
177            return total_dims;
178        }
179
180        // Compute column means
181        let mut means = vec![0.0; total_dims];
182        for row in feature_matrix {
183            for (j, &val) in row.iter().enumerate().take(total_dims) {
184                means[j] += val;
185            }
186        }
187        for m in &mut means {
188            *m /= n as f64;
189        }
190
191        // Compute covariance matrix (total_dims x total_dims)
192        let dim = total_dims.min(50); // Cap for computational feasibility
193        let mut cov = vec![vec![0.0; dim]; dim];
194
195        for row in feature_matrix {
196            for i in 0..dim {
197                let di = if i < row.len() {
198                    row[i] - means[i]
199                } else {
200                    0.0
201                };
202                for j in i..dim {
203                    let dj = if j < row.len() {
204                        row[j] - means[j]
205                    } else {
206                        0.0
207                    };
208                    cov[i][j] += di * dj;
209                }
210            }
211        }
212
213        // Symmetrize and normalize
214        #[allow(clippy::needless_range_loop)]
215        for i in 0..dim {
216            for j in i..dim {
217                cov[i][j] /= (n - 1) as f64;
218                cov[j][i] = cov[i][j];
219            }
220        }
221
222        // Extract eigenvalues via repeated power iteration with deflation
223        let max_eigenvalues = dim;
224        let mut eigenvalues = Vec::new();
225        let mut work_cov = cov.clone();
226
227        for _ in 0..max_eigenvalues {
228            let (eigenvalue, eigenvector) = self.power_iteration(&work_cov, dim);
229            if eigenvalue.abs() < 1e-12 {
230                break;
231            }
232            eigenvalues.push(eigenvalue);
233
234            // Deflate: A = A - lambda * v * v^T
235            for i in 0..dim {
236                for j in 0..dim {
237                    work_cov[i][j] -= eigenvalue * eigenvector[i] * eigenvector[j];
238                }
239            }
240        }
241
242        if eigenvalues.is_empty() {
243            return total_dims;
244        }
245
246        // Count dimensions for 95% of total variance
247        let total_variance: f64 = eigenvalues.iter().filter(|&&v| v > 0.0).sum();
248        if total_variance < 1e-12 {
249            return total_dims;
250        }
251
252        let target = 0.95 * total_variance;
253        let mut cumulative = 0.0;
254        let mut effective = 0;
255
256        for &ev in &eigenvalues {
257            if ev <= 0.0 {
258                continue;
259            }
260            cumulative += ev;
261            effective += 1;
262            if cumulative >= target {
263                break;
264            }
265        }
266
267        effective.max(1)
268    }
269
270    /// Power iteration to find the largest eigenvalue and corresponding eigenvector.
271    fn power_iteration(&self, matrix: &[Vec<f64>], dim: usize) -> (f64, Vec<f64>) {
272        let max_iter = 100;
273        let tolerance = 1e-10;
274
275        // Initialize with a non-zero vector
276        let mut v = vec![1.0 / (dim as f64).sqrt(); dim];
277        let mut eigenvalue = 0.0;
278
279        for _ in 0..max_iter {
280            // w = A * v
281            let mut w = vec![0.0; dim];
282            for i in 0..dim {
283                for j in 0..dim {
284                    w[i] += matrix[i][j] * v[j];
285                }
286            }
287
288            // Compute eigenvalue as v^T * w
289            let new_eigenvalue: f64 = v.iter().zip(w.iter()).map(|(vi, wi)| vi * wi).sum();
290
291            // Normalize w
292            let norm: f64 = w.iter().map(|x| x * x).sum::<f64>().sqrt();
293            if norm < 1e-15 {
294                break;
295            }
296            for x in &mut w {
297                *x /= norm;
298            }
299
300            // Check convergence
301            if (new_eigenvalue - eigenvalue).abs() < tolerance {
302                eigenvalue = new_eigenvalue;
303                v = w;
304                break;
305            }
306
307            eigenvalue = new_eigenvalue;
308            v = w;
309        }
310
311        (eigenvalue, v)
312    }
313}
314
315impl Default for EmbeddingReadinessAnalyzer {
316    fn default() -> Self {
317        Self::new()
318    }
319}
320
321#[cfg(test)]
322#[allow(clippy::unwrap_used)]
323mod tests {
324    use super::*;
325
326    #[test]
327    fn test_valid_embedding_input() {
328        let input = EmbeddingInput {
329            feature_matrix: vec![
330                vec![1.0, 0.0, 0.0, 0.5],
331                vec![0.9, 0.1, 0.0, 0.6],
332                vec![0.0, 1.0, 0.0, 0.2],
333                vec![0.1, 0.9, 0.1, 0.3],
334                vec![0.0, 0.0, 1.0, 0.8],
335                vec![0.0, 0.1, 0.9, 0.7],
336            ],
337            labels: vec![
338                "A".into(),
339                "A".into(),
340                "B".into(),
341                "B".into(),
342                "C".into(),
343                "C".into(),
344            ],
345        };
346
347        let analyzer = EmbeddingReadinessAnalyzer::new();
348        let result = analyzer.analyze(&input).unwrap();
349
350        assert_eq!(result.total_dimensions, 4);
351        assert!(result.effective_dimensionality > 0);
352        assert!(result.effective_dimensionality <= 4);
353        assert!(result.contrastive_learning_viable);
354        assert_eq!(result.min_class_count, 2);
355        assert!(result.embedding_readiness_score > 0.0);
356    }
357
358    #[test]
359    fn test_invalid_single_class() {
360        let input = EmbeddingInput {
361            feature_matrix: vec![vec![1.0, 2.0], vec![3.0, 4.0], vec![5.0, 6.0]],
362            labels: vec!["A".into(), "A".into(), "A".into()],
363        };
364
365        let analyzer = EmbeddingReadinessAnalyzer::new();
366        let result = analyzer.analyze(&input).unwrap();
367
368        assert!(!result.contrastive_learning_viable);
369        assert!(!result.passes);
370    }
371
372    #[test]
373    fn test_empty_input() {
374        let input = EmbeddingInput {
375            feature_matrix: Vec::new(),
376            labels: Vec::new(),
377        };
378
379        let analyzer = EmbeddingReadinessAnalyzer::new();
380        let result = analyzer.analyze(&input).unwrap();
381
382        assert_eq!(result.total_dimensions, 0);
383        assert_eq!(result.effective_dimensionality, 0);
384        assert!(!result.contrastive_learning_viable);
385    }
386}