oxirs_embed/
visualization.rs

1//! Embedding Visualization Tools
2//!
3//! This module provides tools for visualizing knowledge graph embeddings in 2D/3D space
4//! using dimensionality reduction techniques like t-SNE, UMAP, and PCA.
5
6use anyhow::{anyhow, Result};
7use scirs2_core::ndarray_ext::{Array1, Array2};
8use scirs2_core::random::prelude::{Normal, Random};
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use tracing::{debug, info};
12
13/// Dimensionality reduction method
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
15pub enum ReductionMethod {
16    /// Principal Component Analysis
17    PCA,
18    /// t-Distributed Stochastic Neighbor Embedding
19    TSNE,
20    /// Uniform Manifold Approximation and Projection
21    UMAP,
22    /// Random Projection (fast but less accurate)
23    RandomProjection,
24}
25
26/// Visualization configuration
27#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct VisualizationConfig {
29    /// Dimensionality reduction method
30    pub method: ReductionMethod,
31    /// Target dimensions (2 or 3)
32    pub target_dims: usize,
33    /// Perplexity for t-SNE (typically 5-50)
34    pub tsne_perplexity: f32,
35    /// Learning rate for t-SNE
36    pub tsne_learning_rate: f32,
37    /// Number of iterations for iterative methods
38    pub max_iterations: usize,
39    /// Random seed for reproducibility
40    pub random_seed: Option<u64>,
41    /// Number of neighbors for UMAP
42    pub umap_n_neighbors: usize,
43    /// Minimum distance for UMAP
44    pub umap_min_dist: f32,
45}
46
47impl Default for VisualizationConfig {
48    fn default() -> Self {
49        Self {
50            method: ReductionMethod::PCA,
51            target_dims: 2,
52            tsne_perplexity: 30.0,
53            tsne_learning_rate: 200.0,
54            max_iterations: 1000,
55            random_seed: None,
56            umap_n_neighbors: 15,
57            umap_min_dist: 0.1,
58        }
59    }
60}
61
62/// Visualization result with 2D/3D coordinates
63#[derive(Debug, Clone, Serialize, Deserialize)]
64pub struct VisualizationResult {
65    /// Entity ID to coordinates mapping
66    pub coordinates: HashMap<String, Vec<f32>>,
67    /// Number of dimensions
68    pub dimensions: usize,
69    /// Method used
70    pub method: ReductionMethod,
71    /// Explained variance (for PCA)
72    pub explained_variance: Option<Vec<f32>>,
73    /// Final stress/loss (for t-SNE)
74    pub final_loss: Option<f32>,
75}
76
77/// Embedding visualizer
78pub struct EmbeddingVisualizer {
79    config: VisualizationConfig,
80    rng: Random,
81}
82
83impl EmbeddingVisualizer {
84    /// Create new embedding visualizer
85    pub fn new(config: VisualizationConfig) -> Self {
86        let rng = Random::default();
87
88        info!(
89            "Initialized embedding visualizer: method={:?}, target_dims={}",
90            config.method, config.target_dims
91        );
92
93        Self { config, rng }
94    }
95
96    /// Visualize embeddings
97    pub fn visualize(
98        &mut self,
99        embeddings: &HashMap<String, Array1<f32>>,
100    ) -> Result<VisualizationResult> {
101        if embeddings.is_empty() {
102            return Err(anyhow!("No embeddings to visualize"));
103        }
104
105        if self.config.target_dims != 2 && self.config.target_dims != 3 {
106            return Err(anyhow!("Target dimensions must be 2 or 3"));
107        }
108
109        info!("Visualizing {} embeddings", embeddings.len());
110
111        match self.config.method {
112            ReductionMethod::PCA => self.pca(embeddings),
113            ReductionMethod::TSNE => self.tsne(embeddings),
114            ReductionMethod::UMAP => self.umap_approximate(embeddings),
115            ReductionMethod::RandomProjection => self.random_projection(embeddings),
116        }
117    }
118
119    /// PCA dimensionality reduction
120    fn pca(&mut self, embeddings: &HashMap<String, Array1<f32>>) -> Result<VisualizationResult> {
121        let entity_list: Vec<String> = embeddings.keys().cloned().collect();
122        let n = entity_list.len();
123        let d = embeddings.values().next().unwrap().len();
124
125        // Build data matrix (n x d)
126        let mut data_matrix = Array2::zeros((n, d));
127        for (i, entity) in entity_list.iter().enumerate() {
128            let emb = &embeddings[entity];
129            for j in 0..d {
130                data_matrix[[i, j]] = emb[j];
131            }
132        }
133
134        // Center the data
135        let mean = self.compute_mean(&data_matrix);
136        for i in 0..n {
137            for j in 0..d {
138                data_matrix[[i, j]] -= mean[j];
139            }
140        }
141
142        // Compute covariance matrix (d x d)
143        let cov_matrix = self.compute_covariance(&data_matrix);
144
145        // Find principal components using power iteration
146        let (eigenvectors, eigenvalues) =
147            self.power_iteration_top_k(&cov_matrix, self.config.target_dims)?;
148
149        // Project data onto principal components
150        let mut coordinates = HashMap::new();
151        for (i, entity) in entity_list.iter().enumerate() {
152            let mut projected = vec![0.0; self.config.target_dims];
153            for k in 0..self.config.target_dims {
154                let mut dot_product = 0.0;
155                for j in 0..d {
156                    dot_product += data_matrix[[i, j]] * eigenvectors[[j, k]];
157                }
158                projected[k] = dot_product;
159            }
160            coordinates.insert(entity.clone(), projected);
161        }
162
163        // Compute explained variance
164        let total_variance: f32 = eigenvalues.iter().sum();
165        let explained_variance: Vec<f32> =
166            eigenvalues.iter().map(|&ev| ev / total_variance).collect();
167
168        info!(
169            "PCA complete: explained variance = {:?}",
170            explained_variance
171        );
172
173        Ok(VisualizationResult {
174            coordinates,
175            dimensions: self.config.target_dims,
176            method: ReductionMethod::PCA,
177            explained_variance: Some(explained_variance),
178            final_loss: None,
179        })
180    }
181
182    /// t-SNE dimensionality reduction (simplified implementation)
183    fn tsne(&mut self, embeddings: &HashMap<String, Array1<f32>>) -> Result<VisualizationResult> {
184        let entity_list: Vec<String> = embeddings.keys().cloned().collect();
185        let n = entity_list.len();
186
187        // Initialize low-dimensional representation randomly
188        let dist = Normal::new(0.0, 0.01).unwrap();
189        let mut y = Array2::from_shape_fn((n, self.config.target_dims), |_| self.rng.sample(dist));
190
191        // Compute pairwise affinities in high-dimensional space
192        let p = self.compute_affinities(embeddings, &entity_list);
193
194        // Gradient descent
195        let mut final_loss = 0.0;
196        for iteration in 0..self.config.max_iterations {
197            // Compute pairwise affinities in low-dimensional space
198            let q = self.compute_low_dim_affinities(&y);
199
200            // Compute gradient
201            let grad = self.compute_tsne_gradient(&y, &p, &q);
202
203            // Update positions
204            for i in 0..n {
205                for j in 0..self.config.target_dims {
206                    y[[i, j]] -= self.config.tsne_learning_rate * grad[[i, j]];
207                }
208            }
209
210            // Compute KL divergence (loss)
211            if iteration % 100 == 0 {
212                final_loss = self.compute_kl_divergence(&p, &q);
213                debug!("t-SNE iteration {}: loss = {:.6}", iteration, final_loss);
214            }
215        }
216
217        // Extract coordinates
218        let mut coordinates = HashMap::new();
219        for (i, entity) in entity_list.iter().enumerate() {
220            let mut coords = vec![0.0; self.config.target_dims];
221            for j in 0..self.config.target_dims {
222                coords[j] = y[[i, j]];
223            }
224            coordinates.insert(entity.clone(), coords);
225        }
226
227        info!("t-SNE complete: final loss = {:.6}", final_loss);
228
229        Ok(VisualizationResult {
230            coordinates,
231            dimensions: self.config.target_dims,
232            method: ReductionMethod::TSNE,
233            explained_variance: None,
234            final_loss: Some(final_loss),
235        })
236    }
237
238    /// UMAP (simplified/approximate implementation)
239    fn umap_approximate(
240        &mut self,
241        embeddings: &HashMap<String, Array1<f32>>,
242    ) -> Result<VisualizationResult> {
243        // For a full UMAP implementation, we'd need complex graph construction and optimization
244        // This is a simplified approximation using PCA followed by force-directed layout
245
246        info!("Using approximate UMAP (PCA + refinement)");
247
248        // Start with PCA
249        let mut result = self.pca(embeddings)?;
250
251        // Apply force-directed refinement
252        let entity_list: Vec<String> = embeddings.keys().cloned().collect();
253        let n = entity_list.len();
254
255        // Build k-nearest neighbor graph
256        let knn_graph =
257            self.build_knn_graph(embeddings, &entity_list, self.config.umap_n_neighbors);
258
259        // Refine positions using force-directed layout
260        for _iteration in 0..100 {
261            for i in 0..n {
262                let entity = &entity_list[i];
263                let pos = &result.coordinates[entity].clone();
264
265                // Attractive forces from neighbors
266                let mut force = vec![0.0; self.config.target_dims];
267                for &neighbor_idx in &knn_graph[i] {
268                    let neighbor = &entity_list[neighbor_idx];
269                    let neighbor_pos = &result.coordinates[neighbor];
270
271                    for d in 0..self.config.target_dims {
272                        let diff = neighbor_pos[d] - pos[d];
273                        force[d] += diff * 0.01; // Attraction
274                    }
275                }
276
277                // Update position
278                let coords = result.coordinates.get_mut(entity).unwrap();
279                for d in 0..self.config.target_dims {
280                    coords[d] += force[d];
281                }
282            }
283        }
284
285        result.method = ReductionMethod::UMAP;
286        info!("Approximate UMAP complete");
287
288        Ok(result)
289    }
290
291    /// Random projection (fast dimensionality reduction)
292    fn random_projection(
293        &mut self,
294        embeddings: &HashMap<String, Array1<f32>>,
295    ) -> Result<VisualizationResult> {
296        let entity_list: Vec<String> = embeddings.keys().cloned().collect();
297        let d = embeddings.values().next().unwrap().len();
298
299        // Generate random projection matrix
300        let dist = Normal::new(0.0, 1.0).unwrap();
301        let projection_matrix =
302            Array2::from_shape_fn((d, self.config.target_dims), |_| self.rng.sample(dist));
303
304        // Project each embedding
305        let mut coordinates = HashMap::new();
306        for entity in &entity_list {
307            let emb = &embeddings[entity];
308            let mut projected = vec![0.0; self.config.target_dims];
309
310            for k in 0..self.config.target_dims {
311                let mut dot_product = 0.0;
312                for j in 0..d {
313                    dot_product += emb[j] * projection_matrix[[j, k]];
314                }
315                projected[k] = dot_product;
316            }
317
318            coordinates.insert(entity.clone(), projected);
319        }
320
321        info!("Random projection complete");
322
323        Ok(VisualizationResult {
324            coordinates,
325            dimensions: self.config.target_dims,
326            method: ReductionMethod::RandomProjection,
327            explained_variance: None,
328            final_loss: None,
329        })
330    }
331
332    /// Compute mean of each dimension
333    fn compute_mean(&self, data: &Array2<f32>) -> Vec<f32> {
334        let n = data.nrows();
335        let d = data.ncols();
336        let mut mean = vec![0.0; d];
337
338        for j in 0..d {
339            for i in 0..n {
340                mean[j] += data[[i, j]];
341            }
342            mean[j] /= n as f32;
343        }
344
345        mean
346    }
347
348    /// Compute covariance matrix
349    fn compute_covariance(&self, data: &Array2<f32>) -> Array2<f32> {
350        let n = data.nrows() as f32;
351        let d = data.ncols();
352        let mut cov = Array2::zeros((d, d));
353
354        for i in 0..d {
355            for j in 0..d {
356                let mut sum = 0.0;
357                for k in 0..data.nrows() {
358                    sum += data[[k, i]] * data[[k, j]];
359                }
360                cov[[i, j]] = sum / (n - 1.0);
361            }
362        }
363
364        cov
365    }
366
367    /// Power iteration to find top-k eigenvectors
368    fn power_iteration_top_k(
369        &mut self,
370        matrix: &Array2<f32>,
371        k: usize,
372    ) -> Result<(Array2<f32>, Vec<f32>)> {
373        let d = matrix.nrows();
374        let mut eigenvectors = Array2::zeros((d, k));
375        let mut eigenvalues = Vec::new();
376
377        let mut working_matrix = matrix.clone();
378
379        for component in 0..k {
380            // Initialize random vector
381            let dist = Normal::new(0.0f32, 1.0f32).unwrap();
382            let mut v = Array1::from_shape_fn(d, |_| self.rng.sample(dist));
383
384            // Power iteration
385            for _ in 0..100 {
386                // Multiply by matrix
387                let mut new_v = Array1::<f32>::zeros(d);
388                for i in 0..d {
389                    for j in 0..d {
390                        new_v[i] += working_matrix[[i, j]] * v[j];
391                    }
392                }
393
394                // Normalize
395                let norm = new_v.dot(&new_v).sqrt();
396                if norm > 0.0 {
397                    v = new_v / norm;
398                }
399            }
400
401            // Compute eigenvalue
402            let mut av = Array1::<f32>::zeros(d);
403            for i in 0..d {
404                for j in 0..d {
405                    av[i] += working_matrix[[i, j]] * v[j];
406                }
407            }
408            let eigenvalue = v.dot(&av);
409            eigenvalues.push(eigenvalue);
410
411            // Store eigenvector
412            for i in 0..d {
413                eigenvectors[[i, component]] = v[i];
414            }
415
416            // Deflate matrix for next component
417            for i in 0..d {
418                for j in 0..d {
419                    working_matrix[[i, j]] -= eigenvalue * v[i] * v[j];
420                }
421            }
422        }
423
424        Ok((eigenvectors, eigenvalues))
425    }
426
427    /// Compute affinities for t-SNE
428    fn compute_affinities(
429        &self,
430        embeddings: &HashMap<String, Array1<f32>>,
431        entity_list: &[String],
432    ) -> Array2<f32> {
433        let n = entity_list.len();
434        let mut p = Array2::zeros((n, n));
435
436        // Compute pairwise distances
437        for i in 0..n {
438            for j in 0..n {
439                if i != j {
440                    let dist = self.euclidean_distance(
441                        &embeddings[&entity_list[i]],
442                        &embeddings[&entity_list[j]],
443                    );
444                    // Gaussian kernel
445                    p[[i, j]] = (-dist * dist / (2.0 * self.config.tsne_perplexity)).exp();
446                }
447            }
448
449            // Normalize row
450            let row_sum: f32 = (0..n).map(|j| p[[i, j]]).sum();
451            if row_sum > 0.0 {
452                for j in 0..n {
453                    p[[i, j]] /= row_sum;
454                }
455            }
456        }
457
458        // Symmetrize
459        for i in 0..n {
460            for j in 0..n {
461                p[[i, j]] = (p[[i, j]] + p[[j, i]]) / (2.0 * n as f32);
462            }
463        }
464
465        p
466    }
467
468    /// Compute low-dimensional affinities
469    fn compute_low_dim_affinities(&self, y: &Array2<f32>) -> Array2<f32> {
470        let n = y.nrows();
471        let mut q = Array2::zeros((n, n));
472
473        for i in 0..n {
474            for j in 0..n {
475                if i != j {
476                    let mut dist_sq = 0.0;
477                    for k in 0..y.ncols() {
478                        let diff = y[[i, k]] - y[[j, k]];
479                        dist_sq += diff * diff;
480                    }
481                    q[[i, j]] = 1.0 / (1.0 + dist_sq);
482                }
483            }
484        }
485
486        // Normalize
487        let sum: f32 = q.iter().sum();
488        if sum > 0.0 {
489            q /= sum;
490        }
491
492        q
493    }
494
495    /// Compute t-SNE gradient
496    fn compute_tsne_gradient(
497        &self,
498        y: &Array2<f32>,
499        p: &Array2<f32>,
500        q: &Array2<f32>,
501    ) -> Array2<f32> {
502        let n = y.nrows();
503        let d = y.ncols();
504        let mut grad = Array2::zeros((n, d));
505
506        for i in 0..n {
507            for j in 0..n {
508                if i != j {
509                    let pq_diff = p[[i, j]] - q[[i, j]];
510                    let q_val = q[[i, j]];
511
512                    for k in 0..d {
513                        let y_diff = y[[i, k]] - y[[j, k]];
514                        grad[[i, k]] += 4.0 * pq_diff * y_diff * q_val;
515                    }
516                }
517            }
518        }
519
520        grad
521    }
522
523    /// Compute KL divergence
524    fn compute_kl_divergence(&self, p: &Array2<f32>, q: &Array2<f32>) -> f32 {
525        let mut kl = 0.0;
526        for i in 0..p.nrows() {
527            for j in 0..p.ncols() {
528                if p[[i, j]] > 0.0 && q[[i, j]] > 0.0 {
529                    kl += p[[i, j]] * (p[[i, j]] / q[[i, j]]).ln();
530                }
531            }
532        }
533        kl
534    }
535
536    /// Euclidean distance between two embeddings
537    fn euclidean_distance(&self, a: &Array1<f32>, b: &Array1<f32>) -> f32 {
538        let diff = a - b;
539        diff.dot(&diff).sqrt()
540    }
541
542    /// Build k-nearest neighbor graph
543    fn build_knn_graph(
544        &self,
545        embeddings: &HashMap<String, Array1<f32>>,
546        entity_list: &[String],
547        k: usize,
548    ) -> Vec<Vec<usize>> {
549        let n = entity_list.len();
550        let mut knn_graph = Vec::new();
551
552        for i in 0..n {
553            let entity = &entity_list[i];
554            let emb = &embeddings[entity];
555
556            // Compute distances to all other entities
557            let mut distances: Vec<(usize, f32)> = (0..n)
558                .filter(|&j| j != i)
559                .map(|j| {
560                    let other_emb = &embeddings[&entity_list[j]];
561                    let dist = self.euclidean_distance(emb, other_emb);
562                    (j, dist)
563                })
564                .collect();
565
566            // Sort by distance and take top-k
567            distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
568            let neighbors: Vec<usize> = distances.iter().take(k).map(|(idx, _)| *idx).collect();
569            knn_graph.push(neighbors);
570        }
571
572        knn_graph
573    }
574
575    /// Export visualization to JSON
576    pub fn export_json(&self, result: &VisualizationResult) -> Result<String> {
577        serde_json::to_string_pretty(result)
578            .map_err(|e| anyhow!("Failed to serialize visualization: {}", e))
579    }
580
581    /// Export visualization to CSV
582    pub fn export_csv(&self, result: &VisualizationResult) -> Result<String> {
583        let mut csv = String::from("entity");
584        for i in 0..result.dimensions {
585            csv.push_str(&format!(",dim{}", i + 1));
586        }
587        csv.push('\n');
588
589        for (entity, coords) in &result.coordinates {
590            csv.push_str(entity);
591            for coord in coords {
592                csv.push_str(&format!(",{}", coord));
593            }
594            csv.push('\n');
595        }
596
597        Ok(csv)
598    }
599}
600
601#[cfg(test)]
602mod tests {
603    use super::*;
604    use scirs2_core::ndarray_ext::array;
605
606    #[test]
607    fn test_pca_visualization() {
608        let mut embeddings = HashMap::new();
609        embeddings.insert("e1".to_string(), array![1.0, 0.0, 0.0, 0.0]);
610        embeddings.insert("e2".to_string(), array![0.0, 1.0, 0.0, 0.0]);
611        embeddings.insert("e3".to_string(), array![0.0, 0.0, 1.0, 0.0]);
612        embeddings.insert("e4".to_string(), array![0.0, 0.0, 0.0, 1.0]);
613
614        let config = VisualizationConfig {
615            method: ReductionMethod::PCA,
616            target_dims: 2,
617            ..Default::default()
618        };
619
620        let mut visualizer = EmbeddingVisualizer::new(config);
621        let result = visualizer.visualize(&embeddings).unwrap();
622
623        assert_eq!(result.coordinates.len(), 4);
624        assert_eq!(result.dimensions, 2);
625        assert!(result.explained_variance.is_some());
626    }
627
628    #[test]
629    fn test_random_projection() {
630        let mut embeddings = HashMap::new();
631        for i in 0..10 {
632            let emb = Array1::from_vec(vec![i as f32; 100]);
633            embeddings.insert(format!("e{}", i), emb);
634        }
635
636        let config = VisualizationConfig {
637            method: ReductionMethod::RandomProjection,
638            target_dims: 3,
639            ..Default::default()
640        };
641
642        let mut visualizer = EmbeddingVisualizer::new(config);
643        let result = visualizer.visualize(&embeddings).unwrap();
644
645        assert_eq!(result.coordinates.len(), 10);
646        assert_eq!(result.dimensions, 3);
647    }
648
649    #[test]
650    fn test_export_csv() {
651        let mut coordinates = HashMap::new();
652        coordinates.insert("e1".to_string(), vec![1.0, 2.0]);
653        coordinates.insert("e2".to_string(), vec![3.0, 4.0]);
654
655        let result = VisualizationResult {
656            coordinates,
657            dimensions: 2,
658            method: ReductionMethod::PCA,
659            explained_variance: None,
660            final_loss: None,
661        };
662
663        let config = VisualizationConfig::default();
664        let visualizer = EmbeddingVisualizer::new(config);
665        let csv = visualizer.export_csv(&result).unwrap();
666
667        assert!(csv.contains("entity,dim1,dim2"));
668        assert!(csv.contains("e1,1,2"));
669    }
670}