gam 0.3.125

Generalized penalized likelihood engine
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
//! Device-side sigma-cubature stream-pool dispatch.
//!
//! The live GPU entry is [`try_gpu_sigma_stream_pool_eval`]. It runs each
//! sigma point through the unified PIRLS stream-pool executor, returns the
//! per-point `(H_original^-1, beta_original)` pairs, and hands the shared
//! covariance accumulation back to
//! [`crate::solver::reml::eval::accumulate_sigma_cubature_total_covariance`].

use ndarray::{Array1, Array2, ArrayView1};

use crate::gpu::gpu_error::GpuError;

/// Per-sigma-point GPU PIRLS input: penalty, reparameterisation transform,
/// and prior-mean shifts for one ρ / σ point.
///
/// Built by [`crate::solver::reml::eval::sigma_cubature_evaluate_gpu_stream_pool`]
/// from the reparameterisation engine output before the stream pool is allocated.
/// The shared model data (`X_original`, `y`, `prior_w`, `offset`) is uploaded
/// ONCE into [`crate::solver::gpu::pirls_gpu::PirlsGpuSharedData`]; only the
/// small per-point algebra (p×p Qs, p×p S, length-p shift, scalar) needs
/// uploading per sigma point.
pub struct SigmaPointGpuInput {
    /// `p × p` penalised-Hessian contribution `S_λ` in the transformed basis.
    pub s_transformed: Array2<f64>,
    /// `p × p` reparameterisation matrix `Qs`. Uploaded via
    /// `pirls_gpu::upload_qs_pirls` once per sigma point; also used on the
    /// CPU to map the loop's `β_transformed` and `H_transformed` back to the
    /// original basis so the downstream cubature accumulator receives
    /// `(H_original⁻¹, β_original)`.
    pub qs: Array2<f64>,
    /// Length-p linear shift `b` for the shifted-quadratic penalty
    /// `βᵀSβ − 2βᵀb + c`. All-zero for the default sigma-cubature path.
    pub linear_shift: Array1<f64>,
    /// Scalar constant shift `c`. Zero for the default sigma-cubature path.
    pub constant_shift: f64,
}

/// Default number of concurrent CUDA streams in the sigma-cubature pool.
///
/// Caps at `min(8, M)` so we never allocate more streams than sigma points.
/// Eight concurrent streams saturates the SM scheduler on all shipping
/// datacenter GPUs without exhausting the per-context stream limit.
#[cfg(target_os = "linux")]
const STREAM_POOL_MAX: usize = 8;

/// Initial Levenberg-Marquardt damping for each sigma-point PIRLS fit.
///
/// The sigma-point fits are cold-started from a zero β seed, so a small but
/// non-zero seed damping keeps the first Gauss-Newton step well-conditioned
/// when the design's `XᵀWX` is near-singular at β=0; the inner loop's own
/// trust-region logic then grows or shrinks it. `1e-6` is the same seed the
/// stateless CPU sigma-cubature path uses, so the two solvers take an
/// identical first step.
#[cfg(target_os = "linux")]
const SIGMA_PIRLS_INITIAL_LM_LAMBDA: f64 = 1e-6;

/// Compute the stream-pool size for a batch of M sigma points.
///
/// Auto-derived — no flag, no env var.
#[cfg(target_os = "linux")]
#[inline]
fn pool_size(m: usize) -> usize {
    m.min(STREAM_POOL_MAX).max(1)
}

