use serde::{Deserialize, Serialize};
use super::error::CounterfactualError;
use super::feature_change::FeatureChange;
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Counterfactual {
pub original_input: Vec<f32>,
pub original_decision: usize,
pub original_confidence: f32,
pub counterfactual_input: Vec<f32>,
pub alternative_decision: usize,
pub alternative_confidence: f32,
pub changes: Vec<FeatureChange>,
pub sparsity: f32,
pub distance: f32,
}
impl Counterfactual {
pub fn new(
original_input: Vec<f32>,
original_decision: usize,
original_confidence: f32,
counterfactual_input: Vec<f32>,
alternative_decision: usize,
alternative_confidence: f32,
) -> Self {
assert_eq!(original_input.len(), counterfactual_input.len(), "Input dimensions must match");
let mut changes = Vec::new();
let mut l1 = 0.0f32;
let mut l2 = 0.0f32;
for i in 0..original_input.len() {
let delta = counterfactual_input[i] - original_input[i];
if delta.abs() > 1e-6 {
changes.push(FeatureChange::new(i, original_input[i], counterfactual_input[i]));
l1 += delta.abs();
l2 += delta * delta;
}
}
Self {
original_input,
original_decision,
original_confidence,
counterfactual_input,
alternative_decision,
alternative_confidence,
changes,
sparsity: l1,
distance: l2.sqrt(),
}
}
pub fn explain(&self) -> String {
let mut explanation = format!(
"Original decision: {} (confidence: {:.1}%)\n",
self.original_decision,
self.original_confidence * 100.0
);
explanation.push_str(&format!(
"Alternative decision: {} (confidence: {:.1}%)\n",
self.alternative_decision,
self.alternative_confidence * 100.0
));
explanation.push_str(&format!(
"\nThe decision would have been {} if:\n",
self.alternative_decision
));
let mut sorted_changes = self.changes.clone();
sorted_changes.sort_by(|a, b| {
b.abs_delta().partial_cmp(&a.abs_delta()).unwrap_or(std::cmp::Ordering::Equal)
});
for change in sorted_changes.iter().take(5) {
let sign = if change.delta >= 0.0 { "+" } else { "" };
let default_name = format!("feature[{}]", change.feature_idx);
let name = change.feature_name.as_deref().unwrap_or(&default_name);
explanation.push_str(&format!(
" - {}: {:.4} → {:.4} ({}{:.4})\n",
name, change.original_value, change.counterfactual_value, sign, change.delta
));
}
if self.changes.len() > 5 {
explanation.push_str(&format!(" ... and {} more changes\n", self.changes.len() - 5));
}
explanation.push_str(&format!("\nSparsity (L1): {:.4}\n", self.sparsity));
explanation.push_str(&format!("Distance (L2): {:.4}\n", self.distance));
explanation
}
pub fn n_changes(&self) -> usize {
self.changes.len()
}
pub fn is_valid(&self) -> bool {
self.original_decision != self.alternative_decision
}
pub fn with_feature_names(mut self, names: &[String]) -> Self {
for change in &mut self.changes {
if change.feature_idx < names.len() {
change.feature_name = Some(names[change.feature_idx].clone());
}
}
self
}
pub fn to_bytes(&self) -> Vec<u8> {
let mut bytes = Vec::new();
bytes.push(1);
bytes.extend_from_slice(&(self.original_decision as u32).to_le_bytes());
bytes.extend_from_slice(&self.original_confidence.to_le_bytes());
bytes.extend_from_slice(&(self.alternative_decision as u32).to_le_bytes());
bytes.extend_from_slice(&self.alternative_confidence.to_le_bytes());
bytes.extend_from_slice(&(self.original_input.len() as u32).to_le_bytes());
for v in &self.original_input {
bytes.extend_from_slice(&v.to_le_bytes());
}
for v in &self.counterfactual_input {
bytes.extend_from_slice(&v.to_le_bytes());
}
bytes.extend_from_slice(&(self.changes.len() as u32).to_le_bytes());
for change in &self.changes {
bytes.extend_from_slice(&(change.feature_idx as u32).to_le_bytes());
bytes.extend_from_slice(&change.original_value.to_le_bytes());
bytes.extend_from_slice(&change.counterfactual_value.to_le_bytes());
bytes.extend_from_slice(&change.delta.to_le_bytes());
if let Some(name) = &change.feature_name {
bytes.extend_from_slice(&(name.len() as u32).to_le_bytes());
bytes.extend_from_slice(name.as_bytes());
} else {
bytes.extend_from_slice(&0u32.to_le_bytes());
}
}
bytes.extend_from_slice(&self.sparsity.to_le_bytes());
bytes.extend_from_slice(&self.distance.to_le_bytes());
bytes
}
pub fn from_bytes(bytes: &[u8]) -> Result<Self, CounterfactualError> {
if bytes.len() < 21 {
return Err(CounterfactualError::InsufficientData {
expected: 21,
actual: bytes.len(),
});
}
let mut reader = ByteReader::new(bytes);
let version = reader.read_u8()?;
if version != 1 {
return Err(CounterfactualError::VersionMismatch { expected: 1, actual: version });
}
let original_decision = reader.read_u32_as_usize()?;
let original_confidence = reader.read_f32()?;
let alternative_decision = reader.read_u32_as_usize()?;
let alternative_confidence = reader.read_f32()?;
let n_features = reader.read_u32_as_usize()?;
let original_input = reader.read_f32_vec_n(n_features)?;
let counterfactual_input = reader.read_f32_vec_n(n_features)?;
let n_changes = reader.read_u32_as_usize()?;
let mut changes = Vec::with_capacity(n_changes);
for _ in 0..n_changes {
changes.push(reader.read_feature_change()?);
}
let sparsity = reader.read_f32()?;
let distance = reader.read_f32()?;
Ok(Self {
original_input,
original_decision,
original_confidence,
counterfactual_input,
alternative_decision,
alternative_confidence,
changes,
sparsity,
distance,
})
}
}
struct ByteReader<'a> {
data: &'a [u8],
offset: usize,
}
impl<'a> ByteReader<'a> {
fn new(data: &'a [u8]) -> Self {
Self { data, offset: 0 }
}
fn read_u8(&mut self) -> Result<u8, CounterfactualError> {
self.ensure_available(1)?;
let val = self.data[self.offset];
self.offset += 1;
Ok(val)
}
fn read_u32(&mut self) -> Result<u32, CounterfactualError> {
self.ensure_available(4)?;
let o = self.offset;
let val = u32::from_le_bytes([
self.data[o],
self.data[o + 1],
self.data[o + 2],
self.data[o + 3],
]);
self.offset += 4;
Ok(val)
}
fn read_u32_as_usize(&mut self) -> Result<usize, CounterfactualError> {
Ok(self.read_u32()? as usize)
}
fn read_f32(&mut self) -> Result<f32, CounterfactualError> {
self.ensure_available(4)?;
let o = self.offset;
let val = f32::from_le_bytes([
self.data[o],
self.data[o + 1],
self.data[o + 2],
self.data[o + 3],
]);
self.offset += 4;
Ok(val)
}
fn read_f32_vec_n(&mut self, n: usize) -> Result<Vec<f32>, CounterfactualError> {
let mut vec = Vec::with_capacity(n);
for _ in 0..n {
vec.push(self.read_f32()?);
}
Ok(vec)
}
fn read_string(&mut self, len: usize) -> Result<String, CounterfactualError> {
self.ensure_available(len)?;
let s = String::from_utf8_lossy(&self.data[self.offset..self.offset + len]).to_string();
self.offset += len;
Ok(s)
}
fn read_feature_change(&mut self) -> Result<FeatureChange, CounterfactualError> {
let feature_idx = self.read_u32_as_usize()?;
let original_value = self.read_f32()?;
let counterfactual_value = self.read_f32()?;
let delta = self.read_f32()?;
let name_len = self.read_u32_as_usize()?;
let feature_name = if name_len > 0 { Some(self.read_string(name_len)?) } else { None };
Ok(FeatureChange { feature_idx, feature_name, original_value, counterfactual_value, delta })
}
fn ensure_available(&self, needed: usize) -> Result<(), CounterfactualError> {
if self.offset + needed > self.data.len() {
return Err(CounterfactualError::InsufficientData {
expected: self.offset + needed,
actual: self.data.len(),
});
}
Ok(())
}
}