augurs_prophet/prophet/options.rs
1//! Options to configure a Prophet model.
2//!
3//! These correspond very closely to the options in the Python
4//! implementation, but are not identical; some have been updated
5//! to be more idiomatic Rust.
6
7use std::{collections::HashMap, num::NonZeroU32};
8
9use crate::{Error, FeatureMode, Holiday, PositiveFloat, TimestampSeconds, TrendIndicator};
10
11/// The type of growth to use.
12#[derive(Clone, Copy, Debug, Eq, PartialEq, Default)]
13pub enum GrowthType {
14 /// Linear growth (default).
15 #[default]
16 Linear,
17 /// Logistic growth.
18 Logistic,
19 /// Flat growth.
20 Flat,
21}
22
23impl From<GrowthType> for TrendIndicator {
24 fn from(value: GrowthType) -> Self {
25 match value {
26 GrowthType::Linear => TrendIndicator::Linear,
27 GrowthType::Logistic => TrendIndicator::Logistic,
28 GrowthType::Flat => TrendIndicator::Flat,
29 }
30 }
31}
32
33/// Define whether to include a specific seasonality, and how it should be specified.
34#[derive(Clone, Copy, Debug, Default)]
35pub enum SeasonalityOption {
36 /// Automatically determine whether to include this seasonality.
37 ///
38 /// Yearly seasonality is automatically included if there is >=2
39 /// years of history.
40 ///
41 /// Weekly seasonality is automatically included if there is >=2
42 /// weeks of history, and the spacing between the dates in the
43 /// data is <7 days.
44 ///
45 /// Daily seasonality is automatically included if there is >=2
46 /// days of history, and the spacing between the dates in the
47 /// data is <1 day.
48 #[default]
49 Auto,
50 /// Manually specify whether to include this seasonality.
51 Manual(bool),
52 /// Enable this seasonality and use the provided number of Fourier terms.
53 Fourier(NonZeroU32),
54}
55
56/// How to scale the data prior to fitting the model.
57#[derive(Clone, Copy, Debug, Eq, PartialEq, Default)]
58pub enum Scaling {
59 /// Use abs-max scaling (the default).
60 #[default]
61 AbsMax,
62 /// Use min-max scaling.
63 MinMax,
64}
65
66/// How to do parameter estimation.
67///
68/// Note: for now, only MLE/MAP estimation is supported, i.e. there
69/// is no support for MCMC sampling. This will be added in the future!
70/// The enum will be marked as `non_exhaustive` until that point.
71#[derive(Clone, Debug, Copy, PartialEq, Eq, Default)]
72#[non_exhaustive]
73pub enum EstimationMode {
74 /// Use MLE estimation.
75 #[default]
76 Mle,
77 /// Use MAP estimation.
78 Map,
79 // This is not yet implemented. We need to add a new `Sampler` trait and
80 // implement it, then handle the different number outputs when predicting,
81 // before this can be enabled.
82 // /// Do full Bayesian inference with the specified number of MCMC samples.
83 //
84 // Mcmc(u32),
85}
86
87/// The width of the uncertainty intervals.
88///
89/// Must be between `0.0` and `1.0`. Common values are
90/// `0.8` (80%), `0.9` (90%) and `0.95` (95%).
91///
92/// Defaults to `0.8` for 80% intervals.
93#[derive(Debug, Clone, Copy, PartialEq, PartialOrd)]
94pub struct IntervalWidth(f64);
95
96impl Default for IntervalWidth {
97 fn default() -> Self {
98 Self(0.8)
99 }
100}
101
102impl IntervalWidth {
103 /// Attempt to create a new `IntervalWidth `.
104 ///
105 /// # Errors
106 ///
107 /// Returns an error if the provided float is less than 0.0.
108 pub fn try_new(f: f64) -> Result<Self, Error> {
109 if !f.is_finite() || f <= 0.0 {
110 return Err(Error::InvalidIntervalWidth(f));
111 }
112 Ok(Self(f))
113 }
114}
115
116impl TryFrom<f64> for IntervalWidth {
117 type Error = Error;
118 fn try_from(value: f64) -> Result<Self, Self::Error> {
119 Self::try_new(value)
120 }
121}
122
123impl std::ops::Deref for IntervalWidth {
124 type Target = f64;
125 fn deref(&self) -> &Self::Target {
126 &self.0
127 }
128}
129
130impl From<IntervalWidth> for f64 {
131 fn from(value: IntervalWidth) -> Self {
132 value.0
133 }
134}
135
136// TODO: consider getting rid of this? It's a bit weird, but it might
137// make it easier for users of the crate...
138/// Optional version of Prophet's options, before applying any defaults.
139#[derive(Default, Debug, Clone)]
140pub struct OptProphetOptions {
141 /// The type of growth (trend) to use.
142 pub growth: Option<GrowthType>,
143
144 /// An optional list of changepoints.
145 ///
146 /// If not provided, changepoints will be automatically selected.
147 pub changepoints: Option<Vec<TimestampSeconds>>,
148
149 /// The number of potential changepoints to include.
150 ///
151 /// Not used if `changepoints` is provided.
152 ///
153 /// If provided and `changepoints` is not provided, then
154 /// `n_changepoints` potential changepoints will be selected
155 /// uniformly from the first `changepoint_range` proportion of
156 /// the history.
157 pub n_changepoints: Option<u32>,
158
159 /// The proportion of the history to consider for potential changepoints.
160 ///
161 /// Not used if `changepoints` is provided.
162 pub changepoint_range: Option<PositiveFloat>,
163
164 /// How to fit yearly seasonality.
165 pub yearly_seasonality: Option<SeasonalityOption>,
166 /// How to fit weekly seasonality.
167 pub weekly_seasonality: Option<SeasonalityOption>,
168 /// How to fit daily seasonality.
169 pub daily_seasonality: Option<SeasonalityOption>,
170
171 /// How to model seasonality.
172 pub seasonality_mode: Option<FeatureMode>,
173
174 /// The prior scale for seasonality.
175 ///
176 /// This modulates the strength of seasonality,
177 /// with larger values allowing the model to fit
178 /// larger seasonal fluctuations and smaller values
179 /// dampening the seasonality.
180 ///
181 /// Can be specified for individual seasonalities
182 /// using [`Prophet::add_seasonality`](crate::Prophet::add_seasonality).
183 pub seasonality_prior_scale: Option<PositiveFloat>,
184
185 /// The prior scale for changepoints.
186 ///
187 /// This modulates the flexibility of the automatic
188 /// changepoint selection. Large values will allow many
189 /// changepoints, while small values will allow few
190 /// changepoints.
191 pub changepoint_prior_scale: Option<PositiveFloat>,
192
193 /// How to perform parameter estimation.
194 ///
195 /// When [`EstimationMode::Mle`] or [`EstimationMode::Map`]
196 /// are used then no MCMC samples are taken.
197 pub estimation: Option<EstimationMode>,
198
199 /// The width of the uncertainty intervals.
200 ///
201 /// Must be between `0.0` and `1.0`. Common values are
202 /// `0.8` (80%), `0.9` (90%) and `0.95` (95%).
203 pub interval_width: Option<IntervalWidth>,
204
205 /// The number of simulated draws used to estimate uncertainty intervals.
206 ///
207 /// Setting this value to `0` will disable uncertainty
208 /// estimation and speed up the calculation.
209 pub uncertainty_samples: Option<u32>,
210
211 /// How to scale the data prior to fitting the model.
212 pub scaling: Option<Scaling>,
213
214 /// Holidays to include in the model.
215 pub holidays: Option<HashMap<String, Holiday>>,
216 /// Prior scale for holidays.
217 ///
218 /// This parameter modulates the strength of the holiday
219 /// components model, unless overridden in each individual
220 /// holiday's input.
221 pub holidays_prior_scale: Option<PositiveFloat>,
222
223 /// How to model holidays.
224 pub holidays_mode: Option<FeatureMode>,
225}
226
227impl From<OptProphetOptions> for ProphetOptions {
228 fn from(value: OptProphetOptions) -> Self {
229 let defaults = ProphetOptions::default();
230 ProphetOptions {
231 growth: value.growth.unwrap_or(defaults.growth),
232 changepoints: value.changepoints,
233 n_changepoints: value.n_changepoints.unwrap_or(defaults.n_changepoints),
234 changepoint_range: value
235 .changepoint_range
236 .unwrap_or(defaults.changepoint_range),
237 yearly_seasonality: value
238 .yearly_seasonality
239 .unwrap_or(defaults.yearly_seasonality),
240 weekly_seasonality: value
241 .weekly_seasonality
242 .unwrap_or(defaults.weekly_seasonality),
243 daily_seasonality: value
244 .daily_seasonality
245 .unwrap_or(defaults.daily_seasonality),
246 seasonality_mode: value.seasonality_mode.unwrap_or(defaults.seasonality_mode),
247 seasonality_prior_scale: value
248 .seasonality_prior_scale
249 .unwrap_or(defaults.seasonality_prior_scale),
250 changepoint_prior_scale: value
251 .changepoint_prior_scale
252 .unwrap_or(defaults.changepoint_prior_scale),
253 estimation: value.estimation.unwrap_or(defaults.estimation),
254 interval_width: value.interval_width.unwrap_or(defaults.interval_width),
255 uncertainty_samples: value
256 .uncertainty_samples
257 .unwrap_or(defaults.uncertainty_samples),
258 scaling: value.scaling.unwrap_or(defaults.scaling),
259 holidays: value.holidays.unwrap_or(defaults.holidays),
260 holidays_prior_scale: value
261 .holidays_prior_scale
262 .unwrap_or(defaults.holidays_prior_scale),
263 holidays_mode: value.holidays_mode,
264 }
265 }
266}
267
268/// Options for Prophet, after applying defaults.
269#[derive(Debug, Clone)]
270pub struct ProphetOptions {
271 /// The type of growth (trend) to use.
272 ///
273 /// Defaults to [`GrowthType::Linear`].
274 pub growth: GrowthType,
275
276 /// An optional list of changepoints.
277 ///
278 /// If not provided, changepoints will be automatically selected.
279 pub changepoints: Option<Vec<TimestampSeconds>>,
280
281 /// The number of potential changepoints to include.
282 ///
283 /// Not used if `changepoints` is provided.
284 ///
285 /// If provided and `changepoints` is not provided, then
286 /// `n_changepoints` potential changepoints will be selected
287 /// uniformly from the first `changepoint_range` proportion of
288 /// the history.
289 ///
290 /// Defaults to 25.
291 pub n_changepoints: u32,
292
293 /// The proportion of the history to consider for potential changepoints.
294 ///
295 /// Not used if `changepoints` is provided.
296 ///
297 /// Defaults to `0.8` for the first 80% of the data.
298 pub changepoint_range: PositiveFloat,
299
300 /// How to fit yearly seasonality.
301 ///
302 /// Defaults to [`SeasonalityOption::Auto`].
303 pub yearly_seasonality: SeasonalityOption,
304 /// How to fit weekly seasonality.
305 ///
306 /// Defaults to [`SeasonalityOption::Auto`].
307 pub weekly_seasonality: SeasonalityOption,
308 /// How to fit daily seasonality.
309 ///
310 /// Defaults to [`SeasonalityOption::Auto`].
311 pub daily_seasonality: SeasonalityOption,
312
313 /// How to model seasonality.
314 ///
315 /// Defaults to [`FeatureMode::Additive`].
316 pub seasonality_mode: FeatureMode,
317
318 /// The prior scale for seasonality.
319 ///
320 /// This modulates the strength of seasonality,
321 /// with larger values allowing the model to fit
322 /// larger seasonal fluctuations and smaller values
323 /// dampening the seasonality.
324 ///
325 /// Can be specified for individual seasonalities
326 /// using [`Prophet::add_seasonality`](crate::Prophet::add_seasonality).
327 ///
328 /// Defaults to `10.0`.
329 pub seasonality_prior_scale: PositiveFloat,
330
331 /// The prior scale for changepoints.
332 ///
333 /// This modulates the flexibility of the automatic
334 /// changepoint selection. Large values will allow many
335 /// changepoints, while small values will allow few
336 /// changepoints.
337 ///
338 /// Defaults to `0.05`.
339 pub changepoint_prior_scale: PositiveFloat,
340
341 /// How to perform parameter estimation.
342 ///
343 /// When [`EstimationMode::Mle`] or [`EstimationMode::Map`]
344 /// are used then no MCMC samples are taken.
345 ///
346 /// Defaults to [`EstimationMode::Mle`].
347 pub estimation: EstimationMode,
348
349 /// The width of the uncertainty intervals.
350 ///
351 /// Must be between `0.0` and `1.0`. Common values are
352 /// `0.8` (80%), `0.9` (90%) and `0.95` (95%).
353 ///
354 /// Defaults to `0.8` for 80% intervals.
355 pub interval_width: IntervalWidth,
356
357 /// The number of simulated draws used to estimate uncertainty intervals.
358 ///
359 /// Setting this value to `0` will disable uncertainty
360 /// estimation and speed up the calculation.
361 ///
362 /// Defaults to `1000`.
363 pub uncertainty_samples: u32,
364
365 /// How to scale the data prior to fitting the model.
366 ///
367 /// Defaults to [`Scaling::AbsMax`].
368 pub scaling: Scaling,
369
370 /// Holidays to include in the model.
371 pub holidays: HashMap<String, Holiday>,
372 /// Prior scale for holidays.
373 ///
374 /// This parameter modulates the strength of the holiday
375 /// components model, unless overridden in each individual
376 /// holiday's input.
377 ///
378 /// Defaults to `100.0`.
379 pub holidays_prior_scale: PositiveFloat,
380
381 /// How to model holidays.
382 ///
383 /// Defaults to the same value as [`ProphetOptions::seasonality_mode`].
384 pub holidays_mode: Option<FeatureMode>,
385}
386
387impl Default for ProphetOptions {
388 fn default() -> Self {
389 Self {
390 growth: GrowthType::Linear,
391 changepoints: None,
392 n_changepoints: 25,
393 changepoint_range: 0.8.try_into().unwrap(),
394 yearly_seasonality: SeasonalityOption::default(),
395 weekly_seasonality: SeasonalityOption::default(),
396 daily_seasonality: SeasonalityOption::default(),
397 seasonality_mode: FeatureMode::Additive,
398 seasonality_prior_scale: 10.0.try_into().unwrap(),
399 changepoint_prior_scale: 0.05.try_into().unwrap(),
400 estimation: EstimationMode::Mle,
401 interval_width: IntervalWidth::default(),
402 uncertainty_samples: 1000,
403 scaling: Scaling::AbsMax,
404 holidays: HashMap::new(),
405 holidays_prior_scale: 100.0.try_into().unwrap(),
406 holidays_mode: None,
407 }
408 }
409}