/// GPU stream-pool sigma-cubature executor.
///
/// Allocates `N_streams = min(8, M)` per-stream workspace pairs
/// (`SigmaPirlsGpuWorkspace` + `PirlsLoopWorkspace`) against a bootstrap
/// shared context, then rotates sigma points across the pool with
/// `stream_idx = point_idx % N_streams`.  Each point gets its own
/// `PirlsGpuSharedData` (upload of `x_transformed` for that ρ) and runs
/// `pirls_loop_on_stream` on the assigned stream.  After all streams finish,
/// the loop outcome's `(β_transformed, penalized_hessian)` is mapped to
/// `(H_original⁻¹, β_original)` on the CPU and returned.
///
/// Returns `Ok(Some(results))` when every sigma point produced a usable GPU
/// result, `Ok(None)` when the device is unavailable (non-Linux or no
/// runtime), `Err(_)` on driver / shape failure.
/// `x_original`: Original (pre-reparameterization) dense design matrix X_original, shape n × p.
/// Uploaded to device once and reused across all sigma points.
/// `gamma_shape`: Active Gamma dispersion shape (α > 0). Pass `1.0` for non-Gamma families.
pub fn try_gpu_sigma_stream_pool_eval(
    x_original: ndarray::ArrayView2<'_, f64>,
    y: ArrayView1<'_, f64>,
    prior_w: ArrayView1<'_, f64>,
    offset: ArrayView1<'_, f64>,
    per_sigma: &[SigmaPointGpuInput],
    admission: crate::gpu::policy::PirlsLoopAdmission,
    gamma_shape: f64,
    convergence_tol: f64,
    max_iter: usize,
) -> Result<Option<Vec<Option<(ndarray::Array2<f64>, ndarray::Array1<f64>)>>>, GpuError> {
    if per_sigma.is_empty() {
        return Ok(Some(Vec::new()));
    }
    validate_sigma_point_inputs(x_original.ncols(), per_sigma)?;

    #[cfg(target_os = "linux")]
    {
        if crate::gpu::device_runtime::GpuRuntime::global().is_none() {
            return Ok(None);
        }
        let Some(family_kind) = admission.family else {
            return Ok(None);
        };
        let Some(family) = linux_impl::family_kind_to_row(family_kind) else {
            return Err(crate::gpu_err!(
                "sigma stream pool: family not in JIT-cached set"
            ));
        };
        let curvature = linux_impl::curvature_kind_to_row(admission.curvature);
        return linux_impl::stream_pool_eval(
            x_original,
            y,
            prior_w,
            offset,
            per_sigma,
            family,
            curvature,
            gamma_shape,
            convergence_tol,
            max_iter,
        );
    }

    #[cfg(not(target_os = "linux"))]
    {
        // Non-Linux: no CUDA runtime. Consume every parameter in a single
        // trace log so each binding is read once, satisfying -D warnings
        // without an #[allow(unused_variables)] suppression.
        log::trace!(
            "[sigma stream pool] non-Linux target: skipping dispatch \
             (x_original={}x{}, y_len={}, prior_w_len={}, offset_len={}, \
              n_sigma={}, family={:?}, curvature={:?}, gamma_shape={}, \
              tol={}, max_iter={})",
            x_original.nrows(),
            x_original.ncols(),
            y.len(),
            prior_w.len(),
            offset.len(),
            per_sigma.len(),
            admission.family,
            admission.curvature,
            gamma_shape,
            convergence_tol,
            max_iter,
        );
        Ok(None)
    }
}

fn validate_sigma_point_inputs(p: usize, per_sigma: &[SigmaPointGpuInput]) -> Result<(), GpuError> {
    for (idx, pt) in per_sigma.iter().enumerate() {
        if pt.s_transformed.shape() != [p, p] {
            return Err(crate::gpu_err!(
                "sigma stream pool: point[{idx}] S shape {:?} != [{p}, {p}]",
                pt.s_transformed.shape()
            ));
        }
        if pt.qs.shape() != [p, p] {
            return Err(crate::gpu_err!(
                "sigma stream pool: point[{idx}] Qs shape {:?} != [{p}, {p}]",
                pt.qs.shape()
            ));
        }
        if pt.linear_shift.len() != p {
            return Err(crate::gpu_err!(
                "sigma stream pool: point[{idx}] linear shift len {} != {p}",
                pt.linear_shift.len()
            ));
        }
        if !pt.constant_shift.is_finite() {
            return Err(crate::gpu_err!(
                "sigma stream pool: point[{idx}] non-finite constant shift {}",
                pt.constant_shift
            ));
        }
    }
    Ok(())
}

#[cfg(target_os = "linux")]
mod linux_impl {
    use crate::gpu::kernels::pirls_row::{CurvatureMode, PirlsRowFamily};
    use crate::gpu::kernels::sigma_cubature::SigmaPointGpuInput;
    use crate::gpu::policy::{PirlsLoopCurvatureKind, PirlsLoopFamilyKind};
    use crate::linalg::utils::matrix_inversewith_regularization;
    use ndarray::{Array1, Array2, ArrayView1};
    type SigmaPointResult = Option<(Array2<f64>, Array1<f64>)>;

    pub(super) fn family_kind_to_row(f: PirlsLoopFamilyKind) -> Option<PirlsRowFamily> {
        match f {
            PirlsLoopFamilyKind::BernoulliLogit => Some(PirlsRowFamily::BernoulliLogit),
            PirlsLoopFamilyKind::BernoulliProbit => Some(PirlsRowFamily::BernoulliProbit),
            PirlsLoopFamilyKind::BernoulliCLogLog => Some(PirlsRowFamily::BernoulliCLogLog),
            PirlsLoopFamilyKind::PoissonLog => Some(PirlsRowFamily::PoissonLog),
            PirlsLoopFamilyKind::GaussianIdentity => Some(PirlsRowFamily::GaussianIdentity),
            PirlsLoopFamilyKind::GammaLog => Some(PirlsRowFamily::GammaLog),
        }
    }

