Skip to main content

gam_models/transformation_normal/
family.rs

1use super::*;
2
3pub(crate) fn beta_bits_match(cached: &Array1<f64>, candidate: &Array1<f64>) -> bool {
4    cached.len() == candidate.len()
5        && cached
6            .iter()
7            .zip(candidate.iter())
8            .all(|(&left, &right)| left.to_bits() == right.to_bits())
9}
10
11/// Optional warm-start for the transformation model: per-observation location and
12/// scale values from a prior mean/SD normalizer.
13#[derive(Clone, Debug)]
14pub struct TransformationWarmStart {
15    /// μ(x_i): conditional mean of the response at each observation's covariates.
16    pub location: Array1<f64>,
17    /// τ(x_i): conditional standard deviation at each observation's covariates.
18    pub scale: Array1<f64>,
19}
20
21// ---------------------------------------------------------------------------
22// The family
23// ---------------------------------------------------------------------------
24
25/// Conditional transformation model mapping Y|x to N(0,1).
26///
27/// Single-block `CustomFamily`. The block design is `x_val` (tensor product of
28/// response value basis × covariate design). The family internally holds `x_deriv`
29/// (tensor product of response derivative basis × covariate design) for the
30/// Jacobian term in the likelihood.
31#[derive(Clone)]
32pub struct TransformationNormalFamily {
33    // --- Tensor product design matrices ---
34    /// Value design operator: keeps the tensor factors separate and materializes
35    /// only row chunks or explicitly requested dense diagnostics.
36    pub(crate) x_val_kron: KroneckerDesign,
37    /// Derivative design operator: keeps the tensor factors separate.
38    pub(crate) x_deriv_kron: KroneckerDesign,
39    // --- Response-direction basis (fixed, does not depend on κ) ---
40    /// Response value basis: n × p_resp. Columns: [1, I_1(y), ..., I_k(y)].
41    pub(crate) response_val_basis: Array2<f64>,
42    /// Response value basis at the finite lower support endpoint.
43    pub(crate) response_lower_basis: Array1<f64>,
44    /// Response value basis at the finite upper support endpoint.
45    pub(crate) response_upper_basis: Array1<f64>,
46    /// Response derivative basis: n × p_resp. Columns: [0, M_1(y), ..., M_k(y)].
47    pub(crate) response_deriv_basis: Array2<f64>,
48
49    // --- Covariate side (rebuilt on κ change) ---
50    /// Original covariate design used on the right side of the tensor product.
51    pub(crate) covariate_design: DesignMatrix,
52    /// Dense covariate block shared by row-quantity and endpoint evaluations.
53    ///
54    /// CTN row quantities are rebuilt at every accepted/probed β, but the
55    /// covariate design is fixed for the family. Caching this immutable
56    /// `n × p_cov` block avoids repeated chunk materialization and keeps
57    /// large-scale runs from churning large transient allocations.
58    pub(crate) covariate_dense_cache: Arc<Mutex<Option<Arc<Array2<f64>>>>>,
59    /// Optional non-negative row weights folded directly into the likelihood.
60    pub(crate) weights: Arc<Array1<f64>>,
61    /// Additive offset for the transformation linear predictor.
62    pub(crate) offset: Arc<Array1<f64>>,
63    // --- Tensor penalties ---
64    pub(crate) tensor_penalties: Vec<PenaltyMatrix>,
65
66    // --- Initial values ---
67    pub(crate) initial_beta: Array1<f64>,
68    pub(crate) initial_log_lambdas: Array1<f64>,
69
70    // --- Config ---
71    pub(crate) block_name: String,
72
73    // --- Response basis metadata (for reconstruction at predict time) ---
74    pub(crate) response_knots: Array1<f64>,
75    pub(crate) response_transform: Array2<f64>,
76    pub(crate) response_degree: usize,
77    pub(crate) response_median: f64,
78    pub(crate) response_floor_offset: Arc<Array1<f64>>,
79    pub(crate) response_lower_floor_offset: f64,
80    pub(crate) response_upper_floor_offset: f64,
81
82    /// Last row-space transformation quantities for an exact beta vector.
83    ///
84    /// CTN line searches and exact-Newton workspace construction frequently ask
85    /// for likelihood, gradient, and Hessian row factors at the same candidate
86    /// coefficients. This cache keeps the expensive Khatri-Rao forward products
87    /// and reciprocal powers behind a single exact-keyed entry instead of
88    /// recomputing `h`, `h'`, `1/h'`, and derivative powers per call.
89    pub(crate) row_quantity_cache: Arc<Mutex<Option<TransformationNormalRowQuantityCache>>>,
90    /// Optional outer-score Horvitz-Thompson per-row weights.
91    ///
92    /// When present, this is an `n`-vector equal to the original `weights`
93    /// pre-multiplied row-wise by the HT inverse-inclusion multiplier `m_i`
94    /// (`m_i = 1/π_i` on sampled rows, `0.0` on unsampled rows). Assembly
95    /// sites read row weights via [`Self::effective_weights`], which returns
96    /// this array when present and `self.weights` otherwise. Because every
97    /// per-row CTN contribution is linear in `w_i`, masking at this site
98    /// gives `E[Σ_i (m_i · w_i) · f(row_i)] = Σ_i w_i · f(row_i) = full-sum`
99    /// — i.e. an unbiased estimator across log-likelihood, gradient, joint
100    /// Hessian (dense / matvec / diagonal), ψ, and ψ-ψ kernels.
101    ///
102    /// `None` preserves byte-identical legacy behavior (`effective_weights`
103    /// returns the original `weights` array).
104    pub(crate) outer_subsample_weights: Option<Arc<Array1<f64>>>,
105}
106
107#[derive(Clone)]
108pub(crate) struct TransformationNormalRowQuantityCache {
109    pub(crate) beta: Arc<Array1<f64>>,
110    pub(crate) gamma: Arc<Array2<f64>>,
111    pub(crate) h: Arc<Array1<f64>>,
112    pub(crate) h_prime: Arc<Array1<f64>>,
113    pub(crate) h_lower: Arc<Array1<f64>>,
114    pub(crate) h_upper: Arc<Array1<f64>>,
115    pub(crate) endpoint_q: Arc<Vec<LogNormalCdfDiffDerivatives>>,
116    pub(crate) log_likelihood: f64,
117}
118
119#[derive(Debug)]
120pub(crate) struct TransformationNormalRowDerived {
121    pub(crate) log_likelihood: f64,
122    pub(crate) endpoint_q: Vec<LogNormalCdfDiffDerivatives>,
123}
124
125impl TransformationNormalRowQuantityCache {
126    pub(crate) fn matches_beta(&self, beta: &Array1<f64>) -> bool {
127        beta_bits_match(&self.beta, beta)
128    }
129}
130
131pub(crate) fn build_transformation_row_derived(
132    h: &Array1<f64>,
133    h_prime: &Array1<f64>,
134    h_lower: &Array1<f64>,
135    h_upper: &Array1<f64>,
136    weights: &Array1<f64>,
137) -> Result<TransformationNormalRowDerived, String> {
138    let n = h_prime.len();
139    assert_eq!(h.len(), n);
140    assert_eq!(h_lower.len(), n);
141    assert_eq!(h_upper.len(), n);
142    assert_eq!(weights.len(), n);
143
144    if let Some((i, value)) = h
145        .iter()
146        .copied()
147        .enumerate()
148        .find(|(_, value)| !value.is_finite())
149    {
150        return Err(TransformationNormalError::NonFinite {
151            reason: format!(
152                "TransformationNormalFamily row_quantities: h[{i}] = {value} is not finite"
153            ),
154        }
155        .into());
156    }
157    if let Some((i, value)) = weights
158        .iter()
159        .copied()
160        .enumerate()
161        .find(|(_, value)| !value.is_finite())
162    {
163        return Err(TransformationNormalError::NonFinite {
164            reason: format!(
165                "TransformationNormalFamily row_quantities: weight[{i}] = {value} is not finite"
166            ),
167        }
168        .into());
169    }
170
171    // Parallelize the per-row endpoint-normalizer build: each row runs
172    // `log_normal_cdf_diff_derivatives` (two `normal_logcdf` calls, three
173    // 5x5 truncated polynomial multiplies, 32 `signed_normal_pdf_ratio`
174    // calls) which dominates this function's runtime at large scale.
175    // Rows are fully independent — no shared state, no OnceLock guards —
176    // and `LogNormalCdfDiffDerivatives` is a POD struct that's `Send`.
177    // The fast finiteness check rolls all eight derived quantities into
178    // a single short-circuit `||` chain so the named-field error format
179    // only runs on the non-finite slow path.
180    use rayon::iter::{IntoParallelIterator, ParallelIterator};
181    let rows: Vec<(f64, LogNormalCdfDiffDerivatives)> = (0..n)
182        .into_par_iter()
183        .map(|i| -> Result<(f64, LogNormalCdfDiffDerivatives), String> {
184            let hp = h_prime[i];
185            let inv_h_prime = 1.0 / hp;
186            let inv_h_prime_sq = inv_h_prime * inv_h_prime;
187            let inv_h_prime_cu = inv_h_prime_sq * inv_h_prime;
188            let inv_h_prime_qu = inv_h_prime_sq * inv_h_prime_sq;
189            let w_i = weights[i];
190            let h_i = h[i];
191            let weighted_h = w_i * h_i;
192            let weighted_inv_h_prime = w_i * inv_h_prime;
193            let weighted_inv_h_prime_sq = w_i * inv_h_prime_sq;
194            let q = log_normal_cdf_diff_derivatives(h_upper[i], h_lower[i]).map_err(|e| {
195                format!("TransformationNormalFamily row_quantities: row {i} invalid endpoint normalizer: {e}")
196            })?;
197            let log_z = q.log_z;
198            let row_ll = w_i * (-0.5 * h_i * h_i + hp.ln() - log_z);
199            // Fast path: a single short-circuited finiteness check. Only
200            // when something is non-finite do we walk the named-field
201            // table to produce a precise diagnostic.
202            if !(inv_h_prime.is_finite()
203                && inv_h_prime_sq.is_finite()
204                && inv_h_prime_cu.is_finite()
205                && inv_h_prime_qu.is_finite()
206                && weighted_h.is_finite()
207                && weighted_inv_h_prime.is_finite()
208                && weighted_inv_h_prime_sq.is_finite()
209                && log_z.is_finite())
210            {
211                let derived_values = [
212                    ("1/h'", inv_h_prime),
213                    ("1/h'^2", inv_h_prime_sq),
214                    ("1/h'^3", inv_h_prime_cu),
215                    ("1/h'^4", inv_h_prime_qu),
216                    ("w*h", weighted_h),
217                    ("w/h'", weighted_inv_h_prime),
218                    ("w/h'^2", weighted_inv_h_prime_sq),
219                    ("log normalizer", log_z),
220                ];
221                for (name, value) in derived_values {
222                    if !value.is_finite() {
223                        return Err(TransformationNormalError::NonFinite { reason: format!(
224                            "TransformationNormalFamily row_quantities: {name} at row {i} is not finite ({value}); h'={hp} is outside the finite exact-derivative range",
225                        ) }.into());
226                    }
227                }
228                return Err(TransformationNormalError::NonFinite { reason: format!(
229                    "TransformationNormalFamily row_quantities: row {i} entered non-finite branch but no named field was non-finite; h'={hp}",
230                ) }.into());
231            }
232            Ok((row_ll, q))
233        })
234        .collect::<Result<Vec<_>, _>>()?;
235
236    // Sum row contributions in index order so the result is bit-identical
237    // to the previous serial accumulation. The parallel section above only
238    // parallelized the independent per-row computation; the final scalar
239    // reduction stays serial to preserve numerical reproducibility against
240    // existing tests.
241    let mut log_likelihood = 0.0;
242    let mut endpoint_q = Vec::with_capacity(n);
243    for (row_ll, q) in rows {
244        log_likelihood += row_ll;
245        endpoint_q.push(q);
246    }
247    if !log_likelihood.is_finite() {
248        return Err(TransformationNormalError::NonFinite { reason: format!(
249            "TransformationNormalFamily row_quantities: log-likelihood is not finite ({log_likelihood})"
250        ) }.into());
251    }
252
253    Ok(TransformationNormalRowDerived {
254        log_likelihood,
255        endpoint_q,
256    })
257}
258
259impl TransformationNormalFamily {
260    /// Build a transformation model from response values and a pre-built covariate
261    /// design operator with associated penalties.
262    ///
263    /// # Arguments
264    ///
265    /// * `response` - The response variable y (n observations).
266    /// * `covariate_design` - Pre-built covariate-side design operator (n × p_cov).
267    /// * `covariate_penalties` - Penalty matrices for the covariate basis.
268    /// * `config` - Response-direction basis configuration.
269    /// * `warm_start` - Optional location/scale from a prior normalizer.
270    pub fn new(
271        response: &Array1<f64>,
272        weights: &Array1<f64>,
273        offset: &Array1<f64>,
274        covariate_design: DesignMatrix,
275        covariate_penalties: Vec<PenaltyMatrix>,
276        config: &TransformationNormalConfig,
277        warm_start: Option<&TransformationWarmStart>,
278    ) -> Result<Self, String> {
279        let n = response.len();
280        if covariate_design.nrows() != n {
281            return Err(TransformationNormalError::InvalidInput {
282                reason: format!(
283                    "response length {} != covariate design rows {}",
284                    n,
285                    covariate_design.nrows()
286                ),
287            }
288            .into());
289        }
290        let p_cov = covariate_design.ncols();
291        if p_cov == 0 {
292            return Err(TransformationNormalError::DesignDegenerate {
293                reason: "covariate design has zero columns".to_string(),
294            }
295            .into());
296        }
297        if weights.len() != n {
298            return Err(TransformationNormalError::InvalidInput {
299                reason: format!("response length {} != weights length {}", n, weights.len()),
300            }
301            .into());
302        }
303        if offset.len() != n {
304            return Err(TransformationNormalError::InvalidInput {
305                reason: format!("response length {} != offset length {}", n, offset.len()),
306            }
307            .into());
308        }
309        for (i, &weight) in weights.iter().enumerate() {
310            if !weight.is_finite() {
311                return Err(TransformationNormalError::NonFinite {
312                    reason: format!("weights[{i}] is not finite: {weight}"),
313                }
314                .into());
315            }
316            if weight < 0.0 {
317                return Err(TransformationNormalError::InvalidInput {
318                    reason: format!("weights[{i}] must be non-negative: {weight}"),
319                }
320                .into());
321            }
322        }
323        for (i, &value) in offset.iter().enumerate() {
324            if !value.is_finite() {
325                return Err(TransformationNormalError::NonFinite {
326                    reason: format!("offset[{i}] is not finite: {value}"),
327                }
328                .into());
329            }
330        }
331        for (i, sp) in covariate_penalties.iter().enumerate() {
332            let (r, c) = sp.shape();
333            if r != p_cov || c != p_cov {
334                return Err(TransformationNormalError::InvalidInput {
335                    reason: format!(
336                        "covariate penalty {} has shape ({r}, {c}), expected ({p_cov}, {p_cov})",
337                        i,
338                    ),
339                }
340                .into());
341            }
342        }
343
344        // ----- 1. Build response-direction basis -----
345        let (resp_val, resp_deriv, resp_penalties, resp_knots, resp_transform) =
346            build_response_basis(response, config)?;
347        let p_resp = resp_val.ncols();
348        let (response_lower_basis, response_upper_basis) =
349            response_endpoint_value_bases(&resp_transform);
350
351        // ----- 2. Row-wise Kronecker product (operator form) -----
352        let x_val_kron = KroneckerDesign::new_khatri_rao(&resp_val, covariate_design.clone())?;
353        let x_deriv_kron = KroneckerDesign::new_khatri_rao(&resp_deriv, covariate_design.clone())?;
354        let p_total = p_resp * p_cov;
355        assert_eq!(x_val_kron.ncols(), p_total);
356        assert_eq!(x_deriv_kron.ncols(), p_total);
357
358        // ----- 3. Warm start -----
359        let initial_beta = compute_warm_start(
360            response,
361            weights,
362            offset,
363            &x_val_kron,
364            &x_deriv_kron,
365            &covariate_design,
366            &covariate_penalties,
367            p_resp,
368            p_cov,
369            warm_start,
370        )?;
371
372        // ----- 4. Tensor penalties (Kronecker-separable) -----
373        let tensor_penalties = build_tensor_penalties_kronecker(
374            &resp_penalties,
375            covariate_penalties,
376            p_resp,
377            p_cov,
378            config,
379        )?;
380        let policy = ResourcePolicy::default_library();
381        let x_val_weighted_gram = x_val_kron.weighted_gram(weights, &policy);
382
383        // ----- 5. CTN-specific smoothing seed from likelihood/penalty scales -----
384        let initial_log_lambdas =
385            ctn_penalty_scale_log_lambdas(&tensor_penalties, &x_val_weighted_gram);
386
387        // Compute response median for anchoring
388        let mut sorted_resp = response.to_vec();
389        sorted_resp.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
390        let resp_median = if sorted_resp.len() % 2 == 1 {
391            sorted_resp[sorted_resp.len() / 2]
392        } else {
393            0.5 * (sorted_resp[sorted_resp.len() / 2 - 1] + sorted_resp[sorted_resp.len() / 2])
394        };
395        let (response_floor_offset, response_lower_floor_offset, response_upper_floor_offset) =
396            response_floor_offsets(response, &resp_knots, resp_median);
397
398        Ok(Self {
399            x_val_kron,
400            x_deriv_kron,
401            response_val_basis: resp_val,
402            response_lower_basis,
403            response_upper_basis,
404            response_deriv_basis: resp_deriv,
405            covariate_design,
406            weights: Arc::new(weights.clone()),
407            offset: Arc::new(offset.clone()),
408            tensor_penalties,
409            initial_beta,
410            initial_log_lambdas,
411            block_name: "transformation".to_string(),
412            response_knots: resp_knots,
413            response_transform: resp_transform,
414            response_degree: config.response_degree,
415            response_median: resp_median,
416            response_floor_offset: Arc::new(response_floor_offset),
417            response_lower_floor_offset,
418            response_upper_floor_offset,
419            covariate_dense_cache: Arc::new(Mutex::new(None)),
420            row_quantity_cache: Arc::new(Mutex::new(None)),
421            outer_subsample_weights: None,
422        })
423    }
424
425    /// Build from a prebuilt response basis, skipping response basis construction.
426    ///
427    /// For the outer loop where the response basis is precomputed once and reused
428    /// across κ iterations.
429    pub fn from_prebuilt_response_basis(
430        response: &Array1<f64>,
431        response_val_basis: Array2<f64>,
432        response_deriv_basis: Array2<f64>,
433        response_penalties: Vec<Array2<f64>>,
434        response_knots: Array1<f64>,
435        response_degree: usize,
436        response_transform: Array2<f64>,
437        weights: &Array1<f64>,
438        offset: &Array1<f64>,
439        covariate_design: DesignMatrix,
440        covariate_penalties: Vec<PenaltyMatrix>,
441        config: &TransformationNormalConfig,
442        warm_start: Option<&TransformationWarmStart>,
443    ) -> Result<Self, String> {
444        let n = response_val_basis.nrows();
445        if n == 0 {
446            return Err(TransformationNormalError::InvalidInput {
447                reason: "response basis has zero rows".to_string(),
448            }
449            .into());
450        }
451        if response.len() != n {
452            return Err(TransformationNormalError::InvalidInput {
453                reason: format!(
454                    "response length {} != response basis rows {}",
455                    response.len(),
456                    n
457                ),
458            }
459            .into());
460        }
461        if covariate_design.nrows() != n {
462            return Err(TransformationNormalError::InvalidInput {
463                reason: format!(
464                    "response basis rows {} != covariate design rows {}",
465                    n,
466                    covariate_design.nrows()
467                ),
468            }
469            .into());
470        }
471        let p_cov = covariate_design.ncols();
472        if p_cov == 0 {
473            return Err(TransformationNormalError::DesignDegenerate {
474                reason: "covariate design has zero columns".to_string(),
475            }
476            .into());
477        }
478        if weights.len() != n {
479            return Err(TransformationNormalError::InvalidInput {
480                reason: format!(
481                    "response basis rows {} != weights length {}",
482                    n,
483                    weights.len()
484                ),
485            }
486            .into());
487        }
488        if offset.len() != n {
489            return Err(TransformationNormalError::InvalidInput {
490                reason: format!(
491                    "response basis rows {} != offset length {}",
492                    n,
493                    offset.len()
494                ),
495            }
496            .into());
497        }
498        for (i, &weight) in weights.iter().enumerate() {
499            if !weight.is_finite() {
500                return Err(TransformationNormalError::NonFinite {
501                    reason: format!("weights[{i}] is not finite: {weight}"),
502                }
503                .into());
504            }
505            if weight < 0.0 {
506                return Err(TransformationNormalError::InvalidInput {
507                    reason: format!("weights[{i}] must be non-negative: {weight}"),
508                }
509                .into());
510            }
511        }
512        for (i, &value) in offset.iter().enumerate() {
513            if !value.is_finite() {
514                return Err(TransformationNormalError::NonFinite {
515                    reason: format!("offset[{i}] is not finite: {value}"),
516                }
517                .into());
518            }
519        }
520        for (i, sp) in covariate_penalties.iter().enumerate() {
521            let (r, c) = sp.shape();
522            if r != p_cov || c != p_cov {
523                return Err(TransformationNormalError::InvalidInput {
524                    reason: format!(
525                        "covariate penalty {} has shape ({r}, {c}), expected ({p_cov}, {p_cov})",
526                        i,
527                    ),
528                }
529                .into());
530            }
531        }
532
533        let p_resp = response_val_basis.ncols();
534        if response_transform.ncols() + 1 != p_resp {
535            return Err(TransformationNormalError::InvalidInput { reason: format!(
536                "response transform columns {} imply p_resp {}, but response value basis has {} columns",
537                response_transform.ncols(),
538                response_transform.ncols() + 1,
539                p_resp
540            ) }.into());
541        }
542        let (response_lower_basis, response_upper_basis) =
543            response_endpoint_value_bases(&response_transform);
544
545        // Row-wise Kronecker product (operator form).
546        let x_val_kron =
547            KroneckerDesign::new_khatri_rao(&response_val_basis, covariate_design.clone())?;
548        let x_deriv_kron =
549            KroneckerDesign::new_khatri_rao(&response_deriv_basis, covariate_design.clone())?;
550        let p_total = p_resp * p_cov;
551        assert_eq!(x_val_kron.ncols(), p_total);
552        assert_eq!(x_deriv_kron.ncols(), p_total);
553
554        let initial_beta = compute_warm_start(
555            response,
556            weights,
557            offset,
558            &x_val_kron,
559            &x_deriv_kron,
560            &covariate_design,
561            &covariate_penalties,
562            p_resp,
563            p_cov,
564            warm_start,
565        )?;
566
567        // Tensor penalties (Kronecker-separable).
568        let tensor_penalties = build_tensor_penalties_kronecker(
569            &response_penalties,
570            covariate_penalties,
571            p_resp,
572            p_cov,
573            config,
574        )?;
575        let policy = ResourcePolicy::default_library();
576        let x_val_weighted_gram = x_val_kron.weighted_gram(weights, &policy);
577
578        let initial_log_lambdas =
579            ctn_penalty_scale_log_lambdas(&tensor_penalties, &x_val_weighted_gram);
580
581        // Compute response median.
582        let mut sorted_resp = response.to_vec();
583        sorted_resp.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
584        let resp_median = if sorted_resp.len() % 2 == 1 {
585            sorted_resp[sorted_resp.len() / 2]
586        } else {
587            0.5 * (sorted_resp[sorted_resp.len() / 2 - 1] + sorted_resp[sorted_resp.len() / 2])
588        };
589        let (response_floor_offset, response_lower_floor_offset, response_upper_floor_offset) =
590            response_floor_offsets(response, &response_knots, resp_median);
591
592        Ok(Self {
593            x_val_kron,
594            x_deriv_kron,
595            response_val_basis,
596            response_lower_basis,
597            response_upper_basis,
598            response_deriv_basis,
599            covariate_design,
600            weights: Arc::new(weights.clone()),
601            offset: Arc::new(offset.clone()),
602            tensor_penalties,
603            initial_beta,
604            initial_log_lambdas,
605            block_name: "transformation".to_string(),
606            response_knots: response_knots.clone(),
607            response_transform: response_transform.clone(),
608            response_degree,
609            response_median: resp_median,
610            response_floor_offset: Arc::new(response_floor_offset),
611            response_lower_floor_offset,
612            response_upper_floor_offset,
613            covariate_dense_cache: Arc::new(Mutex::new(None)),
614            row_quantity_cache: Arc::new(Mutex::new(None)),
615            outer_subsample_weights: None,
616        })
617    }
618
619    /// Response basis metadata for serialization/prediction.
620    pub fn response_knots(&self) -> &Array1<f64> {
621        &self.response_knots
622    }
623    pub fn response_transform(&self) -> &Array2<f64> {
624        &self.response_transform
625    }
626    pub fn response_degree(&self) -> usize {
627        self.response_degree
628    }
629    pub fn response_median(&self) -> f64 {
630        self.response_median
631    }
632
633    /// Return the `ParameterBlockSpec` for this family (single block).
634    pub fn block_spec(&self) -> ParameterBlockSpec {
635        let offset = self.offset.as_ref() + self.response_floor_offset.as_ref();
636        ParameterBlockSpec {
637            name: self.block_name.clone(),
638            design: DesignMatrix::Dense(DenseDesignMatrix::from(Arc::new(self.x_val_kron.clone()))),
639            offset,
640            penalties: self.tensor_penalties.clone(),
641            nullspace_dims: vec![],
642            initial_log_lambdas: self.initial_log_lambdas.clone(),
643            initial_beta: Some(self.initial_beta.clone()),
644            gauge_priority: 100,
645            jacobian_callback: None,
646            stacked_design: None,
647            stacked_offset: None,
648        }
649    }
650
651    /// Total number of coefficients.
652    pub fn p_total(&self) -> usize {
653        self.x_val_kron.ncols()
654    }
655
656    /// Number of observations.
657    pub fn n_obs(&self) -> usize {
658        self.x_val_kron.nrows()
659    }
660
661    /// Number of response-direction basis columns `p_resp` (`[1, I_1, …, I_K]`).
662    pub(crate) fn p_resp(&self) -> usize {
663        self.response_val_basis.ncols()
664    }
665
666    /// Number of covariate-side design columns `p_cov`.
667    pub(crate) fn p_cov(&self) -> usize {
668        self.covariate_design.ncols()
669    }
670
671    /// Response value basis evaluated at the finite lower support endpoint
672    /// (row-independent; `[1, I_1(y_min), …, I_K(y_min)]`).
673    pub(crate) fn response_lower_basis(&self) -> &Array1<f64> {
674        &self.response_lower_basis
675    }
676
677    /// Response value basis evaluated at the finite upper support endpoint
678    /// (row-independent; `[1, I_1(y_max), …, I_K(y_max)]`).
679    pub(crate) fn response_upper_basis(&self) -> &Array1<f64> {
680        &self.response_upper_basis
681    }
682
683    /// Monotonicity floor offset applied to the lower-endpoint score
684    /// `ε·(y_min − median_y)`.
685    pub(crate) fn response_lower_floor_offset(&self) -> f64 {
686        self.response_lower_floor_offset
687    }
688
689    /// Monotonicity floor offset applied to the upper-endpoint score
690    /// `ε·(y_max − median_y)`.
691    pub(crate) fn response_upper_floor_offset(&self) -> f64 {
692        self.response_upper_floor_offset
693    }
694
695    /// Per-row weight array used by every row-streaming SCOP assembly site.
696    ///
697    /// Returns the masked HT weights when an outer-score subsample is active
698    /// (`outer_subsample_weights = Some(_)`), else the original `weights`.
699    ///
700    /// Math invariant: every CTN per-row contribution to the gradient,
701    /// negative-Hessian, ψ-term, ψ-ψ-term, and log-likelihood is **linear**
702    /// in this scalar — i.e. each `for i in 0..n` step is of the form
703    /// `wᵢ · g(row_quantities_i, β)` with `wᵢ` appearing to the first power
704    /// only. Replacing `wᵢ` with `wᵢ · m_i` (where `m_i = 1/πᵢ` on sampled
705    /// rows and `0` on unsampled) yields an unbiased Horvitz-Thompson
706    /// estimator: `E[Σᵢ mᵢ wᵢ g(row_i)] = Σᵢ wᵢ g(row_i) = full sum`.
707    #[inline]
708    pub(crate) fn effective_weights(&self) -> &Array1<f64> {
709        match self.outer_subsample_weights.as_ref() {
710            Some(w) => w.as_ref(),
711            None => self.weights.as_ref(),
712        }
713    }
714
715    /// Evaluate the response value basis `[1, I_1(y), …, I_K(y)]` (n × p_resp)
716    /// at arbitrary response values using the *fitted* clamped knots and degree.
717    ///
718    /// This is the out-of-sample analogue of the in-sample `response_val_basis`:
719    /// it reuses the exact I-spline kernel and stored knot vector so that the
720    /// score-influence Jacobian (and any other predict-time geometry) evaluates
721    /// `I_k(y)` consistently with how `h` was built during the fit. Knots are
722    /// taken from the family (not re-derived from `response`), so the basis is
723    /// identical to training whenever the response values coincide.
724    pub(crate) fn evaluate_response_value_basis(
725        &self,
726        response: ArrayView1<'_, f64>,
727    ) -> Result<Array2<f64>, String> {
728        let n = response.len();
729        for (i, &v) in response.iter().enumerate() {
730            if !v.is_finite() {
731                return Err(TransformationNormalError::NonFinite {
732                    reason: format!(
733                        "evaluate_response_value_basis: response[{i}] is not finite: {v}"
734                    ),
735                }
736                .into());
737            }
738        }
739        let (i_val_basis, _) = create_basis::<Dense>(
740            response,
741            KnotSource::Provided(self.response_knots.view()),
742            self.response_degree,
743            BasisOptions::i_spline(),
744        )
745        .map_err(|e| format!("evaluate_response_value_basis: I-spline build failed: {e}"))?;
746        let shape_val = i_val_basis.as_ref();
747        let p_shape = shape_val.ncols();
748        let p_resp = self.response_val_basis.ncols();
749        if p_shape + 1 != p_resp {
750            return Err(TransformationNormalError::InvalidInput {
751                reason: format!(
752                    "evaluate_response_value_basis: rebuilt shape columns {p_shape} imply p_resp {}, \
753                     but fitted basis has {p_resp} columns",
754                    p_shape + 1
755                ),
756            }
757            .into());
758        }
759        let mut resp_val = Array2::<f64>::zeros((n, p_resp));
760        resp_val.column_mut(0).fill(1.0);
761        resp_val.slice_mut(s![.., 1..]).assign(shape_val);
762        Ok(resp_val)
763    }
764
765    /// Clone the family with an outer-score Horvitz-Thompson mask installed.
766    ///
767    /// The mask `m` (length `n`) is `1/πᵢ` for sampled rows and `0.0` for
768    /// unsampled. The returned family carries `outer_subsample_weights =
769    /// Some(weights ⊙ m)`. The row-quantity cache and persistent dense
770    /// Hessian cache are reset (they were keyed on β alone; the masked
771    /// family's `log_likelihood` and Hessian differ from the full-data
772    /// build at the same β so they must not alias). The subsample hash is
773    /// computed over `m` so that two distinct masks at the same β never
774    /// share a cache entry.
775    pub(crate) fn with_outer_subsample(
776        &self,
777        mask: &Array1<f64>,
778    ) -> Result<Self, TransformationNormalError> {
779        let n = self.weights.len();
780        if mask.len() != n {
781            bail_invalid_tnorm!(
782                "outer-score subsample mask length {} != n={}",
783                mask.len(),
784                n
785            );
786        }
787        let mut effective = Array1::<f64>::zeros(n);
788        for i in 0..n {
789            let m = mask[i];
790            if !m.is_finite() || m < 0.0 {
791                bail_invalid_tnorm!(
792                    "outer-score subsample mask[{i}] = {m} is invalid (must be finite and >= 0)"
793                );
794            }
795            effective[i] = self.weights[i] * m;
796        }
797        Ok(Self {
798            // Inherit immutable design / response state cheaply via Arc / clone.
799            x_val_kron: self.x_val_kron.clone(),
800            x_deriv_kron: self.x_deriv_kron.clone(),
801            response_val_basis: self.response_val_basis.clone(),
802            response_lower_basis: self.response_lower_basis.clone(),
803            response_upper_basis: self.response_upper_basis.clone(),
804            response_deriv_basis: self.response_deriv_basis.clone(),
805            covariate_design: self.covariate_design.clone(),
806            covariate_dense_cache: Arc::clone(&self.covariate_dense_cache),
807            weights: Arc::clone(&self.weights),
808            offset: Arc::clone(&self.offset),
809            tensor_penalties: self.tensor_penalties.clone(),
810            initial_beta: self.initial_beta.clone(),
811            initial_log_lambdas: self.initial_log_lambdas.clone(),
812            block_name: self.block_name.clone(),
813            response_knots: self.response_knots.clone(),
814            response_transform: self.response_transform.clone(),
815            response_degree: self.response_degree,
816            response_median: self.response_median,
817            response_floor_offset: Arc::clone(&self.response_floor_offset),
818            response_lower_floor_offset: self.response_lower_floor_offset,
819            response_upper_floor_offset: self.response_upper_floor_offset,
820            // Caches must NOT be shared between full-data and subsampled
821            // families: the row-quantity cache stores the LL (mask-dependent),
822            // and the persistent dense Hessian is keyed on β alone.
823            row_quantity_cache: Arc::new(Mutex::new(None)),
824            outer_subsample_weights: Some(Arc::new(effective)),
825        })
826    }
827
828    /// Build an outer-subsample clone from a `BlockwiseFitOptions` row mask,
829    /// returning `None` when no subsample is requested.
830    pub(crate) fn maybe_with_outer_subsample_from_options(
831        &self,
832        options: &BlockwiseFitOptions,
833    ) -> Result<Option<Self>, TransformationNormalError> {
834        let Some(sub) = options.outer_score_subsample.as_ref() else {
835            return Ok(None);
836        };
837        let n = self.weights.len();
838        let mut mask = Array1::<f64>::zeros(n);
839        for row in sub.rows.iter() {
840            if row.index < n {
841                mask[row.index] = row.weight;
842            }
843        }
844        Ok(Some(self.with_outer_subsample(&mask)?))
845    }
846
847    // --- Internal helpers ---
848
849    pub(crate) fn covariate_dense_arc(&self) -> Result<Arc<Array2<f64>>, String> {
850        let mut cache = self
851            .covariate_dense_cache
852            .lock()
853            .expect("CTN covariate dense cache mutex poisoned");
854        if let Some(cached) = cache.as_ref() {
855            return Ok(cached.clone());
856        }
857        let dense = Arc::new(
858            self.covariate_design
859                .try_row_chunk(0..self.response_val_basis.nrows())
860                .map_err(|e| format!("SCOP covariate dense materialization failed: {e}"))?,
861        );
862        *cache = Some(dense.clone());
863        Ok(dense)
864    }
865
866    pub(crate) fn row_quantities(
867        &self,
868        beta: &Array1<f64>,
869    ) -> Result<TransformationNormalRowQuantityCache, String> {
870        {
871            let cache = self
872                .row_quantity_cache
873                .lock()
874                .expect("CTN row quantity cache mutex poisoned");
875            if let Some(cached) = cache.as_ref().filter(|cached| cached.matches_beta(beta)) {
876                return Ok(cached.clone());
877            }
878        }
879
880        let p_resp = self.response_val_basis.ncols();
881        let p_cov = self.covariate_design.ncols();
882        let beta_mat = beta
883            .view()
884            .into_shape_with_order((p_resp, p_cov))
885            .map_err(|e| format!("SCOP endpoint beta reshape failed: {e}"))?;
886        let cov = self.covariate_dense_arc()?;
887
888        // SCOP-CTN: h(y, x) = b(x) + Σ_k γ_k(x)² · I_k(y), with
889        // γ_k(x) = ψ(x)ᵀ Γ_{k,:} and h'(y, x) = Σ_k γ_k(x)² · M_k(y).
890        // Response column 0 is the unconstrained affine/location component
891        // b(x); all remaining response columns are squared shape components.
892        //
893        // The observed value, derivative value, and finite-support endpoints
894        // all depend on the same covariate-side γ_k(x_i).  Compute γ once and
895        // fan it out exactly; the previous path projected β through the same
896        // covariate design three times per row-quantity build.
897        let gamma = fast_abt(cov.as_ref(), &beta_mat);
898        let n = gamma.nrows();
899        let mut h = Array1::<f64>::zeros(n);
900        let mut h_prime = Array1::<f64>::zeros(n);
901        let mut h_lower = Array1::<f64>::zeros(n);
902        let mut h_upper = Array1::<f64>::zeros(n);
903        // Write directly into the four preallocated arrays in parallel; the
904        // previous path collected a `Vec<(f64,f64,f64,f64)>` then serially
905        // scattered into these arrays, costing 32 bytes per row of transient
906        // allocation and a single-threaded post-pass at large scale.
907        ndarray::Zip::indexed(&mut h)
908            .and(&mut h_prime)
909            .and(&mut h_lower)
910            .and(&mut h_upper)
911            .par_for_each(|i, h_i, hp_i, lower_i, upper_i| {
912                let gamma_row = gamma.row(i);
913                let val_row = self.response_val_basis.row(i);
914                let deriv_row = self.response_deriv_basis.row(i);
915                let g0 = gamma_row[0];
916                let offset_i = self.offset[i];
917                let mut h_acc = val_row[0] * g0 + offset_i + self.response_floor_offset[i];
918                let mut hp_acc = deriv_row[0] * g0 + TRANSFORMATION_MONOTONICITY_EPS;
919                let mut lower_acc =
920                    self.response_lower_basis[0] * g0 + offset_i + self.response_lower_floor_offset;
921                let mut upper_acc =
922                    self.response_upper_basis[0] * g0 + offset_i + self.response_upper_floor_offset;
923                for k in 1..p_resp {
924                    let g_sq = gamma_row[k] * gamma_row[k];
925                    h_acc += val_row[k] * g_sq;
926                    hp_acc += deriv_row[k] * g_sq;
927                    lower_acc += self.response_lower_basis[k] * g_sq;
928                    upper_acc += self.response_upper_basis[k] * g_sq;
929                }
930                *h_i = h_acc;
931                *hp_i = hp_acc;
932                *lower_i = lower_acc;
933                *upper_i = upper_acc;
934            });
935        for (i, &value) in h.iter().enumerate() {
936            if !value.is_finite() {
937                return Err(TransformationNormalError::NonFinite {
938                    reason: format!(
939                        "TransformationNormalFamily row_quantities: h[{i}] = {value} is not finite"
940                    ),
941                }
942                .into());
943            }
944            if value.abs() > TRANSFORMATION_NORMAL_H_ABS_MAX {
945                return Err(TransformationNormalError::InvalidInput { reason: format!(
946                    "TransformationNormalFamily row_quantities: h[{i}] = {value:.6e} exceeds the standard-normal domain bound ±{TRANSFORMATION_NORMAL_H_ABS_MAX}"
947                ) }.into());
948            }
949        }
950        // Hard monotonicity / finiteness gate: the reciprocal powers `1/h'^k`
951        // for k ∈ {1,2,3,4} feed the gradient, Hessian, and psi-psi outer
952        // Hessian formulas. A non-finite or non-positive h' produces +∞ /
953        // signed-∞ reciprocals which then collide with zero-valued probe
954        // vectors (`v_*_deriv * weights`) to yield NaN entries throughout the
955        // dense psi-psi block (`hessian_psi_psi`). The likelihood gate in
956        // `evaluate` already rejects such β; surface the same error here so
957        // outer-Hessian probe callsites that call `row_quantities` directly
958        // (psi/psi second-order terms, etc.) produce a clean Err for the
959        // outer evaluator to retreat on, rather than a NaN dense block that
960        // routes a flagrant non-finite Hessian back into the planner.
961        let mut min_hp = f64::INFINITY;
962        let mut nonfinite_idx: Option<usize> = None;
963        for (i, &hp) in h_prime.iter().enumerate() {
964            if !hp.is_finite() {
965                nonfinite_idx = Some(i);
966                break;
967            }
968            if hp < min_hp {
969                min_hp = hp;
970            }
971        }
972        if let Some(i) = nonfinite_idx {
973            return Err(TransformationNormalError::NonFinite {
974                reason: format!(
975                    "TransformationNormalFamily row_quantities: h'[{i}] = {} is not finite",
976                    h_prime[i]
977                ),
978            }
979            .into());
980        }
981        if min_hp <= 0.0 {
982            return Err(TransformationNormalError::MonotonicityViolated { reason: format!(
983                "TransformationNormalFamily row_quantities: h' has non-positive values (min = {min_hp:.6e}). \
984                 Monotonicity constraint may be violated."
985            ) }.into());
986        }
987        // Compute exact f64 row derivatives. If any required reciprocal power
988        // is outside the finite representable range, surface an evaluation
989        // error so the outer solver can retreat; do not clamp or approximate
990        // the analytic Hessian terms.
991        let derived = build_transformation_row_derived(
992            &h,
993            &h_prime,
994            &h_lower,
995            &h_upper,
996            self.effective_weights(),
997        )?;
998        let row_quantities = TransformationNormalRowQuantityCache {
999            beta: Arc::new(beta.clone()),
1000            gamma: Arc::new(gamma),
1001            h: Arc::new(h),
1002            h_prime: Arc::new(h_prime),
1003            h_lower: Arc::new(h_lower),
1004            h_upper: Arc::new(h_upper),
1005            endpoint_q: Arc::new(derived.endpoint_q),
1006            log_likelihood: derived.log_likelihood,
1007        };
1008
1009        let mut cache = self
1010            .row_quantity_cache
1011            .lock()
1012            .expect("CTN row quantity cache mutex poisoned");
1013        *cache = Some(row_quantities.clone());
1014        Ok(row_quantities)
1015    }
1016}