1use serde::{Deserialize, Serialize};
11use std::collections::{HashMap, HashSet, VecDeque};
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct ScoredTriple {
20 pub subject: String,
21 pub predicate: String,
22 pub object: String,
23 pub score: f64,
25 pub source: Option<String>,
27}
28
29impl ScoredTriple {
30 pub fn new(
31 subject: impl Into<String>,
32 predicate: impl Into<String>,
33 object: impl Into<String>,
34 score: f64,
35 ) -> Self {
36 Self {
37 subject: subject.into(),
38 predicate: predicate.into(),
39 object: object.into(),
40 score: score.clamp(0.0, 1.0),
41 source: None,
42 }
43 }
44
45 pub fn with_source(mut self, source: impl Into<String>) -> Self {
46 self.source = Some(source.into());
47 self
48 }
49}
50
51#[derive(Debug, Clone, Serialize, Deserialize)]
57pub struct AttentionExplanation {
58 pub weighted_triples: Vec<ScoredTriple>,
60 pub query: String,
62 pub attention_entropy: f64,
64}
65
66impl AttentionExplanation {
67 pub fn compute(query: &str, triples: &[ScoredTriple], raw_scores: &[f64]) -> Self {
71 assert_eq!(triples.len(), raw_scores.len(), "lengths must match");
72 let weights = softmax(raw_scores);
73 let entropy = shannon_entropy(&weights);
74
75 let weighted_triples = triples
76 .iter()
77 .zip(weights.iter())
78 .map(|(t, &w)| {
79 let mut wt = t.clone();
80 wt.score = w;
81 wt
82 })
83 .collect();
84
85 Self {
86 weighted_triples,
87 query: query.to_string(),
88 attention_entropy: entropy,
89 }
90 }
91
92 pub fn top_k(&self, k: usize) -> Vec<&ScoredTriple> {
94 let mut sorted: Vec<&ScoredTriple> = self.weighted_triples.iter().collect();
95 sorted.sort_by(|a, b| {
96 b.score
97 .partial_cmp(&a.score)
98 .unwrap_or(std::cmp::Ordering::Equal)
99 });
100 sorted.into_iter().take(k).collect()
101 }
102}
103
104fn softmax(scores: &[f64]) -> Vec<f64> {
105 if scores.is_empty() {
106 return Vec::new();
107 }
108 let max = scores.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
109 let exps: Vec<f64> = scores.iter().map(|&s| (s - max).exp()).collect();
110 let sum: f64 = exps.iter().sum();
111 if sum == 0.0 {
112 return vec![1.0 / scores.len() as f64; scores.len()];
113 }
114 exps.iter().map(|&e| e / sum).collect()
115}
116
117fn shannon_entropy(probs: &[f64]) -> f64 {
118 probs
119 .iter()
120 .filter(|&&p| p > 0.0)
121 .map(|&p| -p * p.ln())
122 .sum()
123}
124
125#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
131pub struct PathHop {
132 pub from: String,
133 pub predicate: String,
134 pub to: String,
135}
136
137#[derive(Debug, Clone, Serialize, Deserialize)]
139pub struct PathExplanation {
140 pub from: String,
141 pub to: String,
142 pub hops: Vec<PathHop>,
144 pub path_length: usize,
146}
147
148impl PathExplanation {
149 pub fn find(triples: &[ScoredTriple], from: &str, to: &str) -> Option<Self> {
153 if from == to {
154 return Some(Self {
155 from: from.to_string(),
156 to: to.to_string(),
157 hops: Vec::new(),
158 path_length: 0,
159 });
160 }
161
162 let mut adj: HashMap<&str, Vec<(&str, &str)>> = HashMap::new();
164 for t in triples {
165 adj.entry(&t.subject)
166 .or_default()
167 .push((&t.predicate, &t.object));
168 }
169
170 let mut queue: VecDeque<(&str, Vec<PathHop>)> = VecDeque::new();
172 let mut visited: HashSet<&str> = HashSet::new();
173 queue.push_back((from, Vec::new()));
174 visited.insert(from);
175
176 while let Some((node, path)) = queue.pop_front() {
177 if let Some(neighbors) = adj.get(node) {
178 for &(pred, obj) in neighbors {
179 if visited.contains(obj) {
180 continue;
181 }
182 let mut new_path = path.clone();
183 new_path.push(PathHop {
184 from: node.to_string(),
185 predicate: pred.to_string(),
186 to: obj.to_string(),
187 });
188 if obj == to {
189 let length = new_path.len();
190 return Some(Self {
191 from: from.to_string(),
192 to: to.to_string(),
193 hops: new_path,
194 path_length: length,
195 });
196 }
197 visited.insert(obj);
198 queue.push_back((obj, new_path));
199 }
200 }
201 }
202 None
203 }
204}
205
206#[derive(Debug, Clone, Serialize, Deserialize)]
212pub struct ProvenanceReport {
213 pub sources: HashMap<String, Vec<ScoredTriple>>,
215 pub unknown_count: usize,
217}
218
219impl ProvenanceReport {
220 pub fn from_triples(triples: &[ScoredTriple]) -> Self {
222 let mut sources: HashMap<String, Vec<ScoredTriple>> = HashMap::new();
223 let mut unknown_count = 0;
224 for t in triples {
225 match &t.source {
226 Some(src) => sources.entry(src.clone()).or_default().push(t.clone()),
227 None => unknown_count += 1,
228 }
229 }
230 Self {
231 sources,
232 unknown_count,
233 }
234 }
235
236 pub fn source_iris(&self) -> Vec<&str> {
238 self.sources.keys().map(|s| s.as_str()).collect()
239 }
240}
241
242pub struct ExplainabilityEngine;
248
249impl ExplainabilityEngine {
250 pub fn new() -> Self {
251 Self
252 }
253
254 pub fn explain_attention(
256 &self,
257 query: &str,
258 triples: &[ScoredTriple],
259 raw_scores: &[f64],
260 ) -> AttentionExplanation {
261 AttentionExplanation::compute(query, triples, raw_scores)
262 }
263
264 pub fn explain_path(
266 &self,
267 triples: &[ScoredTriple],
268 from: &str,
269 to: &str,
270 ) -> Option<PathExplanation> {
271 PathExplanation::find(triples, from, to)
272 }
273
274 pub fn explain_provenance(&self, triples: &[ScoredTriple]) -> ProvenanceReport {
276 ProvenanceReport::from_triples(triples)
277 }
278}
279
280impl Default for ExplainabilityEngine {
281 fn default() -> Self {
282 Self::new()
283 }
284}
285
286#[cfg(test)]
291mod tests {
292 use super::*;
293
294 fn make_triples() -> Vec<ScoredTriple> {
295 vec![
296 ScoredTriple::new("Alice", "knows", "Bob", 0.9).with_source("doc:1"),
297 ScoredTriple::new("Bob", "worksAt", "Acme", 0.7).with_source("doc:2"),
298 ScoredTriple::new("Alice", "livesIn", "Tokyo", 0.5).with_source("doc:1"),
299 ]
300 }
301
302 #[test]
303 fn test_attention_softmax_sum_to_one() {
304 let triples = make_triples();
305 let raw = vec![1.0, 2.0, 0.5];
306 let expl = AttentionExplanation::compute("Who does Alice know?", &triples, &raw);
307 let total: f64 = expl.weighted_triples.iter().map(|t| t.score).sum();
308 assert!((total - 1.0).abs() < 1e-9, "weights must sum to 1");
309 }
310
311 #[test]
312 fn test_attention_top_k() {
313 let triples = make_triples();
314 let raw = vec![3.0, 1.0, 2.0];
315 let expl = AttentionExplanation::compute("q", &triples, &raw);
316 let top1 = expl.top_k(1);
317 assert_eq!(top1[0].subject, "Alice");
318 assert_eq!(top1[0].predicate, "knows");
319 }
320
321 #[test]
322 fn test_attention_entropy_uniform() {
323 let triples = make_triples();
324 let raw = vec![1.0, 1.0, 1.0]; let expl = AttentionExplanation::compute("q", &triples, &raw);
326 assert!(
328 expl.attention_entropy > 1.09,
329 "uniform should have high entropy"
330 );
331 }
332
333 #[test]
334 fn test_path_explanation_direct_hop() {
335 let triples = make_triples();
336 let path = PathExplanation::find(&triples, "Alice", "Bob").unwrap();
337 assert_eq!(path.path_length, 1);
338 assert_eq!(path.hops[0].predicate, "knows");
339 }
340
341 #[test]
342 fn test_path_explanation_two_hops() {
343 let triples = make_triples();
344 let path = PathExplanation::find(&triples, "Alice", "Acme").unwrap();
345 assert_eq!(path.path_length, 2);
346 }
347
348 #[test]
349 fn test_path_explanation_no_path() {
350 let triples = make_triples();
351 let path = PathExplanation::find(&triples, "Alice", "XYZ");
352 assert!(path.is_none(), "no path to unknown node");
353 }
354
355 #[test]
356 fn test_path_explanation_same_node() {
357 let triples = make_triples();
358 let path = PathExplanation::find(&triples, "Alice", "Alice").unwrap();
359 assert_eq!(path.path_length, 0);
360 assert!(path.hops.is_empty());
361 }
362
363 #[test]
364 fn test_provenance_report() {
365 let triples = make_triples();
366 let report = ProvenanceReport::from_triples(&triples);
367 let mut sources = report.source_iris();
368 sources.sort();
369 assert_eq!(sources, vec!["doc:1", "doc:2"]);
370 assert_eq!(report.unknown_count, 0);
371 }
372
373 #[test]
374 fn test_provenance_unknown_triples() {
375 let triples = vec![
376 ScoredTriple::new("A", "p", "B", 0.5), ];
378 let report = ProvenanceReport::from_triples(&triples);
379 assert_eq!(report.unknown_count, 1);
380 assert!(report.sources.is_empty());
381 }
382
383 #[test]
384 fn test_explainability_engine_integration() {
385 let engine = ExplainabilityEngine::new();
386 let triples = make_triples();
387 let raw = vec![0.8, 0.6, 0.4];
388
389 let attn = engine.explain_attention("query", &triples, &raw);
390 assert!(!attn.weighted_triples.is_empty());
391
392 let path = engine.explain_path(&triples, "Alice", "Acme");
393 assert!(path.is_some());
394
395 let prov = engine.explain_provenance(&triples);
396 assert!(!prov.sources.is_empty());
397 }
398}