concision_core/traits/
predict.rs1use crate::error::PredictError;
6
7pub trait Forward<T> {
9 type Output;
10
11 fn forward(&self, args: &T) -> Self::Output;
12}
13
14pub trait ForwardIter<T> {
18 type Item: Forward<T, Output = T>;
19
20 fn forward_iter(self, args: &T) -> <Self::Item as Forward<T>>::Output;
21}
22
23pub trait Predict<T> {
24 type Output;
25
26 fn predict(&self, args: &T) -> Result<Self::Output, PredictError>;
27}
28
29impl<X, Y, S> Forward<X> for S
33where
34 S: Predict<X, Output = Y>,
35{
36 type Output = Y;
37
38 fn forward(&self, args: &X) -> Self::Output {
39 if let Ok(y) = self.predict(args) {
40 y
41 } else {
42 panic!("Error in forward propagation")
43 }
44 }
45}
46
47impl<I, M, T> ForwardIter<T> for I
48where
49 I: IntoIterator<Item = M>,
50 M: Forward<T, Output = T>,
51 T: Clone,
52{
53 type Item = M;
54
55 fn forward_iter(self, args: &T) -> M::Output {
56 self.into_iter()
57 .fold(args.clone(), |acc, m| m.forward(&acc))
58 }
59}
60
61impl<S, T> Predict<T> for Option<S>
62where
63 S: Predict<T, Output = T>,
64 T: Clone,
65{
66 type Output = T;
67
68 fn predict(&self, args: &T) -> Result<Self::Output, PredictError> {
69 match self {
70 Some(s) => s.predict(args),
71 None => Ok(args.clone()),
72 }
73 }
74}