1use crate::custom_family::{CustomFamilyBlockPsiDerivative, ParameterBlockSpec};
33use crate::cubic_cell_kernel::{self, DenestedPartitionCell, LocalSpanCubic};
34use crate::outer_subsample::{OuterScoreSubsample, WeightedOuterRow};
35use gam_math::jet_partitions::MultiDirJet;
36use ndarray::{Array1, Array2, Axis};
37use std::ops::Range;
38use std::sync::Arc;
39
40pub fn make_beta_seed_validator(
55 pending: &std::cell::RefCell<Option<Array1<f64>>>,
56) -> impl FnMut(
57 &Array1<f64>,
58)
59 -> Result<gam_solve::rho_optimizer::SeedOutcome, crate::model_types::EstimationError>
60+ '_ {
61 move |beta: &Array1<f64>| {
62 bail_if_cached_beta_non_finite(beta)?;
63 pending.replace(Some(beta.clone()));
69 Ok(gam_solve::rho_optimizer::SeedOutcome::Installed)
70 }
71}
72
73pub use gam_problem::bail_if_cached_beta_non_finite;
81
82#[inline]
83pub const fn eval_coeff4_at(coefficients: &[f64; 4], z: f64) -> f64 {
84 ((coefficients[3] * z + coefficients[2]) * z + coefficients[1]) * z + coefficients[0]
85}
86
87#[inline]
88pub fn add_scaled_coeff4(target: &mut [f64; 4], source: &[f64; 4], scale: f64) {
89 for j in 0..4 {
90 target[j] += scale * source[j];
91 }
92}
93
94#[inline]
95fn coeff4_dot(left: &[f64; 4], right: &[f64; 4]) -> f64 {
96 left[0] * right[0] + left[1] * right[1] + left[2] * right[2] + left[3] * right[3]
97}
98
99#[inline]
100pub const fn scale_coeff4(source: [f64; 4], scale: f64) -> [f64; 4] {
101 [
102 source[0] * scale,
103 source[1] * scale,
104 source[2] * scale,
105 source[3] * scale,
106 ]
107}
108
109pub fn probit_frailty_scale(gaussian_frailty_sd: Option<f64>) -> f64 {
110 let sigma = gaussian_frailty_sd.unwrap_or(0.0);
111 if sigma <= 0.0 {
112 1.0
113 } else {
114 crate::survival::lognormal_kernel::ProbitFrailtyScaleJet::from_log_sigma(
115 sigma.ln(),
116 )
117 .s
118 }
119}
120
121pub(crate) fn probit_frailty_scale_multi_dir_jet(
122 gaussian_frailty_sd: Option<f64>,
123 missing_sigma_message: &str,
124 n_dirs: usize,
125 first_masks: &[usize],
126 second_masks: &[usize],
127) -> Result<MultiDirJet, String> {
128 let sigma = gaussian_frailty_sd.ok_or_else(|| missing_sigma_message.to_string())?;
129 let jet = crate::survival::lognormal_kernel::ProbitFrailtyScaleJet::from_log_sigma(
130 sigma.ln(),
131 );
132 let mut coeffs = Vec::with_capacity(1 + first_masks.len() + second_masks.len());
133 coeffs.push((0usize, jet.s));
134 coeffs.extend(first_masks.iter().copied().map(|mask| (mask, jet.ds)));
135 coeffs.extend(second_masks.iter().copied().map(|mask| (mask, jet.d2s)));
136 Ok(MultiDirJet::with_coeffs(n_dirs, &coeffs))
137}
138
139#[derive(Clone)]
150pub(crate) struct DirectionalScaleJets {
151 pub(crate) obj: Option<MultiDirJet>,
152 pub(crate) grad: MultiDirJet,
153 pub(crate) hess: MultiDirJet,
154}
155
156pub(crate) struct DirectionalPrimaryTerms {
159 pub(crate) objective: f64,
160 pub(crate) grad: Array1<f64>,
161 pub(crate) hess: Array2<f64>,
162}
163
164pub(crate) fn directional_obj_grad_hess<Eval>(
182 primary_dim: usize,
183 leading: &[&Array1<f64>],
184 scales: &DirectionalScaleJets,
185 eval: Eval,
186) -> Result<DirectionalPrimaryTerms, String>
187where
188 Eval: Fn(&[&Array1<f64>], &MultiDirJet) -> Result<f64, String>,
189{
190 let objective = if let Some(scale_obj) = scales.obj.as_ref() {
191 eval(leading, scale_obj)?
192 } else {
193 0.0
194 };
195
196 let unit = |a: usize| -> Array1<f64> {
197 let mut da = Array1::<f64>::zeros(primary_dim);
198 da[a] = 1.0;
199 da
200 };
201
202 let units: Vec<Array1<f64>> = (0..primary_dim).map(unit).collect();
203
204 let mut grad = Array1::<f64>::zeros(primary_dim);
205 let mut dirs: Vec<&Array1<f64>> = Vec::with_capacity(leading.len() + 2);
206 for a in 0..primary_dim {
207 dirs.clear();
208 dirs.extend_from_slice(leading);
209 dirs.push(&units[a]);
210 grad[a] = eval(&dirs, &scales.grad)?;
211 }
212
213 let mut hess = Array2::<f64>::zeros((primary_dim, primary_dim));
214 for a in 0..primary_dim {
215 for b in a..primary_dim {
216 dirs.clear();
217 dirs.extend_from_slice(leading);
218 dirs.push(&units[a]);
219 dirs.push(&units[b]);
220 let value = eval(&dirs, &scales.hess)?;
221 hess[[a, b]] = value;
222 hess[[b, a]] = value;
223 }
224 }
225
226 Ok(DirectionalPrimaryTerms {
227 objective,
228 grad,
229 hess,
230 })
231}
232
233fn zero_local_span_cubic() -> LocalSpanCubic {
234 LocalSpanCubic {
235 left: 0.0,
236 right: 1.0,
237 c0: 0.0,
238 c1: 0.0,
239 c2: 0.0,
240 c3: 0.0,
241 }
242}
243
244pub(crate) fn build_denested_partition_cells(
245 a: f64,
246 b: f64,
247 score_warp: Option<&crate::bms::DeviationRuntime>,
248 beta_h: Option<&Array1<f64>>,
249 link_dev: Option<&crate::bms::DeviationRuntime>,
250 beta_w: Option<&Array1<f64>>,
251 scale: f64,
252) -> Result<Vec<DenestedPartitionCell>, String> {
253 let score_breaks = score_warp
254 .map(|runtime| runtime.breakpoints().to_vec())
255 .unwrap_or_default();
256 let link_breaks = link_dev
257 .map(|runtime| runtime.breakpoints().to_vec())
258 .unwrap_or_default();
259
260 let mut cells = cubic_cell_kernel::build_denested_partition_cells_with_tails(
261 a,
262 b,
263 &score_breaks,
264 &link_breaks,
265 |z| {
266 if let (Some(runtime), Some(beta)) = (score_warp, beta_h) {
267 runtime.local_cubic_at(beta, z)
268 } else {
269 Ok(zero_local_span_cubic())
270 }
271 },
272 |u| {
273 if let (Some(runtime), Some(beta)) = (link_dev, beta_w) {
274 runtime.local_cubic_at(beta, u)
275 } else {
276 Ok(zero_local_span_cubic())
277 }
278 },
279 )?;
280 if scale != 1.0 {
281 for partition_cell in &mut cells {
282 partition_cell.cell.c0 *= scale;
283 partition_cell.cell.c1 *= scale;
284 partition_cell.cell.c2 *= scale;
285 partition_cell.cell.c3 *= scale;
286 }
287 }
288 Ok(cells)
289}
290
291pub(crate) struct ObservedDenestedCellPartials {
292 pub(crate) coeff: [f64; 4],
293 pub(crate) dc_da: [f64; 4],
294 pub(crate) dc_db: [f64; 4],
295 pub(crate) dc_daa: [f64; 4],
296 pub(crate) dc_dab: [f64; 4],
297 pub(crate) dc_dbb: [f64; 4],
298 pub(crate) dc_daaa: [f64; 4],
299 pub(crate) dc_daab: [f64; 4],
300 pub(crate) dc_dabb: [f64; 4],
301 pub(crate) dc_dbbb: [f64; 4],
302}
303
304pub(crate) fn observed_denested_cell_partials(
305 z_obs: f64,
306 a: f64,
307 b: f64,
308 score_warp: Option<&crate::bms::DeviationRuntime>,
309 beta_h: Option<&Array1<f64>>,
310 link_dev: Option<&crate::bms::DeviationRuntime>,
311 beta_w: Option<&Array1<f64>>,
312 scale: f64,
313) -> Result<ObservedDenestedCellPartials, String> {
314 let zero_score_span = zero_local_span_cubic();
315 let zero_link_span = zero_local_span_cubic();
316 let u_obs = a + b * z_obs;
317 let score_span_obs = if let (Some(runtime), Some(beta_h)) = (score_warp, beta_h) {
318 runtime.local_cubic_at(beta_h, z_obs)?
319 } else {
320 zero_score_span
321 };
322 let link_span_obs = if let (Some(runtime), Some(beta_w)) = (link_dev, beta_w) {
323 runtime.local_cubic_at(beta_w, u_obs)?
324 } else {
325 zero_link_span
326 };
327 let coeff = scale_coeff4(
328 cubic_cell_kernel::denested_cell_coefficients(score_span_obs, link_span_obs, a, b),
329 scale,
330 );
331 let (dc_da_raw, dc_db_raw) =
332 cubic_cell_kernel::denested_cell_coefficient_partials(score_span_obs, link_span_obs, a, b);
333 let (dc_daa_raw, dc_dab_raw, dc_dbb_raw) =
334 cubic_cell_kernel::denested_cell_second_partials(score_span_obs, link_span_obs, a, b);
335 let (dc_daaa, dc_daab, dc_dabb, dc_dbbb) =
336 cubic_cell_kernel::denested_cell_third_partials(link_span_obs);
337 Ok(ObservedDenestedCellPartials {
338 coeff,
339 dc_da: scale_coeff4(dc_da_raw, scale),
340 dc_db: scale_coeff4(dc_db_raw, scale),
341 dc_daa: scale_coeff4(dc_daa_raw, scale),
342 dc_dab: scale_coeff4(dc_dab_raw, scale),
343 dc_dbb: scale_coeff4(dc_dbb_raw, scale),
344 dc_daaa: scale_coeff4(dc_daaa, scale),
345 dc_daab: scale_coeff4(dc_daab, scale),
346 dc_dabb: scale_coeff4(dc_dabb, scale),
347 dc_dbbb: scale_coeff4(dc_dbbb, scale),
348 })
349}
350
351pub(crate) fn add_two_surface_psi_outer(
352 block_i: usize,
353 psi_row_i: &Array1<f64>,
354 block_j: usize,
355 psi_row_j: &Array1<f64>,
356 alpha: f64,
357 marginal_block: usize,
358 logslope_block: usize,
359 h_mm: &mut Array2<f64>,
360 h_gg: &mut Array2<f64>,
361 h_mg: &mut Array2<f64>,
362) {
363 if alpha == 0.0 {
364 return;
365 }
366 let col_i = psi_row_i.view().insert_axis(Axis(1));
367 let row_j = psi_row_j.view().insert_axis(Axis(0));
368
369 if block_i == block_j {
370 let col_j = psi_row_j.view().insert_axis(Axis(1));
371 let row_i = psi_row_i.view().insert_axis(Axis(0));
372 let target = match block_i {
373 b if b == marginal_block => h_mm,
374 b if b == logslope_block => h_gg,
375 _ => return,
376 };
377 ndarray::linalg::general_mat_mul(alpha, &col_i, &row_j, 1.0, target);
378 ndarray::linalg::general_mat_mul(alpha, &col_j, &row_i, 1.0, target);
379 } else {
380 let (marginal_row, logslope_row) = if block_i == marginal_block {
381 (psi_row_i, psi_row_j)
382 } else {
383 (psi_row_j, psi_row_i)
384 };
385 let m_col = marginal_row.view().insert_axis(Axis(1));
386 let g_row = logslope_row.view().insert_axis(Axis(0));
387 ndarray::linalg::general_mat_mul(alpha, &m_col, &g_row, 1.0, h_mg);
388 }
389}
390
391pub(crate) fn add_optional_vector(left: &mut Option<Array1<f64>>, right: &Option<Array1<f64>>) {
392 if let (Some(left), Some(right)) = (left.as_mut(), right.as_ref()) {
393 *left += right;
394 }
395}
396
397pub(crate) fn add_optional_matrix(left: &mut Option<Array2<f64>>, right: &Option<Array2<f64>>) {
398 if let (Some(left), Some(right)) = (left.as_mut(), right.as_ref()) {
399 *left += right;
400 }
401}
402
403pub(crate) fn psi_derivative_location(
404 derivative_blocks: &[Vec<CustomFamilyBlockPsiDerivative>],
405 psi_index: usize,
406) -> Option<(usize, usize)> {
407 let mut cursor = 0usize;
408 for (block_idx, block) in derivative_blocks.iter().enumerate() {
409 if psi_index < cursor + block.len() {
410 return Some((block_idx, psi_index - cursor));
411 }
412 cursor += block.len();
413 }
414 None
415}
416
417pub(crate) fn is_sigma_aux_index(
418 gaussian_frailty_sd: Option<f64>,
419 derivative_blocks: &[Vec<CustomFamilyBlockPsiDerivative>],
420 psi_index: usize,
421) -> bool {
422 let total = derivative_blocks.iter().map(Vec::len).sum::<usize>();
423 if gaussian_frailty_sd.is_none() || total == 0 || psi_index != total - 1 {
424 return false;
425 }
426 let Some((block_idx, local_idx)) = psi_derivative_location(derivative_blocks, psi_index) else {
427 return false;
428 };
429 let deriv = &derivative_blocks[block_idx][local_idx];
430 deriv.penalty_index.is_none()
431 && deriv.x_psi.is_empty()
432 && deriv.s_psi.is_empty()
433 && deriv.s_psi_components.is_none()
434 && deriv.x_psi_psi.is_none()
435 && deriv.s_psi_psi.is_none()
436}
437
438#[inline]
442pub(crate) fn parameter_block_specs_match_rows(
443 specs: &[ParameterBlockSpec],
444 expected_n: usize,
445) -> bool {
446 !specs.is_empty()
447 && specs
448 .iter()
449 .all(|spec| spec.design.nrows() == expected_n && spec.offset.len() == expected_n)
450}
451
452#[derive(Clone, Copy)]
453pub(crate) struct CoeffSupport {
454 pub(crate) include_primary: bool,
455 pub(crate) include_h: bool,
456 pub(crate) include_w: bool,
457}
458
459impl CoeffSupport {
460 #[inline]
461 pub(crate) fn without_primary(self) -> Self {
462 Self {
463 include_primary: false,
464 ..self
465 }
466 }
467}
468
469pub(crate) struct SparsePrimaryCoeffJetView<'a> {
470 primary_index: usize,
471 h_range: Option<Range<usize>>,
472 w_range: Option<Range<usize>>,
473 pub(crate) first: &'a [[f64; 4]],
474 pub(crate) a_first: &'a [[f64; 4]],
475 pub(crate) b_first: &'a [[f64; 4]],
476 pub(crate) aa_first: &'a [[f64; 4]],
477 pub(crate) ab_first: &'a [[f64; 4]],
478 pub(crate) bb_first: &'a [[f64; 4]],
479 pub(crate) aaa_first: &'a [[f64; 4]],
480 pub(crate) aab_first: &'a [[f64; 4]],
481 pub(crate) abb_first: &'a [[f64; 4]],
482 pub(crate) bbb_first: &'a [[f64; 4]],
483}
484
485impl<'a> SparsePrimaryCoeffJetView<'a> {
486 pub(crate) fn new(
487 primary_index: usize,
488 h_range: Option<&Range<usize>>,
489 w_range: Option<&Range<usize>>,
490 first: &'a [[f64; 4]],
491 a_first: &'a [[f64; 4]],
492 b_first: &'a [[f64; 4]],
493 aa_first: &'a [[f64; 4]],
494 ab_first: &'a [[f64; 4]],
495 bb_first: &'a [[f64; 4]],
496 aaa_first: &'a [[f64; 4]],
497 aab_first: &'a [[f64; 4]],
498 abb_first: &'a [[f64; 4]],
499 bbb_first: &'a [[f64; 4]],
500 ) -> Self {
501 Self {
502 primary_index,
503 h_range: h_range.cloned(),
504 w_range: w_range.cloned(),
505 first,
506 a_first,
507 b_first,
508 aa_first,
509 ab_first,
510 bb_first,
511 aaa_first,
512 aab_first,
513 abb_first,
514 bbb_first,
515 }
516 }
517
518 #[inline]
519 fn in_h_range(&self, idx: usize) -> bool {
520 self.h_range
521 .as_ref()
522 .map(|range| range.contains(&idx))
523 .unwrap_or(false)
524 }
525
526 #[inline]
527 fn in_w_range(&self, idx: usize) -> bool {
528 self.w_range
529 .as_ref()
530 .map(|range| range.contains(&idx))
531 .unwrap_or(false)
532 }
533
534 #[inline]
535 fn param_supported(&self, idx: usize, support: CoeffSupport) -> bool {
536 (support.include_primary && idx == self.primary_index)
537 || (support.include_h && self.in_h_range(idx))
538 || (support.include_w && self.in_w_range(idx))
539 }
540
541 pub(crate) fn directional_family(
542 &self,
543 family: &[[f64; 4]],
544 dir: &Array1<f64>,
545 support: CoeffSupport,
546 ) -> [f64; 4] {
547 let mut out = [0.0; 4];
548 if support.include_primary {
549 add_scaled_coeff4(
550 &mut out,
551 &family[self.primary_index],
552 dir[self.primary_index],
553 );
554 }
555 if support.include_h
556 && let Some(h_range) = self.h_range.as_ref()
557 {
558 for idx in h_range.clone() {
559 add_scaled_coeff4(&mut out, &family[idx], dir[idx]);
560 }
561 }
562 if support.include_w
563 && let Some(w_range) = self.w_range.as_ref()
564 {
565 for idx in w_range.clone() {
566 add_scaled_coeff4(&mut out, &family[idx], dir[idx]);
567 }
568 }
569 out
570 }
571
572 pub(crate) fn add_directional_family_adjoint(
573 &self,
574 family: &[[f64; 4]],
575 coeff_adjoint: &[f64; 4],
576 support: CoeffSupport,
577 direction_adjoint: &mut [f64],
578 ) {
579 assert!(direction_adjoint.len() > self.primary_index);
580 if support.include_primary {
581 direction_adjoint[self.primary_index] +=
582 coeff4_dot(coeff_adjoint, &family[self.primary_index]);
583 }
584 if support.include_h
585 && let Some(h_range) = self.h_range.as_ref()
586 {
587 for idx in h_range.clone() {
588 direction_adjoint[idx] += coeff4_dot(coeff_adjoint, &family[idx]);
589 }
590 }
591 if support.include_w
592 && let Some(w_range) = self.w_range.as_ref()
593 {
594 for idx in w_range.clone() {
595 direction_adjoint[idx] += coeff4_dot(coeff_adjoint, &family[idx]);
596 }
597 }
598 }
599
600 pub(crate) fn mixed_directional_from_b_family(
601 &self,
602 family: &[[f64; 4]],
603 dir_u: &Array1<f64>,
604 dir_v: &Array1<f64>,
605 support: CoeffSupport,
606 ) -> [f64; 4] {
607 let mut out = [0.0; 4];
608 let dir_u_primary = dir_u[self.primary_index];
609 let dir_v_primary = dir_v[self.primary_index];
610 if support.include_primary {
611 add_scaled_coeff4(
612 &mut out,
613 &family[self.primary_index],
614 dir_u_primary * dir_v_primary,
615 );
616 }
617 if support.include_h
618 && let Some(h_range) = self.h_range.as_ref()
619 {
620 for idx in h_range.clone() {
621 add_scaled_coeff4(
622 &mut out,
623 &family[idx],
624 dir_u_primary * dir_v[idx] + dir_v_primary * dir_u[idx],
625 );
626 }
627 }
628 if support.include_w
629 && let Some(w_range) = self.w_range.as_ref()
630 {
631 for idx in w_range.clone() {
632 add_scaled_coeff4(
633 &mut out,
634 &family[idx],
635 dir_u_primary * dir_v[idx] + dir_v_primary * dir_u[idx],
636 );
637 }
638 }
639 out
640 }
641
642 pub(crate) fn param_directional_from_b_family(
643 &self,
644 family: &[[f64; 4]],
645 param: usize,
646 dir: &Array1<f64>,
647 support: CoeffSupport,
648 ) -> [f64; 4] {
649 if param == self.primary_index {
650 return self.directional_family(family, dir, support);
651 }
652 if self.param_supported(param, support.without_primary()) {
653 let mut out = [0.0; 4];
654 add_scaled_coeff4(&mut out, &family[param], dir[self.primary_index]);
655 return out;
656 }
657 [0.0; 4]
658 }
659
660 pub(crate) fn add_param_directional_from_b_family_adjoint(
661 &self,
662 family: &[[f64; 4]],
663 param: usize,
664 coeff_adjoint: &[f64; 4],
665 support: CoeffSupport,
666 direction_adjoint: &mut [f64],
667 ) {
668 assert!(direction_adjoint.len() > self.primary_index);
669 if param == self.primary_index {
670 self.add_directional_family_adjoint(family, coeff_adjoint, support, direction_adjoint);
671 } else if self.param_supported(param, support.without_primary()) {
672 direction_adjoint[self.primary_index] += coeff4_dot(coeff_adjoint, &family[param]);
673 }
674 }
675
676 pub(crate) fn param_mixed_from_bb_family(
677 &self,
678 family: &[[f64; 4]],
679 param: usize,
680 dir_u: &Array1<f64>,
681 dir_v: &Array1<f64>,
682 support: CoeffSupport,
683 ) -> [f64; 4] {
684 if param == self.primary_index {
685 return self.mixed_directional_from_b_family(family, dir_u, dir_v, support);
686 }
687 if self.param_supported(param, support.without_primary()) {
688 let mut out = [0.0; 4];
689 add_scaled_coeff4(
690 &mut out,
691 &family[param],
692 dir_u[self.primary_index] * dir_v[self.primary_index],
693 );
694 return out;
695 }
696 [0.0; 4]
697 }
698
699 pub(crate) fn pair_from_b_family(
700 &self,
701 family: &[[f64; 4]],
702 u: usize,
703 v: usize,
704 support: CoeffSupport,
705 ) -> [f64; 4] {
706 if u == self.primary_index && v == self.primary_index {
707 if support.include_primary {
708 return family[self.primary_index];
709 }
710 return [0.0; 4];
711 }
712 if u == self.primary_index && self.param_supported(v, support.without_primary()) {
713 return family[v];
714 }
715 if v == self.primary_index && self.param_supported(u, support.without_primary()) {
716 return family[u];
717 }
718 [0.0; 4]
719 }
720
721 pub(crate) fn pair_directional_from_bb_family(
722 &self,
723 family: &[[f64; 4]],
724 u: usize,
725 v: usize,
726 dir: &Array1<f64>,
727 support: CoeffSupport,
728 ) -> [f64; 4] {
729 if u == self.primary_index && v == self.primary_index {
730 return self.directional_family(family, dir, support);
731 }
732 if u == self.primary_index && self.param_supported(v, support.without_primary()) {
733 let mut out = [0.0; 4];
734 add_scaled_coeff4(&mut out, &family[v], dir[self.primary_index]);
735 return out;
736 }
737 if v == self.primary_index && self.param_supported(u, support.without_primary()) {
738 let mut out = [0.0; 4];
739 add_scaled_coeff4(&mut out, &family[u], dir[self.primary_index]);
740 return out;
741 }
742 [0.0; 4]
743 }
744
745 pub(crate) fn add_pair_directional_from_bb_family_adjoint(
746 &self,
747 family: &[[f64; 4]],
748 u: usize,
749 v: usize,
750 coeff_adjoint: &[f64; 4],
751 support: CoeffSupport,
752 direction_adjoint: &mut [f64],
753 ) {
754 assert!(direction_adjoint.len() > self.primary_index);
755 if u == self.primary_index && v == self.primary_index {
756 self.add_directional_family_adjoint(family, coeff_adjoint, support, direction_adjoint);
757 } else if u == self.primary_index && self.param_supported(v, support.without_primary()) {
758 direction_adjoint[self.primary_index] += coeff4_dot(coeff_adjoint, &family[v]);
759 } else if v == self.primary_index && self.param_supported(u, support.without_primary()) {
760 direction_adjoint[self.primary_index] += coeff4_dot(coeff_adjoint, &family[u]);
761 }
762 }
763
764 pub(crate) fn pair_mixed_from_bbb_family(
765 &self,
766 family: &[[f64; 4]],
767 u: usize,
768 v: usize,
769 dir_u: &Array1<f64>,
770 dir_v: &Array1<f64>,
771 support: CoeffSupport,
772 ) -> [f64; 4] {
773 if u == self.primary_index && v == self.primary_index {
774 return self.mixed_directional_from_b_family(family, dir_u, dir_v, support);
775 }
776 if u == self.primary_index && self.param_supported(v, support.without_primary()) {
777 let mut out = [0.0; 4];
778 add_scaled_coeff4(
779 &mut out,
780 &family[v],
781 dir_u[self.primary_index] * dir_v[self.primary_index],
782 );
783 return out;
784 }
785 if v == self.primary_index && self.param_supported(u, support.without_primary()) {
786 let mut out = [0.0; 4];
787 add_scaled_coeff4(
788 &mut out,
789 &family[u],
790 dir_u[self.primary_index] * dir_v[self.primary_index],
791 );
792 return out;
793 }
794 [0.0; 4]
795 }
796}
797
798#[inline]
817const fn splitmix64(state: &mut u64) -> u64 {
818 gam_linalg::utils::splitmix64(state)
819}
820
821#[derive(Clone, Debug)]
850pub struct AutoOuterSubsampleOptions {
851 pub min_n_for_auto: usize,
854 pub min_k: usize,
859 pub target_fraction: f64,
861 pub seed: u64,
865 pub outer_work_per_k_unit: u64,
886 pub min_k_floor: usize,
889}
890
891pub const AUTO_OUTER_WORK_BUDGET: u64 = 500_000_000;
896
897pub const AUTO_OUTER_MIN_K_FLOOR: usize = 1_000;
903
904const AUTO_OUTER_DISTINCT_STEP_L2_TOL: f64 = 1e-10;
910
911#[derive(Clone, Copy, Debug, PartialEq, Eq)]
916pub enum AutoOuterCapReason {
917 Noise,
918 Work,
919 Floor,
920 NFull,
921}
922
923impl AutoOuterCapReason {
924 pub fn as_str(self) -> &'static str {
925 match self {
926 AutoOuterCapReason::Noise => "noise",
927 AutoOuterCapReason::Work => "work",
928 AutoOuterCapReason::Floor => "floor",
929 AutoOuterCapReason::NFull => "n",
930 }
931 }
932}
933
934impl Default for AutoOuterSubsampleOptions {
935 fn default() -> Self {
936 Self {
937 min_n_for_auto: 30_000,
938 min_k: 10_000,
939 target_fraction: 0.10,
940 seed: 0xA075_8A8B_1ED5_5B5C,
941 outer_work_per_k_unit: 1,
942 min_k_floor: AUTO_OUTER_MIN_K_FLOOR,
943 }
944 }
945}
946
947#[derive(Clone, Copy, Debug)]
951pub struct AutoOuterKChoice {
952 pub k: usize,
953 pub k_noise: usize,
954 pub k_work: usize,
955 pub cap_reason: AutoOuterCapReason,
956}
957
958impl AutoOuterSubsampleOptions {
959 pub fn target_k(&self, n: usize) -> Option<usize> {
962 self.target_k_detailed(n).map(|choice| choice.k)
963 }
964
965 pub fn target_k_detailed(&self, n: usize) -> Option<AutoOuterKChoice> {
970 if n < self.min_n_for_auto {
971 return None;
972 }
973 let k_noise_raw = ((n as f64) * self.target_fraction).round() as usize;
974 let k_noise = k_noise_raw.max(self.min_k);
975 let work_per_k = self.outer_work_per_k_unit.max(1);
980 let k_work_u64 = AUTO_OUTER_WORK_BUDGET / work_per_k;
981 let k_work = usize::try_from(k_work_u64).unwrap_or(usize::MAX);
982 let mut k = k_noise.min(k_work);
985 let mut cap_reason = if k_work < k_noise {
986 AutoOuterCapReason::Work
987 } else {
988 AutoOuterCapReason::Noise
989 };
990 if k < self.min_k_floor {
991 k = self.min_k_floor;
992 cap_reason = AutoOuterCapReason::Floor;
993 }
994 if k > n {
995 k = n;
996 cap_reason = AutoOuterCapReason::NFull;
997 }
998 if k >= n {
999 return None;
1002 }
1003 Some(AutoOuterKChoice {
1004 k,
1005 k_noise,
1006 k_work,
1007 cap_reason,
1008 })
1009 }
1010}
1011
1012pub fn auto_outer_score_subsample(
1025 z: &[f64],
1026 stratum_secondary: Option<&[u8]>,
1027 options: &AutoOuterSubsampleOptions,
1028) -> Option<OuterScoreSubsample> {
1029 let n = z.len();
1030 let k = options.target_k(n)?;
1031 let secondary_storage;
1032 let secondary: &[u8] = if let Some(s) = stratum_secondary {
1033 if s.len() != n {
1034 return None;
1036 }
1037 s
1038 } else {
1039 secondary_storage = vec![0u8; n];
1040 &secondary_storage
1041 };
1042 Some(build_outer_score_subsample(z, secondary, k, options.seed))
1043}
1044
1045pub fn maybe_install_auto_outer_subsample(
1070 options: &crate::custom_family::BlockwiseFitOptions,
1071 z: &[f64],
1072 stratum_secondary: Option<&[u8]>,
1073 outer_rho_key: &[f64],
1074 phase_counter: &Arc<std::sync::atomic::AtomicUsize>,
1075 last_rho: &Arc<std::sync::Mutex<Option<Array1<f64>>>>,
1076 phase1_budget: usize,
1077 family_label: &'static str,
1078 outer_work_per_k_unit: u64,
1079 min_n_for_auto: usize,
1080 min_k: usize,
1081 min_k_floor: usize,
1082) -> Option<crate::custom_family::BlockwiseFitOptions> {
1083 if options.outer_score_subsample.is_some() || !options.auto_outer_subsample {
1084 return None;
1085 }
1086 let phase_idx = {
1087 let mut guard = last_rho
1088 .lock()
1089 .expect("auto_subsample_last_rho mutex poisoned");
1090 let new_step = match guard.as_ref() {
1091 None => true,
1092 Some(prev) if prev.len() != outer_rho_key.len() => true,
1093 Some(prev) => {
1094 let mut sq = 0.0_f64;
1095 for (a, b) in outer_rho_key.iter().zip(prev.iter()) {
1096 let d = a - b;
1097 sq += d * d;
1098 }
1099 sq.sqrt() > AUTO_OUTER_DISTINCT_STEP_L2_TOL
1100 }
1101 };
1102 if new_step {
1103 *guard = Some(Array1::from(outer_rho_key.to_vec()));
1104 phase_counter.fetch_add(1, std::sync::atomic::Ordering::SeqCst)
1105 } else {
1106 phase_counter
1107 .load(std::sync::atomic::Ordering::SeqCst)
1108 .saturating_sub(1)
1109 }
1110 };
1111 if phase_idx >= phase1_budget {
1112 if phase_idx == phase1_budget {
1113 log::info!(
1114 "[{family_label} auto-subsample] Phase 1 budget exhausted after {} evals; \
1115 Phase 2 (full data) for remaining iterations",
1116 phase1_budget
1117 );
1118 }
1119 return None;
1120 }
1121 let auto_options = AutoOuterSubsampleOptions {
1130 min_n_for_auto,
1131 min_k,
1132 min_k_floor,
1133 outer_work_per_k_unit: outer_work_per_k_unit.max(1),
1134 ..AutoOuterSubsampleOptions::default()
1135 };
1136 let choice = auto_options.target_k_detailed(z.len())?;
1140 let mask = auto_outer_score_subsample(z, stratum_secondary, &auto_options)?;
1141 let n_full = mask.n_full;
1142 let k = mask.len();
1143 log::info!(
1144 "[{family_label} auto-subsample] phase=1 eval={}/{} n={} K={} fraction={:.3} expected_grad_noise={:.2}% work_per_k_unit={} k_noise={} k_work={} cap_reason={}",
1145 phase_idx + 1,
1146 phase1_budget,
1147 n_full,
1148 k,
1149 k as f64 / n_full.max(1) as f64,
1150 100.0 * (1.0 / (k as f64).sqrt()) * (1.0 - k as f64 / n_full.max(1) as f64).sqrt(),
1151 outer_work_per_k_unit,
1152 choice.k_noise,
1153 choice.k_work,
1154 choice.cap_reason.as_str(),
1155 );
1156 let mut cloned = options.clone();
1157 cloned.outer_score_subsample = Some(Arc::new(mask));
1158 Some(cloned)
1159}
1160
1161pub fn build_outer_score_subsample(
1178 z: &[f64],
1179 stratum_secondary: &[u8],
1180 k: usize,
1181 seed: u64,
1182) -> OuterScoreSubsample {
1183 let n = z.len();
1184 assert_eq!(
1185 n,
1186 stratum_secondary.len(),
1187 "build_outer_score_subsample: z and stratum_secondary must have equal length",
1188 );
1189
1190 if n == 0 {
1191 return OuterScoreSubsample::with_uniform_weight(Vec::new(), 0, seed, 1.0);
1192 }
1193
1194 if k >= n {
1198 let mask: Vec<usize> = (0..n).collect();
1199 return OuterScoreSubsample::with_uniform_weight(mask, n, seed, 1.0);
1200 }
1201
1202 const Q: usize = 100;
1204 let mut z_order: Vec<usize> = (0..n).collect();
1205 z_order.sort_by(|&a, &b| z[a].partial_cmp(&z[b]).unwrap_or(std::cmp::Ordering::Equal));
1206 let mut decile = vec![0u16; n];
1208 for (rank, &row) in z_order.iter().enumerate() {
1209 let bin = (rank * Q) / n;
1212 let bin = bin.min(Q - 1);
1213 decile[row] = bin as u16;
1214 }
1215
1216 let mut distinct_secondary: Vec<u8> = stratum_secondary.to_vec();
1219 distinct_secondary.sort_unstable();
1220 distinct_secondary.dedup();
1221 let mut secondary_rank = vec![0u16; 256];
1224 for (rank, &val) in distinct_secondary.iter().enumerate() {
1225 secondary_rank[val as usize] = rank as u16;
1226 }
1227 let n_strata = distinct_secondary.len() * Q;
1228
1229 let mut strata: Vec<Vec<usize>> = vec![Vec::new(); n_strata];
1231 for i in 0..n {
1232 let s = secondary_rank[stratum_secondary[i] as usize] as usize * Q + decile[i] as usize;
1233 strata[s].push(i);
1234 }
1235
1236 let mut picked: Vec<WeightedOuterRow> = Vec::with_capacity(k + n_strata);
1239 for (stratum_id, rows) in strata.iter().enumerate() {
1240 if rows.is_empty() {
1241 continue;
1242 }
1243 let take = (k as u128 * rows.len() as u128).div_ceil(n as u128) as usize;
1244 let take = take.max(1).min(rows.len());
1245 let w_h = rows.len() as f64 / take as f64;
1248 let stratum_tag = stratum_id as u32;
1249
1250 let mut state = seed ^ (stratum_id as u64).wrapping_mul(0x9E3779B97F4A7C15);
1252 splitmix64(&mut state);
1254
1255 if take == rows.len() {
1256 for &index in rows.iter() {
1257 picked.push(WeightedOuterRow {
1258 index,
1259 weight: w_h,
1260 stratum: stratum_tag,
1261 });
1262 }
1263 } else {
1264 let mut buf: Vec<usize> = rows.clone();
1266 let m = buf.len();
1267 for i in 0..take {
1268 let r = splitmix64(&mut state);
1269 let j = i + (r as usize) % (m - i);
1270 buf.swap(i, j);
1271 }
1272 for &index in &buf[..take] {
1273 picked.push(WeightedOuterRow {
1274 index,
1275 weight: w_h,
1276 stratum: stratum_tag,
1277 });
1278 }
1279 }
1280 }
1281
1282 OuterScoreSubsample::from_weighted_rows(picked, n, seed)
1286}
1287
1288#[derive(Debug, Clone)]
1300pub enum OuterRowIter {
1301 All { n: usize },
1303 Subset { mask: Arc<Vec<usize>> },
1305}
1306
1307impl OuterRowIter {
1308 #[inline]
1310 pub fn len(&self) -> usize {
1311 match self {
1312 OuterRowIter::All { n } => *n,
1313 OuterRowIter::Subset { mask } => mask.len(),
1314 }
1315 }
1316
1317 #[inline]
1318 pub fn is_empty(&self) -> bool {
1319 self.len() == 0
1320 }
1321
1322 pub fn to_vec(&self) -> Vec<usize> {
1326 match self {
1327 OuterRowIter::All { n } => (0..*n).collect(),
1328 OuterRowIter::Subset { mask } => mask.as_ref().clone(),
1329 }
1330 }
1331}
1332
1333pub fn outer_row_indices(
1342 opts: &crate::custom_family::BlockwiseFitOptions,
1343 n: usize,
1344) -> OuterRowIter {
1345 match opts.outer_score_subsample.as_ref() {
1346 Some(s) => OuterRowIter::Subset {
1347 mask: Arc::clone(&s.mask),
1348 },
1349 None => OuterRowIter::All { n },
1350 }
1351}
1352
1353pub fn outer_weighted_rows(
1357 opts: &crate::custom_family::BlockwiseFitOptions,
1358 n: usize,
1359) -> Vec<WeightedOuterRow> {
1360 match opts.outer_score_subsample.as_ref() {
1361 Some(s) => s.rows.as_ref().clone(),
1362 None => (0..n)
1363 .map(|index| WeightedOuterRow {
1364 index,
1365 weight: 1.0,
1366 stratum: 0,
1367 })
1368 .collect(),
1369 }
1370}
1371
1372pub fn outer_row_weights_by_index(
1377 opts: &crate::custom_family::BlockwiseFitOptions,
1378 n: usize,
1379) -> Vec<f64> {
1380 match opts.outer_score_subsample.as_ref() {
1381 Some(s) => {
1382 let mut weights = vec![1.0; n];
1383 for r in s.rows.iter() {
1384 if r.index < n {
1385 weights[r.index] = r.weight;
1386 }
1387 }
1388 weights
1389 }
1390 None => vec![1.0; n],
1391 }
1392}
1393
1394pub fn feasible_step_fraction<E>(
1411 constraints: &gam_problem::LinearInequalityConstraints,
1412 beta: &Array1<f64>,
1413 direction: &Array1<f64>,
1414 map_dim_err: impl Fn(usize, usize, usize) -> E,
1415 map_violation_err: impl Fn(usize, f64) -> E,
1416) -> Result<f64, E> {
1417 if beta.len() != constraints.a.ncols() || direction.len() != constraints.a.ncols() {
1418 return Err(map_dim_err(
1419 beta.len(),
1420 direction.len(),
1421 constraints.a.ncols(),
1422 ));
1423 }
1424 const FEASIBLE_STEP_VIOLATION_TOL: f64 = 1e-8;
1435 const FEASIBLE_STEP_BOUNDARY_BACKOFF: f64 = 0.995;
1440 let mut alpha = 1.0f64;
1441 for row in 0..constraints.a.nrows() {
1442 let a_row = constraints.a.row(row);
1443 let raw_slack = a_row.dot(beta) - constraints.b[row];
1444 if raw_slack < -FEASIBLE_STEP_VIOLATION_TOL {
1445 return Err(map_violation_err(row, raw_slack));
1446 }
1447 let slack = raw_slack.max(0.0);
1450 let drift = a_row.dot(direction);
1451 if drift < 0.0 {
1452 alpha = alpha.min((slack / -drift).clamp(0.0, 1.0));
1453 }
1454 }
1455 if alpha >= 1.0 {
1456 Ok(1.0)
1457 } else {
1458 Ok((FEASIBLE_STEP_BOUNDARY_BACKOFF * alpha).clamp(0.0, 1.0))
1459 }
1460}
1461
1462pub trait MarginalSlopePsiFamily: Send + Sync {
1481 fn is_sigma_aux(&self, psi_index: usize) -> bool;
1484
1485 fn sigma_first_order_terms(
1487 &self,
1488 ) -> Result<Option<crate::custom_family::ExactNewtonJointPsiTerms>, String>;
1489
1490 fn psi_first_order_terms(
1492 &self,
1493 psi_index: usize,
1494 ) -> Result<Option<crate::custom_family::ExactNewtonJointPsiTerms>, String>;
1495
1496 fn psi_first_order_terms_all(
1501 &self,
1502 ) -> Result<Option<Vec<crate::custom_family::ExactNewtonJointPsiTerms>>, String>;
1503
1504 fn both_sigma_aux_second_order(&self, psi_i: usize, psi_j: usize) -> bool;
1509
1510 fn sigma_second_order_terms(
1512 &self,
1513 ) -> Result<Option<crate::custom_family::ExactNewtonJointPsiSecondOrderTerms>, String>;
1514
1515 fn mixed_sigma_aux_second_order(
1519 &self,
1520 ) -> Result<Option<crate::custom_family::ExactNewtonJointPsiSecondOrderTerms>, String>;
1521
1522 fn psi_second_order_terms(
1524 &self,
1525 psi_i: usize,
1526 psi_j: usize,
1527 ) -> Result<Option<crate::custom_family::ExactNewtonJointPsiSecondOrderTerms>, String>;
1528
1529 fn psi_second_order_terms_contracted(
1542 &self,
1543 _: &[f64],
1544 ) -> Result<Option<crate::custom_family::ExactNewtonJointPsiSecondOrderContracted>, String>
1545 {
1546 Ok(None)
1548 }
1549
1550 fn sigma_hessian_directional_derivative(
1554 &self,
1555 d_beta_flat: &Array1<f64>,
1556 ) -> Result<Option<Array2<f64>>, String>;
1557
1558 fn psi_hessian_directional_derivative(
1562 &self,
1563 psi_index: usize,
1564 d_beta_flat: &Array1<f64>,
1565 ) -> Result<Option<Arc<dyn gam_problem::HyperOperator>>, String>;
1566}
1567
1568pub struct MarginalSlopeExactNewtonPsiWorkspace<F: MarginalSlopePsiFamily> {
1572 family: F,
1573}
1574
1575impl<F: MarginalSlopePsiFamily> MarginalSlopeExactNewtonPsiWorkspace<F> {
1576 pub fn new(family: F) -> Self {
1577 Self { family }
1578 }
1579}
1580
1581impl<F: MarginalSlopePsiFamily> crate::custom_family::ExactNewtonJointPsiWorkspace
1582 for MarginalSlopeExactNewtonPsiWorkspace<F>
1583{
1584 fn first_order_terms(
1585 &self,
1586 psi_index: usize,
1587 ) -> Result<Option<crate::custom_family::ExactNewtonJointPsiTerms>, String> {
1588 if self.family.is_sigma_aux(psi_index) {
1589 return self.family.sigma_first_order_terms();
1590 }
1591 self.family.psi_first_order_terms(psi_index)
1592 }
1593
1594 fn first_order_terms_all(
1595 &self,
1596 ) -> Result<Option<Vec<crate::custom_family::ExactNewtonJointPsiTerms>>, String> {
1597 self.family.psi_first_order_terms_all()
1598 }
1599
1600 fn second_order_terms(
1601 &self,
1602 psi_i: usize,
1603 psi_j: usize,
1604 ) -> Result<Option<crate::custom_family::ExactNewtonJointPsiSecondOrderTerms>, String> {
1605 if self.family.is_sigma_aux(psi_i) || self.family.is_sigma_aux(psi_j) {
1606 if self.family.both_sigma_aux_second_order(psi_i, psi_j) {
1607 return self.family.sigma_second_order_terms();
1608 }
1609 return self.family.mixed_sigma_aux_second_order();
1610 }
1611 self.family.psi_second_order_terms(psi_i, psi_j)
1612 }
1613
1614 fn second_order_terms_contracted(
1615 &self,
1616 alpha_psi: &[f64],
1617 ) -> Result<Option<crate::custom_family::ExactNewtonJointPsiSecondOrderContracted>, String>
1618 {
1619 for (j, &weight) in alpha_psi.iter().enumerate() {
1628 if weight != 0.0 && self.family.is_sigma_aux(j) {
1629 return Ok(None);
1630 }
1631 }
1632 self.family.psi_second_order_terms_contracted(alpha_psi)
1633 }
1634
1635 fn hessian_directional_derivative(
1636 &self,
1637 psi_index: usize,
1638 d_beta_flat: &Array1<f64>,
1639 ) -> Result<Option<gam_problem::DriftDerivResult>, String> {
1640 if self.family.is_sigma_aux(psi_index) {
1641 return self
1642 .family
1643 .sigma_hessian_directional_derivative(d_beta_flat)
1644 .map(|result| result.map(gam_problem::DriftDerivResult::Dense));
1645 }
1646 self.family
1647 .psi_hessian_directional_derivative(psi_index, d_beta_flat)
1648 .map(|result| result.map(gam_problem::DriftDerivResult::Operator))
1649 }
1650}
1651
1652pub(crate) fn chunked_row_reduction<Item, Acc, Init, Process, Combine>(
1674 rows: &[Item],
1675 init: Init,
1676 process_row: Process,
1677 mut combine: Combine,
1678) -> Result<Acc, String>
1679where
1680 Item: Sync + Copy,
1681 Acc: Send,
1682 Init: Fn() -> Acc + Sync,
1683 Process: Fn(Item, &mut Acc) -> Result<(), String> + Sync,
1684 Combine: FnMut(&mut Acc, Acc),
1685{
1686 use rayon::iter::{IntoParallelIterator, ParallelIterator};
1687 let n = rows.len();
1688 if n == 0 {
1689 return Ok(init());
1690 }
1691 const CHUNKS_PER_WORKER: usize = 4;
1703 const MIN_CHUNK_COUNT: usize = 32;
1704 const MIN_ROWS_PER_CHUNK: usize = 64;
1705 let workers = rayon::current_num_threads().max(1);
1706 let target_chunk_count = workers
1707 .saturating_mul(CHUNKS_PER_WORKER)
1708 .max(MIN_CHUNK_COUNT);
1709 let chunk_count = target_chunk_count
1712 .min(n.div_ceil(MIN_ROWS_PER_CHUNK))
1713 .max(1);
1714 let chunk_size = n.div_ceil(chunk_count).max(1);
1715 let n_chunks = n.div_ceil(chunk_size);
1716 let chunk_states: Vec<Acc> = (0..n_chunks)
1721 .into_par_iter()
1722 .map(|chunk_idx| -> Result<Acc, String> {
1723 let start = chunk_idx * chunk_size;
1724 let end = (start + chunk_size).min(n);
1725 let mut acc = init();
1726 for &item in &rows[start..end] {
1727 process_row(item, &mut acc)?;
1728 }
1729 Ok(acc)
1730 })
1731 .collect::<Result<Vec<Acc>, String>>()?;
1732 let mut total = init();
1733 for chunk in chunk_states {
1734 combine(&mut total, chunk);
1735 }
1736 Ok(total)
1737}
1738
1739#[cfg(test)]
1740mod tests {
1741 use super::*;
1742
1743 use gam_math::jet_partitions::MultiDirJet;
1766
1767 struct Lcg(u64);
1770 impl Lcg {
1771 fn next_f64(&mut self) -> f64 {
1772 self.0 = self.0.wrapping_mul(6364136223846793005).wrapping_add(1);
1773 ((self.0 >> 11) as f64) / ((1u64 << 53) as f64) * 2.0 - 1.0
1774 }
1775 }
1776
1777 fn synthetic_row_eval(
1783 bases: &[f64],
1784 weight: f64,
1785 dirs: &[&Array1<f64>],
1786 scale: &MultiDirJet,
1787 ) -> Result<f64, String> {
1788 let k = dirs.len();
1789 if k > 4 {
1790 return Err(format!("synthetic eval expects 0..=4 directions, got {k}"));
1791 }
1792 if scale.coeffs.len() != (1usize << k) {
1793 return Err(format!(
1794 "synthetic eval scale jet dimension mismatch: coeffs={}, dirs={k}",
1795 scale.coeffs.len()
1796 ));
1797 }
1798 let primary_dim = bases.len();
1799 let first = |dir: &Array1<f64>| -> Vec<f64> {
1803 (0..k).map(|j| dir[j % primary_dim]).collect::<Vec<f64>>()
1804 };
1805 let mut product = MultiDirJet::constant(k, 1.0);
1806 for (slot, dir) in dirs.iter().enumerate() {
1807 let base = bases[slot % primary_dim] + 0.25 * slot as f64;
1808 let comps: Vec<f64> = (0..primary_dim)
1809 .map(|p| dir[p] * (1.0 + 0.5 * p as f64))
1810 .collect();
1811 let lin = MultiDirJet::linear(k, base, &first(&Array1::from(comps)));
1812 product = product.mul(&lin);
1813 }
1814 let scaled = product.mul(scale);
1815 let x = scaled.coeff(0);
1818 let denom = 1.0 + x * x;
1819 let d1 = weight * (2.0 * x) / denom;
1820 let d2 = weight * (2.0 * (1.0 - x * x)) / (denom * denom);
1821 let d3 = weight * (-4.0 * x * (3.0 - x * x)) / (denom * denom * denom);
1822 let d4 = weight * (-12.0 * (1.0 - 6.0 * x * x + x * x * x * x))
1823 / (denom * denom * denom * denom);
1824 let phi = weight * denom.ln();
1825 Ok(scaled
1826 .compose_unary([phi, d1, d2, d3, d4])
1827 .coeff((1usize << k) - 1))
1828 }
1829
1830 fn reference_obj_grad_hess<Eval>(
1834 primary_dim: usize,
1835 leading: &[&Array1<f64>],
1836 scales: &DirectionalScaleJets,
1837 eval: Eval,
1838 ) -> Result<(f64, Array1<f64>, Array2<f64>), String>
1839 where
1840 Eval: Fn(&[&Array1<f64>], &MultiDirJet) -> Result<f64, String>,
1841 {
1842 let unit = |a: usize| -> Array1<f64> {
1843 let mut da = Array1::<f64>::zeros(primary_dim);
1844 da[a] = 1.0;
1845 da
1846 };
1847 let objective = if let Some(scale_obj) = scales.obj.as_ref() {
1848 eval(leading, scale_obj)?
1849 } else {
1850 0.0
1851 };
1852 let mut grad = Array1::<f64>::zeros(primary_dim);
1853 for a in 0..primary_dim {
1854 let da = unit(a);
1855 let mut dirs: Vec<&Array1<f64>> = leading.to_vec();
1856 dirs.push(&da);
1857 grad[a] = eval(&dirs, &scales.grad)?;
1858 }
1859 let mut hess = Array2::<f64>::zeros((primary_dim, primary_dim));
1860 for a in 0..primary_dim {
1861 let da = unit(a);
1862 for b in a..primary_dim {
1863 let db = unit(b);
1864 let mut dirs: Vec<&Array1<f64>> = leading.to_vec();
1865 dirs.push(&da);
1866 dirs.push(&db);
1867 let value = eval(&dirs, &scales.hess)?;
1868 hess[[a, b]] = value;
1869 hess[[b, a]] = value;
1870 }
1871 }
1872 Ok((objective, grad, hess))
1873 }
1874
1875 fn random_scale_jet(
1880 rng: &mut Lcg,
1881 n_dirs: usize,
1882 first_masks: &[usize],
1883 second_masks: &[usize],
1884 ) -> MultiDirJet {
1885 let mut coeffs: Vec<(usize, f64)> = vec![(0usize, 1.0 + 0.1 * rng.next_f64())];
1886 for &m in first_masks {
1887 coeffs.push((1usize << m, rng.next_f64()));
1888 }
1889 for &m in second_masks {
1890 coeffs.push(((1usize << m) | 1usize, rng.next_f64()));
1891 }
1892 MultiDirJet::with_coeffs(n_dirs, &coeffs)
1893 }
1894
1895 #[test]
1896 fn directional_obj_grad_hess_matches_reference_loop_nest() {
1897 let primary_dim = 4usize;
1898 let mut rng = Lcg(0x5EED_1234_ABCD_0001);
1899 for trial in 0..32 {
1903 let bases: Vec<f64> = (0..primary_dim).map(|_| rng.next_f64()).collect();
1904 let weight = 0.5 + 0.5 * (rng.next_f64() + 1.0);
1905 let eval = |dirs: &[&Array1<f64>], scale: &MultiDirJet| {
1906 synthetic_row_eval(&bases, weight, dirs, scale)
1907 };
1908
1909 let zero = Array1::<f64>::zeros(primary_dim);
1910 let row_dir: Array1<f64> =
1911 Array1::from((0..primary_dim).map(|_| rng.next_f64()).collect::<Vec<_>>());
1912
1913 let cases: Vec<(Vec<&Array1<f64>>, DirectionalScaleJets)> = vec![
1914 (
1915 vec![&zero],
1916 DirectionalScaleJets {
1917 obj: Some(random_scale_jet(&mut rng, 1, &[], &[])),
1918 grad: random_scale_jet(&mut rng, 2, &[0], &[]),
1919 hess: random_scale_jet(&mut rng, 3, &[0], &[]),
1920 },
1921 ),
1922 (
1923 vec![&zero, &zero],
1924 DirectionalScaleJets {
1925 obj: Some(random_scale_jet(&mut rng, 2, &[0, 1], &[])),
1926 grad: random_scale_jet(&mut rng, 3, &[0, 1], &[]),
1927 hess: random_scale_jet(&mut rng, 4, &[0, 1], &[]),
1928 },
1929 ),
1930 (
1931 vec![&zero, &row_dir],
1932 DirectionalScaleJets {
1933 obj: None,
1934 grad: random_scale_jet(&mut rng, 3, &[0], &[]),
1935 hess: random_scale_jet(&mut rng, 4, &[0], &[]),
1936 },
1937 ),
1938 ];
1939
1940 for (leading, scales) in &cases {
1941 let shared =
1942 directional_obj_grad_hess(primary_dim, leading, scales, eval).expect("shared");
1943 let (ref_obj, ref_grad, ref_hess) =
1944 reference_obj_grad_hess(primary_dim, leading, scales, eval).expect("reference");
1945
1946 assert_eq!(
1947 shared.objective, ref_obj,
1948 "trial {trial}: objective drift {} vs {}",
1949 shared.objective, ref_obj
1950 );
1951 for a in 0..primary_dim {
1952 assert_eq!(
1953 shared.grad[a], ref_grad[a],
1954 "trial {trial}: grad[{a}] drift {} vs {}",
1955 shared.grad[a], ref_grad[a]
1956 );
1957 for b in 0..primary_dim {
1958 assert_eq!(
1959 shared.hess[[a, b]],
1960 ref_hess[[a, b]],
1961 "trial {trial}: hess[{a},{b}] drift {} vs {}",
1962 shared.hess[[a, b]],
1963 ref_hess[[a, b]]
1964 );
1965 }
1966 }
1967 for a in 0..primary_dim {
1971 for b in 0..primary_dim {
1972 assert_eq!(
1973 shared.hess[[a, b]],
1974 shared.hess[[b, a]],
1975 "trial {trial}: hess asymmetric at ({a},{b})"
1976 );
1977 }
1978 }
1979 }
1980 }
1981 }
1982
1983 #[test]
1984 fn auto_outer_score_subsample_skips_small_problems() {
1985 let n = 1000;
1986 let z: Vec<f64> = (0..n).map(|i| i as f64).collect();
1987 let opts = AutoOuterSubsampleOptions::default();
1988 assert!(
1989 auto_outer_score_subsample(&z, None, &opts).is_none(),
1990 "n={n} below default min_n_for_auto=30000 should not subsample"
1991 );
1992 }
1993
1994 #[test]
1995 fn auto_outer_score_subsample_returns_target_k_above_threshold() {
1996 let n = 60_000;
1997 let z: Vec<f64> = (0..n).map(|i| (i as f64).sin()).collect();
1998 let opts = AutoOuterSubsampleOptions::default();
1999 let mask = auto_outer_score_subsample(&z, None, &opts)
2000 .expect("n=60000 should auto-subsample with default options");
2001 assert_eq!(mask.n_full, n);
2003 assert!(
2004 mask.len() >= 9_900 && mask.len() <= 10_200,
2005 "expected K≈10_000, got {}",
2006 mask.len()
2007 );
2008 let weight_sum: f64 = mask.rows.iter().map(|r| r.weight).sum();
2011 let rel_err = (weight_sum - n as f64).abs() / n as f64;
2012 assert!(
2013 rel_err < 0.02,
2014 "HT weight sum {weight_sum:.3} should ≈ n_full={n}, rel_err={rel_err:.4}"
2015 );
2016 }
2017
2018 #[test]
2019 fn auto_outer_score_subsample_horvitz_thompson_unbiased() {
2020 let n = 50_000;
2026 let z: Vec<f64> = (0..n)
2027 .map(|i| ((i as f64) / n as f64) * 2.0 - 1.0)
2028 .collect();
2029 let stratum: Vec<u8> = (0..n).map(|i| if i % 3 == 0 { 1 } else { 0 }).collect();
2030 let opts = AutoOuterSubsampleOptions {
2031 seed: 0xC0FFEE,
2032 ..AutoOuterSubsampleOptions::default()
2033 };
2034 let t: Vec<f64> = z.iter().map(|zi| zi * zi + 1.0).collect();
2035 let exact: f64 = t.iter().sum();
2036 let mask = auto_outer_score_subsample(&z, Some(&stratum), &opts)
2037 .expect("n=50000 should auto-subsample");
2038 let estimate: f64 = mask.rows.iter().map(|r| r.weight * t[r.index]).sum();
2039 let k = mask.len();
2043 let predicted_se =
2044 exact * 0.4 * (1.0 / (k as f64).sqrt()) * (1.0 - k as f64 / n as f64).sqrt();
2045 let observed_err = (estimate - exact).abs();
2046 assert!(
2047 observed_err < 5.0 * predicted_se.max(1.0),
2048 "HT estimate {estimate:.3} vs exact {exact:.3}: err={observed_err:.3} exceeds 5×predicted_se={:.3}",
2049 predicted_se
2050 );
2051 }
2052
2053 #[test]
2054 fn subsample_full_n_equals_no_subsample() {
2055 let n: usize = 1024;
2059 let z: Vec<f64> = (0..n).map(|i| i as f64).collect();
2060 let secondary: Vec<u8> = (0..n).map(|i| (i % 2) as u8).collect();
2061 let s = build_outer_score_subsample(&z, &secondary, n, 0xDEADBEEF);
2062 assert_eq!(s.len(), n);
2063 assert!((s.weight_scale - 1.0).abs() < 1e-12);
2064
2065 let mut full = crate::custom_family::BlockwiseFitOptions::default();
2066 let from_none = outer_row_indices(&full, n).to_vec();
2067 full.outer_score_subsample = Some(Arc::new(s));
2068 let from_some = outer_row_indices(&full, n).to_vec();
2069
2070 let mut a = from_none.clone();
2071 let mut b = from_some.clone();
2072 a.sort_unstable();
2073 b.sort_unstable();
2074 assert_eq!(a, b);
2075 assert_eq!(a, (0..n).collect::<Vec<_>>());
2076 }
2077
2078 #[test]
2079 fn stratification_covers_all_strata() {
2080 let n: usize = 20_000;
2083 let z: Vec<f64> = (0..n).map(|i| (i as f64) * 0.001).collect();
2084 let secondary: Vec<u8> = (0..n).map(|i| (i % 2) as u8).collect();
2085 let k = 2_000;
2086 let s = build_outer_score_subsample(&z, &secondary, k, 12345);
2087 assert!(s.len() >= k, "subsample size {} < k {}", s.len(), k);
2088
2089 let mut order: Vec<usize> = (0..n).collect();
2091 order.sort_by(|&a, &b| z[a].partial_cmp(&z[b]).unwrap());
2092 let mut decile = vec![0usize; n];
2093 for (rank, &row) in order.iter().enumerate() {
2094 decile[row] = ((rank * 100) / n).min(99);
2095 }
2096 let mut covered = [false; 200];
2098 for &row in s.mask.iter() {
2099 let stratum = secondary[row] as usize * 100 + decile[row];
2100 covered[stratum] = true;
2101 }
2102 for (stratum, &c) in covered.iter().enumerate() {
2105 assert!(c, "stratum {} uncovered", stratum);
2106 }
2107 }
2108
2109 #[test]
2110 fn deterministic_seed() {
2111 let n: usize = 5_000;
2115 let z: Vec<f64> = (0..n).map(|i| (i as f64).sin()).collect();
2116 let secondary: Vec<u8> = (0..n).map(|i| (i % 2) as u8).collect();
2117 let k = 800;
2118 let a = build_outer_score_subsample(&z, &secondary, k, 0xABCDEF);
2119 let b = build_outer_score_subsample(&z, &secondary, k, 0xABCDEF);
2120 let c = build_outer_score_subsample(&z, &secondary, k, 0xFEDCBA);
2121 assert_eq!(a.mask.as_ref(), b.mask.as_ref());
2122 assert_ne!(a.mask.as_ref(), c.mask.as_ref());
2123 }
2124
2125 #[test]
2126 fn weight_scale_correct() {
2127 let n: usize = 10_000;
2130 let z: Vec<f64> = (0..n).map(|i| i as f64).collect();
2131 let secondary: Vec<u8> = (0..n).map(|i| (i % 2) as u8).collect();
2132 let k = 2_000;
2133 let s = build_outer_score_subsample(&z, &secondary, k, 7);
2134 assert!(s.len() >= k);
2135 assert!(
2138 s.len() <= k + 200,
2139 "subsample {} much larger than expected",
2140 s.len()
2141 );
2142 let scale = s.weight_scale;
2143 assert!(
2145 (scale - 5.0).abs() < 0.5,
2146 "weight_scale {} not near 5.0",
2147 scale
2148 );
2149 }
2150}