1use ndarray::{Array1, Array2};
2use crate::{
3 data::SurvivalData,
4 error::{CoxError, Result},
5};
6
7#[derive(Debug, Clone, Copy, PartialEq)]
9pub enum OptimizerType {
10 NewtonRaphson,
11 CoordinateDescent,
12 Adam,
13 RMSprop,
14}
15
16
17#[derive(Debug, Clone)]
19pub struct OptimizationConfig {
20 pub l1_penalty: f64,
21 pub l2_penalty: f64,
22 pub max_iterations: usize,
23 pub tolerance: f64,
24 pub optimizer_type: OptimizerType,
25 pub learning_rate: f64,
26 pub beta1: f64, pub beta2: f64, pub epsilon: f64, }
30
31impl Default for OptimizationConfig {
32 fn default() -> Self {
33 Self {
34 l1_penalty: 0.0,
35 l2_penalty: 0.0,
36 max_iterations: 1000,
37 tolerance: 1e-6,
38 optimizer_type: OptimizerType::NewtonRaphson,
39 learning_rate: 0.001,
40 beta1: 0.9,
41 beta2: 0.999,
42 epsilon: 1e-8,
43 }
44 }
45}
46
47#[derive(Debug, Clone)]
49struct AdamState {
50 m: Array1<f64>, v: Array1<f64>, t: usize, }
54
55impl AdamState {
56 fn new(n_features: usize) -> Self {
57 Self {
58 m: Array1::zeros(n_features),
59 v: Array1::zeros(n_features),
60 t: 0,
61 }
62 }
63}
64
65#[derive(Debug, Clone)]
67struct RMSpropState {
68 v: Array1<f64>, }
70
71impl RMSpropState {
72 fn new(n_features: usize) -> Self {
73 Self {
74 v: Array1::zeros(n_features),
75 }
76 }
77}
78
79pub struct CoxOptimizer {
81 config: OptimizationConfig,
82 adam_state: Option<AdamState>,
83 rmsprop_state: Option<RMSpropState>,
84}
85
86impl CoxOptimizer {
87 pub fn new(config: OptimizationConfig) -> Self {
88 Self {
89 config,
90 adam_state: None,
91 rmsprop_state: None,
92 }
93 }
94
95 pub fn optimize(&mut self, data: &SurvivalData) -> Result<Array1<f64>> {
97 let n_features = data.n_features();
98 let mut beta = Array1::zeros(n_features);
99
100 match self.config.optimizer_type {
101 OptimizerType::Adam => {
102 self.adam_optimize(data, &mut beta)?;
103 }
104 OptimizerType::RMSprop => {
105 self.rmsprop_optimize(data, &mut beta)?;
106 }
107 OptimizerType::CoordinateDescent => {
108 self.coordinate_descent_optimize(data, &mut beta)?;
109 }
110 OptimizerType::NewtonRaphson => {
111 if self.config.l1_penalty > 0.0 {
112 self.coordinate_descent_optimize(data, &mut beta)?;
113 } else {
114 self.newton_raphson_optimize(data, &mut beta)?;
115 }
116 }
117 }
118
119 Ok(beta)
120 }
121
122 fn adam_optimize(&mut self, data: &SurvivalData, beta: &mut Array1<f64>) -> Result<()> {
124 let n_features = data.n_features();
125
126 if self.adam_state.is_none() {
128 self.adam_state = Some(AdamState::new(n_features));
129 }
130
131 let mut prev_loglik = f64::NEG_INFINITY;
132 let mut best_loglik = f64::NEG_INFINITY;
133 let mut no_improvement_count = 0;
134 let max_no_improvement = 50;
135
136 for _iteration in 0..self.config.max_iterations {
137 let gradient = self.compute_cox_gradient(data, beta)?;
139
140 if gradient.iter().any(|&g| !g.is_finite()) {
142 break; }
144
145 let mut regularized_gradient = gradient.clone();
147
148 if self.config.l2_penalty > 0.0 {
150 regularized_gradient = ®ularized_gradient - &(self.config.l2_penalty * &*beta);
151 }
152
153 if self.config.l1_penalty > 0.0 {
155 for i in 0..n_features {
156 if beta[i].abs() > 1e-10 { regularized_gradient[i] -= self.config.l1_penalty * beta[i].signum();
158 }
159 }
160 }
161
162 if let Some(ref mut adam_state) = self.adam_state {
164 adam_state.t += 1;
165
166 adam_state.m = &(self.config.beta1 * &adam_state.m) + &((1.0 - self.config.beta1) * ®ularized_gradient);
168
169 adam_state.v = &(self.config.beta2 * &adam_state.v) +
171 &((1.0 - self.config.beta2) * ®ularized_gradient.mapv(|x| x * x));
172
173 let m_hat = &adam_state.m / (1.0 - self.config.beta1.powi(adam_state.t as i32));
175
176 let v_hat = &adam_state.v / (1.0 - self.config.beta2.powi(adam_state.t as i32));
178
179 for i in 0..n_features {
181 let update = self.config.learning_rate * m_hat[i] / (v_hat[i].sqrt() + self.config.epsilon);
182 let clipped_update = update.max(-1.0).min(1.0);
184 beta[i] += clipped_update;
185
186 beta[i] = beta[i].max(-10.0).min(10.0);
188 }
189 }
190
191 if beta.iter().any(|&b| !b.is_finite()) {
193 break; }
195
196 let loglik = self.compute_log_likelihood(data, beta)?;
198 let penalized_loglik = loglik -
199 0.5 * self.config.l2_penalty * beta.dot(beta) -
200 self.config.l1_penalty * beta.mapv(f64::abs).sum();
201
202 if (penalized_loglik - prev_loglik).abs() < self.config.tolerance {
204 break;
205 }
206
207 if penalized_loglik > best_loglik {
209 best_loglik = penalized_loglik;
210 no_improvement_count = 0;
211 } else {
212 no_improvement_count += 1;
213 if no_improvement_count >= max_no_improvement {
214 break; }
216 }
217
218 prev_loglik = penalized_loglik;
219 }
220
221 Ok(())
222 }
223
224 fn rmsprop_optimize(&mut self, data: &SurvivalData, beta: &mut Array1<f64>) -> Result<()> {
226 let n_features = data.n_features();
227
228 if self.rmsprop_state.is_none() {
230 self.rmsprop_state = Some(RMSpropState::new(n_features));
231 }
232
233 let mut prev_loglik = f64::NEG_INFINITY;
234 let mut best_loglik = f64::NEG_INFINITY;
235 let mut no_improvement_count = 0;
236 let max_no_improvement = 50;
237
238 for _iteration in 0..self.config.max_iterations {
239 let gradient = self.compute_cox_gradient(data, beta)?;
241
242 if gradient.iter().any(|&g| !g.is_finite()) {
244 break; }
246
247 let mut regularized_gradient = gradient.clone();
249
250 if self.config.l2_penalty > 0.0 {
252 regularized_gradient = ®ularized_gradient - &(self.config.l2_penalty * &*beta);
253 }
254
255 if self.config.l1_penalty > 0.0 {
257 for i in 0..n_features {
258 if beta[i].abs() > 1e-10 { regularized_gradient[i] -= self.config.l1_penalty * beta[i].signum();
260 }
261 }
262 }
263
264 if let Some(ref mut rmsprop_state) = self.rmsprop_state {
266 rmsprop_state.v = &(self.config.beta2 * &rmsprop_state.v) +
268 &((1.0 - self.config.beta2) * ®ularized_gradient.mapv(|x| x * x));
269
270 for i in 0..n_features {
272 let update = self.config.learning_rate * regularized_gradient[i] / (rmsprop_state.v[i].sqrt() + self.config.epsilon);
273 let clipped_update = update.max(-1.0).min(1.0);
275 beta[i] += clipped_update;
276
277 beta[i] = beta[i].max(-10.0).min(10.0);
279 }
280 }
281
282 if beta.iter().any(|&b| !b.is_finite()) {
284 break; }
286
287 let loglik = self.compute_log_likelihood(data, beta)?;
289 let penalized_loglik = loglik -
290 0.5 * self.config.l2_penalty * beta.dot(beta) -
291 self.config.l1_penalty * beta.mapv(f64::abs).sum();
292
293 if (penalized_loglik - prev_loglik).abs() < self.config.tolerance {
295 break;
296 }
297
298 if penalized_loglik > best_loglik {
300 best_loglik = penalized_loglik;
301 no_improvement_count = 0;
302 } else {
303 no_improvement_count += 1;
304 if no_improvement_count >= max_no_improvement {
305 break; }
307 }
308
309 prev_loglik = penalized_loglik;
310 }
311
312 Ok(())
313 }
314
315 fn newton_raphson_optimize(&self, data: &SurvivalData, beta: &mut Array1<f64>) -> Result<()> {
317 let mut prev_loglik = f64::NEG_INFINITY;
318
319 for iteration in 0..self.config.max_iterations {
320 let (loglik, gradient, hessian) = self.compute_likelihood_derivatives(data, beta)?;
321
322 let penalized_loglik = loglik - 0.5 * self.config.l2_penalty * beta.dot(beta);
324
325 if (penalized_loglik - prev_loglik).abs() < self.config.tolerance {
327 break;
328 }
329
330 if iteration == self.config.max_iterations - 1 {
331 return Err(CoxError::optimization_failed(
332 "Newton-Raphson failed to converge"
333 ));
334 }
335
336 let penalized_gradient = &gradient - self.config.l2_penalty * &*beta;
338 let mut penalized_hessian = hessian.clone();
339 for i in 0..beta.len() {
340 penalized_hessian[[i, i]] -= self.config.l2_penalty;
341 }
342
343 match self.solve_linear_system(&penalized_hessian, &penalized_gradient) {
345 Ok(step) => {
346 *beta = beta.clone() - step;
347 }
348 Err(_) => {
349 let step_size = 0.01;
351 *beta = beta.clone() + step_size * &penalized_gradient;
352 }
353 }
354
355 prev_loglik = penalized_loglik;
356 }
357
358 Ok(())
359 }
360
361 fn coordinate_descent_optimize(&self, data: &SurvivalData, beta: &mut Array1<f64>) -> Result<()> {
363 let n_features = data.n_features();
364
365 for iteration in 0..self.config.max_iterations {
366 let mut converged = true;
367 let _beta_old = beta.clone();
368
369 for j in 0..n_features {
370 let beta_old_j = beta[j];
371
372 let partial_gradient = self.compute_partial_gradient(data, beta, j)?;
374 let partial_hessian = self.compute_partial_hessian(data, beta, j)?;
375
376 let raw_update = beta[j] + partial_gradient / partial_hessian.abs().max(1e-8);
378 beta[j] = self.soft_threshold(raw_update, self.config.l1_penalty / partial_hessian.abs().max(1e-8));
379
380 if self.config.l2_penalty > 0.0 {
382 beta[j] /= 1.0 + self.config.l2_penalty / partial_hessian.abs().max(1e-8);
383 }
384
385 if (beta[j] - beta_old_j).abs() > self.config.tolerance {
386 converged = false;
387 }
388 }
389
390 if converged {
391 break;
392 }
393
394 if iteration == self.config.max_iterations - 1 {
395 return Err(CoxError::optimization_failed(
396 "Coordinate descent failed to converge"
397 ));
398 }
399 }
400
401 Ok(())
402 }
403
404 fn soft_threshold(&self, x: f64, lambda: f64) -> f64 {
406 if x > lambda {
407 x - lambda
408 } else if x < -lambda {
409 x + lambda
410 } else {
411 0.0
412 }
413 }
414
415 fn compute_partial_gradient(&self, data: &SurvivalData, beta: &Array1<f64>, j: usize) -> Result<f64> {
417 let mut gradient = 0.0;
418 let event_times = data.event_times();
419
420 for &event_time in &event_times {
421 let events_at_time: Vec<usize> = (0..data.n_samples())
422 .filter(|&i| data.times()[i] == event_time && data.events()[i])
423 .collect();
424
425 if events_at_time.is_empty() {
426 continue;
427 }
428
429 let risk_set: Vec<usize> = (0..data.n_samples())
430 .filter(|&i| data.times()[i] >= event_time)
431 .collect();
432
433 if risk_set.is_empty() {
434 continue;
435 }
436
437 let mut risk_sum = 0.0;
439 let mut weighted_covariate_sum = 0.0;
440
441 for &i in &risk_set {
442 let linear_pred = data.covariates().row(i).dot(beta);
443 let exp_pred = linear_pred.exp();
444 risk_sum += exp_pred;
445 weighted_covariate_sum += data.covariates()[[i, j]] * exp_pred;
446 }
447
448 if risk_sum <= 0.0 {
449 return Err(CoxError::numerical_error("Risk set sum is non-positive"));
450 }
451
452 for &event_idx in &events_at_time {
454 gradient += data.covariates()[[event_idx, j]] - weighted_covariate_sum / risk_sum;
455 }
456 }
457
458 Ok(gradient)
459 }
460
461 fn compute_partial_hessian(&self, data: &SurvivalData, beta: &Array1<f64>, j: usize) -> Result<f64> {
463 let mut hessian = 0.0;
464 let event_times = data.event_times();
465
466 for &event_time in &event_times {
467 let events_at_time: Vec<usize> = (0..data.n_samples())
468 .filter(|&i| data.times()[i] == event_time && data.events()[i])
469 .collect();
470
471 if events_at_time.is_empty() {
472 continue;
473 }
474
475 let risk_set: Vec<usize> = (0..data.n_samples())
476 .filter(|&i| data.times()[i] >= event_time)
477 .collect();
478
479 if risk_set.is_empty() {
480 continue;
481 }
482
483 let mut risk_sum = 0.0;
484 let mut weighted_covariate_sum = 0.0;
485 let mut weighted_covariate_squared_sum = 0.0;
486
487 for &i in &risk_set {
488 let linear_pred = data.covariates().row(i).dot(beta);
489 let exp_pred = linear_pred.exp();
490 let covariate_j = data.covariates()[[i, j]];
491
492 risk_sum += exp_pred;
493 weighted_covariate_sum += covariate_j * exp_pred;
494 weighted_covariate_squared_sum += covariate_j * covariate_j * exp_pred;
495 }
496
497 if risk_sum <= 0.0 {
498 return Err(CoxError::numerical_error("Risk set sum is non-positive"));
499 }
500
501 let first_moment = weighted_covariate_sum / risk_sum;
503 let second_moment = weighted_covariate_squared_sum / risk_sum;
504
505 hessian -= events_at_time.len() as f64 * (second_moment - first_moment * first_moment);
506 }
507
508 Ok(hessian)
509 }
510
511 fn compute_likelihood_derivatives(
513 &self,
514 data: &SurvivalData,
515 beta: &Array1<f64>,
516 ) -> Result<(f64, Array1<f64>, Array2<f64>)> {
517 let n_features = data.n_features();
518 let mut loglik = 0.0;
519 let mut gradient = Array1::zeros(n_features);
520 let mut hessian = Array2::zeros((n_features, n_features));
521
522 let event_times = data.event_times();
523
524 for &event_time in &event_times {
525 let events_at_time: Vec<usize> = (0..data.n_samples())
526 .filter(|&i| data.times()[i] == event_time && data.events()[i])
527 .collect();
528
529 if events_at_time.is_empty() {
530 continue;
531 }
532
533 let risk_set: Vec<usize> = (0..data.n_samples())
534 .filter(|&i| data.times()[i] >= event_time)
535 .collect();
536
537 if risk_set.is_empty() {
538 continue;
539 }
540
541 let (log_sum, weighted_mean, weighted_variance) = self.compute_risk_set_statistics(data, beta, &risk_set)?;
543
544 for &event_idx in &events_at_time {
546 let event_linear_pred = data.covariates().row(event_idx).dot(beta);
547 loglik += event_linear_pred - log_sum;
548
549 let event_covariates = data.covariates().row(event_idx).to_owned();
550 gradient += &(&event_covariates - &weighted_mean);
551
552 hessian -= &weighted_variance;
554 }
555 }
556
557 Ok((loglik, gradient, hessian))
558 }
559
560
561 fn compute_risk_set_statistics(
563 &self,
564 data: &SurvivalData,
565 beta: &Array1<f64>,
566 risk_set: &[usize],
567 ) -> Result<(f64, Array1<f64>, Array2<f64>)> {
568 let n_features = data.n_features();
569 let mut risk_sum = 0.0;
570 let mut weighted_covariate_sum = Array1::zeros(n_features);
571 let mut weighted_covariate_outer_sum = Array2::zeros((n_features, n_features));
572
573 for &i in risk_set {
574 let linear_pred = data.covariates().row(i).dot(beta);
575 let exp_pred = linear_pred.exp();
576
577 if !exp_pred.is_finite() || exp_pred <= 0.0 {
578 return Err(CoxError::numerical_error(
579 format!("Invalid exponential prediction: {}", exp_pred)
580 ));
581 }
582
583 risk_sum += exp_pred;
584 let covariates_i = data.covariates().row(i).to_owned();
585 weighted_covariate_sum += &(exp_pred * &covariates_i);
586
587 for j in 0..n_features {
589 for k in 0..n_features {
590 weighted_covariate_outer_sum[[j, k]] +=
591 exp_pred * covariates_i[j] * covariates_i[k];
592 }
593 }
594 }
595
596 if risk_sum <= 0.0 {
597 return Err(CoxError::numerical_error("Risk set sum is non-positive"));
598 }
599
600 let log_sum = risk_sum.ln();
601 let weighted_mean = &weighted_covariate_sum / risk_sum;
602
603 let mut weighted_variance = weighted_covariate_outer_sum / risk_sum;
605 for i in 0..n_features {
606 for j in 0..n_features {
607 weighted_variance[[i, j]] -= weighted_mean[i] * weighted_mean[j];
608 }
609 }
610
611 Ok((log_sum, weighted_mean, weighted_variance))
612 }
613
614 fn solve_linear_system(&self, a: &Array2<f64>, b: &Array1<f64>) -> Result<Array1<f64>> {
616 let n = a.nrows();
619 if n != a.ncols() || n != b.len() {
620 return Err(CoxError::invalid_dimensions("Matrix dimensions mismatch"));
621 }
622
623 let mut a_copy = a.clone();
624 let mut b_copy = b.clone();
625
626 for i in 0..n {
628 let mut max_row = i;
630 for k in i + 1..n {
631 if a_copy[[k, i]].abs() > a_copy[[max_row, i]].abs() {
632 max_row = k;
633 }
634 }
635
636 if a_copy[[max_row, i]].abs() < 1e-12 {
637 return Err(CoxError::numerical_error("Matrix is singular"));
638 }
639
640 if max_row != i {
642 for j in 0..n {
643 let temp = a_copy[[i, j]];
644 a_copy[[i, j]] = a_copy[[max_row, j]];
645 a_copy[[max_row, j]] = temp;
646 }
647 let temp = b_copy[i];
648 b_copy[i] = b_copy[max_row];
649 b_copy[max_row] = temp;
650 }
651
652 for k in i + 1..n {
654 let factor = a_copy[[k, i]] / a_copy[[i, i]];
655 for j in i..n {
656 a_copy[[k, j]] -= factor * a_copy[[i, j]];
657 }
658 b_copy[k] -= factor * b_copy[i];
659 }
660 }
661
662 let mut x = Array1::zeros(n);
664 for i in (0..n).rev() {
665 x[i] = b_copy[i];
666 for j in i + 1..n {
667 x[i] -= a_copy[[i, j]] * x[j];
668 }
669 x[i] /= a_copy[[i, i]];
670 }
671
672 Ok(x)
673 }
674
675 fn compute_cox_gradient(&self, data: &SurvivalData, beta: &Array1<f64>) -> Result<Array1<f64>> {
677 let n_features = data.n_features();
678 let mut gradient = Array1::zeros(n_features);
679 let event_times = data.event_times();
680
681 for &event_time in &event_times {
682 let events_at_time: Vec<usize> = (0..data.n_samples())
683 .filter(|&i| data.times()[i] == event_time && data.events()[i])
684 .collect();
685
686 if events_at_time.is_empty() {
687 continue;
688 }
689
690 let risk_set: Vec<usize> = (0..data.n_samples())
691 .filter(|&i| data.times()[i] >= event_time)
692 .collect();
693
694 if risk_set.is_empty() {
695 continue;
696 }
697
698 let (_, weighted_mean, _) = self.compute_risk_set_statistics(data, beta, &risk_set)?;
700
701 for &event_idx in &events_at_time {
703 let event_covariates = data.covariates().row(event_idx).to_owned();
704 gradient += &(&event_covariates - &weighted_mean);
705 }
706 }
707
708 Ok(gradient)
709 }
710
711 fn compute_log_likelihood(&self, data: &SurvivalData, beta: &Array1<f64>) -> Result<f64> {
713 let mut loglik = 0.0;
714 let event_times = data.event_times();
715
716 for &event_time in &event_times {
717 let events_at_time: Vec<usize> = (0..data.n_samples())
718 .filter(|&i| data.times()[i] == event_time && data.events()[i])
719 .collect();
720
721 if events_at_time.is_empty() {
722 continue;
723 }
724
725 let risk_set: Vec<usize> = (0..data.n_samples())
726 .filter(|&i| data.times()[i] >= event_time)
727 .collect();
728
729 if risk_set.is_empty() {
730 continue;
731 }
732
733 let (log_sum, _, _) = self.compute_risk_set_statistics(data, beta, &risk_set)?;
734
735 for &event_idx in &events_at_time {
737 let event_linear_pred = data.covariates().row(event_idx).dot(beta);
738 loglik += event_linear_pred - log_sum;
739 }
740 }
741
742 Ok(loglik)
743 }
744}
745
746#[cfg(test)]
747mod tests {
748 use super::*;
749 use ndarray::Array2;
750 use approx::assert_relative_eq;
751
752 fn create_test_data() -> SurvivalData {
753 let times = vec![1.0, 2.0, 3.0, 4.0, 5.0];
754 let events = vec![true, true, true, true, true];
755 let covariates = Array2::from_shape_vec((5, 2), vec![
756 1.0, 0.0,
757 0.0, 1.0,
758 1.0, 1.0,
759 -1.0, 0.0,
760 0.0, -1.0,
761 ]).unwrap();
762
763 SurvivalData::new(times, events, covariates).unwrap()
764 }
765
766 #[test]
767 fn test_optimizer_creation() {
768 let config = OptimizationConfig::default();
769 let optimizer = CoxOptimizer::new(config.clone());
770 assert_eq!(optimizer.config.l1_penalty, config.l1_penalty);
771 assert_eq!(optimizer.config.l2_penalty, config.l2_penalty);
772 }
773
774 #[test]
775 fn test_soft_threshold() {
776 let config = OptimizationConfig::default();
777 let optimizer = CoxOptimizer::new(config);
778
779 assert_relative_eq!(optimizer.soft_threshold(2.0, 1.0), 1.0, epsilon = 1e-10);
780 assert_relative_eq!(optimizer.soft_threshold(-2.0, 1.0), -1.0, epsilon = 1e-10);
781 assert_relative_eq!(optimizer.soft_threshold(0.5, 1.0), 0.0, epsilon = 1e-10);
782 }
783
784 #[test]
785 fn test_optimization_no_regularization() {
786 let data = create_test_data();
787 let config = OptimizationConfig::default();
788 let mut optimizer = CoxOptimizer::new(config);
789
790 let result = optimizer.optimize(&data);
791 assert!(result.is_ok());
792
793 let beta = result.unwrap();
794 assert_eq!(beta.len(), 2);
795 }
796
797 #[test]
798 fn test_optimization_with_ridge() {
799 let data = create_test_data();
800 let config = OptimizationConfig {
801 l1_penalty: 0.0,
802 l2_penalty: 0.1,
803 max_iterations: 100,
804 tolerance: 1e-6,
805 ..Default::default()
806 };
807 let mut optimizer = CoxOptimizer::new(config);
808
809 let result = optimizer.optimize(&data);
810 assert!(result.is_ok());
811
812 let beta = result.unwrap();
813 assert_eq!(beta.len(), 2);
814 }
815
816 #[test]
817 fn test_optimization_with_lasso() {
818 let data = create_test_data();
819 let config = OptimizationConfig {
820 l1_penalty: 0.1,
821 l2_penalty: 0.0,
822 max_iterations: 100,
823 tolerance: 1e-6,
824 ..Default::default()
825 };
826 let mut optimizer = CoxOptimizer::new(config);
827
828 let result = optimizer.optimize(&data);
829 assert!(result.is_ok());
830
831 let beta = result.unwrap();
832 assert_eq!(beta.len(), 2);
833 }
834
835 #[test]
836 fn test_optimization_with_elastic_net() {
837 let data = create_test_data();
838 let config = OptimizationConfig {
839 l1_penalty: 0.05,
840 l2_penalty: 0.05,
841 max_iterations: 100,
842 tolerance: 1e-6,
843 optimizer_type: OptimizerType::CoordinateDescent,
844 learning_rate: 0.001,
845 beta1: 0.9,
846 beta2: 0.999,
847 epsilon: 1e-8,
848 };
849 let mut optimizer = CoxOptimizer::new(config);
850
851 let result = optimizer.optimize(&data);
852 assert!(result.is_ok());
853
854 let beta = result.unwrap();
855 assert_eq!(beta.len(), 2);
856 }
857
858 #[test]
859 fn test_adam_optimizer() {
860 let data = create_test_data();
861 let config = OptimizationConfig {
862 l1_penalty: 0.0,
863 l2_penalty: 0.0,
864 max_iterations: 500,
865 tolerance: 1e-4, optimizer_type: OptimizerType::Adam,
867 learning_rate: 0.1, beta1: 0.9,
869 beta2: 0.999,
870 epsilon: 1e-8,
871 };
872 let mut optimizer = CoxOptimizer::new(config);
873
874 let result = optimizer.optimize(&data);
875 if let Err(ref e) = result {
876 println!("Adam optimizer failed with error: {:?}", e);
877 }
878 assert!(result.is_ok());
879
880 let beta = result.unwrap();
881 assert_eq!(beta.len(), 2);
882 assert!(beta.iter().all(|&x| x.is_finite()));
883 }
884
885 #[test]
886 fn test_adam_with_regularization() {
887 let data = create_test_data();
888 let config = OptimizationConfig {
889 l1_penalty: 0.01,
890 l2_penalty: 0.01,
891 max_iterations: 800,
892 tolerance: 1e-4, optimizer_type: OptimizerType::Adam,
894 learning_rate: 0.05, beta1: 0.9,
896 beta2: 0.999,
897 epsilon: 1e-8,
898 };
899 let mut optimizer = CoxOptimizer::new(config);
900
901 let result = optimizer.optimize(&data);
902 assert!(result.is_ok());
903
904 let beta = result.unwrap();
905 assert_eq!(beta.len(), 2);
906 assert!(beta.iter().all(|&x| x.is_finite()));
907 }
908
909 #[test]
910 fn test_optimizer_type_enum() {
911 let config1 = OptimizationConfig {
912 optimizer_type: OptimizerType::Adam,
913 ..Default::default()
914 };
915
916 let config2 = OptimizationConfig {
917 optimizer_type: OptimizerType::NewtonRaphson,
918 ..Default::default()
919 };
920
921 let config3 = OptimizationConfig {
922 optimizer_type: OptimizerType::CoordinateDescent,
923 ..Default::default()
924 };
925
926 let config4 = OptimizationConfig {
927 optimizer_type: OptimizerType::RMSprop,
928 ..Default::default()
929 };
930
931 assert_eq!(config1.optimizer_type, OptimizerType::Adam);
932 assert_eq!(config2.optimizer_type, OptimizerType::NewtonRaphson);
933 assert_eq!(config3.optimizer_type, OptimizerType::CoordinateDescent);
934 assert_eq!(config4.optimizer_type, OptimizerType::RMSprop);
935 }
936
937 #[test]
938 fn test_rmsprop_optimizer() {
939 let data = create_test_data();
940 let config = OptimizationConfig {
941 l1_penalty: 0.0,
942 l2_penalty: 0.0,
943 max_iterations: 500,
944 tolerance: 1e-4, optimizer_type: OptimizerType::RMSprop,
946 learning_rate: 0.1, beta1: 0.9, beta2: 0.9, epsilon: 1e-8,
950 };
951 let mut optimizer = CoxOptimizer::new(config);
952
953 let result = optimizer.optimize(&data);
954 if let Err(ref e) = result {
955 println!("RMSprop optimizer failed with error: {:?}", e);
956 }
957 assert!(result.is_ok());
958
959 let beta = result.unwrap();
960 assert_eq!(beta.len(), 2);
961 assert!(beta.iter().all(|&x| x.is_finite()));
962 }
963
964 #[test]
965 fn test_rmsprop_with_regularization() {
966 let data = create_test_data();
967 let config = OptimizationConfig {
968 l1_penalty: 0.01,
969 l2_penalty: 0.01,
970 max_iterations: 800,
971 tolerance: 1e-4, optimizer_type: OptimizerType::RMSprop,
973 learning_rate: 0.05, beta1: 0.9, beta2: 0.9, epsilon: 1e-8,
977 };
978 let mut optimizer = CoxOptimizer::new(config);
979
980 let result = optimizer.optimize(&data);
981 assert!(result.is_ok());
982
983 let beta = result.unwrap();
984 assert_eq!(beta.len(), 2);
985 assert!(beta.iter().all(|&x| x.is_finite()));
986 }
987}