use crate::error::{ClusteringError, Result};
use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
use scirs2_core::numeric::Float;
use std::fmt::Debug;
#[allow(dead_code)]
pub fn leader_clustering<F, D>(
data: ArrayView2<F>,
threshold: F,
metric: D,
) -> Result<(Array2<F>, Array1<usize>)>
where
F: Float + Debug,
D: Fn(ArrayView1<F>, ArrayView1<F>) -> F,
{
if data.is_empty() {
return Err(ClusteringError::InvalidInput(
"Input data is empty".to_string(),
));
}
if threshold <= F::zero() {
return Err(ClusteringError::InvalidInput(
"Threshold must be positive".to_string(),
));
}
let n_samples = data.nrows();
let n_features = data.ncols();
let mut leaders: Vec<Array1<F>> = Vec::new();
let mut labels = Array1::zeros(n_samples);
for (i, sample) in data.rows().into_iter().enumerate() {
let mut min_distance = F::infinity();
let mut closest_leader = 0;
for (j, leader) in leaders.iter().enumerate() {
let distance = metric(sample, leader.view());
if distance < min_distance {
min_distance = distance;
closest_leader = j;
}
}
if leaders.is_empty() || min_distance > threshold {
leaders.push(sample.to_owned());
let label_idx = leaders.len() - 1;
labels[i] = label_idx;
} else {
labels[i] = closest_leader;
}
}
let n_leaders = leaders.len();
let mut leaders_array = Array2::zeros((n_leaders, n_features));
for (i, leader) in leaders.iter().enumerate() {
leaders_array.row_mut(i).assign(leader);
}
Ok((leaders_array, labels))
}
#[allow(dead_code)]
pub fn euclidean_distance<F: Float>(a: ArrayView1<F>, b: ArrayView1<F>) -> F {
a.iter()
.zip(b.iter())
.map(|(x, y)| (*x - *y) * (*x - *y))
.fold(F::zero(), |acc, x| acc + x)
.sqrt()
}
#[allow(dead_code)]
pub fn manhattan_distance<F: Float>(a: ArrayView1<F>, b: ArrayView1<F>) -> F {
a.iter()
.zip(b.iter())
.map(|(x, y)| (*x - *y).abs())
.fold(F::zero(), |acc, x| acc + x)
}
pub struct LeaderClustering<F: Float> {
threshold: F,
leaders: Vec<Array1<F>>,
}
impl<F: Float + Debug> LeaderClustering<F> {
pub fn new(threshold: F) -> Result<Self> {
if threshold <= F::zero() {
return Err(ClusteringError::InvalidInput(
"Threshold must be positive".to_string(),
));
}
Ok(Self {
threshold,
leaders: Vec::new(),
})
}
pub fn fit(&mut self, data: ArrayView2<F>) -> Result<()> {
self.leaders.clear();
for sample in data.rows() {
let mut min_distance = F::infinity();
for leader in &self.leaders {
let distance = euclidean_distance(sample, leader.view());
if distance < min_distance {
min_distance = distance;
}
}
if self.leaders.is_empty() || min_distance > self.threshold {
self.leaders.push(sample.to_owned());
}
}
Ok(())
}
pub fn predict(&self, data: ArrayView2<F>) -> Result<Array1<usize>> {
if self.leaders.is_empty() {
return Err(ClusteringError::InvalidState(
"Model has not been fitted yet".to_string(),
));
}
let n_samples = data.nrows();
let mut labels = Array1::zeros(n_samples);
for (i, sample) in data.rows().into_iter().enumerate() {
let mut min_distance = F::infinity();
let mut closest_leader = 0;
for (j, leader) in self.leaders.iter().enumerate() {
let distance = euclidean_distance(sample, leader.view());
if distance < min_distance {
min_distance = distance;
closest_leader = j;
}
}
labels[i] = closest_leader;
}
Ok(labels)
}
pub fn fit_predict(&mut self, data: ArrayView2<F>) -> Result<Array1<usize>> {
self.fit(data)?;
self.predict(data)
}
pub fn get_leaders(&self) -> Array2<F> {
if self.leaders.is_empty() {
return Array2::zeros((0, 0));
}
let n_leaders = self.leaders.len();
let n_features = self.leaders[0].len();
let mut leaders_array = Array2::zeros((n_leaders, n_features));
for (i, leader) in self.leaders.iter().enumerate() {
leaders_array.row_mut(i).assign(leader);
}
leaders_array
}
pub fn n_clusters(&self) -> usize {
self.leaders.len()
}
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct LeaderTree<F: Float> {
pub roots: Vec<LeaderNode<F>>,
pub threshold: F,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct LeaderNode<F: Float> {
pub leader: Array1<F>,
pub children: Vec<LeaderNode<F>>,
pub members: Vec<usize>,
}
impl<F: Float + Debug> LeaderTree<F> {
pub fn build_hierarchical(data: ArrayView2<F>, thresholds: &[F]) -> Result<Self> {
if thresholds.is_empty() {
return Err(ClusteringError::InvalidInput(
"At least one threshold is required".to_string(),
));
}
let current_threshold = thresholds[0];
let (leaders, labels) = leader_clustering(data, current_threshold, euclidean_distance)?;
let mut roots = Vec::new();
for i in 0..leaders.nrows() {
let mut members = Vec::new();
for (j, &label) in labels.iter().enumerate() {
if label == i {
members.push(j);
}
}
roots.push(LeaderNode {
leader: leaders.row(i).to_owned(),
children: Vec::new(),
members,
});
}
if thresholds.len() > 1 {
for root in &mut roots {
Self::build_subtree(data, root, &thresholds[1..])?;
}
}
Ok(LeaderTree {
roots,
threshold: current_threshold,
})
}
fn build_subtree(
data: ArrayView2<F>,
parent: &mut LeaderNode<F>,
thresholds: &[F],
) -> Result<()> {
if thresholds.is_empty() || parent.members.len() <= 1 {
return Ok(());
}
let n_features = data.ncols();
let mut cluster_data = Array2::zeros((parent.members.len(), n_features));
for (i, &idx) in parent.members.iter().enumerate() {
cluster_data.row_mut(i).assign(&data.row(idx));
}
let (sub_leaders, sub_labels) =
leader_clustering(cluster_data.view(), thresholds[0], euclidean_distance)?;
for i in 0..sub_leaders.nrows() {
let mut members = Vec::new();
for (j, &label) in sub_labels.iter().enumerate() {
if label == i {
members.push(parent.members[j]);
}
}
let mut child = LeaderNode {
leader: sub_leaders.row(i).to_owned(),
children: Vec::new(),
members,
};
if thresholds.len() > 1 {
Self::build_subtree(data, &mut child, &thresholds[1..])?;
}
parent.children.push(child);
}
Ok(())
}
pub fn node_count(&self) -> usize {
self.roots.iter().map(|root| Self::count_nodes(root)).sum()
}
fn count_nodes(node: &LeaderNode<F>) -> usize {
1 + node
.children
.iter()
.map(|child| Self::count_nodes(child))
.sum::<usize>()
}
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::array;
#[test]
fn test_leader_clustering_basic() {
let data = array![[1.0, 2.0], [1.2, 1.8], [5.0, 4.0], [5.2, 4.1],];
let (leaders, labels) =
leader_clustering(data.view(), 1.0, euclidean_distance).expect("Operation failed");
assert_eq!(leaders.nrows(), 2);
assert_eq!(labels.len(), 4);
assert_eq!(labels[0], labels[1]);
assert_eq!(labels[2], labels[3]);
assert_ne!(labels[0], labels[2]);
}
#[test]
fn test_leader_clustering_single_cluster() {
let data = array![[1.0, 2.0], [1.2, 1.8], [1.1, 2.1], [0.9, 1.9],];
let (leaders, labels) =
leader_clustering(data.view(), 2.0, euclidean_distance).expect("Operation failed");
assert_eq!(leaders.nrows(), 1);
assert!(labels.iter().all(|&l| l == 0));
}
#[test]
fn test_leader_class() {
let data = array![[1.0, 2.0], [1.2, 1.8], [5.0, 4.0], [5.2, 4.1],];
let mut leader = LeaderClustering::new(1.0).expect("Operation failed");
let labels = leader.fit_predict(data.view()).expect("Operation failed");
assert_eq!(leader.n_clusters(), 2);
assert_eq!(labels.len(), 4);
let new_data = array![[1.1, 1.9], [5.1, 4.05]];
let new_labels = leader.predict(new_data.view()).expect("Operation failed");
assert_eq!(new_labels[0], labels[0]); assert_eq!(new_labels[1], labels[2]); }
#[test]
fn test_hierarchical_leader_tree() {
let data = array![
[1.0, 2.0],
[1.2, 1.8],
[5.0, 4.0],
[5.2, 4.1],
[10.0, 10.0],
[10.2, 9.8],
];
let thresholds = vec![6.0, 1.0];
let tree =
LeaderTree::build_hierarchical(data.view(), &thresholds).expect("Operation failed");
assert!(tree.roots.len() <= 3);
assert!(tree.node_count() > tree.roots.len()); }
#[test]
fn test_invalid_threshold() {
let data = array![[1.0, 2.0]];
let result = leader_clustering(data.view(), -1.0, euclidean_distance);
assert!(result.is_err());
let result = LeaderClustering::<f64>::new(-1.0);
assert!(result.is_err());
}
#[test]
fn test_empty_data() {
let data: Array2<f64> = Array2::zeros((0, 2));
let result = leader_clustering(data.view(), 1.0, euclidean_distance);
assert!(result.is_err());
}
}