Skip to main content

oxirs_graphrag/
explainability.rs

1//! Explainability engine for graph-based RAG — attention weights, path explanation, provenance.
2//!
3//! Provides human-interpretable explanations of why a set of triples was
4//! retrieved for a given query.  Three explanation strategies:
5//!
6//! 1. **Attention weights** — normalized relevance scores per triple
7//! 2. **Path explanation** — BFS shortest-path from query entity to result entity
8//! 3. **Provenance** — source documents / IRIs that contributed to the answer
9
10use serde::{Deserialize, Serialize};
11use std::collections::{HashMap, HashSet, VecDeque};
12
13// ─────────────────────────────────────────────────────────────────────────────
14// Core types
15// ─────────────────────────────────────────────────────────────────────────────
16
17/// A knowledge graph triple with an associated relevance score.
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct ScoredTriple {
20    pub subject: String,
21    pub predicate: String,
22    pub object: String,
23    /// Relevance score in [0, 1].
24    pub score: f64,
25    /// Optional provenance source IRI.
26    pub source: Option<String>,
27}
28
29impl ScoredTriple {
30    pub fn new(
31        subject: impl Into<String>,
32        predicate: impl Into<String>,
33        object: impl Into<String>,
34        score: f64,
35    ) -> Self {
36        Self {
37            subject: subject.into(),
38            predicate: predicate.into(),
39            object: object.into(),
40            score: score.clamp(0.0, 1.0),
41            source: None,
42        }
43    }
44
45    pub fn with_source(mut self, source: impl Into<String>) -> Self {
46        self.source = Some(source.into());
47        self
48    }
49}
50
51// ─────────────────────────────────────────────────────────────────────────────
52// Attention weights
53// ─────────────────────────────────────────────────────────────────────────────
54
55/// Normalized attention weights over a set of triples.
56#[derive(Debug, Clone, Serialize, Deserialize)]
57pub struct AttentionExplanation {
58    /// Triples with softmax-normalized attention weights.
59    pub weighted_triples: Vec<ScoredTriple>,
60    /// The query that generated these weights.
61    pub query: String,
62    /// Entropy of the weight distribution (higher = more diffuse attention).
63    pub attention_entropy: f64,
64}
65
66impl AttentionExplanation {
67    /// Compute softmax-normalized attention weights for the given triples.
68    ///
69    /// The `raw_scores` slice must have the same length as `triples`.
70    pub fn compute(query: &str, triples: &[ScoredTriple], raw_scores: &[f64]) -> Self {
71        assert_eq!(triples.len(), raw_scores.len(), "lengths must match");
72        let weights = softmax(raw_scores);
73        let entropy = shannon_entropy(&weights);
74
75        let weighted_triples = triples
76            .iter()
77            .zip(weights.iter())
78            .map(|(t, &w)| {
79                let mut wt = t.clone();
80                wt.score = w;
81                wt
82            })
83            .collect();
84
85        Self {
86            weighted_triples,
87            query: query.to_string(),
88            attention_entropy: entropy,
89        }
90    }
91
92    /// Return the top-k triples by attention weight.
93    pub fn top_k(&self, k: usize) -> Vec<&ScoredTriple> {
94        let mut sorted: Vec<&ScoredTriple> = self.weighted_triples.iter().collect();
95        sorted.sort_by(|a, b| {
96            b.score
97                .partial_cmp(&a.score)
98                .unwrap_or(std::cmp::Ordering::Equal)
99        });
100        sorted.into_iter().take(k).collect()
101    }
102}
103
104fn softmax(scores: &[f64]) -> Vec<f64> {
105    if scores.is_empty() {
106        return Vec::new();
107    }
108    let max = scores.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
109    let exps: Vec<f64> = scores.iter().map(|&s| (s - max).exp()).collect();
110    let sum: f64 = exps.iter().sum();
111    if sum == 0.0 {
112        return vec![1.0 / scores.len() as f64; scores.len()];
113    }
114    exps.iter().map(|&e| e / sum).collect()
115}
116
117fn shannon_entropy(probs: &[f64]) -> f64 {
118    probs
119        .iter()
120        .filter(|&&p| p > 0.0)
121        .map(|&p| -p * p.ln())
122        .sum()
123}
124
125// ─────────────────────────────────────────────────────────────────────────────
126// Path explanation
127// ─────────────────────────────────────────────────────────────────────────────
128
129/// One hop in a graph path.
130#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
131pub struct PathHop {
132    pub from: String,
133    pub predicate: String,
134    pub to: String,
135}
136
137/// A BFS shortest-path explanation from a query entity to a result entity.
138#[derive(Debug, Clone, Serialize, Deserialize)]
139pub struct PathExplanation {
140    pub from: String,
141    pub to: String,
142    /// Ordered sequence of hops forming the shortest path.
143    pub hops: Vec<PathHop>,
144    /// Total path length (number of hops).
145    pub path_length: usize,
146}
147
148impl PathExplanation {
149    /// Find the BFS shortest path in a set of triples from `from` to `to`.
150    ///
151    /// Returns `None` if no path exists.
152    pub fn find(triples: &[ScoredTriple], from: &str, to: &str) -> Option<Self> {
153        if from == to {
154            return Some(Self {
155                from: from.to_string(),
156                to: to.to_string(),
157                hops: Vec::new(),
158                path_length: 0,
159            });
160        }
161
162        // Build adjacency list
163        let mut adj: HashMap<&str, Vec<(&str, &str)>> = HashMap::new();
164        for t in triples {
165            adj.entry(&t.subject)
166                .or_default()
167                .push((&t.predicate, &t.object));
168        }
169
170        // BFS
171        let mut queue: VecDeque<(&str, Vec<PathHop>)> = VecDeque::new();
172        let mut visited: HashSet<&str> = HashSet::new();
173        queue.push_back((from, Vec::new()));
174        visited.insert(from);
175
176        while let Some((node, path)) = queue.pop_front() {
177            if let Some(neighbors) = adj.get(node) {
178                for &(pred, obj) in neighbors {
179                    if visited.contains(obj) {
180                        continue;
181                    }
182                    let mut new_path = path.clone();
183                    new_path.push(PathHop {
184                        from: node.to_string(),
185                        predicate: pred.to_string(),
186                        to: obj.to_string(),
187                    });
188                    if obj == to {
189                        let length = new_path.len();
190                        return Some(Self {
191                            from: from.to_string(),
192                            to: to.to_string(),
193                            hops: new_path,
194                            path_length: length,
195                        });
196                    }
197                    visited.insert(obj);
198                    queue.push_back((obj, new_path));
199                }
200            }
201        }
202        None
203    }
204}
205
206// ─────────────────────────────────────────────────────────────────────────────
207// Provenance
208// ─────────────────────────────────────────────────────────────────────────────
209
210/// Source provenance for a set of triples.
211#[derive(Debug, Clone, Serialize, Deserialize)]
212pub struct ProvenanceReport {
213    /// Mapping from source IRI to the triples sourced from it.
214    pub sources: HashMap<String, Vec<ScoredTriple>>,
215    /// Number of triples with no provenance information.
216    pub unknown_count: usize,
217}
218
219impl ProvenanceReport {
220    /// Build a provenance report from a set of scored triples.
221    pub fn from_triples(triples: &[ScoredTriple]) -> Self {
222        let mut sources: HashMap<String, Vec<ScoredTriple>> = HashMap::new();
223        let mut unknown_count = 0;
224        for t in triples {
225            match &t.source {
226                Some(src) => sources.entry(src.clone()).or_default().push(t.clone()),
227                None => unknown_count += 1,
228            }
229        }
230        Self {
231            sources,
232            unknown_count,
233        }
234    }
235
236    /// Return all distinct source IRIs.
237    pub fn source_iris(&self) -> Vec<&str> {
238        self.sources.keys().map(|s| s.as_str()).collect()
239    }
240}
241
242// ─────────────────────────────────────────────────────────────────────────────
243// ExplainabilityEngine
244// ─────────────────────────────────────────────────────────────────────────────
245
246/// Top-level explainability engine — wraps all explanation strategies.
247pub struct ExplainabilityEngine;
248
249impl ExplainabilityEngine {
250    pub fn new() -> Self {
251        Self
252    }
253
254    /// Compute attention-based explanation.
255    pub fn explain_attention(
256        &self,
257        query: &str,
258        triples: &[ScoredTriple],
259        raw_scores: &[f64],
260    ) -> AttentionExplanation {
261        AttentionExplanation::compute(query, triples, raw_scores)
262    }
263
264    /// Find the shortest explanatory path from `from` to `to`.
265    pub fn explain_path(
266        &self,
267        triples: &[ScoredTriple],
268        from: &str,
269        to: &str,
270    ) -> Option<PathExplanation> {
271        PathExplanation::find(triples, from, to)
272    }
273
274    /// Build a provenance report for the given triples.
275    pub fn explain_provenance(&self, triples: &[ScoredTriple]) -> ProvenanceReport {
276        ProvenanceReport::from_triples(triples)
277    }
278}
279
280impl Default for ExplainabilityEngine {
281    fn default() -> Self {
282        Self::new()
283    }
284}
285
286// ─────────────────────────────────────────────────────────────────────────────
287// Tests
288// ─────────────────────────────────────────────────────────────────────────────
289
290#[cfg(test)]
291mod tests {
292    use super::*;
293
294    fn make_triples() -> Vec<ScoredTriple> {
295        vec![
296            ScoredTriple::new("Alice", "knows", "Bob", 0.9).with_source("doc:1"),
297            ScoredTriple::new("Bob", "worksAt", "Acme", 0.7).with_source("doc:2"),
298            ScoredTriple::new("Alice", "livesIn", "Tokyo", 0.5).with_source("doc:1"),
299        ]
300    }
301
302    #[test]
303    fn test_attention_softmax_sum_to_one() {
304        let triples = make_triples();
305        let raw = vec![1.0, 2.0, 0.5];
306        let expl = AttentionExplanation::compute("Who does Alice know?", &triples, &raw);
307        let total: f64 = expl.weighted_triples.iter().map(|t| t.score).sum();
308        assert!((total - 1.0).abs() < 1e-9, "weights must sum to 1");
309    }
310
311    #[test]
312    fn test_attention_top_k() {
313        let triples = make_triples();
314        let raw = vec![3.0, 1.0, 2.0];
315        let expl = AttentionExplanation::compute("q", &triples, &raw);
316        let top1 = expl.top_k(1);
317        assert_eq!(top1[0].subject, "Alice");
318        assert_eq!(top1[0].predicate, "knows");
319    }
320
321    #[test]
322    fn test_attention_entropy_uniform() {
323        let triples = make_triples();
324        let raw = vec![1.0, 1.0, 1.0]; // uniform → max entropy for 3 items
325        let expl = AttentionExplanation::compute("q", &triples, &raw);
326        // entropy of [1/3, 1/3, 1/3] = ln(3) ≈ 1.099
327        assert!(
328            expl.attention_entropy > 1.09,
329            "uniform should have high entropy"
330        );
331    }
332
333    #[test]
334    fn test_path_explanation_direct_hop() {
335        let triples = make_triples();
336        let path = PathExplanation::find(&triples, "Alice", "Bob").unwrap();
337        assert_eq!(path.path_length, 1);
338        assert_eq!(path.hops[0].predicate, "knows");
339    }
340
341    #[test]
342    fn test_path_explanation_two_hops() {
343        let triples = make_triples();
344        let path = PathExplanation::find(&triples, "Alice", "Acme").unwrap();
345        assert_eq!(path.path_length, 2);
346    }
347
348    #[test]
349    fn test_path_explanation_no_path() {
350        let triples = make_triples();
351        let path = PathExplanation::find(&triples, "Alice", "XYZ");
352        assert!(path.is_none(), "no path to unknown node");
353    }
354
355    #[test]
356    fn test_path_explanation_same_node() {
357        let triples = make_triples();
358        let path = PathExplanation::find(&triples, "Alice", "Alice").unwrap();
359        assert_eq!(path.path_length, 0);
360        assert!(path.hops.is_empty());
361    }
362
363    #[test]
364    fn test_provenance_report() {
365        let triples = make_triples();
366        let report = ProvenanceReport::from_triples(&triples);
367        let mut sources = report.source_iris();
368        sources.sort();
369        assert_eq!(sources, vec!["doc:1", "doc:2"]);
370        assert_eq!(report.unknown_count, 0);
371    }
372
373    #[test]
374    fn test_provenance_unknown_triples() {
375        let triples = vec![
376            ScoredTriple::new("A", "p", "B", 0.5), // no source
377        ];
378        let report = ProvenanceReport::from_triples(&triples);
379        assert_eq!(report.unknown_count, 1);
380        assert!(report.sources.is_empty());
381    }
382
383    #[test]
384    fn test_explainability_engine_integration() {
385        let engine = ExplainabilityEngine::new();
386        let triples = make_triples();
387        let raw = vec![0.8, 0.6, 0.4];
388
389        let attn = engine.explain_attention("query", &triples, &raw);
390        assert!(!attn.weighted_triples.is_empty());
391
392        let path = engine.explain_path(&triples, "Alice", "Acme");
393        assert!(path.is_some());
394
395        let prov = engine.explain_provenance(&triples);
396        assert!(!prov.sources.is_empty());
397    }
398}