use super::*;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub(crate) struct ConstraintNullspaceCacheKey {
pub(crate) centersrows: usize,
pub(crate) centers_cols: usize,
pub(crate) centers_hash: u64,
pub(crate) order: ConstraintNullspaceOrderKey,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub(crate) enum ConstraintNullspaceOrderKey {
Duchon(DuchonNullspaceOrder),
ThinPlate,
}
#[derive(Default, Clone, Debug)]
pub(crate) struct ConstraintNullspaceCache {
pub(crate) map: HashMap<ConstraintNullspaceCacheKey, Arc<Array2<f64>>>,
pub(crate) order: Vec<ConstraintNullspaceCacheKey>,
}
pub(crate) const CONSTRAINT_NULLSPACE_CACHE_MAX_ENTRIES: usize = 32;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub(crate) struct OwnedDataCacheKey {
pub(crate) rows: usize,
pub(crate) cols: usize,
pub(crate) ptr: usize,
pub(crate) stride0: isize,
pub(crate) stride1: isize,
}
#[derive(Debug)]
pub(crate) struct BasisCacheContext {
pub(crate) constraint_nullspace: ConstraintNullspaceCache,
pub(crate) owned_data:
crate::solver::resource::ByteLruCache<OwnedDataCacheKey, Arc<Array2<f64>>>,
}
impl BasisCacheContext {
pub(crate) fn with_policy(policy: &crate::solver::resource::ResourcePolicy) -> Self {
Self {
constraint_nullspace: ConstraintNullspaceCache::default(),
owned_data: crate::solver::resource::ByteLruCache::with_max_entries(
policy.max_owned_data_cache_bytes,
crate::solver::resource::OWNED_DATA_CACHE_MAX_ENTRIES,
),
}
}
}
impl Default for BasisCacheContext {
fn default() -> Self {
Self::with_policy(&crate::solver::resource::ResourcePolicy::default_library())
}
}
#[derive(Debug)]
pub struct BasisWorkspace {
pub(crate) cache: BasisCacheContext,
pub(crate) policy: crate::solver::resource::ResourcePolicy,
}
impl BasisWorkspace {
pub fn new() -> Self {
Self::default()
}
pub fn with_policy(policy: crate::solver::resource::ResourcePolicy) -> Self {
Self {
cache: BasisCacheContext::with_policy(&policy),
policy,
}
}
pub fn default_library() -> Self {
Self::with_policy(crate::solver::resource::ResourcePolicy::default_library())
}
pub fn policy(&self) -> &crate::solver::resource::ResourcePolicy {
&self.policy
}
}
impl Default for BasisWorkspace {
fn default() -> Self {
Self::default_library()
}
}
pub(crate) fn hash_arrayview2(values: ArrayView2<'_, f64>) -> u64 {
let mut hasher = DefaultHasher::new();
values.nrows().hash(&mut hasher);
values.ncols().hash(&mut hasher);
for v in values {
v.to_bits().hash(&mut hasher);
}
hasher.finish()
}
pub(crate) fn shared_owned_data_matrix(
data: ArrayView2<'_, f64>,
cache: &BasisCacheContext,
) -> Arc<Array2<f64>> {
let key = OwnedDataCacheKey {
rows: data.nrows(),
cols: data.ncols(),
ptr: data.as_ptr() as usize,
stride0: data.strides()[0],
stride1: data.strides()[1],
};
if let Some(hit) = cache.owned_data.get(&key) {
return hit;
}
let owned = Arc::new(data.to_owned());
if let Some(hit) = cache.owned_data.get(&key) {
return hit;
}
cache.owned_data.insert(key, owned.clone());
owned
}
#[inline]
pub(crate) fn shared_owned_data_matrix_from_view(data: ArrayView2<'_, f64>) -> Arc<Array2<f64>> {
Arc::new(data.to_owned())
}
#[inline]
pub(crate) fn shared_owned_centers_matrix_from_view(
centers: ArrayView2<'_, f64>,
) -> Arc<Array2<f64>> {
Arc::new(centers.to_owned())
}
pub(crate) fn kernel_constraint_nullspace(
centers: ArrayView2<'_, f64>,
order: DuchonNullspaceOrder,
cache: &mut BasisCacheContext,
) -> Result<Array2<f64>, BasisError> {
let effective_order = duchon_effective_nullspace_order(centers, order);
let degraded = effective_order != order;
let key = ConstraintNullspaceCacheKey {
centersrows: centers.nrows(),
centers_cols: centers.ncols(),
centers_hash: hash_arrayview2(centers),
order: ConstraintNullspaceOrderKey::Duchon(effective_order),
};
if let Some(hit) = cache.constraint_nullspace.map.get(&key) {
return Ok((**hit).clone());
}
let p_k = polynomial_block_from_order(centers, effective_order);
let z = Arc::new(kernel_constraint_nullspace_from_matrix(p_k.view()).map_err(|err| {
if degraded {
BasisError::InvalidInput(format!(
"Duchon degraded from order={:?} to order={:?} due to insufficient centers ({} in dim={}); order={:?} construction then failed: {err}",
order,
effective_order,
centers.nrows(),
centers.ncols(),
effective_order,
))
} else {
err
}
})?);
if let Some(hit) = cache.constraint_nullspace.map.get(&key) {
return Ok((**hit).clone());
}
cache.constraint_nullspace.map.insert(key, z.clone());
cache.constraint_nullspace.order.push(key);
while cache.constraint_nullspace.map.len() > CONSTRAINT_NULLSPACE_CACHE_MAX_ENTRIES {
if cache.constraint_nullspace.order.is_empty() {
break;
}
let oldkey = cache.constraint_nullspace.order.remove(0);
cache.constraint_nullspace.map.remove(&oldkey);
}
Ok((*z).clone())
}
pub(crate) fn thin_plate_kernel_constraint_nullspace(
centers: ArrayView2<'_, f64>,
cache: &mut BasisCacheContext,
) -> Result<Array2<f64>, BasisError> {
let key = ConstraintNullspaceCacheKey {
centersrows: centers.nrows(),
centers_cols: centers.ncols(),
centers_hash: hash_arrayview2(centers),
order: ConstraintNullspaceOrderKey::ThinPlate,
};
if let Some(hit) = cache.constraint_nullspace.map.get(&key) {
return Ok((**hit).clone());
}
let p_k = thin_plate_polynomial_block(centers);
if centers.nrows() < p_k.ncols() {
crate::bail_invalid_basis!(
"thin-plate spline requires at least {} centers to span the degree-{} polynomial null space in dimension {}; got {}",
p_k.ncols(),
thin_plate_polynomial_degree(centers.ncols()),
centers.ncols(),
centers.nrows()
);
}
let (z, rank) =
rrqr_nullspace_basis(&p_k, default_rrqr_rank_alpha()).map_err(BasisError::LinalgError)?;
if rank != p_k.ncols() {
crate::bail_invalid_basis!(
"thin-plate spline polynomial block is rank deficient at the selected centers: expected rank {}, got {}; choose geometrically independent centers for dimension {}",
p_k.ncols(),
rank,
centers.ncols()
);
}
let z = Arc::new(z);
if let Some(hit) = cache.constraint_nullspace.map.get(&key) {
return Ok((**hit).clone());
}
cache.constraint_nullspace.map.insert(key, z.clone());
cache.constraint_nullspace.order.push(key);
while cache.constraint_nullspace.map.len() > CONSTRAINT_NULLSPACE_CACHE_MAX_ENTRIES {
if cache.constraint_nullspace.order.is_empty() {
break;
}
let oldkey = cache.constraint_nullspace.order.remove(0);
cache.constraint_nullspace.map.remove(&oldkey);
}
Ok((*z).clone())
}
pub(crate) fn matern_identifiability_transform(
centers: ArrayView2<'_, f64>,
identifiability: &MaternIdentifiability,
) -> Result<Option<Array2<f64>>, BasisError> {
let k = centers.nrows();
match identifiability {
MaternIdentifiability::None => Ok(None),
MaternIdentifiability::CenterSumToZero => {
let q = Array2::<f64>::ones((k, 1));
Ok(Some(kernel_constraint_nullspace_from_matrix(q.view())?))
}
MaternIdentifiability::CenterLinearOrthogonal => {
let effective_order =
duchon_effective_nullspace_order(centers, DuchonNullspaceOrder::Linear);
let q = polynomial_block_from_order(centers, effective_order);
Ok(Some(kernel_constraint_nullspace_from_matrix(q.view())?))
}
MaternIdentifiability::FrozenTransform { transform, .. } => {
if transform.nrows() != k {
crate::bail_dim_basis!(
"frozen Matérn identifiability transform mismatch: centers={k}, transform rows={}",
transform.nrows()
);
}
Ok(Some(transform.clone()))
}
}
}
pub(crate) fn build_matern_operator_penalty_candidates(
centers: ArrayView2<'_, f64>,
length_scale: f64,
nu: MaternNu,
include_intercept: bool,
z_opt: Option<&Array2<f64>>,
aniso_log_scales: Option<&[f64]>,
) -> Result<Vec<PenaltyCandidate>, BasisError> {
let ops = build_matern_collocation_operator_matrices(
centers,
None,
length_scale,
nu,
include_intercept,
z_opt.map(|z| z.view()),
aniso_log_scales,
)?;
let matern_spec = DuchonOperatorPenaltySpec::matern_for_smoothness(nu, centers.ncols());
Ok(operator_penalty_candidates_from_collocation(
&ops.d0,
&ops.d1,
&ops.d2,
&matern_spec,
))
}
pub(crate) fn matern_double_penalty_candidates_with_decision(
primary: &Array2<f64>,
frozen: Option<bool>,
) -> Result<(Vec<PenaltyCandidate>, bool), BasisError> {
let mut candidates = vec![normalize_penalty_candidate(
primary.clone(),
0,
PenaltySource::Primary,
)];
let survived = match frozen {
Some(forced) => {
if forced && let Some(shrinkage) = build_nullspace_shrinkage_penalty(primary)? {
candidates.push(normalize_penalty_candidate(
shrinkage.sym_penalty,
0,
PenaltySource::DoublePenaltyNullspace,
));
true
} else {
false
}
}
None => {
if let Some(shrinkage) = build_nullspace_shrinkage_penalty(primary)? {
candidates.push(normalize_penalty_candidate(
shrinkage.sym_penalty,
0,
PenaltySource::DoublePenaltyNullspace,
));
true
} else {
false
}
}
};
Ok((candidates, survived))
}
pub(crate) fn build_matern_double_penalty_candidates(
spline: &MaternSplineBasis,
full_transform: Option<&Array2<f64>>,
frozen_nullspace_shrinkage_survived: Option<bool>,
) -> Result<(Vec<PenaltyCandidate>, bool), BasisError> {
let primary = project_penalty_matrix(&spline.penalty_kernel, full_transform);
matern_double_penalty_candidates_with_decision(&primary, frozen_nullspace_shrinkage_survived)
}
pub fn create_matern_spline_basiswithworkspace(
data: ArrayView2<'_, f64>,
centers: ArrayView2<'_, f64>,
length_scale: f64,
nu: MaternNu,
include_intercept: bool,
aniso_log_scales: Option<&[f64]>,
workspace: &mut BasisWorkspace,
) -> Result<MaternSplineBasis, BasisError> {
let n = data.nrows();
let d = data.ncols();
let k = centers.nrows();
let total_cols = k + usize::from(include_intercept);
let dense_bytes = dense_design_bytes(n, total_cols);
if dense_bytes > workspace.policy().max_single_materialization_bytes {
crate::bail_invalid_basis!(
"Matérn basis dense design exceeds resource policy: n={n}, p={total_cols}, dense={:.1} MiB, cap={:.1} MiB",
dense_bytes as f64 / (1024.0 * 1024.0),
workspace.policy().max_single_materialization_bytes as f64 / (1024.0 * 1024.0),
);
}
if d == 0 {
crate::bail_invalid_basis!("Matérn basis requires at least one covariate dimension");
}
if k == 0 {
crate::bail_invalid_basis!("Matérn basis requires at least one center");
}
if centers.ncols() != d {
crate::bail_dim_basis!(
"Matérn basis dimension mismatch: data has {d} columns, centers have {}",
centers.ncols()
);
}
if data.iter().any(|v| !v.is_finite()) || centers.iter().any(|v| !v.is_finite()) {
crate::bail_invalid_basis!("Matérn basis requires finite data and center values");
}
validate_matern_length_scale(length_scale)?;
if let Some(eta) = aniso_log_scales {
if eta.len() != d {
crate::bail_dim_basis!(
"aniso_log_scales length {} does not match data dimension {d}",
eta.len()
);
}
if eta.iter().any(|v| !v.is_finite()) {
crate::bail_invalid_basis!("aniso_log_scales must contain finite values");
}
}
let warn_bounds = if let Some(eta) = aniso_log_scales {
let y_centers = points_in_aniso_y_space(centers, eta);
pairwise_distance_bounds(y_centers.view())
} else {
pairwise_distance_bounds(centers)
};
if let Some((r_min, r_max)) = warn_bounds {
let kappa = 1.0 / length_scale.max(1e-300);
let kappa_lo = 1e-2 / r_max;
let kappa_hi = 1e2 / r_min;
if kappa < kappa_lo || kappa > kappa_hi {
log::debug!(
"Matérn κ={} is outside recommended range [{}, {}] derived from centers (r_min={}, r_max={}); kernel conditioning may degrade",
kappa,
kappa_lo,
kappa_hi,
r_min,
r_max
);
}
}
let mut kernel_block = Array2::<f64>::zeros((n, k));
let mut center_kernel = Array2::<f64>::zeros((k, k));
let axis_scales = aniso_log_scales.map(aniso_axis_scales);
let kernel_result: Result<(), BasisError> = kernel_block
.axis_iter_mut(Axis(0))
.into_par_iter()
.enumerate()
.try_for_each(|(i, mut row)| {
for j in 0..k {
let r = if let Some(scales) = axis_scales.as_deref() {
aniso_distance_rows_with_scales(data, i, centers, j, scales)
} else {
euclidean_distance_rows(data, i, centers, j)
};
row[j] = matern_kernel_from_distance(r, length_scale, nu)?;
}
Ok(())
});
kernel_result?;
fill_symmetric_from_row_kernel(&mut center_kernel, |i, j| {
let r = if let Some(scales) = axis_scales.as_deref() {
aniso_distance_rows_with_scales(centers, i, centers, j, scales)
} else {
euclidean_distance_rows(centers, i, centers, j)
};
matern_kernel_from_distance(r, length_scale, nu)
})?;
let mut basis = Array2::<f64>::zeros((n, total_cols));
basis.slice_mut(s![.., 0..k]).assign(&kernel_block);
if include_intercept {
basis.column_mut(k).fill(1.0);
}
let mut penalty_kernel = Array2::<f64>::zeros((total_cols, total_cols));
penalty_kernel
.slice_mut(s![0..k, 0..k])
.assign(¢er_kernel);
let penalty_ridge = build_nullspace_shrinkage_penalty(&penalty_kernel)?
.map(|block| block.sym_penalty)
.unwrap_or_else(|| Array2::<f64>::zeros((total_cols, total_cols)));
Ok(MaternSplineBasis {
basis,
penalty_kernel,
penalty_ridge,
num_kernel_basis: k,
num_polynomial_basis: usize::from(include_intercept),
dimension: d,
})
}
#[inline]
pub(crate) fn validate_lat_lon_matrix(
data: ArrayView2<'_, f64>,
context: &str,
radians: bool,
) -> Result<(), BasisError> {
if data.ncols() != 2 {
crate::bail_dim_basis!(
"{context} requires exactly two columns: latitude and longitude; got {}",
data.ncols()
);
}
if data.nrows() == 0 {
crate::bail_invalid_basis!("{context} requires at least one row");
}
let (lat_lo, lat_hi, unit) = if radians {
(
-std::f64::consts::FRAC_PI_2,
std::f64::consts::FRAC_PI_2,
"radians",
)
} else {
(-90.0, 90.0, "degrees")
};
for (i, row) in data.outer_iter().enumerate() {
let lat = row[0];
let lon = row[1];
if !lat.is_finite() || !lon.is_finite() {
crate::bail_invalid_basis!(
"{context} requires finite latitude/longitude; row {i} has ({lat}, {lon})"
);
}
if !(lat_lo..=lat_hi).contains(&lat) {
crate::bail_invalid_basis!(
"{context} latitude must be in [{lat_lo}, {lat_hi}] {unit}; row {i} has {lat}"
);
}
}
Ok(())
}
pub fn spherical_wahba_kernel_matrix(
data: ArrayView2<'_, f64>,
centers: ArrayView2<'_, f64>,
penalty_order: usize,
radians: bool,
) -> Result<Array2<f64>, BasisError> {
spherical_wahba_kernel_matrix_with_kind(
data,
centers,
penalty_order,
radians,
SphereWahbaKernel::Sobolev,
)
}
pub fn spherical_wahba_kernel_matrix_with_kind(
data: ArrayView2<'_, f64>,
centers: ArrayView2<'_, f64>,
penalty_order: usize,
radians: bool,
kernel: SphereWahbaKernel,
) -> Result<Array2<f64>, BasisError> {
validate_lat_lon_matrix(data, "spherical spline data", radians)?;
validate_lat_lon_matrix(centers, "spherical spline centers", radians)?;
let n = data.nrows();
let k = centers.nrows();
let deg = if radians {
1.0
} else {
std::f64::consts::PI / 180.0
};
let mut sin_lat_c = Vec::<f64>::with_capacity(k);
let mut cos_lat_c = Vec::<f64>::with_capacity(k);
let mut sin_lon_c = Vec::<f64>::with_capacity(k);
let mut cos_lon_c = Vec::<f64>::with_capacity(k);
for c in centers.outer_iter() {
let lat = c[0] * deg;
let lon = c[1] * deg;
let (s_lat, c_lat) = lat.sin_cos();
let (s_lon, c_lon) = lon.sin_cos();
sin_lat_c.push(s_lat);
cos_lat_c.push(c_lat);
sin_lon_c.push(s_lon);
cos_lon_c.push(c_lon);
}
let mut out = Array2::<f64>::zeros((n, k));
let err_flag = std::sync::atomic::AtomicBool::new(false);
out.axis_chunks_iter_mut(ndarray::Axis(0), 256)
.into_par_iter()
.enumerate()
.for_each(|(chunk_idx, mut block)| {
use wide::f64x4;
let row_offset = chunk_idx * 256;
let chunks = k / 4;
let tail = k % 4;
for (local_i, mut out_row) in block.outer_iter_mut().enumerate() {
let i = row_offset + local_i;
let lat = data[(i, 0)] * deg;
let lon = data[(i, 1)] * deg;
let (sin_lat, cos_lat) = lat.sin_cos();
let (sin_lon, cos_lon) = lon.sin_cos();
let sin_lat_v = f64x4::from(sin_lat);
let cos_lat_v = f64x4::from(cos_lat);
let sin_lon_v = f64x4::from(sin_lon);
let cos_lon_v = f64x4::from(cos_lon);
for cidx in 0..chunks {
let base = cidx * 4;
let sl_c = f64x4::from([
sin_lat_c[base],
sin_lat_c[base + 1],
sin_lat_c[base + 2],
sin_lat_c[base + 3],
]);
let cl_c = f64x4::from([
cos_lat_c[base],
cos_lat_c[base + 1],
cos_lat_c[base + 2],
cos_lat_c[base + 3],
]);
let sn_c = f64x4::from([
sin_lon_c[base],
sin_lon_c[base + 1],
sin_lon_c[base + 2],
sin_lon_c[base + 3],
]);
let cn_c = f64x4::from([
cos_lon_c[base],
cos_lon_c[base + 1],
cos_lon_c[base + 2],
cos_lon_c[base + 3],
]);
let dlon_cos = cos_lon_v * cn_c + sin_lon_v * sn_c;
let cos_gamma = sin_lat_v * sl_c + cos_lat_v * cl_c * dlon_cos;
let vals =
wahba_sphere_kernel_from_cos_simd_kind(cos_gamma, penalty_order, kernel);
let arr = vals.to_array();
for lane in 0..4 {
if !arr[lane].is_finite() {
err_flag.store(true, std::sync::atomic::Ordering::Relaxed);
return;
}
out_row[base + lane] = arr[lane];
}
}
let tail_start = chunks * 4;
for t in 0..tail {
let j = tail_start + t;
let dlon_cos = cos_lon * cos_lon_c[j] + sin_lon * sin_lon_c[j];
let cos_gamma = sin_lat * sin_lat_c[j] + cos_lat * cos_lat_c[j] * dlon_cos;
match wahba_sphere_kernel_from_cos_kind(cos_gamma, penalty_order, kernel) {
Ok(v) => out_row[j] = v,
Err(_) => {
err_flag.store(true, std::sync::atomic::Ordering::Relaxed);
return;
}
}
}
}
});
if err_flag.load(std::sync::atomic::Ordering::Relaxed) {
crate::bail_invalid_basis!("spherical spline kernel produced a non-finite value");
}
Ok(out)
}
pub(crate) fn weighted_coefficient_sum_to_zero_transform(
weights: ArrayView1<'_, f64>,
) -> Result<Array2<f64>, BasisError> {
let k = weights.len();
if k < 2 {
return Err(BasisError::InsufficientColumnsForConstraint { found: k });
}
if weights.iter().any(|w| !w.is_finite() || *w < 0.0) {
crate::bail_invalid_basis!(
"sphere coefficient constraint weights must be finite and non-negative"
);
}
let norm = weights.iter().map(|w| w * w).sum::<f64>().sqrt();
if norm <= 0.0 {
crate::bail_invalid_basis!("sphere coefficient constraint weights cannot all be zero");
}
let c = Array2::from_shape_vec((k, 1), weights.iter().map(|w| *w / norm).collect())
.map_err(|e| BasisError::InvalidInput(format!("invalid sphere constraint weights: {e}")))?;
let (z, rank) =
rrqr_nullspace_basis(&c, default_rrqr_rank_alpha()).map_err(BasisError::LinalgError)?;
if rank >= k {
return Err(BasisError::ConstraintNullspaceCollapsed {
site: "weighted_coefficient_sum_to_zero_transform",
cross_rank: rank,
coeff_dim: k,
cross_frobenius: 1.0,
constrained_gram_max_eigenvalue: f64::NAN,
constrained_gram_min_eigenvalue: f64::NAN,
spectral_tolerance: f64::NAN,
});
}
Ok(z)
}
pub(crate) fn sphere_area_weights(centers: ArrayView2<'_, f64>, radians: bool) -> Array1<f64> {
let to_rad = if radians {
1.0
} else {
std::f64::consts::PI / 180.0
};
Array1::from_iter(
centers
.outer_iter()
.map(|row| (row[0] * to_rad).cos().max(0.0)),
)
}
#[inline]
pub(crate) fn spherical_chord_distance2(
a: ArrayView1<'_, f64>,
b: ArrayView1<'_, f64>,
radians: bool,
) -> f64 {
let to_rad = if radians {
1.0
} else {
std::f64::consts::PI / 180.0
};
let lat_a = a[0] * to_rad;
let lon_a = a[1] * to_rad;
let lat_b = b[0] * to_rad;
let lon_b = b[1] * to_rad;
let cos_gamma = lat_a.sin() * lat_b.sin() + lat_a.cos() * lat_b.cos() * (lon_a - lon_b).cos();
2.0 * (1.0 - cos_gamma.clamp(-1.0, 1.0))
}
pub fn select_spherical_farthest_point_centers(
data: ArrayView2<'_, f64>,
num_centers: usize,
radians: bool,
) -> Result<Array2<f64>, BasisError> {
validate_lat_lon_matrix(data, "spherical farthest-point centers", radians)?;
let n = data.nrows();
if num_centers == 0 {
crate::bail_invalid_basis!("spherical farthest-point center count must be positive");
}
if num_centers > n {
crate::bail_invalid_basis!(
"requested {} spherical centers but only {} rows are available",
num_centers,
n
);
}
let mut seed_idx = 0usize;
for i in 1..n {
let lat_i = data[[i, 0]];
let lon_i = data[[i, 1]];
let lat_s = data[[seed_idx, 0]];
let lon_s = data[[seed_idx, 1]];
if lat_i < lat_s || (lat_i == lat_s && lon_i < lon_s) {
seed_idx = i;
}
}
let mut selected = Vec::with_capacity(num_centers);
let mut chosen = vec![false; n];
let mut min_dist2 = vec![f64::INFINITY; n];
selected.push(seed_idx);
chosen[seed_idx] = true;
min_dist2.par_iter_mut().enumerate().for_each(|(i, slot)| {
*slot = spherical_chord_distance2(data.row(i), data.row(seed_idx), radians);
});
min_dist2[seed_idx] = 0.0;
while selected.len() < num_centers {
let best_idx = min_dist2
.par_iter()
.enumerate()
.filter(|(i, _)| !chosen[*i])
.map(|(i, &cand)| (i, cand))
.reduce_with(|a, b| {
if b.1 > a.1 || (b.1 == a.1 && b.0 < a.0) {
b
} else {
a
}
})
.map(|(i, _)| i);
let Some(next_idx) = best_idx else {
break;
};
selected.push(next_idx);
chosen[next_idx] = true;
min_dist2.par_iter_mut().enumerate().for_each(|(i, slot)| {
if chosen[i] {
return;
}
let d2 = spherical_chord_distance2(data.row(i), data.row(next_idx), radians);
if d2 < *slot {
*slot = d2;
}
});
}
let mut centers = Array2::<f64>::zeros((selected.len(), 2));
for (r, &idx) in selected.iter().enumerate() {
centers.row_mut(r).assign(&data.row(idx));
}
Ok(centers)
}
pub fn auto_streaming_chunk_size_for_dense(n_rows: usize, n_basis_cols: usize) -> Option<usize> {
if n_rows == 0 || n_basis_cols == 0 {
return None;
}
const DENSE_THRESHOLD_BYTES: usize = 1024 * 1024 * 1024;
const TARGET_CHUNK_BYTES: usize = 256 * 1024 * 1024;
const MIN_CHUNK_ROWS: usize = 1024;
let dense_bytes = n_rows.saturating_mul(n_basis_cols).saturating_mul(8);
if dense_bytes <= DENSE_THRESHOLD_BYTES {
return None;
}
let row_bytes = n_basis_cols.saturating_mul(8).max(1);
let raw_chunk = TARGET_CHUNK_BYTES / row_bytes;
let clamped = raw_chunk.max(MIN_CHUNK_ROWS).min(n_rows);
Some(clamped)
}
pub(crate) fn try_build_truncated_sphere_design_gpu(
data: ArrayView2<'_, f64>,
centers: ArrayView2<'_, f64>,
kernel: SphereWahbaKernel,
penalty_order: usize,
radians: bool,
) -> Option<Array2<f64>> {
let (lmax_u16, kind) = match kernel {
SphereWahbaKernel::SobolevTruncated { lmax } => (
lmax,
crate::terms::basis::sphere_gpu::SphereSpectralKernelKind::Sobolev,
),
SphereWahbaKernel::PseudoTruncated { lmax } => (
lmax,
crate::terms::basis::sphere_gpu::SphereSpectralKernelKind::Pseudo,
),
SphereWahbaKernel::Sobolev | SphereWahbaKernel::Pseudo => return None,
};
let lmax = lmax_u16 as usize;
if lmax == 0 {
return None;
}
let n = data.nrows();
let m = centers.nrows();
let decision = crate::terms::basis::sphere_gpu::sphere_kernel_decision(n, m, lmax);
if !decision.use_gpu {
return None;
}
let data_xyz = crate::terms::basis::sphere_gpu::latlon_to_xyz_host(data, radians).ok()?;
let centers_xyz = crate::terms::basis::sphere_gpu::latlon_to_xyz_host(centers, radians).ok()?;
let coeffs = kind.coefficients(lmax, penalty_order);
let inputs = crate::terms::basis::sphere_gpu::S2KernelBuildInputs {
n,
m,
lmax,
data_xyz: &data_xyz,
centers_xyz: ¢ers_xyz,
coeffs: &coeffs,
kind,
layout: crate::terms::basis::sphere_gpu::DeviceMatrixLayout::ColumnMajor,
};
let dev = match crate::terms::basis::sphere_gpu::build_kernel_matrix_device(inputs) {
Ok(d) => d,
Err(err) => {
log::warn!(
"sphere GPU kernel build fell back to CPU (n={n}, m={m}, lmax={lmax}): {err}"
);
return None;
}
};
match dev.to_host_array() {
Ok(arr) => {
log::info!(
"sphere GPU kernel matrix: n={n} m={m} lmax={lmax} kind={}",
kind.tag()
);
Some(arr)
}
Err(err) => {
log::warn!("sphere GPU dtoh fell back to CPU (n={n}, m={m}, lmax={lmax}): {err}");
None
}
}
}