Skip to main content

scirs2_stats/
advanced_integration.rs

1//! Advanced Statistical Analysis Integration
2//!
3//! This module provides high-level interfaces that integrate multiple advanced
4//! statistical methods for comprehensive data analysis workflows.
5
6use crate::error::{StatsError, StatsResult as Result};
7use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
8use scirs2_core::validation::*;
9
10use crate::bayesian::{BayesianLinearRegression, BayesianRegressionResult};
11use crate::mcmc::{GibbsSampler, MultivariateNormalGibbs};
12use crate::multivariate::{FactorAnalysis, FactorAnalysisResult, PCAResult, PCA};
13use crate::qmc::{halton, latin_hypercube, sobol};
14use crate::survival::{CoxPHModel, KaplanMeierEstimator};
15
16/// Comprehensive Bayesian analysis workflow
17#[derive(Debug, Clone)]
18pub struct BayesianAnalysisWorkflow {
19    /// Enable MCMC sampling
20    pub use_mcmc: bool,
21    /// Number of MCMC samples
22    pub n_mcmc_samples: usize,
23    /// MCMC burn-in period
24    pub mcmc_burnin: usize,
25    /// Random seed
26    pub random_seed: Option<u64>,
27}
28
29impl Default for BayesianAnalysisWorkflow {
30    fn default() -> Self {
31        Self {
32            use_mcmc: true,
33            n_mcmc_samples: 1000,
34            mcmc_burnin: 100,
35            random_seed: None,
36        }
37    }
38}
39
40/// Results of comprehensive Bayesian analysis
41#[derive(Debug, Clone)]
42pub struct BayesianAnalysisResult {
43    /// Bayesian regression results
44    pub regression: BayesianRegressionResult,
45    /// MCMC samples (if requested)
46    pub mcmc_samples: Option<Array2<f64>>,
47    /// Posterior predictive samples
48    pub predictive_samples: Option<Array2<f64>>,
49    /// Model comparison metrics
50    pub model_metrics: BayesianModelMetrics,
51}
52
53/// Bayesian model comparison metrics
54#[derive(Debug, Clone)]
55pub struct BayesianModelMetrics {
56    /// Log marginal likelihood
57    pub log_marginal_likelihood: f64,
58    /// Deviance Information Criterion
59    pub dic: f64,
60    /// Widely Applicable Information Criterion
61    pub waic: f64,
62    /// Leave-One-Out Cross-Validation
63    pub loo_ic: f64,
64}
65
66impl BayesianAnalysisWorkflow {
67    /// Create new Bayesian analysis workflow
68    pub fn new() -> Self {
69        Self::default()
70    }
71
72    /// Configure MCMC settings
73    pub fn with_mcmc(mut self, n_samples_: usize, burnin: usize) -> Self {
74        self.use_mcmc = true;
75        self.n_mcmc_samples = n_samples_;
76        self.mcmc_burnin = burnin;
77        self
78    }
79
80    /// Disable MCMC sampling
81    pub fn without_mcmc(mut self) -> Self {
82        self.use_mcmc = false;
83        self
84    }
85
86    /// Set random seed
87    pub fn with_seed(mut self, seed: u64) -> Self {
88        self.random_seed = Some(seed);
89        self
90    }
91
92    /// Perform comprehensive Bayesian analysis
93    pub fn analyze(
94        &self,
95        x: ArrayView2<f64>,
96        y: ArrayView1<f64>,
97    ) -> Result<BayesianAnalysisResult> {
98        checkarray_finite(&x, "x")?;
99        checkarray_finite(&y, "y")?;
100
101        let (n_samples_, n_features) = x.dim();
102        if y.len() != n_samples_ {
103            return Err(StatsError::DimensionMismatch(format!(
104                "y length ({}) must match x rows ({})",
105                y.len(),
106                n_samples_
107            )));
108        }
109
110        // Perform Bayesian linear regression
111        let bayesian_reg = BayesianLinearRegression::new(n_features, true)?;
112        let regression = bayesian_reg.fit(x, y)?;
113
114        // MCMC sampling if requested
115        let mcmc_samples = if self.use_mcmc {
116            Some(self.perform_mcmc_sampling(&regression, n_features)?)
117        } else {
118            None
119        };
120
121        // Generate predictive samples
122        let predictive_samples = if self.use_mcmc {
123            Some(self.generate_predictive_samples(&bayesian_reg, &regression, x)?)
124        } else {
125            None
126        };
127
128        // Compute model metrics
129        let model_metrics = self.compute_model_metrics(&regression, x, y)?;
130
131        Ok(BayesianAnalysisResult {
132            regression,
133            mcmc_samples,
134            predictive_samples,
135            model_metrics,
136        })
137    }
138
139    /// Perform MCMC sampling from posterior
140    fn perform_mcmc_sampling(
141        &self,
142        regression: &BayesianRegressionResult,
143        _n_features: usize,
144    ) -> Result<Array2<f64>> {
145        use scirs2_core::random::{rngs::StdRng, SeedableRng};
146
147        let mut rng = match self.random_seed {
148            Some(seed) => StdRng::seed_from_u64(seed),
149            None => {
150                use std::time::{SystemTime, UNIX_EPOCH};
151                let seed = SystemTime::now()
152                    .duration_since(UNIX_EPOCH)
153                    .unwrap_or_default()
154                    .as_secs();
155                StdRng::seed_from_u64(seed)
156            }
157        };
158
159        // Use Gibbs sampling for multivariate normal posterior
160        let gibbs_sampler = MultivariateNormalGibbs::from_precision(
161            regression.posterior_mean.clone(),
162            regression.posterior_covariance.clone(),
163        )?;
164
165        let mut sampler = GibbsSampler::new(gibbs_sampler, regression.posterior_mean.clone())?;
166
167        // Burn-in
168        for _ in 0..self.mcmc_burnin {
169            sampler.step(&mut rng)?;
170        }
171
172        // Collect samples
173        let samples = sampler.sample(self.n_mcmc_samples, &mut rng)?;
174        Ok(samples)
175    }
176
177    /// Generate posterior predictive samples
178    fn generate_predictive_samples(
179        &self,
180        bayesian_reg: &BayesianLinearRegression,
181        regression: &BayesianRegressionResult,
182        x_test: ArrayView2<f64>,
183    ) -> Result<Array2<f64>> {
184        use scirs2_core::random::{rngs::StdRng, SeedableRng};
185        use scirs2_core::random::{Distribution, Normal};
186
187        let mut rng = match self.random_seed {
188            Some(seed) => StdRng::seed_from_u64(seed),
189            None => {
190                use std::time::{SystemTime, UNIX_EPOCH};
191                let seed = SystemTime::now()
192                    .duration_since(UNIX_EPOCH)
193                    .unwrap_or_default()
194                    .as_secs();
195                StdRng::seed_from_u64(seed)
196            }
197        };
198
199        let n_test = x_test.nrows();
200        let mut predictive_samples = Array2::zeros((self.n_mcmc_samples, n_test));
201
202        // Generate predictive samples
203        for i in 0..self.n_mcmc_samples {
204            // Sample from posterior parameter distribution
205            let mut beta_sample = Array1::zeros(regression.posterior_mean.len());
206            for j in 0..beta_sample.len() {
207                let var = regression.posterior_covariance[[j, j]];
208                let normal =
209                    Normal::new(regression.posterior_mean[j], var.sqrt()).map_err(|e| {
210                        StatsError::ComputationError(format!("Failed to create normal: {}", e))
211                    })?;
212                beta_sample[j] = normal.sample(&mut rng);
213            }
214
215            // Generate predictions with this parameter sample
216            let pred_result = bayesian_reg.predict(x_test, regression)?;
217
218            // Add noise
219            let noise_std = (regression.posterior_beta / regression.posterior_alpha).sqrt();
220            let noise_normal = Normal::new(0.0, noise_std).map_err(|e| {
221                StatsError::ComputationError(format!("Failed to create noise normal: {}", e))
222            })?;
223
224            for j in 0..n_test {
225                let noise = noise_normal.sample(&mut rng);
226                predictive_samples[[i, j]] = pred_result.mean[j] + noise;
227            }
228        }
229
230        Ok(predictive_samples)
231    }
232
233    /// Compute Bayesian model comparison metrics
234    fn compute_model_metrics(
235        &self,
236        regression: &BayesianRegressionResult,
237        x: ArrayView2<f64>,
238        _y: ArrayView1<f64>,
239    ) -> Result<BayesianModelMetrics> {
240        let n_samples_ = x.nrows() as f64;
241        let n_params = regression.posterior_mean.len() as f64;
242
243        // Log marginal likelihood (already computed)
244        let log_marginal_likelihood = regression.log_marginal_likelihood;
245
246        // Simplified DIC calculation
247        let deviance = -2.0 * log_marginal_likelihood;
248        let effective_params = n_params; // Simplified
249        let dic = deviance + 2.0 * effective_params;
250
251        // Simplified WAIC (Watanabe-Akaike Information Criterion)
252        let waic = -2.0 * log_marginal_likelihood + 2.0 * effective_params;
253
254        // Simplified LOO-IC (Leave-One-Out Information Criterion)
255        let loo_ic = -2.0 * log_marginal_likelihood
256            + 2.0 * effective_params * n_samples_ / (n_samples_ - n_params - 1.0);
257
258        Ok(BayesianModelMetrics {
259            log_marginal_likelihood,
260            dic,
261            waic,
262            loo_ic,
263        })
264    }
265}
266
267/// Dimensionality reduction and analysis workflow
268#[derive(Debug, Clone)]
269pub struct DimensionalityAnalysisWorkflow {
270    /// Number of PCA components
271    pub n_pca_components: Option<usize>,
272    /// Number of factors for factor analysis
273    pub n_factors: Option<usize>,
274    /// Whether to use incremental PCA for large datasets
275    pub use_incremental_pca: bool,
276    /// PCA batch size (for incremental)
277    pub pca_batchsize: usize,
278    /// Random seed
279    pub random_seed: Option<u64>,
280}
281
282impl Default for DimensionalityAnalysisWorkflow {
283    fn default() -> Self {
284        Self {
285            n_pca_components: None,
286            n_factors: None,
287            use_incremental_pca: false,
288            pca_batchsize: 1000,
289            random_seed: None,
290        }
291    }
292}
293
294/// Results of dimensionality analysis
295#[derive(Debug, Clone)]
296pub struct DimensionalityAnalysisResult {
297    /// PCA results
298    pub pca: Option<PCAResult>,
299    /// Factor analysis results
300    pub factor_analysis: Option<FactorAnalysisResult>,
301    /// Recommended number of components/factors
302    pub recommendations: DimensionalityRecommendations,
303    /// Comparison metrics
304    pub comparison_metrics: DimensionalityMetrics,
305}
306
307/// Recommendations for dimensionality reduction
308#[derive(Debug, Clone)]
309pub struct DimensionalityRecommendations {
310    /// Optimal number of PCA components (Kaiser criterion)
311    pub optimal_pca_components: usize,
312    /// Optimal number of factors (parallel analysis)
313    pub optimal_factors: usize,
314    /// Variance explained by recommended components
315    pub explained_variance_ratio: f64,
316}
317
318/// Comparison metrics for dimensionality reduction methods
319#[derive(Debug, Clone)]
320pub struct DimensionalityMetrics {
321    /// Scree plot data (eigenvalues)
322    pub eigenvalues: Array1<f64>,
323    /// Cumulative explained variance
324    pub cumulative_variance: Array1<f64>,
325    /// Kaiser-Meyer-Olkin measure
326    pub kmo_measure: f64,
327    /// Bartlett's test statistic and p-value
328    pub bartlett_test: (f64, f64),
329}
330
331impl DimensionalityAnalysisWorkflow {
332    /// Create new dimensionality analysis workflow
333    pub fn new() -> Self {
334        Self::default()
335    }
336
337    /// Set PCA configuration
338    pub fn with_pca(
339        mut self,
340        n_components: Option<usize>,
341        incremental: bool,
342        batchsize: usize,
343    ) -> Self {
344        self.n_pca_components = n_components;
345        self.use_incremental_pca = incremental;
346        self.pca_batchsize = batchsize;
347        self
348    }
349
350    /// Set factor analysis configuration
351    pub fn with_factor_analysis(mut self, n_factors: Option<usize>) -> Self {
352        self.n_factors = n_factors;
353        self
354    }
355
356    /// Set random seed
357    pub fn with_seed(mut self, seed: u64) -> Self {
358        self.random_seed = Some(seed);
359        self
360    }
361
362    /// Perform comprehensive dimensionality analysis
363    pub fn analyze(&self, data: ArrayView2<f64>) -> Result<DimensionalityAnalysisResult> {
364        checkarray_finite(&data, "data")?;
365        let (n_samples_, n_features) = data.dim();
366
367        if n_samples_ < 3 {
368            return Err(StatsError::InvalidArgument(
369                "Need at least 3 samples for analysis".to_string(),
370            ));
371        }
372
373        // Perform PCA analysis
374        let pca = if self.use_incremental_pca && n_samples_ > self.pca_batchsize {
375            Some(self.perform_incremental_pca(data)?)
376        } else {
377            Some(self.perform_standard_pca(data)?)
378        };
379
380        // Perform factor analysis if requested
381        let factor_analysis = if self.n_factors.is_some() {
382            Some(self.perform_factor_analysis(data)?)
383        } else {
384            None
385        };
386
387        // Generate recommendations
388        let recommendations = self.generate_recommendations(data, &pca)?;
389
390        // Compute comparison metrics
391        let comparison_metrics = self.compute_metrics(data)?;
392
393        Ok(DimensionalityAnalysisResult {
394            pca,
395            factor_analysis,
396            recommendations,
397            comparison_metrics,
398        })
399    }
400
401    /// Perform standard PCA
402    fn perform_standard_pca(&self, data: ArrayView2<f64>) -> Result<PCAResult> {
403        let n_components = self
404            .n_pca_components
405            .unwrap_or(data.ncols().min(data.nrows()));
406
407        let pca = PCA::new()
408            .with_n_components(n_components)
409            .with_center(true)
410            .with_scale(false);
411
412        if let Some(seed) = self.random_seed {
413            pca.with_random_state(seed).fit(data)
414        } else {
415            pca.fit(data)
416        }
417    }
418
419    /// Perform incremental PCA for large datasets
420    fn perform_incremental_pca(&self, data: ArrayView2<f64>) -> Result<PCAResult> {
421        // For now, fall back to standard PCA since IncrementalPCA fields are private
422        // This would need to be implemented with public accessors in the actual IncrementalPCA
423        self.perform_standard_pca(data)
424    }
425
426    /// Perform factor analysis
427    fn perform_factor_analysis(&self, data: ArrayView2<f64>) -> Result<FactorAnalysisResult> {
428        use crate::multivariate::RotationType;
429
430        let n_factors = self.n_factors.unwrap_or(2);
431
432        let mut fa = FactorAnalysis::new(n_factors)?
433            .with_rotation(RotationType::Varimax)
434            .with_max_iter(1000)
435            .with_tolerance(1e-6);
436
437        if let Some(seed) = self.random_seed {
438            fa = fa.with_random_state(seed);
439        }
440
441        fa.fit(data)
442    }
443
444    /// Generate dimensionality recommendations
445    fn generate_recommendations(
446        &self,
447        data: ArrayView2<f64>,
448        pca: &Option<PCAResult>,
449    ) -> Result<DimensionalityRecommendations> {
450        use crate::multivariate::{efa::parallel_analysis, mle_components};
451
452        // Kaiser criterion for PCA (eigenvalues > 1)
453        let optimal_pca_components = if let Some(ref pca_result) = pca {
454            pca_result
455                .explained_variance
456                .iter()
457                .position(|&ev| ev < 1.0)
458                .unwrap_or(pca_result.explained_variance.len())
459        } else {
460            mle_components(data, None)?
461        };
462
463        // Parallel analysis for factor analysis
464        let optimal_factors = parallel_analysis(data, 100, 95.0, self.random_seed)?;
465
466        // Explained variance ratio
467        let explained_variance_ratio = if let Some(ref pca_result) = pca {
468            pca_result
469                .explained_variance_ratio
470                .slice(scirs2_core::ndarray::s![..optimal_pca_components])
471                .sum()
472        } else {
473            0.0
474        };
475
476        Ok(DimensionalityRecommendations {
477            optimal_pca_components,
478            optimal_factors,
479            explained_variance_ratio,
480        })
481    }
482
483    /// Compute comparison metrics
484    fn compute_metrics(&self, data: ArrayView2<f64>) -> Result<DimensionalityMetrics> {
485        use crate::multivariate::efa::{bartlett_test, kmo_test};
486
487        // Compute covariance matrix for eigenvalues
488        let mean = data
489            .mean_axis(scirs2_core::ndarray::Axis(0))
490            .expect("Operation failed");
491        let mut centered = data.to_owned();
492        for mut row in centered.rows_mut() {
493            row -= &mean;
494        }
495
496        let cov = centered.t().dot(&centered) / (data.nrows() - 1) as f64;
497
498        // Compute eigenvalues using scirs2_linalg
499        let (eigenvalues_unsorted, _eigenvectors) = scirs2_linalg::eigh_f64_lapack(&cov.view())
500            .map_err(|e| {
501                StatsError::ComputationError(format!("Eigenvalue decomposition failed: {}", e))
502            })?;
503
504        // Sort eigenvalues in descending order
505        let mut sorted_eigenvalues: Vec<f64> = eigenvalues_unsorted.to_vec();
506        sorted_eigenvalues.sort_by(|a: &f64, b: &f64| b.partial_cmp(a).expect("Operation failed"));
507        let eigenvalues = Array1::from_vec(sorted_eigenvalues);
508
509        // Cumulative variance
510        let total_variance = eigenvalues.sum();
511        let mut cumulative_variance = Array1::zeros(eigenvalues.len());
512        let mut cumsum = 0.0;
513        for i in 0..eigenvalues.len() {
514            cumsum += eigenvalues[i];
515            cumulative_variance[i] = cumsum / total_variance;
516        }
517
518        // KMO measure
519        let kmo_measure = kmo_test(data)?;
520
521        // Bartlett's test
522        let bartlett_test = bartlett_test(data)?;
523
524        Ok(DimensionalityMetrics {
525            eigenvalues,
526            cumulative_variance,
527            kmo_measure,
528            bartlett_test,
529        })
530    }
531}
532
533/// Quasi-Monte Carlo integration and optimization workflow
534#[derive(Debug, Clone)]
535pub struct QMCWorkflow {
536    /// Sequence type
537    pub sequence_type: QMCSequenceType,
538    /// Whether to use scrambling
539    pub scrambling: bool,
540    /// Number of dimensions
541    pub dimensions: usize,
542    /// Number of samples
543    pub n_samples_: usize,
544    /// Random seed
545    pub random_seed: Option<u64>,
546}
547
548/// QMC sequence types
549#[derive(Debug, Clone, Copy)]
550pub enum QMCSequenceType {
551    /// Sobol sequence
552    Sobol,
553    /// Halton sequence
554    Halton,
555    /// Latin Hypercube sampling
556    LatinHypercube,
557}
558
559/// QMC analysis results
560#[derive(Debug, Clone)]
561pub struct QMCResult {
562    /// Generated samples
563    pub samples: Array2<f64>,
564    /// Sequence type used
565    pub sequence_type: QMCSequenceType,
566    /// Quality metrics
567    pub quality_metrics: QMCQualityMetrics,
568}
569
570/// Quality metrics for QMC sequences
571#[derive(Debug, Clone)]
572pub struct QMCQualityMetrics {
573    /// Star discrepancy
574    pub star_discrepancy: f64,
575    /// Uniformity measure
576    pub uniformity: f64,
577    /// Coverage efficiency
578    pub coverage_efficiency: f64,
579}
580
581impl Default for QMCWorkflow {
582    fn default() -> Self {
583        Self {
584            sequence_type: QMCSequenceType::Sobol,
585            scrambling: true,
586            dimensions: 2,
587            n_samples_: 1000,
588            random_seed: None,
589        }
590    }
591}
592
593impl QMCWorkflow {
594    /// Create new QMC workflow
595    pub fn new(dimensions: usize, n_samples_: usize) -> Self {
596        Self {
597            dimensions,
598            n_samples_,
599            ..Default::default()
600        }
601    }
602
603    /// Set sequence type
604    pub fn with_sequence_type(mut self, sequence_type: QMCSequenceType) -> Self {
605        self.sequence_type = sequence_type;
606        self
607    }
608
609    /// Enable or disable scrambling
610    pub fn with_scrambling(mut self, scrambling: bool) -> Self {
611        self.scrambling = scrambling;
612        self
613    }
614
615    /// Set random seed
616    pub fn with_seed(mut self, seed: u64) -> Self {
617        self.random_seed = Some(seed);
618        self
619    }
620
621    /// Generate QMC samples with quality assessment
622    pub fn generate(&self) -> Result<QMCResult> {
623        check_positive(self.dimensions, "dimensions")?;
624        check_positive(self.n_samples_, "n_samples_")?;
625
626        // Generate samples based on sequence type
627        let samples = match self.sequence_type {
628            QMCSequenceType::Sobol => sobol(
629                self.n_samples_,
630                self.dimensions,
631                self.scrambling,
632                self.random_seed,
633            )?,
634            QMCSequenceType::Halton => halton(
635                self.n_samples_,
636                self.dimensions,
637                self.scrambling,
638                self.random_seed,
639            )?,
640            QMCSequenceType::LatinHypercube => {
641                latin_hypercube(self.n_samples_, self.dimensions, self.random_seed)?
642            }
643        };
644
645        // Compute quality metrics
646        let quality_metrics = self.compute_quality_metrics(&samples)?;
647
648        Ok(QMCResult {
649            samples,
650            sequence_type: self.sequence_type,
651            quality_metrics,
652        })
653    }
654
655    /// Compute quality metrics for the sequence
656    fn compute_quality_metrics(&self, samples: &Array2<f64>) -> Result<QMCQualityMetrics> {
657        use crate::qmc::star_discrepancy;
658
659        // Convert to format expected by star_discrepancy
660        let sample_points: Vec<Array1<f64>> = samples
661            .rows()
662            .into_iter()
663            .map(|row| row.to_owned())
664            .collect();
665
666        let samples_view = Array1::from_vec(sample_points);
667        let star_discrepancy = star_discrepancy(&samples_view.view())?;
668
669        // Compute uniformity measure (coefficient of variation of nearest neighbor distances)
670        let uniformity = self.compute_uniformity(samples)?;
671
672        // Compute coverage efficiency
673        let coverage_efficiency = self.compute_coverage_efficiency(samples)?;
674
675        Ok(QMCQualityMetrics {
676            star_discrepancy,
677            uniformity,
678            coverage_efficiency,
679        })
680    }
681
682    /// Compute uniformity measure
683    fn compute_uniformity(&self, samples: &Array2<f64>) -> Result<f64> {
684        let n_samples_ = samples.nrows();
685        let mut min_distances = Array1::zeros(n_samples_);
686
687        // Compute minimum distance to other points for each sample
688        for i in 0..n_samples_ {
689            let mut min_dist = f64::INFINITY;
690            for j in 0..n_samples_ {
691                if i != j {
692                    let mut dist = 0.0;
693                    for k in 0..self.dimensions {
694                        let diff = samples[[i, k]] - samples[[j, k]];
695                        dist += diff * diff;
696                    }
697                    dist = dist.sqrt();
698                    if dist < min_dist {
699                        min_dist = dist;
700                    }
701                }
702            }
703            min_distances[i] = min_dist;
704        }
705
706        // Coefficient of variation of minimum distances
707        let mean_dist = min_distances.mean().expect("Operation failed");
708        let var_dist = min_distances.var(1.0);
709        let uniformity = 1.0 / (var_dist.sqrt() / mean_dist); // Inverse CV
710
711        Ok(uniformity)
712    }
713
714    /// Compute coverage efficiency
715    fn compute_coverage_efficiency(&self, samples: &Array2<f64>) -> Result<f64> {
716        // Simple approximation: ratio of actual coverage to expected coverage
717        let n_bins = (self.n_samples_ as f64)
718            .powf(1.0 / self.dimensions as f64)
719            .ceil() as usize;
720        let mut occupied_bins = std::collections::HashSet::new();
721
722        for i in 0..samples.nrows() {
723            let mut bin_id = Vec::new();
724            for j in 0..self.dimensions {
725                let bin = (samples[[i, j]] * n_bins as f64).floor() as usize;
726                bin_id.push(bin.min(n_bins - 1));
727            }
728            occupied_bins.insert(bin_id);
729        }
730
731        let total_bins = n_bins.pow(self.dimensions as u32);
732        let coverage_efficiency = occupied_bins.len() as f64 / total_bins as f64;
733
734        Ok(coverage_efficiency)
735    }
736}
737
738/// Comprehensive survival analysis workflow
739#[derive(Debug, Clone)]
740pub struct SurvivalAnalysisWorkflow {
741    /// Confidence level for intervals
742    pub confidence_level: f64,
743    /// Whether to fit Cox model
744    pub fit_cox_model: bool,
745    /// Maximum iterations for Cox model
746    pub cox_max_iter: usize,
747    /// Convergence tolerance for Cox model
748    pub cox_tolerance: f64,
749}
750
751impl Default for SurvivalAnalysisWorkflow {
752    fn default() -> Self {
753        Self {
754            confidence_level: 0.95,
755            fit_cox_model: true,
756            cox_max_iter: 100,
757            cox_tolerance: 1e-6,
758        }
759    }
760}
761
762/// Comprehensive survival analysis results
763#[derive(Debug, Clone)]
764pub struct SurvivalAnalysisResult {
765    /// Kaplan-Meier estimator
766    pub kaplan_meier: crate::survival::KaplanMeierEstimator,
767    /// Cox proportional hazards model (if requested and covariates provided)
768    pub cox_model: Option<crate::survival::CoxPHModel>,
769    /// Survival summary statistics
770    pub summary_stats: SurvivalSummaryStats,
771}
772
773/// Summary statistics for survival analysis
774#[derive(Debug, Clone)]
775pub struct SurvivalSummaryStats {
776    /// Median survival time
777    pub median_survival: Option<f64>,
778    /// 25th percentile survival time
779    pub q25_survival: Option<f64>,
780    /// 75th percentile survival time
781    pub q75_survival: Option<f64>,
782    /// Event rate
783    pub event_rate: f64,
784    /// Censoring rate
785    pub censoring_rate: f64,
786}
787
788impl SurvivalAnalysisWorkflow {
789    /// Create new survival analysis workflow
790    pub fn new() -> Self {
791        Self::default()
792    }
793
794    /// Set confidence level
795    pub fn with_confidence_level(mut self, level: f64) -> Self {
796        self.confidence_level = level;
797        self
798    }
799
800    /// Configure Cox model fitting
801    pub fn with_cox_model(mut self, max_iter: usize, tolerance: f64) -> Self {
802        self.fit_cox_model = true;
803        self.cox_max_iter = max_iter;
804        self.cox_tolerance = tolerance;
805        self
806    }
807
808    /// Disable Cox model fitting
809    pub fn without_cox_model(mut self) -> Self {
810        self.fit_cox_model = false;
811        self
812    }
813
814    /// Perform comprehensive survival analysis
815    pub fn analyze(
816        &self,
817        durations: ArrayView1<f64>,
818        events: ArrayView1<bool>,
819        covariates: Option<ArrayView2<f64>>,
820    ) -> Result<SurvivalAnalysisResult> {
821        checkarray_finite(&durations, "durations")?;
822
823        if durations.len() != events.len() {
824            return Err(StatsError::DimensionMismatch(format!(
825                "durations length ({}) must match events length ({})",
826                durations.len(),
827                events.len()
828            )));
829        }
830
831        // Fit Kaplan-Meier estimator
832        let kaplan_meier =
833            KaplanMeierEstimator::fit(durations, events, Some(self.confidence_level))?;
834
835        // Fit Cox model if requested and covariates provided
836        let cox_model = if self.fit_cox_model {
837            if let Some(cov) = covariates {
838                Some(CoxPHModel::fit(
839                    durations,
840                    events,
841                    cov,
842                    Some(self.cox_max_iter),
843                    Some(self.cox_tolerance),
844                )?)
845            } else {
846                None
847            }
848        } else {
849            None
850        };
851
852        // Compute summary statistics
853        let summary_stats = self.compute_summary_stats(&durations, &events, &kaplan_meier)?;
854
855        Ok(SurvivalAnalysisResult {
856            kaplan_meier,
857            cox_model,
858            summary_stats,
859        })
860    }
861
862    /// Compute survival summary statistics
863    fn compute_summary_stats(
864        &self,
865        _durations: &ArrayView1<f64>,
866        events: &ArrayView1<bool>,
867        km: &KaplanMeierEstimator,
868    ) -> Result<SurvivalSummaryStats> {
869        // Event and censoring rates
870        let total_events: usize = events.iter().map(|&e| if e { 1 } else { 0 }).sum();
871        let total_observations = events.len();
872        let event_rate = total_events as f64 / total_observations as f64;
873        let censoring_rate = 1.0 - event_rate;
874
875        // Median survival time (already computed in KM estimator)
876        let median_survival = km.median_survival_time;
877
878        // Percentile survival times
879        let q25_survival = self.find_survival_percentile(km, 0.75)?; // 75% survival = 25th percentile time
880        let q75_survival = self.find_survival_percentile(km, 0.25)?; // 25% survival = 75th percentile time
881
882        Ok(SurvivalSummaryStats {
883            median_survival,
884            q25_survival,
885            q75_survival,
886            event_rate,
887            censoring_rate,
888        })
889    }
890
891    /// Find time at which survival probability equals target
892    fn find_survival_percentile(
893        &self,
894        km: &KaplanMeierEstimator,
895        target_survival: f64,
896    ) -> Result<Option<f64>> {
897        for i in 0..km.survival_function.len() {
898            if km.survival_function[i] <= target_survival {
899                return Ok(Some(km.event_times[i]));
900            }
901        }
902        Ok(None) // Target survival not reached
903    }
904}