1use crate::survival::marginal_slope::row_kernel::RigidRowInputs;
54
55mod survival_rowjet_host_oracle_tests;
60
61#[derive(Debug, Clone, PartialEq)]
70pub struct SurvivalRowJetChannels {
71 pub n_rows: usize,
72 pub value: Vec<f64>,
73 pub grad: Vec<f64>,
74 pub hess: Vec<f64>,
75 pub third: Vec<f64>,
76 pub fourth: Vec<f64>,
77}
78
79#[derive(Debug, Clone)]
84pub struct SurvivalRowInputs {
85 pub primaries: [f64; 4],
86 pub wi: f64,
87 pub di: f64,
88 pub z_sum: f64,
89 pub cov_ones: f64,
90}
91
92pub const DEVICE_ROW_THRESHOLD: usize = 100_000;
98
99#[must_use]
104pub fn survival_rigid_row_jets_cpu(
105 rows: &[SurvivalRowInputs],
106 probit_scale: f64,
107 dir: &[f64; 4],
108 dir_u: &[f64; 4],
109 dir_v: &[f64; 4],
110) -> SurvivalRowJetChannels {
111 use crate::survival::marginal_slope::row_kernel::{
112 RIGID_LINEAR_MASK, SparseOrder2, rigid_row_nll,
113 };
114 use gam_math::jet_scalar::{JetScalar, OneSeed, TwoSeed};
115 let n = rows.len();
116 let mut value = vec![0.0_f64; n];
117 let mut grad = vec![0.0_f64; n * 4];
118 let mut hess = vec![0.0_f64; n * 16];
119 let mut third = vec![0.0_f64; n * 16];
120 let mut fourth = vec![0.0_f64; n * 16];
121 for (row, inp) in rows.iter().enumerate() {
122 let in_row = RigidRowInputs {
123 row,
124 wi: inp.wi,
125 di: inp.di,
126 z_sum: inp.z_sum,
127 covariance_ones: inp.cov_ones,
128 probit_scale,
129 qd1_lower: f64::NEG_INFINITY,
134 };
135 let p = inp.primaries;
137 let vars: [SparseOrder2<RIGID_LINEAR_MASK>; 4] =
138 std::array::from_fn(|a| SparseOrder2::variable(p[a], a));
139 if let Ok(out) = rigid_row_nll(&vars, &in_row) {
140 value[row] = out.value();
141 grad[row * 4..row * 4 + 4].copy_from_slice(&out.g());
142 let h = out.h();
143 for a in 0..4 {
144 for b in 0..4 {
145 hess[row * 16 + a * 4 + b] = h[a][b];
146 }
147 }
148 }
149 let vars1: [OneSeed<4>; 4] =
151 std::array::from_fn(|a| OneSeed::seed_direction(p[a], a, dir[a]));
152 if let Ok(out1) = rigid_row_nll(&vars1, &in_row) {
153 let t = out1.contracted_third();
154 for a in 0..4 {
155 for b in 0..4 {
156 third[row * 16 + a * 4 + b] = t[a][b];
157 }
158 }
159 }
160 let vars2: [TwoSeed<4>; 4] =
162 std::array::from_fn(|a| TwoSeed::seed(p[a], a, dir_u[a], dir_v[a]));
163 if let Ok(out2) = rigid_row_nll(&vars2, &in_row) {
164 let f = out2.contracted_fourth();
165 for a in 0..4 {
166 for b in 0..4 {
167 fourth[row * 16 + a * 4 + b] = f[a][b];
168 }
169 }
170 }
171 }
172 SurvivalRowJetChannels {
173 n_rows: n,
174 value,
175 grad,
176 hess,
177 third,
178 fourth,
179 }
180}
181
182#[must_use]
188pub fn survival_rigid_row_jets(
189 rows: &[SurvivalRowInputs],
190 probit_scale: f64,
191 dir: &[f64; 4],
192 dir_u: &[f64; 4],
193 dir_v: &[f64; 4],
194) -> SurvivalRowJetChannels {
195 #[cfg(target_os = "linux")]
196 {
197 if rows.len() >= DEVICE_ROW_THRESHOLD {
198 match device::survival_rigid_row_jets_device(rows, probit_scale, dir, dir_u, dir_v) {
199 Ok(out) => return out,
200 Err(e) => {
201 log::info!("[GPU] survival_rowjet device path fell back to CPU: {e}");
205 }
206 }
207 }
208 }
209 survival_rigid_row_jets_cpu(rows, probit_scale, dir, dir_u, dir_v)
210}
211
212#[cfg(target_os = "linux")]
216pub fn survival_rigid_row_jets_device_only(
217 rows: &[SurvivalRowInputs],
218 probit_scale: f64,
219 dir: &[f64; 4],
220 dir_u: &[f64; 4],
221 dir_v: &[f64; 4],
222) -> Result<SurvivalRowJetChannels, String> {
223 device::survival_rigid_row_jets_device(rows, probit_scale, dir, dir_u, dir_v)
224 .map_err(|e| e.to_string())
225}
226
227#[cfg(target_os = "linux")]
231pub const SURVIVAL_ROWJET_SOURCE: &str = include_str!("survival_rowjet_kernel.cu");
232
233#[cfg(target_os = "linux")]
234mod device {
235 use super::{SURVIVAL_ROWJET_SOURCE, SurvivalRowInputs, SurvivalRowJetChannels};
236 use gam_gpu::gpu_error::{GpuError, GpuResultExt};
237 use std::sync::{Arc, Mutex, OnceLock};
238
239 use cudarc::driver::{CudaContext, CudaModule, CudaStream, LaunchConfig, PushKernelArg};
240
241 struct Backend {
242 ctx: Arc<CudaContext>,
243 stream: Arc<CudaStream>,
244 module: Mutex<Option<Arc<CudaModule>>>,
245 }
246
247 fn backend() -> Result<&'static Backend, GpuError> {
248 static BACKEND: OnceLock<Result<Backend, GpuError>> = OnceLock::new();
249 BACKEND
250 .get_or_init(|| {
251 let parts = gam_gpu::backend_probe::probe_cuda_backend("survival_rowjet")?;
252 Ok(Backend {
253 ctx: parts.ctx,
254 stream: parts.stream,
255 module: Mutex::new(None),
256 })
257 })
258 .as_ref()
259 .map_err(GpuError::clone)
260 }
261
262 fn module(b: &Backend) -> Result<Arc<CudaModule>, GpuError> {
263 if let Ok(guard) = b.module.lock() {
264 if let Some(m) = guard.as_ref() {
265 return Ok(m.clone());
266 }
267 }
268 let ptx = gam_gpu::device_cache::compile_ptx_arch(SURVIVAL_ROWJET_SOURCE)
276 .gpu_ctx_with(|err| format!("survival_rowjet NVRTC compile: {err}"))?;
277 let m = b
278 .ctx
279 .load_module(ptx)
280 .gpu_ctx("survival_rowjet module load")?;
281 if let Ok(mut guard) = b.module.lock() {
282 guard.get_or_insert_with(|| m.clone());
283 }
284 Ok(m)
285 }
286
287 fn has_nonzero_direction(dir: &[f64; 4]) -> bool {
288 dir.iter().any(|&v| v != 0.0)
289 }
290
291 pub(super) fn survival_rigid_row_jets_device(
292 rows: &[SurvivalRowInputs],
293 probit_scale: f64,
294 dir: &[f64; 4],
295 dir_u: &[f64; 4],
296 dir_v: &[f64; 4],
297 ) -> Result<SurvivalRowJetChannels, GpuError> {
298 let n = rows.len();
299 if n == 0 {
300 return Ok(SurvivalRowJetChannels {
301 n_rows: 0,
302 value: Vec::new(),
303 grad: Vec::new(),
304 hess: Vec::new(),
305 third: Vec::new(),
306 fourth: Vec::new(),
307 });
308 }
309 let b = backend()?;
310 let m = module(b)?;
311 let need_fourth = has_nonzero_direction(dir_u) && has_nonzero_direction(dir_v);
312 let func_name = if need_fourth {
313 "survival_rowjet"
314 } else {
315 "survival_rowjet_no_t4"
316 };
317 let func = m
318 .load_function(func_name)
319 .gpu_ctx_with(|err| format!("survival_rowjet load_function {func_name}: {err}"))?;
320 let stream = b.stream.clone();
321
322 let mut q0 = vec![0.0_f64; n];
324 let mut q1 = vec![0.0_f64; n];
325 let mut qd1 = vec![0.0_f64; n];
326 let mut g = vec![0.0_f64; n];
327 let mut wi = vec![0.0_f64; n];
328 let mut di = vec![0.0_f64; n];
329 let mut zs = vec![0.0_f64; n];
330 let mut cov = vec![0.0_f64; n];
331 for (i, r) in rows.iter().enumerate() {
332 q0[i] = r.primaries[0];
333 q1[i] = r.primaries[1];
334 qd1[i] = r.primaries[2];
335 g[i] = r.primaries[3];
336 wi[i] = r.wi;
337 di[i] = r.di;
338 zs[i] = r.z_sum;
339 cov[i] = r.cov_ones;
340 }
341
342 let q0_d = stream.clone_htod(&q0).gpu_ctx("htod q0")?;
343 let q1_d = stream.clone_htod(&q1).gpu_ctx("htod q1")?;
344 let qd1_d = stream.clone_htod(&qd1).gpu_ctx("htod qd1")?;
345 let g_d = stream.clone_htod(&g).gpu_ctx("htod g")?;
346 let wi_d = stream.clone_htod(&wi).gpu_ctx("htod wi")?;
347 let di_d = stream.clone_htod(&di).gpu_ctx("htod di")?;
348 let zs_d = stream.clone_htod(&zs).gpu_ctx("htod zsum")?;
349 let cov_d = stream.clone_htod(&cov).gpu_ctx("htod cov")?;
350 let dir_d = stream.clone_htod(&dir.to_vec()).gpu_ctx("htod dir")?;
351
352 let mut value_d = stream.alloc_zeros::<f64>(n).gpu_ctx("alloc value")?;
353 let mut grad_d = stream.alloc_zeros::<f64>(n * 4).gpu_ctx("alloc grad")?;
354 let mut hess_d = stream.alloc_zeros::<f64>(n * 16).gpu_ctx("alloc hess")?;
355 let mut third_d = stream.alloc_zeros::<f64>(n * 16).gpu_ctx("alloc third")?;
356 let mut fourth_d = stream.alloc_zeros::<f64>(n * 16).gpu_ctx("alloc fourth")?;
357
358 let n_i32 = i32::try_from(n)
359 .map_err(|_| gam_gpu::gpu_err!("survival_rowjet n={n} overflows i32"))?;
360 const TPB: u32 = 128;
361 let grid = ((n as u32).div_ceil(TPB)).max(1);
362 let cfg = LaunchConfig {
363 grid_dim: (grid, 1, 1),
364 block_dim: (TPB, 1, 1),
365 shared_mem_bytes: 0,
366 };
367 let mut builder = stream.launch_builder(&func);
368 builder
369 .arg(&n_i32)
370 .arg(&q0_d)
371 .arg(&q1_d)
372 .arg(&qd1_d)
373 .arg(&g_d)
374 .arg(&wi_d)
375 .arg(&di_d)
376 .arg(&zs_d)
377 .arg(&cov_d)
378 .arg(&probit_scale)
379 .arg(&dir_d);
380 let diru_d;
381 let dirv_d;
382 if need_fourth {
383 diru_d = stream.clone_htod(&dir_u.to_vec()).gpu_ctx("htod dir_u")?;
384 dirv_d = stream.clone_htod(&dir_v.to_vec()).gpu_ctx("htod dir_v")?;
385 builder.arg(&diru_d).arg(&dirv_d);
386 }
387 builder
388 .arg(&mut value_d)
389 .arg(&mut grad_d)
390 .arg(&mut hess_d)
391 .arg(&mut third_d)
392 .arg(&mut fourth_d);
393 unsafe { builder.launch(cfg) }.gpu_ctx("survival_rowjet kernel launch")?;
398
399 let mut value = vec![0.0_f64; n];
400 let mut grad = vec![0.0_f64; n * 4];
401 let mut hess = vec![0.0_f64; n * 16];
402 let mut third = vec![0.0_f64; n * 16];
403 let mut fourth = vec![0.0_f64; n * 16];
404 stream
405 .memcpy_dtoh(&value_d, &mut value)
406 .gpu_ctx("dtoh value")?;
407 stream
408 .memcpy_dtoh(&grad_d, &mut grad)
409 .gpu_ctx("dtoh grad")?;
410 stream
411 .memcpy_dtoh(&hess_d, &mut hess)
412 .gpu_ctx("dtoh hess")?;
413 stream
414 .memcpy_dtoh(&third_d, &mut third)
415 .gpu_ctx("dtoh third")?;
416 stream
417 .memcpy_dtoh(&fourth_d, &mut fourth)
418 .gpu_ctx("dtoh fourth")?;
419 stream
420 .synchronize()
421 .gpu_ctx("survival_rowjet synchronize")?;
422
423 Ok(SurvivalRowJetChannels {
424 n_rows: n,
425 value,
426 grad,
427 hess,
428 third,
429 fourth,
430 })
431 }
432}
433
434#[cfg(test)]
435mod tests {
436 use super::*;
437
438 fn fixture(n: usize) -> Vec<SurvivalRowInputs> {
439 (0..n)
440 .map(|i| {
441 let t = i as f64 / n as f64;
442 SurvivalRowInputs {
443 primaries: [
444 -2.5 + 5.0 * (12.0 * t).sin(),
445 -1.5 + 4.0 * (9.0 * t + 0.3).cos(),
446 0.2 + 1.8 * (0.5 + 0.5 * (7.0 * t).sin()),
447 -1.0 + 2.0 * (5.0 * t + 1.1).sin(),
448 ],
449 wi: 1.0,
450 di: if i % 3 == 0 { 1.0 } else { 0.0 },
451 z_sum: 0.5 * (3.0 * t).cos(),
452 cov_ones: 0.4 + 0.3 * (0.5 + 0.5 * (2.0 * t).sin()),
453 }
454 })
455 .collect()
456 }
457
458 const DIR: [f64; 4] = [0.31, -0.22, 0.17, 0.44];
459 const DIRU: [f64; 4] = [0.13, 0.27, -0.41, 0.05];
460 const DIRV: [f64; 4] = [-0.19, 0.33, 0.08, 0.22];
461
462 #[test]
463 fn cpu_channels_match_unified_rowkernel() {
464 use crate::survival::marginal_slope::row_kernel::rigid_row_nll;
469 use gam_math::jet_scalar::{JetScalar, Order2};
470 let rows = fixture(7);
471 let out = survival_rigid_row_jets_cpu(&rows, 0.7, &DIR, &DIRU, &DIRV);
472 for (row, inp) in rows.iter().enumerate() {
473 let in_row = RigidRowInputs {
474 row,
475 wi: inp.wi,
476 di: inp.di,
477 z_sum: inp.z_sum,
478 covariance_ones: inp.cov_ones,
479 probit_scale: 0.7,
480 qd1_lower: f64::NEG_INFINITY,
481 };
482 let vars: [Order2<4>; 4] =
483 std::array::from_fn(|a| Order2::variable(inp.primaries[a], a));
484 let dense = rigid_row_nll(&vars, &in_row).expect("dense order2");
485 assert!((dense.value() - out.value[row]).abs() <= 1e-12);
486 for a in 0..4 {
487 assert!((dense.g()[a] - out.grad[row * 4 + a]).abs() <= 1e-12);
488 for b in 0..4 {
489 assert!(
490 (dense.h()[a][b] - out.hess[row * 16 + a * 4 + b]).abs() <= 1e-12,
491 "hess mismatch row {row} {a},{b}"
492 );
493 }
494 }
495 }
496 }
497
498 #[test]
499 fn cpu_third_fourth_match_dense_tower_oracle() {
500 use crate::survival::marginal_slope::row_kernel::rigid_row_nll;
507 use gam_math::jet_tower::Tower4;
508 let rows = fixture(9);
509 let out = survival_rigid_row_jets_cpu(&rows, 0.7, &DIR, &DIRU, &DIRV);
510 for (row, inp) in rows.iter().enumerate() {
511 let in_row = RigidRowInputs {
512 row,
513 wi: inp.wi,
514 di: inp.di,
515 z_sum: inp.z_sum,
516 covariance_ones: inp.cov_ones,
517 probit_scale: 0.7,
518 qd1_lower: f64::NEG_INFINITY,
519 };
520 let vars: [Tower4<4>; 4] =
521 std::array::from_fn(|a| Tower4::variable(inp.primaries[a], a));
522 let tower = rigid_row_nll(&vars, &in_row).expect("dense tower4");
523 let t3 = tower.third_contracted(&DIR);
524 let t4 = tower.fourth_contracted(&DIRU, &DIRV);
525 for a in 0..4 {
526 for b in 0..4 {
527 assert!(
528 (t3[a][b] - out.third[row * 16 + a * 4 + b]).abs() <= 1e-12,
529 "third mismatch row {row} {a},{b}: tensor={} seeded={}",
530 t3[a][b],
531 out.third[row * 16 + a * 4 + b]
532 );
533 assert!(
534 (t4[a][b] - out.fourth[row * 16 + a * 4 + b]).abs() <= 1e-12,
535 "fourth mismatch row {row} {a},{b}: tensor={} seeded={}",
536 t4[a][b],
537 out.fourth[row * 16 + a * 4 + b]
538 );
539 }
540 }
541 }
542 }
543
544 #[cfg(target_os = "linux")]
545 #[test]
546 fn device_matches_cpu_when_available() {
547 let rows = fixture(DEVICE_ROW_THRESHOLD + 1024);
552 let cpu = survival_rigid_row_jets_cpu(&rows, 0.7, &DIR, &DIRU, &DIRV);
553 let got = survival_rigid_row_jets(&rows, 0.7, &DIR, &DIRU, &DIRV);
554 let mut maxabs = 0.0_f64;
555 let cmp = |a: &[f64], b: &[f64], m: &mut f64| {
556 for (x, y) in a.iter().zip(b) {
557 *m = m.max((x - y).abs());
558 }
559 };
560 cmp(&cpu.value, &got.value, &mut maxabs);
561 cmp(&cpu.grad, &got.grad, &mut maxabs);
562 cmp(&cpu.hess, &got.hess, &mut maxabs);
563 cmp(&cpu.third, &got.third, &mut maxabs);
564 cmp(&cpu.fourth, &got.fourth, &mut maxabs);
565 assert!(
566 maxabs <= 1e-9,
567 "survival device vs CPU row-jet max abs diff {maxabs} > 1e-9"
568 );
569 }
570}