    pub(super) fn curvature_kind_to_row(c: PirlsLoopCurvatureKind) -> CurvatureMode {
        match c {
            PirlsLoopCurvatureKind::Fisher => CurvatureMode::Fisher,
            PirlsLoopCurvatureKind::Observed => CurvatureMode::Observed,
        }
    }

    /// Map `H_transformed` (in the `Qs` basis) back to the original basis.
    ///
    /// `H_original = Qs · H_transformed · Qsᵀ`
    fn hessian_to_original(
        h_transformed: &ndarray::Array2<f64>,
        qs: &ndarray::Array2<f64>,
    ) -> ndarray::Array2<f64> {
        let tmp = qs.dot(h_transformed);
        let mut h_orig = tmp.dot(&qs.t());
        crate::families::custom_family::symmetrize_dense_in_place(&mut h_orig);
        h_orig
    }

    pub(super) fn stream_pool_eval(
        x_original: ndarray::ArrayView2<'_, f64>,
        y: ArrayView1<'_, f64>,
        prior_w: ArrayView1<'_, f64>,
        offset: ArrayView1<'_, f64>,
        per_sigma: &[SigmaPointGpuInput],
        family: PirlsRowFamily,
        curvature: CurvatureMode,
        gamma_shape: f64,
        convergence_tol: f64,
        max_iter: usize,
    ) -> Result<Option<Vec<SigmaPointResult>>, crate::gpu::GpuError> {
        use crate::gpu::kernels::sigma_cubature::pool_size;
        use crate::solver::gpu::pirls_gpu;

        let m = per_sigma.len();
        let p = x_original.ncols();

        // Validate uniform shape across all sigma points.
        for (idx, pt) in per_sigma.iter().enumerate() {
            if pt.s_transformed.shape() != [p, p] || pt.qs.shape() != [p, p] {
                return Err(crate::gpu_err!(
                    "sigma stream pool: point[{idx}] shape mismatch against point[0]"
                ));
            }
        }

        // Gaussian-identity exact PLS bypass (#272).
        //
        // For Gaussian-identity the working weight is prior_w (constant across
        // all sigma points). XᵀWX and XᵀW(y−offset) are therefore the same for
        // every point. Compute them once, then for each sigma point call the
        // exact Gaussian PLS solver with the per-point (Qs, S_transformed,
        // linear_shift) — no row-kernel PIRLS loop, no iterative solver.
        if family == PirlsRowFamily::GaussianIdentity {
            return gaussian_sigma_pool_eval(x_original, y, prior_w, offset, per_sigma, p);
        }

        // Upload X_original, y, prior_w, offset once — shared across all sigma points.
        // Per sigma point, only Qs changes; it gets uploaded via upload_qs_pirls.
        let bootstrap_shared =
            pirls_gpu::upload_shared_pirls_gpu(x_original, y, prior_w, offset)
                .map_err(|e| crate::gpu_err!("sigma stream pool bootstrap upload: {e}"))?;

        let n_streams = pool_size(m);

        // Allocate N_streams workspace pairs bound to independent streams.
        let mut workspace_pairs: Vec<(
            crate::solver::gpu::pirls_gpu::SigmaPirlsGpuWorkspace,
            crate::solver::gpu::pirls_gpu::cuda::PirlsLoopWorkspace,
        )> = Vec::with_capacity(n_streams);
        for _ in 0..n_streams {
            let ws = pirls_gpu::allocate_sigma_pirls_workspace(&bootstrap_shared)
                .map_err(|e| crate::gpu_err!("sigma stream pool alloc workspace: {e}"))?;
            let loop_ws = pirls_gpu::allocate_pirls_loop_workspace(&bootstrap_shared, &ws)
                .map_err(|e| crate::gpu_err!("sigma stream pool alloc loop_ws: {e}"))?;
            workspace_pairs.push((ws, loop_ws));
        }

        // Zero-initialised beta seed (length p). The sigma-point PIRLS fits
        // have no warm-start; a zero seed matches the stateless CPU path.
        let beta0: Array1<f64> = Array1::zeros(p);

        // For each sigma point, upload Qs (small p×p) then run pirls_loop.
        // X_original, y, prior_w, offset stay in bootstrap_shared throughout.
        let mut outcomes: Vec<SigmaPointResult> = Vec::with_capacity(m);
        for (idx, pt) in per_sigma.iter().enumerate() {
            let stream_idx = idx % n_streams;

            let (ws, loop_ws) = &mut workspace_pairs[stream_idx];
            // Upload this sigma point's Qs matrix to the workspace.
            pirls_gpu::upload_qs_pirls(ws, pt.qs.view())
                .map_err(|e| crate::gpu_err!("sigma stream pool upload Qs pt[{idx}]: {e}"))?;
            let shared = &bootstrap_shared;

            // Use per-point linear_shift and constant_shift from SigmaPointGpuInput (#260).
            let outcome = pirls_gpu::pirls_loop_on_stream(
                shared,
                ws,
                loop_ws,
                family,
                curvature,
                gamma_shape,
                beta0.view(),
                pt.s_transformed.view(),
                pt.linear_shift.view(),
                pt.constant_shift,
                super::SIGMA_PIRLS_INITIAL_LM_LAMBDA,
                0.0,
                max_iter,
                convergence_tol,
                None,
            );

            let sigma_result = match outcome {
                Ok(loop_out) => {
                    // Map H_transformed → H_original, invert, map β_transformed
                    // → β_original. Mirrors the CPU path's post-processing.
                    let h_orig = hessian_to_original(&loop_out.penalized_hessian, &pt.qs);
                    let cov = matrix_inversewith_regularization(&h_orig, "gpu sigma point")
                        .ok_or_else(|| {
                            crate::gpu_err!(
                                "gpu sigma point: penalised Hessian inverse not well-defined"
                            )
                        })?;
                    let beta_orig = pt.qs.dot(&loop_out.beta);
                    Some((cov, beta_orig))
                }
                Err(e) => {
                    log::warn!(
                        "[sigma-cubature gpu] point[{idx}] pirls_loop_on_stream failed: {e}"
                    );
                    None
                }
            };

            outcomes.push(sigma_result);
        }

        Ok(Some(outcomes))
    }

