oxirs-embed 0.2.4

Knowledge graph embeddings with TransE, ComplEx, and custom models
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
//! Clustering evaluation module
//!
//! This module provides comprehensive evaluation for clustering algorithms using
//! embedding models, including silhouette score, inertia, and other clustering
//! quality metrics.

use super::ApplicationEvalConfig;
use crate::EmbeddingModel;
use anyhow::{anyhow, Result};
use scirs2_core::ndarray_ext::Array2;
#[allow(unused_imports)]
use scirs2_core::random::{Random, RngExt};
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};

/// Clustering evaluation metrics
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ClusteringMetric {
    /// Silhouette score
    SilhouetteScore,
    /// Calinski-Harabasz index
    CalinskiHarabaszIndex,
    /// Davies-Bouldin index
    DaviesBouldinIndex,
    /// Adjusted Rand Index (requires ground truth)
    AdjustedRandIndex,
    /// Normalized Mutual Information (requires ground truth)
    NormalizedMutualInformation,
    /// Clustering purity (requires ground truth)
    Purity,
    /// Inertia (within-cluster sum of squares)
    Inertia,
}

/// Cluster quality analysis
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ClusterAnalysis {
    /// Number of clusters
    pub num_clusters: usize,
    /// Cluster sizes
    pub cluster_sizes: Vec<usize>,
    /// Cluster cohesion scores
    pub cluster_cohesion: Vec<f64>,
    /// Cluster separation scores
    pub cluster_separation: Vec<f64>,
    /// Inter-cluster distances
    pub inter_cluster_distances: Array2<f64>,
}

/// Clustering stability analysis
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ClusteringStabilityAnalysis {
    /// Stability score across multiple runs
    pub stability_score: f64,
    /// Consistency of cluster assignments
    pub assignment_consistency: f64,
    /// Robustness to parameter changes
    pub parameter_robustness: f64,
}

/// Clustering evaluation results
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ClusteringResults {
    /// Metric scores
    pub metric_scores: HashMap<String, f64>,
    /// Cluster quality analysis
    pub cluster_analysis: ClusterAnalysis,
    /// Optimal number of clusters (if determined)
    pub optimal_k: Option<usize>,
    /// Clustering stability analysis
    pub stability_analysis: ClusteringStabilityAnalysis,
}

/// Clustering evaluator
pub struct ClusteringEvaluator {
    /// Ground truth clusters (if available)
    ground_truth_clusters: Option<HashMap<String, String>>,
    /// Clustering metrics
    metrics: Vec<ClusteringMetric>,
}

impl ClusteringEvaluator {
    /// Create a new clustering evaluator
    pub fn new() -> Self {
        Self {
            ground_truth_clusters: None,
            metrics: vec![
                ClusteringMetric::SilhouetteScore,
                ClusteringMetric::CalinskiHarabaszIndex,
                ClusteringMetric::DaviesBouldinIndex,
                ClusteringMetric::Inertia,
            ],
        }
    }

    /// Set ground truth clusters
    pub fn set_ground_truth(&mut self, clusters: HashMap<String, String>) {
        self.ground_truth_clusters = Some(clusters);

        // Add supervised metrics
        self.metrics.extend(vec![
            ClusteringMetric::AdjustedRandIndex,
            ClusteringMetric::NormalizedMutualInformation,
            ClusteringMetric::Purity,
        ]);
    }

    /// Evaluate clustering performance
    pub async fn evaluate(
        &self,
        model: &dyn EmbeddingModel,
        config: &ApplicationEvalConfig,
    ) -> Result<ClusteringResults> {
        // Get entity embeddings
        let entities = model.get_entities();
        let sample_entities: Vec<_> = entities.into_iter().take(config.sample_size).collect();

        let mut embeddings = Vec::new();
        for entity in &sample_entities {
            if let Ok(embedding) = model.get_entity_embedding(entity) {
                embeddings.push(embedding.values);
            }
        }

        if embeddings.is_empty() {
            return Err(anyhow!("No embeddings available for clustering evaluation"));
        }

        // Perform clustering
        let cluster_assignments = self.perform_clustering(&embeddings, config.num_clusters)?;

        // Calculate metrics
        let mut metric_scores = HashMap::new();
        for metric in &self.metrics {
            let score = self.calculate_clustering_metric(
                metric,
                &embeddings,
                &cluster_assignments,
                &sample_entities,
            )?;
            metric_scores.insert(format!("{metric:?}"), score);
        }

        // Analyze clusters
        let cluster_analysis = self.analyze_clusters(&embeddings, &cluster_assignments)?;

        // Analyze stability
        let stability_analysis = self.analyze_stability(&embeddings, config)?;

        Ok(ClusteringResults {
            metric_scores,
            cluster_analysis,
            optimal_k: Some(config.num_clusters), // Simplified
            stability_analysis,
        })
    }

