Skip to main content

entrenar/monitor/inference/counterfactual/
explanation.rs

1//! Counterfactual explanation structure and methods.
2
3use serde::{Deserialize, Serialize};
4
5use super::error::CounterfactualError;
6use super::feature_change::FeatureChange;
7
8/// Counterfactual explanation for a decision
9///
10/// Answers: "What minimal change would have flipped the decision?"
11#[derive(Clone, Debug, Serialize, Deserialize)]
12pub struct Counterfactual {
13    /// Original input that produced the decision
14    pub original_input: Vec<f32>,
15    /// Original decision/class
16    pub original_decision: usize,
17    /// Original confidence
18    pub original_confidence: f32,
19    /// Modified input that would flip the decision
20    pub counterfactual_input: Vec<f32>,
21    /// The alternative decision
22    pub alternative_decision: usize,
23    /// Alternative confidence
24    pub alternative_confidence: f32,
25    /// Which features changed and by how much
26    pub changes: Vec<FeatureChange>,
27    /// L1 distance (sparsity of changes)
28    pub sparsity: f32,
29    /// L2 distance (magnitude of changes)
30    pub distance: f32,
31}
32
33impl Counterfactual {
34    /// Create a new counterfactual
35    pub fn new(
36        original_input: Vec<f32>,
37        original_decision: usize,
38        original_confidence: f32,
39        counterfactual_input: Vec<f32>,
40        alternative_decision: usize,
41        alternative_confidence: f32,
42    ) -> Self {
43        assert_eq!(original_input.len(), counterfactual_input.len(), "Input dimensions must match");
44
45        let mut changes = Vec::new();
46        let mut l1 = 0.0f32;
47        let mut l2 = 0.0f32;
48
49        for i in 0..original_input.len() {
50            let delta = counterfactual_input[i] - original_input[i];
51            if delta.abs() > 1e-6 {
52                changes.push(FeatureChange::new(i, original_input[i], counterfactual_input[i]));
53                l1 += delta.abs();
54                l2 += delta * delta;
55            }
56        }
57
58        Self {
59            original_input,
60            original_decision,
61            original_confidence,
62            counterfactual_input,
63            alternative_decision,
64            alternative_confidence,
65            changes,
66            sparsity: l1,
67            distance: l2.sqrt(),
68        }
69    }
70
71    /// Generate natural language explanation
72    pub fn explain(&self) -> String {
73        let mut explanation = format!(
74            "Original decision: {} (confidence: {:.1}%)\n",
75            self.original_decision,
76            self.original_confidence * 100.0
77        );
78        explanation.push_str(&format!(
79            "Alternative decision: {} (confidence: {:.1}%)\n",
80            self.alternative_decision,
81            self.alternative_confidence * 100.0
82        ));
83        explanation.push_str(&format!(
84            "\nThe decision would have been {} if:\n",
85            self.alternative_decision
86        ));
87
88        // Sort changes by absolute delta (most impactful first)
89        let mut sorted_changes = self.changes.clone();
90        sorted_changes.sort_by(|a, b| {
91            b.abs_delta().partial_cmp(&a.abs_delta()).unwrap_or(std::cmp::Ordering::Equal)
92        });
93
94        for change in sorted_changes.iter().take(5) {
95            let sign = if change.delta >= 0.0 { "+" } else { "" };
96            let default_name = format!("feature[{}]", change.feature_idx);
97            let name = change.feature_name.as_deref().unwrap_or(&default_name);
98            explanation.push_str(&format!(
99                "  - {}: {:.4} → {:.4} ({}{:.4})\n",
100                name, change.original_value, change.counterfactual_value, sign, change.delta
101            ));
102        }
103
104        if self.changes.len() > 5 {
105            explanation.push_str(&format!("  ... and {} more changes\n", self.changes.len() - 5));
106        }
107
108        explanation.push_str(&format!("\nSparsity (L1): {:.4}\n", self.sparsity));
109        explanation.push_str(&format!("Distance (L2): {:.4}\n", self.distance));
110
111        explanation
112    }
113
114    /// Number of features that changed
115    pub fn n_changes(&self) -> usize {
116        self.changes.len()
117    }
118
119    /// Check if this is a valid counterfactual (decision actually flipped)
120    pub fn is_valid(&self) -> bool {
121        self.original_decision != self.alternative_decision
122    }
123
124    /// Set feature names for all changes
125    pub fn with_feature_names(mut self, names: &[String]) -> Self {
126        for change in &mut self.changes {
127            if change.feature_idx < names.len() {
128                change.feature_name = Some(names[change.feature_idx].clone());
129            }
130        }
131        self
132    }
133
134    /// Convert to binary format
135    pub fn to_bytes(&self) -> Vec<u8> {
136        let mut bytes = Vec::new();
137        bytes.push(1); // version
138
139        // Original decision info
140        bytes.extend_from_slice(&(self.original_decision as u32).to_le_bytes());
141        bytes.extend_from_slice(&self.original_confidence.to_le_bytes());
142        bytes.extend_from_slice(&(self.alternative_decision as u32).to_le_bytes());
143        bytes.extend_from_slice(&self.alternative_confidence.to_le_bytes());
144
145        // Original input
146        bytes.extend_from_slice(&(self.original_input.len() as u32).to_le_bytes());
147        for v in &self.original_input {
148            bytes.extend_from_slice(&v.to_le_bytes());
149        }
150
151        // Counterfactual input
152        for v in &self.counterfactual_input {
153            bytes.extend_from_slice(&v.to_le_bytes());
154        }
155
156        // Changes (compact: only store changed indices and deltas)
157        bytes.extend_from_slice(&(self.changes.len() as u32).to_le_bytes());
158        for change in &self.changes {
159            bytes.extend_from_slice(&(change.feature_idx as u32).to_le_bytes());
160            bytes.extend_from_slice(&change.original_value.to_le_bytes());
161            bytes.extend_from_slice(&change.counterfactual_value.to_le_bytes());
162            bytes.extend_from_slice(&change.delta.to_le_bytes());
163
164            // Feature name (length-prefixed)
165            if let Some(name) = &change.feature_name {
166                bytes.extend_from_slice(&(name.len() as u32).to_le_bytes());
167                bytes.extend_from_slice(name.as_bytes());
168            } else {
169                bytes.extend_from_slice(&0u32.to_le_bytes());
170            }
171        }
172
173        // Metrics
174        bytes.extend_from_slice(&self.sparsity.to_le_bytes());
175        bytes.extend_from_slice(&self.distance.to_le_bytes());
176
177        bytes
178    }
179
180    /// Reconstruct from binary format
181    pub fn from_bytes(bytes: &[u8]) -> Result<Self, CounterfactualError> {
182        if bytes.len() < 21 {
183            return Err(CounterfactualError::InsufficientData {
184                expected: 21,
185                actual: bytes.len(),
186            });
187        }
188
189        let mut reader = ByteReader::new(bytes);
190
191        let version = reader.read_u8()?;
192        if version != 1 {
193            return Err(CounterfactualError::VersionMismatch { expected: 1, actual: version });
194        }
195
196        let original_decision = reader.read_u32_as_usize()?;
197        let original_confidence = reader.read_f32()?;
198        let alternative_decision = reader.read_u32_as_usize()?;
199        let alternative_confidence = reader.read_f32()?;
200        let n_features = reader.read_u32_as_usize()?;
201
202        let original_input = reader.read_f32_vec_n(n_features)?;
203        let counterfactual_input = reader.read_f32_vec_n(n_features)?;
204
205        let n_changes = reader.read_u32_as_usize()?;
206        let mut changes = Vec::with_capacity(n_changes);
207        for _ in 0..n_changes {
208            changes.push(reader.read_feature_change()?);
209        }
210
211        let sparsity = reader.read_f32()?;
212        let distance = reader.read_f32()?;
213
214        Ok(Self {
215            original_input,
216            original_decision,
217            original_confidence,
218            counterfactual_input,
219            alternative_decision,
220            alternative_confidence,
221            changes,
222            sparsity,
223            distance,
224        })
225    }
226}
227
228/// Stateful byte reader that tracks offset and validates bounds.
229struct ByteReader<'a> {
230    data: &'a [u8],
231    offset: usize,
232}
233
234impl<'a> ByteReader<'a> {
235    fn new(data: &'a [u8]) -> Self {
236        Self { data, offset: 0 }
237    }
238
239    fn read_u8(&mut self) -> Result<u8, CounterfactualError> {
240        self.ensure_available(1)?;
241        let val = self.data[self.offset];
242        self.offset += 1;
243        Ok(val)
244    }
245
246    fn read_u32(&mut self) -> Result<u32, CounterfactualError> {
247        self.ensure_available(4)?;
248        let o = self.offset;
249        let val = u32::from_le_bytes([
250            self.data[o],
251            self.data[o + 1],
252            self.data[o + 2],
253            self.data[o + 3],
254        ]);
255        self.offset += 4;
256        Ok(val)
257    }
258
259    fn read_u32_as_usize(&mut self) -> Result<usize, CounterfactualError> {
260        Ok(self.read_u32()? as usize)
261    }
262
263    fn read_f32(&mut self) -> Result<f32, CounterfactualError> {
264        self.ensure_available(4)?;
265        let o = self.offset;
266        let val = f32::from_le_bytes([
267            self.data[o],
268            self.data[o + 1],
269            self.data[o + 2],
270            self.data[o + 3],
271        ]);
272        self.offset += 4;
273        Ok(val)
274    }
275
276    fn read_f32_vec_n(&mut self, n: usize) -> Result<Vec<f32>, CounterfactualError> {
277        let mut vec = Vec::with_capacity(n);
278        for _ in 0..n {
279            vec.push(self.read_f32()?);
280        }
281        Ok(vec)
282    }
283
284    fn read_string(&mut self, len: usize) -> Result<String, CounterfactualError> {
285        self.ensure_available(len)?;
286        let s = String::from_utf8_lossy(&self.data[self.offset..self.offset + len]).to_string();
287        self.offset += len;
288        Ok(s)
289    }
290
291    fn read_feature_change(&mut self) -> Result<FeatureChange, CounterfactualError> {
292        let feature_idx = self.read_u32_as_usize()?;
293        let original_value = self.read_f32()?;
294        let counterfactual_value = self.read_f32()?;
295        let delta = self.read_f32()?;
296        let name_len = self.read_u32_as_usize()?;
297        let feature_name = if name_len > 0 { Some(self.read_string(name_len)?) } else { None };
298        Ok(FeatureChange { feature_idx, feature_name, original_value, counterfactual_value, delta })
299    }
300
301    fn ensure_available(&self, needed: usize) -> Result<(), CounterfactualError> {
302        if self.offset + needed > self.data.len() {
303            return Err(CounterfactualError::InsufficientData {
304                expected: self.offset + needed,
305                actual: self.data.len(),
306            });
307        }
308        Ok(())
309    }
310}