gam_models/gpu_kernels/cubic_cell/mod.rs
1//! GPU substrate for de-nested cubic-cell **derivative moments**.
2//!
3//! This module is the shared GPU evaluator for the de-nested cubic transport
4//! kernel that currently lives in `src/families/cubic_cell_kernel.rs`. For
5//! each partition cell `(left, right, c_0, c_1, c_2, c_3)` it computes the
6//! derivative-moment vector
7//!
8//! ```text
9//! M_k = ∫_{left}^{right} z^k · exp(-q(z)) dz, k = 0..=max_degree,
10//! q(z) = 0.5 · (z² + η(z)²),
11//! η(z) = c_0 + c_1·z + c_2·z² + c_3·z³.
12//! ```
13//!
14//! Three branches feed into the same device API:
15//!
16//! * **Affine** (`c_2 = c_3 = 0`, finite interval): closed-form via the
17//! `T_n(a,b)` recurrence used by `affine_anchor_moment_vector_into`.
18//! * **Non-affine finite**: fixed 384-point Gauss–Legendre on the cell.
19//! * **Affine tail**: closed-form on a semi-infinite (or whole-line) interval.
20//!
21//! This is **distinct** from `src/gpu/cubic_bspline_moments.rs`, which
22//! computes tensor B-spline cell moments. The two modules share neither math
23//! nor data layout: do not conflate them.
24//!
25//! ## Layout
26//!
27//! * [`branch`] — host-side branch classifier; mirrors
28//! `cubic_cell_kernel::branch_cell` + the semi-infinite tail logic of
29//! `evaluate_cell_state_dispatched`.
30//! * [`host_substrate`] — CPU-resident implementation. Works on every
31//! platform and is the parity reference for the device kernel.
32//! * [`kernel_src`] — NVRTC-compilable CUDA C++ source as Rust string
33//! constants (D9 / D15 / D21 specializations).
34//! * [`device`] — Linux+CUDA dispatcher that compiles, launches, and
35//! gathers the NVRTC kernel for the NonAffineFinite bucket; Affine /
36//! AffineTail buckets stay on CPU until Stage-2.
37
38pub(crate) mod branch;
39pub(crate) mod device;
40pub(crate) mod host_substrate;
41pub(crate) mod kernel_src;
42
43use gam_gpu::gpu_error::GpuError;
44
45pub(crate) use host_substrate::build_host_cell_status;
46
47/// Maximum derivative-moment degree the substrate is built to evaluate.
48///
49/// Consumers and their high-water marks:
50/// * Bernoulli flex Hessian: 9
51/// * BMS outer higher-derivative reuse: 21
52/// * Survival flex Hessian (with `D_uv` cross terms): 24
53pub(crate) const MAX_SUPPORTED_DEGREE: usize = 24;
54
55/// A single de-nested cubic-cell payload in the layout the device kernels
56/// consume. Matches the CPU layout in `cubic_cell_kernel.rs`: the cubic
57/// correction `η(z) = c_0 + c_1·z + c_2·z² + c_3·z³` evaluated over
58/// `[left, right]`.
59#[derive(Clone, Copy, Debug, PartialEq)]
60pub(crate) struct GpuDenestedCubicCell {
61 pub left: f64,
62 pub right: f64,
63 pub c0: f64,
64 pub c1: f64,
65 pub c2: f64,
66 pub c3: f64,
67}
68
69/// Branch classification for a single cell. The device dispatcher buckets
70/// cells by tag and launches one specialized kernel per branch to avoid
71/// warp divergence.
72#[derive(Clone, Copy, Debug, Eq, PartialEq)]
73pub(crate) enum GpuCellBranchTag {
74 /// `c_2 = c_3 = 0` and the interval is finite — closed-form `T_n`
75 /// recurrence at the affine anchor.
76 Affine,
77 /// Finite interval with at least one of `c_2`, `c_3` non-zero — fixed
78 /// 384-point Gauss–Legendre on the cell.
79 NonAffineFinite,
80 /// Semi-infinite (or whole-line) affine tail with `c_2 = c_3 = 0` —
81 /// closed-form on the tail interval.
82 AffineTail,
83}
84
85/// Where the caller wants results materialized.
86#[derive(Clone, Copy, Debug, Eq, PartialEq)]
87pub(crate) enum CubicCellMomentResidency {
88 /// Materialize moments into a host `Vec<f64>` (parity reference; works on
89 /// every platform).
90 Host,
91 /// Materialize moments into a device-resident `CudaSlice<f64>` on the
92 /// shared cubic-cell context. Linux+CUDA only; on other platforms this
93 /// variant degrades to `Host`-shaped output through the host substrate
94 /// (no silent device claim).
95 #[cfg(target_os = "linux")]
96 Device,
97}
98
99/// Per-cell status code written by the substrate. Numeric values match the
100/// device kernel's status code emission so the GPU and host paths fill
101/// `Vec<u8>` with the same byte pattern.
102#[repr(u8)]
103#[derive(Clone, Copy, Debug, Eq, PartialEq)]
104pub(crate) enum CubicCellMomentStatus {
105 Ok = 0,
106 /// Finite cell with `right <= left`, mismatched caller branch tag, or
107 /// CPU classifier rejected the cell.
108 InvalidInterval = 1,
109 /// Semi-infinite cell with material `c_2` or `c_3`.
110 NonAffineInfiniteInterval = 2,
111 /// At least one of `c_0..c_3` was NaN/Inf.
112 NonFiniteCoefficient = 3,
113 /// Evaluator produced a non-finite moment (q overflow on a pathological
114 /// cell). The row is zeroed; this is the GPU-side counterpart to a CPU
115 /// `Err`.
116 NonFiniteEvaluation = 4,
117}
118
119/// Host-side input view for `try_build_cubic_cell_derivative_moments`.
120/// The substrate borrows cell data from the caller; it does not own the
121/// CPU partition. `branches` is parallel to `cells`.
122pub(crate) struct CubicCellDerivativeMomentHostView<'a> {
123 pub cells: &'a [GpuDenestedCubicCell],
124 pub branches: &'a [GpuCellBranchTag],
125 pub max_degree: usize,
126 pub residency: CubicCellMomentResidency,
127}
128
129/// Output of `try_build_cubic_cell_derivative_moments`.
130#[derive(Debug)]
131pub(crate) enum CubicCellDerivativeMomentOutput {
132 /// Per-cell substrate status codes from the host (CPU) path. Production
133 /// callers only consume the per-cell classifier verdicts here; the actual
134 /// moments stay on the upstream `HostMomentBatch` and are accessed via
135 /// `build_host_moments` directly when needed (e.g. parity tests). The
136 /// Device variant below carries the moments because the device kernel
137 /// hands them back as a residency-bound buffer.
138 Host { status: Vec<u8> },
139 /// Device-resident moments on the cubic-cell backend's shared CUDA
140 /// context. Linux-only — non-Linux callers see the `Host` variant even
141 /// when they request `Device` residency. Layout matches `Host` so
142 /// `d_moments` is a row-major `[n_cells, stride]` `CudaSlice<f64>`. The
143 /// host-side `status` vector mirrors the per-cell device status so
144 /// downstream branching decisions never have to round-trip from the
145 /// device.
146 #[cfg(target_os = "linux")]
147 Device {
148 d_moments: cudarc::driver::CudaSlice<f64>,
149 status: Vec<u8>,
150 stride: usize,
151 n_cells: usize,
152 },
153}
154
155/// Try to build derivative moments via the substrate.
156///
157/// * `Host` residency: routes through the CPU classifier and returns
158/// per-cell status only. Production consumers read the verdict and feed
159/// moments from their own evaluator (LRU cache for BMS row-primary
160/// Hessian, dedicated host buffer for survival-flex). The full moment
161/// matrix used to be returned here as well, but no production caller
162/// ever read it — the parity reference path that compares CPU moments to
163/// the device kernel now lives next to the host substrate's own unit
164/// tests.
165/// * `Device` residency: on Linux+CUDA with a probed runtime, the device
166/// dispatcher launches the NVRTC kernel for the NonAffineFinite bucket
167/// and CPU-evaluates the Affine/AffineTail buckets, packing both back
168/// into a `Device { … }` output for the caller. When the runtime is
169/// unavailable the caller receives a `Host { status }` output instead —
170/// no silent device claim.
171///
172/// Returns `Ok(None)` only when the workload is empty.
173///
174pub(crate) fn try_build_cubic_cell_derivative_moments(
175 input: CubicCellDerivativeMomentHostView<'_>,
176) -> Result<Option<CubicCellDerivativeMomentOutput>, GpuError> {
177 if input.cells.len() != input.branches.len() {
178 gam_gpu::gpu_bail!(
179 "gpu cubic-cell substrate: cells.len()={} != branches.len()={}",
180 input.cells.len(),
181 input.branches.len()
182 );
183 }
184 if input.max_degree > MAX_SUPPORTED_DEGREE {
185 gam_gpu::gpu_bail!(
186 "gpu cubic-cell substrate: max_degree={} exceeds MAX_SUPPORTED_DEGREE={}",
187 input.max_degree,
188 MAX_SUPPORTED_DEGREE
189 );
190 }
191 if input.cells.is_empty() {
192 return Ok(None);
193 }
194
195 match input.residency {
196 CubicCellMomentResidency::Host => {
197 let status = build_host_cell_status(&input)
198 .map_err(|reason| GpuError::DriverCallFailed { reason })?;
199 Ok(Some(CubicCellDerivativeMomentOutput::Host { status }))
200 }
201 #[cfg(target_os = "linux")]
202 CubicCellMomentResidency::Device => {
203 if let Some(device_batch) = device::try_device_moments_resident(&input)? {
204 return Ok(Some(device_batch));
205 }
206 // Non-Linux, or no usable runtime: fall back to the host shape so
207 // the caller has a parity-shaped result instead of a phantom
208 // device claim.
209 let status = build_host_cell_status(&input)
210 .map_err(|reason| GpuError::DriverCallFailed { reason })?;
211 Ok(Some(CubicCellDerivativeMomentOutput::Host { status }))
212 }
213 }
214}
215
216#[cfg(test)]
217mod tests {
218 use super::*;
219
220 fn affine_cell() -> GpuDenestedCubicCell {
221 GpuDenestedCubicCell {
222 left: -1.0,
223 right: 1.0,
224 c0: 0.0,
225 c1: 0.0,
226 c2: 0.0,
227 c3: 0.0,
228 }
229 }
230
231 fn host_view<'a>(
232 cells: &'a [GpuDenestedCubicCell],
233 branches: &'a [GpuCellBranchTag],
234 max_degree: usize,
235 ) -> CubicCellDerivativeMomentHostView<'a> {
236 CubicCellDerivativeMomentHostView {
237 cells,
238 branches,
239 max_degree,
240 residency: CubicCellMomentResidency::Host,
241 }
242 }
243
244 #[test]
245 fn host_residency_returns_ok_status_for_valid_cell() {
246 // The public substrate's host residency returns only the per-cell
247 // status codes — production callers (the BMS row-primary Hessian
248 // assembler and the survival-flex row evaluator) consume the
249 // classifier verdict, not the moments. The moment-emitting parity
250 // path lives next to the host substrate's own unit tests; here we
251 // only assert the public entry surface delivers an Ok status for a
252 // valid cell, the contract production callers depend on.
253 let cells = [affine_cell()];
254 let branches = [GpuCellBranchTag::Affine];
255 let out = try_build_cubic_cell_derivative_moments(host_view(&cells, &branches, 9))
256 .expect("host substrate succeeds on a valid cell")
257 .expect("non-empty input produces output");
258 let status = match out {
259 CubicCellDerivativeMomentOutput::Host { status } => status,
260 #[cfg(target_os = "linux")]
261 CubicCellDerivativeMomentOutput::Device { .. } => {
262 panic!("host residency request must not yield Device output")
263 }
264 };
265 assert_eq!(status, vec![CubicCellMomentStatus::Ok as u8]);
266 }
267
268 #[test]
269 fn empty_input_returns_ok_none() {
270 let out = try_build_cubic_cell_derivative_moments(host_view(&[], &[], 9)).expect("ok");
271 assert!(out.is_none());
272 }
273
274 #[test]
275 fn rejects_mismatched_lengths() {
276 let cells = [affine_cell()];
277 let branches: [GpuCellBranchTag; 0] = [];
278 let err =
279 try_build_cubic_cell_derivative_moments(host_view(&cells, &branches, 9)).unwrap_err();
280 let msg = err.to_string();
281 assert!(msg.contains("cells.len()"), "got: {msg}");
282 assert!(msg.contains("branches.len()"), "got: {msg}");
283 }
284
285 #[test]
286 fn rejects_degree_above_supported_max() {
287 let cells = [affine_cell()];
288 let branches = [GpuCellBranchTag::Affine];
289 let err = try_build_cubic_cell_derivative_moments(host_view(
290 &cells,
291 &branches,
292 MAX_SUPPORTED_DEGREE + 1,
293 ))
294 .unwrap_err();
295 assert!(err.to_string().contains("MAX_SUPPORTED_DEGREE"));
296 }
297
298 #[test]
299 fn status_codes_match_kernel_abi() {
300 assert_eq!(CubicCellMomentStatus::Ok as u8, 0);
301 assert_eq!(CubicCellMomentStatus::InvalidInterval as u8, 1);
302 assert_eq!(CubicCellMomentStatus::NonAffineInfiniteInterval as u8, 2);
303 assert_eq!(CubicCellMomentStatus::NonFiniteCoefficient as u8, 3);
304 assert_eq!(CubicCellMomentStatus::NonFiniteEvaluation as u8, 4);
305 }
306
307 // Phase 4 device-residency parity test lives next to the device backend
308 // at `crate::gpu_kernels::cubic_cell::device::tests::cubic_cell_device_residency_matches_cpu_all_branches`
309 // so it can use the in-mod `download_moments` helper directly.
310}