Skip to main content

augurs_prophet/
data.rs

1use std::collections::HashMap;
2
3use crate::{Error, TimestampSeconds};
4
5/// The data needed to train a Prophet model.
6///
7/// Create a `TrainingData` object with the `new` method, which
8/// takes a vector of dates and a vector of values.
9///
10/// Optionally, you can add seasonality conditions, regressors,
11/// floor and cap columns.
12#[derive(Clone, Debug)]
13pub struct TrainingData {
14    pub(crate) n: usize,
15    pub(crate) ds: Vec<TimestampSeconds>,
16    pub(crate) y: Vec<f64>,
17    pub(crate) cap: Option<Vec<f64>>,
18    pub(crate) floor: Option<Vec<f64>>,
19    pub(crate) seasonality_conditions: HashMap<String, Vec<bool>>,
20    pub(crate) x: HashMap<String, Vec<f64>>,
21}
22
23impl TrainingData {
24    /// Create some input data for Prophet.
25    ///
26    /// # Errors
27    ///
28    /// Returns an error if the lengths of `ds` and `y` differ.
29    pub fn new(ds: Vec<TimestampSeconds>, y: Vec<f64>) -> Result<Self, Error> {
30        if ds.len() != y.len() {
31            return Err(Error::MismatchedLengths {
32                a_name: "ds".to_string(),
33                a: ds.len(),
34                b_name: "y".to_string(),
35                b: y.len(),
36            });
37        }
38        Ok(Self {
39            n: ds.len(),
40            ds,
41            y,
42            cap: None,
43            floor: None,
44            seasonality_conditions: HashMap::new(),
45            x: HashMap::new(),
46        })
47    }
48
49    /// Add the cap for logistic growth.
50    ///
51    /// # Errors
52    ///
53    /// Returns an error if the lengths of `ds` and `cap` differ.
54    pub fn with_cap(mut self, cap: Vec<f64>) -> Result<Self, Error> {
55        if self.n != cap.len() {
56            return Err(Error::MismatchedLengths {
57                a_name: "ds".to_string(),
58                a: self.ds.len(),
59                b_name: "cap".to_string(),
60                b: cap.len(),
61            });
62        }
63        self.cap = Some(cap);
64        Ok(self)
65    }
66
67    /// Add the floor for logistic growth.
68    ///
69    /// # Errors
70    ///
71    /// Returns an error if the lengths of `ds` and `floor` differ.
72    pub fn with_floor(mut self, floor: Vec<f64>) -> Result<Self, Error> {
73        if self.n != floor.len() {
74            return Err(Error::MismatchedLengths {
75                a_name: "ds".to_string(),
76                a: self.ds.len(),
77                b_name: "floor".to_string(),
78                b: floor.len(),
79            });
80        }
81        self.floor = Some(floor);
82        Ok(self)
83    }
84
85    /// Add condition columns for conditional seasonalities.
86    ///
87    /// # Errors
88    ///
89    /// Returns an error if the lengths of `ds` and any of the seasonality
90    /// condition columns differ.
91    pub fn with_seasonality_conditions(
92        mut self,
93        seasonality_conditions: HashMap<String, Vec<bool>>,
94    ) -> Result<Self, Error> {
95        for (name, cond) in seasonality_conditions.iter() {
96            if self.n != cond.len() {
97                return Err(Error::MismatchedLengths {
98                    a_name: "ds".to_string(),
99                    a: self.ds.len(),
100                    b_name: name.clone(),
101                    b: cond.len(),
102                });
103            }
104        }
105        self.seasonality_conditions = seasonality_conditions;
106        Ok(self)
107    }
108
109    /// Add regressors.
110    ///
111    /// # Errors
112    ///
113    /// Returns an error if the lengths of `ds` and any of the regressor
114    /// columns differ.
115    pub fn with_regressors(mut self, x: HashMap<String, Vec<f64>>) -> Result<Self, Error> {
116        for (name, reg) in x.iter() {
117            if self.n != reg.len() {
118                return Err(Error::MismatchedLengths {
119                    a_name: "ds".to_string(),
120                    a: self.ds.len(),
121                    b_name: name.clone(),
122                    b: reg.len(),
123                });
124            }
125            if reg.iter().any(|x| x.is_nan()) {
126                return Err(Error::NaNValue {
127                    column: name.clone(),
128                });
129            }
130        }
131        self.x = x;
132        Ok(self)
133    }
134
135    /// Remove any NaN values from the `y` column, and the corresponding values
136    /// in the other columns.
137    ///
138    /// This handles updating all columns and `n` appropriately.
139    ///
140    /// NaN values in other columns are retained.
141    pub(crate) fn filter_nans(mut self) -> Self {
142        let mut n = self.n;
143        let mut keep = vec![true; self.n];
144        self.y = self
145            .y
146            .into_iter()
147            .zip(keep.iter_mut())
148            .filter_map(|(y, keep)| {
149                if y.is_nan() {
150                    *keep = false;
151                    n -= 1;
152                    None
153                } else {
154                    Some(y)
155                }
156            })
157            .collect();
158
159        fn retain<T>(v: &mut Vec<T>, keep: &[bool]) {
160            let mut iter = keep.iter();
161            v.retain(|_| *iter.next().unwrap());
162        }
163
164        self.n = n;
165        retain(&mut self.ds, &keep);
166        if let Some(cap) = self.cap.as_mut() {
167            retain(cap, &keep);
168        }
169        if let Some(floor) = self.floor.as_mut() {
170            retain(floor, &keep);
171        }
172        for v in self.x.values_mut() {
173            retain(v, &keep);
174        }
175        for v in self.seasonality_conditions.values_mut() {
176            retain(v, &keep);
177        }
178        self
179    }
180
181    #[cfg(test)]
182    pub(crate) fn head(mut self, n: usize) -> Self {
183        self.n = n;
184        self.ds.truncate(n);
185        self.y.truncate(n);
186        if let Some(cap) = self.cap.as_mut() {
187            cap.truncate(n);
188        }
189        if let Some(floor) = self.floor.as_mut() {
190            floor.truncate(n);
191        }
192        for (_, v) in self.x.iter_mut() {
193            v.truncate(n);
194        }
195        for (_, v) in self.seasonality_conditions.iter_mut() {
196            v.truncate(n);
197        }
198        self
199    }
200
201    #[cfg(test)]
202    pub(crate) fn tail(mut self, n: usize) -> Self {
203        let split = self.ds.len() - n;
204        self.n = n;
205        self.ds = self.ds.split_off(split);
206        self.y = self.y.split_off(split);
207        if let Some(cap) = self.cap.as_mut() {
208            *cap = cap.split_off(split);
209        }
210        if let Some(floor) = self.floor.as_mut() {
211            *floor = floor.split_off(split);
212        }
213        for (_, v) in self.x.iter_mut() {
214            *v = v.split_off(split);
215        }
216        for (_, v) in self.seasonality_conditions.iter_mut() {
217            *v = v.split_off(split);
218        }
219        self
220    }
221
222    #[cfg(test)]
223    pub(crate) fn len(&self) -> usize {
224        self.n
225    }
226}
227
228/// The data needed to predict with a Prophet model.
229///
230/// The structure of the prediction data must be the same as the
231/// training data used to train the model, with the exception of
232/// `y` (which is being predicted).
233///
234/// That is, if your model used certain seasonality conditions or
235/// regressors, you must include them in the prediction data.
236#[derive(Clone, Debug)]
237pub struct PredictionData {
238    /// The number of time points in the prediction data.
239    pub n: usize,
240
241    /// The timestamps of the time series.
242    ///
243    /// These should be in seconds since the epoch.
244    pub ds: Vec<TimestampSeconds>,
245
246    /// Optionally, an upper bound (cap) on the values of the time series.
247    ///
248    /// Only used if the model's growth type is `logistic`.
249    pub cap: Option<Vec<f64>>,
250
251    /// Optionally, a lower bound (floor) on the values of the time series.
252    ///
253    /// Only used if the model's growth type is `logistic`.
254    pub floor: Option<Vec<f64>>,
255
256    /// Indicator variables for conditional seasonalities.
257    ///
258    /// The keys of the map are the names of the seasonality components,
259    /// and the values are boolean arrays of length `n` where `true` indicates
260    /// that the component is active for the corresponding time point.
261    pub seasonality_conditions: HashMap<String, Vec<bool>>,
262
263    /// Exogenous regressors.
264    ///
265    /// The keys of the map are the names of the regressors,
266    /// and the values are arrays of length `n` containing the regressor values
267    /// for each time point.
268    pub x: HashMap<String, Vec<f64>>,
269}
270
271impl PredictionData {
272    /// Create some data to be used for predictions.
273    ///
274    /// Predictions will be made for each of the dates in `ds`.
275    pub fn new(ds: Vec<TimestampSeconds>) -> Self {
276        Self {
277            n: ds.len(),
278            ds,
279            cap: None,
280            floor: None,
281            seasonality_conditions: HashMap::new(),
282            x: HashMap::new(),
283        }
284    }
285
286    /// Add the cap for logistic growth.
287    ///
288    /// # Errors
289    ///
290    /// Returns an error if the lengths of `ds` and `cap` are not equal.
291    pub fn with_cap(mut self, cap: Vec<f64>) -> Result<Self, Error> {
292        if self.n != cap.len() {
293            return Err(Error::MismatchedLengths {
294                a_name: "ds".to_string(),
295                a: self.ds.len(),
296                b_name: "cap".to_string(),
297                b: cap.len(),
298            });
299        }
300        self.cap = Some(cap);
301        Ok(self)
302    }
303
304    /// Add the floor for logistic growth.
305    ///
306    /// # Errors
307    ///
308    /// Returns an error if the lengths of `ds` and `floor` are not equal.
309    pub fn with_floor(mut self, floor: Vec<f64>) -> Result<Self, Error> {
310        if self.n != floor.len() {
311            return Err(Error::MismatchedLengths {
312                a_name: "ds".to_string(),
313                a: self.ds.len(),
314                b_name: "floor".to_string(),
315                b: floor.len(),
316            });
317        }
318        self.floor = Some(floor);
319        Ok(self)
320    }
321
322    /// Add condition columns for conditional seasonalities.
323    ///
324    /// # Errors
325    ///
326    /// Returns an error if the lengths of any of the seasonality conditions
327    /// are not equal to the length of `ds`.
328    pub fn with_seasonality_conditions(
329        mut self,
330        seasonality_conditions: HashMap<String, Vec<bool>>,
331    ) -> Result<Self, Error> {
332        for (name, cond) in seasonality_conditions.iter() {
333            if self.n != cond.len() {
334                return Err(Error::MismatchedLengths {
335                    a_name: "ds".to_string(),
336                    a: self.ds.len(),
337                    b_name: name.clone(),
338                    b: cond.len(),
339                });
340            }
341        }
342        self.seasonality_conditions = seasonality_conditions;
343        Ok(self)
344    }
345
346    /// Add regressors.
347    ///
348    /// # Errors
349    ///
350    /// Returns an error if the lengths of any of the regressors
351    /// are not equal to the length of `ds`.
352    pub fn with_regressors(mut self, x: HashMap<String, Vec<f64>>) -> Result<Self, Error> {
353        for (name, reg) in x.iter() {
354            if self.n != reg.len() {
355                return Err(Error::MismatchedLengths {
356                    a_name: "ds".to_string(),
357                    a: self.ds.len(),
358                    b_name: name.clone(),
359                    b: reg.len(),
360                });
361            }
362            if reg.iter().any(|x| x.is_nan()) {
363                return Err(Error::NaNValue {
364                    column: name.clone(),
365                });
366            }
367        }
368        self.x = x;
369        Ok(self)
370    }
371}
372
373#[cfg(test)]
374mod test {
375    use crate::testdata::daily_univariate_ts;
376
377    #[test]
378    fn filter_nans() {
379        let mut data = daily_univariate_ts();
380        let expected_len = data.n - 1;
381        data.y[10] = f64::NAN;
382        let data = data.filter_nans();
383        assert_eq!(data.n, expected_len);
384        assert_eq!(data.y.len(), expected_len);
385        assert_eq!(data.ds.len(), expected_len);
386    }
387}