survival 1.0.17

A high-performance survival analysis library written in Rust with Python bindings
Documentation
use chrono::{NaiveDate, Duration};
use std::cmp::Ordering;

#[derive(Debug, Clone)]
enum DimensionKind {
    Discrete,
    Continuous,
}

#[derive(Debug, Clone)]
struct Dimension {
    kind: DimensionKind,
    cutpoints: Vec<f64>,
}

#[derive(Debug)]
struct Ratetable {
    dim_sizes: Vec<usize>,
    dimensions: Vec<Dimension>,
    rates: Vec<f64>,
}

#[derive(Debug, PartialEq)]
struct RateWalkResult {
    cells: Vec<Vec<usize>>,
    days: Vec<f64>,
    hazards: Vec<f64>,
}

fn compute_linear_index(cell: &[usize], dim_sizes: &[usize]) -> Result<usize, String> {
    if cell.len() != dim_sizes.len() {
        return Err("Cell length doesn't match dimensions".to_string());
    }
    
    let mut index = 0;
    let mut stride = 1;
    for (i, (&c, &d)) in cell.iter().zip(dim_sizes.iter()).enumerate().rev() {
        if c >= d {
            return Err(format!("Cell index {} out of bounds for dimension {} (size {})", c, i, d));
        }
        index += c * stride;
        stride *= d;
    }
    Ok(index)
}

fn ratewalk(start: &[f64], mut futime: f64, ratetable: &Ratable) -> Result<RateWalkResult, String> {
    if start.len() != ratetable.dimensions.len() {
        return Err("Start length doesn't match ratetable dimensions".to_string());
    }
    if futime <= 0.0 {
        return Err("futime must be positive".to_string());
    }

    const EPS: f64 = 1e-8;
    let mut current_start = start.to_vec();
    let mut cells = Vec::new();
    let mut days = Vec::new();
    let mut hazards = Vec::new();

    while futime > 0.0 {
        let mut cell = vec![0; ratetable.dimensions.len()];
        let mut edge = futime;

        for (i, dim) in ratetable.dimensions.iter().enumerate() {
            match dim.kind {
                DimensionKind::Discrete => {
                    cell[i] = current_start[i] as usize;
                }
                DimensionKind::Continuous => {
                    let adj_value = current_start[i] + EPS;
                    let count = dim.cutpoints.partition_point(|cp| cp <= &adj_value);
                    cell[i] = count;
                    
                    if count < dim.cutpoints.len() {
                        let time_to_next = dim.cutpoints[count] - current_start[i];
                        if time_to_next < edge {
                            edge = time_to_next;
                        }
                    }
                }
            }
        }

        let linear_index = compute_linear_index(&cell, &ratetable.dim_sizes)?;
        let rate = *ratetable.rates.get(linear_index).unwrap_or(&0.0);
        let hazard = edge * rate;

        days.push(edge);
        hazards.push(hazard);
        cells.push(cell.clone());

        for (i, dim) in ratetable.dimensions.iter().enumerate() {
            if let DimensionKind::Continuous = dim.kind {
                current_start[i] += edge;
            }
        }
        futime -= edge;
    }

    Ok(RateWalkResult { cells, days, hazards })
}

fn mdy_date(m: u32, d: u32, y: i32) -> NaiveDate {
    let y = if y < 100 { y + 1900 } else { y };
    NaiveDate::from_ymd_opt(y, m, d).expect("Invalid date")
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_ratewalk_simple() {
        let age_cutpoints: Vec<f64> = (0..=100).map(|i| i as f64 * 365.25).collect();
        let year_cutpoints: Vec<f64> = (1960..=2000)
            .map(|y| {
                let date = NaiveDate::from_ymd_opt(y, 1, 1).unwrap();
                (date - NaiveDate::from_ymd_opt(1960, 1, 1).unwrap()).num_days() as f64
            })
            .collect();

        let dim_sizes = vec![
            age_cutpoints.len() + 1,
            2,
            year_cutpoints.len() + 1,
        ];

        let dimensions = vec![
            Dimension {
                kind: DimensionKind::Continuous,
                cutpoints: age_cutpoints,
            },
            Dimension {
                kind: DimensionKind::Discrete,
                cutpoints: vec![],
            },
            Dimension {
                kind: DimensionKind::Continuous,
                cutpoints: year_cutpoints,
            },
        ];

        let mut rates = vec![0.0; dim_sizes.iter().product()];
        
        let age_index = 20; 
        let sex_index = 0;  
        let year_index = 0; 
        let idx = compute_linear_index(&[age_index, sex_index, year_index], &dim_sizes).unwrap();
        rates[idx] = - (1.0 - 0.00169).ln() / 365.25; 

        let ratetable = Ratetable {
            dim_sizes,
            dimensions,
            rates,
        };

        let birth_date = mdy_date(1, 1, 36);
        let entry_date = mdy_date(9, 7, 60);
        let age_days = (entry_date - birth_date).num_days() as f64;
        let year_days = (entry_date - NaiveDate::from_ymd_opt(1960, 1, 1).unwrap()).num_days() as f64;
        let start = vec![age_days, 1.0, year_days];
        
        let result = ratewalk(&start, 200.0, &ratetable).unwrap();

        assert_eq!(result.days.len(), 2);
        assert!((result.days[0] - 116.0).abs() < 1e-6);
        assert!((result.days[1] - 84.0).abs() < 1e-6);

        let expected_hazard = -(116.0 / 365.25) * (1.0 - 0.00169).ln();
        assert!((result.hazards[0] - expected_hazard).abs() < 1e-6);
    }
}