1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
//! Cached piecewise-cubic partition build.
//!
//! Builds the cells + per-cell moment states + fixed partials once per
//! `(a, b, β_h, β_w)` so the first-order / full / directional / bidirectional
//! integration passes (F, D, D_uv) all share one partition. The cell-table and
//! per-cell fixed-partials assemblies route through the GPU-shaped `try_device_*`
//! seams, falling back to the CPU implementation on decline.
use super::*;
impl SurvivalMarginalSlopeFamily {
/// Build a cached partition: cells + moment states + fixed partials,
/// computed once per (a, b, β_h, β_w) and reused across the three
/// integration passes (F, D, D_uv).
///
/// The cell-table assembly and the per-cell primary-fixed-partials
/// assembly route through the GPU-shaped `try_device_*` seams in
/// [`crate::survival::marginal_slope::gpu_prep`]. Until the matching NVRTC kernels
/// land, both seams return `Ok(None)` and the call site falls back to
/// the existing CPU implementation, so behavior is preserved.
pub(crate) fn build_cached_partition_with_moment_order(
&self,
primary: &FlexPrimarySlices,
a: f64,
b: f64,
beta_h: Option<&Array1<f64>>,
beta_w: Option<&Array1<f64>>,
moment_order: usize,
) -> Result<CachedPartitionCells, String> {
// ── 1. partition cells via the device seam, CPU fallback on decline ──
let raw_cells = {
let row_input =
crate::survival::marginal_slope::gpu_prep::PartitionCellsRowInputs {
a,
b,
beta_h: beta_h.and_then(|b| b.as_slice()),
beta_w: beta_w.and_then(|b| b.as_slice()),
};
let dev =
crate::survival::marginal_slope::gpu_prep::try_device_partition_cells(
std::slice::from_ref(&row_input),
)
.map_err(|e| e.to_string())?;
match dev {
Some(mut by_row) if by_row.len() == 1 => by_row.remove(0),
_ => self.denested_partition_cells(a, b, beta_h, beta_w)?,
}
};
// ── 2. per-cell prelude (z_mid, u_mid, moment state) ──
let n = raw_cells.len();
let mut z_mids = Vec::with_capacity(n);
let mut u_mids = Vec::with_capacity(n);
let mut states = Vec::with_capacity(n);
let mut fp_inputs = Vec::<
crate::survival::marginal_slope::gpu_prep::CellPrimaryFixedPartialsCellInputs,
>::with_capacity(n);
for partition_cell in &raw_cells {
let cell = partition_cell.cell;
let z_mid = exact_kernel::interval_probe_point(cell.left, cell.right)?;
let u_mid = a + b * z_mid;
let state = exact_kernel::evaluate_cell_moments(cell, moment_order)?;
z_mids.push(z_mid);
u_mids.push(u_mid);
states.push(state);
fp_inputs.push(
crate::survival::marginal_slope::gpu_prep::CellPrimaryFixedPartialsCellInputs {
score_span: partition_cell.score_span,
link_span: partition_cell.link_span,
},
);
}
// ── 3. per-cell fixed partials via the device seam, CPU fallback ──
let layout = crate::survival::marginal_slope::gpu_prep::FlexPrimaryLayout {
r: u32::try_from(primary.total).map_err(|_| {
format!(
"build_cached_partition_with_moment_order: primary.total={} exceeds u32",
primary.total
)
})?,
g_slot: u32::try_from(primary.g).map_err(|_| {
format!(
"build_cached_partition_with_moment_order: primary.g={} exceeds u32",
primary.g
)
})?,
};
let row_fp_input =
crate::survival::marginal_slope::gpu_prep::CellPrimaryFixedPartialsRowInputs {
cells: &fp_inputs,
layout,
};
let dev_fixed = crate::survival::marginal_slope::gpu_prep::try_device_cell_primary_fixed_partials(
std::slice::from_ref(&row_fp_input),
)
.map_err(|e| e.to_string())?;
// When the device path returns flat-packed partials, reconstruct
// the per-cell `DenestedCellPrimaryFixedPartials` from the device
// buffer via the `from_flat_slice` shim — byte-identical to what
// the CPU per-cell helper would produce for the supported
// (trivial-span) shape. Any decline drops through to the CPU
// per-cell loop below.
if let Some(out) = dev_fixed.as_ref()
&& out.partials.len() == 1
&& out.partials[0].len() == n
{
let mut cells = Vec::with_capacity(n);
for (idx, partition_cell) in raw_cells.into_iter().enumerate() {
let flat = &out.partials[0][idx];
let fixed = DenestedCellPrimaryFixedPartials::from_flat_slice(
flat.as_slice(),
primary.total,
)?;
cells.push(CachedCellEntry {
partition_cell,
state: states[idx].clone(),
fixed,
});
}
return Ok(CachedPartitionCells { cells });
}
let mut cells = Vec::with_capacity(n);
for (idx, partition_cell) in raw_cells.into_iter().enumerate() {
let fixed = self.denested_cell_primary_fixed_partials(
primary,
a,
b,
partition_cell.score_span,
partition_cell.link_span,
z_mids[idx],
u_mids[idx],
)?;
cells.push(CachedCellEntry {
partition_cell,
state: states[idx].clone(),
fixed,
});
}
Ok(CachedPartitionCells { cells })
}
pub(crate) fn build_cached_partition(
&self,
primary: &FlexPrimarySlices,
a: f64,
b: f64,
beta_h: Option<&Array1<f64>>,
beta_w: Option<&Array1<f64>>,
) -> Result<CachedPartitionCells, String> {
// The flex moment closure expands `S(z)=e^{−Δq}=Σ_{k≤4}(−Δq)^k/k!`; with η
// cubic, `−Δq=½(η²−η₀²)` is degree-6 in z, so the fourth-order `(−Δq)⁴` term
// reaches `M_{n+24}` (= `M_28` for the `n≤4` base moments). A moment order of
// 27 silently dropped `M_28`, truncating the contracted-fourth (Jet4) channel.
// Build to 32 so every order of the e^{−Δq} closure has its full moment dot.
self.build_cached_partition_with_moment_order(primary, a, b, beta_h, beta_w, 32)
}
}