oxigdal_algorithms/vector/clustering/
hierarchical.rs1use crate::error::{AlgorithmError, Result};
6use crate::vector::clustering::dbscan::{DistanceMetric, calculate_distance};
7use oxigdal_core::vector::Point;
8use std::collections::HashMap;
9
10#[derive(Debug, Clone)]
12pub struct HierarchicalOptions {
13 pub num_clusters: usize,
15 pub linkage: LinkageMethod,
17 pub metric: DistanceMetric,
19 pub distance_threshold: Option<f64>,
21}
22
23impl Default for HierarchicalOptions {
24 fn default() -> Self {
25 Self {
26 num_clusters: 3,
27 linkage: LinkageMethod::Average,
28 metric: DistanceMetric::Euclidean,
29 distance_threshold: None,
30 }
31 }
32}
33
34#[derive(Debug, Clone, Copy, PartialEq, Eq)]
36pub enum LinkageMethod {
37 Single,
39 Complete,
41 Average,
43 Ward,
45}
46
47#[derive(Debug, Clone)]
49pub struct HierarchicalResult {
50 pub labels: Vec<usize>,
52 pub dendrogram: Vec<Merge>,
54 pub num_clusters: usize,
56 pub cluster_sizes: HashMap<usize, usize>,
58}
59
60#[derive(Debug, Clone)]
62pub struct Merge {
63 pub cluster1: usize,
65 pub cluster2: usize,
67 pub distance: f64,
69 pub new_cluster: usize,
71}
72
73pub fn hierarchical_cluster(
109 points: &[Point],
110 options: &HierarchicalOptions,
111) -> Result<HierarchicalResult> {
112 if points.is_empty() {
113 return Err(AlgorithmError::InvalidInput(
114 "Cannot cluster empty point set".to_string(),
115 ));
116 }
117
118 let n = points.len();
119
120 let mut clusters: Vec<Vec<usize>> = (0..n).map(|i| vec![i]).collect();
122 let mut dendrogram = Vec::new();
123
124 let mut distances = compute_distance_matrix(points, options.metric);
126
127 let target_clusters = options.num_clusters.max(1);
129
130 while clusters.len() > target_clusters {
131 let (i, j, dist) = find_closest_clusters(&clusters, &distances, options.linkage)?;
133
134 if let Some(threshold) = options.distance_threshold {
136 if dist >= threshold {
137 break;
138 }
139 }
140
141 let new_cluster_id = clusters.len();
143 let merged = merge_clusters(&mut clusters, i, j);
144
145 dendrogram.push(Merge {
146 cluster1: i,
147 cluster2: j,
148 distance: dist,
149 new_cluster: new_cluster_id,
150 });
151
152 update_distances(&mut distances, i, j, &merged, points, options)?;
154 }
155
156 let mut labels = vec![0; n];
158 for (cluster_id, cluster) in clusters.iter().enumerate() {
159 for &point_idx in cluster {
160 labels[point_idx] = cluster_id;
161 }
162 }
163
164 let mut cluster_sizes: HashMap<usize, usize> = HashMap::new();
166 for &label in &labels {
167 *cluster_sizes.entry(label).or_insert(0) += 1;
168 }
169
170 Ok(HierarchicalResult {
171 labels,
172 dendrogram,
173 num_clusters: clusters.len(),
174 cluster_sizes,
175 })
176}
177
178fn compute_distance_matrix(points: &[Point], metric: DistanceMetric) -> Vec<Vec<f64>> {
180 let n = points.len();
181 let mut distances = vec![vec![0.0; n]; n];
182
183 for i in 0..n {
184 for j in (i + 1)..n {
185 let dist = calculate_distance(&points[i], &points[j], metric);
186 distances[i][j] = dist;
187 distances[j][i] = dist;
188 }
189 }
190
191 distances
192}
193
194fn find_closest_clusters(
196 clusters: &[Vec<usize>],
197 distances: &[Vec<f64>],
198 linkage: LinkageMethod,
199) -> Result<(usize, usize, f64)> {
200 let mut min_dist = f64::INFINITY;
201 let mut best_i = 0;
202 let mut best_j = 1;
203
204 for i in 0..clusters.len() {
205 for j in (i + 1)..clusters.len() {
206 let dist = cluster_distance(&clusters[i], &clusters[j], distances, linkage);
207
208 if dist < min_dist {
209 min_dist = dist;
210 best_i = i;
211 best_j = j;
212 }
213 }
214 }
215
216 if min_dist.is_infinite() {
217 return Err(AlgorithmError::ComputationError(
218 "No valid cluster pair found".to_string(),
219 ));
220 }
221
222 Ok((best_i, best_j, min_dist))
223}
224
225fn cluster_distance(
227 cluster1: &[usize],
228 cluster2: &[usize],
229 distances: &[Vec<f64>],
230 linkage: LinkageMethod,
231) -> f64 {
232 match linkage {
233 LinkageMethod::Single => {
234 cluster1
236 .iter()
237 .flat_map(|&i| cluster2.iter().map(move |&j| distances[i][j]))
238 .fold(f64::INFINITY, f64::min)
239 }
240 LinkageMethod::Complete => {
241 cluster1
243 .iter()
244 .flat_map(|&i| cluster2.iter().map(move |&j| distances[i][j]))
245 .fold(f64::NEG_INFINITY, f64::max)
246 }
247 LinkageMethod::Average => {
248 let sum: f64 = cluster1
250 .iter()
251 .flat_map(|&i| cluster2.iter().map(move |&j| distances[i][j]))
252 .sum();
253 let count = (cluster1.len() * cluster2.len()) as f64;
254 if count > 0.0 {
255 sum / count
256 } else {
257 f64::INFINITY
258 }
259 }
260 LinkageMethod::Ward => {
261 let sum: f64 = cluster1
263 .iter()
264 .flat_map(|&i| cluster2.iter().map(move |&j| distances[i][j]))
265 .sum();
266 let count = (cluster1.len() * cluster2.len()) as f64;
267 if count > 0.0 {
268 sum / count
269 } else {
270 f64::INFINITY
271 }
272 }
273 }
274}
275
276fn merge_clusters(clusters: &mut Vec<Vec<usize>>, i: usize, j: usize) -> Vec<usize> {
278 let (idx1, idx2) = if i < j { (i, j) } else { (j, i) };
279
280 let cluster2 = clusters.remove(idx2);
282 let mut cluster1 = clusters.remove(idx1);
283
284 cluster1.extend(cluster2);
286
287 clusters.push(cluster1.clone());
289
290 cluster1
291}
292
293fn update_distances(
295 _distances: &mut Vec<Vec<f64>>,
296 _i: usize,
297 _j: usize,
298 _merged: &[usize],
299 _points: &[Point],
300 _options: &HierarchicalOptions,
301) -> Result<()> {
302 Ok(())
305}
306
307#[cfg(test)]
308mod tests {
309 use super::*;
310
311 #[test]
312 fn test_hierarchical_simple() {
313 let points = vec![
314 Point::new(0.0, 0.0),
315 Point::new(0.1, 0.1),
316 Point::new(5.0, 5.0),
317 ];
318
319 let options = HierarchicalOptions {
320 num_clusters: 2,
321 ..Default::default()
322 };
323
324 let result = hierarchical_cluster(&points, &options);
325 assert!(result.is_ok());
326
327 let clustering = result.expect("Clustering failed");
328 assert_eq!(clustering.num_clusters, 2);
329 }
330
331 #[test]
332 fn test_linkage_methods() {
333 let points = vec![
334 Point::new(0.0, 0.0),
335 Point::new(1.0, 0.0),
336 Point::new(10.0, 0.0),
337 ];
338
339 for linkage in [
340 LinkageMethod::Single,
341 LinkageMethod::Complete,
342 LinkageMethod::Average,
343 LinkageMethod::Ward,
344 ] {
345 let options = HierarchicalOptions {
346 num_clusters: 2,
347 linkage,
348 ..Default::default()
349 };
350
351 let result = hierarchical_cluster(&points, &options);
352 assert!(result.is_ok());
353 }
354 }
355
356 #[test]
357 fn test_distance_threshold() {
358 let points = vec![
359 Point::new(0.0, 0.0),
360 Point::new(0.5, 0.0),
361 Point::new(10.0, 0.0),
362 ];
363
364 let options = HierarchicalOptions {
365 num_clusters: 1,
366 distance_threshold: Some(2.0),
367 ..Default::default()
368 };
369
370 let result = hierarchical_cluster(&points, &options);
371 assert!(result.is_ok());
372
373 let clustering = result.expect("Clustering failed");
374 assert!(clustering.num_clusters >= 2); }
376}