1use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
2use crate::{
3 data::SurvivalData,
4 error::{CoxError, Result},
5 optimization::{CoxOptimizer, OptimizationConfig, OptimizerType},
6};
7
8#[derive(Debug, Clone)]
10pub struct CoxModel {
11 coefficients: Option<Array1<f64>>, l1_penalty: f64, l2_penalty: f64, max_iterations: usize, tolerance: f64, fitted: bool, feature_names: Option<Vec<String>>, optimizer_type: OptimizerType, learning_rate: f64, beta1: f64, beta2: f64, epsilon: f64, }
24
25impl Default for CoxModel {
26 fn default() -> Self {
27 Self {
28 coefficients: None,
29 l1_penalty: 0.0,
30 l2_penalty: 0.0,
31 max_iterations: 1000,
32 tolerance: 1e-6,
33 fitted: false,
34 feature_names: None,
35 optimizer_type: OptimizerType::NewtonRaphson,
36 learning_rate: 0.001,
37 beta1: 0.9,
38 beta2: 0.999,
39 epsilon: 1e-8,
40 }
41 }
42}
43
44impl CoxModel {
45 pub fn new() -> Self {
47 Self::default()
48 }
49
50 pub fn with_l1_penalty(mut self, penalty: f64) -> Self {
52 self.l1_penalty = penalty.max(0.0);
53 self
54 }
55
56 pub fn with_l2_penalty(mut self, penalty: f64) -> Self {
58 self.l2_penalty = penalty.max(0.0);
59 self
60 }
61
62 pub fn with_elastic_net(mut self, alpha: f64, penalty: f64) -> Self {
64 if alpha < 0.0 || alpha > 1.0 {
65 panic!("alpha must be in [0,1]");
66 }
67 self.l1_penalty = alpha * penalty; self.l2_penalty = (1.0 - alpha) * penalty; self
70 }
71
72 pub fn with_max_iterations(mut self, max_iter: usize) -> Self {
74 self.max_iterations = max_iter;
75 self
76 }
77
78 pub fn with_tolerance(mut self, tol: f64) -> Self {
80 self.tolerance = tol;
81 self
82 }
83
84 pub fn with_feature_names(mut self, names: Vec<String>) -> Self {
86 self.feature_names = Some(names);
87 self
88 }
89
90 pub fn with_optimizer(mut self, optimizer: OptimizerType) -> Self {
92 self.optimizer_type = optimizer;
93 self
94 }
95
96 pub fn with_learning_rate(mut self, lr: f64) -> Self {
98 self.learning_rate = lr.max(0.0);
99 self
100 }
101
102 pub fn with_adam_params(mut self, beta1: f64, beta2: f64) -> Self {
104 self.beta1 = beta1.clamp(0.0, 1.0);
105 self.beta2 = beta2.clamp(0.0, 1.0);
106 self
107 }
108
109 pub fn with_epsilon(mut self, eps: f64) -> Self {
111 self.epsilon = eps.max(0.0);
112 self
113 }
114
115 pub fn fit(&mut self, data: &SurvivalData) -> Result<&mut Self> {
117 let config = OptimizationConfig {
118 l1_penalty: self.l1_penalty,
119 l2_penalty: self.l2_penalty,
120 max_iterations: self.max_iterations,
121 tolerance: self.tolerance,
122 optimizer_type: self.optimizer_type,
123 learning_rate: self.learning_rate,
124 beta1: self.beta1,
125 beta2: self.beta2,
126 epsilon: self.epsilon,
127 };
128
129 let mut optimizer = CoxOptimizer::new(config);
130 self.coefficients = Some(optimizer.optimize(data)?);
131 self.fitted = true;
132
133 Ok(self)
134 }
135
136 pub fn coefficients(&self) -> Result<ArrayView1<'_, f64>> {
138 match &self.coefficients {
139 Some(coefs) => Ok(coefs.view()),
140 None => Err(CoxError::ModelNotFitted),
141 }
142 }
143
144 pub fn predict(&self, covariates: ArrayView2<f64>) -> Result<Array1<f64>> {
146 let coefs = self.coefficients()?;
147
148 if covariates.ncols() != coefs.len() {
149 return Err(CoxError::invalid_dimensions(
150 format!("feature count mismatch: expected {}, got {}",
151 coefs.len(), covariates.ncols())
152 ));
153 }
154
155 Ok(covariates.dot(&coefs)) }
157
158 pub fn predict_hazard_ratios(&self, covariates: ArrayView2<f64>) -> Result<Array1<f64>> {
160 let linear_predictors = self.predict(covariates)?;
161 Ok(linear_predictors.mapv(f64::exp))
162 }
163
164 pub fn predict_survival(&self, covariates: ArrayView2<f64>, times: ArrayView1<f64>) -> Result<Array2<f64>> {
166 let risk_scores = self.predict(covariates)?;
167 let n_samples = covariates.nrows();
168 let n_times = times.len();
169
170 let mut survival_probs = Array2::zeros((n_samples, n_times));
172
173 for (i, &time) in times.iter().enumerate() {
174 for j in 0..n_samples {
175 let hazard_ratio = risk_scores[j].exp();
176 let baseline_hazard = 0.1; survival_probs[[j, i]] = (-baseline_hazard * hazard_ratio * time).exp();
178 }
179 }
180
181 Ok(survival_probs)
182 }
183
184 pub fn feature_importance(&self) -> Result<Array1<f64>> {
186 let coefs = self.coefficients()?;
187 Ok(coefs.mapv(f64::abs))
188 }
189
190 pub fn summary(&self) -> Result<CoxModelSummary> {
192 if !self.fitted {
193 return Err(CoxError::ModelNotFitted);
194 }
195
196 let coefs = self.coefficients()?.to_owned();
197 let hazard_ratios = coefs.mapv(f64::exp);
198
199 Ok(CoxModelSummary {
200 coefficients: coefs,
201 hazard_ratios,
202 l1_penalty: self.l1_penalty,
203 l2_penalty: self.l2_penalty,
204 feature_names: self.feature_names.clone(),
205 })
206 }
207
208 pub fn is_fitted(&self) -> bool {
210 self.fitted
211 }
212
213 pub fn regularization_params(&self) -> (f64, f64) {
215 (self.l1_penalty, self.l2_penalty) }
217}
218
219#[derive(Debug, Clone)]
221pub struct CoxModelSummary {
222 pub coefficients: Array1<f64>, pub hazard_ratios: Array1<f64>, pub l1_penalty: f64, pub l2_penalty: f64, pub feature_names: Option<Vec<String>>, }
228
229impl CoxModelSummary {
230 pub fn print(&self) {
232 println!("cox proportional hazards model summary");
233 println!("=====================================");
234 println!("l1 penalty (lasso): {:.6}", self.l1_penalty);
235 println!("l2 penalty (ridge): {:.6}", self.l2_penalty);
236 println!("");
237
238 println!("{:<20} {:>12} {:>12}", "feature", "coefficient", "hazard ratio");
239 println!("{:-<44}", "");
240
241 for i in 0..self.coefficients.len() {
242 let default_name = format!("x{}", i);
243 let feature_name = match &self.feature_names {
244 Some(names) => names.get(i).map(|s| s.as_str()).unwrap_or(&default_name),
245 None => &default_name,
246 };
247
248 println!("{:<20} {:>12.6} {:>12.6}",
249 feature_name,
250 self.coefficients[i],
251 self.hazard_ratios[i]);
252 }
253 }
254}
255
256#[cfg(test)]
257mod tests {
258 use super::*;
259 use ndarray::Array2;
260 use approx::assert_relative_eq;
261
262 fn create_test_data() -> SurvivalData {
263 let times = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
264 let events = vec![true, false, true, true, false, true, true, false];
265 let covariates = Array2::from_shape_vec((8, 3), vec![
266 1.0, 0.0, 0.5,
267 0.0, 1.0, -0.5,
268 1.0, 1.0, 0.0,
269 -1.0, 0.0, 1.0,
270 0.0, -1.0, -1.0,
271 1.0, -1.0, 0.5,
272 -1.0, 1.0, -0.5,
273 0.0, 0.0, 0.0,
274 ]).unwrap();
275
276 SurvivalData::new(times, events, covariates).unwrap()
277 }
278
279 #[test]
280 fn test_model_creation() {
281 let model = CoxModel::new()
282 .with_l1_penalty(0.1)
283 .with_l2_penalty(0.05)
284 .with_max_iterations(500);
285
286 assert_eq!(model.l1_penalty, 0.1);
287 assert_eq!(model.l2_penalty, 0.05);
288 assert_eq!(model.max_iterations, 500);
289 assert!(!model.is_fitted());
290 }
291
292 #[test]
293 fn test_elastic_net_parameters() {
294 let model = CoxModel::new().with_elastic_net(0.5, 1.0);
295 assert_relative_eq!(model.l1_penalty, 0.5, epsilon = 1e-10);
296 assert_relative_eq!(model.l2_penalty, 0.5, epsilon = 1e-10);
297 }
298
299 #[test]
300 fn test_model_not_fitted_error() {
301 let model = CoxModel::new();
302 assert!(model.coefficients().is_err());
303 assert!(model.summary().is_err());
304
305 let covariates = Array2::zeros((5, 3));
306 assert!(model.predict(covariates.view()).is_err());
307 }
308
309 #[test]
310 fn test_feature_names() {
311 let names = vec!["age".to_string(), "gender".to_string(), "treatment".to_string()];
312 let model = CoxModel::new().with_feature_names(names.clone());
313 assert_eq!(model.feature_names.unwrap(), names);
314 }
315
316 #[test]
317 fn test_prediction_dimension_mismatch() {
318 let data = create_test_data();
319 let mut model = CoxModel::new();
320 model.fit(&data).unwrap();
321
322 let wrong_covariates = Array2::zeros((5, 2)); assert!(model.predict(wrong_covariates.view()).is_err());
325 }
326
327 #[test]
328 fn test_adam_optimizer() {
329 let data = create_test_data();
330 let mut model = CoxModel::new()
331 .with_optimizer(OptimizerType::Adam)
332 .with_learning_rate(0.1)
333 .with_adam_params(0.9, 0.999)
334 .with_tolerance(1e-4)
335 .with_max_iterations(500);
336
337 let result = model.fit(&data);
338 assert!(result.is_ok());
339 assert!(model.is_fitted());
340
341 let coefs = model.coefficients().unwrap();
342 assert_eq!(coefs.len(), 3);
343 assert!(coefs.iter().all(|&x| x.is_finite()));
344 }
345
346 #[test]
347 fn test_adam_with_regularization() {
348 let data = create_test_data();
349 let mut model = CoxModel::new()
350 .with_optimizer(OptimizerType::Adam)
351 .with_learning_rate(0.05)
352 .with_l1_penalty(0.01)
353 .with_l2_penalty(0.01)
354 .with_tolerance(1e-4)
355 .with_max_iterations(800);
356
357 let result = model.fit(&data);
358 assert!(result.is_ok());
359 assert!(model.is_fitted());
360
361 let coefs = model.coefficients().unwrap();
362 assert_eq!(coefs.len(), 3);
363 assert!(coefs.iter().all(|&x| x.is_finite()));
364
365 let test_covariates = Array2::zeros((2, 3));
367 let predictions = model.predict(test_covariates.view());
368 assert!(predictions.is_ok());
369 }
370}