1use super::boc::Kinematics;
4use super::boc::Scales;
5use super::grid::Grid;
6use super::pids;
7use super::subgrid::{self, Subgrid, SubgridEnum};
8use itertools::izip;
9use rustc_hash::FxHashMap;
10use serde::{Deserialize, Serialize};
11
12const REN_IDX: usize = 0;
13const FAC_IDX: usize = 1;
14const FRG_IDX: usize = 2;
15const SCALES_CNT: usize = 3;
16
17struct ConvCache1d<'a> {
18 xfx: &'a mut dyn FnMut(i32, f64, f64) -> f64,
19 cache: FxHashMap<(i32, usize, usize), f64>,
20 conv: Conv,
21}
22
23pub struct ConvolutionCache<'a> {
26 caches: Vec<ConvCache1d<'a>>,
27 alphas: &'a mut dyn FnMut(f64) -> f64,
28 alphas_cache: Vec<f64>,
29 mu2: [Vec<f64>; SCALES_CNT],
30 x_grid: Vec<f64>,
31}
32
33impl<'a> ConvolutionCache<'a> {
34 pub fn new(
36 convolutions: Vec<Conv>,
37 xfx: Vec<&'a mut dyn FnMut(i32, f64, f64) -> f64>,
38 alphas: &'a mut dyn FnMut(f64) -> f64,
39 ) -> Self {
40 Self {
41 caches: xfx
42 .into_iter()
43 .zip(convolutions)
44 .map(|(xfx, conv)| ConvCache1d {
45 xfx,
46 cache: FxHashMap::default(),
47 conv,
48 })
49 .collect(),
50 alphas,
51 alphas_cache: Vec::new(),
52 mu2: [const { Vec::new() }; SCALES_CNT],
53 x_grid: Vec::new(),
54 }
55 }
56
57 pub(crate) fn new_grid_conv_cache<'b>(
58 &'b mut self,
59 grid: &Grid,
60 xi: &[(f64, f64, f64)],
61 ) -> GridConvCache<'a, 'b> {
62 self.clear();
64
65 let scales: [_; SCALES_CNT] = grid.scales().into();
66 let xi: Vec<_> = (0..SCALES_CNT)
67 .map(|idx| {
68 let mut vars: Vec<_> = xi
69 .iter()
70 .map(|&x| <[_; SCALES_CNT]>::from(x)[idx])
71 .collect();
72 vars.sort_by(f64::total_cmp);
73 vars.dedup();
74 vars
75 })
76 .collect();
77
78 for (result, scale, xi) in izip!(&mut self.mu2, scales, xi) {
79 result.clear();
80 result.extend(
81 grid.subgrids()
82 .iter()
83 .filter(|subgrid| !subgrid.is_empty())
84 .flat_map(|subgrid| {
85 scale
86 .calc(&subgrid.node_values(), grid.kinematics())
87 .into_owned()
88 })
89 .flat_map(|scale| xi.iter().map(move |&xi| xi * xi * scale)),
90 );
91 result.sort_by(f64::total_cmp);
92 result.dedup();
93 }
94
95 let mut x_grid: Vec<_> = grid
96 .subgrids()
97 .iter()
98 .filter(|subgrid| !subgrid.is_empty())
99 .flat_map(|subgrid| {
100 grid.kinematics()
101 .iter()
102 .zip(subgrid.node_values())
103 .filter(|(kin, _)| matches!(kin, Kinematics::X(_)))
104 .flat_map(|(_, node_values)| node_values)
105 })
106 .collect();
107 x_grid.sort_by(f64::total_cmp);
108 x_grid.dedup();
109
110 self.alphas_cache = self.mu2[REN_IDX]
111 .iter()
112 .map(|&mur2| (self.alphas)(mur2))
113 .collect();
114 self.x_grid = x_grid;
115
116 let perm = grid
117 .convolutions()
118 .iter()
119 .enumerate()
120 .map(|(max_idx, grid_conv)| {
121 self.caches
122 .iter()
123 .take(max_idx + 1)
124 .enumerate()
125 .rev()
126 .find_map(|(idx, ConvCache1d { conv, .. })| {
127 if grid_conv == conv {
128 Some((idx, false))
129 } else if *grid_conv == conv.cc() {
130 Some((idx, true))
131 } else {
132 None
133 }
134 })
135 .unwrap_or_else(|| {
137 panic!(
138 "couldn't match {grid_conv:?} with a convolution function from cache {:?}",
139 self.caches
140 .iter()
141 .map(|cache| cache.conv.clone())
142 .collect::<Vec<_>>()
143 )
144 })
145 })
146 .collect();
147
148 GridConvCache {
149 cache: self,
150 perm,
151 imu2: [const { Vec::new() }; SCALES_CNT],
152 scales: grid.scales().clone(),
153 ix: Vec::new(),
154 scale_dims: Vec::new(),
155 }
156 }
157
158 pub fn clear(&mut self) {
160 self.alphas_cache.clear();
161 for xfx_cache in &mut self.caches {
162 xfx_cache.cache.clear();
163 }
164 for scales in &mut self.mu2 {
165 scales.clear();
166 }
167 self.x_grid.clear();
168 }
169}
170
171pub struct GridConvCache<'a, 'b> {
173 cache: &'b mut ConvolutionCache<'a>,
174 perm: Vec<(usize, bool)>,
175 imu2: [Vec<usize>; SCALES_CNT],
176 scales: Scales,
177 ix: Vec<Vec<usize>>,
178 scale_dims: Vec<usize>,
179}
180
181impl GridConvCache<'_, '_> {
182 pub fn as_fx_prod(&mut self, pdg_ids: &[i32], as_order: u8, indices: &[usize]) -> f64 {
184 let x_start = indices.len() - pdg_ids.len();
190 let indices_scales = &indices[0..x_start];
191 let indices_x = &indices[x_start..];
192
193 let ix = self.ix.iter().zip(indices_x).map(|(ix, &index)| ix[index]);
194 let idx_pid = self.perm.iter().zip(pdg_ids).map(|(&(idx, cc), &pdg_id)| {
195 (
196 idx,
197 if cc {
198 pids::charge_conjugate_pdg_pid(pdg_id)
199 } else {
200 pdg_id
201 },
202 )
203 });
204
205 let fx_prod: f64 = ix
206 .zip(idx_pid)
207 .map(|(ix, (idx, pid))| {
208 let ConvCache1d { xfx, cache, conv } = &mut self.cache.caches[idx];
209
210 let (scale, scale_idx) = match conv.conv_type() {
211 ConvType::UnpolPDF | ConvType::PolPDF => (
212 FAC_IDX,
213 self.scales.fac.idx(indices_scales, &self.scale_dims),
214 ),
215 ConvType::UnpolFF | ConvType::PolFF => (
216 FRG_IDX,
217 self.scales.frg.idx(indices_scales, &self.scale_dims),
218 ),
219 };
220
221 let imu2 = self.imu2[scale][scale_idx];
222 let mu2 = self.cache.mu2[scale][imu2];
223
224 *cache.entry((pid, ix, imu2)).or_insert_with(|| {
225 let x = self.cache.x_grid[ix];
226 xfx(pid, x, mu2) / x
227 })
228 })
229 .product();
230 let alphas_powers = if as_order != 0 {
231 let ren_scale_idx = self.scales.ren.idx(indices_scales, &self.scale_dims);
232 self.cache.alphas_cache[self.imu2[REN_IDX][ren_scale_idx]].powi(as_order.into())
233 } else {
234 1.0
235 };
236
237 fx_prod * alphas_powers
238 }
239
240 pub fn set_grids(&mut self, grid: &Grid, subgrid: &SubgridEnum, xi: (f64, f64, f64)) {
242 let node_values = subgrid.node_values();
243 let kinematics = grid.kinematics();
244 let scales: [_; SCALES_CNT] = grid.scales().into();
245 let xi: [_; SCALES_CNT] = xi.into();
246
247 for (result, values, scale, xi) in izip!(&mut self.imu2, &self.cache.mu2, scales, xi) {
248 result.clear();
249 result.extend(scale.calc(&node_values, kinematics).iter().map(|s| {
250 values
251 .iter()
252 .position(|&value| subgrid::node_value_eq(value, xi * xi * s))
253 .unwrap_or_else(|| unreachable!())
255 }));
256 }
257
258 self.ix = (0..grid.convolutions().len())
259 .map(|idx| {
260 kinematics
261 .iter()
262 .zip(&node_values)
263 .find_map(|(kin, node_values)| {
264 matches!(kin, &Kinematics::X(index) if index == idx).then_some(node_values)
265 })
266 .unwrap_or_else(|| unreachable!())
268 .iter()
269 .map(|&xd| {
270 self.cache
271 .x_grid
272 .iter()
273 .position(|&x| subgrid::node_value_eq(xd, x))
274 .unwrap_or_else(|| unreachable!())
275 })
276 .collect()
277 })
278 .collect();
279
280 self.scale_dims = grid
281 .kinematics()
282 .iter()
283 .zip(node_values)
284 .filter_map(|(kin, node_values)| {
285 matches!(kin, Kinematics::Scale(_)).then_some(node_values.len())
286 })
287 .collect();
288 }
289}
290
291#[repr(C)]
293#[derive(Clone, Copy, Debug, Deserialize, Eq, PartialEq, Serialize)]
294pub enum ConvType {
295 UnpolPDF,
297 PolPDF,
299 UnpolFF,
301 PolFF,
303}
304
305impl ConvType {
306 #[must_use]
308 pub const fn new(polarized: bool, time_like: bool) -> Self {
309 match (polarized, time_like) {
310 (false, false) => Self::UnpolPDF,
311 (false, true) => Self::UnpolFF,
312 (true, false) => Self::PolPDF,
313 (true, true) => Self::PolFF,
314 }
315 }
316
317 #[must_use]
319 pub const fn is_pdf(&self) -> bool {
320 matches!(self, Self::UnpolPDF | Self::PolPDF)
321 }
322
323 #[must_use]
325 pub const fn is_ff(&self) -> bool {
326 matches!(self, Self::UnpolFF | Self::PolFF)
327 }
328}
329
330#[repr(C)]
332#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
333pub struct Conv {
334 conv_type: ConvType,
335 pid: i32,
336}
337
338impl Conv {
339 #[must_use]
341 pub const fn new(conv_type: ConvType, pid: i32) -> Self {
342 Self { conv_type, pid }
343 }
344
345 #[must_use]
347 pub const fn cc(&self) -> Self {
348 Self {
349 conv_type: self.conv_type,
350 pid: pids::charge_conjugate_pdg_pid(self.pid),
351 }
352 }
353
354 #[must_use]
356 pub const fn pid(&self) -> i32 {
357 self.pid
358 }
359
360 #[must_use]
362 pub const fn conv_type(&self) -> ConvType {
363 self.conv_type
364 }
365}
366
367#[cfg(test)]
368mod tests {
369 use super::*;
370
371 #[test]
372 fn conv_cc() {
373 assert_eq!(
374 Conv::new(ConvType::UnpolPDF, 2212).cc(),
375 Conv::new(ConvType::UnpolPDF, -2212)
376 );
377 assert_eq!(
378 Conv::new(ConvType::PolPDF, 2212).cc(),
379 Conv::new(ConvType::PolPDF, -2212)
380 );
381 assert_eq!(
382 Conv::new(ConvType::UnpolFF, 2212).cc(),
383 Conv::new(ConvType::UnpolFF, -2212)
384 );
385 assert_eq!(
386 Conv::new(ConvType::PolFF, 2212).cc(),
387 Conv::new(ConvType::PolFF, -2212)
388 );
389 }
390
391 #[test]
392 fn conv_pid() {
393 assert_eq!(Conv::new(ConvType::UnpolPDF, 2212).pid(), 2212);
394 assert_eq!(Conv::new(ConvType::PolPDF, 2212).pid(), 2212);
395 assert_eq!(Conv::new(ConvType::UnpolFF, 2212).pid(), 2212);
396 assert_eq!(Conv::new(ConvType::PolFF, 2212).pid(), 2212);
397 }
398}