use pyo3::prelude::*;
use pyo3::exceptions::PyValueError;
use numpy::{PyArray1, PyArrayMethods, PyReadonlyArray1};
use spliny::SplineCurve;
use crate::{CubicSplineFit, evaluate, ops};
fn to_vec(arr: &PyReadonlyArray1<f64>) -> PyResult<Vec<f64>> {
arr.to_vec()
.map_err(|e| PyValueError::new_err(format!("failed to convert array: {e}")))
}
fn extract_xy(x: &PyReadonlyArray1<f64>, y: &PyReadonlyArray1<f64>) -> PyResult<(Vec<f64>, Vec<f64>)> {
let xv = to_vec(x)?;
let yv = to_vec(y)?;
if xv.len() != yv.len() {
return Err(PyValueError::new_err(
format!("x and y must have the same length, got {} and {}", xv.len(), yv.len())
));
}
if xv.len() < 4 {
return Err(PyValueError::new_err(
format!("need at least 4 data points, got {}", xv.len())
));
}
Ok((xv, yv))
}
#[pyclass]
pub struct CubicSpline {
inner: SplineCurve<3, 1>,
}
#[pymethods]
impl CubicSpline {
#[staticmethod]
fn smoothing(x: PyReadonlyArray1<f64>, y: PyReadonlyArray1<f64>, rms: f64) -> PyResult<Self> {
let (xv, yv) = extract_xy(&x, &y)?;
let s = CubicSplineFit::new(xv, yv)
.smoothing_spline(rms)
.map_err(|e| PyValueError::new_err(e.to_string()))?;
Ok(CubicSpline { inner: s })
}
#[staticmethod]
fn interpolating(x: PyReadonlyArray1<f64>, y: PyReadonlyArray1<f64>) -> PyResult<Self> {
let (xv, yv) = extract_xy(&x, &y)?;
let s = CubicSplineFit::new(xv, yv)
.interpolating_spline()
.map_err(|e| PyValueError::new_err(e.to_string()))?;
Ok(CubicSpline { inner: s })
}
#[staticmethod]
fn cardinal(x: PyReadonlyArray1<f64>, y: PyReadonlyArray1<f64>, dt: f64) -> PyResult<Self> {
let (xv, yv) = extract_xy(&x, &y)?;
let s = CubicSplineFit::new(xv, yv)
.cardinal_spline(dt)
.map_err(|e| PyValueError::new_err(e.to_string()))?;
Ok(CubicSpline { inner: s })
}
fn evaluate<'py>(&self, py: Python<'py>, x: PyReadonlyArray1<f64>) -> PyResult<Bound<'py, PyArray1<f64>>> {
let x_slice = x.as_slice()
.map_err(|e| PyValueError::new_err(format!("failed to read array: {e}")))?;
let y = evaluate::evaluate(&self.inner, x_slice)
.map_err(|e| PyValueError::new_err(e.to_string()))?;
Ok(PyArray1::from_vec(py, y))
}
fn integral(&self, a: f64, b: f64) -> f64 {
ops::integral(&self.inner, a, b)
}
fn roots<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyArray1<f64>>> {
let z = ops::roots(&self.inner)
.map_err(|e| PyValueError::new_err(e.to_string()))?;
Ok(PyArray1::from_vec(py, z))
}
fn knots<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray1<f64>> {
PyArray1::from_vec(py, self.inner.t.clone())
}
fn coefficients<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray1<f64>> {
PyArray1::from_vec(py, self.inner.c.clone())
}
#[getter]
fn num_knots(&self) -> usize {
self.inner.t.len()
}
fn __repr__(&self) -> String {
format!(
"CubicSpline(num_knots={}, domain=[{:.6}, {:.6}])",
self.inner.t.len(),
self.inner.t.first().unwrap_or(&0.0),
self.inner.t.last().unwrap_or(&0.0),
)
}
}
#[pymodule]
pub fn splinefit(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<CubicSpline>()?;
Ok(())
}