1use super::*;
2
3pub(crate) fn sas_log_deltaridgeweight() -> f64 {
4 1e-4
7}
8
9#[inline]
10pub(crate) fn sas_log_delta_edge_barrierweight() -> f64 {
11 1e-2
14}
15
16#[inline]
17pub(crate) fn sas_log_delta_bound() -> f64 {
18 crate::mixture_link::SAS_LOG_DELTA_BOUND
19}
20
21#[inline]
22pub(crate) fn sas_log_delta_edge_barriercostgrad(raw_log_delta: f64) -> (f64, f64) {
23 let w = sas_log_delta_edge_barrierweight();
24 if w <= 0.0 || !raw_log_delta.is_finite() {
25 return (0.0, 0.0);
26 }
27 let b = sas_log_delta_bound().max(f64::EPSILON);
28 let t = (raw_log_delta / b).tanh();
29 let one_minus_t2 = (1.0 - t * t).max(1e-12);
30 let cost = -w * one_minus_t2.ln();
31 let grad = (2.0 * w / b) * t;
33 (cost, grad)
34}
35
36#[inline]
37pub(crate) fn sas_epsilon_bound() -> f64 {
38 8.0
40}
41
42#[inline]
43pub(crate) fn sas_effective_epsilon(raw_epsilon: f64) -> (f64, f64) {
44 let bound = sas_epsilon_bound().max(f64::EPSILON);
45 let t = (raw_epsilon / bound).tanh();
46 let epsilon = bound * t;
47 let d_epsilon_d_raw = 1.0 - t * t;
48 (epsilon, d_epsilon_d_raw)
49}
50
51#[inline]
52pub(crate) fn sas_effective_epsilon_second(raw_epsilon: f64) -> (f64, f64, f64) {
53 let bound = sas_epsilon_bound().max(f64::EPSILON);
54 let t = (raw_epsilon / bound).tanh();
55 let first = 1.0 - t * t;
56 let second = -2.0 * t * first / bound;
57 (bound * t, first, second)
58}
59
60#[inline]
61pub(crate) fn sas_log_delta_edge_barriercostgradhess(raw_log_delta: f64) -> (f64, f64, f64) {
62 let w = sas_log_delta_edge_barrierweight();
63 if w <= 0.0 || !raw_log_delta.is_finite() {
64 return (0.0, 0.0, 0.0);
65 }
66 let b = sas_log_delta_bound().max(f64::EPSILON);
67 let t = (raw_log_delta / b).tanh();
68 let one_minus_t2 = (1.0 - t * t).max(1e-12);
69 let cost = -w * one_minus_t2.ln();
70 let grad = (2.0 * w / b) * t;
71 let hess = (2.0 * w / (b * b)) * one_minus_t2;
72 (cost, grad, hess)
73}
74
75pub(crate) fn materialize_link_outer_hessian(
76 hessian: gam_problem::HessianResult,
77 theta_dim: usize,
78) -> Result<Array2<f64>, EstimationError> {
79 match hessian.materialize_dense() {
80 Ok(Some(h)) => {
81 if h.nrows() != theta_dim || h.ncols() != theta_dim {
82 crate::bail_invalid_estim!(
83 "unified evaluator Hessian shape {}x{} != theta_dim {}",
84 h.nrows(),
85 h.ncols(),
86 theta_dim
87 );
88 }
89 Ok(h)
90 }
91 Ok(None) => Err(EstimationError::InvalidInput(
92 "unified evaluator returned no analytic Hessian in ValueGradientHessian mode"
93 .to_string(),
94 )),
95 Err(err) => Err(EstimationError::InvalidInput(format!(
96 "failed to materialize analytic link Hessian: {err}"
97 ))),
98 }
99}
100
101pub fn evaluate_externalgradient<X>(
103 y: ArrayView1<'_, f64>,
104 w: ArrayView1<'_, f64>,
105 x: X,
106 offset: ArrayView1<'_, f64>,
107 s_list: &[BlockwisePenalty],
108 opts: &ExternalOptimOptions,
109 rho: &Array1<f64>,
110) -> Result<Array1<f64>, EstimationError>
111where
112 X: Into<DesignMatrix>,
113{
114 let specs: Vec<PenaltySpec> = s_list.iter().map(PenaltySpec::from_blockwise_ref).collect();
115 let x = x.into();
116 if let Some(message) = row_mismatch_message(y.len(), w.len(), x.nrows(), offset.len()) {
117 crate::bail_invalid_estim!("{}", message);
118 }
119
120 let p = x.ncols();
121 validate_penalty_specs(&specs, p, "evaluate_externalgradient")?;
122 let (canonical, active_nullspace_dims) = gam_terms::construction::canonicalize_penalty_specs(
123 &specs,
124 &opts.nullspace_dims,
125 p,
126 "evaluate_externalgradient",
127 )?;
128 if rho.len() != active_nullspace_dims.len() {
129 crate::bail_invalid_estim!(
130 "rho dimension mismatch: rho_dim={}, active_penalties={}",
131 rho.len(),
132 active_nullspace_dims.len()
133 );
134 }
135
136 let (cfg, _) = resolved_external_config(opts)?;
137
138 let y_o = y.to_owned();
139 let w_o = w.to_owned();
140 let offset_o = offset.to_owned();
141 let conditioning = ParametricColumnConditioning::infer_from_penalty_specs(&x, &specs);
142 let x_fit = conditioning.apply_to_design(&x);
143 let fit_linear_constraints =
144 conditioning.transform_linear_constraints_to_internal(opts.linear_constraints.clone());
145
146 let mut reml_state = RemlState::newwith_offset(
147 y_o.view(),
148 x_fit,
149 w_o.view(),
150 offset_o.view(),
151 canonical,
152 p,
153 &cfg,
154 Some(active_nullspace_dims),
155 None,
156 fit_linear_constraints,
157 )?;
158 reml_state.set_penalty_shrinkage_floor(opts.penalty_shrinkage_floor);
159 reml_state.set_rho_prior(opts.rho_prior.clone());
160 reml_state.set_link_states(
161 cfg.link_kind.mixture_state().cloned(),
162 cfg.link_kind.sas_state().copied(),
163 );
164
165 reml_state.compute_gradient(rho)
166}
167
168fn gaussian_identity_inner_residual_norm(
169 y: ArrayView1<'_, f64>,
170 w: ArrayView1<'_, f64>,
171 x: &DesignMatrix,
172 offset: ArrayView1<'_, f64>,
173 canonical_penalties: &[gam_terms::construction::CanonicalPenalty],
174 rho: &Array1<f64>,
175 beta: &Array1<f64>,
176) -> Result<f64, EstimationError> {
177 if beta.len() != x.ncols() {
178 crate::bail_invalid_estim!(
179 "beta dimension mismatch: beta_dim={}, x_cols={}",
180 beta.len(),
181 x.ncols()
182 );
183 }
184 if rho.len() != canonical_penalties.len() {
185 crate::bail_invalid_estim!(
186 "rho dimension mismatch: rho_dim={}, active_penalties={}",
187 rho.len(),
188 canonical_penalties.len()
189 );
190 }
191
192 let mut residual = x.apply(beta);
193 residual += &offset;
194 residual -= &y;
195 residual *= &w;
196 let mut gradient = x.apply_transpose(&residual);
197
198 for (k, cp) in canonical_penalties.iter().enumerate() {
199 let lambda = rho[k].exp();
200 if lambda == 0.0 || cp.rank() == 0 {
201 continue;
202 }
203 let r = cp.col_range.clone();
204 let centered = &beta.slice(s![r.start..r.end]) - &cp.prior_mean;
205 let penalty_grad = cp.local.dot(¢ered) * lambda;
206 gradient
207 .slice_mut(s![r.start..r.end])
208 .scaled_add(1.0, &penalty_grad);
209 }
210
211 Ok(gradient.iter().map(|v| v * v).sum::<f64>().sqrt())
212}
213
214pub fn evaluate_external_ift_residual_at_perturbed_rho<X>(
262 y: ArrayView1<'_, f64>,
263 w: ArrayView1<'_, f64>,
264 x: X,
265 offset: ArrayView1<'_, f64>,
266 s_list: &[BlockwisePenalty],
267 opts: &ExternalOptimOptions,
268 rho: &Array1<f64>,
269 delta_rho: ArrayView1<'_, f64>,
270) -> Result<(f64, f64), EstimationError>
271where
272 X: Into<DesignMatrix>,
273{
274 if !opts.family.is_gaussian_identity() {
275 crate::bail_invalid_estim!(
276 "evaluate_external_ift_residual_at_perturbed_rho currently supports GaussianIdentity"
277 .to_string(),
278 );
279 }
280 if opts.linear_constraints.is_some() {
281 crate::bail_invalid_estim!(
282 "evaluate_external_ift_residual_at_perturbed_rho does not support constrained fits"
283 .to_string(),
284 );
285 }
286
287 let specs: Vec<PenaltySpec> = s_list.iter().map(PenaltySpec::from_blockwise_ref).collect();
288 let x = x.into();
289 if let Some(message) = row_mismatch_message(y.len(), w.len(), x.nrows(), offset.len()) {
290 crate::bail_invalid_estim!("{}", message);
291 }
292
293 let p = x.ncols();
294 validate_penalty_specs(&specs, p, "evaluate_external_ift_residual_at_perturbed_rho")?;
295 let (canonical, active_nullspace_dims) = gam_terms::construction::canonicalize_penalty_specs(
296 &specs,
297 &opts.nullspace_dims,
298 p,
299 "evaluate_external_ift_residual_at_perturbed_rho",
300 )?;
301 if rho.len() != active_nullspace_dims.len() {
302 crate::bail_invalid_estim!(
303 "rho dimension mismatch: rho_dim={}, active_penalties={}",
304 rho.len(),
305 active_nullspace_dims.len()
306 );
307 }
308 if delta_rho.len() != rho.len() {
309 crate::bail_invalid_estim!(
310 "delta_rho dimension mismatch: delta_dim={}, rho_dim={}",
311 delta_rho.len(),
312 rho.len()
313 );
314 }
315
316 let mut tight_opts = opts.clone();
317 tight_opts.tol = 1e-12;
318 let (cfg, _) = resolved_external_config(&tight_opts)?;
319
320 let y_o = y.to_owned();
321 let w_o = w.to_owned();
322 let offset_o = offset.to_owned();
323 let conditioning = ParametricColumnConditioning::infer_from_penalty_specs(&x, &specs);
324 let x_fit = conditioning.apply_to_design(&x);
325 let fit_linear_constraints =
326 conditioning.transform_linear_constraints_to_internal(tight_opts.linear_constraints);
327
328 let mut reml_state = RemlState::newwith_offset(
329 y_o.view(),
330 x_fit.clone(),
331 w_o.view(),
332 offset_o.view(),
333 canonical.clone(),
334 p,
335 &cfg,
336 Some(active_nullspace_dims),
337 None,
338 fit_linear_constraints,
339 )?;
340 reml_state.set_penalty_shrinkage_floor(tight_opts.penalty_shrinkage_floor);
341 reml_state.set_rho_prior(tight_opts.rho_prior.clone());
342 reml_state.set_link_states(
343 cfg.link_kind.mixture_state().cloned(),
344 cfg.link_kind.sas_state().copied(),
345 );
346
347 reml_state.compute_gradient(rho)?;
348 let beta_hat = reml_state
349 .warm_start_beta
350 .read()
351 .unwrap()
352 .as_ref()
353 .map(|beta| beta.0.clone())
354 .ok_or_else(|| {
355 EstimationError::InvalidInput(
356 "PIRLS solve did not populate the warm-start beta cache".to_string(),
357 )
358 })?;
359
360 let rho_perturbed = rho + &delta_rho.to_owned();
361 let beta_pred = reml_state
362 .predict_warm_start_beta_ift_with_outcome(&rho_perturbed)
363 .map(|(beta, _)| beta.as_ref().clone())
364 .ok_or_else(|| {
365 EstimationError::InvalidInput(
366 "IFT warm-start predictor rejected the perturbed rho".to_string(),
367 )
368 })?;
369
370 let ift_residual = gaussian_identity_inner_residual_norm(
371 y_o.view(),
372 w_o.view(),
373 &x_fit,
374 offset_o.view(),
375 &canonical,
376 &rho_perturbed,
377 &beta_pred,
378 )?;
379 let flat_residual = gaussian_identity_inner_residual_norm(
380 y_o.view(),
381 w_o.view(),
382 &x_fit,
383 offset_o.view(),
384 &canonical,
385 &rho_perturbed,
386 &beta_hat,
387 )?;
388
389 Ok((ift_residual, flat_residual))
390}
391
392pub fn evaluate_externalcost_andridge<X>(
395 y: ArrayView1<'_, f64>,
396 w: ArrayView1<'_, f64>,
397 x: X,
398 offset: ArrayView1<'_, f64>,
399 s_list: &[BlockwisePenalty],
400 opts: &ExternalOptimOptions,
401 rho: &Array1<f64>,
402) -> Result<(f64, f64), EstimationError>
403where
404 X: Into<DesignMatrix>,
405{
406 let specs: Vec<PenaltySpec> = s_list.iter().map(PenaltySpec::from_blockwise_ref).collect();
407 let x = x.into();
408 if let Some(message) = row_mismatch_message(y.len(), w.len(), x.nrows(), offset.len()) {
409 crate::bail_invalid_estim!("{}", message);
410 }
411
412 let p = x.ncols();
413 validate_penalty_specs(&specs, p, "evaluate_externalcost_andridge")?;
414 let (canonical, active_nullspace_dims) = gam_terms::construction::canonicalize_penalty_specs(
415 &specs,
416 &opts.nullspace_dims,
417 p,
418 "evaluate_externalcost_andridge",
419 )?;
420 if rho.len() != active_nullspace_dims.len() {
421 crate::bail_invalid_estim!(
422 "rho dimension mismatch: rho_dim={}, active_penalties={}",
423 rho.len(),
424 active_nullspace_dims.len()
425 );
426 }
427
428 let (cfg, _) = resolved_external_config(opts)?;
429
430 let y_o = y.to_owned();
431 let w_o = w.to_owned();
432 let offset_o = offset.to_owned();
433 let conditioning = ParametricColumnConditioning::infer_from_penalty_specs(&x, &specs);
434 let x_fit = conditioning.apply_to_design(&x);
435 let fit_linear_constraints =
436 conditioning.transform_linear_constraints_to_internal(opts.linear_constraints.clone());
437
438 let mut reml_state = RemlState::newwith_offset(
439 y_o.view(),
440 x_fit,
441 w_o.view(),
442 offset_o.view(),
443 canonical,
444 p,
445 &cfg,
446 Some(active_nullspace_dims),
447 None,
448 fit_linear_constraints,
449 )?;
450 reml_state.set_penalty_shrinkage_floor(opts.penalty_shrinkage_floor);
451 reml_state.set_rho_prior(opts.rho_prior.clone());
452 reml_state.set_link_states(
453 cfg.link_kind.mixture_state().cloned(),
454 cfg.link_kind.sas_state().copied(),
455 );
456
457 let cost = reml_state.compute_cost(rho)?;
458 let ridge = reml_state.last_ridge_used().unwrap_or(0.0);
459 Ok((cost, ridge))
460}