oxirs_embed/
interpretability.rs

1//! Model Interpretability Tools
2//!
3//! This module provides tools for understanding and interpreting knowledge graph
4//! embeddings, including attention analysis, embedding similarity, feature importance,
5//! and counterfactual explanations.
6
7use anyhow::{anyhow, Result};
8use rayon::prelude::*;
9use scirs2_core::ndarray_ext::Array1;
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use tracing::info;
13
14/// Interpretation method
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
16pub enum InterpretationMethod {
17    /// Analyze embedding similarities
18    SimilarityAnalysis,
19    /// Feature importance (gradient-based)
20    FeatureImportance,
21    /// Counterfactual explanations
22    Counterfactual,
23    /// Nearest neighbors analysis
24    NearestNeighbors,
25}
26
27/// Interpretability configuration
28#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct InterpretabilityConfig {
30    /// Interpretation method
31    pub method: InterpretationMethod,
32    /// Top-K most important features/neighbors
33    pub top_k: usize,
34    /// Similarity threshold
35    pub similarity_threshold: f32,
36    /// Enable detailed analysis
37    pub detailed: bool,
38}
39
40impl Default for InterpretabilityConfig {
41    fn default() -> Self {
42        Self {
43            method: InterpretationMethod::SimilarityAnalysis,
44            top_k: 10,
45            similarity_threshold: 0.7,
46            detailed: false,
47        }
48    }
49}
50
51/// Similarity analysis result
52#[derive(Debug, Clone, Serialize, Deserialize)]
53pub struct SimilarityAnalysis {
54    /// Entity being analyzed
55    pub entity: String,
56    /// Most similar entities with scores
57    pub similar_entities: Vec<(String, f32)>,
58    /// Least similar entities with scores
59    pub dissimilar_entities: Vec<(String, f32)>,
60    /// Average similarity to all other entities
61    pub avg_similarity: f32,
62}
63
64/// Feature importance result
65#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct FeatureImportance {
67    /// Entity being analyzed
68    pub entity: String,
69    /// Feature indices and their importance scores
70    pub important_features: Vec<(usize, f32)>,
71    /// Feature statistics
72    pub feature_stats: FeatureStats,
73}
74
75/// Feature statistics
76#[derive(Debug, Clone, Serialize, Deserialize)]
77pub struct FeatureStats {
78    /// Mean feature values
79    pub mean: Vec<f32>,
80    /// Standard deviation of features
81    pub std: Vec<f32>,
82    /// Min feature values
83    pub min: Vec<f32>,
84    /// Max feature values
85    pub max: Vec<f32>,
86}
87
88/// Counterfactual explanation
89#[derive(Debug, Clone, Serialize, Deserialize)]
90pub struct CounterfactualExplanation {
91    /// Original entity
92    pub original: String,
93    /// Target entity (for comparison)
94    pub target: String,
95    /// Dimensions that need to change
96    pub required_changes: Vec<(usize, f32, f32)>, // (dim, from, to)
97    /// Estimated difficulty (0-1, higher is harder)
98    pub difficulty: f32,
99}
100
101/// Nearest neighbors analysis
102#[derive(Debug, Clone, Serialize, Deserialize)]
103pub struct NearestNeighborsAnalysis {
104    /// Entity being analyzed
105    pub entity: String,
106    /// Nearest neighbors with distances
107    pub neighbors: Vec<(String, f32)>,
108    /// Neighbor clusters (if detected)
109    pub neighbor_clusters: Vec<Vec<String>>,
110}
111
112/// Model interpretability analyzer
113pub struct InterpretabilityAnalyzer {
114    config: InterpretabilityConfig,
115}
116
117impl InterpretabilityAnalyzer {
118    /// Create new interpretability analyzer
119    pub fn new(config: InterpretabilityConfig) -> Self {
120        info!(
121            "Initialized interpretability analyzer: method={:?}, top_k={}",
122            config.method, config.top_k
123        );
124
125        Self { config }
126    }
127
128    /// Analyze a specific entity
129    pub fn analyze_entity(
130        &self,
131        entity: &str,
132        embeddings: &HashMap<String, Array1<f32>>,
133    ) -> Result<String> {
134        if !embeddings.contains_key(entity) {
135            return Err(anyhow!("Entity not found: {}", entity));
136        }
137
138        match self.config.method {
139            InterpretationMethod::SimilarityAnalysis => {
140                let analysis = self.similarity_analysis(entity, embeddings)?;
141                Ok(serde_json::to_string_pretty(&analysis)?)
142            }
143            InterpretationMethod::FeatureImportance => {
144                let importance = self.feature_importance(entity, embeddings)?;
145                Ok(serde_json::to_string_pretty(&importance)?)
146            }
147            InterpretationMethod::NearestNeighbors => {
148                let neighbors = self.nearest_neighbors_analysis(entity, embeddings)?;
149                Ok(serde_json::to_string_pretty(&neighbors)?)
150            }
151            InterpretationMethod::Counterfactual => {
152                Err(anyhow!("Counterfactual requires target entity"))
153            }
154        }
155    }
156
157    /// Analyze similarity between entities
158    pub fn similarity_analysis(
159        &self,
160        entity: &str,
161        embeddings: &HashMap<String, Array1<f32>>,
162    ) -> Result<SimilarityAnalysis> {
163        let entity_emb = &embeddings[entity];
164
165        // Compute similarities to all other entities
166        let mut similarities: Vec<(String, f32)> = embeddings
167            .par_iter()
168            .filter(|(e, _)| *e != entity)
169            .map(|(other, other_emb)| {
170                let sim = self.cosine_similarity(entity_emb, other_emb);
171                (other.clone(), sim)
172            })
173            .collect();
174
175        // Sort by similarity descending
176        similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
177
178        // Get top-K most similar
179        let similar_entities: Vec<(String, f32)> = similarities
180            .iter()
181            .take(self.config.top_k)
182            .cloned()
183            .collect();
184
185        // Get top-K least similar
186        let mut dissimilar_entities: Vec<(String, f32)> = similarities
187            .iter()
188            .rev()
189            .take(self.config.top_k)
190            .cloned()
191            .collect();
192        dissimilar_entities.reverse();
193
194        // Compute average similarity
195        let avg_similarity =
196            similarities.iter().map(|(_, sim)| sim).sum::<f32>() / similarities.len() as f32;
197
198        info!(
199            "Similarity analysis for '{}': avg_similarity={:.4}",
200            entity, avg_similarity
201        );
202
203        Ok(SimilarityAnalysis {
204            entity: entity.to_string(),
205            similar_entities,
206            dissimilar_entities,
207            avg_similarity,
208        })
209    }
210
211    /// Analyze feature importance for an entity
212    pub fn feature_importance(
213        &self,
214        entity: &str,
215        embeddings: &HashMap<String, Array1<f32>>,
216    ) -> Result<FeatureImportance> {
217        let entity_emb = &embeddings[entity];
218        let dim = entity_emb.len();
219
220        // Compute global feature statistics
221        let feature_stats = self.compute_feature_stats(embeddings);
222
223        // Compute importance as deviation from mean
224        let mut important_features: Vec<(usize, f32)> = (0..dim)
225            .map(|i| {
226                let value = entity_emb[i];
227                let mean = feature_stats.mean[i];
228                let std = feature_stats.std[i];
229
230                // Z-score based importance
231                let importance = if std > 0.0 {
232                    ((value - mean) / std).abs()
233                } else {
234                    0.0
235                };
236
237                (i, importance)
238            })
239            .collect();
240
241        // Sort by importance descending
242        important_features
243            .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
244
245        // Keep top-K
246        important_features.truncate(self.config.top_k);
247
248        info!(
249            "Feature importance for '{}': top feature has importance {:.4}",
250            entity,
251            important_features
252                .first()
253                .map(|(_, imp)| *imp)
254                .unwrap_or(0.0)
255        );
256
257        Ok(FeatureImportance {
258            entity: entity.to_string(),
259            important_features,
260            feature_stats,
261        })
262    }
263
264    /// Generate counterfactual explanation
265    pub fn counterfactual_explanation(
266        &self,
267        original: &str,
268        target: &str,
269        embeddings: &HashMap<String, Array1<f32>>,
270    ) -> Result<CounterfactualExplanation> {
271        let original_emb = embeddings
272            .get(original)
273            .ok_or_else(|| anyhow!("Original entity not found"))?;
274
275        let target_emb = embeddings
276            .get(target)
277            .ok_or_else(|| anyhow!("Target entity not found"))?;
278
279        // Identify dimensions that differ significantly
280        let mut required_changes = Vec::new();
281        let mut total_change = 0.0;
282
283        for i in 0..original_emb.len() {
284            let diff = (target_emb[i] - original_emb[i]).abs();
285            if diff > 0.1 {
286                // Threshold for significance
287                required_changes.push((i, original_emb[i], target_emb[i]));
288                total_change += diff;
289            }
290        }
291
292        // Sort by magnitude of change
293        required_changes.sort_by(|a, b| {
294            let diff_a = (a.2 - a.1).abs();
295            let diff_b = (b.2 - b.1).abs();
296            diff_b
297                .partial_cmp(&diff_a)
298                .unwrap_or(std::cmp::Ordering::Equal)
299        });
300
301        // Keep top-K most important changes
302        required_changes.truncate(self.config.top_k);
303
304        // Compute difficulty (normalized by embedding norm)
305        let norm = original_emb.dot(original_emb).sqrt();
306        let difficulty = if norm > 0.0 {
307            (total_change / norm).min(1.0)
308        } else {
309            1.0
310        };
311
312        info!(
313            "Counterfactual '{}' -> '{}': {} changes, difficulty={:.4}",
314            original,
315            target,
316            required_changes.len(),
317            difficulty
318        );
319
320        Ok(CounterfactualExplanation {
321            original: original.to_string(),
322            target: target.to_string(),
323            required_changes,
324            difficulty,
325        })
326    }
327
328    /// Analyze nearest neighbors
329    pub fn nearest_neighbors_analysis(
330        &self,
331        entity: &str,
332        embeddings: &HashMap<String, Array1<f32>>,
333    ) -> Result<NearestNeighborsAnalysis> {
334        let entity_emb = &embeddings[entity];
335
336        // Find nearest neighbors
337        let mut distances: Vec<(String, f32)> = embeddings
338            .par_iter()
339            .filter(|(e, _)| *e != entity)
340            .map(|(other, other_emb)| {
341                let dist = self.euclidean_distance(entity_emb, other_emb);
342                (other.clone(), dist)
343            })
344            .collect();
345
346        // Sort by distance ascending
347        distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
348
349        // Get top-K nearest neighbors
350        let neighbors: Vec<(String, f32)> =
351            distances.iter().take(self.config.top_k).cloned().collect();
352
353        // Attempt to cluster neighbors (simple distance-based clustering)
354        let neighbor_clusters = if self.config.detailed {
355            self.cluster_neighbors(&neighbors, embeddings)
356        } else {
357            vec![]
358        };
359
360        info!(
361            "Nearest neighbors for '{}': closest neighbor at distance {:.4}",
362            entity,
363            neighbors.first().map(|(_, d)| *d).unwrap_or(0.0)
364        );
365
366        Ok(NearestNeighborsAnalysis {
367            entity: entity.to_string(),
368            neighbors,
369            neighbor_clusters,
370        })
371    }
372
373    /// Batch analysis for multiple entities
374    pub fn batch_analysis(
375        &self,
376        entities: &[String],
377        embeddings: &HashMap<String, Array1<f32>>,
378    ) -> Result<HashMap<String, String>> {
379        let results: HashMap<String, String> = entities
380            .par_iter()
381            .filter_map(|entity| {
382                self.analyze_entity(entity, embeddings)
383                    .ok()
384                    .map(|analysis| (entity.clone(), analysis))
385            })
386            .collect();
387
388        Ok(results)
389    }
390
391    /// Compute global feature statistics
392    fn compute_feature_stats(&self, embeddings: &HashMap<String, Array1<f32>>) -> FeatureStats {
393        let n = embeddings.len() as f32;
394        let dim = embeddings.values().next().unwrap().len();
395
396        let mut mean = vec![0.0; dim];
397        let mut m2 = vec![0.0; dim]; // For variance calculation
398        let mut min = vec![f32::INFINITY; dim];
399        let mut max = vec![f32::NEG_INFINITY; dim];
400
401        // Welford's online algorithm for mean and variance
402        for (count, emb) in embeddings.values().enumerate() {
403            let count_f = (count + 1) as f32;
404
405            for i in 0..dim {
406                let value = emb[i];
407
408                // Update min/max
409                min[i] = min[i].min(value);
410                max[i] = max[i].max(value);
411
412                // Update mean and M2
413                let delta = value - mean[i];
414                mean[i] += delta / count_f;
415                let delta2 = value - mean[i];
416                m2[i] += delta * delta2;
417            }
418        }
419
420        // Compute standard deviation
421        let std: Vec<f32> = m2.iter().map(|&m2_val| (m2_val / n).sqrt()).collect();
422
423        FeatureStats {
424            mean,
425            std,
426            min,
427            max,
428        }
429    }
430
431    /// Cluster neighbors based on distance
432    fn cluster_neighbors(
433        &self,
434        neighbors: &[(String, f32)],
435        embeddings: &HashMap<String, Array1<f32>>,
436    ) -> Vec<Vec<String>> {
437        if neighbors.len() < 2 {
438            return vec![neighbors.iter().map(|(e, _)| e.clone()).collect()];
439        }
440
441        // Simple single-linkage clustering
442        let mut clusters: Vec<Vec<String>> = Vec::new();
443        let distance_threshold = 0.5; // Threshold for clustering
444
445        for (entity, _) in neighbors {
446            let entity_emb = &embeddings[entity];
447            let mut assigned = false;
448
449            // Try to assign to existing cluster
450            for cluster in &mut clusters {
451                let cluster_center = cluster.first().unwrap();
452                let center_emb = &embeddings[cluster_center];
453                let dist = self.euclidean_distance(entity_emb, center_emb);
454
455                if dist <= distance_threshold {
456                    cluster.push(entity.clone());
457                    assigned = true;
458                    break;
459                }
460            }
461
462            // Create new cluster if not assigned
463            if !assigned {
464                clusters.push(vec![entity.clone()]);
465            }
466        }
467
468        clusters
469    }
470
471    /// Cosine similarity between two embeddings
472    fn cosine_similarity(&self, a: &Array1<f32>, b: &Array1<f32>) -> f32 {
473        let dot = a.dot(b);
474        let norm_a = a.dot(a).sqrt();
475        let norm_b = b.dot(b).sqrt();
476
477        if norm_a == 0.0 || norm_b == 0.0 {
478            0.0
479        } else {
480            dot / (norm_a * norm_b)
481        }
482    }
483
484    /// Euclidean distance between two embeddings
485    fn euclidean_distance(&self, a: &Array1<f32>, b: &Array1<f32>) -> f32 {
486        let diff = a - b;
487        diff.dot(&diff).sqrt()
488    }
489
490    /// Generate interpretation report
491    pub fn generate_report(
492        &self,
493        entity: &str,
494        embeddings: &HashMap<String, Array1<f32>>,
495    ) -> Result<String> {
496        let mut report = String::new();
497
498        report.push_str(&format!("# Interpretability Report for '{}'\n\n", entity));
499
500        // Similarity analysis
501        if let Ok(sim_analysis) = self.similarity_analysis(entity, embeddings) {
502            report.push_str("## Similarity Analysis\n\n");
503            report.push_str(&format!(
504                "Average similarity: {:.4}\n\n",
505                sim_analysis.avg_similarity
506            ));
507
508            report.push_str("### Most Similar Entities:\n");
509            for (i, (other, score)) in sim_analysis.similar_entities.iter().enumerate() {
510                report.push_str(&format!(
511                    "{}. {} (similarity: {:.4})\n",
512                    i + 1,
513                    other,
514                    score
515                ));
516            }
517
518            report.push_str("\n### Least Similar Entities:\n");
519            for (i, (other, score)) in sim_analysis.dissimilar_entities.iter().enumerate() {
520                report.push_str(&format!(
521                    "{}. {} (similarity: {:.4})\n",
522                    i + 1,
523                    other,
524                    score
525                ));
526            }
527            report.push('\n');
528        }
529
530        // Feature importance
531        if let Ok(feat_importance) = self.feature_importance(entity, embeddings) {
532            report.push_str("## Feature Importance\n\n");
533            report.push_str("### Top Important Features:\n");
534            for (i, (feature_idx, importance)) in
535                feat_importance.important_features.iter().enumerate()
536            {
537                report.push_str(&format!(
538                    "{}. Dimension {} (importance: {:.4})\n",
539                    i + 1,
540                    feature_idx,
541                    importance
542                ));
543            }
544            report.push('\n');
545        }
546
547        // Nearest neighbors
548        if let Ok(nn_analysis) = self.nearest_neighbors_analysis(entity, embeddings) {
549            report.push_str("## Nearest Neighbors\n\n");
550            for (i, (neighbor, distance)) in nn_analysis.neighbors.iter().enumerate() {
551                report.push_str(&format!(
552                    "{}. {} (distance: {:.4})\n",
553                    i + 1,
554                    neighbor,
555                    distance
556                ));
557            }
558            report.push('\n');
559        }
560
561        Ok(report)
562    }
563}
564
565#[cfg(test)]
566mod tests {
567    use super::*;
568    use scirs2_core::ndarray_ext::array;
569
570    #[test]
571    fn test_similarity_analysis() {
572        let mut embeddings = HashMap::new();
573        embeddings.insert("e1".to_string(), array![1.0, 0.0, 0.0]);
574        embeddings.insert("e2".to_string(), array![0.9, 0.1, 0.0]);
575        embeddings.insert("e3".to_string(), array![0.0, 1.0, 0.0]);
576
577        let config = InterpretabilityConfig {
578            method: InterpretationMethod::SimilarityAnalysis,
579            top_k: 2,
580            ..Default::default()
581        };
582
583        let analyzer = InterpretabilityAnalyzer::new(config);
584        let analysis = analyzer.similarity_analysis("e1", &embeddings).unwrap();
585
586        assert_eq!(analysis.entity, "e1");
587        assert_eq!(analysis.similar_entities.len(), 2);
588        // e2 should be most similar to e1
589        assert_eq!(analysis.similar_entities[0].0, "e2");
590    }
591
592    #[test]
593    fn test_feature_importance() {
594        let mut embeddings = HashMap::new();
595        embeddings.insert("e1".to_string(), array![1.0, 0.0, 0.0]);
596        embeddings.insert("e2".to_string(), array![0.0, 1.0, 0.0]);
597        embeddings.insert("e3".to_string(), array![0.0, 0.0, 1.0]);
598        embeddings.insert("e4".to_string(), array![5.0, 0.0, 0.0]); // Outlier in dim 0
599
600        let config = InterpretabilityConfig {
601            method: InterpretationMethod::FeatureImportance,
602            top_k: 3,
603            ..Default::default()
604        };
605
606        let analyzer = InterpretabilityAnalyzer::new(config);
607        let importance = analyzer.feature_importance("e4", &embeddings).unwrap();
608
609        assert_eq!(importance.entity, "e4");
610        assert!(!importance.important_features.is_empty());
611        // Dimension 0 should be most important for e4 (outlier)
612        assert_eq!(importance.important_features[0].0, 0);
613    }
614
615    #[test]
616    fn test_counterfactual() {
617        let mut embeddings = HashMap::new();
618        embeddings.insert("e1".to_string(), array![1.0, 0.0, 0.0]);
619        embeddings.insert("e2".to_string(), array![0.0, 1.0, 0.0]);
620
621        let config = InterpretabilityConfig::default();
622        let analyzer = InterpretabilityAnalyzer::new(config);
623
624        let cf = analyzer
625            .counterfactual_explanation("e1", "e2", &embeddings)
626            .unwrap();
627
628        assert_eq!(cf.original, "e1");
629        assert_eq!(cf.target, "e2");
630        assert!(!cf.required_changes.is_empty());
631        assert!(cf.difficulty > 0.0);
632    }
633
634    #[test]
635    fn test_nearest_neighbors() {
636        let mut embeddings = HashMap::new();
637        embeddings.insert("e1".to_string(), array![1.0, 0.0]);
638        embeddings.insert("e2".to_string(), array![1.1, 0.1]);
639        embeddings.insert("e3".to_string(), array![5.0, 5.0]);
640
641        let config = InterpretabilityConfig {
642            method: InterpretationMethod::NearestNeighbors,
643            top_k: 2,
644            ..Default::default()
645        };
646
647        let analyzer = InterpretabilityAnalyzer::new(config);
648        let nn = analyzer
649            .nearest_neighbors_analysis("e1", &embeddings)
650            .unwrap();
651
652        assert_eq!(nn.entity, "e1");
653        assert_eq!(nn.neighbors.len(), 2);
654        // e2 should be nearest to e1
655        assert_eq!(nn.neighbors[0].0, "e2");
656    }
657
658    #[test]
659    fn test_generate_report() {
660        let mut embeddings = HashMap::new();
661        embeddings.insert("e1".to_string(), array![1.0, 0.0, 0.0]);
662        embeddings.insert("e2".to_string(), array![0.9, 0.1, 0.0]);
663
664        let config = InterpretabilityConfig::default();
665        let analyzer = InterpretabilityAnalyzer::new(config);
666
667        let report = analyzer.generate_report("e1", &embeddings).unwrap();
668
669        assert!(report.contains("Interpretability Report"));
670        assert!(report.contains("Similarity Analysis"));
671        assert!(report.contains("Feature Importance"));
672        assert!(report.contains("Nearest Neighbors"));
673    }
674}