augurs/
lib.rs

1//! Python bindings for the augurs time series framework.
2//!
3//! These bindings are intended to be used from Python, and are not useful from Rust.
4//! The documentation here is useful for understanding the Python API, however.
5//!
6//! See the crate README for information on Python API usage and installation.
7#![warn(
8    missing_docs,
9    missing_debug_implementations,
10    rust_2018_idioms,
11    unreachable_pub
12)]
13
14use numpy::{PyArray1, ToPyArray};
15use pyo3::prelude::*;
16
17pub mod ets;
18pub mod mstl;
19pub mod trend;
20
21/// Forecasts produced by augurs models.
22#[derive(Debug, Clone)]
23#[pyclass]
24pub struct Forecast {
25    inner: augurs_core::Forecast,
26}
27
28impl From<augurs_core::Forecast> for Forecast {
29    fn from(inner: augurs_core::Forecast) -> Self {
30        Self { inner }
31    }
32}
33
34impl From<Forecast> for augurs_core::Forecast {
35    fn from(forecast: Forecast) -> Self {
36        forecast.inner
37    }
38}
39
40#[pymethods]
41impl Forecast {
42    #[new]
43    fn new(
44        py: Python<'_>,
45        point: Py<PyArray1<f64>>,
46        level: Option<f64>,
47        lower: Option<Py<PyArray1<f64>>>,
48        upper: Option<Py<PyArray1<f64>>>,
49    ) -> pyo3::PyResult<Self> {
50        Ok(Self {
51            inner: augurs_core::Forecast {
52                point: point.extract(py)?,
53                intervals: level
54                    .zip(lower)
55                    .zip(upper)
56                    .map(|((level, lower), upper)| {
57                        Ok::<_, PyErr>(augurs_core::ForecastIntervals {
58                            level,
59                            lower: lower.extract(py)?,
60                            upper: upper.extract(py)?,
61                        })
62                    })
63                    .transpose()?,
64            },
65        })
66    }
67
68    fn __repr__(&self) -> String {
69        let intervals = self.inner.intervals.as_ref();
70        format!(
71            "Forecast(point={:?}, level={:?}, lower={:?}, upper={:?})",
72            self.inner.point,
73            intervals.map(|x| x.level),
74            intervals.map(|x| &x.lower),
75            intervals.map(|x| &x.upper)
76        )
77    }
78
79    /// Get the point forecast.
80    fn point(&self, py: Python<'_>) -> Py<PyArray1<f64>> {
81        // Use `to_pyarray` to allocate a new array on the Python heap.
82        // We could also use `into_pyarray` to construct the
83        // numpy arrays in the Rust heap; let's see which ends up being
84        // faster and more convenient.
85        self.inner.point.to_pyarray(py).into()
86    }
87
88    /// Get the lower prediction interval.
89    fn lower(&self, py: Python<'_>) -> Option<Py<PyArray1<f64>>> {
90        self.inner
91            .intervals
92            .as_ref()
93            .map(|x| x.lower.to_pyarray(py).into())
94    }
95
96    /// Get the upper prediction interval.
97    fn upper(&self, py: Python<'_>) -> Option<Py<PyArray1<f64>>> {
98        self.inner
99            .intervals
100            .as_ref()
101            .map(|x| x.upper.to_pyarray(py).into())
102    }
103}
104
105/// Python bindings for the augurs time series framework.
106#[pymodule]
107fn augurs(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
108    pyo3_log::init();
109    m.add_class::<ets::AutoETS>()?;
110    m.add_class::<mstl::MSTL>()?;
111    m.add_class::<trend::PyTrendModel>()?;
112    m.add_class::<Forecast>()?;
113    Ok(())
114}