Skip to main content

gam_models/survival/marginal_slope/
block_jacobians.rs

1//! Per-block effective Jacobians: the family scalars and the rigid / flex /
2//! time-wiggle `BlockEffectiveJacobian` implementations (time, marginal,
3//! logslope, score-warp, link-dev) plus the primary->joint row chain.
4
5use super::*;
6
7/// Per-row scalars for survival marginal-slope Jacobian evaluation at a given β.
8///
9/// Fields:
10/// - `q0_i`: entry-time probit argument (per-row, length n)
11/// - `q1_i`: exit-time probit argument (per-row, length n)
12/// - `qd1_i`: derivative probit argument (per-row, length n)
13/// - `g_i`: per-row log-slope value `g = logslope_design · β_logslope`
14/// - `c_i`: `sqrt(1 + (s·g_i)²)` (per-row, length n)
15/// - `s`: probit scale (scalar, = `probit_frailty_scale()`)
16/// - `z_i`: per-row covariate score (length n)
17pub struct SurvivalMarginalSlopeFamilyScalars {
18    pub q0_i: Vec<f64>,
19    pub q1_i: Vec<f64>,
20    pub qd1_i: Vec<f64>,
21    pub g_i: Vec<f64>,
22    pub c_i: Vec<f64>,
23    pub s: f64,
24    pub z_i: Vec<f64>,
25}
26
27impl SurvivalMarginalSlopeFamilyScalars {
28    /// Construct with c_i computed from g_i and s.
29    pub fn new(
30        q0_i: Vec<f64>,
31        q1_i: Vec<f64>,
32        qd1_i: Vec<f64>,
33        g_i: Vec<f64>,
34        s: f64,
35        z_i: Vec<f64>,
36    ) -> Self {
37        let c_i: Vec<f64> = g_i
38            .iter()
39            .map(|&g| (1.0 + (s * g).powi(2)).sqrt())
40            .collect();
41        Self {
42            q0_i,
43            q1_i,
44            qd1_i,
45            g_i,
46            c_i,
47            s,
48            z_i,
49        }
50    }
51}
52
53/// n_outputs=3 stacked Jacobian for the logslope block.
54///
55/// The logslope block contributes `g_i = logslope_design[i] · β` to each row.
56/// The three stacked output rows for row i are:
57///
58/// ```text
59/// ∂η0[i]/∂β = (q0[i] · s²·g[i]/c[i] + s·z[i]) · G[i,:]
60/// ∂η1[i]/∂β = (q1[i] · s²·g[i]/c[i] + s·z[i]) · G[i,:]
61/// ∂ad1[i]/∂β = qd1[i] · s²·g[i]/c[i] · G[i,:]
62/// ```
63///
64/// At g=0 (β=0 init): c=1, s²·g/c=0, so:
65/// ```text
66/// ∂η0[i]/∂β = s·z[i] · G[i,:]
67/// ∂η1[i]/∂β = s·z[i] · G[i,:]
68/// ∂ad1[i]/∂β = 0
69/// ```
70pub struct LogslopeBlockJacobian {
71    /// The logslope basis design (n × p_logslope). Held behind an `Arc` so a
72    /// materialized design is shared with its owner rather than deep-copied —
73    /// at biobank scale each retained `n × p` copy in these construction-time
74    /// callbacks was hundreds of MiB held for the whole fit (#979 OOM).
75    pub(crate) design: Arc<Array2<f64>>,
76    /// Per-row covariate score z_i (length n).
77    pub(crate) z: Vec<f64>,
78    /// Probit scale s.
79    pub(crate) s: f64,
80}
81
82impl LogslopeBlockJacobian {
83    pub fn new(design: impl Into<Arc<Array2<f64>>>, z: Vec<f64>, s: f64) -> Self {
84        Self {
85            design: design.into(),
86            z,
87            s,
88        }
89    }
90}
91
92impl crate::custom_family::BlockEffectiveJacobian for LogslopeBlockJacobian {
93    fn effective_jacobian_rows(
94        &self,
95        state: &crate::custom_family::FamilyLinearizationState<'_>,
96        rows: std::ops::Range<usize>,
97    ) -> Result<Array2<f64>, String> {
98        let n = self.design.nrows();
99        let p = self.design.ncols();
100        let rows = rows.start.min(n)..rows.end.min(n);
101        let chunk = rows.end - rows.start;
102        // Read s_f from the linearization state so that outer-loop σ updates are
103        // reflected without requiring the spec to be rebuilt.  Every construction
104        // site sets probit_frailty_scale = 1.0 when it does not know the family's
105        // σ; `self.s` carries the construction-time value as a fallback.  Use the
106        // state value when positive and finite; fall back to self.s otherwise.
107        // For the no-frailty case both are 1.0 so the choice is immaterial.
108        let s = if state.probit_frailty_scale > 0.0 && state.probit_frailty_scale.is_finite() {
109            state.probit_frailty_scale
110        } else {
111            self.s
112        };
113
114        // Compute per-row g_i = logslope_design[i,:] · β directly from state.beta.
115        // This block owns the logslope design so g is always self-computable without
116        // family_scalars.  Truncate to min(p, beta.len()) to handle the pre-fit
117        // initialisation call where beta may be shorter or empty.
118        let beta = state.beta;
119        let p_use = p.min(beta.len());
120        let mut g_rows = vec![0.0_f64; chunk];
121        for i in rows.clone() {
122            let local_i = i - rows.start;
123            for j in 0..p_use {
124                g_rows[local_i] += self.design[[i, j]] * beta[j];
125            }
126        }
127
128        // Hard contract: when any g_i is nonzero the per-row primary scalars
129        // (q0, q1, qd1) from the time/marginal blocks are required for the correct
130        // hyperbolic formula (q·s²g/c + s·z).  Those scalars live in family_scalars.
131        // A caller operating at non-init β must populate them.
132        let scalars: Option<&SurvivalMarginalSlopeFamilyScalars> = state
133            .family_scalars
134            .as_ref()
135            .and_then(|a| a.downcast_ref::<SurvivalMarginalSlopeFamilyScalars>());
136
137        let any_nonzero_g = g_rows.iter().any(|&gi| gi != 0.0);
138        if any_nonzero_g && scalars.is_none() {
139            return Err("survival marginal-slope logslope block requires \
140                 SurvivalMarginalSlopeFamilyScalars when beta != 0 \
141                 (g_i != 0 for at least one row); got family_scalars: None. \
142                 The caller must compute per-row (q0, q1, qd1) at the current \
143                 beta and pass them via FamilyLinearizationState::family_scalars."
144                .to_string());
145        }
146
147        let mut jac = Array2::<f64>::zeros((3 * chunk, p));
148
149        for i in rows.clone() {
150            let local_i = i - rows.start;
151            // g_i computed from beta above; c_i from family_scalars when present,
152            // otherwise computed from g_i.  q0/q1/qd1 from family_scalars -
153            // guaranteed present by the contract check whenever g_i != 0.
154            let g = g_rows[local_i];
155            let (q0, q1, qd1, c) = match scalars {
156                Some(sc) => (sc.q0_i[i], sc.q1_i[i], sc.qd1_i[i], sc.c_i[i]),
157                None => {
158                    // g == 0.0 here (enforced by contract above), so c = 1.
159                    // The q terms vanish: q * s^2 * 0 / 1 = 0.
160                    (0.0_f64, 0.0_f64, 0.0_f64, 1.0_f64)
161                }
162            };
163            let z_i = self.z[i];
164            let sg_over_c = if g == 0.0 { 0.0 } else { s * s * g / c };
165            let coeff_eta0 = q0 * sg_over_c + s * z_i;
166            let coeff_eta1 = q1 * sg_over_c + s * z_i;
167            let coeff_ad1 = qd1 * sg_over_c;
168
169            for j in 0..p {
170                let g_ij = self.design[[i, j]];
171                jac[[local_i, j]] = coeff_eta0 * g_ij;
172                jac[[chunk + local_i, j]] = coeff_eta1 * g_ij;
173                jac[[2 * chunk + local_i, j]] = coeff_ad1 * g_ij;
174            }
175        }
176        Ok(jac)
177    }
178
179    fn n_outputs(&self) -> usize {
180        3
181    }
182}
183
184/// n_outputs=3 stacked Jacobian for the marginal block.
185///
186/// The marginal block contributes identically to q0 and q1 (both entry and
187/// exit probit arguments) but not to ad1 (the derivative). The stacked Jacobian is:
188///
189/// ```text
190/// ∂η0[i]/∂β = c[i] · M[i,:]
191/// ∂η1[i]/∂β = c[i] · M[i,:]
192/// ∂ad1[i]/∂β = 0
193/// ```
194///
195/// At g=0 (β=0 init): c=1, so each row is just M[i,:].
196pub struct MarginalBlockJacobian {
197    /// The marginal basis design (n × p_marginal), `Arc`-shared with its
198    /// owner (see [`LogslopeBlockJacobian::design`]).
199    pub(crate) design: Arc<Array2<f64>>,
200}
201
202impl MarginalBlockJacobian {
203    pub fn new(design: impl Into<Arc<Array2<f64>>>) -> Self {
204        Self {
205            design: design.into(),
206        }
207    }
208}
209
210impl crate::custom_family::BlockEffectiveJacobian for MarginalBlockJacobian {
211    fn effective_jacobian_rows(
212        &self,
213        state: &crate::custom_family::FamilyLinearizationState<'_>,
214        rows: std::ops::Range<usize>,
215    ) -> Result<Array2<f64>, String> {
216        let n = self.design.nrows();
217        let p = self.design.ncols();
218        let rows = rows.start.min(n)..rows.end.min(n);
219        let chunk = rows.end - rows.start;
220
221        // c_i = sqrt(1 + (s * g_i)^2) depends on the logslope block's g at the
222        // current beta.  This block does not own the logslope design so it cannot
223        // compute c from beta alone.  Hard contract: when state.beta is non-empty
224        // (post-init), family_scalars must carry SurvivalMarginalSlopeFamilyScalars
225        // so the correct c_i is used.  At init (beta empty or all-zero), c_i = 1
226        // exactly and family_scalars may be omitted.
227        let scalars: Option<&SurvivalMarginalSlopeFamilyScalars> = state
228            .family_scalars
229            .as_ref()
230            .and_then(|a| a.downcast_ref::<SurvivalMarginalSlopeFamilyScalars>());
231
232        let beta_nonzero = state.beta.iter().any(|&b| b != 0.0);
233        if beta_nonzero && scalars.is_none() {
234            return Err("survival marginal-slope marginal block requires \
235                 SurvivalMarginalSlopeFamilyScalars when beta != 0 (c_i != 1 in general); \
236                 got family_scalars: None. The caller must populate per-row c_i via \
237                 FamilyLinearizationState::family_scalars."
238                .to_string());
239        }
240
241        let mut jac = Array2::<f64>::zeros((3 * chunk, p));
242
243        for i in rows.clone() {
244            let local_i = i - rows.start;
245            let c = match scalars {
246                Some(sc) => sc.c_i[i],
247                // beta is all-zero here (enforced above), so g = 0 and c = 1.
248                None => 1.0_f64,
249            };
250            for j in 0..p {
251                let m_ij = c * self.design[[i, j]];
252                jac[[local_i, j]] = m_ij;
253                jac[[chunk + local_i, j]] = m_ij;
254                // jac[[2*n + i, j]] = 0 -- ad1 row stays zero
255            }
256        }
257        Ok(jac)
258    }
259
260    fn n_outputs(&self) -> usize {
261        3
262    }
263}
264
265/// n_outputs=3 stacked Jacobian for the time block.
266///
267/// The time block contributes separately to η0 (entry), η1 (exit), and ad1
268/// (derivative) via three distinct design matrices. The stacked Jacobian is:
269///
270/// ```text
271/// ∂η0[i]/∂β = c[i] · T_entry[i,:]
272/// ∂η1[i]/∂β = c[i] · T_exit[i,:]
273/// ∂ad1[i]/∂β = c[i] · T_deriv[i,:]
274/// ```
275///
276/// At g=0 (β=0 init): c=1.
277pub struct TimeBlockJacobian {
278    // `Arc`-shared with their owners (see [`LogslopeBlockJacobian::design`]).
279    pub(crate) design_entry: Arc<Array2<f64>>,
280    pub(crate) design_exit: Arc<Array2<f64>>,
281    pub(crate) design_deriv: Arc<Array2<f64>>,
282}
283
284impl TimeBlockJacobian {
285    pub fn new(
286        design_entry: impl Into<Arc<Array2<f64>>>,
287        design_exit: impl Into<Arc<Array2<f64>>>,
288        design_deriv: impl Into<Arc<Array2<f64>>>,
289    ) -> Self {
290        Self {
291            design_entry: design_entry.into(),
292            design_exit: design_exit.into(),
293            design_deriv: design_deriv.into(),
294        }
295    }
296}
297
298impl crate::custom_family::BlockEffectiveJacobian for TimeBlockJacobian {
299    fn effective_jacobian_rows(
300        &self,
301        state: &crate::custom_family::FamilyLinearizationState<'_>,
302        rows: std::ops::Range<usize>,
303    ) -> Result<Array2<f64>, String> {
304        let n = self.design_entry.nrows();
305        let p = self.design_entry.ncols();
306        let rows = rows.start.min(n)..rows.end.min(n);
307        let chunk = rows.end - rows.start;
308
309        if self.design_exit.nrows() != n || self.design_deriv.nrows() != n {
310            return Err(format!(
311                "TimeBlockJacobian: design row count mismatch \
312                 entry={n} exit={} deriv={}",
313                self.design_exit.nrows(),
314                self.design_deriv.nrows(),
315            ));
316        }
317        if self.design_exit.ncols() != p || self.design_deriv.ncols() != p {
318            return Err(format!(
319                "TimeBlockJacobian: design col count mismatch \
320                 entry={p} exit={} deriv={}",
321                self.design_exit.ncols(),
322                self.design_deriv.ncols(),
323            ));
324        }
325
326        // c_i = sqrt(1 + (s * g_i)^2) depends on the logslope block's g.  This block
327        // does not own the logslope design.  Hard contract: when beta is non-empty/nonzero,
328        // family_scalars must carry SurvivalMarginalSlopeFamilyScalars with the correct c_i.
329        // At init (beta empty or all-zero), c_i = 1 exactly.
330        let scalars: Option<&SurvivalMarginalSlopeFamilyScalars> = state
331            .family_scalars
332            .as_ref()
333            .and_then(|a| a.downcast_ref::<SurvivalMarginalSlopeFamilyScalars>());
334
335        let beta_nonzero = state.beta.iter().any(|&b| b != 0.0);
336        if beta_nonzero && scalars.is_none() {
337            return Err("survival marginal-slope time block requires \
338                 SurvivalMarginalSlopeFamilyScalars when beta != 0 (c_i != 1 in general); \
339                 got family_scalars: None. The caller must populate per-row c_i via \
340                 FamilyLinearizationState::family_scalars."
341                .to_string());
342        }
343
344        let mut jac = Array2::<f64>::zeros((3 * chunk, p));
345
346        for i in rows.clone() {
347            let local_i = i - rows.start;
348            let c = match scalars {
349                Some(sc) => sc.c_i[i],
350                // beta is all-zero here (enforced above), so g = 0 and c = 1.
351                None => 1.0_f64,
352            };
353            for j in 0..p {
354                jac[[local_i, j]] = c * self.design_entry[[i, j]];
355                jac[[chunk + local_i, j]] = c * self.design_exit[[i, j]];
356                jac[[2 * chunk + local_i, j]] = c * self.design_deriv[[i, j]];
357            }
358        }
359        Ok(jac)
360    }
361
362    fn n_outputs(&self) -> usize {
363        3
364    }
365}
366
367// ── Timewiggle-active Jacobians ───────────────────────────────────────
368//
369// When timewiggle is active, (q0, q1, qd1) are nonlinear functions of
370// (β_time, β_marginal) through the composition:
371//
372//   h0 = X_entry_base[i] · β_t_base + offset_entry[i] + M[i] · β_m
373//   q0 = h0 + B(h0) · β_tw           (B = monotone wiggle basis)
374//
375// and analogously for q1 and qd1.  The chain rule gives:
376//
377//   ∂q0/∂β_t[j < p_base] = (1 + B'(h0)·β_tw) · X_entry[i,j]
378//                         = dq_dq0(h0) · X_entry[i,j]
379//   ∂q0/∂β_t[p_base + k] = B_k(h0)
380//   ∂q0/∂β_m[j]          = dq_dq0(h0) · M[i,j]
381//
382// Since η_r = c · q_r + … and ∂η_r/∂β_block = c · ∂q_r/∂β_block,
383// the stacked Jacobian for each block is:
384//
385//   J[i,       j] = c_i · ∂q0/∂β_block[j]
386//   J[n + i,   j] = c_i · ∂q1/∂β_block[j]
387//   J[2*n + i, j] = c_i · ∂qd1/∂β_block[j]
388//
389// where c_i = sqrt(1 + (s · g_i)²) and g_i = G[i] · β_g.
390//
391// At β = 0: dq_dq0 = 1, d²q/dh² = 0, c_i = 1, so both timewiggle
392// callbacks reduce to the rigid-path `TimeBlockJacobian` /
393// `MarginalBlockJacobian` values.
394//
395// Joint β layout (same for both callbacks):
396//   [β_t (p_time) | β_m (p_m) | β_g (p_g) | …]
397//
398// p_time = p_base + p_tw where p_tw = time_wiggle_ncols.
399
400/// n_outputs = 3 stacked Jacobian for the **time** block when timewiggle
401/// is active.  Computes `c_i` from the embedded logslope design and
402/// joint β, so no `family_scalars` are required.
403pub struct SmsTimewiggleTimeJacobian {
404    pub(crate) design_entry: Arc<Array2<f64>>,
405    pub(crate) design_exit: Arc<Array2<f64>>,
406    pub(crate) design_deriv: Arc<Array2<f64>>,
407    pub(crate) design_marginal: Arc<Array2<f64>>,
408    pub(crate) design_logslope: Arc<Array2<f64>>,
409    pub(crate) offset_entry: Arc<Array1<f64>>,
410    pub(crate) offset_exit: Arc<Array1<f64>>,
411    pub(crate) offset_deriv: Arc<Array1<f64>>,
412    /// Fixed marginal-predictor offset. The full marginal predictor entering
413    /// the entry/exit channels is `design_marginal·β_m + marginal_offset`
414    /// (see `row_dynamic_q_values`); this is the β-independent part.
415    pub(crate) marginal_offset: Arc<Array1<f64>>,
416    pub(crate) time_wiggle_knots: Array1<f64>,
417    pub(crate) time_wiggle_degree: usize,
418    /// Full time block width (= design_entry.ncols()).
419    pub(crate) p_time: usize,
420    /// Wiggle tail width.
421    pub(crate) p_tw: usize,
422    /// Marginal block width (for joint β parsing).
423    pub(crate) p_m: usize,
424    /// Logslope block width (for joint β parsing).
425    pub(crate) p_g: usize,
426    /// Probit frailty scale s.
427    pub(crate) probit_scale: f64,
428}
429
430impl SmsTimewiggleTimeJacobian {
431    /// Construct.
432    pub fn new(
433        design_entry: Arc<Array2<f64>>,
434        design_exit: Arc<Array2<f64>>,
435        design_deriv: Arc<Array2<f64>>,
436        design_marginal: Arc<Array2<f64>>,
437        design_logslope: Arc<Array2<f64>>,
438        offset_entry: Arc<Array1<f64>>,
439        offset_exit: Arc<Array1<f64>>,
440        offset_deriv: Arc<Array1<f64>>,
441        marginal_offset: Arc<Array1<f64>>,
442        time_wiggle_knots: Array1<f64>,
443        time_wiggle_degree: usize,
444        p_tw: usize,
445        p_m: usize,
446        p_g: usize,
447        probit_scale: f64,
448    ) -> Self {
449        let p_time = design_entry.ncols();
450        Self {
451            design_entry,
452            design_exit,
453            design_deriv,
454            design_marginal,
455            design_logslope,
456            offset_entry,
457            offset_exit,
458            offset_deriv,
459            marginal_offset,
460            time_wiggle_knots,
461            time_wiggle_degree,
462            p_time,
463            p_tw,
464            p_m,
465            p_g,
466            probit_scale,
467        }
468    }
469}
470
471impl crate::custom_family::BlockEffectiveJacobian for SmsTimewiggleTimeJacobian {
472    fn effective_jacobian_rows(
473        &self,
474        state: &crate::custom_family::FamilyLinearizationState<'_>,
475        rows: std::ops::Range<usize>,
476    ) -> Result<Array2<f64>, String> {
477        let n = self.design_entry.nrows();
478        let p = self.p_time;
479        let rows = rows.start.min(n)..rows.end.min(n);
480        let chunk = rows.end - rows.start;
481        let p_base = p.saturating_sub(self.p_tw);
482
483        let beta = state.beta;
484        // β_t = joint β[0 .. p_time]
485        let beta_t = if beta.len() >= p { &beta[..p] } else { beta };
486        let beta_t_base = &beta_t[..p_base.min(beta_t.len())];
487        // β_tw must always be a length-`p_tw` vector. The timewiggle block
488        // exists whenever `self.p_tw > 0`, independent of how many coefficients
489        // the caller supplied: the identifiability canonicaliser calls this at
490        // the β=0 linearisation point with `beta = &[]` (see
491        // `BlockJacobianAsRowOp::from_callback`), so inferring "no wiggle block"
492        // from an empty slice — the old behaviour — wrongly drove `beta_tw`
493        // empty, made `sms_tw_first_order_geom` return `None`, and zeroed the
494        // wiggle tail columns. That made the time block look structurally
495        // aliased ("block 0 fully aliased") even though ∂q/∂β_tw[j] = B_j(h) ≠ 0
496        // at β=0. Zero-pad to `self.p_tw` so the basis is always evaluated.
497        let zero_tw: Vec<f64>;
498        let beta_tw: &[f64] = if beta_t.len() >= p_base + self.p_tw {
499            &beta_t[p_base..p_base + self.p_tw]
500        } else {
501            zero_tw = vec![0.0; self.p_tw];
502            &zero_tw
503        };
504        // β_m = joint β[p_time .. p_time + p_m]
505        let beta_m = {
506            let s = p;
507            let e = (s + self.p_m).min(beta.len());
508            if e > s { &beta[s..e] } else { &[][..] }
509        };
510        // β_g = joint β[p_time + p_m .. p_time + p_m + p_g]
511        let beta_g = {
512            let s = p + self.p_m;
513            let e = (s + self.p_g).min(beta.len());
514            if e > s { &beta[s..e] } else { &[][..] }
515        };
516
517        let sc = self.probit_scale;
518        let knots = &self.time_wiggle_knots;
519        let degree = self.time_wiggle_degree;
520
521        let mut jac = Array2::<f64>::zeros((3 * chunk, p));
522
523        for i in rows.clone() {
524            let local_i = i - rows.start;
525            // c_i computed directly from logslope design and joint β_g.
526            let g_i: f64 = beta_g
527                .iter()
528                .enumerate()
529                .filter(|&(j, _)| j < self.design_logslope.ncols())
530                .map(|(j, &b)| self.design_logslope[[i, j]] * b)
531                .sum();
532            let c_i = (1.0_f64 + (sc * g_i).powi(2)).sqrt();
533
534            // Base marginal η contribution.
535            let eta_m: f64 = beta_m
536                .iter()
537                .enumerate()
538                .filter(|&(j, _)| j < self.design_marginal.ncols())
539                .map(|(j, &b)| self.design_marginal[[i, j]] * b)
540                .sum();
541
542            // The marginal predictor (coefficient part `eta_m` plus the fixed
543            // `marginal_offset`) enters BOTH entry and exit channels but NOT
544            // the derivative channel — see `row_dynamic_q_values`.
545            let h0: f64 = self.offset_entry[i]
546                + eta_m
547                + self.marginal_offset[i]
548                + (0..p_base.min(beta_t_base.len()).min(self.design_entry.ncols()))
549                    .map(|j| self.design_entry[[i, j]] * beta_t_base[j])
550                    .sum::<f64>();
551            let h1: f64 = self.offset_exit[i]
552                + eta_m
553                + self.marginal_offset[i]
554                + (0..p_base.min(beta_t_base.len()).min(self.design_exit.ncols()))
555                    .map(|j| self.design_exit[[i, j]] * beta_t_base[j])
556                    .sum::<f64>();
557            let d_raw: f64 = self.offset_deriv[i]
558                + (0..p_base.min(beta_t_base.len()).min(self.design_deriv.ncols()))
559                    .map(|j| self.design_deriv[[i, j]] * beta_t_base[j])
560                    .sum::<f64>();
561
562            let beta_tw_view = ndarray::ArrayView1::from(beta_tw);
563            let eg = sms_tw_first_order_geom(
564                ndarray::ArrayView1::from(&[h0][..]),
565                beta_tw_view,
566                knots,
567                degree,
568            )?;
569            let xg = sms_tw_first_order_geom(
570                ndarray::ArrayView1::from(&[h1][..]),
571                beta_tw_view,
572                knots,
573                degree,
574            )?;
575
576            let (entry_dq, exit_dq, exit_d2q, entry_basis, exit_basis, exit_basis_d1) =
577                match (eg, xg) {
578                    (Some(eg), Some(xg)) => (
579                        eg.dq_dq0[0],
580                        xg.dq_dq0[0],
581                        xg.d2q_dq02[0],
582                        Some(eg.basis),
583                        Some(xg.basis),
584                        Some(xg.basis_d1),
585                    ),
586                    _ => (1.0_f64, 1.0_f64, 0.0_f64, None, None, None),
587                };
588
589            // Base columns j < p_base.
590            for j in 0..p_base.min(self.design_entry.ncols()) {
591                let xe = self.design_entry[[i, j]];
592                let xx = self.design_exit[[i, j]];
593                let xd = self.design_deriv[[i, j]];
594                jac[[local_i, j]] = c_i * entry_dq * xe;
595                jac[[chunk + local_i, j]] = c_i * exit_dq * xx;
596                jac[[2 * chunk + local_i, j]] = c_i * (exit_d2q * d_raw * xx + exit_dq * xd);
597            }
598
599            // Wiggle tail columns.
600            for local_idx in 0..self.p_tw {
601                let col = p_base + local_idx;
602                let b0 = entry_basis.as_ref().map_or(0.0, |b| b[[0, local_idx]]);
603                let b1 = exit_basis.as_ref().map_or(0.0, |b| b[[0, local_idx]]);
604                let bd1 = exit_basis_d1.as_ref().map_or(0.0, |b| b[[0, local_idx]]);
605                jac[[local_i, col]] = c_i * b0;
606                jac[[chunk + local_i, col]] = c_i * b1;
607                jac[[2 * chunk + local_i, col]] = c_i * bd1 * d_raw;
608            }
609        }
610        Ok(jac)
611    }
612
613    fn n_outputs(&self) -> usize {
614        3
615    }
616}
617
618/// n_outputs = 3 stacked Jacobian for the **marginal** block when timewiggle
619/// is active.
620pub struct SmsTimewiggleMarginalJacobian {
621    pub(crate) design_entry: Arc<Array2<f64>>,
622    pub(crate) design_exit: Arc<Array2<f64>>,
623    pub(crate) design_deriv: Arc<Array2<f64>>,
624    pub(crate) design_marginal: Arc<Array2<f64>>,
625    pub(crate) design_logslope: Arc<Array2<f64>>,
626    pub(crate) offset_entry: Arc<Array1<f64>>,
627    pub(crate) offset_exit: Arc<Array1<f64>>,
628    pub(crate) offset_deriv: Arc<Array1<f64>>,
629    /// Fixed marginal-predictor offset (β-independent part of the marginal
630    /// predictor entering the entry/exit channels; see `row_dynamic_q_values`).
631    pub(crate) marginal_offset: Arc<Array1<f64>>,
632    pub(crate) time_wiggle_knots: Array1<f64>,
633    pub(crate) time_wiggle_degree: usize,
634    pub(crate) p_time: usize,
635    pub(crate) p_tw: usize,
636    pub(crate) p_g: usize,
637    pub(crate) probit_scale: f64,
638}
639
640impl SmsTimewiggleMarginalJacobian {
641    /// Construct.
642    pub fn new(
643        design_entry: Arc<Array2<f64>>,
644        design_exit: Arc<Array2<f64>>,
645        design_deriv: Arc<Array2<f64>>,
646        design_marginal: Arc<Array2<f64>>,
647        design_logslope: Arc<Array2<f64>>,
648        offset_entry: Arc<Array1<f64>>,
649        offset_exit: Arc<Array1<f64>>,
650        offset_deriv: Arc<Array1<f64>>,
651        marginal_offset: Arc<Array1<f64>>,
652        time_wiggle_knots: Array1<f64>,
653        time_wiggle_degree: usize,
654        p_time: usize,
655        p_tw: usize,
656        p_g: usize,
657        probit_scale: f64,
658    ) -> Self {
659        Self {
660            design_entry,
661            design_exit,
662            design_deriv,
663            design_marginal,
664            design_logslope,
665            offset_entry,
666            offset_exit,
667            offset_deriv,
668            marginal_offset,
669            time_wiggle_knots,
670            time_wiggle_degree,
671            p_time,
672            p_tw,
673            p_g,
674            probit_scale,
675        }
676    }
677}
678
679impl crate::custom_family::BlockEffectiveJacobian for SmsTimewiggleMarginalJacobian {
680    fn effective_jacobian_rows(
681        &self,
682        state: &crate::custom_family::FamilyLinearizationState<'_>,
683        rows: std::ops::Range<usize>,
684    ) -> Result<Array2<f64>, String> {
685        let n = self.design_marginal.nrows();
686        let p_m = self.design_marginal.ncols();
687        let rows = rows.start.min(n)..rows.end.min(n);
688        let chunk = rows.end - rows.start;
689        let p_t = self.p_time;
690        let p_base = p_t.saturating_sub(self.p_tw);
691
692        let beta = state.beta;
693        let beta_t = if beta.len() >= p_t {
694            &beta[..p_t]
695        } else {
696            beta
697        };
698        let beta_t_base = &beta_t[..p_base.min(beta_t.len())];
699        let beta_tw = if beta_t.len() > p_base {
700            &beta_t[p_base..]
701        } else {
702            &[][..]
703        };
704        let beta_m = {
705            let s = p_t;
706            let e = (s + p_m).min(beta.len());
707            if e > s { &beta[s..e] } else { &[][..] }
708        };
709        let beta_g = {
710            let s = p_t + p_m;
711            let e = (s + self.p_g).min(beta.len());
712            if e > s { &beta[s..e] } else { &[][..] }
713        };
714
715        let sc = self.probit_scale;
716        let knots = &self.time_wiggle_knots;
717        let degree = self.time_wiggle_degree;
718
719        let mut jac = Array2::<f64>::zeros((3 * chunk, p_m));
720
721        for i in rows.clone() {
722            let local_i = i - rows.start;
723            let g_i: f64 = beta_g
724                .iter()
725                .enumerate()
726                .filter(|&(j, _)| j < self.design_logslope.ncols())
727                .map(|(j, &b)| self.design_logslope[[i, j]] * b)
728                .sum();
729            let c_i = (1.0_f64 + (sc * g_i).powi(2)).sqrt();
730
731            let eta_m: f64 = beta_m
732                .iter()
733                .enumerate()
734                .filter(|&(j, _)| j < p_m)
735                .map(|(j, &b)| self.design_marginal[[i, j]] * b)
736                .sum();
737
738            // Marginal predictor (eta_m + fixed marginal_offset) enters entry
739            // and exit channels alike (see `row_dynamic_q_values`).
740            let h0: f64 = self.offset_entry[i]
741                + eta_m
742                + self.marginal_offset[i]
743                + (0..p_base.min(beta_t_base.len()).min(self.design_entry.ncols()))
744                    .map(|j| self.design_entry[[i, j]] * beta_t_base[j])
745                    .sum::<f64>();
746            let h1: f64 = self.offset_exit[i]
747                + eta_m
748                + self.marginal_offset[i]
749                + (0..p_base.min(beta_t_base.len()).min(self.design_exit.ncols()))
750                    .map(|j| self.design_exit[[i, j]] * beta_t_base[j])
751                    .sum::<f64>();
752            let d_raw: f64 = self.offset_deriv[i]
753                + (0..p_base.min(beta_t_base.len()).min(self.design_deriv.ncols()))
754                    .map(|j| self.design_deriv[[i, j]] * beta_t_base[j])
755                    .sum::<f64>();
756
757            let beta_tw_view = ndarray::ArrayView1::from(beta_tw);
758            let eg = sms_tw_first_order_geom(
759                ndarray::ArrayView1::from(&[h0][..]),
760                beta_tw_view,
761                knots,
762                degree,
763            )?;
764            let xg = sms_tw_first_order_geom(
765                ndarray::ArrayView1::from(&[h1][..]),
766                beta_tw_view,
767                knots,
768                degree,
769            )?;
770
771            let (entry_dq, exit_dq, exit_d2q) = match (eg, xg) {
772                (Some(eg), Some(xg)) => (eg.dq_dq0[0], xg.dq_dq0[0], xg.d2q_dq02[0]),
773                _ => (1.0_f64, 1.0_f64, 0.0_f64),
774            };
775
776            for j in 0..p_m {
777                let m_ij = self.design_marginal[[i, j]];
778                jac[[local_i, j]] = c_i * entry_dq * m_ij;
779                jac[[chunk + local_i, j]] = c_i * exit_dq * m_ij;
780                jac[[2 * chunk + local_i, j]] = c_i * exit_d2q * d_raw * m_ij;
781            }
782        }
783        Ok(jac)
784    }
785
786    fn n_outputs(&self) -> usize {
787        3
788    }
789}
790
791/// Compute timewiggle first-order geometry at a single evaluation point `h0`.
792///
793/// Returns `Ok(None)` when `beta_tw` is empty (no active wiggle columns).
794/// This is a free-function mirror of
795/// `SurvivalMarginalSlopeFamily::time_wiggle_first_order_geometry` for use in
796/// `BlockEffectiveJacobian` impls that do not hold a family reference.
797pub(crate) fn sms_tw_first_order_geom(
798    h0: ndarray::ArrayView1<'_, f64>,
799    beta_tw: ndarray::ArrayView1<'_, f64>,
800    knots: &Array1<f64>,
801    degree: usize,
802) -> Result<Option<SurvivalTimeWiggleFirstOrderGeometry>, String> {
803    if beta_tw.is_empty() {
804        return Ok(None);
805    }
806    let basis = monotone_wiggle_basis_with_derivative_order(h0, knots, degree, 0)?;
807    let basis_d1 = monotone_wiggle_basis_with_derivative_order(h0, knots, degree, 1)?;
808    let basis_d2 = monotone_wiggle_basis_with_derivative_order(h0, knots, degree, 2)?;
809    if basis.ncols() != beta_tw.len()
810        || basis_d1.ncols() != beta_tw.len()
811        || basis_d2.ncols() != beta_tw.len()
812    {
813        return Err(format!(
814            "sms_tw_first_order_geom: basis/beta_tw width mismatch \
815             B/B'/B''={}/{}/{} beta_tw={}",
816            basis.ncols(),
817            basis_d1.ncols(),
818            basis_d2.ncols(),
819            beta_tw.len(),
820        ));
821    }
822    let dq_dq0 = fast_av(&basis_d1, &beta_tw) + 1.0;
823    let d2q_dq02 = fast_av(&basis_d2, &beta_tw);
824    Ok(Some(SurvivalTimeWiggleFirstOrderGeometry {
825        basis,
826        basis_d1,
827        basis_d2,
828        dq_dq0,
829        d2q_dq02,
830    }))
831}