Skip to main content

pineappl/
convolutions.rs

1//! Module for everything related to convolution functions.
2
3use super::boc::Kinematics;
4use super::boc::Scales;
5use super::grid::Grid;
6use super::pids;
7use super::subgrid::{self, Subgrid, SubgridEnum};
8use itertools::izip;
9use rustc_hash::FxHashMap;
10use serde::{Deserialize, Serialize};
11
12const REN_IDX: usize = 0;
13const FAC_IDX: usize = 1;
14const FRG_IDX: usize = 2;
15const SCALES_CNT: usize = 3;
16
17struct ConvCache1d<'a> {
18    xfx: &'a mut dyn FnMut(i32, f64, f64) -> f64,
19    cache: FxHashMap<(i32, usize, usize), f64>,
20    conv: Conv,
21}
22
23/// A cache for evaluating PDFs. Methods like [`Grid::convolve`] accept instances of this `struct`
24/// instead of the PDFs themselves.
25pub struct ConvolutionCache<'a> {
26    caches: Vec<ConvCache1d<'a>>,
27    alphas: &'a mut dyn FnMut(f64) -> f64,
28    alphas_cache: Vec<f64>,
29    mu2: [Vec<f64>; SCALES_CNT],
30    x_grid: Vec<f64>,
31}
32
33impl<'a> ConvolutionCache<'a> {
34    /// TODO
35    pub fn new(
36        convolutions: Vec<Conv>,
37        xfx: Vec<&'a mut dyn FnMut(i32, f64, f64) -> f64>,
38        alphas: &'a mut dyn FnMut(f64) -> f64,
39    ) -> Self {
40        Self {
41            caches: xfx
42                .into_iter()
43                .zip(convolutions)
44                .map(|(xfx, conv)| ConvCache1d {
45                    xfx,
46                    cache: FxHashMap::default(),
47                    conv,
48                })
49                .collect(),
50            alphas,
51            alphas_cache: Vec::new(),
52            mu2: [const { Vec::new() }; SCALES_CNT],
53            x_grid: Vec::new(),
54        }
55    }
56
57    pub(crate) fn new_grid_conv_cache<'b>(
58        &'b mut self,
59        grid: &Grid,
60        xi: &[(f64, f64, f64)],
61    ) -> GridConvCache<'a, 'b> {
62        // TODO: try to avoid calling clear
63        self.clear();
64
65        let scales: [_; SCALES_CNT] = grid.scales().into();
66        let xi: Vec<_> = (0..SCALES_CNT)
67            .map(|idx| {
68                let mut vars: Vec<_> = xi
69                    .iter()
70                    .map(|&x| <[_; SCALES_CNT]>::from(x)[idx])
71                    .collect();
72                vars.sort_by(f64::total_cmp);
73                vars.dedup();
74                vars
75            })
76            .collect();
77
78        for (result, scale, xi) in izip!(&mut self.mu2, scales, xi) {
79            result.clear();
80            result.extend(
81                grid.subgrids()
82                    .iter()
83                    .filter(|subgrid| !subgrid.is_empty())
84                    .flat_map(|subgrid| {
85                        scale
86                            .calc(&subgrid.node_values(), grid.kinematics())
87                            .into_owned()
88                    })
89                    .flat_map(|scale| xi.iter().map(move |&xi| xi * xi * scale)),
90            );
91            result.sort_by(f64::total_cmp);
92            result.dedup();
93        }
94
95        let mut x_grid: Vec<_> = grid
96            .subgrids()
97            .iter()
98            .filter(|subgrid| !subgrid.is_empty())
99            .flat_map(|subgrid| {
100                grid.kinematics()
101                    .iter()
102                    .zip(subgrid.node_values())
103                    .filter(|(kin, _)| matches!(kin, Kinematics::X(_)))
104                    .flat_map(|(_, node_values)| node_values)
105            })
106            .collect();
107        x_grid.sort_by(f64::total_cmp);
108        x_grid.dedup();
109
110        self.alphas_cache = self.mu2[REN_IDX]
111            .iter()
112            .map(|&mur2| (self.alphas)(mur2))
113            .collect();
114        self.x_grid = x_grid;
115
116        let perm = grid
117            .convolutions()
118            .iter()
119            .enumerate()
120            .map(|(max_idx, grid_conv)| {
121                self.caches
122                    .iter()
123                    .take(max_idx + 1)
124                    .enumerate()
125                    .rev()
126                    .find_map(|(idx, ConvCache1d { conv, .. })| {
127                        if grid_conv == conv {
128                            Some((idx, false))
129                        } else if *grid_conv == conv.cc() {
130                            Some((idx, true))
131                        } else {
132                            None
133                        }
134                    })
135                    // TODO: convert `unwrap` to `Err`
136                    .unwrap_or_else(|| {
137                        panic!(
138                        "couldn't match {grid_conv:?} with a convolution function from cache {:?}",
139                        self.caches
140                            .iter()
141                            .map(|cache| cache.conv.clone())
142                            .collect::<Vec<_>>()
143                    )
144                    })
145            })
146            .collect();
147
148        GridConvCache {
149            cache: self,
150            perm,
151            imu2: [const { Vec::new() }; SCALES_CNT],
152            scales: grid.scales().clone(),
153            ix: Vec::new(),
154            scale_dims: Vec::new(),
155        }
156    }
157
158    /// Clears the cache.
159    pub fn clear(&mut self) {
160        self.alphas_cache.clear();
161        for xfx_cache in &mut self.caches {
162            xfx_cache.cache.clear();
163        }
164        for scales in &mut self.mu2 {
165            scales.clear();
166        }
167        self.x_grid.clear();
168    }
169}
170
171/// TODO
172pub struct GridConvCache<'a, 'b> {
173    cache: &'b mut ConvolutionCache<'a>,
174    perm: Vec<(usize, bool)>,
175    imu2: [Vec<usize>; SCALES_CNT],
176    scales: Scales,
177    ix: Vec<Vec<usize>>,
178    scale_dims: Vec<usize>,
179}
180
181impl GridConvCache<'_, '_> {
182    /// TODO
183    pub fn as_fx_prod(&mut self, pdg_ids: &[i32], as_order: u8, indices: &[usize]) -> f64 {
184        // TODO: here we assume that
185        // - indices[0] is the (squared) factorization scale,
186        // - indices[1] is x1 and
187        // - indices[2] is x2.
188        // Lift this restriction!
189        let x_start = indices.len() - pdg_ids.len();
190        let indices_scales = &indices[0..x_start];
191        let indices_x = &indices[x_start..];
192
193        let ix = self.ix.iter().zip(indices_x).map(|(ix, &index)| ix[index]);
194        let idx_pid = self.perm.iter().zip(pdg_ids).map(|(&(idx, cc), &pdg_id)| {
195            (
196                idx,
197                if cc {
198                    pids::charge_conjugate_pdg_pid(pdg_id)
199                } else {
200                    pdg_id
201                },
202            )
203        });
204
205        let fx_prod: f64 = ix
206            .zip(idx_pid)
207            .map(|(ix, (idx, pid))| {
208                let ConvCache1d { xfx, cache, conv } = &mut self.cache.caches[idx];
209
210                let (scale, scale_idx) = match conv.conv_type() {
211                    ConvType::UnpolPDF | ConvType::PolPDF => (
212                        FAC_IDX,
213                        self.scales.fac.idx(indices_scales, &self.scale_dims),
214                    ),
215                    ConvType::UnpolFF | ConvType::PolFF => (
216                        FRG_IDX,
217                        self.scales.frg.idx(indices_scales, &self.scale_dims),
218                    ),
219                };
220
221                let imu2 = self.imu2[scale][scale_idx];
222                let mu2 = self.cache.mu2[scale][imu2];
223
224                *cache.entry((pid, ix, imu2)).or_insert_with(|| {
225                    let x = self.cache.x_grid[ix];
226                    xfx(pid, x, mu2) / x
227                })
228            })
229            .product();
230        let alphas_powers = if as_order != 0 {
231            let ren_scale_idx = self.scales.ren.idx(indices_scales, &self.scale_dims);
232            self.cache.alphas_cache[self.imu2[REN_IDX][ren_scale_idx]].powi(as_order.into())
233        } else {
234            1.0
235        };
236
237        fx_prod * alphas_powers
238    }
239
240    /// Set the grids.
241    pub fn set_grids(&mut self, grid: &Grid, subgrid: &SubgridEnum, xi: (f64, f64, f64)) {
242        let node_values = subgrid.node_values();
243        let kinematics = grid.kinematics();
244        let scales: [_; SCALES_CNT] = grid.scales().into();
245        let xi: [_; SCALES_CNT] = xi.into();
246
247        for (result, values, scale, xi) in izip!(&mut self.imu2, &self.cache.mu2, scales, xi) {
248            result.clear();
249            result.extend(scale.calc(&node_values, kinematics).iter().map(|s| {
250                values
251                    .iter()
252                    .position(|&value| subgrid::node_value_eq(value, xi * xi * s))
253                    // UNWRAP: if this fails, `new_grid_conv_cache` hasn't been called properly
254                    .unwrap_or_else(|| unreachable!())
255            }));
256        }
257
258        self.ix = (0..grid.convolutions().len())
259            .map(|idx| {
260                kinematics
261                    .iter()
262                    .zip(&node_values)
263                    .find_map(|(kin, node_values)| {
264                        matches!(kin, &Kinematics::X(index) if index == idx).then_some(node_values)
265                    })
266                    // UNWRAP: guaranteed by the grid constructor
267                    .unwrap_or_else(|| unreachable!())
268                    .iter()
269                    .map(|&xd| {
270                        self.cache
271                            .x_grid
272                            .iter()
273                            .position(|&x| subgrid::node_value_eq(xd, x))
274                            .unwrap_or_else(|| unreachable!())
275                    })
276                    .collect()
277            })
278            .collect();
279
280        self.scale_dims = grid
281            .kinematics()
282            .iter()
283            .zip(node_values)
284            .filter_map(|(kin, node_values)| {
285                matches!(kin, Kinematics::Scale(_)).then_some(node_values.len())
286            })
287            .collect();
288    }
289}
290
291/// TODO
292#[repr(C)]
293#[derive(Clone, Copy, Debug, Deserialize, Eq, PartialEq, Serialize)]
294pub enum ConvType {
295    /// Unpolarized parton distribution function.
296    UnpolPDF,
297    /// Polarized parton distribution function.
298    PolPDF,
299    /// Unpolarized fragmentation function.
300    UnpolFF,
301    /// Polarized fragmentation function.
302    PolFF,
303}
304
305impl ConvType {
306    /// TODO
307    #[must_use]
308    pub const fn new(polarized: bool, time_like: bool) -> Self {
309        match (polarized, time_like) {
310            (false, false) => Self::UnpolPDF,
311            (false, true) => Self::UnpolFF,
312            (true, false) => Self::PolPDF,
313            (true, true) => Self::PolFF,
314        }
315    }
316
317    /// TODO
318    #[must_use]
319    pub const fn is_pdf(&self) -> bool {
320        matches!(self, Self::UnpolPDF | Self::PolPDF)
321    }
322
323    /// TODO
324    #[must_use]
325    pub const fn is_ff(&self) -> bool {
326        matches!(self, Self::UnpolFF | Self::PolFF)
327    }
328}
329
330/// Data type that indentifies different types of convolutions.
331#[repr(C)]
332#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
333pub struct Conv {
334    conv_type: ConvType,
335    pid: i32,
336}
337
338impl Conv {
339    /// Constructor.
340    #[must_use]
341    pub const fn new(conv_type: ConvType, pid: i32) -> Self {
342        Self { conv_type, pid }
343    }
344
345    /// Return the convolution if the PID is charged conjugated.
346    #[must_use]
347    pub const fn cc(&self) -> Self {
348        Self {
349            conv_type: self.conv_type,
350            pid: pids::charge_conjugate_pdg_pid(self.pid),
351        }
352    }
353
354    /// Return the PID of the convolution.
355    #[must_use]
356    pub const fn pid(&self) -> i32 {
357        self.pid
358    }
359
360    /// Return the convolution type of this convolution.
361    #[must_use]
362    pub const fn conv_type(&self) -> ConvType {
363        self.conv_type
364    }
365}
366
367#[cfg(test)]
368mod tests {
369    use super::*;
370
371    #[test]
372    fn conv_cc() {
373        assert_eq!(
374            Conv::new(ConvType::UnpolPDF, 2212).cc(),
375            Conv::new(ConvType::UnpolPDF, -2212)
376        );
377        assert_eq!(
378            Conv::new(ConvType::PolPDF, 2212).cc(),
379            Conv::new(ConvType::PolPDF, -2212)
380        );
381        assert_eq!(
382            Conv::new(ConvType::UnpolFF, 2212).cc(),
383            Conv::new(ConvType::UnpolFF, -2212)
384        );
385        assert_eq!(
386            Conv::new(ConvType::PolFF, 2212).cc(),
387            Conv::new(ConvType::PolFF, -2212)
388        );
389    }
390
391    #[test]
392    fn conv_pid() {
393        assert_eq!(Conv::new(ConvType::UnpolPDF, 2212).pid(), 2212);
394        assert_eq!(Conv::new(ConvType::PolPDF, 2212).pid(), 2212);
395        assert_eq!(Conv::new(ConvType::UnpolFF, 2212).pid(), 2212);
396        assert_eq!(Conv::new(ConvType::PolFF, 2212).pid(), 2212);
397    }
398}