use crate::linalg::faer_ndarray::{FaerCholesky, fast_ata, fast_atb};
use crate::solver::warm_start_artifact::{
FitArtifact, FitDescriptor, RHO_SATURATION, TermIdentityKey, TransferProvenance,
};
use faer::Side;
use ndarray::{Array1, Array2};
const PROJECTED_BETA_CLAMP: f64 = 1.0e6;
#[derive(Clone, Debug)]
pub struct TermBuildContext {
pub identity: TermIdentityKey,
pub rho_slots: Vec<usize>,
pub reduced_width: usize,
pub gauge_t_block: Option<Array2<f64>>,
}
#[derive(Clone, Debug)]
pub struct TransferResult {
pub rho: Array1<f64>,
pub block_beta: Vec<Array1<f64>>,
pub provenance: Vec<TransferProvenance>,
}
fn project_raw_beta_to_reduced(
t_block: &Array2<f64>,
raw_beta_parent: &[f64],
reduced_width: usize,
) -> Option<Array1<f64>> {
let (raw_rows, red_cols) = t_block.dim();
if red_cols != reduced_width || raw_rows != raw_beta_parent.len() {
return None;
}
if reduced_width == 0 {
return Some(Array1::zeros(0));
}
if raw_beta_parent.iter().any(|v| !v.is_finite()) || t_block.iter().any(|v| !v.is_finite()) {
return None;
}
let mut gram = fast_ata(t_block); let trace: f64 = (0..reduced_width).map(|i| gram[[i, i]]).sum();
let eps = (1.0e-8 * trace / (reduced_width as f64)).max(1.0e-12);
for i in 0..reduced_width {
gram[[i, i]] += eps;
}
let rhs_col = Array2::from_shape_vec((raw_rows, 1), raw_beta_parent.to_vec()).ok()?;
let rhs = fast_atb(t_block, &rhs_col); let rhs_vec = rhs.column(0).to_owned();
let factor = gram.cholesky(Side::Lower).ok()?;
let theta = factor.solvevec(&rhs_vec);
if theta.len() != reduced_width
|| theta
.iter()
.any(|v| !v.is_finite() || v.abs() > PROJECTED_BETA_CLAMP)
{
return None;
}
Some(theta)
}
#[derive(Clone, Copy, Debug)]
pub struct TransferConfig {
pub rho_saturation: f64,
pub rho_interior_clamp: f64,
}
impl Default for TransferConfig {
fn default() -> Self {
Self {
rho_saturation: RHO_SATURATION,
rho_interior_clamp: RHO_SATURATION - 1.0,
}
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum TransferError {
ParentUnusable,
DescriptorMismatch,
}
pub fn build_warm_start(
new_descriptor: &FitDescriptor,
new_terms: &[TermBuildContext],
rho_default: &Array1<f64>,
parent: &FitArtifact,
cfg: TransferConfig,
) -> Result<TransferResult, TransferError> {
if !parent.is_usable() {
return Err(TransferError::ParentUnusable);
}
if parent.descriptor.descriptor_key() != new_descriptor.descriptor_key() {
return Err(TransferError::DescriptorMismatch);
}
let mut rho = rho_default.clone();
let mut provenance = vec![TransferProvenance::Cold; new_terms.len()];
let mut block_beta: Vec<Array1<f64>> = new_terms
.iter()
.map(|t| Array1::<f64>::zeros(t.reduced_width))
.collect();
for (term_idx, new_term) in new_terms.iter().enumerate() {
let Some(parent_term) = parent
.terms
.iter()
.find(|p| p.identity == new_term.identity)
else {
continue;
};
let mut beta_projected = false;
if let Some(t_block) = new_term.gauge_t_block.as_ref()
&& let Some(theta) =
project_raw_beta_to_reduced(t_block, &parent_term.raw_beta, new_term.reduced_width)
{
block_beta[term_idx] = theta;
beta_projected = true;
}
let mut copied_any = false;
if parent_term.rho_for_term.len() == new_term.rho_slots.len() {
for (slot, &parent_rho) in new_term
.rho_slots
.iter()
.zip(parent_term.rho_for_term.iter())
{
if *slot >= rho.len() {
continue;
}
if !parent_rho.is_finite() {
continue;
}
if parent_rho.abs() >= cfg.rho_saturation {
continue;
}
let clamped = parent_rho.clamp(-cfg.rho_interior_clamp, cfg.rho_interior_clamp);
rho[*slot] = clamped;
copied_any = true;
}
}
provenance[term_idx] = if beta_projected {
TransferProvenance::Projected
} else if copied_any {
TransferProvenance::RhoOnly
} else {
TransferProvenance::Cold
};
}
Ok(TransferResult {
rho,
block_beta,
provenance,
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::solver::warm_start_artifact::{
FIT_ARTIFACT_SCHEMA, GlobalFitSummary, ResponseSig, SerializableBasisMeta, TermArtifact,
TermRole, term_identity_from_block,
};
use ndarray::{Array1, Array2};
fn block_id(block_name: &str) -> TermIdentityKey {
term_identity_from_block(TermRole::Mean, block_name, &[None], &[1], 10)
}
fn rho_only_ctx(identity: TermIdentityKey, rho_slots: Vec<usize>) -> TermBuildContext {
TermBuildContext {
identity,
rho_slots,
reduced_width: 0,
gauge_t_block: None,
}
}
fn basis_meta_stub() -> SerializableBasisMeta {
SerializableBasisMeta {
kind: "block-spec".to_string(),
degree: None,
num_knots: None,
n_centers: Some(5),
nullspace_order: None,
matern_nu: None,
periodic: false,
}
}
fn parent_with(identity: TermIdentityKey, rho_for_term: Vec<f64>) -> FitArtifact {
FitArtifact {
schema: FIT_ARTIFACT_SCHEMA,
created_unix_secs: 0,
descriptor: FitDescriptor {
family_kind: "gaussian".to_string(),
term_identities: vec![identity],
response_signature: ResponseSig {
family_kind: "gaussian".to_string(),
n_response_channels: 1,
},
row_population: None,
},
terms: vec![TermArtifact {
identity,
role: TermRole::Mean,
basis_meta: basis_meta_stub(),
joint_null_rotation: None,
raw_beta: vec![0.0; 5],
rho_for_term,
}],
global: GlobalFitSummary {
outer_objective: -10.0,
converged: true,
n_rows: 1000,
},
}
}
fn new_descriptor(identity: TermIdentityKey) -> FitDescriptor {
FitDescriptor {
family_kind: "gaussian".to_string(),
term_identities: vec![identity],
response_signature: ResponseSig {
family_kind: "gaussian".to_string(),
n_response_channels: 1,
},
row_population: None,
}
}
#[test]
fn matched_term_copies_parent_rho() {
let id = block_id("s(x)");
let parent = parent_with(id, vec![2.5]);
let new_terms = vec![rho_only_ctx(id, vec![0])];
let rho_default = Array1::from_vec(vec![0.0]);
let res = build_warm_start(
&new_descriptor(id),
&new_terms,
&rho_default,
&parent,
TransferConfig::default(),
)
.expect("transfer builds");
assert_eq!(res.rho[0], 2.5, "matched term must inherit parent ρ");
assert_eq!(res.provenance[0], TransferProvenance::RhoOnly);
}
#[test]
fn unmatched_term_keeps_default() {
let parent_id = block_id("s(x)");
let new_id = block_id("s(z)");
let new_terms = vec![rho_only_ctx(new_id, vec![0])];
let rho_default = Array1::from_vec(vec![-1.3]);
let mut parent = parent_with(new_id, vec![2.5]);
parent.terms[0].identity = parent_id;
let res = build_warm_start(
&new_descriptor(new_id),
&new_terms,
&rho_default,
&parent,
TransferConfig::default(),
)
.expect("transfer builds");
assert_eq!(res.rho[0], -1.3, "unmatched term keeps the new default ρ");
assert_eq!(res.provenance[0], TransferProvenance::Cold);
}
#[test]
fn saturated_parent_rho_not_copied() {
let id = block_id("s(x)");
let parent = parent_with(id, vec![12.0]);
let new_terms = vec![rho_only_ctx(id, vec![0])];
let rho_default = Array1::from_vec(vec![0.7]);
let res = build_warm_start(
&new_descriptor(id),
&new_terms,
&rho_default,
&parent,
TransferConfig::default(),
)
.expect("transfer builds");
assert_eq!(res.rho[0], 0.7, "saturated parent ρ must not be copied");
assert_eq!(res.provenance[0], TransferProvenance::Cold);
}
#[test]
fn near_box_parent_rho_is_interior_clamped() {
let id = block_id("s(x)");
let parent = parent_with(id, vec![8.7]);
let new_terms = vec![rho_only_ctx(id, vec![0])];
let rho_default = Array1::from_vec(vec![0.0]);
let cfg = TransferConfig::default();
let res = build_warm_start(&new_descriptor(id), &new_terms, &rho_default, &parent, cfg)
.expect("transfer builds");
assert!(res.rho[0] <= cfg.rho_interior_clamp);
assert_eq!(res.rho[0], cfg.rho_interior_clamp);
assert_eq!(res.provenance[0], TransferProvenance::RhoOnly);
}
#[test]
fn nonfinite_parent_is_rejected() {
let id = block_id("s(x)");
let mut parent = parent_with(id, vec![2.0]);
parent.terms[0].raw_beta[0] = f64::NAN; let new_terms = vec![rho_only_ctx(id, vec![0])];
let rho_default = Array1::from_vec(vec![0.42]);
let err = build_warm_start(
&new_descriptor(id),
&new_terms,
&rho_default,
&parent,
TransferConfig::default(),
)
.unwrap_err();
assert_eq!(err, TransferError::ParentUnusable);
}
#[test]
fn rho_only_transfer_leaves_unrelated_slots_at_default() {
let id = block_id("s(x)");
let parent = parent_with(id, vec![3.3]);
let new_terms = vec![rho_only_ctx(id, vec![0])];
let rho_default = Array1::from_vec(vec![0.0, -2.0]);
let res = build_warm_start(
&new_descriptor(id),
&new_terms,
&rho_default,
&parent,
TransferConfig::default(),
)
.expect("transfer builds");
assert_eq!(res.rho[0], 3.3, "matched slot warm-starts");
assert_eq!(res.rho[1], -2.0, "unrelated slot keeps the default");
}
#[test]
fn descriptor_mismatch_rejected() {
let id_a = block_id("s(x)");
let id_b = block_id("s(z)");
let parent = parent_with(id_a, vec![2.0]);
let new_terms = vec![rho_only_ctx(id_b, vec![0])];
let rho_default = Array1::from_vec(vec![0.0]);
let err = build_warm_start(
&new_descriptor(id_b),
&new_terms,
&rho_default,
&parent,
TransferConfig::default(),
)
.unwrap_err();
assert_eq!(err, TransferError::DescriptorMismatch);
}
fn parent_with_raw_beta(
identity: TermIdentityKey,
raw_beta: Vec<f64>,
rho_for_term: Vec<f64>,
) -> FitArtifact {
let mut p = parent_with(identity, rho_for_term);
p.terms[0].raw_beta = raw_beta;
p
}
fn beta_ctx(
identity: TermIdentityKey,
rho_slots: Vec<usize>,
reduced_width: usize,
t_block: Array2<f64>,
) -> TermBuildContext {
TermBuildContext {
identity,
rho_slots,
reduced_width,
gauge_t_block: Some(t_block),
}
}
#[test]
fn beta_projects_to_reduced_width() {
let id = block_id("s(x)");
let raw = vec![1.0, -2.0, 3.5];
let parent = parent_with_raw_beta(id, raw.clone(), vec![1.0]);
let t = Array2::<f64>::eye(3);
let new_terms = vec![beta_ctx(id, vec![0], 3, t)];
let rho_default = Array1::from_vec(vec![0.0]);
let res = build_warm_start(
&new_descriptor(id),
&new_terms,
&rho_default,
&parent,
TransferConfig::default(),
)
.expect("transfer builds");
assert_eq!(res.block_beta[0].len(), 3, "β must be at the reduced width");
for (got, want) in res.block_beta[0].iter().zip(raw.iter()) {
assert!((got - want).abs() < 1e-6, "identity projection ≈ parent β");
}
assert_eq!(res.provenance[0], TransferProvenance::Projected);
}
#[test]
fn cross_width_loso_case_transfers_beta() {
let id = block_id("s(x)");
let raw = vec![0.5, 0.5, 1.0, -1.0];
let parent = parent_with_raw_beta(id, raw, vec![1.0]);
let t =
Array2::from_shape_vec((4, 2), vec![1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0]).unwrap();
let new_terms = vec![beta_ctx(id, vec![0], 2, t)];
let rho_default = Array1::from_vec(vec![0.0]);
let res = build_warm_start(
&new_descriptor(id),
&new_terms,
&rho_default,
&parent,
TransferConfig::default(),
)
.expect("transfer builds");
assert_eq!(
res.block_beta[0].len(),
2,
"cross-width LOSO must project to the new reduced width, not skip"
);
assert!(res.block_beta[0].iter().all(|v| v.is_finite()));
assert_eq!(res.provenance[0], TransferProvenance::Projected);
}
#[test]
fn beta_dimension_anomaly_falls_back_to_cold() {
let id = block_id("s(x)");
let parent = parent_with_raw_beta(id, vec![1.0, 2.0, 3.0, 4.0, 5.0], vec![1.0]);
let t = Array2::<f64>::eye(3); let new_terms = vec![beta_ctx(id, vec![0], 3, t)];
let rho_default = Array1::from_vec(vec![0.0]);
let res = build_warm_start(
&new_descriptor(id),
&new_terms,
&rho_default,
&parent,
TransferConfig::default(),
)
.expect("transfer builds");
assert_eq!(
res.block_beta[0].len(),
3,
"cold β still at the reduced width"
);
assert!(
res.block_beta[0].iter().all(|&v| v == 0.0),
"dimension anomaly must yield cold zeros"
);
assert_eq!(res.provenance[0], TransferProvenance::RhoOnly);
}
#[test]
fn beta_nonfinite_parent_is_globally_rejected() {
let id = block_id("s(x)");
let mut parent = parent_with_raw_beta(id, vec![1.0, 0.0, 3.0], vec![1.0]);
parent.terms[0].raw_beta[1] = f64::NAN;
let t = Array2::<f64>::eye(3);
let new_terms = vec![beta_ctx(id, vec![0], 3, t)];
let rho_default = Array1::from_vec(vec![0.0]);
let err = build_warm_start(
&new_descriptor(id),
&new_terms,
&rho_default,
&parent,
TransferConfig::default(),
)
.unwrap_err();
assert_eq!(err, TransferError::ParentUnusable);
}
#[test]
fn projection_helper_identity_is_exact() {
let raw = vec![2.0, -1.0, 0.0, 4.0];
let t = Array2::<f64>::eye(4);
let theta = project_raw_beta_to_reduced(&t, &raw, 4).expect("projects");
for (g, w) in theta.iter().zip(raw.iter()) {
assert!((g - w).abs() < 1e-7);
}
}
}