#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum FrailtyDistribution {
Gamma,
LogNormal,
InverseGaussian,
}
#[derive(Debug, Clone)]
pub struct FrailtyConfig {
pub distribution: FrailtyDistribution,
pub max_iterations: usize,
pub tolerance: f64,
pub initial_variance: f64,
}
impl Default for FrailtyConfig {
fn default() -> Self {
Self {
distribution: FrailtyDistribution::Gamma,
max_iterations: 200,
tolerance: 1e-6,
initial_variance: 1.0,
}
}
}
#[derive(Debug, Clone)]
pub struct FrailtyResult {
pub coefficients: Vec<f64>,
pub frailty_variance: f64,
pub frailty_estimates: Vec<f64>,
pub log_likelihood_history: Vec<f64>,
pub converged: bool,
pub iterations: usize,
pub baseline_hazard: Vec<(f64, f64)>,
}
#[derive(Debug, Clone)]
pub struct ClusterInfo {
pub cluster_id: usize,
pub subject_indices: Vec<usize>,
pub n_events: usize,
}
impl ClusterInfo {
pub fn new(cluster_id: usize, subject_indices: Vec<usize>, events: &[bool]) -> Self {
let n_events = subject_indices
.iter()
.filter(|&&i| i < events.len() && events[i])
.count();
Self {
cluster_id,
subject_indices,
n_events,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_frailty_config_default() {
let cfg = FrailtyConfig::default();
assert_eq!(cfg.distribution, FrailtyDistribution::Gamma);
assert_eq!(cfg.max_iterations, 200);
assert!((cfg.tolerance - 1e-6).abs() < 1e-15);
assert!((cfg.initial_variance - 1.0).abs() < 1e-15);
}
#[test]
fn test_cluster_info_event_count() {
let events = [true, false, true, false, true];
let info = ClusterInfo::new(0, vec![0, 2, 4], &events);
assert_eq!(info.n_events, 3);
assert_eq!(info.cluster_id, 0);
}
#[test]
fn test_cluster_info_out_of_bounds_indices() {
let events = [true, false];
let info = ClusterInfo::new(1, vec![0, 5], &events);
assert_eq!(info.n_events, 1);
}
#[test]
fn test_frailty_distribution_variants() {
let g = FrailtyDistribution::Gamma;
let ln = FrailtyDistribution::LogNormal;
let ig = FrailtyDistribution::InverseGaussian;
assert_ne!(g, ln);
assert_ne!(g, ig);
assert_ne!(ln, ig);
}
}