Skip to main content

oxirs_embed/
interpretability.rs

1//! Model Interpretability Tools
2//!
3//! This module provides tools for understanding and interpreting knowledge graph
4//! embeddings, including attention analysis, embedding similarity, feature importance,
5//! and counterfactual explanations.
6
7use anyhow::{anyhow, Result};
8use rayon::prelude::*;
9use scirs2_core::ndarray_ext::Array1;
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use tracing::info;
13
14/// Interpretation method
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
16pub enum InterpretationMethod {
17    /// Analyze embedding similarities
18    SimilarityAnalysis,
19    /// Feature importance (gradient-based)
20    FeatureImportance,
21    /// Counterfactual explanations
22    Counterfactual,
23    /// Nearest neighbors analysis
24    NearestNeighbors,
25}
26
27/// Interpretability configuration
28#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct InterpretabilityConfig {
30    /// Interpretation method
31    pub method: InterpretationMethod,
32    /// Top-K most important features/neighbors
33    pub top_k: usize,
34    /// Similarity threshold
35    pub similarity_threshold: f32,
36    /// Enable detailed analysis
37    pub detailed: bool,
38}
39
40impl Default for InterpretabilityConfig {
41    fn default() -> Self {
42        Self {
43            method: InterpretationMethod::SimilarityAnalysis,
44            top_k: 10,
45            similarity_threshold: 0.7,
46            detailed: false,
47        }
48    }
49}
50
51/// Similarity analysis result
52#[derive(Debug, Clone, Serialize, Deserialize)]
53pub struct SimilarityAnalysis {
54    /// Entity being analyzed
55    pub entity: String,
56    /// Most similar entities with scores
57    pub similar_entities: Vec<(String, f32)>,
58    /// Least similar entities with scores
59    pub dissimilar_entities: Vec<(String, f32)>,
60    /// Average similarity to all other entities
61    pub avg_similarity: f32,
62}
63
64/// Feature importance result
65#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct FeatureImportance {
67    /// Entity being analyzed
68    pub entity: String,
69    /// Feature indices and their importance scores
70    pub important_features: Vec<(usize, f32)>,
71    /// Feature statistics
72    pub feature_stats: FeatureStats,
73}
74
75/// Feature statistics
76#[derive(Debug, Clone, Serialize, Deserialize)]
77pub struct FeatureStats {
78    /// Mean feature values
79    pub mean: Vec<f32>,
80    /// Standard deviation of features
81    pub std: Vec<f32>,
82    /// Min feature values
83    pub min: Vec<f32>,
84    /// Max feature values
85    pub max: Vec<f32>,
86}
87
88/// Counterfactual explanation
89#[derive(Debug, Clone, Serialize, Deserialize)]
90pub struct CounterfactualExplanation {
91    /// Original entity
92    pub original: String,
93    /// Target entity (for comparison)
94    pub target: String,
95    /// Dimensions that need to change
96    pub required_changes: Vec<(usize, f32, f32)>, // (dim, from, to)
97    /// Estimated difficulty (0-1, higher is harder)
98    pub difficulty: f32,
99}
100
101/// Nearest neighbors analysis
102#[derive(Debug, Clone, Serialize, Deserialize)]
103pub struct NearestNeighborsAnalysis {
104    /// Entity being analyzed
105    pub entity: String,
106    /// Nearest neighbors with distances
107    pub neighbors: Vec<(String, f32)>,
108    /// Neighbor clusters (if detected)
109    pub neighbor_clusters: Vec<Vec<String>>,
110}
111
112/// Model interpretability analyzer
113pub struct InterpretabilityAnalyzer {
114    config: InterpretabilityConfig,
115}
116
117impl InterpretabilityAnalyzer {
118    /// Create new interpretability analyzer
119    pub fn new(config: InterpretabilityConfig) -> Self {
120        info!(
121            "Initialized interpretability analyzer: method={:?}, top_k={}",
122            config.method, config.top_k
123        );
124
125        Self { config }
126    }
127
128    /// Analyze a specific entity
129    pub fn analyze_entity(
130        &self,
131        entity: &str,
132        embeddings: &HashMap<String, Array1<f32>>,
133    ) -> Result<String> {
134        if !embeddings.contains_key(entity) {
135            return Err(anyhow!("Entity not found: {}", entity));
136        }
137
138        match self.config.method {
139            InterpretationMethod::SimilarityAnalysis => {
140                let analysis = self.similarity_analysis(entity, embeddings)?;
141                Ok(serde_json::to_string_pretty(&analysis)?)
142            }
143            InterpretationMethod::FeatureImportance => {
144                let importance = self.feature_importance(entity, embeddings)?;
145                Ok(serde_json::to_string_pretty(&importance)?)
146            }
147            InterpretationMethod::NearestNeighbors => {
148                let neighbors = self.nearest_neighbors_analysis(entity, embeddings)?;
149                Ok(serde_json::to_string_pretty(&neighbors)?)
150            }
151            InterpretationMethod::Counterfactual => {
152                Err(anyhow!("Counterfactual requires target entity"))
153            }
154        }
155    }
156
157    /// Analyze similarity between entities
158    pub fn similarity_analysis(
159        &self,
160        entity: &str,
161        embeddings: &HashMap<String, Array1<f32>>,
162    ) -> Result<SimilarityAnalysis> {
163        let entity_emb = &embeddings[entity];
164
165        // Compute similarities to all other entities
166        let mut similarities: Vec<(String, f32)> = embeddings
167            .par_iter()
168            .filter(|(e, _)| *e != entity)
169            .map(|(other, other_emb)| {
170                let sim = self.cosine_similarity(entity_emb, other_emb);
171                (other.clone(), sim)
172            })
173            .collect();
174
175        // Sort by similarity descending
176        similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
177
178        // Get top-K most similar
179        let similar_entities: Vec<(String, f32)> = similarities
180            .iter()
181            .take(self.config.top_k)
182            .cloned()
183            .collect();
184
185        // Get top-K least similar
186        let mut dissimilar_entities: Vec<(String, f32)> = similarities
187            .iter()
188            .rev()
189            .take(self.config.top_k)
190            .cloned()
191            .collect();
192        dissimilar_entities.reverse();
193
194        // Compute average similarity
195        let avg_similarity =
196            similarities.iter().map(|(_, sim)| sim).sum::<f32>() / similarities.len() as f32;
197
198        info!(
199            "Similarity analysis for '{}': avg_similarity={:.4}",
200            entity, avg_similarity
201        );
202
203        Ok(SimilarityAnalysis {
204            entity: entity.to_string(),
205            similar_entities,
206            dissimilar_entities,
207            avg_similarity,
208        })
209    }
210
211    /// Analyze feature importance for an entity
212    pub fn feature_importance(
213        &self,
214        entity: &str,
215        embeddings: &HashMap<String, Array1<f32>>,
216    ) -> Result<FeatureImportance> {
217        let entity_emb = &embeddings[entity];
218        let dim = entity_emb.len();
219
220        // Compute global feature statistics
221        let feature_stats = self.compute_feature_stats(embeddings);
222
223        // Compute importance as deviation from mean
224        let mut important_features: Vec<(usize, f32)> = (0..dim)
225            .map(|i| {
226                let value = entity_emb[i];
227                let mean = feature_stats.mean[i];
228                let std = feature_stats.std[i];
229
230                // Z-score based importance
231                let importance = if std > 0.0 {
232                    ((value - mean) / std).abs()
233                } else {
234                    0.0
235                };
236
237                (i, importance)
238            })
239            .collect();
240
241        // Sort by importance descending
242        important_features
243            .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
244
245        // Keep top-K
246        important_features.truncate(self.config.top_k);
247
248        info!(
249            "Feature importance for '{}': top feature has importance {:.4}",
250            entity,
251            important_features
252                .first()
253                .map(|(_, imp)| *imp)
254                .unwrap_or(0.0)
255        );
256
257        Ok(FeatureImportance {
258            entity: entity.to_string(),
259            important_features,
260            feature_stats,
261        })
262    }
263
264    /// Generate counterfactual explanation
265    pub fn counterfactual_explanation(
266        &self,
267        original: &str,
268        target: &str,
269        embeddings: &HashMap<String, Array1<f32>>,
270    ) -> Result<CounterfactualExplanation> {
271        let original_emb = embeddings
272            .get(original)
273            .ok_or_else(|| anyhow!("Original entity not found"))?;
274
275        let target_emb = embeddings
276            .get(target)
277            .ok_or_else(|| anyhow!("Target entity not found"))?;
278
279        // Identify dimensions that differ significantly
280        let mut required_changes = Vec::new();
281        let mut total_change = 0.0;
282
283        for i in 0..original_emb.len() {
284            let diff = (target_emb[i] - original_emb[i]).abs();
285            if diff > 0.1 {
286                // Threshold for significance
287                required_changes.push((i, original_emb[i], target_emb[i]));
288                total_change += diff;
289            }
290        }
291
292        // Sort by magnitude of change
293        required_changes.sort_by(|a, b| {
294            let diff_a = (a.2 - a.1).abs();
295            let diff_b = (b.2 - b.1).abs();
296            diff_b
297                .partial_cmp(&diff_a)
298                .unwrap_or(std::cmp::Ordering::Equal)
299        });
300
301        // Keep top-K most important changes
302        required_changes.truncate(self.config.top_k);
303
304        // Compute difficulty (normalized by embedding norm)
305        let norm = original_emb.dot(original_emb).sqrt();
306        let difficulty = if norm > 0.0 {
307            (total_change / norm).min(1.0)
308        } else {
309            1.0
310        };
311
312        info!(
313            "Counterfactual '{}' -> '{}': {} changes, difficulty={:.4}",
314            original,
315            target,
316            required_changes.len(),
317            difficulty
318        );
319
320        Ok(CounterfactualExplanation {
321            original: original.to_string(),
322            target: target.to_string(),
323            required_changes,
324            difficulty,
325        })
326    }
327
328    /// Analyze nearest neighbors
329    pub fn nearest_neighbors_analysis(
330        &self,
331        entity: &str,
332        embeddings: &HashMap<String, Array1<f32>>,
333    ) -> Result<NearestNeighborsAnalysis> {
334        let entity_emb = &embeddings[entity];
335
336        // Find nearest neighbors
337        let mut distances: Vec<(String, f32)> = embeddings
338            .par_iter()
339            .filter(|(e, _)| *e != entity)
340            .map(|(other, other_emb)| {
341                let dist = self.euclidean_distance(entity_emb, other_emb);
342                (other.clone(), dist)
343            })
344            .collect();
345
346        // Sort by distance ascending
347        distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
348
349        // Get top-K nearest neighbors
350        let neighbors: Vec<(String, f32)> =
351            distances.iter().take(self.config.top_k).cloned().collect();
352
353        // Attempt to cluster neighbors (simple distance-based clustering)
354        let neighbor_clusters = if self.config.detailed {
355            self.cluster_neighbors(&neighbors, embeddings)
356        } else {
357            vec![]
358        };
359
360        info!(
361            "Nearest neighbors for '{}': closest neighbor at distance {:.4}",
362            entity,
363            neighbors.first().map(|(_, d)| *d).unwrap_or(0.0)
364        );
365
366        Ok(NearestNeighborsAnalysis {
367            entity: entity.to_string(),
368            neighbors,
369            neighbor_clusters,
370        })
371    }
372
373    /// Batch analysis for multiple entities
374    pub fn batch_analysis(
375        &self,
376        entities: &[String],
377        embeddings: &HashMap<String, Array1<f32>>,
378    ) -> Result<HashMap<String, String>> {
379        let results: HashMap<String, String> = entities
380            .par_iter()
381            .filter_map(|entity| {
382                self.analyze_entity(entity, embeddings)
383                    .ok()
384                    .map(|analysis| (entity.clone(), analysis))
385            })
386            .collect();
387
388        Ok(results)
389    }
390
391    /// Compute global feature statistics
392    fn compute_feature_stats(&self, embeddings: &HashMap<String, Array1<f32>>) -> FeatureStats {
393        let n = embeddings.len() as f32;
394        let dim = embeddings
395            .values()
396            .next()
397            .expect("embeddings should not be empty")
398            .len();
399
400        let mut mean = vec![0.0; dim];
401        let mut m2 = vec![0.0; dim]; // For variance calculation
402        let mut min = vec![f32::INFINITY; dim];
403        let mut max = vec![f32::NEG_INFINITY; dim];
404
405        // Welford's online algorithm for mean and variance
406        for (count, emb) in embeddings.values().enumerate() {
407            let count_f = (count + 1) as f32;
408
409            for i in 0..dim {
410                let value = emb[i];
411
412                // Update min/max
413                min[i] = min[i].min(value);
414                max[i] = max[i].max(value);
415
416                // Update mean and M2
417                let delta = value - mean[i];
418                mean[i] += delta / count_f;
419                let delta2 = value - mean[i];
420                m2[i] += delta * delta2;
421            }
422        }
423
424        // Compute standard deviation
425        let std: Vec<f32> = m2.iter().map(|&m2_val| (m2_val / n).sqrt()).collect();
426
427        FeatureStats {
428            mean,
429            std,
430            min,
431            max,
432        }
433    }
434
435    /// Cluster neighbors based on distance
436    fn cluster_neighbors(
437        &self,
438        neighbors: &[(String, f32)],
439        embeddings: &HashMap<String, Array1<f32>>,
440    ) -> Vec<Vec<String>> {
441        if neighbors.len() < 2 {
442            return vec![neighbors.iter().map(|(e, _)| e.clone()).collect()];
443        }
444
445        // Simple single-linkage clustering
446        let mut clusters: Vec<Vec<String>> = Vec::new();
447        let distance_threshold = 0.5; // Threshold for clustering
448
449        for (entity, _) in neighbors {
450            let entity_emb = &embeddings[entity];
451            let mut assigned = false;
452
453            // Try to assign to existing cluster
454            for cluster in &mut clusters {
455                let cluster_center = cluster
456                    .first()
457                    .expect("collection validated to be non-empty");
458                let center_emb = &embeddings[cluster_center];
459                let dist = self.euclidean_distance(entity_emb, center_emb);
460
461                if dist <= distance_threshold {
462                    cluster.push(entity.clone());
463                    assigned = true;
464                    break;
465                }
466            }
467
468            // Create new cluster if not assigned
469            if !assigned {
470                clusters.push(vec![entity.clone()]);
471            }
472        }
473
474        clusters
475    }
476
477    /// Cosine similarity between two embeddings
478    fn cosine_similarity(&self, a: &Array1<f32>, b: &Array1<f32>) -> f32 {
479        let dot = a.dot(b);
480        let norm_a = a.dot(a).sqrt();
481        let norm_b = b.dot(b).sqrt();
482
483        if norm_a == 0.0 || norm_b == 0.0 {
484            0.0
485        } else {
486            dot / (norm_a * norm_b)
487        }
488    }
489
490    /// Euclidean distance between two embeddings
491    fn euclidean_distance(&self, a: &Array1<f32>, b: &Array1<f32>) -> f32 {
492        let diff = a - b;
493        diff.dot(&diff).sqrt()
494    }
495
496    /// Generate interpretation report
497    pub fn generate_report(
498        &self,
499        entity: &str,
500        embeddings: &HashMap<String, Array1<f32>>,
501    ) -> Result<String> {
502        let mut report = String::new();
503
504        report.push_str(&format!("# Interpretability Report for '{}'\n\n", entity));
505
506        // Similarity analysis
507        if let Ok(sim_analysis) = self.similarity_analysis(entity, embeddings) {
508            report.push_str("## Similarity Analysis\n\n");
509            report.push_str(&format!(
510                "Average similarity: {:.4}\n\n",
511                sim_analysis.avg_similarity
512            ));
513
514            report.push_str("### Most Similar Entities:\n");
515            for (i, (other, score)) in sim_analysis.similar_entities.iter().enumerate() {
516                report.push_str(&format!(
517                    "{}. {} (similarity: {:.4})\n",
518                    i + 1,
519                    other,
520                    score
521                ));
522            }
523
524            report.push_str("\n### Least Similar Entities:\n");
525            for (i, (other, score)) in sim_analysis.dissimilar_entities.iter().enumerate() {
526                report.push_str(&format!(
527                    "{}. {} (similarity: {:.4})\n",
528                    i + 1,
529                    other,
530                    score
531                ));
532            }
533            report.push('\n');
534        }
535
536        // Feature importance
537        if let Ok(feat_importance) = self.feature_importance(entity, embeddings) {
538            report.push_str("## Feature Importance\n\n");
539            report.push_str("### Top Important Features:\n");
540            for (i, (feature_idx, importance)) in
541                feat_importance.important_features.iter().enumerate()
542            {
543                report.push_str(&format!(
544                    "{}. Dimension {} (importance: {:.4})\n",
545                    i + 1,
546                    feature_idx,
547                    importance
548                ));
549            }
550            report.push('\n');
551        }
552
553        // Nearest neighbors
554        if let Ok(nn_analysis) = self.nearest_neighbors_analysis(entity, embeddings) {
555            report.push_str("## Nearest Neighbors\n\n");
556            for (i, (neighbor, distance)) in nn_analysis.neighbors.iter().enumerate() {
557                report.push_str(&format!(
558                    "{}. {} (distance: {:.4})\n",
559                    i + 1,
560                    neighbor,
561                    distance
562                ));
563            }
564            report.push('\n');
565        }
566
567        Ok(report)
568    }
569}
570
571#[cfg(test)]
572mod tests {
573    use super::*;
574    use scirs2_core::ndarray_ext::array;
575
576    #[test]
577    fn test_similarity_analysis() {
578        let mut embeddings = HashMap::new();
579        embeddings.insert("e1".to_string(), array![1.0, 0.0, 0.0]);
580        embeddings.insert("e2".to_string(), array![0.9, 0.1, 0.0]);
581        embeddings.insert("e3".to_string(), array![0.0, 1.0, 0.0]);
582
583        let config = InterpretabilityConfig {
584            method: InterpretationMethod::SimilarityAnalysis,
585            top_k: 2,
586            ..Default::default()
587        };
588
589        let analyzer = InterpretabilityAnalyzer::new(config);
590        let analysis = analyzer.similarity_analysis("e1", &embeddings).unwrap();
591
592        assert_eq!(analysis.entity, "e1");
593        assert_eq!(analysis.similar_entities.len(), 2);
594        // e2 should be most similar to e1
595        assert_eq!(analysis.similar_entities[0].0, "e2");
596    }
597
598    #[test]
599    fn test_feature_importance() {
600        let mut embeddings = HashMap::new();
601        embeddings.insert("e1".to_string(), array![1.0, 0.0, 0.0]);
602        embeddings.insert("e2".to_string(), array![0.0, 1.0, 0.0]);
603        embeddings.insert("e3".to_string(), array![0.0, 0.0, 1.0]);
604        embeddings.insert("e4".to_string(), array![5.0, 0.0, 0.0]); // Outlier in dim 0
605
606        let config = InterpretabilityConfig {
607            method: InterpretationMethod::FeatureImportance,
608            top_k: 3,
609            ..Default::default()
610        };
611
612        let analyzer = InterpretabilityAnalyzer::new(config);
613        let importance = analyzer.feature_importance("e4", &embeddings).unwrap();
614
615        assert_eq!(importance.entity, "e4");
616        assert!(!importance.important_features.is_empty());
617        // Dimension 0 should be most important for e4 (outlier)
618        assert_eq!(importance.important_features[0].0, 0);
619    }
620
621    #[test]
622    fn test_counterfactual() {
623        let mut embeddings = HashMap::new();
624        embeddings.insert("e1".to_string(), array![1.0, 0.0, 0.0]);
625        embeddings.insert("e2".to_string(), array![0.0, 1.0, 0.0]);
626
627        let config = InterpretabilityConfig::default();
628        let analyzer = InterpretabilityAnalyzer::new(config);
629
630        let cf = analyzer
631            .counterfactual_explanation("e1", "e2", &embeddings)
632            .unwrap();
633
634        assert_eq!(cf.original, "e1");
635        assert_eq!(cf.target, "e2");
636        assert!(!cf.required_changes.is_empty());
637        assert!(cf.difficulty > 0.0);
638    }
639
640    #[test]
641    fn test_nearest_neighbors() {
642        let mut embeddings = HashMap::new();
643        embeddings.insert("e1".to_string(), array![1.0, 0.0]);
644        embeddings.insert("e2".to_string(), array![1.1, 0.1]);
645        embeddings.insert("e3".to_string(), array![5.0, 5.0]);
646
647        let config = InterpretabilityConfig {
648            method: InterpretationMethod::NearestNeighbors,
649            top_k: 2,
650            ..Default::default()
651        };
652
653        let analyzer = InterpretabilityAnalyzer::new(config);
654        let nn = analyzer
655            .nearest_neighbors_analysis("e1", &embeddings)
656            .unwrap();
657
658        assert_eq!(nn.entity, "e1");
659        assert_eq!(nn.neighbors.len(), 2);
660        // e2 should be nearest to e1
661        assert_eq!(nn.neighbors[0].0, "e2");
662    }
663
664    #[test]
665    fn test_generate_report() {
666        let mut embeddings = HashMap::new();
667        embeddings.insert("e1".to_string(), array![1.0, 0.0, 0.0]);
668        embeddings.insert("e2".to_string(), array![0.9, 0.1, 0.0]);
669
670        let config = InterpretabilityConfig::default();
671        let analyzer = InterpretabilityAnalyzer::new(config);
672
673        let report = analyzer.generate_report("e1", &embeddings).unwrap();
674
675        assert!(report.contains("Interpretability Report"));
676        assert!(report.contains("Similarity Analysis"));
677        assert!(report.contains("Feature Importance"));
678        assert!(report.contains("Nearest Neighbors"));
679    }
680}