augurs_prophet/features.rs
1//! Features used by Prophet, such as seasonality, regressors and holidays.
2use std::num::NonZeroU32;
3
4use crate::{
5 positive_float::PositiveFloat, prophet::prep::ONE_DAY_IN_SECONDS_INT, TimestampSeconds,
6};
7
8/// The mode of a seasonality, regressor, or holiday.
9#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
10pub enum FeatureMode {
11 /// Additive mode.
12 #[default]
13 Additive,
14 /// Multiplicative mode.
15 Multiplicative,
16}
17
18/// An occurrence of a holiday.
19///
20/// Each occurrence has a start and end time represented as
21/// a Unix timestamp. Holiday occurrences are therefore
22/// timestamp-unaware and can therefore span multiple days
23/// or even sub-daily periods.
24///
25/// This differs from the Python and R Prophet implementations,
26/// which require all holidays to be day-long events. Some
27/// convenience methods are provided to create day-long
28/// occurrences: see [`HolidayOccurrence::for_day`] and
29/// [`HolidayOccurrence::for_day_in_tz`].
30///
31/// The caller is responsible for ensuring that the start
32/// and end time provided are in the correct timezone.
33/// One way to do this is to use [`chrono::FixedOffset`][fo]
34/// to create an offset representing the time zone,
35/// [`FixedOffset::with_ymd_and_hms`][wyah] to create a
36/// [`DateTime`][dt] in that time zone, then [`DateTime::timestamp`][ts]
37/// to get the Unix timestamp.
38///
39/// [fo]: https://docs.rs/chrono/latest/chrono/struct.FixedOffset.html
40/// [wyah]: https://docs.rs/chrono/latest/chrono/struct.FixedOffset.html#method.with_ymd_and_hms
41/// [dt]: https://docs.rs/chrono/latest/chrono/struct.DateTime.html
42/// [ts]: https://docs.rs/chrono/latest/chrono/struct.DateTime.html#method.timestamp
43#[derive(Debug, Clone)]
44pub struct HolidayOccurrence {
45 pub(crate) start: TimestampSeconds,
46 pub(crate) end: TimestampSeconds,
47}
48
49impl HolidayOccurrence {
50 /// Create a new holiday occurrence with the given
51 /// start and end timestamp.
52 pub fn new(start: TimestampSeconds, end: TimestampSeconds) -> Self {
53 Self { start, end }
54 }
55
56 /// Create a new holiday encompassing midnight on the day
57 /// of the given timestamp to midnight on the following day,
58 /// in UTC.
59 ///
60 /// This is a convenience method to reproduce the behaviour
61 /// of the Python and R Prophet implementations, which require
62 /// all holidays to be day-long events.
63 ///
64 /// Note that this will _not_ handle daylight saving time
65 /// transitions correctly. To handle this correctly, use
66 /// [`HolidayOccurrence::new`] with the correct start and
67 /// end times, e.g. by calculating them using [`chrono`].
68 ///
69 /// [`chrono`]: https://docs.rs/chrono/latest/chrono
70 pub fn for_day(day: TimestampSeconds) -> Self {
71 Self::for_day_in_tz(day, 0)
72 }
73
74 /// Create a new holiday encompassing midnight on the day
75 /// of the given timestamp to midnight on the following day,
76 /// in a timezone represented by the `utc_offset_seconds`.
77 ///
78 /// The UTC offset can be calculated using, for example,
79 /// [`chrono::FixedOffset::local_minus_utc`][lmu]. Alternatively
80 /// it's the number of seconds to add to convert from the
81 /// local time to UTC, so UTC+1 is represented by `3600`
82 /// and UTC-5 by `-18000`.
83 ///
84 /// This is a convenience method to reproduce the behaviour
85 /// of the Python and R Prophet implementations, which require
86 /// all holidays to be day-long events.
87 ///
88 /// Note that this will _not_ handle daylight saving time
89 /// transitions correctly. To handle this correctly, use
90 /// [`HolidayOccurrence::new`] with the correct start and
91 /// end times, e.g. by calculating them using [`chrono`].
92 ///
93 /// [`chrono`]: https://docs.rs/chrono/latest/chrono
94 /// [lmu]: https://docs.rs/chrono/latest/chrono/struct.FixedOffset.html#method.local_minus_utc
95 pub fn for_day_in_tz(day: TimestampSeconds, utc_offset_seconds: i32) -> Self {
96 let day = floor_day(day, utc_offset_seconds);
97 Self {
98 start: day,
99 end: day + ONE_DAY_IN_SECONDS_INT,
100 }
101 }
102
103 /// Check if the given timestamp is within this occurrence.
104 pub(crate) fn contains(&self, ds: TimestampSeconds) -> bool {
105 self.start <= ds && ds < self.end
106 }
107}
108
109/// A holiday to be considered by the Prophet model.
110#[derive(Debug, Clone)]
111pub struct Holiday {
112 pub(crate) occurrences: Vec<HolidayOccurrence>,
113 pub(crate) prior_scale: Option<PositiveFloat>,
114}
115
116impl Holiday {
117 /// Create a new holiday with the given occurrences.
118 pub fn new(occurrences: Vec<HolidayOccurrence>) -> Self {
119 Self {
120 occurrences,
121 prior_scale: None,
122 }
123 }
124
125 /// Set the prior scale for the holiday.
126 pub fn with_prior_scale(mut self, prior_scale: PositiveFloat) -> Self {
127 self.prior_scale = Some(prior_scale);
128 self
129 }
130}
131
132fn floor_day(ds: TimestampSeconds, offset: i32) -> TimestampSeconds {
133 let adjusted_ds = ds + offset as TimestampSeconds;
134 let remainder =
135 ((adjusted_ds % ONE_DAY_IN_SECONDS_INT) + ONE_DAY_IN_SECONDS_INT) % ONE_DAY_IN_SECONDS_INT;
136 // Adjust the date to the holiday's UTC offset.
137 ds - remainder
138}
139
140/// Whether or not to standardize a regressor.
141#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
142pub enum Standardize {
143 /// Automatically determine whether to standardize.
144 ///
145 /// Numeric regressors will be standardized while
146 /// binary regressors will not.
147 #[default]
148 Auto,
149 /// Standardize this regressor.
150 Yes,
151 /// Do not standardize this regressor.
152 No,
153}
154
155impl From<bool> for Standardize {
156 fn from(b: bool) -> Self {
157 if b {
158 Standardize::Yes
159 } else {
160 Standardize::No
161 }
162 }
163}
164
165/// Scales for a regressor.
166///
167/// This will be inserted into [`Scales::extra_regressors`]
168/// if the regressor is standardized.
169#[derive(Debug, Clone)]
170pub(crate) struct RegressorScale {
171 /// Whether to standardize this regressor.
172 ///
173 /// This is a `bool` rather than a `Standardize`
174 /// because we'll have decided whether to automatically
175 /// standardize by the time this is constructed.
176 pub(crate) standardize: bool,
177 /// The mean of the regressor.
178 pub(crate) mu: f64,
179 /// The standard deviation of the regressor.
180 pub(crate) std: f64,
181}
182
183impl Default for RegressorScale {
184 fn default() -> Self {
185 Self {
186 standardize: false,
187 mu: 0.0,
188 std: 1.0,
189 }
190 }
191}
192
193/// An exogynous regressor.
194///
195/// By default, regressors inherit the `holidays_prior_scale`
196/// configured on the Prophet model as their prior scale.
197#[derive(Debug, Clone, Default)]
198pub struct Regressor {
199 pub(crate) mode: FeatureMode,
200 pub(crate) prior_scale: Option<PositiveFloat>,
201 pub(crate) standardize: Standardize,
202}
203
204impl Regressor {
205 /// Create a new additive regressor.
206 pub fn additive() -> Self {
207 Self {
208 mode: FeatureMode::Additive,
209 ..Default::default()
210 }
211 }
212
213 /// Create a new multiplicative regressor.
214 pub fn multiplicative() -> Self {
215 Self {
216 mode: FeatureMode::Multiplicative,
217 ..Default::default()
218 }
219 }
220
221 /// Set the prior scale of this regressor.
222 ///
223 /// By default, regressors inherit the `holidays_prior_scale`
224 /// configured on the Prophet model as their prior scale.
225 pub fn with_prior_scale(mut self, prior_scale: PositiveFloat) -> Self {
226 self.prior_scale = Some(prior_scale);
227 self
228 }
229
230 /// Set whether to standardize this regressor.
231 pub fn with_standardize(mut self, standardize: Standardize) -> Self {
232 self.standardize = standardize;
233 self
234 }
235}
236
237/// A seasonality to include in the model.
238#[derive(Debug, Clone)]
239pub struct Seasonality {
240 pub(crate) period: PositiveFloat,
241 pub(crate) fourier_order: NonZeroU32,
242 pub(crate) prior_scale: Option<PositiveFloat>,
243 pub(crate) mode: Option<FeatureMode>,
244 pub(crate) condition_name: Option<String>,
245}
246
247impl Seasonality {
248 /// Create a new `Seasonality` with the given period and fourier order.
249 ///
250 /// By default, the prior scale and mode will be inherited from the
251 /// Prophet model config, and the seasonality is assumed to be
252 /// non-conditional.
253 pub fn new(period: PositiveFloat, fourier_order: NonZeroU32) -> Self {
254 Self {
255 period,
256 fourier_order,
257 prior_scale: None,
258 mode: None,
259 condition_name: None,
260 }
261 }
262
263 /// Set the prior scale of this seasonality.
264 ///
265 /// By default, seasonalities inherit the prior scale
266 /// configured on the Prophet model; this allows the
267 /// prior scale to be customised for each seasonality.
268 pub fn with_prior_scale(mut self, prior_scale: PositiveFloat) -> Self {
269 self.prior_scale = Some(prior_scale);
270 self
271 }
272
273 /// Set the mode of this seasonality.
274 ///
275 /// By default, seasonalities inherit the mode
276 /// configured on the Prophet model; this allows the
277 /// mode to be customised for each seasonality.
278 pub fn with_mode(mut self, mode: FeatureMode) -> Self {
279 self.mode = Some(mode);
280 self
281 }
282
283 /// Set this seasonality as conditional.
284 ///
285 /// A column with the provided condition name must be
286 /// present in the data passed to Prophet otherwise
287 /// training will fail. This can be added with
288 /// [`TrainingData::with_seasonality_conditions`](crate::TrainingData::with_seasonality_conditions).
289 pub fn with_condition(mut self, condition_name: String) -> Self {
290 self.condition_name = Some(condition_name);
291 self
292 }
293}
294
295#[cfg(test)]
296mod test {
297 use chrono::{FixedOffset, TimeZone, Utc};
298
299 use crate::features::floor_day;
300
301 #[test]
302 fn floor_day_no_offset() {
303 let offset = Utc;
304 let expected = offset
305 .with_ymd_and_hms(2024, 11, 21, 0, 0, 0)
306 .unwrap()
307 .timestamp();
308 assert_eq!(floor_day(expected, 0), expected);
309 assert_eq!(
310 floor_day(
311 offset
312 .with_ymd_and_hms(2024, 11, 21, 15, 3, 12)
313 .unwrap()
314 .timestamp(),
315 0
316 ),
317 expected
318 );
319 }
320
321 #[test]
322 fn floor_day_positive_offset() {
323 let offset = FixedOffset::east_opt(60 * 60 * 4).unwrap();
324 let expected = offset
325 .with_ymd_and_hms(2024, 11, 21, 0, 0, 0)
326 .unwrap()
327 .timestamp();
328
329 assert_eq!(floor_day(expected, offset.local_minus_utc()), expected);
330 assert_eq!(
331 floor_day(
332 offset
333 .with_ymd_and_hms(2024, 11, 21, 15, 3, 12)
334 .unwrap()
335 .timestamp(),
336 offset.local_minus_utc()
337 ),
338 expected
339 );
340 }
341
342 #[test]
343 fn floor_day_negative_offset() {
344 let offset = FixedOffset::west_opt(60 * 60 * 3).unwrap();
345 let expected = offset
346 .with_ymd_and_hms(2024, 11, 21, 0, 0, 0)
347 .unwrap()
348 .timestamp();
349
350 assert_eq!(floor_day(expected, offset.local_minus_utc()), expected);
351 assert_eq!(
352 floor_day(
353 offset
354 .with_ymd_and_hms(2024, 11, 21, 15, 3, 12)
355 .unwrap()
356 .timestamp(),
357 offset.local_minus_utc()
358 ),
359 expected
360 );
361 }
362
363 #[test]
364 fn floor_day_edge_cases() {
365 // Test maximum valid offset (UTC+14)
366 let max_offset = 14 * 60 * 60;
367 let offset = FixedOffset::east_opt(max_offset).unwrap();
368 let expected = offset
369 .with_ymd_and_hms(2024, 11, 21, 0, 0, 0)
370 .unwrap()
371 .timestamp();
372 assert_eq!(
373 floor_day(
374 offset
375 .with_ymd_and_hms(2024, 11, 21, 12, 0, 0)
376 .unwrap()
377 .timestamp(),
378 offset.local_minus_utc()
379 ),
380 expected
381 );
382
383 // Test near day boundary
384 let offset = FixedOffset::east_opt(60).unwrap();
385 let expected = offset
386 .with_ymd_and_hms(2024, 11, 21, 0, 0, 0)
387 .unwrap()
388 .timestamp();
389 assert_eq!(
390 floor_day(
391 offset
392 .with_ymd_and_hms(2024, 11, 21, 23, 59, 59)
393 .unwrap()
394 .timestamp(),
395 offset.local_minus_utc()
396 ),
397 expected
398 );
399
400 // Test when the day is before the epoch.
401 let offset = FixedOffset::west_opt(3600).unwrap();
402 let expected = offset
403 .with_ymd_and_hms(1969, 1, 1, 0, 0, 0)
404 .unwrap()
405 .timestamp();
406 assert_eq!(
407 floor_day(
408 offset
409 .with_ymd_and_hms(1969, 1, 1, 0, 30, 0)
410 .unwrap()
411 .timestamp(),
412 offset.local_minus_utc()
413 ),
414 expected
415 );
416 }
417}