rsdiff_core/traits/
propagate.rs1use crate::error::PredictError;
6
7pub trait Backward {
11 type Output;
12
13 fn backward(&self) -> Self::Output;
14}
15
16pub trait Module<T>: Forward<T> + Backward {
17 type Config;
18 type Params;
19
20 fn new(config: Self::Config) -> Self;
21
22 fn config(&self) -> &Self::Config;
23
24 fn config_mut(&mut self) -> &mut Self::Config;
25
26 fn parameters(&self) -> Self::Params;
27}
28
29pub trait Forward<T> {
31 type Output;
32
33 fn forward(&self, args: &T) -> Result<Self::Output, PredictError>;
34}
35
36pub trait ForwardIter<T> {
37 type Item: Forward<T, Output = T>;
38
39 fn forward_iter(self, args: &T) -> Result<<Self::Item as Forward<T>>::Output, PredictError>;
40}
41
42mod impls {
44 use super::*;
45
46 impl<I, M, T> ForwardIter<T> for I
47 where
48 I: IntoIterator<Item = M>,
49 M: Forward<T, Output = T>,
50 T: Clone,
51 {
52 type Item = M;
53
54 fn forward_iter(self, args: &T) -> Result<M::Output, PredictError> {
55 let mut result = args.clone();
56 for i in self {
57 result = i.forward(&result)?;
58 }
59 Ok(result)
60 }
61 }
62
63 impl<S, T> Forward<T> for Option<S>
64 where
65 S: Forward<T, Output = T>,
66 T: Clone,
67 {
68 type Output = T;
69
70 fn forward(&self, args: &T) -> Result<Self::Output, PredictError> {
71 match self {
72 Some(s) => s.forward(args),
73 None => Ok(args.clone()),
74 }
75 }
76 }
77
78 impl<S, T> Forward<T> for S
79 where
80 S: AsRef<dyn Forward<T, Output = T>>,
81 T: Clone,
82 {
83 type Output = T;
84
85 fn forward(&self, args: &T) -> Result<Self::Output, PredictError> {
86 self.as_ref().forward(args)
87 }
88 }
89}