use super::*;
use gam_math::jet_scalar::{JetScalar, OneSeed, TwoSeed};
pub(crate) const RIGID_LINEAR_MASK: u32 = (1 << 0) | (1 << 1) | (1 << 2);
#[inline(always)]
const fn axis_is_linear(mask: u32, a: usize) -> bool {
(mask >> a) & 1 == 1
}
#[derive(Clone, Copy, Debug)]
pub(crate) struct SparseOrder2<const LIN: u32> {
v: f64,
grad: [f64; 4],
hess: [[f64; 4]; 4],
}
impl<const LIN: u32> SparseOrder2<LIN> {
#[inline]
pub(crate) fn g(&self) -> [f64; 4] {
self.grad
}
#[inline]
pub(crate) fn h(&self) -> [[f64; 4]; 4] {
self.hess
}
#[inline(always)]
fn check_contract(&self) {
for i in 0..4 {
if axis_is_linear(LIN, i) {
for j in 0..4 {
if axis_is_linear(LIN, j) {
assert!(
self.hess[i][j] == 0.0,
"static-sparsity contract violated: linear×linear Hessian h[{i}][{j}]={} != 0 (axes {i},{j} both declared linear but the program forms curvature between them)",
self.hess[i][j]
);
}
}
}
}
}
}
impl<const LIN: u32> JetScalar<4> for SparseOrder2<LIN> {
fn constant(c: f64) -> Self {
Self {
v: c,
grad: [0.0; 4],
hess: [[0.0; 4]; 4],
}
}
fn variable(x: f64, axis: usize) -> Self {
let mut grad = [0.0; 4];
grad[axis] = 1.0;
Self {
v: x,
grad,
hess: [[0.0; 4]; 4],
}
}
fn value(&self) -> f64 {
self.v
}
fn add(&self, o: &Self) -> Self {
let mut r = *self;
r.v += o.v;
for i in 0..4 {
r.grad[i] += o.grad[i];
for j in 0..4 {
r.hess[i][j] += o.hess[i][j];
}
}
r
}
fn sub(&self, o: &Self) -> Self {
self.add(&o.neg())
}
fn neg(&self) -> Self {
self.scale(-1.0)
}
fn scale(&self, s: f64) -> Self {
let mut r = *self;
r.v *= s;
for i in 0..4 {
r.grad[i] *= s;
for j in 0..4 {
r.hess[i][j] *= s;
}
}
r
}
fn mul(&self, o: &Self) -> Self {
let a = self;
let b = o;
a.check_contract();
b.check_contract();
let mut r = Self::constant(a.v * b.v);
for i in 0..4 {
r.grad[i] = a.v * b.grad[i] + a.grad[i] * b.v;
}
for i in 0..4 {
for j in 0..4 {
let mut hij = a.grad[i] * b.grad[j] + a.grad[j] * b.grad[i];
if !axis_is_linear(LIN, i) || !axis_is_linear(LIN, j) {
hij += a.v * b.hess[i][j] + a.hess[i][j] * b.v;
}
r.hess[i][j] = hij;
}
}
r
}
fn compose_unary(&self, d: [f64; 5]) -> Self {
self.check_contract();
let (f1, f2) = (d[1], d[2]);
let mut r = Self::constant(d[0]);
for i in 0..4 {
r.grad[i] = f1 * self.grad[i];
}
for i in 0..4 {
for j in 0..4 {
let mut hij = f2 * self.grad[i] * self.grad[j];
if !axis_is_linear(LIN, i) || !axis_is_linear(LIN, j) {
hij += f1 * self.hess[i][j];
}
r.hess[i][j] = hij;
}
}
r
}
}
pub(crate) struct SurvivalMarginalSlopeRowKernel {
pub(crate) family: SurvivalMarginalSlopeFamily,
pub(crate) block_states: Vec<ParameterBlockState>,
pub(crate) slices: BlockSlices,
}
impl SurvivalMarginalSlopeRowKernel {
pub(crate) fn new(
family: SurvivalMarginalSlopeFamily,
block_states: Vec<ParameterBlockState>,
) -> Self {
let slices = block_slices(&family, &block_states);
Self {
family,
block_states,
slices,
}
}
}
pub(crate) fn rigid_row_kernel_primaries(
family: &SurvivalMarginalSlopeFamily,
block_states: &[ParameterBlockState],
row: usize,
) -> Result<[f64; 4], String> {
let q_geom = family.row_dynamic_q_values(row, block_states)?;
Ok([q_geom.q0, q_geom.q1, q_geom.qd1, block_states[2].eta[row]])
}
pub(crate) struct RigidRowInputs {
pub(crate) row: usize,
pub(crate) wi: f64,
pub(crate) di: f64,
pub(crate) z_sum: f64,
pub(crate) covariance_ones: f64,
pub(crate) probit_scale: f64,
pub(crate) qd1_lower: f64,
}
pub(crate) fn rigid_row_inputs(
family: &SurvivalMarginalSlopeFamily,
block_states: &[ParameterBlockState],
row: usize,
context: &str,
) -> Result<RigidRowInputs, String> {
let (z_sum, covariance_ones) = family.exact_shared_score_summary(row, block_states, context)?;
Ok(RigidRowInputs {
row,
wi: family.weights[row],
di: family.event[row],
z_sum,
covariance_ones,
probit_scale: family.probit_frailty_scale(),
qd1_lower: family.time_derivative_lower_bound(),
})
}
pub(crate) fn rigid_row_nll<S: JetScalar<4>>(
vars: &[S; 4],
inputs: &RigidRowInputs,
) -> Result<S, String> {
let RigidRowInputs {
row,
wi,
di,
z_sum,
covariance_ones,
probit_scale,
qd1_lower,
} = *inputs;
let q0 = &vars[0];
let q1 = &vars[1];
let qd1 = &vars[2];
let g = &vars[3];
let observed_g = g.scale(probit_scale);
let one_plus_b2 = observed_g
.mul(&observed_g)
.scale(covariance_ones)
.add(&S::constant(1.0));
let c = one_plus_b2.compose_unary(unary_derivatives_sqrt(one_plus_b2.value()));
let observed_gz = observed_g.scale(z_sum);
let eta0 = q0.mul(&c).add(&observed_gz);
let eta1 = q1.mul(&c).add(&observed_gz);
let ad1 = qd1.mul(&c);
if survival_derivative_guard_violated(qd1.value(), qd1_lower) {
return Err(SurvivalMarginalSlopeError::MonotonicityViolation {
reason: format!(
"survival marginal-slope monotonicity violated at row {row}: raw time derivative={:.3e} must be at least derivative_guard={:.3e}; transformed time derivative={:.3e}",
qd1.value(), qd1_lower, ad1.value()
),
}
.into());
}
let reject_nonfinite_margin = |margin: f64, weight: f64| -> Result<(), String> {
if weight != 0.0 && margin != f64::INFINITY && !margin.is_finite() {
Err(SurvivalMarginalSlopeError::NumericalFailure {
reason: format!(
"non-finite signed margin in rigid survival marginal-slope row tower at row {row}: {margin}"
),
}
.into())
} else {
Ok(())
}
};
let neg_eta0 = eta0.neg();
reject_nonfinite_margin(neg_eta0.value(), wi)?;
let entry = neg_eta0
.compose_unary(unary_derivatives_neglog_phi(neg_eta0.value(), wi))
.scale(-1.0);
let neg_eta1 = eta1.neg();
reject_nonfinite_margin(neg_eta1.value(), wi * (1.0 - di))?;
let exit = neg_eta1.compose_unary(unary_derivatives_neglog_phi(
neg_eta1.value(),
wi * (1.0 - di),
));
let event_density = if di > 0.0 {
eta1.compose_unary(unary_derivatives_log_normal_pdf(eta1.value()))
.scale(-wi * di)
} else {
S::constant(0.0)
};
let time_deriv = if di > 0.0 {
ad1.compose_unary(unary_derivatives_log(ad1.value()))
.scale(-wi * di)
} else {
S::constant(0.0)
};
Ok(exit.add(&entry).add(&event_density).add(&time_deriv))
}
pub(crate) fn rigid_row_kernel_nll_tower(
family: &SurvivalMarginalSlopeFamily,
block_states: &[ParameterBlockState],
row: usize,
p: &[gam_math::jet_tower::Tower4<4>; 4],
context: &str,
) -> Result<gam_math::jet_tower::Tower4<4>, String> {
let inputs = rigid_row_inputs(family, block_states, row, context)?;
rigid_row_nll(p, &inputs)
}
impl gam_math::jet_tower::RowNllProgram<4> for SurvivalMarginalSlopeRowKernel {
fn n_rows(&self) -> usize {
self.family.n
}
fn primaries(&self, row: usize) -> Result<[f64; 4], String> {
rigid_row_kernel_primaries(&self.family, &self.block_states, row)
}
fn row_nll(
&self,
row: usize,
p: &[gam_math::jet_tower::Tower4<4>; 4],
) -> Result<gam_math::jet_tower::Tower4<4>, String> {
rigid_row_kernel_nll_tower(
&self.family,
&self.block_states,
row,
p,
"survival marginal-slope rigid row tower",
)
}
}
impl RowKernel<4> for SurvivalMarginalSlopeRowKernel {
fn n_rows(&self) -> usize {
self.family.n
}
fn n_coefficients(&self) -> usize {
self.slices.total
}
fn row_kernel(&self, row: usize) -> Result<(f64, [f64; 4], [[f64; 4]; 4]), String> {
let inputs = rigid_row_inputs(
&self.family,
&self.block_states,
row,
"survival marginal-slope rigid row kernel",
)?;
let p = rigid_row_kernel_primaries(&self.family, &self.block_states, row)?;
let vars: [SparseOrder2<RIGID_LINEAR_MASK>; 4] =
std::array::from_fn(|a| SparseOrder2::variable(p[a], a));
let out = rigid_row_nll(&vars, &inputs)?;
Ok((out.value(), out.g(), out.h()))
}
fn batched_value_grad_hess_all(
&self,
) -> Option<Result<(Vec<f64>, Vec<[f64; 4]>, Vec<[[f64; 4]; 4]>), String>> {
use crate::gpu_kernels::survival_rowjet::{SurvivalRowInputs, survival_rigid_row_jets};
let n = self.family.n;
let probit_scale = self.family.probit_frailty_scale();
let qd1_lower = self.family.time_derivative_lower_bound();
let gather: Result<Vec<SurvivalRowInputs>, String> = (0..n)
.into_par_iter()
.map(|row| {
let p = rigid_row_kernel_primaries(&self.family, &self.block_states, row)?;
if survival_derivative_guard_violated(p[2], qd1_lower) {
return Err("monotonicity-violation-fallback".to_string());
}
let inputs = rigid_row_inputs(
&self.family,
&self.block_states,
row,
"survival marginal-slope rigid row kernel (batched)",
)?;
Ok(SurvivalRowInputs {
primaries: p,
wi: inputs.wi,
di: inputs.di,
z_sum: inputs.z_sum,
cov_ones: inputs.covariance_ones,
})
})
.collect();
let rows = match gather {
Ok(rows) => rows,
Err(_) => return None,
};
let zero = [0.0_f64; 4];
let ch = survival_rigid_row_jets(&rows, probit_scale, &zero, &zero, &zero);
let mut grads = vec![[0.0_f64; 4]; n];
let mut hesss = vec![[[0.0_f64; 4]; 4]; n];
for row in 0..n {
for a in 0..4 {
grads[row][a] = ch.grad[row * 4 + a];
for b in 0..4 {
hesss[row][a][b] = ch.hess[row * 16 + a * 4 + b];
}
}
}
Some(Ok((ch.value, grads, hesss)))
}
fn jacobian_action(&self, row: usize, d_beta: &[f64]) -> [f64; 4] {
let d_beta = ndarray::ArrayView1::from(d_beta);
let d_time = d_beta.slice(s![self.slices.time.clone()]);
let d_marginal = d_beta.slice(s![self.slices.marginal.clone()]);
let d_logslope = d_beta.slice(s![self.slices.logslope.clone()]);
[
self.family.design_entry.dot_row_view(row, d_time)
+ self.family.marginal_design.dot_row_view(row, d_marginal),
self.family.design_exit.dot_row_view(row, d_time)
+ self.family.marginal_design.dot_row_view(row, d_marginal),
self.family.design_derivative_exit.dot_row_view(row, d_time),
self.family.logslope_design.dot_row_view(row, d_logslope),
]
}
fn jacobian_action_matrix(&self, factor: ArrayView2<'_, f64>) -> Option<Array2<f64>> {
if factor.nrows() != self.slices.total {
return None;
}
let n_rows = self.family.n;
Some(self.assemble_jf(factor, n_rows, |design, factor_block| {
crate::row_kernel::row_kernel_design_jf(design, factor_block, n_rows)
}))
}
fn jacobian_action_matrix_rows(
&self,
factor: ArrayView2<'_, f64>,
start: usize,
end: usize,
) -> Array2<f64> {
if factor.nrows() != self.slices.total {
return crate::row_kernel::row_kernel_jacobian_action_matrix_generic_rows(
self, factor, start, end,
);
}
let b = end.saturating_sub(start);
self.assemble_jf(factor, b, |design, factor_block| {
crate::row_kernel::row_kernel_design_jf_rows(design, factor_block, start, end)
})
}
fn jacobian_transpose_action(&self, row: usize, v: &[f64; 4], out: &mut [f64]) {
{
let mut time = ndarray::ArrayViewMut1::from(&mut out[self.slices.time.clone()]);
self.family
.design_entry
.axpy_row_into(row, v[0], &mut time)
.expect("time entry axpy dim mismatch");
self.family
.design_exit
.axpy_row_into(row, v[1], &mut time)
.expect("time exit axpy dim mismatch");
self.family
.design_derivative_exit
.axpy_row_into(row, v[2], &mut time)
.expect("time deriv axpy dim mismatch");
}
{
let mut marginal = ndarray::ArrayViewMut1::from(&mut out[self.slices.marginal.clone()]);
self.family
.marginal_design
.axpy_row_into(row, v[0] + v[1], &mut marginal)
.expect("marginal axpy dim mismatch");
}
{
let mut logslope = ndarray::ArrayViewMut1::from(&mut out[self.slices.logslope.clone()]);
self.family
.logslope_design
.axpy_row_into(row, v[3], &mut logslope)
.expect("logslope axpy dim mismatch");
}
}
fn add_pullback_hessian(&self, row: usize, h: &[[f64; 4]; 4], target: &mut Array2<f64>) {
let mut h_arr = Array2::<f64>::zeros((4, 4));
for a in 0..4 {
for b in 0..4 {
h_arr[[a, b]] = h[a][b];
}
}
self.family
.add_pullback_primary_hessian(target, row, &self.slices, &h_arr);
}
fn add_diagonal_quadratic(&self, row: usize, h: &[[f64; 4]; 4], diag: &mut [f64]) {
let designs: [(usize, &DesignMatrix); 3] = [
(0, &self.family.design_entry),
(1, &self.family.design_exit),
(2, &self.family.design_derivative_exit),
];
for &(pi, des) in &designs {
{
let mut td = ndarray::ArrayViewMut1::from(&mut diag[self.slices.time.clone()]);
des.squared_axpy_row_into(row, h[pi][pi], &mut td)
.expect("time squared_axpy dim mismatch");
}
for &(pj, des_j) in &designs {
if pj <= pi {
continue;
}
let mut td = ndarray::ArrayViewMut1::from(&mut diag[self.slices.time.clone()]);
des.crossdiag_axpy_row_into(row, des_j, 2.0 * h[pi][pj], &mut td)
.expect("time crossdiag dim mismatch");
}
}
{
let alpha = h[0][0] + 2.0 * h[0][1] + h[1][1];
let mut md = ndarray::ArrayViewMut1::from(&mut diag[self.slices.marginal.clone()]);
self.family
.marginal_design
.squared_axpy_row_into(row, alpha, &mut md)
.expect("marginal squared_axpy dim mismatch");
}
{
let mut gd = ndarray::ArrayViewMut1::from(&mut diag[self.slices.logslope.clone()]);
self.family
.logslope_design
.squared_axpy_row_into(row, h[3][3], &mut gd)
.expect("logslope squared_axpy dim mismatch");
}
}
fn row_third_contracted(&self, row: usize, dir: &[f64; 4]) -> Result<[[f64; 4]; 4], String> {
let inputs = rigid_row_inputs(
&self.family,
&self.block_states,
row,
"survival marginal-slope rigid row third",
)?;
let p = rigid_row_kernel_primaries(&self.family, &self.block_states, row)?;
let vars: [OneSeed<4>; 4] =
std::array::from_fn(|a| OneSeed::seed_direction(p[a], a, dir[a]));
Ok(rigid_row_nll(&vars, &inputs)?.contracted_third())
}
fn row_fourth_contracted(
&self,
row: usize,
dir_u: &[f64; 4],
dir_v: &[f64; 4],
) -> Result<[[f64; 4]; 4], String> {
let inputs = rigid_row_inputs(
&self.family,
&self.block_states,
row,
"survival marginal-slope rigid row fourth",
)?;
let p = rigid_row_kernel_primaries(&self.family, &self.block_states, row)?;
let vars: [TwoSeed<4>; 4] =
std::array::from_fn(|a| TwoSeed::seed(p[a], a, dir_u[a], dir_v[a]));
Ok(rigid_row_nll(&vars, &inputs)?.contracted_fourth())
}
fn directional_derivative_all_axes_dense_override(
&self,
rows: &crate::row_kernel::RowSet,
p: usize,
) -> Option<Result<Vec<Array2<f64>>, String>> {
if p != self.n_coefficients() {
return Some(Err(format!(
"survival marginal-slope directional_derivative_all_axes_dense_override: \
axis count {p} disagrees with n_coefficients() {}",
self.n_coefficients(),
)));
}
if !matches!(rows, crate::row_kernel::RowSet::All) {
return None;
}
Some(self.directional_derivative_all_axes_build_once(p))
}
fn second_directional_derivative_all_axes_dense_override(
&self,
rows: &crate::row_kernel::RowSet,
d_beta_u: &[f64],
) -> Option<Result<Vec<Array2<f64>>, String>> {
if d_beta_u.len() != self.n_coefficients() {
return Some(Err(format!(
"survival marginal-slope second_directional_derivative_all_axes_dense_override: \
fixed direction has {} entries, expected {}",
d_beta_u.len(),
self.n_coefficients(),
)));
}
if !matches!(rows, crate::row_kernel::RowSet::All) {
return None;
}
Some(self.second_directional_derivative_all_axes_build_once(d_beta_u))
}
}
impl SurvivalMarginalSlopeRowKernel {
pub(crate) fn assemble_jf<F>(
&self,
factor: ArrayView2<'_, f64>,
n_out: usize,
axis: F,
) -> Array2<f64>
where
F: Fn(&DesignMatrix, ArrayView2<'_, f64>) -> Array2<f64>,
{
let rank = factor.ncols();
if rank == 0 {
return Array2::<f64>::zeros((n_out, 0));
}
let f_time = factor.slice(s![self.slices.time.clone(), ..]);
let f_marginal = factor.slice(s![self.slices.marginal.clone(), ..]);
let f_logslope = factor.slice(s![self.slices.logslope.clone(), ..]);
let jf_marginal = axis(&self.family.marginal_design, f_marginal);
let mut axis0 = axis(&self.family.design_entry, f_time);
axis0 += &jf_marginal;
let mut axis1 = axis(&self.family.design_exit, f_time);
axis1 += &jf_marginal;
let axis2 = axis(&self.family.design_derivative_exit, f_time);
let axis3 = axis(&self.family.logslope_design, f_logslope);
crate::row_kernel::row_kernel_pack_jf_axes::<4>(
n_out,
rank,
[(0, axis0), (1, axis1), (2, axis2), (3, axis3)],
)
}
}
impl SurvivalMarginalSlopeRowKernel {
fn build_row_towers(&self) -> Result<Vec<gam_math::jet_tower::Tower4<4>>, String> {
let n = <Self as RowKernel<4>>::n_rows(self);
(0..n)
.into_par_iter()
.map(|row| gam_math::jet_tower::evaluate_program::<4, Self>(self, row))
.collect()
}
fn chunked_pullback_reduce<F>(&self, p: usize, per_row: F) -> Result<Array2<f64>, String>
where
F: Fn(usize, &mut Array2<f64>) -> Result<(), String> + Sync,
{
let n = <Self as RowKernel<4>>::n_rows(self);
let chunk = crate::outer_subsample::ARROW_ROW_CHUNK;
let n_chunks = crate::outer_subsample::arrow_row_chunk_count(n);
let chunk_accumulators: Vec<Result<Array2<f64>, String>> = (0..n_chunks)
.into_par_iter()
.map(|chunk_idx| {
let start = chunk_idx * chunk;
let end = (start + chunk).min(n);
let mut acc = Array2::<f64>::zeros((p, p));
for row in start..end {
per_row(row, &mut acc)?;
}
Ok(acc)
})
.collect();
let mut total = Array2::<f64>::zeros((p, p));
for acc in chunk_accumulators {
total += &acc?;
}
Ok(total)
}
fn directional_derivative_all_axes_build_once(
&self,
p: usize,
) -> Result<Vec<Array2<f64>>, String> {
let towers = self.build_row_towers()?;
(0..p)
.into_par_iter()
.map(|a| {
let mut axis = vec![0.0_f64; p];
axis[a] = 1.0;
gam_problem::with_nested_parallel(|| {
self.chunked_pullback_reduce(p, |row, acc| {
let dir = self.jacobian_action(row, &axis);
let third = towers[row].third_contracted(&dir);
self.add_pullback_hessian(row, &third, acc);
Ok(())
})
})
})
.collect()
}
fn second_directional_derivative_all_axes_build_once(
&self,
d_beta_u: &[f64],
) -> Result<Vec<Array2<f64>>, String> {
let p = self.n_coefficients();
let towers = self.build_row_towers()?;
(0..p)
.into_par_iter()
.map(|a| {
let mut axis = vec![0.0_f64; p];
axis[a] = 1.0;
gam_problem::with_nested_parallel(|| {
self.chunked_pullback_reduce(p, |row, acc| {
let dir_u = self.jacobian_action(row, d_beta_u);
let dir_v = self.jacobian_action(row, &axis);
let fourth = towers[row].fourth_contracted(&dir_u, &dir_v);
self.add_pullback_hessian(row, &fourth, acc);
Ok(())
})
})
})
.collect()
}
}