1#![warn(
8 missing_docs,
9 missing_debug_implementations,
10 rust_2018_idioms,
11 unreachable_pub
12)]
13
14use numpy::{PyArray1, ToPyArray};
15use pyo3::prelude::*;
16
17pub mod ets;
18pub mod mstl;
19pub mod trend;
20
21#[derive(Debug, Clone)]
23#[pyclass]
24pub struct Forecast {
25 inner: augurs_core::Forecast,
26}
27
28impl From<augurs_core::Forecast> for Forecast {
29 fn from(inner: augurs_core::Forecast) -> Self {
30 Self { inner }
31 }
32}
33
34impl From<Forecast> for augurs_core::Forecast {
35 fn from(forecast: Forecast) -> Self {
36 forecast.inner
37 }
38}
39
40#[pymethods]
41impl Forecast {
42 #[new]
43 fn new(
44 py: Python<'_>,
45 point: Py<PyArray1<f64>>,
46 level: Option<f64>,
47 lower: Option<Py<PyArray1<f64>>>,
48 upper: Option<Py<PyArray1<f64>>>,
49 ) -> pyo3::PyResult<Self> {
50 Ok(Self {
51 inner: augurs_core::Forecast {
52 point: point.extract(py)?,
53 intervals: level
54 .zip(lower)
55 .zip(upper)
56 .map(|((level, lower), upper)| {
57 Ok::<_, PyErr>(augurs_core::ForecastIntervals {
58 level,
59 lower: lower.extract(py)?,
60 upper: upper.extract(py)?,
61 })
62 })
63 .transpose()?,
64 },
65 })
66 }
67
68 fn __repr__(&self) -> String {
69 let intervals = self.inner.intervals.as_ref();
70 format!(
71 "Forecast(point={:?}, level={:?}, lower={:?}, upper={:?})",
72 self.inner.point,
73 intervals.map(|x| x.level),
74 intervals.map(|x| &x.lower),
75 intervals.map(|x| &x.upper)
76 )
77 }
78
79 fn point(&self, py: Python<'_>) -> Py<PyArray1<f64>> {
81 self.inner.point.to_pyarray(py).into()
86 }
87
88 fn lower(&self, py: Python<'_>) -> Option<Py<PyArray1<f64>>> {
90 self.inner
91 .intervals
92 .as_ref()
93 .map(|x| x.lower.to_pyarray(py).into())
94 }
95
96 fn upper(&self, py: Python<'_>) -> Option<Py<PyArray1<f64>>> {
98 self.inner
99 .intervals
100 .as_ref()
101 .map(|x| x.upper.to_pyarray(py).into())
102 }
103}
104
105#[pymodule]
107fn augurs(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
108 pyo3_log::init();
109 m.add_class::<ets::AutoETS>()?;
110 m.add_class::<mstl::MSTL>()?;
111 m.add_class::<trend::PyTrendModel>()?;
112 m.add_class::<Forecast>()?;
113 Ok(())
114}