Skip to main content

augurs_prophet/prophet/
options.rs

1//! Options to configure a Prophet model.
2//!
3//! These correspond very closely to the options in the Python
4//! implementation, but are not identical; some have been updated
5//! to be more idiomatic Rust.
6
7use std::{collections::HashMap, num::NonZeroU32};
8
9use crate::{Error, FeatureMode, Holiday, PositiveFloat, TimestampSeconds, TrendIndicator};
10
11/// The type of growth to use.
12#[derive(Clone, Copy, Debug, Eq, PartialEq, Default)]
13pub enum GrowthType {
14    /// Linear growth (default).
15    #[default]
16    Linear,
17    /// Logistic growth.
18    Logistic,
19    /// Flat growth.
20    Flat,
21}
22
23impl From<GrowthType> for TrendIndicator {
24    fn from(value: GrowthType) -> Self {
25        match value {
26            GrowthType::Linear => TrendIndicator::Linear,
27            GrowthType::Logistic => TrendIndicator::Logistic,
28            GrowthType::Flat => TrendIndicator::Flat,
29        }
30    }
31}
32
33/// Define whether to include a specific seasonality, and how it should be specified.
34#[derive(Clone, Copy, Debug, Default)]
35pub enum SeasonalityOption {
36    /// Automatically determine whether to include this seasonality.
37    ///
38    /// Yearly seasonality is automatically included if there is >=2
39    /// years of history.
40    ///
41    /// Weekly seasonality is automatically included if there is >=2
42    /// weeks of history, and the spacing between the dates in the
43    /// data is <7 days.
44    ///
45    /// Daily seasonality is automatically included if there is >=2
46    /// days of history, and the spacing between the dates in the
47    /// data is <1 day.
48    #[default]
49    Auto,
50    /// Manually specify whether to include this seasonality.
51    Manual(bool),
52    /// Enable this seasonality and use the provided number of Fourier terms.
53    Fourier(NonZeroU32),
54}
55
56/// How to scale the data prior to fitting the model.
57#[derive(Clone, Copy, Debug, Eq, PartialEq, Default)]
58pub enum Scaling {
59    /// Use abs-max scaling (the default).
60    #[default]
61    AbsMax,
62    /// Use min-max scaling.
63    MinMax,
64}
65
66/// How to do parameter estimation.
67///
68/// Note: for now, only MLE/MAP estimation is supported, i.e. there
69/// is no support for MCMC sampling. This will be added in the future!
70/// The enum will be marked as `non_exhaustive` until that point.
71#[derive(Clone, Debug, Copy, PartialEq, Eq, Default)]
72#[non_exhaustive]
73pub enum EstimationMode {
74    /// Use MLE estimation.
75    #[default]
76    Mle,
77    /// Use MAP estimation.
78    Map,
79    // This is not yet implemented. We need to add a new `Sampler` trait and
80    // implement it, then handle the different number outputs when predicting,
81    // before this can be enabled.
82    // /// Do full Bayesian inference with the specified number of MCMC samples.
83    //
84    // Mcmc(u32),
85}
86
87/// The width of the uncertainty intervals.
88///
89/// Must be between `0.0` and `1.0`. Common values are
90/// `0.8` (80%), `0.9` (90%) and `0.95` (95%).
91///
92/// Defaults to `0.8` for 80% intervals.
93#[derive(Debug, Clone, Copy, PartialEq, PartialOrd)]
94pub struct IntervalWidth(f64);
95
96impl Default for IntervalWidth {
97    fn default() -> Self {
98        Self(0.8)
99    }
100}
101
102impl IntervalWidth {
103    /// Attempt to create a new `IntervalWidth `.
104    ///
105    /// # Errors
106    ///
107    /// Returns an error if the provided float is less than 0.0.
108    pub fn try_new(f: f64) -> Result<Self, Error> {
109        if !f.is_finite() || f <= 0.0 {
110            return Err(Error::InvalidIntervalWidth(f));
111        }
112        Ok(Self(f))
113    }
114}
115
116impl TryFrom<f64> for IntervalWidth {
117    type Error = Error;
118    fn try_from(value: f64) -> Result<Self, Self::Error> {
119        Self::try_new(value)
120    }
121}
122
123impl std::ops::Deref for IntervalWidth {
124    type Target = f64;
125    fn deref(&self) -> &Self::Target {
126        &self.0
127    }
128}
129
130impl From<IntervalWidth> for f64 {
131    fn from(value: IntervalWidth) -> Self {
132        value.0
133    }
134}
135
136// TODO: consider getting rid of this? It's a bit weird, but it might
137// make it easier for users of the crate...
138/// Optional version of Prophet's options, before applying any defaults.
139#[derive(Default, Debug, Clone)]
140pub struct OptProphetOptions {
141    /// The type of growth (trend) to use.
142    pub growth: Option<GrowthType>,
143
144    /// An optional list of changepoints.
145    ///
146    /// If not provided, changepoints will be automatically selected.
147    pub changepoints: Option<Vec<TimestampSeconds>>,
148
149    /// The number of potential changepoints to include.
150    ///
151    /// Not used if `changepoints` is provided.
152    ///
153    /// If provided and `changepoints` is not provided, then
154    /// `n_changepoints` potential changepoints will be selected
155    /// uniformly from the first `changepoint_range` proportion of
156    /// the history.
157    pub n_changepoints: Option<u32>,
158
159    /// The proportion of the history to consider for potential changepoints.
160    ///
161    /// Not used if `changepoints` is provided.
162    pub changepoint_range: Option<PositiveFloat>,
163
164    /// How to fit yearly seasonality.
165    pub yearly_seasonality: Option<SeasonalityOption>,
166    /// How to fit weekly seasonality.
167    pub weekly_seasonality: Option<SeasonalityOption>,
168    /// How to fit daily seasonality.
169    pub daily_seasonality: Option<SeasonalityOption>,
170
171    /// How to model seasonality.
172    pub seasonality_mode: Option<FeatureMode>,
173
174    /// The prior scale for seasonality.
175    ///
176    /// This modulates the strength of seasonality,
177    /// with larger values allowing the model to fit
178    /// larger seasonal fluctuations and smaller values
179    /// dampening the seasonality.
180    ///
181    /// Can be specified for individual seasonalities
182    /// using [`Prophet::add_seasonality`](crate::Prophet::add_seasonality).
183    pub seasonality_prior_scale: Option<PositiveFloat>,
184
185    /// The prior scale for changepoints.
186    ///
187    /// This modulates the flexibility of the automatic
188    /// changepoint selection. Large values will allow many
189    /// changepoints, while small values will allow few
190    /// changepoints.
191    pub changepoint_prior_scale: Option<PositiveFloat>,
192
193    /// How to perform parameter estimation.
194    ///
195    /// When [`EstimationMode::Mle`] or [`EstimationMode::Map`]
196    /// are used then no MCMC samples are taken.
197    pub estimation: Option<EstimationMode>,
198
199    /// The width of the uncertainty intervals.
200    ///
201    /// Must be between `0.0` and `1.0`. Common values are
202    /// `0.8` (80%), `0.9` (90%) and `0.95` (95%).
203    pub interval_width: Option<IntervalWidth>,
204
205    /// The number of simulated draws used to estimate uncertainty intervals.
206    ///
207    /// Setting this value to `0` will disable uncertainty
208    /// estimation and speed up the calculation.
209    pub uncertainty_samples: Option<u32>,
210
211    /// How to scale the data prior to fitting the model.
212    pub scaling: Option<Scaling>,
213
214    /// Holidays to include in the model.
215    pub holidays: Option<HashMap<String, Holiday>>,
216    /// Prior scale for holidays.
217    ///
218    /// This parameter modulates the strength of the holiday
219    /// components model, unless overridden in each individual
220    /// holiday's input.
221    pub holidays_prior_scale: Option<PositiveFloat>,
222
223    /// How to model holidays.
224    pub holidays_mode: Option<FeatureMode>,
225}
226
227impl From<OptProphetOptions> for ProphetOptions {
228    fn from(value: OptProphetOptions) -> Self {
229        let defaults = ProphetOptions::default();
230        ProphetOptions {
231            growth: value.growth.unwrap_or(defaults.growth),
232            changepoints: value.changepoints,
233            n_changepoints: value.n_changepoints.unwrap_or(defaults.n_changepoints),
234            changepoint_range: value
235                .changepoint_range
236                .unwrap_or(defaults.changepoint_range),
237            yearly_seasonality: value
238                .yearly_seasonality
239                .unwrap_or(defaults.yearly_seasonality),
240            weekly_seasonality: value
241                .weekly_seasonality
242                .unwrap_or(defaults.weekly_seasonality),
243            daily_seasonality: value
244                .daily_seasonality
245                .unwrap_or(defaults.daily_seasonality),
246            seasonality_mode: value.seasonality_mode.unwrap_or(defaults.seasonality_mode),
247            seasonality_prior_scale: value
248                .seasonality_prior_scale
249                .unwrap_or(defaults.seasonality_prior_scale),
250            changepoint_prior_scale: value
251                .changepoint_prior_scale
252                .unwrap_or(defaults.changepoint_prior_scale),
253            estimation: value.estimation.unwrap_or(defaults.estimation),
254            interval_width: value.interval_width.unwrap_or(defaults.interval_width),
255            uncertainty_samples: value
256                .uncertainty_samples
257                .unwrap_or(defaults.uncertainty_samples),
258            scaling: value.scaling.unwrap_or(defaults.scaling),
259            holidays: value.holidays.unwrap_or(defaults.holidays),
260            holidays_prior_scale: value
261                .holidays_prior_scale
262                .unwrap_or(defaults.holidays_prior_scale),
263            holidays_mode: value.holidays_mode,
264        }
265    }
266}
267
268/// Options for Prophet, after applying defaults.
269#[derive(Debug, Clone)]
270pub struct ProphetOptions {
271    /// The type of growth (trend) to use.
272    ///
273    /// Defaults to [`GrowthType::Linear`].
274    pub growth: GrowthType,
275
276    /// An optional list of changepoints.
277    ///
278    /// If not provided, changepoints will be automatically selected.
279    pub changepoints: Option<Vec<TimestampSeconds>>,
280
281    /// The number of potential changepoints to include.
282    ///
283    /// Not used if `changepoints` is provided.
284    ///
285    /// If provided and `changepoints` is not provided, then
286    /// `n_changepoints` potential changepoints will be selected
287    /// uniformly from the first `changepoint_range` proportion of
288    /// the history.
289    ///
290    /// Defaults to 25.
291    pub n_changepoints: u32,
292
293    /// The proportion of the history to consider for potential changepoints.
294    ///
295    /// Not used if `changepoints` is provided.
296    ///
297    /// Defaults to `0.8` for the first 80% of the data.
298    pub changepoint_range: PositiveFloat,
299
300    /// How to fit yearly seasonality.
301    ///
302    /// Defaults to [`SeasonalityOption::Auto`].
303    pub yearly_seasonality: SeasonalityOption,
304    /// How to fit weekly seasonality.
305    ///
306    /// Defaults to [`SeasonalityOption::Auto`].
307    pub weekly_seasonality: SeasonalityOption,
308    /// How to fit daily seasonality.
309    ///
310    /// Defaults to [`SeasonalityOption::Auto`].
311    pub daily_seasonality: SeasonalityOption,
312
313    /// How to model seasonality.
314    ///
315    /// Defaults to [`FeatureMode::Additive`].
316    pub seasonality_mode: FeatureMode,
317
318    /// The prior scale for seasonality.
319    ///
320    /// This modulates the strength of seasonality,
321    /// with larger values allowing the model to fit
322    /// larger seasonal fluctuations and smaller values
323    /// dampening the seasonality.
324    ///
325    /// Can be specified for individual seasonalities
326    /// using [`Prophet::add_seasonality`](crate::Prophet::add_seasonality).
327    ///
328    /// Defaults to `10.0`.
329    pub seasonality_prior_scale: PositiveFloat,
330
331    /// The prior scale for changepoints.
332    ///
333    /// This modulates the flexibility of the automatic
334    /// changepoint selection. Large values will allow many
335    /// changepoints, while small values will allow few
336    /// changepoints.
337    ///
338    /// Defaults to `0.05`.
339    pub changepoint_prior_scale: PositiveFloat,
340
341    /// How to perform parameter estimation.
342    ///
343    /// When [`EstimationMode::Mle`] or [`EstimationMode::Map`]
344    /// are used then no MCMC samples are taken.
345    ///
346    /// Defaults to [`EstimationMode::Mle`].
347    pub estimation: EstimationMode,
348
349    /// The width of the uncertainty intervals.
350    ///
351    /// Must be between `0.0` and `1.0`. Common values are
352    /// `0.8` (80%), `0.9` (90%) and `0.95` (95%).
353    ///
354    /// Defaults to `0.8` for 80% intervals.
355    pub interval_width: IntervalWidth,
356
357    /// The number of simulated draws used to estimate uncertainty intervals.
358    ///
359    /// Setting this value to `0` will disable uncertainty
360    /// estimation and speed up the calculation.
361    ///
362    /// Defaults to `1000`.
363    pub uncertainty_samples: u32,
364
365    /// How to scale the data prior to fitting the model.
366    ///
367    /// Defaults to [`Scaling::AbsMax`].
368    pub scaling: Scaling,
369
370    /// Holidays to include in the model.
371    pub holidays: HashMap<String, Holiday>,
372    /// Prior scale for holidays.
373    ///
374    /// This parameter modulates the strength of the holiday
375    /// components model, unless overridden in each individual
376    /// holiday's input.
377    ///
378    /// Defaults to `100.0`.
379    pub holidays_prior_scale: PositiveFloat,
380
381    /// How to model holidays.
382    ///
383    /// Defaults to the same value as [`ProphetOptions::seasonality_mode`].
384    pub holidays_mode: Option<FeatureMode>,
385}
386
387impl Default for ProphetOptions {
388    fn default() -> Self {
389        Self {
390            growth: GrowthType::Linear,
391            changepoints: None,
392            n_changepoints: 25,
393            changepoint_range: 0.8.try_into().unwrap(),
394            yearly_seasonality: SeasonalityOption::default(),
395            weekly_seasonality: SeasonalityOption::default(),
396            daily_seasonality: SeasonalityOption::default(),
397            seasonality_mode: FeatureMode::Additive,
398            seasonality_prior_scale: 10.0.try_into().unwrap(),
399            changepoint_prior_scale: 0.05.try_into().unwrap(),
400            estimation: EstimationMode::Mle,
401            interval_width: IntervalWidth::default(),
402            uncertainty_samples: 1000,
403            scaling: Scaling::AbsMax,
404            holidays: HashMap::new(),
405            holidays_prior_scale: 100.0.try_into().unwrap(),
406            holidays_mode: None,
407        }
408    }
409}