concision_traits/
predict.rs1use crate::Forward;
6
7pub trait Predict<Rhs> {
16 type Output;
17
18 private!();
19
20 fn predict(&self, input: &Rhs) -> Self::Output;
21}
22
23pub trait PredictWithConfidence<Rhs>: Predict<Rhs> {
26 type Confidence;
27
28 fn predict_with_confidence(&self, input: &Rhs) -> Option<(Self::Output, Self::Confidence)>;
29}
30
31use ndarray::{Array, Dimension, ScalarOperand};
36use num_traits::{Float, FromPrimitive};
37
38impl<M, U, V> Predict<U> for M
39where
40 M: Forward<U, Output = V>,
41{
42 type Output = V;
43
44 seal!();
45
46 fn predict(&self, input: &U) -> Self::Output {
47 self.forward(input)
48 }
49}
50
51impl<M, U, A, D> PredictWithConfidence<U> for M
52where
53 A: Float + FromPrimitive + ScalarOperand,
54 D: Dimension,
55 Self: Predict<U, Output = Array<A, D>>,
56{
57 type Confidence = A;
58
59 fn predict_with_confidence(&self, input: &U) -> Option<(Self::Output, Self::Confidence)> {
60 let prediction = Predict::predict(self, input);
62 let shape = prediction.shape();
63 let batch_size = shape[0];
66
67 let mut variance_sum = A::zero();
68
69 for sample in prediction.rows() {
70 let variance = sample.var(A::one());
72 variance_sum = variance_sum + variance;
73 }
74
75 let avg_variance = variance_sum / A::from_usize(batch_size)?;
77 let confidence = (A::one() + avg_variance).recip();
79
80 Some((prediction, confidence))
81 }
82}