1use super::traits::{DecisionPath, PathError};
4use serde::{Deserialize, Serialize};
5
6#[derive(Clone, Debug, Serialize, Deserialize)]
8pub struct NeuralPath {
9 pub input_gradient: Vec<f32>,
11 pub activations: Option<Vec<Vec<f32>>>,
13 pub attention_weights: Option<Vec<Vec<f32>>>,
15 pub integrated_gradients: Option<Vec<f32>>,
17 pub prediction: f32,
19 pub confidence: f32,
21}
22
23impl NeuralPath {
24 pub fn new(input_gradient: Vec<f32>, prediction: f32, confidence: f32) -> Self {
26 Self {
27 input_gradient,
28 activations: None,
29 attention_weights: None,
30 integrated_gradients: None,
31 prediction,
32 confidence,
33 }
34 }
35
36 pub fn with_activations(mut self, activations: Vec<Vec<f32>>) -> Self {
38 self.activations = Some(activations);
39 self
40 }
41
42 pub fn with_attention(mut self, attention: Vec<Vec<f32>>) -> Self {
44 self.attention_weights = Some(attention);
45 self
46 }
47
48 pub fn with_integrated_gradients(mut self, ig: Vec<f32>) -> Self {
50 self.integrated_gradients = Some(ig);
51 self
52 }
53
54 pub fn top_salient_features(&self, k: usize) -> Vec<(usize, f32)> {
56 let mut indexed: Vec<(usize, f32)> = self
57 .input_gradient
58 .iter()
59 .enumerate()
60 .map(|(i, &g)| (i, g))
61 .collect();
62
63 indexed.sort_by(|a, b| {
64 b.1.abs()
65 .partial_cmp(&a.1.abs())
66 .unwrap_or(std::cmp::Ordering::Equal)
67 });
68 indexed.truncate(k);
69 indexed
70 }
71}
72
73impl DecisionPath for NeuralPath {
74 fn explain(&self) -> String {
75 let mut explanation = format!(
76 "Neural Network Prediction: {:.4} (confidence: {:.1}%)\n",
77 self.prediction,
78 self.confidence * 100.0
79 );
80
81 explanation.push_str("\nTop salient input features (by gradient):\n");
82 for (idx, grad) in self.top_salient_features(5) {
83 let sign = if grad >= 0.0 { "+" } else { "" };
84 explanation.push_str(&format!(" input[{idx}]: {sign}{grad:.6}\n"));
85 }
86
87 if let Some(ig) = &self.integrated_gradients {
88 explanation.push_str("\nIntegrated gradients available (");
89 let len = ig.len();
90 explanation.push_str(&format!("{len} features)\n"));
91 }
92
93 if self.attention_weights.is_some() {
94 explanation.push_str("\nAttention weights available\n");
95 }
96
97 explanation
98 }
99
100 fn feature_contributions(&self) -> &[f32] {
101 self.integrated_gradients
102 .as_deref()
103 .unwrap_or(&self.input_gradient)
104 }
105
106 fn confidence(&self) -> f32 {
107 self.confidence
108 }
109
110 fn to_bytes(&self) -> Vec<u8> {
111 let mut bytes = Vec::new();
112 bytes.push(1); bytes.extend_from_slice(&(self.input_gradient.len() as u32).to_le_bytes());
116 for g in &self.input_gradient {
117 bytes.extend_from_slice(&g.to_le_bytes());
118 }
119
120 bytes.extend_from_slice(&self.prediction.to_le_bytes());
122 bytes.extend_from_slice(&self.confidence.to_le_bytes());
123
124 let has_activations = self.activations.is_some();
126 bytes.push(u8::from(has_activations));
127 if let Some(activations) = &self.activations {
128 bytes.extend_from_slice(&(activations.len() as u32).to_le_bytes());
129 for layer in activations {
130 bytes.extend_from_slice(&(layer.len() as u32).to_le_bytes());
131 for a in layer {
132 bytes.extend_from_slice(&a.to_le_bytes());
133 }
134 }
135 }
136
137 let has_attention = self.attention_weights.is_some();
139 bytes.push(u8::from(has_attention));
140 if let Some(attention) = &self.attention_weights {
141 bytes.extend_from_slice(&(attention.len() as u32).to_le_bytes());
142 for layer in attention {
143 bytes.extend_from_slice(&(layer.len() as u32).to_le_bytes());
144 for a in layer {
145 bytes.extend_from_slice(&a.to_le_bytes());
146 }
147 }
148 }
149
150 let has_ig = self.integrated_gradients.is_some();
152 bytes.push(u8::from(has_ig));
153 if let Some(ig) = &self.integrated_gradients {
154 bytes.extend_from_slice(&(ig.len() as u32).to_le_bytes());
155 for g in ig {
156 bytes.extend_from_slice(&g.to_le_bytes());
157 }
158 }
159
160 bytes
161 }
162
163 fn from_bytes(bytes: &[u8]) -> Result<Self, PathError> {
164 if bytes.len() < 5 {
165 return Err(PathError::InsufficientData {
166 expected: 5,
167 actual: bytes.len(),
168 });
169 }
170
171 let version = bytes[0];
172 if version != 1 {
173 return Err(PathError::VersionMismatch {
174 expected: 1,
175 actual: version,
176 });
177 }
178
179 let mut offset = 1;
180
181 let n_grad = u32::from_le_bytes([
183 bytes[offset],
184 bytes[offset + 1],
185 bytes[offset + 2],
186 bytes[offset + 3],
187 ]) as usize;
188 offset += 4;
189
190 let mut input_gradient = Vec::with_capacity(n_grad);
191 for _ in 0..n_grad {
192 if offset + 4 > bytes.len() {
193 return Err(PathError::InsufficientData {
194 expected: offset + 4,
195 actual: bytes.len(),
196 });
197 }
198 let g = f32::from_le_bytes([
199 bytes[offset],
200 bytes[offset + 1],
201 bytes[offset + 2],
202 bytes[offset + 3],
203 ]);
204 offset += 4;
205 input_gradient.push(g);
206 }
207
208 if offset + 8 > bytes.len() {
210 return Err(PathError::InsufficientData {
211 expected: offset + 8,
212 actual: bytes.len(),
213 });
214 }
215 let prediction = f32::from_le_bytes([
216 bytes[offset],
217 bytes[offset + 1],
218 bytes[offset + 2],
219 bytes[offset + 3],
220 ]);
221 offset += 4;
222
223 let confidence = f32::from_le_bytes([
224 bytes[offset],
225 bytes[offset + 1],
226 bytes[offset + 2],
227 bytes[offset + 3],
228 ]);
229 offset += 4;
230
231 if offset + 1 > bytes.len() {
233 return Err(PathError::InsufficientData {
234 expected: offset + 1,
235 actual: bytes.len(),
236 });
237 }
238 let has_activations = bytes[offset] != 0;
239 offset += 1;
240
241 let activations = if has_activations {
242 if offset + 4 > bytes.len() {
243 return Err(PathError::InsufficientData {
244 expected: offset + 4,
245 actual: bytes.len(),
246 });
247 }
248 let n_layers = u32::from_le_bytes([
249 bytes[offset],
250 bytes[offset + 1],
251 bytes[offset + 2],
252 bytes[offset + 3],
253 ]) as usize;
254 offset += 4;
255
256 let mut layers = Vec::with_capacity(n_layers);
257 for _ in 0..n_layers {
258 if offset + 4 > bytes.len() {
259 return Err(PathError::InsufficientData {
260 expected: offset + 4,
261 actual: bytes.len(),
262 });
263 }
264 let layer_len = u32::from_le_bytes([
265 bytes[offset],
266 bytes[offset + 1],
267 bytes[offset + 2],
268 bytes[offset + 3],
269 ]) as usize;
270 offset += 4;
271
272 let mut layer = Vec::with_capacity(layer_len);
273 for _ in 0..layer_len {
274 if offset + 4 > bytes.len() {
275 return Err(PathError::InsufficientData {
276 expected: offset + 4,
277 actual: bytes.len(),
278 });
279 }
280 let a = f32::from_le_bytes([
281 bytes[offset],
282 bytes[offset + 1],
283 bytes[offset + 2],
284 bytes[offset + 3],
285 ]);
286 offset += 4;
287 layer.push(a);
288 }
289 layers.push(layer);
290 }
291 Some(layers)
292 } else {
293 None
294 };
295
296 if offset + 1 > bytes.len() {
298 return Err(PathError::InsufficientData {
299 expected: offset + 1,
300 actual: bytes.len(),
301 });
302 }
303 let has_attention = bytes[offset] != 0;
304 offset += 1;
305
306 let attention_weights = if has_attention {
307 if offset + 4 > bytes.len() {
308 return Err(PathError::InsufficientData {
309 expected: offset + 4,
310 actual: bytes.len(),
311 });
312 }
313 let n_layers = u32::from_le_bytes([
314 bytes[offset],
315 bytes[offset + 1],
316 bytes[offset + 2],
317 bytes[offset + 3],
318 ]) as usize;
319 offset += 4;
320
321 let mut layers = Vec::with_capacity(n_layers);
322 for _ in 0..n_layers {
323 if offset + 4 > bytes.len() {
324 return Err(PathError::InsufficientData {
325 expected: offset + 4,
326 actual: bytes.len(),
327 });
328 }
329 let layer_len = u32::from_le_bytes([
330 bytes[offset],
331 bytes[offset + 1],
332 bytes[offset + 2],
333 bytes[offset + 3],
334 ]) as usize;
335 offset += 4;
336
337 let mut layer = Vec::with_capacity(layer_len);
338 for _ in 0..layer_len {
339 if offset + 4 > bytes.len() {
340 return Err(PathError::InsufficientData {
341 expected: offset + 4,
342 actual: bytes.len(),
343 });
344 }
345 let a = f32::from_le_bytes([
346 bytes[offset],
347 bytes[offset + 1],
348 bytes[offset + 2],
349 bytes[offset + 3],
350 ]);
351 offset += 4;
352 layer.push(a);
353 }
354 layers.push(layer);
355 }
356 Some(layers)
357 } else {
358 None
359 };
360
361 if offset + 1 > bytes.len() {
363 return Err(PathError::InsufficientData {
364 expected: offset + 1,
365 actual: bytes.len(),
366 });
367 }
368 let has_ig = bytes[offset] != 0;
369 offset += 1;
370
371 let integrated_gradients = if has_ig {
372 if offset + 4 > bytes.len() {
373 return Err(PathError::InsufficientData {
374 expected: offset + 4,
375 actual: bytes.len(),
376 });
377 }
378 let n_ig = u32::from_le_bytes([
379 bytes[offset],
380 bytes[offset + 1],
381 bytes[offset + 2],
382 bytes[offset + 3],
383 ]) as usize;
384 offset += 4;
385
386 let mut ig = Vec::with_capacity(n_ig);
387 for _ in 0..n_ig {
388 if offset + 4 > bytes.len() {
389 return Err(PathError::InsufficientData {
390 expected: offset + 4,
391 actual: bytes.len(),
392 });
393 }
394 let g = f32::from_le_bytes([
395 bytes[offset],
396 bytes[offset + 1],
397 bytes[offset + 2],
398 bytes[offset + 3],
399 ]);
400 offset += 4;
401 ig.push(g);
402 }
403 Some(ig)
404 } else {
405 None
406 };
407
408 Ok(Self {
409 input_gradient,
410 activations,
411 attention_weights,
412 integrated_gradients,
413 prediction,
414 confidence,
415 })
416 }
417}
418
419#[cfg(test)]
420mod tests {
421 use super::*;
422
423 #[test]
424 fn test_neural_path_new() {
425 let path = NeuralPath::new(vec![0.1, -0.2, 0.3], 0.87, 0.92);
426 assert_eq!(path.input_gradient.len(), 3);
427 assert_eq!(path.prediction, 0.87);
428 assert_eq!(path.confidence, 0.92);
429 }
430
431 #[test]
432 fn test_neural_path_top_salient() {
433 let path = NeuralPath::new(vec![0.1, -0.5, 0.3], 0.0, 0.0);
434 let top = path.top_salient_features(2);
435 assert_eq!(top[0].0, 1); assert_eq!(top[1].0, 2); }
438
439 #[test]
440 fn test_neural_path_serialization_roundtrip() {
441 let path = NeuralPath::new(vec![0.1, -0.2, 0.3], 0.87, 0.92)
442 .with_activations(vec![vec![0.5, 0.6], vec![0.7, 0.8]])
443 .with_attention(vec![vec![0.1, 0.9]])
444 .with_integrated_gradients(vec![0.15, -0.25, 0.35]);
445
446 let bytes = path.to_bytes();
447 let restored = NeuralPath::from_bytes(&bytes).expect("Failed to deserialize");
448
449 assert_eq!(path.input_gradient.len(), restored.input_gradient.len());
450 assert!((path.prediction - restored.prediction).abs() < 1e-6);
451 assert!((path.confidence - restored.confidence).abs() < 1e-6);
452 assert!(restored.activations.is_some());
453 assert!(restored.attention_weights.is_some());
454 assert!(restored.integrated_gradients.is_some());
455 }
456
457 #[test]
458 fn test_neural_path_feature_contributions() {
459 let path = NeuralPath::new(vec![0.1, -0.2, 0.3], 0.0, 0.0);
460 assert_eq!(path.feature_contributions(), &[0.1, -0.2, 0.3]);
461
462 let path_with_ig = NeuralPath::new(vec![0.1, -0.2, 0.3], 0.0, 0.0)
463 .with_integrated_gradients(vec![0.5, 0.5]);
464 assert_eq!(path_with_ig.feature_contributions(), &[0.5, 0.5]);
465 }
466
467 #[test]
468 fn test_neural_path_invalid_version() {
469 let result = NeuralPath::from_bytes(&[2u8, 0, 0, 0, 0]);
470 assert!(matches!(result, Err(PathError::VersionMismatch { .. })));
471 }
472
473 #[test]
474 fn test_neural_path_insufficient_data() {
475 let result = NeuralPath::from_bytes(&[1u8, 0, 0]);
476 assert!(matches!(result, Err(PathError::InsufficientData { .. })));
477 }
478
479 #[test]
480 fn test_neural_path_explain_with_ig() {
481 let path =
482 NeuralPath::new(vec![0.1], 0.5, 0.9).with_integrated_gradients(vec![0.2, 0.3, 0.5]);
483 let explanation = path.explain();
484 assert!(explanation.contains("Integrated gradients"));
485 assert!(explanation.contains("3 features"));
486 }
487
488 #[test]
489 fn test_neural_path_explain_with_attention() {
490 let path = NeuralPath::new(vec![0.1], 0.5, 0.9).with_attention(vec![vec![0.5, 0.5]]);
491 let explanation = path.explain();
492 assert!(explanation.contains("Attention weights"));
493 }
494
495 #[test]
496 fn test_neural_path_serialization_minimal() {
497 let path = NeuralPath::new(vec![0.1, 0.2], 0.5, 0.9);
498 let bytes = path.to_bytes();
499 let restored = NeuralPath::from_bytes(&bytes).expect("Failed to deserialize");
500 assert!(restored.activations.is_none());
501 assert!(restored.attention_weights.is_none());
502 assert!(restored.integrated_gradients.is_none());
503 }
504
505 #[test]
506 fn test_neural_path_with_activations() {
507 let path = NeuralPath::new(vec![0.1], 0.5, 0.9)
508 .with_activations(vec![vec![1.0, 2.0], vec![3.0, 4.0]]);
509 assert!(path.activations.is_some());
510 let activations = path.activations.unwrap();
511 assert_eq!(activations.len(), 2);
512 assert_eq!(activations[0], vec![1.0, 2.0]);
513 assert_eq!(activations[1], vec![3.0, 4.0]);
514 }
515
516 #[test]
517 fn test_neural_path_confidence_method() {
518 let path = NeuralPath::new(vec![0.1], 0.5, 0.85);
519 assert_eq!(path.confidence(), 0.85);
520 }
521
522 #[test]
523 fn test_neural_path_explain_basic() {
524 let path = NeuralPath::new(vec![0.1, -0.2, 0.3], 0.75, 0.90);
525 let explanation = path.explain();
526 assert!(explanation.contains("Neural Network Prediction"));
527 assert!(explanation.contains("0.75"));
528 assert!(explanation.contains("90.0%"));
529 assert!(explanation.contains("Top salient input features"));
530 }
531
532 #[test]
533 fn test_neural_path_top_salient_features_empty() {
534 let path = NeuralPath::new(vec![], 0.5, 0.9);
535 let top = path.top_salient_features(5);
536 assert!(top.is_empty());
537 }
538
539 #[test]
540 fn test_neural_path_top_salient_features_more_than_available() {
541 let path = NeuralPath::new(vec![0.1, 0.2], 0.5, 0.9);
542 let top = path.top_salient_features(10);
543 assert_eq!(top.len(), 2);
544 }
545}