concision_traits/
predict.rs

1/*
2    Appellation: predict <module>
3    Contrib: @FL03
4*/
5use crate::Forward;
6
7/// The [`Predict`] trait is designed as a _**model-specific**_ interface for making
8/// predictions. In the future, we may consider opening the trait up allowing for an
9/// alternative implementation of the trait, but for now, it is simply implemented for all
10/// implementors of the [`Forward`] trait.
11///
12/// **Note:** The trait is sealed, preventing external implementations, ensuring that only the
13/// library can define how predictions are made. This is to maintain consistency and integrity
14/// across different model implementations.
15pub trait Predict<Rhs> {
16    type Output;
17
18    private!();
19
20    fn predict(&self, input: &Rhs) -> Self::Output;
21}
22
23/// The [`PredictWithConfidence`] trait is an extension of the [`Predict`] trait, providing
24/// an additional method to obtain predictions along with a confidence score.
25pub trait PredictWithConfidence<Rhs>: Predict<Rhs> {
26    type Confidence;
27
28    fn predict_with_confidence(&self, input: &Rhs) -> Option<(Self::Output, Self::Confidence)>;
29}
30
31/*
32 ************* Implementations *************
33*/
34
35use ndarray::{Array, Dimension, ScalarOperand};
36use num_traits::{Float, FromPrimitive};
37
38impl<M, U, V> Predict<U> for M
39where
40    M: Forward<U, Output = V>,
41{
42    type Output = V;
43
44    seal!();
45
46    fn predict(&self, input: &U) -> Self::Output {
47        self.forward(input)
48    }
49}
50
51impl<M, U, A, D> PredictWithConfidence<U> for M
52where
53    A: Float + FromPrimitive + ScalarOperand,
54    D: Dimension,
55    Self: Predict<U, Output = Array<A, D>>,
56{
57    type Confidence = A;
58
59    fn predict_with_confidence(&self, input: &U) -> Option<(Self::Output, Self::Confidence)> {
60        // Get the base prediction
61        let prediction = Predict::predict(self, input);
62        let shape = prediction.shape();
63        // Calculate confidence as the inverse of the variance of the output
64        // For each sample, compute the variance across the output dimensions
65        let batch_size = shape[0];
66
67        let mut variance_sum = A::zero();
68
69        for sample in prediction.rows() {
70            // Compute variance
71            let variance = sample.var(A::one());
72            variance_sum = variance_sum + variance;
73        }
74
75        // Average variance across the batch
76        let avg_variance = variance_sum / A::from_usize(batch_size)?;
77        // Confidence: inverse of variance (clipped to avoid division by zero)
78        let confidence = (A::one() + avg_variance).recip();
79
80        Some((prediction, confidence))
81    }
82}