use super::*;
pub fn select_centers_by_strategy(
data: ArrayView2<'_, f64>,
strategy: &CenterStrategy,
) -> Result<Array2<f64>, BasisError> {
match strategy {
CenterStrategy::Auto(inner) => select_centers_by_strategy(data, inner.as_ref()),
CenterStrategy::UserProvided(centers) => {
if centers.ncols() != data.ncols() {
crate::bail_dim_basis!(
"user centers have {} columns but data has {}",
centers.ncols(),
data.ncols()
);
}
if centers.nrows() == 0 {
crate::bail_invalid_basis!("user-provided center list cannot be empty");
}
Ok(centers.clone())
}
CenterStrategy::EqualMass { num_centers } => select_equal_mass_centers(data, *num_centers),
CenterStrategy::EqualMassCovarRepresentative { num_centers } => {
select_equal_mass_covar_representative_centers(data, *num_centers)
}
CenterStrategy::FarthestPoint { num_centers } => {
select_thin_plate_knots(data, *num_centers)
}
CenterStrategy::KMeans {
num_centers,
max_iter,
} => select_kmeans_centers(data, *num_centers, *max_iter),
CenterStrategy::UniformGrid { points_per_dim } => {
select_uniform_grid_centers(data, *points_per_dim)
}
}
}
pub fn build_bspline_basis_1d(
data: ArrayView1<'_, f64>,
spec: &BSplineBasisSpec,
) -> Result<BasisBuildResult, BasisError> {
if let OneDimensionalBoundary::Cyclic { start, end } = spec.boundary
&& end <= start
{
return Err(BasisError::InvalidRange(start, end));
}
let (spec_owned, auto_shrink_note) = maybe_auto_shrink_bspline_spec(spec, data.len());
let spec = &spec_owned;
let periodic_build = match &spec.knotspec {
BSplineKnotSpec::PeriodicUniform {
data_range,
num_basis,
} => {
if let Some((boundary_start, boundary_end, _)) = spec.boundary.period() {
let scale = (boundary_end - boundary_start).abs().max(1.0);
let tol = 1e-12 * scale;
if (data_range.0 - boundary_start).abs() > tol
|| (data_range.1 - boundary_end).abs() > tol
{
crate::bail_invalid_basis!(
"periodic B-spline knot range ({}, {}) conflicts with cyclic boundary ({}, {})",
data_range.0,
data_range.1,
boundary_start,
boundary_end
);
}
}
Some((data_range.0, data_range.1, *num_basis))
}
_ => spec.boundary.period().map(|(start, end, _)| {
let num_basis = match &spec.knotspec {
BSplineKnotSpec::Generate {
num_internal_knots, ..
} => num_internal_knots + spec.degree + 1,
BSplineKnotSpec::Automatic {
num_internal_knots, ..
} => {
num_internal_knots.unwrap_or_else(|| {
default_internal_knot_count_for_data(data.len(), spec.degree)
}) + spec.degree
+ 1
}
BSplineKnotSpec::Provided(knots) => knots.len().saturating_sub(spec.degree + 1),
BSplineKnotSpec::PeriodicUniform { .. } => {
assert!(
false,
"PeriodicUniform knotspec should have been handled by the outer match arm"
);
0
}
};
(start, end, num_basis)
}),
};
if let Some((start, end, num_basis)) = periodic_build {
if spec.degree != 3 {
crate::bail_invalid_basis!(
"cyclic P-splines currently require cubic degree=3, got degree={}",
spec.degree
);
}
if !spec.boundary_conditions.is_free() {
crate::bail_invalid_basis!(
"periodic B-splines cannot also declare endpoint boundary conditions"
);
}
let knots = cyclic_uniform_knot_vector(start, end, spec.degree, num_basis);
let s_bend_raw = create_cyclic_difference_penalty_matrix(num_basis, spec.penalty_order)?;
let penalties_raw = vec![PenaltyCandidate {
matrix: s_bend_raw.clone(),
nullspace_dim_hint: 1,
source: PenaltySource::Primary,
normalization_scale: 1.0,
kronecker_factors: None,
op: None,
}];
let penalties_raw_mats = penalties_raw
.iter()
.map(|candidate| candidate.matrix.clone())
.collect();
let auto_chunk = auto_streaming_chunk_size_for_dense(data.len(), num_basis);
let (design, transformed_candidates, identifiability_transform) =
if let Some(chunk) = auto_chunk {
log::info!(
"B-spline basis auto-streaming evaluator: n={} p={} chunk_size={}",
data.len(),
num_basis,
chunk,
);
build_streaming_bspline_design_and_candidates(
data,
&knots,
spec.degree,
Some((start, end - start, num_basis)),
&spec.identifiability,
penalties_raw,
penalties_raw_mats,
Some(chunk),
)?
} else {
let (basis, _) =
create_cyclic_bspline_basis_dense(data, start, end, spec.degree, num_basis)?;
let (design_c, penalty_mats, identifiability_transform) =
apply_bspline_identifiability_policy(
basis,
penalties_raw_mats,
&knots,
spec.degree,
&spec.identifiability,
)?;
let transformed_candidates = penalty_mats
.into_iter()
.zip(penalties_raw)
.map(|(matrix, candidate)| PenaltyCandidate {
nullspace_dim_hint: candidate.nullspace_dim_hint,
matrix,
source: candidate.source,
normalization_scale: candidate.normalization_scale,
kronecker_factors: None,
op: None,
})
.collect();
(
DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(design_c)),
transformed_candidates,
identifiability_transform,
)
};
let (penalties, nullspace_dims, penaltyinfo, null_eigenvectors, ops) =
filter_active_penalty_candidates_with_ops(transformed_candidates)?;
return Ok(BasisBuildResult {
design,
penalties,
nullspace_dims,
penaltyinfo,
metadata: BasisMetadata::BSpline1D {
knots,
identifiability_transform,
periodic: Some((start, end - start, num_basis)),
degree: Some(spec.degree),
auto_shrink_note: auto_shrink_note.clone(),
},
kronecker_factored: None,
ops,
null_eigenvectors,
joint_null_rotation: None,
});
}
let auto_chunk_streaming = {
let knots_for_estimate = match &spec.knotspec {
BSplineKnotSpec::Generate {
data_range,
num_internal_knots,
} => Some(internal::generate_full_knot_vector(
*data_range,
*num_internal_knots,
spec.degree,
)?),
BSplineKnotSpec::Provided(knots) => Some(knots.clone()),
BSplineKnotSpec::Automatic {
num_internal_knots,
placement,
} => {
let inferred = num_internal_knots.unwrap_or_else(|| {
default_internal_knot_count_for_data(data.len(), spec.degree)
});
Some(match placement {
BSplineKnotPlacement::Uniform => {
let range = finite_data_range(data)?;
internal::generate_full_knot_vector(range, inferred, spec.degree)?
}
BSplineKnotPlacement::Quantile => {
internal::generate_full_knot_vector_quantile(data, inferred, spec.degree)?
}
})
}
BSplineKnotSpec::PeriodicUniform { .. } => None,
};
match knots_for_estimate {
Some(knots_est) => {
let p_raw_est = knots_est
.len()
.checked_sub(spec.degree + 1)
.ok_or_else(|| {
BasisError::InvalidInput(
"invalid B-spline knot/degree combination".to_string(),
)
})?;
auto_streaming_chunk_size_for_dense(data.len(), p_raw_est)
.map(|chunk| (knots_est, p_raw_est, chunk))
}
None => None,
}
};
if let Some((knots, p_raw, chunk)) = auto_chunk_streaming {
let greville_for_penalty = penalty_greville_abscissae_for_knots(&knots, spec.degree)?;
let s_bend_raw = create_difference_penalty_matrix(
p_raw,
spec.penalty_order,
greville_for_penalty.as_ref().map(|g| g.view()),
)?;
let mut penalties_raw = vec![PenaltyCandidate {
matrix: s_bend_raw.clone(),
nullspace_dim_hint: 0,
source: PenaltySource::Primary,
normalization_scale: 1.0,
kronecker_factors: None,
op: None,
}];
if spec.double_penalty {
penalties_raw.push(PenaltyCandidate {
matrix: build_nullspace_shrinkage_penalty(&s_bend_raw)?
.map(|shrink| shrink.sym_penalty)
.unwrap_or_else(|| Array2::<f64>::zeros(s_bend_raw.raw_dim())),
nullspace_dim_hint: 0,
source: PenaltySource::DoublePenaltyNullspace,
normalization_scale: 1.0,
kronecker_factors: None,
op: None,
});
}
let penalties_raw_mats = penalties_raw
.iter()
.map(|candidate| candidate.matrix.clone())
.collect();
log::info!(
"B-spline basis auto-streaming evaluator: n={} p={} chunk_size={}",
data.len(),
p_raw,
chunk,
);
let (design, transformed_candidates, identifiability_transform) =
build_streaming_bspline_design_and_candidates(
data,
&knots,
spec.degree,
None,
&spec.identifiability,
penalties_raw,
penalties_raw_mats,
Some(chunk),
)?;
let (penalties, nullspace_dims, penaltyinfo, null_eigenvectors, ops) =
filter_active_penalty_candidates_with_ops(transformed_candidates)?;
return Ok(BasisBuildResult {
design,
penalties,
nullspace_dims,
penaltyinfo,
metadata: BasisMetadata::BSpline1D {
knots,
identifiability_transform,
periodic: None,
degree: Some(spec.degree),
auto_shrink_note: auto_shrink_note.clone(),
},
kronecker_factored: None,
ops,
null_eigenvectors,
joint_null_rotation: None,
});
}
let prefer_sparse_design = spec.boundary_conditions.is_free()
&& matches!(
spec.identifiability,
BSplineIdentifiability::None | BSplineIdentifiability::WeightedSumToZero { .. }
);
let (design_sparse_opt, design_dense_opt, knots) = if prefer_sparse_design {
match &spec.knotspec {
BSplineKnotSpec::Generate {
data_range,
num_internal_knots,
} => {
let (basis, knots) = create_basis::<Sparse>(
data,
KnotSource::Generate {
data_range: *data_range,
num_internal_knots: *num_internal_knots,
},
spec.degree,
BasisOptions::value(),
)?;
(Some(basis), None, knots)
}
BSplineKnotSpec::Provided(knots) => {
let (basis, knots) = create_basis::<Sparse>(
data,
KnotSource::Provided(knots.view()),
spec.degree,
BasisOptions::value(),
)?;
(Some(basis), None, knots)
}
BSplineKnotSpec::PeriodicUniform { .. } => {
crate::bail_invalid_basis!(
"periodic B-spline must be handled before storage selection; \
this branch is reserved for non-periodic knot specs"
.to_string(),
);
}
BSplineKnotSpec::Automatic {
num_internal_knots,
placement,
} => {
let inferred = num_internal_knots.unwrap_or_else(|| {
default_internal_knot_count_for_data(data.len(), spec.degree)
});
let knots = match placement {
BSplineKnotPlacement::Uniform => {
let range = finite_data_range(data)?;
internal::generate_full_knot_vector(range, inferred, spec.degree)?
}
BSplineKnotPlacement::Quantile => {
internal::generate_full_knot_vector_quantile(data, inferred, spec.degree)?
}
};
let (basis, knots) = create_basis::<Sparse>(
data,
KnotSource::Provided(knots.view()),
spec.degree,
BasisOptions::value(),
)?;
(Some(basis), None, knots)
}
}
} else {
match &spec.knotspec {
BSplineKnotSpec::Generate {
data_range,
num_internal_knots,
} => {
let (basis, knots) = create_basis::<Dense>(
data,
KnotSource::Generate {
data_range: *data_range,
num_internal_knots: *num_internal_knots,
},
spec.degree,
BasisOptions::value(),
)?;
(None, Some((*basis).clone()), knots)
}
BSplineKnotSpec::Provided(knots) => {
let (basis, knots) = create_basis::<Dense>(
data,
KnotSource::Provided(knots.view()),
spec.degree,
BasisOptions::value(),
)?;
(None, Some((*basis).clone()), knots)
}
BSplineKnotSpec::PeriodicUniform { .. } => {
crate::bail_invalid_basis!(
"periodic B-spline must be handled before storage selection; \
this branch is reserved for non-periodic knot specs"
.to_string(),
);
}
BSplineKnotSpec::Automatic {
num_internal_knots,
placement,
} => {
let inferred = num_internal_knots.unwrap_or_else(|| {
default_internal_knot_count_for_data(data.len(), spec.degree)
});
let knots = match placement {
BSplineKnotPlacement::Uniform => {
let range = finite_data_range(data)?;
internal::generate_full_knot_vector(range, inferred, spec.degree)?
}
BSplineKnotPlacement::Quantile => {
internal::generate_full_knot_vector_quantile(data, inferred, spec.degree)?
}
};
let (basis, knots) = create_basis::<Dense>(
data,
KnotSource::Provided(knots.view()),
spec.degree,
BasisOptions::value(),
)?;
(None, Some((*basis).clone()), knots)
}
}
};
let p_raw = design_sparse_opt
.as_ref()
.map(|basis| basis.ncols())
.or_else(|| design_dense_opt.as_ref().map(Array2::ncols))
.expect("B-spline basis should be present");
let greville_for_penalty = penalty_greville_abscissae_for_knots(&knots, spec.degree)?;
let s_bend_raw = create_difference_penalty_matrix(
p_raw,
spec.penalty_order,
greville_for_penalty.as_ref().map(|g| g.view()),
)?;
let mut penalties_raw = vec![PenaltyCandidate {
matrix: s_bend_raw.clone(),
nullspace_dim_hint: 0,
source: PenaltySource::Primary,
normalization_scale: 1.0,
kronecker_factors: None,
op: None,
}];
if spec.double_penalty {
penalties_raw.push(PenaltyCandidate {
matrix: build_nullspace_shrinkage_penalty(&s_bend_raw)?
.map(|shrink| shrink.sym_penalty)
.unwrap_or_else(|| Array2::<f64>::zeros(s_bend_raw.raw_dim())),
nullspace_dim_hint: 0,
source: PenaltySource::DoublePenaltyNullspace,
normalization_scale: 1.0,
kronecker_factors: None,
op: None,
});
}
let penalties_raw_mats: Vec<Array2<f64>> = penalties_raw
.iter()
.map(|candidate| candidate.matrix.clone())
.collect();
let (design, transformed_candidates, identifiability_transform) = if let Some(sparse_basis) =
design_sparse_opt
{
match &spec.identifiability {
BSplineIdentifiability::None => {
let transformed_candidates = penalties_raw
.into_iter()
.map(|candidate| -> Result<PenaltyCandidate, BasisError> {
Ok(PenaltyCandidate {
nullspace_dim_hint: candidate.nullspace_dim_hint,
matrix: candidate.matrix,
source: candidate.source,
normalization_scale: candidate.normalization_scale,
kronecker_factors: None,
op: None,
})
})
.collect::<Result<Vec<_>, _>>()?;
(
DesignMatrix::Sparse(crate::matrix::SparseDesignMatrix::new(sparse_basis)),
transformed_candidates,
None,
)
}
BSplineIdentifiability::WeightedSumToZero { weights } => {
let (constrained_basis, z) = apply_sum_to_zero_constraint_sparse(
&sparse_basis,
weights.as_ref().map(|w| w.view()),
)?;
let transformed_candidates = penalties_raw
.into_iter()
.map(|candidate| -> Result<PenaltyCandidate, BasisError> {
let zt_s = fast_atb(&z, &candidate.matrix);
let matrix = fast_ab(&zt_s, &z);
Ok(PenaltyCandidate {
nullspace_dim_hint: candidate.nullspace_dim_hint,
matrix,
source: candidate.source,
normalization_scale: candidate.normalization_scale,
kronecker_factors: None,
op: None,
})
})
.collect::<Result<Vec<_>, _>>()?;
(
DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(Arc::new(
constrained_basis,
))),
transformed_candidates,
Some(z),
)
}
BSplineIdentifiability::RemoveLinearTrend
| BSplineIdentifiability::OrthogonalToDesignColumns { .. }
| BSplineIdentifiability::FrozenTransform { .. } => {
crate::bail_invalid_basis!(
"sparse B-spline identifiability only supports None or \
WeightedSumToZero; RemoveLinearTrend, \
OrthogonalToDesignColumns, and FrozenTransform require \
the dense path"
.to_string(),
);
}
}
} else {
let (design, penalties, identifiability_transform) = apply_bspline_identifiability_policy(
design_dense_opt.expect("dense B-spline basis should be present"),
penalties_raw_mats,
&knots,
spec.degree,
&spec.identifiability,
)?;
let transformed_candidates = penalties
.into_iter()
.zip(penalties_raw.into_iter())
.map(
|(matrix, candidate)| -> Result<PenaltyCandidate, BasisError> {
Ok(PenaltyCandidate {
nullspace_dim_hint: candidate.nullspace_dim_hint,
matrix,
source: candidate.source,
normalization_scale: candidate.normalization_scale,
kronecker_factors: None,
op: None,
})
},
)
.collect::<Result<Vec<_>, _>>()?;
(
DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(design)),
transformed_candidates,
identifiability_transform,
)
};
let (penalties, nullspace_dims, penaltyinfo, null_eigenvectors, ops) =
filter_active_penalty_candidates_with_ops(transformed_candidates)?;
Ok(BasisBuildResult {
design,
penalties,
nullspace_dims,
penaltyinfo,
metadata: BasisMetadata::BSpline1D {
knots,
identifiability_transform,
periodic: None,
degree: Some(spec.degree),
auto_shrink_note,
},
kronecker_factored: None,
ops,
null_eigenvectors,
joint_null_rotation: None,
})
}
pub(crate) fn compose_bspline_transform(
existing: Option<Array2<f64>>,
next: Array2<f64>,
) -> Result<Array2<f64>, BasisError> {
match existing {
Some(prev) => {
if prev.ncols() != next.nrows() {
crate::bail_dim_basis!(
"B-spline streaming transform composition mismatch: previous is {}x{}, next is {}x{}",
prev.nrows(),
prev.ncols(),
next.nrows(),
next.ncols()
);
}
Ok(fast_ab(&prev, &next))
}
None => Ok(next),
}
}
pub(crate) fn bspline_sum_to_zero_transform_from_cross(
c: &Array1<f64>,
) -> Result<Array2<f64>, BasisError> {
let k = c.len();
if k < 2 {
return Err(BasisError::InsufficientColumnsForConstraint { found: k });
}
let pivot_abs = c.iter().map(|v| v.abs()).fold(0.0_f64, f64::max);
if pivot_abs <= 1e-12 {
return Ok(Array2::eye(k));
}
let mut c_mat = Array2::<f64>::zeros((k, 1));
c_mat.column_mut(0).assign(c);
let (z, rank) =
rrqr_nullspace_basis(&c_mat, default_rrqr_rank_alpha()).map_err(BasisError::LinalgError)?;
if rank >= k {
return Err(BasisError::ConstraintNullspaceCollapsed {
site: "bspline_sum_to_zero_transform_from_cross",
cross_rank: rank,
coeff_dim: k,
cross_frobenius: c.iter().map(|v| v * v).sum::<f64>().sqrt(),
constrained_gram_max_eigenvalue: f64::NAN,
constrained_gram_min_eigenvalue: f64::NAN,
spectral_tolerance: f64::NAN,
});
}
Ok(z)
}
pub(crate) fn streaming_bspline_current_chunk(
data: ArrayView1<'_, f64>,
knots: &Array1<f64>,
degree: usize,
periodic: Option<(f64, f64, usize)>,
transform: Option<&Array2<f64>>,
start: usize,
end: usize,
) -> Result<Array2<f64>, BasisError> {
let raw = bspline_raw_row_chunk(data, knots.view(), degree, periodic, start, end)?;
Ok(match transform {
Some(z) => fast_ab(&raw, z),
None => raw,
})
}
pub(crate) fn streaming_bspline_sum_cross(
data: ArrayView1<'_, f64>,
knots: &Array1<f64>,
degree: usize,
periodic: Option<(f64, f64, usize)>,
transform: Option<&Array2<f64>>,
weights: Option<ArrayView1<'_, f64>>,
chunk_size: usize,
) -> Result<Array1<f64>, BasisError> {
if let Some(w) = weights.as_ref()
&& w.len() != data.len()
{
return Err(BasisError::WeightsDimensionMismatch {
expected: data.len(),
found: w.len(),
});
}
let cols = transform.map(Array2::ncols).unwrap_or(
bspline_raw_column_count(knots, degree, periodic).map_err(BasisError::InvalidInput)?,
);
let mut out = Array1::<f64>::zeros(cols);
for start in (0..data.len()).step_by(chunk_size.max(1)) {
let end = (start + chunk_size.max(1)).min(data.len());
let current =
streaming_bspline_current_chunk(data, knots, degree, periodic, transform, start, end)?;
let w_chunk = match weights.as_ref() {
Some(w) => w.slice(s![start..end]).to_owned(),
None => Array1::<f64>::ones(end - start),
};
out += ¤t.t().dot(&w_chunk);
}
Ok(out)
}
pub(crate) fn streaming_bspline_orthogonality_transform(
data: ArrayView1<'_, f64>,
knots: &Array1<f64>,
degree: usize,
periodic: Option<(f64, f64, usize)>,
transform: Option<&Array2<f64>>,
columns: ArrayView2<'_, f64>,
weights: Option<ArrayView1<'_, f64>>,
chunk_size: usize,
) -> Result<Array2<f64>, BasisError> {
if columns.nrows() != data.len() {
return Err(BasisError::ConstraintMatrixRowMismatch {
basisrows: data.len(),
constraintrows: columns.nrows(),
});
}
if let Some(w) = weights.as_ref()
&& w.len() != data.len()
{
return Err(BasisError::WeightsDimensionMismatch {
expected: data.len(),
found: w.len(),
});
}
let cols = transform.map(Array2::ncols).unwrap_or(
bspline_raw_column_count(knots, degree, periodic).map_err(BasisError::InvalidInput)?,
);
if columns.ncols() == 0 {
return Ok(Array2::eye(cols));
}
let mut cross = Array2::<f64>::zeros((cols, columns.ncols()));
let mut gram = Array2::<f64>::zeros((cols, cols));
for start in (0..data.len()).step_by(chunk_size.max(1)) {
let end = (start + chunk_size.max(1)).min(data.len());
let current =
streaming_bspline_current_chunk(data, knots, degree, periodic, transform, start, end)?;
let mut weighted_constraints = columns.slice(s![start..end, ..]).to_owned();
if let Some(w) = weights.as_ref() {
for (mut row, &weight) in weighted_constraints
.axis_iter_mut(Axis(0))
.zip(w.slice(s![start..end]).iter())
{
row *= weight;
}
}
cross += ¤t.t().dot(&weighted_constraints);
gram += &fast_ata(¤t);
}
orthogonality_transform_from_cross_and_gram(&cross, &gram)
}
pub(crate) fn build_streaming_bspline_design_and_candidates(
data: ArrayView1<'_, f64>,
knots: &Array1<f64>,
degree: usize,
periodic: Option<(f64, f64, usize)>,
identifiability: &BSplineIdentifiability,
penalties_raw: Vec<PenaltyCandidate>,
mut penalty_mats: Vec<Array2<f64>>,
chunk_size: Option<usize>,
) -> Result<(DesignMatrix, Vec<PenaltyCandidate>, Option<Array2<f64>>), BasisError> {
let chunk = chunk_size.unwrap_or(DEFAULT_STREAMING_CHUNK_ROWS).max(1);
let mut transform_opt: Option<Array2<f64>> = None;
match identifiability {
BSplineIdentifiability::None => {}
BSplineIdentifiability::WeightedSumToZero { weights } => {
let cross = streaming_bspline_sum_cross(
data,
knots,
degree,
periodic,
transform_opt.as_ref(),
weights.as_ref().map(|w| w.view()),
chunk,
)?;
let z = bspline_sum_to_zero_transform_from_cross(&cross)?;
penalty_mats = penalty_mats
.into_iter()
.map(|s| project_penalty_matrix(&s, Some(&z)))
.collect();
transform_opt = Some(compose_bspline_transform(transform_opt, z)?);
}
BSplineIdentifiability::RemoveLinearTrend => {
let (z, _) = compute_geometric_constraint_transform(knots, degree, 2)?;
penalty_mats = penalty_mats
.into_iter()
.map(|s| project_penalty_matrix(&s, Some(&z)))
.collect();
transform_opt = Some(compose_bspline_transform(transform_opt, z)?);
}
BSplineIdentifiability::OrthogonalToDesignColumns { columns, weights } => {
let z = streaming_bspline_orthogonality_transform(
data,
knots,
degree,
periodic,
transform_opt.as_ref(),
columns.view(),
weights.as_ref().map(|w| w.view()),
chunk,
)?;
penalty_mats = penalty_mats
.into_iter()
.map(|s| project_penalty_matrix(&s, Some(&z)))
.collect();
transform_opt = Some(compose_bspline_transform(transform_opt, z)?);
}
BSplineIdentifiability::FrozenTransform { transform } => {
let raw_cols = transform_opt.as_ref().map(Array2::ncols).unwrap_or(
bspline_raw_column_count(knots, degree, periodic)
.map_err(BasisError::InvalidInput)?,
);
if raw_cols != transform.nrows() {
crate::bail_dim_basis!(
"frozen identifiability transform mismatch: design has {} columns but transform has {} rows",
raw_cols,
transform.nrows()
);
}
let z = transform.clone();
penalty_mats = penalty_mats
.into_iter()
.map(|s| project_penalty_matrix(&s, Some(&z)))
.collect();
transform_opt = Some(compose_bspline_transform(transform_opt, z)?);
}
}
let transformed_candidates = penalty_mats
.into_iter()
.zip(penalties_raw)
.map(|(matrix, candidate)| PenaltyCandidate {
nullspace_dim_hint: candidate.nullspace_dim_hint,
matrix,
source: candidate.source,
normalization_scale: candidate.normalization_scale,
kronecker_factors: None,
op: None,
})
.collect();
let op = StreamingBSplineEvaluator::new(
Arc::new(data.to_owned()),
Arc::new(knots.clone()),
degree,
periodic,
transform_opt.as_ref().map(|z| Arc::new(z.clone())),
chunk_size,
)
.map_err(BasisError::InvalidInput)?;
Ok((
DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(Arc::new(op))),
transformed_candidates,
transform_opt,
))
}
pub(crate) fn apply_bspline_identifiability_policy(
design: Array2<f64>,
penalties: Vec<Array2<f64>>,
knots: &Array1<f64>,
degree: usize,
identifiability: &BSplineIdentifiability,
) -> Result<(Array2<f64>, Vec<Array2<f64>>, Option<Array2<f64>>), BasisError> {
let (design_c, z_opt): (Array2<f64>, Option<Array2<f64>>) = match identifiability {
BSplineIdentifiability::None => (design, None),
BSplineIdentifiability::WeightedSumToZero { weights } => {
let (b_c, z) =
apply_sum_to_zero_constraint(design.view(), weights.as_ref().map(|w| w.view()))?;
(b_c, Some(z))
}
BSplineIdentifiability::RemoveLinearTrend => {
let (z, _) = compute_geometric_constraint_transform(knots, degree, 2)?;
(fast_ab(&design, &z), Some(z))
}
BSplineIdentifiability::OrthogonalToDesignColumns { columns, weights } => {
let (b_c, z) = applyweighted_orthogonality_constraint(
design.view(),
columns.view(),
weights.as_ref().map(|w| w.view()),
)?;
(b_c, Some(z))
}
BSplineIdentifiability::FrozenTransform { transform } => {
let z = transform.clone();
if design.ncols() != z.nrows() {
crate::bail_dim_basis!(
"frozen identifiability transform mismatch: design has {} columns but transform has {} rows",
design.ncols(),
z.nrows()
);
}
(fast_ab(&design, &z), Some(z))
}
};
let penalties_c = if let Some(ref z) = z_opt {
penalties
.into_iter()
.map(|s| {
let zt_s = fast_atb(z, &s);
fast_ab(&zt_s, z)
})
.collect()
} else {
penalties
};
Ok((design_c, penalties_c, z_opt))
}
pub(crate) fn estimate_penalty_nullity(penalty: &Array2<f64>) -> Result<usize, BasisError> {
if penalty.nrows() != penalty.ncols() {
crate::bail_dim_basis!("penalty matrix must be square when estimating nullspace");
}
if penalty.nrows() == 0 {
return Ok(0);
}
let (sym, evals, _) = spectral_summary(penalty)?;
let tol = spectral_tolerance(&sym, &evals);
Ok(evals.iter().filter(|&&ev| ev.abs() <= tol).count())
}
#[derive(Debug, Clone)]
pub(crate) struct PsdSpectralSummary {
pub(crate) min_eigenvalue: f64,
pub(crate) max_abs_eigenvalue: f64,
pub(crate) tolerance: f64,
pub(crate) effective_rank: usize,
}
pub(crate) fn symmetrize_penalty(penalty: &Array2<f64>) -> Array2<f64> {
let mut sym = penalty.clone();
for i in 0..sym.nrows() {
for j in 0..i {
let v = 0.5 * (sym[[i, j]] + sym[[j, i]]);
sym[[i, j]] = v;
sym[[j, i]] = v;
}
}
sym
}
pub(crate) fn project_penalty_to_psd_cone(matrix: &Array2<f64>) -> Array2<f64> {
let sym = symmetrize_penalty(matrix);
let n = sym.nrows();
if n == 0 || n != sym.ncols() {
return sym;
}
let (evals, evecs) = match FaerEigh::eigh(&sym, Side::Lower) {
Ok(pair) => pair,
Err(_) => return sym,
};
if evals.is_empty() {
return sym;
}
let min_ev = evals.iter().copied().fold(f64::INFINITY, f64::min);
if min_ev >= 0.0 {
return sym;
}
let mut clamped = sym.clone();
for i in 0..n {
for j in 0..n {
let mut acc = 0.0_f64;
for k in 0..evals.len() {
let lam = evals[k];
if lam > 0.0 {
acc += lam * evecs[[i, k]] * evecs[[j, k]];
}
}
clamped[[i, j]] = acc;
}
}
for i in 0..n {
for j in 0..i {
let v = 0.5 * (clamped[[i, j]] + clamped[[j, i]]);
clamped[[i, j]] = v;
clamped[[j, i]] = v;
}
}
clamped
}
pub(crate) fn spectral_tolerance(sym: &Array2<f64>, evals: &Array1<f64>) -> f64 {
let max_abs_ev = evals
.iter()
.copied()
.fold(0.0_f64, |acc, v| acc.max(v.abs()));
(sym.nrows().max(1) as f64) * 1e-10 * max_abs_ev
}
pub(crate) fn spectral_summary(
penalty: &Array2<f64>,
) -> Result<(Array2<f64>, Array1<f64>, Array2<f64>), BasisError> {
let sym = symmetrize_penalty(penalty);
let (evals, evecs) = FaerEigh::eigh(&sym, Side::Lower).map_err(BasisError::LinalgError)?;
Ok((sym, evals, evecs))
}
pub(crate) fn validate_psd_penalty(
penalty: &Array2<f64>,
context: &str,
guidance: &str,
) -> Result<PsdSpectralSummary, BasisError> {
if penalty.nrows() != penalty.ncols() {
crate::bail_dim_basis!("{context}: penalty matrix must be square for PSD validation");
}
if penalty.nrows() == 0 {
return Ok(PsdSpectralSummary {
min_eigenvalue: 0.0,
max_abs_eigenvalue: 0.0,
tolerance: 1e-10,
effective_rank: 0,
});
}
let (sym, evals, _) = spectral_summary(penalty)?;
let tolerance = spectral_tolerance(&sym, &evals);
let min_eigenvalue = evals.iter().copied().fold(f64::INFINITY, f64::min);
let max_abs_eigenvalue = evals
.iter()
.copied()
.fold(0.0_f64, |acc, v| acc.max(v.abs()));
let effective_rank = evals.iter().filter(|&&ev| ev > tolerance).count();
if min_eigenvalue < -tolerance {
return Err(BasisError::IndefinitePenalty {
context: context.to_string(),
min_eigenvalue,
tolerance,
guidance: guidance.to_string(),
});
}
Ok(PsdSpectralSummary {
min_eigenvalue,
max_abs_eigenvalue,
tolerance,
effective_rank,
})
}
pub fn analyze_penalty_block(penalty: &Array2<f64>) -> Result<CanonicalPenaltyBlock, BasisError> {
analyze_penalty_block_with_op(penalty, None)
}
pub fn analyze_penalty_block_with_op(
penalty: &Array2<f64>,
op: Option<std::sync::Arc<dyn crate::terms::penalty_op::PenaltyOp>>,
) -> Result<CanonicalPenaltyBlock, BasisError> {
if penalty.nrows() != penalty.ncols() {
crate::bail_dim_basis!("penalty matrix must be square when analyzing penalty");
}
if penalty.nrows() == 0 {
return Ok(CanonicalPenaltyBlock {
sym_penalty: Array2::<f64>::zeros((0, 0)),
eigenvalues: Array1::<f64>::zeros(0),
eigenvectors: Array2::<f64>::zeros((0, 0)),
rank: 0,
nullity: 0,
tol: 1e-10,
iszero: true,
op,
});
}
let (sym, evals, evecs) = spectral_summary(penalty)?;
let tol = spectral_tolerance(&sym, &evals);
let rank = evals.iter().filter(|&&ev| ev > tol).count();
let nullity = sym.nrows().saturating_sub(rank);
let max_abs_eigenvalue = evals
.iter()
.copied()
.fold(0.0_f64, |acc, v| acc.max(v.abs()));
Ok(CanonicalPenaltyBlock {
sym_penalty: sym,
eigenvalues: evals,
eigenvectors: evecs,
rank,
nullity,
tol,
iszero: max_abs_eigenvalue <= tol,
op,
})
}
pub fn filter_active_penalty_candidates(
candidates: Vec<PenaltyCandidate>,
) -> Result<(Vec<Array2<f64>>, Vec<usize>, Vec<PenaltyInfo>), BasisError> {
let (penalties, nullspace_dims, penaltyinfo, _null_eigenvectors, _ops) =
filter_active_penalty_candidates_with_ops(candidates)?;
Ok((penalties, nullspace_dims, penaltyinfo))
}
pub(crate) fn nullspace_basis_from_block(block: &CanonicalPenaltyBlock) -> Option<Array2<f64>> {
if block.nullity == 0 {
return None;
}
let null_idx: Vec<usize> = block
.eigenvalues
.iter()
.enumerate()
.filter_map(|(i, &ev)| (ev <= block.tol).then_some(i))
.collect();
if null_idx.is_empty() {
return None;
}
Some(block.eigenvectors.select(Axis(1), &null_idx))
}
pub fn recompute_null_eigenvectors(
penalties: &[Array2<f64>],
) -> Result<Vec<Option<Array2<f64>>>, BasisError> {
penalties
.iter()
.map(|s| {
let block = analyze_penalty_block_with_op(s, None)?;
Ok(nullspace_basis_from_block(&block))
})
.collect()
}
pub fn compute_joint_null_rotation(
penalties: &[Array2<f64>],
) -> Result<Option<JointNullRotation>, BasisError> {
if penalties.is_empty() {
return Ok(None);
}
let p = penalties[0].nrows();
if p == 0 {
return Ok(None);
}
for (k, s) in penalties.iter().enumerate() {
if s.nrows() != p || s.ncols() != p {
crate::bail_dim_basis!(
"compute_joint_null_rotation: penalty[{}] is {}×{}, expected {}×{}",
k,
s.nrows(),
s.ncols(),
p,
p
);
}
}
let mut s_sum = Array2::<f64>::zeros((p, p));
for s in penalties {
s_sum += s;
}
let (sym, evals, evecs) = spectral_summary(&s_sum)?;
let tol = spectral_tolerance(&sym, &evals);
let joint_nullity = evals.iter().filter(|&&ev| ev <= tol).count();
if joint_nullity == 0 {
return Ok(None);
}
let n = evals.len();
let mut order: Vec<usize> = (0..n).collect();
order.sort_by(|&a, &b| {
let ev_a = evals[a];
let ev_b = evals[b];
let null_a = ev_a <= tol;
let null_b = ev_b <= tol;
match (null_a, null_b) {
(false, true) => std::cmp::Ordering::Less,
(true, false) => std::cmp::Ordering::Greater,
_ => ev_b.partial_cmp(&ev_a).unwrap_or(std::cmp::Ordering::Equal),
}
});
let rotation = evecs.select(Axis(1), &order);
Ok(Some(JointNullRotation {
rotation,
joint_nullity,
}))
}
pub fn filter_active_penalty_candidates_with_ops(
candidates: Vec<PenaltyCandidate>,
) -> Result<
(
Vec<Array2<f64>>,
Vec<usize>,
Vec<PenaltyInfo>,
Vec<Option<Array2<f64>>>,
Vec<Option<std::sync::Arc<dyn crate::terms::penalty_op::PenaltyOp>>>,
),
BasisError,
> {
let mut penalties = Vec::with_capacity(candidates.len());
let mut nullspace_dims = Vec::with_capacity(candidates.len());
let mut penaltyinfo = Vec::with_capacity(candidates.len());
let mut active_null_eigenvectors: Vec<Option<Array2<f64>>> =
Vec::with_capacity(candidates.len());
let mut active_ops: Vec<Option<std::sync::Arc<dyn crate::terms::penalty_op::PenaltyOp>>> =
Vec::with_capacity(candidates.len());
for (original_index, candidate) in candidates.into_iter().enumerate() {
let analysis = analyze_penalty_block_with_op(&candidate.matrix, candidate.op.clone())?;
let dropped_reason = if analysis.rank == 0 {
Some(if analysis.iszero {
PenaltyDropReason::ZeroMatrix
} else {
PenaltyDropReason::NumericalRankZero
})
} else {
None
};
let active = dropped_reason.is_none();
let kronecker_factors =
validated_kronecker_factors(candidate.kronecker_factors, &analysis.sym_penalty);
if active {
let null_basis = nullspace_basis_from_block(&analysis);
log::debug!(
"Retained penalty block source={:?} original_index={} rank={} nullspace_dim_hint={} has_op={} has_null_basis={}",
candidate.source,
original_index,
analysis.rank,
analysis.nullity,
analysis.op.is_some(),
null_basis.is_some(),
);
penalties.push(analysis.sym_penalty);
nullspace_dims.push(analysis.nullity);
active_null_eigenvectors.push(null_basis);
active_ops.push(analysis.op);
} else {
log::debug!(
"Dropped inactive penalty block source={:?} original_index={} reason={:?}",
candidate.source,
original_index,
dropped_reason
);
}
penaltyinfo.push(PenaltyInfo {
source: candidate.source,
original_index,
active,
effective_rank: analysis.rank,
dropped_reason,
nullspace_dim_hint: analysis.nullity,
normalization_scale: candidate.normalization_scale,
kronecker_factors,
});
}
Ok((
penalties,
nullspace_dims,
penaltyinfo,
active_null_eigenvectors,
active_ops,
))
}
pub(crate) fn validated_kronecker_factors(
factors: Option<Vec<Array2<f64>>>,
matrix: &Array2<f64>,
) -> Option<Vec<Array2<f64>>> {
let factors = factors?;
let Some((first, rest)) = factors.split_first() else {
return None;
};
let mut kron = first.clone();
for factor in rest {
kron = crate::construction::kronecker_product(&kron, factor);
}
if kron.dim() != matrix.dim() {
return None;
}
let scale = kron
.iter()
.chain(matrix.iter())
.fold(0.0_f64, |acc, &value| acc.max(value.abs()))
.max(1.0);
let max_abs_diff = kron
.iter()
.zip(matrix.iter())
.fold(0.0_f64, |acc, (&lhs, &rhs)| acc.max((lhs - rhs).abs()));
(max_abs_diff <= scale * 1e-10).then_some(factors)
}
pub(crate) fn build_nullspace_shrinkage_penalty(
penalty: &Array2<f64>,
) -> Result<Option<CanonicalPenaltyBlock>, BasisError> {
if penalty.nrows() != penalty.ncols() {
crate::bail_dim_basis!(
"penalty matrix must be square when building nullspace shrinkage penalty"
);
}
if penalty.nrows() == 0 {
return Ok(None);
}
let (sym, evals, evecs) = spectral_summary(penalty)?;
let tol = spectral_tolerance(&sym, &evals);
let zero_idx: Vec<usize> = evals
.iter()
.enumerate()
.filter_map(|(i, &ev)| (ev.abs() <= tol).then_some(i))
.collect();
if zero_idx.is_empty() {
return Ok(None);
}
let z = evecs.select(Axis(1), &zero_idx);
let shrink = fast_abt(&z, &z);
Ok(Some(CanonicalPenaltyBlock {
sym_penalty: shrink,
eigenvalues: evals,
eigenvectors: evecs,
rank: zero_idx.len(),
nullity: 0,
tol,
iszero: false,
op: None,
}))
}
pub(crate) fn default_internal_knot_count_for_data(n: usize, degree: usize) -> usize {
if n < 8 {
return 0;
}
let heuristic = if n < 16 { 3 } else { (n / 4).max(3) };
let max_reasonable = n.saturating_sub(degree + 2);
heuristic.min(40).min(max_reasonable)
}
pub(crate) fn auto_shrink_bspline_config(
n: usize,
requested_num_internal_knots: usize,
requested_degree: usize,
) -> Option<(usize, usize, bool)> {
if n < 2 {
return None;
}
let mut degree = requested_degree.max(1);
while degree + 1 > n && degree > 1 {
degree -= 1;
}
if degree + 1 > n {
return None;
}
let max_interior = n.saturating_sub(2);
let num_internal_knots = requested_num_internal_knots.min(max_interior);
let shrunk =
num_internal_knots != requested_num_internal_knots || degree != requested_degree.max(1);
Some((num_internal_knots, degree, shrunk))
}
pub(crate) fn maybe_auto_shrink_bspline_spec(
spec: &BSplineBasisSpec,
n: usize,
) -> (BSplineBasisSpec, Option<String>) {
match &spec.knotspec {
BSplineKnotSpec::Generate {
data_range,
num_internal_knots,
} => {
let Some((eff_interior, eff_degree, shrunk)) =
auto_shrink_bspline_config(n, *num_internal_knots, spec.degree)
else {
return (spec.clone(), None);
};
if !shrunk {
return (spec.clone(), None);
}
let note = format!(
"auto-shrink (#340): n={n} too small for requested degree={req_deg}, \
interior_knots={req_ki}; using degree={eff_deg}, interior_knots={eff_ki}",
n = n,
req_deg = spec.degree,
req_ki = num_internal_knots,
eff_deg = eff_degree,
eff_ki = eff_interior,
);
log::info!("B-spline {note} on Generate knotspec");
let mut shrunk_spec = spec.clone();
shrunk_spec.degree = eff_degree;
shrunk_spec.knotspec = BSplineKnotSpec::Generate {
data_range: *data_range,
num_internal_knots: eff_interior,
};
(shrunk_spec, Some(note))
}
BSplineKnotSpec::Automatic {
num_internal_knots,
placement,
} => {
let requested_interior = num_internal_knots
.unwrap_or_else(|| default_internal_knot_count_for_data(n, spec.degree));
let Some((eff_interior, eff_degree, shrunk)) =
auto_shrink_bspline_config(n, requested_interior, spec.degree)
else {
return (spec.clone(), None);
};
if !shrunk {
return (spec.clone(), None);
}
let note = format!(
"auto-shrink (#340): n={n} too small for requested degree={req_deg}, \
interior_knots={req_ki}; using degree={eff_deg}, interior_knots={eff_ki}",
n = n,
req_deg = spec.degree,
req_ki = requested_interior,
eff_deg = eff_degree,
eff_ki = eff_interior,
);
log::info!("B-spline {note} on Automatic knotspec");
let mut shrunk_spec = spec.clone();
shrunk_spec.degree = eff_degree;
shrunk_spec.knotspec = BSplineKnotSpec::Automatic {
num_internal_knots: Some(eff_interior),
placement: *placement,
};
(shrunk_spec, Some(note))
}
BSplineKnotSpec::Provided(_) | BSplineKnotSpec::PeriodicUniform { .. } => {
(spec.clone(), None)
}
}
}
pub(crate) fn finite_data_range(data: ArrayView1<'_, f64>) -> Result<(f64, f64), BasisError> {
if data.is_empty() {
crate::bail_invalid_basis!("cannot infer knot range from empty data");
}
if data.iter().any(|v| !v.is_finite()) {
crate::bail_invalid_basis!("automatic knot placement requires finite data values");
}
let mut minv = f64::INFINITY;
let mut maxv = f64::NEG_INFINITY;
for &x in data {
if x < minv {
minv = x;
}
if x > maxv {
maxv = x;
}
}
Ok((minv, maxv))
}
pub(crate) fn expand_periodic_centers(
centers: &Array2<f64>,
periodic: Option<&[Option<f64>]>,
) -> Result<Array2<f64>, BasisError> {
let Some(periodic) = periodic else {
return Ok(centers.clone());
};
if periodic.len() != centers.ncols() {
crate::bail_dim_basis!(
"period vector length {} does not match smooth dimension {}",
periodic.len(),
centers.ncols()
);
}
let active: Vec<(usize, f64)> = periodic
.iter()
.enumerate()
.filter_map(|(i, p)| p.map(|v| (i, v)))
.collect();
if active.is_empty() {
return Ok(centers.clone());
}
for (axis, period) in &active {
if !period.is_finite() || *period <= 0.0 {
crate::bail_invalid_basis!(
"period for axis {axis} must be finite and positive, got {period}"
);
}
}
let shifts = 3usize.pow(active.len() as u32);
let mut out = Array2::<f64>::zeros((centers.nrows() * shifts, centers.ncols()));
let mut row_out = 0usize;
for code in 0..shifts {
let mut tmp = code;
let mut offsets = vec![0.0; centers.ncols()];
for &(axis, period) in &active {
let digit = tmp % 3;
tmp /= 3;
offsets[axis] = match digit {
0 => -period,
1 => 0.0,
_ => period,
};
}
for r in 0..centers.nrows() {
for c in 0..centers.ncols() {
out[[row_out, c]] = centers[[r, c]] + offsets[c];
}
row_out += 1;
}
}
Ok(out)
}