augurs/trend.rs
1//! Bindings for trend models implemented in Python.
2//!
3//! This module provides the [`PyTrendModel`] struct, which wraps a Python
4//! class which implements a trend model. This allows users to implement their
5//! trend models in Python and use them in the MSTL algorithm using
6//! [`MSTL::custom_trend`][crate::mstl::MSTL::custom_trend].
7//!
8//! The Python class must implement the following methods:
9//!
10//! - `fit(self, y: np.ndarray) -> None`
11//! - `predict(self, horizon: int, level: float | None = None) -> augurs.Forecast`
12//! - `predict_in_sample(self, level: float | None = None) -> augurs.Forecast`
13use numpy::ToPyArray;
14use pyo3::{exceptions::PyException, prelude::*};
15
16use augurs_mstl::TrendModel;
17
18use crate::Forecast;
19
20/// A Python wrapper for a trend model.
21///
22/// This allows users to implement their own trend models in Python and use
23/// them in the MSTL algorithm using [`MSTL::custom_trend`][crate::mstl::MSTL::custom_trend].
24///
25/// The Python class must implement the following methods:
26///
27/// - `fit(self, y: np.ndarray) -> None`
28/// - `predict(self, horizon: int, level: float | None = None) -> augurs.Forecast`
29/// - `predict_in_sample(self, level: float | None = None) -> augurs.Forecast`
30#[pyclass(name = "TrendModel")]
31#[derive(Clone, Debug)]
32pub struct PyTrendModel {
33 model: Py<PyAny>,
34}
35
36#[pymethods]
37impl PyTrendModel {
38 fn __repr__(&self) -> String {
39 format!("PyTrendModel(model=\"{}\")", self.name())
40 }
41
42 /// Wrap a trend model implemented in Python into a PyTrendModel.
43 ///
44 /// The returned PyTrendModel can be used in MSTL models using the
45 /// `custom_trend` method of the MSTL class.
46 #[new]
47 pub fn new(model: Py<PyAny>) -> Self {
48 Self { model }
49 }
50}
51
52impl TrendModel for PyTrendModel {
53 fn name(&self) -> std::borrow::Cow<'_, str> {
54 Python::with_gil(|py| {
55 self.model
56 .as_ref(py)
57 .get_type()
58 .name()
59 .map(|s| s.to_owned().into())
60 })
61 .unwrap_or_else(|_| "unknown Python class".into())
62 }
63
64 fn fit(&mut self, y: &[f64]) -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> {
65 Python::with_gil(|py| {
66 let np = y.to_pyarray(py);
67 self.model.call_method1(py, "fit", (np,))
68 })?;
69 Ok(())
70 }
71
72 fn predict(
73 &self,
74 horizon: usize,
75 level: Option<f64>,
76 ) -> Result<augurs_core::Forecast, Box<dyn std::error::Error + Send + Sync + 'static>> {
77 Python::with_gil(|py| {
78 let preds = self
79 .model
80 .call_method1(py, "predict", (horizon, level))
81 .map_err(|e| Box::new(PyException::new_err(format!("error predicting: {e}"))))?;
82 let preds: Forecast = preds.extract(py)?;
83 Ok(preds.into())
84 })
85 }
86
87 fn predict_in_sample(
88 &self,
89 level: Option<f64>,
90 ) -> Result<augurs_core::Forecast, Box<dyn std::error::Error + Send + Sync + 'static>> {
91 Python::with_gil(|py| {
92 let preds = self
93 .model
94 .call_method1(py, "predict_in_sample", (level,))
95 .map_err(|e| {
96 Box::new(PyException::new_err(format!(
97 "error predicting in-sample: {e}"
98 )))
99 })?;
100 let preds: Forecast = preds.extract(py)?;
101 Ok(preds.into())
102 })
103 }
104}