Skip to main content

laddu_python/
math.rs

1use laddu_core::{math::Histogram, LadduError};
2use numpy::{PyArray1, PyReadonlyArray1};
3use pyo3::{
4    exceptions::PyValueError,
5    prelude::*,
6    types::{PyBytes, PyTuple},
7};
8
9fn extract_f64_vec(value: &Bound<'_, PyAny>, name: &str) -> PyResult<Vec<f64>> {
10    if let Ok(array) = value.extract::<PyReadonlyArray1<'_, f64>>() {
11        return Ok(array.as_slice()?.to_vec());
12    }
13    value
14        .extract::<Vec<f64>>()
15        .map_err(|_| PyValueError::new_err(format!("{name} must be a one-dimensional float array")))
16}
17
18/// A weighted histogram with explicit bin edges.
19///
20/// Parameters
21/// ----------
22/// bin_edges : array_like
23///     Strictly increasing bin edges. The length must be one greater than ``counts``.
24/// counts : array_like
25///     Finite nonnegative weighted counts. The total weight must be positive.
26///
27#[pyclass(name = "Histogram", module = "laddu", from_py_object)]
28#[derive(Clone, Debug)]
29pub struct PyHistogram(pub Histogram);
30
31#[pymethods]
32impl PyHistogram {
33    #[new]
34    fn new(bin_edges: &Bound<'_, PyAny>, counts: &Bound<'_, PyAny>) -> PyResult<Self> {
35        Ok(Self(Histogram::new(
36            extract_f64_vec(bin_edges, "bin_edges")?,
37            extract_f64_vec(counts, "counts")?,
38        )?))
39    }
40
41    /// Construct a histogram from NumPy-compatible arrays.
42    #[staticmethod]
43    fn from_numpy(bin_edges: &Bound<'_, PyAny>, counts: &Bound<'_, PyAny>) -> PyResult<Self> {
44        Self::new(bin_edges, counts)
45    }
46
47    /// Bin edges as a NumPy array.
48    #[getter]
49    fn bin_edges<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray1<f64>> {
50        PyArray1::from_slice(py, self.0.bin_edges())
51    }
52
53    /// Weighted counts as a NumPy array.
54    #[getter]
55    fn counts<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray1<f64>> {
56        PyArray1::from_slice(py, self.0.counts())
57    }
58
59    /// Total histogram weight.
60    #[getter]
61    fn total_weight(&self) -> f64 {
62        self.0.total_weight()
63    }
64
65    /// Return ``(bin_edges, counts)`` as NumPy arrays.
66    fn to_numpy<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyTuple>> {
67        PyTuple::new(
68            py,
69            [self.bin_edges(py).into_any(), self.counts(py).into_any()],
70        )
71    }
72
73    fn __getstate__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
74        Ok(PyBytes::new(
75            py,
76            serde_pickle::to_vec(&self.0, serde_pickle::SerOptions::new())
77                .map_err(LadduError::PickleError)?
78                .as_slice(),
79        ))
80    }
81
82    fn __getnewargs__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyTuple>> {
83        PyTuple::new(py, [self.0.bin_edges().to_vec(), self.0.counts().to_vec()])
84    }
85
86    fn __setstate__(&mut self, state: Bound<'_, PyBytes>) -> PyResult<()> {
87        *self = Self(
88            serde_pickle::from_slice(state.as_bytes(), serde_pickle::DeOptions::new())
89                .map_err(LadduError::PickleError)?,
90        );
91        Ok(())
92    }
93
94    fn __repr__(&self) -> String {
95        format!("{:?}", self.0)
96    }
97
98    fn __str__(&self) -> String {
99        self.__repr__()
100    }
101}