1use super::*;
2
3#[derive(Clone)]
4pub(crate) struct TransformationExactGeometryCache {
5 pub(crate) key: Vec<u64>,
6 pub(crate) covariate_spec_resolved: TermCollectionSpec,
7 pub(crate) covariate_design: TermCollectionDesign,
8 pub(crate) family: TransformationNormalFamily,
9 pub(crate) blocks: Vec<ParameterBlockSpec>,
10 pub(crate) derivative_blocks: Vec<Vec<CustomFamilyBlockPsiDerivative>>,
11}
12
13#[derive(Clone)]
14pub(crate) struct TransformationExactWarmStart {
15 pub(crate) theta: Array1<f64>,
16 pub(crate) warm_start: CustomFamilyWarmStart,
17}
18
19impl TransformationExactWarmStart {
20 pub(crate) fn is_compatible_with(&self, theta: &Array1<f64>, rho: &Array1<f64>) -> bool {
21 const MAX_THETA_DISTANCE: f64 = 1.5;
22
23 self.theta.len() == theta.len()
24 && self
25 .theta
26 .iter()
27 .zip(theta.iter())
28 .all(|(&a, &b)| (a - b).abs() <= MAX_THETA_DISTANCE)
29 && self.warm_start.compatible_with_rho(rho)
30 }
31}
32
33impl TransformationExactGeometryCache {
34 pub(crate) fn update_initial_log_lambdas(
35 &mut self,
36 log_lambdas: &Array1<f64>,
37 ) -> Result<(), String> {
38 let spec = self
39 .blocks
40 .first_mut()
41 .ok_or_else(|| "missing transformation block spec".to_string())?;
42 if log_lambdas.len() != spec.initial_log_lambdas.len() {
43 return Err(TransformationNormalError::InvalidInput {
44 reason: format!(
45 "transformation final fit rho length mismatch: got {}, expected {}",
46 log_lambdas.len(),
47 spec.initial_log_lambdas.len()
48 ),
49 }
50 .into());
51 }
52 spec.initial_log_lambdas = log_lambdas.clone();
53 Ok(())
54 }
55}
56
57pub(crate) fn transformation_spatial_geometry_key(
58 spec: &TermCollectionSpec,
59 spatial_terms: &[usize],
60) -> Result<Vec<u64>, String> {
61 let mut key = Vec::new();
62 key.push(spatial_terms.len() as u64);
63 for &term_idx in spatial_terms {
64 let term = spec.smooth_terms.get(term_idx).ok_or_else(|| {
65 format!(
66 "transformation spatial geometry key term index {term_idx} out of range for {} smooth terms",
67 spec.smooth_terms.len()
68 )
69 })?;
70 key.push(term_idx as u64);
71
72 let payload = serde_json::to_vec(term).map_err(|err| {
80 format!("failed to serialize transformation spatial geometry term {term_idx}: {err}")
81 })?;
82 key.push(payload.len() as u64);
83 for chunk in payload.chunks(8) {
84 let mut bytes = [0u8; 8];
85 for (dst, src) in bytes.iter_mut().zip(chunk.iter().copied()) {
86 *dst = src;
87 }
88 key.push(u64::from_le_bytes(bytes));
89 }
90 }
91 Ok(key)
92}
93
94#[derive(Clone)]
100pub struct TransformationNormalFitResult {
101 pub family: TransformationNormalFamily,
102 pub fit: UnifiedFitResult,
103 pub covariate_spec_resolved: TermCollectionSpec,
104 pub covariate_design: TermCollectionDesign,
105 pub score_calibration: TransformationScoreCalibration,
106}
107
108pub fn fit_transformation_normal(
116 response: &Array1<f64>,
117 weights: &Array1<f64>,
118 offset: &Array1<f64>,
119 covariate_data: ArrayView2<'_, f64>,
120 covariate_spec: &TermCollectionSpec,
121 config: &TransformationNormalConfig,
122 options: &BlockwiseFitOptions,
123 kappa_options: &SpatialLengthScaleOptimizationOptions,
124 warm_start: Option<&TransformationWarmStart>,
125) -> Result<TransformationNormalFitResult, String> {
126 let mut options = options.clone();
127 let covariate_spec = covariate_spec.clone();
132
133 let boot_design = build_term_collection_design(covariate_data, &covariate_spec)
136 .map_err(|e| format!("failed to build bootstrap covariate design: {e}"))?;
137 let boot_spec = freeze_term_collection_from_design(&covariate_spec, &boot_design)
138 .map_err(|e| format!("failed to freeze bootstrap covariate spatial basis centers: {e}"))?;
139 let mut effective_config = config.clone();
140 if !config.response_num_internal_knots_pinned {
146 effective_config.response_num_internal_knots = effective_response_num_internal_knots(
147 config,
148 response.len(),
149 boot_design.design.ncols(),
150 response.view(),
151 );
152 }
153
154 let (resp_val, resp_deriv, resp_penalties, resp_knots, resp_transform) =
157 build_response_basis(response, &effective_config)?;
158
159 let realized_p_total = resp_val.ncols().saturating_mul(boot_design.design.ncols());
167 let ctn_inner_cap = CTN_INNER_MAX_CYCLES_BASE
168 .saturating_add(realized_p_total.saturating_mul(CTN_INNER_MAX_CYCLES_PER_DIM))
169 .min(CTN_INNER_MAX_CYCLES_CEILING);
170 options.inner_max_cycles = options.inner_max_cycles.min(ctn_inner_cap);
171
172 let spatial_terms = spatial_length_scale_term_indices(&covariate_spec);
174
175 if spatial_terms.is_empty() || !kappa_options.enabled {
176 let cov_design = boot_design;
180 let cov_spec_resolved = boot_spec;
181
182 let family = TransformationNormalFamily::from_prebuilt_response_basis(
183 response,
184 resp_val,
185 resp_deriv,
186 resp_penalties,
187 resp_knots.clone(),
188 effective_config.response_degree,
189 resp_transform,
190 weights,
191 offset,
192 cov_design.design.clone(),
193 cov_design
194 .penalties
195 .iter()
196 .map(|bp| bp.to_penalty_matrix(cov_design.design.ncols()))
197 .collect(),
198 &effective_config,
199 warm_start,
200 )?;
201 let blocks = vec![family.block_spec()];
202 let fit = fit_custom_family(&family, &blocks, &options)
203 .map_err(|e| format!("transformation fit failed: {e}"))?;
204 let (fit, score_calibration) = calibrate_transformation_scores(&family, fit)?;
205
206 return Ok(TransformationNormalFitResult {
207 family,
208 fit,
209 covariate_spec_resolved: cov_spec_resolved,
210 covariate_design: cov_design,
211 score_calibration,
212 });
213 }
214
215 let kappa0 = SpatialLogKappaCoords::from_length_scales_aniso(
220 &covariate_spec,
221 &spatial_terms,
222 kappa_options,
223 )
224 .reseed_from_data(
225 covariate_data,
226 &covariate_spec,
227 &spatial_terms,
228 kappa_options,
229 );
230 let kappa_dims = kappa0.dims_per_term().to_vec();
231 let kappa_lower = SpatialLogKappaCoords::lower_bounds_aniso_from_data(
232 covariate_data,
233 &covariate_spec,
234 &spatial_terms,
235 &kappa_dims,
236 kappa_options,
237 );
238 let kappa_upper = SpatialLogKappaCoords::upper_bounds_aniso_from_data(
239 covariate_data,
240 &covariate_spec,
241 &spatial_terms,
242 &kappa_dims,
243 kappa_options,
244 );
245 let kappa0 = kappa0.clamp_to_bounds(&kappa_lower, &kappa_upper);
247
248 let analytic_psi_available =
250 build_block_spatial_psi_derivatives(covariate_data, &boot_spec, &boot_design)?.is_some();
251
252 let probe_design = build_term_collection_design(covariate_data, &boot_spec)
268 .map_err(|e| format!("failed to rebuild frozen probe covariate design: {e}"))?;
269
270 let probe_family = TransformationNormalFamily::from_prebuilt_response_basis(
272 response,
273 resp_val.clone(),
274 resp_deriv.clone(),
275 resp_penalties.clone(),
276 resp_knots.clone(),
277 effective_config.response_degree,
278 resp_transform.clone(),
279 weights,
280 offset,
281 probe_design.design.clone(),
282 probe_design
283 .penalties
284 .iter()
285 .map(|bp| bp.to_penalty_matrix(probe_design.design.ncols()))
286 .collect(),
287 &effective_config,
288 warm_start,
289 )?;
290 let probe_block = probe_family.block_spec();
291 let n_penalties = probe_block.initial_log_lambdas.len();
292 log::info!(
293 "[transformation-normal] exact joint setup: rho_dim={} log_kappa_dim={} dims_per_term={:?}",
294 n_penalties,
295 kappa0.len(),
296 kappa_dims,
297 );
298 let rho0 = probe_block.initial_log_lambdas.clone();
299 let rho_floor = -12.0;
300 let rho_lower = Array1::<f64>::from_elem(n_penalties, rho_floor);
301 let rho_upper = Array1::<f64>::from_elem(n_penalties, 12.0);
302 let probe_blocks = vec![probe_block.clone()];
303 let (_, cap_hessian) = crate::custom_family::custom_family_outer_derivatives(
304 &probe_family,
305 &probe_blocks,
306 &options,
307 );
308 let analytic_gradient = analytic_psi_available;
309 let analytic_hessian_supported = analytic_gradient && cap_hessian.is_analytic();
310 let analytic_hessian = false;
311 if analytic_hessian_supported {
312 log::info!(
313 "[transformation-normal] CTN exact joint analytic outer Hessian is available but disabled for spatial kappa optimization; using analytic-gradient outer solves to avoid callback logdet trace work"
314 );
315 }
316
317 let (rho0_min, rho0_max) = if rho0.is_empty() {
318 (0.0, 0.0)
319 } else {
320 (
321 rho0.iter().copied().fold(f64::INFINITY, f64::min),
322 rho0.iter().copied().fold(f64::NEG_INFINITY, f64::max),
323 )
324 };
325 log::info!(
326 "[transformation-normal] skipping baseline custom-family prefit before exact joint optimization \
327 (rho_dim={}, log_kappa_dim={}, rho0_range=[{:.3}, {:.3}]); using CTN warm start and penalty-scale rho seed",
328 n_penalties,
329 kappa0.len(),
330 rho0_min,
331 rho0_max,
332 );
333
334 if !analytic_psi_available {
335 return Err(
336 "transformation-normal spatial length-scale optimization requires analytic spatial psi derivatives"
337 .to_string(),
338 );
339 }
340
341 let exact_warm_start: RefCell<Option<TransformationExactWarmStart>> = RefCell::new(None);
343
344 let joint_setup =
345 ExactJointHyperSetup::new(rho0, rho_lower, rho_upper, kappa0, kappa_lower, kappa_upper);
346
347 let rv = resp_val.clone();
349 let rd = resp_deriv.clone();
350 let rp = resp_penalties.clone();
351 let rk = resp_knots.clone();
352 let rt = resp_transform.clone();
353 let rdeg = effective_config.response_degree;
354 let cfg = effective_config.clone();
355 let ws = warm_start.cloned();
356
357 let make_family =
359 |cov_design: &TermCollectionDesign| -> Result<TransformationNormalFamily, String> {
360 TransformationNormalFamily::from_prebuilt_response_basis(
361 response,
362 rv.clone(),
363 rd.clone(),
364 rp.clone(),
365 rk.clone(),
366 rdeg,
367 rt.clone(),
368 weights,
369 offset,
370 cov_design.design.clone(),
371 cov_design
372 .penalties
373 .iter()
374 .map(|bp| bp.to_penalty_matrix(cov_design.design.ncols()))
375 .collect(),
376 &cfg,
377 ws.as_ref(),
378 )
379 };
380
381 let block_specs_slice = [boot_spec.clone()];
382 let block_term_indices_slice = [spatial_terms.clone()];
383 let exact_geometry_cache: RefCell<Option<TransformationExactGeometryCache>> =
384 RefCell::new(None);
385 let spatial_terms_for_cache = spatial_terms.clone();
386
387 let ensure_exact_geometry = |spec: &TermCollectionSpec,
388 design: &TermCollectionDesign|
389 -> Result<(), String> {
390 let effective_spec = freeze_term_collection_from_design(spec, design)
391 .map_err(|e| format!("failed to freeze transformation geometry key: {e}"))?;
392 let key = transformation_spatial_geometry_key(&effective_spec, &spatial_terms_for_cache)?;
393 let needs_rebuild = exact_geometry_cache
394 .borrow()
395 .as_ref()
396 .map(|cached| cached.key != key)
397 .unwrap_or(true);
398 if !needs_rebuild {
399 return Ok(());
400 }
401
402 let geom_start = std::time::Instant::now();
403 let exact_design = build_term_collection_design(covariate_data, &effective_spec)
404 .map_err(|e| format!("failed to rebuild frozen transformation geometry: {e}"))?;
405 let family = make_family(&exact_design)?;
406 let cov_psi_derivs =
407 build_block_spatial_psi_derivatives(covariate_data, &effective_spec, &exact_design)?
408 .ok_or_else(|| {
409 "missing covariate spatial psi derivatives for transformation model".to_string()
410 })?;
411 let tensor_derivs = build_tensor_psi_derivatives(&family, &cov_psi_derivs)?;
412
413 log::debug!(
414 "[transformation-normal] rebuilt exact geometry cache for {} spatial terms in {:.3}s",
415 spatial_terms_for_cache.len(),
416 geom_start.elapsed().as_secs_f64(),
417 );
418
419 exact_geometry_cache.replace(Some(TransformationExactGeometryCache {
420 key,
421 covariate_spec_resolved: effective_spec,
422 covariate_design: exact_design,
423 blocks: vec![family.block_spec()],
424 family,
425 derivative_blocks: vec![tensor_derivs],
426 }));
427 Ok(())
428 };
429
430 let compatible_warm_start =
431 |theta: &Array1<f64>, rho: &Array1<f64>| -> Option<CustomFamilyWarmStart> {
432 exact_warm_start
433 .borrow()
434 .as_ref()
435 .filter(|warm| warm.is_compatible_with(theta, rho))
436 .map(|warm| warm.warm_start.clone())
437 };
438 let store_warm_start = |theta: &Array1<f64>, warm_start: CustomFamilyWarmStart| {
439 exact_warm_start
440 .borrow_mut()
441 .replace(TransformationExactWarmStart {
442 theta: theta.clone(),
443 warm_start,
444 });
445 };
446
447 log::info!(
448 "[transformation-normal] entering exact joint outer optimization \
449 (analytic_gradient={}, analytic_hessian={})",
450 analytic_gradient,
451 analytic_hessian,
452 );
453 let outer_derivative_policy =
458 probe_family.outer_derivative_policy(&probe_blocks, joint_setup.log_kappa_dim(), &options);
459
460 let solved = optimize_spatial_length_scale_exact_joint(
461 covariate_data,
462 &block_specs_slice,
463 &block_term_indices_slice,
464 kappa_options,
465 &joint_setup,
466 gam_solve::seeding::SeedRiskProfile::Gaussian,
467 analytic_gradient,
468 analytic_hessian,
469 true,
474 None,
475 outer_derivative_policy,
476 |theta, specs: &[TermCollectionSpec], designs: &[TermCollectionDesign]| {
478 ensure_exact_geometry(&specs[0], &designs[0])?;
479 let mut cache_ref = exact_geometry_cache.borrow_mut();
480 let geometry = cache_ref
481 .as_mut()
482 .ok_or_else(|| "missing transformation exact geometry cache".to_string())?;
483 let rho = theta.slice(s![..joint_setup.rho_dim()]).to_owned();
484 geometry.update_initial_log_lambdas(&rho)?;
485 let warm_start = compatible_warm_start(theta, &rho);
486 let fit = fit_custom_family_fixed_log_lambdas(
487 &geometry.family,
488 &geometry.blocks,
489 &options,
490 warm_start.as_ref(),
491 0,
492 None,
493 true,
494 )
495 .map_err(|e| format!("transformation fit_fn: {e}"))?;
496 if let Some(block) = fit.block_states.first() {
497 *geometry
498 .family
499 .row_quantity_cache
500 .lock()
501 .expect("CTN row quantity cache mutex poisoned") = None;
502 let final_rows = geometry.family.row_quantities(&block.beta)?;
503 let max_abs_h = final_rows
504 .h
505 .iter()
506 .copied()
507 .map(f64::abs)
508 .fold(0.0, f64::max);
509 let cov_chunk = geometry
510 .family
511 .covariate_design
512 .try_row_chunk(0..response.len())
513 .map_err(|err| {
514 format!("final CTN covariate design validation failed: {err}")
515 })?;
516 let max_abs_cov = cov_chunk.iter().copied().map(f64::abs).fold(0.0, f64::max);
517 log::info!(
518 "[transformation-normal] final fixed-rho CTN validation: max_abs_h={:.6e} max_abs_covariate_basis={:.6e}",
519 max_abs_h,
520 max_abs_cov
521 );
522 }
523 Ok(TransformationNormalFitResult {
524 family: geometry.family.clone(),
525 fit,
526 covariate_spec_resolved: geometry.covariate_spec_resolved.clone(),
527 covariate_design: geometry.covariate_design.clone(),
528 score_calibration: TransformationScoreCalibration::finite_support_pit(),
529 })
530 },
531 |theta,
533 specs: &[TermCollectionSpec],
534 designs: &[TermCollectionDesign],
535 eval_mode,
536 _row_set| {
537 ensure_exact_geometry(&specs[0], &designs[0])?;
538 let mut cache_ref = exact_geometry_cache.borrow_mut();
539 let geometry = cache_ref
540 .as_mut()
541 .ok_or_else(|| "missing transformation exact geometry cache".to_string())?;
542 let rho = theta.slice(s![..joint_setup.rho_dim()]).to_owned();
543 let warm_start = compatible_warm_start(theta, &rho);
544
545 let eval = evaluate_custom_family_joint_hyper(
546 &geometry.family,
547 &geometry.blocks,
548 &options,
549 &rho,
550 &geometry.derivative_blocks,
551 warm_start.as_ref(),
552 eval_mode,
553 )
554 .map_err(|e| format!("transformation exact_fn: {e}"))?;
555
556 if !eval.objective.is_finite() {
557 log::warn!(
558 "transformation exact joint returned non-finite objective: eval_mode={:?} rho={:?} gradient_len={}",
559 eval_mode,
560 rho,
561 eval.gradient.len(),
562 );
563 }
564
565 if eval.objective.is_finite() && eval.gradient.iter().all(|value| value.is_finite()) {
566 store_warm_start(theta, eval.warm_start.clone());
567 }
568
569 if !eval.inner_converged {
570 return Err(format!(
571 "transformation exact joint inner solve did not converge for eval_mode={eval_mode:?}; cached warm start for retry"
572 ));
573 }
574
575 Ok((eval.objective, eval.gradient, eval.outer_hessian))
576 },
577 |theta, specs: &[TermCollectionSpec], designs: &[TermCollectionDesign]| {
578 ensure_exact_geometry(&specs[0], &designs[0])?;
579 let mut cache_ref = exact_geometry_cache.borrow_mut();
580 let geometry = cache_ref
581 .as_mut()
582 .ok_or_else(|| "missing transformation exact geometry cache".to_string())?;
583 let rho = theta.slice(s![..joint_setup.rho_dim()]).to_owned();
584 let warm_start = compatible_warm_start(theta, &rho);
585 let eval = evaluate_custom_family_joint_hyper_efs(
586 &geometry.family,
587 &geometry.blocks,
588 &options,
589 &rho,
590 &geometry.derivative_blocks,
591 warm_start.as_ref(),
592 )
593 .map_err(|e| format!("transformation exact_efs_fn: {e}"))?;
594 store_warm_start(theta, eval.warm_start.clone());
595 if !eval.inner_converged {
596 return Err(
597 "transformation exact joint EFS inner solve did not converge; cached warm start for retry"
598 .to_string(),
599 );
600 }
601 Ok(eval.efs_eval)
602 },
603 |_beta: &Array1<f64>| Ok(gam_solve::rho_optimizer::SeedOutcome::NoSlot),
604 )?;
605
606 let mut fit = solved.fit;
607 let (calibrated_fit, score_calibration) =
608 calibrate_transformation_scores(&fit.family, fit.fit.clone())?;
609 fit.fit = calibrated_fit;
610 fit.score_calibration = score_calibration;
611 Ok(fit)
612}