    /// Gaussian-identity sigma-cubature bypass (#272).
    ///
    /// For Gaussian-identity the working weight equals the prior weight, which
    /// is the same for every sigma point. XᵀWX and XᵀW(y−offset) are computed
    /// once from `x_original`, `prior_w`, `y`, `offset`, and then the exact GPU
    /// PLS solver is called per sigma point with the per-point `(Qs, S, shift)`.
    /// This eliminates the row-kernel PIRLS loop entirely for Gaussian fits,
    /// matching the single-fit `try_gpu_gaussian_pls_dispatch` bypass.
    fn gaussian_sigma_pool_eval(
        x_original: ndarray::ArrayView2<'_, f64>,
        y: ArrayView1<'_, f64>,
        prior_w: ArrayView1<'_, f64>,
        offset: ArrayView1<'_, f64>,
        per_sigma: &[SigmaPointGpuInput],
        p: usize,
    ) -> Result<Option<Vec<SigmaPointResult>>, crate::gpu::GpuError> {
        use ndarray::Array1;
        // XᵀWX = Xᵀ·diag(prior_w)·X (constant across all sigma points).
        // Computed on the GPU via weighted_crossprod_gpu.
        let xtwx = crate::solver::gpu::pirls_gpu::weighted_crossprod_gpu(x_original, prior_w)
            .map_err(|e| crate::gpu_err!("gaussian sigma: XᵀWX gpu failed: {e}"))?;

        // XᵀW(y − offset) = Xᵀ·diag(prior_w)·(y − offset).
        // Compute on the host (n-vector; the GPU would need a separate kernel).
        let mut yw = y.to_owned();
        yw -= &offset;
        yw *= &prior_w;
        // Xᵀ·(prior_w · (y − offset)).
        let xtwy: Array1<f64> = x_original.t().dot(&yw);

        let prior_mean_zero: Array1<f64> = Array1::zeros(p);

        let mut outcomes: Vec<SigmaPointResult> = Vec::with_capacity(per_sigma.len());
        for (idx, pt) in per_sigma.iter().enumerate() {
            let pls = crate::solver::gpu::pirls_gpu::solve_gaussian_pls_gpu(
                xtwx.view(),
                xtwy.view(),
                pt.s_transformed.view(),
                pt.linear_shift.view(),
                prior_mean_zero.view(),
                0.0,
                Some(pt.qs.view()),
            )
            .map_err(|e| crate::gpu_err!("gaussian sigma pool: point[{idx}] pls failed: {e}"))?;

            let h_orig = hessian_to_original(&pls.penalized_hessian, &pt.qs);
            let cov = matrix_inversewith_regularization(&h_orig, "gaussian sigma point")
                .ok_or_else(|| {
                    crate::gpu_err!(
                        "gaussian sigma point: penalised Hessian inverse not well-defined"
                    )
                })?;
            let beta_orig = pt.qs.dot(&pls.beta);
            outcomes.push(Some((cov, beta_orig)));
        }

        Ok(Some(outcomes))
    }
}