use ndarray::Array3;
use crate::linalg::matrix::DesignMatrix;
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
pub struct BlockTag(pub u32);
pub trait BlockPrimaryJacobian {
fn n_channels(&self) -> usize;
fn n_rows(&self) -> usize;
fn channel_contributions(&self, block: BlockTag) -> Vec<Option<DesignMatrix>>;
fn row_channel_metric(&self) -> Array3<f64>;
}
#[repr(u8)]
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
pub enum SurvivalPrimaryChannel {
EntryLocation = 0,
ExitLocation = 1,
ExitDerivative = 2,
Logslope = 3,
EtaScalar = 4,
}
impl SurvivalPrimaryChannel {
pub const COUNT: usize = 5;
pub const ALL: [SurvivalPrimaryChannel; Self::COUNT] = [
Self::EntryLocation,
Self::ExitLocation,
Self::ExitDerivative,
Self::Logslope,
Self::EtaScalar,
];
#[inline]
pub const fn index(self) -> usize {
self as usize
}
}
#[repr(u32)]
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
pub enum SurvivalBlock {
Time = 0,
Marginal = 1,
Logslope = 2,
ScoreWarp = 3,
LinkDev = 4,
}
impl SurvivalBlock {
#[inline]
pub const fn tag(self) -> BlockTag {
BlockTag(self as u32)
}
}
#[derive(Clone, Debug)]
pub struct SurvivalPrimaryJacobian {
pub n_rows: usize,
pub time_design_entry: DesignMatrix,
pub time_design_exit: DesignMatrix,
pub time_design_derivative_exit: DesignMatrix,
pub marginal_design: DesignMatrix,
pub logslope_design: DesignMatrix,
pub score_warp_design: Option<DesignMatrix>,
pub link_dev_design: Option<DesignMatrix>,
pub row_metric: Array3<f64>,
}
impl BlockPrimaryJacobian for SurvivalPrimaryJacobian {
#[inline]
fn n_channels(&self) -> usize {
SurvivalPrimaryChannel::COUNT
}
#[inline]
fn n_rows(&self) -> usize {
self.n_rows
}
fn channel_contributions(&self, block: BlockTag) -> Vec<Option<DesignMatrix>> {
let mut out: Vec<Option<DesignMatrix>> =
(0..SurvivalPrimaryChannel::COUNT).map(|_| None).collect();
match block {
t if t == SurvivalBlock::Time.tag() => {
out[SurvivalPrimaryChannel::EntryLocation.index()] =
Some(self.time_design_entry.clone());
out[SurvivalPrimaryChannel::ExitLocation.index()] =
Some(self.time_design_exit.clone());
out[SurvivalPrimaryChannel::ExitDerivative.index()] =
Some(self.time_design_derivative_exit.clone());
}
t if t == SurvivalBlock::Marginal.tag() => {
out[SurvivalPrimaryChannel::EntryLocation.index()] =
Some(self.marginal_design.clone());
out[SurvivalPrimaryChannel::ExitLocation.index()] =
Some(self.marginal_design.clone());
}
t if t == SurvivalBlock::Logslope.tag() => {
out[SurvivalPrimaryChannel::Logslope.index()] = Some(self.logslope_design.clone());
}
t if t == SurvivalBlock::ScoreWarp.tag() => {
if let Some(d) = self.score_warp_design.as_ref() {
out[SurvivalPrimaryChannel::EtaScalar.index()] = Some(d.clone());
}
}
t if t == SurvivalBlock::LinkDev.tag() => {
if let Some(d) = self.link_dev_design.as_ref() {
out[SurvivalPrimaryChannel::EtaScalar.index()] = Some(d.clone());
}
}
_ => {}
}
out
}
#[inline]
fn row_channel_metric(&self) -> Array3<f64> {
self.row_metric.clone()
}
}
#[repr(u8)]
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
pub enum BernoulliPrimaryChannel {
EtaScalar = 0,
}
impl BernoulliPrimaryChannel {
pub const COUNT: usize = 1;
#[inline]
pub const fn index(self) -> usize {
self as usize
}
}
#[repr(u32)]
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
pub enum BernoulliBlock {
Marginal = 0,
Logslope = 1,
ScoreWarp = 2,
LinkDev = 3,
}
impl BernoulliBlock {
#[inline]
pub const fn tag(self) -> BlockTag {
BlockTag(self as u32)
}
}
#[derive(Clone, Debug)]
pub struct BernoulliPrimaryJacobian {
pub n_rows: usize,
pub marginal_design: DesignMatrix,
pub logslope_design: DesignMatrix,
pub score_warp_design: Option<DesignMatrix>,
pub link_dev_design: Option<DesignMatrix>,
pub pilot_irls_w: ndarray::Array1<f64>,
}
impl BlockPrimaryJacobian for BernoulliPrimaryJacobian {
#[inline]
fn n_channels(&self) -> usize {
BernoulliPrimaryChannel::COUNT
}
#[inline]
fn n_rows(&self) -> usize {
self.n_rows
}
fn channel_contributions(&self, block: BlockTag) -> Vec<Option<DesignMatrix>> {
let mut out: Vec<Option<DesignMatrix>> =
(0..BernoulliPrimaryChannel::COUNT).map(|_| None).collect();
match block {
t if t == BernoulliBlock::Marginal.tag() => {
out[BernoulliPrimaryChannel::EtaScalar.index()] =
Some(self.marginal_design.clone());
}
t if t == BernoulliBlock::Logslope.tag() => {
out[BernoulliPrimaryChannel::EtaScalar.index()] =
Some(self.logslope_design.clone());
}
t if t == BernoulliBlock::ScoreWarp.tag() => {
if let Some(d) = self.score_warp_design.as_ref() {
out[BernoulliPrimaryChannel::EtaScalar.index()] = Some(d.clone());
}
}
t if t == BernoulliBlock::LinkDev.tag() => {
if let Some(d) = self.link_dev_design.as_ref() {
out[BernoulliPrimaryChannel::EtaScalar.index()] = Some(d.clone());
}
}
_ => {}
}
out
}
fn row_channel_metric(&self) -> Array3<f64> {
let n = self.n_rows;
assert_eq!(
self.pilot_irls_w.len(),
n,
"BernoulliPrimaryJacobian: pilot_irls_w length {} does not match n_rows {}",
self.pilot_irls_w.len(),
n,
);
let mut metric = Array3::<f64>::zeros((n, 1, 1));
for i in 0..n {
metric[[i, 0, 0]] = self.pilot_irls_w[i];
}
metric
}
}
pub trait BlockEffectiveJacobian:
crate::families::identifiability_compiler::RowJacobianOperator
{
fn block_label(&self) -> &str;
}
pub struct ScoreWarpEffectiveJacobian {
inner: crate::families::survival_marginal_slope_identifiability::QChannelBlockOperator,
}
impl ScoreWarpEffectiveJacobian {
pub fn new(dq: ndarray::Array2<f64>, dqd1: ndarray::Array2<f64>) -> Self {
Self {
inner:
crate::families::survival_marginal_slope_identifiability::QChannelBlockOperator::new(
dq, dqd1,
),
}
}
}
impl crate::families::identifiability_compiler::RowJacobianOperator for ScoreWarpEffectiveJacobian {
fn k(&self) -> usize {
self.inner.k()
}
fn ncols(&self) -> usize {
self.inner.ncols()
}
fn nrows(&self) -> usize {
self.inner.nrows()
}
fn apply_row(&self, row: usize, delta_beta: &[f64], out: &mut [f64]) {
self.inner.apply_row(row, delta_beta, out);
}
fn evaluate_full(&self) -> ndarray::Array3<f64> {
self.inner.evaluate_full()
}
}
impl BlockEffectiveJacobian for ScoreWarpEffectiveJacobian {
fn block_label(&self) -> &str {
"score_warp_dev"
}
}
pub struct LinkDevEffectiveJacobian {
inner: crate::families::survival_marginal_slope_identifiability::QChannelBlockOperator,
}
impl LinkDevEffectiveJacobian {
pub fn new(dq: ndarray::Array2<f64>, dqd1: ndarray::Array2<f64>) -> Self {
Self {
inner:
crate::families::survival_marginal_slope_identifiability::QChannelBlockOperator::new(
dq, dqd1,
),
}
}
}
impl crate::families::identifiability_compiler::RowJacobianOperator for LinkDevEffectiveJacobian {
fn k(&self) -> usize {
self.inner.k()
}
fn ncols(&self) -> usize {
self.inner.ncols()
}
fn nrows(&self) -> usize {
self.inner.nrows()
}
fn apply_row(&self, row: usize, delta_beta: &[f64], out: &mut [f64]) {
self.inner.apply_row(row, delta_beta, out);
}
fn evaluate_full(&self) -> ndarray::Array3<f64> {
self.inner.evaluate_full()
}
}
impl BlockEffectiveJacobian for LinkDevEffectiveJacobian {
fn block_label(&self) -> &str {
"link_dev"
}
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::Array1;
#[test]
fn bernoulli_metric_reduces_to_irls_w() {
let w = Array1::from(vec![0.25_f64, 0.1875, 0.0625]);
let p = 2usize;
let m = ndarray::Array2::<f64>::zeros((3, p));
let dm = DesignMatrix::from(m.clone());
let bj = BernoulliPrimaryJacobian {
n_rows: 3,
marginal_design: dm.clone(),
logslope_design: dm.clone(),
score_warp_design: None,
link_dev_design: None,
pilot_irls_w: w.clone(),
};
let metric = bj.row_channel_metric();
assert_eq!(metric.shape(), &[3, 1, 1]);
for i in 0..3 {
assert!((metric[[i, 0, 0]] - w[i]).abs() < 1e-15);
}
let contrib = bj.channel_contributions(BernoulliBlock::ScoreWarp.tag());
assert!(contrib[0].is_none());
let contrib = bj.channel_contributions(BernoulliBlock::Marginal.tag());
assert!(contrib[0].is_some());
}
#[test]
fn survival_channel_routing() {
let n = 4;
let pt = 3;
let pm = 2;
let pg = 2;
let mk = |p: usize| DesignMatrix::from(ndarray::Array2::<f64>::zeros((n, p)));
let row_metric = Array3::<f64>::zeros((
n,
SurvivalPrimaryChannel::COUNT,
SurvivalPrimaryChannel::COUNT,
));
let sj = SurvivalPrimaryJacobian {
n_rows: n,
time_design_entry: mk(pt),
time_design_exit: mk(pt),
time_design_derivative_exit: mk(pt),
marginal_design: mk(pm),
logslope_design: mk(pg),
score_warp_design: None,
link_dev_design: None,
row_metric,
};
assert_eq!(sj.n_channels(), 5);
let time = sj.channel_contributions(SurvivalBlock::Time.tag());
assert!(time[SurvivalPrimaryChannel::EntryLocation.index()].is_some());
assert!(time[SurvivalPrimaryChannel::ExitLocation.index()].is_some());
assert!(time[SurvivalPrimaryChannel::ExitDerivative.index()].is_some());
assert!(time[SurvivalPrimaryChannel::Logslope.index()].is_none());
assert!(time[SurvivalPrimaryChannel::EtaScalar.index()].is_none());
let marg = sj.channel_contributions(SurvivalBlock::Marginal.tag());
assert!(marg[SurvivalPrimaryChannel::EntryLocation.index()].is_some());
assert!(marg[SurvivalPrimaryChannel::ExitLocation.index()].is_some());
assert!(marg[SurvivalPrimaryChannel::ExitDerivative.index()].is_none());
}
}