1use gam_linalg::faer_ndarray::{FaerCholesky, fast_ata, fast_atb};
19use crate::warm_start_artifact::{
20 FitArtifact, FitDescriptor, RHO_SATURATION, TermIdentityKey, TransferProvenance,
21};
22use faer::Side;
23use ndarray::{Array1, Array2};
24
25const PROJECTED_BETA_CLAMP: f64 = 1.0e6;
30
31#[derive(Clone, Debug)]
37pub struct TermBuildContext {
38 pub identity: TermIdentityKey,
40 pub rho_slots: Vec<usize>,
43 pub reduced_width: usize,
47 pub gauge_t_block: Option<Array2<f64>>,
54}
55
56#[derive(Clone, Debug)]
67pub struct TransferResult {
68 pub rho: Array1<f64>,
69 pub block_beta: Vec<Array1<f64>>,
70 pub provenance: Vec<TransferProvenance>,
71}
72
73fn project_raw_beta_to_reduced(
90 t_block: &Array2<f64>,
91 raw_beta_parent: &[f64],
92 reduced_width: usize,
93) -> Option<Array1<f64>> {
94 let (raw_rows, red_cols) = t_block.dim();
95 if red_cols != reduced_width || raw_rows != raw_beta_parent.len() {
96 return None;
97 }
98 if reduced_width == 0 {
99 return Some(Array1::zeros(0));
100 }
101 if raw_beta_parent.iter().any(|v| !v.is_finite()) || t_block.iter().any(|v| !v.is_finite()) {
102 return None;
103 }
104 let mut gram = fast_ata(t_block); let trace: f64 = (0..reduced_width).map(|i| gram[[i, i]]).sum();
108 let eps = (1.0e-8 * trace / (reduced_width as f64)).max(1.0e-12);
109 for i in 0..reduced_width {
110 gram[[i, i]] += eps;
111 }
112 let rhs_col = Array2::from_shape_vec((raw_rows, 1), raw_beta_parent.to_vec()).ok()?;
113 let rhs = fast_atb(t_block, &rhs_col); let rhs_vec = rhs.column(0).to_owned();
115 let factor = gram.cholesky(Side::Lower).ok()?;
116 let theta = factor.solvevec(&rhs_vec);
117 if theta.len() != reduced_width
118 || theta
119 .iter()
120 .any(|v| !v.is_finite() || v.abs() > PROJECTED_BETA_CLAMP)
121 {
122 return None;
123 }
124 Some(theta)
125}
126
127#[derive(Clone, Copy, Debug)]
130pub struct TransferConfig {
131 pub rho_saturation: f64,
134 pub rho_interior_clamp: f64,
137}
138
139impl Default for TransferConfig {
140 fn default() -> Self {
141 Self {
142 rho_saturation: RHO_SATURATION,
143 rho_interior_clamp: RHO_SATURATION - 1.0,
147 }
148 }
149}
150
151#[derive(Clone, Debug, PartialEq, Eq)]
155pub enum TransferError {
156 ParentUnusable,
158 DescriptorMismatch,
160}
161
162pub fn build_warm_start(
183 new_descriptor: &FitDescriptor,
184 new_terms: &[TermBuildContext],
185 rho_default: &Array1<f64>,
186 parent: &FitArtifact,
187 cfg: TransferConfig,
188) -> Result<TransferResult, TransferError> {
189 if !parent.is_usable() {
191 return Err(TransferError::ParentUnusable);
192 }
193 if parent.descriptor.descriptor_key() != new_descriptor.descriptor_key() {
195 return Err(TransferError::DescriptorMismatch);
196 }
197
198 let mut rho = rho_default.clone();
199 let mut provenance = vec![TransferProvenance::Cold; new_terms.len()];
200 let mut block_beta: Vec<Array1<f64>> = new_terms
203 .iter()
204 .map(|t| Array1::<f64>::zeros(t.reduced_width))
205 .collect();
206
207 for (term_idx, new_term) in new_terms.iter().enumerate() {
208 let Some(parent_term) = parent
210 .terms
211 .iter()
212 .find(|p| p.identity == new_term.identity)
213 else {
214 continue;
216 };
217
218 let mut beta_projected = false;
222 if let Some(t_block) = new_term.gauge_t_block.as_ref()
223 && let Some(theta) =
224 project_raw_beta_to_reduced(t_block, &parent_term.raw_beta, new_term.reduced_width)
225 {
226 block_beta[term_idx] = theta;
227 beta_projected = true;
228 }
229
230 let mut copied_any = false;
234 if parent_term.rho_for_term.len() == new_term.rho_slots.len() {
235 for (slot, &parent_rho) in new_term
236 .rho_slots
237 .iter()
238 .zip(parent_term.rho_for_term.iter())
239 {
240 if *slot >= rho.len() {
241 continue;
243 }
244 if !parent_rho.is_finite() {
245 continue;
246 }
247 if parent_rho.abs() >= cfg.rho_saturation {
250 continue;
251 }
252 let clamped = parent_rho.clamp(-cfg.rho_interior_clamp, cfg.rho_interior_clamp);
254 rho[*slot] = clamped;
255 copied_any = true;
256 }
257 }
258
259 provenance[term_idx] = if beta_projected {
260 TransferProvenance::Projected
261 } else if copied_any {
262 TransferProvenance::RhoOnly
263 } else {
264 TransferProvenance::Cold
265 };
266 }
267
268 Ok(TransferResult {
269 rho,
270 block_beta,
271 provenance,
272 })
273}
274
275#[cfg(test)]
276mod tests {
277 use super::*;
278 use crate::warm_start_artifact::{
279 FIT_ARTIFACT_SCHEMA, GlobalFitSummary, ResponseSig, SerializableBasisMeta, TermArtifact,
280 TermRole, term_identity_from_block,
281 };
282 use ndarray::{Array1, Array2};
283
284 fn block_id(block_name: &str) -> TermIdentityKey {
287 term_identity_from_block(TermRole::Mean, block_name, &[None], &[1], 10)
288 }
289
290 fn rho_only_ctx(identity: TermIdentityKey, rho_slots: Vec<usize>) -> TermBuildContext {
292 TermBuildContext {
293 identity,
294 rho_slots,
295 reduced_width: 0,
296 gauge_t_block: None,
297 }
298 }
299
300 fn basis_meta_stub() -> SerializableBasisMeta {
303 SerializableBasisMeta {
304 kind: "block-spec".to_string(),
305 degree: None,
306 num_knots: None,
307 n_centers: Some(5),
308 nullspace_order: None,
309 matern_nu: None,
310 periodic: false,
311 }
312 }
313
314 fn parent_with(identity: TermIdentityKey, rho_for_term: Vec<f64>) -> FitArtifact {
315 FitArtifact {
316 schema: FIT_ARTIFACT_SCHEMA,
317 created_unix_secs: 0,
318 descriptor: FitDescriptor {
319 family_kind: "gaussian".to_string(),
320 term_identities: vec![identity],
321 response_signature: ResponseSig {
322 family_kind: "gaussian".to_string(),
323 n_response_channels: 1,
324 },
325 row_population: None,
326 },
327 terms: vec![TermArtifact {
328 identity,
329 role: TermRole::Mean,
330 basis_meta: basis_meta_stub(),
331 joint_null_rotation: None,
332 raw_beta: vec![0.0; 5],
333 rho_for_term,
334 }],
335 global: GlobalFitSummary {
336 outer_objective: -10.0,
337 converged: true,
338 n_rows: 1000,
339 },
340 }
341 }
342
343 fn new_descriptor(identity: TermIdentityKey) -> FitDescriptor {
344 FitDescriptor {
345 family_kind: "gaussian".to_string(),
346 term_identities: vec![identity],
347 response_signature: ResponseSig {
348 family_kind: "gaussian".to_string(),
349 n_response_channels: 1,
350 },
351 row_population: None,
352 }
353 }
354
355 #[test]
356 fn matched_term_copies_parent_rho() {
357 let id = block_id("s(x)");
358 let parent = parent_with(id, vec![2.5]);
359 let new_terms = vec![rho_only_ctx(id, vec![0])];
360 let rho_default = Array1::from_vec(vec![0.0]);
361 let res = build_warm_start(
362 &new_descriptor(id),
363 &new_terms,
364 &rho_default,
365 &parent,
366 TransferConfig::default(),
367 )
368 .expect("transfer builds");
369 assert_eq!(res.rho[0], 2.5, "matched term must inherit parent ρ");
370 assert_eq!(res.provenance[0], TransferProvenance::RhoOnly);
371 }
372
373 #[test]
374 fn unmatched_term_keeps_default() {
375 let parent_id = block_id("s(x)");
376 let new_id = block_id("s(z)");
380 let new_terms = vec![rho_only_ctx(new_id, vec![0])];
381 let rho_default = Array1::from_vec(vec![-1.3]);
382 let mut parent = parent_with(new_id, vec![2.5]);
386 parent.terms[0].identity = parent_id;
387 let res = build_warm_start(
388 &new_descriptor(new_id),
389 &new_terms,
390 &rho_default,
391 &parent,
392 TransferConfig::default(),
393 )
394 .expect("transfer builds");
395 assert_eq!(res.rho[0], -1.3, "unmatched term keeps the new default ρ");
396 assert_eq!(res.provenance[0], TransferProvenance::Cold);
397 }
398
399 #[test]
400 fn saturated_parent_rho_not_copied() {
401 let id = block_id("s(x)");
402 let parent = parent_with(id, vec![12.0]);
404 let new_terms = vec![rho_only_ctx(id, vec![0])];
405 let rho_default = Array1::from_vec(vec![0.7]);
406 let res = build_warm_start(
407 &new_descriptor(id),
408 &new_terms,
409 &rho_default,
410 &parent,
411 TransferConfig::default(),
412 )
413 .expect("transfer builds");
414 assert_eq!(res.rho[0], 0.7, "saturated parent ρ must not be copied");
415 assert_eq!(res.provenance[0], TransferProvenance::Cold);
416 }
417
418 #[test]
419 fn near_box_parent_rho_is_interior_clamped() {
420 let id = block_id("s(x)");
421 let parent = parent_with(id, vec![8.7]);
423 let new_terms = vec![rho_only_ctx(id, vec![0])];
424 let rho_default = Array1::from_vec(vec![0.0]);
425 let cfg = TransferConfig::default();
426 let res = build_warm_start(&new_descriptor(id), &new_terms, &rho_default, &parent, cfg)
427 .expect("transfer builds");
428 assert!(res.rho[0] <= cfg.rho_interior_clamp);
429 assert_eq!(res.rho[0], cfg.rho_interior_clamp);
430 assert_eq!(res.provenance[0], TransferProvenance::RhoOnly);
431 }
432
433 #[test]
434 fn nonfinite_parent_is_rejected() {
435 let id = block_id("s(x)");
436 let mut parent = parent_with(id, vec![2.0]);
437 parent.terms[0].raw_beta[0] = f64::NAN; let new_terms = vec![rho_only_ctx(id, vec![0])];
439 let rho_default = Array1::from_vec(vec![0.42]);
440 let err = build_warm_start(
441 &new_descriptor(id),
442 &new_terms,
443 &rho_default,
444 &parent,
445 TransferConfig::default(),
446 )
447 .unwrap_err();
448 assert_eq!(err, TransferError::ParentUnusable);
449 }
450
451 #[test]
452 fn rho_only_transfer_leaves_unrelated_slots_at_default() {
453 let id = block_id("s(x)");
458 let parent = parent_with(id, vec![3.3]);
459 let new_terms = vec![rho_only_ctx(id, vec![0])];
460 let rho_default = Array1::from_vec(vec![0.0, -2.0]);
462 let res = build_warm_start(
463 &new_descriptor(id),
464 &new_terms,
465 &rho_default,
466 &parent,
467 TransferConfig::default(),
468 )
469 .expect("transfer builds");
470 assert_eq!(res.rho[0], 3.3, "matched slot warm-starts");
471 assert_eq!(res.rho[1], -2.0, "unrelated slot keeps the default");
472 }
473
474 #[test]
475 fn descriptor_mismatch_rejected() {
476 let id_a = block_id("s(x)");
477 let id_b = block_id("s(z)");
478 let parent = parent_with(id_a, vec![2.0]);
479 let new_terms = vec![rho_only_ctx(id_b, vec![0])];
480 let rho_default = Array1::from_vec(vec![0.0]);
481 let err = build_warm_start(
482 &new_descriptor(id_b),
483 &new_terms,
484 &rho_default,
485 &parent,
486 TransferConfig::default(),
487 )
488 .unwrap_err();
489 assert_eq!(err, TransferError::DescriptorMismatch);
490 }
491
492 fn parent_with_raw_beta(
494 identity: TermIdentityKey,
495 raw_beta: Vec<f64>,
496 rho_for_term: Vec<f64>,
497 ) -> FitArtifact {
498 let mut p = parent_with(identity, rho_for_term);
499 p.terms[0].raw_beta = raw_beta;
500 p
501 }
502
503 fn beta_ctx(
504 identity: TermIdentityKey,
505 rho_slots: Vec<usize>,
506 reduced_width: usize,
507 t_block: Array2<f64>,
508 ) -> TermBuildContext {
509 TermBuildContext {
510 identity,
511 rho_slots,
512 reduced_width,
513 gauge_t_block: Some(t_block),
514 }
515 }
516
517 #[test]
518 fn beta_projects_to_reduced_width() {
519 let id = block_id("s(x)");
522 let raw = vec![1.0, -2.0, 3.5];
523 let parent = parent_with_raw_beta(id, raw.clone(), vec![1.0]);
524 let t = Array2::<f64>::eye(3);
525 let new_terms = vec![beta_ctx(id, vec![0], 3, t)];
526 let rho_default = Array1::from_vec(vec![0.0]);
527 let res = build_warm_start(
528 &new_descriptor(id),
529 &new_terms,
530 &rho_default,
531 &parent,
532 TransferConfig::default(),
533 )
534 .expect("transfer builds");
535 assert_eq!(res.block_beta[0].len(), 3, "β must be at the reduced width");
536 for (got, want) in res.block_beta[0].iter().zip(raw.iter()) {
537 assert!((got - want).abs() < 1e-6, "identity projection ≈ parent β");
538 }
539 assert_eq!(res.provenance[0], TransferProvenance::Projected);
540 }
541
542 #[test]
543 fn cross_width_loso_case_transfers_beta() {
544 let id = block_id("s(x)");
548 let raw = vec![0.5, 0.5, 1.0, -1.0];
549 let parent = parent_with_raw_beta(id, raw, vec![1.0]);
550 let t =
553 Array2::from_shape_vec((4, 2), vec![1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0]).unwrap();
554 let new_terms = vec![beta_ctx(id, vec![0], 2, t)];
555 let rho_default = Array1::from_vec(vec![0.0]);
556 let res = build_warm_start(
557 &new_descriptor(id),
558 &new_terms,
559 &rho_default,
560 &parent,
561 TransferConfig::default(),
562 )
563 .expect("transfer builds");
564 assert_eq!(
565 res.block_beta[0].len(),
566 2,
567 "cross-width LOSO must project to the new reduced width, not skip"
568 );
569 assert!(res.block_beta[0].iter().all(|v| v.is_finite()));
570 assert_eq!(res.provenance[0], TransferProvenance::Projected);
571 }
572
573 #[test]
574 fn beta_dimension_anomaly_falls_back_to_cold() {
575 let id = block_id("s(x)");
578 let parent = parent_with_raw_beta(id, vec![1.0, 2.0, 3.0, 4.0, 5.0], vec![1.0]);
579 let t = Array2::<f64>::eye(3); let new_terms = vec![beta_ctx(id, vec![0], 3, t)];
581 let rho_default = Array1::from_vec(vec![0.0]);
582 let res = build_warm_start(
583 &new_descriptor(id),
584 &new_terms,
585 &rho_default,
586 &parent,
587 TransferConfig::default(),
588 )
589 .expect("transfer builds");
590 assert_eq!(
591 res.block_beta[0].len(),
592 3,
593 "cold β still at the reduced width"
594 );
595 assert!(
596 res.block_beta[0].iter().all(|&v| v == 0.0),
597 "dimension anomaly must yield cold zeros"
598 );
599 assert_eq!(res.provenance[0], TransferProvenance::RhoOnly);
601 }
602
603 #[test]
604 fn beta_nonfinite_parent_is_globally_rejected() {
605 let id = block_id("s(x)");
608 let mut parent = parent_with_raw_beta(id, vec![1.0, 0.0, 3.0], vec![1.0]);
609 parent.terms[0].raw_beta[1] = f64::NAN;
610 let t = Array2::<f64>::eye(3);
611 let new_terms = vec![beta_ctx(id, vec![0], 3, t)];
612 let rho_default = Array1::from_vec(vec![0.0]);
613 let err = build_warm_start(
614 &new_descriptor(id),
615 &new_terms,
616 &rho_default,
617 &parent,
618 TransferConfig::default(),
619 )
620 .unwrap_err();
621 assert_eq!(err, TransferError::ParentUnusable);
622 }
623
624 #[test]
625 fn projection_helper_identity_is_exact() {
626 let raw = vec![2.0, -1.0, 0.0, 4.0];
627 let t = Array2::<f64>::eye(4);
628 let theta = project_raw_beta_to_reduced(&t, &raw, 4).expect("projects");
629 for (g, w) in theta.iter().zip(raw.iter()) {
630 assert!((g - w).abs() < 1e-7);
631 }
632 }
633}