use super::*;
pub(crate) struct SurvivalMarginalSlopeRowKernel {
pub(crate) family: SurvivalMarginalSlopeFamily,
pub(crate) block_states: Vec<ParameterBlockState>,
pub(crate) slices: BlockSlices,
}
impl SurvivalMarginalSlopeRowKernel {
pub(crate) fn new(
family: SurvivalMarginalSlopeFamily,
block_states: Vec<ParameterBlockState>,
) -> Self {
let slices = block_slices(&family, &block_states);
Self {
family,
block_states,
slices,
}
}
}
impl RowKernel<4> for SurvivalMarginalSlopeRowKernel {
fn n_rows(&self) -> usize {
self.family.n
}
fn n_coefficients(&self) -> usize {
self.slices.total
}
fn row_kernel(&self, row: usize) -> Result<(f64, [f64; 4], [[f64; 4]; 4]), String> {
let beta_time = &self.block_states[0].beta;
let q0 = self.family.design_entry.dot_row(row, beta_time)
+ self.family.offset_entry[row]
+ self.block_states[1].eta[row];
let q1 = self.family.design_exit.dot_row(row, beta_time)
+ self.family.offset_exit[row]
+ self.block_states[1].eta[row];
let qd1 = self.family.design_derivative_exit.dot_row(row, beta_time)
+ self.family.derivative_offset_exit[row];
self.family.row_primary_closed_form_rigid(
row,
q0,
q1,
qd1,
&self.block_states,
self.family.probit_frailty_scale(),
)
}
fn jacobian_action(&self, row: usize, d_beta: &[f64]) -> [f64; 4] {
let d_beta = ndarray::ArrayView1::from(d_beta);
let d_time = d_beta.slice(s![self.slices.time.clone()]);
let d_marginal = d_beta.slice(s![self.slices.marginal.clone()]);
let d_logslope = d_beta.slice(s![self.slices.logslope.clone()]);
[
self.family.design_entry.dot_row_view(row, d_time)
+ self.family.marginal_design.dot_row_view(row, d_marginal),
self.family.design_exit.dot_row_view(row, d_time)
+ self.family.marginal_design.dot_row_view(row, d_marginal),
self.family.design_derivative_exit.dot_row_view(row, d_time),
self.family.logslope_design.dot_row_view(row, d_logslope),
]
}
fn jacobian_action_matrix(&self, factor: ArrayView2<'_, f64>) -> Option<Array2<f64>> {
if factor.nrows() != self.slices.total {
return None;
}
let n_rows = self.family.n;
Some(self.assemble_jf(factor, n_rows, |design, factor_block| {
crate::families::row_kernel::row_kernel_design_jf(design, factor_block, n_rows)
}))
}
fn jacobian_action_matrix_rows(
&self,
factor: ArrayView2<'_, f64>,
start: usize,
end: usize,
) -> Array2<f64> {
if factor.nrows() != self.slices.total {
return crate::families::row_kernel::row_kernel_jacobian_action_matrix_generic_rows(
self, factor, start, end,
);
}
let b = end.saturating_sub(start);
self.assemble_jf(factor, b, |design, factor_block| {
crate::families::row_kernel::row_kernel_design_jf_rows(design, factor_block, start, end)
})
}
fn jacobian_transpose_action(&self, row: usize, v: &[f64; 4], out: &mut [f64]) {
{
let mut time = ndarray::ArrayViewMut1::from(&mut out[self.slices.time.clone()]);
self.family
.design_entry
.axpy_row_into(row, v[0], &mut time)
.expect("time entry axpy dim mismatch");
self.family
.design_exit
.axpy_row_into(row, v[1], &mut time)
.expect("time exit axpy dim mismatch");
self.family
.design_derivative_exit
.axpy_row_into(row, v[2], &mut time)
.expect("time deriv axpy dim mismatch");
}
{
let mut marginal = ndarray::ArrayViewMut1::from(&mut out[self.slices.marginal.clone()]);
self.family
.marginal_design
.axpy_row_into(row, v[0] + v[1], &mut marginal)
.expect("marginal axpy dim mismatch");
}
{
let mut logslope = ndarray::ArrayViewMut1::from(&mut out[self.slices.logslope.clone()]);
self.family
.logslope_design
.axpy_row_into(row, v[3], &mut logslope)
.expect("logslope axpy dim mismatch");
}
}
fn add_pullback_hessian(&self, row: usize, h: &[[f64; 4]; 4], target: &mut Array2<f64>) {
let mut h_arr = Array2::<f64>::zeros((4, 4));
for a in 0..4 {
for b in 0..4 {
h_arr[[a, b]] = h[a][b];
}
}
self.family
.add_pullback_primary_hessian(target, row, &self.slices, &h_arr);
}
fn add_diagonal_quadratic(&self, row: usize, h: &[[f64; 4]; 4], diag: &mut [f64]) {
let designs: [(usize, &DesignMatrix); 3] = [
(0, &self.family.design_entry),
(1, &self.family.design_exit),
(2, &self.family.design_derivative_exit),
];
for &(pi, des) in &designs {
{
let mut td = ndarray::ArrayViewMut1::from(&mut diag[self.slices.time.clone()]);
des.squared_axpy_row_into(row, h[pi][pi], &mut td)
.expect("time squared_axpy dim mismatch");
}
for &(pj, des_j) in &designs {
if pj <= pi {
continue;
}
let mut td = ndarray::ArrayViewMut1::from(&mut diag[self.slices.time.clone()]);
des.crossdiag_axpy_row_into(row, des_j, 2.0 * h[pi][pj], &mut td)
.expect("time crossdiag dim mismatch");
}
}
{
let alpha = h[0][0] + 2.0 * h[0][1] + h[1][1];
let mut md = ndarray::ArrayViewMut1::from(&mut diag[self.slices.marginal.clone()]);
self.family
.marginal_design
.squared_axpy_row_into(row, alpha, &mut md)
.expect("marginal squared_axpy dim mismatch");
}
{
let mut gd = ndarray::ArrayViewMut1::from(&mut diag[self.slices.logslope.clone()]);
self.family
.logslope_design
.squared_axpy_row_into(row, h[3][3], &mut gd)
.expect("logslope squared_axpy dim mismatch");
}
}
fn row_third_contracted(&self, row: usize, dir: &[f64; 4]) -> Result<[[f64; 4]; 4], String> {
let dir_view = ndarray::aview1(&dir[..]);
self.family
.row_primary_third_contracted_batched(row, &self.block_states, dir_view)
}
fn row_fourth_contracted(
&self,
row: usize,
dir_u: &[f64; 4],
dir_v: &[f64; 4],
) -> Result<[[f64; 4]; 4], String> {
let u_view = ndarray::aview1(&dir_u[..]);
let v_view = ndarray::aview1(&dir_v[..]);
self.family
.row_primary_fourth_contracted_batched(row, &self.block_states, u_view, v_view)
}
}
impl SurvivalMarginalSlopeRowKernel {
pub(crate) fn assemble_jf<F>(
&self,
factor: ArrayView2<'_, f64>,
n_out: usize,
axis: F,
) -> Array2<f64>
where
F: Fn(&DesignMatrix, ArrayView2<'_, f64>) -> Array2<f64>,
{
let rank = factor.ncols();
if rank == 0 {
return Array2::<f64>::zeros((n_out, 0));
}
let f_time = factor.slice(s![self.slices.time.clone(), ..]);
let f_marginal = factor.slice(s![self.slices.marginal.clone(), ..]);
let f_logslope = factor.slice(s![self.slices.logslope.clone(), ..]);
let jf_marginal = axis(&self.family.marginal_design, f_marginal);
let mut axis0 = axis(&self.family.design_entry, f_time);
axis0 += &jf_marginal;
let mut axis1 = axis(&self.family.design_exit, f_time);
axis1 += &jf_marginal;
let axis2 = axis(&self.family.design_derivative_exit, f_time);
let axis3 = axis(&self.family.logslope_design, f_logslope);
crate::families::row_kernel::row_kernel_pack_jf_axes::<4>(
n_out,
rank,
[(0, axis0), (1, axis1), (2, axis2), (3, axis3)],
)
}
}