use super::*;
impl SurvivalMarginalSlopeFamily {
pub(crate) fn build_cached_partition_with_moment_order(
&self,
primary: &FlexPrimarySlices,
a: f64,
b: f64,
beta_h: Option<&Array1<f64>>,
beta_w: Option<&Array1<f64>>,
moment_order: usize,
) -> Result<CachedPartitionCells, String> {
let raw_cells = {
let row_input =
crate::survival::marginal_slope::gpu_prep::PartitionCellsRowInputs {
a,
b,
beta_h: beta_h.and_then(|b| b.as_slice()),
beta_w: beta_w.and_then(|b| b.as_slice()),
};
let dev =
crate::survival::marginal_slope::gpu_prep::try_device_partition_cells(
std::slice::from_ref(&row_input),
)
.map_err(|e| e.to_string())?;
match dev {
Some(mut by_row) if by_row.len() == 1 => by_row.remove(0),
_ => self.denested_partition_cells(a, b, beta_h, beta_w)?,
}
};
let n = raw_cells.len();
let mut z_mids = Vec::with_capacity(n);
let mut u_mids = Vec::with_capacity(n);
let mut states = Vec::with_capacity(n);
let mut fp_inputs = Vec::<
crate::survival::marginal_slope::gpu_prep::CellPrimaryFixedPartialsCellInputs,
>::with_capacity(n);
for partition_cell in &raw_cells {
let cell = partition_cell.cell;
let z_mid = exact_kernel::interval_probe_point(cell.left, cell.right)?;
let u_mid = a + b * z_mid;
let state = exact_kernel::evaluate_cell_moments(cell, moment_order)?;
z_mids.push(z_mid);
u_mids.push(u_mid);
states.push(state);
fp_inputs.push(
crate::survival::marginal_slope::gpu_prep::CellPrimaryFixedPartialsCellInputs {
score_span: partition_cell.score_span,
link_span: partition_cell.link_span,
},
);
}
let layout = crate::survival::marginal_slope::gpu_prep::FlexPrimaryLayout {
r: u32::try_from(primary.total).map_err(|_| {
format!(
"build_cached_partition_with_moment_order: primary.total={} exceeds u32",
primary.total
)
})?,
g_slot: u32::try_from(primary.g).map_err(|_| {
format!(
"build_cached_partition_with_moment_order: primary.g={} exceeds u32",
primary.g
)
})?,
};
let row_fp_input =
crate::survival::marginal_slope::gpu_prep::CellPrimaryFixedPartialsRowInputs {
cells: &fp_inputs,
layout,
};
let dev_fixed = crate::survival::marginal_slope::gpu_prep::try_device_cell_primary_fixed_partials(
std::slice::from_ref(&row_fp_input),
)
.map_err(|e| e.to_string())?;
if let Some(out) = dev_fixed.as_ref()
&& out.partials.len() == 1
&& out.partials[0].len() == n
{
let mut cells = Vec::with_capacity(n);
for (idx, partition_cell) in raw_cells.into_iter().enumerate() {
let flat = &out.partials[0][idx];
let fixed = DenestedCellPrimaryFixedPartials::from_flat_slice(
flat.as_slice(),
primary.total,
)?;
cells.push(CachedCellEntry {
partition_cell,
state: states[idx].clone(),
fixed,
});
}
return Ok(CachedPartitionCells { cells });
}
let mut cells = Vec::with_capacity(n);
for (idx, partition_cell) in raw_cells.into_iter().enumerate() {
let fixed = self.denested_cell_primary_fixed_partials(
primary,
a,
b,
partition_cell.score_span,
partition_cell.link_span,
z_mids[idx],
u_mids[idx],
)?;
cells.push(CachedCellEntry {
partition_cell,
state: states[idx].clone(),
fixed,
});
}
Ok(CachedPartitionCells { cells })
}
pub(crate) fn build_cached_partition(
&self,
primary: &FlexPrimarySlices,
a: f64,
b: f64,
beta_h: Option<&Array1<f64>>,
beta_w: Option<&Array1<f64>>,
) -> Result<CachedPartitionCells, String> {
self.build_cached_partition_with_moment_order(primary, a, b, beta_h, beta_w, 32)
}
}