use super::*;
pub fn build_spherical_spline_basis(
data: ArrayView2<'_, f64>,
spec: &SphericalSplineBasisSpec,
) -> Result<BasisBuildResult, BasisError> {
if matches!(spec.method, SphereMethod::Harmonic) {
return build_spherical_harmonic_basis(data, spec);
}
if matches!(spec.wahba_kernel, SphereWahbaKernel::Pseudo) {
let mut harmonic_spec = spec.clone();
harmonic_spec.method = SphereMethod::Harmonic;
harmonic_spec.penalty_order = 2;
harmonic_spec.max_degree = Some(
spec.max_degree
.unwrap_or_else(|| harmonic_degree_for_wahba_basis_width(spec, data.nrows())),
);
return build_spherical_harmonic_basis(data, &harmonic_spec);
}
validate_lat_lon_matrix(data, "spherical spline", spec.radians)?;
if !(1..=4).contains(&spec.penalty_order) {
crate::bail_invalid_basis!(
"spherical spline penalty_order must be one of 1, 2, 3, 4; got {}",
spec.penalty_order
);
}
let centers = match realized_center_strategy(&spec.center_strategy) {
CenterStrategy::FarthestPoint { num_centers } => {
select_spherical_farthest_point_centers(data, *num_centers, spec.radians)?
}
_ => select_centers_by_strategy(data, &spec.center_strategy)?,
};
validate_lat_lon_matrix(centers.view(), "spherical spline centers", spec.radians)?;
if centers.nrows() < 2 {
return Err(BasisError::InsufficientColumnsForConstraint {
found: centers.nrows(),
});
}
let center_kernel = spherical_wahba_kernel_matrix_with_kind(
centers.view(),
centers.view(),
spec.penalty_order,
spec.radians,
spec.wahba_kernel,
)?;
let decomposition =
wahba_low_degree_decomposition(centers.view(), spec.radians, center_kernel.view())?;
let raw_kernel_design = spherical_wahba_kernel_matrix_with_kind(
data,
centers.view(),
spec.penalty_order,
spec.radians,
spec.wahba_kernel,
)?;
let raw_design =
build_wahba_decomposed_design(raw_kernel_design.view(), data, spec.radians, &decomposition);
let mut raw_penalty = build_wahba_decomposed_penalty(center_kernel.view(), &decomposition);
let kernel_rank = decomposition.kernel_basis.ncols();
let diag_scale = if kernel_rank > 0 {
(0..kernel_rank)
.map(|i| raw_penalty[[i, i]].abs())
.sum::<f64>()
/ kernel_rank as f64
} else {
0.0
};
if diag_scale.is_finite() && diag_scale > 0.0 {
for i in 0..kernel_rank {
raw_penalty[[i, i]] += 10.0 * diag_scale;
}
}
let raw_width = raw_design.ncols();
let z = match &spec.identifiability {
SphericalSplineIdentifiability::FrozenTransform { transform } => {
if transform.nrows() != raw_width {
crate::bail_dim_basis!(
"frozen spherical identifiability transform mismatch: {} raw basis columns but transform has {} rows",
raw_width,
transform.nrows()
);
}
transform.clone()
}
SphericalSplineIdentifiability::CenterSumToZero => Array2::<f64>::eye(raw_width),
};
let gauge = gam_problem::Gauge::from_block_transforms(&[z.clone()]);
let penalty = gauge.restrict_penalty(&raw_penalty);
let design = DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(
gauge.restrict_design(&raw_design),
));
let (penalty_norm, c_primary) = normalize_penalty(&((&penalty + &penalty.t()) * 0.5));
let mut candidates = vec![PenaltyCandidate {
matrix: penalty_norm,
nullspace_dim_hint: 0,
source: PenaltySource::Primary,
normalization_scale: c_primary,
kronecker_factors: None,
op: None,
}];
if spec.double_penalty {
let null_shrinkage = build_nullspace_shrinkage_penalty(&raw_penalty)?
.map(|block| block.sym_penalty)
.unwrap_or_else(|| Array2::<f64>::zeros((raw_width, raw_width)));
let ridge = gauge.restrict_penalty(&null_shrinkage);
let (ridge_norm, c_ridge) = normalize_penalty(&ridge);
candidates.push(PenaltyCandidate {
matrix: ridge_norm,
nullspace_dim_hint: 0,
source: PenaltySource::DoublePenaltyNullspace,
normalization_scale: c_ridge,
kronecker_factors: None,
op: None,
});
}
let (penalties, nullspace_dims, penaltyinfo, null_eigenvectors, ops) =
filter_active_penalty_candidates_with_ops(candidates)?;
Ok(BasisBuildResult {
design,
penalties,
nullspace_dims,
penaltyinfo,
metadata: BasisMetadata::Sphere {
centers,
penalty_order: spec.penalty_order,
method: SphereMethod::Wahba,
max_degree: None,
wahba_kernel: spec.wahba_kernel,
constraint_transform: Some(z),
},
kronecker_factored: None,
ops,
null_eigenvectors,
joint_null_rotation: None,
})
}
pub(crate) const SPHERE_UNPENALIZED_LOW_DEGREE: usize = 1;
pub(crate) fn harmonic_degree_for_wahba_basis_width(
spec: &SphericalSplineBasisSpec,
n_rows: usize,
) -> usize {
let target = match &spec.center_strategy {
CenterStrategy::Auto(inner) => match inner.as_ref() {
CenterStrategy::FarthestPoint { num_centers }
| CenterStrategy::EqualMass { num_centers }
| CenterStrategy::EqualMassCovarRepresentative { num_centers }
| CenterStrategy::KMeans { num_centers, .. } => *num_centers,
CenterStrategy::UniformGrid { points_per_dim } => points_per_dim.saturating_pow(2),
CenterStrategy::UserProvided(centers) => centers.nrows(),
CenterStrategy::Auto(_) => default_num_centers(n_rows, 2),
},
CenterStrategy::FarthestPoint { num_centers } => *num_centers,
CenterStrategy::EqualMass { num_centers } => *num_centers,
CenterStrategy::EqualMassCovarRepresentative { num_centers } => *num_centers,
CenterStrategy::KMeans { num_centers, .. } => *num_centers,
CenterStrategy::UniformGrid { points_per_dim } => points_per_dim.saturating_pow(2),
CenterStrategy::UserProvided(centers) => centers.nrows(),
}
.max(1);
(1..=32)
.find(|&l| l * (l + 2) >= target)
.unwrap_or_else(|| default_spherical_harmonic_degree(n_rows))
.max(8)
}
fn real_spherical_harmonic_design_up_to_degree(
data: ArrayView2<'_, f64>,
max_degree: usize,
radians: bool,
) -> Array2<f64> {
let p = max_degree * (max_degree + 2);
let to_rad = if radians {
1.0
} else {
std::f64::consts::PI / 180.0
};
let norms = precompute_harmonic_norms(max_degree);
let l_cap = max_degree + 1;
let mut out = Array2::<f64>::zeros((data.nrows(), p));
let mut p_buf = vec![0.0_f64; l_cap * l_cap];
for (i, mut row) in out.outer_iter_mut().enumerate() {
let lat_raw = data[(i, 0)] * to_rad;
let lat = lat_raw.clamp(-std::f64::consts::FRAC_PI_2, std::f64::consts::FRAC_PI_2);
let lon = data[(i, 1)] * to_rad;
fill_real_spherical_harmonics_row(
lat,
lon,
max_degree,
p_buf.as_mut_slice(),
norms.as_slice(),
row.view_mut(),
);
}
out
}
fn orthonormal_column_basis(matrix: ArrayView2<'_, f64>, rel_tol: f64) -> Array2<f64> {
let n = matrix.nrows();
let mut cols: Vec<Vec<f64>> = Vec::new();
let mut scale = 0.0_f64;
for col in matrix.columns() {
scale = scale.max(col.iter().map(|v| v * v).sum::<f64>().sqrt());
}
let tol = rel_tol * scale.max(1.0);
for col in matrix.columns() {
let mut v = col.to_vec();
for _ in 0..2 {
for q in &cols {
let dot = v.iter().zip(q.iter()).map(|(a, b)| a * b).sum::<f64>();
for (vi, qi) in v.iter_mut().zip(q.iter()) {
*vi -= dot * qi;
}
}
}
let norm = v.iter().map(|x| x * x).sum::<f64>().sqrt();
if norm > tol {
for vi in &mut v {
*vi /= norm;
}
cols.push(v);
}
}
let mut q = Array2::<f64>::zeros((n, cols.len()));
for (j, col) in cols.iter().enumerate() {
for i in 0..n {
q[(i, j)] = col[i];
}
}
q
}
fn orthonormal_complement(q: ArrayView2<'_, f64>, rel_tol: f64) -> Array2<f64> {
let n = q.nrows();
let mut cols: Vec<Vec<f64>> = Vec::new();
let tol = rel_tol.max(0.0);
for i in 0..n {
let mut v = vec![0.0_f64; n];
v[i] = 1.0;
for _ in 0..2 {
for q_col in q.columns() {
let dot = v.iter().zip(q_col.iter()).map(|(a, b)| a * b).sum::<f64>();
for (vi, qi) in v.iter_mut().zip(q_col.iter()) {
*vi -= dot * qi;
}
}
for c in &cols {
let dot = v.iter().zip(c.iter()).map(|(a, b)| a * b).sum::<f64>();
for (vi, ci) in v.iter_mut().zip(c.iter()) {
*vi -= dot * ci;
}
}
}
let norm = v.iter().map(|x| x * x).sum::<f64>().sqrt();
if norm > tol {
for vi in &mut v {
*vi /= norm;
}
cols.push(v);
}
}
let mut out = Array2::<f64>::zeros((n, cols.len()));
for (j, col) in cols.iter().enumerate() {
for i in 0..n {
out[[i, j]] = col[i];
}
}
out
}
pub(crate) struct WahbaLowDegreeDecomposition {
pub(crate) kernel_basis: Array2<f64>,
pub(crate) low_degree_centers: Option<Array2<f64>>,
pub(crate) kernel_low_projection: Option<Array2<f64>>,
pub(crate) low_degree_cols: usize,
}
pub(crate) fn wahba_low_degree_decomposition(
centers: ArrayView2<'_, f64>,
radians: bool,
center_kernel: ArrayView2<'_, f64>,
) -> Result<WahbaLowDegreeDecomposition, BasisError> {
let low_cols = SPHERE_UNPENALIZED_LOW_DEGREE * (SPHERE_UNPENALIZED_LOW_DEGREE + 2);
if centers.nrows() <= low_cols {
return Ok(WahbaLowDegreeDecomposition {
kernel_basis: Array2::<f64>::eye(centers.nrows()),
low_degree_centers: None,
kernel_low_projection: None,
low_degree_cols: 0,
});
}
let harmonics = real_spherical_harmonic_design_up_to_degree(
centers,
SPHERE_UNPENALIZED_LOW_DEGREE,
radians,
);
let low_degree_coefficients = solve_spd_columns_ridged(center_kernel, harmonics.view())?;
let low_coeff_basis = orthonormal_column_basis(low_degree_coefficients.view(), 1e-10);
let kernel_basis = orthonormal_complement(low_coeff_basis.view(), 1e-10);
let low_degree_centers = harmonics;
if low_degree_centers.ncols() == 0 || kernel_basis.ncols() == centers.nrows() {
return Ok(WahbaLowDegreeDecomposition {
kernel_basis,
low_degree_centers: None,
kernel_low_projection: None,
low_degree_cols: 0,
});
}
let center_kernel_reduced = center_kernel.dot(&kernel_basis);
let low_normal = low_degree_centers.t().dot(&low_degree_centers);
let low_cross = low_degree_centers.t().dot(¢er_kernel_reduced);
let kernel_low_projection = solve_spd_columns_ridged(low_normal.view(), low_cross.view())?;
let low_degree_cols = low_degree_centers.ncols();
Ok(WahbaLowDegreeDecomposition {
kernel_basis,
low_degree_centers: Some(low_degree_centers),
kernel_low_projection: Some(kernel_low_projection),
low_degree_cols,
})
}
fn hstack_dense(left: ArrayView2<'_, f64>, right: ArrayView2<'_, f64>) -> Array2<f64> {
let n = left.nrows();
assert_eq!(right.nrows(), n);
let mut out = Array2::<f64>::zeros((n, left.ncols() + right.ncols()));
out.slice_mut(s![.., 0..left.ncols()]).assign(&left);
out.slice_mut(s![.., left.ncols()..]).assign(&right);
out
}
fn build_wahba_decomposed_design(
raw_kernel_design: ArrayView2<'_, f64>,
data: ArrayView2<'_, f64>,
radians: bool,
decomposition: &WahbaLowDegreeDecomposition,
) -> Array2<f64> {
let mut kernel_design = raw_kernel_design.dot(&decomposition.kernel_basis);
match (
&decomposition.low_degree_centers,
&decomposition.kernel_low_projection,
) {
(Some(low_degree_centers), Some(kernel_low_projection)) => {
let raw_low = real_spherical_harmonic_design_up_to_degree(
data,
SPHERE_UNPENALIZED_LOW_DEGREE,
radians,
);
assert_eq!(
raw_low.ncols(),
low_degree_centers.ncols(),
"low-degree spherical harmonic design width must match its centers"
);
let low_design = raw_low;
kernel_design -= &low_design.dot(kernel_low_projection);
hstack_dense(kernel_design.view(), low_design.view())
}
_ => kernel_design,
}
}
pub(crate) fn build_wahba_decomposed_jet(
raw_kernel_jet: &Array3<f64>,
low_jet: Option<&Array3<f64>>,
decomposition: &WahbaLowDegreeDecomposition,
) -> Array3<f64> {
let n = raw_kernel_jet.shape()[0];
match (
&decomposition.kernel_low_projection,
low_jet,
&decomposition.low_degree_centers,
) {
(Some(kernel_low_projection), Some(low_jet), Some(_)) => {
let kernel_cols = decomposition.kernel_basis.ncols();
let low_cols = decomposition.low_degree_cols;
let mut out = Array3::<f64>::zeros((n, kernel_cols + low_cols, 2));
for axis in 0..2 {
let raw_axis = raw_kernel_jet.index_axis(ndarray::Axis(2), axis);
let low_axis = low_jet.index_axis(ndarray::Axis(2), axis);
let kernel_axis =
raw_axis.dot(&decomposition.kernel_basis) - low_axis.dot(kernel_low_projection);
out.slice_mut(s![.., 0..kernel_cols, axis])
.assign(&kernel_axis);
out.slice_mut(s![.., kernel_cols.., axis]).assign(&low_axis);
}
out
}
_ => {
let kernel_cols = decomposition.kernel_basis.ncols();
let mut out = Array3::<f64>::zeros((n, kernel_cols, 2));
for axis in 0..2 {
let raw_axis = raw_kernel_jet.index_axis(ndarray::Axis(2), axis);
let projected = raw_axis.dot(&decomposition.kernel_basis);
out.slice_mut(s![.., .., axis]).assign(&projected);
}
out
}
}
}
fn build_wahba_decomposed_penalty(
center_kernel: ArrayView2<'_, f64>,
decomposition: &WahbaLowDegreeDecomposition,
) -> Array2<f64> {
let kernel_penalty = decomposition
.kernel_basis
.t()
.dot(¢er_kernel.dot(&decomposition.kernel_basis));
let p = kernel_penalty.nrows() + decomposition.low_degree_cols;
let mut out = Array2::<f64>::zeros((p, p));
out.slice_mut(s![0..kernel_penalty.nrows(), 0..kernel_penalty.ncols()])
.assign(&kernel_penalty);
(&out + &out.t()) * 0.5
}
fn solve_spd_columns_ridged(
a: ArrayView2<'_, f64>,
b: ArrayView2<'_, f64>,
) -> Result<Array2<f64>, BasisError> {
use faer::Side;
use gam_linalg::faer_ndarray::{FaerArrayView, FaerLlt, FaerSolve};
let n = a.nrows();
if n == 0 || a.ncols() != n || b.nrows() != n {
crate::bail_dim_basis!(
"ridged SPD solve needs square A and matching RHS rows, got A={}x{} and B={}x{}",
a.nrows(),
a.ncols(),
b.nrows(),
b.ncols()
);
}
let trace: f64 = (0..n).map(|i| a[[i, i]].abs()).sum();
let ridge = if trace.is_finite() && trace > 0.0 {
1e-8 * trace / n as f64
} else {
1e-12
};
let mut m = a.to_owned();
for i in 0..n {
m[[i, i]] += ridge;
}
let mview = FaerArrayView::new(&m);
let factor = FaerLlt::new(mview.as_ref(), Side::Lower).map_err(|err| {
BasisError::InvalidInput(format!(
"sphere Wahba low-degree Gram solve failed after ridge {ridge:.3e}: {err:?}"
))
})?;
let rhs_owned = b.to_owned();
let rhs = FaerArrayView::new(&rhs_owned);
let solved = factor.solve(rhs.as_ref());
let mut out = Array2::<f64>::zeros((n, b.ncols()));
for j in 0..b.ncols() {
for i in 0..n {
out[[i, j]] = solved[(i, j)];
}
}
if !out.iter().all(|v| v.is_finite()) {
return Err(BasisError::InvalidInput(
"sphere Wahba low-degree Gram solve produced non-finite coefficients".to_string(),
));
}
Ok(out)
}
pub(crate) fn precompute_harmonic_norms(max_degree: usize) -> Vec<f64> {
let l_cap = max_degree + 1;
let sqrt2 = std::f64::consts::SQRT_2;
let mut out = vec![0.0_f64; l_cap * l_cap];
for l in 0..=max_degree {
let mut ratio = 1.0_f64;
let base = ((2 * l + 1) as f64) / (4.0 * std::f64::consts::PI);
out[l * l_cap] = base.sqrt(); for m in 1..=l {
ratio /= ((l - m + 1) * (l + m)) as f64;
out[l * l_cap + m] = sqrt2 * (base * ratio).sqrt();
}
}
out
}
pub(crate) fn fill_real_spherical_harmonics_row(
lat: f64,
lon: f64,
max_degree: usize,
p_buf: &mut [f64],
norms: &[f64],
mut row: ndarray::ArrayViewMut1<'_, f64>,
) {
let l_cap = max_degree + 1;
assert_eq!(p_buf.len(), l_cap * l_cap);
assert_eq!(norms.len(), l_cap * l_cap);
let x = lat.sin();
let somx2 = (1.0 - x * x).max(0.0).sqrt();
for slot in p_buf.iter_mut() {
*slot = 0.0;
}
let idx = |l: usize, m: usize| l * l_cap + m;
p_buf[idx(0, 0)] = 1.0;
for m in 1..=max_degree {
p_buf[idx(m, m)] = -((2 * m - 1) as f64) * somx2 * p_buf[idx(m - 1, m - 1)];
}
for m in 0..max_degree {
p_buf[idx(m + 1, m)] = ((2 * m + 1) as f64) * x * p_buf[idx(m, m)];
}
for m in 0..=max_degree {
for l in (m + 2)..=max_degree {
p_buf[idx(l, m)] = (((2 * l - 1) as f64) * x * p_buf[idx(l - 1, m)]
- ((l + m - 1) as f64) * p_buf[idx(l - 2, m)])
/ ((l - m) as f64);
}
}
let (sin1, cos1) = lon.sin_cos();
let mut sin_buf = [0.0_f64; 33];
let mut cos_buf = [0.0_f64; 33];
sin_buf[0] = 0.0;
cos_buf[0] = 1.0;
if max_degree >= 1 {
sin_buf[1] = sin1;
cos_buf[1] = cos1;
}
let two_cos1 = 2.0 * cos1;
for m in 2..=max_degree {
sin_buf[m] = two_cos1 * sin_buf[m - 1] - sin_buf[m - 2];
cos_buf[m] = two_cos1 * cos_buf[m - 1] - cos_buf[m - 2];
}
let mut col = 0usize;
for l in 1..=max_degree {
for m_pos in (1..=l).rev() {
row[col] = norms[idx(l, m_pos)] * p_buf[idx(l, m_pos)] * sin_buf[m_pos];
col += 1;
}
row[col] = norms[idx(l, 0)] * p_buf[idx(l, 0)];
col += 1;
for m in 1..=l {
row[col] = norms[idx(l, m)] * p_buf[idx(l, m)] * cos_buf[m];
col += 1;
}
}
}
pub fn default_spherical_harmonic_degree(n_rows: usize) -> usize {
let target_cols = ((n_rows as f64) * 0.25).min(50.0).max(3.0);
let mut l = 1usize;
while (l as f64) * (l as f64 + 2.0) < target_cols && l < 12 {
l += 1;
}
l.max(2)
}
pub(crate) fn build_spherical_harmonic_basis(
data: ArrayView2<'_, f64>,
spec: &SphericalSplineBasisSpec,
) -> Result<BasisBuildResult, BasisError> {
validate_lat_lon_matrix(data, "spherical-harmonic", spec.radians)?;
let n = data.nrows();
let l_max = spec
.max_degree
.unwrap_or_else(|| default_spherical_harmonic_degree(n));
if l_max < 1 {
crate::bail_invalid_basis!("spherical-harmonic max_degree must be >= 1");
}
if l_max > 32 {
crate::bail_invalid_basis!("spherical-harmonic max_degree {l_max} too large; cap is 32");
}
if !(1..=4).contains(&spec.penalty_order) {
crate::bail_invalid_basis!(
"spherical-harmonic penalty_order must be one of 1, 2, 3, 4; got {}",
spec.penalty_order
);
}
let p = l_max * (l_max + 2);
let to_rad = if spec.radians {
1.0
} else {
std::f64::consts::PI / 180.0
};
let norms = precompute_harmonic_norms(l_max);
let l_cap = l_max + 1;
let mut design = Array2::<f64>::zeros((n, p));
{
let mut row_blocks = design
.axis_chunks_iter_mut(ndarray::Axis(0), 1024)
.collect::<Vec<_>>();
let chunks_iter = row_blocks.par_iter_mut().enumerate();
let chunk_size = 1024usize;
chunks_iter.try_for_each(|(chunk_idx, block)| -> Result<(), BasisError> {
let mut p_buf = vec![0.0_f64; l_cap * l_cap];
let row_offset = chunk_idx * chunk_size;
for (local_i, mut out_row) in block.outer_iter_mut().enumerate() {
let i = row_offset + local_i;
let lat_raw = data[(i, 0)] * to_rad;
let lat = lat_raw.clamp(-std::f64::consts::FRAC_PI_2, std::f64::consts::FRAC_PI_2);
let lon = data[(i, 1)] * to_rad;
fill_real_spherical_harmonics_row(
lat,
lon,
l_max,
p_buf.as_mut_slice(),
norms.as_slice(),
out_row.view_mut(),
);
}
Ok(())
})?;
}
let mut penalty = Array2::<f64>::zeros((p, p));
let mut col = 0usize;
for l in 1..=l_max {
let laplace = l as f64 * (l as f64 + 1.0);
let eig = laplace.powi(spec.penalty_order as i32);
for _ in 0..(2 * l + 1) {
penalty[[col, col]] = eig;
col += 1;
}
}
let mut ridge = Array2::<f64>::zeros((p, p));
for c in 0..p {
ridge[[c, c]] = 1.0;
}
let transform = match &spec.identifiability {
SphericalSplineIdentifiability::FrozenTransform { transform } => {
if transform.nrows() != p {
crate::bail_dim_basis!(
"frozen spherical-harmonic identifiability transform mismatch: {p} basis columns but transform has {} rows",
transform.nrows()
);
}
transform.clone()
}
SphericalSplineIdentifiability::CenterSumToZero => Array2::<f64>::eye(p),
};
let gauge = gam_problem::Gauge::from_block_transforms(&[transform.clone()]);
let design = gauge.restrict_design(&design);
let penalty = gauge.restrict_penalty(&penalty);
let ridge = gauge.restrict_penalty(&ridge);
let mut candidates = vec![PenaltyCandidate {
matrix: penalty,
nullspace_dim_hint: 0,
source: PenaltySource::Primary,
normalization_scale: 1.0,
kronecker_factors: None,
op: None,
}];
if spec.double_penalty {
let (ridge_norm, c_ridge) = normalize_penalty(&ridge);
candidates.push(PenaltyCandidate {
matrix: ridge_norm,
nullspace_dim_hint: 0,
source: PenaltySource::DoublePenaltyNullspace,
normalization_scale: c_ridge,
kronecker_factors: None,
op: None,
});
}
let (penalties, nullspace_dims, penaltyinfo, null_eigenvectors, ops) =
filter_active_penalty_candidates_with_ops(candidates)?;
Ok(BasisBuildResult {
design: DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(design)),
penalties,
nullspace_dims,
penaltyinfo,
metadata: BasisMetadata::Sphere {
centers: Array2::<f64>::zeros((0, 2)),
penalty_order: spec.penalty_order,
method: SphereMethod::Harmonic,
max_degree: Some(l_max),
wahba_kernel: spec.wahba_kernel,
constraint_transform: Some(transform),
},
kronecker_factored: None,
ops,
null_eigenvectors,
joint_null_rotation: None,
})
}
pub fn build_matern_basis(
data: ArrayView2<'_, f64>,
spec: &MaternBasisSpec,
) -> Result<BasisBuildResult, BasisError> {
let mut workspace = BasisWorkspace::default();
build_matern_basiswithworkspace(data, spec, &mut workspace)
}
pub fn build_matern_basis_literal_aniso(
data: ArrayView2<'_, f64>,
spec: &MaternBasisSpec,
) -> Result<BasisBuildResult, BasisError> {
let mut workspace = BasisWorkspace::default();
build_matern_basis_seeded(data, spec, &mut workspace, AnisoSeedMode::Literal)
}
pub fn build_matern_basiswithworkspace(
data: ArrayView2<'_, f64>,
spec: &MaternBasisSpec,
workspace: &mut BasisWorkspace,
) -> Result<BasisBuildResult, BasisError> {
build_matern_basis_seeded(data, spec, workspace, AnisoSeedMode::AutoSeedFromGeometry)
}
pub(crate) fn build_matern_basis_seeded(
data: ArrayView2<'_, f64>,
spec: &MaternBasisSpec,
workspace: &mut BasisWorkspace,
aniso_seed_mode: AnisoSeedMode,
) -> Result<BasisBuildResult, BasisError> {
let selected_centers = select_centers_by_strategy(data, &spec.center_strategy)?;
let original_centers = if matches!(
spec.identifiability,
MaternIdentifiability::FrozenTransform { .. }
) {
selected_centers
} else {
let reduce_aniso = resolve_matern_forward_aniso(
aniso_seed_mode,
selected_centers.view(),
spec.aniso_log_scales.as_deref(),
);
matern_rank_reduce_centers(
data,
&selected_centers,
spec.length_scale,
spec.nu,
reduce_aniso.as_deref(),
)?
};
let centers = expand_periodic_centers(&original_centers, spec.periodic.as_deref())?;
let aniso = resolve_matern_forward_aniso(
aniso_seed_mode,
centers.view(),
spec.aniso_log_scales.as_deref(),
);
let z_opt = matern_identifiability_transform(centers.view(), &spec.identifiability)?;
let identifiability_transform = z_opt.clone();
let full_transform = z_opt.as_ref().map(|z| {
if spec.include_intercept {
append_intercept_to_transform(z)
} else {
z.clone()
}
});
let frozen_nullspace_shrinkage_survived = match &spec.identifiability {
MaternIdentifiability::FrozenTransform {
nullspace_shrinkage_survived,
..
} => *nullspace_shrinkage_survived,
_ => None,
};
let mut realized_nullspace_shrinkage_survived = false;
let design_cols =
z_opt.as_ref().map_or(centers.nrows(), Array2::ncols) + usize::from(spec.include_intercept);
let dense_bytes = dense_design_bytes(data.nrows(), design_cols);
let matern_auto_chunk = auto_streaming_chunk_size_for_dense(data.nrows(), design_cols);
let use_streaming = matern_auto_chunk.is_some();
let use_lazy = !use_streaming
&& should_use_lazy_spatial_design(data.nrows(), design_cols, workspace.policy());
let (design, candidates) = if let Some(chunk) = matern_auto_chunk {
log::info!(
"Matérn basis auto-streaming evaluator: n={} p={} chunk_size={}",
data.nrows(),
design_cols,
chunk,
);
let shared_data = shared_owned_data_matrix(data, &workspace.cache);
let op = StreamingMaternEvaluator::new(
shared_data,
Arc::new(centers.clone()),
spec.length_scale,
spec.nu,
aniso.clone(),
z_opt.as_ref().map(|z| Arc::new(z.clone())),
spec.include_intercept,
Some(chunk),
)
.map_err(BasisError::InvalidInput)?;
let design = DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(Arc::new(op)));
let candidates = if spec.double_penalty {
let penalty_kernel = build_matern_kernel_penalty(
centers.view(),
spec.length_scale,
spec.nu,
spec.include_intercept,
aniso.as_deref(),
)?;
let primary = project_penalty_matrix(&penalty_kernel, full_transform.as_ref());
let (candidates, survived) = matern_double_penalty_candidates_with_decision(
&primary,
frozen_nullspace_shrinkage_survived,
)?;
realized_nullspace_shrinkage_survived = survived;
candidates
} else {
build_matern_operator_penalty_candidates(
centers.view(),
spec.length_scale,
spec.nu,
spec.include_intercept,
z_opt.as_ref(),
aniso.as_deref(),
)?
};
(design, candidates)
} else if use_lazy {
log::info!(
"Matérn basis switching to lazy chunked design: n={} p={} ({:.1} MiB dense)",
data.nrows(),
design_cols,
dense_bytes as f64 / (1024.0 * 1024.0),
);
let shared_data = shared_owned_data_matrix(data, &workspace.cache);
let d = data.ncols();
let length_scale = spec.length_scale;
let nu = spec.nu;
let poly_basis = if spec.include_intercept {
Some(Arc::new(Array2::<f64>::ones((data.nrows(), 1))))
} else {
None
};
let kernel_gauge = z_opt
.as_ref()
.map(|z| Arc::new(gam_problem::Gauge::from_block_transforms(&[z.clone()])));
let design = if let Some(eta) = aniso.as_ref() {
let metric_weights = eta.iter().map(|&v| (2.0 * v).exp()).collect::<Vec<_>>();
let kernel = move |data_row: &[f64], center_row: &[f64]| -> f64 {
let mut q = 0.0f64;
for axis in 0..data_row.len() {
let delta = data_row[axis] - center_row[axis];
q += metric_weights[axis] * delta * delta;
}
matern_kernel_from_distance(q.sqrt(), length_scale, nu)
.expect("validated Matérn inputs should not fail")
};
let op = ChunkedKernelDesignOperator::new(
shared_data.clone(),
Arc::new(centers.clone()),
kernel,
kernel_gauge.clone(),
poly_basis.clone(),
)
.map_err(BasisError::InvalidInput)?;
DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(Arc::new(op)))
} else {
let kernel = move |data_row: &[f64], center_row: &[f64]| -> f64 {
let r = stable_euclidean_norm((0..d).map(|axis| data_row[axis] - center_row[axis]));
matern_kernel_from_distance(r, length_scale, nu)
.expect("validated Matérn inputs should not fail")
};
let op = ChunkedKernelDesignOperator::new(
shared_data,
Arc::new(centers.clone()),
kernel,
kernel_gauge,
poly_basis,
)
.map_err(BasisError::InvalidInput)?;
DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(Arc::new(op)))
};
let candidates = if spec.double_penalty {
let penalty_kernel = build_matern_kernel_penalty(
centers.view(),
spec.length_scale,
spec.nu,
spec.include_intercept,
aniso.as_deref(),
)?;
let primary = project_penalty_matrix(&penalty_kernel, full_transform.as_ref());
let (candidates, survived) = matern_double_penalty_candidates_with_decision(
&primary,
frozen_nullspace_shrinkage_survived,
)?;
realized_nullspace_shrinkage_survived = survived;
candidates
} else {
build_matern_operator_penalty_candidates(
centers.view(),
spec.length_scale,
spec.nu,
spec.include_intercept,
z_opt.as_ref(),
aniso.as_deref(),
)?
};
(design, candidates)
} else {
let m = create_matern_spline_basiswithworkspace(
data,
centers.view(),
spec.length_scale,
spec.nu,
spec.include_intercept,
aniso.as_deref(),
workspace,
)?;
let design = if let Some(transform) = full_transform.as_ref() {
DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(fast_ab(
&m.basis, transform,
)))
} else {
DesignMatrix::Dense(gam_linalg::matrix::DenseDesignMatrix::from(m.basis.clone()))
};
let candidates = if spec.double_penalty {
let (candidates, survived) = build_matern_double_penalty_candidates(
&m,
full_transform.as_ref(),
frozen_nullspace_shrinkage_survived,
)?;
realized_nullspace_shrinkage_survived = survived;
candidates
} else {
build_matern_operator_penalty_candidates(
centers.view(),
spec.length_scale,
spec.nu,
spec.include_intercept,
z_opt.as_ref(),
aniso.as_deref(),
)?
};
(design, candidates)
};
let (penalties, nullspace_dims, penaltyinfo, null_eigenvectors, ops) =
filter_active_penalty_candidates_with_ops(candidates)?;
Ok(BasisBuildResult {
design,
penalties,
nullspace_dims,
penaltyinfo,
metadata: BasisMetadata::Matern {
centers: original_centers,
length_scale: spec.length_scale,
periodic: spec.periodic.clone(),
nu: spec.nu,
include_intercept: spec.include_intercept,
identifiability_transform,
input_scales: None,
aniso_log_scales: aniso,
nullspace_shrinkage_survived: realized_nullspace_shrinkage_survived,
},
kronecker_factored: None,
ops,
null_eigenvectors,
joint_null_rotation: None,
})
}
#[inline(always)]
pub(crate) fn eval_polywith_derivatives(coeffs: &[f64], a: f64) -> (f64, f64, f64) {
let mut p = 0.0;
let mut p1 = 0.0;
let mut p2 = 0.0;
for (i, &c) in coeffs.iter().enumerate() {
p += c * a.powi(i as i32);
if i >= 1 {
p1 += (i as f64) * c * a.powi((i - 1) as i32);
}
if i >= 2 {
p2 += (i as f64) * ((i - 1) as f64) * c * a.powi((i - 2) as i32);
}
}
(p, p1, p2)
}
#[inline(always)]
pub(crate) fn maternvalue_psi_triplet(
r: f64,
length_scale: f64,
nu: MaternNu,
) -> Result<(f64, f64, f64), BasisError> {
validate_matern_inputs(r, length_scale)?;
let kappa = 1.0 / length_scale;
let (s, p): (f64, &[f64]) = match nu {
MaternNu::Half => (kappa, &[1.0]),
MaternNu::ThreeHalves => (3.0_f64.sqrt() * kappa, &[1.0, 1.0]),
MaternNu::FiveHalves => (5.0_f64.sqrt() * kappa, &[1.0, 1.0, 1.0 / 3.0]),
MaternNu::SevenHalves => (7.0_f64.sqrt() * kappa, &[1.0, 1.0, 2.0 / 5.0, 1.0 / 15.0]),
MaternNu::NineHalves => (
9.0_f64.sqrt() * kappa,
&[1.0, 1.0, 3.0 / 7.0, 2.0 / 21.0, 1.0 / 105.0],
),
};
let a = s * r;
if a > 700.0 {
return Ok((0.0, 0.0, 0.0));
}
let e = (-a).exp();
let (p0, p1, p2) = eval_polywith_derivatives(p, a);
let value = e * p0;
let value_psi = e * a * (p1 - p0);
let value_psi_psi = e * (a * (p1 - p0) + a * a * (p2 - 2.0 * p1 + p0));
Ok((value, value_psi, value_psi_psi))
}
#[inline(always)]
pub(crate) fn exp_poly_scaled_s2_psi_triplet(
s: f64,
a: f64,
coeffs: &[f64],
scalar: f64,
) -> (f64, f64, f64) {
if a > 700.0 {
return (0.0, 0.0, 0.0);
}
let e = (-a).exp();
let (p0, p1, p2) = eval_polywith_derivatives(coeffs, a);
let d = p1 - p0;
let y = scalar * s * s * e * p0;
let y_psi = scalar * s * s * e * (2.0 * p0 + a * d);
let y_psi_psi = scalar * s * s * e * (4.0 * p0 + 5.0 * a * d + a * a * (p2 - 2.0 * p1 + p0));
(y, y_psi, y_psi_psi)
}
#[inline(always)]
pub(crate) fn matern_operator_psi_triplet(
r: f64,
length_scale: f64,
nu: MaternNu,
dimension: usize,
) -> Result<
(
f64, // phi
f64, // phi_psi
f64, // phi_psi_psi
f64, // phi_r_over_r
f64, // derivative of phi_r_over_r with respect to psi
f64, // second derivative of phi_r_over_r with respect to psi
f64, // lap
f64, // lap_psi
f64, // lap_psi_psi
),
BasisError,
> {
let (phi, phi_psi, phi_psi_psi) = maternvalue_psi_triplet(r, length_scale, nu)?;
let kappa = 1.0 / length_scale;
let d = dimension as f64;
let (s, q, rr): (f64, &[f64], &[f64]) = match nu {
MaternNu::Half => (kappa, &[1.0], &[1.0]),
MaternNu::ThreeHalves => (3.0_f64.sqrt() * kappa, &[1.0], &[-1.0, 1.0]),
MaternNu::FiveHalves => (
5.0_f64.sqrt() * kappa,
&[1.0 / 3.0, 1.0 / 3.0],
&[-1.0 / 3.0, -1.0 / 3.0, 1.0 / 3.0],
),
MaternNu::SevenHalves => (
7.0_f64.sqrt() * kappa,
&[1.0 / 5.0, 1.0 / 5.0, 1.0 / 15.0],
&[-1.0 / 5.0, -1.0 / 5.0, 0.0, 1.0 / 15.0],
),
MaternNu::NineHalves => (
9.0_f64.sqrt() * kappa,
&[1.0 / 7.0, 1.0 / 7.0, 2.0 / 35.0, 1.0 / 105.0],
&[
-1.0 / 7.0,
-1.0 / 7.0,
-1.0 / 35.0,
2.0 / 105.0,
1.0 / 105.0,
],
),
};
let a = s * r;
let (phi_rr, phi_rr_psi, phi_rr_psi_psi) = exp_poly_scaled_s2_psi_triplet(s, a, rr, 1.0);
let (ratio, ratio_psi, ratio_psi_psi) = if matches!(nu, MaternNu::Half) {
let r_eff = r.max(1e-12);
let e_eff = (-a).exp();
let g = -(s / r_eff) * e_eff;
let g_psi = -(s / r_eff) * e_eff * (1.0 - a);
let g_psi_psi = -(s / r_eff) * e_eff * (1.0 - 3.0 * a + a * a);
(g, g_psi, g_psi_psi)
} else {
exp_poly_scaled_s2_psi_triplet(s, a, q, -1.0)
};
let lap = phi_rr + (d - 1.0) * ratio;
let lap_psi = phi_rr_psi + (d - 1.0) * ratio_psi;
let lap_psi_psi = phi_rr_psi_psi + (d - 1.0) * ratio_psi_psi;
if !phi.is_finite()
|| !phi_psi.is_finite()
|| !phi_psi_psi.is_finite()
|| !ratio.is_finite()
|| !ratio_psi.is_finite()
|| !ratio_psi_psi.is_finite()
|| !lap.is_finite()
|| !lap_psi.is_finite()
|| !lap_psi_psi.is_finite()
{
crate::bail_invalid_basis!(
"non-finite Matérn psi-derivative operator values at r={r}, length_scale={length_scale}, nu={nu:?}"
);
}
Ok((
phi,
phi_psi,
phi_psi_psi,
ratio,
ratio_psi,
ratio_psi_psi,
lap,
lap_psi,
lap_psi_psi,
))
}
pub(crate) fn gram_and_psi_derivatives_from_operator(
d: &Array2<f64>,
d_psi: &Array2<f64>,
d_psi_psi: &Array2<f64>,
) -> (Array2<f64>, Array2<f64>, Array2<f64>) {
let s_raw = symmetrize(&fast_ata(d));
let s_raw_psi = symmetrize(&(d_psi.t().dot(d) + d.t().dot(d_psi)));
let s_raw_psi_psi =
symmetrize(&(d_psi_psi.t().dot(d) + d.t().dot(d_psi_psi) + 2.0 * d_psi.t().dot(d_psi)));
(s_raw, s_raw_psi, s_raw_psi_psi)
}
pub(crate) fn gram_cross_psi_derivative_from_operator(
d: &Array2<f64>,
d_a: &Array2<f64>,
d_b: &Array2<f64>,
d_ab: &Array2<f64>,
) -> Array2<f64> {
symmetrize(&(d_ab.t().dot(d) + d.t().dot(d_ab) + d_a.t().dot(d_b) + d_b.t().dot(d_a)))
}
pub(crate) fn normalize_penalty_cross_psi_derivative(
s: &Array2<f64>,
s_a: &Array2<f64>,
s_b: &Array2<f64>,
s_ab: &Array2<f64>,
c: f64,
) -> Array2<f64> {
if !c.is_finite() || c <= 1e-12 {
return Array2::<f64>::zeros(s.raw_dim());
}
let c2 = c * c;
let c3 = c2 * c;
let a_val = trace_of_product(s, s_a);
let c_a = a_val / c;
let b_val = trace_of_product(s, s_b);
let c_b = b_val / c;
let cross_val = trace_of_product(s_a, s_b) + trace_of_product(s, s_ab);
let c_ab = cross_val / c - c_a * c_b / c;
let coeff_s = 2.0 * c_a * c_b / c3 - c_ab / c2;
s_ab.mapv(|v| v / c) - s_b.mapv(|v| c_a / c2 * v) - s_a.mapv(|v| c_b / c2 * v)
+ s.mapv(|v| coeff_s * v)
}
#[inline(always)]
pub(crate) fn trace_of_product(a: &Array2<f64>, b: &Array2<f64>) -> f64 {
a.t().dot(b).diag().sum()
}
pub(crate) fn normalize_penaltywith_psi_derivatives(
s: &Array2<f64>,
s_psi: &Array2<f64>,
s_psi_psi: &Array2<f64>,
) -> (Array2<f64>, Array2<f64>, Array2<f64>, f64) {
let fro2 = trace_of_product(s, s);
let c = fro2.sqrt();
if !c.is_finite() || c <= 1e-12 {
return (
s.clone(),
Array2::<f64>::zeros(s.raw_dim()),
Array2::<f64>::zeros(s.raw_dim()),
1.0,
);
}
let a = trace_of_product(s, s_psi);
let b = trace_of_product(s_psi, s_psi) + trace_of_product(s, s_psi_psi);
let c_psi = a / c;
let c_psi_psi = b / c - (a * a) / (c * c * c);
let s_tilde = s.mapv(|v| v / c);
let s_tilde_psi = s_psi.mapv(|v| v / c) - s.mapv(|v| (c_psi / (c * c)) * v);
let s_tilde_psi_psi = s_psi_psi.mapv(|v| v / c) - s_psi.mapv(|v| 2.0 * c_psi / (c * c) * v)
+ s.mapv(|v| ((2.0 * c_psi * c_psi) / (c * c * c) - c_psi_psi / (c * c)) * v);
(s_tilde, s_tilde_psi, s_tilde_psi_psi, c)
}
pub fn build_matern_operator_penalty_psi_derivatives(
centers: ArrayView2<'_, f64>,
length_scale: f64,
nu: MaternNu,
include_intercept: bool,
z_opt: Option<&Array2<f64>>,
aniso_log_scales: Option<&[f64]>,
) -> Result<(Vec<Array2<f64>>, Vec<Array2<f64>>), BasisError> {
let p = centers.nrows();
let d = centers.ncols();
let mut d0_raw = Array2::<f64>::zeros((p, p));
let mut d1_raw = Array2::<f64>::zeros((p * d, p));
let mut d2_raw = Array2::<f64>::zeros((p * d * d, p));
let mut d0_raw_psi = Array2::<f64>::zeros((p, p));
let mut d1_raw_psi = Array2::<f64>::zeros((p * d, p));
let mut d2_raw_psi = Array2::<f64>::zeros((p * d * d, p));
let mut d0_raw_psi_psi = Array2::<f64>::zeros((p, p));
let mut d1_raw_psi_psi = Array2::<f64>::zeros((p * d, p));
let mut d2_raw_psi_psi = Array2::<f64>::zeros((p * d * d, p));
let metric_weights = aniso_log_scales
.map(centered_aniso_metric_weights)
.unwrap_or_else(|| vec![1.0; d]);
for k in 0..p {
for j in 0..p {
let (r, _s_vec) = if let Some(eta) = aniso_log_scales {
aniso_distance_and_components(
centers.row(k).as_slice().unwrap(),
centers.row(j).as_slice().unwrap(),
eta,
)
} else {
(
stable_euclidean_norm((0..d).map(|c| centers[[k, c]] - centers[[j, c]])),
(0..d)
.map(|c| {
let h = centers[[k, c]] - centers[[j, c]];
h * h
})
.collect(),
)
};
let (
phi,
phi_psi,
phi_psi_psi,
ratio,
ratio_psi,
ratio_psi_psi,
_lap,
lap_psi,
lap_psi_psi,
) = matern_operator_psi_triplet(r, length_scale, nu, d)?;
let (_, _q_shape, t, _t_r, _t_rr) =
matern_aniso_extended_radial_scalars(r, length_scale, nu)?;
let q = ratio;
let q_psi = ratio_psi;
let q_psi_psi = ratio_psi_psi;
let (t_psi, t_psi_psi) = if r < 1e-14 {
(0.0, 0.0)
} else {
let r2 = r * r;
let d_f64 = d as f64;
(
(lap_psi - d_f64 * ratio_psi) / r2,
(lap_psi_psi - d_f64 * ratio_psi_psi) / r2,
)
};
d0_raw[[k, j]] = phi;
d0_raw_psi[[k, j]] = phi_psi;
d0_raw_psi_psi[[k, j]] = phi_psi_psi;
for axis in 0..d {
let delta = centers[[k, axis]] - centers[[j, axis]];
let axis_scale = metric_weights[axis];
let row = k * d + axis;
d1_raw[[row, j]] = ratio * axis_scale * delta;
d1_raw_psi[[row, j]] = ratio_psi * axis_scale * delta;
d1_raw_psi_psi[[row, j]] = ratio_psi_psi * axis_scale * delta;
}
for b in 0..d {
let h_b = centers[[k, b]] - centers[[j, b]];
let w_b = metric_weights[b];
for c in 0..d {
let h_c = centers[[k, c]] - centers[[j, c]];
let w_c = metric_weights[c];
let row = (k * d + b) * d + c;
d2_raw[[row, j]] = hessian_operator_entry(q, t, h_b, h_c, w_b, w_c, b, c);
d2_raw_psi[[row, j]] =
hessian_operator_entry(q_psi, t_psi, h_b, h_c, w_b, w_c, b, c);
d2_raw_psi_psi[[row, j]] =
hessian_operator_entry(q_psi_psi, t_psi_psi, h_b, h_c, w_b, w_c, b, c);
}
}
}
}
let coefficient_gauge = z_opt.map(|z| gam_problem::Gauge::from_block_transforms(&[z.clone()]));
let project = |mat: Array2<f64>| {
if let Some(gauge) = coefficient_gauge.as_ref() {
gauge.restrict_design(&mat)
} else {
mat
}
};
let d0_kernel = project(d0_raw);
let d0_kernel_psi = project(d0_raw_psi);
let d0_kernel_psi_psi = project(d0_raw_psi_psi);
let d1_kernel = project(d1_raw);
let d1_kernel_psi = project(d1_raw_psi);
let d1_kernel_psi_psi = project(d1_raw_psi_psi);
let d2_kernel = project(d2_raw);
let d2_kernel_psi = project(d2_raw_psi);
let d2_kernel_psi_psi = project(d2_raw_psi_psi);
let kernel_cols = d0_kernel.ncols();
let total_cols = kernel_cols + usize::from(include_intercept);
let mut d0 = Array2::<f64>::zeros((p, total_cols));
let mut d1 = Array2::<f64>::zeros((p * d, total_cols));
let mut d2 = Array2::<f64>::zeros((p * d * d, total_cols));
let mut d0_psi = Array2::<f64>::zeros((p, total_cols));
let mut d1_psi = Array2::<f64>::zeros((p * d, total_cols));
let mut d2_psi = Array2::<f64>::zeros((p * d * d, total_cols));
let mut d0_psi_psi = Array2::<f64>::zeros((p, total_cols));
let mut d1_psi_psi = Array2::<f64>::zeros((p * d, total_cols));
let mut d2_psi_psi = Array2::<f64>::zeros((p * d * d, total_cols));
d0.slice_mut(s![.., 0..kernel_cols]).assign(&d0_kernel);
d1.slice_mut(s![.., 0..kernel_cols]).assign(&d1_kernel);
d2.slice_mut(s![.., 0..kernel_cols]).assign(&d2_kernel);
d0_psi
.slice_mut(s![.., 0..kernel_cols])
.assign(&d0_kernel_psi);
d1_psi
.slice_mut(s![.., 0..kernel_cols])
.assign(&d1_kernel_psi);
d2_psi
.slice_mut(s![.., 0..kernel_cols])
.assign(&d2_kernel_psi);
d0_psi_psi
.slice_mut(s![.., 0..kernel_cols])
.assign(&d0_kernel_psi_psi);
d1_psi_psi
.slice_mut(s![.., 0..kernel_cols])
.assign(&d1_kernel_psi_psi);
d2_psi_psi
.slice_mut(s![.., 0..kernel_cols])
.assign(&d2_kernel_psi_psi);
if include_intercept {
d0.column_mut(kernel_cols).fill(1.0);
}
let (s0, s0_psi, s0_psi_psi) =
gram_and_psi_derivatives_from_operator(&d0, &d0_psi, &d0_psi_psi);
let (s1, s1_psi, s1_psi_psi) =
gram_and_psi_derivatives_from_operator(&d1, &d1_psi, &d1_psi_psi);
let (s2, s2_psi, s2_psi_psi) =
gram_and_psi_derivatives_from_operator(&d2, &d2_psi, &d2_psi_psi);
let (s0_norm, s0_norm_psi, s0_norm_psi_psi, c0) =
normalize_penaltywith_psi_derivatives(&s0, &s0_psi, &s0_psi_psi);
let (s1_norm, s1_norm_psi, s1_norm_psi_psi, c1) =
normalize_penaltywith_psi_derivatives(&s1, &s1_psi, &s1_psi_psi);
let (s2_norm, s2_norm_psi, s2_norm_psi_psi, c2) =
normalize_penaltywith_psi_derivatives(&s2, &s2_psi, &s2_psi_psi);
let matern_spec = DuchonOperatorPenaltySpec::matern_for_smoothness(nu, d);
let mut candidates = Vec::with_capacity(3);
for (spec_gate, source, matrix, normalization_scale) in [
(&matern_spec.mass, PenaltySource::OperatorMass, s0_norm, c0),
(
&matern_spec.tension,
PenaltySource::OperatorTension,
s1_norm,
c1,
),
(
&matern_spec.stiffness,
PenaltySource::OperatorStiffness,
s2_norm,
c2,
),
] {
if !matches!(spec_gate, OperatorPenaltySpec::Active { .. }) {
continue;
}
candidates.push(PenaltyCandidate {
matrix,
nullspace_dim_hint: 0,
source,
normalization_scale,
kronecker_factors: None,
op: None,
});
}
let (_, _, penaltyinfo) = filter_active_penalty_candidates(candidates)?;
let penalties_derivative = active_operator_penalty_derivatives(
&penaltyinfo,
&[s0_norm_psi, s1_norm_psi, s2_norm_psi],
"Matérn",
)?;
let penaltiessecond_derivative = active_operator_penalty_derivatives(
&penaltyinfo,
&[s0_norm_psi_psi, s1_norm_psi_psi, s2_norm_psi_psi],
"Matérn",
)?;
Ok((penalties_derivative, penaltiessecond_derivative))
}
pub(crate) struct DuchonRawPenaltyPsiDerivativeBlocks {
pub(crate) d0: Array2<f64>,
pub(crate) d1: Array2<f64>,
pub(crate) d2: Array2<f64>,
pub(crate) d0_psi: Array2<f64>,
pub(crate) d1_psi: Array2<f64>,
pub(crate) d2_psi: Array2<f64>,
pub(crate) d0_psi_psi: Array2<f64>,
pub(crate) d1_psi_psi: Array2<f64>,
pub(crate) d2_psi_psi: Array2<f64>,
}
impl DuchonRawPenaltyPsiDerivativeBlocks {
pub(crate) fn zeros(p: usize, d: usize, cols: usize) -> Self {
Self {
d0: Array2::<f64>::zeros((p, cols)),
d1: Array2::<f64>::zeros((p * d, cols)),
d2: Array2::<f64>::zeros((p * d * d, cols)),
d0_psi: Array2::<f64>::zeros((p, cols)),
d1_psi: Array2::<f64>::zeros((p * d, cols)),
d2_psi: Array2::<f64>::zeros((p * d * d, cols)),
d0_psi_psi: Array2::<f64>::zeros((p, cols)),
d1_psi_psi: Array2::<f64>::zeros((p * d, cols)),
d2_psi_psi: Array2::<f64>::zeros((p * d * d, cols)),
}
}
pub(crate) fn add_assign(&mut self, rhs: &Self) {
self.d0 += &rhs.d0;
self.d1 += &rhs.d1;
self.d2 += &rhs.d2;
self.d0_psi += &rhs.d0_psi;
self.d1_psi += &rhs.d1_psi;
self.d2_psi += &rhs.d2_psi;
self.d0_psi_psi += &rhs.d0_psi_psi;
self.d1_psi_psi += &rhs.d1_psi_psi;
self.d2_psi_psi += &rhs.d2_psi_psi;
}
}
pub fn build_duchon_operator_penalty_psi_derivatives(
collocation_points: ArrayView2<'_, f64>,
centers: ArrayView2<'_, f64>,
spec: &DuchonBasisSpec,
identifiability_transform: Option<&Array2<f64>>,
workspace: &mut BasisWorkspace,
) -> Result<(Vec<PenaltySource>, Vec<Array2<f64>>, Vec<Array2<f64>>), BasisError> {
let length_scale = spec.length_scale.ok_or_else(|| {
BasisError::InvalidInput(
"exact Duchon log-kappa derivatives require hybrid Duchon with length_scale"
.to_string(),
)
})?;
let effective_nullspace_order = duchon_effective_nullspace_order(centers, spec.nullspace_order);
let p_order = duchon_p_from_nullspace_order(effective_nullspace_order);
let s_order = spec.power_as_usize();
let dim = centers.ncols();
let two_pps = 2.0 * (p_order as f64 + spec.power);
let mut effective_operator_penalties = spec.operator_penalties.clone();
if two_pps <= dim as f64 + 1.0 {
effective_operator_penalties.tension = OperatorPenaltySpec::Disabled;
}
if two_pps <= dim as f64 + 2.0 {
effective_operator_penalties.stiffness = OperatorPenaltySpec::Disabled;
}
let max_derivative_order =
duchon_max_active_operator_derivative_order(&effective_operator_penalties);
if max_derivative_order == 0
&& !matches!(
effective_operator_penalties.mass,
OperatorPenaltySpec::Active { .. }
)
{
return Ok((Vec::new(), Vec::new(), Vec::new()));
}
validate_duchon_collocation_orders(
Some(length_scale),
p_order,
s_order as f64,
dim,
max_derivative_order,
)?;
let coeffs = duchon_partial_fraction_coeffs(p_order, s_order, 1.0 / length_scale);
let z_kernel =
kernel_constraint_nullspace(centers, effective_nullspace_order, &mut workspace.cache)?;
let n_basis = centers.nrows();
if collocation_points.ncols() != dim {
crate::bail_dim_basis!(
"Duchon psi-derivative collocation dim {} != centers dim {dim}",
collocation_points.ncols()
);
}
let p_colloc = collocation_points.nrows();
let d = dim;
let kernel_cols = z_kernel.ncols();
let aniso = spec.aniso_log_scales.as_deref();
if let Some(eta) = aniso
&& eta.len() != d
{
crate::bail_dim_basis!(
"Duchon anisotropy dimension mismatch: got {}, expected {d}",
eta.len()
);
}
let metric_weights: Option<Vec<f64>> = aniso.map(centered_aniso_metric_weights);
let need_d1 = max_derivative_order >= 1;
let need_d2 = max_derivative_order >= 2;
let chunk_count = rayon::current_num_threads().max(1);
let chunk_size = p_colloc.div_ceil(chunk_count).max(1);
let chunks: Vec<(usize, usize)> = (0..p_colloc)
.step_by(chunk_size)
.map(|start| (start, (start + chunk_size).min(p_colloc)))
.collect();
let partial_blocks = chunks
.into_par_iter()
.map(
|(start, end)| -> Result<DuchonRawPenaltyPsiDerivativeBlocks, BasisError> {
let mut local =
DuchonRawPenaltyPsiDerivativeBlocks::zeros(p_colloc, d, kernel_cols);
for i in start..end {
for j in 0..n_basis {
let r = if let Some(eta) = aniso {
let row_i: Vec<f64> =
(0..d).map(|a| collocation_points[[i, a]]).collect();
let row_j: Vec<f64> = (0..d).map(|a| centers[[j, a]]).collect();
let (r, _) = aniso_distance_and_components(&row_i, &row_j, eta);
r
} else {
stable_euclidean_norm(
(0..d)
.map(|axis| collocation_points[[i, axis]] - centers[[j, axis]]),
)
};
let core = duchon_radial_core_psi_triplet(
r,
length_scale,
p_order,
s_order,
d,
&coeffs,
)?;
for col in 0..kernel_cols {
let z_jc = z_kernel[[j, col]];
local.d0[[i, col]] += core.phi.value * z_jc;
local.d0_psi[[i, col]] += core.phi.psi * z_jc;
local.d0_psi_psi[[i, col]] += core.phi.psi_psi * z_jc;
}
if !need_d1 && !need_d2 {
continue;
}
if r > 1e-10 {
let jets =
duchon_radial_jets(r, length_scale, p_order, s_order, d, &coeffs)?;
let q = jets.q;
let (q_psi, q_psi_psi) =
duchon_q_psi_triplet_from_jets(&jets, p_order, s_order, d, r);
let t_exponent = duchon_scaling_exponent(p_order, s_order, d) + 4.0;
let (t_psi, t_psi_psi) = scaled_log_kappa_derivatives(
jets.t, jets.t_r, jets.t_rr, t_exponent, r,
);
if need_d1 {
for axis in 0..d {
let delta = collocation_points[[i, axis]] - centers[[j, axis]];
let axis_scale = metric_weights
.as_ref()
.map(|weights| weights[axis])
.unwrap_or(1.0);
let row = i * d + axis;
for col in 0..kernel_cols {
let z_jc = z_kernel[[j, col]];
local.d1[[row, col]] += q * axis_scale * delta * z_jc;
local.d1_psi[[row, col]] +=
q_psi * axis_scale * delta * z_jc;
local.d1_psi_psi[[row, col]] +=
q_psi_psi * axis_scale * delta * z_jc;
}
}
}
if need_d2 {
for col in 0..kernel_cols {
let z_jc = z_kernel[[j, col]];
for axis_b in 0..d {
let h_b =
collocation_points[[i, axis_b]] - centers[[j, axis_b]];
let w_b = metric_weights
.as_ref()
.map(|weights| weights[axis_b])
.unwrap_or(1.0);
for axis_c in 0..d {
let h_c = collocation_points[[i, axis_c]]
- centers[[j, axis_c]];
let w_c = metric_weights
.as_ref()
.map(|weights| weights[axis_c])
.unwrap_or(1.0);
let row = (i * d + axis_b) * d + axis_c;
local.d2[[row, col]] += hessian_operator_entry(
q, jets.t, h_b, h_c, w_b, w_c, axis_b, axis_c,
) * z_jc;
local.d2_psi[[row, col]] += hessian_operator_entry(
q_psi, t_psi, h_b, h_c, w_b, w_c, axis_b, axis_c,
) * z_jc;
local.d2_psi_psi[[row, col]] += hessian_operator_entry(
q_psi_psi, t_psi_psi, h_b, h_c, w_b, w_c, axis_b,
axis_c,
) * z_jc;
}
}
}
}
} else if need_d2 {
let (phi_rr, phi_rr_psi, phi_rr_psi_psi) =
duchonphi_rr_collision_psi_triplet(
length_scale,
p_order,
s_order,
d,
&coeffs,
)?;
for col in 0..kernel_cols {
let z_jc = z_kernel[[j, col]];
for axis in 0..d {
let w_axis = metric_weights
.as_ref()
.map(|weights| weights[axis])
.unwrap_or(1.0);
let row = (i * d + axis) * d + axis;
local.d2[[row, col]] += w_axis * phi_rr * z_jc;
local.d2_psi[[row, col]] += w_axis * phi_rr_psi * z_jc;
local.d2_psi_psi[[row, col]] += w_axis * phi_rr_psi_psi * z_jc;
}
}
}
}
}
Ok(local)
},
)
.collect::<Result<Vec<_>, BasisError>>()?;
let mut raw = DuchonRawPenaltyPsiDerivativeBlocks::zeros(p_colloc, d, kernel_cols);
for partial in &partial_blocks {
raw.add_assign(partial);
}
let d0_raw = raw.d0;
let d1_raw = raw.d1;
let d2_raw = raw.d2;
let d0_raw_psi = raw.d0_psi;
let d1_raw_psi = raw.d1_psi;
let d2_raw_psi = raw.d2_psi;
let d0_raw_psi_psi = raw.d0_psi_psi;
let d1_raw_psi_psi = raw.d1_psi_psi;
let d2_raw_psi_psi = raw.d2_psi_psi;
let poly = polynomial_block_from_order(centers, effective_nullspace_order);
let kernel_cols = d0_raw.ncols();
let poly_cols = poly.ncols();
let total_cols = kernel_cols + poly_cols;
let mut d0 = Array2::<f64>::zeros((p_colloc, total_cols));
let mut d1 = Array2::<f64>::zeros((p_colloc * d, total_cols));
let mut d2 = Array2::<f64>::zeros((p_colloc * d * d, total_cols));
let mut d0_psi = Array2::<f64>::zeros((p_colloc, total_cols));
let mut d1_psi = Array2::<f64>::zeros((p_colloc * d, total_cols));
let mut d2_psi = Array2::<f64>::zeros((p_colloc * d * d, total_cols));
let mut d0_psi_psi = Array2::<f64>::zeros((p_colloc, total_cols));
let mut d1_psi_psi = Array2::<f64>::zeros((p_colloc * d, total_cols));
let mut d2_psi_psi = Array2::<f64>::zeros((p_colloc * d * d, total_cols));
d0.slice_mut(s![.., 0..kernel_cols]).assign(&d0_raw);
d1.slice_mut(s![.., 0..kernel_cols]).assign(&d1_raw);
d2.slice_mut(s![.., 0..kernel_cols]).assign(&d2_raw);
d0_psi.slice_mut(s![.., 0..kernel_cols]).assign(&d0_raw_psi);
d1_psi.slice_mut(s![.., 0..kernel_cols]).assign(&d1_raw_psi);
d2_psi.slice_mut(s![.., 0..kernel_cols]).assign(&d2_raw_psi);
d0_psi_psi
.slice_mut(s![.., 0..kernel_cols])
.assign(&d0_raw_psi_psi);
d1_psi_psi
.slice_mut(s![.., 0..kernel_cols])
.assign(&d1_raw_psi_psi);
d2_psi_psi
.slice_mut(s![.., 0..kernel_cols])
.assign(&d2_raw_psi_psi);
let coefficient_gauge =
identifiability_transform.map(|z| gam_problem::Gauge::from_block_transforms(&[z.clone()]));
let project = |mat: Array2<f64>| {
if let Some(gauge) = coefficient_gauge.as_ref() {
gauge.restrict_design(&mat)
} else {
mat
}
};
let d0 = project(d0);
let d1 = project(d1);
let d2 = project(d2);
let d0_psi = project(d0_psi);
let d1_psi = project(d1_psi);
let d2_psi = project(d2_psi);
let d0_psi_psi = project(d0_psi_psi);
let d1_psi_psi = project(d1_psi_psi);
let d2_psi_psi = project(d2_psi_psi);
let (s0, s0_psi, s0_psi_psi) =
centered_operator_gram_and_psi_derivatives(&d0, &d0_psi, &d0_psi_psi);
let (mut s1, mut s1_psi, mut s1_psi_psi) =
gram_and_psi_derivatives_from_operator(&d1, &d1_psi, &d1_psi_psi);
let (mut s2, mut s2_psi, mut s2_psi_psi) =
gram_and_psi_derivatives_from_operator(&d2, &d2_psi, &d2_psi_psi);
let kappa = 1.0 / length_scale.max(1e-300);
let aniso = spec.aniso_log_scales.as_deref();
if duchon_closed_form_operator_penalty_converges(1, p_order, s_order as f64, d) {
let (cf_s, cf_s_psi, cf_s_psi_psi) = closed_form_psi_derivatives_in_total_basis(
centers,
1,
p_order,
s_order,
kappa,
aniso,
Some(&z_kernel),
poly_cols,
identifiability_transform,
);
s1 = cf_s;
s1_psi = cf_s_psi;
s1_psi_psi = cf_s_psi_psi;
}
if duchon_closed_form_operator_penalty_converges(2, p_order, s_order as f64, d) {
let (cf_s, cf_s_psi, cf_s_psi_psi) = closed_form_psi_derivatives_in_total_basis(
centers,
2,
p_order,
s_order,
kappa,
aniso,
Some(&z_kernel),
poly_cols,
identifiability_transform,
);
s2 = cf_s;
s2_psi = cf_s_psi;
s2_psi_psi = cf_s_psi_psi;
}
let (s0_norm, s0_norm_psi, s0_norm_psi_psi, c0) =
normalize_penaltywith_psi_derivatives(&s0, &s0_psi, &s0_psi_psi);
let (s1_norm, s1_norm_psi, s1_norm_psi_psi, c1) =
normalize_penaltywith_psi_derivatives(&s1, &s1_psi, &s1_psi_psi);
let (s2_norm, s2_norm_psi, s2_norm_psi_psi, c2) =
normalize_penaltywith_psi_derivatives(&s2, &s2_psi, &s2_psi_psi);
let candidates = vec![
PenaltyCandidate {
matrix: s0_norm,
nullspace_dim_hint: 0,
source: PenaltySource::OperatorMass,
normalization_scale: c0,
kronecker_factors: None,
op: None,
},
PenaltyCandidate {
matrix: s1_norm,
nullspace_dim_hint: 0,
source: PenaltySource::OperatorTension,
normalization_scale: c1,
kronecker_factors: None,
op: None,
},
PenaltyCandidate {
matrix: s2_norm,
nullspace_dim_hint: 0,
source: PenaltySource::OperatorStiffness,
normalization_scale: c2,
kronecker_factors: None,
op: None,
},
];
let candidates = operator_penalty_candidates_from_derivative_candidates(
candidates,
&effective_operator_penalties,
);
let first_derivs = vec![s0_norm_psi, s1_norm_psi, s2_norm_psi];
let second_derivs = vec![s0_norm_psi_psi, s1_norm_psi_psi, s2_norm_psi_psi];
let (_, _, penaltyinfo) = filter_active_penalty_candidates(candidates)?;
let active_sources = penaltyinfo
.iter()
.filter(|info| info.active)
.map(|info| info.source.clone())
.collect::<Vec<_>>();
let penalties_derivative =
active_operator_penalty_derivatives(&penaltyinfo, &first_derivs, "Duchon")?;
let penaltiessecond_derivative =
active_operator_penalty_derivatives(&penaltyinfo, &second_derivs, "Duchon")?;
Ok((
active_sources,
penalties_derivative,
penaltiessecond_derivative,
))
}
pub(crate) fn operator_penalty_candidates_from_derivative_candidates(
candidates: Vec<PenaltyCandidate>,
spec: &DuchonOperatorPenaltySpec,
) -> Vec<PenaltyCandidate> {
candidates
.into_iter()
.filter(|candidate| match candidate.source {
PenaltySource::OperatorMass => matches!(spec.mass, OperatorPenaltySpec::Active { .. }),
PenaltySource::OperatorTension => {
matches!(spec.tension, OperatorPenaltySpec::Active { .. })
}
PenaltySource::OperatorStiffness => {
matches!(spec.stiffness, OperatorPenaltySpec::Active { .. })
}
_ => true,
})
.collect()
}
pub fn build_duchon_native_penalty_psi_derivatives(
centers: ArrayView2<'_, f64>,
spec: &DuchonBasisSpec,
identifiability_transform: Option<&Array2<f64>>,
workspace: &mut BasisWorkspace,
) -> Result<(Vec<PenaltySource>, Vec<Array2<f64>>, Vec<Array2<f64>>), BasisError> {
let length_scale = spec.length_scale.ok_or_else(|| {
BasisError::InvalidInput(
"exact Duchon native penalty log-kappa derivatives require hybrid Duchon with length_scale"
.to_string(),
)
})?;
let effective_nullspace_order = duchon_effective_nullspace_order(centers, spec.nullspace_order);
let p_order = duchon_p_from_nullspace_order(effective_nullspace_order);
let s_order = spec.power_as_usize();
let dim = centers.ncols();
let mut z =
kernel_constraint_nullspace(centers, effective_nullspace_order, &mut workspace.cache)?;
if let Some(v) = spec.radial_reparam.as_ref() {
if v.nrows() != z.ncols() {
crate::bail_dim_basis!(
"Duchon frozen radial reparam shape {:?} does not match constrained kernel dimension {}",
v.dim(),
z.ncols()
);
}
z = fast_ab(&z, v);
}
let kernel_cols = z.ncols();
let poly_cols = polynomial_block_from_order(centers, effective_nullspace_order).ncols();
let total_cols = kernel_cols + poly_cols;
let coeffs = duchon_partial_fraction_coeffs(p_order, s_order, 1.0 / length_scale.max(1e-300));
let kernel_amp = duchon_kernel_amplification(
centers,
Some(length_scale),
p_order,
s_order,
dim,
spec.aniso_log_scales.as_deref(),
Some(&coeffs),
None,
);
let axis_scales = spec.aniso_log_scales.as_deref().map(aniso_axis_scales);
let n_centers = centers.nrows();
let mut kernel = Array2::<f64>::zeros((n_centers, n_centers));
let mut kernel_psi = Array2::<f64>::zeros((n_centers, n_centers));
let mut kernel_psi_psi = Array2::<f64>::zeros((n_centers, n_centers));
for i in 0..n_centers {
for j in i..n_centers {
let r = if let Some(scales) = axis_scales.as_deref() {
aniso_distance_rows_with_scales(centers, i, centers, j, scales)
} else {
euclidean_distance_rows(centers, i, centers, j)
};
let core =
duchon_radial_core_psi_triplet(r, length_scale, p_order, s_order, dim, &coeffs)?;
kernel[[i, j]] = core.phi.value;
kernel[[j, i]] = core.phi.value;
kernel_psi[[i, j]] = core.phi.psi;
kernel_psi[[j, i]] = core.phi.psi;
kernel_psi_psi[[i, j]] = core.phi.psi_psi;
kernel_psi_psi[[j, i]] = core.phi.psi_psi;
}
}
let amp2 = kernel_amp * kernel_amp;
let kernel_gauge = gam_problem::Gauge::from_block_transforms(&[z.clone()]);
let project_kernel = |k: &Array2<f64>| kernel_gauge.restrict_penalty(k).mapv(|v| v * amp2);
let omega = project_kernel(&kernel);
let omega_psi = project_kernel(&kernel_psi);
let omega_psi_psi = project_kernel(&kernel_psi_psi);
let outer_gauge =
identifiability_transform.map(|z| gam_problem::Gauge::from_block_transforms(&[z.clone()]));
let embed = |block: Array2<f64>| {
let mut out = Array2::<f64>::zeros((total_cols, total_cols));
out.slice_mut(s![..kernel_cols, ..kernel_cols])
.assign(&block);
match outer_gauge.as_ref() {
Some(gauge) => symmetrize(&gauge.restrict_penalty(&out)),
None => symmetrize(&out),
}
};
let primary = embed(omega);
let primary_psi = embed(omega_psi);
let primary_psi_psi = embed(omega_psi_psi);
let (_, primary_psi_norm, primary_psi_psi_norm, _) =
normalize_penaltywith_psi_derivatives(&primary, &primary_psi, &primary_psi_psi);
let candidates = duchon_native_penalty_candidates(
centers,
spec.length_scale,
spec.power,
effective_nullspace_order,
spec.aniso_log_scales.as_deref(),
&z,
identifiability_transform,
poly_cols,
)?;
let (_, _, penaltyinfo) = filter_active_penalty_candidates(candidates)?;
let mut sources = Vec::new();
let mut first = Vec::new();
let mut second = Vec::new();
for info in penaltyinfo.iter().filter(|info| info.active) {
sources.push(info.source.clone());
match info.source {
PenaltySource::Primary => {
first.push(primary_psi_norm.clone());
second.push(primary_psi_psi_norm.clone());
}
PenaltySource::DoublePenaltyNullspace => {
first.push(Array2::<f64>::zeros(primary_psi_norm.raw_dim()));
second.push(Array2::<f64>::zeros(primary_psi_psi_norm.raw_dim()));
}
ref other => {
crate::bail_invalid_basis!(
"unexpected Duchon native penalty source in derivative path: {other:?}"
);
}
}
}
Ok((sources, first, second))
}
pub(crate) fn prepare_duchon_derivative_contextwithworkspace(
data: ArrayView2<'_, f64>,
spec: &DuchonBasisSpec,
workspace: &mut BasisWorkspace,
) -> Result<(Array2<f64>, Option<Array2<f64>>), BasisError> {
let original_centers = select_centers_by_strategy(data, &spec.center_strategy)?;
let centers = expand_periodic_centers(&original_centers, spec.periodic.as_deref())?;
assert_spatial_centers_below_large_scale_cap(data.ncols(), centers.view())?;
let raw_design = build_duchon_basis_designwithworkspace(
data,
centers.view(),
spec.length_scale,
spec.power,
spec.nullspace_order,
spec.aniso_log_scales.as_deref(),
None,
workspace,
)?;
let identifiability_transform = spatial_identifiability_transform_from_design(
data,
raw_design.basis.view(),
&spec.identifiability,
"Duchon",
)?;
Ok((centers, identifiability_transform))
}
pub(crate) fn prepare_periodic_duchon_centers_1d(
centers: Array2<f64>,
) -> Result<(Array2<f64>, f64, f64), BasisError> {
prepare_periodic_duchon_centers_1d_with_period(centers, None)
}
pub(crate) fn prepare_periodic_duchon_centers_1d_with_period(
centers: Array2<f64>,
explicit_period: Option<f64>,
) -> Result<(Array2<f64>, f64, f64), BasisError> {
if centers.ncols() != 1 {
crate::bail_invalid_basis!(
"periodic Duchon smooths currently require exactly one covariate"
);
}
let left = centers
.column(0)
.iter()
.fold(f64::INFINITY, |a, &b| a.min(b));
let right = centers
.column(0)
.iter()
.fold(f64::NEG_INFINITY, |a, &b| a.max(b));
if !left.is_finite() || !right.is_finite() || left >= right {
return Err(BasisError::InvalidRange(left, right));
}
let span = right - left;
let period = match explicit_period {
Some(p) => {
if !p.is_finite() || p <= 0.0 {
crate::bail_invalid_basis!(
"periodic Duchon period must be finite and positive; got {p}"
);
}
if p < span - 1.0e-10 * span.max(1.0) {
crate::bail_invalid_basis!(
"periodic Duchon period ({p}) is smaller than the center span ({span}); \
every center must lie within a single period"
);
}
p
}
None => span,
};
let centers = collapse_periodic_endpoint(centers, left, period);
Ok((centers, left, period))
}
pub(crate) fn fill_periodic_duchon_kernel_psi_matrices(
rows: ArrayView2<'_, f64>,
centers: ArrayView2<'_, f64>,
left: f64,
period: f64,
length_scale: f64,
p_order: usize,
s_order: usize,
coeffs: &DuchonPartialFractionCoeffs,
) -> Result<(Array2<f64>, Array2<f64>, Array2<f64>), BasisError> {
let n = rows.nrows();
let k = centers.nrows();
let mut kernel = Array2::<f64>::zeros((n, k));
let mut kernel_psi = Array2::<f64>::zeros((n, k));
let mut kernel_psi_psi = Array2::<f64>::zeros((n, k));
for i in 0..n {
let x = wrap_to_period(rows[[i, 0]], left, period);
for j in 0..k {
let r = periodic_distance_1d(x, centers[[j, 0]], period);
let core =
duchon_radial_core_psi_triplet(r, length_scale, p_order, s_order, 1, coeffs)?;
kernel[[i, j]] = core.phi.value;
kernel_psi[[i, j]] = core.phi.psi;
kernel_psi_psi[[i, j]] = core.phi.psi_psi;
}
}
Ok((kernel, kernel_psi, kernel_psi_psi))
}
pub(crate) fn periodic_duchon_identifiability_transformwithworkspace(
data: ArrayView2<'_, f64>,
spec: &DuchonBasisSpec,
centers: Array2<f64>,
workspace: &mut BasisWorkspace,
) -> Result<Option<Array2<f64>>, BasisError> {
let built = build_periodic_duchon_basis_1d(data, spec, centers, workspace)?;
match built.metadata {
BasisMetadata::Duchon {
identifiability_transform,
..
} => Ok(identifiability_transform),
other => Err(BasisError::InvalidInput(format!(
"periodic Duchon builder must return Duchon metadata, got {:?}",
std::mem::discriminant(&other)
))),
}
}
pub(crate) fn build_periodic_duchon_basis_log_kappa_derivativeswithworkspace(
data: ArrayView2<'_, f64>,
spec: &DuchonBasisSpec,
workspace: &mut BasisWorkspace,
) -> Result<BasisPsiDerivativeBundle, BasisError> {
if data.ncols() != 1 {
crate::bail_invalid_basis!(
"periodic Duchon log-kappa derivatives require exactly one covariate"
);
}
let length_scale = spec.length_scale.ok_or_else(|| {
BasisError::InvalidInput(
"periodic Duchon log-kappa derivatives require hybrid Duchon with length_scale"
.to_string(),
)
})?;
let centers = select_centers_by_strategy(data, &spec.center_strategy)?;
assert_spatial_centers_below_large_scale_cap(data.ncols(), centers.view())?;
let (centers, left, period) = prepare_periodic_duchon_centers_1d(centers)?;
let effective_nullspace_order = DuchonNullspaceOrder::Zero;
let p_order = duchon_p_from_nullspace_order(effective_nullspace_order);
let s_order = spec.power_as_usize();
validate_duchon_kernel_orders(Some(length_scale), p_order, s_order as f64, 1)?;
let coeffs = duchon_partial_fraction_coeffs(p_order, s_order, 1.0 / length_scale.max(1e-300));
let z_kernel = kernel_constraint_nullspace(
centers.view(),
effective_nullspace_order,
&mut workspace.cache,
)?;
let identifiability_transform = periodic_duchon_identifiability_transformwithworkspace(
data,
spec,
centers.clone(),
workspace,
)?;
let (_data_kernel, data_kernel_psi, data_kernel_psi_psi) =
fill_periodic_duchon_kernel_psi_matrices(
data,
centers.view(),
left,
period,
length_scale,
p_order,
s_order,
&coeffs,
)?;
let kernel_amp = duchon_kernel_amplification(
centers.view(),
Some(length_scale),
p_order,
s_order,
1,
None,
Some(&coeffs),
None,
);
let kernel_cols = z_kernel.ncols();
let total_cols = kernel_cols + 1;
let kernel_gauge = gam_problem::Gauge::from_block_transforms(&[z_kernel.clone()]);
let mut design_first = Array2::<f64>::zeros((data.nrows(), total_cols));
let mut design_second = Array2::<f64>::zeros((data.nrows(), total_cols));
design_first
.slice_mut(s![.., 0..kernel_cols])
.assign(&(kernel_gauge.restrict_design(&data_kernel_psi) * kernel_amp));
design_second
.slice_mut(s![.., 0..kernel_cols])
.assign(&(kernel_gauge.restrict_design(&data_kernel_psi_psi) * kernel_amp));
if let Some(gauge) = identifiability_transform
.as_ref()
.map(|z| gam_problem::Gauge::from_block_transforms(&[z.clone()]))
{
design_first = gauge.restrict_design(&design_first);
design_second = gauge.restrict_design(&design_second);
}
let (center_kernel, center_kernel_psi, center_kernel_psi_psi) =
fill_periodic_duchon_kernel_psi_matrices(
centers.view(),
centers.view(),
left,
period,
length_scale,
p_order,
s_order,
&coeffs,
)?;
let omega = kernel_gauge.restrict_penalty(¢er_kernel);
let omega_psi = kernel_gauge.restrict_penalty(¢er_kernel_psi);
let omega_psi_psi = kernel_gauge.restrict_penalty(¢er_kernel_psi_psi);
let mut penalty = Array2::<f64>::zeros((total_cols, total_cols));
let mut penalty_psi = Array2::<f64>::zeros((total_cols, total_cols));
let mut penalty_psi_psi = Array2::<f64>::zeros((total_cols, total_cols));
penalty
.slice_mut(s![0..kernel_cols, 0..kernel_cols])
.assign(&omega);
penalty_psi
.slice_mut(s![0..kernel_cols, 0..kernel_cols])
.assign(&omega_psi);
penalty_psi_psi
.slice_mut(s![0..kernel_cols, 0..kernel_cols])
.assign(&omega_psi_psi);
if let Some(gauge) = identifiability_transform
.as_ref()
.map(|z| gam_problem::Gauge::from_block_transforms(&[z.clone()]))
{
penalty = gauge.restrict_penalty(&penalty);
penalty_psi = gauge.restrict_penalty(&penalty_psi);
penalty_psi_psi = gauge.restrict_penalty(&penalty_psi_psi);
}
let (penalty_norm, penalty_norm_psi, penalty_norm_psi_psi, normalization_scale) =
normalize_penaltywith_psi_derivatives(
&symmetrize(&penalty),
&symmetrize(&penalty_psi),
&symmetrize(&penalty_psi_psi),
);
let (_, _, penaltyinfo) = filter_active_penalty_candidates(vec![PenaltyCandidate {
matrix: penalty_norm,
nullspace_dim_hint: 1,
source: PenaltySource::Primary,
normalization_scale,
kronecker_factors: None,
op: None,
}])?;
let mut penalties_derivative = Vec::new();
let mut penaltiessecond_derivative = Vec::new();
for info in penaltyinfo.iter().filter(|info| info.active) {
match info.source {
PenaltySource::Primary => {
penalties_derivative.push(penalty_norm_psi.clone());
penaltiessecond_derivative.push(penalty_norm_psi_psi.clone());
}
ref other => {
crate::bail_invalid_basis!(
"unexpected periodic Duchon penalty source in derivative path: {other:?}"
);
}
}
}
Ok(BasisPsiDerivativeBundle {
first: BasisPsiDerivativeResult {
design_derivative: design_first,
penalties_derivative,
implicit_operator: None,
},
second: BasisPsiSecondDerivativeResult {
designsecond_derivative: design_second,
penaltiessecond_derivative,
implicit_operator: None,
},
implicit_operator: None,
})
}
pub(crate) fn build_matern_design_psi_derivatives(
data: ArrayView2<'_, f64>,
centers: ArrayView2<'_, f64>,
length_scale: f64,
nu: MaternNu,
include_intercept: bool,
z_opt: Option<&Array2<f64>>,
aniso_log_scales: Option<&[f64]>,
) -> Result<ScalarDesignPsiDerivatives, BasisError> {
let k = centers.nrows();
let kernel_cols = z_opt.map(|z| z.ncols()).unwrap_or(k);
let total_cols = kernel_cols + usize::from(include_intercept);
build_scalar_design_psi_derivatives_shared(
data,
centers,
aniso_log_scales,
total_cols,
z_opt.cloned(),
None,
usize::from(include_intercept),
RadialScalarKind::Matern { length_scale, nu },
0.0,
)
}
pub(crate) fn build_matern_double_penalty_primarywith_psi_derivatives(
centers: ArrayView2<'_, f64>,
length_scale: f64,
nu: MaternNu,
include_intercept: bool,
z_opt: Option<&Array2<f64>>,
aniso_log_scales: Option<&[f64]>,
) -> Result<
(
Array2<f64>,
Array2<f64>,
Array2<f64>,
f64,
Array2<f64>,
Array2<f64>,
Array2<f64>,
),
BasisError,
> {
let k = centers.nrows();
let kernel_cols = z_opt.map(|z| z.ncols()).unwrap_or(k);
let total_cols = kernel_cols + usize::from(include_intercept);
let mut kernel = Array2::<f64>::zeros((k, k));
let mut kernel_psi = Array2::<f64>::zeros((k, k));
let mut kernel_psi_psi = Array2::<f64>::zeros((k, k));
for i in 0..k {
for j in i..k {
let r = if let Some(eta) = aniso_log_scales {
aniso_distance(
centers.row(i).as_slice().unwrap(),
centers.row(j).as_slice().unwrap(),
eta,
)
} else {
stable_euclidean_norm(
(0..centers.ncols()).map(|axis| centers[[i, axis]] - centers[[j, axis]]),
)
};
let value = matern_kernel_from_distance(r, length_scale, nu)?;
let d1 = matern_kernel_log_kappa_derivative_from_distance(r, length_scale, nu)?;
let d2 = matern_kernel_log_kappasecond_derivative_from_distance(r, length_scale, nu)?;
kernel[[i, j]] = value;
kernel[[j, i]] = value;
kernel_psi[[i, j]] = d1;
kernel_psi[[j, i]] = d1;
kernel_psi_psi[[i, j]] = d2;
kernel_psi_psi[[j, i]] = d2;
}
}
let (kernel, kernel_psi, kernel_psi_psi) = if let Some(gauge) =
z_opt.map(|z| gam_problem::Gauge::from_block_transforms(&[z.clone()]))
{
(
gauge.restrict_penalty(&kernel),
gauge.restrict_penalty(&kernel_psi),
gauge.restrict_penalty(&kernel_psi_psi),
)
} else {
(kernel, kernel_psi, kernel_psi_psi)
};
let mut s = Array2::<f64>::zeros((total_cols, total_cols));
let mut s_psi = Array2::<f64>::zeros((total_cols, total_cols));
let mut s_psi_psi = Array2::<f64>::zeros((total_cols, total_cols));
s.slice_mut(s![0..kernel_cols, 0..kernel_cols])
.assign(&kernel);
s_psi
.slice_mut(s![0..kernel_cols, 0..kernel_cols])
.assign(&kernel_psi);
s_psi_psi
.slice_mut(s![0..kernel_cols, 0..kernel_cols])
.assign(&kernel_psi_psi);
let (s_norm, s_norm_psi, s_norm_psi_psi, c) =
normalize_penaltywith_psi_derivatives(&s, &s_psi, &s_psi_psi);
Ok((s_norm, s_norm_psi, s_norm_psi_psi, c, s, s_psi, s_psi_psi))
}
pub(crate) struct ShrinkageProjectorFrame {
pub(crate) u: Array2<f64>,
pub(crate) evals: Array1<f64>,
pub(crate) in_null: Vec<f64>,
pub(crate) null_dim: usize,
pub(crate) gap_floor: f64,
}
impl ShrinkageProjectorFrame {
pub(crate) fn build(a_raw: &Array2<f64>) -> Result<Option<Self>, BasisError> {
if a_raw.nrows() == 0 {
return Ok(None);
}
let (sym, evals, evecs) = spectral_summary(a_raw)?;
let tol = spectral_tolerance(&sym, &evals);
let in_null: Vec<f64> = evals
.iter()
.map(|&ev| if ev.abs() <= tol { 1.0 } else { 0.0 })
.collect();
let null_dim = in_null.iter().filter(|&&b| b != 0.0).count();
if null_dim == 0 {
return Ok(None);
}
Ok(Some(Self {
u: evecs,
evals,
in_null,
null_dim,
gap_floor: tol.max(f64::MIN_POSITIVE),
}))
}
pub(crate) fn dim(&self) -> usize {
self.u.nrows()
}
pub(crate) fn connection(&self, a_dir: &Array2<f64>) -> (Array2<f64>, Array2<f64>, Vec<f64>) {
let p = self.dim();
let b_hat = fast_atb(&self.u, &fast_ab(&symmetrize(a_dir), &self.u));
let mut omega = Array2::<f64>::zeros((p, p));
for m in 0..p {
for k in 0..p {
if m == k {
continue;
}
let gap = self.evals[k] - self.evals[m];
if gap.abs() > self.gap_floor {
omega[[m, k]] = b_hat[[m, k]] / gap;
}
}
}
let lam_prime: Vec<f64> = (0..p).map(|k| b_hat[[k, k]]).collect();
(omega, b_hat, lam_prime)
}
pub(crate) fn projector_first_hat(&self, omega: &Array2<f64>) -> Array2<f64> {
let p = self.dim();
let mut out = Array2::<f64>::zeros((p, p));
for i in 0..p {
for j in 0..p {
let coeff = self.in_null[j] - self.in_null[i];
if coeff != 0.0 {
out[[i, j]] = omega[[i, j]] * coeff;
}
}
}
out
}
pub(crate) fn first(&self, a_dir: &Array2<f64>) -> Array2<f64> {
let (omega, _b_hat, _lam) = self.connection(a_dir);
let p1_hat = self.projector_first_hat(&omega);
self.to_lab(&p1_hat)
}
pub(crate) fn second(
&self,
a_dir_a: &Array2<f64>,
a_dir_b: &Array2<f64>,
a_cross: &Array2<f64>,
) -> Array2<f64> {
let p = self.dim();
let (omega_a, b_hat_a, _lam_a) = self.connection(a_dir_a);
let (omega_b, _b_hat_b, lam_prime_b) = self.connection(a_dir_b);
let p1a_hat = self.projector_first_hat(&omega_a);
let c_hat = fast_atb(&self.u, &fast_ab(&symmetrize(a_cross), &self.u));
let b_hat_a_prime = &c_hat + &(fast_ab(&b_hat_a, &omega_b) - fast_ab(&omega_b, &b_hat_a));
let mut omega_a_db = Array2::<f64>::zeros((p, p));
for m in 0..p {
for k in 0..p {
if m == k {
continue;
}
let gap = self.evals[k] - self.evals[m];
if gap.abs() > self.gap_floor {
omega_a_db[[m, k]] = b_hat_a_prime[[m, k]] / gap
- b_hat_a[[m, k]] * (lam_prime_b[k] - lam_prime_b[m]) / (gap * gap);
}
}
}
let mut p2_hat = fast_ab(&omega_b, &p1a_hat) - fast_ab(&p1a_hat, &omega_b);
for i in 0..p {
for j in 0..p {
let coeff = self.in_null[j] - self.in_null[i];
if coeff != 0.0 {
p2_hat[[i, j]] += omega_a_db[[i, j]] * coeff;
}
}
}
self.to_lab(&p2_hat)
}
pub(crate) fn to_lab(&self, p_hat: &Array2<f64>) -> Array2<f64> {
let inv_norm = 1.0 / (self.null_dim as f64).sqrt();
symmetrize(&fast_ab(&self.u, &fast_abt(p_hat, &self.u))).mapv(|v| v * inv_norm)
}
}
pub(crate) fn matern_nullspace_shrinkage_psi_derivatives(
a_raw: &Array2<f64>,
a_raw_psi: &Array2<f64>,
a_raw_psi_psi: &Array2<f64>,
) -> Result<(Array2<f64>, Array2<f64>), BasisError> {
let p = a_raw.nrows();
let zero = || Array2::<f64>::zeros((p, p));
let Some(frame) = ShrinkageProjectorFrame::build(a_raw)? else {
return Ok((zero(), zero()));
};
let first = frame.first(a_raw_psi);
let second = frame.second(a_raw_psi, a_raw_psi, a_raw_psi_psi);
Ok((first, second))
}
pub(crate) fn active_matern_double_penalty_derivatives(
penaltyinfo: &[PenaltyInfo],
primary_derivative: &Array2<f64>,
shrinkage_derivative: &Array2<f64>,
) -> Result<Vec<Array2<f64>>, BasisError> {
penaltyinfo
.iter()
.filter(|info| info.active)
.map(|info| match &info.source {
PenaltySource::Primary => Ok(primary_derivative.clone()),
PenaltySource::DoublePenaltyNullspace => Ok(shrinkage_derivative.clone()),
other => Err(BasisError::InvalidInput(format!(
"unexpected Matérn penalty source in double-penalty path: {other:?}"
))),
})
.collect()
}
pub fn build_matern_basis_log_kappa_derivative(
data: ArrayView2<'_, f64>,
spec: &MaternBasisSpec,
) -> Result<BasisPsiDerivativeResult, BasisError> {
let mut workspace = BasisWorkspace::default();
build_matern_basis_log_kappa_derivativewithworkspace(data, spec, &mut workspace)
}
pub fn build_matern_basis_log_kappa_derivativewithworkspace(
data: ArrayView2<'_, f64>,
spec: &MaternBasisSpec,
workspace: &mut BasisWorkspace,
) -> Result<BasisPsiDerivativeResult, BasisError> {
let mut bundle = build_matern_basis_log_kappa_derivativeswithworkspace(data, spec, workspace)?;
bundle.first.implicit_operator = bundle.implicit_operator;
Ok(bundle.first)
}
pub fn build_matern_basis_log_kappa_derivatives(
data: ArrayView2<'_, f64>,
spec: &MaternBasisSpec,
) -> Result<BasisPsiDerivativeBundle, BasisError> {
let mut workspace = BasisWorkspace::default();
build_matern_basis_log_kappa_derivativeswithworkspace(data, spec, &mut workspace)
}
pub fn build_matern_basis_log_kappa_derivativeswithworkspace(
data: ArrayView2<'_, f64>,
spec: &MaternBasisSpec,
workspace: &mut BasisWorkspace,
) -> Result<BasisPsiDerivativeBundle, BasisError> {
let base = build_matern_basiswithworkspace(data, spec, workspace)?;
let (base_centers, base_transform, base_aniso) = match &base.metadata {
BasisMetadata::Matern {
centers,
identifiability_transform,
aniso_log_scales,
..
} => (
centers.clone(),
identifiability_transform.clone(),
aniso_log_scales.clone(),
),
other => {
return Err(BasisError::InvalidInput(format!(
"Matérn ψ-derivative build expected Matérn metadata, got {:?}",
std::mem::discriminant(other)
)));
}
};
let centers = expand_periodic_centers(&base_centers, spec.periodic.as_deref())?;
let z_opt = base_transform;
let aniso = base_aniso.as_deref();
let design_derivatives = build_matern_design_psi_derivatives(
data,
centers.view(),
spec.length_scale,
spec.nu,
spec.include_intercept,
z_opt.as_ref(),
aniso,
)?;
let (penalties_derivative, penaltiessecond_derivative) = if spec.double_penalty {
let (_, primary_derivative, primarysecond_derivative, _, a_raw, a_raw_psi, a_raw_psi_psi) =
build_matern_double_penalty_primarywith_psi_derivatives(
centers.view(),
spec.length_scale,
spec.nu,
spec.include_intercept,
z_opt.as_ref(),
aniso,
)?;
let (shrinkage_first, shrinkagesecond) =
if base.penaltyinfo.iter().any(|info| {
info.active && matches!(info.source, PenaltySource::DoublePenaltyNullspace)
}) {
matern_nullspace_shrinkage_psi_derivatives(&a_raw, &a_raw_psi, &a_raw_psi_psi)?
} else {
(
Array2::<f64>::zeros(a_raw.raw_dim()),
Array2::<f64>::zeros(a_raw.raw_dim()),
)
};
(
active_matern_double_penalty_derivatives(
&base.penaltyinfo,
&primary_derivative,
&shrinkage_first,
)?,
active_matern_double_penalty_derivatives(
&base.penaltyinfo,
&primarysecond_derivative,
&shrinkagesecond,
)?,
)
} else {
build_matern_operator_penalty_psi_derivatives(
centers.view(),
spec.length_scale,
spec.nu,
spec.include_intercept,
z_opt.as_ref(),
aniso,
)?
};
Ok(BasisPsiDerivativeBundle {
first: BasisPsiDerivativeResult {
design_derivative: design_derivatives.design_first,
penalties_derivative,
implicit_operator: None,
},
second: BasisPsiSecondDerivativeResult {
designsecond_derivative: design_derivatives.design_second_diag,
penaltiessecond_derivative,
implicit_operator: None,
},
implicit_operator: design_derivatives.implicit_operator,
})
}
pub fn build_matern_basis_log_kappasecond_derivative(
data: ArrayView2<'_, f64>,
spec: &MaternBasisSpec,
) -> Result<BasisPsiSecondDerivativeResult, BasisError> {
let mut workspace = BasisWorkspace::default();
build_matern_basis_log_kappasecond_derivativewithworkspace(data, spec, &mut workspace)
}
pub fn build_matern_basis_log_kappasecond_derivativewithworkspace(
data: ArrayView2<'_, f64>,
spec: &MaternBasisSpec,
workspace: &mut BasisWorkspace,
) -> Result<BasisPsiSecondDerivativeResult, BasisError> {
let mut bundle = build_matern_basis_log_kappa_derivativeswithworkspace(data, spec, workspace)?;
bundle.second.implicit_operator = bundle.implicit_operator;
Ok(bundle.second)
}
pub(crate) fn build_matern_design_psi_aniso_derivatives(
data: ArrayView2<'_, f64>,
centers: ArrayView2<'_, f64>,
length_scale: f64,
nu: MaternNu,
eta: &[f64],
include_intercept: bool,
z_opt: Option<&Array2<f64>>,
) -> Result<AnisoBasisPsiDerivatives, BasisError> {
let k = centers.nrows();
let p_constrained = z_opt.map(|z| z.ncols()).unwrap_or(k);
let n_poly = usize::from(include_intercept);
let p_smooth = p_constrained + n_poly;
build_aniso_design_psi_derivatives_shared(
data,
centers,
eta,
p_smooth,
z_opt.cloned(),
None,
n_poly,
RadialScalarKind::Matern { length_scale, nu },
)
}
pub(crate) fn build_matern_aniso_primary_raw_derivative_matrices(
centers: ArrayView2<'_, f64>,
eta: &[f64],
length_scale: f64,
nu: MaternNu,
) -> Result<(Vec<Array2<f64>>, Vec<Array2<f64>>), BasisError> {
let k = centers.nrows();
let dim = centers.ncols();
let row_blocks: Result<Vec<_>, BasisError> = (0..k)
.into_par_iter()
.map(|i| {
let ci: Vec<f64> = (0..dim).map(|a| centers[[i, a]]).collect();
let mut first_by_axis: Vec<Vec<f64>> =
(0..dim).map(|_| Vec::with_capacity(k - i)).collect();
let mut second_diag_by_axis: Vec<Vec<f64>> =
(0..dim).map(|_| Vec::with_capacity(k - i)).collect();
for j in i..k {
let cj: Vec<f64> = (0..dim).map(|a| centers[[j, a]]).collect();
let (r, s_vec) = aniso_distance_and_components(&ci, &cj, eta);
let (_, q, t, _, _) = matern_aniso_extended_radial_scalars(r, length_scale, nu)?;
for a in 0..dim {
let s_a = s_vec[a];
first_by_axis[a].push(q * s_a);
second_diag_by_axis[a].push(2.0 * q * s_a + t * s_a * s_a);
}
}
Ok((first_by_axis, second_diag_by_axis))
})
.collect();
let row_blocks = row_blocks?;
let mut raw_first = vec![Array2::<f64>::zeros((k, k)); dim];
let mut raw_second_diag = vec![Array2::<f64>::zeros((k, k)); dim];
for (i, (first_by_axis, second_diag_by_axis)) in row_blocks.into_iter().enumerate() {
for (offset, j) in (i..k).enumerate() {
for a in 0..dim {
let d1 = first_by_axis[a][offset];
let d2 = second_diag_by_axis[a][offset];
raw_first[a][[i, j]] = d1;
raw_first[a][[j, i]] = d1;
raw_second_diag[a][[i, j]] = d2;
raw_second_diag[a][[j, i]] = d2;
}
}
}
Ok((raw_first, raw_second_diag))
}
pub(crate) fn build_matern_aniso_raw_cross_derivative_matrix(
centers: ArrayView2<'_, f64>,
eta: &[f64],
length_scale: f64,
nu: MaternNu,
axis_a: usize,
axis_b: usize,
) -> Result<Array2<f64>, BasisError> {
let k = centers.nrows();
let dim = centers.ncols();
let row_blocks: Result<Vec<_>, BasisError> = (0..k)
.into_par_iter()
.map(|i| {
let ci: Vec<f64> = (0..dim).map(|ax| centers[[i, ax]]).collect();
let mut values = Vec::with_capacity(k - i);
for j in i..k {
let cj: Vec<f64> = (0..dim).map(|ax| centers[[j, ax]]).collect();
let (r, s_vec) = aniso_distance_and_components(&ci, &cj, eta);
let (_, _, t_val, _, _) =
matern_aniso_extended_radial_scalars(r, length_scale, nu)?;
values.push(t_val * s_vec[axis_a] * s_vec[axis_b]);
}
Ok(values)
})
.collect();
let row_blocks = row_blocks?;
let mut raw_cross = Array2::<f64>::zeros((k, k));
for (i, values) in row_blocks.into_iter().enumerate() {
for (offset, j) in (i..k).enumerate() {
let value = values[offset];
raw_cross[[i, j]] = value;
raw_cross[[j, i]] = value;
}
}
Ok(raw_cross)
}
pub fn build_matern_basis_log_kappa_aniso_derivatives(
data: ArrayView2<'_, f64>,
spec: &MaternBasisSpec,
) -> Result<AnisoBasisPsiDerivatives, BasisError> {
if spec.aniso_log_scales.is_none() {
return Err(BasisError::InvalidInput(
"aniso derivatives require aniso_log_scales to be set".to_string(),
));
}
let dim = data.ncols();
let base = build_matern_basiswithworkspace(data, spec, &mut BasisWorkspace::default())?;
let (base_centers, z_opt, base_aniso) = match &base.metadata {
BasisMetadata::Matern {
centers,
identifiability_transform,
aniso_log_scales,
..
} => (
centers.clone(),
identifiability_transform.clone(),
aniso_log_scales.clone(),
),
other => {
return Err(BasisError::InvalidInput(format!(
"Matérn aniso ψ-derivative build expected Matérn metadata, got {:?}",
std::mem::discriminant(other)
)));
}
};
let centers = expand_periodic_centers(&base_centers, spec.periodic.as_deref())?;
let eta = base_aniso.as_deref().ok_or_else(|| {
BasisError::InvalidInput(
"aniso derivatives require resolved aniso_log_scales from the value build".to_string(),
)
})?;
if eta.len() != dim {
crate::bail_dim_basis!(
"resolved aniso_log_scales length {} != data dimension {dim}",
eta.len()
);
}
let mut result = build_matern_design_psi_aniso_derivatives(
data,
centers.view(),
spec.length_scale,
spec.nu,
eta,
spec.include_intercept,
z_opt.as_ref(),
)?;
if spec.double_penalty {
let k = centers.nrows();
let kernel_cols = z_opt.as_ref().map(|z| z.ncols()).unwrap_or(k);
let total_cols = kernel_cols + usize::from(spec.include_intercept);
let mut primary_first = vec![Array2::<f64>::zeros((total_cols, total_cols)); dim];
let mut primary_second_diag = vec![Array2::<f64>::zeros((total_cols, total_cols)); dim];
let coefficient_gauge = z_opt
.as_ref()
.map(|z| gam_problem::Gauge::from_block_transforms(&[z.clone()]));
let (mut raw_first, mut raw_second_diag) =
build_matern_aniso_primary_raw_derivative_matrices(
centers.view(),
eta,
spec.length_scale,
spec.nu,
)?;
for a in 0..dim {
let projected_first = if let Some(gauge) = coefficient_gauge.as_ref() {
gauge.restrict_penalty(&raw_first[a])
} else {
std::mem::take(&mut raw_first[a])
};
let projected_second = if let Some(gauge) = coefficient_gauge.as_ref() {
gauge.restrict_penalty(&raw_second_diag[a])
} else {
std::mem::take(&mut raw_second_diag[a])
};
primary_first[a]
.slice_mut(s![0..kernel_cols, 0..kernel_cols])
.assign(&projected_first);
primary_second_diag[a]
.slice_mut(s![0..kernel_cols, 0..kernel_cols])
.assign(&projected_second);
}
let mut dp_cross_pairs: Vec<(usize, usize)> = Vec::new();
for a in 0..dim {
for b in (a + 1)..dim {
dp_cross_pairs.push((a, b));
}
}
let has_shrinkage = base.penaltyinfo.iter().any(|info| {
info.active && matches!(info.source, PenaltySource::DoublePenaltyNullspace)
});
let shrinkage_frame = if has_shrinkage {
let kernel = build_matern_kernel_penalty(
centers.view(),
spec.length_scale,
spec.nu,
spec.include_intercept,
Some(eta),
)?;
let kblock = kernel.slice(s![0..k, 0..k]).to_owned();
let mut a_raw = Array2::<f64>::zeros((total_cols, total_cols));
let projected = if let Some(gauge) = coefficient_gauge.as_ref() {
gauge.restrict_penalty(&kblock)
} else {
kblock
};
a_raw
.slice_mut(s![0..kernel_cols, 0..kernel_cols])
.assign(&projected);
ShrinkageProjectorFrame::build(&a_raw)?
} else {
None
};
let shrinkage_first: Vec<Array2<f64>> = (0..dim)
.map(|a| match &shrinkage_frame {
Some(frame) => frame.first(&primary_first[a]),
None => Array2::<f64>::zeros((total_cols, total_cols)),
})
.collect();
let shrinkage_second_diag: Vec<Array2<f64>> = (0..dim)
.map(|a| match &shrinkage_frame {
Some(frame) => frame.second(
&primary_first[a],
&primary_first[a],
&primary_second_diag[a],
),
None => Array2::<f64>::zeros((total_cols, total_cols)),
})
.collect();
result.penalties_first = Vec::with_capacity(dim);
result.penalties_second_diag = Vec::with_capacity(dim);
for a in 0..dim {
let pf = active_matern_double_penalty_derivatives(
&base.penaltyinfo,
&primary_first[a],
&shrinkage_first[a],
)?;
let ps = active_matern_double_penalty_derivatives(
&base.penaltyinfo,
&primary_second_diag[a],
&shrinkage_second_diag[a],
)?;
result.penalties_first.push(pf);
result.penalties_second_diag.push(ps);
}
result.penalties_cross_pairs = dp_cross_pairs;
let centers_owned = centers.to_owned();
let eta_owned = eta.to_vec();
let gauge_owned = z_opt
.as_ref()
.map(|z| gam_problem::Gauge::from_block_transforms(&[z.clone()]));
let penaltyinfo = base.penaltyinfo.clone();
let length_scale = spec.length_scale;
let nu = spec.nu;
let primary_first_owned = primary_first.clone();
let include_intercept = spec.include_intercept;
result.penalties_cross_provider = Some(AnisoPenaltyCrossProvider::new(
move |axis_a: usize, axis_b: usize| {
let (a, b) = if axis_a < axis_b {
(axis_a, axis_b)
} else {
(axis_b, axis_a)
};
if a == b || b >= eta_owned.len() {
return Ok(Vec::new());
}
let raw_cross = build_matern_aniso_raw_cross_derivative_matrix(
centers_owned.view(),
&eta_owned,
length_scale,
nu,
a,
b,
)?;
let projected: Array2<f64> = if let Some(gauge) = gauge_owned.as_ref() {
gauge.restrict_penalty(&raw_cross)
} else {
raw_cross
};
let mut padded = Array2::<f64>::zeros((total_cols, total_cols));
padded
.slice_mut(s![0..kernel_cols, 0..kernel_cols])
.assign(&projected);
let shrinkage_cross = if penaltyinfo.iter().any(|info| {
info.active && matches!(info.source, PenaltySource::DoublePenaltyNullspace)
}) {
let kernel = build_matern_kernel_penalty(
centers_owned.view(),
length_scale,
nu,
include_intercept,
Some(&eta_owned),
)?;
let k = centers_owned.nrows();
let kblock = kernel.slice(s![0..k, 0..k]).to_owned();
let projected_a = if let Some(gauge) = gauge_owned.as_ref() {
gauge.restrict_penalty(&kblock)
} else {
kblock
};
let mut a_raw = Array2::<f64>::zeros((total_cols, total_cols));
a_raw
.slice_mut(s![0..kernel_cols, 0..kernel_cols])
.assign(&projected_a);
match ShrinkageProjectorFrame::build(&a_raw)? {
Some(frame) => {
frame.second(&primary_first_owned[a], &primary_first_owned[b], &padded)
}
None => Array2::<f64>::zeros((total_cols, total_cols)),
}
} else {
Array2::<f64>::zeros((total_cols, total_cols))
};
active_matern_double_penalty_derivatives(&penaltyinfo, &padded, &shrinkage_cross)
},
));
} else {
let (per_axis, cross_pairs, cross_provider) =
build_matern_operator_penalty_aniso_derivatives(
centers.view(),
spec.length_scale,
spec.nu,
spec.include_intercept,
z_opt.as_ref(),
eta,
)?;
result.penalties_first = Vec::with_capacity(dim);
result.penalties_second_diag = Vec::with_capacity(dim);
for (pen_first, pen_second) in per_axis {
result.penalties_first.push(pen_first);
result.penalties_second_diag.push(pen_second);
}
result.penalties_cross_pairs = cross_pairs;
result.penalties_cross_provider = Some(cross_provider);
}
Ok(result)
}
#[cfg(test)]
mod harmonic_penalty_invariants_tests {
use super::*;
use ndarray::array;
#[test]
fn harmonic_double_penalty_targets_primary_nullspace() {
let data = array![
[-0.9, 0.0],
[-0.4, 1.0],
[0.0, 2.0],
[0.4, 3.0],
[0.9, 4.0],
[0.2, 5.0],
];
let spec = SphericalSplineBasisSpec {
center_strategy: CenterStrategy::FarthestPoint { num_centers: 6 },
penalty_order: 2,
double_penalty: true,
radians: true,
method: SphereMethod::Harmonic,
max_degree: Some(3),
wahba_kernel: SphereWahbaKernel::Sobolev,
identifiability: SphericalSplineIdentifiability::CenterSumToZero,
};
let built = build_spherical_harmonic_basis(data.view(), &spec).expect("harmonic basis");
assert_eq!(built.penalties.len(), 2);
let primary = &built.penalties[0];
let shrink = &built.penalties[1];
for col in 0..primary.ncols() {
let primary_diag = primary[[col, col]].abs();
let shrink_diag = shrink[[col, col]].abs();
if col < SPHERE_UNPENALIZED_LOW_DEGREE * (SPHERE_UNPENALIZED_LOW_DEGREE + 2) {
assert!(
primary_diag <= 1e-12,
"low-degree column {col} must be primary-null"
);
assert!(
shrink_diag > 0.0,
"low-degree column {col} must be shrink-penalized"
);
} else {
assert!(
primary_diag > 0.0,
"higher-degree column {col} must carry roughness"
);
assert!(
shrink_diag <= 1e-12,
"higher-degree column {col} must not be in null shrinkage"
);
}
}
}
}