entrenar/monitor/inference/counterfactual/
explanation.rs1use serde::{Deserialize, Serialize};
4
5use super::error::CounterfactualError;
6use super::feature_change::FeatureChange;
7
8#[derive(Clone, Debug, Serialize, Deserialize)]
12pub struct Counterfactual {
13 pub original_input: Vec<f32>,
15 pub original_decision: usize,
17 pub original_confidence: f32,
19 pub counterfactual_input: Vec<f32>,
21 pub alternative_decision: usize,
23 pub alternative_confidence: f32,
25 pub changes: Vec<FeatureChange>,
27 pub sparsity: f32,
29 pub distance: f32,
31}
32
33impl Counterfactual {
34 pub fn new(
36 original_input: Vec<f32>,
37 original_decision: usize,
38 original_confidence: f32,
39 counterfactual_input: Vec<f32>,
40 alternative_decision: usize,
41 alternative_confidence: f32,
42 ) -> Self {
43 assert_eq!(original_input.len(), counterfactual_input.len(), "Input dimensions must match");
44
45 let mut changes = Vec::new();
46 let mut l1 = 0.0f32;
47 let mut l2 = 0.0f32;
48
49 for i in 0..original_input.len() {
50 let delta = counterfactual_input[i] - original_input[i];
51 if delta.abs() > 1e-6 {
52 changes.push(FeatureChange::new(i, original_input[i], counterfactual_input[i]));
53 l1 += delta.abs();
54 l2 += delta * delta;
55 }
56 }
57
58 Self {
59 original_input,
60 original_decision,
61 original_confidence,
62 counterfactual_input,
63 alternative_decision,
64 alternative_confidence,
65 changes,
66 sparsity: l1,
67 distance: l2.sqrt(),
68 }
69 }
70
71 pub fn explain(&self) -> String {
73 let mut explanation = format!(
74 "Original decision: {} (confidence: {:.1}%)\n",
75 self.original_decision,
76 self.original_confidence * 100.0
77 );
78 explanation.push_str(&format!(
79 "Alternative decision: {} (confidence: {:.1}%)\n",
80 self.alternative_decision,
81 self.alternative_confidence * 100.0
82 ));
83 explanation.push_str(&format!(
84 "\nThe decision would have been {} if:\n",
85 self.alternative_decision
86 ));
87
88 let mut sorted_changes = self.changes.clone();
90 sorted_changes.sort_by(|a, b| {
91 b.abs_delta().partial_cmp(&a.abs_delta()).unwrap_or(std::cmp::Ordering::Equal)
92 });
93
94 for change in sorted_changes.iter().take(5) {
95 let sign = if change.delta >= 0.0 { "+" } else { "" };
96 let default_name = format!("feature[{}]", change.feature_idx);
97 let name = change.feature_name.as_deref().unwrap_or(&default_name);
98 explanation.push_str(&format!(
99 " - {}: {:.4} → {:.4} ({}{:.4})\n",
100 name, change.original_value, change.counterfactual_value, sign, change.delta
101 ));
102 }
103
104 if self.changes.len() > 5 {
105 explanation.push_str(&format!(" ... and {} more changes\n", self.changes.len() - 5));
106 }
107
108 explanation.push_str(&format!("\nSparsity (L1): {:.4}\n", self.sparsity));
109 explanation.push_str(&format!("Distance (L2): {:.4}\n", self.distance));
110
111 explanation
112 }
113
114 pub fn n_changes(&self) -> usize {
116 self.changes.len()
117 }
118
119 pub fn is_valid(&self) -> bool {
121 self.original_decision != self.alternative_decision
122 }
123
124 pub fn with_feature_names(mut self, names: &[String]) -> Self {
126 for change in &mut self.changes {
127 if change.feature_idx < names.len() {
128 change.feature_name = Some(names[change.feature_idx].clone());
129 }
130 }
131 self
132 }
133
134 pub fn to_bytes(&self) -> Vec<u8> {
136 let mut bytes = Vec::new();
137 bytes.push(1); bytes.extend_from_slice(&(self.original_decision as u32).to_le_bytes());
141 bytes.extend_from_slice(&self.original_confidence.to_le_bytes());
142 bytes.extend_from_slice(&(self.alternative_decision as u32).to_le_bytes());
143 bytes.extend_from_slice(&self.alternative_confidence.to_le_bytes());
144
145 bytes.extend_from_slice(&(self.original_input.len() as u32).to_le_bytes());
147 for v in &self.original_input {
148 bytes.extend_from_slice(&v.to_le_bytes());
149 }
150
151 for v in &self.counterfactual_input {
153 bytes.extend_from_slice(&v.to_le_bytes());
154 }
155
156 bytes.extend_from_slice(&(self.changes.len() as u32).to_le_bytes());
158 for change in &self.changes {
159 bytes.extend_from_slice(&(change.feature_idx as u32).to_le_bytes());
160 bytes.extend_from_slice(&change.original_value.to_le_bytes());
161 bytes.extend_from_slice(&change.counterfactual_value.to_le_bytes());
162 bytes.extend_from_slice(&change.delta.to_le_bytes());
163
164 if let Some(name) = &change.feature_name {
166 bytes.extend_from_slice(&(name.len() as u32).to_le_bytes());
167 bytes.extend_from_slice(name.as_bytes());
168 } else {
169 bytes.extend_from_slice(&0u32.to_le_bytes());
170 }
171 }
172
173 bytes.extend_from_slice(&self.sparsity.to_le_bytes());
175 bytes.extend_from_slice(&self.distance.to_le_bytes());
176
177 bytes
178 }
179
180 pub fn from_bytes(bytes: &[u8]) -> Result<Self, CounterfactualError> {
182 if bytes.len() < 21 {
183 return Err(CounterfactualError::InsufficientData {
184 expected: 21,
185 actual: bytes.len(),
186 });
187 }
188
189 let mut reader = ByteReader::new(bytes);
190
191 let version = reader.read_u8()?;
192 if version != 1 {
193 return Err(CounterfactualError::VersionMismatch { expected: 1, actual: version });
194 }
195
196 let original_decision = reader.read_u32_as_usize()?;
197 let original_confidence = reader.read_f32()?;
198 let alternative_decision = reader.read_u32_as_usize()?;
199 let alternative_confidence = reader.read_f32()?;
200 let n_features = reader.read_u32_as_usize()?;
201
202 let original_input = reader.read_f32_vec_n(n_features)?;
203 let counterfactual_input = reader.read_f32_vec_n(n_features)?;
204
205 let n_changes = reader.read_u32_as_usize()?;
206 let mut changes = Vec::with_capacity(n_changes);
207 for _ in 0..n_changes {
208 changes.push(reader.read_feature_change()?);
209 }
210
211 let sparsity = reader.read_f32()?;
212 let distance = reader.read_f32()?;
213
214 Ok(Self {
215 original_input,
216 original_decision,
217 original_confidence,
218 counterfactual_input,
219 alternative_decision,
220 alternative_confidence,
221 changes,
222 sparsity,
223 distance,
224 })
225 }
226}
227
228struct ByteReader<'a> {
230 data: &'a [u8],
231 offset: usize,
232}
233
234impl<'a> ByteReader<'a> {
235 fn new(data: &'a [u8]) -> Self {
236 Self { data, offset: 0 }
237 }
238
239 fn read_u8(&mut self) -> Result<u8, CounterfactualError> {
240 self.ensure_available(1)?;
241 let val = self.data[self.offset];
242 self.offset += 1;
243 Ok(val)
244 }
245
246 fn read_u32(&mut self) -> Result<u32, CounterfactualError> {
247 self.ensure_available(4)?;
248 let o = self.offset;
249 let val = u32::from_le_bytes([
250 self.data[o],
251 self.data[o + 1],
252 self.data[o + 2],
253 self.data[o + 3],
254 ]);
255 self.offset += 4;
256 Ok(val)
257 }
258
259 fn read_u32_as_usize(&mut self) -> Result<usize, CounterfactualError> {
260 Ok(self.read_u32()? as usize)
261 }
262
263 fn read_f32(&mut self) -> Result<f32, CounterfactualError> {
264 self.ensure_available(4)?;
265 let o = self.offset;
266 let val = f32::from_le_bytes([
267 self.data[o],
268 self.data[o + 1],
269 self.data[o + 2],
270 self.data[o + 3],
271 ]);
272 self.offset += 4;
273 Ok(val)
274 }
275
276 fn read_f32_vec_n(&mut self, n: usize) -> Result<Vec<f32>, CounterfactualError> {
277 let mut vec = Vec::with_capacity(n);
278 for _ in 0..n {
279 vec.push(self.read_f32()?);
280 }
281 Ok(vec)
282 }
283
284 fn read_string(&mut self, len: usize) -> Result<String, CounterfactualError> {
285 self.ensure_available(len)?;
286 let s = String::from_utf8_lossy(&self.data[self.offset..self.offset + len]).to_string();
287 self.offset += len;
288 Ok(s)
289 }
290
291 fn read_feature_change(&mut self) -> Result<FeatureChange, CounterfactualError> {
292 let feature_idx = self.read_u32_as_usize()?;
293 let original_value = self.read_f32()?;
294 let counterfactual_value = self.read_f32()?;
295 let delta = self.read_f32()?;
296 let name_len = self.read_u32_as_usize()?;
297 let feature_name = if name_len > 0 { Some(self.read_string(name_len)?) } else { None };
298 Ok(FeatureChange { feature_idx, feature_name, original_value, counterfactual_value, delta })
299 }
300
301 fn ensure_available(&self, needed: usize) -> Result<(), CounterfactualError> {
302 if self.offset + needed > self.data.len() {
303 return Err(CounterfactualError::InsufficientData {
304 expected: self.offset + needed,
305 actual: self.data.len(),
306 });
307 }
308 Ok(())
309 }
310}