use ndarray::{s, Array1, Array6, ArrayView2};
use serde::{Deserialize, Serialize};
use super::interpolator::InterpolationConfig;
#[derive(Debug, Clone, Copy, Deserialize, Serialize)]
pub struct ParamRange {
pub min: f64,
pub max: f64,
}
impl ParamRange {
pub fn new(min: f64, max: f64) -> Self {
Self { min, max }
}
pub fn contains(&self, value: f64) -> bool {
value >= self.min && value <= self.max
}
}
pub struct RangeParameters {
pub nucleons: ParamRange,
pub alphas: ParamRange,
pub kt: ParamRange,
pub x: ParamRange,
pub q2: ParamRange,
}
impl RangeParameters {
pub fn new(
nucleons: ParamRange,
alphas: ParamRange,
kt: ParamRange,
x: ParamRange,
q2: ParamRange,
) -> Self {
Self {
nucleons,
alphas,
kt,
x,
q2,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SubGrid {
pub xs: Array1<f64>,
pub q2s: Array1<f64>,
pub kts: Array1<f64>,
pub grid: Array6<f64>,
pub nucleons: Array1<f64>,
pub alphas: Array1<f64>,
pub nucleons_range: ParamRange,
pub alphas_range: ParamRange,
pub kt_range: ParamRange,
pub x_range: ParamRange,
pub q2_range: ParamRange,
}
impl SubGrid {
pub fn new(
nucleon_numbers: Vec<f64>,
alphas_values: Vec<f64>,
kt_subgrid: Vec<f64>,
x_subgrid: Vec<f64>,
q2_subgrid: Vec<f64>,
nflav: usize,
grid_data: Vec<f64>,
) -> Self {
let xs_range = ParamRange::new(*x_subgrid.first().unwrap(), *x_subgrid.last().unwrap());
let q2s_range = ParamRange::new(*q2_subgrid.first().unwrap(), *q2_subgrid.last().unwrap());
let kts_range = ParamRange::new(*kt_subgrid.first().unwrap(), *kt_subgrid.last().unwrap());
let ncs_range = ParamRange::new(
*nucleon_numbers.first().unwrap(),
*nucleon_numbers.last().unwrap(),
);
let as_range = ParamRange::new(
*alphas_values.first().unwrap(),
*alphas_values.last().unwrap(),
);
let subgrid = Array6::from_shape_vec(
(
nucleon_numbers.len(),
alphas_values.len(),
kt_subgrid.len(),
x_subgrid.len(),
q2_subgrid.len(),
nflav,
),
grid_data,
)
.expect("Failed to create grid")
.permuted_axes([0, 1, 5, 2, 3, 4])
.as_standard_layout()
.to_owned();
Self {
xs: Array1::from_vec(x_subgrid),
q2s: Array1::from_vec(q2_subgrid),
kts: Array1::from_vec(kt_subgrid),
grid: subgrid,
nucleons: Array1::from_vec(nucleon_numbers),
alphas: Array1::from_vec(alphas_values),
nucleons_range: ncs_range,
alphas_range: as_range,
kt_range: kts_range,
x_range: xs_range,
q2_range: q2s_range,
}
}
pub fn contains_point(&self, points: &[f64]) -> bool {
let (expected_len, ranges) = match self.interpolation_config() {
InterpolationConfig::TwoD => (2, vec![]),
InterpolationConfig::ThreeDNucleons => (3, vec![&self.nucleons_range]),
InterpolationConfig::ThreeDAlphas => (3, vec![&self.alphas_range]),
InterpolationConfig::ThreeDKt => (3, vec![&self.kt_range]),
InterpolationConfig::FourDNucleonsAlphas => {
(4, vec![&self.nucleons_range, &self.alphas_range])
}
InterpolationConfig::FourDNucleonsKt => (4, vec![&self.nucleons_range, &self.kt_range]),
InterpolationConfig::FourDAlphasKt => (4, vec![&self.alphas_range, &self.kt_range]),
InterpolationConfig::FiveD => (
5,
vec![&self.nucleons_range, &self.alphas_range, &self.kt_range],
),
};
points.len() == expected_len
&& self.x_range.contains(points[expected_len - 2])
&& self.q2_range.contains(points[expected_len - 1])
&& ranges
.iter()
.zip(points)
.all(|(range, &point)| range.contains(point))
}
pub fn distance_to_point(&self, points: &[f64]) -> f64 {
self.parameter_ranges()
.iter()
.zip(points)
.map(|(range, &point)| match point {
p if p < range.min => (range.min - p) * (range.min - p),
p if p > range.max => (p - range.max) * (p - range.max),
_ => 0.0,
})
.sum()
}
fn parameter_ranges(&self) -> Vec<ParamRange> {
let mut ranges = match self.interpolation_config() {
InterpolationConfig::TwoD => vec![],
InterpolationConfig::ThreeDNucleons => vec![self.nucleons_range],
InterpolationConfig::ThreeDAlphas => vec![self.alphas_range],
InterpolationConfig::ThreeDKt => vec![self.kt_range],
InterpolationConfig::FourDNucleonsAlphas => {
vec![self.nucleons_range, self.alphas_range]
}
InterpolationConfig::FourDNucleonsKt => vec![self.nucleons_range, self.kt_range],
InterpolationConfig::FourDAlphasKt => vec![self.alphas_range, self.kt_range],
InterpolationConfig::FiveD => {
vec![self.nucleons_range, self.alphas_range, self.kt_range]
}
};
ranges.extend([self.x_range, self.q2_range]);
ranges
}
pub fn interpolation_config(&self) -> InterpolationConfig {
InterpolationConfig::from_dimensions(self.nucleons.len(), self.alphas.len(), self.kts.len())
}
pub fn ranges(&self) -> RangeParameters {
RangeParameters::new(
self.nucleons_range,
self.alphas_range,
self.kt_range,
self.x_range,
self.q2_range,
)
}
pub fn grid_slice(&self, pid_index: usize) -> ArrayView2<f64> {
match self.interpolation_config() {
InterpolationConfig::TwoD => self.grid.slice(s![0, 0, pid_index, 0, .., ..]),
_ => panic!("grid_slice only valid for 2D interpolation"),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_param_range() {
let range = ParamRange::new(1.0, 10.0);
assert!(range.contains(5.0));
assert!(!range.contains(15.0));
}
}