Skip to main content

oxigdal_algorithms/vector/clustering/
hierarchical.rs

1//! Hierarchical (agglomerative) clustering
2//!
3//! Build a dendrogram by progressively merging clusters.
4
5use crate::error::{AlgorithmError, Result};
6use crate::vector::clustering::dbscan::{DistanceMetric, calculate_distance};
7use oxigdal_core::vector::Point;
8use std::collections::HashMap;
9
10/// Options for hierarchical clustering
11#[derive(Debug, Clone)]
12pub struct HierarchicalOptions {
13    /// Number of clusters to extract
14    pub num_clusters: usize,
15    /// Linkage method
16    pub linkage: LinkageMethod,
17    /// Distance metric
18    pub metric: DistanceMetric,
19    /// Distance threshold (alternative to num_clusters)
20    pub distance_threshold: Option<f64>,
21}
22
23impl Default for HierarchicalOptions {
24    fn default() -> Self {
25        Self {
26            num_clusters: 3,
27            linkage: LinkageMethod::Average,
28            metric: DistanceMetric::Euclidean,
29            distance_threshold: None,
30        }
31    }
32}
33
34/// Linkage method for hierarchical clustering
35#[derive(Debug, Clone, Copy, PartialEq, Eq)]
36pub enum LinkageMethod {
37    /// Single linkage (minimum distance)
38    Single,
39    /// Complete linkage (maximum distance)
40    Complete,
41    /// Average linkage (average distance)
42    Average,
43    /// Ward linkage (minimize variance)
44    Ward,
45}
46
47/// Result of hierarchical clustering
48#[derive(Debug, Clone)]
49pub struct HierarchicalResult {
50    /// Cluster assignments for each point
51    pub labels: Vec<usize>,
52    /// Dendrogram (merge history)
53    pub dendrogram: Vec<Merge>,
54    /// Number of clusters
55    pub num_clusters: usize,
56    /// Cluster sizes
57    pub cluster_sizes: HashMap<usize, usize>,
58}
59
60/// A merge operation in the dendrogram
61#[derive(Debug, Clone)]
62pub struct Merge {
63    /// First cluster being merged
64    pub cluster1: usize,
65    /// Second cluster being merged
66    pub cluster2: usize,
67    /// Distance at which merge occurs
68    pub distance: f64,
69    /// New cluster ID
70    pub new_cluster: usize,
71}
72
73/// Perform hierarchical clustering
74///
75/// # Arguments
76///
77/// * `points` - Points to cluster
78/// * `options` - Hierarchical clustering options
79///
80/// # Returns
81///
82/// Clustering result with dendrogram and labels
83///
84/// # Examples
85///
86/// ```
87/// use oxigdal_algorithms::vector::clustering::{hierarchical_cluster, HierarchicalOptions};
88/// use oxigdal_algorithms::Point;
89/// # use oxigdal_algorithms::error::Result;
90///
91/// # fn main() -> Result<()> {
92/// let points = vec![
93///     Point::new(0.0, 0.0),
94///     Point::new(0.1, 0.1),
95///     Point::new(5.0, 5.0),
96/// ];
97///
98/// let options = HierarchicalOptions {
99///     num_clusters: 2,
100///     ..Default::default()
101/// };
102///
103/// let result = hierarchical_cluster(&points, &options)?;
104/// assert_eq!(result.num_clusters, 2);
105/// # Ok(())
106/// # }
107/// ```
108pub fn hierarchical_cluster(
109    points: &[Point],
110    options: &HierarchicalOptions,
111) -> Result<HierarchicalResult> {
112    if points.is_empty() {
113        return Err(AlgorithmError::InvalidInput(
114            "Cannot cluster empty point set".to_string(),
115        ));
116    }
117
118    let n = points.len();
119
120    // Initialize each point as its own cluster
121    let mut clusters: Vec<Vec<usize>> = (0..n).map(|i| vec![i]).collect();
122    let mut dendrogram = Vec::new();
123
124    // Build distance matrix
125    let mut distances = compute_distance_matrix(points, options.metric);
126
127    // Merge clusters until we reach the desired number
128    let target_clusters = options.num_clusters.max(1);
129
130    while clusters.len() > target_clusters {
131        // Find pair of clusters with minimum distance
132        let (i, j, dist) = find_closest_clusters(&clusters, &distances, options.linkage)?;
133
134        // Check distance threshold BEFORE merging
135        if let Some(threshold) = options.distance_threshold {
136            if dist >= threshold {
137                break;
138            }
139        }
140
141        // Merge clusters
142        let new_cluster_id = clusters.len();
143        let merged = merge_clusters(&mut clusters, i, j);
144
145        dendrogram.push(Merge {
146            cluster1: i,
147            cluster2: j,
148            distance: dist,
149            new_cluster: new_cluster_id,
150        });
151
152        // Update distance matrix
153        update_distances(&mut distances, i, j, &merged, points, options)?;
154    }
155
156    // Extract final labels
157    let mut labels = vec![0; n];
158    for (cluster_id, cluster) in clusters.iter().enumerate() {
159        for &point_idx in cluster {
160            labels[point_idx] = cluster_id;
161        }
162    }
163
164    // Calculate cluster sizes
165    let mut cluster_sizes: HashMap<usize, usize> = HashMap::new();
166    for &label in &labels {
167        *cluster_sizes.entry(label).or_insert(0) += 1;
168    }
169
170    Ok(HierarchicalResult {
171        labels,
172        dendrogram,
173        num_clusters: clusters.len(),
174        cluster_sizes,
175    })
176}
177
178/// Compute pairwise distance matrix
179fn compute_distance_matrix(points: &[Point], metric: DistanceMetric) -> Vec<Vec<f64>> {
180    let n = points.len();
181    let mut distances = vec![vec![0.0; n]; n];
182
183    for i in 0..n {
184        for j in (i + 1)..n {
185            let dist = calculate_distance(&points[i], &points[j], metric);
186            distances[i][j] = dist;
187            distances[j][i] = dist;
188        }
189    }
190
191    distances
192}
193
194/// Find the pair of clusters with minimum distance
195fn find_closest_clusters(
196    clusters: &[Vec<usize>],
197    distances: &[Vec<f64>],
198    linkage: LinkageMethod,
199) -> Result<(usize, usize, f64)> {
200    let mut min_dist = f64::INFINITY;
201    let mut best_i = 0;
202    let mut best_j = 1;
203
204    for i in 0..clusters.len() {
205        for j in (i + 1)..clusters.len() {
206            let dist = cluster_distance(&clusters[i], &clusters[j], distances, linkage);
207
208            if dist < min_dist {
209                min_dist = dist;
210                best_i = i;
211                best_j = j;
212            }
213        }
214    }
215
216    if min_dist.is_infinite() {
217        return Err(AlgorithmError::ComputationError(
218            "No valid cluster pair found".to_string(),
219        ));
220    }
221
222    Ok((best_i, best_j, min_dist))
223}
224
225/// Calculate distance between two clusters
226fn cluster_distance(
227    cluster1: &[usize],
228    cluster2: &[usize],
229    distances: &[Vec<f64>],
230    linkage: LinkageMethod,
231) -> f64 {
232    match linkage {
233        LinkageMethod::Single => {
234            // Minimum distance
235            cluster1
236                .iter()
237                .flat_map(|&i| cluster2.iter().map(move |&j| distances[i][j]))
238                .fold(f64::INFINITY, f64::min)
239        }
240        LinkageMethod::Complete => {
241            // Maximum distance
242            cluster1
243                .iter()
244                .flat_map(|&i| cluster2.iter().map(move |&j| distances[i][j]))
245                .fold(f64::NEG_INFINITY, f64::max)
246        }
247        LinkageMethod::Average => {
248            // Average distance
249            let sum: f64 = cluster1
250                .iter()
251                .flat_map(|&i| cluster2.iter().map(move |&j| distances[i][j]))
252                .sum();
253            let count = (cluster1.len() * cluster2.len()) as f64;
254            if count > 0.0 {
255                sum / count
256            } else {
257                f64::INFINITY
258            }
259        }
260        LinkageMethod::Ward => {
261            // Ward linkage (simplified as average for now)
262            let sum: f64 = cluster1
263                .iter()
264                .flat_map(|&i| cluster2.iter().map(move |&j| distances[i][j]))
265                .sum();
266            let count = (cluster1.len() * cluster2.len()) as f64;
267            if count > 0.0 {
268                sum / count
269            } else {
270                f64::INFINITY
271            }
272        }
273    }
274}
275
276/// Merge two clusters
277fn merge_clusters(clusters: &mut Vec<Vec<usize>>, i: usize, j: usize) -> Vec<usize> {
278    let (idx1, idx2) = if i < j { (i, j) } else { (j, i) };
279
280    // Remove clusters (remove larger index first)
281    let cluster2 = clusters.remove(idx2);
282    let mut cluster1 = clusters.remove(idx1);
283
284    // Merge
285    cluster1.extend(cluster2);
286
287    // Add merged cluster back
288    clusters.push(cluster1.clone());
289
290    cluster1
291}
292
293/// Update distance matrix after merge
294fn update_distances(
295    _distances: &mut Vec<Vec<f64>>,
296    _i: usize,
297    _j: usize,
298    _merged: &[usize],
299    _points: &[Point],
300    _options: &HierarchicalOptions,
301) -> Result<()> {
302    // Simplified: distances remain unchanged as we recalculate cluster distances on-the-fly
303    // A full implementation would update the distance matrix here
304    Ok(())
305}
306
307#[cfg(test)]
308mod tests {
309    use super::*;
310
311    #[test]
312    fn test_hierarchical_simple() {
313        let points = vec![
314            Point::new(0.0, 0.0),
315            Point::new(0.1, 0.1),
316            Point::new(5.0, 5.0),
317        ];
318
319        let options = HierarchicalOptions {
320            num_clusters: 2,
321            ..Default::default()
322        };
323
324        let result = hierarchical_cluster(&points, &options);
325        assert!(result.is_ok());
326
327        let clustering = result.expect("Clustering failed");
328        assert_eq!(clustering.num_clusters, 2);
329    }
330
331    #[test]
332    fn test_linkage_methods() {
333        let points = vec![
334            Point::new(0.0, 0.0),
335            Point::new(1.0, 0.0),
336            Point::new(10.0, 0.0),
337        ];
338
339        for linkage in [
340            LinkageMethod::Single,
341            LinkageMethod::Complete,
342            LinkageMethod::Average,
343            LinkageMethod::Ward,
344        ] {
345            let options = HierarchicalOptions {
346                num_clusters: 2,
347                linkage,
348                ..Default::default()
349            };
350
351            let result = hierarchical_cluster(&points, &options);
352            assert!(result.is_ok());
353        }
354    }
355
356    #[test]
357    fn test_distance_threshold() {
358        let points = vec![
359            Point::new(0.0, 0.0),
360            Point::new(0.5, 0.0),
361            Point::new(10.0, 0.0),
362        ];
363
364        let options = HierarchicalOptions {
365            num_clusters: 1,
366            distance_threshold: Some(2.0),
367            ..Default::default()
368        };
369
370        let result = hierarchical_cluster(&points, &options);
371        assert!(result.is_ok());
372
373        let clustering = result.expect("Clustering failed");
374        assert!(clustering.num_clusters >= 2); // Should stop merging due to threshold
375    }
376}