use super::boc::Kinematics;
use super::boc::Scales;
use super::grid::Grid;
use super::pids;
use super::subgrid::{self, Subgrid, SubgridEnum};
use itertools::izip;
use rustc_hash::FxHashMap;
use serde::{Deserialize, Serialize};
const REN_IDX: usize = 0;
const FAC_IDX: usize = 1;
const FRG_IDX: usize = 2;
const SCALES_CNT: usize = 3;
struct ConvCache1d<'a> {
xfx: &'a mut dyn FnMut(i32, f64, f64) -> f64,
cache: FxHashMap<(i32, usize, usize), f64>,
conv: Conv,
}
pub struct ConvolutionCache<'a> {
caches: Vec<ConvCache1d<'a>>,
alphas: &'a mut dyn FnMut(f64) -> f64,
alphas_cache: Vec<f64>,
mu2: [Vec<f64>; SCALES_CNT],
x_grid: Vec<f64>,
}
impl<'a> ConvolutionCache<'a> {
pub fn new(
convolutions: Vec<Conv>,
xfx: Vec<&'a mut dyn FnMut(i32, f64, f64) -> f64>,
alphas: &'a mut dyn FnMut(f64) -> f64,
) -> Self {
Self {
caches: xfx
.into_iter()
.zip(convolutions)
.map(|(xfx, conv)| ConvCache1d {
xfx,
cache: FxHashMap::default(),
conv,
})
.collect(),
alphas,
alphas_cache: Vec::new(),
mu2: [const { Vec::new() }; SCALES_CNT],
x_grid: Vec::new(),
}
}
pub(crate) fn new_grid_conv_cache<'b>(
&'b mut self,
grid: &Grid,
xi: &[(f64, f64, f64)],
) -> GridConvCache<'a, 'b> {
self.clear();
let scales: [_; SCALES_CNT] = grid.scales().into();
let xi: Vec<_> = (0..SCALES_CNT)
.map(|idx| {
let mut vars: Vec<_> = xi
.iter()
.map(|&x| <[_; SCALES_CNT]>::from(x)[idx])
.collect();
vars.sort_by(f64::total_cmp);
vars.dedup();
vars
})
.collect();
for (result, scale, xi) in izip!(&mut self.mu2, scales, xi) {
result.clear();
result.extend(
grid.subgrids()
.iter()
.filter(|subgrid| !subgrid.is_empty())
.flat_map(|subgrid| {
scale
.calc(&subgrid.node_values(), grid.kinematics())
.into_owned()
})
.flat_map(|scale| xi.iter().map(move |&xi| xi * xi * scale)),
);
result.sort_by(f64::total_cmp);
result.dedup();
}
let mut x_grid: Vec<_> = grid
.subgrids()
.iter()
.filter(|subgrid| !subgrid.is_empty())
.flat_map(|subgrid| {
grid.kinematics()
.iter()
.zip(subgrid.node_values())
.filter(|(kin, _)| matches!(kin, Kinematics::X(_)))
.flat_map(|(_, node_values)| node_values)
})
.collect();
x_grid.sort_by(f64::total_cmp);
x_grid.dedup();
self.alphas_cache = self.mu2[REN_IDX]
.iter()
.map(|&mur2| (self.alphas)(mur2))
.collect();
self.x_grid = x_grid;
let perm = grid
.convolutions()
.iter()
.enumerate()
.map(|(max_idx, grid_conv)| {
self.caches
.iter()
.take(max_idx + 1)
.enumerate()
.rev()
.find_map(|(idx, ConvCache1d { conv, .. })| {
if grid_conv == conv {
Some((idx, false))
} else if *grid_conv == conv.cc() {
Some((idx, true))
} else {
None
}
})
.unwrap_or_else(|| {
panic!(
"couldn't match {grid_conv:?} with a convolution function from cache {:?}",
self.caches
.iter()
.map(|cache| cache.conv.clone())
.collect::<Vec<_>>()
)
})
})
.collect();
GridConvCache {
cache: self,
perm,
imu2: [const { Vec::new() }; SCALES_CNT],
scales: grid.scales().clone(),
ix: Vec::new(),
scale_dims: Vec::new(),
}
}
pub fn clear(&mut self) {
self.alphas_cache.clear();
for xfx_cache in &mut self.caches {
xfx_cache.cache.clear();
}
for scales in &mut self.mu2 {
scales.clear();
}
self.x_grid.clear();
}
}
pub struct GridConvCache<'a, 'b> {
cache: &'b mut ConvolutionCache<'a>,
perm: Vec<(usize, bool)>,
imu2: [Vec<usize>; SCALES_CNT],
scales: Scales,
ix: Vec<Vec<usize>>,
scale_dims: Vec<usize>,
}
impl GridConvCache<'_, '_> {
pub fn as_fx_prod(&mut self, pdg_ids: &[i32], as_order: u8, indices: &[usize]) -> f64 {
let x_start = indices.len() - pdg_ids.len();
let indices_scales = &indices[0..x_start];
let indices_x = &indices[x_start..];
let ix = self.ix.iter().zip(indices_x).map(|(ix, &index)| ix[index]);
let idx_pid = self.perm.iter().zip(pdg_ids).map(|(&(idx, cc), &pdg_id)| {
(
idx,
if cc {
pids::charge_conjugate_pdg_pid(pdg_id)
} else {
pdg_id
},
)
});
let fx_prod: f64 = ix
.zip(idx_pid)
.map(|(ix, (idx, pid))| {
let ConvCache1d { xfx, cache, conv } = &mut self.cache.caches[idx];
let (scale, scale_idx) = match conv.conv_type() {
ConvType::UnpolPDF | ConvType::PolPDF => (
FAC_IDX,
self.scales.fac.idx(indices_scales, &self.scale_dims),
),
ConvType::UnpolFF | ConvType::PolFF => (
FRG_IDX,
self.scales.frg.idx(indices_scales, &self.scale_dims),
),
};
let imu2 = self.imu2[scale][scale_idx];
let mu2 = self.cache.mu2[scale][imu2];
*cache.entry((pid, ix, imu2)).or_insert_with(|| {
let x = self.cache.x_grid[ix];
xfx(pid, x, mu2) / x
})
})
.product();
let alphas_powers = if as_order != 0 {
let ren_scale_idx = self.scales.ren.idx(indices_scales, &self.scale_dims);
self.cache.alphas_cache[self.imu2[REN_IDX][ren_scale_idx]].powi(as_order.into())
} else {
1.0
};
fx_prod * alphas_powers
}
pub fn set_grids(&mut self, grid: &Grid, subgrid: &SubgridEnum, xi: (f64, f64, f64)) {
let node_values = subgrid.node_values();
let kinematics = grid.kinematics();
let scales: [_; SCALES_CNT] = grid.scales().into();
let xi: [_; SCALES_CNT] = xi.into();
for (result, values, scale, xi) in izip!(&mut self.imu2, &self.cache.mu2, scales, xi) {
result.clear();
result.extend(scale.calc(&node_values, kinematics).iter().map(|s| {
values
.iter()
.position(|&value| subgrid::node_value_eq(value, xi * xi * s))
.unwrap_or_else(|| unreachable!())
}));
}
self.ix = (0..grid.convolutions().len())
.map(|idx| {
kinematics
.iter()
.zip(&node_values)
.find_map(|(kin, node_values)| {
matches!(kin, &Kinematics::X(index) if index == idx).then_some(node_values)
})
.unwrap_or_else(|| unreachable!())
.iter()
.map(|&xd| {
self.cache
.x_grid
.iter()
.position(|&x| subgrid::node_value_eq(xd, x))
.unwrap_or_else(|| unreachable!())
})
.collect()
})
.collect();
self.scale_dims = grid
.kinematics()
.iter()
.zip(node_values)
.filter_map(|(kin, node_values)| {
matches!(kin, Kinematics::Scale(_)).then_some(node_values.len())
})
.collect();
}
}
#[repr(C)]
#[derive(Clone, Copy, Debug, Deserialize, Eq, PartialEq, Serialize)]
pub enum ConvType {
UnpolPDF,
PolPDF,
UnpolFF,
PolFF,
}
impl ConvType {
#[must_use]
pub const fn new(polarized: bool, time_like: bool) -> Self {
match (polarized, time_like) {
(false, false) => Self::UnpolPDF,
(false, true) => Self::UnpolFF,
(true, false) => Self::PolPDF,
(true, true) => Self::PolFF,
}
}
#[must_use]
pub const fn is_pdf(&self) -> bool {
matches!(self, Self::UnpolPDF | Self::PolPDF)
}
#[must_use]
pub const fn is_ff(&self) -> bool {
matches!(self, Self::UnpolFF | Self::PolFF)
}
}
#[repr(C)]
#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
pub struct Conv {
conv_type: ConvType,
pid: i32,
}
impl Conv {
#[must_use]
pub const fn new(conv_type: ConvType, pid: i32) -> Self {
Self { conv_type, pid }
}
#[must_use]
pub const fn cc(&self) -> Self {
Self {
conv_type: self.conv_type,
pid: pids::charge_conjugate_pdg_pid(self.pid),
}
}
#[must_use]
pub const fn pid(&self) -> i32 {
self.pid
}
#[must_use]
pub const fn conv_type(&self) -> ConvType {
self.conv_type
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn conv_cc() {
assert_eq!(
Conv::new(ConvType::UnpolPDF, 2212).cc(),
Conv::new(ConvType::UnpolPDF, -2212)
);
assert_eq!(
Conv::new(ConvType::PolPDF, 2212).cc(),
Conv::new(ConvType::PolPDF, -2212)
);
assert_eq!(
Conv::new(ConvType::UnpolFF, 2212).cc(),
Conv::new(ConvType::UnpolFF, -2212)
);
assert_eq!(
Conv::new(ConvType::PolFF, 2212).cc(),
Conv::new(ConvType::PolFF, -2212)
);
}
#[test]
fn conv_pid() {
assert_eq!(Conv::new(ConvType::UnpolPDF, 2212).pid(), 2212);
assert_eq!(Conv::new(ConvType::PolPDF, 2212).pid(), 2212);
assert_eq!(Conv::new(ConvType::UnpolFF, 2212).pid(), 2212);
assert_eq!(Conv::new(ConvType::PolFF, 2212).pid(), 2212);
}
}