light_curve_feature/transformers/
composed.rs

1use crate::transformers::transformer::*;
2
3use thiserror::Error;
4
5macro_const! {
6    const DOC: &str = r#"
7Transformer composed from a list of transformers
8
9The transformers are stacked in the order they are given in the list, with number of
10features per trasnformer specified.
11"#;
12}
13
14#[derive(Error, Debug)]
15pub enum ComposedTransformerConstructionError {
16    #[error("Size mismatch between transformer size requirements and given feature size")]
17    SizeMismatch,
18}
19
20#[doc = DOC!()]
21#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
22pub struct ComposedTransformer<Tr> {
23    transformers: Vec<(Tr, usize)>,
24    input_size: usize,
25    size_hint: usize,
26}
27
28impl<Tr> ComposedTransformer<Tr>
29where
30    Tr: TransformerPropsTrait,
31{
32    /// Create a new composed transformer from a list of transformers and the number of features
33    /// they take as input.
34    pub fn new(
35        transformers: impl Into<Vec<(Tr, usize)>>,
36    ) -> Result<Self, ComposedTransformerConstructionError> {
37        let transformers = transformers.into();
38        let mut input_size = 0;
39        let mut size_hint = 0;
40        for (tr, size) in &transformers {
41            if !tr.is_size_valid(*size) {
42                return Err(ComposedTransformerConstructionError::SizeMismatch);
43            }
44            input_size += *size;
45            size_hint += tr.size_hint(*size);
46        }
47        Ok(Self {
48            transformers,
49            input_size,
50            size_hint,
51        })
52    }
53
54    /// Create a new composed transformer from a list of transformers assumed that all of them
55    /// may take a single feature as input.
56    pub fn from_transformers(
57        transformers: impl IntoIterator<Item = Tr>,
58    ) -> Result<Self, ComposedTransformerConstructionError> {
59        let transformers = transformers
60            .into_iter()
61            .map(|tr| (tr, 1))
62            .collect::<Vec<_>>();
63        Self::new(transformers)
64    }
65}
66
67impl<Tr> ComposedTransformer<Tr> {
68    pub const fn doc() -> &'static str {
69        DOC
70    }
71}
72
73impl<Tr> TransformerPropsTrait for ComposedTransformer<Tr>
74where
75    Tr: TransformerPropsTrait,
76{
77    fn is_size_valid(&self, input_size: usize) -> bool {
78        self.input_size == input_size
79    }
80
81    fn size_hint(&self, _input_size: usize) -> usize {
82        self.size_hint
83    }
84
85    fn names(&self, input_names: &[&str]) -> Vec<String> {
86        let mut names_iter = input_names.iter();
87        self.transformers
88            .iter()
89            .flat_map(|(tr, size)| {
90                let names_batch = names_iter.by_ref().take(*size).copied().collect::<Vec<_>>();
91                tr.names(&names_batch[..]).into_iter()
92            })
93            .collect()
94    }
95
96    fn descriptions(&self, input_descriptions: &[&str]) -> Vec<String> {
97        let mut desc_iter = input_descriptions.iter();
98        self.transformers
99            .iter()
100            .flat_map(|(tr, size)| {
101                let desc_batch = desc_iter.by_ref().take(*size).copied().collect::<Vec<_>>();
102                tr.descriptions(&desc_batch[..]).into_iter()
103            })
104            .collect()
105    }
106}
107
108impl<T, Tr> TransformerTrait<T> for ComposedTransformer<Tr>
109where
110    T: Float,
111    Tr: TransformerTrait<T>,
112{
113    fn transform(&self, input: Vec<T>) -> Vec<T> {
114        let mut input_iter = input.into_iter();
115        self.transformers
116            .iter()
117            .flat_map(|(tr, size)| {
118                let input_batch = input_iter.by_ref().take(*size).collect::<Vec<_>>();
119                tr.transform(input_batch)
120            })
121            .collect()
122    }
123}
124
125#[cfg(test)]
126mod tests {
127    use super::*;
128    use crate::transformers::bazin_fit::BazinFitTransformer;
129    use crate::transformers::clipped_lg::ClippedLgTransformer;
130    use crate::transformers::identity::IdentityTransformer;
131    use crate::transformers::linexp_fit::LinexpFitTransformer;
132
133    transformer_check_doc_static_method!(
134        check_doc_static_method,
135        ComposedTransformer::<Transformer<f32>>
136    );
137    transformer_check_size_hint!(
138        check_size_hint,
139        ComposedTransformer::<Transformer<f32>>::new([
140            (IdentityTransformer::new().into(), 3),
141            (ClippedLgTransformer::default().into(), 2),
142            (BazinFitTransformer::default().into(), 6),
143            (LinexpFitTransformer::default().into(), 5),
144        ])
145        .unwrap(),
146        ComposedTransformer::<Transformer<f32>>
147    );
148
149    #[test]
150    fn test_from_transformers() {
151        let tr = ComposedTransformer::from_transformers([
152            IdentityTransformer::new(),
153            IdentityTransformer::new(),
154        ])
155        .unwrap();
156        assert_eq!(tr.transformers.len(), 2);
157        assert_eq!(tr.input_size, 2);
158        assert_eq!(tr.size_hint, 2);
159        assert_eq!(tr.transformers[0].1, 1);
160        assert_eq!(tr.transformers[1].1, 1);
161    }
162
163    #[test]
164    fn test_from_transformers_size_mismatch() {
165        let result = ComposedTransformer::<Transformer<f32>>::from_transformers([
166            IdentityTransformer::new().into(),
167            BazinFitTransformer::default().into(), // requires more than one feature
168            LinexpFitTransformer::default().into(), // requires more than one feature
169        ]);
170        assert!(result.is_err());
171    }
172
173    #[test]
174    fn test_new() {
175        let tr = ComposedTransformer::<Transformer<f32>>::new([
176            (IdentityTransformer::new().into(), 2),
177            (BazinFitTransformer::default().into(), 6),
178            (LinexpFitTransformer::default().into(), 5),
179        ])
180        .unwrap();
181        assert_eq!(tr.transformers.len(), 3);
182        assert_eq!(tr.input_size, 13);
183        assert_eq!(tr.size_hint, 11);
184        assert_eq!(tr.transformers[0].1, 2);
185        assert_eq!(tr.transformers[1].1, 6);
186        assert_eq!(tr.transformers[2].1, 5);
187    }
188
189    #[test]
190    fn test_new_size_mismatch() {
191        let result = ComposedTransformer::<Transformer<f32>>::new([
192            (IdentityTransformer::new().into(), 3),
193            (BazinFitTransformer::default().into(), 3), // requires six features
194            (LinexpFitTransformer::default().into(), 3), // requires five features
195        ]);
196        assert!(result.is_err());
197    }
198}