use crate::families::lognormal_kernel::{FrailtySpec, HazardLoading};
use crate::families::survival_location_scale::TimeBlockInput;
use ndarray::{Array1, ArrayView2};
#[derive(Clone, Copy, Debug)]
pub struct LatentFrailtyResolution {
pub sigma: Option<f64>,
pub loading: HazardLoading,
}
pub struct LatentIntervalRowView<'a> {
pub frailty: &'a FrailtySpec,
pub age_entry: &'a Array1<f64>,
pub age_exit: &'a Array1<f64>,
pub event_target: &'a Array1<u8>,
pub weights: &'a Array1<f64>,
pub unloaded_mass_entry: &'a Array1<f64>,
pub unloaded_mass_exit: &'a Array1<f64>,
pub unloaded_hazard_exit: Option<&'a Array1<f64>>,
pub mean_offset: &'a Array1<f64>,
pub derivative_guard: f64,
pub time_block: &'a TimeBlockInput,
}
pub trait LatentIntervalModel {
fn context() -> &'static str;
fn frailty_policy(
frailty: &FrailtySpec,
) -> Result<LatentFrailtyResolution, crate::families::latent_survival::LatentSurvivalError>;
}
pub fn validate_latent_interval_inputs<M: LatentIntervalModel>(
data: ArrayView2<'_, f64>,
row: &LatentIntervalRowView<'_>,
) -> Result<Option<f64>, crate::families::latent_survival::LatentSurvivalError> {
use crate::families::latent_survival::{
LatentSurvivalError, validate_unloaded_components_for_loading,
};
let context = M::context();
let resolution = M::frailty_policy(row.frailty)?;
let LatentFrailtyResolution { sigma, loading } = resolution;
let n = data.nrows();
if n == 0 {
return Err(LatentSurvivalError::InvalidDataset {
reason: format!("{context} requires a non-empty dataset"),
});
}
let hazard_lengths_match = match row.unloaded_hazard_exit {
Some(hazard) => hazard.len() == n,
None => true,
};
if row.age_entry.len() != n
|| row.age_exit.len() != n
|| row.event_target.len() != n
|| row.weights.len() != n
|| row.unloaded_mass_entry.len() != n
|| row.unloaded_mass_exit.len() != n
|| !hazard_lengths_match
|| row.mean_offset.len() != n
{
return Err(LatentSurvivalError::InvalidDataset {
reason: size_mismatch_reason(context, n, row),
});
}
if !row.derivative_guard.is_finite() || row.derivative_guard < 0.0 {
return Err(LatentSurvivalError::InvalidDataset {
reason: format!(
"{context} derivative_guard must be finite and >= 0, got {}",
row.derivative_guard
),
});
}
for i in 0..n {
let entry = row.age_entry[i];
let exit = row.age_exit[i];
let event = row.event_target[i];
let weight = row.weights[i];
let unloaded_entry = row.unloaded_mass_entry[i];
let unloaded_exit = row.unloaded_mass_exit[i];
let unloaded_hazard = row.unloaded_hazard_exit.map(|hazard| hazard[i]);
if !entry.is_finite() || !exit.is_finite() {
return Err(LatentSurvivalError::InvalidDataset {
reason: format!(
"{context} row {} has non-finite entry/exit ages: entry={}, exit={}",
i + 1,
entry,
exit
),
});
}
if entry < 0.0 || exit < entry {
return Err(LatentSurvivalError::InvalidDataset {
reason: format!(
"{context} row {} has invalid delayed-entry bounds: entry={}, exit={}",
i + 1,
entry,
exit
),
});
}
if event > 1 {
return Err(LatentSurvivalError::InvalidDataset {
reason: format!(
"{context} row {} has invalid event target {}; expected 0 or 1",
i + 1,
event
),
});
}
if !weight.is_finite() || weight < 0.0 {
return Err(LatentSurvivalError::InvalidDataset {
reason: format!(
"{context} row {} has invalid weight {}; expected a finite non-negative weight",
i + 1,
weight
),
});
}
let masses_invalid = !unloaded_entry.is_finite()
|| !unloaded_exit.is_finite()
|| unloaded_entry < 0.0
|| unloaded_exit < unloaded_entry;
let hazard_invalid =
unloaded_hazard.is_some_and(|hazard| !hazard.is_finite() || hazard < 0.0);
if masses_invalid || hazard_invalid {
return Err(LatentSurvivalError::InvalidDataset {
reason: unloaded_decomposition_reason(
context,
i,
unloaded_entry,
unloaded_exit,
unloaded_hazard,
),
});
}
validate_unloaded_components_for_loading(
context,
i,
loading,
unloaded_entry,
unloaded_exit,
unloaded_hazard,
)?;
}
validate_latent_interval_time_block(context, n, row.time_block)?;
Ok(sigma)
}
fn size_mismatch_reason(context: &str, n: usize, row: &LatentIntervalRowView<'_>) -> String {
match row.unloaded_hazard_exit {
Some(hazard) => format!(
"{context} size mismatch: data has {n} rows, entry={}, exit={}, event={}, weights={}, unloaded_entry={}, unloaded_exit={}, unloaded_hazard={}, offset={}",
row.age_entry.len(),
row.age_exit.len(),
row.event_target.len(),
row.weights.len(),
row.unloaded_mass_entry.len(),
row.unloaded_mass_exit.len(),
hazard.len(),
row.mean_offset.len()
),
None => format!(
"{context} size mismatch: data has {n} rows, entry={}, exit={}, event={}, weights={}, unloaded_entry={}, unloaded_exit={}, offset={}",
row.age_entry.len(),
row.age_exit.len(),
row.event_target.len(),
row.weights.len(),
row.unloaded_mass_entry.len(),
row.unloaded_mass_exit.len(),
row.mean_offset.len()
),
}
}
fn unloaded_decomposition_reason(
context: &str,
row_index: usize,
unloaded_entry: f64,
unloaded_exit: f64,
unloaded_hazard: Option<f64>,
) -> String {
match unloaded_hazard {
Some(hazard) => format!(
"{context} row {} has invalid unloaded hazard decomposition: entry_mass={}, exit_mass={}, exit_hazard={}",
row_index + 1,
unloaded_entry,
unloaded_exit,
hazard
),
None => format!(
"{context} row {} has invalid unloaded mass decomposition: entry_mass={}, exit_mass={}",
row_index + 1,
unloaded_entry,
unloaded_exit,
),
}
}
fn validate_latent_interval_time_block(
context: &str,
n: usize,
time_block: &TimeBlockInput,
) -> Result<(), crate::families::latent_survival::LatentSurvivalError> {
use crate::families::latent_survival::LatentSurvivalError;
let p_time = time_block.design_exit.ncols();
if time_block.design_entry.nrows() != n
|| time_block.design_exit.nrows() != n
|| time_block.design_derivative_exit.nrows() != n
{
return Err(LatentSurvivalError::InvalidDataset {
reason: format!(
"{context} time block row mismatch: n={}, entry_rows={}, exit_rows={}, derivative_rows={}",
n,
time_block.design_entry.nrows(),
time_block.design_exit.nrows(),
time_block.design_derivative_exit.nrows()
),
});
}
if time_block.design_entry.ncols() != p_time
|| time_block.design_derivative_exit.ncols() != p_time
{
return Err(LatentSurvivalError::InvalidDataset {
reason: format!(
"{context} time block column mismatch: entry_cols={}, exit_cols={}, derivative_cols={}",
time_block.design_entry.ncols(),
time_block.design_exit.ncols(),
time_block.design_derivative_exit.ncols()
),
});
}
if time_block.offset_entry.len() != n
|| time_block.offset_exit.len() != n
|| time_block.derivative_offset_exit.len() != n
{
return Err(LatentSurvivalError::InvalidDataset {
reason: format!(
"{context} time block offset mismatch: n={}, entry_offset={}, exit_offset={}, derivative_offset={}",
n,
time_block.offset_entry.len(),
time_block.offset_exit.len(),
time_block.derivative_offset_exit.len()
),
});
}
Ok(())
}