neopdf/
gridpdf.rs

1use ndarray::Array1;
2use numpy::{PyArrayMethods, PyReadonlyArray6};
3use pyo3::prelude::*;
4
5use neopdf::gridpdf::GridArray;
6use neopdf::subgrid::{ParamRange, SubGrid};
7
8/// Python wrapper for the `SubGrid` struct.
9#[pyclass(name = "SubGrid")]
10pub struct PySubGrid {
11    pub(crate) subgrid: SubGrid,
12}
13
14#[pymethods]
15impl PySubGrid {
16    /// Constructs a new `SubGrid` instance from the provided axes and grid data.
17    ///
18    /// # Parameters
19    ///
20    /// - `xs`: The x-axis values.
21    /// - `q2s`: The Q^2-axis values.
22    /// - `kts`: The kT-axis values.
23    /// - `nucleons`: The nucleon number axis values.
24    /// - `alphas`: The alpha_s axis values.
25    /// - `grid`: The 6D grid data as a NumPy array.
26    ///
27    /// # Returns
28    ///
29    /// Returns a new `PySubGrid` instance.
30    ///
31    /// # Panics
32    ///
33    /// Panics if any of the input vectors are empty.
34    ///
35    /// # Errors
36    ///
37    /// Returns a `PyErr` if the grid cannot be constructed from the input data.
38    #[new]
39    #[allow(clippy::needless_pass_by_value)]
40    pub fn new(
41        xs: Vec<f64>,
42        q2s: Vec<f64>,
43        kts: Vec<f64>,
44        nucleons: Vec<f64>,
45        alphas: Vec<f64>,
46        grid: PyReadonlyArray6<f64>,
47    ) -> PyResult<Self> {
48        let alphas_range = ParamRange::new(*alphas.first().unwrap(), *alphas.last().unwrap());
49        let x_range = ParamRange::new(*xs.first().unwrap(), *xs.last().unwrap());
50        let q2_range = ParamRange::new(*q2s.first().unwrap(), *q2s.last().unwrap());
51        let kt_range = ParamRange::new(*kts.first().unwrap(), *kts.last().unwrap());
52        let nucleons_range = ParamRange::new(*nucleons.first().unwrap(), *nucleons.last().unwrap());
53
54        let subgrid = SubGrid {
55            xs: Array1::from(xs),
56            q2s: Array1::from(q2s),
57            kts: Array1::from(kts),
58            grid: grid.to_owned_array(),
59            nucleons: Array1::from(nucleons),
60            alphas: Array1::from(alphas),
61            nucleons_range,
62            alphas_range,
63            kt_range,
64            x_range,
65            q2_range,
66        };
67
68        Ok(Self { subgrid })
69    }
70
71    /// Returns the minimum and maximum values of the alpha_s axis.
72    #[must_use]
73    pub const fn alphas_range(&self) -> (f64, f64) {
74        (self.subgrid.alphas_range.min, self.subgrid.alphas_range.max)
75    }
76
77    /// Returns the minimum and maximum values of the momentum fraction `x`.
78    #[must_use]
79    pub const fn x_range(&self) -> (f64, f64) {
80        (self.subgrid.x_range.min, self.subgrid.x_range.max)
81    }
82
83    /// Returns the minimum and maximum values of the momentum scale `Q^2`.
84    #[must_use]
85    pub const fn q2_range(&self) -> (f64, f64) {
86        (self.subgrid.q2_range.min, self.subgrid.q2_range.max)
87    }
88
89    /// Returns the minimum and maximum values of the Nucleon number `A`.
90    #[must_use]
91    pub const fn nucleons_range(&self) -> (f64, f64) {
92        (
93            self.subgrid.nucleons_range.min,
94            self.subgrid.nucleons_range.max,
95        )
96    }
97
98    /// Returns the minimum and maximum values of the transverse momentum `kT`.
99    #[must_use]
100    pub const fn kt_range(&self) -> (f64, f64) {
101        (self.subgrid.kt_range.min, self.subgrid.kt_range.max)
102    }
103
104    /// Returns the shape of the subgrid
105    #[must_use]
106    pub fn grid_shape(&self) -> (usize, usize, usize, usize, usize, usize) {
107        self.subgrid.grid.dim()
108    }
109}
110
111/// Python wrapper for the `GridArray` struct.
112#[pyclass(name = "GridArray")]
113#[repr(transparent)]
114pub struct PyGridArray {
115    pub(crate) gridarray: GridArray,
116}
117
118#[pymethods]
119impl PyGridArray {
120    /// Constructs a new `GridArray` from a list of particle IDs and subgrids.
121    ///
122    /// # Parameters
123    ///
124    /// - `pids`: The list of particle IDs.
125    /// - `subgrids`: The list of subgrid objects.
126    ///
127    /// # Returns
128    ///
129    /// Returns a new `PyGridArray` instance.
130    #[new]
131    #[must_use]
132    pub fn new(pids: Vec<i32>, subgrids: Vec<PyRef<PySubGrid>>) -> Self {
133        let subgrids = subgrids
134            .into_iter()
135            .map(|py_ref| py_ref.subgrid.clone())
136            .collect();
137
138        let gridarray = GridArray {
139            pids: Array1::from(pids),
140            subgrids,
141        };
142        Self { gridarray }
143    }
144
145    /// Returns the particle IDs associated with this grid array.
146    #[must_use]
147    pub fn pids(&self) -> Vec<i32> {
148        self.gridarray.pids.to_vec()
149    }
150
151    /// Returns the subgrids contained in this grid array.
152    #[must_use]
153    pub fn subgrids(&self) -> Vec<PySubGrid> {
154        self.gridarray
155            .subgrids
156            .iter()
157            .cloned()
158            .map(|sg| PySubGrid { subgrid: sg })
159            .collect()
160    }
161}
162
163/// Registers the gridpdf module with the parent Python module.
164///
165/// Adds the `gridpdf` submodule to the parent Python module, exposing grid
166/// interpolation utilities to Python.
167///
168/// # Errors
169///
170/// Returns a `PyErr` if the submodule cannot be created or added, or if any
171/// class registration fails.
172pub fn register(parent_module: &Bound<'_, PyModule>) -> PyResult<()> {
173    let m = PyModule::new(parent_module.py(), "gridpdf")?;
174    m.setattr(
175        pyo3::intern!(m.py(), "__doc__"),
176        "GridPDF interpolation interface.",
177    )?;
178    pyo3::py_run!(
179        parent_module.py(),
180        m,
181        "import sys; sys.modules['neopdf.gridpdf'] = m"
182    );
183    m.add_class::<PySubGrid>()?;
184    m.add_class::<PyGridArray>()?;
185    parent_module.add_submodule(&m)
186}