1use anyhow::{anyhow, Result};
7use scirs2_core::ndarray_ext::{Array1, Array2};
8use scirs2_core::random::prelude::{Normal, Random};
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use tracing::{debug, info};
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
15pub enum ReductionMethod {
16 PCA,
18 TSNE,
20 UMAP,
22 RandomProjection,
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct VisualizationConfig {
29 pub method: ReductionMethod,
31 pub target_dims: usize,
33 pub tsne_perplexity: f32,
35 pub tsne_learning_rate: f32,
37 pub max_iterations: usize,
39 pub random_seed: Option<u64>,
41 pub umap_n_neighbors: usize,
43 pub umap_min_dist: f32,
45}
46
47impl Default for VisualizationConfig {
48 fn default() -> Self {
49 Self {
50 method: ReductionMethod::PCA,
51 target_dims: 2,
52 tsne_perplexity: 30.0,
53 tsne_learning_rate: 200.0,
54 max_iterations: 1000,
55 random_seed: None,
56 umap_n_neighbors: 15,
57 umap_min_dist: 0.1,
58 }
59 }
60}
61
62#[derive(Debug, Clone, Serialize, Deserialize)]
64pub struct VisualizationResult {
65 pub coordinates: HashMap<String, Vec<f32>>,
67 pub dimensions: usize,
69 pub method: ReductionMethod,
71 pub explained_variance: Option<Vec<f32>>,
73 pub final_loss: Option<f32>,
75}
76
77pub struct EmbeddingVisualizer {
79 config: VisualizationConfig,
80 rng: Random,
81}
82
83impl EmbeddingVisualizer {
84 pub fn new(config: VisualizationConfig) -> Self {
86 let rng = Random::default();
87
88 info!(
89 "Initialized embedding visualizer: method={:?}, target_dims={}",
90 config.method, config.target_dims
91 );
92
93 Self { config, rng }
94 }
95
96 pub fn visualize(
98 &mut self,
99 embeddings: &HashMap<String, Array1<f32>>,
100 ) -> Result<VisualizationResult> {
101 if embeddings.is_empty() {
102 return Err(anyhow!("No embeddings to visualize"));
103 }
104
105 if self.config.target_dims != 2 && self.config.target_dims != 3 {
106 return Err(anyhow!("Target dimensions must be 2 or 3"));
107 }
108
109 info!("Visualizing {} embeddings", embeddings.len());
110
111 match self.config.method {
112 ReductionMethod::PCA => self.pca(embeddings),
113 ReductionMethod::TSNE => self.tsne(embeddings),
114 ReductionMethod::UMAP => self.umap_approximate(embeddings),
115 ReductionMethod::RandomProjection => self.random_projection(embeddings),
116 }
117 }
118
119 fn pca(&mut self, embeddings: &HashMap<String, Array1<f32>>) -> Result<VisualizationResult> {
121 let entity_list: Vec<String> = embeddings.keys().cloned().collect();
122 let n = entity_list.len();
123 let d = embeddings.values().next().unwrap().len();
124
125 let mut data_matrix = Array2::zeros((n, d));
127 for (i, entity) in entity_list.iter().enumerate() {
128 let emb = &embeddings[entity];
129 for j in 0..d {
130 data_matrix[[i, j]] = emb[j];
131 }
132 }
133
134 let mean = self.compute_mean(&data_matrix);
136 for i in 0..n {
137 for j in 0..d {
138 data_matrix[[i, j]] -= mean[j];
139 }
140 }
141
142 let cov_matrix = self.compute_covariance(&data_matrix);
144
145 let (eigenvectors, eigenvalues) =
147 self.power_iteration_top_k(&cov_matrix, self.config.target_dims)?;
148
149 let mut coordinates = HashMap::new();
151 for (i, entity) in entity_list.iter().enumerate() {
152 let mut projected = vec![0.0; self.config.target_dims];
153 for k in 0..self.config.target_dims {
154 let mut dot_product = 0.0;
155 for j in 0..d {
156 dot_product += data_matrix[[i, j]] * eigenvectors[[j, k]];
157 }
158 projected[k] = dot_product;
159 }
160 coordinates.insert(entity.clone(), projected);
161 }
162
163 let total_variance: f32 = eigenvalues.iter().sum();
165 let explained_variance: Vec<f32> =
166 eigenvalues.iter().map(|&ev| ev / total_variance).collect();
167
168 info!(
169 "PCA complete: explained variance = {:?}",
170 explained_variance
171 );
172
173 Ok(VisualizationResult {
174 coordinates,
175 dimensions: self.config.target_dims,
176 method: ReductionMethod::PCA,
177 explained_variance: Some(explained_variance),
178 final_loss: None,
179 })
180 }
181
182 fn tsne(&mut self, embeddings: &HashMap<String, Array1<f32>>) -> Result<VisualizationResult> {
184 let entity_list: Vec<String> = embeddings.keys().cloned().collect();
185 let n = entity_list.len();
186
187 let dist = Normal::new(0.0, 0.01).unwrap();
189 let mut y = Array2::from_shape_fn((n, self.config.target_dims), |_| self.rng.sample(dist));
190
191 let p = self.compute_affinities(embeddings, &entity_list);
193
194 let mut final_loss = 0.0;
196 for iteration in 0..self.config.max_iterations {
197 let q = self.compute_low_dim_affinities(&y);
199
200 let grad = self.compute_tsne_gradient(&y, &p, &q);
202
203 for i in 0..n {
205 for j in 0..self.config.target_dims {
206 y[[i, j]] -= self.config.tsne_learning_rate * grad[[i, j]];
207 }
208 }
209
210 if iteration % 100 == 0 {
212 final_loss = self.compute_kl_divergence(&p, &q);
213 debug!("t-SNE iteration {}: loss = {:.6}", iteration, final_loss);
214 }
215 }
216
217 let mut coordinates = HashMap::new();
219 for (i, entity) in entity_list.iter().enumerate() {
220 let mut coords = vec![0.0; self.config.target_dims];
221 for j in 0..self.config.target_dims {
222 coords[j] = y[[i, j]];
223 }
224 coordinates.insert(entity.clone(), coords);
225 }
226
227 info!("t-SNE complete: final loss = {:.6}", final_loss);
228
229 Ok(VisualizationResult {
230 coordinates,
231 dimensions: self.config.target_dims,
232 method: ReductionMethod::TSNE,
233 explained_variance: None,
234 final_loss: Some(final_loss),
235 })
236 }
237
238 fn umap_approximate(
240 &mut self,
241 embeddings: &HashMap<String, Array1<f32>>,
242 ) -> Result<VisualizationResult> {
243 info!("Using approximate UMAP (PCA + refinement)");
247
248 let mut result = self.pca(embeddings)?;
250
251 let entity_list: Vec<String> = embeddings.keys().cloned().collect();
253 let n = entity_list.len();
254
255 let knn_graph =
257 self.build_knn_graph(embeddings, &entity_list, self.config.umap_n_neighbors);
258
259 for _iteration in 0..100 {
261 for i in 0..n {
262 let entity = &entity_list[i];
263 let pos = &result.coordinates[entity].clone();
264
265 let mut force = vec![0.0; self.config.target_dims];
267 for &neighbor_idx in &knn_graph[i] {
268 let neighbor = &entity_list[neighbor_idx];
269 let neighbor_pos = &result.coordinates[neighbor];
270
271 for d in 0..self.config.target_dims {
272 let diff = neighbor_pos[d] - pos[d];
273 force[d] += diff * 0.01; }
275 }
276
277 let coords = result.coordinates.get_mut(entity).unwrap();
279 for d in 0..self.config.target_dims {
280 coords[d] += force[d];
281 }
282 }
283 }
284
285 result.method = ReductionMethod::UMAP;
286 info!("Approximate UMAP complete");
287
288 Ok(result)
289 }
290
291 fn random_projection(
293 &mut self,
294 embeddings: &HashMap<String, Array1<f32>>,
295 ) -> Result<VisualizationResult> {
296 let entity_list: Vec<String> = embeddings.keys().cloned().collect();
297 let d = embeddings.values().next().unwrap().len();
298
299 let dist = Normal::new(0.0, 1.0).unwrap();
301 let projection_matrix =
302 Array2::from_shape_fn((d, self.config.target_dims), |_| self.rng.sample(dist));
303
304 let mut coordinates = HashMap::new();
306 for entity in &entity_list {
307 let emb = &embeddings[entity];
308 let mut projected = vec![0.0; self.config.target_dims];
309
310 for k in 0..self.config.target_dims {
311 let mut dot_product = 0.0;
312 for j in 0..d {
313 dot_product += emb[j] * projection_matrix[[j, k]];
314 }
315 projected[k] = dot_product;
316 }
317
318 coordinates.insert(entity.clone(), projected);
319 }
320
321 info!("Random projection complete");
322
323 Ok(VisualizationResult {
324 coordinates,
325 dimensions: self.config.target_dims,
326 method: ReductionMethod::RandomProjection,
327 explained_variance: None,
328 final_loss: None,
329 })
330 }
331
332 fn compute_mean(&self, data: &Array2<f32>) -> Vec<f32> {
334 let n = data.nrows();
335 let d = data.ncols();
336 let mut mean = vec![0.0; d];
337
338 for j in 0..d {
339 for i in 0..n {
340 mean[j] += data[[i, j]];
341 }
342 mean[j] /= n as f32;
343 }
344
345 mean
346 }
347
348 fn compute_covariance(&self, data: &Array2<f32>) -> Array2<f32> {
350 let n = data.nrows() as f32;
351 let d = data.ncols();
352 let mut cov = Array2::zeros((d, d));
353
354 for i in 0..d {
355 for j in 0..d {
356 let mut sum = 0.0;
357 for k in 0..data.nrows() {
358 sum += data[[k, i]] * data[[k, j]];
359 }
360 cov[[i, j]] = sum / (n - 1.0);
361 }
362 }
363
364 cov
365 }
366
367 fn power_iteration_top_k(
369 &mut self,
370 matrix: &Array2<f32>,
371 k: usize,
372 ) -> Result<(Array2<f32>, Vec<f32>)> {
373 let d = matrix.nrows();
374 let mut eigenvectors = Array2::zeros((d, k));
375 let mut eigenvalues = Vec::new();
376
377 let mut working_matrix = matrix.clone();
378
379 for component in 0..k {
380 let dist = Normal::new(0.0f32, 1.0f32).unwrap();
382 let mut v = Array1::from_shape_fn(d, |_| self.rng.sample(dist));
383
384 for _ in 0..100 {
386 let mut new_v = Array1::<f32>::zeros(d);
388 for i in 0..d {
389 for j in 0..d {
390 new_v[i] += working_matrix[[i, j]] * v[j];
391 }
392 }
393
394 let norm = new_v.dot(&new_v).sqrt();
396 if norm > 0.0 {
397 v = new_v / norm;
398 }
399 }
400
401 let mut av = Array1::<f32>::zeros(d);
403 for i in 0..d {
404 for j in 0..d {
405 av[i] += working_matrix[[i, j]] * v[j];
406 }
407 }
408 let eigenvalue = v.dot(&av);
409 eigenvalues.push(eigenvalue);
410
411 for i in 0..d {
413 eigenvectors[[i, component]] = v[i];
414 }
415
416 for i in 0..d {
418 for j in 0..d {
419 working_matrix[[i, j]] -= eigenvalue * v[i] * v[j];
420 }
421 }
422 }
423
424 Ok((eigenvectors, eigenvalues))
425 }
426
427 fn compute_affinities(
429 &self,
430 embeddings: &HashMap<String, Array1<f32>>,
431 entity_list: &[String],
432 ) -> Array2<f32> {
433 let n = entity_list.len();
434 let mut p = Array2::zeros((n, n));
435
436 for i in 0..n {
438 for j in 0..n {
439 if i != j {
440 let dist = self.euclidean_distance(
441 &embeddings[&entity_list[i]],
442 &embeddings[&entity_list[j]],
443 );
444 p[[i, j]] = (-dist * dist / (2.0 * self.config.tsne_perplexity)).exp();
446 }
447 }
448
449 let row_sum: f32 = (0..n).map(|j| p[[i, j]]).sum();
451 if row_sum > 0.0 {
452 for j in 0..n {
453 p[[i, j]] /= row_sum;
454 }
455 }
456 }
457
458 for i in 0..n {
460 for j in 0..n {
461 p[[i, j]] = (p[[i, j]] + p[[j, i]]) / (2.0 * n as f32);
462 }
463 }
464
465 p
466 }
467
468 fn compute_low_dim_affinities(&self, y: &Array2<f32>) -> Array2<f32> {
470 let n = y.nrows();
471 let mut q = Array2::zeros((n, n));
472
473 for i in 0..n {
474 for j in 0..n {
475 if i != j {
476 let mut dist_sq = 0.0;
477 for k in 0..y.ncols() {
478 let diff = y[[i, k]] - y[[j, k]];
479 dist_sq += diff * diff;
480 }
481 q[[i, j]] = 1.0 / (1.0 + dist_sq);
482 }
483 }
484 }
485
486 let sum: f32 = q.iter().sum();
488 if sum > 0.0 {
489 q /= sum;
490 }
491
492 q
493 }
494
495 fn compute_tsne_gradient(
497 &self,
498 y: &Array2<f32>,
499 p: &Array2<f32>,
500 q: &Array2<f32>,
501 ) -> Array2<f32> {
502 let n = y.nrows();
503 let d = y.ncols();
504 let mut grad = Array2::zeros((n, d));
505
506 for i in 0..n {
507 for j in 0..n {
508 if i != j {
509 let pq_diff = p[[i, j]] - q[[i, j]];
510 let q_val = q[[i, j]];
511
512 for k in 0..d {
513 let y_diff = y[[i, k]] - y[[j, k]];
514 grad[[i, k]] += 4.0 * pq_diff * y_diff * q_val;
515 }
516 }
517 }
518 }
519
520 grad
521 }
522
523 fn compute_kl_divergence(&self, p: &Array2<f32>, q: &Array2<f32>) -> f32 {
525 let mut kl = 0.0;
526 for i in 0..p.nrows() {
527 for j in 0..p.ncols() {
528 if p[[i, j]] > 0.0 && q[[i, j]] > 0.0 {
529 kl += p[[i, j]] * (p[[i, j]] / q[[i, j]]).ln();
530 }
531 }
532 }
533 kl
534 }
535
536 fn euclidean_distance(&self, a: &Array1<f32>, b: &Array1<f32>) -> f32 {
538 let diff = a - b;
539 diff.dot(&diff).sqrt()
540 }
541
542 fn build_knn_graph(
544 &self,
545 embeddings: &HashMap<String, Array1<f32>>,
546 entity_list: &[String],
547 k: usize,
548 ) -> Vec<Vec<usize>> {
549 let n = entity_list.len();
550 let mut knn_graph = Vec::new();
551
552 for i in 0..n {
553 let entity = &entity_list[i];
554 let emb = &embeddings[entity];
555
556 let mut distances: Vec<(usize, f32)> = (0..n)
558 .filter(|&j| j != i)
559 .map(|j| {
560 let other_emb = &embeddings[&entity_list[j]];
561 let dist = self.euclidean_distance(emb, other_emb);
562 (j, dist)
563 })
564 .collect();
565
566 distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
568 let neighbors: Vec<usize> = distances.iter().take(k).map(|(idx, _)| *idx).collect();
569 knn_graph.push(neighbors);
570 }
571
572 knn_graph
573 }
574
575 pub fn export_json(&self, result: &VisualizationResult) -> Result<String> {
577 serde_json::to_string_pretty(result)
578 .map_err(|e| anyhow!("Failed to serialize visualization: {}", e))
579 }
580
581 pub fn export_csv(&self, result: &VisualizationResult) -> Result<String> {
583 let mut csv = String::from("entity");
584 for i in 0..result.dimensions {
585 csv.push_str(&format!(",dim{}", i + 1));
586 }
587 csv.push('\n');
588
589 for (entity, coords) in &result.coordinates {
590 csv.push_str(entity);
591 for coord in coords {
592 csv.push_str(&format!(",{}", coord));
593 }
594 csv.push('\n');
595 }
596
597 Ok(csv)
598 }
599}
600
601#[cfg(test)]
602mod tests {
603 use super::*;
604 use scirs2_core::ndarray_ext::array;
605
606 #[test]
607 fn test_pca_visualization() {
608 let mut embeddings = HashMap::new();
609 embeddings.insert("e1".to_string(), array![1.0, 0.0, 0.0, 0.0]);
610 embeddings.insert("e2".to_string(), array![0.0, 1.0, 0.0, 0.0]);
611 embeddings.insert("e3".to_string(), array![0.0, 0.0, 1.0, 0.0]);
612 embeddings.insert("e4".to_string(), array![0.0, 0.0, 0.0, 1.0]);
613
614 let config = VisualizationConfig {
615 method: ReductionMethod::PCA,
616 target_dims: 2,
617 ..Default::default()
618 };
619
620 let mut visualizer = EmbeddingVisualizer::new(config);
621 let result = visualizer.visualize(&embeddings).unwrap();
622
623 assert_eq!(result.coordinates.len(), 4);
624 assert_eq!(result.dimensions, 2);
625 assert!(result.explained_variance.is_some());
626 }
627
628 #[test]
629 fn test_random_projection() {
630 let mut embeddings = HashMap::new();
631 for i in 0..10 {
632 let emb = Array1::from_vec(vec![i as f32; 100]);
633 embeddings.insert(format!("e{}", i), emb);
634 }
635
636 let config = VisualizationConfig {
637 method: ReductionMethod::RandomProjection,
638 target_dims: 3,
639 ..Default::default()
640 };
641
642 let mut visualizer = EmbeddingVisualizer::new(config);
643 let result = visualizer.visualize(&embeddings).unwrap();
644
645 assert_eq!(result.coordinates.len(), 10);
646 assert_eq!(result.dimensions, 3);
647 }
648
649 #[test]
650 fn test_export_csv() {
651 let mut coordinates = HashMap::new();
652 coordinates.insert("e1".to_string(), vec![1.0, 2.0]);
653 coordinates.insert("e2".to_string(), vec![3.0, 4.0]);
654
655 let result = VisualizationResult {
656 coordinates,
657 dimensions: 2,
658 method: ReductionMethod::PCA,
659 explained_variance: None,
660 final_loss: None,
661 };
662
663 let config = VisualizationConfig::default();
664 let visualizer = EmbeddingVisualizer::new(config);
665 let csv = visualizer.export_csv(&result).unwrap();
666
667 assert!(csv.contains("entity,dim1,dim2"));
668 assert!(csv.contains("e1,1,2"));
669 }
670}