use numpy::{PyArray1, PyReadonlyArray1};
use pyo3::prelude::*;
use crate::inference::{ConformalPredictor, Prediction};
#[pyclass(name = "Prediction")]
#[derive(Clone)]
pub struct PyPrediction {
inner: Prediction,
}
#[pymethods]
impl PyPrediction {
#[staticmethod]
fn point(value: f32) -> Self {
Self {
inner: Prediction::point(value),
}
}
#[staticmethod]
fn with_interval(point: f32, lower: f32, upper: f32) -> Self {
Self {
inner: Prediction::with_interval(point, lower, upper),
}
}
#[getter]
fn get_point(&self) -> f32 {
self.inner.point
}
#[getter]
fn lower(&self) -> Option<f32> {
self.inner.lower
}
#[getter]
fn upper(&self) -> Option<f32> {
self.inner.upper
}
#[getter]
fn has_interval(&self) -> bool {
self.inner.has_interval()
}
#[getter]
fn interval_width(&self) -> Option<f32> {
self.inner.interval_width()
}
fn __repr__(&self) -> String {
if self.inner.has_interval() {
format!(
"Prediction(point={:.4}, lower={:.4}, upper={:.4})",
self.inner.point,
self.inner.lower.unwrap(),
self.inner.upper.unwrap()
)
} else {
format!("Prediction(point={:.4})", self.inner.point)
}
}
}
impl From<Prediction> for PyPrediction {
fn from(pred: Prediction) -> Self {
Self { inner: pred }
}
}
#[pyclass(name = "ConformalPredictor")]
#[derive(Clone)]
pub struct PyConformalPredictor {
inner: ConformalPredictor,
}
#[pymethods]
impl PyConformalPredictor {
#[staticmethod]
#[pyo3(signature = (residuals, coverage=0.9))]
fn from_residuals<'py>(residuals: PyReadonlyArray1<'py, f32>, coverage: f32) -> PyResult<Self> {
if coverage <= 0.0 || coverage >= 1.0 {
return Err(pyo3::exceptions::PyValueError::new_err(
"coverage must be in (0, 1)",
));
}
let residuals_arr = residuals.as_array();
let residuals_vec: Vec<f32> = residuals_arr.to_vec();
Ok(Self {
inner: ConformalPredictor::from_residuals(&residuals_vec, coverage),
})
}
#[staticmethod]
#[pyo3(signature = (quantile, coverage=0.9))]
fn from_quantile(quantile: f32, coverage: f32) -> PyResult<Self> {
if coverage <= 0.0 || coverage >= 1.0 {
return Err(pyo3::exceptions::PyValueError::new_err(
"coverage must be in (0, 1)",
));
}
if quantile < 0.0 {
return Err(pyo3::exceptions::PyValueError::new_err(
"quantile must be non-negative",
));
}
Ok(Self {
inner: ConformalPredictor::from_quantile(quantile, coverage),
})
}
fn predict(&self, point: f32) -> PyPrediction {
self.inner.predict(point).into()
}
fn predict_batch<'py>(&self, points: PyReadonlyArray1<'py, f32>) -> Vec<PyPrediction> {
let points_arr = points.as_array();
let points_vec: Vec<f32> = points_arr.to_vec();
self.inner
.predict_batch(&points_vec)
.into_iter()
.map(|p| p.into())
.collect()
}
fn predict_batch_lower<'py>(
&self,
py: Python<'py>,
points: PyReadonlyArray1<'py, f32>,
) -> Bound<'py, PyArray1<f32>> {
let points_arr = points.as_array();
let lower: Vec<f32> = points_arr
.iter()
.map(|&p| p - self.inner.quantile())
.collect();
PyArray1::from_vec(py, lower)
}
fn predict_batch_upper<'py>(
&self,
py: Python<'py>,
points: PyReadonlyArray1<'py, f32>,
) -> Bound<'py, PyArray1<f32>> {
let points_arr = points.as_array();
let upper: Vec<f32> = points_arr
.iter()
.map(|&p| p + self.inner.quantile())
.collect();
PyArray1::from_vec(py, upper)
}
#[getter]
fn quantile(&self) -> f32 {
self.inner.quantile()
}
#[getter]
fn coverage(&self) -> f32 {
self.inner.coverage()
}
fn empirical_coverage<'py>(
&self,
true_values: PyReadonlyArray1<'py, f32>,
predictions: PyReadonlyArray1<'py, f32>,
) -> PyResult<f32> {
let true_arr = true_values.as_array();
let pred_arr = predictions.as_array();
if true_arr.len() != pred_arr.len() {
return Err(pyo3::exceptions::PyValueError::new_err(format!(
"true_values length {} doesn't match predictions length {}",
true_arr.len(),
pred_arr.len()
)));
}
let true_vec: Vec<f32> = true_arr.to_vec();
let pred_vec: Vec<f32> = pred_arr.to_vec();
Ok(self.inner.empirical_coverage(&true_vec, &pred_vec))
}
fn __repr__(&self) -> String {
format!(
"ConformalPredictor(quantile={:.4}, coverage={:.2})",
self.inner.quantile(),
self.inner.coverage()
)
}
}
impl From<ConformalPredictor> for PyConformalPredictor {
fn from(pred: ConformalPredictor) -> Self {
Self { inner: pred }
}
}
pub fn register(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PyPrediction>()?;
m.add_class::<PyConformalPredictor>()?;
Ok(())
}