augurs/
trend.rs

1//! Bindings for trend models implemented in Python.
2//!
3//! This module provides the [`PyTrendModel`] struct, which wraps a Python
4//! class which implements a trend model. This allows users to implement their
5//! trend models in Python and use them in the MSTL algorithm using
6//! [`MSTL::custom_trend`][crate::mstl::MSTL::custom_trend].
7//!
8//! The Python class must implement the following methods:
9//!
10//! - `fit(self, y: np.ndarray) -> None`
11//! - `predict(self, horizon: int, level: float | None = None) -> augurs.Forecast`
12//! - `predict_in_sample(self, level: float | None = None) -> augurs.Forecast`
13use numpy::ToPyArray;
14use pyo3::{exceptions::PyException, prelude::*};
15
16use augurs_mstl::TrendModel;
17
18use crate::Forecast;
19
20/// A Python wrapper for a trend model.
21///
22/// This allows users to implement their own trend models in Python and use
23/// them in the MSTL algorithm using [`MSTL::custom_trend`][crate::mstl::MSTL::custom_trend].
24///
25/// The Python class must implement the following methods:
26///
27/// - `fit(self, y: np.ndarray) -> None`
28/// - `predict(self, horizon: int, level: float | None = None) -> augurs.Forecast`
29/// - `predict_in_sample(self, level: float | None = None) -> augurs.Forecast`
30#[pyclass(name = "TrendModel")]
31#[derive(Clone, Debug)]
32pub struct PyTrendModel {
33    model: Py<PyAny>,
34}
35
36#[pymethods]
37impl PyTrendModel {
38    fn __repr__(&self) -> String {
39        format!("PyTrendModel(model=\"{}\")", self.name())
40    }
41
42    /// Wrap a trend model implemented in Python into a PyTrendModel.
43    ///
44    /// The returned PyTrendModel can be used in MSTL models using the
45    /// `custom_trend` method of the MSTL class.
46    #[new]
47    pub fn new(model: Py<PyAny>) -> Self {
48        Self { model }
49    }
50}
51
52impl TrendModel for PyTrendModel {
53    fn name(&self) -> std::borrow::Cow<'_, str> {
54        Python::with_gil(|py| {
55            self.model
56                .as_ref(py)
57                .get_type()
58                .name()
59                .map(|s| s.to_owned().into())
60        })
61        .unwrap_or_else(|_| "unknown Python class".into())
62    }
63
64    fn fit(&mut self, y: &[f64]) -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> {
65        Python::with_gil(|py| {
66            let np = y.to_pyarray(py);
67            self.model.call_method1(py, "fit", (np,))
68        })?;
69        Ok(())
70    }
71
72    fn predict(
73        &self,
74        horizon: usize,
75        level: Option<f64>,
76    ) -> Result<augurs_core::Forecast, Box<dyn std::error::Error + Send + Sync + 'static>> {
77        Python::with_gil(|py| {
78            let preds = self
79                .model
80                .call_method1(py, "predict", (horizon, level))
81                .map_err(|e| Box::new(PyException::new_err(format!("error predicting: {e}"))))?;
82            let preds: Forecast = preds.extract(py)?;
83            Ok(preds.into())
84        })
85    }
86
87    fn predict_in_sample(
88        &self,
89        level: Option<f64>,
90    ) -> Result<augurs_core::Forecast, Box<dyn std::error::Error + Send + Sync + 'static>> {
91        Python::with_gil(|py| {
92            let preds = self
93                .model
94                .call_method1(py, "predict_in_sample", (level,))
95                .map_err(|e| {
96                    Box::new(PyException::new_err(format!(
97                        "error predicting in-sample: {e}"
98                    )))
99                })?;
100            let preds: Forecast = preds.extract(py)?;
101            Ok(preds.into())
102        })
103    }
104}