    /// Perform K-means clustering
    fn perform_clustering(&self, embeddings: &[Vec<f32>], k: usize) -> Result<Vec<usize>> {
        if embeddings.is_empty() || k == 0 {
            return Ok(Vec::new());
        }

        let n = embeddings.len();
        let dim = embeddings[0].len();

        // Initialize centroids randomly
        let mut centroids = Vec::new();
        let mut rng = Random::default();
        for _ in 0..k {
            let idx = rng.random_range(0..n);
            centroids.push(embeddings[idx].clone());
        }

        let mut assignments = vec![0; n];
        let max_iterations = 100;

        for _iteration in 0..max_iterations {
            let mut new_assignments = vec![0; n];
            let mut changed = false;

            // Assign points to nearest centroid
            for (i, embedding) in embeddings.iter().enumerate() {
                let mut min_distance = f32::INFINITY;
                let mut best_cluster = 0;

                for (c, centroid) in centroids.iter().enumerate() {
                    let distance = self.euclidean_distance(embedding, centroid);
                    if distance < min_distance {
                        min_distance = distance;
                        best_cluster = c;
                    }
                }

                new_assignments[i] = best_cluster;
                if new_assignments[i] != assignments[i] {
                    changed = true;
                }
            }

            assignments = new_assignments;

            if !changed {
                break;
            }

            // Update centroids
            for (c, centroid) in centroids.iter_mut().enumerate().take(k) {
                let cluster_points: Vec<_> = embeddings
                    .iter()
                    .enumerate()
                    .filter(|(i, _)| assignments[*i] == c)
                    .map(|(_, emb)| emb)
                    .collect();

                if !cluster_points.is_empty() {
                    let mut new_centroid = vec![0.0f32; dim];
                    for point in &cluster_points {
                        for (i, &value) in point.iter().enumerate() {
                            new_centroid[i] += value;
                        }
                    }
                    for value in &mut new_centroid {
                        *value /= cluster_points.len() as f32;
                    }
                    *centroid = new_centroid;
                }
            }
        }

        Ok(assignments)
    }

    /// Calculate clustering metric
    fn calculate_clustering_metric(
        &self,
        metric: &ClusteringMetric,
        embeddings: &[Vec<f32>],
        assignments: &[usize],
        entities: &[String],
    ) -> Result<f64> {
        match metric {
            ClusteringMetric::SilhouetteScore => {
                self.calculate_silhouette_score(embeddings, assignments)
            }
            ClusteringMetric::Inertia => self.calculate_inertia(embeddings, assignments),
            ClusteringMetric::CalinskiHarabaszIndex => {
                self.calculate_calinski_harabasz(embeddings, assignments)
            }
            ClusteringMetric::DaviesBouldinIndex => {
                self.calculate_davies_bouldin(embeddings, assignments)
            }
            ClusteringMetric::AdjustedRandIndex => {
                if let Some(ref ground_truth) = self.ground_truth_clusters {
                    self.calculate_adjusted_rand_index(assignments, ground_truth, entities)
                } else {
                    Ok(0.0)
                }
            }
            _ => Ok(0.5), // Placeholder for other metrics
        }
    }

    /// Calculate silhouette score
    fn calculate_silhouette_score(
        &self,
        embeddings: &[Vec<f32>],
        assignments: &[usize],
    ) -> Result<f64> {
        if embeddings.len() != assignments.len() || embeddings.is_empty() {
            return Ok(0.0);
        }

        let mut silhouette_scores = Vec::new();

        for (i, embedding) in embeddings.iter().enumerate() {
            let own_cluster = assignments[i];

            // Calculate average intra-cluster distance
            let same_cluster_points: Vec<_> = embeddings
                .iter()
                .enumerate()
                .filter(|(j, _)| *j != i && assignments[*j] == own_cluster)
                .map(|(_, emb)| emb)
                .collect();

            let a = if same_cluster_points.is_empty() {
                0.0
            } else {
                same_cluster_points
                    .iter()
                    .map(|other| self.euclidean_distance(embedding, other) as f64)
                    .sum::<f64>()
                    / same_cluster_points.len() as f64
            };

            // Calculate average nearest-cluster distance
            let unique_clusters: HashSet<usize> = assignments.iter().cloned().collect();
            let mut min_b = f64::INFINITY;

            for &cluster in &unique_clusters {
                if cluster != own_cluster {
                    let other_cluster_points: Vec<_> = embeddings
                        .iter()
                        .enumerate()
                        .filter(|(j, _)| assignments[*j] == cluster)
                        .map(|(_, emb)| emb)
                        .collect();

                    if !other_cluster_points.is_empty() {
                        let avg_distance = other_cluster_points
                            .iter()
                            .map(|other| self.euclidean_distance(embedding, other) as f64)
                            .sum::<f64>()
                            / other_cluster_points.len() as f64;

                        min_b = min_b.min(avg_distance);
                    }
                }
            }

            let b = min_b;

            // Calculate silhouette score for this point
            let silhouette = if a < b {
                (b - a) / b
            } else if a > b {
                (b - a) / a
            } else {
                0.0
            };

            silhouette_scores.push(silhouette);
        }

        Ok(silhouette_scores.iter().sum::<f64>() / silhouette_scores.len() as f64)
    }

