use ndarray::{Array1, Array3, ArrayView2};
use crate::terms::basis::{BasisError, MaternNu};
#[derive(Debug, Clone)]
pub enum RadialInputKernel {
Matern { length_scale: f64, nu: MaternNu },
DuchonHybrid {
length_scale: f64,
p_order: usize,
s_order: usize,
dim: usize,
},
DuchonPure {
block_order: usize,
p_order: usize,
s_order: usize,
dim: usize,
},
ThinPlate { length_scale: f64, dim: usize },
}
impl RadialInputKernel {
pub const fn dim(&self) -> usize {
match self {
RadialInputKernel::Matern { .. } => {
0
}
RadialInputKernel::DuchonHybrid { dim, .. }
| RadialInputKernel::DuchonPure { dim, .. }
| RadialInputKernel::ThinPlate { dim, .. } => *dim,
}
}
}
pub fn contract_input_loc_gradient(
grad_phi: ArrayView2<'_, f64>,
jet: &Array3<f64>,
) -> Result<Array1<f64>, BasisError> {
let n_obs = jet.shape()[0];
let n_centers = jet.shape()[1];
let d = jet.shape()[2];
if grad_phi.shape() != [n_obs, n_centers] {
return Err(BasisError::DimensionMismatch(format!(
"contract_input_loc_gradient: grad_phi shape {:?} != expected {:?}",
grad_phi.shape(),
[n_obs, n_centers]
)));
}
let mut grad_t = Array1::<f64>::zeros(n_obs * d);
for n in 0..n_obs {
for a in 0..d {
let mut acc = 0.0_f64;
for k in 0..n_centers {
acc += grad_phi[[n, k]] * jet[[n, k, a]];
}
grad_t[n * d + a] = acc;
}
}
Ok(grad_t)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::terms::basis::{RadialScalarKind, duchon_partial_fraction_coeffs};
use ndarray::array;
fn into_scalar_kind(kernel: &RadialInputKernel) -> RadialScalarKind {
match kernel {
RadialInputKernel::Matern { length_scale, nu } => RadialScalarKind::Matern {
length_scale: *length_scale,
nu: *nu,
},
RadialInputKernel::DuchonHybrid {
length_scale,
p_order,
s_order,
dim,
} => {
let kappa = 1.0 / length_scale.max(1e-300);
let coeffs = duchon_partial_fraction_coeffs(*p_order, *s_order, kappa);
RadialScalarKind::Duchon {
length_scale: *length_scale,
p_order: *p_order,
s_order: *s_order,
dim: *dim,
coeffs,
}
}
RadialInputKernel::DuchonPure {
block_order,
p_order,
s_order,
dim,
} => RadialScalarKind::PureDuchon {
block_order: *block_order,
p_order: *p_order,
s_order: *s_order,
dim: *dim,
},
RadialInputKernel::ThinPlate { length_scale, dim } => RadialScalarKind::ThinPlate {
length_scale: *length_scale,
dim: *dim,
},
}
}
#[test]
fn contract_input_loc_gradient_matches_einsum() {
let jet = Array3::from_shape_vec(
(2, 3, 2),
vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
],
)
.unwrap();
let grad_phi = array![[1.0_f64, 0.0, 1.0], [0.0, 1.0, 0.0]];
let out = contract_input_loc_gradient(grad_phi.view(), &jet).unwrap();
assert_eq!(out[0], 6.0);
assert_eq!(out[1], 8.0);
assert_eq!(out[2], 9.0);
assert_eq!(out[3], 10.0);
}
#[test]
fn matern_half_finite_difference_diverges_near_collision() {
let kernel = RadialInputKernel::Matern {
length_scale: 1.0,
nu: MaternNu::Half,
};
let kind = into_scalar_kind(&kernel);
let eps = 1e-8_f64;
let (_, q, _) = kind
.eval_design_triplet(eps)
.expect("ν=1/2 at r=ε is finite");
assert!(
q.abs() > 1e6,
"expected divergent q for Matérn ν=1/2 near r=0, got {q}"
);
}
#[test]
fn thin_plate_collision_2d_finite_difference_diverges() {
let kernel = RadialInputKernel::ThinPlate {
length_scale: 1.0,
dim: 2,
};
let kind = into_scalar_kind(&kernel);
let eps = 1e-10_f64;
let (_, q, _) = kind
.eval_design_triplet(eps)
.expect("TPS dim=2 at r=ε is finite");
assert!(
q.abs() > 10.0,
"expected large |q| for TPS dim=2 near r=0, got {q}"
);
}
}