1use super::*;
2
3impl CustomFamily for TransformationNormalFamily {
8 fn evaluate(&self, block_states: &[ParameterBlockState]) -> Result<FamilyEvaluation, String> {
9 crate::block_layout::block_count::validate_block_count::<
10 TransformationNormalError,
11 >("TransformationNormalFamily", 1, block_states.len())?;
12 let evaluate_start = std::time::Instant::now();
13 let beta = &block_states[0].beta;
14 let row_q_start = std::time::Instant::now();
15 let row_quantities = self.row_quantities(beta)?;
16 log::info!(
17 "[STAGE] CTN row_quantities (h, h', 1/h', powers) n={} elapsed={:.3}s",
18 row_quantities.h.len(),
19 row_q_start.elapsed().as_secs_f64(),
20 );
21 let h = row_quantities.h.as_ref();
22 let n = h.len();
23
24 let log_likelihood = row_quantities.log_likelihood;
25 let grad_start = std::time::Instant::now();
29 let (grad, hessian) = self.scop_gradient_and_negative_hessian(beta, &row_quantities)?;
30 log::info!(
31 "[STAGE] CTN gradient terms n={} p={} elapsed={:.3}s",
32 n,
33 grad.len(),
34 grad_start.elapsed().as_secs_f64(),
35 );
36
37 let hess_start = std::time::Instant::now();
38 let p_dim = hessian.nrows() as u64;
39 let n_u64 = n as u64;
40 log::info!(
41 "[STAGE] CTN hessian terms (SCOP exact dense) n={} p={} flops~{} elapsed={:.3}s",
42 n,
43 p_dim,
44 n_u64.saturating_mul(p_dim).saturating_mul(p_dim),
45 hess_start.elapsed().as_secs_f64(),
46 );
47 log::info!(
48 "[STAGE] CTN evaluate end n={} p={} elapsed={:.3}s",
49 n,
50 p_dim,
51 evaluate_start.elapsed().as_secs_f64(),
52 );
53
54 Ok(FamilyEvaluation {
55 log_likelihood,
56 blockworking_sets: vec![BlockWorkingSet::ExactNewton {
57 gradient: grad,
58 hessian: SymmetricMatrix::Dense(hessian),
59 }],
60 })
61 }
62
63 fn log_likelihood_only(&self, block_states: &[ParameterBlockState]) -> Result<f64, String> {
64 crate::block_layout::block_count::validate_block_count::<
65 TransformationNormalError,
66 >("TransformationNormalFamily", 1, block_states.len())?;
67 let row_quantities = match self.row_quantities(&block_states[0].beta) {
71 Ok(rq) => rq,
72 Err(_) => return Ok(f64::NEG_INFINITY),
73 };
74 Ok(row_quantities.log_likelihood)
75 }
76
77 fn log_likelihood_only_with_options(
78 &self,
79 block_states: &[ParameterBlockState],
80 options: &BlockwiseFitOptions,
81 ) -> Result<f64, String> {
82 match self.maybe_with_outer_subsample_from_options(options) {
89 Ok(Some(masked)) => masked.log_likelihood_only(block_states),
90 Ok(None) => self.log_likelihood_only(block_states),
91 Err(e) => Err(e.into()),
92 }
93 }
94
95 fn exact_newton_joint_gradient_evaluation(
111 &self,
112 block_states: &[ParameterBlockState],
113 _: &[ParameterBlockSpec],
114 ) -> Result<Option<ExactNewtonJointGradientEvaluation>, String> {
115 crate::block_layout::block_count::validate_block_count::<
116 TransformationNormalError,
117 >("TransformationNormalFamily", 1, block_states.len())?;
118 let beta = &block_states[0].beta;
119 let row_quantities = self.row_quantities(beta)?;
120 let log_likelihood = row_quantities.log_likelihood;
121 let gradient = self.scop_gradient(beta, &row_quantities)?;
122 Ok(Some(ExactNewtonJointGradientEvaluation {
123 log_likelihood,
124 gradient,
125 }))
126 }
127
128 fn exact_newton_joint_hessian_beta_dependent(&self) -> bool {
129 true
131 }
132
133 fn joint_jeffreys_term_required(&self) -> bool {
134 false
153 }
154
155 fn coefficient_hessian_cost(&self, specs: &[ParameterBlockSpec]) -> u64 {
156 let n_usize = self.response_val_basis.nrows();
171 let p_resp = self.response_val_basis.ncols() as u64;
172 let p_cov = self.covariate_design.ncols() as u64;
173 let expected_p_total = p_resp.saturating_mul(p_cov);
174 let p_total = match specs {
186 [] => expected_p_total,
187 [spec] if spec.design.ncols() as u64 == expected_p_total => spec.design.ncols() as u64,
188 _ => return u64::MAX,
189 };
190 let n = n_usize as u64;
191 crate::coefficient_cost::operator_aware_hessian_cost(
198 p_total,
199 n,
200 n.saturating_mul(p_resp.saturating_add(p_cov)),
201 n.saturating_mul(p_total.saturating_mul(p_total)),
202 )
203 }
204
205 fn coefficient_gradient_cost(&self, specs: &[ParameterBlockSpec]) -> u64 {
206 self.coefficient_hessian_cost(specs) / 2
210 }
211
212 fn outer_derivative_policy(
213 &self,
214 specs: &[crate::custom_family::ParameterBlockSpec],
215 psi_dim: usize,
216 options: &crate::custom_family::BlockwiseFitOptions,
217 ) -> crate::custom_family::OuterDerivativePolicy {
218 let capability = self.exact_outer_derivative_order(specs, options);
231 let n = specs.first().map_or(0u128, |s| s.design.nrows() as u128);
232 let p_total: u128 = specs
233 .iter()
234 .map(|s| s.design.ncols() as u128)
235 .fold(0u128, |acc, x| acc.saturating_add(x));
236 let rho_dim: u128 = specs
237 .iter()
238 .map(|s| s.penalties.len() as u128)
239 .fold(0u128, |acc, x| acc.saturating_add(x));
240 let k = rho_dim.saturating_add(psi_dim as u128).max(1);
241 let p_eff = p_total.max(1);
242 let work_grad = n.saturating_mul(k).saturating_mul(p_eff);
244 let dense_hess = work_grad.saturating_mul(p_eff);
250 let mfree_hess = work_grad.saturating_mul(rho_dim.max(1));
251 let work_hess = dense_hess.min(mfree_hess);
252 crate::custom_family::OuterDerivativePolicy {
253 capability,
254 predicted_hessian_work: work_hess,
255 predicted_gradient_work: work_grad,
256 subsample_capable: true,
266 }
267 }
268
269 fn outer_seed_config(&self, n_params: usize) -> gam_solve::seeding::SeedConfig {
270 gam_solve::seeding::SeedConfig {
271 bounds: (-12.0, 12.0),
272 max_seeds: if n_params <= 8 { 1 } else { 2 },
273 seed_budget: 1,
274 screen_max_inner_iterations: 2,
275 risk_profile: gam_solve::seeding::SeedRiskProfile::Gaussian,
276 num_auxiliary_trailing: 0,
277 over_smoothing_probe_rho: None,
278 }
279 }
280
281 fn max_feasible_step_size(
282 &self,
283 block_states: &[ParameterBlockState],
284 block_index: usize,
285 delta: &Array1<f64>,
286 ) -> Result<Option<f64>, String> {
287 if block_index != 0 {
288 return Ok(None);
289 }
290 crate::block_layout::block_count::validate_block_count::<
291 TransformationNormalError,
292 >("TransformationNormalFamily", 1, block_states.len())?;
293 if delta.len() != block_states[0].beta.len() {
294 return Err(TransformationNormalError::InvalidInput {
295 reason: format!(
296 "CTN line-search step length {} != beta length {}",
297 delta.len(),
298 block_states[0].beta.len()
299 ),
300 }
301 .into());
302 }
303 Ok(None)
309 }
310
311 fn block_linear_constraints(
312 &self,
313 _: &[ParameterBlockState],
314 block_index: usize,
315 block_spec: &ParameterBlockSpec,
316 ) -> Result<Option<LinearInequalityConstraints>, String> {
317 assert!(!block_spec.name.is_empty());
318 if block_index != 0 {
319 return Ok(None);
320 }
321 Ok(None)
325 }
326
327 fn exact_newton_hessian_directional_derivative(
328 &self,
329 block_states: &[ParameterBlockState],
330 block_index: usize,
331 d_beta: &Array1<f64>,
332 ) -> Result<Option<Array2<f64>>, String> {
333 if block_index != 0 {
334 return Ok(None);
335 }
336 let beta = &block_states[0].beta;
337 let row_quantities = self.row_quantities(beta)?;
338 let dd = self.scop_hessian_directional_derivative(beta, d_beta, &row_quantities)?;
339 Ok(Some(dd))
340 }
341
342 fn exact_newton_joint_hessian(
343 &self,
344 block_states: &[ParameterBlockState],
345 ) -> Result<Option<Array2<f64>>, String> {
346 let beta = &block_states[0].beta;
348 let row_quantities = self.row_quantities(beta)?;
349 let (_, hessian) = self.scop_gradient_and_negative_hessian(beta, &row_quantities)?;
350 Ok(Some(hessian))
351 }
352
353 fn exact_newton_joint_hessian_directional_derivative(
354 &self,
355 block_states: &[ParameterBlockState],
356 d_beta_flat: &Array1<f64>,
357 ) -> Result<Option<Array2<f64>>, String> {
358 self.exact_newton_hessian_directional_derivative(block_states, 0, d_beta_flat)
359 }
360
361 fn exact_newton_joint_hessiansecond_directional_derivative(
362 &self,
363 block_states: &[ParameterBlockState],
364 d_beta_u_flat: &Array1<f64>,
365 d_beta_v_flat: &Array1<f64>,
366 ) -> Result<Option<Array2<f64>>, String> {
367 let beta = &block_states[0].beta;
368 let row_quantities = self.row_quantities(beta)?;
369 let d2 = self.scop_hessian_second_directional_derivative(
370 beta,
371 d_beta_u_flat,
372 d_beta_v_flat,
373 &row_quantities,
374 )?;
375 Ok(Some(d2))
376 }
377
378 fn exact_newton_joint_psi_terms(
379 &self,
380 block_states: &[ParameterBlockState],
381 _: &[ParameterBlockSpec],
382 psi_derivs: &[Vec<CustomFamilyBlockPsiDerivative>],
383 psi_index: usize,
384 ) -> Result<Option<ExactNewtonJointPsiTerms>, String> {
385 if psi_derivs.is_empty() || psi_index >= psi_derivs[0].len() {
386 return Ok(None);
387 }
388 let psi_first_start = std::time::Instant::now();
389 let deriv = &psi_derivs[0][psi_index];
390 let beta = &block_states[0].beta;
391 let row = self.row_quantities(beta)?;
392 let op = deriv
393 .implicit_operator
394 .as_ref()
395 .and_then(|op| op.as_any().downcast_ref::<TensorKroneckerPsiOperator>())
396 .ok_or_else(|| {
397 "TransformationNormalFamily requires tensor psi derivatives to remain operator-backed"
398 .to_string()
399 })?;
400 let axis = deriv.implicit_axis;
401 let op_arc = Arc::clone(
402 deriv
403 .implicit_operator
404 .as_ref()
405 .expect("validated CTN psi derivative operator disappeared"),
406 );
407 let terms = self.scop_psi_terms(beta, &row, op, op_arc, axis)?;
408
409 log::info!(
410 "[STAGE] CTN psi first-order terms axis={} psi_index={} elapsed={:.3}s",
411 deriv.implicit_axis,
412 psi_index,
413 psi_first_start.elapsed().as_secs_f64(),
414 );
415
416 Ok(Some(terms))
417 }
418
419 fn exact_newton_joint_psisecond_order_terms(
420 &self,
421 block_states: &[ParameterBlockState],
422 _: &[ParameterBlockSpec],
423 psi_derivs: &[Vec<CustomFamilyBlockPsiDerivative>],
424 psi_i: usize,
425 psi_j: usize,
426 ) -> Result<Option<ExactNewtonJointPsiSecondOrderTerms>, String> {
427 if psi_derivs.is_empty() || psi_i >= psi_derivs[0].len() || psi_j >= psi_derivs[0].len() {
428 return Ok(None);
429 }
430 let psi_pair_start = std::time::Instant::now();
431 let deriv_i = &psi_derivs[0][psi_i];
432 let deriv_j = &psi_derivs[0][psi_j];
433 let beta = &block_states[0].beta;
434 let row = self.row_quantities(beta)?;
435 let p_resp = self.response_val_basis.ncols();
436 let p_cov = self.covariate_design.ncols();
437 let p_total = p_resp * p_cov;
438 if beta.len() != p_total {
439 return Err(TransformationNormalError::InvalidInput {
440 reason: format!(
441 "SCOP psi-psi terms beta length {} != p_resp({p_resp}) * p_cov({p_cov})",
442 beta.len()
443 ),
444 }
445 .into());
446 }
447
448 let op = deriv_i
449 .implicit_operator
450 .as_ref()
451 .and_then(|op| op.as_any().downcast_ref::<TensorKroneckerPsiOperator>())
452 .ok_or_else(|| {
453 "TransformationNormalFamily requires tensor psi derivatives to remain operator-backed"
454 .to_string()
455 })?;
456 let axis_i = deriv_i.implicit_axis;
457 let axis_j = deriv_j.implicit_axis;
458
459 let (objective_psi_psi, score_psi_psi, _) = self
460 .scop_psi_psi_value_score_hvp_from_operator(
461 beta,
462 op,
463 axis_i,
464 axis_j,
465 row.gamma.view(),
466 row.h.view(),
467 row.h_prime.view(),
468 row.endpoint_q.as_slice(),
469 None,
470 )?;
471 let hessian_psi_psi_operator: Box<dyn HyperOperator> =
472 Box::new(TransformationNormalPsiPsiHessianOperator::new(
473 Arc::new(self.clone()),
474 beta.clone(),
475 Arc::clone(
476 deriv_i
477 .implicit_operator
478 .as_ref()
479 .expect("validated CTN psi derivative has an implicit operator"),
480 ),
481 axis_i,
482 axis_j,
483 Arc::clone(&row.gamma),
484 Arc::clone(&row.h),
485 Arc::clone(&row.h_prime),
486 Arc::clone(&row.endpoint_q),
487 ));
488
489 if !objective_psi_psi.is_finite() || !score_psi_psi.iter().all(|v| v.is_finite()) {
495 return Err(TransformationNormalError::NonFinite {
496 reason: format!(
497 "TransformationNormalFamily exact ψ-ψ second-order terms produced \
498 non-finite values at psi_i={psi_i}, psi_j={psi_j}: \
499 obj_finite={}, score_all_finite={}. \
500 The outer evaluator should retreat from this trial point.",
501 objective_psi_psi.is_finite(),
502 score_psi_psi.iter().all(|v| v.is_finite()),
503 ),
504 }
505 .into());
506 }
507
508 log::info!(
509 "[STAGE] CTN psi-psi pair (psi_i={}, psi_j={}, axes={},{}) elapsed={:.3}s",
510 psi_i,
511 psi_j,
512 deriv_i.implicit_axis,
513 deriv_j.implicit_axis,
514 psi_pair_start.elapsed().as_secs_f64(),
515 );
516
517 Ok(Some(ExactNewtonJointPsiSecondOrderTerms {
518 objective_psi_psi,
519 score_psi_psi,
520 hessian_psi_psi: Array2::zeros((0, 0)),
521 hessian_psi_psi_operator: Some(hessian_psi_psi_operator),
522 }))
523 }
524
525 fn exact_newton_joint_psihessian_directional_derivative(
526 &self,
527 block_states: &[ParameterBlockState],
528 _: &[ParameterBlockSpec],
529 psi_derivs: &[Vec<CustomFamilyBlockPsiDerivative>],
530 psi_index: usize,
531 d_beta_flat: &Array1<f64>,
532 ) -> Result<Option<Array2<f64>>, String> {
533 if psi_derivs.is_empty() || psi_index >= psi_derivs[0].len() {
534 return Ok(None);
535 }
536 let deriv = &psi_derivs[0][psi_index];
537 let beta = &block_states[0].beta;
538 let op = deriv
539 .implicit_operator
540 .as_ref()
541 .and_then(|op| op.as_any().downcast_ref::<TensorKroneckerPsiOperator>())
542 .ok_or_else(|| {
543 "TransformationNormalFamily requires tensor psi derivatives to remain operator-backed"
544 .to_string()
545 })?;
546 let axis = deriv.implicit_axis;
547 let row = self.row_quantities(beta)?;
548 let hess =
549 self.scop_psi_hessian_directional_derivative(beta, d_beta_flat, &row, op, axis)?;
550 Ok(Some(hess))
551 }
552
553 fn exact_newton_joint_hessian_workspace(
554 &self,
555 block_states: &[ParameterBlockState],
556 specs: &[ParameterBlockSpec],
557 ) -> Result<Option<Arc<dyn ExactNewtonJointHessianWorkspace>>, String> {
558 crate::block_layout::block_count::validate_block_count::<
559 TransformationNormalError,
560 >("TransformationNormalFamily", 1, block_states.len())?;
561 if !self.inner_coefficient_hessian_hvp_available(specs) {
562 return Err(TransformationNormalError::InvalidInput {
563 reason: "TransformationNormalFamily joint Hessian workspace received incompatible block specs"
564 .to_string(),
565 }
566 .into());
567 }
568 let beta = &block_states[0].beta;
569 let row_quantities = self.row_quantities(beta)?;
570 let workspace = TransformationNormalJointHessianWorkspace::new(
574 Arc::new(self.clone()),
575 beta.clone(),
576 row_quantities.clone(),
577 )?;
578 Ok(Some(
579 Arc::new(workspace) as Arc<dyn ExactNewtonJointHessianWorkspace>
580 ))
581 }
582
583 fn exact_newton_joint_psi_workspace(
584 &self,
585 block_states: &[ParameterBlockState],
586 specs: &[ParameterBlockSpec],
587 derivative_blocks: &[Vec<CustomFamilyBlockPsiDerivative>],
588 ) -> Result<Option<Arc<dyn ExactNewtonJointPsiWorkspace>>, String> {
589 if !self.inner_coefficient_hessian_hvp_available(specs) {
590 return Err(TransformationNormalError::InvalidInput {
591 reason: "TransformationNormalFamily joint psi workspace received incompatible block specs"
592 .to_string(),
593 }
594 .into());
595 }
596 Ok(Some(Arc::new(TransformationNormalPsiWorkspace::new(
597 self.clone(),
598 block_states.to_vec(),
599 derivative_blocks.to_vec(),
600 ))))
601 }
602
603 fn exact_newton_joint_hessian_workspace_with_options(
604 &self,
605 block_states: &[ParameterBlockState],
606 specs: &[ParameterBlockSpec],
607 options: &BlockwiseFitOptions,
608 ) -> Result<Option<Arc<dyn ExactNewtonJointHessianWorkspace>>, String> {
609 match self.maybe_with_outer_subsample_from_options(options)? {
617 Some(masked) => masked.exact_newton_joint_hessian_workspace(block_states, specs),
618 None => self.exact_newton_joint_hessian_workspace(block_states, specs),
619 }
620 }
621
622 fn exact_newton_joint_psi_workspace_with_options(
623 &self,
624 block_states: &[ParameterBlockState],
625 specs: &[ParameterBlockSpec],
626 derivative_blocks: &[Vec<CustomFamilyBlockPsiDerivative>],
627 options: &BlockwiseFitOptions,
628 ) -> Result<Option<Arc<dyn ExactNewtonJointPsiWorkspace>>, String> {
629 if !self.inner_coefficient_hessian_hvp_available(specs) {
630 return Err(TransformationNormalError::InvalidInput {
631 reason: "TransformationNormalFamily joint psi workspace received incompatible block specs"
632 .to_string(),
633 }
634 .into());
635 }
636 let family = match self.maybe_with_outer_subsample_from_options(options)? {
649 Some(masked) => masked,
650 None => self.clone(),
651 };
652 Ok(Some(Arc::new(TransformationNormalPsiWorkspace::new(
653 family,
654 block_states.to_vec(),
655 derivative_blocks.to_vec(),
656 ))))
657 }
658
659 fn exact_newton_joint_psi_workspace_for_first_order_terms(&self) -> bool {
660 true
666 }
667
668 fn inner_coefficient_hessian_hvp_available(&self, specs: &[ParameterBlockSpec]) -> bool {
669 matches!(specs, [spec] if spec.design.ncols()
672 == self.response_val_basis.ncols().saturating_mul(self.covariate_design.ncols()))
673 }
674
675 fn outer_hyper_hessian_hvp_available(&self, specs: &[ParameterBlockSpec]) -> bool {
676 self.inner_coefficient_hessian_hvp_available(specs)
677 }
678
679 fn outer_hyper_hessian_dense_available(&self, specs: &[ParameterBlockSpec]) -> bool {
680 self.inner_coefficient_hessian_hvp_available(specs)
684 }
685}