1use crate::survival::marginal_slope::row_kernel::RigidRowInputs;
54
55#[derive(Debug, Clone, PartialEq)]
64pub struct SurvivalRowJetChannels {
65 pub n_rows: usize,
66 pub value: Vec<f64>,
67 pub grad: Vec<f64>,
68 pub hess: Vec<f64>,
69 pub third: Vec<f64>,
70 pub fourth: Vec<f64>,
71}
72
73#[derive(Debug, Clone)]
78pub struct SurvivalRowInputs {
79 pub primaries: [f64; 4],
80 pub wi: f64,
81 pub di: f64,
82 pub z_sum: f64,
83 pub cov_ones: f64,
84}
85
86pub const DEVICE_ROW_THRESHOLD: usize = 100_000;
92
93#[must_use]
98pub fn survival_rigid_row_jets_cpu(
99 rows: &[SurvivalRowInputs],
100 probit_scale: f64,
101 dir: &[f64; 4],
102 dir_u: &[f64; 4],
103 dir_v: &[f64; 4],
104) -> SurvivalRowJetChannels {
105 use crate::survival::marginal_slope::row_kernel::{
106 RIGID_LINEAR_MASK, SparseOrder2, rigid_row_nll,
107 };
108 use gam_math::jet_scalar::{JetScalar, OneSeed, TwoSeed};
109 let n = rows.len();
110 let mut value = vec![0.0_f64; n];
111 let mut grad = vec![0.0_f64; n * 4];
112 let mut hess = vec![0.0_f64; n * 16];
113 let mut third = vec![0.0_f64; n * 16];
114 let mut fourth = vec![0.0_f64; n * 16];
115 for (row, inp) in rows.iter().enumerate() {
116 let in_row = RigidRowInputs {
117 row,
118 wi: inp.wi,
119 di: inp.di,
120 z_sum: inp.z_sum,
121 covariance_ones: inp.cov_ones,
122 probit_scale,
123 qd1_lower: f64::NEG_INFINITY,
128 };
129 let p = inp.primaries;
131 let vars: [SparseOrder2<RIGID_LINEAR_MASK>; 4] =
132 std::array::from_fn(|a| SparseOrder2::variable(p[a], a));
133 if let Ok(out) = rigid_row_nll(&vars, &in_row) {
134 value[row] = out.value();
135 grad[row * 4..row * 4 + 4].copy_from_slice(&out.g());
136 let h = out.h();
137 for a in 0..4 {
138 for b in 0..4 {
139 hess[row * 16 + a * 4 + b] = h[a][b];
140 }
141 }
142 }
143 let vars1: [OneSeed<4>; 4] =
145 std::array::from_fn(|a| OneSeed::seed_direction(p[a], a, dir[a]));
146 if let Ok(out1) = rigid_row_nll(&vars1, &in_row) {
147 let t = out1.contracted_third();
148 for a in 0..4 {
149 for b in 0..4 {
150 third[row * 16 + a * 4 + b] = t[a][b];
151 }
152 }
153 }
154 let vars2: [TwoSeed<4>; 4] =
156 std::array::from_fn(|a| TwoSeed::seed(p[a], a, dir_u[a], dir_v[a]));
157 if let Ok(out2) = rigid_row_nll(&vars2, &in_row) {
158 let f = out2.contracted_fourth();
159 for a in 0..4 {
160 for b in 0..4 {
161 fourth[row * 16 + a * 4 + b] = f[a][b];
162 }
163 }
164 }
165 }
166 SurvivalRowJetChannels {
167 n_rows: n,
168 value,
169 grad,
170 hess,
171 third,
172 fourth,
173 }
174}
175
176#[must_use]
182pub fn survival_rigid_row_jets(
183 rows: &[SurvivalRowInputs],
184 probit_scale: f64,
185 dir: &[f64; 4],
186 dir_u: &[f64; 4],
187 dir_v: &[f64; 4],
188) -> SurvivalRowJetChannels {
189 #[cfg(target_os = "linux")]
190 {
191 if rows.len() >= DEVICE_ROW_THRESHOLD {
192 match device::survival_rigid_row_jets_device(rows, probit_scale, dir, dir_u, dir_v) {
193 Ok(out) => return out,
194 Err(e) => {
195 log::info!("[GPU] survival_rowjet device path fell back to CPU: {e}");
199 }
200 }
201 }
202 }
203 survival_rigid_row_jets_cpu(rows, probit_scale, dir, dir_u, dir_v)
204}
205
206#[cfg(target_os = "linux")]
210pub fn survival_rigid_row_jets_device_only(
211 rows: &[SurvivalRowInputs],
212 probit_scale: f64,
213 dir: &[f64; 4],
214 dir_u: &[f64; 4],
215 dir_v: &[f64; 4],
216) -> Result<SurvivalRowJetChannels, String> {
217 device::survival_rigid_row_jets_device(rows, probit_scale, dir, dir_u, dir_v)
218 .map_err(|e| e.to_string())
219}
220
221#[cfg(target_os = "linux")]
225pub const SURVIVAL_ROWJET_SOURCE: &str = include_str!("survival_rowjet_kernel.cu");
226
227#[cfg(target_os = "linux")]
228mod device {
229 use super::{SURVIVAL_ROWJET_SOURCE, SurvivalRowInputs, SurvivalRowJetChannels};
230 use gam_gpu::gpu_error::{GpuError, GpuResultExt};
231 use std::sync::{Arc, Mutex, OnceLock};
232
233 use cudarc::driver::{CudaContext, CudaModule, CudaStream, LaunchConfig, PushKernelArg};
234
235 struct Backend {
236 ctx: Arc<CudaContext>,
237 stream: Arc<CudaStream>,
238 module: Mutex<Option<Arc<CudaModule>>>,
239 }
240
241 fn backend() -> Result<&'static Backend, GpuError> {
242 static BACKEND: OnceLock<Result<Backend, GpuError>> = OnceLock::new();
243 BACKEND
244 .get_or_init(|| {
245 let parts = gam_gpu::backend_probe::probe_cuda_backend("survival_rowjet")?;
246 Ok(Backend {
247 ctx: parts.ctx,
248 stream: parts.stream,
249 module: Mutex::new(None),
250 })
251 })
252 .as_ref()
253 .map_err(GpuError::clone)
254 }
255
256 fn module(b: &Backend) -> Result<Arc<CudaModule>, GpuError> {
257 if let Ok(guard) = b.module.lock() {
258 if let Some(m) = guard.as_ref() {
259 return Ok(m.clone());
260 }
261 }
262 let ptx = cudarc::nvrtc::compile_ptx(SURVIVAL_ROWJET_SOURCE)
263 .gpu_ctx_with(|err| format!("survival_rowjet NVRTC compile: {err}"))?;
264 let m = b
265 .ctx
266 .load_module(ptx)
267 .gpu_ctx("survival_rowjet module load")?;
268 if let Ok(mut guard) = b.module.lock() {
269 guard.get_or_insert_with(|| m.clone());
270 }
271 Ok(m)
272 }
273
274 fn has_nonzero_direction(dir: &[f64; 4]) -> bool {
275 dir.iter().any(|&v| v != 0.0)
276 }
277
278 pub(super) fn survival_rigid_row_jets_device(
279 rows: &[SurvivalRowInputs],
280 probit_scale: f64,
281 dir: &[f64; 4],
282 dir_u: &[f64; 4],
283 dir_v: &[f64; 4],
284 ) -> Result<SurvivalRowJetChannels, GpuError> {
285 let n = rows.len();
286 if n == 0 {
287 return Ok(SurvivalRowJetChannels {
288 n_rows: 0,
289 value: Vec::new(),
290 grad: Vec::new(),
291 hess: Vec::new(),
292 third: Vec::new(),
293 fourth: Vec::new(),
294 });
295 }
296 let b = backend()?;
297 let m = module(b)?;
298 let need_fourth = has_nonzero_direction(dir_u) && has_nonzero_direction(dir_v);
299 let func_name = if need_fourth {
300 "survival_rowjet"
301 } else {
302 "survival_rowjet_no_t4"
303 };
304 let func = m
305 .load_function(func_name)
306 .gpu_ctx_with(|err| format!("survival_rowjet load_function {func_name}: {err}"))?;
307 let stream = b.stream.clone();
308
309 let mut q0 = vec![0.0_f64; n];
311 let mut q1 = vec![0.0_f64; n];
312 let mut qd1 = vec![0.0_f64; n];
313 let mut g = vec![0.0_f64; n];
314 let mut wi = vec![0.0_f64; n];
315 let mut di = vec![0.0_f64; n];
316 let mut zs = vec![0.0_f64; n];
317 let mut cov = vec![0.0_f64; n];
318 for (i, r) in rows.iter().enumerate() {
319 q0[i] = r.primaries[0];
320 q1[i] = r.primaries[1];
321 qd1[i] = r.primaries[2];
322 g[i] = r.primaries[3];
323 wi[i] = r.wi;
324 di[i] = r.di;
325 zs[i] = r.z_sum;
326 cov[i] = r.cov_ones;
327 }
328
329 let q0_d = stream.clone_htod(&q0).gpu_ctx("htod q0")?;
330 let q1_d = stream.clone_htod(&q1).gpu_ctx("htod q1")?;
331 let qd1_d = stream.clone_htod(&qd1).gpu_ctx("htod qd1")?;
332 let g_d = stream.clone_htod(&g).gpu_ctx("htod g")?;
333 let wi_d = stream.clone_htod(&wi).gpu_ctx("htod wi")?;
334 let di_d = stream.clone_htod(&di).gpu_ctx("htod di")?;
335 let zs_d = stream.clone_htod(&zs).gpu_ctx("htod zsum")?;
336 let cov_d = stream.clone_htod(&cov).gpu_ctx("htod cov")?;
337 let dir_d = stream.clone_htod(&dir.to_vec()).gpu_ctx("htod dir")?;
338
339 let mut value_d = stream.alloc_zeros::<f64>(n).gpu_ctx("alloc value")?;
340 let mut grad_d = stream.alloc_zeros::<f64>(n * 4).gpu_ctx("alloc grad")?;
341 let mut hess_d = stream.alloc_zeros::<f64>(n * 16).gpu_ctx("alloc hess")?;
342 let mut third_d = stream.alloc_zeros::<f64>(n * 16).gpu_ctx("alloc third")?;
343 let mut fourth_d = stream.alloc_zeros::<f64>(n * 16).gpu_ctx("alloc fourth")?;
344
345 let n_i32 = i32::try_from(n)
346 .map_err(|_| gam_gpu::gpu_err!("survival_rowjet n={n} overflows i32"))?;
347 const TPB: u32 = 128;
348 let grid = ((n as u32).div_ceil(TPB)).max(1);
349 let cfg = LaunchConfig {
350 grid_dim: (grid, 1, 1),
351 block_dim: (TPB, 1, 1),
352 shared_mem_bytes: 0,
353 };
354 let mut builder = stream.launch_builder(&func);
355 builder
356 .arg(&n_i32)
357 .arg(&q0_d)
358 .arg(&q1_d)
359 .arg(&qd1_d)
360 .arg(&g_d)
361 .arg(&wi_d)
362 .arg(&di_d)
363 .arg(&zs_d)
364 .arg(&cov_d)
365 .arg(&probit_scale)
366 .arg(&dir_d);
367 let diru_d;
368 let dirv_d;
369 if need_fourth {
370 diru_d = stream.clone_htod(&dir_u.to_vec()).gpu_ctx("htod dir_u")?;
371 dirv_d = stream.clone_htod(&dir_v.to_vec()).gpu_ctx("htod dir_v")?;
372 builder.arg(&diru_d).arg(&dirv_d);
373 }
374 builder
375 .arg(&mut value_d)
376 .arg(&mut grad_d)
377 .arg(&mut hess_d)
378 .arg(&mut third_d)
379 .arg(&mut fourth_d);
380 unsafe { builder.launch(cfg) }.gpu_ctx("survival_rowjet kernel launch")?;
385
386 let mut value = vec![0.0_f64; n];
387 let mut grad = vec![0.0_f64; n * 4];
388 let mut hess = vec![0.0_f64; n * 16];
389 let mut third = vec![0.0_f64; n * 16];
390 let mut fourth = vec![0.0_f64; n * 16];
391 stream
392 .memcpy_dtoh(&value_d, &mut value)
393 .gpu_ctx("dtoh value")?;
394 stream
395 .memcpy_dtoh(&grad_d, &mut grad)
396 .gpu_ctx("dtoh grad")?;
397 stream
398 .memcpy_dtoh(&hess_d, &mut hess)
399 .gpu_ctx("dtoh hess")?;
400 stream
401 .memcpy_dtoh(&third_d, &mut third)
402 .gpu_ctx("dtoh third")?;
403 stream
404 .memcpy_dtoh(&fourth_d, &mut fourth)
405 .gpu_ctx("dtoh fourth")?;
406 stream
407 .synchronize()
408 .gpu_ctx("survival_rowjet synchronize")?;
409
410 Ok(SurvivalRowJetChannels {
411 n_rows: n,
412 value,
413 grad,
414 hess,
415 third,
416 fourth,
417 })
418 }
419}
420
421#[cfg(test)]
422mod tests {
423 use super::*;
424
425 fn fixture(n: usize) -> Vec<SurvivalRowInputs> {
426 (0..n)
427 .map(|i| {
428 let t = i as f64 / n as f64;
429 SurvivalRowInputs {
430 primaries: [
431 -2.5 + 5.0 * (12.0 * t).sin(),
432 -1.5 + 4.0 * (9.0 * t + 0.3).cos(),
433 0.2 + 1.8 * (0.5 + 0.5 * (7.0 * t).sin()),
434 -1.0 + 2.0 * (5.0 * t + 1.1).sin(),
435 ],
436 wi: 1.0,
437 di: if i % 3 == 0 { 1.0 } else { 0.0 },
438 z_sum: 0.5 * (3.0 * t).cos(),
439 cov_ones: 0.4 + 0.3 * (0.5 + 0.5 * (2.0 * t).sin()),
440 }
441 })
442 .collect()
443 }
444
445 const DIR: [f64; 4] = [0.31, -0.22, 0.17, 0.44];
446 const DIRU: [f64; 4] = [0.13, 0.27, -0.41, 0.05];
447 const DIRV: [f64; 4] = [-0.19, 0.33, 0.08, 0.22];
448
449 #[test]
450 fn cpu_channels_match_unified_rowkernel() {
451 use crate::survival::marginal_slope::row_kernel::rigid_row_nll;
456 use gam_math::jet_scalar::{JetScalar, Order2};
457 let rows = fixture(7);
458 let out = survival_rigid_row_jets_cpu(&rows, 0.7, &DIR, &DIRU, &DIRV);
459 for (row, inp) in rows.iter().enumerate() {
460 let in_row = RigidRowInputs {
461 row,
462 wi: inp.wi,
463 di: inp.di,
464 z_sum: inp.z_sum,
465 covariance_ones: inp.cov_ones,
466 probit_scale: 0.7,
467 qd1_lower: f64::NEG_INFINITY,
468 };
469 let vars: [Order2<4>; 4] =
470 std::array::from_fn(|a| Order2::variable(inp.primaries[a], a));
471 let dense = rigid_row_nll(&vars, &in_row).expect("dense order2");
472 assert!((dense.value() - out.value[row]).abs() <= 1e-12);
473 for a in 0..4 {
474 assert!((dense.g()[a] - out.grad[row * 4 + a]).abs() <= 1e-12);
475 for b in 0..4 {
476 assert!(
477 (dense.h()[a][b] - out.hess[row * 16 + a * 4 + b]).abs() <= 1e-12,
478 "hess mismatch row {row} {a},{b}"
479 );
480 }
481 }
482 }
483 }
484
485 #[test]
486 fn cpu_third_fourth_match_dense_tower_oracle() {
487 use crate::survival::marginal_slope::row_kernel::rigid_row_nll;
494 use gam_math::jet_tower::Tower4;
495 let rows = fixture(9);
496 let out = survival_rigid_row_jets_cpu(&rows, 0.7, &DIR, &DIRU, &DIRV);
497 for (row, inp) in rows.iter().enumerate() {
498 let in_row = RigidRowInputs {
499 row,
500 wi: inp.wi,
501 di: inp.di,
502 z_sum: inp.z_sum,
503 covariance_ones: inp.cov_ones,
504 probit_scale: 0.7,
505 qd1_lower: f64::NEG_INFINITY,
506 };
507 let vars: [Tower4<4>; 4] =
508 std::array::from_fn(|a| Tower4::variable(inp.primaries[a], a));
509 let tower = rigid_row_nll(&vars, &in_row).expect("dense tower4");
510 let t3 = tower.third_contracted(&DIR);
511 let t4 = tower.fourth_contracted(&DIRU, &DIRV);
512 for a in 0..4 {
513 for b in 0..4 {
514 assert!(
515 (t3[a][b] - out.third[row * 16 + a * 4 + b]).abs() <= 1e-12,
516 "third mismatch row {row} {a},{b}: tensor={} seeded={}",
517 t3[a][b],
518 out.third[row * 16 + a * 4 + b]
519 );
520 assert!(
521 (t4[a][b] - out.fourth[row * 16 + a * 4 + b]).abs() <= 1e-12,
522 "fourth mismatch row {row} {a},{b}: tensor={} seeded={}",
523 t4[a][b],
524 out.fourth[row * 16 + a * 4 + b]
525 );
526 }
527 }
528 }
529 }
530
531 #[cfg(target_os = "linux")]
532 #[test]
533 fn device_matches_cpu_when_available() {
534 let rows = fixture(DEVICE_ROW_THRESHOLD + 1024);
539 let cpu = survival_rigid_row_jets_cpu(&rows, 0.7, &DIR, &DIRU, &DIRV);
540 let got = survival_rigid_row_jets(&rows, 0.7, &DIR, &DIRU, &DIRV);
541 let mut maxabs = 0.0_f64;
542 let cmp = |a: &[f64], b: &[f64], m: &mut f64| {
543 for (x, y) in a.iter().zip(b) {
544 *m = m.max((x - y).abs());
545 }
546 };
547 cmp(&cpu.value, &got.value, &mut maxabs);
548 cmp(&cpu.grad, &got.grad, &mut maxabs);
549 cmp(&cpu.hess, &got.hess, &mut maxabs);
550 cmp(&cpu.third, &got.third, &mut maxabs);
551 cmp(&cpu.fourth, &got.fourth, &mut maxabs);
552 assert!(
553 maxabs <= 1e-9,
554 "survival device vs CPU row-jet max abs diff {maxabs} > 1e-9"
555 );
556 }
557}