cox_hazards/
optimization.rs

1use ndarray::{Array1, Array2};
2use crate::{
3    data::SurvivalData,
4    error::{CoxError, Result},
5};
6
7/// Optimization algorithm types
8#[derive(Debug, Clone, Copy, PartialEq)]
9pub enum OptimizerType {
10    NewtonRaphson,
11    CoordinateDescent,
12    Adam,
13    RMSprop,
14}
15
16
17/// Configuration for Cox model optimization
18#[derive(Debug, Clone)]
19pub struct OptimizationConfig {
20    pub l1_penalty: f64,
21    pub l2_penalty: f64,
22    pub max_iterations: usize,
23    pub tolerance: f64,
24    pub optimizer_type: OptimizerType,
25    pub learning_rate: f64,
26    pub beta1: f64,  // Adam momentum parameter
27    pub beta2: f64,  // Adam/RMSprop decay parameter for second moment
28    pub epsilon: f64, // Adam/RMSprop numerical stability
29}
30
31impl Default for OptimizationConfig {
32    fn default() -> Self {
33        Self {
34            l1_penalty: 0.0,
35            l2_penalty: 0.0,
36            max_iterations: 1000,
37            tolerance: 1e-6,
38            optimizer_type: OptimizerType::NewtonRaphson,
39            learning_rate: 0.001,
40            beta1: 0.9,
41            beta2: 0.999,
42            epsilon: 1e-8,
43        }
44    }
45}
46
47/// Adam optimizer state for momentum tracking
48#[derive(Debug, Clone)]
49struct AdamState {
50    m: Array1<f64>,  // First moment estimate
51    v: Array1<f64>,  // Second moment estimate
52    t: usize,        // Time step
53}
54
55impl AdamState {
56    fn new(n_features: usize) -> Self {
57        Self {
58            m: Array1::zeros(n_features),
59            v: Array1::zeros(n_features),
60            t: 0,
61        }
62    }
63}
64
65/// RMSprop optimizer state for second moment tracking
66#[derive(Debug, Clone)]
67struct RMSpropState {
68    v: Array1<f64>,  // Second moment estimate (moving average of squared gradients)
69}
70
71impl RMSpropState {
72    fn new(n_features: usize) -> Self {
73        Self {
74            v: Array1::zeros(n_features),
75        }
76    }
77}
78
79/// Cox proportional hazards optimizer with elastic net regularization
80pub struct CoxOptimizer {
81    config: OptimizationConfig,
82    adam_state: Option<AdamState>,
83    rmsprop_state: Option<RMSpropState>,
84}
85
86impl CoxOptimizer {
87    pub fn new(config: OptimizationConfig) -> Self {
88        Self { 
89            config,
90            adam_state: None,
91            rmsprop_state: None,
92        }
93    }
94    
95    /// Optimize Cox model using configured optimizer
96    pub fn optimize(&mut self, data: &SurvivalData) -> Result<Array1<f64>> {
97        let n_features = data.n_features();
98        let mut beta = Array1::zeros(n_features);
99        
100        match self.config.optimizer_type {
101            OptimizerType::Adam => {
102                self.adam_optimize(data, &mut beta)?;
103            }
104            OptimizerType::RMSprop => {
105                self.rmsprop_optimize(data, &mut beta)?;
106            }
107            OptimizerType::CoordinateDescent => {
108                self.coordinate_descent_optimize(data, &mut beta)?;
109            }
110            OptimizerType::NewtonRaphson => {
111                if self.config.l1_penalty > 0.0 {
112                    self.coordinate_descent_optimize(data, &mut beta)?;
113                } else {
114                    self.newton_raphson_optimize(data, &mut beta)?;
115                }
116            }
117        }
118        
119        Ok(beta)
120    }
121    
122    /// Adam optimization algorithm for Cox regression
123    fn adam_optimize(&mut self, data: &SurvivalData, beta: &mut Array1<f64>) -> Result<()> {
124        let n_features = data.n_features();
125        
126        // Initialize Adam state if not already done
127        if self.adam_state.is_none() {
128            self.adam_state = Some(AdamState::new(n_features));
129        }
130        
131        let mut prev_loglik = f64::NEG_INFINITY;
132        let mut best_loglik = f64::NEG_INFINITY;
133        let mut no_improvement_count = 0;
134        let max_no_improvement = 50;
135        
136        for _iteration in 0..self.config.max_iterations {
137            // Compute gradient
138            let gradient = self.compute_cox_gradient(data, beta)?;
139            
140            // Check if gradient is reasonable
141            if gradient.iter().any(|&g| !g.is_finite()) {
142                break; // Stop if gradient becomes invalid
143            }
144            
145            // Apply regularization to gradient
146            let mut regularized_gradient = gradient.clone();
147            
148            // L2 penalty (Ridge)
149            if self.config.l2_penalty > 0.0 {
150                regularized_gradient = &regularized_gradient - &(self.config.l2_penalty * &*beta);
151            }
152            
153            // L1 penalty (Lasso) - use sign of current beta
154            if self.config.l1_penalty > 0.0 {
155                for i in 0..n_features {
156                    if beta[i].abs() > 1e-10 {  // Only apply L1 penalty if beta is not near zero
157                        regularized_gradient[i] -= self.config.l1_penalty * beta[i].signum();
158                    }
159                }
160            }
161            
162            // Adam update
163            if let Some(ref mut adam_state) = self.adam_state {
164                adam_state.t += 1;
165                
166                // Update biased first moment estimate
167                adam_state.m = &(self.config.beta1 * &adam_state.m) + &((1.0 - self.config.beta1) * &regularized_gradient);
168                
169                // Update biased second raw moment estimate
170                adam_state.v = &(self.config.beta2 * &adam_state.v) + 
171                    &((1.0 - self.config.beta2) * &regularized_gradient.mapv(|x| x * x));
172                
173                // Compute bias-corrected first moment estimate
174                let m_hat = &adam_state.m / (1.0 - self.config.beta1.powi(adam_state.t as i32));
175                
176                // Compute bias-corrected second raw moment estimate
177                let v_hat = &adam_state.v / (1.0 - self.config.beta2.powi(adam_state.t as i32));
178                
179                // Update parameters with clipping to prevent exploding gradients
180                for i in 0..n_features {
181                    let update = self.config.learning_rate * m_hat[i] / (v_hat[i].sqrt() + self.config.epsilon);
182                    // Clip updates to reasonable range
183                    let clipped_update = update.max(-1.0).min(1.0);
184                    beta[i] += clipped_update;
185                    
186                    // Prevent coefficients from becoming too large
187                    beta[i] = beta[i].max(-10.0).min(10.0);
188                }
189            }
190            
191            // Check for numerical issues
192            if beta.iter().any(|&b| !b.is_finite()) {
193                break; // Stop if beta becomes invalid
194            }
195            
196            // Check convergence
197            let loglik = self.compute_log_likelihood(data, beta)?;
198            let penalized_loglik = loglik - 
199                0.5 * self.config.l2_penalty * beta.dot(beta) - 
200                self.config.l1_penalty * beta.mapv(f64::abs).sum();
201            
202            // Check for convergence or improvement
203            if (penalized_loglik - prev_loglik).abs() < self.config.tolerance {
204                break;
205            }
206            
207            // Track best likelihood and early stopping
208            if penalized_loglik > best_loglik {
209                best_loglik = penalized_loglik;
210                no_improvement_count = 0;
211            } else {
212                no_improvement_count += 1;
213                if no_improvement_count >= max_no_improvement {
214                    break; // Early stopping if no improvement
215                }
216            }
217            
218            prev_loglik = penalized_loglik;
219        }
220        
221        Ok(())
222    }
223    
224    /// RMSprop optimization algorithm for Cox regression
225    fn rmsprop_optimize(&mut self, data: &SurvivalData, beta: &mut Array1<f64>) -> Result<()> {
226        let n_features = data.n_features();
227        
228        // Initialize RMSprop state if not already done
229        if self.rmsprop_state.is_none() {
230            self.rmsprop_state = Some(RMSpropState::new(n_features));
231        }
232        
233        let mut prev_loglik = f64::NEG_INFINITY;
234        let mut best_loglik = f64::NEG_INFINITY;
235        let mut no_improvement_count = 0;
236        let max_no_improvement = 50;
237        
238        for _iteration in 0..self.config.max_iterations {
239            // Compute gradient
240            let gradient = self.compute_cox_gradient(data, beta)?;
241            
242            // Check if gradient is reasonable
243            if gradient.iter().any(|&g| !g.is_finite()) {
244                break; // Stop if gradient becomes invalid
245            }
246            
247            // Apply regularization to gradient
248            let mut regularized_gradient = gradient.clone();
249            
250            // L2 penalty (Ridge)
251            if self.config.l2_penalty > 0.0 {
252                regularized_gradient = &regularized_gradient - &(self.config.l2_penalty * &*beta);
253            }
254            
255            // L1 penalty (Lasso) - use sign of current beta
256            if self.config.l1_penalty > 0.0 {
257                for i in 0..n_features {
258                    if beta[i].abs() > 1e-10 {  // Only apply L1 penalty if beta is not near zero
259                        regularized_gradient[i] -= self.config.l1_penalty * beta[i].signum();
260                    }
261                }
262            }
263            
264            // RMSprop update
265            if let Some(ref mut rmsprop_state) = self.rmsprop_state {
266                // Update moving average of squared gradients
267                rmsprop_state.v = &(self.config.beta2 * &rmsprop_state.v) + 
268                    &((1.0 - self.config.beta2) * &regularized_gradient.mapv(|x| x * x));
269                
270                // Update parameters with clipping to prevent exploding gradients
271                for i in 0..n_features {
272                    let update = self.config.learning_rate * regularized_gradient[i] / (rmsprop_state.v[i].sqrt() + self.config.epsilon);
273                    // Clip updates to reasonable range
274                    let clipped_update = update.max(-1.0).min(1.0);
275                    beta[i] += clipped_update;
276                    
277                    // Prevent coefficients from becoming too large
278                    beta[i] = beta[i].max(-10.0).min(10.0);
279                }
280            }
281            
282            // Check for numerical issues
283            if beta.iter().any(|&b| !b.is_finite()) {
284                break; // Stop if beta becomes invalid
285            }
286            
287            // Check convergence
288            let loglik = self.compute_log_likelihood(data, beta)?;
289            let penalized_loglik = loglik - 
290                0.5 * self.config.l2_penalty * beta.dot(beta) - 
291                self.config.l1_penalty * beta.mapv(f64::abs).sum();
292            
293            // Check for convergence or improvement
294            if (penalized_loglik - prev_loglik).abs() < self.config.tolerance {
295                break;
296            }
297            
298            // Track best likelihood and early stopping
299            if penalized_loglik > best_loglik {
300                best_loglik = penalized_loglik;
301                no_improvement_count = 0;
302            } else {
303                no_improvement_count += 1;
304                if no_improvement_count >= max_no_improvement {
305                    break; // Early stopping if no improvement
306                }
307            }
308            
309            prev_loglik = penalized_loglik;
310        }
311        
312        Ok(())
313    }
314    
315    /// Newton-Raphson optimization (for Ridge regression)
316    fn newton_raphson_optimize(&self, data: &SurvivalData, beta: &mut Array1<f64>) -> Result<()> {
317        let mut prev_loglik = f64::NEG_INFINITY;
318        
319        for iteration in 0..self.config.max_iterations {
320            let (loglik, gradient, hessian) = self.compute_likelihood_derivatives(data, beta)?;
321            
322            // Add Ridge penalty
323            let penalized_loglik = loglik - 0.5 * self.config.l2_penalty * beta.dot(beta);
324            
325            // Check for convergence
326            if (penalized_loglik - prev_loglik).abs() < self.config.tolerance {
327                break;
328            }
329            
330            if iteration == self.config.max_iterations - 1 {
331                return Err(CoxError::optimization_failed(
332                    "Newton-Raphson failed to converge"
333                ));
334            }
335            
336            // Add Ridge penalty to gradient and Hessian
337            let penalized_gradient = &gradient - self.config.l2_penalty * &*beta;
338            let mut penalized_hessian = hessian.clone();
339            for i in 0..beta.len() {
340                penalized_hessian[[i, i]] -= self.config.l2_penalty;
341            }
342            
343            // Newton-Raphson step
344            match self.solve_linear_system(&penalized_hessian, &penalized_gradient) {
345                Ok(step) => {
346                    *beta = beta.clone() - step;
347                }
348                Err(_) => {
349                    // Fall back to gradient descent
350                    let step_size = 0.01;
351                    *beta = beta.clone() + step_size * &penalized_gradient;
352                }
353            }
354            
355            prev_loglik = penalized_loglik;
356        }
357        
358        Ok(())
359    }
360    
361    /// Coordinate descent optimization (for elastic net)
362    fn coordinate_descent_optimize(&self, data: &SurvivalData, beta: &mut Array1<f64>) -> Result<()> {
363        let n_features = data.n_features();
364        
365        for iteration in 0..self.config.max_iterations {
366            let mut converged = true;
367            let _beta_old = beta.clone();
368            
369            for j in 0..n_features {
370                let beta_old_j = beta[j];
371                
372                // Compute partial residuals
373                let partial_gradient = self.compute_partial_gradient(data, beta, j)?;
374                let partial_hessian = self.compute_partial_hessian(data, beta, j)?;
375                
376                // Coordinate-wise update with soft thresholding
377                let raw_update = beta[j] + partial_gradient / partial_hessian.abs().max(1e-8);
378                beta[j] = self.soft_threshold(raw_update, self.config.l1_penalty / partial_hessian.abs().max(1e-8));
379                
380                // Add Ridge penalty
381                if self.config.l2_penalty > 0.0 {
382                    beta[j] /= 1.0 + self.config.l2_penalty / partial_hessian.abs().max(1e-8);
383                }
384                
385                if (beta[j] - beta_old_j).abs() > self.config.tolerance {
386                    converged = false;
387                }
388            }
389            
390            if converged {
391                break;
392            }
393            
394            if iteration == self.config.max_iterations - 1 {
395                return Err(CoxError::optimization_failed(
396                    "Coordinate descent failed to converge"
397                ));
398            }
399        }
400        
401        Ok(())
402    }
403    
404    /// Soft thresholding operator for L1 regularization
405    fn soft_threshold(&self, x: f64, lambda: f64) -> f64 {
406        if x > lambda {
407            x - lambda
408        } else if x < -lambda {
409            x + lambda
410        } else {
411            0.0
412        }
413    }
414    
415    /// Compute partial gradient for coordinate j
416    fn compute_partial_gradient(&self, data: &SurvivalData, beta: &Array1<f64>, j: usize) -> Result<f64> {
417        let mut gradient = 0.0;
418        let event_times = data.event_times();
419        
420        for &event_time in &event_times {
421            let events_at_time: Vec<usize> = (0..data.n_samples())
422                .filter(|&i| data.times()[i] == event_time && data.events()[i])
423                .collect();
424            
425            if events_at_time.is_empty() {
426                continue;
427            }
428            
429            let risk_set: Vec<usize> = (0..data.n_samples())
430                .filter(|&i| data.times()[i] >= event_time)
431                .collect();
432            
433            if risk_set.is_empty() {
434                continue;
435            }
436            
437            // Calculate risk set sum
438            let mut risk_sum = 0.0;
439            let mut weighted_covariate_sum = 0.0;
440            
441            for &i in &risk_set {
442                let linear_pred = data.covariates().row(i).dot(beta);
443                let exp_pred = linear_pred.exp();
444                risk_sum += exp_pred;
445                weighted_covariate_sum += data.covariates()[[i, j]] * exp_pred;
446            }
447            
448            if risk_sum <= 0.0 {
449                return Err(CoxError::numerical_error("Risk set sum is non-positive"));
450            }
451            
452            // Add contribution from each event
453            for &event_idx in &events_at_time {
454                gradient += data.covariates()[[event_idx, j]] - weighted_covariate_sum / risk_sum;
455            }
456        }
457        
458        Ok(gradient)
459    }
460    
461    /// Compute partial Hessian for coordinate j
462    fn compute_partial_hessian(&self, data: &SurvivalData, beta: &Array1<f64>, j: usize) -> Result<f64> {
463        let mut hessian = 0.0;
464        let event_times = data.event_times();
465        
466        for &event_time in &event_times {
467            let events_at_time: Vec<usize> = (0..data.n_samples())
468                .filter(|&i| data.times()[i] == event_time && data.events()[i])
469                .collect();
470            
471            if events_at_time.is_empty() {
472                continue;
473            }
474            
475            let risk_set: Vec<usize> = (0..data.n_samples())
476                .filter(|&i| data.times()[i] >= event_time)
477                .collect();
478            
479            if risk_set.is_empty() {
480                continue;
481            }
482            
483            let mut risk_sum = 0.0;
484            let mut weighted_covariate_sum = 0.0;
485            let mut weighted_covariate_squared_sum = 0.0;
486            
487            for &i in &risk_set {
488                let linear_pred = data.covariates().row(i).dot(beta);
489                let exp_pred = linear_pred.exp();
490                let covariate_j = data.covariates()[[i, j]];
491                
492                risk_sum += exp_pred;
493                weighted_covariate_sum += covariate_j * exp_pred;
494                weighted_covariate_squared_sum += covariate_j * covariate_j * exp_pred;
495            }
496            
497            if risk_sum <= 0.0 {
498                return Err(CoxError::numerical_error("Risk set sum is non-positive"));
499            }
500            
501            // Second derivative calculation
502            let first_moment = weighted_covariate_sum / risk_sum;
503            let second_moment = weighted_covariate_squared_sum / risk_sum;
504            
505            hessian -= events_at_time.len() as f64 * (second_moment - first_moment * first_moment);
506        }
507        
508        Ok(hessian)
509    }
510    
511    /// Compute log partial likelihood and its derivatives
512    fn compute_likelihood_derivatives(
513        &self,
514        data: &SurvivalData,
515        beta: &Array1<f64>,
516    ) -> Result<(f64, Array1<f64>, Array2<f64>)> {
517        let n_features = data.n_features();
518        let mut loglik = 0.0;
519        let mut gradient = Array1::zeros(n_features);
520        let mut hessian = Array2::zeros((n_features, n_features));
521        
522        let event_times = data.event_times();
523        
524        for &event_time in &event_times {
525            let events_at_time: Vec<usize> = (0..data.n_samples())
526                .filter(|&i| data.times()[i] == event_time && data.events()[i])
527                .collect();
528            
529            if events_at_time.is_empty() {
530                continue;
531            }
532            
533            let risk_set: Vec<usize> = (0..data.n_samples())
534                .filter(|&i| data.times()[i] >= event_time)
535                .collect();
536            
537            if risk_set.is_empty() {
538                continue;
539            }
540            
541            // Compute risk set statistics
542            let (log_sum, weighted_mean, weighted_variance) = self.compute_risk_set_statistics(data, beta, &risk_set)?;
543            
544            // Update likelihood and derivatives for each event
545            for &event_idx in &events_at_time {
546                let event_linear_pred = data.covariates().row(event_idx).dot(beta);
547                loglik += event_linear_pred - log_sum;
548                
549                let event_covariates = data.covariates().row(event_idx).to_owned();
550                gradient += &(&event_covariates - &weighted_mean);
551                
552                // Hessian update
553                hessian -= &weighted_variance;
554            }
555        }
556        
557        Ok((loglik, gradient, hessian))
558    }
559    
560    
561    /// Compute statistics for a risk set
562    fn compute_risk_set_statistics(
563        &self,
564        data: &SurvivalData,
565        beta: &Array1<f64>,
566        risk_set: &[usize],
567    ) -> Result<(f64, Array1<f64>, Array2<f64>)> {
568        let n_features = data.n_features();
569        let mut risk_sum = 0.0;
570        let mut weighted_covariate_sum = Array1::zeros(n_features);
571        let mut weighted_covariate_outer_sum = Array2::zeros((n_features, n_features));
572        
573        for &i in risk_set {
574            let linear_pred = data.covariates().row(i).dot(beta);
575            let exp_pred = linear_pred.exp();
576            
577            if !exp_pred.is_finite() || exp_pred <= 0.0 {
578                return Err(CoxError::numerical_error(
579                    format!("Invalid exponential prediction: {}", exp_pred)
580                ));
581            }
582            
583            risk_sum += exp_pred;
584            let covariates_i = data.covariates().row(i).to_owned();
585            weighted_covariate_sum += &(exp_pred * &covariates_i);
586            
587            // Outer product for Hessian
588            for j in 0..n_features {
589                for k in 0..n_features {
590                    weighted_covariate_outer_sum[[j, k]] += 
591                        exp_pred * covariates_i[j] * covariates_i[k];
592                }
593            }
594        }
595        
596        if risk_sum <= 0.0 {
597            return Err(CoxError::numerical_error("Risk set sum is non-positive"));
598        }
599        
600        let log_sum = risk_sum.ln();
601        let weighted_mean = &weighted_covariate_sum / risk_sum;
602        
603        // Compute variance matrix
604        let mut weighted_variance = weighted_covariate_outer_sum / risk_sum;
605        for i in 0..n_features {
606            for j in 0..n_features {
607                weighted_variance[[i, j]] -= weighted_mean[i] * weighted_mean[j];
608            }
609        }
610        
611        Ok((log_sum, weighted_mean, weighted_variance))
612    }
613    
614    /// Solve linear system Ax = b
615    fn solve_linear_system(&self, a: &Array2<f64>, b: &Array1<f64>) -> Result<Array1<f64>> {
616        // Simple LU decomposition approach
617        // In practice, you might want to use a more robust solver
618        let n = a.nrows();
619        if n != a.ncols() || n != b.len() {
620            return Err(CoxError::invalid_dimensions("Matrix dimensions mismatch"));
621        }
622        
623        let mut a_copy = a.clone();
624        let mut b_copy = b.clone();
625        
626        // Forward elimination
627        for i in 0..n {
628            // Find pivot
629            let mut max_row = i;
630            for k in i + 1..n {
631                if a_copy[[k, i]].abs() > a_copy[[max_row, i]].abs() {
632                    max_row = k;
633                }
634            }
635            
636            if a_copy[[max_row, i]].abs() < 1e-12 {
637                return Err(CoxError::numerical_error("Matrix is singular"));
638            }
639            
640            // Swap rows
641            if max_row != i {
642                for j in 0..n {
643                    let temp = a_copy[[i, j]];
644                    a_copy[[i, j]] = a_copy[[max_row, j]];
645                    a_copy[[max_row, j]] = temp;
646                }
647                let temp = b_copy[i];
648                b_copy[i] = b_copy[max_row];
649                b_copy[max_row] = temp;
650            }
651            
652            // Eliminate
653            for k in i + 1..n {
654                let factor = a_copy[[k, i]] / a_copy[[i, i]];
655                for j in i..n {
656                    a_copy[[k, j]] -= factor * a_copy[[i, j]];
657                }
658                b_copy[k] -= factor * b_copy[i];
659            }
660        }
661        
662        // Back substitution
663        let mut x = Array1::zeros(n);
664        for i in (0..n).rev() {
665            x[i] = b_copy[i];
666            for j in i + 1..n {
667                x[i] -= a_copy[[i, j]] * x[j];
668            }
669            x[i] /= a_copy[[i, i]];
670        }
671        
672        Ok(x)
673    }
674    
675    /// Compute Cox regression gradient
676    fn compute_cox_gradient(&self, data: &SurvivalData, beta: &Array1<f64>) -> Result<Array1<f64>> {
677        let n_features = data.n_features();
678        let mut gradient = Array1::zeros(n_features);
679        let event_times = data.event_times();
680        
681        for &event_time in &event_times {
682            let events_at_time: Vec<usize> = (0..data.n_samples())
683                .filter(|&i| data.times()[i] == event_time && data.events()[i])
684                .collect();
685            
686            if events_at_time.is_empty() {
687                continue;
688            }
689            
690            let risk_set: Vec<usize> = (0..data.n_samples())
691                .filter(|&i| data.times()[i] >= event_time)
692                .collect();
693            
694            if risk_set.is_empty() {
695                continue;
696            }
697            
698            // Compute weighted mean of covariates in risk set
699            let (_, weighted_mean, _) = self.compute_risk_set_statistics(data, beta, &risk_set)?;
700            
701            // Add contribution from each event
702            for &event_idx in &events_at_time {
703                let event_covariates = data.covariates().row(event_idx).to_owned();
704                gradient += &(&event_covariates - &weighted_mean);
705            }
706        }
707        
708        Ok(gradient)
709    }
710    
711    /// Compute Cox regression log-likelihood
712    fn compute_log_likelihood(&self, data: &SurvivalData, beta: &Array1<f64>) -> Result<f64> {
713        let mut loglik = 0.0;
714        let event_times = data.event_times();
715        
716        for &event_time in &event_times {
717            let events_at_time: Vec<usize> = (0..data.n_samples())
718                .filter(|&i| data.times()[i] == event_time && data.events()[i])
719                .collect();
720            
721            if events_at_time.is_empty() {
722                continue;
723            }
724            
725            let risk_set: Vec<usize> = (0..data.n_samples())
726                .filter(|&i| data.times()[i] >= event_time)
727                .collect();
728            
729            if risk_set.is_empty() {
730                continue;
731            }
732            
733            let (log_sum, _, _) = self.compute_risk_set_statistics(data, beta, &risk_set)?;
734            
735            // Add contribution from each event
736            for &event_idx in &events_at_time {
737                let event_linear_pred = data.covariates().row(event_idx).dot(beta);
738                loglik += event_linear_pred - log_sum;
739            }
740        }
741        
742        Ok(loglik)
743    }
744}
745
746#[cfg(test)]
747mod tests {
748    use super::*;
749    use ndarray::Array2;
750    use approx::assert_relative_eq;
751    
752    fn create_test_data() -> SurvivalData {
753        let times = vec![1.0, 2.0, 3.0, 4.0, 5.0];
754        let events = vec![true, true, true, true, true];
755        let covariates = Array2::from_shape_vec((5, 2), vec![
756            1.0, 0.0,
757            0.0, 1.0,
758            1.0, 1.0,
759            -1.0, 0.0,
760            0.0, -1.0,
761        ]).unwrap();
762        
763        SurvivalData::new(times, events, covariates).unwrap()
764    }
765    
766    #[test]
767    fn test_optimizer_creation() {
768        let config = OptimizationConfig::default();
769        let optimizer = CoxOptimizer::new(config.clone());
770        assert_eq!(optimizer.config.l1_penalty, config.l1_penalty);
771        assert_eq!(optimizer.config.l2_penalty, config.l2_penalty);
772    }
773    
774    #[test]
775    fn test_soft_threshold() {
776        let config = OptimizationConfig::default();
777        let optimizer = CoxOptimizer::new(config);
778        
779        assert_relative_eq!(optimizer.soft_threshold(2.0, 1.0), 1.0, epsilon = 1e-10);
780        assert_relative_eq!(optimizer.soft_threshold(-2.0, 1.0), -1.0, epsilon = 1e-10);
781        assert_relative_eq!(optimizer.soft_threshold(0.5, 1.0), 0.0, epsilon = 1e-10);
782    }
783    
784    #[test]
785    fn test_optimization_no_regularization() {
786        let data = create_test_data();
787        let config = OptimizationConfig::default();
788        let mut optimizer = CoxOptimizer::new(config);
789        
790        let result = optimizer.optimize(&data);
791        assert!(result.is_ok());
792        
793        let beta = result.unwrap();
794        assert_eq!(beta.len(), 2);
795    }
796    
797    #[test]
798    fn test_optimization_with_ridge() {
799        let data = create_test_data();
800        let config = OptimizationConfig {
801            l1_penalty: 0.0,
802            l2_penalty: 0.1,
803            max_iterations: 100,
804            tolerance: 1e-6,
805            ..Default::default()
806        };
807        let mut optimizer = CoxOptimizer::new(config);
808        
809        let result = optimizer.optimize(&data);
810        assert!(result.is_ok());
811        
812        let beta = result.unwrap();
813        assert_eq!(beta.len(), 2);
814    }
815    
816    #[test]
817    fn test_optimization_with_lasso() {
818        let data = create_test_data();
819        let config = OptimizationConfig {
820            l1_penalty: 0.1,
821            l2_penalty: 0.0,
822            max_iterations: 100,
823            tolerance: 1e-6,
824            ..Default::default()
825        };
826        let mut optimizer = CoxOptimizer::new(config);
827        
828        let result = optimizer.optimize(&data);
829        assert!(result.is_ok());
830        
831        let beta = result.unwrap();
832        assert_eq!(beta.len(), 2);
833    }
834    
835    #[test]
836    fn test_optimization_with_elastic_net() {
837        let data = create_test_data();
838        let config = OptimizationConfig {
839            l1_penalty: 0.05,
840            l2_penalty: 0.05,
841            max_iterations: 100,
842            tolerance: 1e-6,
843            optimizer_type: OptimizerType::CoordinateDescent,
844            learning_rate: 0.001,
845            beta1: 0.9,
846            beta2: 0.999,
847            epsilon: 1e-8,
848        };
849        let mut optimizer = CoxOptimizer::new(config);
850        
851        let result = optimizer.optimize(&data);
852        assert!(result.is_ok());
853        
854        let beta = result.unwrap();
855        assert_eq!(beta.len(), 2);
856    }
857    
858    #[test]
859    fn test_adam_optimizer() {
860        let data = create_test_data();
861        let config = OptimizationConfig {
862            l1_penalty: 0.0,
863            l2_penalty: 0.0,
864            max_iterations: 500,
865            tolerance: 1e-4,  // Looser tolerance for Adam
866            optimizer_type: OptimizerType::Adam,
867            learning_rate: 0.1,  // Higher learning rate
868            beta1: 0.9,
869            beta2: 0.999,
870            epsilon: 1e-8,
871        };
872        let mut optimizer = CoxOptimizer::new(config);
873        
874        let result = optimizer.optimize(&data);
875        if let Err(ref e) = result {
876            println!("Adam optimizer failed with error: {:?}", e);
877        }
878        assert!(result.is_ok());
879        
880        let beta = result.unwrap();
881        assert_eq!(beta.len(), 2);
882        assert!(beta.iter().all(|&x| x.is_finite()));
883    }
884    
885    #[test]
886    fn test_adam_with_regularization() {
887        let data = create_test_data();
888        let config = OptimizationConfig {
889            l1_penalty: 0.01,
890            l2_penalty: 0.01,
891            max_iterations: 800,
892            tolerance: 1e-4,  // Looser tolerance for Adam
893            optimizer_type: OptimizerType::Adam,
894            learning_rate: 0.05,  // Moderate learning rate
895            beta1: 0.9,
896            beta2: 0.999,
897            epsilon: 1e-8,
898        };
899        let mut optimizer = CoxOptimizer::new(config);
900        
901        let result = optimizer.optimize(&data);
902        assert!(result.is_ok());
903        
904        let beta = result.unwrap();
905        assert_eq!(beta.len(), 2);
906        assert!(beta.iter().all(|&x| x.is_finite()));
907    }
908    
909    #[test]
910    fn test_optimizer_type_enum() {
911        let config1 = OptimizationConfig {
912            optimizer_type: OptimizerType::Adam,
913            ..Default::default()
914        };
915        
916        let config2 = OptimizationConfig {
917            optimizer_type: OptimizerType::NewtonRaphson,
918            ..Default::default()
919        };
920        
921        let config3 = OptimizationConfig {
922            optimizer_type: OptimizerType::CoordinateDescent,
923            ..Default::default()
924        };
925        
926        let config4 = OptimizationConfig {
927            optimizer_type: OptimizerType::RMSprop,
928            ..Default::default()
929        };
930        
931        assert_eq!(config1.optimizer_type, OptimizerType::Adam);
932        assert_eq!(config2.optimizer_type, OptimizerType::NewtonRaphson);
933        assert_eq!(config3.optimizer_type, OptimizerType::CoordinateDescent);
934        assert_eq!(config4.optimizer_type, OptimizerType::RMSprop);
935    }
936    
937    #[test]
938    fn test_rmsprop_optimizer() {
939        let data = create_test_data();
940        let config = OptimizationConfig {
941            l1_penalty: 0.0,
942            l2_penalty: 0.0,
943            max_iterations: 500,
944            tolerance: 1e-4,  // Looser tolerance for RMSprop
945            optimizer_type: OptimizerType::RMSprop,
946            learning_rate: 0.1,  // Higher learning rate
947            beta1: 0.9,  // Not used in RMSprop but kept for consistency
948            beta2: 0.9,  // Decay rate for RMSprop
949            epsilon: 1e-8,
950        };
951        let mut optimizer = CoxOptimizer::new(config);
952        
953        let result = optimizer.optimize(&data);
954        if let Err(ref e) = result {
955            println!("RMSprop optimizer failed with error: {:?}", e);
956        }
957        assert!(result.is_ok());
958        
959        let beta = result.unwrap();
960        assert_eq!(beta.len(), 2);
961        assert!(beta.iter().all(|&x| x.is_finite()));
962    }
963    
964    #[test]
965    fn test_rmsprop_with_regularization() {
966        let data = create_test_data();
967        let config = OptimizationConfig {
968            l1_penalty: 0.01,
969            l2_penalty: 0.01,
970            max_iterations: 800,
971            tolerance: 1e-4,  // Looser tolerance for RMSprop
972            optimizer_type: OptimizerType::RMSprop,
973            learning_rate: 0.05,  // Moderate learning rate
974            beta1: 0.9,  // Not used in RMSprop but kept for consistency
975            beta2: 0.9,  // Decay rate for RMSprop
976            epsilon: 1e-8,
977        };
978        let mut optimizer = CoxOptimizer::new(config);
979        
980        let result = optimizer.optimize(&data);
981        assert!(result.is_ok());
982        
983        let beta = result.unwrap();
984        assert_eq!(beta.len(), 2);
985        assert!(beta.iter().all(|&x| x.is_finite()));
986    }
987}