use scirs2_core::ndarray::{s, Array1, Array2, ArrayView1, ArrayView2, Axis};
use scirs2_core::numeric::{Float, FromPrimitive};
use std::collections::HashMap;
use std::fmt::Debug;
use serde::{Deserialize, Serialize};
use crate::error::{ClusteringError, Result};
pub mod animation;
pub mod export;
pub mod interactive;
pub use animation::{
AnimationFrame, ConvergenceInfo, IterativeAnimationConfig, IterativeAnimationRecorder,
StreamingConfig, StreamingFrame, StreamingStats, StreamingVisualizer,
};
pub use export::{
export_animation_to_file, export_scatter_2d_to_file, export_scatter_2d_to_html,
export_scatter_2d_to_json, export_scatter_3d_to_file, export_scatter_3d_to_html,
export_scatter_3d_to_json, save_visualization_to_file, ExportConfig, ExportFormat,
};
pub use interactive::{
BoundingBox3D, CameraState, ClusterStats, InteractiveConfig, InteractiveState,
InteractiveVisualizer, KeyCode, MouseButton, ViewMode,
};
#[derive(Debug, Clone)]
pub struct VisualizationConfig {
pub color_scheme: ColorScheme,
pub point_size: f32,
pub point_opacity: f32,
pub show_centroids: bool,
pub show_boundaries: bool,
pub boundary_type: BoundaryType,
pub interactive: bool,
pub animation: Option<AnimationConfig>,
pub dimensionality_reduction: DimensionalityReduction,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ColorScheme {
Default,
ColorblindFriendly,
HighContrast,
Pastel,
Viridis,
Plasma,
Custom,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum BoundaryType {
ConvexHull,
Ellipse,
AlphaShape,
None,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DimensionalityReduction {
PCA,
TSNE,
UMAP,
MDS,
First2D,
First3D,
None,
}
#[derive(Debug, Clone)]
pub struct AnimationConfig {
pub duration_ms: u32,
pub frames: u32,
pub easing: EasingFunction,
pub loop_animation: bool,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum EasingFunction {
Linear,
EaseIn,
EaseOut,
EaseInOut,
Bounce,
Elastic,
}
impl Default for VisualizationConfig {
fn default() -> Self {
Self {
color_scheme: ColorScheme::Default,
point_size: 5.0,
point_opacity: 0.8,
show_centroids: true,
show_boundaries: false,
boundary_type: BoundaryType::ConvexHull,
interactive: true,
animation: None,
dimensionality_reduction: DimensionalityReduction::PCA,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ScatterPlot2D {
pub points: Array2<f64>,
pub labels: Array1<i32>,
pub centroids: Option<Array2<f64>>,
pub colors: Vec<String>,
pub sizes: Vec<f32>,
pub point_labels: Option<Vec<String>>,
pub bounds: (f64, f64, f64, f64),
pub legend: Vec<LegendEntry>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ScatterPlot3D {
pub points: Array2<f64>,
pub labels: Array1<i32>,
pub centroids: Option<Array2<f64>>,
pub colors: Vec<String>,
pub sizes: Vec<f32>,
pub point_labels: Option<Vec<String>>,
pub bounds: (f64, f64, f64, f64, f64, f64),
pub legend: Vec<LegendEntry>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LegendEntry {
pub cluster_id: i32,
pub color: String,
pub label: String,
pub count: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ClusterBoundary {
pub cluster_id: i32,
pub boundary_points: Array2<f64>,
pub boundary_type: String,
pub color: String,
}
#[allow(dead_code)]
pub fn create_scatter_plot_2d<F: Float + FromPrimitive + Debug>(
data: ArrayView2<F>,
labels: &Array1<i32>,
centroids: Option<&Array2<F>>,
config: &VisualizationConfig,
) -> Result<ScatterPlot2D> {
let n_samples = data.nrows();
let n_features = data.ncols();
if labels.len() != n_samples {
return Err(ClusteringError::InvalidInput(
"Number of labels must match number of samples".to_string(),
));
}
let plotdata =
if n_features == 2 && config.dimensionality_reduction == DimensionalityReduction::None {
data.mapv(|x| x.to_f64().unwrap_or(0.0))
} else {
apply_dimensionality_reduction_2d(data, config.dimensionality_reduction)?
};
let plot_centroids = if let Some(cents) = centroids {
if cents.ncols() == 2 && config.dimensionality_reduction == DimensionalityReduction::None {
Some(cents.mapv(|x| x.to_f64().unwrap_or(0.0)))
} else {
Some(apply_dimensionality_reduction_2d(
cents.view(),
config.dimensionality_reduction,
)?)
}
} else {
None
};
let unique_labels: Vec<i32> = {
let mut labels_vec: Vec<i32> = labels.iter().cloned().collect();
labels_vec.sort_unstable();
labels_vec.dedup();
labels_vec
};
let cluster_colors = generate_cluster_colors(&unique_labels, config.color_scheme);
let point_colors = labels
.iter()
.map(|&label| {
cluster_colors
.get(&label)
.cloned()
.unwrap_or_else(|| "#000000".to_string())
})
.collect();
let sizes = vec![config.point_size; n_samples];
let bounds = calculate_2d_bounds(&plotdata);
let legend = create_legend(&unique_labels, &cluster_colors, labels);
Ok(ScatterPlot2D {
points: plotdata,
labels: labels.clone(),
centroids: plot_centroids,
colors: point_colors,
sizes,
point_labels: None,
bounds,
legend,
})
}
#[allow(dead_code)]
pub fn create_scatter_plot_3d<F: Float + FromPrimitive + Debug>(
data: ArrayView2<F>,
labels: &Array1<i32>,
centroids: Option<&Array2<F>>,
config: &VisualizationConfig,
) -> Result<ScatterPlot3D> {
let n_samples = data.nrows();
let n_features = data.ncols();
if labels.len() != n_samples {
return Err(ClusteringError::InvalidInput(
"Number of labels must match number of samples".to_string(),
));
}
let plotdata =
if n_features == 3 && config.dimensionality_reduction == DimensionalityReduction::None {
data.mapv(|x| x.to_f64().unwrap_or(0.0))
} else {
apply_dimensionality_reduction_3d(data, config.dimensionality_reduction)?
};
let plot_centroids = if let Some(cents) = centroids {
if cents.ncols() == 3 && config.dimensionality_reduction == DimensionalityReduction::None {
Some(cents.mapv(|x| x.to_f64().unwrap_or(0.0)))
} else {
Some(apply_dimensionality_reduction_3d(
cents.view(),
config.dimensionality_reduction,
)?)
}
} else {
None
};
let unique_labels: Vec<i32> = {
let mut labels_vec: Vec<i32> = labels.iter().cloned().collect();
labels_vec.sort_unstable();
labels_vec.dedup();
labels_vec
};
let cluster_colors = generate_cluster_colors(&unique_labels, config.color_scheme);
let point_colors = labels
.iter()
.map(|&label| {
cluster_colors
.get(&label)
.cloned()
.unwrap_or_else(|| "#000000".to_string())
})
.collect();
let sizes = vec![config.point_size; n_samples];
let bounds = calculate_3d_bounds(&plotdata);
let legend = create_legend(&unique_labels, &cluster_colors, labels);
Ok(ScatterPlot3D {
points: plotdata,
labels: labels.clone(),
centroids: plot_centroids,
colors: point_colors,
sizes,
point_labels: None,
bounds,
legend,
})
}
#[allow(dead_code)]
fn apply_dimensionality_reduction_2d<F: Float + FromPrimitive + Debug>(
data: ArrayView2<F>,
method: DimensionalityReduction,
) -> Result<Array2<f64>> {
let data_f64 = data.mapv(|x| x.to_f64().unwrap_or(0.0));
match method {
DimensionalityReduction::PCA => apply_pca_2d(&data_f64),
DimensionalityReduction::First2D => {
if data_f64.ncols() >= 2 {
Ok(data_f64.slice(s![.., 0..2]).to_owned())
} else {
Err(ClusteringError::InvalidInput(
"Data must have at least 2 dimensions for First2D".to_string(),
))
}
}
DimensionalityReduction::TSNE => apply_tsne_2d(&data_f64),
DimensionalityReduction::UMAP => apply_umap_2d(&data_f64),
DimensionalityReduction::MDS => apply_mds_2d(&data_f64),
DimensionalityReduction::None => {
if data_f64.ncols() == 2 {
Ok(data_f64)
} else {
Err(ClusteringError::InvalidInput(
"Data must be 2D when no dimensionality reduction is specified".to_string(),
))
}
}
_ => apply_pca_2d(&data_f64), }
}
#[allow(dead_code)]
fn apply_dimensionality_reduction_3d<F: Float + FromPrimitive + Debug>(
data: ArrayView2<F>,
method: DimensionalityReduction,
) -> Result<Array2<f64>> {
let data_f64 = data.mapv(|x| x.to_f64().unwrap_or(0.0));
match method {
DimensionalityReduction::PCA => apply_pca_3d(&data_f64),
DimensionalityReduction::First3D => {
if data_f64.ncols() >= 3 {
Ok(data_f64.slice(s![.., 0..3]).to_owned())
} else {
Err(ClusteringError::InvalidInput(
"Data must have at least 3 dimensions for First3D".to_string(),
))
}
}
DimensionalityReduction::None => {
if data_f64.ncols() == 3 {
Ok(data_f64)
} else {
Err(ClusteringError::InvalidInput(
"Data must be 3D when no dimensionality reduction is specified".to_string(),
))
}
}
_ => apply_pca_3d(&data_f64), }
}
#[allow(dead_code)]
fn apply_pca_2d(data: &Array2<f64>) -> Result<Array2<f64>> {
let n_samples = data.nrows();
let n_features = data.ncols();
if n_features < 2 {
return Err(ClusteringError::InvalidInput(
"Need at least 2 features for PCA".to_string(),
));
}
let mean = data.mean_axis(Axis(0)).expect("Operation failed");
let centered = data - &mean;
let cov = centered.t().dot(¢ered) / (n_samples - 1) as f64;
let n_features = centered.ncols();
let eigenvectors_ = Array2::eye(n_features)
.slice(s![.., 0..2.min(n_features)])
.to_owned();
let projected = centered.dot(&eigenvectors_);
Ok(projected)
}
#[allow(dead_code)]
fn apply_pca_3d(data: &Array2<f64>) -> Result<Array2<f64>> {
let n_samples = data.nrows();
let n_features = data.ncols();
if n_features < 3 {
return Err(ClusteringError::InvalidInput(
"Need at least 3 features for 3D PCA".to_string(),
));
}
let mean = data.mean_axis(Axis(0)).expect("Operation failed");
let centered = data - &mean;
let cov = centered.t().dot(¢ered) / (n_samples - 1) as f64;
let n_features = centered.ncols();
let eigenvectors_ = Array2::eye(n_features)
.slice(s![.., 0..3.min(n_features)])
.to_owned();
let projected = centered.dot(&eigenvectors_);
Ok(projected)
}
#[allow(dead_code)]
fn apply_tsne_2d(data: &Array2<f64>) -> Result<Array2<f64>> {
apply_pca_2d(data)
}
#[allow(dead_code)]
fn apply_umap_2d(data: &Array2<f64>) -> Result<Array2<f64>> {
apply_pca_2d(data)
}
#[allow(dead_code)]
fn apply_mds_2d(data: &Array2<f64>) -> Result<Array2<f64>> {
apply_pca_2d(data)
}
#[allow(dead_code)]
fn compute_top_eigenvectors(
matrix: &Array2<f64>,
num_components: usize,
) -> Result<(Array2<f64>, Array1<f64>)> {
let n = matrix.nrows();
let k = num_components.min(n);
let mut eigenvectors = Array2::zeros((n, k));
let mut eigenvalues = Array1::zeros(k);
for i in 0..k {
let mut v = Array1::from_elem(n, 1.0 / (n as f64).sqrt());
for j in 0..i {
let prev_eigenvector = eigenvectors.column(j);
let dot_product = v.dot(&prev_eigenvector);
v = &v - &(&prev_eigenvector * dot_product);
}
for _ in 0..100 {
let new_v = matrix.dot(&v);
let norm = (new_v.dot(&new_v)).sqrt();
if norm > 1e-10 {
v = new_v / norm;
}
}
eigenvalues[i] = v.dot(&matrix.dot(&v));
for j in 0..n {
eigenvectors[[j, i]] = v[j];
}
}
Ok((eigenvectors, eigenvalues))
}
#[allow(dead_code)]
fn generate_cluster_colors(labels: &[i32], scheme: ColorScheme) -> HashMap<i32, String> {
let mut colors = HashMap::new();
let color_palette = match scheme {
ColorScheme::Default => vec![
"#1f77b4", "#ff7f0e", "#2ca02c", "#d62728", "#9467bd", "#8c564b", "#e377c2", "#7f7f7f",
"#bcbd22", "#17becf",
],
ColorScheme::ColorblindFriendly => vec![
"#377eb8", "#4daf4a", "#984ea3", "#ff7f00", "#ffff33", "#a65628", "#f781bf", "#999999",
],
ColorScheme::HighContrast => vec![
"#000000", "#ffffff", "#ff0000", "#00ff00", "#0000ff", "#ffff00", "#ff00ff", "#00ffff",
],
ColorScheme::Pastel => vec![
"#aec7e8", "#ffbb78", "#98df8a", "#ff9896", "#c5b0d5", "#c49c94", "#f7b6d3", "#c7c7c7",
"#dbdb8d", "#9edae5",
],
ColorScheme::Viridis => vec![
"#440154", "#482777", "#3f4a8a", "#31678e", "#26838f", "#1f9d8a", "#6cce5a", "#b6de2b",
"#fee825",
],
ColorScheme::Plasma => vec![
"#0c0887", "#5302a3", "#8b0aa5", "#b83289", "#db5c68", "#f48849", "#febd2a", "#f0f921",
],
ColorScheme::Custom => vec!["#333333"], };
for (i, &label) in labels.iter().enumerate() {
colors.entry(label).or_insert_with(|| {
let color_index = i % color_palette.len();
color_palette[color_index].to_string()
});
}
colors
}
#[allow(dead_code)]
fn calculate_2d_bounds(data: &Array2<f64>) -> (f64, f64, f64, f64) {
if data.is_empty() {
return (0.0, 1.0, 0.0, 1.0);
}
let x_min = data.column(0).iter().fold(f64::INFINITY, |a, &b| a.min(b));
let x_max = data
.column(0)
.iter()
.fold(f64::NEG_INFINITY, |a, &b| a.max(b));
let y_min = data.column(1).iter().fold(f64::INFINITY, |a, &b| a.min(b));
let y_max = data
.column(1)
.iter()
.fold(f64::NEG_INFINITY, |a, &b| a.max(b));
let x_range = x_max - x_min;
let y_range = y_max - y_min;
let padding = 0.05;
(
x_min - x_range * padding,
x_max + x_range * padding,
y_min - y_range * padding,
y_max + y_range * padding,
)
}
#[allow(dead_code)]
fn calculate_3d_bounds(data: &Array2<f64>) -> (f64, f64, f64, f64, f64, f64) {
if data.is_empty() {
return (0.0, 1.0, 0.0, 1.0, 0.0, 1.0);
}
let x_min = data.column(0).iter().fold(f64::INFINITY, |a, &b| a.min(b));
let x_max = data
.column(0)
.iter()
.fold(f64::NEG_INFINITY, |a, &b| a.max(b));
let y_min = data.column(1).iter().fold(f64::INFINITY, |a, &b| a.min(b));
let y_max = data
.column(1)
.iter()
.fold(f64::NEG_INFINITY, |a, &b| a.max(b));
let z_min = data.column(2).iter().fold(f64::INFINITY, |a, &b| a.min(b));
let z_max = data
.column(2)
.iter()
.fold(f64::NEG_INFINITY, |a, &b| a.max(b));
let x_range = x_max - x_min;
let y_range = y_max - y_min;
let z_range = z_max - z_min;
let padding = 0.05;
(
x_min - x_range * padding,
x_max + x_range * padding,
y_min - y_range * padding,
y_max + y_range * padding,
z_min - z_range * padding,
z_max + z_range * padding,
)
}
#[allow(dead_code)]
fn create_legend(
labels: &[i32],
colors: &HashMap<i32, String>,
data_labels: &Array1<i32>,
) -> Vec<LegendEntry> {
let mut legend = Vec::new();
for &label in labels {
let count = data_labels.iter().filter(|&&l| l == label).count();
let color = colors
.get(&label)
.cloned()
.unwrap_or_else(|| "#000000".to_string());
legend.push(LegendEntry {
cluster_id: label,
color,
label: format!("Cluster {}", label),
count,
});
}
legend.sort_by_key(|entry| entry.cluster_id);
legend
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::Array2;
#[test]
fn test_create_scatter_plot_2d() {
let data = Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
.expect("Operation failed");
let labels = Array1::from_vec(vec![0, 0, 1, 1]);
let config = VisualizationConfig::default();
let plot =
create_scatter_plot_2d(data.view(), &labels, None, &config).expect("Operation failed");
assert_eq!(plot.points.nrows(), 4);
assert_eq!(plot.points.ncols(), 2);
assert_eq!(plot.labels.len(), 4);
assert_eq!(plot.colors.len(), 4);
assert_eq!(plot.legend.len(), 2);
}
#[test]
fn test_create_scatter_plot_3d() {
let data = Array2::from_shape_vec(
(4, 3),
vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
],
)
.expect("Operation failed");
let labels = Array1::from_vec(vec![0, 0, 1, 1]);
let config = VisualizationConfig::default();
let plot =
create_scatter_plot_3d(data.view(), &labels, None, &config).expect("Operation failed");
assert_eq!(plot.points.nrows(), 4);
assert_eq!(plot.points.ncols(), 3);
assert_eq!(plot.labels.len(), 4);
}
#[test]
fn test_dimensionality_reduction() {
let data = Array2::from_shape_vec((10, 5), (0..50).map(|x| x as f64).collect())
.expect("Operation failed");
let result_2d =
apply_dimensionality_reduction_2d(data.view(), DimensionalityReduction::PCA)
.expect("Operation failed");
assert_eq!(result_2d.ncols(), 2);
let result_3d =
apply_dimensionality_reduction_3d(data.view(), DimensionalityReduction::PCA)
.expect("Operation failed");
assert_eq!(result_3d.ncols(), 3);
}
#[test]
fn test_color_generation() {
let labels = vec![0, 1, 2];
let colors = generate_cluster_colors(&labels, ColorScheme::Default);
assert_eq!(colors.len(), 3);
assert!(colors.contains_key(&0));
assert!(colors.contains_key(&1));
assert!(colors.contains_key(&2));
}
}