use crate::config::UmapConfig;
use crate::distances::EuclideanMetric;
use crate::manifold::LearnedManifold;
use crate::metric::Metric;
use crate::optimizer::Optimizer;
use crate::umap::find_ab_params::find_ab_params;
use crate::umap::fuzzy_simplicial_set::FuzzySimplicialSet;
use crate::umap::raise_disconnected_warning::raise_disconnected_warning;
use dashmap::DashSet;
use ndarray::Array2;
use ndarray::ArrayView2;
use rayon::iter::IntoParallelIterator;
use rayon::iter::ParallelIterator;
use serde::Deserialize;
use serde::Serialize;
use std::time::Instant;
use tracing::info;
pub struct Umap {
config: UmapConfig,
metric: Box<dyn Metric>,
output_metric: Box<dyn Metric>,
}
impl Umap {
pub fn new(config: UmapConfig) -> Self {
Self {
config,
metric: Box::new(EuclideanMetric),
output_metric: Box::new(EuclideanMetric),
}
}
pub fn with_metrics(
config: UmapConfig,
metric: Box<dyn Metric>,
output_metric: Box<dyn Metric>,
) -> Self {
Self {
config,
metric,
output_metric,
}
}
pub fn learn_manifold(
&self,
data: ArrayView2<f32>,
knn_indices: ArrayView2<u32>,
knn_dists: ArrayView2<f32>,
) -> LearnedManifold {
let n_samples = data.shape()[0];
self.validate_parameters(n_samples, &knn_indices, &knn_dists);
let (a, b) =
if let (Some(a_val), Some(b_val)) = (self.config.manifold.a, self.config.manifold.b) {
(a_val, b_val)
} else {
find_ab_params(self.config.manifold.spread, self.config.manifold.min_dist)
};
let disconnection_distance = self
.config
.graph
.disconnection_distance
.unwrap_or_else(|| self.metric.disconnection_threshold());
let started = Instant::now();
let knn_disconnections = DashSet::new();
(0..n_samples).into_par_iter().for_each(|row_no| {
let row = knn_dists.row(row_no);
for (col_no, &dist) in row.iter().enumerate() {
if dist >= disconnection_distance {
knn_disconnections.insert((row_no, col_no));
}
}
});
let edges_removed = knn_disconnections.len();
info!(
duration_ms = started.elapsed().as_millis(),
edges_removed, "disconnection detection complete"
);
info!(
n_samples,
n_neighbors = self.config.graph.n_neighbors,
"starting fuzzy simplicial set"
);
let started = Instant::now();
let (graph, sigmas, rhos) = FuzzySimplicialSet::builder()
.n_samples(n_samples)
.n_neighbors(self.config.graph.n_neighbors)
.knn_indices(knn_indices)
.knn_dists(knn_dists)
.knn_disconnections(&knn_disconnections)
.local_connectivity(self.config.graph.local_connectivity)
.set_op_mix_ratio(self.config.graph.set_op_mix_ratio)
.apply_set_operations(self.config.graph.symmetrize)
.build()
.exec();
info!(
duration_ms = started.elapsed().as_millis(),
"fuzzy simplicial set complete"
);
let vertices_disconnected = graph
.outer_iterator()
.filter(|row| {
let sum: f32 = row.data().iter().sum();
sum == 0.0
})
.count();
raise_disconnected_warning(
edges_removed,
vertices_disconnected,
disconnection_distance,
n_samples,
0.1,
);
LearnedManifold {
graph,
sigmas,
rhos,
n_vertices: n_samples,
a,
b,
}
}
pub fn fit(
&self,
data: ArrayView2<f32>,
knn_indices: ArrayView2<u32>,
knn_dists: ArrayView2<f32>,
init: ArrayView2<f32>,
) -> FittedUmap {
let n_samples = data.shape()[0];
if init.shape()[1] != self.config.n_components {
panic!(
"init has {} components but n_components is {}",
init.shape()[1],
self.config.n_components
);
}
if init.shape()[0] != n_samples {
panic!(
"init has {} samples but data has {} samples",
init.shape()[0],
n_samples
);
}
let manifold = self.learn_manifold(data, knn_indices, knn_dists);
let total_epochs = self
.config
.optimization
.n_epochs
.unwrap_or_else(|| if n_samples <= 10000 { 500 } else { 200 });
let metric_type = self.output_metric.metric_type();
let mut optimizer = Optimizer::new(
manifold,
init.to_owned(),
total_epochs,
&self.config,
metric_type,
);
optimizer.step_epochs(total_epochs, self.output_metric.as_ref());
let mut fitted = optimizer.into_fitted(self.config.clone());
for (i, row) in fitted.manifold.graph.outer_iterator().enumerate() {
let sum: f32 = row.data().iter().sum();
if sum == 0.0 {
for j in 0..fitted.embedding.shape()[1] {
fitted.embedding[(i, j)] = f32::NAN;
}
}
}
fitted
}
fn validate_parameters(
&self,
n_samples: usize,
knn_indices: &ArrayView2<u32>,
knn_dists: &ArrayView2<f32>,
) {
if self.config.graph.set_op_mix_ratio < 0.0 || self.config.graph.set_op_mix_ratio > 1.0 {
panic!(
"set_op_mix_ratio must be between 0.0 and 1.0, got {}",
self.config.graph.set_op_mix_ratio
);
}
if self.config.graph.n_neighbors < 2 {
panic!(
"n_neighbors must be >= 2, got {}",
self.config.graph.n_neighbors
);
}
if self.config.optimization.repulsion_strength < 0.0 {
panic!(
"repulsion_strength cannot be negative, got {}",
self.config.optimization.repulsion_strength
);
}
if self.config.manifold.min_dist > self.config.manifold.spread {
panic!(
"min_dist ({}) must be <= spread ({})",
self.config.manifold.min_dist, self.config.manifold.spread
);
}
if self.config.manifold.min_dist < 0.0 {
panic!(
"min_dist cannot be negative, got {}",
self.config.manifold.min_dist
);
}
if self.config.optimization.learning_rate < 0.0 {
panic!(
"learning_rate must be positive, got {}",
self.config.optimization.learning_rate
);
}
if self.config.n_components < 1 {
panic!(
"n_components must be >= 1, got {}",
self.config.n_components
);
}
if knn_dists.shape() != knn_indices.shape() {
panic!(
"knn_dists and knn_indices must have the same shape, got {:?} vs {:?}",
knn_dists.shape(),
knn_indices.shape()
);
}
if knn_dists.shape()[1] != self.config.graph.n_neighbors {
panic!(
"knn_dists has {} neighbors but n_neighbors is {}",
knn_dists.shape()[1],
self.config.graph.n_neighbors
);
}
if knn_dists.shape()[0] != n_samples {
panic!(
"knn_dists has {} samples but data has {} samples",
knn_dists.shape()[0],
n_samples
);
}
if n_samples <= self.config.graph.n_neighbors {
panic!(
"Number of samples ({}) must be > n_neighbors ({})",
n_samples, self.config.graph.n_neighbors
);
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FittedUmap {
pub(crate) embedding: Array2<f32>,
pub(crate) manifold: LearnedManifold,
pub(crate) config: UmapConfig,
}
impl FittedUmap {
pub fn embedding(&self) -> ArrayView2<'_, f32> {
self.embedding.view()
}
pub fn into_embedding(self) -> Array2<f32> {
self.embedding
}
pub fn manifold(&self) -> &LearnedManifold {
&self.manifold
}
pub fn config(&self) -> &UmapConfig {
&self.config
}
#[allow(unused_variables)]
pub fn transform(
&self,
new_data: ArrayView2<f32>,
new_knn_indices: ArrayView2<u32>,
new_knn_dists: ArrayView2<f32>,
) -> Array2<f32> {
todo!("Transform not yet implemented - contributions welcome!")
}
}