1use anyhow::{anyhow, Result};
8use rayon::prelude::*;
9use scirs2_core::ndarray_ext::Array1;
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use tracing::info;
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
16pub enum InterpretationMethod {
17 SimilarityAnalysis,
19 FeatureImportance,
21 Counterfactual,
23 NearestNeighbors,
25}
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct InterpretabilityConfig {
30 pub method: InterpretationMethod,
32 pub top_k: usize,
34 pub similarity_threshold: f32,
36 pub detailed: bool,
38}
39
40impl Default for InterpretabilityConfig {
41 fn default() -> Self {
42 Self {
43 method: InterpretationMethod::SimilarityAnalysis,
44 top_k: 10,
45 similarity_threshold: 0.7,
46 detailed: false,
47 }
48 }
49}
50
51#[derive(Debug, Clone, Serialize, Deserialize)]
53pub struct SimilarityAnalysis {
54 pub entity: String,
56 pub similar_entities: Vec<(String, f32)>,
58 pub dissimilar_entities: Vec<(String, f32)>,
60 pub avg_similarity: f32,
62}
63
64#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct FeatureImportance {
67 pub entity: String,
69 pub important_features: Vec<(usize, f32)>,
71 pub feature_stats: FeatureStats,
73}
74
75#[derive(Debug, Clone, Serialize, Deserialize)]
77pub struct FeatureStats {
78 pub mean: Vec<f32>,
80 pub std: Vec<f32>,
82 pub min: Vec<f32>,
84 pub max: Vec<f32>,
86}
87
88#[derive(Debug, Clone, Serialize, Deserialize)]
90pub struct CounterfactualExplanation {
91 pub original: String,
93 pub target: String,
95 pub required_changes: Vec<(usize, f32, f32)>, pub difficulty: f32,
99}
100
101#[derive(Debug, Clone, Serialize, Deserialize)]
103pub struct NearestNeighborsAnalysis {
104 pub entity: String,
106 pub neighbors: Vec<(String, f32)>,
108 pub neighbor_clusters: Vec<Vec<String>>,
110}
111
112pub struct InterpretabilityAnalyzer {
114 config: InterpretabilityConfig,
115}
116
117impl InterpretabilityAnalyzer {
118 pub fn new(config: InterpretabilityConfig) -> Self {
120 info!(
121 "Initialized interpretability analyzer: method={:?}, top_k={}",
122 config.method, config.top_k
123 );
124
125 Self { config }
126 }
127
128 pub fn analyze_entity(
130 &self,
131 entity: &str,
132 embeddings: &HashMap<String, Array1<f32>>,
133 ) -> Result<String> {
134 if !embeddings.contains_key(entity) {
135 return Err(anyhow!("Entity not found: {}", entity));
136 }
137
138 match self.config.method {
139 InterpretationMethod::SimilarityAnalysis => {
140 let analysis = self.similarity_analysis(entity, embeddings)?;
141 Ok(serde_json::to_string_pretty(&analysis)?)
142 }
143 InterpretationMethod::FeatureImportance => {
144 let importance = self.feature_importance(entity, embeddings)?;
145 Ok(serde_json::to_string_pretty(&importance)?)
146 }
147 InterpretationMethod::NearestNeighbors => {
148 let neighbors = self.nearest_neighbors_analysis(entity, embeddings)?;
149 Ok(serde_json::to_string_pretty(&neighbors)?)
150 }
151 InterpretationMethod::Counterfactual => {
152 Err(anyhow!("Counterfactual requires target entity"))
153 }
154 }
155 }
156
157 pub fn similarity_analysis(
159 &self,
160 entity: &str,
161 embeddings: &HashMap<String, Array1<f32>>,
162 ) -> Result<SimilarityAnalysis> {
163 let entity_emb = &embeddings[entity];
164
165 let mut similarities: Vec<(String, f32)> = embeddings
167 .par_iter()
168 .filter(|(e, _)| *e != entity)
169 .map(|(other, other_emb)| {
170 let sim = self.cosine_similarity(entity_emb, other_emb);
171 (other.clone(), sim)
172 })
173 .collect();
174
175 similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
177
178 let similar_entities: Vec<(String, f32)> = similarities
180 .iter()
181 .take(self.config.top_k)
182 .cloned()
183 .collect();
184
185 let mut dissimilar_entities: Vec<(String, f32)> = similarities
187 .iter()
188 .rev()
189 .take(self.config.top_k)
190 .cloned()
191 .collect();
192 dissimilar_entities.reverse();
193
194 let avg_similarity =
196 similarities.iter().map(|(_, sim)| sim).sum::<f32>() / similarities.len() as f32;
197
198 info!(
199 "Similarity analysis for '{}': avg_similarity={:.4}",
200 entity, avg_similarity
201 );
202
203 Ok(SimilarityAnalysis {
204 entity: entity.to_string(),
205 similar_entities,
206 dissimilar_entities,
207 avg_similarity,
208 })
209 }
210
211 pub fn feature_importance(
213 &self,
214 entity: &str,
215 embeddings: &HashMap<String, Array1<f32>>,
216 ) -> Result<FeatureImportance> {
217 let entity_emb = &embeddings[entity];
218 let dim = entity_emb.len();
219
220 let feature_stats = self.compute_feature_stats(embeddings);
222
223 let mut important_features: Vec<(usize, f32)> = (0..dim)
225 .map(|i| {
226 let value = entity_emb[i];
227 let mean = feature_stats.mean[i];
228 let std = feature_stats.std[i];
229
230 let importance = if std > 0.0 {
232 ((value - mean) / std).abs()
233 } else {
234 0.0
235 };
236
237 (i, importance)
238 })
239 .collect();
240
241 important_features
243 .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
244
245 important_features.truncate(self.config.top_k);
247
248 info!(
249 "Feature importance for '{}': top feature has importance {:.4}",
250 entity,
251 important_features
252 .first()
253 .map(|(_, imp)| *imp)
254 .unwrap_or(0.0)
255 );
256
257 Ok(FeatureImportance {
258 entity: entity.to_string(),
259 important_features,
260 feature_stats,
261 })
262 }
263
264 pub fn counterfactual_explanation(
266 &self,
267 original: &str,
268 target: &str,
269 embeddings: &HashMap<String, Array1<f32>>,
270 ) -> Result<CounterfactualExplanation> {
271 let original_emb = embeddings
272 .get(original)
273 .ok_or_else(|| anyhow!("Original entity not found"))?;
274
275 let target_emb = embeddings
276 .get(target)
277 .ok_or_else(|| anyhow!("Target entity not found"))?;
278
279 let mut required_changes = Vec::new();
281 let mut total_change = 0.0;
282
283 for i in 0..original_emb.len() {
284 let diff = (target_emb[i] - original_emb[i]).abs();
285 if diff > 0.1 {
286 required_changes.push((i, original_emb[i], target_emb[i]));
288 total_change += diff;
289 }
290 }
291
292 required_changes.sort_by(|a, b| {
294 let diff_a = (a.2 - a.1).abs();
295 let diff_b = (b.2 - b.1).abs();
296 diff_b
297 .partial_cmp(&diff_a)
298 .unwrap_or(std::cmp::Ordering::Equal)
299 });
300
301 required_changes.truncate(self.config.top_k);
303
304 let norm = original_emb.dot(original_emb).sqrt();
306 let difficulty = if norm > 0.0 {
307 (total_change / norm).min(1.0)
308 } else {
309 1.0
310 };
311
312 info!(
313 "Counterfactual '{}' -> '{}': {} changes, difficulty={:.4}",
314 original,
315 target,
316 required_changes.len(),
317 difficulty
318 );
319
320 Ok(CounterfactualExplanation {
321 original: original.to_string(),
322 target: target.to_string(),
323 required_changes,
324 difficulty,
325 })
326 }
327
328 pub fn nearest_neighbors_analysis(
330 &self,
331 entity: &str,
332 embeddings: &HashMap<String, Array1<f32>>,
333 ) -> Result<NearestNeighborsAnalysis> {
334 let entity_emb = &embeddings[entity];
335
336 let mut distances: Vec<(String, f32)> = embeddings
338 .par_iter()
339 .filter(|(e, _)| *e != entity)
340 .map(|(other, other_emb)| {
341 let dist = self.euclidean_distance(entity_emb, other_emb);
342 (other.clone(), dist)
343 })
344 .collect();
345
346 distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
348
349 let neighbors: Vec<(String, f32)> =
351 distances.iter().take(self.config.top_k).cloned().collect();
352
353 let neighbor_clusters = if self.config.detailed {
355 self.cluster_neighbors(&neighbors, embeddings)
356 } else {
357 vec![]
358 };
359
360 info!(
361 "Nearest neighbors for '{}': closest neighbor at distance {:.4}",
362 entity,
363 neighbors.first().map(|(_, d)| *d).unwrap_or(0.0)
364 );
365
366 Ok(NearestNeighborsAnalysis {
367 entity: entity.to_string(),
368 neighbors,
369 neighbor_clusters,
370 })
371 }
372
373 pub fn batch_analysis(
375 &self,
376 entities: &[String],
377 embeddings: &HashMap<String, Array1<f32>>,
378 ) -> Result<HashMap<String, String>> {
379 let results: HashMap<String, String> = entities
380 .par_iter()
381 .filter_map(|entity| {
382 self.analyze_entity(entity, embeddings)
383 .ok()
384 .map(|analysis| (entity.clone(), analysis))
385 })
386 .collect();
387
388 Ok(results)
389 }
390
391 fn compute_feature_stats(&self, embeddings: &HashMap<String, Array1<f32>>) -> FeatureStats {
393 let n = embeddings.len() as f32;
394 let dim = embeddings
395 .values()
396 .next()
397 .expect("embeddings should not be empty")
398 .len();
399
400 let mut mean = vec![0.0; dim];
401 let mut m2 = vec![0.0; dim]; let mut min = vec![f32::INFINITY; dim];
403 let mut max = vec![f32::NEG_INFINITY; dim];
404
405 for (count, emb) in embeddings.values().enumerate() {
407 let count_f = (count + 1) as f32;
408
409 for i in 0..dim {
410 let value = emb[i];
411
412 min[i] = min[i].min(value);
414 max[i] = max[i].max(value);
415
416 let delta = value - mean[i];
418 mean[i] += delta / count_f;
419 let delta2 = value - mean[i];
420 m2[i] += delta * delta2;
421 }
422 }
423
424 let std: Vec<f32> = m2.iter().map(|&m2_val| (m2_val / n).sqrt()).collect();
426
427 FeatureStats {
428 mean,
429 std,
430 min,
431 max,
432 }
433 }
434
435 fn cluster_neighbors(
437 &self,
438 neighbors: &[(String, f32)],
439 embeddings: &HashMap<String, Array1<f32>>,
440 ) -> Vec<Vec<String>> {
441 if neighbors.len() < 2 {
442 return vec![neighbors.iter().map(|(e, _)| e.clone()).collect()];
443 }
444
445 let mut clusters: Vec<Vec<String>> = Vec::new();
447 let distance_threshold = 0.5; for (entity, _) in neighbors {
450 let entity_emb = &embeddings[entity];
451 let mut assigned = false;
452
453 for cluster in &mut clusters {
455 let cluster_center = cluster
456 .first()
457 .expect("collection validated to be non-empty");
458 let center_emb = &embeddings[cluster_center];
459 let dist = self.euclidean_distance(entity_emb, center_emb);
460
461 if dist <= distance_threshold {
462 cluster.push(entity.clone());
463 assigned = true;
464 break;
465 }
466 }
467
468 if !assigned {
470 clusters.push(vec![entity.clone()]);
471 }
472 }
473
474 clusters
475 }
476
477 fn cosine_similarity(&self, a: &Array1<f32>, b: &Array1<f32>) -> f32 {
479 let dot = a.dot(b);
480 let norm_a = a.dot(a).sqrt();
481 let norm_b = b.dot(b).sqrt();
482
483 if norm_a == 0.0 || norm_b == 0.0 {
484 0.0
485 } else {
486 dot / (norm_a * norm_b)
487 }
488 }
489
490 fn euclidean_distance(&self, a: &Array1<f32>, b: &Array1<f32>) -> f32 {
492 let diff = a - b;
493 diff.dot(&diff).sqrt()
494 }
495
496 pub fn generate_report(
498 &self,
499 entity: &str,
500 embeddings: &HashMap<String, Array1<f32>>,
501 ) -> Result<String> {
502 let mut report = String::new();
503
504 report.push_str(&format!("# Interpretability Report for '{}'\n\n", entity));
505
506 if let Ok(sim_analysis) = self.similarity_analysis(entity, embeddings) {
508 report.push_str("## Similarity Analysis\n\n");
509 report.push_str(&format!(
510 "Average similarity: {:.4}\n\n",
511 sim_analysis.avg_similarity
512 ));
513
514 report.push_str("### Most Similar Entities:\n");
515 for (i, (other, score)) in sim_analysis.similar_entities.iter().enumerate() {
516 report.push_str(&format!(
517 "{}. {} (similarity: {:.4})\n",
518 i + 1,
519 other,
520 score
521 ));
522 }
523
524 report.push_str("\n### Least Similar Entities:\n");
525 for (i, (other, score)) in sim_analysis.dissimilar_entities.iter().enumerate() {
526 report.push_str(&format!(
527 "{}. {} (similarity: {:.4})\n",
528 i + 1,
529 other,
530 score
531 ));
532 }
533 report.push('\n');
534 }
535
536 if let Ok(feat_importance) = self.feature_importance(entity, embeddings) {
538 report.push_str("## Feature Importance\n\n");
539 report.push_str("### Top Important Features:\n");
540 for (i, (feature_idx, importance)) in
541 feat_importance.important_features.iter().enumerate()
542 {
543 report.push_str(&format!(
544 "{}. Dimension {} (importance: {:.4})\n",
545 i + 1,
546 feature_idx,
547 importance
548 ));
549 }
550 report.push('\n');
551 }
552
553 if let Ok(nn_analysis) = self.nearest_neighbors_analysis(entity, embeddings) {
555 report.push_str("## Nearest Neighbors\n\n");
556 for (i, (neighbor, distance)) in nn_analysis.neighbors.iter().enumerate() {
557 report.push_str(&format!(
558 "{}. {} (distance: {:.4})\n",
559 i + 1,
560 neighbor,
561 distance
562 ));
563 }
564 report.push('\n');
565 }
566
567 Ok(report)
568 }
569}
570
571#[cfg(test)]
572mod tests {
573 use super::*;
574 use scirs2_core::ndarray_ext::array;
575
576 #[test]
577 fn test_similarity_analysis() {
578 let mut embeddings = HashMap::new();
579 embeddings.insert("e1".to_string(), array![1.0, 0.0, 0.0]);
580 embeddings.insert("e2".to_string(), array![0.9, 0.1, 0.0]);
581 embeddings.insert("e3".to_string(), array![0.0, 1.0, 0.0]);
582
583 let config = InterpretabilityConfig {
584 method: InterpretationMethod::SimilarityAnalysis,
585 top_k: 2,
586 ..Default::default()
587 };
588
589 let analyzer = InterpretabilityAnalyzer::new(config);
590 let analysis = analyzer.similarity_analysis("e1", &embeddings).unwrap();
591
592 assert_eq!(analysis.entity, "e1");
593 assert_eq!(analysis.similar_entities.len(), 2);
594 assert_eq!(analysis.similar_entities[0].0, "e2");
596 }
597
598 #[test]
599 fn test_feature_importance() {
600 let mut embeddings = HashMap::new();
601 embeddings.insert("e1".to_string(), array![1.0, 0.0, 0.0]);
602 embeddings.insert("e2".to_string(), array![0.0, 1.0, 0.0]);
603 embeddings.insert("e3".to_string(), array![0.0, 0.0, 1.0]);
604 embeddings.insert("e4".to_string(), array![5.0, 0.0, 0.0]); let config = InterpretabilityConfig {
607 method: InterpretationMethod::FeatureImportance,
608 top_k: 3,
609 ..Default::default()
610 };
611
612 let analyzer = InterpretabilityAnalyzer::new(config);
613 let importance = analyzer.feature_importance("e4", &embeddings).unwrap();
614
615 assert_eq!(importance.entity, "e4");
616 assert!(!importance.important_features.is_empty());
617 assert_eq!(importance.important_features[0].0, 0);
619 }
620
621 #[test]
622 fn test_counterfactual() {
623 let mut embeddings = HashMap::new();
624 embeddings.insert("e1".to_string(), array![1.0, 0.0, 0.0]);
625 embeddings.insert("e2".to_string(), array![0.0, 1.0, 0.0]);
626
627 let config = InterpretabilityConfig::default();
628 let analyzer = InterpretabilityAnalyzer::new(config);
629
630 let cf = analyzer
631 .counterfactual_explanation("e1", "e2", &embeddings)
632 .unwrap();
633
634 assert_eq!(cf.original, "e1");
635 assert_eq!(cf.target, "e2");
636 assert!(!cf.required_changes.is_empty());
637 assert!(cf.difficulty > 0.0);
638 }
639
640 #[test]
641 fn test_nearest_neighbors() {
642 let mut embeddings = HashMap::new();
643 embeddings.insert("e1".to_string(), array![1.0, 0.0]);
644 embeddings.insert("e2".to_string(), array![1.1, 0.1]);
645 embeddings.insert("e3".to_string(), array![5.0, 5.0]);
646
647 let config = InterpretabilityConfig {
648 method: InterpretationMethod::NearestNeighbors,
649 top_k: 2,
650 ..Default::default()
651 };
652
653 let analyzer = InterpretabilityAnalyzer::new(config);
654 let nn = analyzer
655 .nearest_neighbors_analysis("e1", &embeddings)
656 .unwrap();
657
658 assert_eq!(nn.entity, "e1");
659 assert_eq!(nn.neighbors.len(), 2);
660 assert_eq!(nn.neighbors[0].0, "e2");
662 }
663
664 #[test]
665 fn test_generate_report() {
666 let mut embeddings = HashMap::new();
667 embeddings.insert("e1".to_string(), array![1.0, 0.0, 0.0]);
668 embeddings.insert("e2".to_string(), array![0.9, 0.1, 0.0]);
669
670 let config = InterpretabilityConfig::default();
671 let analyzer = InterpretabilityAnalyzer::new(config);
672
673 let report = analyzer.generate_report("e1", &embeddings).unwrap();
674
675 assert!(report.contains("Interpretability Report"));
676 assert!(report.contains("Similarity Analysis"));
677 assert!(report.contains("Feature Importance"));
678 assert!(report.contains("Nearest Neighbors"));
679 }
680}