1use std::collections::HashMap;
2
3use crate::{Error, TimestampSeconds};
4
5#[derive(Clone, Debug)]
13pub struct TrainingData {
14 pub(crate) n: usize,
15 pub(crate) ds: Vec<TimestampSeconds>,
16 pub(crate) y: Vec<f64>,
17 pub(crate) cap: Option<Vec<f64>>,
18 pub(crate) floor: Option<Vec<f64>>,
19 pub(crate) seasonality_conditions: HashMap<String, Vec<bool>>,
20 pub(crate) x: HashMap<String, Vec<f64>>,
21}
22
23impl TrainingData {
24 pub fn new(ds: Vec<TimestampSeconds>, y: Vec<f64>) -> Result<Self, Error> {
30 if ds.len() != y.len() {
31 return Err(Error::MismatchedLengths {
32 a_name: "ds".to_string(),
33 a: ds.len(),
34 b_name: "y".to_string(),
35 b: y.len(),
36 });
37 }
38 Ok(Self {
39 n: ds.len(),
40 ds,
41 y,
42 cap: None,
43 floor: None,
44 seasonality_conditions: HashMap::new(),
45 x: HashMap::new(),
46 })
47 }
48
49 pub fn with_cap(mut self, cap: Vec<f64>) -> Result<Self, Error> {
55 if self.n != cap.len() {
56 return Err(Error::MismatchedLengths {
57 a_name: "ds".to_string(),
58 a: self.ds.len(),
59 b_name: "cap".to_string(),
60 b: cap.len(),
61 });
62 }
63 self.cap = Some(cap);
64 Ok(self)
65 }
66
67 pub fn with_floor(mut self, floor: Vec<f64>) -> Result<Self, Error> {
73 if self.n != floor.len() {
74 return Err(Error::MismatchedLengths {
75 a_name: "ds".to_string(),
76 a: self.ds.len(),
77 b_name: "floor".to_string(),
78 b: floor.len(),
79 });
80 }
81 self.floor = Some(floor);
82 Ok(self)
83 }
84
85 pub fn with_seasonality_conditions(
92 mut self,
93 seasonality_conditions: HashMap<String, Vec<bool>>,
94 ) -> Result<Self, Error> {
95 for (name, cond) in seasonality_conditions.iter() {
96 if self.n != cond.len() {
97 return Err(Error::MismatchedLengths {
98 a_name: "ds".to_string(),
99 a: self.ds.len(),
100 b_name: name.clone(),
101 b: cond.len(),
102 });
103 }
104 }
105 self.seasonality_conditions = seasonality_conditions;
106 Ok(self)
107 }
108
109 pub fn with_regressors(mut self, x: HashMap<String, Vec<f64>>) -> Result<Self, Error> {
116 for (name, reg) in x.iter() {
117 if self.n != reg.len() {
118 return Err(Error::MismatchedLengths {
119 a_name: "ds".to_string(),
120 a: self.ds.len(),
121 b_name: name.clone(),
122 b: reg.len(),
123 });
124 }
125 if reg.iter().any(|x| x.is_nan()) {
126 return Err(Error::NaNValue {
127 column: name.clone(),
128 });
129 }
130 }
131 self.x = x;
132 Ok(self)
133 }
134
135 pub(crate) fn filter_nans(mut self) -> Self {
142 let mut n = self.n;
143 let mut keep = vec![true; self.n];
144 self.y = self
145 .y
146 .into_iter()
147 .zip(keep.iter_mut())
148 .filter_map(|(y, keep)| {
149 if y.is_nan() {
150 *keep = false;
151 n -= 1;
152 None
153 } else {
154 Some(y)
155 }
156 })
157 .collect();
158
159 fn retain<T>(v: &mut Vec<T>, keep: &[bool]) {
160 let mut iter = keep.iter();
161 v.retain(|_| *iter.next().unwrap());
162 }
163
164 self.n = n;
165 retain(&mut self.ds, &keep);
166 if let Some(cap) = self.cap.as_mut() {
167 retain(cap, &keep);
168 }
169 if let Some(floor) = self.floor.as_mut() {
170 retain(floor, &keep);
171 }
172 for v in self.x.values_mut() {
173 retain(v, &keep);
174 }
175 for v in self.seasonality_conditions.values_mut() {
176 retain(v, &keep);
177 }
178 self
179 }
180
181 #[cfg(test)]
182 pub(crate) fn head(mut self, n: usize) -> Self {
183 self.n = n;
184 self.ds.truncate(n);
185 self.y.truncate(n);
186 if let Some(cap) = self.cap.as_mut() {
187 cap.truncate(n);
188 }
189 if let Some(floor) = self.floor.as_mut() {
190 floor.truncate(n);
191 }
192 for (_, v) in self.x.iter_mut() {
193 v.truncate(n);
194 }
195 for (_, v) in self.seasonality_conditions.iter_mut() {
196 v.truncate(n);
197 }
198 self
199 }
200
201 #[cfg(test)]
202 pub(crate) fn tail(mut self, n: usize) -> Self {
203 let split = self.ds.len() - n;
204 self.n = n;
205 self.ds = self.ds.split_off(split);
206 self.y = self.y.split_off(split);
207 if let Some(cap) = self.cap.as_mut() {
208 *cap = cap.split_off(split);
209 }
210 if let Some(floor) = self.floor.as_mut() {
211 *floor = floor.split_off(split);
212 }
213 for (_, v) in self.x.iter_mut() {
214 *v = v.split_off(split);
215 }
216 for (_, v) in self.seasonality_conditions.iter_mut() {
217 *v = v.split_off(split);
218 }
219 self
220 }
221
222 #[cfg(test)]
223 pub(crate) fn len(&self) -> usize {
224 self.n
225 }
226}
227
228#[derive(Clone, Debug)]
237pub struct PredictionData {
238 pub n: usize,
240
241 pub ds: Vec<TimestampSeconds>,
245
246 pub cap: Option<Vec<f64>>,
250
251 pub floor: Option<Vec<f64>>,
255
256 pub seasonality_conditions: HashMap<String, Vec<bool>>,
262
263 pub x: HashMap<String, Vec<f64>>,
269}
270
271impl PredictionData {
272 pub fn new(ds: Vec<TimestampSeconds>) -> Self {
276 Self {
277 n: ds.len(),
278 ds,
279 cap: None,
280 floor: None,
281 seasonality_conditions: HashMap::new(),
282 x: HashMap::new(),
283 }
284 }
285
286 pub fn with_cap(mut self, cap: Vec<f64>) -> Result<Self, Error> {
292 if self.n != cap.len() {
293 return Err(Error::MismatchedLengths {
294 a_name: "ds".to_string(),
295 a: self.ds.len(),
296 b_name: "cap".to_string(),
297 b: cap.len(),
298 });
299 }
300 self.cap = Some(cap);
301 Ok(self)
302 }
303
304 pub fn with_floor(mut self, floor: Vec<f64>) -> Result<Self, Error> {
310 if self.n != floor.len() {
311 return Err(Error::MismatchedLengths {
312 a_name: "ds".to_string(),
313 a: self.ds.len(),
314 b_name: "floor".to_string(),
315 b: floor.len(),
316 });
317 }
318 self.floor = Some(floor);
319 Ok(self)
320 }
321
322 pub fn with_seasonality_conditions(
329 mut self,
330 seasonality_conditions: HashMap<String, Vec<bool>>,
331 ) -> Result<Self, Error> {
332 for (name, cond) in seasonality_conditions.iter() {
333 if self.n != cond.len() {
334 return Err(Error::MismatchedLengths {
335 a_name: "ds".to_string(),
336 a: self.ds.len(),
337 b_name: name.clone(),
338 b: cond.len(),
339 });
340 }
341 }
342 self.seasonality_conditions = seasonality_conditions;
343 Ok(self)
344 }
345
346 pub fn with_regressors(mut self, x: HashMap<String, Vec<f64>>) -> Result<Self, Error> {
353 for (name, reg) in x.iter() {
354 if self.n != reg.len() {
355 return Err(Error::MismatchedLengths {
356 a_name: "ds".to_string(),
357 a: self.ds.len(),
358 b_name: name.clone(),
359 b: reg.len(),
360 });
361 }
362 if reg.iter().any(|x| x.is_nan()) {
363 return Err(Error::NaNValue {
364 column: name.clone(),
365 });
366 }
367 }
368 self.x = x;
369 Ok(self)
370 }
371}
372
373#[cfg(test)]
374mod test {
375 use crate::testdata::daily_univariate_ts;
376
377 #[test]
378 fn filter_nans() {
379 let mut data = daily_univariate_ts();
380 let expected_len = data.n - 1;
381 data.y[10] = f64::NAN;
382 let data = data.filter_nans();
383 assert_eq!(data.n, expected_len);
384 assert_eq!(data.y.len(), expected_len);
385 assert_eq!(data.ds.len(), expected_len);
386 }
387}