1use super::*;
2
3pub(crate) fn beta_bits_match(cached: &Array1<f64>, candidate: &Array1<f64>) -> bool {
4 cached.len() == candidate.len()
5 && cached
6 .iter()
7 .zip(candidate.iter())
8 .all(|(&left, &right)| left.to_bits() == right.to_bits())
9}
10
11#[derive(Clone, Debug)]
14pub struct TransformationWarmStart {
15 pub location: Array1<f64>,
17 pub scale: Array1<f64>,
19}
20
21#[derive(Clone)]
32pub struct TransformationNormalFamily {
33 pub(crate) x_val_kron: KroneckerDesign,
37 pub(crate) x_deriv_kron: KroneckerDesign,
39 pub(crate) response_val_basis: Array2<f64>,
42 pub(crate) response_lower_basis: Array1<f64>,
44 pub(crate) response_upper_basis: Array1<f64>,
46 pub(crate) response_deriv_basis: Array2<f64>,
48
49 pub(crate) covariate_design: DesignMatrix,
52 pub(crate) covariate_dense_cache: Arc<Mutex<Option<Arc<Array2<f64>>>>>,
59 pub(crate) weights: Arc<Array1<f64>>,
61 pub(crate) offset: Arc<Array1<f64>>,
63 pub(crate) tensor_penalties: Vec<PenaltyMatrix>,
65
66 pub(crate) initial_beta: Array1<f64>,
68 pub(crate) initial_log_lambdas: Array1<f64>,
69
70 pub(crate) block_name: String,
72
73 pub(crate) response_knots: Array1<f64>,
75 pub(crate) response_transform: Array2<f64>,
76 pub(crate) response_degree: usize,
77 pub(crate) response_median: f64,
78 pub(crate) response_floor_offset: Arc<Array1<f64>>,
79 pub(crate) response_lower_floor_offset: f64,
80 pub(crate) response_upper_floor_offset: f64,
81
82 pub(crate) row_quantity_cache: Arc<Mutex<Option<TransformationNormalRowQuantityCache>>>,
90 pub(crate) outer_subsample_weights: Option<Arc<Array1<f64>>>,
105}
106
107#[derive(Clone)]
108pub(crate) struct TransformationNormalRowQuantityCache {
109 pub(crate) beta: Arc<Array1<f64>>,
110 pub(crate) gamma: Arc<Array2<f64>>,
111 pub(crate) h: Arc<Array1<f64>>,
112 pub(crate) h_prime: Arc<Array1<f64>>,
113 pub(crate) h_lower: Arc<Array1<f64>>,
114 pub(crate) h_upper: Arc<Array1<f64>>,
115 pub(crate) endpoint_q: Arc<Vec<LogNormalCdfDiffDerivatives>>,
116 pub(crate) log_likelihood: f64,
117}
118
119#[derive(Debug)]
120pub(crate) struct TransformationNormalRowDerived {
121 pub(crate) log_likelihood: f64,
122 pub(crate) endpoint_q: Vec<LogNormalCdfDiffDerivatives>,
123}
124
125impl TransformationNormalRowQuantityCache {
126 pub(crate) fn matches_beta(&self, beta: &Array1<f64>) -> bool {
127 beta_bits_match(&self.beta, beta)
128 }
129}
130
131pub(crate) fn build_transformation_row_derived(
132 h: &Array1<f64>,
133 h_prime: &Array1<f64>,
134 h_lower: &Array1<f64>,
135 h_upper: &Array1<f64>,
136 weights: &Array1<f64>,
137) -> Result<TransformationNormalRowDerived, String> {
138 let n = h_prime.len();
139 assert_eq!(h.len(), n);
140 assert_eq!(h_lower.len(), n);
141 assert_eq!(h_upper.len(), n);
142 assert_eq!(weights.len(), n);
143
144 if let Some((i, value)) = h
145 .iter()
146 .copied()
147 .enumerate()
148 .find(|(_, value)| !value.is_finite())
149 {
150 return Err(TransformationNormalError::NonFinite {
151 reason: format!(
152 "TransformationNormalFamily row_quantities: h[{i}] = {value} is not finite"
153 ),
154 }
155 .into());
156 }
157 if let Some((i, value)) = weights
158 .iter()
159 .copied()
160 .enumerate()
161 .find(|(_, value)| !value.is_finite())
162 {
163 return Err(TransformationNormalError::NonFinite {
164 reason: format!(
165 "TransformationNormalFamily row_quantities: weight[{i}] = {value} is not finite"
166 ),
167 }
168 .into());
169 }
170
171 use rayon::iter::{IntoParallelIterator, ParallelIterator};
181 let rows: Vec<(f64, LogNormalCdfDiffDerivatives)> = (0..n)
182 .into_par_iter()
183 .map(|i| -> Result<(f64, LogNormalCdfDiffDerivatives), String> {
184 let hp = h_prime[i];
185 let inv_h_prime = 1.0 / hp;
186 let inv_h_prime_sq = inv_h_prime * inv_h_prime;
187 let inv_h_prime_cu = inv_h_prime_sq * inv_h_prime;
188 let inv_h_prime_qu = inv_h_prime_sq * inv_h_prime_sq;
189 let w_i = weights[i];
190 let h_i = h[i];
191 let weighted_h = w_i * h_i;
192 let weighted_inv_h_prime = w_i * inv_h_prime;
193 let weighted_inv_h_prime_sq = w_i * inv_h_prime_sq;
194 let q = log_normal_cdf_diff_derivatives(h_upper[i], h_lower[i]).map_err(|e| {
195 format!("TransformationNormalFamily row_quantities: row {i} invalid endpoint normalizer: {e}")
196 })?;
197 let log_z = q.log_z;
198 let row_ll = w_i * (-0.5 * h_i * h_i + hp.ln() - log_z);
199 if !(inv_h_prime.is_finite()
203 && inv_h_prime_sq.is_finite()
204 && inv_h_prime_cu.is_finite()
205 && inv_h_prime_qu.is_finite()
206 && weighted_h.is_finite()
207 && weighted_inv_h_prime.is_finite()
208 && weighted_inv_h_prime_sq.is_finite()
209 && log_z.is_finite())
210 {
211 let derived_values = [
212 ("1/h'", inv_h_prime),
213 ("1/h'^2", inv_h_prime_sq),
214 ("1/h'^3", inv_h_prime_cu),
215 ("1/h'^4", inv_h_prime_qu),
216 ("w*h", weighted_h),
217 ("w/h'", weighted_inv_h_prime),
218 ("w/h'^2", weighted_inv_h_prime_sq),
219 ("log normalizer", log_z),
220 ];
221 for (name, value) in derived_values {
222 if !value.is_finite() {
223 return Err(TransformationNormalError::NonFinite { reason: format!(
224 "TransformationNormalFamily row_quantities: {name} at row {i} is not finite ({value}); h'={hp} is outside the finite exact-derivative range",
225 ) }.into());
226 }
227 }
228 return Err(TransformationNormalError::NonFinite { reason: format!(
229 "TransformationNormalFamily row_quantities: row {i} entered non-finite branch but no named field was non-finite; h'={hp}",
230 ) }.into());
231 }
232 Ok((row_ll, q))
233 })
234 .collect::<Result<Vec<_>, _>>()?;
235
236 let mut log_likelihood = 0.0;
242 let mut endpoint_q = Vec::with_capacity(n);
243 for (row_ll, q) in rows {
244 log_likelihood += row_ll;
245 endpoint_q.push(q);
246 }
247 if !log_likelihood.is_finite() {
248 return Err(TransformationNormalError::NonFinite { reason: format!(
249 "TransformationNormalFamily row_quantities: log-likelihood is not finite ({log_likelihood})"
250 ) }.into());
251 }
252
253 Ok(TransformationNormalRowDerived {
254 log_likelihood,
255 endpoint_q,
256 })
257}
258
259impl TransformationNormalFamily {
260 pub fn new(
271 response: &Array1<f64>,
272 weights: &Array1<f64>,
273 offset: &Array1<f64>,
274 covariate_design: DesignMatrix,
275 covariate_penalties: Vec<PenaltyMatrix>,
276 config: &TransformationNormalConfig,
277 warm_start: Option<&TransformationWarmStart>,
278 ) -> Result<Self, String> {
279 let n = response.len();
280 if covariate_design.nrows() != n {
281 return Err(TransformationNormalError::InvalidInput {
282 reason: format!(
283 "response length {} != covariate design rows {}",
284 n,
285 covariate_design.nrows()
286 ),
287 }
288 .into());
289 }
290 let p_cov = covariate_design.ncols();
291 if p_cov == 0 {
292 return Err(TransformationNormalError::DesignDegenerate {
293 reason: "covariate design has zero columns".to_string(),
294 }
295 .into());
296 }
297 if weights.len() != n {
298 return Err(TransformationNormalError::InvalidInput {
299 reason: format!("response length {} != weights length {}", n, weights.len()),
300 }
301 .into());
302 }
303 if offset.len() != n {
304 return Err(TransformationNormalError::InvalidInput {
305 reason: format!("response length {} != offset length {}", n, offset.len()),
306 }
307 .into());
308 }
309 for (i, &weight) in weights.iter().enumerate() {
310 if !weight.is_finite() {
311 return Err(TransformationNormalError::NonFinite {
312 reason: format!("weights[{i}] is not finite: {weight}"),
313 }
314 .into());
315 }
316 if weight < 0.0 {
317 return Err(TransformationNormalError::InvalidInput {
318 reason: format!("weights[{i}] must be non-negative: {weight}"),
319 }
320 .into());
321 }
322 }
323 for (i, &value) in offset.iter().enumerate() {
324 if !value.is_finite() {
325 return Err(TransformationNormalError::NonFinite {
326 reason: format!("offset[{i}] is not finite: {value}"),
327 }
328 .into());
329 }
330 }
331 for (i, sp) in covariate_penalties.iter().enumerate() {
332 let (r, c) = sp.shape();
333 if r != p_cov || c != p_cov {
334 return Err(TransformationNormalError::InvalidInput {
335 reason: format!(
336 "covariate penalty {} has shape ({r}, {c}), expected ({p_cov}, {p_cov})",
337 i,
338 ),
339 }
340 .into());
341 }
342 }
343
344 let (resp_val, resp_deriv, resp_penalties, resp_knots, resp_transform) =
346 build_response_basis(response, config)?;
347 let p_resp = resp_val.ncols();
348 let (response_lower_basis, response_upper_basis) =
349 response_endpoint_value_bases(&resp_transform);
350
351 let x_val_kron = KroneckerDesign::new_khatri_rao(&resp_val, covariate_design.clone())?;
353 let x_deriv_kron = KroneckerDesign::new_khatri_rao(&resp_deriv, covariate_design.clone())?;
354 let p_total = p_resp * p_cov;
355 assert_eq!(x_val_kron.ncols(), p_total);
356 assert_eq!(x_deriv_kron.ncols(), p_total);
357
358 let initial_beta = compute_warm_start(
360 response,
361 weights,
362 offset,
363 &x_val_kron,
364 &x_deriv_kron,
365 &covariate_design,
366 &covariate_penalties,
367 p_resp,
368 p_cov,
369 warm_start,
370 )?;
371
372 let tensor_penalties = build_tensor_penalties_kronecker(
374 &resp_penalties,
375 covariate_penalties,
376 p_resp,
377 p_cov,
378 config,
379 )?;
380 let policy = ResourcePolicy::default_library();
381 let x_val_weighted_gram = x_val_kron.weighted_gram(weights, &policy);
382
383 let initial_log_lambdas =
385 ctn_penalty_scale_log_lambdas(&tensor_penalties, &x_val_weighted_gram);
386
387 let mut sorted_resp = response.to_vec();
389 sorted_resp.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
390 let resp_median = if sorted_resp.len() % 2 == 1 {
391 sorted_resp[sorted_resp.len() / 2]
392 } else {
393 0.5 * (sorted_resp[sorted_resp.len() / 2 - 1] + sorted_resp[sorted_resp.len() / 2])
394 };
395 let (response_floor_offset, response_lower_floor_offset, response_upper_floor_offset) =
396 response_floor_offsets(response, &resp_knots, resp_median);
397
398 Ok(Self {
399 x_val_kron,
400 x_deriv_kron,
401 response_val_basis: resp_val,
402 response_lower_basis,
403 response_upper_basis,
404 response_deriv_basis: resp_deriv,
405 covariate_design,
406 weights: Arc::new(weights.clone()),
407 offset: Arc::new(offset.clone()),
408 tensor_penalties,
409 initial_beta,
410 initial_log_lambdas,
411 block_name: "transformation".to_string(),
412 response_knots: resp_knots,
413 response_transform: resp_transform,
414 response_degree: config.response_degree,
415 response_median: resp_median,
416 response_floor_offset: Arc::new(response_floor_offset),
417 response_lower_floor_offset,
418 response_upper_floor_offset,
419 covariate_dense_cache: Arc::new(Mutex::new(None)),
420 row_quantity_cache: Arc::new(Mutex::new(None)),
421 outer_subsample_weights: None,
422 })
423 }
424
425 pub fn from_prebuilt_response_basis(
430 response: &Array1<f64>,
431 response_val_basis: Array2<f64>,
432 response_deriv_basis: Array2<f64>,
433 response_penalties: Vec<Array2<f64>>,
434 response_knots: Array1<f64>,
435 response_degree: usize,
436 response_transform: Array2<f64>,
437 weights: &Array1<f64>,
438 offset: &Array1<f64>,
439 covariate_design: DesignMatrix,
440 covariate_penalties: Vec<PenaltyMatrix>,
441 config: &TransformationNormalConfig,
442 warm_start: Option<&TransformationWarmStart>,
443 ) -> Result<Self, String> {
444 let n = response_val_basis.nrows();
445 if n == 0 {
446 return Err(TransformationNormalError::InvalidInput {
447 reason: "response basis has zero rows".to_string(),
448 }
449 .into());
450 }
451 if response.len() != n {
452 return Err(TransformationNormalError::InvalidInput {
453 reason: format!(
454 "response length {} != response basis rows {}",
455 response.len(),
456 n
457 ),
458 }
459 .into());
460 }
461 if covariate_design.nrows() != n {
462 return Err(TransformationNormalError::InvalidInput {
463 reason: format!(
464 "response basis rows {} != covariate design rows {}",
465 n,
466 covariate_design.nrows()
467 ),
468 }
469 .into());
470 }
471 let p_cov = covariate_design.ncols();
472 if p_cov == 0 {
473 return Err(TransformationNormalError::DesignDegenerate {
474 reason: "covariate design has zero columns".to_string(),
475 }
476 .into());
477 }
478 if weights.len() != n {
479 return Err(TransformationNormalError::InvalidInput {
480 reason: format!(
481 "response basis rows {} != weights length {}",
482 n,
483 weights.len()
484 ),
485 }
486 .into());
487 }
488 if offset.len() != n {
489 return Err(TransformationNormalError::InvalidInput {
490 reason: format!(
491 "response basis rows {} != offset length {}",
492 n,
493 offset.len()
494 ),
495 }
496 .into());
497 }
498 for (i, &weight) in weights.iter().enumerate() {
499 if !weight.is_finite() {
500 return Err(TransformationNormalError::NonFinite {
501 reason: format!("weights[{i}] is not finite: {weight}"),
502 }
503 .into());
504 }
505 if weight < 0.0 {
506 return Err(TransformationNormalError::InvalidInput {
507 reason: format!("weights[{i}] must be non-negative: {weight}"),
508 }
509 .into());
510 }
511 }
512 for (i, &value) in offset.iter().enumerate() {
513 if !value.is_finite() {
514 return Err(TransformationNormalError::NonFinite {
515 reason: format!("offset[{i}] is not finite: {value}"),
516 }
517 .into());
518 }
519 }
520 for (i, sp) in covariate_penalties.iter().enumerate() {
521 let (r, c) = sp.shape();
522 if r != p_cov || c != p_cov {
523 return Err(TransformationNormalError::InvalidInput {
524 reason: format!(
525 "covariate penalty {} has shape ({r}, {c}), expected ({p_cov}, {p_cov})",
526 i,
527 ),
528 }
529 .into());
530 }
531 }
532
533 let p_resp = response_val_basis.ncols();
534 if response_transform.ncols() + 1 != p_resp {
535 return Err(TransformationNormalError::InvalidInput { reason: format!(
536 "response transform columns {} imply p_resp {}, but response value basis has {} columns",
537 response_transform.ncols(),
538 response_transform.ncols() + 1,
539 p_resp
540 ) }.into());
541 }
542 let (response_lower_basis, response_upper_basis) =
543 response_endpoint_value_bases(&response_transform);
544
545 let x_val_kron =
547 KroneckerDesign::new_khatri_rao(&response_val_basis, covariate_design.clone())?;
548 let x_deriv_kron =
549 KroneckerDesign::new_khatri_rao(&response_deriv_basis, covariate_design.clone())?;
550 let p_total = p_resp * p_cov;
551 assert_eq!(x_val_kron.ncols(), p_total);
552 assert_eq!(x_deriv_kron.ncols(), p_total);
553
554 let initial_beta = compute_warm_start(
555 response,
556 weights,
557 offset,
558 &x_val_kron,
559 &x_deriv_kron,
560 &covariate_design,
561 &covariate_penalties,
562 p_resp,
563 p_cov,
564 warm_start,
565 )?;
566
567 let tensor_penalties = build_tensor_penalties_kronecker(
569 &response_penalties,
570 covariate_penalties,
571 p_resp,
572 p_cov,
573 config,
574 )?;
575 let policy = ResourcePolicy::default_library();
576 let x_val_weighted_gram = x_val_kron.weighted_gram(weights, &policy);
577
578 let initial_log_lambdas =
579 ctn_penalty_scale_log_lambdas(&tensor_penalties, &x_val_weighted_gram);
580
581 let mut sorted_resp = response.to_vec();
583 sorted_resp.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
584 let resp_median = if sorted_resp.len() % 2 == 1 {
585 sorted_resp[sorted_resp.len() / 2]
586 } else {
587 0.5 * (sorted_resp[sorted_resp.len() / 2 - 1] + sorted_resp[sorted_resp.len() / 2])
588 };
589 let (response_floor_offset, response_lower_floor_offset, response_upper_floor_offset) =
590 response_floor_offsets(response, &response_knots, resp_median);
591
592 Ok(Self {
593 x_val_kron,
594 x_deriv_kron,
595 response_val_basis,
596 response_lower_basis,
597 response_upper_basis,
598 response_deriv_basis,
599 covariate_design,
600 weights: Arc::new(weights.clone()),
601 offset: Arc::new(offset.clone()),
602 tensor_penalties,
603 initial_beta,
604 initial_log_lambdas,
605 block_name: "transformation".to_string(),
606 response_knots: response_knots.clone(),
607 response_transform: response_transform.clone(),
608 response_degree,
609 response_median: resp_median,
610 response_floor_offset: Arc::new(response_floor_offset),
611 response_lower_floor_offset,
612 response_upper_floor_offset,
613 covariate_dense_cache: Arc::new(Mutex::new(None)),
614 row_quantity_cache: Arc::new(Mutex::new(None)),
615 outer_subsample_weights: None,
616 })
617 }
618
619 pub fn response_knots(&self) -> &Array1<f64> {
621 &self.response_knots
622 }
623 pub fn response_transform(&self) -> &Array2<f64> {
624 &self.response_transform
625 }
626 pub fn response_degree(&self) -> usize {
627 self.response_degree
628 }
629 pub fn response_median(&self) -> f64 {
630 self.response_median
631 }
632
633 pub fn block_spec(&self) -> ParameterBlockSpec {
635 let offset = self.offset.as_ref() + self.response_floor_offset.as_ref();
636 ParameterBlockSpec {
637 name: self.block_name.clone(),
638 design: DesignMatrix::Dense(DenseDesignMatrix::from(Arc::new(self.x_val_kron.clone()))),
639 offset,
640 penalties: self.tensor_penalties.clone(),
641 nullspace_dims: vec![],
642 initial_log_lambdas: self.initial_log_lambdas.clone(),
643 initial_beta: Some(self.initial_beta.clone()),
644 gauge_priority: 100,
645 jacobian_callback: None,
646 stacked_design: None,
647 stacked_offset: None,
648 }
649 }
650
651 pub fn p_total(&self) -> usize {
653 self.x_val_kron.ncols()
654 }
655
656 pub fn n_obs(&self) -> usize {
658 self.x_val_kron.nrows()
659 }
660
661 pub(crate) fn p_resp(&self) -> usize {
663 self.response_val_basis.ncols()
664 }
665
666 pub(crate) fn p_cov(&self) -> usize {
668 self.covariate_design.ncols()
669 }
670
671 pub(crate) fn response_lower_basis(&self) -> &Array1<f64> {
674 &self.response_lower_basis
675 }
676
677 pub(crate) fn response_upper_basis(&self) -> &Array1<f64> {
680 &self.response_upper_basis
681 }
682
683 pub(crate) fn response_lower_floor_offset(&self) -> f64 {
686 self.response_lower_floor_offset
687 }
688
689 pub(crate) fn response_upper_floor_offset(&self) -> f64 {
692 self.response_upper_floor_offset
693 }
694
695 #[inline]
708 pub(crate) fn effective_weights(&self) -> &Array1<f64> {
709 match self.outer_subsample_weights.as_ref() {
710 Some(w) => w.as_ref(),
711 None => self.weights.as_ref(),
712 }
713 }
714
715 pub(crate) fn evaluate_response_value_basis(
725 &self,
726 response: ArrayView1<'_, f64>,
727 ) -> Result<Array2<f64>, String> {
728 let n = response.len();
729 for (i, &v) in response.iter().enumerate() {
730 if !v.is_finite() {
731 return Err(TransformationNormalError::NonFinite {
732 reason: format!(
733 "evaluate_response_value_basis: response[{i}] is not finite: {v}"
734 ),
735 }
736 .into());
737 }
738 }
739 let (i_val_basis, _) = create_basis::<Dense>(
740 response,
741 KnotSource::Provided(self.response_knots.view()),
742 self.response_degree,
743 BasisOptions::i_spline(),
744 )
745 .map_err(|e| format!("evaluate_response_value_basis: I-spline build failed: {e}"))?;
746 let shape_val = i_val_basis.as_ref();
747 let p_shape = shape_val.ncols();
748 let p_resp = self.response_val_basis.ncols();
749 if p_shape + 1 != p_resp {
750 return Err(TransformationNormalError::InvalidInput {
751 reason: format!(
752 "evaluate_response_value_basis: rebuilt shape columns {p_shape} imply p_resp {}, \
753 but fitted basis has {p_resp} columns",
754 p_shape + 1
755 ),
756 }
757 .into());
758 }
759 let mut resp_val = Array2::<f64>::zeros((n, p_resp));
760 resp_val.column_mut(0).fill(1.0);
761 resp_val.slice_mut(s![.., 1..]).assign(shape_val);
762 Ok(resp_val)
763 }
764
765 pub(crate) fn with_outer_subsample(
776 &self,
777 mask: &Array1<f64>,
778 ) -> Result<Self, TransformationNormalError> {
779 let n = self.weights.len();
780 if mask.len() != n {
781 bail_invalid_tnorm!(
782 "outer-score subsample mask length {} != n={}",
783 mask.len(),
784 n
785 );
786 }
787 let mut effective = Array1::<f64>::zeros(n);
788 for i in 0..n {
789 let m = mask[i];
790 if !m.is_finite() || m < 0.0 {
791 bail_invalid_tnorm!(
792 "outer-score subsample mask[{i}] = {m} is invalid (must be finite and >= 0)"
793 );
794 }
795 effective[i] = self.weights[i] * m;
796 }
797 Ok(Self {
798 x_val_kron: self.x_val_kron.clone(),
800 x_deriv_kron: self.x_deriv_kron.clone(),
801 response_val_basis: self.response_val_basis.clone(),
802 response_lower_basis: self.response_lower_basis.clone(),
803 response_upper_basis: self.response_upper_basis.clone(),
804 response_deriv_basis: self.response_deriv_basis.clone(),
805 covariate_design: self.covariate_design.clone(),
806 covariate_dense_cache: Arc::clone(&self.covariate_dense_cache),
807 weights: Arc::clone(&self.weights),
808 offset: Arc::clone(&self.offset),
809 tensor_penalties: self.tensor_penalties.clone(),
810 initial_beta: self.initial_beta.clone(),
811 initial_log_lambdas: self.initial_log_lambdas.clone(),
812 block_name: self.block_name.clone(),
813 response_knots: self.response_knots.clone(),
814 response_transform: self.response_transform.clone(),
815 response_degree: self.response_degree,
816 response_median: self.response_median,
817 response_floor_offset: Arc::clone(&self.response_floor_offset),
818 response_lower_floor_offset: self.response_lower_floor_offset,
819 response_upper_floor_offset: self.response_upper_floor_offset,
820 row_quantity_cache: Arc::new(Mutex::new(None)),
824 outer_subsample_weights: Some(Arc::new(effective)),
825 })
826 }
827
828 pub(crate) fn maybe_with_outer_subsample_from_options(
831 &self,
832 options: &BlockwiseFitOptions,
833 ) -> Result<Option<Self>, TransformationNormalError> {
834 let Some(sub) = options.outer_score_subsample.as_ref() else {
835 return Ok(None);
836 };
837 let n = self.weights.len();
838 let mut mask = Array1::<f64>::zeros(n);
839 for row in sub.rows.iter() {
840 if row.index < n {
841 mask[row.index] = row.weight;
842 }
843 }
844 Ok(Some(self.with_outer_subsample(&mask)?))
845 }
846
847 pub(crate) fn covariate_dense_arc(&self) -> Result<Arc<Array2<f64>>, String> {
850 let mut cache = self
851 .covariate_dense_cache
852 .lock()
853 .expect("CTN covariate dense cache mutex poisoned");
854 if let Some(cached) = cache.as_ref() {
855 return Ok(cached.clone());
856 }
857 let dense = Arc::new(
858 self.covariate_design
859 .try_row_chunk(0..self.response_val_basis.nrows())
860 .map_err(|e| format!("SCOP covariate dense materialization failed: {e}"))?,
861 );
862 *cache = Some(dense.clone());
863 Ok(dense)
864 }
865
866 pub(crate) fn row_quantities(
867 &self,
868 beta: &Array1<f64>,
869 ) -> Result<TransformationNormalRowQuantityCache, String> {
870 {
871 let cache = self
872 .row_quantity_cache
873 .lock()
874 .expect("CTN row quantity cache mutex poisoned");
875 if let Some(cached) = cache.as_ref().filter(|cached| cached.matches_beta(beta)) {
876 return Ok(cached.clone());
877 }
878 }
879
880 let p_resp = self.response_val_basis.ncols();
881 let p_cov = self.covariate_design.ncols();
882 let beta_mat = beta
883 .view()
884 .into_shape_with_order((p_resp, p_cov))
885 .map_err(|e| format!("SCOP endpoint beta reshape failed: {e}"))?;
886 let cov = self.covariate_dense_arc()?;
887
888 let gamma = fast_abt(cov.as_ref(), &beta_mat);
898 let n = gamma.nrows();
899 let mut h = Array1::<f64>::zeros(n);
900 let mut h_prime = Array1::<f64>::zeros(n);
901 let mut h_lower = Array1::<f64>::zeros(n);
902 let mut h_upper = Array1::<f64>::zeros(n);
903 ndarray::Zip::indexed(&mut h)
908 .and(&mut h_prime)
909 .and(&mut h_lower)
910 .and(&mut h_upper)
911 .par_for_each(|i, h_i, hp_i, lower_i, upper_i| {
912 let gamma_row = gamma.row(i);
913 let val_row = self.response_val_basis.row(i);
914 let deriv_row = self.response_deriv_basis.row(i);
915 let g0 = gamma_row[0];
916 let offset_i = self.offset[i];
917 let mut h_acc = val_row[0] * g0 + offset_i + self.response_floor_offset[i];
918 let mut hp_acc = deriv_row[0] * g0 + TRANSFORMATION_MONOTONICITY_EPS;
919 let mut lower_acc =
920 self.response_lower_basis[0] * g0 + offset_i + self.response_lower_floor_offset;
921 let mut upper_acc =
922 self.response_upper_basis[0] * g0 + offset_i + self.response_upper_floor_offset;
923 for k in 1..p_resp {
924 let g_sq = gamma_row[k] * gamma_row[k];
925 h_acc += val_row[k] * g_sq;
926 hp_acc += deriv_row[k] * g_sq;
927 lower_acc += self.response_lower_basis[k] * g_sq;
928 upper_acc += self.response_upper_basis[k] * g_sq;
929 }
930 *h_i = h_acc;
931 *hp_i = hp_acc;
932 *lower_i = lower_acc;
933 *upper_i = upper_acc;
934 });
935 for (i, &value) in h.iter().enumerate() {
936 if !value.is_finite() {
937 return Err(TransformationNormalError::NonFinite {
938 reason: format!(
939 "TransformationNormalFamily row_quantities: h[{i}] = {value} is not finite"
940 ),
941 }
942 .into());
943 }
944 if value.abs() > TRANSFORMATION_NORMAL_H_ABS_MAX {
945 return Err(TransformationNormalError::InvalidInput { reason: format!(
946 "TransformationNormalFamily row_quantities: h[{i}] = {value:.6e} exceeds the standard-normal domain bound ±{TRANSFORMATION_NORMAL_H_ABS_MAX}"
947 ) }.into());
948 }
949 }
950 let mut min_hp = f64::INFINITY;
962 let mut nonfinite_idx: Option<usize> = None;
963 for (i, &hp) in h_prime.iter().enumerate() {
964 if !hp.is_finite() {
965 nonfinite_idx = Some(i);
966 break;
967 }
968 if hp < min_hp {
969 min_hp = hp;
970 }
971 }
972 if let Some(i) = nonfinite_idx {
973 return Err(TransformationNormalError::NonFinite {
974 reason: format!(
975 "TransformationNormalFamily row_quantities: h'[{i}] = {} is not finite",
976 h_prime[i]
977 ),
978 }
979 .into());
980 }
981 if min_hp <= 0.0 {
982 return Err(TransformationNormalError::MonotonicityViolated { reason: format!(
983 "TransformationNormalFamily row_quantities: h' has non-positive values (min = {min_hp:.6e}). \
984 Monotonicity constraint may be violated."
985 ) }.into());
986 }
987 let derived = build_transformation_row_derived(
992 &h,
993 &h_prime,
994 &h_lower,
995 &h_upper,
996 self.effective_weights(),
997 )?;
998 let row_quantities = TransformationNormalRowQuantityCache {
999 beta: Arc::new(beta.clone()),
1000 gamma: Arc::new(gamma),
1001 h: Arc::new(h),
1002 h_prime: Arc::new(h_prime),
1003 h_lower: Arc::new(h_lower),
1004 h_upper: Arc::new(h_upper),
1005 endpoint_q: Arc::new(derived.endpoint_q),
1006 log_likelihood: derived.log_likelihood,
1007 };
1008
1009 let mut cache = self
1010 .row_quantity_cache
1011 .lock()
1012 .expect("CTN row quantity cache mutex poisoned");
1013 *cache = Some(row_quantities.clone());
1014 Ok(row_quantities)
1015 }
1016}