concision_neural/traits/
predict.rs

1/*
2    Appellation: predict <module>
3    Contrib: @FL03
4*/
5use cnc::Forward;
6
7/// The [`Predict`] trait is designed as a _**model-specific**_ interface for making
8/// predictions. In the future, we may consider opening the trait up allowing for an
9/// alternative implementation of the trait, but for now, it is simply implemented for all
10/// implementors of the [`Forward`] trait.
11///
12/// **Note:** The trait is sealed, preventing external implementations, ensuring that only the
13/// library can define how predictions are made. This is to maintain consistency and integrity
14/// across different model implementations.
15pub trait Predict<Rhs> {
16    type Output;
17
18    private!();
19
20    fn predict(&self, input: &Rhs) -> crate::ModelResult<Self::Output>;
21}
22
23/// The [`PredictWithConfidence`] trait is an extension of the [`Predict`] trait, providing
24/// an additional method to obtain predictions along with a confidence score.
25pub trait PredictWithConfidence<Rhs>: Predict<Rhs> {
26    type Confidence;
27
28    fn predict_with_confidence(
29        &self,
30        input: &Rhs,
31    ) -> crate::ModelResult<(Self::Output, Self::Confidence)>;
32}
33
34/*
35 ************* Implementations *************
36*/
37
38use ndarray::{Array, Dimension, ScalarOperand};
39use num_traits::{Float, FromPrimitive};
40
41impl<M, U, V> Predict<U> for M
42where
43    M: Forward<U, Output = V>,
44{
45    type Output = V;
46
47    seal!();
48
49    fn predict(&self, input: &U) -> crate::ModelResult<Self::Output> {
50        self.forward(input).map_err(core::convert::Into::into)
51    }
52}
53
54impl<M, U, A, D> PredictWithConfidence<U> for M
55where
56    A: Float + FromPrimitive + ScalarOperand,
57    D: Dimension,
58    Self: Predict<U, Output = Array<A, D>>,
59{
60    type Confidence = A;
61
62    fn predict_with_confidence(
63        &self,
64        input: &U,
65    ) -> Result<(Self::Output, Self::Confidence), crate::ModelError> {
66        // Get the base prediction
67        let prediction = Predict::predict(self, input)?;
68        let shape = prediction.shape();
69        // Calculate confidence as the inverse of the variance of the output
70        // For each sample, compute the variance across the output dimensions
71        let batch_size = shape[0];
72
73        let mut variance_sum = A::zero();
74
75        for sample in prediction.rows() {
76            // Compute variance
77            let variance = sample.var(A::one());
78            variance_sum = variance_sum + variance;
79        }
80
81        // Average variance across the batch
82        let avg_variance = variance_sum / A::from_usize(batch_size).unwrap();
83        // Confidence: inverse of variance (clipped to avoid division by zero)
84        let confidence = (A::one() + avg_variance).recip();
85
86        Ok((prediction, confidence))
87    }
88}