augurs_prophet/
features.rs

1//! Features used by Prophet, such as seasonality, regressors and holidays.
2use std::num::NonZeroU32;
3
4use crate::{
5    positive_float::PositiveFloat, prophet::prep::ONE_DAY_IN_SECONDS_INT, TimestampSeconds,
6};
7
8/// The mode of a seasonality, regressor, or holiday.
9#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
10pub enum FeatureMode {
11    /// Additive mode.
12    #[default]
13    Additive,
14    /// Multiplicative mode.
15    Multiplicative,
16}
17
18/// An occurrence of a holiday.
19///
20/// Each occurrence has a start and end time represented as
21/// a Unix timestamp. Holiday occurrences are therefore
22/// timestamp-unaware and can therefore span multiple days
23/// or even sub-daily periods.
24///
25/// This differs from the Python and R Prophet implementations,
26/// which require all holidays to be day-long events. Some
27/// convenience methods are provided to create day-long
28/// occurrences: see [`HolidayOccurrence::for_day`] and
29/// [`HolidayOccurrence::for_day_in_tz`].
30///
31/// The caller is responsible for ensuring that the start
32/// and end time provided are in the correct timezone.
33/// One way to do this is to use [`chrono::FixedOffset`][fo]
34/// to create an offset representing the time zone,
35/// [`FixedOffset::with_ymd_and_hms`][wyah] to create a
36/// [`DateTime`][dt] in that time zone, then [`DateTime::timestamp`][ts]
37/// to get the Unix timestamp.
38///
39/// [fo]: https://docs.rs/chrono/latest/chrono/struct.FixedOffset.html
40/// [wyah]: https://docs.rs/chrono/latest/chrono/struct.FixedOffset.html#method.with_ymd_and_hms
41/// [dt]: https://docs.rs/chrono/latest/chrono/struct.DateTime.html
42/// [ts]: https://docs.rs/chrono/latest/chrono/struct.DateTime.html#method.timestamp
43#[derive(Debug, Clone)]
44pub struct HolidayOccurrence {
45    pub(crate) start: TimestampSeconds,
46    pub(crate) end: TimestampSeconds,
47}
48
49impl HolidayOccurrence {
50    /// Create a new holiday occurrence with the given
51    /// start and end timestamp.
52    pub fn new(start: TimestampSeconds, end: TimestampSeconds) -> Self {
53        Self { start, end }
54    }
55
56    /// Create a new holiday encompassing midnight on the day
57    /// of the given timestamp to midnight on the following day,
58    /// in UTC.
59    ///
60    /// This is a convenience method to reproduce the behaviour
61    /// of the Python and R Prophet implementations, which require
62    /// all holidays to be day-long events.
63    ///
64    /// Note that this will _not_ handle daylight saving time
65    /// transitions correctly. To handle this correctly, use
66    /// [`HolidayOccurrence::new`] with the correct start and
67    /// end times, e.g. by calculating them using [`chrono`].
68    ///
69    /// [`chrono`]: https://docs.rs/chrono/latest/chrono
70    pub fn for_day(day: TimestampSeconds) -> Self {
71        Self::for_day_in_tz(day, 0)
72    }
73
74    /// Create a new holiday encompassing midnight on the day
75    /// of the given timestamp to midnight on the following day,
76    /// in a timezone represented by the `utc_offset_seconds`.
77    ///
78    /// The UTC offset can be calculated using, for example,
79    /// [`chrono::FixedOffset::local_minus_utc`][lmu]. Alternatively
80    /// it's the number of seconds to add to convert from the
81    /// local time to UTC, so UTC+1 is represented by `3600`
82    /// and UTC-5 by `-18000`.
83    ///
84    /// This is a convenience method to reproduce the behaviour
85    /// of the Python and R Prophet implementations, which require
86    /// all holidays to be day-long events.
87    ///
88    /// Note that this will _not_ handle daylight saving time
89    /// transitions correctly. To handle this correctly, use
90    /// [`HolidayOccurrence::new`] with the correct start and
91    /// end times, e.g. by calculating them using [`chrono`].
92    ///
93    /// [`chrono`]: https://docs.rs/chrono/latest/chrono
94    /// [lmu]: https://docs.rs/chrono/latest/chrono/struct.FixedOffset.html#method.local_minus_utc
95    pub fn for_day_in_tz(day: TimestampSeconds, utc_offset_seconds: i32) -> Self {
96        let day = floor_day(day, utc_offset_seconds);
97        Self {
98            start: day,
99            end: day + ONE_DAY_IN_SECONDS_INT,
100        }
101    }
102
103    /// Check if the given timestamp is within this occurrence.
104    pub(crate) fn contains(&self, ds: TimestampSeconds) -> bool {
105        self.start <= ds && ds < self.end
106    }
107}
108
109/// A holiday to be considered by the Prophet model.
110#[derive(Debug, Clone)]
111pub struct Holiday {
112    pub(crate) occurrences: Vec<HolidayOccurrence>,
113    pub(crate) prior_scale: Option<PositiveFloat>,
114}
115
116impl Holiday {
117    /// Create a new holiday with the given occurrences.
118    pub fn new(occurrences: Vec<HolidayOccurrence>) -> Self {
119        Self {
120            occurrences,
121            prior_scale: None,
122        }
123    }
124
125    /// Set the prior scale for the holiday.
126    pub fn with_prior_scale(mut self, prior_scale: PositiveFloat) -> Self {
127        self.prior_scale = Some(prior_scale);
128        self
129    }
130}
131
132fn floor_day(ds: TimestampSeconds, offset: i32) -> TimestampSeconds {
133    let adjusted_ds = ds + offset as TimestampSeconds;
134    let remainder =
135        ((adjusted_ds % ONE_DAY_IN_SECONDS_INT) + ONE_DAY_IN_SECONDS_INT) % ONE_DAY_IN_SECONDS_INT;
136    // Adjust the date to the holiday's UTC offset.
137    ds - remainder
138}
139
140/// Whether or not to standardize a regressor.
141#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
142pub enum Standardize {
143    /// Automatically determine whether to standardize.
144    ///
145    /// Numeric regressors will be standardized while
146    /// binary regressors will not.
147    #[default]
148    Auto,
149    /// Standardize this regressor.
150    Yes,
151    /// Do not standardize this regressor.
152    No,
153}
154
155impl From<bool> for Standardize {
156    fn from(b: bool) -> Self {
157        if b {
158            Standardize::Yes
159        } else {
160            Standardize::No
161        }
162    }
163}
164
165/// Scales for a regressor.
166///
167/// This will be inserted into [`Scales::extra_regressors`]
168/// if the regressor is standardized.
169#[derive(Debug, Clone)]
170pub(crate) struct RegressorScale {
171    /// Whether to standardize this regressor.
172    ///
173    /// This is a `bool` rather than a `Standardize`
174    /// because we'll have decided whether to automatically
175    /// standardize by the time this is constructed.
176    pub(crate) standardize: bool,
177    /// The mean of the regressor.
178    pub(crate) mu: f64,
179    /// The standard deviation of the regressor.
180    pub(crate) std: f64,
181}
182
183impl Default for RegressorScale {
184    fn default() -> Self {
185        Self {
186            standardize: false,
187            mu: 0.0,
188            std: 1.0,
189        }
190    }
191}
192
193/// An exogynous regressor.
194///
195/// By default, regressors inherit the `holidays_prior_scale`
196/// configured on the Prophet model as their prior scale.
197#[derive(Debug, Clone, Default)]
198pub struct Regressor {
199    pub(crate) mode: FeatureMode,
200    pub(crate) prior_scale: Option<PositiveFloat>,
201    pub(crate) standardize: Standardize,
202}
203
204impl Regressor {
205    /// Create a new additive regressor.
206    pub fn additive() -> Self {
207        Self {
208            mode: FeatureMode::Additive,
209            ..Default::default()
210        }
211    }
212
213    /// Create a new multiplicative regressor.
214    pub fn multiplicative() -> Self {
215        Self {
216            mode: FeatureMode::Multiplicative,
217            ..Default::default()
218        }
219    }
220
221    /// Set the prior scale of this regressor.
222    ///
223    /// By default, regressors inherit the `holidays_prior_scale`
224    /// configured on the Prophet model as their prior scale.
225    pub fn with_prior_scale(mut self, prior_scale: PositiveFloat) -> Self {
226        self.prior_scale = Some(prior_scale);
227        self
228    }
229
230    /// Set whether to standardize this regressor.
231    pub fn with_standardize(mut self, standardize: Standardize) -> Self {
232        self.standardize = standardize;
233        self
234    }
235}
236
237/// A seasonality to include in the model.
238#[derive(Debug, Clone)]
239pub struct Seasonality {
240    pub(crate) period: PositiveFloat,
241    pub(crate) fourier_order: NonZeroU32,
242    pub(crate) prior_scale: Option<PositiveFloat>,
243    pub(crate) mode: Option<FeatureMode>,
244    pub(crate) condition_name: Option<String>,
245}
246
247impl Seasonality {
248    /// Create a new `Seasonality` with the given period and fourier order.
249    ///
250    /// By default, the prior scale and mode will be inherited from the
251    /// Prophet model config, and the seasonality is assumed to be
252    /// non-conditional.
253    pub fn new(period: PositiveFloat, fourier_order: NonZeroU32) -> Self {
254        Self {
255            period,
256            fourier_order,
257            prior_scale: None,
258            mode: None,
259            condition_name: None,
260        }
261    }
262
263    /// Set the prior scale of this seasonality.
264    ///
265    /// By default, seasonalities inherit the prior scale
266    /// configured on the Prophet model; this allows the
267    /// prior scale to be customised for each seasonality.
268    pub fn with_prior_scale(mut self, prior_scale: PositiveFloat) -> Self {
269        self.prior_scale = Some(prior_scale);
270        self
271    }
272
273    /// Set the mode of this seasonality.
274    ///
275    /// By default, seasonalities inherit the mode
276    /// configured on the Prophet model; this allows the
277    /// mode to be customised for each seasonality.
278    pub fn with_mode(mut self, mode: FeatureMode) -> Self {
279        self.mode = Some(mode);
280        self
281    }
282
283    /// Set this seasonality as conditional.
284    ///
285    /// A column with the provided condition name must be
286    /// present in the data passed to Prophet otherwise
287    /// training will fail. This can be added with
288    /// [`TrainingData::with_seasonality_conditions`](crate::TrainingData::with_seasonality_conditions).
289    pub fn with_condition(mut self, condition_name: String) -> Self {
290        self.condition_name = Some(condition_name);
291        self
292    }
293}
294
295#[cfg(test)]
296mod test {
297    use chrono::{FixedOffset, TimeZone, Utc};
298
299    use crate::features::floor_day;
300
301    #[test]
302    fn floor_day_no_offset() {
303        let offset = Utc;
304        let expected = offset
305            .with_ymd_and_hms(2024, 11, 21, 0, 0, 0)
306            .unwrap()
307            .timestamp();
308        assert_eq!(floor_day(expected, 0), expected);
309        assert_eq!(
310            floor_day(
311                offset
312                    .with_ymd_and_hms(2024, 11, 21, 15, 3, 12)
313                    .unwrap()
314                    .timestamp(),
315                0
316            ),
317            expected
318        );
319    }
320
321    #[test]
322    fn floor_day_positive_offset() {
323        let offset = FixedOffset::east_opt(60 * 60 * 4).unwrap();
324        let expected = offset
325            .with_ymd_and_hms(2024, 11, 21, 0, 0, 0)
326            .unwrap()
327            .timestamp();
328
329        assert_eq!(floor_day(expected, offset.local_minus_utc()), expected);
330        assert_eq!(
331            floor_day(
332                offset
333                    .with_ymd_and_hms(2024, 11, 21, 15, 3, 12)
334                    .unwrap()
335                    .timestamp(),
336                offset.local_minus_utc()
337            ),
338            expected
339        );
340    }
341
342    #[test]
343    fn floor_day_negative_offset() {
344        let offset = FixedOffset::west_opt(60 * 60 * 3).unwrap();
345        let expected = offset
346            .with_ymd_and_hms(2024, 11, 21, 0, 0, 0)
347            .unwrap()
348            .timestamp();
349
350        assert_eq!(floor_day(expected, offset.local_minus_utc()), expected);
351        assert_eq!(
352            floor_day(
353                offset
354                    .with_ymd_and_hms(2024, 11, 21, 15, 3, 12)
355                    .unwrap()
356                    .timestamp(),
357                offset.local_minus_utc()
358            ),
359            expected
360        );
361    }
362
363    #[test]
364    fn floor_day_edge_cases() {
365        // Test maximum valid offset (UTC+14)
366        let max_offset = 14 * 60 * 60;
367        let offset = FixedOffset::east_opt(max_offset).unwrap();
368        let expected = offset
369            .with_ymd_and_hms(2024, 11, 21, 0, 0, 0)
370            .unwrap()
371            .timestamp();
372        assert_eq!(
373            floor_day(
374                offset
375                    .with_ymd_and_hms(2024, 11, 21, 12, 0, 0)
376                    .unwrap()
377                    .timestamp(),
378                offset.local_minus_utc()
379            ),
380            expected
381        );
382
383        // Test near day boundary
384        let offset = FixedOffset::east_opt(60).unwrap();
385        let expected = offset
386            .with_ymd_and_hms(2024, 11, 21, 0, 0, 0)
387            .unwrap()
388            .timestamp();
389        assert_eq!(
390            floor_day(
391                offset
392                    .with_ymd_and_hms(2024, 11, 21, 23, 59, 59)
393                    .unwrap()
394                    .timestamp(),
395                offset.local_minus_utc()
396            ),
397            expected
398        );
399
400        // Test when the day is before the epoch.
401        let offset = FixedOffset::west_opt(3600).unwrap();
402        let expected = offset
403            .with_ymd_and_hms(1969, 1, 1, 0, 0, 0)
404            .unwrap()
405            .timestamp();
406        assert_eq!(
407            floor_day(
408                offset
409                    .with_ymd_and_hms(1969, 1, 1, 0, 30, 0)
410                    .unwrap()
411                    .timestamp(),
412                offset.local_minus_utc()
413            ),
414            expected
415        );
416    }
417}