concision_core/traits/
predict.rs

1/*
2   Appellation: predict <mod>
3   Contrib: FL03 <jo3mccain@icloud.com>
4*/
5use crate::error::PredictError;
6
7/// [Forward] describes an object capable of forward propagation.
8pub trait Forward<T> {
9    type Output;
10
11    fn forward(&self, args: &T) -> Self::Output;
12}
13
14/// [ForwardIter] describes any iterators whose elements implement [Forward].
15/// This trait is typically used in deep neural networks who need to forward propagate
16/// across a number of layers.
17pub trait ForwardIter<T> {
18    type Item: Forward<T, Output = T>;
19
20    fn forward_iter(self, args: &T) -> <Self::Item as Forward<T>>::Output;
21}
22
23pub trait Predict<T> {
24    type Output;
25
26    fn predict(&self, args: &T) -> Result<Self::Output, PredictError>;
27}
28
29/*
30 ********* Implementations *********
31*/
32impl<X, Y, S> Forward<X> for S
33where
34    S: Predict<X, Output = Y>,
35{
36    type Output = Y;
37
38    fn forward(&self, args: &X) -> Self::Output {
39        if let Ok(y) = self.predict(args) {
40            y
41        } else {
42            panic!("Error in forward propagation")
43        }
44    }
45}
46
47impl<I, M, T> ForwardIter<T> for I
48where
49    I: IntoIterator<Item = M>,
50    M: Forward<T, Output = T>,
51    T: Clone,
52{
53    type Item = M;
54
55    fn forward_iter(self, args: &T) -> M::Output {
56        self.into_iter()
57            .fold(args.clone(), |acc, m| m.forward(&acc))
58    }
59}
60
61impl<S, T> Predict<T> for Option<S>
62where
63    S: Predict<T, Output = T>,
64    T: Clone,
65{
66    type Output = T;
67
68    fn predict(&self, args: &T) -> Result<Self::Output, PredictError> {
69        match self {
70            Some(s) => s.predict(args),
71            None => Ok(args.clone()),
72        }
73    }
74}