entrenar/monitor/inference/path/
neural.rs

1//! Neural network decision path (gradient-based)
2
3use super::traits::{DecisionPath, PathError};
4use serde::{Deserialize, Serialize};
5
6/// Decision path for neural networks (gradient-based)
7#[derive(Clone, Debug, Serialize, Deserialize)]
8pub struct NeuralPath {
9    /// Input gradient (saliency map)
10    pub input_gradient: Vec<f32>,
11    /// Layer activations (optional, feature-gated for memory)
12    pub activations: Option<Vec<Vec<f32>>>,
13    /// Attention weights (for transformers)
14    pub attention_weights: Option<Vec<Vec<f32>>>,
15    /// Integrated gradients attribution
16    pub integrated_gradients: Option<Vec<f32>>,
17    /// Final prediction
18    pub prediction: f32,
19    /// Confidence (softmax probability)
20    pub confidence: f32,
21}
22
23impl NeuralPath {
24    /// Create a new neural path
25    pub fn new(input_gradient: Vec<f32>, prediction: f32, confidence: f32) -> Self {
26        Self {
27            input_gradient,
28            activations: None,
29            attention_weights: None,
30            integrated_gradients: None,
31            prediction,
32            confidence,
33        }
34    }
35
36    /// Set layer activations
37    pub fn with_activations(mut self, activations: Vec<Vec<f32>>) -> Self {
38        self.activations = Some(activations);
39        self
40    }
41
42    /// Set attention weights
43    pub fn with_attention(mut self, attention: Vec<Vec<f32>>) -> Self {
44        self.attention_weights = Some(attention);
45        self
46    }
47
48    /// Set integrated gradients
49    pub fn with_integrated_gradients(mut self, ig: Vec<f32>) -> Self {
50        self.integrated_gradients = Some(ig);
51        self
52    }
53
54    /// Get top salient features by absolute gradient
55    pub fn top_salient_features(&self, k: usize) -> Vec<(usize, f32)> {
56        let mut indexed: Vec<(usize, f32)> = self
57            .input_gradient
58            .iter()
59            .enumerate()
60            .map(|(i, &g)| (i, g))
61            .collect();
62
63        indexed.sort_by(|a, b| {
64            b.1.abs()
65                .partial_cmp(&a.1.abs())
66                .unwrap_or(std::cmp::Ordering::Equal)
67        });
68        indexed.truncate(k);
69        indexed
70    }
71}
72
73impl DecisionPath for NeuralPath {
74    fn explain(&self) -> String {
75        let mut explanation = format!(
76            "Neural Network Prediction: {:.4} (confidence: {:.1}%)\n",
77            self.prediction,
78            self.confidence * 100.0
79        );
80
81        explanation.push_str("\nTop salient input features (by gradient):\n");
82        for (idx, grad) in self.top_salient_features(5) {
83            let sign = if grad >= 0.0 { "+" } else { "" };
84            explanation.push_str(&format!("  input[{idx}]: {sign}{grad:.6}\n"));
85        }
86
87        if let Some(ig) = &self.integrated_gradients {
88            explanation.push_str("\nIntegrated gradients available (");
89            let len = ig.len();
90            explanation.push_str(&format!("{len} features)\n"));
91        }
92
93        if self.attention_weights.is_some() {
94            explanation.push_str("\nAttention weights available\n");
95        }
96
97        explanation
98    }
99
100    fn feature_contributions(&self) -> &[f32] {
101        self.integrated_gradients
102            .as_deref()
103            .unwrap_or(&self.input_gradient)
104    }
105
106    fn confidence(&self) -> f32 {
107        self.confidence
108    }
109
110    fn to_bytes(&self) -> Vec<u8> {
111        let mut bytes = Vec::new();
112        bytes.push(1); // version
113
114        // Input gradient
115        bytes.extend_from_slice(&(self.input_gradient.len() as u32).to_le_bytes());
116        for g in &self.input_gradient {
117            bytes.extend_from_slice(&g.to_le_bytes());
118        }
119
120        // Prediction and confidence
121        bytes.extend_from_slice(&self.prediction.to_le_bytes());
122        bytes.extend_from_slice(&self.confidence.to_le_bytes());
123
124        // Activations
125        let has_activations = self.activations.is_some();
126        bytes.push(u8::from(has_activations));
127        if let Some(activations) = &self.activations {
128            bytes.extend_from_slice(&(activations.len() as u32).to_le_bytes());
129            for layer in activations {
130                bytes.extend_from_slice(&(layer.len() as u32).to_le_bytes());
131                for a in layer {
132                    bytes.extend_from_slice(&a.to_le_bytes());
133                }
134            }
135        }
136
137        // Attention weights
138        let has_attention = self.attention_weights.is_some();
139        bytes.push(u8::from(has_attention));
140        if let Some(attention) = &self.attention_weights {
141            bytes.extend_from_slice(&(attention.len() as u32).to_le_bytes());
142            for layer in attention {
143                bytes.extend_from_slice(&(layer.len() as u32).to_le_bytes());
144                for a in layer {
145                    bytes.extend_from_slice(&a.to_le_bytes());
146                }
147            }
148        }
149
150        // Integrated gradients
151        let has_ig = self.integrated_gradients.is_some();
152        bytes.push(u8::from(has_ig));
153        if let Some(ig) = &self.integrated_gradients {
154            bytes.extend_from_slice(&(ig.len() as u32).to_le_bytes());
155            for g in ig {
156                bytes.extend_from_slice(&g.to_le_bytes());
157            }
158        }
159
160        bytes
161    }
162
163    fn from_bytes(bytes: &[u8]) -> Result<Self, PathError> {
164        if bytes.len() < 5 {
165            return Err(PathError::InsufficientData {
166                expected: 5,
167                actual: bytes.len(),
168            });
169        }
170
171        let version = bytes[0];
172        if version != 1 {
173            return Err(PathError::VersionMismatch {
174                expected: 1,
175                actual: version,
176            });
177        }
178
179        let mut offset = 1;
180
181        // Input gradient
182        let n_grad = u32::from_le_bytes([
183            bytes[offset],
184            bytes[offset + 1],
185            bytes[offset + 2],
186            bytes[offset + 3],
187        ]) as usize;
188        offset += 4;
189
190        let mut input_gradient = Vec::with_capacity(n_grad);
191        for _ in 0..n_grad {
192            if offset + 4 > bytes.len() {
193                return Err(PathError::InsufficientData {
194                    expected: offset + 4,
195                    actual: bytes.len(),
196                });
197            }
198            let g = f32::from_le_bytes([
199                bytes[offset],
200                bytes[offset + 1],
201                bytes[offset + 2],
202                bytes[offset + 3],
203            ]);
204            offset += 4;
205            input_gradient.push(g);
206        }
207
208        // Prediction and confidence
209        if offset + 8 > bytes.len() {
210            return Err(PathError::InsufficientData {
211                expected: offset + 8,
212                actual: bytes.len(),
213            });
214        }
215        let prediction = f32::from_le_bytes([
216            bytes[offset],
217            bytes[offset + 1],
218            bytes[offset + 2],
219            bytes[offset + 3],
220        ]);
221        offset += 4;
222
223        let confidence = f32::from_le_bytes([
224            bytes[offset],
225            bytes[offset + 1],
226            bytes[offset + 2],
227            bytes[offset + 3],
228        ]);
229        offset += 4;
230
231        // Activations
232        if offset + 1 > bytes.len() {
233            return Err(PathError::InsufficientData {
234                expected: offset + 1,
235                actual: bytes.len(),
236            });
237        }
238        let has_activations = bytes[offset] != 0;
239        offset += 1;
240
241        let activations = if has_activations {
242            if offset + 4 > bytes.len() {
243                return Err(PathError::InsufficientData {
244                    expected: offset + 4,
245                    actual: bytes.len(),
246                });
247            }
248            let n_layers = u32::from_le_bytes([
249                bytes[offset],
250                bytes[offset + 1],
251                bytes[offset + 2],
252                bytes[offset + 3],
253            ]) as usize;
254            offset += 4;
255
256            let mut layers = Vec::with_capacity(n_layers);
257            for _ in 0..n_layers {
258                if offset + 4 > bytes.len() {
259                    return Err(PathError::InsufficientData {
260                        expected: offset + 4,
261                        actual: bytes.len(),
262                    });
263                }
264                let layer_len = u32::from_le_bytes([
265                    bytes[offset],
266                    bytes[offset + 1],
267                    bytes[offset + 2],
268                    bytes[offset + 3],
269                ]) as usize;
270                offset += 4;
271
272                let mut layer = Vec::with_capacity(layer_len);
273                for _ in 0..layer_len {
274                    if offset + 4 > bytes.len() {
275                        return Err(PathError::InsufficientData {
276                            expected: offset + 4,
277                            actual: bytes.len(),
278                        });
279                    }
280                    let a = f32::from_le_bytes([
281                        bytes[offset],
282                        bytes[offset + 1],
283                        bytes[offset + 2],
284                        bytes[offset + 3],
285                    ]);
286                    offset += 4;
287                    layer.push(a);
288                }
289                layers.push(layer);
290            }
291            Some(layers)
292        } else {
293            None
294        };
295
296        // Attention weights (similar pattern)
297        if offset + 1 > bytes.len() {
298            return Err(PathError::InsufficientData {
299                expected: offset + 1,
300                actual: bytes.len(),
301            });
302        }
303        let has_attention = bytes[offset] != 0;
304        offset += 1;
305
306        let attention_weights = if has_attention {
307            if offset + 4 > bytes.len() {
308                return Err(PathError::InsufficientData {
309                    expected: offset + 4,
310                    actual: bytes.len(),
311                });
312            }
313            let n_layers = u32::from_le_bytes([
314                bytes[offset],
315                bytes[offset + 1],
316                bytes[offset + 2],
317                bytes[offset + 3],
318            ]) as usize;
319            offset += 4;
320
321            let mut layers = Vec::with_capacity(n_layers);
322            for _ in 0..n_layers {
323                if offset + 4 > bytes.len() {
324                    return Err(PathError::InsufficientData {
325                        expected: offset + 4,
326                        actual: bytes.len(),
327                    });
328                }
329                let layer_len = u32::from_le_bytes([
330                    bytes[offset],
331                    bytes[offset + 1],
332                    bytes[offset + 2],
333                    bytes[offset + 3],
334                ]) as usize;
335                offset += 4;
336
337                let mut layer = Vec::with_capacity(layer_len);
338                for _ in 0..layer_len {
339                    if offset + 4 > bytes.len() {
340                        return Err(PathError::InsufficientData {
341                            expected: offset + 4,
342                            actual: bytes.len(),
343                        });
344                    }
345                    let a = f32::from_le_bytes([
346                        bytes[offset],
347                        bytes[offset + 1],
348                        bytes[offset + 2],
349                        bytes[offset + 3],
350                    ]);
351                    offset += 4;
352                    layer.push(a);
353                }
354                layers.push(layer);
355            }
356            Some(layers)
357        } else {
358            None
359        };
360
361        // Integrated gradients
362        if offset + 1 > bytes.len() {
363            return Err(PathError::InsufficientData {
364                expected: offset + 1,
365                actual: bytes.len(),
366            });
367        }
368        let has_ig = bytes[offset] != 0;
369        offset += 1;
370
371        let integrated_gradients = if has_ig {
372            if offset + 4 > bytes.len() {
373                return Err(PathError::InsufficientData {
374                    expected: offset + 4,
375                    actual: bytes.len(),
376                });
377            }
378            let n_ig = u32::from_le_bytes([
379                bytes[offset],
380                bytes[offset + 1],
381                bytes[offset + 2],
382                bytes[offset + 3],
383            ]) as usize;
384            offset += 4;
385
386            let mut ig = Vec::with_capacity(n_ig);
387            for _ in 0..n_ig {
388                if offset + 4 > bytes.len() {
389                    return Err(PathError::InsufficientData {
390                        expected: offset + 4,
391                        actual: bytes.len(),
392                    });
393                }
394                let g = f32::from_le_bytes([
395                    bytes[offset],
396                    bytes[offset + 1],
397                    bytes[offset + 2],
398                    bytes[offset + 3],
399                ]);
400                offset += 4;
401                ig.push(g);
402            }
403            Some(ig)
404        } else {
405            None
406        };
407
408        Ok(Self {
409            input_gradient,
410            activations,
411            attention_weights,
412            integrated_gradients,
413            prediction,
414            confidence,
415        })
416    }
417}
418
419#[cfg(test)]
420mod tests {
421    use super::*;
422
423    #[test]
424    fn test_neural_path_new() {
425        let path = NeuralPath::new(vec![0.1, -0.2, 0.3], 0.87, 0.92);
426        assert_eq!(path.input_gradient.len(), 3);
427        assert_eq!(path.prediction, 0.87);
428        assert_eq!(path.confidence, 0.92);
429    }
430
431    #[test]
432    fn test_neural_path_top_salient() {
433        let path = NeuralPath::new(vec![0.1, -0.5, 0.3], 0.0, 0.0);
434        let top = path.top_salient_features(2);
435        assert_eq!(top[0].0, 1); // -0.5 has highest absolute value
436        assert_eq!(top[1].0, 2); // 0.3 is second
437    }
438
439    #[test]
440    fn test_neural_path_serialization_roundtrip() {
441        let path = NeuralPath::new(vec![0.1, -0.2, 0.3], 0.87, 0.92)
442            .with_activations(vec![vec![0.5, 0.6], vec![0.7, 0.8]])
443            .with_attention(vec![vec![0.1, 0.9]])
444            .with_integrated_gradients(vec![0.15, -0.25, 0.35]);
445
446        let bytes = path.to_bytes();
447        let restored = NeuralPath::from_bytes(&bytes).expect("Failed to deserialize");
448
449        assert_eq!(path.input_gradient.len(), restored.input_gradient.len());
450        assert!((path.prediction - restored.prediction).abs() < 1e-6);
451        assert!((path.confidence - restored.confidence).abs() < 1e-6);
452        assert!(restored.activations.is_some());
453        assert!(restored.attention_weights.is_some());
454        assert!(restored.integrated_gradients.is_some());
455    }
456
457    #[test]
458    fn test_neural_path_feature_contributions() {
459        let path = NeuralPath::new(vec![0.1, -0.2, 0.3], 0.0, 0.0);
460        assert_eq!(path.feature_contributions(), &[0.1, -0.2, 0.3]);
461
462        let path_with_ig = NeuralPath::new(vec![0.1, -0.2, 0.3], 0.0, 0.0)
463            .with_integrated_gradients(vec![0.5, 0.5]);
464        assert_eq!(path_with_ig.feature_contributions(), &[0.5, 0.5]);
465    }
466
467    #[test]
468    fn test_neural_path_invalid_version() {
469        let result = NeuralPath::from_bytes(&[2u8, 0, 0, 0, 0]);
470        assert!(matches!(result, Err(PathError::VersionMismatch { .. })));
471    }
472
473    #[test]
474    fn test_neural_path_insufficient_data() {
475        let result = NeuralPath::from_bytes(&[1u8, 0, 0]);
476        assert!(matches!(result, Err(PathError::InsufficientData { .. })));
477    }
478
479    #[test]
480    fn test_neural_path_explain_with_ig() {
481        let path =
482            NeuralPath::new(vec![0.1], 0.5, 0.9).with_integrated_gradients(vec![0.2, 0.3, 0.5]);
483        let explanation = path.explain();
484        assert!(explanation.contains("Integrated gradients"));
485        assert!(explanation.contains("3 features"));
486    }
487
488    #[test]
489    fn test_neural_path_explain_with_attention() {
490        let path = NeuralPath::new(vec![0.1], 0.5, 0.9).with_attention(vec![vec![0.5, 0.5]]);
491        let explanation = path.explain();
492        assert!(explanation.contains("Attention weights"));
493    }
494
495    #[test]
496    fn test_neural_path_serialization_minimal() {
497        let path = NeuralPath::new(vec![0.1, 0.2], 0.5, 0.9);
498        let bytes = path.to_bytes();
499        let restored = NeuralPath::from_bytes(&bytes).expect("Failed to deserialize");
500        assert!(restored.activations.is_none());
501        assert!(restored.attention_weights.is_none());
502        assert!(restored.integrated_gradients.is_none());
503    }
504
505    #[test]
506    fn test_neural_path_with_activations() {
507        let path = NeuralPath::new(vec![0.1], 0.5, 0.9)
508            .with_activations(vec![vec![1.0, 2.0], vec![3.0, 4.0]]);
509        assert!(path.activations.is_some());
510        let activations = path.activations.unwrap();
511        assert_eq!(activations.len(), 2);
512        assert_eq!(activations[0], vec![1.0, 2.0]);
513        assert_eq!(activations[1], vec![3.0, 4.0]);
514    }
515
516    #[test]
517    fn test_neural_path_confidence_method() {
518        let path = NeuralPath::new(vec![0.1], 0.5, 0.85);
519        assert_eq!(path.confidence(), 0.85);
520    }
521
522    #[test]
523    fn test_neural_path_explain_basic() {
524        let path = NeuralPath::new(vec![0.1, -0.2, 0.3], 0.75, 0.90);
525        let explanation = path.explain();
526        assert!(explanation.contains("Neural Network Prediction"));
527        assert!(explanation.contains("0.75"));
528        assert!(explanation.contains("90.0%"));
529        assert!(explanation.contains("Top salient input features"));
530    }
531
532    #[test]
533    fn test_neural_path_top_salient_features_empty() {
534        let path = NeuralPath::new(vec![], 0.5, 0.9);
535        let top = path.top_salient_features(5);
536        assert!(top.is_empty());
537    }
538
539    #[test]
540    fn test_neural_path_top_salient_features_more_than_available() {
541        let path = NeuralPath::new(vec![0.1, 0.2], 0.5, 0.9);
542        let top = path.top_salient_features(10);
543        assert_eq!(top.len(), 2);
544    }
545}