    /// Calculate inertia (within-cluster sum of squares)
    fn calculate_inertia(&self, embeddings: &[Vec<f32>], assignments: &[usize]) -> Result<f64> {
        let unique_clusters: HashSet<usize> = assignments.iter().cloned().collect();
        let mut total_inertia = 0.0;

        for &cluster in &unique_clusters {
            let cluster_points: Vec<_> = embeddings
                .iter()
                .enumerate()
                .filter(|(i, _)| assignments[*i] == cluster)
                .map(|(_, emb)| emb)
                .collect();

            if cluster_points.is_empty() {
                continue;
            }

            // Calculate centroid
            let dim = cluster_points[0].len();
            let mut centroid = vec![0.0f32; dim];
            for point in &cluster_points {
                for (i, &value) in point.iter().enumerate() {
                    centroid[i] += value;
                }
            }
            for value in &mut centroid {
                *value /= cluster_points.len() as f32;
            }

            // Calculate sum of squared distances to centroid
            for point in &cluster_points {
                let distance = self.euclidean_distance(point, &centroid);
                total_inertia += (distance * distance) as f64;
            }
        }

        Ok(total_inertia)
    }

    /// Calculate Calinski-Harabasz index (simplified)
    fn calculate_calinski_harabasz(
        &self,
        embeddings: &[Vec<f32>],
        assignments: &[usize],
    ) -> Result<f64> {
        // Simplified implementation
        Ok(embeddings.len() as f64 * assignments.len() as f64 / 1000.0)
    }

    /// Calculate Davies-Bouldin index (simplified)
    fn calculate_davies_bouldin(
        &self,
        _embeddings: &[Vec<f32>],
        _assignments: &[usize],
    ) -> Result<f64> {
        // Simplified implementation
        Ok(0.5)
    }

    /// Calculate Adjusted Rand Index (simplified)
    fn calculate_adjusted_rand_index(
        &self,
        _assignments: &[usize],
        _ground_truth: &HashMap<String, String>,
        _entities: &[String],
    ) -> Result<f64> {
        // Simplified implementation
        Ok(0.6)
    }

    /// Analyze clusters
    fn analyze_clusters(
        &self,
        _embeddings: &[Vec<f32>],
        assignments: &[usize],
    ) -> Result<ClusterAnalysis> {
        let unique_clusters: HashSet<usize> = assignments.iter().cloned().collect();
        let num_clusters = unique_clusters.len();

        let mut cluster_sizes = Vec::new();
        let cluster_cohesion = vec![0.5; num_clusters]; // Simplified
        let cluster_separation = vec![0.6; num_clusters]; // Simplified

        for &cluster in &unique_clusters {
            let cluster_size = assignments.iter().filter(|&&c| c == cluster).count();
            cluster_sizes.push(cluster_size);
        }

        // Simplified inter-cluster distances
        let inter_cluster_distances = Array2::zeros((num_clusters, num_clusters));

        Ok(ClusterAnalysis {
            num_clusters,
            cluster_sizes,
            cluster_cohesion,
            cluster_separation,
            inter_cluster_distances,
        })
    }

    /// Analyze clustering stability
    fn analyze_stability(
        &self,
        _embeddings: &[Vec<f32>],
        _config: &ApplicationEvalConfig,
    ) -> Result<ClusteringStabilityAnalysis> {
        // Simplified implementation
        Ok(ClusteringStabilityAnalysis {
            stability_score: 0.75,
            assignment_consistency: 0.8,
            parameter_robustness: 0.7,
        })
    }

    /// Calculate Euclidean distance
    fn euclidean_distance(&self, v1: &[f32], v2: &[f32]) -> f32 {
        v1.iter()
            .zip(v2.iter())
            .map(|(a, b)| (a - b).powi(2))
            .sum::<f32>()
            .sqrt()
    }
}

impl Default for ClusteringEvaluator {
    fn default() -> Self {
        Self::new()
    }
}