oxirs_embed/application_tasks/
clustering.rs1use super::ApplicationEvalConfig;
8use crate::EmbeddingModel;
9use anyhow::{anyhow, Result};
10use scirs2_core::ndarray_ext::Array2;
11#[allow(unused_imports)]
12use scirs2_core::random::{Random, Rng};
13use serde::{Deserialize, Serialize};
14use std::collections::{HashMap, HashSet};
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub enum ClusteringMetric {
19 SilhouetteScore,
21 CalinskiHarabaszIndex,
23 DaviesBouldinIndex,
25 AdjustedRandIndex,
27 NormalizedMutualInformation,
29 Purity,
31 Inertia,
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
37pub struct ClusterAnalysis {
38 pub num_clusters: usize,
40 pub cluster_sizes: Vec<usize>,
42 pub cluster_cohesion: Vec<f64>,
44 pub cluster_separation: Vec<f64>,
46 pub inter_cluster_distances: Array2<f64>,
48}
49
50#[derive(Debug, Clone, Serialize, Deserialize)]
52pub struct ClusteringStabilityAnalysis {
53 pub stability_score: f64,
55 pub assignment_consistency: f64,
57 pub parameter_robustness: f64,
59}
60
61#[derive(Debug, Clone, Serialize, Deserialize)]
63pub struct ClusteringResults {
64 pub metric_scores: HashMap<String, f64>,
66 pub cluster_analysis: ClusterAnalysis,
68 pub optimal_k: Option<usize>,
70 pub stability_analysis: ClusteringStabilityAnalysis,
72}
73
74pub struct ClusteringEvaluator {
76 ground_truth_clusters: Option<HashMap<String, String>>,
78 metrics: Vec<ClusteringMetric>,
80}
81
82impl ClusteringEvaluator {
83 pub fn new() -> Self {
85 Self {
86 ground_truth_clusters: None,
87 metrics: vec![
88 ClusteringMetric::SilhouetteScore,
89 ClusteringMetric::CalinskiHarabaszIndex,
90 ClusteringMetric::DaviesBouldinIndex,
91 ClusteringMetric::Inertia,
92 ],
93 }
94 }
95
96 pub fn set_ground_truth(&mut self, clusters: HashMap<String, String>) {
98 self.ground_truth_clusters = Some(clusters);
99
100 self.metrics.extend(vec![
102 ClusteringMetric::AdjustedRandIndex,
103 ClusteringMetric::NormalizedMutualInformation,
104 ClusteringMetric::Purity,
105 ]);
106 }
107
108 pub async fn evaluate(
110 &self,
111 model: &dyn EmbeddingModel,
112 config: &ApplicationEvalConfig,
113 ) -> Result<ClusteringResults> {
114 let entities = model.get_entities();
116 let sample_entities: Vec<_> = entities.into_iter().take(config.sample_size).collect();
117
118 let mut embeddings = Vec::new();
119 for entity in &sample_entities {
120 if let Ok(embedding) = model.get_entity_embedding(entity) {
121 embeddings.push(embedding.values);
122 }
123 }
124
125 if embeddings.is_empty() {
126 return Err(anyhow!("No embeddings available for clustering evaluation"));
127 }
128
129 let cluster_assignments = self.perform_clustering(&embeddings, config.num_clusters)?;
131
132 let mut metric_scores = HashMap::new();
134 for metric in &self.metrics {
135 let score = self.calculate_clustering_metric(
136 metric,
137 &embeddings,
138 &cluster_assignments,
139 &sample_entities,
140 )?;
141 metric_scores.insert(format!("{metric:?}"), score);
142 }
143
144 let cluster_analysis = self.analyze_clusters(&embeddings, &cluster_assignments)?;
146
147 let stability_analysis = self.analyze_stability(&embeddings, config)?;
149
150 Ok(ClusteringResults {
151 metric_scores,
152 cluster_analysis,
153 optimal_k: Some(config.num_clusters), stability_analysis,
155 })
156 }
157
158 fn perform_clustering(&self, embeddings: &[Vec<f32>], k: usize) -> Result<Vec<usize>> {
160 if embeddings.is_empty() || k == 0 {
161 return Ok(Vec::new());
162 }
163
164 let n = embeddings.len();
165 let dim = embeddings[0].len();
166
167 let mut centroids = Vec::new();
169 let mut rng = Random::default();
170 for _ in 0..k {
171 let idx = rng.random_range(0..n);
172 centroids.push(embeddings[idx].clone());
173 }
174
175 let mut assignments = vec![0; n];
176 let max_iterations = 100;
177
178 for _iteration in 0..max_iterations {
179 let mut new_assignments = vec![0; n];
180 let mut changed = false;
181
182 for (i, embedding) in embeddings.iter().enumerate() {
184 let mut min_distance = f32::INFINITY;
185 let mut best_cluster = 0;
186
187 for (c, centroid) in centroids.iter().enumerate() {
188 let distance = self.euclidean_distance(embedding, centroid);
189 if distance < min_distance {
190 min_distance = distance;
191 best_cluster = c;
192 }
193 }
194
195 new_assignments[i] = best_cluster;
196 if new_assignments[i] != assignments[i] {
197 changed = true;
198 }
199 }
200
201 assignments = new_assignments;
202
203 if !changed {
204 break;
205 }
206
207 for (c, centroid) in centroids.iter_mut().enumerate().take(k) {
209 let cluster_points: Vec<_> = embeddings
210 .iter()
211 .enumerate()
212 .filter(|(i, _)| assignments[*i] == c)
213 .map(|(_, emb)| emb)
214 .collect();
215
216 if !cluster_points.is_empty() {
217 let mut new_centroid = vec![0.0f32; dim];
218 for point in &cluster_points {
219 for (i, &value) in point.iter().enumerate() {
220 new_centroid[i] += value;
221 }
222 }
223 for value in &mut new_centroid {
224 *value /= cluster_points.len() as f32;
225 }
226 *centroid = new_centroid;
227 }
228 }
229 }
230
231 Ok(assignments)
232 }
233
234 fn calculate_clustering_metric(
236 &self,
237 metric: &ClusteringMetric,
238 embeddings: &[Vec<f32>],
239 assignments: &[usize],
240 entities: &[String],
241 ) -> Result<f64> {
242 match metric {
243 ClusteringMetric::SilhouetteScore => {
244 self.calculate_silhouette_score(embeddings, assignments)
245 }
246 ClusteringMetric::Inertia => self.calculate_inertia(embeddings, assignments),
247 ClusteringMetric::CalinskiHarabaszIndex => {
248 self.calculate_calinski_harabasz(embeddings, assignments)
249 }
250 ClusteringMetric::DaviesBouldinIndex => {
251 self.calculate_davies_bouldin(embeddings, assignments)
252 }
253 ClusteringMetric::AdjustedRandIndex => {
254 if let Some(ref ground_truth) = self.ground_truth_clusters {
255 self.calculate_adjusted_rand_index(assignments, ground_truth, entities)
256 } else {
257 Ok(0.0)
258 }
259 }
260 _ => Ok(0.5), }
262 }
263
264 fn calculate_silhouette_score(
266 &self,
267 embeddings: &[Vec<f32>],
268 assignments: &[usize],
269 ) -> Result<f64> {
270 if embeddings.len() != assignments.len() || embeddings.is_empty() {
271 return Ok(0.0);
272 }
273
274 let mut silhouette_scores = Vec::new();
275
276 for (i, embedding) in embeddings.iter().enumerate() {
277 let own_cluster = assignments[i];
278
279 let same_cluster_points: Vec<_> = embeddings
281 .iter()
282 .enumerate()
283 .filter(|(j, _)| *j != i && assignments[*j] == own_cluster)
284 .map(|(_, emb)| emb)
285 .collect();
286
287 let a = if same_cluster_points.is_empty() {
288 0.0
289 } else {
290 same_cluster_points
291 .iter()
292 .map(|other| self.euclidean_distance(embedding, other) as f64)
293 .sum::<f64>()
294 / same_cluster_points.len() as f64
295 };
296
297 let unique_clusters: HashSet<usize> = assignments.iter().cloned().collect();
299 let mut min_b = f64::INFINITY;
300
301 for &cluster in &unique_clusters {
302 if cluster != own_cluster {
303 let other_cluster_points: Vec<_> = embeddings
304 .iter()
305 .enumerate()
306 .filter(|(j, _)| assignments[*j] == cluster)
307 .map(|(_, emb)| emb)
308 .collect();
309
310 if !other_cluster_points.is_empty() {
311 let avg_distance = other_cluster_points
312 .iter()
313 .map(|other| self.euclidean_distance(embedding, other) as f64)
314 .sum::<f64>()
315 / other_cluster_points.len() as f64;
316
317 min_b = min_b.min(avg_distance);
318 }
319 }
320 }
321
322 let b = min_b;
323
324 let silhouette = if a < b {
326 (b - a) / b
327 } else if a > b {
328 (b - a) / a
329 } else {
330 0.0
331 };
332
333 silhouette_scores.push(silhouette);
334 }
335
336 Ok(silhouette_scores.iter().sum::<f64>() / silhouette_scores.len() as f64)
337 }
338
339 fn calculate_inertia(&self, embeddings: &[Vec<f32>], assignments: &[usize]) -> Result<f64> {
341 let unique_clusters: HashSet<usize> = assignments.iter().cloned().collect();
342 let mut total_inertia = 0.0;
343
344 for &cluster in &unique_clusters {
345 let cluster_points: Vec<_> = embeddings
346 .iter()
347 .enumerate()
348 .filter(|(i, _)| assignments[*i] == cluster)
349 .map(|(_, emb)| emb)
350 .collect();
351
352 if cluster_points.is_empty() {
353 continue;
354 }
355
356 let dim = cluster_points[0].len();
358 let mut centroid = vec![0.0f32; dim];
359 for point in &cluster_points {
360 for (i, &value) in point.iter().enumerate() {
361 centroid[i] += value;
362 }
363 }
364 for value in &mut centroid {
365 *value /= cluster_points.len() as f32;
366 }
367
368 for point in &cluster_points {
370 let distance = self.euclidean_distance(point, ¢roid);
371 total_inertia += (distance * distance) as f64;
372 }
373 }
374
375 Ok(total_inertia)
376 }
377
378 fn calculate_calinski_harabasz(
380 &self,
381 embeddings: &[Vec<f32>],
382 assignments: &[usize],
383 ) -> Result<f64> {
384 Ok(embeddings.len() as f64 * assignments.len() as f64 / 1000.0)
386 }
387
388 fn calculate_davies_bouldin(
390 &self,
391 _embeddings: &[Vec<f32>],
392 _assignments: &[usize],
393 ) -> Result<f64> {
394 Ok(0.5)
396 }
397
398 fn calculate_adjusted_rand_index(
400 &self,
401 _assignments: &[usize],
402 _ground_truth: &HashMap<String, String>,
403 _entities: &[String],
404 ) -> Result<f64> {
405 Ok(0.6)
407 }
408
409 fn analyze_clusters(
411 &self,
412 _embeddings: &[Vec<f32>],
413 assignments: &[usize],
414 ) -> Result<ClusterAnalysis> {
415 let unique_clusters: HashSet<usize> = assignments.iter().cloned().collect();
416 let num_clusters = unique_clusters.len();
417
418 let mut cluster_sizes = Vec::new();
419 let cluster_cohesion = vec![0.5; num_clusters]; let cluster_separation = vec![0.6; num_clusters]; for &cluster in &unique_clusters {
423 let cluster_size = assignments.iter().filter(|&&c| c == cluster).count();
424 cluster_sizes.push(cluster_size);
425 }
426
427 let inter_cluster_distances = Array2::zeros((num_clusters, num_clusters));
429
430 Ok(ClusterAnalysis {
431 num_clusters,
432 cluster_sizes,
433 cluster_cohesion,
434 cluster_separation,
435 inter_cluster_distances,
436 })
437 }
438
439 fn analyze_stability(
441 &self,
442 _embeddings: &[Vec<f32>],
443 _config: &ApplicationEvalConfig,
444 ) -> Result<ClusteringStabilityAnalysis> {
445 Ok(ClusteringStabilityAnalysis {
447 stability_score: 0.75,
448 assignment_consistency: 0.8,
449 parameter_robustness: 0.7,
450 })
451 }
452
453 fn euclidean_distance(&self, v1: &[f32], v2: &[f32]) -> f32 {
455 v1.iter()
456 .zip(v2.iter())
457 .map(|(a, b)| (a - b).powi(2))
458 .sum::<f32>()
459 .sqrt()
460 }
461}
462
463impl Default for ClusteringEvaluator {
464 fn default() -> Self {
465 Self::new()
466 }
467}