1use laddu_core::{math::Histogram, LadduError};
2use numpy::{PyArray1, PyReadonlyArray1};
3use pyo3::{
4 exceptions::PyValueError,
5 prelude::*,
6 types::{PyBytes, PyTuple},
7};
8
9fn extract_f64_vec(value: &Bound<'_, PyAny>, name: &str) -> PyResult<Vec<f64>> {
10 if let Ok(array) = value.extract::<PyReadonlyArray1<'_, f64>>() {
11 return Ok(array.as_slice()?.to_vec());
12 }
13 value
14 .extract::<Vec<f64>>()
15 .map_err(|_| PyValueError::new_err(format!("{name} must be a one-dimensional float array")))
16}
17
18#[pyclass(name = "Histogram", module = "laddu", from_py_object)]
28#[derive(Clone, Debug)]
29pub struct PyHistogram(pub Histogram);
30
31#[pymethods]
32impl PyHistogram {
33 #[new]
34 fn new(bin_edges: &Bound<'_, PyAny>, counts: &Bound<'_, PyAny>) -> PyResult<Self> {
35 Ok(Self(Histogram::new(
36 extract_f64_vec(bin_edges, "bin_edges")?,
37 extract_f64_vec(counts, "counts")?,
38 )?))
39 }
40
41 #[staticmethod]
43 fn from_numpy(bin_edges: &Bound<'_, PyAny>, counts: &Bound<'_, PyAny>) -> PyResult<Self> {
44 Self::new(bin_edges, counts)
45 }
46
47 #[getter]
49 fn bin_edges<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray1<f64>> {
50 PyArray1::from_slice(py, self.0.bin_edges())
51 }
52
53 #[getter]
55 fn counts<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray1<f64>> {
56 PyArray1::from_slice(py, self.0.counts())
57 }
58
59 #[getter]
61 fn total_weight(&self) -> f64 {
62 self.0.total_weight()
63 }
64
65 fn to_numpy<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyTuple>> {
67 PyTuple::new(
68 py,
69 [self.bin_edges(py).into_any(), self.counts(py).into_any()],
70 )
71 }
72
73 fn __getstate__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
74 Ok(PyBytes::new(
75 py,
76 serde_pickle::to_vec(&self.0, serde_pickle::SerOptions::new())
77 .map_err(LadduError::PickleError)?
78 .as_slice(),
79 ))
80 }
81
82 fn __getnewargs__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyTuple>> {
83 PyTuple::new(py, [self.0.bin_edges().to_vec(), self.0.counts().to_vec()])
84 }
85
86 fn __setstate__(&mut self, state: Bound<'_, PyBytes>) -> PyResult<()> {
87 *self = Self(
88 serde_pickle::from_slice(state.as_bytes(), serde_pickle::DeOptions::new())
89 .map_err(LadduError::PickleError)?,
90 );
91 Ok(())
92 }
93
94 fn __repr__(&self) -> String {
95 format!("{:?}", self.0)
96 }
97
98 fn __str__(&self) -> String {
99 self.__repr__()
100 }
101}