use serde::{Deserialize, Serialize};
use super::traits::{ByteReader, DecisionPath, PathError};
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct NeuralPath {
pub input_gradient: Vec<f32>,
pub activations: Option<Vec<Vec<f32>>>,
pub attention_weights: Option<Vec<Vec<f32>>>,
pub integrated_gradients: Option<Vec<f32>>,
pub prediction: f32,
pub confidence: f32,
}
impl NeuralPath {
pub fn new(input_gradient: Vec<f32>, prediction: f32, confidence: f32) -> Self {
Self {
input_gradient,
activations: None,
attention_weights: None,
integrated_gradients: None,
prediction,
confidence,
}
}
pub fn with_activations(mut self, activations: Vec<Vec<f32>>) -> Self {
self.activations = Some(activations);
self
}
pub fn with_attention(mut self, attention: Vec<Vec<f32>>) -> Self {
self.attention_weights = Some(attention);
self
}
pub fn with_integrated_gradients(mut self, ig: Vec<f32>) -> Self {
self.integrated_gradients = Some(ig);
self
}
pub fn top_salient_features(&self, k: usize) -> Vec<(usize, f32)> {
let mut indexed: Vec<(usize, f32)> = self
.input_gradient
.iter()
.enumerate()
.map(|(i, &g)| (i, g))
.collect();
indexed.sort_by(|a, b| {
b.1.abs()
.partial_cmp(&a.1.abs())
.unwrap_or(std::cmp::Ordering::Equal)
});
indexed.truncate(k);
indexed
}
}
impl DecisionPath for NeuralPath {
fn explain(&self) -> String {
let mut explanation = format!(
"Neural Network Prediction: {:.4} (confidence: {:.1}%)\n",
self.prediction,
self.confidence * 100.0
);
explanation.push_str("\nTop salient input features (by gradient):\n");
for (idx, grad) in self.top_salient_features(5) {
let sign = if grad >= 0.0 { "+" } else { "" };
explanation.push_str(&format!(" input[{idx}]: {sign}{grad:.6}\n"));
}
if let Some(ig) = &self.integrated_gradients {
explanation.push_str("\nIntegrated gradients available (");
let len = ig.len();
explanation.push_str(&format!("{len} features)\n"));
}
if self.attention_weights.is_some() {
explanation.push_str("\nAttention weights available\n");
}
explanation
}
fn feature_contributions(&self) -> &[f32] {
self.integrated_gradients
.as_deref()
.unwrap_or(&self.input_gradient)
}
fn confidence(&self) -> f32 {
self.confidence
}
fn to_bytes(&self) -> Vec<u8> {
let mut bytes = Vec::new();
bytes.push(1);
bytes.extend_from_slice(&(self.input_gradient.len() as u32).to_le_bytes());
for g in &self.input_gradient {
bytes.extend_from_slice(&g.to_le_bytes());
}
bytes.extend_from_slice(&self.prediction.to_le_bytes());
bytes.extend_from_slice(&self.confidence.to_le_bytes());
let has_activations = self.activations.is_some();
bytes.push(u8::from(has_activations));
if let Some(activations) = &self.activations {
bytes.extend_from_slice(&(activations.len() as u32).to_le_bytes());
for layer in activations {
bytes.extend_from_slice(&(layer.len() as u32).to_le_bytes());
for a in layer {
bytes.extend_from_slice(&a.to_le_bytes());
}
}
}
let has_attention = self.attention_weights.is_some();
bytes.push(u8::from(has_attention));
if let Some(attention) = &self.attention_weights {
bytes.extend_from_slice(&(attention.len() as u32).to_le_bytes());
for layer in attention {
bytes.extend_from_slice(&(layer.len() as u32).to_le_bytes());
for a in layer {
bytes.extend_from_slice(&a.to_le_bytes());
}
}
}
let has_ig = self.integrated_gradients.is_some();
bytes.push(u8::from(has_ig));
if let Some(ig) = &self.integrated_gradients {
bytes.extend_from_slice(&(ig.len() as u32).to_le_bytes());
for g in ig {
bytes.extend_from_slice(&g.to_le_bytes());
}
}
bytes
}
fn from_bytes(bytes: &[u8]) -> Result<Self, PathError> {
if bytes.len() < 5 {
return Err(PathError::InsufficientData {
expected: 5,
actual: bytes.len(),
});
}
let mut reader = ByteReader::new(bytes);
let version = reader.read_u8()?;
if version != 1 {
return Err(PathError::VersionMismatch {
expected: 1,
actual: version,
});
}
let input_gradient = reader.read_f32_vec()?;
let prediction = reader.read_f32()?;
let confidence = reader.read_f32()?;
let activations = reader.read_optional(ByteReader::read_nested_f32_vecs)?;
let attention_weights = reader.read_optional(ByteReader::read_nested_f32_vecs)?;
let integrated_gradients = reader.read_optional(ByteReader::read_f32_vec)?;
Ok(Self {
input_gradient,
activations,
attention_weights,
integrated_gradients,
prediction,
confidence,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_neural_path_new() {
let path = NeuralPath::new(vec![0.1, -0.2, 0.3], 0.87, 0.92);
assert_eq!(path.input_gradient.len(), 3);
assert_eq!(path.prediction, 0.87);
assert_eq!(path.confidence, 0.92);
}
#[test]
fn test_neural_path_top_salient() {
let path = NeuralPath::new(vec![0.1, -0.5, 0.3], 0.0, 0.0);
let top = path.top_salient_features(2);
assert_eq!(top[0].0, 1);
assert_eq!(top[1].0, 2);
}
#[test]
fn test_neural_path_serialization_roundtrip() {
let path = NeuralPath::new(vec![0.1, -0.2, 0.3], 0.87, 0.92)
.with_activations(vec![vec![0.5, 0.6], vec![0.7, 0.8]])
.with_attention(vec![vec![0.1, 0.9]])
.with_integrated_gradients(vec![0.15, -0.25, 0.35]);
let bytes = path.to_bytes();
let restored = NeuralPath::from_bytes(&bytes).expect("Failed to deserialize");
assert_eq!(path.input_gradient.len(), restored.input_gradient.len());
assert!((path.prediction - restored.prediction).abs() < 1e-6);
assert!((path.confidence - restored.confidence).abs() < 1e-6);
assert!(restored.activations.is_some());
assert!(restored.attention_weights.is_some());
assert!(restored.integrated_gradients.is_some());
}
#[test]
fn test_neural_path_feature_contributions() {
let path = NeuralPath::new(vec![0.1, -0.2, 0.3], 0.0, 0.0);
assert_eq!(path.feature_contributions(), &[0.1, -0.2, 0.3]);
let path_with_ig = NeuralPath::new(vec![0.1, -0.2, 0.3], 0.0, 0.0)
.with_integrated_gradients(vec![0.5, 0.5]);
assert_eq!(path_with_ig.feature_contributions(), &[0.5, 0.5]);
}
#[test]
fn test_neural_path_invalid_version() {
let result = NeuralPath::from_bytes(&[2u8, 0, 0, 0, 0]);
assert!(matches!(result, Err(PathError::VersionMismatch { .. })));
}
#[test]
fn test_neural_path_insufficient_data() {
let result = NeuralPath::from_bytes(&[1u8, 0, 0]);
assert!(matches!(result, Err(PathError::InsufficientData { .. })));
}
#[test]
fn test_neural_path_explain_with_ig() {
let path =
NeuralPath::new(vec![0.1], 0.5, 0.9).with_integrated_gradients(vec![0.2, 0.3, 0.5]);
let explanation = path.explain();
assert!(explanation.contains("Integrated gradients"));
assert!(explanation.contains("3 features"));
}
#[test]
fn test_neural_path_explain_with_attention() {
let path = NeuralPath::new(vec![0.1], 0.5, 0.9).with_attention(vec![vec![0.5, 0.5]]);
let explanation = path.explain();
assert!(explanation.contains("Attention weights"));
}
#[test]
fn test_neural_path_serialization_minimal() {
let path = NeuralPath::new(vec![0.1, 0.2], 0.5, 0.9);
let bytes = path.to_bytes();
let restored = NeuralPath::from_bytes(&bytes).expect("Failed to deserialize");
assert!(restored.activations.is_none());
assert!(restored.attention_weights.is_none());
assert!(restored.integrated_gradients.is_none());
}
#[test]
fn test_neural_path_with_activations() {
let path = NeuralPath::new(vec![0.1], 0.5, 0.9)
.with_activations(vec![vec![1.0, 2.0], vec![3.0, 4.0]]);
assert!(path.activations.is_some());
let activations = path.activations.expect("has activations");
assert_eq!(activations.len(), 2);
assert_eq!(activations[0], vec![1.0, 2.0]);
assert_eq!(activations[1], vec![3.0, 4.0]);
}
#[test]
fn test_neural_path_confidence_method() {
let path = NeuralPath::new(vec![0.1], 0.5, 0.85);
assert_eq!(DecisionPath::confidence(&path), 0.85);
}
#[test]
fn test_neural_path_explain_basic() {
let path = NeuralPath::new(vec![0.1, -0.2, 0.3], 0.75, 0.90);
let explanation = path.explain();
assert!(explanation.contains("Neural Network Prediction"));
assert!(explanation.contains("0.75"));
assert!(explanation.contains("90.0%"));
assert!(explanation.contains("Top salient input features"));
}
#[test]
fn test_neural_path_top_salient_features_empty() {
let path = NeuralPath::new(vec![], 0.5, 0.9);
let top = path.top_salient_features(5);
assert!(top.is_empty());
}
#[test]
fn test_neural_path_top_salient_features_more_than_available() {
let path = NeuralPath::new(vec![0.1, 0.2], 0.5, 0.9);
let top = path.top_salient_features(10);
assert_eq!(top.len(), 2);
}
}