concision_neural/traits/
predict.rs

1/*
2    Appellation: predict <module>
3    Contrib: @FL03
4*/
5
6/// [Predict] isn't designed to be implemented directly, rather, as a blanket impl for any
7/// entity that implements the [`Forward`](cnc::Forward) trait. This is primarily used to
8/// define the base functionality of the [`Model`](crate::Model) trait.
9pub trait Predict<Rhs> {
10    type Output;
11
12    private!();
13
14    fn predict(&self, input: &Rhs) -> crate::NeuralResult<Self::Output>;
15}
16
17/// This trait extends the [`Predict`] trait to include a confidence score for the prediction.
18/// The confidence score is calculated as the inverse of the variance of the output.
19pub trait PredictWithConfidence<Rhs>: Predict<Rhs> {
20    type Confidence;
21
22    fn predict_with_confidence(
23        &self,
24        input: &Rhs,
25    ) -> crate::NeuralResult<(Self::Output, Self::Confidence)>;
26}
27
28/*
29 ************* Implementations *************
30*/
31
32use cnc::Forward;
33use ndarray::{Array, Dimension, ScalarOperand};
34use num_traits::{Float, FromPrimitive};
35
36impl<M, U, V> Predict<U> for M
37where
38    M: Forward<U, Output = V>,
39{
40    type Output = V;
41
42    seal!();
43
44    fn predict(&self, input: &U) -> crate::NeuralResult<Self::Output> {
45        self.forward(input).map_err(core::convert::Into::into)
46    }
47}
48
49impl<M, U, A, D> PredictWithConfidence<U> for M
50where
51    A: Float + FromPrimitive + ScalarOperand,
52    D: Dimension,
53    Self: Predict<U, Output = Array<A, D>>,
54{
55    type Confidence = A;
56
57    fn predict_with_confidence(
58        &self,
59        input: &U,
60    ) -> Result<(Self::Output, Self::Confidence), crate::NeuralError> {
61        // Get the base prediction
62        let prediction = Predict::predict(self, input)?;
63        let shape = prediction.shape();
64        // Calculate confidence as the inverse of the variance of the output
65        // For each sample, compute the variance across the output dimensions
66        let batch_size = shape[0];
67
68        let mut variance_sum = A::zero();
69
70        for sample in prediction.rows() {
71            // Compute variance
72            let variance = sample.var(A::one());
73            variance_sum = variance_sum + variance;
74        }
75
76        // Average variance across the batch
77        let avg_variance = variance_sum / A::from_usize(batch_size).unwrap();
78        // Confidence: inverse of variance (clipped to avoid division by zero)
79        let confidence = (A::one() + avg_variance).recip();
80
81        Ok((prediction, confidence))
82    }
83}