rsdiff_core/traits/
propagate.rs

1/*
2    Appellation: prop <mod>
3    Contrib: FL03 <jo3mccain@icloud.com>
4*/
5use crate::error::PredictError;
6
7/// [Backward] describes an object capable of backward propagation.
8///
9///  
10pub trait Backward {
11    type Output;
12
13    fn backward(&self) -> Self::Output;
14}
15
16pub trait Module<T>: Forward<T> + Backward {
17    type Config;
18    type Params;
19
20    fn new(config: Self::Config) -> Self;
21
22    fn config(&self) -> &Self::Config;
23
24    fn config_mut(&mut self) -> &mut Self::Config;
25
26    fn parameters(&self) -> Self::Params;
27}
28
29/// [Forward] describes an object capable of forward propagation.
30pub trait Forward<T> {
31    type Output;
32
33    fn forward(&self, args: &T) -> Result<Self::Output, PredictError>;
34}
35
36pub trait ForwardIter<T> {
37    type Item: Forward<T, Output = T>;
38
39    fn forward_iter(self, args: &T) -> Result<<Self::Item as Forward<T>>::Output, PredictError>;
40}
41
42// Trait implementations
43mod impls {
44    use super::*;
45
46    impl<I, M, T> ForwardIter<T> for I
47    where
48        I: IntoIterator<Item = M>,
49        M: Forward<T, Output = T>,
50        T: Clone,
51    {
52        type Item = M;
53
54        fn forward_iter(self, args: &T) -> Result<M::Output, PredictError> {
55            let mut result = args.clone();
56            for i in self {
57                result = i.forward(&result)?;
58            }
59            Ok(result)
60        }
61    }
62
63    impl<S, T> Forward<T> for Option<S>
64    where
65        S: Forward<T, Output = T>,
66        T: Clone,
67    {
68        type Output = T;
69
70        fn forward(&self, args: &T) -> Result<Self::Output, PredictError> {
71            match self {
72                Some(s) => s.forward(args),
73                None => Ok(args.clone()),
74            }
75        }
76    }
77
78    impl<S, T> Forward<T> for S
79    where
80        S: AsRef<dyn Forward<T, Output = T>>,
81        T: Clone,
82    {
83        type Output = T;
84
85        fn forward(&self, args: &T) -> Result<Self::Output, PredictError> {
86            self.as_ref().forward(args)
87        }
88    }
89}