use crate::solver::warm_start_artifact::{
FitArtifact, FitDescriptor, RHO_SATURATION, TermIdentityKey, TransferProvenance,
};
use ndarray::Array1;
#[derive(Clone, Debug)]
pub struct TermBuildContext {
pub identity: TermIdentityKey,
pub rho_slots: Vec<usize>,
}
#[derive(Clone, Debug)]
pub struct TransferResult {
pub rho: Array1<f64>,
pub provenance: Vec<TransferProvenance>,
}
#[derive(Clone, Copy, Debug)]
pub struct TransferConfig {
pub rho_saturation: f64,
pub rho_interior_clamp: f64,
}
impl Default for TransferConfig {
fn default() -> Self {
Self {
rho_saturation: RHO_SATURATION,
rho_interior_clamp: RHO_SATURATION - 1.0,
}
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum TransferError {
ParentUnusable,
DescriptorMismatch,
}
pub fn build_warm_start(
new_descriptor: &FitDescriptor,
new_terms: &[TermBuildContext],
rho_default: &Array1<f64>,
parent: &FitArtifact,
cfg: TransferConfig,
) -> Result<TransferResult, TransferError> {
if !parent.is_usable() {
return Err(TransferError::ParentUnusable);
}
if parent.descriptor.descriptor_key() != new_descriptor.descriptor_key() {
return Err(TransferError::DescriptorMismatch);
}
let mut rho = rho_default.clone();
let mut provenance = vec![TransferProvenance::Cold; new_terms.len()];
for (term_idx, new_term) in new_terms.iter().enumerate() {
let Some(parent_term) = parent
.terms
.iter()
.find(|p| p.identity == new_term.identity)
else {
continue;
};
if parent_term.rho_for_term.len() != new_term.rho_slots.len() {
continue;
}
let mut copied_any = false;
for (slot, &parent_rho) in new_term
.rho_slots
.iter()
.zip(parent_term.rho_for_term.iter())
{
if *slot >= rho.len() {
continue;
}
if !parent_rho.is_finite() {
continue;
}
if parent_rho.abs() >= cfg.rho_saturation {
continue;
}
let clamped = parent_rho.clamp(-cfg.rho_interior_clamp, cfg.rho_interior_clamp);
rho[*slot] = clamped;
copied_any = true;
}
provenance[term_idx] = if copied_any {
TransferProvenance::RhoOnly
} else {
TransferProvenance::Cold
};
}
Ok(TransferResult { rho, provenance })
}
#[cfg(test)]
mod tests {
use super::*;
use crate::solver::warm_start_artifact::{
FIT_ARTIFACT_SCHEMA, GlobalFitSummary, ResponseSig, SerializableBasisMeta, TermArtifact,
TermRole, term_identity_from_block,
};
use ndarray::Array1;
fn block_id(block_name: &str) -> TermIdentityKey {
term_identity_from_block(TermRole::Mean, block_name, &[None], &[1])
}
fn basis_meta_stub() -> SerializableBasisMeta {
SerializableBasisMeta {
kind: "block-spec".to_string(),
degree: None,
num_knots: None,
n_centers: Some(5),
nullspace_order: None,
matern_nu: None,
periodic: false,
}
}
fn parent_with(identity: TermIdentityKey, rho_for_term: Vec<f64>) -> FitArtifact {
FitArtifact {
schema: FIT_ARTIFACT_SCHEMA,
created_unix_secs: 0,
descriptor: FitDescriptor {
family_kind: "gaussian".to_string(),
term_identities: vec![identity],
response_signature: ResponseSig {
family_kind: "gaussian".to_string(),
n_response_channels: 1,
},
row_population: None,
},
terms: vec![TermArtifact {
identity,
role: TermRole::Mean,
basis_meta: basis_meta_stub(),
joint_null_rotation: None,
raw_beta: vec![0.0; 5],
rho_for_term,
}],
global: GlobalFitSummary {
outer_objective: -10.0,
converged: true,
n_rows: 1000,
},
}
}
fn new_descriptor(identity: TermIdentityKey) -> FitDescriptor {
FitDescriptor {
family_kind: "gaussian".to_string(),
term_identities: vec![identity],
response_signature: ResponseSig {
family_kind: "gaussian".to_string(),
n_response_channels: 1,
},
row_population: None,
}
}
#[test]
fn matched_term_copies_parent_rho() {
let id = block_id("s(x)");
let parent = parent_with(id, vec![2.5]);
let new_terms = vec![TermBuildContext {
identity: id,
rho_slots: vec![0],
}];
let rho_default = Array1::from_vec(vec![0.0]);
let res = build_warm_start(
&new_descriptor(id),
&new_terms,
&rho_default,
&parent,
TransferConfig::default(),
)
.expect("transfer builds");
assert_eq!(res.rho[0], 2.5, "matched term must inherit parent ρ");
assert_eq!(res.provenance[0], TransferProvenance::RhoOnly);
}
#[test]
fn unmatched_term_keeps_default() {
let parent_id = block_id("s(x)");
let new_id = block_id("s(z)");
let new_terms = vec![TermBuildContext {
identity: new_id,
rho_slots: vec![0],
}];
let rho_default = Array1::from_vec(vec![-1.3]);
let mut parent = parent_with(new_id, vec![2.5]);
parent.terms[0].identity = parent_id;
let res = build_warm_start(
&new_descriptor(new_id),
&new_terms,
&rho_default,
&parent,
TransferConfig::default(),
)
.expect("transfer builds");
assert_eq!(res.rho[0], -1.3, "unmatched term keeps the new default ρ");
assert_eq!(res.provenance[0], TransferProvenance::Cold);
}
#[test]
fn saturated_parent_rho_not_copied() {
let id = block_id("s(x)");
let parent = parent_with(id, vec![12.0]);
let new_terms = vec![TermBuildContext {
identity: id,
rho_slots: vec![0],
}];
let rho_default = Array1::from_vec(vec![0.7]);
let res = build_warm_start(
&new_descriptor(id),
&new_terms,
&rho_default,
&parent,
TransferConfig::default(),
)
.expect("transfer builds");
assert_eq!(res.rho[0], 0.7, "saturated parent ρ must not be copied");
assert_eq!(res.provenance[0], TransferProvenance::Cold);
}
#[test]
fn near_box_parent_rho_is_interior_clamped() {
let id = block_id("s(x)");
let parent = parent_with(id, vec![8.7]);
let new_terms = vec![TermBuildContext {
identity: id,
rho_slots: vec![0],
}];
let rho_default = Array1::from_vec(vec![0.0]);
let cfg = TransferConfig::default();
let res = build_warm_start(&new_descriptor(id), &new_terms, &rho_default, &parent, cfg)
.expect("transfer builds");
assert!(res.rho[0] <= cfg.rho_interior_clamp);
assert_eq!(res.rho[0], cfg.rho_interior_clamp);
assert_eq!(res.provenance[0], TransferProvenance::RhoOnly);
}
#[test]
fn nonfinite_parent_is_rejected() {
let id = block_id("s(x)");
let mut parent = parent_with(id, vec![2.0]);
parent.terms[0].raw_beta[0] = f64::NAN; let new_terms = vec![TermBuildContext {
identity: id,
rho_slots: vec![0],
}];
let rho_default = Array1::from_vec(vec![0.42]);
let err = build_warm_start(
&new_descriptor(id),
&new_terms,
&rho_default,
&parent,
TransferConfig::default(),
)
.unwrap_err();
assert_eq!(err, TransferError::ParentUnusable);
}
#[test]
fn rho_only_transfer_leaves_unrelated_slots_at_default() {
let id = block_id("s(x)");
let parent = parent_with(id, vec![3.3]);
let new_terms = vec![TermBuildContext {
identity: id,
rho_slots: vec![0],
}];
let rho_default = Array1::from_vec(vec![0.0, -2.0]);
let res = build_warm_start(
&new_descriptor(id),
&new_terms,
&rho_default,
&parent,
TransferConfig::default(),
)
.expect("transfer builds");
assert_eq!(res.rho[0], 3.3, "matched slot warm-starts");
assert_eq!(res.rho[1], -2.0, "unrelated slot keeps the default");
}
#[test]
fn descriptor_mismatch_rejected() {
let id_a = block_id("s(x)");
let id_b = block_id("s(z)");
let parent = parent_with(id_a, vec![2.0]);
let new_terms = vec![TermBuildContext {
identity: id_b,
rho_slots: vec![0],
}];
let rho_default = Array1::from_vec(vec![0.0]);
let err = build_warm_start(
&new_descriptor(id_b),
&new_terms,
&rho_default,
&parent,
TransferConfig::default(),
)
.unwrap_err();
assert_eq!(err, TransferError::DescriptorMismatch);
}
}