1use crate::{Backward, Forward, Train};
2
3#[derive(Clone, Debug)]
4pub struct Chain<T, U>(pub T, pub U);
5
6impl<T, U> Forward for Chain<T, U>
7where
8 T: Forward,
9 U: Forward<Input = T::Output>,
10{
11 type Input = T::Input;
12 type Internal = (T::Internal, T::Output, U::Internal);
13 type Output = U::Output;
14
15 fn forward(&self, input: &T::Input) -> (Self::Internal, Self::Output) {
16 let (t_internal, t_output) = self.0.forward(input);
17 let (u_internal, u_output) = self.1.forward(&t_output);
18 ((t_internal, t_output, u_internal), u_output)
19 }
20}
21
22impl<T, U, O> Backward for Chain<T, U>
23where
24 T: Backward<OutputDelta = U::InputDelta> + Forward<Output = O>,
25 U: Backward + Forward<Input = O>,
26{
27 type OutputDelta = U::OutputDelta;
28 type InputDelta = T::InputDelta;
29 type TrainDelta = (T::TrainDelta, U::TrainDelta);
30
31 fn backward(
32 &self,
33 input: &T::Input,
34 internal: &Self::Internal,
35 output_delta: &U::OutputDelta,
36 ) -> (Self::InputDelta, Self::TrainDelta) {
37 let (t_internal, t_output, u_internal) = internal;
38 let (u_input_delta, u_train_delta) = self.1.backward(t_output, u_internal, output_delta);
39 let (t_input_delta, t_train_delta) = self.0.backward(input, t_internal, &u_input_delta);
40 (t_input_delta, (t_train_delta, u_train_delta))
41 }
42}
43
44impl<T, U, O> Train for Chain<T, U>
45where
46 T: Train + Backward<OutputDelta = U::InputDelta> + Forward<Output = O>,
47 U: Train + Backward + Forward<Input = O>,
48{
49 fn train(&mut self, train_delta: &Self::TrainDelta) {
50 let (t_train_delta, u_train_delta) = train_delta;
51 self.0.train(t_train_delta);
52 self.1.train(u_train_delta);
53 }
54}