Skip to main content

proof_engine/ml/
embeddings.rs

1//! Embedding space visualization: PCA, t-SNE, UMAP, nearest neighbors.
2
3use glam::Vec2;
4
5/// A collection of embedding vectors with optional labels.
6#[derive(Debug, Clone)]
7pub struct EmbeddingSpace {
8    pub vectors: Vec<Vec<f32>>,
9    pub labels: Vec<String>,
10    pub dim: usize,
11}
12
13impl EmbeddingSpace {
14    pub fn new(dim: usize) -> Self {
15        Self { vectors: Vec::new(), labels: Vec::new(), dim }
16    }
17
18    pub fn add(&mut self, vector: Vec<f32>, label: String) {
19        assert_eq!(vector.len(), self.dim);
20        self.vectors.push(vector);
21        self.labels.push(label);
22    }
23
24    pub fn len(&self) -> usize {
25        self.vectors.len()
26    }
27
28    pub fn is_empty(&self) -> bool {
29        self.vectors.is_empty()
30    }
31}
32
33/// Cosine similarity between two vectors.
34pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
35    assert_eq!(a.len(), b.len());
36    let dot: f32 = a.iter().zip(b).map(|(x, y)| x * y).sum();
37    let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
38    let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
39    let denom = norm_a * norm_b;
40    if denom < 1e-12 { 0.0 } else { dot / denom }
41}
42
43/// Find the k nearest neighbors of a query vector, returning (index, similarity).
44pub fn nearest_neighbors(space: &EmbeddingSpace, query: &[f32], k: usize) -> Vec<(usize, f32)> {
45    assert_eq!(query.len(), space.dim);
46    let mut scored: Vec<(usize, f32)> = space.vectors.iter().enumerate()
47        .map(|(i, v)| (i, cosine_similarity(query, v)))
48        .collect();
49    scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
50    scored.truncate(k);
51    scored
52}
53
54// ── PCA ─────────────────────────────────────────────────────────────────
55
56/// PCA dimensionality reduction via power iteration for top-k eigenvectors.
57pub fn pca(vectors: &[Vec<f32>], target_dim: usize) -> Vec<Vec<f32>> {
58    if vectors.is_empty() { return vec![]; }
59    let n = vectors.len();
60    let d = vectors[0].len();
61    let target_dim = target_dim.min(d);
62
63    // Compute mean
64    let mut mean = vec![0.0f32; d];
65    for v in vectors {
66        for (i, &val) in v.iter().enumerate() {
67            mean[i] += val;
68        }
69    }
70    for m in &mut mean { *m /= n as f32; }
71
72    // Center data
73    let centered: Vec<Vec<f32>> = vectors.iter()
74        .map(|v| v.iter().zip(&mean).map(|(a, b)| a - b).collect())
75        .collect();
76
77    // Compute covariance matrix (d x d) — for efficiency with large d,
78    // we compute X^T X / n which is d x d.
79    let mut cov = vec![vec![0.0f32; d]; d];
80    for v in &centered {
81        for i in 0..d {
82            for j in i..d {
83                let val = v[i] * v[j];
84                cov[i][j] += val;
85                if i != j { cov[j][i] += val; }
86            }
87        }
88    }
89    let nf = n as f32;
90    for row in &mut cov {
91        for val in row.iter_mut() { *val /= nf; }
92    }
93
94    // Power iteration for top eigenvectors
95    let mut components = Vec::with_capacity(target_dim);
96    let mut deflated_cov = cov;
97
98    for _ in 0..target_dim {
99        let eigvec = power_iteration(&deflated_cov, d, 100);
100        // Deflate: C = C - lambda * v * v^T
101        // lambda = v^T C v
102        let mut lambda = 0.0f32;
103        for i in 0..d {
104            let mut row_dot = 0.0f32;
105            for j in 0..d {
106                row_dot += deflated_cov[i][j] * eigvec[j];
107            }
108            lambda += eigvec[i] * row_dot;
109        }
110        for i in 0..d {
111            for j in 0..d {
112                deflated_cov[i][j] -= lambda * eigvec[i] * eigvec[j];
113            }
114        }
115        components.push(eigvec);
116    }
117
118    // Project data onto components
119    centered.iter().map(|v| {
120        components.iter().map(|comp| {
121            v.iter().zip(comp).map(|(a, b)| a * b).sum()
122        }).collect()
123    }).collect()
124}
125
126fn power_iteration(matrix: &[Vec<f32>], d: usize, iterations: usize) -> Vec<f32> {
127    let mut v = vec![0.0f32; d];
128    // Initialize with [1, 0, 0, ...]
129    if d > 0 { v[0] = 1.0; }
130    // Add some variation to avoid degenerate cases
131    for i in 0..d { v[i] = 1.0 / (1.0 + i as f32); }
132    normalize(&mut v);
133
134    for _ in 0..iterations {
135        let mut new_v = vec![0.0f32; d];
136        for i in 0..d {
137            for j in 0..d {
138                new_v[i] += matrix[i][j] * v[j];
139            }
140        }
141        normalize(&mut new_v);
142        v = new_v;
143    }
144    v
145}
146
147fn normalize(v: &mut [f32]) {
148    let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
149    if norm > 1e-12 {
150        for x in v.iter_mut() { *x /= norm; }
151    }
152}
153
154// ── t-SNE ───────────────────────────────────────────────────────────────
155
156/// Simplified t-SNE implementation.
157pub fn tsne(vectors: &[Vec<f32>], target_dim: usize, perplexity: f32, iterations: usize) -> Vec<Vec2> {
158    let n = vectors.len();
159    if n == 0 { return vec![]; }
160    let target_dim = target_dim.min(2); // we output Vec2 so max 2
161    let _ = target_dim;
162
163    // Compute pairwise squared distances
164    let mut dist2 = vec![vec![0.0f32; n]; n];
165    for i in 0..n {
166        for j in i + 1..n {
167            let d: f32 = vectors[i].iter().zip(&vectors[j])
168                .map(|(a, b)| (a - b) * (a - b)).sum();
169            dist2[i][j] = d;
170            dist2[j][i] = d;
171        }
172    }
173
174    // Compute conditional probabilities P(j|i) using binary search for sigma
175    let mut p = vec![vec![0.0f32; n]; n];
176    let target_entropy = perplexity.ln();
177
178    for i in 0..n {
179        let mut sigma = 1.0f32;
180        // Binary search for sigma that matches target perplexity
181        let mut lo = 1e-10f32;
182        let mut hi = 1e4f32;
183        for _ in 0..50 {
184            sigma = (lo + hi) / 2.0;
185            let mut sum_exp = 0.0f32;
186            for j in 0..n {
187                if j != i {
188                    sum_exp += (-dist2[i][j] / (2.0 * sigma * sigma)).exp();
189                }
190            }
191            if sum_exp < 1e-12 { lo = sigma; continue; }
192            let mut entropy = 0.0f32;
193            for j in 0..n {
194                if j != i {
195                    let pj = (-dist2[i][j] / (2.0 * sigma * sigma)).exp() / sum_exp;
196                    if pj > 1e-12 { entropy -= pj * pj.ln(); }
197                }
198            }
199            if entropy > target_entropy { hi = sigma; } else { lo = sigma; }
200        }
201        // Set probabilities with this sigma
202        let mut sum_exp = 0.0f32;
203        for j in 0..n {
204            if j != i {
205                sum_exp += (-dist2[i][j] / (2.0 * sigma * sigma)).exp();
206            }
207        }
208        if sum_exp > 1e-12 {
209            for j in 0..n {
210                if j != i {
211                    p[i][j] = (-dist2[i][j] / (2.0 * sigma * sigma)).exp() / sum_exp;
212                }
213            }
214        }
215    }
216
217    // Symmetrize: P = (P + P^T) / (2N)
218    for i in 0..n {
219        for j in i + 1..n {
220            let sym = (p[i][j] + p[j][i]) / (2.0 * n as f32);
221            p[i][j] = sym.max(1e-12);
222            p[j][i] = sym.max(1e-12);
223        }
224    }
225
226    // Initialize embedding with small random values
227    let mut y: Vec<[f32; 2]> = Vec::with_capacity(n);
228    let mut rng = 42u64;
229    for _ in 0..n {
230        rng ^= rng << 13; rng ^= rng >> 7; rng ^= rng << 17;
231        let x = (rng as u32 as f32 / u32::MAX as f32 - 0.5) * 0.01;
232        rng ^= rng << 13; rng ^= rng >> 7; rng ^= rng << 17;
233        let y_val = (rng as u32 as f32 / u32::MAX as f32 - 0.5) * 0.01;
234        y.push([x, y_val]);
235    }
236
237    let lr = 200.0f32;
238    let momentum = 0.8f32;
239    let mut gains = vec![[1.0f32; 2]; n];
240    let mut vy = vec![[0.0f32; 2]; n];
241
242    for _iter in 0..iterations {
243        // Compute Q distribution (Student-t with 1 DOF)
244        let mut q_unnorm = vec![vec![0.0f32; n]; n];
245        let mut q_sum = 0.0f32;
246        for i in 0..n {
247            for j in i + 1..n {
248                let dx = y[i][0] - y[j][0];
249                let dy = y[i][1] - y[j][1];
250                let val = 1.0 / (1.0 + dx * dx + dy * dy);
251                q_unnorm[i][j] = val;
252                q_unnorm[j][i] = val;
253                q_sum += 2.0 * val;
254            }
255        }
256        if q_sum < 1e-12 { q_sum = 1e-12; }
257
258        // Compute gradients
259        let mut grad = vec![[0.0f32; 2]; n];
260        for i in 0..n {
261            for j in 0..n {
262                if i == j { continue; }
263                let q_ij = q_unnorm[i][j] / q_sum;
264                let mult = 4.0 * (p[i][j] - q_ij) * q_unnorm[i][j];
265                grad[i][0] += mult * (y[i][0] - y[j][0]);
266                grad[i][1] += mult * (y[i][1] - y[j][1]);
267            }
268        }
269
270        // Update
271        for i in 0..n {
272            for d in 0..2 {
273                // Adaptive gains
274                if (grad[i][d] > 0.0) != (vy[i][d] > 0.0) {
275                    gains[i][d] = (gains[i][d] + 0.2).min(10.0);
276                } else {
277                    gains[i][d] = (gains[i][d] * 0.8).max(0.01);
278                }
279                vy[i][d] = momentum * vy[i][d] - lr * gains[i][d] * grad[i][d];
280                y[i][d] += vy[i][d];
281            }
282        }
283    }
284
285    y.iter().map(|p| Vec2::new(p[0], p[1])).collect()
286}
287
288// ── UMAP (simplified) ───────────────────────────────────────────────────
289
290/// Simplified UMAP: approximate using nearest-neighbor graph + force-directed layout.
291pub fn umap(vectors: &[Vec<f32>], n_neighbors: usize, min_dist: f32, target_dim: usize) -> Vec<Vec2> {
292    let n = vectors.len();
293    if n == 0 { return vec![]; }
294    let _ = target_dim; // we always produce 2D
295
296    // Compute pairwise distances and build k-NN graph
297    let mut knn: Vec<Vec<(usize, f32)>> = Vec::with_capacity(n);
298    for i in 0..n {
299        let mut dists: Vec<(usize, f32)> = (0..n)
300            .filter(|&j| j != i)
301            .map(|j| {
302                let d: f32 = vectors[i].iter().zip(&vectors[j])
303                    .map(|(a, b)| (a - b) * (a - b)).sum::<f32>().sqrt();
304                (j, d)
305            })
306            .collect();
307        dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
308        dists.truncate(n_neighbors);
309        knn.push(dists);
310    }
311
312    // Build symmetric adjacency weights
313    let mut weights = vec![vec![0.0f32; n]; n];
314    for i in 0..n {
315        let sigma = knn[i].last().map(|&(_, d)| d.max(1e-6)).unwrap_or(1.0);
316        for &(j, d) in &knn[i] {
317            let w = (-d / sigma).exp();
318            weights[i][j] = weights[i][j].max(w);
319            weights[j][i] = weights[j][i].max(w);
320        }
321    }
322
323    // Initialize layout
324    let mut pos: Vec<[f32; 2]> = Vec::with_capacity(n);
325    let mut rng = 123u64;
326    for _ in 0..n {
327        rng ^= rng << 13; rng ^= rng >> 7; rng ^= rng << 17;
328        let x = (rng as u32 as f32 / u32::MAX as f32 - 0.5) * 10.0;
329        rng ^= rng << 13; rng ^= rng >> 7; rng ^= rng << 17;
330        let y = (rng as u32 as f32 / u32::MAX as f32 - 0.5) * 10.0;
331        pos.push([x, y]);
332    }
333
334    // Optimize with attractive/repulsive forces
335    let epochs = 200;
336    let initial_lr = 1.0f32;
337
338    for epoch in 0..epochs {
339        let lr = initial_lr * (1.0 - epoch as f32 / epochs as f32);
340        let mut forces = vec![[0.0f32; 2]; n];
341
342        for i in 0..n {
343            for &(j, _) in &knn[i] {
344                let dx = pos[i][0] - pos[j][0];
345                let dy = pos[i][1] - pos[j][1];
346                let dist = (dx * dx + dy * dy).sqrt().max(min_dist);
347                // Attractive force
348                let attract = -2.0 * weights[i][j] * (dist - min_dist) / dist;
349                forces[i][0] += attract * dx;
350                forces[i][1] += attract * dy;
351            }
352
353            // Repulsive: sample a few random non-neighbors
354            let mut neg_rng = rng.wrapping_add(i as u64);
355            for _ in 0..5.min(n) {
356                neg_rng ^= neg_rng << 13;
357                neg_rng ^= neg_rng >> 7;
358                neg_rng ^= neg_rng << 17;
359                let j = neg_rng as usize % n;
360                if j == i { continue; }
361                let dx = pos[i][0] - pos[j][0];
362                let dy = pos[i][1] - pos[j][1];
363                let dist2 = dx * dx + dy * dy;
364                let repel = 2.0 / (dist2 + 0.01);
365                forces[i][0] += repel * dx;
366                forces[i][1] += repel * dy;
367            }
368        }
369
370        for i in 0..n {
371            pos[i][0] += lr * forces[i][0].clamp(-4.0, 4.0);
372            pos[i][1] += lr * forces[i][1].clamp(-4.0, 4.0);
373        }
374        rng ^= rng << 13; rng ^= rng >> 7; rng ^= rng << 17;
375    }
376
377    pos.iter().map(|p| Vec2::new(p[0], p[1])).collect()
378}
379
380// ── Embedding Renderer ──────────────────────────────────────────────────
381
382/// Renders embedding points as colored glyphs for 2D visualization.
383pub struct EmbeddingRenderer {
384    pub point_size: f32,
385    pub color_by_label: bool,
386}
387
388impl EmbeddingRenderer {
389    pub fn new() -> Self {
390        Self { point_size: 3.0, color_by_label: true }
391    }
392
393    /// Render 2D points to a list of (position, color_rgba).
394    pub fn render(&self, points: &[Vec2], labels: &[String]) -> Vec<(Vec2, [f32; 4])> {
395        let unique_labels: Vec<&String> = {
396            let mut u: Vec<&String> = labels.iter().collect();
397            u.sort();
398            u.dedup();
399            u
400        };
401
402        points.iter().zip(labels).map(|(&pos, label)| {
403            let color = if self.color_by_label {
404                let idx = unique_labels.iter().position(|l| *l == label).unwrap_or(0);
405                label_to_color(idx, unique_labels.len())
406            } else {
407                [1.0, 1.0, 1.0, 1.0]
408            };
409            (pos, color)
410        }).collect()
411    }
412}
413
414fn label_to_color(index: usize, total: usize) -> [f32; 4] {
415    if total == 0 { return [1.0, 1.0, 1.0, 1.0]; }
416    let hue = index as f32 / total as f32;
417    // Simple HSV -> RGB with S=0.8, V=0.9
418    let h = hue * 6.0;
419    let c = 0.9 * 0.8;
420    let x = c * (1.0 - ((h % 2.0) - 1.0).abs());
421    let m = 0.9 - c;
422    let (r, g, b) = match h as u32 {
423        0 => (c, x, 0.0),
424        1 => (x, c, 0.0),
425        2 => (0.0, c, x),
426        3 => (0.0, x, c),
427        4 => (x, 0.0, c),
428        _ => (c, 0.0, x),
429    };
430    [r + m, g + m, b + m, 1.0]
431}
432
433#[cfg(test)]
434mod tests {
435    use super::*;
436
437    #[test]
438    fn test_cosine_similarity_identical() {
439        let a = vec![1.0, 2.0, 3.0];
440        assert!((cosine_similarity(&a, &a) - 1.0).abs() < 1e-5);
441    }
442
443    #[test]
444    fn test_cosine_similarity_orthogonal() {
445        let a = vec![1.0, 0.0];
446        let b = vec![0.0, 1.0];
447        assert!(cosine_similarity(&a, &b).abs() < 1e-5);
448    }
449
450    #[test]
451    fn test_cosine_similarity_opposite() {
452        let a = vec![1.0, 0.0];
453        let b = vec![-1.0, 0.0];
454        assert!((cosine_similarity(&a, &b) + 1.0).abs() < 1e-5);
455    }
456
457    #[test]
458    fn test_nearest_neighbors() {
459        let mut space = EmbeddingSpace::new(3);
460        space.add(vec![1.0, 0.0, 0.0], "a".into());
461        space.add(vec![0.9, 0.1, 0.0], "b".into());
462        space.add(vec![0.0, 1.0, 0.0], "c".into());
463        space.add(vec![0.0, 0.0, 1.0], "d".into());
464
465        let query = vec![1.0, 0.0, 0.0];
466        let nn = nearest_neighbors(&space, &query, 2);
467        assert_eq!(nn.len(), 2);
468        assert_eq!(nn[0].0, 0); // self is most similar
469        assert_eq!(nn[1].0, 1); // "b" is next closest
470    }
471
472    #[test]
473    fn test_pca_reduces_dimensions() {
474        let vectors = vec![
475            vec![1.0, 0.0, 0.0, 0.0],
476            vec![0.0, 1.0, 0.0, 0.0],
477            vec![1.0, 1.0, 0.0, 0.0],
478            vec![0.0, 0.0, 1.0, 0.0],
479            vec![1.0, 0.0, 1.0, 0.0],
480        ];
481        let reduced = pca(&vectors, 2);
482        assert_eq!(reduced.len(), 5);
483        assert_eq!(reduced[0].len(), 2);
484    }
485
486    #[test]
487    fn test_pca_preserves_variance_ordering() {
488        // Data with more variance along first dimension
489        let vectors: Vec<Vec<f32>> = (0..20).map(|i| {
490            vec![i as f32 * 10.0, (i % 3) as f32, 0.5]
491        }).collect();
492        let reduced = pca(&vectors, 2);
493        // First component should capture more variance
494        let var1: f32 = reduced.iter().map(|v| v[0] * v[0]).sum::<f32>() / reduced.len() as f32;
495        let var2: f32 = reduced.iter().map(|v| v[1] * v[1]).sum::<f32>() / reduced.len() as f32;
496        assert!(var1 > var2, "PCA first component variance ({var1}) should exceed second ({var2})");
497    }
498
499    #[test]
500    fn test_tsne_output_size() {
501        let vectors = vec![
502            vec![1.0, 0.0, 0.0],
503            vec![0.0, 1.0, 0.0],
504            vec![0.0, 0.0, 1.0],
505            vec![1.0, 1.0, 0.0],
506        ];
507        let result = tsne(&vectors, 2, 2.0, 50);
508        assert_eq!(result.len(), 4);
509    }
510
511    #[test]
512    fn test_umap_output_size() {
513        let vectors = vec![
514            vec![1.0, 0.0, 0.0],
515            vec![0.0, 1.0, 0.0],
516            vec![0.0, 0.0, 1.0],
517            vec![1.0, 1.0, 0.0],
518            vec![0.0, 1.0, 1.0],
519        ];
520        let result = umap(&vectors, 2, 0.1, 2);
521        assert_eq!(result.len(), 5);
522    }
523
524    #[test]
525    fn test_embedding_renderer() {
526        let renderer = EmbeddingRenderer::new();
527        let points = vec![Vec2::new(0.0, 0.0), Vec2::new(1.0, 1.0)];
528        let labels = vec!["a".to_string(), "b".to_string()];
529        let result = renderer.render(&points, &labels);
530        assert_eq!(result.len(), 2);
531        // Colors should differ for different labels
532        assert_ne!(result[0].1, result[1].1);
533    }
534
535    #[test]
536    fn test_embedding_space() {
537        let mut space = EmbeddingSpace::new(4);
538        assert!(space.is_empty());
539        space.add(vec![1.0, 2.0, 3.0, 4.0], "test".into());
540        assert_eq!(space.len(), 1);
541    }
542}