Skip to main content

neopdf/
gridpdf.rs

1use ndarray::{Array1, Dimension};
2use numpy::{PyArrayMethods, PyReadonlyArrayDyn};
3use pyo3::prelude::*;
4
5use neopdf::gridpdf::GridArray;
6use neopdf::subgrid::{GridData, ParamRange, SubGrid};
7
8/// Python wrapper for the `SubGrid` struct.
9#[pyclass(name = "SubGrid")]
10pub struct PySubGrid {
11    pub(crate) subgrid: SubGrid,
12}
13
14#[pymethods]
15impl PySubGrid {
16    /// Constructs a new `SubGrid` instance from the provided axes and grid data.
17    ///
18    /// # Parameters
19    ///
20    /// - `xs`: The x-axis values.
21    /// - `q2s`: The Q^2-axis values.
22    /// - `kts`: The kT-axis values.
23    /// - `xsis`: The Skeweness `\xi`.
24    /// - `deltas`: The total momentum fraction `\Delta`.
25    /// - `nucleons`: The nucleon number axis values.
26    /// - `alphas`: The `alpha_s` axis values.
27    /// - `grid`: The 6D grid data as a `NumPy` array.
28    ///
29    /// # Returns
30    ///
31    /// Returns a new `PySubGrid` instance.
32    ///
33    /// # Panics
34    ///
35    /// Panics if any of the input vectors are empty.
36    ///
37    /// # Errors
38    ///
39    /// Returns a `PyErr` if the grid cannot be constructed from the input data.
40    #[new]
41    #[allow(clippy::too_many_arguments)]
42    #[allow(clippy::needless_pass_by_value)]
43    pub fn new(
44        xs: Vec<f64>,
45        q2s: Vec<f64>,
46        kts: Vec<f64>,
47        xsis: Vec<f64>,
48        deltas: Vec<f64>,
49        nucleons: Vec<f64>,
50        alphas: Vec<f64>,
51        grid: PyReadonlyArrayDyn<f64>,
52    ) -> PyResult<Self> {
53        let alphas_range = ParamRange::new(*alphas.first().unwrap(), *alphas.last().unwrap());
54        let x_range = ParamRange::new(*xs.first().unwrap(), *xs.last().unwrap());
55        let q2_range = ParamRange::new(*q2s.first().unwrap(), *q2s.last().unwrap());
56        let xsi_range = ParamRange::new(*xsis.first().unwrap(), *xsis.last().unwrap());
57        let delta_range = ParamRange::new(*deltas.first().unwrap(), *deltas.last().unwrap());
58        let kt_range = ParamRange::new(*kts.first().unwrap(), *kts.last().unwrap());
59        let nucleons_range = ParamRange::new(*nucleons.first().unwrap(), *nucleons.last().unwrap());
60
61        let subgrid = SubGrid {
62            xs: Array1::from(xs),
63            q2s: Array1::from(q2s),
64            kts: Array1::from(kts),
65            xis: Array1::from(xsis),
66            deltas: Array1::from(deltas),
67            grid: GridData::Grid8D(grid.to_owned_array()),
68            nucleons: Array1::from(nucleons),
69            alphas: Array1::from(alphas),
70            nucleons_range,
71            alphas_range,
72            xi_range: xsi_range,
73            delta_range,
74            kt_range,
75            x_range,
76            q2_range,
77        };
78
79        Ok(Self { subgrid })
80    }
81
82    /// Returns the minimum and maximum values of the `alpha_s` axis.
83    #[must_use]
84    pub const fn alphas_range(&self) -> (f64, f64) {
85        (self.subgrid.alphas_range.min, self.subgrid.alphas_range.max)
86    }
87
88    /// Returns the minimum and maximum values of the momentum fraction `x`.
89    #[must_use]
90    pub const fn x_range(&self) -> (f64, f64) {
91        (self.subgrid.x_range.min, self.subgrid.x_range.max)
92    }
93
94    /// Returns the minimum and maximum values of the momentum scale `Q^2`.
95    #[must_use]
96    pub const fn q2_range(&self) -> (f64, f64) {
97        (self.subgrid.q2_range.min, self.subgrid.q2_range.max)
98    }
99
100    /// Returns the minimum and maximum values of the skeweness `xi`.
101    #[must_use]
102    pub const fn xi_range(&self) -> (f64, f64) {
103        (self.subgrid.xi_range.min, self.subgrid.xi_range.max)
104    }
105
106    /// Returns the minimum and maximum values of the total momentum fraction `delta`.
107    #[must_use]
108    pub const fn delta_range(&self) -> (f64, f64) {
109        (self.subgrid.delta_range.min, self.subgrid.delta_range.max)
110    }
111
112    /// Returns the minimum and maximum values of the Nucleon number `A`.
113    #[must_use]
114    pub const fn nucleons_range(&self) -> (f64, f64) {
115        (
116            self.subgrid.nucleons_range.min,
117            self.subgrid.nucleons_range.max,
118        )
119    }
120
121    /// Returns the minimum and maximum values of the transverse momentum `kT`.
122    #[must_use]
123    pub const fn kt_range(&self) -> (f64, f64) {
124        (self.subgrid.kt_range.min, self.subgrid.kt_range.max)
125    }
126
127    /// Returns the shape of the subgrid
128    ///
129    /// # Panics
130    ///
131    /// TODO
132    #[must_use]
133    pub fn grid_shape(&self) -> Vec<usize> {
134        match &self.subgrid.grid {
135            GridData::Grid6D(grid) => {
136                let (d0, d1, d2, d3, d4, d5) = grid.dim();
137                vec![d0, d1, d2, d3, d4, d5]
138            }
139            GridData::Grid8D(grid) => grid.dim().as_array_view().to_vec(),
140        }
141    }
142}
143
144/// Python wrapper for the `GridArray` struct.
145#[pyclass(name = "GridArray")]
146#[repr(transparent)]
147pub struct PyGridArray {
148    pub(crate) gridarray: GridArray,
149}
150
151#[pymethods]
152impl PyGridArray {
153    /// Constructs a new `GridArray` from a list of particle IDs and subgrids.
154    ///
155    /// # Parameters
156    ///
157    /// - `pids`: The list of particle IDs.
158    /// - `subgrids`: The list of subgrid objects.
159    ///
160    /// # Returns
161    ///
162    /// Returns a new `PyGridArray` instance.
163    #[new]
164    #[must_use]
165    pub fn new(pids: Vec<i32>, subgrids: Vec<PyRef<PySubGrid>>) -> Self {
166        let subgrids = subgrids
167            .into_iter()
168            .map(|py_ref| py_ref.subgrid.clone())
169            .collect();
170
171        let gridarray = GridArray::from_parts(Array1::from(pids), subgrids);
172        Self { gridarray }
173    }
174
175    /// Returns the particle IDs associated with this grid array.
176    #[must_use]
177    pub fn pids(&self) -> Vec<i32> {
178        self.gridarray.pids.to_vec()
179    }
180
181    /// Returns the subgrids contained in this grid array.
182    #[must_use]
183    pub fn subgrids(&self) -> Vec<PySubGrid> {
184        self.gridarray
185            .subgrids
186            .iter()
187            .cloned()
188            .map(|sg| PySubGrid { subgrid: sg })
189            .collect()
190    }
191}
192
193/// Registers the gridpdf module with the parent Python module.
194///
195/// Adds the `gridpdf` submodule to the parent Python module, exposing grid
196/// interpolation utilities to Python.
197///
198/// # Errors
199///
200/// Returns a `PyErr` if the submodule cannot be created or added, or if any
201/// class registration fails.
202pub fn register(parent_module: &Bound<'_, PyModule>) -> PyResult<()> {
203    let m = PyModule::new(parent_module.py(), "gridpdf")?;
204    m.setattr(
205        pyo3::intern!(m.py(), "__doc__"),
206        "GridPDF interpolation interface.",
207    )?;
208    pyo3::py_run!(
209        parent_module.py(),
210        m,
211        "import sys; sys.modules['neopdf.gridpdf'] = m"
212    );
213    m.add_class::<PySubGrid>()?;
214    m.add_class::<PyGridArray>()?;
215    parent_module.add_submodule(&m)
216}