cox_hazards/
model.rs

1use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
2use crate::{
3    data::SurvivalData,
4    error::{CoxError, Result},
5    optimization::{CoxOptimizer, OptimizationConfig, OptimizerType},
6};
7
8/// cox model w/ elastic net regularization
9#[derive(Debug, Clone)]  
10pub struct CoxModel {
11    coefficients: Option<Array1<f64>>,  // fitted coefficients
12    l1_penalty: f64,                    // lasso penalty 
13    l2_penalty: f64,                    // ridge penalty
14    max_iterations: usize,              // optimization limit
15    tolerance: f64,                     // convergence threshold
16    fitted: bool,                       // have we been fit yet?
17    feature_names: Option<Vec<String>>, // optional feature labels
18    optimizer_type: OptimizerType,      // which optimizer to use
19    learning_rate: f64,                 // learning rate for Adam/RMSprop
20    beta1: f64,                         // Adam momentum parameter
21    beta2: f64,                         // Adam/RMSprop decay parameter
22    epsilon: f64,                       // Adam/RMSprop numerical stability
23}
24
25impl Default for CoxModel {
26    fn default() -> Self {
27        Self {
28            coefficients: None,
29            l1_penalty: 0.0,
30            l2_penalty: 0.0,
31            max_iterations: 1000,
32            tolerance: 1e-6,
33            fitted: false,
34            feature_names: None,
35            optimizer_type: OptimizerType::NewtonRaphson,
36            learning_rate: 0.001,
37            beta1: 0.9,
38            beta2: 0.999,
39            epsilon: 1e-8,
40        }
41    }
42}
43
44impl CoxModel {
45    /// new cox model w/ defaults
46    pub fn new() -> Self {
47        Self::default()
48    }
49    
50    /// add lasso penalty (L1) - encourages sparsity 
51    pub fn with_l1_penalty(mut self, penalty: f64) -> Self {
52        self.l1_penalty = penalty.max(0.0);
53        self
54    }
55    
56    /// add ridge penalty (L2) - shrinks coefficients
57    pub fn with_l2_penalty(mut self, penalty: f64) -> Self {
58        self.l2_penalty = penalty.max(0.0);
59        self
60    }
61    
62    /// elastic net mixing: alpha=0 -> pure ridge, alpha=1 -> pure lasso
63    pub fn with_elastic_net(mut self, alpha: f64, penalty: f64) -> Self {
64        if alpha < 0.0 || alpha > 1.0 {
65            panic!("alpha must be in [0,1]");
66        }
67        self.l1_penalty = alpha * penalty;        // lasso component
68        self.l2_penalty = (1.0 - alpha) * penalty; // ridge component  
69        self
70    }
71    
72    /// max iterations before giving up
73    pub fn with_max_iterations(mut self, max_iter: usize) -> Self {
74        self.max_iterations = max_iter;
75        self
76    }
77    
78    /// how close is close enough for convergence
79    pub fn with_tolerance(mut self, tol: f64) -> Self {
80        self.tolerance = tol;
81        self
82    }
83    
84    /// give names to your features for nicer output
85    pub fn with_feature_names(mut self, names: Vec<String>) -> Self {
86        self.feature_names = Some(names);
87        self
88    }
89    
90    /// choose which optimizer to use
91    pub fn with_optimizer(mut self, optimizer: OptimizerType) -> Self {
92        self.optimizer_type = optimizer;
93        self
94    }
95    
96    /// set learning rate for Adam/RMSprop optimizers
97    pub fn with_learning_rate(mut self, lr: f64) -> Self {
98        self.learning_rate = lr.max(0.0);
99        self
100    }
101    
102    /// set Adam/RMSprop parameters (beta1 for Adam momentum, beta2 for decay rate)
103    pub fn with_adam_params(mut self, beta1: f64, beta2: f64) -> Self {
104        self.beta1 = beta1.clamp(0.0, 1.0);
105        self.beta2 = beta2.clamp(0.0, 1.0);
106        self
107    }
108    
109    /// set Adam/RMSprop numerical stability parameter
110    pub fn with_epsilon(mut self, eps: f64) -> Self {
111        self.epsilon = eps.max(0.0);
112        self
113    }
114    
115    /// fit the model to data - this does the actual work
116    pub fn fit(&mut self, data: &SurvivalData) -> Result<&mut Self> {
117        let config = OptimizationConfig {
118            l1_penalty: self.l1_penalty,
119            l2_penalty: self.l2_penalty,
120            max_iterations: self.max_iterations,
121            tolerance: self.tolerance,
122            optimizer_type: self.optimizer_type,
123            learning_rate: self.learning_rate,
124            beta1: self.beta1,
125            beta2: self.beta2,
126            epsilon: self.epsilon,
127        };
128        
129        let mut optimizer = CoxOptimizer::new(config);
130        self.coefficients = Some(optimizer.optimize(data)?);
131        self.fitted = true;
132        
133        Ok(self)
134    }
135    
136    /// get the fitted coefficients (betas)
137    pub fn coefficients(&self) -> Result<ArrayView1<'_, f64>> {
138        match &self.coefficients {
139            Some(coefs) => Ok(coefs.view()),
140            None => Err(CoxError::ModelNotFitted),
141        }
142    }
143    
144    /// predict risk scores for new patients  
145    pub fn predict(&self, covariates: ArrayView2<f64>) -> Result<Array1<f64>> {
146        let coefs = self.coefficients()?;
147        
148        if covariates.ncols() != coefs.len() {
149            return Err(CoxError::invalid_dimensions(
150                format!("feature count mismatch: expected {}, got {}", 
151                       coefs.len(), covariates.ncols())
152            ));
153        }
154        
155        Ok(covariates.dot(&coefs))  // linear combination
156    }
157    
158    /// predict hazard ratios (exp of risk scores)
159    pub fn predict_hazard_ratios(&self, covariates: ArrayView2<f64>) -> Result<Array1<f64>> {
160        let linear_predictors = self.predict(covariates)?;
161        Ok(linear_predictors.mapv(f64::exp))
162    }
163    
164    /// predict survival probs at specific time points  
165    pub fn predict_survival(&self, covariates: ArrayView2<f64>, times: ArrayView1<f64>) -> Result<Array2<f64>> {
166        let risk_scores = self.predict(covariates)?;
167        let n_samples = covariates.nrows();
168        let n_times = times.len();
169        
170        // simplified survival estimation (in practice use breslow estimator)
171        let mut survival_probs = Array2::zeros((n_samples, n_times));
172        
173        for (i, &time) in times.iter().enumerate() {
174            for j in 0..n_samples {
175                let hazard_ratio = risk_scores[j].exp();
176                let baseline_hazard = 0.1; // rough approximation 
177                survival_probs[[j, i]] = (-baseline_hazard * hazard_ratio * time).exp();
178            }
179        }
180        
181        Ok(survival_probs)
182    }
183    
184    /// feature importance = abs value of coefficients
185    pub fn feature_importance(&self) -> Result<Array1<f64>> {
186        let coefs = self.coefficients()?;
187        Ok(coefs.mapv(f64::abs))
188    }
189    
190    /// get a nice summary of the fitted model
191    pub fn summary(&self) -> Result<CoxModelSummary> {
192        if !self.fitted {
193            return Err(CoxError::ModelNotFitted);
194        }
195        
196        let coefs = self.coefficients()?.to_owned();
197        let hazard_ratios = coefs.mapv(f64::exp);
198        
199        Ok(CoxModelSummary {
200            coefficients: coefs,
201            hazard_ratios,
202            l1_penalty: self.l1_penalty,
203            l2_penalty: self.l2_penalty,
204            feature_names: self.feature_names.clone(),
205        })
206    }
207    
208    /// has this model been fit to data yet?
209    pub fn is_fitted(&self) -> bool {
210        self.fitted
211    }
212    
213    /// what regularization penalties are we using?
214    pub fn regularization_params(&self) -> (f64, f64) {
215        (self.l1_penalty, self.l2_penalty)  // (lasso, ridge)
216    }
217}
218
219/// nice summary of what the model learned
220#[derive(Debug, Clone)]
221pub struct CoxModelSummary {
222    pub coefficients: Array1<f64>,   // the betas  
223    pub hazard_ratios: Array1<f64>,  // exp(betas)
224    pub l1_penalty: f64,             // lasso penalty used
225    pub l2_penalty: f64,             // ridge penalty used
226    pub feature_names: Option<Vec<String>>, // optional labels
227}
228
229impl CoxModelSummary {
230    /// print out what we learned
231    pub fn print(&self) {
232        println!("cox proportional hazards model summary");
233        println!("=====================================");
234        println!("l1 penalty (lasso): {:.6}", self.l1_penalty);
235        println!("l2 penalty (ridge): {:.6}", self.l2_penalty);
236        println!("");
237        
238        println!("{:<20} {:>12} {:>12}", "feature", "coefficient", "hazard ratio");
239        println!("{:-<44}", "");
240        
241        for i in 0..self.coefficients.len() {
242            let default_name = format!("x{}", i);
243            let feature_name = match &self.feature_names {
244                Some(names) => names.get(i).map(|s| s.as_str()).unwrap_or(&default_name),
245                None => &default_name,
246            };
247            
248            println!("{:<20} {:>12.6} {:>12.6}", 
249                    feature_name,
250                    self.coefficients[i],
251                    self.hazard_ratios[i]);
252        }
253    }
254}
255
256#[cfg(test)]
257mod tests {
258    use super::*;
259    use ndarray::Array2;
260    use approx::assert_relative_eq;
261    
262    fn create_test_data() -> SurvivalData {
263        let times = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
264        let events = vec![true, false, true, true, false, true, true, false];
265        let covariates = Array2::from_shape_vec((8, 3), vec![
266            1.0, 0.0, 0.5,
267            0.0, 1.0, -0.5,
268            1.0, 1.0, 0.0,
269            -1.0, 0.0, 1.0,
270            0.0, -1.0, -1.0,
271            1.0, -1.0, 0.5,
272            -1.0, 1.0, -0.5,
273            0.0, 0.0, 0.0,
274        ]).unwrap();
275        
276        SurvivalData::new(times, events, covariates).unwrap()
277    }
278    
279    #[test]
280    fn test_model_creation() {
281        let model = CoxModel::new()
282            .with_l1_penalty(0.1)
283            .with_l2_penalty(0.05)
284            .with_max_iterations(500);
285        
286        assert_eq!(model.l1_penalty, 0.1);
287        assert_eq!(model.l2_penalty, 0.05);
288        assert_eq!(model.max_iterations, 500);
289        assert!(!model.is_fitted());
290    }
291    
292    #[test]
293    fn test_elastic_net_parameters() {
294        let model = CoxModel::new().with_elastic_net(0.5, 1.0);
295        assert_relative_eq!(model.l1_penalty, 0.5, epsilon = 1e-10);
296        assert_relative_eq!(model.l2_penalty, 0.5, epsilon = 1e-10);
297    }
298    
299    #[test]
300    fn test_model_not_fitted_error() {
301        let model = CoxModel::new();
302        assert!(model.coefficients().is_err());
303        assert!(model.summary().is_err());
304        
305        let covariates = Array2::zeros((5, 3));
306        assert!(model.predict(covariates.view()).is_err());
307    }
308    
309    #[test]
310    fn test_feature_names() {
311        let names = vec!["age".to_string(), "gender".to_string(), "treatment".to_string()];
312        let model = CoxModel::new().with_feature_names(names.clone());
313        assert_eq!(model.feature_names.unwrap(), names);
314    }
315    
316    #[test]
317    fn test_prediction_dimension_mismatch() {
318        let data = create_test_data();
319        let mut model = CoxModel::new();
320        model.fit(&data).unwrap();
321        
322        // Wrong number of features
323        let wrong_covariates = Array2::zeros((5, 2)); // Should be 3 features
324        assert!(model.predict(wrong_covariates.view()).is_err());
325    }
326    
327    #[test]
328    fn test_adam_optimizer() {
329        let data = create_test_data();
330        let mut model = CoxModel::new()
331            .with_optimizer(OptimizerType::Adam)
332            .with_learning_rate(0.1)
333            .with_adam_params(0.9, 0.999)
334            .with_tolerance(1e-4)
335            .with_max_iterations(500);
336        
337        let result = model.fit(&data);
338        assert!(result.is_ok());
339        assert!(model.is_fitted());
340        
341        let coefs = model.coefficients().unwrap();
342        assert_eq!(coefs.len(), 3);
343        assert!(coefs.iter().all(|&x| x.is_finite()));
344    }
345    
346    #[test]
347    fn test_adam_with_regularization() {
348        let data = create_test_data();
349        let mut model = CoxModel::new()
350            .with_optimizer(OptimizerType::Adam)
351            .with_learning_rate(0.05)
352            .with_l1_penalty(0.01)
353            .with_l2_penalty(0.01)
354            .with_tolerance(1e-4)
355            .with_max_iterations(800);
356        
357        let result = model.fit(&data);
358        assert!(result.is_ok());
359        assert!(model.is_fitted());
360        
361        let coefs = model.coefficients().unwrap();
362        assert_eq!(coefs.len(), 3);
363        assert!(coefs.iter().all(|&x| x.is_finite()));
364        
365        // Test predictions still work
366        let test_covariates = Array2::zeros((2, 3));
367        let predictions = model.predict(test_covariates.view());
368        assert!(predictions.is_ok());
369    }
370}