1use ndarray::{Array1, Dimension};
2use numpy::{PyArrayMethods, PyReadonlyArrayDyn};
3use pyo3::prelude::*;
4
5use neopdf::gridpdf::GridArray;
6use neopdf::subgrid::{GridData, ParamRange, SubGrid};
7
8#[pyclass(name = "SubGrid")]
10pub struct PySubGrid {
11 pub(crate) subgrid: SubGrid,
12}
13
14#[pymethods]
15impl PySubGrid {
16 #[new]
41 #[allow(clippy::too_many_arguments)]
42 #[allow(clippy::needless_pass_by_value)]
43 pub fn new(
44 xs: Vec<f64>,
45 q2s: Vec<f64>,
46 kts: Vec<f64>,
47 xsis: Vec<f64>,
48 deltas: Vec<f64>,
49 nucleons: Vec<f64>,
50 alphas: Vec<f64>,
51 grid: PyReadonlyArrayDyn<f64>,
52 ) -> PyResult<Self> {
53 let alphas_range = ParamRange::new(*alphas.first().unwrap(), *alphas.last().unwrap());
54 let x_range = ParamRange::new(*xs.first().unwrap(), *xs.last().unwrap());
55 let q2_range = ParamRange::new(*q2s.first().unwrap(), *q2s.last().unwrap());
56 let xsi_range = ParamRange::new(*xsis.first().unwrap(), *xsis.last().unwrap());
57 let delta_range = ParamRange::new(*deltas.first().unwrap(), *deltas.last().unwrap());
58 let kt_range = ParamRange::new(*kts.first().unwrap(), *kts.last().unwrap());
59 let nucleons_range = ParamRange::new(*nucleons.first().unwrap(), *nucleons.last().unwrap());
60
61 let subgrid = SubGrid {
62 xs: Array1::from(xs),
63 q2s: Array1::from(q2s),
64 kts: Array1::from(kts),
65 xis: Array1::from(xsis),
66 deltas: Array1::from(deltas),
67 grid: GridData::Grid8D(grid.to_owned_array()),
68 nucleons: Array1::from(nucleons),
69 alphas: Array1::from(alphas),
70 nucleons_range,
71 alphas_range,
72 xi_range: xsi_range,
73 delta_range,
74 kt_range,
75 x_range,
76 q2_range,
77 };
78
79 Ok(Self { subgrid })
80 }
81
82 #[must_use]
84 pub const fn alphas_range(&self) -> (f64, f64) {
85 (self.subgrid.alphas_range.min, self.subgrid.alphas_range.max)
86 }
87
88 #[must_use]
90 pub const fn x_range(&self) -> (f64, f64) {
91 (self.subgrid.x_range.min, self.subgrid.x_range.max)
92 }
93
94 #[must_use]
96 pub const fn q2_range(&self) -> (f64, f64) {
97 (self.subgrid.q2_range.min, self.subgrid.q2_range.max)
98 }
99
100 #[must_use]
102 pub const fn xi_range(&self) -> (f64, f64) {
103 (self.subgrid.xi_range.min, self.subgrid.xi_range.max)
104 }
105
106 #[must_use]
108 pub const fn delta_range(&self) -> (f64, f64) {
109 (self.subgrid.delta_range.min, self.subgrid.delta_range.max)
110 }
111
112 #[must_use]
114 pub const fn nucleons_range(&self) -> (f64, f64) {
115 (
116 self.subgrid.nucleons_range.min,
117 self.subgrid.nucleons_range.max,
118 )
119 }
120
121 #[must_use]
123 pub const fn kt_range(&self) -> (f64, f64) {
124 (self.subgrid.kt_range.min, self.subgrid.kt_range.max)
125 }
126
127 #[must_use]
133 pub fn grid_shape(&self) -> Vec<usize> {
134 match &self.subgrid.grid {
135 GridData::Grid6D(grid) => {
136 let (d0, d1, d2, d3, d4, d5) = grid.dim();
137 vec![d0, d1, d2, d3, d4, d5]
138 }
139 GridData::Grid8D(grid) => grid.dim().as_array_view().to_vec(),
140 }
141 }
142}
143
144#[pyclass(name = "GridArray")]
146#[repr(transparent)]
147pub struct PyGridArray {
148 pub(crate) gridarray: GridArray,
149}
150
151#[pymethods]
152impl PyGridArray {
153 #[new]
164 #[must_use]
165 pub fn new(pids: Vec<i32>, subgrids: Vec<PyRef<PySubGrid>>) -> Self {
166 let subgrids = subgrids
167 .into_iter()
168 .map(|py_ref| py_ref.subgrid.clone())
169 .collect();
170
171 let gridarray = GridArray::from_parts(Array1::from(pids), subgrids);
172 Self { gridarray }
173 }
174
175 #[must_use]
177 pub fn pids(&self) -> Vec<i32> {
178 self.gridarray.pids.to_vec()
179 }
180
181 #[must_use]
183 pub fn subgrids(&self) -> Vec<PySubGrid> {
184 self.gridarray
185 .subgrids
186 .iter()
187 .cloned()
188 .map(|sg| PySubGrid { subgrid: sg })
189 .collect()
190 }
191}
192
193pub fn register(parent_module: &Bound<'_, PyModule>) -> PyResult<()> {
203 let m = PyModule::new(parent_module.py(), "gridpdf")?;
204 m.setattr(
205 pyo3::intern!(m.py(), "__doc__"),
206 "GridPDF interpolation interface.",
207 )?;
208 pyo3::py_run!(
209 parent_module.py(),
210 m,
211 "import sys; sys.modules['neopdf.gridpdf'] = m"
212 );
213 m.add_class::<PySubGrid>()?;
214 m.add_class::<PyGridArray>()?;
215 parent_module.add_submodule(&m)
216}