Skip to main content

oxirs_embed/
visualization.rs

1//! Embedding Visualization Tools
2//!
3//! This module provides tools for visualizing knowledge graph embeddings in 2D/3D space
4//! using dimensionality reduction techniques like t-SNE, UMAP, and PCA.
5
6use anyhow::{anyhow, Result};
7use scirs2_core::ndarray_ext::{Array1, Array2};
8use scirs2_core::random::prelude::{Normal, Random};
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use tracing::{debug, info};
12
13/// Dimensionality reduction method
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
15pub enum ReductionMethod {
16    /// Principal Component Analysis
17    PCA,
18    /// t-Distributed Stochastic Neighbor Embedding
19    TSNE,
20    /// Uniform Manifold Approximation and Projection
21    UMAP,
22    /// Random Projection (fast but less accurate)
23    RandomProjection,
24}
25
26/// Visualization configuration
27#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct VisualizationConfig {
29    /// Dimensionality reduction method
30    pub method: ReductionMethod,
31    /// Target dimensions (2 or 3)
32    pub target_dims: usize,
33    /// Perplexity for t-SNE (typically 5-50)
34    pub tsne_perplexity: f32,
35    /// Learning rate for t-SNE
36    pub tsne_learning_rate: f32,
37    /// Number of iterations for iterative methods
38    pub max_iterations: usize,
39    /// Random seed for reproducibility
40    pub random_seed: Option<u64>,
41    /// Number of neighbors for UMAP
42    pub umap_n_neighbors: usize,
43    /// Minimum distance for UMAP
44    pub umap_min_dist: f32,
45}
46
47impl Default for VisualizationConfig {
48    fn default() -> Self {
49        Self {
50            method: ReductionMethod::PCA,
51            target_dims: 2,
52            tsne_perplexity: 30.0,
53            tsne_learning_rate: 200.0,
54            max_iterations: 1000,
55            random_seed: None,
56            umap_n_neighbors: 15,
57            umap_min_dist: 0.1,
58        }
59    }
60}
61
62/// Visualization result with 2D/3D coordinates
63#[derive(Debug, Clone, Serialize, Deserialize)]
64pub struct VisualizationResult {
65    /// Entity ID to coordinates mapping
66    pub coordinates: HashMap<String, Vec<f32>>,
67    /// Number of dimensions
68    pub dimensions: usize,
69    /// Method used
70    pub method: ReductionMethod,
71    /// Explained variance (for PCA)
72    pub explained_variance: Option<Vec<f32>>,
73    /// Final stress/loss (for t-SNE)
74    pub final_loss: Option<f32>,
75}
76
77/// Embedding visualizer
78pub struct EmbeddingVisualizer {
79    config: VisualizationConfig,
80    rng: Random,
81}
82
83impl EmbeddingVisualizer {
84    /// Create new embedding visualizer
85    pub fn new(config: VisualizationConfig) -> Self {
86        let rng = Random::default();
87
88        info!(
89            "Initialized embedding visualizer: method={:?}, target_dims={}",
90            config.method, config.target_dims
91        );
92
93        Self { config, rng }
94    }
95
96    /// Visualize embeddings
97    pub fn visualize(
98        &mut self,
99        embeddings: &HashMap<String, Array1<f32>>,
100    ) -> Result<VisualizationResult> {
101        if embeddings.is_empty() {
102            return Err(anyhow!("No embeddings to visualize"));
103        }
104
105        if self.config.target_dims != 2 && self.config.target_dims != 3 {
106            return Err(anyhow!("Target dimensions must be 2 or 3"));
107        }
108
109        info!("Visualizing {} embeddings", embeddings.len());
110
111        match self.config.method {
112            ReductionMethod::PCA => self.pca(embeddings),
113            ReductionMethod::TSNE => self.tsne(embeddings),
114            ReductionMethod::UMAP => self.umap_approximate(embeddings),
115            ReductionMethod::RandomProjection => self.random_projection(embeddings),
116        }
117    }
118
119    /// PCA dimensionality reduction
120    fn pca(&mut self, embeddings: &HashMap<String, Array1<f32>>) -> Result<VisualizationResult> {
121        let entity_list: Vec<String> = embeddings.keys().cloned().collect();
122        let n = entity_list.len();
123        let d = embeddings
124            .values()
125            .next()
126            .expect("embeddings should not be empty")
127            .len();
128
129        // Build data matrix (n x d)
130        let mut data_matrix = Array2::zeros((n, d));
131        for (i, entity) in entity_list.iter().enumerate() {
132            let emb = &embeddings[entity];
133            for j in 0..d {
134                data_matrix[[i, j]] = emb[j];
135            }
136        }
137
138        // Center the data
139        let mean = self.compute_mean(&data_matrix);
140        for i in 0..n {
141            for j in 0..d {
142                data_matrix[[i, j]] -= mean[j];
143            }
144        }
145
146        // Compute covariance matrix (d x d)
147        let cov_matrix = self.compute_covariance(&data_matrix);
148
149        // Find principal components using power iteration
150        let (eigenvectors, eigenvalues) =
151            self.power_iteration_top_k(&cov_matrix, self.config.target_dims)?;
152
153        // Project data onto principal components
154        let mut coordinates = HashMap::new();
155        for (i, entity) in entity_list.iter().enumerate() {
156            let mut projected = vec![0.0; self.config.target_dims];
157            for k in 0..self.config.target_dims {
158                let mut dot_product = 0.0;
159                for j in 0..d {
160                    dot_product += data_matrix[[i, j]] * eigenvectors[[j, k]];
161                }
162                projected[k] = dot_product;
163            }
164            coordinates.insert(entity.clone(), projected);
165        }
166
167        // Compute explained variance
168        let total_variance: f32 = eigenvalues.iter().sum();
169        let explained_variance: Vec<f32> =
170            eigenvalues.iter().map(|&ev| ev / total_variance).collect();
171
172        info!(
173            "PCA complete: explained variance = {:?}",
174            explained_variance
175        );
176
177        Ok(VisualizationResult {
178            coordinates,
179            dimensions: self.config.target_dims,
180            method: ReductionMethod::PCA,
181            explained_variance: Some(explained_variance),
182            final_loss: None,
183        })
184    }
185
186    /// t-SNE dimensionality reduction (simplified implementation)
187    fn tsne(&mut self, embeddings: &HashMap<String, Array1<f32>>) -> Result<VisualizationResult> {
188        let entity_list: Vec<String> = embeddings.keys().cloned().collect();
189        let n = entity_list.len();
190
191        // Initialize low-dimensional representation randomly
192        let dist = Normal::new(0.0, 0.01)
193            .expect("Normal distribution should be valid for these parameters");
194        let mut y = Array2::from_shape_fn((n, self.config.target_dims), |_| self.rng.sample(dist));
195
196        // Compute pairwise affinities in high-dimensional space
197        let p = self.compute_affinities(embeddings, &entity_list);
198
199        // Gradient descent
200        let mut final_loss = 0.0;
201        for iteration in 0..self.config.max_iterations {
202            // Compute pairwise affinities in low-dimensional space
203            let q = self.compute_low_dim_affinities(&y);
204
205            // Compute gradient
206            let grad = self.compute_tsne_gradient(&y, &p, &q);
207
208            // Update positions
209            for i in 0..n {
210                for j in 0..self.config.target_dims {
211                    y[[i, j]] -= self.config.tsne_learning_rate * grad[[i, j]];
212                }
213            }
214
215            // Compute KL divergence (loss)
216            if iteration % 100 == 0 {
217                final_loss = self.compute_kl_divergence(&p, &q);
218                debug!("t-SNE iteration {}: loss = {:.6}", iteration, final_loss);
219            }
220        }
221
222        // Extract coordinates
223        let mut coordinates = HashMap::new();
224        for (i, entity) in entity_list.iter().enumerate() {
225            let mut coords = vec![0.0; self.config.target_dims];
226            for j in 0..self.config.target_dims {
227                coords[j] = y[[i, j]];
228            }
229            coordinates.insert(entity.clone(), coords);
230        }
231
232        info!("t-SNE complete: final loss = {:.6}", final_loss);
233
234        Ok(VisualizationResult {
235            coordinates,
236            dimensions: self.config.target_dims,
237            method: ReductionMethod::TSNE,
238            explained_variance: None,
239            final_loss: Some(final_loss),
240        })
241    }
242
243    /// UMAP (simplified/approximate implementation)
244    fn umap_approximate(
245        &mut self,
246        embeddings: &HashMap<String, Array1<f32>>,
247    ) -> Result<VisualizationResult> {
248        // For a full UMAP implementation, we'd need complex graph construction and optimization
249        // This is a simplified approximation using PCA followed by force-directed layout
250
251        info!("Using approximate UMAP (PCA + refinement)");
252
253        // Start with PCA
254        let mut result = self.pca(embeddings)?;
255
256        // Apply force-directed refinement
257        let entity_list: Vec<String> = embeddings.keys().cloned().collect();
258        let n = entity_list.len();
259
260        // Build k-nearest neighbor graph
261        let knn_graph =
262            self.build_knn_graph(embeddings, &entity_list, self.config.umap_n_neighbors);
263
264        // Refine positions using force-directed layout
265        for _iteration in 0..100 {
266            for i in 0..n {
267                let entity = &entity_list[i];
268                let pos = &result.coordinates[entity].clone();
269
270                // Attractive forces from neighbors
271                let mut force = vec![0.0; self.config.target_dims];
272                for &neighbor_idx in &knn_graph[i] {
273                    let neighbor = &entity_list[neighbor_idx];
274                    let neighbor_pos = &result.coordinates[neighbor];
275
276                    for d in 0..self.config.target_dims {
277                        let diff = neighbor_pos[d] - pos[d];
278                        force[d] += diff * 0.01; // Attraction
279                    }
280                }
281
282                // Update position
283                let coords = result
284                    .coordinates
285                    .get_mut(entity)
286                    .expect("entity should exist in coordinates");
287                for d in 0..self.config.target_dims {
288                    coords[d] += force[d];
289                }
290            }
291        }
292
293        result.method = ReductionMethod::UMAP;
294        info!("Approximate UMAP complete");
295
296        Ok(result)
297    }
298
299    /// Random projection (fast dimensionality reduction)
300    fn random_projection(
301        &mut self,
302        embeddings: &HashMap<String, Array1<f32>>,
303    ) -> Result<VisualizationResult> {
304        let entity_list: Vec<String> = embeddings.keys().cloned().collect();
305        let d = embeddings
306            .values()
307            .next()
308            .expect("embeddings should not be empty")
309            .len();
310
311        // Generate random projection matrix
312        let dist = Normal::new(0.0, 1.0)
313            .expect("Normal distribution should be valid for these parameters");
314        let projection_matrix =
315            Array2::from_shape_fn((d, self.config.target_dims), |_| self.rng.sample(dist));
316
317        // Project each embedding
318        let mut coordinates = HashMap::new();
319        for entity in &entity_list {
320            let emb = &embeddings[entity];
321            let mut projected = vec![0.0; self.config.target_dims];
322
323            for k in 0..self.config.target_dims {
324                let mut dot_product = 0.0;
325                for j in 0..d {
326                    dot_product += emb[j] * projection_matrix[[j, k]];
327                }
328                projected[k] = dot_product;
329            }
330
331            coordinates.insert(entity.clone(), projected);
332        }
333
334        info!("Random projection complete");
335
336        Ok(VisualizationResult {
337            coordinates,
338            dimensions: self.config.target_dims,
339            method: ReductionMethod::RandomProjection,
340            explained_variance: None,
341            final_loss: None,
342        })
343    }
344
345    /// Compute mean of each dimension
346    fn compute_mean(&self, data: &Array2<f32>) -> Vec<f32> {
347        let n = data.nrows();
348        let d = data.ncols();
349        let mut mean = vec![0.0; d];
350
351        for j in 0..d {
352            for i in 0..n {
353                mean[j] += data[[i, j]];
354            }
355            mean[j] /= n as f32;
356        }
357
358        mean
359    }
360
361    /// Compute covariance matrix
362    fn compute_covariance(&self, data: &Array2<f32>) -> Array2<f32> {
363        let n = data.nrows() as f32;
364        let d = data.ncols();
365        let mut cov = Array2::zeros((d, d));
366
367        for i in 0..d {
368            for j in 0..d {
369                let mut sum = 0.0;
370                for k in 0..data.nrows() {
371                    sum += data[[k, i]] * data[[k, j]];
372                }
373                cov[[i, j]] = sum / (n - 1.0);
374            }
375        }
376
377        cov
378    }
379
380    /// Power iteration to find top-k eigenvectors
381    fn power_iteration_top_k(
382        &mut self,
383        matrix: &Array2<f32>,
384        k: usize,
385    ) -> Result<(Array2<f32>, Vec<f32>)> {
386        let d = matrix.nrows();
387        let mut eigenvectors = Array2::zeros((d, k));
388        let mut eigenvalues = Vec::new();
389
390        let mut working_matrix = matrix.clone();
391
392        for component in 0..k {
393            // Initialize random vector
394            let dist = Normal::new(0.0f32, 1.0f32)
395                .expect("Normal distribution should be valid for these parameters");
396            let mut v = Array1::from_shape_fn(d, |_| self.rng.sample(dist));
397
398            // Power iteration
399            for _ in 0..100 {
400                // Multiply by matrix
401                let mut new_v = Array1::<f32>::zeros(d);
402                for i in 0..d {
403                    for j in 0..d {
404                        new_v[i] += working_matrix[[i, j]] * v[j];
405                    }
406                }
407
408                // Normalize
409                let norm = new_v.dot(&new_v).sqrt();
410                if norm > 0.0 {
411                    v = new_v / norm;
412                }
413            }
414
415            // Compute eigenvalue
416            let mut av = Array1::<f32>::zeros(d);
417            for i in 0..d {
418                for j in 0..d {
419                    av[i] += working_matrix[[i, j]] * v[j];
420                }
421            }
422            let eigenvalue = v.dot(&av);
423            eigenvalues.push(eigenvalue);
424
425            // Store eigenvector
426            for i in 0..d {
427                eigenvectors[[i, component]] = v[i];
428            }
429
430            // Deflate matrix for next component
431            for i in 0..d {
432                for j in 0..d {
433                    working_matrix[[i, j]] -= eigenvalue * v[i] * v[j];
434                }
435            }
436        }
437
438        Ok((eigenvectors, eigenvalues))
439    }
440
441    /// Compute affinities for t-SNE
442    fn compute_affinities(
443        &self,
444        embeddings: &HashMap<String, Array1<f32>>,
445        entity_list: &[String],
446    ) -> Array2<f32> {
447        let n = entity_list.len();
448        let mut p = Array2::zeros((n, n));
449
450        // Compute pairwise distances
451        for i in 0..n {
452            for j in 0..n {
453                if i != j {
454                    let dist = self.euclidean_distance(
455                        &embeddings[&entity_list[i]],
456                        &embeddings[&entity_list[j]],
457                    );
458                    // Gaussian kernel
459                    p[[i, j]] = (-dist * dist / (2.0 * self.config.tsne_perplexity)).exp();
460                }
461            }
462
463            // Normalize row
464            let row_sum: f32 = (0..n).map(|j| p[[i, j]]).sum();
465            if row_sum > 0.0 {
466                for j in 0..n {
467                    p[[i, j]] /= row_sum;
468                }
469            }
470        }
471
472        // Symmetrize
473        for i in 0..n {
474            for j in 0..n {
475                p[[i, j]] = (p[[i, j]] + p[[j, i]]) / (2.0 * n as f32);
476            }
477        }
478
479        p
480    }
481
482    /// Compute low-dimensional affinities
483    fn compute_low_dim_affinities(&self, y: &Array2<f32>) -> Array2<f32> {
484        let n = y.nrows();
485        let mut q = Array2::zeros((n, n));
486
487        for i in 0..n {
488            for j in 0..n {
489                if i != j {
490                    let mut dist_sq = 0.0;
491                    for k in 0..y.ncols() {
492                        let diff = y[[i, k]] - y[[j, k]];
493                        dist_sq += diff * diff;
494                    }
495                    q[[i, j]] = 1.0 / (1.0 + dist_sq);
496                }
497            }
498        }
499
500        // Normalize
501        let sum: f32 = q.iter().sum();
502        if sum > 0.0 {
503            q /= sum;
504        }
505
506        q
507    }
508
509    /// Compute t-SNE gradient
510    fn compute_tsne_gradient(
511        &self,
512        y: &Array2<f32>,
513        p: &Array2<f32>,
514        q: &Array2<f32>,
515    ) -> Array2<f32> {
516        let n = y.nrows();
517        let d = y.ncols();
518        let mut grad = Array2::zeros((n, d));
519
520        for i in 0..n {
521            for j in 0..n {
522                if i != j {
523                    let pq_diff = p[[i, j]] - q[[i, j]];
524                    let q_val = q[[i, j]];
525
526                    for k in 0..d {
527                        let y_diff = y[[i, k]] - y[[j, k]];
528                        grad[[i, k]] += 4.0 * pq_diff * y_diff * q_val;
529                    }
530                }
531            }
532        }
533
534        grad
535    }
536
537    /// Compute KL divergence
538    fn compute_kl_divergence(&self, p: &Array2<f32>, q: &Array2<f32>) -> f32 {
539        let mut kl = 0.0;
540        for i in 0..p.nrows() {
541            for j in 0..p.ncols() {
542                if p[[i, j]] > 0.0 && q[[i, j]] > 0.0 {
543                    kl += p[[i, j]] * (p[[i, j]] / q[[i, j]]).ln();
544                }
545            }
546        }
547        kl
548    }
549
550    /// Euclidean distance between two embeddings
551    fn euclidean_distance(&self, a: &Array1<f32>, b: &Array1<f32>) -> f32 {
552        let diff = a - b;
553        diff.dot(&diff).sqrt()
554    }
555
556    /// Build k-nearest neighbor graph
557    fn build_knn_graph(
558        &self,
559        embeddings: &HashMap<String, Array1<f32>>,
560        entity_list: &[String],
561        k: usize,
562    ) -> Vec<Vec<usize>> {
563        let n = entity_list.len();
564        let mut knn_graph = Vec::new();
565
566        for i in 0..n {
567            let entity = &entity_list[i];
568            let emb = &embeddings[entity];
569
570            // Compute distances to all other entities
571            let mut distances: Vec<(usize, f32)> = (0..n)
572                .filter(|&j| j != i)
573                .map(|j| {
574                    let other_emb = &embeddings[&entity_list[j]];
575                    let dist = self.euclidean_distance(emb, other_emb);
576                    (j, dist)
577                })
578                .collect();
579
580            // Sort by distance and take top-k
581            distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
582            let neighbors: Vec<usize> = distances.iter().take(k).map(|(idx, _)| *idx).collect();
583            knn_graph.push(neighbors);
584        }
585
586        knn_graph
587    }
588
589    /// Export visualization to JSON
590    pub fn export_json(&self, result: &VisualizationResult) -> Result<String> {
591        serde_json::to_string_pretty(result)
592            .map_err(|e| anyhow!("Failed to serialize visualization: {}", e))
593    }
594
595    /// Export visualization to CSV
596    pub fn export_csv(&self, result: &VisualizationResult) -> Result<String> {
597        let mut csv = String::from("entity");
598        for i in 0..result.dimensions {
599            csv.push_str(&format!(",dim{}", i + 1));
600        }
601        csv.push('\n');
602
603        for (entity, coords) in &result.coordinates {
604            csv.push_str(entity);
605            for coord in coords {
606                csv.push_str(&format!(",{}", coord));
607            }
608            csv.push('\n');
609        }
610
611        Ok(csv)
612    }
613}
614
615#[cfg(test)]
616mod tests {
617    use super::*;
618    use scirs2_core::ndarray_ext::array;
619
620    #[test]
621    fn test_pca_visualization() {
622        let mut embeddings = HashMap::new();
623        embeddings.insert("e1".to_string(), array![1.0, 0.0, 0.0, 0.0]);
624        embeddings.insert("e2".to_string(), array![0.0, 1.0, 0.0, 0.0]);
625        embeddings.insert("e3".to_string(), array![0.0, 0.0, 1.0, 0.0]);
626        embeddings.insert("e4".to_string(), array![0.0, 0.0, 0.0, 1.0]);
627
628        let config = VisualizationConfig {
629            method: ReductionMethod::PCA,
630            target_dims: 2,
631            ..Default::default()
632        };
633
634        let mut visualizer = EmbeddingVisualizer::new(config);
635        let result = visualizer.visualize(&embeddings).unwrap();
636
637        assert_eq!(result.coordinates.len(), 4);
638        assert_eq!(result.dimensions, 2);
639        assert!(result.explained_variance.is_some());
640    }
641
642    #[test]
643    fn test_random_projection() {
644        let mut embeddings = HashMap::new();
645        for i in 0..10 {
646            let emb = Array1::from_vec(vec![i as f32; 100]);
647            embeddings.insert(format!("e{}", i), emb);
648        }
649
650        let config = VisualizationConfig {
651            method: ReductionMethod::RandomProjection,
652            target_dims: 3,
653            ..Default::default()
654        };
655
656        let mut visualizer = EmbeddingVisualizer::new(config);
657        let result = visualizer.visualize(&embeddings).unwrap();
658
659        assert_eq!(result.coordinates.len(), 10);
660        assert_eq!(result.dimensions, 3);
661    }
662
663    #[test]
664    fn test_export_csv() {
665        let mut coordinates = HashMap::new();
666        coordinates.insert("e1".to_string(), vec![1.0, 2.0]);
667        coordinates.insert("e2".to_string(), vec![3.0, 4.0]);
668
669        let result = VisualizationResult {
670            coordinates,
671            dimensions: 2,
672            method: ReductionMethod::PCA,
673            explained_variance: None,
674            final_loss: None,
675        };
676
677        let config = VisualizationConfig::default();
678        let visualizer = EmbeddingVisualizer::new(config);
679        let csv = visualizer.export_csv(&result).unwrap();
680
681        assert!(csv.contains("entity,dim1,dim2"));
682        assert!(csv.contains("e1,1,2"));
683    }
684}