Skip to main content

datasynth_eval/ml/
embedding_readiness.rs

1//! Embedding readiness evaluation.
2//!
3//! Validates prerequisites for representation learning by checking effective
4//! dimensionality (via eigendecomposition), contrastive learning viability
5//! (minimum class counts), and feature overlap between classes.
6
7use crate::error::EvalResult;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10
11/// Input for embedding readiness analysis.
12#[derive(Debug, Clone)]
13pub struct EmbeddingInput {
14    /// Feature matrix: rows are samples, columns are features.
15    pub feature_matrix: Vec<Vec<f64>>,
16    /// Class labels for each sample.
17    pub labels: Vec<String>,
18}
19
20/// Thresholds for embedding readiness analysis.
21#[derive(Debug, Clone)]
22pub struct EmbeddingReadinessThresholds {
23    /// Minimum embedding readiness score.
24    pub min_embedding_readiness: f64,
25}
26
27impl Default for EmbeddingReadinessThresholds {
28    fn default() -> Self {
29        Self {
30            min_embedding_readiness: 0.50,
31        }
32    }
33}
34
35/// Results of embedding readiness analysis.
36#[derive(Debug, Clone, Serialize, Deserialize)]
37pub struct EmbeddingReadinessAnalysis {
38    /// Overall embedding readiness score (0.0-1.0).
39    pub embedding_readiness_score: f64,
40    /// Number of dimensions needed to capture 95% of variance.
41    pub effective_dimensionality: usize,
42    /// Total number of feature dimensions.
43    pub total_dimensions: usize,
44    /// Whether contrastive learning is viable (min class count >= 2).
45    pub contrastive_learning_viable: bool,
46    /// Minimum number of samples in any class.
47    pub min_class_count: usize,
48    /// Whether the analysis passes all thresholds.
49    pub passes: bool,
50    /// Issues found during analysis.
51    pub issues: Vec<String>,
52}
53
54/// Analyzer for embedding readiness.
55pub struct EmbeddingReadinessAnalyzer {
56    thresholds: EmbeddingReadinessThresholds,
57}
58
59impl EmbeddingReadinessAnalyzer {
60    /// Create a new analyzer with default thresholds.
61    pub fn new() -> Self {
62        Self {
63            thresholds: EmbeddingReadinessThresholds::default(),
64        }
65    }
66
67    /// Create an analyzer with custom thresholds.
68    pub fn with_thresholds(thresholds: EmbeddingReadinessThresholds) -> Self {
69        Self { thresholds }
70    }
71
72    /// Analyze embedding readiness.
73    pub fn analyze(&self, input: &EmbeddingInput) -> EvalResult<EmbeddingReadinessAnalysis> {
74        let mut issues = Vec::new();
75
76        if input.feature_matrix.is_empty() {
77            return Ok(EmbeddingReadinessAnalysis {
78                embedding_readiness_score: 0.0,
79                effective_dimensionality: 0,
80                total_dimensions: 0,
81                contrastive_learning_viable: false,
82                min_class_count: 0,
83                passes: true,
84                issues: vec!["No samples provided".to_string()],
85            });
86        }
87
88        let total_dimensions = input
89            .feature_matrix
90            .first()
91            .map(std::vec::Vec::len)
92            .unwrap_or(0);
93
94        if total_dimensions == 0 {
95            return Ok(EmbeddingReadinessAnalysis {
96                embedding_readiness_score: 0.0,
97                effective_dimensionality: 0,
98                total_dimensions: 0,
99                contrastive_learning_viable: false,
100                min_class_count: 0,
101                passes: false,
102                issues: vec!["Zero-dimensional features".to_string()],
103            });
104        }
105
106        // Compute effective dimensionality
107        let effective_dimensionality =
108            self.compute_effective_dimensionality(&input.feature_matrix, total_dimensions);
109
110        // Check contrastive learning viability
111        let mut class_counts: HashMap<&str, usize> = HashMap::new();
112        for label in &input.labels {
113            *class_counts.entry(label.as_str()).or_insert(0) += 1;
114        }
115
116        let min_class_count = class_counts.values().copied().min().unwrap_or(0);
117        let num_classes = class_counts.len();
118        let contrastive_learning_viable = min_class_count >= 2 && num_classes >= 2;
119
120        if !contrastive_learning_viable {
121            issues.push(format!(
122                "Contrastive learning not viable: {num_classes} classes, min count = {min_class_count}"
123            ));
124        }
125
126        // Composite readiness score
127        let dim_ratio = if total_dimensions > 0 {
128            effective_dimensionality as f64 / total_dimensions as f64
129        } else {
130            0.0
131        };
132        // Lower effective dimensionality ratio = better (more compressible)
133        let dim_score = (1.0 - dim_ratio).clamp(0.0, 1.0);
134        let contrastive_score = if contrastive_learning_viable {
135            1.0
136        } else {
137            0.0
138        };
139        let class_balance_score = if num_classes >= 2 && min_class_count > 0 {
140            let max_count = class_counts.values().copied().max().unwrap_or(1);
141            (min_class_count as f64 / max_count as f64).clamp(0.0, 1.0)
142        } else {
143            0.0
144        };
145
146        let embedding_readiness_score =
147            (dim_score * 0.4 + contrastive_score * 0.3 + class_balance_score * 0.3).clamp(0.0, 1.0);
148
149        if embedding_readiness_score < self.thresholds.min_embedding_readiness {
150            issues.push(format!(
151                "Embedding readiness score {:.4} < {:.4} (threshold)",
152                embedding_readiness_score, self.thresholds.min_embedding_readiness
153            ));
154        }
155
156        let passes = issues.is_empty();
157
158        Ok(EmbeddingReadinessAnalysis {
159            embedding_readiness_score,
160            effective_dimensionality,
161            total_dimensions,
162            contrastive_learning_viable,
163            min_class_count,
164            passes,
165            issues,
166        })
167    }
168
169    /// Compute effective dimensionality using power iteration on the covariance matrix.
170    ///
171    /// Finds the top eigenvalues and counts how many are needed to reach 95%
172    /// of total variance.
173    fn compute_effective_dimensionality(
174        &self,
175        feature_matrix: &[Vec<f64>],
176        total_dims: usize,
177    ) -> usize {
178        let n = feature_matrix.len();
179        if n < 2 || total_dims == 0 {
180            return total_dims;
181        }
182
183        // Compute column means
184        let mut means = vec![0.0; total_dims];
185        for row in feature_matrix {
186            for (j, &val) in row.iter().enumerate().take(total_dims) {
187                means[j] += val;
188            }
189        }
190        for m in &mut means {
191            *m /= n as f64;
192        }
193
194        // Compute covariance matrix (total_dims x total_dims)
195        let dim = total_dims.min(50); // Cap for computational feasibility
196        let mut cov = vec![vec![0.0; dim]; dim];
197
198        for row in feature_matrix {
199            for i in 0..dim {
200                let di = if i < row.len() {
201                    row[i] - means[i]
202                } else {
203                    0.0
204                };
205                for j in i..dim {
206                    let dj = if j < row.len() {
207                        row[j] - means[j]
208                    } else {
209                        0.0
210                    };
211                    cov[i][j] += di * dj;
212                }
213            }
214        }
215
216        // Symmetrize and normalize
217        #[allow(clippy::needless_range_loop)]
218        for i in 0..dim {
219            for j in i..dim {
220                cov[i][j] /= (n - 1) as f64;
221                cov[j][i] = cov[i][j];
222            }
223        }
224
225        // Extract eigenvalues via repeated power iteration with deflation
226        let max_eigenvalues = dim;
227        let mut eigenvalues = Vec::new();
228        let mut work_cov = cov.clone();
229
230        for _ in 0..max_eigenvalues {
231            let (eigenvalue, eigenvector) = self.power_iteration(&work_cov, dim);
232            if eigenvalue.abs() < 1e-12 {
233                break;
234            }
235            eigenvalues.push(eigenvalue);
236
237            // Deflate: A = A - lambda * v * v^T
238            for i in 0..dim {
239                for j in 0..dim {
240                    work_cov[i][j] -= eigenvalue * eigenvector[i] * eigenvector[j];
241                }
242            }
243        }
244
245        if eigenvalues.is_empty() {
246            return total_dims;
247        }
248
249        // Count dimensions for 95% of total variance
250        let total_variance: f64 = eigenvalues.iter().filter(|&&v| v > 0.0).sum();
251        if total_variance < 1e-12 {
252            return total_dims;
253        }
254
255        let target = 0.95 * total_variance;
256        let mut cumulative = 0.0;
257        let mut effective = 0;
258
259        for &ev in &eigenvalues {
260            if ev <= 0.0 {
261                continue;
262            }
263            cumulative += ev;
264            effective += 1;
265            if cumulative >= target {
266                break;
267            }
268        }
269
270        effective.max(1)
271    }
272
273    /// Power iteration to find the largest eigenvalue and corresponding eigenvector.
274    fn power_iteration(&self, matrix: &[Vec<f64>], dim: usize) -> (f64, Vec<f64>) {
275        let max_iter = 100;
276        let tolerance = 1e-10;
277
278        // Initialize with a non-zero vector
279        let mut v = vec![1.0 / (dim as f64).sqrt(); dim];
280        let mut eigenvalue = 0.0;
281
282        for _ in 0..max_iter {
283            // w = A * v
284            let mut w = vec![0.0; dim];
285            for i in 0..dim {
286                for j in 0..dim {
287                    w[i] += matrix[i][j] * v[j];
288                }
289            }
290
291            // Compute eigenvalue as v^T * w
292            let new_eigenvalue: f64 = v.iter().zip(w.iter()).map(|(vi, wi)| vi * wi).sum();
293
294            // Normalize w
295            let norm: f64 = w.iter().map(|x| x * x).sum::<f64>().sqrt();
296            if norm < 1e-15 {
297                break;
298            }
299            for x in &mut w {
300                *x /= norm;
301            }
302
303            // Check convergence
304            if (new_eigenvalue - eigenvalue).abs() < tolerance {
305                eigenvalue = new_eigenvalue;
306                v = w;
307                break;
308            }
309
310            eigenvalue = new_eigenvalue;
311            v = w;
312        }
313
314        (eigenvalue, v)
315    }
316}
317
318impl Default for EmbeddingReadinessAnalyzer {
319    fn default() -> Self {
320        Self::new()
321    }
322}
323
324#[cfg(test)]
325#[allow(clippy::unwrap_used)]
326mod tests {
327    use super::*;
328
329    #[test]
330    fn test_valid_embedding_input() {
331        let input = EmbeddingInput {
332            feature_matrix: vec![
333                vec![1.0, 0.0, 0.0, 0.5],
334                vec![0.9, 0.1, 0.0, 0.6],
335                vec![0.0, 1.0, 0.0, 0.2],
336                vec![0.1, 0.9, 0.1, 0.3],
337                vec![0.0, 0.0, 1.0, 0.8],
338                vec![0.0, 0.1, 0.9, 0.7],
339            ],
340            labels: vec![
341                "A".into(),
342                "A".into(),
343                "B".into(),
344                "B".into(),
345                "C".into(),
346                "C".into(),
347            ],
348        };
349
350        let analyzer = EmbeddingReadinessAnalyzer::new();
351        let result = analyzer.analyze(&input).unwrap();
352
353        assert_eq!(result.total_dimensions, 4);
354        assert!(result.effective_dimensionality > 0);
355        assert!(result.effective_dimensionality <= 4);
356        assert!(result.contrastive_learning_viable);
357        assert_eq!(result.min_class_count, 2);
358        assert!(result.embedding_readiness_score > 0.0);
359    }
360
361    #[test]
362    fn test_invalid_single_class() {
363        let input = EmbeddingInput {
364            feature_matrix: vec![vec![1.0, 2.0], vec![3.0, 4.0], vec![5.0, 6.0]],
365            labels: vec!["A".into(), "A".into(), "A".into()],
366        };
367
368        let analyzer = EmbeddingReadinessAnalyzer::new();
369        let result = analyzer.analyze(&input).unwrap();
370
371        assert!(!result.contrastive_learning_viable);
372        assert!(!result.passes);
373    }
374
375    #[test]
376    fn test_empty_input() {
377        let input = EmbeddingInput {
378            feature_matrix: Vec::new(),
379            labels: Vec::new(),
380        };
381
382        let analyzer = EmbeddingReadinessAnalyzer::new();
383        let result = analyzer.analyze(&input).unwrap();
384
385        assert_eq!(result.total_dimensions, 0);
386        assert_eq!(result.effective_dimensionality, 0);
387        assert!(!result.contrastive_learning_viable);
388    }
389}