1use anyhow::{anyhow, Result};
8use rayon::prelude::*;
9use scirs2_core::ndarray_ext::Array1;
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use tracing::info;
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
16pub enum InterpretationMethod {
17 SimilarityAnalysis,
19 FeatureImportance,
21 Counterfactual,
23 NearestNeighbors,
25}
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct InterpretabilityConfig {
30 pub method: InterpretationMethod,
32 pub top_k: usize,
34 pub similarity_threshold: f32,
36 pub detailed: bool,
38}
39
40impl Default for InterpretabilityConfig {
41 fn default() -> Self {
42 Self {
43 method: InterpretationMethod::SimilarityAnalysis,
44 top_k: 10,
45 similarity_threshold: 0.7,
46 detailed: false,
47 }
48 }
49}
50
51#[derive(Debug, Clone, Serialize, Deserialize)]
53pub struct SimilarityAnalysis {
54 pub entity: String,
56 pub similar_entities: Vec<(String, f32)>,
58 pub dissimilar_entities: Vec<(String, f32)>,
60 pub avg_similarity: f32,
62}
63
64#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct FeatureImportance {
67 pub entity: String,
69 pub important_features: Vec<(usize, f32)>,
71 pub feature_stats: FeatureStats,
73}
74
75#[derive(Debug, Clone, Serialize, Deserialize)]
77pub struct FeatureStats {
78 pub mean: Vec<f32>,
80 pub std: Vec<f32>,
82 pub min: Vec<f32>,
84 pub max: Vec<f32>,
86}
87
88#[derive(Debug, Clone, Serialize, Deserialize)]
90pub struct CounterfactualExplanation {
91 pub original: String,
93 pub target: String,
95 pub required_changes: Vec<(usize, f32, f32)>, pub difficulty: f32,
99}
100
101#[derive(Debug, Clone, Serialize, Deserialize)]
103pub struct NearestNeighborsAnalysis {
104 pub entity: String,
106 pub neighbors: Vec<(String, f32)>,
108 pub neighbor_clusters: Vec<Vec<String>>,
110}
111
112pub struct InterpretabilityAnalyzer {
114 config: InterpretabilityConfig,
115}
116
117impl InterpretabilityAnalyzer {
118 pub fn new(config: InterpretabilityConfig) -> Self {
120 info!(
121 "Initialized interpretability analyzer: method={:?}, top_k={}",
122 config.method, config.top_k
123 );
124
125 Self { config }
126 }
127
128 pub fn analyze_entity(
130 &self,
131 entity: &str,
132 embeddings: &HashMap<String, Array1<f32>>,
133 ) -> Result<String> {
134 if !embeddings.contains_key(entity) {
135 return Err(anyhow!("Entity not found: {}", entity));
136 }
137
138 match self.config.method {
139 InterpretationMethod::SimilarityAnalysis => {
140 let analysis = self.similarity_analysis(entity, embeddings)?;
141 Ok(serde_json::to_string_pretty(&analysis)?)
142 }
143 InterpretationMethod::FeatureImportance => {
144 let importance = self.feature_importance(entity, embeddings)?;
145 Ok(serde_json::to_string_pretty(&importance)?)
146 }
147 InterpretationMethod::NearestNeighbors => {
148 let neighbors = self.nearest_neighbors_analysis(entity, embeddings)?;
149 Ok(serde_json::to_string_pretty(&neighbors)?)
150 }
151 InterpretationMethod::Counterfactual => {
152 Err(anyhow!("Counterfactual requires target entity"))
153 }
154 }
155 }
156
157 pub fn similarity_analysis(
159 &self,
160 entity: &str,
161 embeddings: &HashMap<String, Array1<f32>>,
162 ) -> Result<SimilarityAnalysis> {
163 let entity_emb = &embeddings[entity];
164
165 let mut similarities: Vec<(String, f32)> = embeddings
167 .par_iter()
168 .filter(|(e, _)| *e != entity)
169 .map(|(other, other_emb)| {
170 let sim = self.cosine_similarity(entity_emb, other_emb);
171 (other.clone(), sim)
172 })
173 .collect();
174
175 similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
177
178 let similar_entities: Vec<(String, f32)> = similarities
180 .iter()
181 .take(self.config.top_k)
182 .cloned()
183 .collect();
184
185 let mut dissimilar_entities: Vec<(String, f32)> = similarities
187 .iter()
188 .rev()
189 .take(self.config.top_k)
190 .cloned()
191 .collect();
192 dissimilar_entities.reverse();
193
194 let avg_similarity =
196 similarities.iter().map(|(_, sim)| sim).sum::<f32>() / similarities.len() as f32;
197
198 info!(
199 "Similarity analysis for '{}': avg_similarity={:.4}",
200 entity, avg_similarity
201 );
202
203 Ok(SimilarityAnalysis {
204 entity: entity.to_string(),
205 similar_entities,
206 dissimilar_entities,
207 avg_similarity,
208 })
209 }
210
211 pub fn feature_importance(
213 &self,
214 entity: &str,
215 embeddings: &HashMap<String, Array1<f32>>,
216 ) -> Result<FeatureImportance> {
217 let entity_emb = &embeddings[entity];
218 let dim = entity_emb.len();
219
220 let feature_stats = self.compute_feature_stats(embeddings);
222
223 let mut important_features: Vec<(usize, f32)> = (0..dim)
225 .map(|i| {
226 let value = entity_emb[i];
227 let mean = feature_stats.mean[i];
228 let std = feature_stats.std[i];
229
230 let importance = if std > 0.0 {
232 ((value - mean) / std).abs()
233 } else {
234 0.0
235 };
236
237 (i, importance)
238 })
239 .collect();
240
241 important_features
243 .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
244
245 important_features.truncate(self.config.top_k);
247
248 info!(
249 "Feature importance for '{}': top feature has importance {:.4}",
250 entity,
251 important_features
252 .first()
253 .map(|(_, imp)| *imp)
254 .unwrap_or(0.0)
255 );
256
257 Ok(FeatureImportance {
258 entity: entity.to_string(),
259 important_features,
260 feature_stats,
261 })
262 }
263
264 pub fn counterfactual_explanation(
266 &self,
267 original: &str,
268 target: &str,
269 embeddings: &HashMap<String, Array1<f32>>,
270 ) -> Result<CounterfactualExplanation> {
271 let original_emb = embeddings
272 .get(original)
273 .ok_or_else(|| anyhow!("Original entity not found"))?;
274
275 let target_emb = embeddings
276 .get(target)
277 .ok_or_else(|| anyhow!("Target entity not found"))?;
278
279 let mut required_changes = Vec::new();
281 let mut total_change = 0.0;
282
283 for i in 0..original_emb.len() {
284 let diff = (target_emb[i] - original_emb[i]).abs();
285 if diff > 0.1 {
286 required_changes.push((i, original_emb[i], target_emb[i]));
288 total_change += diff;
289 }
290 }
291
292 required_changes.sort_by(|a, b| {
294 let diff_a = (a.2 - a.1).abs();
295 let diff_b = (b.2 - b.1).abs();
296 diff_b
297 .partial_cmp(&diff_a)
298 .unwrap_or(std::cmp::Ordering::Equal)
299 });
300
301 required_changes.truncate(self.config.top_k);
303
304 let norm = original_emb.dot(original_emb).sqrt();
306 let difficulty = if norm > 0.0 {
307 (total_change / norm).min(1.0)
308 } else {
309 1.0
310 };
311
312 info!(
313 "Counterfactual '{}' -> '{}': {} changes, difficulty={:.4}",
314 original,
315 target,
316 required_changes.len(),
317 difficulty
318 );
319
320 Ok(CounterfactualExplanation {
321 original: original.to_string(),
322 target: target.to_string(),
323 required_changes,
324 difficulty,
325 })
326 }
327
328 pub fn nearest_neighbors_analysis(
330 &self,
331 entity: &str,
332 embeddings: &HashMap<String, Array1<f32>>,
333 ) -> Result<NearestNeighborsAnalysis> {
334 let entity_emb = &embeddings[entity];
335
336 let mut distances: Vec<(String, f32)> = embeddings
338 .par_iter()
339 .filter(|(e, _)| *e != entity)
340 .map(|(other, other_emb)| {
341 let dist = self.euclidean_distance(entity_emb, other_emb);
342 (other.clone(), dist)
343 })
344 .collect();
345
346 distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
348
349 let neighbors: Vec<(String, f32)> =
351 distances.iter().take(self.config.top_k).cloned().collect();
352
353 let neighbor_clusters = if self.config.detailed {
355 self.cluster_neighbors(&neighbors, embeddings)
356 } else {
357 vec![]
358 };
359
360 info!(
361 "Nearest neighbors for '{}': closest neighbor at distance {:.4}",
362 entity,
363 neighbors.first().map(|(_, d)| *d).unwrap_or(0.0)
364 );
365
366 Ok(NearestNeighborsAnalysis {
367 entity: entity.to_string(),
368 neighbors,
369 neighbor_clusters,
370 })
371 }
372
373 pub fn batch_analysis(
375 &self,
376 entities: &[String],
377 embeddings: &HashMap<String, Array1<f32>>,
378 ) -> Result<HashMap<String, String>> {
379 let results: HashMap<String, String> = entities
380 .par_iter()
381 .filter_map(|entity| {
382 self.analyze_entity(entity, embeddings)
383 .ok()
384 .map(|analysis| (entity.clone(), analysis))
385 })
386 .collect();
387
388 Ok(results)
389 }
390
391 fn compute_feature_stats(&self, embeddings: &HashMap<String, Array1<f32>>) -> FeatureStats {
393 let n = embeddings.len() as f32;
394 let dim = embeddings.values().next().unwrap().len();
395
396 let mut mean = vec![0.0; dim];
397 let mut m2 = vec![0.0; dim]; let mut min = vec![f32::INFINITY; dim];
399 let mut max = vec![f32::NEG_INFINITY; dim];
400
401 for (count, emb) in embeddings.values().enumerate() {
403 let count_f = (count + 1) as f32;
404
405 for i in 0..dim {
406 let value = emb[i];
407
408 min[i] = min[i].min(value);
410 max[i] = max[i].max(value);
411
412 let delta = value - mean[i];
414 mean[i] += delta / count_f;
415 let delta2 = value - mean[i];
416 m2[i] += delta * delta2;
417 }
418 }
419
420 let std: Vec<f32> = m2.iter().map(|&m2_val| (m2_val / n).sqrt()).collect();
422
423 FeatureStats {
424 mean,
425 std,
426 min,
427 max,
428 }
429 }
430
431 fn cluster_neighbors(
433 &self,
434 neighbors: &[(String, f32)],
435 embeddings: &HashMap<String, Array1<f32>>,
436 ) -> Vec<Vec<String>> {
437 if neighbors.len() < 2 {
438 return vec![neighbors.iter().map(|(e, _)| e.clone()).collect()];
439 }
440
441 let mut clusters: Vec<Vec<String>> = Vec::new();
443 let distance_threshold = 0.5; for (entity, _) in neighbors {
446 let entity_emb = &embeddings[entity];
447 let mut assigned = false;
448
449 for cluster in &mut clusters {
451 let cluster_center = cluster.first().unwrap();
452 let center_emb = &embeddings[cluster_center];
453 let dist = self.euclidean_distance(entity_emb, center_emb);
454
455 if dist <= distance_threshold {
456 cluster.push(entity.clone());
457 assigned = true;
458 break;
459 }
460 }
461
462 if !assigned {
464 clusters.push(vec![entity.clone()]);
465 }
466 }
467
468 clusters
469 }
470
471 fn cosine_similarity(&self, a: &Array1<f32>, b: &Array1<f32>) -> f32 {
473 let dot = a.dot(b);
474 let norm_a = a.dot(a).sqrt();
475 let norm_b = b.dot(b).sqrt();
476
477 if norm_a == 0.0 || norm_b == 0.0 {
478 0.0
479 } else {
480 dot / (norm_a * norm_b)
481 }
482 }
483
484 fn euclidean_distance(&self, a: &Array1<f32>, b: &Array1<f32>) -> f32 {
486 let diff = a - b;
487 diff.dot(&diff).sqrt()
488 }
489
490 pub fn generate_report(
492 &self,
493 entity: &str,
494 embeddings: &HashMap<String, Array1<f32>>,
495 ) -> Result<String> {
496 let mut report = String::new();
497
498 report.push_str(&format!("# Interpretability Report for '{}'\n\n", entity));
499
500 if let Ok(sim_analysis) = self.similarity_analysis(entity, embeddings) {
502 report.push_str("## Similarity Analysis\n\n");
503 report.push_str(&format!(
504 "Average similarity: {:.4}\n\n",
505 sim_analysis.avg_similarity
506 ));
507
508 report.push_str("### Most Similar Entities:\n");
509 for (i, (other, score)) in sim_analysis.similar_entities.iter().enumerate() {
510 report.push_str(&format!(
511 "{}. {} (similarity: {:.4})\n",
512 i + 1,
513 other,
514 score
515 ));
516 }
517
518 report.push_str("\n### Least Similar Entities:\n");
519 for (i, (other, score)) in sim_analysis.dissimilar_entities.iter().enumerate() {
520 report.push_str(&format!(
521 "{}. {} (similarity: {:.4})\n",
522 i + 1,
523 other,
524 score
525 ));
526 }
527 report.push('\n');
528 }
529
530 if let Ok(feat_importance) = self.feature_importance(entity, embeddings) {
532 report.push_str("## Feature Importance\n\n");
533 report.push_str("### Top Important Features:\n");
534 for (i, (feature_idx, importance)) in
535 feat_importance.important_features.iter().enumerate()
536 {
537 report.push_str(&format!(
538 "{}. Dimension {} (importance: {:.4})\n",
539 i + 1,
540 feature_idx,
541 importance
542 ));
543 }
544 report.push('\n');
545 }
546
547 if let Ok(nn_analysis) = self.nearest_neighbors_analysis(entity, embeddings) {
549 report.push_str("## Nearest Neighbors\n\n");
550 for (i, (neighbor, distance)) in nn_analysis.neighbors.iter().enumerate() {
551 report.push_str(&format!(
552 "{}. {} (distance: {:.4})\n",
553 i + 1,
554 neighbor,
555 distance
556 ));
557 }
558 report.push('\n');
559 }
560
561 Ok(report)
562 }
563}
564
565#[cfg(test)]
566mod tests {
567 use super::*;
568 use scirs2_core::ndarray_ext::array;
569
570 #[test]
571 fn test_similarity_analysis() {
572 let mut embeddings = HashMap::new();
573 embeddings.insert("e1".to_string(), array![1.0, 0.0, 0.0]);
574 embeddings.insert("e2".to_string(), array![0.9, 0.1, 0.0]);
575 embeddings.insert("e3".to_string(), array![0.0, 1.0, 0.0]);
576
577 let config = InterpretabilityConfig {
578 method: InterpretationMethod::SimilarityAnalysis,
579 top_k: 2,
580 ..Default::default()
581 };
582
583 let analyzer = InterpretabilityAnalyzer::new(config);
584 let analysis = analyzer.similarity_analysis("e1", &embeddings).unwrap();
585
586 assert_eq!(analysis.entity, "e1");
587 assert_eq!(analysis.similar_entities.len(), 2);
588 assert_eq!(analysis.similar_entities[0].0, "e2");
590 }
591
592 #[test]
593 fn test_feature_importance() {
594 let mut embeddings = HashMap::new();
595 embeddings.insert("e1".to_string(), array![1.0, 0.0, 0.0]);
596 embeddings.insert("e2".to_string(), array![0.0, 1.0, 0.0]);
597 embeddings.insert("e3".to_string(), array![0.0, 0.0, 1.0]);
598 embeddings.insert("e4".to_string(), array![5.0, 0.0, 0.0]); let config = InterpretabilityConfig {
601 method: InterpretationMethod::FeatureImportance,
602 top_k: 3,
603 ..Default::default()
604 };
605
606 let analyzer = InterpretabilityAnalyzer::new(config);
607 let importance = analyzer.feature_importance("e4", &embeddings).unwrap();
608
609 assert_eq!(importance.entity, "e4");
610 assert!(!importance.important_features.is_empty());
611 assert_eq!(importance.important_features[0].0, 0);
613 }
614
615 #[test]
616 fn test_counterfactual() {
617 let mut embeddings = HashMap::new();
618 embeddings.insert("e1".to_string(), array![1.0, 0.0, 0.0]);
619 embeddings.insert("e2".to_string(), array![0.0, 1.0, 0.0]);
620
621 let config = InterpretabilityConfig::default();
622 let analyzer = InterpretabilityAnalyzer::new(config);
623
624 let cf = analyzer
625 .counterfactual_explanation("e1", "e2", &embeddings)
626 .unwrap();
627
628 assert_eq!(cf.original, "e1");
629 assert_eq!(cf.target, "e2");
630 assert!(!cf.required_changes.is_empty());
631 assert!(cf.difficulty > 0.0);
632 }
633
634 #[test]
635 fn test_nearest_neighbors() {
636 let mut embeddings = HashMap::new();
637 embeddings.insert("e1".to_string(), array![1.0, 0.0]);
638 embeddings.insert("e2".to_string(), array![1.1, 0.1]);
639 embeddings.insert("e3".to_string(), array![5.0, 5.0]);
640
641 let config = InterpretabilityConfig {
642 method: InterpretationMethod::NearestNeighbors,
643 top_k: 2,
644 ..Default::default()
645 };
646
647 let analyzer = InterpretabilityAnalyzer::new(config);
648 let nn = analyzer
649 .nearest_neighbors_analysis("e1", &embeddings)
650 .unwrap();
651
652 assert_eq!(nn.entity, "e1");
653 assert_eq!(nn.neighbors.len(), 2);
654 assert_eq!(nn.neighbors[0].0, "e2");
656 }
657
658 #[test]
659 fn test_generate_report() {
660 let mut embeddings = HashMap::new();
661 embeddings.insert("e1".to_string(), array![1.0, 0.0, 0.0]);
662 embeddings.insert("e2".to_string(), array![0.9, 0.1, 0.0]);
663
664 let config = InterpretabilityConfig::default();
665 let analyzer = InterpretabilityAnalyzer::new(config);
666
667 let report = analyzer.generate_report("e1", &embeddings).unwrap();
668
669 assert!(report.contains("Interpretability Report"));
670 assert!(report.contains("Similarity Analysis"));
671 assert!(report.contains("Feature Importance"));
672 assert!(report.contains("Nearest Neighbors"));
673 }
674}