light_curve_feature/transformers/
composed.rs1use crate::transformers::transformer::*;
2
3use thiserror::Error;
4
5macro_const! {
6 const DOC: &str = r#"
7Transformer composed from a list of transformers
8
9The transformers are stacked in the order they are given in the list, with number of
10features per trasnformer specified.
11"#;
12}
13
14#[derive(Error, Debug)]
15pub enum ComposedTransformerConstructionError {
16 #[error("Size mismatch between transformer size requirements and given feature size")]
17 SizeMismatch,
18}
19
20#[doc = DOC!()]
21#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
22pub struct ComposedTransformer<Tr> {
23 transformers: Vec<(Tr, usize)>,
24 input_size: usize,
25 size_hint: usize,
26}
27
28impl<Tr> ComposedTransformer<Tr>
29where
30 Tr: TransformerPropsTrait,
31{
32 pub fn new(
35 transformers: impl Into<Vec<(Tr, usize)>>,
36 ) -> Result<Self, ComposedTransformerConstructionError> {
37 let transformers = transformers.into();
38 let mut input_size = 0;
39 let mut size_hint = 0;
40 for (tr, size) in &transformers {
41 if !tr.is_size_valid(*size) {
42 return Err(ComposedTransformerConstructionError::SizeMismatch);
43 }
44 input_size += *size;
45 size_hint += tr.size_hint(*size);
46 }
47 Ok(Self {
48 transformers,
49 input_size,
50 size_hint,
51 })
52 }
53
54 pub fn from_transformers(
57 transformers: impl IntoIterator<Item = Tr>,
58 ) -> Result<Self, ComposedTransformerConstructionError> {
59 let transformers = transformers
60 .into_iter()
61 .map(|tr| (tr, 1))
62 .collect::<Vec<_>>();
63 Self::new(transformers)
64 }
65}
66
67impl<Tr> ComposedTransformer<Tr> {
68 pub const fn doc() -> &'static str {
69 DOC
70 }
71}
72
73impl<Tr> TransformerPropsTrait for ComposedTransformer<Tr>
74where
75 Tr: TransformerPropsTrait,
76{
77 fn is_size_valid(&self, input_size: usize) -> bool {
78 self.input_size == input_size
79 }
80
81 fn size_hint(&self, _input_size: usize) -> usize {
82 self.size_hint
83 }
84
85 fn names(&self, input_names: &[&str]) -> Vec<String> {
86 let mut names_iter = input_names.iter();
87 self.transformers
88 .iter()
89 .flat_map(|(tr, size)| {
90 let names_batch = names_iter.by_ref().take(*size).copied().collect::<Vec<_>>();
91 tr.names(&names_batch[..]).into_iter()
92 })
93 .collect()
94 }
95
96 fn descriptions(&self, input_descriptions: &[&str]) -> Vec<String> {
97 let mut desc_iter = input_descriptions.iter();
98 self.transformers
99 .iter()
100 .flat_map(|(tr, size)| {
101 let desc_batch = desc_iter.by_ref().take(*size).copied().collect::<Vec<_>>();
102 tr.descriptions(&desc_batch[..]).into_iter()
103 })
104 .collect()
105 }
106}
107
108impl<T, Tr> TransformerTrait<T> for ComposedTransformer<Tr>
109where
110 T: Float,
111 Tr: TransformerTrait<T>,
112{
113 fn transform(&self, input: Vec<T>) -> Vec<T> {
114 let mut input_iter = input.into_iter();
115 self.transformers
116 .iter()
117 .flat_map(|(tr, size)| {
118 let input_batch = input_iter.by_ref().take(*size).collect::<Vec<_>>();
119 tr.transform(input_batch)
120 })
121 .collect()
122 }
123}
124
125#[cfg(test)]
126mod tests {
127 use super::*;
128 use crate::transformers::bazin_fit::BazinFitTransformer;
129 use crate::transformers::clipped_lg::ClippedLgTransformer;
130 use crate::transformers::identity::IdentityTransformer;
131 use crate::transformers::linexp_fit::LinexpFitTransformer;
132
133 transformer_check_doc_static_method!(
134 check_doc_static_method,
135 ComposedTransformer::<Transformer<f32>>
136 );
137 transformer_check_size_hint!(
138 check_size_hint,
139 ComposedTransformer::<Transformer<f32>>::new([
140 (IdentityTransformer::new().into(), 3),
141 (ClippedLgTransformer::default().into(), 2),
142 (BazinFitTransformer::default().into(), 6),
143 (LinexpFitTransformer::default().into(), 5),
144 ])
145 .unwrap(),
146 ComposedTransformer::<Transformer<f32>>
147 );
148
149 #[test]
150 fn test_from_transformers() {
151 let tr = ComposedTransformer::from_transformers([
152 IdentityTransformer::new(),
153 IdentityTransformer::new(),
154 ])
155 .unwrap();
156 assert_eq!(tr.transformers.len(), 2);
157 assert_eq!(tr.input_size, 2);
158 assert_eq!(tr.size_hint, 2);
159 assert_eq!(tr.transformers[0].1, 1);
160 assert_eq!(tr.transformers[1].1, 1);
161 }
162
163 #[test]
164 fn test_from_transformers_size_mismatch() {
165 let result = ComposedTransformer::<Transformer<f32>>::from_transformers([
166 IdentityTransformer::new().into(),
167 BazinFitTransformer::default().into(), LinexpFitTransformer::default().into(), ]);
170 assert!(result.is_err());
171 }
172
173 #[test]
174 fn test_new() {
175 let tr = ComposedTransformer::<Transformer<f32>>::new([
176 (IdentityTransformer::new().into(), 2),
177 (BazinFitTransformer::default().into(), 6),
178 (LinexpFitTransformer::default().into(), 5),
179 ])
180 .unwrap();
181 assert_eq!(tr.transformers.len(), 3);
182 assert_eq!(tr.input_size, 13);
183 assert_eq!(tr.size_hint, 11);
184 assert_eq!(tr.transformers[0].1, 2);
185 assert_eq!(tr.transformers[1].1, 6);
186 assert_eq!(tr.transformers[2].1, 5);
187 }
188
189 #[test]
190 fn test_new_size_mismatch() {
191 let result = ComposedTransformer::<Transformer<f32>>::new([
192 (IdentityTransformer::new().into(), 3),
193 (BazinFitTransformer::default().into(), 3), (LinexpFitTransformer::default().into(), 3), ]);
196 assert!(result.is_err());
197 }
198}