mli/
chain.rs

1use crate::{Backward, Forward, Train};
2
3#[derive(Clone, Debug)]
4pub struct Chain<T, U>(pub T, pub U);
5
6impl<T, U> Forward for Chain<T, U>
7where
8    T: Forward,
9    U: Forward<Input = T::Output>,
10{
11    type Input = T::Input;
12    type Internal = (T::Internal, T::Output, U::Internal);
13    type Output = U::Output;
14
15    fn forward(&self, input: &T::Input) -> (Self::Internal, Self::Output) {
16        let (t_internal, t_output) = self.0.forward(input);
17        let (u_internal, u_output) = self.1.forward(&t_output);
18        ((t_internal, t_output, u_internal), u_output)
19    }
20}
21
22impl<T, U, O> Backward for Chain<T, U>
23where
24    T: Backward<OutputDelta = U::InputDelta> + Forward<Output = O>,
25    U: Backward + Forward<Input = O>,
26{
27    type OutputDelta = U::OutputDelta;
28    type InputDelta = T::InputDelta;
29    type TrainDelta = (T::TrainDelta, U::TrainDelta);
30
31    fn backward(
32        &self,
33        input: &T::Input,
34        internal: &Self::Internal,
35        output_delta: &U::OutputDelta,
36    ) -> (Self::InputDelta, Self::TrainDelta) {
37        let (t_internal, t_output, u_internal) = internal;
38        let (u_input_delta, u_train_delta) = self.1.backward(t_output, u_internal, output_delta);
39        let (t_input_delta, t_train_delta) = self.0.backward(input, t_internal, &u_input_delta);
40        (t_input_delta, (t_train_delta, u_train_delta))
41    }
42}
43
44impl<T, U, O> Train for Chain<T, U>
45where
46    T: Train + Backward<OutputDelta = U::InputDelta> + Forward<Output = O>,
47    U: Train + Backward + Forward<Input = O>,
48{
49    fn train(&mut self, train_delta: &Self::TrainDelta) {
50        let (t_train_delta, u_train_delta) = train_delta;
51        self.0.train(t_train_delta);
52        self.1.train(u_train_delta);
53    }
54}