Skip to main content

augurs_prophet/
optimizer.rs

1//! Methods for optimizing the Prophet model.
2//!
3//! This module contains the `Optimize` trait, which represents
4//! a way of finding the optimal parameters for the Prophet model
5//! given the data.
6//!
7//! The original Prophet library uses Stan for this; specifically,
8//! it uses the `optimize` command of Stan to find the maximum
9//! likelihood estimate (or maximum a-priori estimates) of the
10//! parameters.
11//!
12//! The `cmdstan` feature of this crate provides an implementation
13//! of the `Optimize` trait that uses `cmdstan` to do the same.
14//! This requires a working installation of `cmdstan`.
15//!
16//! The `libstan` feature uses FFI calls to call out to the Stan
17//! C++ library to do the same. This requires a C++ compiler.
18//!
19// TODO: actually add these features.
20// TODO: come up with a way of doing something in WASM. Maybe
21//       WASM Components?
22// TODO: write a pure Rust optimizer for the default case.
23
24use std::{fmt, sync::Arc};
25
26use crate::positive_float::PositiveFloat;
27
28/// The initial parameters for the optimization.
29#[derive(Clone, Debug, PartialEq)]
30#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
31pub struct InitialParams {
32    /// Base trend growth rate.
33    pub k: f64,
34    /// Trend offset.
35    pub m: f64,
36    /// Trend rate adjustments, length s in data.
37    pub delta: Vec<f64>,
38    /// Regressor coefficients, length k in data.
39    pub beta: Vec<f64>,
40    /// Observation noise.
41    pub sigma_obs: PositiveFloat,
42}
43
44/// The type of trend to use.
45#[derive(Clone, Debug, Copy, Eq, PartialEq)]
46pub enum TrendIndicator {
47    /// Linear trend (default).
48    Linear,
49    /// Logistic trend.
50    Logistic,
51    /// Flat trend.
52    Flat,
53}
54
55#[cfg(feature = "serde")]
56impl serde::Serialize for TrendIndicator {
57    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
58    where
59        S: serde::Serializer,
60    {
61        serializer.serialize_u8(match self {
62            Self::Linear => 0,
63            Self::Logistic => 1,
64            Self::Flat => 2,
65        })
66    }
67}
68
69#[cfg(feature = "serde")]
70impl<'de> serde::Deserialize<'de> for TrendIndicator {
71    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
72    where
73        D: serde::Deserializer<'de>,
74        D::Error: serde::de::Error,
75    {
76        let value = u8::deserialize(deserializer)?;
77        match value {
78            0 => Ok(Self::Linear),
79            1 => Ok(Self::Logistic),
80            2 => Ok(Self::Flat),
81            _ => Err(serde::de::Error::custom("invalid trend indicator")),
82        }
83    }
84}
85
86/// Data for the Prophet model.
87#[derive(Clone, Debug, PartialEq)]
88#[allow(non_snake_case)]
89pub struct Data {
90    /// Number of time periods.
91    pub T: i32,
92    /// Time series, length n.
93    pub y: Vec<f64>,
94    /// Time, length n.
95    pub t: Vec<f64>,
96    /// Capacities for logistic trend, length n.
97    pub cap: Vec<f64>,
98    /// Number of changepoints.
99    pub S: i32,
100    /// Times of trend changepoints, length s.
101    pub t_change: Vec<f64>,
102    /// The type of trend to use.
103    pub trend_indicator: TrendIndicator,
104    /// Number of regressors.
105    /// Must be greater than or equal to 1.
106    pub K: i32,
107    /// Indicator of additive features, length k.
108    pub s_a: Vec<i32>,
109    /// Indicator of multiplicative features, length k.
110    pub s_m: Vec<i32>,
111    /// Regressors, shape (n, k).
112    ///
113    /// This is stored as a `Vec<f64>` rather than a nested `Vec<Vec<f64>>`
114    /// because passing such a struct by reference is tricky in Rust, since
115    /// it can't be dereferenced to a `&[&[f64]]` (which would be ideal).
116    ///
117    /// However, when serialized to JSON, it is converted to a nested array
118    /// of arrays, which is what cmdstan expects.
119    pub X: Vec<f64>,
120    /// Scale on seasonality prior.
121    pub sigmas: Vec<PositiveFloat>,
122    /// Scale on changepoints prior.
123    /// Must be greater than 0.
124    pub tau: PositiveFloat,
125}
126
127#[cfg(feature = "serde")]
128impl serde::Serialize for Data {
129    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
130    where
131        S: serde::Serializer,
132    {
133        use serde::ser::{SerializeSeq, SerializeStruct};
134
135        /// A serializer which serializes X, a flat slice of f64s, as an sequence of sequences,
136        /// with each one having length equal to the second field.
137        struct XSerializer<'a>(&'a [f64], usize);
138
139        impl serde::Serialize for XSerializer<'_> {
140            fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
141            where
142                S: serde::Serializer,
143            {
144                if self.1 == 0 {
145                    return Err(serde::ser::Error::custom(
146                        "Invalid value for K: cannot be zero",
147                    ));
148                }
149                let chunk_size = self.1;
150                let mut outer = serializer.serialize_seq(Some(self.0.len() / chunk_size))?;
151                for chunk in self.0.chunks(chunk_size) {
152                    outer.serialize_element(&chunk)?;
153                }
154                outer.end()
155            }
156        }
157
158        let mut s = serializer.serialize_struct("Data", 13)?;
159        let x = XSerializer(&self.X, self.K as usize);
160        s.serialize_field("T", &self.T)?;
161        s.serialize_field("y", &self.y)?;
162        s.serialize_field("t", &self.t)?;
163        s.serialize_field("cap", &self.cap)?;
164        s.serialize_field("S", &self.S)?;
165        s.serialize_field("t_change", &self.t_change)?;
166        s.serialize_field("trend_indicator", &self.trend_indicator)?;
167        s.serialize_field("K", &self.K)?;
168        s.serialize_field("s_a", &self.s_a)?;
169        s.serialize_field("s_m", &self.s_m)?;
170        s.serialize_field("X", &x)?;
171        s.serialize_field("sigmas", &self.sigmas)?;
172        s.serialize_field("tau", &self.tau)?;
173        s.end()
174    }
175}
176
177/// The algorithm to use for optimization. One of: 'BFGS', 'LBFGS', 'Newton'.
178#[derive(Debug, Clone, Copy, Eq, PartialEq)]
179pub enum Algorithm {
180    /// Use the Newton algorithm.
181    Newton,
182    /// Use the Broyden-Fletcher-Goldfarb-Shanno (BFGS) algorithm.
183    Bfgs,
184    /// Use the Limited-memory BFGS (L-BFGS) algorithm.
185    Lbfgs,
186}
187
188impl fmt::Display for Algorithm {
189    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
190        let s = match self {
191            Self::Lbfgs => "lbfgs",
192            Self::Newton => "newton",
193            Self::Bfgs => "bfgs",
194        };
195        f.write_str(s)
196    }
197}
198
199/// Arguments for optimization.
200#[derive(Default, Debug, Clone)]
201pub struct OptimizeOpts {
202    /// Algorithm to use.
203    pub algorithm: Option<Algorithm>,
204    /// The random seed to use for the optimization.
205    pub seed: Option<u32>,
206    /// The chain id to advance the PRNG.
207    pub chain: Option<u32>,
208    /// Line search step size for first iteration.
209    pub init_alpha: Option<f64>,
210    /// Convergence tolerance on changes in objective function value.
211    pub tol_obj: Option<f64>,
212    /// Convergence tolerance on relative changes in objective function value.
213    pub tol_rel_obj: Option<f64>,
214    /// Convergence tolerance on the norm of the gradient.
215    pub tol_grad: Option<f64>,
216    /// Convergence tolerance on the relative norm of the gradient.
217    pub tol_rel_grad: Option<f64>,
218    /// Convergence tolerance on changes in parameter value.
219    pub tol_param: Option<f64>,
220    /// Size of the history for LBFGS Hessian approximation. The value should
221    /// be less than the dimensionality of the parameter space. 5-10 usually
222    /// sufficient.
223    pub history_size: Option<u32>,
224    /// Total number of iterations.
225    pub iter: Option<u32>,
226    /// When `true`, use the Jacobian matrix to approximate the Hessian.
227    /// Default is `false`.
228    pub jacobian: Option<bool>,
229    /// How frequently to emit convergence statistics, in number of iterations.
230    pub refresh: Option<u32>,
231}
232
233/// The optimized parameters.
234#[derive(Debug, Clone)]
235pub struct OptimizedParams {
236    /// Base trend growth rate.
237    pub k: f64,
238    /// Trend offset.
239    pub m: f64,
240    /// Observation noise.
241    pub sigma_obs: PositiveFloat,
242    /// Trend rate adjustments.
243    pub delta: Vec<f64>,
244    /// Regressor coefficients.
245    pub beta: Vec<f64>,
246    /// Transformed trend.
247    pub trend: Vec<f64>,
248}
249
250/// An error that occurred during the optimization procedure.
251#[derive(Debug, thiserror::Error)]
252#[error(transparent)]
253pub struct Error(
254    /// The kind of error that occurred.
255    ///
256    /// This is a private field so that we can evolve
257    /// the `ErrorKind` enum without breaking changes.
258    #[from]
259    ErrorKind,
260);
261
262impl Error {
263    /// A static string error.
264    pub fn static_str(s: &'static str) -> Self {
265        Self(ErrorKind::StaticStr(s))
266    }
267
268    /// A string error.
269    pub fn string(s: String) -> Self {
270        Self(ErrorKind::String(s))
271    }
272
273    /// A custom error, which is any type that implements `std::error::Error`.
274    pub fn custom<E: std::error::Error + 'static>(e: E) -> Self {
275        Self(ErrorKind::Custom(Box::new(e)))
276    }
277}
278
279#[derive(Debug, thiserror::Error)]
280enum ErrorKind {
281    #[error("Error in optimization: {0}")]
282    StaticStr(&'static str),
283    #[error("Error in optimization: {0}")]
284    String(String),
285    #[error("Error in optimization: {0}")]
286    Custom(#[from] Box<dyn std::error::Error>),
287}
288
289/// A type that can run maximum likelihood estimation optimization
290/// for the Prophet model.
291pub trait Optimizer: std::fmt::Debug {
292    /// Find the maximum likelihood estimate of the parameters given the
293    /// data, initial parameters and optimization options.
294    fn optimize(
295        &self,
296        init: &InitialParams,
297        data: &Data,
298        opts: &OptimizeOpts,
299    ) -> Result<OptimizedParams, Error>;
300}
301
302/// An implementation of `Optimize` which simply delegates to the
303/// `Arc`'s inner type. This enables thread-safe sharing of optimizers
304/// while maintaining the ability to use dynamic dispatch.
305impl Optimizer for Arc<dyn Optimizer> {
306    fn optimize(
307        &self,
308        init: &InitialParams,
309        data: &Data,
310        opts: &OptimizeOpts,
311    ) -> Result<OptimizedParams, Error> {
312        (**self).optimize(init, data, opts)
313    }
314}
315
316#[cfg(test)]
317pub(crate) mod mock_optimizer {
318    use std::cell::RefCell;
319
320    use super::*;
321
322    #[derive(Debug, Clone)]
323    pub(crate) struct OptimizeCall {
324        pub init: InitialParams,
325        pub data: Data,
326        pub _opts: OptimizeOpts,
327    }
328
329    /// A mock optimizer that records the optimization call.
330    #[derive(Clone, Debug)]
331    pub(crate) struct MockOptimizer {
332        /// The optimization call.
333        ///
334        /// This will be updated by the mock optimizer when
335        /// [`Optimizer::optimize`] is called.
336        // [`Optimizer::optimize`] takes self by reference,
337        // so we need to store the call in a RefCell.
338        pub call: RefCell<Option<OptimizeCall>>,
339    }
340
341    impl MockOptimizer {
342        /// Create a new mock optimizer.
343        pub(crate) fn new() -> Self {
344            Self {
345                call: RefCell::new(None),
346            }
347        }
348
349        /// Take the optimization call out of the mock.
350        pub(crate) fn take_call(&self) -> Option<OptimizeCall> {
351            self.call.borrow_mut().take()
352        }
353    }
354
355    impl Optimizer for MockOptimizer {
356        fn optimize(
357            &self,
358            init: &InitialParams,
359            data: &Data,
360            opts: &OptimizeOpts,
361        ) -> Result<OptimizedParams, Error> {
362            *self.call.borrow_mut() = Some(OptimizeCall {
363                init: init.clone(),
364                data: data.clone(),
365                _opts: opts.clone(),
366            });
367            Ok(OptimizedParams {
368                k: init.k,
369                m: init.m,
370                sigma_obs: init.sigma_obs,
371                delta: init.delta.clone(),
372                beta: init.beta.clone(),
373                trend: Vec::new(),
374            })
375        }
376    }
377}
378
379#[cfg(test)]
380mod tests {
381
382    #[cfg(feature = "serde")]
383    #[test]
384    fn serialize_data() {
385        use super::*;
386        let data = Data {
387            T: 3,
388            y: vec![1.0, 2.0, 3.0],
389            t: vec![0.0, 1.0, 2.0],
390            X: vec![1.0, 2.0, 3.0, 1.0, 2.0, 3.0],
391            sigmas: vec![
392                1.0.try_into().unwrap(),
393                2.0.try_into().unwrap(),
394                3.0.try_into().unwrap(),
395            ],
396            tau: 1.0.try_into().unwrap(),
397            K: 2,
398            s_a: vec![1, 1, 1],
399            s_m: vec![0, 0, 0],
400            cap: vec![0.0, 0.0, 0.0],
401            S: 2,
402            t_change: vec![0.0, 0.0, 0.0],
403            trend_indicator: TrendIndicator::Linear,
404        };
405        let serialized = serde_json::to_string_pretty(&data).unwrap();
406        pretty_assertions::assert_eq!(
407            serialized,
408            r#"{
409  "T": 3,
410  "y": [
411    1.0,
412    2.0,
413    3.0
414  ],
415  "t": [
416    0.0,
417    1.0,
418    2.0
419  ],
420  "cap": [
421    0.0,
422    0.0,
423    0.0
424  ],
425  "S": 2,
426  "t_change": [
427    0.0,
428    0.0,
429    0.0
430  ],
431  "trend_indicator": 0,
432  "K": 2,
433  "s_a": [
434    1,
435    1,
436    1
437  ],
438  "s_m": [
439    0,
440    0,
441    0
442  ],
443  "X": [
444    [
445      1.0,
446      2.0
447    ],
448    [
449      3.0,
450      1.0
451    ],
452    [
453      2.0,
454      3.0
455    ]
456  ],
457  "sigmas": [
458    1.0,
459    2.0,
460    3.0
461  ],
462  "tau": 1.0
463}"#
464        );
465    }
466}