1use anyhow::{anyhow, Result};
7use scirs2_core::ndarray_ext::{Array1, Array2};
8use scirs2_core::random::prelude::{Normal, Random};
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use tracing::{debug, info};
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
15pub enum ReductionMethod {
16 PCA,
18 TSNE,
20 UMAP,
22 RandomProjection,
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct VisualizationConfig {
29 pub method: ReductionMethod,
31 pub target_dims: usize,
33 pub tsne_perplexity: f32,
35 pub tsne_learning_rate: f32,
37 pub max_iterations: usize,
39 pub random_seed: Option<u64>,
41 pub umap_n_neighbors: usize,
43 pub umap_min_dist: f32,
45}
46
47impl Default for VisualizationConfig {
48 fn default() -> Self {
49 Self {
50 method: ReductionMethod::PCA,
51 target_dims: 2,
52 tsne_perplexity: 30.0,
53 tsne_learning_rate: 200.0,
54 max_iterations: 1000,
55 random_seed: None,
56 umap_n_neighbors: 15,
57 umap_min_dist: 0.1,
58 }
59 }
60}
61
62#[derive(Debug, Clone, Serialize, Deserialize)]
64pub struct VisualizationResult {
65 pub coordinates: HashMap<String, Vec<f32>>,
67 pub dimensions: usize,
69 pub method: ReductionMethod,
71 pub explained_variance: Option<Vec<f32>>,
73 pub final_loss: Option<f32>,
75}
76
77pub struct EmbeddingVisualizer {
79 config: VisualizationConfig,
80 rng: Random,
81}
82
83impl EmbeddingVisualizer {
84 pub fn new(config: VisualizationConfig) -> Self {
86 let rng = Random::default();
87
88 info!(
89 "Initialized embedding visualizer: method={:?}, target_dims={}",
90 config.method, config.target_dims
91 );
92
93 Self { config, rng }
94 }
95
96 pub fn visualize(
98 &mut self,
99 embeddings: &HashMap<String, Array1<f32>>,
100 ) -> Result<VisualizationResult> {
101 if embeddings.is_empty() {
102 return Err(anyhow!("No embeddings to visualize"));
103 }
104
105 if self.config.target_dims != 2 && self.config.target_dims != 3 {
106 return Err(anyhow!("Target dimensions must be 2 or 3"));
107 }
108
109 info!("Visualizing {} embeddings", embeddings.len());
110
111 match self.config.method {
112 ReductionMethod::PCA => self.pca(embeddings),
113 ReductionMethod::TSNE => self.tsne(embeddings),
114 ReductionMethod::UMAP => self.umap_approximate(embeddings),
115 ReductionMethod::RandomProjection => self.random_projection(embeddings),
116 }
117 }
118
119 fn pca(&mut self, embeddings: &HashMap<String, Array1<f32>>) -> Result<VisualizationResult> {
121 let entity_list: Vec<String> = embeddings.keys().cloned().collect();
122 let n = entity_list.len();
123 let d = embeddings
124 .values()
125 .next()
126 .expect("embeddings should not be empty")
127 .len();
128
129 let mut data_matrix = Array2::zeros((n, d));
131 for (i, entity) in entity_list.iter().enumerate() {
132 let emb = &embeddings[entity];
133 for j in 0..d {
134 data_matrix[[i, j]] = emb[j];
135 }
136 }
137
138 let mean = self.compute_mean(&data_matrix);
140 for i in 0..n {
141 for j in 0..d {
142 data_matrix[[i, j]] -= mean[j];
143 }
144 }
145
146 let cov_matrix = self.compute_covariance(&data_matrix);
148
149 let (eigenvectors, eigenvalues) =
151 self.power_iteration_top_k(&cov_matrix, self.config.target_dims)?;
152
153 let mut coordinates = HashMap::new();
155 for (i, entity) in entity_list.iter().enumerate() {
156 let mut projected = vec![0.0; self.config.target_dims];
157 for k in 0..self.config.target_dims {
158 let mut dot_product = 0.0;
159 for j in 0..d {
160 dot_product += data_matrix[[i, j]] * eigenvectors[[j, k]];
161 }
162 projected[k] = dot_product;
163 }
164 coordinates.insert(entity.clone(), projected);
165 }
166
167 let total_variance: f32 = eigenvalues.iter().sum();
169 let explained_variance: Vec<f32> =
170 eigenvalues.iter().map(|&ev| ev / total_variance).collect();
171
172 info!(
173 "PCA complete: explained variance = {:?}",
174 explained_variance
175 );
176
177 Ok(VisualizationResult {
178 coordinates,
179 dimensions: self.config.target_dims,
180 method: ReductionMethod::PCA,
181 explained_variance: Some(explained_variance),
182 final_loss: None,
183 })
184 }
185
186 fn tsne(&mut self, embeddings: &HashMap<String, Array1<f32>>) -> Result<VisualizationResult> {
188 let entity_list: Vec<String> = embeddings.keys().cloned().collect();
189 let n = entity_list.len();
190
191 let dist = Normal::new(0.0, 0.01)
193 .expect("Normal distribution should be valid for these parameters");
194 let mut y = Array2::from_shape_fn((n, self.config.target_dims), |_| self.rng.sample(dist));
195
196 let p = self.compute_affinities(embeddings, &entity_list);
198
199 let mut final_loss = 0.0;
201 for iteration in 0..self.config.max_iterations {
202 let q = self.compute_low_dim_affinities(&y);
204
205 let grad = self.compute_tsne_gradient(&y, &p, &q);
207
208 for i in 0..n {
210 for j in 0..self.config.target_dims {
211 y[[i, j]] -= self.config.tsne_learning_rate * grad[[i, j]];
212 }
213 }
214
215 if iteration % 100 == 0 {
217 final_loss = self.compute_kl_divergence(&p, &q);
218 debug!("t-SNE iteration {}: loss = {:.6}", iteration, final_loss);
219 }
220 }
221
222 let mut coordinates = HashMap::new();
224 for (i, entity) in entity_list.iter().enumerate() {
225 let mut coords = vec![0.0; self.config.target_dims];
226 for j in 0..self.config.target_dims {
227 coords[j] = y[[i, j]];
228 }
229 coordinates.insert(entity.clone(), coords);
230 }
231
232 info!("t-SNE complete: final loss = {:.6}", final_loss);
233
234 Ok(VisualizationResult {
235 coordinates,
236 dimensions: self.config.target_dims,
237 method: ReductionMethod::TSNE,
238 explained_variance: None,
239 final_loss: Some(final_loss),
240 })
241 }
242
243 fn umap_approximate(
245 &mut self,
246 embeddings: &HashMap<String, Array1<f32>>,
247 ) -> Result<VisualizationResult> {
248 info!("Using approximate UMAP (PCA + refinement)");
252
253 let mut result = self.pca(embeddings)?;
255
256 let entity_list: Vec<String> = embeddings.keys().cloned().collect();
258 let n = entity_list.len();
259
260 let knn_graph =
262 self.build_knn_graph(embeddings, &entity_list, self.config.umap_n_neighbors);
263
264 for _iteration in 0..100 {
266 for i in 0..n {
267 let entity = &entity_list[i];
268 let pos = &result.coordinates[entity].clone();
269
270 let mut force = vec![0.0; self.config.target_dims];
272 for &neighbor_idx in &knn_graph[i] {
273 let neighbor = &entity_list[neighbor_idx];
274 let neighbor_pos = &result.coordinates[neighbor];
275
276 for d in 0..self.config.target_dims {
277 let diff = neighbor_pos[d] - pos[d];
278 force[d] += diff * 0.01; }
280 }
281
282 let coords = result
284 .coordinates
285 .get_mut(entity)
286 .expect("entity should exist in coordinates");
287 for d in 0..self.config.target_dims {
288 coords[d] += force[d];
289 }
290 }
291 }
292
293 result.method = ReductionMethod::UMAP;
294 info!("Approximate UMAP complete");
295
296 Ok(result)
297 }
298
299 fn random_projection(
301 &mut self,
302 embeddings: &HashMap<String, Array1<f32>>,
303 ) -> Result<VisualizationResult> {
304 let entity_list: Vec<String> = embeddings.keys().cloned().collect();
305 let d = embeddings
306 .values()
307 .next()
308 .expect("embeddings should not be empty")
309 .len();
310
311 let dist = Normal::new(0.0, 1.0)
313 .expect("Normal distribution should be valid for these parameters");
314 let projection_matrix =
315 Array2::from_shape_fn((d, self.config.target_dims), |_| self.rng.sample(dist));
316
317 let mut coordinates = HashMap::new();
319 for entity in &entity_list {
320 let emb = &embeddings[entity];
321 let mut projected = vec![0.0; self.config.target_dims];
322
323 for k in 0..self.config.target_dims {
324 let mut dot_product = 0.0;
325 for j in 0..d {
326 dot_product += emb[j] * projection_matrix[[j, k]];
327 }
328 projected[k] = dot_product;
329 }
330
331 coordinates.insert(entity.clone(), projected);
332 }
333
334 info!("Random projection complete");
335
336 Ok(VisualizationResult {
337 coordinates,
338 dimensions: self.config.target_dims,
339 method: ReductionMethod::RandomProjection,
340 explained_variance: None,
341 final_loss: None,
342 })
343 }
344
345 fn compute_mean(&self, data: &Array2<f32>) -> Vec<f32> {
347 let n = data.nrows();
348 let d = data.ncols();
349 let mut mean = vec![0.0; d];
350
351 for j in 0..d {
352 for i in 0..n {
353 mean[j] += data[[i, j]];
354 }
355 mean[j] /= n as f32;
356 }
357
358 mean
359 }
360
361 fn compute_covariance(&self, data: &Array2<f32>) -> Array2<f32> {
363 let n = data.nrows() as f32;
364 let d = data.ncols();
365 let mut cov = Array2::zeros((d, d));
366
367 for i in 0..d {
368 for j in 0..d {
369 let mut sum = 0.0;
370 for k in 0..data.nrows() {
371 sum += data[[k, i]] * data[[k, j]];
372 }
373 cov[[i, j]] = sum / (n - 1.0);
374 }
375 }
376
377 cov
378 }
379
380 fn power_iteration_top_k(
382 &mut self,
383 matrix: &Array2<f32>,
384 k: usize,
385 ) -> Result<(Array2<f32>, Vec<f32>)> {
386 let d = matrix.nrows();
387 let mut eigenvectors = Array2::zeros((d, k));
388 let mut eigenvalues = Vec::new();
389
390 let mut working_matrix = matrix.clone();
391
392 for component in 0..k {
393 let dist = Normal::new(0.0f32, 1.0f32)
395 .expect("Normal distribution should be valid for these parameters");
396 let mut v = Array1::from_shape_fn(d, |_| self.rng.sample(dist));
397
398 for _ in 0..100 {
400 let mut new_v = Array1::<f32>::zeros(d);
402 for i in 0..d {
403 for j in 0..d {
404 new_v[i] += working_matrix[[i, j]] * v[j];
405 }
406 }
407
408 let norm = new_v.dot(&new_v).sqrt();
410 if norm > 0.0 {
411 v = new_v / norm;
412 }
413 }
414
415 let mut av = Array1::<f32>::zeros(d);
417 for i in 0..d {
418 for j in 0..d {
419 av[i] += working_matrix[[i, j]] * v[j];
420 }
421 }
422 let eigenvalue = v.dot(&av);
423 eigenvalues.push(eigenvalue);
424
425 for i in 0..d {
427 eigenvectors[[i, component]] = v[i];
428 }
429
430 for i in 0..d {
432 for j in 0..d {
433 working_matrix[[i, j]] -= eigenvalue * v[i] * v[j];
434 }
435 }
436 }
437
438 Ok((eigenvectors, eigenvalues))
439 }
440
441 fn compute_affinities(
443 &self,
444 embeddings: &HashMap<String, Array1<f32>>,
445 entity_list: &[String],
446 ) -> Array2<f32> {
447 let n = entity_list.len();
448 let mut p = Array2::zeros((n, n));
449
450 for i in 0..n {
452 for j in 0..n {
453 if i != j {
454 let dist = self.euclidean_distance(
455 &embeddings[&entity_list[i]],
456 &embeddings[&entity_list[j]],
457 );
458 p[[i, j]] = (-dist * dist / (2.0 * self.config.tsne_perplexity)).exp();
460 }
461 }
462
463 let row_sum: f32 = (0..n).map(|j| p[[i, j]]).sum();
465 if row_sum > 0.0 {
466 for j in 0..n {
467 p[[i, j]] /= row_sum;
468 }
469 }
470 }
471
472 for i in 0..n {
474 for j in 0..n {
475 p[[i, j]] = (p[[i, j]] + p[[j, i]]) / (2.0 * n as f32);
476 }
477 }
478
479 p
480 }
481
482 fn compute_low_dim_affinities(&self, y: &Array2<f32>) -> Array2<f32> {
484 let n = y.nrows();
485 let mut q = Array2::zeros((n, n));
486
487 for i in 0..n {
488 for j in 0..n {
489 if i != j {
490 let mut dist_sq = 0.0;
491 for k in 0..y.ncols() {
492 let diff = y[[i, k]] - y[[j, k]];
493 dist_sq += diff * diff;
494 }
495 q[[i, j]] = 1.0 / (1.0 + dist_sq);
496 }
497 }
498 }
499
500 let sum: f32 = q.iter().sum();
502 if sum > 0.0 {
503 q /= sum;
504 }
505
506 q
507 }
508
509 fn compute_tsne_gradient(
511 &self,
512 y: &Array2<f32>,
513 p: &Array2<f32>,
514 q: &Array2<f32>,
515 ) -> Array2<f32> {
516 let n = y.nrows();
517 let d = y.ncols();
518 let mut grad = Array2::zeros((n, d));
519
520 for i in 0..n {
521 for j in 0..n {
522 if i != j {
523 let pq_diff = p[[i, j]] - q[[i, j]];
524 let q_val = q[[i, j]];
525
526 for k in 0..d {
527 let y_diff = y[[i, k]] - y[[j, k]];
528 grad[[i, k]] += 4.0 * pq_diff * y_diff * q_val;
529 }
530 }
531 }
532 }
533
534 grad
535 }
536
537 fn compute_kl_divergence(&self, p: &Array2<f32>, q: &Array2<f32>) -> f32 {
539 let mut kl = 0.0;
540 for i in 0..p.nrows() {
541 for j in 0..p.ncols() {
542 if p[[i, j]] > 0.0 && q[[i, j]] > 0.0 {
543 kl += p[[i, j]] * (p[[i, j]] / q[[i, j]]).ln();
544 }
545 }
546 }
547 kl
548 }
549
550 fn euclidean_distance(&self, a: &Array1<f32>, b: &Array1<f32>) -> f32 {
552 let diff = a - b;
553 diff.dot(&diff).sqrt()
554 }
555
556 fn build_knn_graph(
558 &self,
559 embeddings: &HashMap<String, Array1<f32>>,
560 entity_list: &[String],
561 k: usize,
562 ) -> Vec<Vec<usize>> {
563 let n = entity_list.len();
564 let mut knn_graph = Vec::new();
565
566 for i in 0..n {
567 let entity = &entity_list[i];
568 let emb = &embeddings[entity];
569
570 let mut distances: Vec<(usize, f32)> = (0..n)
572 .filter(|&j| j != i)
573 .map(|j| {
574 let other_emb = &embeddings[&entity_list[j]];
575 let dist = self.euclidean_distance(emb, other_emb);
576 (j, dist)
577 })
578 .collect();
579
580 distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
582 let neighbors: Vec<usize> = distances.iter().take(k).map(|(idx, _)| *idx).collect();
583 knn_graph.push(neighbors);
584 }
585
586 knn_graph
587 }
588
589 pub fn export_json(&self, result: &VisualizationResult) -> Result<String> {
591 serde_json::to_string_pretty(result)
592 .map_err(|e| anyhow!("Failed to serialize visualization: {}", e))
593 }
594
595 pub fn export_csv(&self, result: &VisualizationResult) -> Result<String> {
597 let mut csv = String::from("entity");
598 for i in 0..result.dimensions {
599 csv.push_str(&format!(",dim{}", i + 1));
600 }
601 csv.push('\n');
602
603 for (entity, coords) in &result.coordinates {
604 csv.push_str(entity);
605 for coord in coords {
606 csv.push_str(&format!(",{}", coord));
607 }
608 csv.push('\n');
609 }
610
611 Ok(csv)
612 }
613}
614
615#[cfg(test)]
616mod tests {
617 use super::*;
618 use scirs2_core::ndarray_ext::array;
619
620 #[test]
621 fn test_pca_visualization() {
622 let mut embeddings = HashMap::new();
623 embeddings.insert("e1".to_string(), array![1.0, 0.0, 0.0, 0.0]);
624 embeddings.insert("e2".to_string(), array![0.0, 1.0, 0.0, 0.0]);
625 embeddings.insert("e3".to_string(), array![0.0, 0.0, 1.0, 0.0]);
626 embeddings.insert("e4".to_string(), array![0.0, 0.0, 0.0, 1.0]);
627
628 let config = VisualizationConfig {
629 method: ReductionMethod::PCA,
630 target_dims: 2,
631 ..Default::default()
632 };
633
634 let mut visualizer = EmbeddingVisualizer::new(config);
635 let result = visualizer.visualize(&embeddings).unwrap();
636
637 assert_eq!(result.coordinates.len(), 4);
638 assert_eq!(result.dimensions, 2);
639 assert!(result.explained_variance.is_some());
640 }
641
642 #[test]
643 fn test_random_projection() {
644 let mut embeddings = HashMap::new();
645 for i in 0..10 {
646 let emb = Array1::from_vec(vec![i as f32; 100]);
647 embeddings.insert(format!("e{}", i), emb);
648 }
649
650 let config = VisualizationConfig {
651 method: ReductionMethod::RandomProjection,
652 target_dims: 3,
653 ..Default::default()
654 };
655
656 let mut visualizer = EmbeddingVisualizer::new(config);
657 let result = visualizer.visualize(&embeddings).unwrap();
658
659 assert_eq!(result.coordinates.len(), 10);
660 assert_eq!(result.dimensions, 3);
661 }
662
663 #[test]
664 fn test_export_csv() {
665 let mut coordinates = HashMap::new();
666 coordinates.insert("e1".to_string(), vec![1.0, 2.0]);
667 coordinates.insert("e2".to_string(), vec![3.0, 4.0]);
668
669 let result = VisualizationResult {
670 coordinates,
671 dimensions: 2,
672 method: ReductionMethod::PCA,
673 explained_variance: None,
674 final_loss: None,
675 };
676
677 let config = VisualizationConfig::default();
678 let visualizer = EmbeddingVisualizer::new(config);
679 let csv = visualizer.export_csv(&result).unwrap();
680
681 assert!(csv.contains("entity,dim1,dim2"));
682 assert!(csv.contains("e1,1,2"));
683 }
684}