use scirs2_core::ndarray::{Array2, ArrayView2};
use scirs2_core::numeric::{Float, FromPrimitive};
use std::collections::HashMap;
use std::fmt::Debug;
use crate::error::{ClusteringError, Result};
use crate::hierarchy::{LinkageMethod, Metric};
#[derive(Debug, Clone)]
pub struct SparseDistanceMatrix<F: Float> {
rows: Vec<usize>,
cols: Vec<usize>,
data: Vec<F>,
n_samples: usize,
default_value: F,
}
impl<F: Float + FromPrimitive> SparseDistanceMatrix<F> {
pub fn new(n_samples: usize, default_value: F) -> Self {
Self {
rows: Vec::new(),
cols: Vec::new(),
data: Vec::new(),
n_samples,
default_value,
}
}
pub fn from_dense(dense: ArrayView2<F>, threshold: F) -> Self {
let n_samples = dense.shape()[0];
let mut rows = Vec::new();
let mut cols = Vec::new();
let mut data = Vec::new();
for i in 0..n_samples {
for j in (i + 1)..n_samples {
let distance = dense[[i, j]];
if distance > threshold {
rows.push(i);
cols.push(j);
data.push(distance);
}
}
}
Self {
rows,
cols,
data,
n_samples,
default_value: F::zero(),
}
}
pub fn add_distance(&mut self, i: usize, j: usize, distance: F) -> Result<()> {
if i >= self.n_samples || j >= self.n_samples {
return Err(ClusteringError::InvalidInput("Index out of bounds".into()));
}
let (row, col) = if i < j { (i, j) } else { (j, i) };
for idx in 0..self.rows.len() {
if self.rows[idx] == row && self.cols[idx] == col {
if distance < self.data[idx] {
self.data[idx] = distance;
}
return Ok(());
}
}
self.rows.push(row);
self.cols.push(col);
self.data.push(distance);
Ok(())
}
pub fn get_distance(&self, i: usize, j: usize) -> F {
if i == j {
return F::zero();
}
let (row, col) = if i < j { (i, j) } else { (j, i) };
for idx in 0..self.rows.len() {
if self.rows[idx] == row && self.cols[idx] == col {
return self.data[idx];
}
}
self.default_value
}
pub fn neighbors_within_distance(&self, point: usize, maxdistance: F) -> Vec<(usize, F)> {
let mut neighbors = Vec::new();
for idx in 0..self.rows.len() {
let (neighbor, distance) = if self.rows[idx] == point {
(self.cols[idx], self.data[idx])
} else if self.cols[idx] == point {
(self.rows[idx], self.data[idx])
} else {
continue;
};
if distance <= maxdistance {
neighbors.push((neighbor, distance));
}
}
neighbors
}
pub fn k_nearest_neighbors(&self, point: usize, k: usize) -> Vec<(usize, F)> {
let mut all_neighbors = Vec::new();
for idx in 0..self.rows.len() {
let (neighbor, distance) = if self.rows[idx] == point {
(self.cols[idx], self.data[idx])
} else if self.cols[idx] == point {
(self.rows[idx], self.data[idx])
} else {
continue;
};
all_neighbors.push((neighbor, distance));
}
all_neighbors.sort_by(|a, b| a.1.partial_cmp(&b.1).expect("Operation failed"));
all_neighbors.truncate(k);
all_neighbors
}
pub fn to_dense(&self) -> Array2<F> {
let mut dense = Array2::from_elem((self.n_samples, self.n_samples), self.default_value);
for i in 0..self.n_samples {
dense[[i, i]] = F::zero();
}
for idx in 0..self.rows.len() {
let i = self.rows[idx];
let j = self.cols[idx];
let distance = self.data[idx];
dense[[i, j]] = distance;
dense[[j, i]] = distance;
}
dense
}
pub fn nnz(&self) -> usize {
self.data.len()
}
pub fn sparsity(&self) -> f64 {
let total_entries = self.n_samples * (self.n_samples - 1) / 2;
1.0 - (self.nnz() as f64 / total_entries as f64)
}
pub fn n_samples(&self) -> usize {
self.n_samples
}
}
pub struct SparseHierarchicalClustering<F: Float> {
sparse_matrix: SparseDistanceMatrix<F>,
linkage_method: LinkageMethod,
}
impl<F: Float + FromPrimitive + Debug + PartialOrd> SparseHierarchicalClustering<F> {
pub fn new(sparse_matrix: SparseDistanceMatrix<F>, linkage_method: LinkageMethod) -> Self {
Self {
sparse_matrix,
linkage_method,
}
}
pub fn fit(&self) -> Result<Array2<F>> {
let n_samples = self.sparse_matrix.n_samples();
if n_samples < 2 {
return Err(ClusteringError::InvalidInput(
"Need at least 2 samples for clustering".into(),
));
}
let mst_edges = self.minimum_spanning_tree()?;
self.mst_to_linkage(mst_edges)
}
fn minimum_spanning_tree(&self) -> Result<Vec<(usize, usize, F)>> {
let n_samples = self.sparse_matrix.n_samples();
let mut mst_edges = Vec::new();
let mut visited = vec![false; n_samples];
let mut min_edge: HashMap<usize, (usize, F)> = HashMap::new();
visited[0] = true;
for neighbor_idx in 0..self.sparse_matrix.rows.len() {
let (i, j) = (
self.sparse_matrix.rows[neighbor_idx],
self.sparse_matrix.cols[neighbor_idx],
);
let distance = self.sparse_matrix.data[neighbor_idx];
if i == 0 && !visited[j] {
min_edge.insert(j, (i, distance));
} else if j == 0 && !visited[i] {
min_edge.insert(i, (j, distance));
}
}
for _ in 1..n_samples {
let mut min_dist = F::infinity();
let mut min_vertex = 0;
let mut min_parent = 0;
for (&vertex, &(parent, distance)) in &min_edge {
if !visited[vertex] && distance < min_dist {
min_dist = distance;
min_vertex = vertex;
min_parent = parent;
}
}
if min_dist == F::infinity() {
min_dist = self.sparse_matrix.default_value;
}
mst_edges.push((min_parent, min_vertex, min_dist));
visited[min_vertex] = true;
for neighbor_idx in 0..self.sparse_matrix.rows.len() {
let (i, j) = (
self.sparse_matrix.rows[neighbor_idx],
self.sparse_matrix.cols[neighbor_idx],
);
let distance = self.sparse_matrix.data[neighbor_idx];
let (from_vertex, to_vertex) = if i == min_vertex && !visited[j] {
(i, j)
} else if j == min_vertex && !visited[i] {
(j, i)
} else {
continue;
};
match min_edge.get(&to_vertex) {
Some(&(_, current_dist)) if distance < current_dist => {
min_edge.insert(to_vertex, (from_vertex, distance));
}
None => {
min_edge.insert(to_vertex, (from_vertex, distance));
}
_ => {}
}
}
}
Ok(mst_edges)
}
fn mst_to_linkage(&self, mut mst_edges: Vec<(usize, usize, F)>) -> Result<Array2<F>> {
let n_samples = self.sparse_matrix.n_samples();
match self.linkage_method {
LinkageMethod::Single => {
mst_edges.sort_by(|a, b| a.2.partial_cmp(&b.2).expect("Operation failed"));
}
_ => {
}
}
let mut linkage_matrix = Array2::zeros((n_samples - 1, 4));
let mut cluster_map: HashMap<usize, usize> = HashMap::new();
let mut next_cluster_id = n_samples;
for i in 0..n_samples {
cluster_map.insert(i, i);
}
for (step, (i, j, distance)) in mst_edges.iter().enumerate() {
let cluster_i = cluster_map[i];
let cluster_j = cluster_map[j];
linkage_matrix[[step, 0]] = F::from(cluster_i).expect("Failed to convert to float");
linkage_matrix[[step, 1]] = F::from(cluster_j).expect("Failed to convert to float");
linkage_matrix[[step, 2]] = *distance;
linkage_matrix[[step, 3]] = F::from(2).expect("Failed to convert constant to float");
cluster_map.insert(*i, next_cluster_id);
cluster_map.insert(*j, next_cluster_id);
next_cluster_id += 1;
}
Ok(linkage_matrix)
}
}
#[allow(dead_code)]
pub fn sparse_knn_graph<F>(
data: ArrayView2<F>,
k: usize,
metric: Metric,
) -> Result<SparseDistanceMatrix<F>>
where
F: Float + FromPrimitive + Debug,
{
let n_samples = data.shape()[0];
let n_features = data.shape()[1];
if k >= n_samples {
return Err(ClusteringError::InvalidInput(
"k must be less than number of samples".into(),
));
}
let mut sparse_matrix = SparseDistanceMatrix::new(n_samples, F::infinity());
for i in 0..n_samples {
let mut distances: Vec<(usize, F)> = Vec::new();
for j in 0..n_samples {
if i == j {
continue;
}
let dist = match metric {
Metric::Euclidean => {
let mut sum = F::zero();
for k in 0..n_features {
let diff = data[[i, k]] - data[[j, k]];
sum = sum + diff * diff;
}
sum.sqrt()
}
Metric::Manhattan => {
let mut sum = F::zero();
for k in 0..n_features {
let diff = (data[[i, k]] - data[[j, k]]).abs();
sum = sum + diff;
}
sum
}
_ => {
return Err(ClusteringError::InvalidInput(
"Metric not yet supported for sparse KNN".into(),
));
}
};
distances.push((j, dist));
}
distances.sort_by(|a, b| a.1.partial_cmp(&b.1).expect("Operation failed"));
distances.truncate(k);
for (neighbor, distance) in distances {
sparse_matrix.add_distance(i, neighbor, distance)?;
}
}
Ok(sparse_matrix)
}
#[allow(dead_code)]
pub fn sparse_epsilon_graph<F>(
data: ArrayView2<F>,
epsilon: F,
metric: Metric,
) -> Result<SparseDistanceMatrix<F>>
where
F: Float + FromPrimitive + Debug,
{
let n_samples = data.shape()[0];
let n_features = data.shape()[1];
let mut sparse_matrix = SparseDistanceMatrix::new(n_samples, F::infinity());
for i in 0..n_samples {
for j in (i + 1)..n_samples {
let dist = match metric {
Metric::Euclidean => {
let mut sum = F::zero();
for k in 0..n_features {
let diff = data[[i, k]] - data[[j, k]];
sum = sum + diff * diff;
}
sum.sqrt()
}
Metric::Manhattan => {
let mut sum = F::zero();
for k in 0..n_features {
let diff = (data[[i, k]] - data[[j, k]]).abs();
sum = sum + diff;
}
sum
}
_ => {
return Err(ClusteringError::InvalidInput(
"Metric not yet supported for sparse epsilon graph".into(),
));
}
};
if dist <= epsilon {
sparse_matrix.add_distance(i, j, dist)?;
}
}
}
Ok(sparse_matrix)
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::Array2;
#[test]
fn test_sparse_distance_matrix_creation() {
let sparse_matrix = SparseDistanceMatrix::<f64>::new(5, 0.0);
assert_eq!(sparse_matrix.n_samples(), 5);
assert_eq!(sparse_matrix.nnz(), 0);
assert_eq!(sparse_matrix.sparsity(), 1.0);
}
#[test]
fn test_sparse_distance_matrix_add_distance() {
let mut sparse_matrix = SparseDistanceMatrix::new(3, 0.0);
sparse_matrix
.add_distance(0, 1, 2.0)
.expect("Operation failed");
sparse_matrix
.add_distance(1, 2, 3.0)
.expect("Operation failed");
assert_eq!(sparse_matrix.get_distance(0, 1), 2.0);
assert_eq!(sparse_matrix.get_distance(1, 0), 2.0); assert_eq!(sparse_matrix.get_distance(1, 2), 3.0);
assert_eq!(sparse_matrix.get_distance(0, 2), 0.0); assert_eq!(sparse_matrix.nnz(), 2);
}
#[test]
fn test_sparse_from_dense() {
let dense =
Array2::from_shape_vec((3, 3), vec![0.0, 1.0, 5.0, 1.0, 0.0, 2.0, 5.0, 2.0, 0.0])
.expect("Operation failed");
let sparse = SparseDistanceMatrix::from_dense(dense.view(), 1.5);
assert_eq!(sparse.nnz(), 2);
assert_eq!(sparse.get_distance(0, 2), 5.0);
assert_eq!(sparse.get_distance(1, 2), 2.0);
assert_eq!(sparse.get_distance(0, 1), 0.0); }
#[test]
fn test_neighbors_within_distance() {
let mut sparse_matrix = SparseDistanceMatrix::new(4, f64::INFINITY);
sparse_matrix
.add_distance(0, 1, 1.0)
.expect("Operation failed");
sparse_matrix
.add_distance(0, 2, 2.5)
.expect("Operation failed");
sparse_matrix
.add_distance(0, 3, 0.5)
.expect("Operation failed");
let neighbors = sparse_matrix.neighbors_within_distance(0, 2.0);
assert_eq!(neighbors.len(), 2);
let mut neighbor_distances: Vec<f64> = neighbors.iter().map(|(_, d)| *d).collect();
neighbor_distances.sort_by(|a, b| a.partial_cmp(b).expect("Operation failed"));
assert_eq!(neighbor_distances, vec![0.5, 1.0]);
}
#[test]
fn test_k_nearest_neighbors() {
let mut sparse_matrix = SparseDistanceMatrix::new(5, f64::INFINITY);
sparse_matrix
.add_distance(0, 1, 3.0)
.expect("Operation failed");
sparse_matrix
.add_distance(0, 2, 1.0)
.expect("Operation failed");
sparse_matrix
.add_distance(0, 3, 2.0)
.expect("Operation failed");
sparse_matrix
.add_distance(0, 4, 4.0)
.expect("Operation failed");
let knn = sparse_matrix.k_nearest_neighbors(0, 2);
assert_eq!(knn.len(), 2);
assert_eq!(knn[0], (2, 1.0)); assert_eq!(knn[1], (3, 2.0)); }
#[test]
fn test_sparse_knn_graph() {
let data = Array2::from_shape_vec((4, 2), vec![0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 5.0, 5.0])
.expect("Operation failed");
let sparse_graph =
sparse_knn_graph(data.view(), 2, Metric::Euclidean).expect("Operation failed");
assert!(sparse_graph.nnz() > 0);
assert!(sparse_graph.sparsity() > 0.0);
}
#[test]
fn test_sparse_epsilon_graph() {
let data = Array2::from_shape_vec((4, 2), vec![0.0, 0.0, 0.5, 0.0, 0.0, 0.5, 5.0, 5.0])
.expect("Operation failed");
let sparse_graph =
sparse_epsilon_graph(data.view(), 1.0, Metric::Euclidean).expect("Operation failed");
assert!(sparse_graph.nnz() >= 3);
assert!(sparse_graph.get_distance(0, 1) <= 1.0);
assert!(sparse_graph.get_distance(0, 2) <= 1.0);
}
#[test]
fn test_to_dense() {
let mut sparse_matrix = SparseDistanceMatrix::new(3, f64::INFINITY);
sparse_matrix
.add_distance(0, 1, 2.0)
.expect("Operation failed");
sparse_matrix
.add_distance(1, 2, 3.0)
.expect("Operation failed");
let dense = sparse_matrix.to_dense();
assert_eq!(dense.shape(), &[3, 3]);
assert_eq!(dense[[0, 1]], 2.0);
assert_eq!(dense[[1, 0]], 2.0); assert_eq!(dense[[1, 2]], 3.0);
assert_eq!(dense[[2, 1]], 3.0); assert_eq!(dense[[0, 0]], 0.0); assert_eq!(dense[[0, 2]], f64::INFINITY); }
}