use crate::cubic_cell_kernel as exact_kernel;
use crate::marginal_slope_shared::{
CoeffSupport, ObservedDenestedCellPartials, eval_coeff4_at, scale_coeff4,
};
use gam_math::jet_partitions::MultiDirJet;
use ndarray::{Array1, Array2};
use super::{CachedCellEntry, FlexPrimarySlices, SurvivalMarginalSlopeFamily};
pub(crate) fn neg_cell_of(entry: &CachedCellEntry) -> exact_kernel::DenestedCubicCell {
let cell = entry.partition_cell.cell;
exact_kernel::DenestedCubicCell {
left: cell.left,
right: cell.right,
c0: -cell.c0,
c1: -cell.c1,
c2: -cell.c2,
c3: -cell.c3,
}
}
pub(crate) fn poly_add_jets(lhs: &[MultiDirJet], rhs: &[MultiDirJet]) -> Vec<MultiDirJet> {
let count = lhs.len().max(rhs.len());
let mut out = Vec::with_capacity(count);
for idx in 0..count {
let left = lhs
.get(idx)
.cloned()
.unwrap_or_else(|| MultiDirJet::zero(2));
let right = rhs
.get(idx)
.cloned()
.unwrap_or_else(|| MultiDirJet::zero(2));
out.push(left.add(&right));
}
out
}
pub(crate) fn poly_scale_jets(poly: &[MultiDirJet], scale: &MultiDirJet) -> Vec<MultiDirJet> {
poly.iter().map(|coeff| coeff.mul(scale)).collect()
}
pub(crate) fn poly_mul_jets(lhs: &[MultiDirJet], rhs: &[MultiDirJet]) -> Vec<MultiDirJet> {
if lhs.is_empty() || rhs.is_empty() {
return Vec::new();
}
let mut out = vec![MultiDirJet::zero(2); lhs.len() + rhs.len() - 1];
for (i, left) in lhs.iter().enumerate() {
for (j, right) in rhs.iter().enumerate() {
let prod = left.mul(right);
out[i + j] = out[i + j].add(&prod);
}
}
out
}
pub(crate) fn poly_coeff_mask(poly: &[MultiDirJet], mask: usize) -> Vec<f64> {
poly.iter().map(|coeff| coeff.coeff(mask)).collect()
}
pub(crate) const COEFF_SUPPORT_GHW: CoeffSupport = CoeffSupport {
include_primary: true,
include_h: true,
include_w: true,
};
pub(crate) const COEFF_SUPPORT_GW: CoeffSupport = CoeffSupport {
include_primary: true,
include_h: false,
include_w: true,
};
pub(crate) fn scalar_composite_bilinear(
base: f64,
da: f64,
daa: f64,
fixed_d1: f64,
fixed_d2: f64,
fixed_d12: f64,
da_d1: f64,
da_d2: f64,
ad1: f64,
ad2: f64,
ad12: f64,
) -> MultiDirJet {
MultiDirJet::bilinear(
base,
da * ad1 + fixed_d1,
da * ad2 + fixed_d2,
da * ad12 + daa * ad1 * ad2 + da_d1 * ad2 + da_d2 * ad1 + fixed_d12,
)
}
pub(crate) fn coeff4_fixed_bilinear(
base: &[f64; 4],
d1: &[f64; 4],
d2: &[f64; 4],
d12: &[f64; 4],
) -> Vec<MultiDirJet> {
(0..4)
.map(|k| MultiDirJet::bilinear(base[k], d1[k], d2[k], d12[k]))
.collect()
}
pub(crate) fn coeff4_composite_bilinear(
base: &[f64; 4],
da: &[f64; 4],
daa: &[f64; 4],
fixed_d1: &[f64; 4],
fixed_d2: &[f64; 4],
fixed_d12: &[f64; 4],
da_d1: &[f64; 4],
da_d2: &[f64; 4],
ad1: f64,
ad2: f64,
ad12: f64,
) -> Vec<MultiDirJet> {
(0..4)
.map(|k| {
scalar_composite_bilinear(
base[k],
da[k],
daa[k],
fixed_d1[k],
fixed_d2[k],
fixed_d12[k],
da_d1[k],
da_d2[k],
ad1,
ad2,
ad12,
)
})
.collect()
}
pub(crate) struct SurvivalFlexTimepointDirectionalExact {
pub(crate) eta_uv_dir: Array2<f64>,
pub(crate) eta_u_dir: Array1<f64>,
pub(crate) chi_u_dir: Array1<f64>,
pub(crate) chi_uv_dir: Array2<f64>,
pub(crate) d_u_dir: Array1<f64>,
pub(crate) d_uv_dir: Array2<f64>,
pub(crate) a_uv_dir: Array2<f64>,
}
pub(crate) struct SurvivalFlexTimepointBiDirectionalExact {
pub(crate) eta_uv_uv: Array2<f64>,
pub(crate) chi_uv_uv: Array2<f64>,
pub(crate) d_uv_uv: Array2<f64>,
}
impl SurvivalMarginalSlopeFamily {
pub(crate) fn cell_pair_second_coeff(
&self,
primary: &FlexPrimarySlices,
coeff_bu: &[[f64; 4]],
u: usize,
v: usize,
) -> [f64; 4] {
if u == primary.g {
coeff_bu[v]
} else if v == primary.g {
coeff_bu[u]
} else {
[0.0; 4]
}
}
pub(crate) fn cell_pair_third_coeff_a(
&self,
primary: &FlexPrimarySlices,
coeff_abu: &[[f64; 4]],
u: usize,
v: usize,
) -> [f64; 4] {
if u == primary.g {
coeff_abu[v]
} else if v == primary.g {
coeff_abu[u]
} else {
[0.0; 4]
}
}
pub(crate) fn add_cell_pair_third_coeff_dir(
&self,
primary: &FlexPrimarySlices,
coeff_bbu: &[[f64; 4]],
u: usize,
v: usize,
dir: &Array1<f64>,
sign: f64,
out: &mut [f64; 4],
) {
let g = primary.g;
if u == g && v == g {
for (c, &dir_c) in dir.iter().enumerate() {
if dir_c == 0.0 {
continue;
}
for k in 0..4 {
out[k] += sign * coeff_bbu[c][k] * dir_c;
}
}
} else if u == g {
let dir_g = dir[g];
if dir_g != 0.0 {
for k in 0..4 {
out[k] += sign * coeff_bbu[v][k] * dir_g;
}
}
} else if v == g {
let dir_g = dir[g];
if dir_g != 0.0 {
for k in 0..4 {
out[k] += sign * coeff_bbu[u][k] * dir_g;
}
}
}
}
pub(crate) fn observed_fixed_eta_second_partial(
&self,
primary: &FlexPrimarySlices,
obs: &ObservedDenestedCellPartials,
row: usize,
u: usize,
v: usize,
z_obs: f64,
u_obs: f64,
a: f64,
b: f64,
) -> Result<f64, String> {
let scale = self.probit_frailty_scale();
if u == primary.g && v == primary.g {
return Ok(eval_coeff4_at(&obs.dc_dbb, z_obs));
}
if u == primary.g {
if let Some(h_range) = primary.h.as_ref()
&& v >= h_range.start
&& v < h_range.end
{
let local_idx = v - h_range.start;
return Ok(eval_coeff4_at(
&scale_coeff4(
self.observed_score_basis_coefficients(row, local_idx, z_obs, 1.0)?,
scale,
),
z_obs,
));
}
if let Some(w_range) = primary.w.as_ref()
&& v >= w_range.start
&& v < w_range.end
{
let local_idx = v - w_range.start;
let runtime = self
.link_dev
.as_ref()
.ok_or_else(|| "missing survival link runtime".to_string())?;
let basis_span = runtime.basis_cubic_at(local_idx, u_obs)?;
let (_, dc_bw) =
exact_kernel::link_basis_cell_coefficient_partials(basis_span, a, b);
return Ok(eval_coeff4_at(&scale_coeff4(dc_bw, scale), z_obs));
}
}
if v == primary.g {
return self
.observed_fixed_eta_second_partial(primary, obs, row, v, u, z_obs, u_obs, a, b);
}
Ok(0.0)
}
pub(crate) fn observed_fixed_chi_second_partial(
&self,
primary: &FlexPrimarySlices,
obs: &ObservedDenestedCellPartials,
u: usize,
v: usize,
z_obs: f64,
u_obs: f64,
a: f64,
b: f64,
) -> Result<f64, String> {
let scale = self.probit_frailty_scale();
if u == primary.g && v == primary.g {
return Ok(eval_coeff4_at(&obs.dc_dabb, z_obs));
}
if u == primary.g
&& let Some(w_range) = primary.w.as_ref()
&& v >= w_range.start
&& v < w_range.end
{
let local_idx = v - w_range.start;
let runtime = self
.link_dev
.as_ref()
.ok_or_else(|| "missing survival link runtime".to_string())?;
let basis_span = runtime.basis_cubic_at(local_idx, u_obs)?;
let (_, dc_abw, _) = exact_kernel::link_basis_cell_second_partials(basis_span, a, b);
return Ok(eval_coeff4_at(&scale_coeff4(dc_abw, scale), z_obs));
}
if v == primary.g {
return self.observed_fixed_chi_second_partial(primary, obs, v, u, z_obs, u_obs, a, b);
}
Ok(0.0)
}
}