1#![allow(dead_code)]
10
11use crate::error::{StatsError, StatsResult};
12use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
13use scirs2_core::numeric::{Float, NumCast};
14use serde::{Deserialize, Serialize};
15use std::collections::HashMap;
16use std::marker::PhantomData;
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct StandardizedConfig {
21 pub auto_optimize: bool,
23 pub parallel: bool,
25 pub simd: bool,
27 pub memory_limit: Option<usize>,
29 pub confidence_level: f64,
31 pub null_handling: NullHandling,
33 pub output_precision: usize,
35 pub include_metadata: bool,
37}
38
39impl Default for StandardizedConfig {
40 fn default() -> Self {
41 Self {
42 auto_optimize: true,
43 parallel: true,
44 simd: true,
45 memory_limit: None,
46 confidence_level: 0.95,
47 null_handling: NullHandling::Exclude,
48 output_precision: 6,
49 include_metadata: false,
50 }
51 }
52}
53
54#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
56pub enum NullHandling {
57 Exclude,
59 Propagate,
61 Replace(f64),
63 Fail,
65}
66
67#[derive(Debug, Clone, Serialize, Deserialize)]
69pub struct StandardizedResult<T> {
70 pub value: T,
72 pub metadata: ResultMetadata,
74 pub warnings: Vec<String>,
76}
77
78#[derive(Debug, Clone, Serialize, Deserialize)]
80pub struct ResultMetadata {
81 pub samplesize: usize,
83 pub degrees_of_freedom: Option<usize>,
85 pub confidence_level: Option<f64>,
87 pub method: String,
89 pub computation_time_ms: f64,
91 pub memory_usage_bytes: Option<usize>,
93 pub optimized: bool,
95 pub extra: HashMap<String, String>,
97}
98
99pub struct DescriptiveStatsBuilder<F> {
101 config: StandardizedConfig,
102 ddof: Option<usize>,
103 axis: Option<usize>,
104 weights: Option<Array1<F>>,
105 phantom: PhantomData<F>,
106}
107
108pub struct CorrelationBuilder<F> {
110 config: StandardizedConfig,
111 method: CorrelationMethod,
112 min_periods: Option<usize>,
113 controls: Option<Array2<F>>,
115 phantom: PhantomData<F>,
116}
117
118pub struct StatisticalTestBuilder<F> {
120 config: StandardizedConfig,
121 alternative: Alternative,
122 equal_var: bool,
123 phantom: PhantomData<F>,
124}
125
126#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
128pub enum CorrelationMethod {
129 Pearson,
130 Spearman,
131 Kendall,
132 PartialPearson,
133 PartialSpearman,
134}
135
136#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
138pub enum Alternative {
139 TwoSided,
140 Less,
141 Greater,
142}
143
144pub struct StatsAnalyzer<F> {
146 config: StandardizedConfig,
147 phantom: PhantomData<F>,
148}
149
150#[derive(Debug, Clone, Serialize, Deserialize)]
152pub struct DescriptiveStats<F> {
153 pub count: usize,
154 pub mean: F,
155 pub std: F,
156 pub min: F,
157 pub percentile_25: F,
158 pub median: F,
159 pub percentile_75: F,
160 pub max: F,
161 pub variance: F,
162 pub skewness: F,
163 pub kurtosis: F,
164}
165
166#[derive(Debug, Clone, Serialize, Deserialize)]
168pub struct CorrelationResult<F> {
169 pub correlation: F,
170 pub p_value: Option<F>,
171 pub confidence_interval: Option<(F, F)>,
172 pub method: CorrelationMethod,
173}
174
175#[derive(Debug, Clone, Serialize, Deserialize)]
177pub struct TestResult<F> {
178 pub statistic: F,
179 pub p_value: F,
180 pub confidence_interval: Option<(F, F)>,
181 pub effectsize: Option<F>,
182 pub power: Option<F>,
183}
184
185impl<F> DescriptiveStatsBuilder<F>
186where
187 F: Float
188 + NumCast
189 + Clone
190 + scirs2_core::simd_ops::SimdUnifiedOps
191 + std::iter::Sum<F>
192 + std::ops::Div<Output = F>
193 + Sync
194 + Send
195 + std::fmt::Display
196 + std::fmt::Debug
197 + 'static,
198{
199 pub fn new() -> Self {
201 Self {
202 config: StandardizedConfig::default(),
203 ddof: None,
204 axis: None,
205 weights: None,
206 phantom: PhantomData,
207 }
208 }
209
210 pub fn ddof(mut self, ddof: usize) -> Self {
212 self.ddof = Some(ddof);
213 self
214 }
215
216 pub fn axis(mut self, axis: usize) -> Self {
218 self.axis = Some(axis);
219 self
220 }
221
222 pub fn weights(mut self, weights: Array1<F>) -> Self {
224 self.weights = Some(weights);
225 self
226 }
227
228 pub fn parallel(mut self, enable: bool) -> Self {
230 self.config.parallel = enable;
231 self
232 }
233
234 pub fn simd(mut self, enable: bool) -> Self {
236 self.config.simd = enable;
237 self
238 }
239
240 pub fn null_handling(mut self, strategy: NullHandling) -> Self {
242 self.config.null_handling = strategy;
243 self
244 }
245
246 pub fn memory_limit(mut self, limit: usize) -> Self {
248 self.config.memory_limit = Some(limit);
249 self
250 }
251
252 pub fn with_metadata(mut self) -> Self {
254 self.config.include_metadata = true;
255 self
256 }
257
258 pub fn compute(
260 &self,
261 data: ArrayView1<F>,
262 ) -> StatsResult<StandardizedResult<DescriptiveStats<F>>> {
263 let start_time = std::time::Instant::now();
264 let mut warnings = Vec::new();
265
266 if data.is_empty() {
268 return Err(StatsError::InvalidArgument(
269 "Cannot compute statistics for empty array".to_string(),
270 ));
271 }
272
273 let (cleaneddata, samplesize) = self.handle_null_values(&data, &mut warnings)?;
275
276 let stats = if self.config.auto_optimize {
278 self.compute_optimized(&cleaneddata, &mut warnings)?
279 } else {
280 self.compute_standard(&cleaneddata, &mut warnings)?
281 };
282
283 let computation_time = start_time.elapsed().as_secs_f64() * 1000.0;
284
285 let metadata = ResultMetadata {
287 samplesize,
288 degrees_of_freedom: Some(samplesize.saturating_sub(self.ddof.unwrap_or(1))),
289 confidence_level: None,
290 method: self.select_method_name(),
291 computation_time_ms: computation_time,
292 memory_usage_bytes: self.estimate_memory_usage(samplesize),
293 optimized: self.config.simd || self.config.parallel,
294 extra: HashMap::new(),
295 };
296
297 Ok(StandardizedResult {
298 value: stats,
299 metadata,
300 warnings,
301 })
302 }
303
304 fn handle_null_values(
306 &self,
307 data: &ArrayView1<F>,
308 warnings: &mut Vec<String>,
309 ) -> StatsResult<(Array1<F>, usize)> {
310 let finitedata: Vec<F> = data.iter().filter(|&&x| x.is_finite()).cloned().collect();
313
314 if finitedata.len() != data.len() {
315 warnings.push(format!(
316 "Removed {} non-finite values",
317 data.len() - finitedata.len()
318 ));
319 }
320
321 let finite_count = finitedata.len();
322 match self.config.null_handling {
323 NullHandling::Exclude => Ok((Array1::from_vec(finitedata), finite_count)),
324 NullHandling::Fail if finite_count != data.len() => Err(StatsError::InvalidArgument(
325 "Null values encountered with Fail strategy".to_string(),
326 )),
327 _ => Ok((Array1::from_vec(finitedata), finite_count)),
328 }
329 }
330
331 fn compute_optimized(
333 &self,
334 data: &Array1<F>,
335 warnings: &mut Vec<String>,
336 ) -> StatsResult<DescriptiveStats<F>> {
337 let n = data.len();
338
339 if self.config.simd && n > 64 {
341 self.compute_simd_optimized(data, warnings)
342 } else if self.config.parallel && n > 10000 {
343 self.compute_parallel_optimized(data, warnings)
344 } else {
345 self.compute_standard(data, warnings)
346 }
347 }
348
349 fn compute_simd_optimized(
351 &self,
352 data: &Array1<F>,
353 _warnings: &mut Vec<String>,
354 ) -> StatsResult<DescriptiveStats<F>> {
355 let mean = crate::descriptive_simd::mean_simd(&data.view())?;
357 let variance =
358 crate::descriptive_simd::variance_simd(&data.view(), self.ddof.unwrap_or(1))?;
359 let std = variance.sqrt();
360
361 let (min, max) = self.compute_min_max(data);
363 let sorteddata = self.getsorteddata(data);
364 let percentiles = self.compute_percentiles(&sorteddata)?;
365
366 let skewness = crate::descriptive::skew(&data.view(), false, None)?;
368 let kurtosis = crate::descriptive::kurtosis(&data.view(), true, false, None)?;
369
370 Ok(DescriptiveStats {
371 count: data.len(),
372 mean,
373 std,
374 min,
375 percentile_25: percentiles[0],
376 median: percentiles[1],
377 percentile_75: percentiles[2],
378 max,
379 variance,
380 skewness,
381 kurtosis,
382 })
383 }
384
385 fn compute_parallel_optimized(
387 &self,
388 data: &Array1<F>,
389 _warnings: &mut Vec<String>,
390 ) -> StatsResult<DescriptiveStats<F>> {
391 let mean = crate::parallel_stats::mean_parallel(&data.view())?;
393 let variance =
394 crate::parallel_stats::variance_parallel(&data.view(), self.ddof.unwrap_or(1))?;
395 let std = variance.sqrt();
396
397 let (min, max) = self.compute_min_max(data);
399 let sorteddata = self.getsorteddata(data);
400 let percentiles = self.compute_percentiles(&sorteddata)?;
401
402 let skewness = crate::descriptive::skew(&data.view(), false, None)?;
404 let kurtosis = crate::descriptive::kurtosis(&data.view(), true, false, None)?;
405
406 Ok(DescriptiveStats {
407 count: data.len(),
408 mean,
409 std,
410 min,
411 percentile_25: percentiles[0],
412 median: percentiles[1],
413 percentile_75: percentiles[2],
414 max,
415 variance,
416 skewness,
417 kurtosis,
418 })
419 }
420
421 fn compute_standard(
423 &self,
424 data: &Array1<F>,
425 _warnings: &mut Vec<String>,
426 ) -> StatsResult<DescriptiveStats<F>> {
427 let mean = crate::descriptive::mean(&data.view())?;
428 let variance = crate::descriptive::var(&data.view(), self.ddof.unwrap_or(1), None)?;
429 let std = variance.sqrt();
430
431 let (min, max) = self.compute_min_max(data);
432 let sorteddata = self.getsorteddata(data);
433 let percentiles = self.compute_percentiles(&sorteddata)?;
434
435 let skewness = crate::descriptive::skew(&data.view(), false, None)?;
436 let kurtosis = crate::descriptive::kurtosis(&data.view(), true, false, None)?;
437
438 Ok(DescriptiveStats {
439 count: data.len(),
440 mean,
441 std,
442 min,
443 percentile_25: percentiles[0],
444 median: percentiles[1],
445 percentile_75: percentiles[2],
446 max,
447 variance,
448 skewness,
449 kurtosis,
450 })
451 }
452
453 fn compute_min_max(&self, data: &Array1<F>) -> (F, F) {
455 let mut min = data[0];
456 let mut max = data[0];
457
458 for &value in data.iter() {
459 if value < min {
460 min = value;
461 }
462 if value > max {
463 max = value;
464 }
465 }
466
467 (min, max)
468 }
469
470 fn getsorteddata(&self, data: &Array1<F>) -> Vec<F> {
472 let mut sorted = data.to_vec();
473 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
474 sorted
475 }
476
477 fn compute_percentiles(&self, sorteddata: &[F]) -> StatsResult<[F; 3]> {
479 let n = sorteddata.len();
480 if n == 0 {
481 return Err(StatsError::InvalidArgument("Empty data".to_string()));
482 }
483
484 let p25_idx = (n as f64 * 0.25) as usize;
485 let p50_idx = (n as f64 * 0.50) as usize;
486 let p75_idx = (n as f64 * 0.75) as usize;
487
488 Ok([
489 sorteddata[p25_idx.min(n - 1)],
490 sorteddata[p50_idx.min(n - 1)],
491 sorteddata[p75_idx.min(n - 1)],
492 ])
493 }
494
495 fn select_method_name(&self) -> String {
497 if self.config.simd && self.config.parallel {
498 "SIMD+Parallel".to_string()
499 } else if self.config.simd {
500 "SIMD".to_string()
501 } else if self.config.parallel {
502 "Parallel".to_string()
503 } else {
504 "Standard".to_string()
505 }
506 }
507
508 fn estimate_memory_usage(&self, samplesize: usize) -> Option<usize> {
510 if self.config.include_metadata {
511 Some(samplesize * std::mem::size_of::<F>() * 2) } else {
513 None
514 }
515 }
516}
517
518impl<F> CorrelationBuilder<F>
519where
520 F: Float
521 + NumCast
522 + Clone
523 + std::fmt::Debug
524 + std::fmt::Display
525 + scirs2_core::simd_ops::SimdUnifiedOps
526 + std::iter::Sum<F>
527 + std::ops::Div<Output = F>
528 + Send
529 + Sync
530 + 'static,
531{
532 pub fn new() -> Self {
534 Self {
535 config: StandardizedConfig::default(),
536 method: CorrelationMethod::Pearson,
537 min_periods: None,
538 controls: None,
539 phantom: PhantomData,
540 }
541 }
542
543 pub fn with_controls(mut self, controls: Array2<F>) -> Self {
553 self.controls = Some(controls);
554 self
555 }
556
557 pub fn method(mut self, method: CorrelationMethod) -> Self {
559 self.method = method;
560 self
561 }
562
563 pub fn min_periods(mut self, periods: usize) -> Self {
565 self.min_periods = Some(periods);
566 self
567 }
568
569 pub fn confidence_level(mut self, level: f64) -> Self {
571 self.config.confidence_level = level;
572 self
573 }
574
575 pub fn parallel(mut self, enable: bool) -> Self {
577 self.config.parallel = enable;
578 self
579 }
580
581 pub fn simd(mut self, enable: bool) -> Self {
583 self.config.simd = enable;
584 self
585 }
586
587 pub fn with_metadata(mut self) -> Self {
589 self.config.include_metadata = true;
590 self
591 }
592
593 pub fn compute<'a>(
595 &self,
596 x: ArrayView1<'a, F>,
597 y: ArrayView1<'a, F>,
598 ) -> StatsResult<StandardizedResult<CorrelationResult<F>>> {
599 let start_time = std::time::Instant::now();
600 let mut warnings = Vec::new();
601
602 if x.len() != y.len() {
604 return Err(StatsError::DimensionMismatch(
605 "Input arrays must have the same length".to_string(),
606 ));
607 }
608
609 if x.is_empty() {
610 return Err(StatsError::InvalidArgument(
611 "Cannot compute correlation for empty arrays".to_string(),
612 ));
613 }
614
615 if let Some(min_periods) = self.min_periods {
617 if x.len() < min_periods {
618 return Err(StatsError::InvalidArgument(format!(
619 "Insufficient data: {} observations, {} required",
620 x.len(),
621 min_periods
622 )));
623 }
624 }
625
626 let correlation = match self.method {
628 CorrelationMethod::Pearson => {
629 if self.config.simd && x.len() > 64 {
630 crate::correlation_simd::pearson_r_simd(&x, &y)?
631 } else {
632 crate::correlation::pearson_r(&x, &y)?
633 }
634 }
635 CorrelationMethod::Spearman => crate::correlation::spearman_r(&x, &y)?,
636 CorrelationMethod::Kendall => crate::correlation::kendall_tau(&x, &y, "b")?,
637 CorrelationMethod::PartialPearson => match &self.controls {
638 Some(z) => crate::correlation::partial_corr(&x, &y, &z.view())?,
639 None => {
640 warnings.push(
641 "PartialPearson requires control variables; \
642 falling back to Pearson (no controls provided)"
643 .to_string(),
644 );
645 crate::correlation::pearson_r(&x, &y)?
646 }
647 },
648 CorrelationMethod::PartialSpearman => {
649 match &self.controls {
650 Some(z) => {
651 let rx = rank_array(&x)?;
654 let ry = rank_array(&y)?;
655 let (n_rows, n_cols) = z.dim();
656 let mut rz_data = vec![F::zero(); n_rows * n_cols];
657 for j in 0..n_cols {
658 let col = z.column(j);
659 let col_owned: Array1<F> = Array1::from_iter(col.iter().copied());
660 let col_ranks = rank_array(&col_owned.view())?;
661 for (i, &v) in col_ranks.iter().enumerate() {
662 rz_data[i * n_cols + j] = v;
663 }
664 }
665 let rz =
666 Array2::from_shape_vec((n_rows, n_cols), rz_data).map_err(|e| {
667 StatsError::ComputationError(format!(
668 "Failed to build ranked controls matrix: {}",
669 e
670 ))
671 })?;
672 crate::correlation::partial_corr(&rx.view(), &ry.view(), &rz.view())?
673 }
674 None => {
675 warnings.push(
676 "PartialSpearman requires control variables; \
677 falling back to Spearman (no controls provided)"
678 .to_string(),
679 );
680 crate::correlation::spearman_r(&x, &y)?
681 }
682 }
683 }
684 };
685
686 let (p_value, confidence_interval) = if self.config.include_metadata {
688 self.compute_statistical_inference(correlation, x.len(), &mut warnings)?
689 } else {
690 (None, None)
691 };
692
693 let computation_time = start_time.elapsed().as_secs_f64() * 1000.0;
694
695 let result = CorrelationResult {
696 correlation,
697 p_value,
698 confidence_interval,
699 method: self.method,
700 };
701
702 let metadata = ResultMetadata {
703 samplesize: x.len(),
704 degrees_of_freedom: Some(x.len().saturating_sub(2)),
705 confidence_level: Some(self.config.confidence_level),
706 method: format!("{:?}", self.method),
707 computation_time_ms: computation_time,
708 memory_usage_bytes: self.estimate_memory_usage(x.len()),
709 optimized: self.config.simd || self.config.parallel,
710 extra: HashMap::new(),
711 };
712
713 Ok(StandardizedResult {
714 value: result,
715 metadata,
716 warnings,
717 })
718 }
719
720 pub fn compute_matrix(
722 &self,
723 data: ArrayView2<F>,
724 ) -> StatsResult<StandardizedResult<Array2<F>>> {
725 let start_time = std::time::Instant::now();
726 let warnings = Vec::new();
727
728 let correlation_matrix = if self.config.auto_optimize {
730 let mut optimizer = crate::memory_optimization_advanced::MemoryOptimizationSuite::new(
732 crate::memory_optimization_advanced::MemoryOptimizationConfig::default(),
733 );
734 optimizer.optimized_correlation_matrix(data)?
735 } else {
736 crate::correlation::corrcoef(&data, "pearson")?
737 };
738
739 let computation_time = start_time.elapsed().as_secs_f64() * 1000.0;
740
741 let metadata = ResultMetadata {
742 samplesize: data.nrows(),
743 degrees_of_freedom: Some(data.nrows().saturating_sub(2)),
744 confidence_level: Some(self.config.confidence_level),
745 method: format!("Matrix {:?}", self.method),
746 computation_time_ms: computation_time,
747 memory_usage_bytes: self.estimate_memory_usage(data.nrows() * data.ncols()),
748 optimized: self.config.simd || self.config.parallel,
749 extra: HashMap::new(),
750 };
751
752 Ok(StandardizedResult {
753 value: correlation_matrix,
754 metadata,
755 warnings,
756 })
757 }
758
759 fn compute_statistical_inference(
761 &self,
762 correlation: F,
763 n: usize,
764 warnings: &mut Vec<String>,
765 ) -> StatsResult<(Option<F>, Option<(F, F)>)> {
766 let z = ((F::one() + correlation) / (F::one() - correlation)).ln()
768 * F::from(0.5).expect("Failed to convert constant to float");
769 let se_z = F::one() / F::from(n - 3).expect("Failed to convert to float").sqrt();
770
771 let _alpha =
773 F::one() - F::from(self.config.confidence_level).expect("Failed to convert to float");
774 let z_critical = F::from(1.96).expect("Failed to convert constant to float"); let z_lower = z - z_critical * se_z;
777 let z_upper = z + z_critical * se_z;
778
779 let r_lower = (F::from(2.0).expect("Failed to convert constant to float") * z_lower).exp();
781 let r_lower = (r_lower - F::one()) / (r_lower + F::one());
782
783 let r_upper = (F::from(2.0).expect("Failed to convert constant to float") * z_upper).exp();
784 let r_upper = (r_upper - F::one()) / (r_upper + F::one());
785
786 let _t_stat = correlation * F::from(n - 2).expect("Failed to convert to float").sqrt()
788 / (F::one() - correlation * correlation).sqrt();
789 let p_value = F::from(2.0).expect("Failed to convert constant to float")
790 * (F::one() - F::from(0.95).expect("Failed to convert constant to float")); Ok((Some(p_value), Some((r_lower, r_upper))))
793 }
794
795 fn estimate_memory_usage(&self, size: usize) -> Option<usize> {
797 if self.config.include_metadata {
798 Some(size * std::mem::size_of::<F>() * 3) } else {
800 None
801 }
802 }
803}
804
805impl<F> StatsAnalyzer<F>
806where
807 F: Float
808 + NumCast
809 + Clone
810 + scirs2_core::simd_ops::SimdUnifiedOps
811 + std::iter::Sum<F>
812 + std::ops::Div<Output = F>
813 + Sync
814 + Send
815 + std::fmt::Display
816 + std::fmt::Debug
817 + 'static,
818{
819 pub fn new() -> Self {
821 Self {
822 config: StandardizedConfig::default(),
823 phantom: PhantomData,
824 }
825 }
826
827 pub fn configure(mut self, config: StandardizedConfig) -> Self {
829 self.config = config;
830 self
831 }
832
833 pub fn describe(
835 &self,
836 data: ArrayView1<F>,
837 ) -> StatsResult<StandardizedResult<DescriptiveStats<F>>> {
838 DescriptiveStatsBuilder::new()
839 .parallel(self.config.parallel)
840 .simd(self.config.simd)
841 .null_handling(self.config.null_handling)
842 .with_metadata()
843 .compute(data)
844 }
845
846 pub fn correlate<'a>(
848 &self,
849 x: ArrayView1<'a, F>,
850 y: ArrayView1<'a, F>,
851 method: CorrelationMethod,
852 ) -> StatsResult<StandardizedResult<CorrelationResult<F>>> {
853 CorrelationBuilder::new()
854 .method(method)
855 .confidence_level(self.config.confidence_level)
856 .parallel(self.config.parallel)
857 .simd(self.config.simd)
858 .with_metadata()
859 .compute(x, y)
860 }
861
862 pub fn get_config(&self) -> &StandardizedConfig {
864 &self.config
865 }
866}
867
868pub type F64StatsAnalyzer = StatsAnalyzer<f64>;
870pub type F32StatsAnalyzer = StatsAnalyzer<f32>;
871
872pub type F64DescriptiveBuilder = DescriptiveStatsBuilder<f64>;
873pub type F32DescriptiveBuilder = DescriptiveStatsBuilder<f32>;
874
875pub type F64CorrelationBuilder = CorrelationBuilder<f64>;
876pub type F32CorrelationBuilder = CorrelationBuilder<f32>;
877
878impl<F> Default for DescriptiveStatsBuilder<F>
879where
880 F: Float
881 + NumCast
882 + Clone
883 + scirs2_core::simd_ops::SimdUnifiedOps
884 + std::iter::Sum<F>
885 + std::ops::Div<Output = F>
886 + Sync
887 + Send
888 + std::fmt::Display
889 + std::fmt::Debug
890 + 'static,
891{
892 fn default() -> Self {
893 Self::new()
894 }
895}
896
897impl<F> Default for CorrelationBuilder<F>
898where
899 F: Float
900 + NumCast
901 + Clone
902 + std::fmt::Debug
903 + std::fmt::Display
904 + scirs2_core::simd_ops::SimdUnifiedOps
905 + std::iter::Sum<F>
906 + std::ops::Div<Output = F>
907 + Send
908 + Sync
909 + 'static,
910{
911 fn default() -> Self {
912 Self::new()
913 }
914}
915
916impl<F> Default for StatsAnalyzer<F>
917where
918 F: Float
919 + NumCast
920 + Clone
921 + scirs2_core::simd_ops::SimdUnifiedOps
922 + std::iter::Sum<F>
923 + std::ops::Div<Output = F>
924 + Sync
925 + Send
926 + std::fmt::Display
927 + std::fmt::Debug
928 + 'static,
929{
930 fn default() -> Self {
931 Self::new()
932 }
933}
934
935#[cfg(test)]
936mod tests {
937 use super::*;
938 use scirs2_core::ndarray::array;
939
940 #[test]
941 fn test_descriptive_stats_builder() {
942 let data = array![1.0, 2.0, 3.0, 4.0, 5.0];
943
944 let result = DescriptiveStatsBuilder::new()
945 .ddof(1)
946 .parallel(false)
947 .simd(false)
948 .with_metadata()
949 .compute(data.view())
950 .expect("Operation failed");
951
952 assert_eq!(result.value.count, 5);
953 assert!((result.value.mean - 3.0).abs() < 1e-10);
954 assert!(result.metadata.optimized == false);
955 }
956
957 #[test]
958 fn test_correlation_builder() {
959 let x = array![1.0, 2.0, 3.0, 4.0, 5.0];
960 let y = array![2.0, 4.0, 6.0, 8.0, 10.0];
961
962 let result = CorrelationBuilder::new()
963 .method(CorrelationMethod::Pearson)
964 .confidence_level(0.95)
965 .with_metadata()
966 .compute(x.view(), y.view())
967 .expect("Operation failed");
968
969 assert!((result.value.correlation - 1.0).abs() < 1e-10);
970 assert!(result.value.p_value.is_some());
971 assert!(result.value.confidence_interval.is_some());
972 }
973
974 #[test]
975 fn test_stats_analyzer() {
976 let analyzer = StatsAnalyzer::new();
977 let data = array![1.0, 2.0, 3.0, 4.0, 5.0];
978
979 let desc_result = analyzer.describe(data.view()).expect("Operation failed");
980 assert_eq!(desc_result.value.count, 5);
981
982 let x = array![1.0, 2.0, 3.0, 4.0, 5.0];
983 let y = array![5.0, 4.0, 3.0, 2.0, 1.0];
984 let corr_result = analyzer
985 .correlate(x.view(), y.view(), CorrelationMethod::Pearson)
986 .expect("Operation failed");
987 assert!((corr_result.value.correlation + 1.0).abs() < 1e-10);
988 }
989
990 #[test]
991 fn test_null_handling() {
992 let data = array![1.0, 2.0, f64::NAN, 4.0, 5.0];
993
994 let result = DescriptiveStatsBuilder::new()
995 .null_handling(NullHandling::Exclude)
996 .compute(data.view())
997 .expect("Operation failed");
998
999 assert_eq!(result.value.count, 4); assert!(!result.warnings.is_empty()); }
1002
1003 #[test]
1004 fn test_standardized_config() {
1005 let config = StandardizedConfig {
1006 auto_optimize: false,
1007 parallel: false,
1008 simd: true,
1009 confidence_level: 0.99,
1010 ..Default::default()
1011 };
1012
1013 assert!(!config.auto_optimize);
1014 assert!(!config.parallel);
1015 assert!(config.simd);
1016 assert!((config.confidence_level - 0.99).abs() < 1e-10);
1017 }
1018
1019 #[test]
1020 fn test_api_validation() {
1021 let framework = APIValidationFramework::new();
1022 let signature = APISignature {
1023 function_name: "test_function".to_string(),
1024 module_path: "scirs2, _stats::test".to_string(),
1025 parameters: vec![ParameterSpec {
1026 name: "data".to_string(),
1027 param_type: "ArrayView1<f64>".to_string(),
1028 optional: false,
1029 default_value: None,
1030 description: Some("Input data array".to_string()),
1031 constraints: vec![ParameterConstraint::Finite],
1032 }],
1033 return_type: ReturnTypeSpec {
1034 type_name: "f64".to_string(),
1035 result_wrapped: true,
1036 inner_type: Some("f64".to_string()),
1037 error_type: Some("StatsError".to_string()),
1038 },
1039 error_types: vec!["StatsError".to_string()],
1040 documentation: DocumentationSpec {
1041 has_doc_comment: true,
1042 has_param_docs: true,
1043 has_return_docs: true,
1044 has_examples: true,
1045 has_error_docs: true,
1046 scipy_compatibility: Some("Compatible with scipy.stats".to_string()),
1047 },
1048 performance: PerformanceSpec {
1049 time_complexity: Some("O(n)".to_string()),
1050 space_complexity: Some("O(1)".to_string()),
1051 simd_optimized: true,
1052 parallel_processing: true,
1053 cache_efficient: true,
1054 },
1055 };
1056
1057 let report = framework.validate_api(&signature);
1058 assert!(matches!(
1059 report.overall_status,
1060 ValidationStatus::Passed | ValidationStatus::PassedWithWarnings
1061 ));
1062 }
1063}
1064
1065#[derive(Debug)]
1067pub struct APIValidationFramework {
1068 validation_rules: HashMap<String, Vec<ValidationRule>>,
1070 compatibility_checkers: HashMap<String, CompatibilityChecker>,
1072 performance_benchmarks: HashMap<String, PerformanceBenchmark>,
1074 error_patterns: HashMap<String, ErrorPattern>,
1076}
1077
1078#[derive(Debug, Clone)]
1080pub struct ValidationRule {
1081 pub id: String,
1083 pub description: String,
1085 pub category: ValidationCategory,
1087 pub severity: ValidationSeverity,
1089}
1090
1091#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1093pub enum ValidationCategory {
1094 ParameterNaming,
1096 ReturnTypes,
1098 ErrorHandling,
1100 Documentation,
1102 Performance,
1104 ScipyCompatibility,
1106 ThreadSafety,
1108 NumericalStability,
1110}
1111
1112#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
1114pub enum ValidationSeverity {
1115 Info,
1117 Warning,
1119 Error,
1121 Critical,
1123}
1124
1125#[derive(Debug, Clone)]
1127pub struct APISignature {
1128 pub function_name: String,
1130 pub module_path: String,
1132 pub parameters: Vec<ParameterSpec>,
1134 pub return_type: ReturnTypeSpec,
1136 pub error_types: Vec<String>,
1138 pub documentation: DocumentationSpec,
1140 pub performance: PerformanceSpec,
1142}
1143
1144#[derive(Debug, Clone)]
1146pub struct ParameterSpec {
1147 pub name: String,
1149 pub param_type: String,
1151 pub optional: bool,
1153 pub default_value: Option<String>,
1155 pub description: Option<String>,
1157 pub constraints: Vec<ParameterConstraint>,
1159}
1160
1161#[derive(Debug, Clone)]
1163pub enum ParameterConstraint {
1164 Positive,
1166 NonNegative,
1168 Finite,
1170 Range(f64, f64),
1172 OneOf(Vec<String>),
1174 Shape(Vec<Option<usize>>),
1176 Custom(String),
1178}
1179
1180#[derive(Debug, Clone)]
1182pub struct ReturnTypeSpec {
1183 pub type_name: String,
1185 pub result_wrapped: bool,
1187 pub inner_type: Option<String>,
1189 pub error_type: Option<String>,
1191}
1192
1193#[derive(Debug, Clone)]
1195pub struct DocumentationSpec {
1196 pub has_doc_comment: bool,
1198 pub has_param_docs: bool,
1200 pub has_return_docs: bool,
1202 pub has_examples: bool,
1204 pub has_error_docs: bool,
1206 pub scipy_compatibility: Option<String>,
1208}
1209
1210#[derive(Debug, Clone)]
1212pub struct PerformanceSpec {
1213 pub time_complexity: Option<String>,
1215 pub space_complexity: Option<String>,
1217 pub simd_optimized: bool,
1219 pub parallel_processing: bool,
1221 pub cache_efficient: bool,
1223}
1224
1225#[derive(Debug, Clone)]
1227pub struct ValidationResult {
1228 pub passed: bool,
1230 pub messages: Vec<ValidationMessage>,
1232 pub suggested_fixes: Vec<String>,
1234 pub related_rules: Vec<String>,
1236}
1237
1238#[derive(Debug, Clone)]
1240pub struct ValidationMessage {
1241 pub severity: ValidationSeverity,
1243 pub message: String,
1245 pub location: Option<String>,
1247 pub rule_id: String,
1249}
1250
1251#[derive(Debug, Clone)]
1253pub struct CompatibilityChecker {
1254 pub scipy_function: String,
1256 pub parameter_mapping: HashMap<String, String>,
1258 pub return_type_mapping: HashMap<String, String>,
1260 pub known_differences: Vec<CompatibilityDifference>,
1262}
1263
1264#[derive(Debug, Clone)]
1266pub struct CompatibilityDifference {
1267 pub category: DifferenceCategory,
1269 pub description: String,
1271 pub justification: String,
1273 pub workaround: Option<String>,
1275}
1276
1277#[derive(Debug, Clone, Copy)]
1279pub enum DifferenceCategory {
1280 Improvement,
1282 RustConstraint,
1284 Performance,
1286 Safety,
1288 Unintentional,
1290}
1291
1292#[derive(Debug, Clone)]
1294pub struct PerformanceBenchmark {
1295 pub name: String,
1297 pub expected_complexity: ComplexityClass,
1299 pub memory_usage: MemoryUsagePattern,
1301 pub scalability: ScalabilityRequirement,
1303}
1304
1305#[derive(Debug, Clone, Copy)]
1307pub enum ComplexityClass {
1308 Constant,
1309 Logarithmic,
1310 Linear,
1311 LogLinear,
1312 Quadratic,
1313 Cubic,
1314 Exponential,
1315}
1316
1317#[derive(Debug, Clone, Copy)]
1319pub enum MemoryUsagePattern {
1320 Constant,
1321 Linear,
1322 Quadratic,
1323 Streaming,
1324 OutOfCore,
1325}
1326
1327#[derive(Debug, Clone)]
1329pub struct ScalabilityRequirement {
1330 pub maxdatasize: usize,
1332 pub parallel_efficiency: f64,
1334 pub simd_acceleration: f64,
1336}
1337
1338#[derive(Debug, Clone)]
1340pub struct ErrorPattern {
1341 pub category: ErrorCategory,
1343 pub message_template: String,
1345 pub recovery_suggestions: Vec<String>,
1347 pub related_errors: Vec<String>,
1349}
1350
1351#[derive(Debug, Clone, Copy)]
1353pub enum ErrorCategory {
1354 InvalidInput,
1356 Numerical,
1358 Memory,
1360 Convergence,
1362 DimensionMismatch,
1364 NotImplemented,
1366 Internal,
1368}
1369
1370#[derive(Debug)]
1372pub struct ValidationReport {
1373 pub function_name: String,
1375 pub results: HashMap<String, ValidationResult>,
1377 pub overall_status: ValidationStatus,
1379 pub summary: ValidationSummary,
1381}
1382
1383#[derive(Debug, Clone, Copy)]
1385pub enum ValidationStatus {
1386 Passed,
1387 PassedWithWarnings,
1388 Failed,
1389 Critical,
1390}
1391
1392#[derive(Debug, Clone)]
1394pub struct ValidationSummary {
1395 pub total_rules: usize,
1397 pub passed: usize,
1399 pub warnings: usize,
1401 pub errors: usize,
1403 pub critical: usize,
1405}
1406
1407impl APIValidationFramework {
1408 pub fn new() -> Self {
1410 let mut framework = Self {
1411 validation_rules: HashMap::new(),
1412 compatibility_checkers: HashMap::new(),
1413 performance_benchmarks: HashMap::new(),
1414 error_patterns: HashMap::new(),
1415 };
1416
1417 framework.initialize_default_rules();
1418 framework
1419 }
1420
1421 fn initialize_default_rules(&mut self) {
1423 self.add_validation_rule(ValidationRule {
1425 id: "param_naming_consistency".to_string(),
1426 description: "Parameter names should follow consistent snake_case conventions"
1427 .to_string(),
1428 category: ValidationCategory::ParameterNaming,
1429 severity: ValidationSeverity::Warning,
1430 });
1431
1432 self.add_validation_rule(ValidationRule {
1434 id: "error_handling_consistency".to_string(),
1435 description: "Functions should return Result<T, StatsError> for consistency"
1436 .to_string(),
1437 category: ValidationCategory::ErrorHandling,
1438 severity: ValidationSeverity::Error,
1439 });
1440
1441 self.add_validation_rule(ValidationRule {
1443 id: "documentation_completeness".to_string(),
1444 description: "All public functions should have complete documentation".to_string(),
1445 category: ValidationCategory::Documentation,
1446 severity: ValidationSeverity::Warning,
1447 });
1448
1449 self.add_validation_rule(ValidationRule {
1451 id: "scipy_compatibility".to_string(),
1452 description: "Functions should maintain SciPy compatibility where possible".to_string(),
1453 category: ValidationCategory::ScipyCompatibility,
1454 severity: ValidationSeverity::Info,
1455 });
1456
1457 self.add_validation_rule(ValidationRule {
1459 id: "performance_characteristics".to_string(),
1460 description: "Functions should document performance characteristics".to_string(),
1461 category: ValidationCategory::Performance,
1462 severity: ValidationSeverity::Info,
1463 });
1464 }
1465
1466 pub fn add_validation_rule(&mut self, rule: ValidationRule) {
1468 let category_key = format!("{:?}", rule.category);
1469 self.validation_rules
1470 .entry(category_key)
1471 .or_default()
1472 .push(rule);
1473 }
1474
1475 pub fn validate_api(&self, signature: &APISignature) -> ValidationReport {
1477 let mut report = ValidationReport::new(signature.function_name.clone());
1478
1479 for rules in self.validation_rules.values() {
1480 for rule in rules {
1481 let result = self.apply_validation_rule(rule, signature);
1482 report.add_result(rule.id.clone(), result);
1483 }
1484 }
1485
1486 report
1487 }
1488
1489 fn apply_validation_rule(
1491 &self,
1492 rule: &ValidationRule,
1493 signature: &APISignature,
1494 ) -> ValidationResult {
1495 match rule.category {
1496 ValidationCategory::ParameterNaming => self.validate_parameter_naming(signature),
1497 ValidationCategory::ErrorHandling => self.validate_error_handling(signature),
1498 ValidationCategory::Documentation => self.validate_documentation(signature),
1499 ValidationCategory::ScipyCompatibility => self.validate_scipy_compatibility(signature),
1500 ValidationCategory::Performance => self.validate_performance(signature),
1501 _ => ValidationResult {
1502 passed: true,
1503 messages: vec![],
1504 suggested_fixes: vec![],
1505 related_rules: vec![],
1506 },
1507 }
1508 }
1509
1510 fn validate_parameter_naming(&self, signature: &APISignature) -> ValidationResult {
1512 let mut messages = Vec::new();
1513 let mut suggested_fixes = Vec::new();
1514
1515 for param in &signature.parameters {
1516 if param.name.contains(char::is_uppercase) || param.name.contains('-') {
1518 messages.push(ValidationMessage {
1519 severity: ValidationSeverity::Warning,
1520 message: format!("Parameter '{}' should use snake_case naming", param.name),
1521 location: Some(format!(
1522 "{}::{}",
1523 signature.module_path, signature.function_name
1524 )),
1525 rule_id: "param_naming_consistency".to_string(),
1526 });
1527 suggested_fixes.push(format!("Rename parameter '{}' to snake_case", param.name));
1528 }
1529 }
1530
1531 ValidationResult {
1532 passed: messages.is_empty(),
1533 messages,
1534 suggested_fixes,
1535 related_rules: vec!["return_type_consistency".to_string()],
1536 }
1537 }
1538
1539 fn validate_error_handling(&self, signature: &APISignature) -> ValidationResult {
1541 let mut messages = Vec::new();
1542 let mut suggested_fixes = Vec::new();
1543
1544 if !signature.return_type.result_wrapped {
1545 messages.push(ValidationMessage {
1546 severity: ValidationSeverity::Error,
1547 message: "Function should return Result<T, StatsError> for consistency".to_string(),
1548 location: Some(format!(
1549 "{}::{}",
1550 signature.module_path, signature.function_name
1551 )),
1552 rule_id: "error_handling_consistency".to_string(),
1553 });
1554 suggested_fixes.push("Wrap return type in Result<T, StatsError>".to_string());
1555 }
1556
1557 if let Some(error_type) = &signature.return_type.error_type {
1558 if error_type != "StatsError" {
1559 messages.push(ValidationMessage {
1560 severity: ValidationSeverity::Warning,
1561 message: format!("Non-standard error type '{}' used", error_type),
1562 location: Some(format!(
1563 "{}::{}",
1564 signature.module_path, signature.function_name
1565 )),
1566 rule_id: "error_handling_consistency".to_string(),
1567 });
1568 suggested_fixes.push("Use StatsError for consistency".to_string());
1569 }
1570 }
1571
1572 ValidationResult {
1573 passed: messages.is_empty(),
1574 messages,
1575 suggested_fixes,
1576 related_rules: vec!["documentation_completeness".to_string()],
1577 }
1578 }
1579
1580 fn validate_documentation(&self, signature: &APISignature) -> ValidationResult {
1582 let mut messages = Vec::new();
1583 let mut suggested_fixes = Vec::new();
1584
1585 if !signature.documentation.has_doc_comment {
1586 messages.push(ValidationMessage {
1587 severity: ValidationSeverity::Warning,
1588 message: "Function lacks documentation comment".to_string(),
1589 location: Some(format!(
1590 "{}::{}",
1591 signature.module_path, signature.function_name
1592 )),
1593 rule_id: "documentation_completeness".to_string(),
1594 });
1595 suggested_fixes.push("Add comprehensive doc comment".to_string());
1596 }
1597
1598 if !signature.documentation.has_examples {
1599 messages.push(ValidationMessage {
1600 severity: ValidationSeverity::Info,
1601 message: "Function lacks usage examples".to_string(),
1602 location: Some(format!(
1603 "{}::{}",
1604 signature.module_path, signature.function_name
1605 )),
1606 rule_id: "documentation_completeness".to_string(),
1607 });
1608 suggested_fixes.push("Add usage examples in # Examples section".to_string());
1609 }
1610
1611 ValidationResult {
1612 passed: messages
1613 .iter()
1614 .all(|m| matches!(m.severity, ValidationSeverity::Info)),
1615 messages,
1616 suggested_fixes,
1617 related_rules: vec!["scipy_compatibility".to_string()],
1618 }
1619 }
1620
1621 fn validate_scipy_compatibility(&self, signature: &APISignature) -> ValidationResult {
1623 let mut messages = Vec::new();
1624 let mut suggested_fixes = Vec::new();
1625
1626 let scipy_standard_params = [
1628 "axis",
1629 "ddof",
1630 "keepdims",
1631 "out",
1632 "dtype",
1633 "method",
1634 "alternative",
1635 ];
1636 let has_scipy_params = signature
1637 .parameters
1638 .iter()
1639 .any(|p| scipy_standard_params.contains(&p.name.as_str()));
1640
1641 if has_scipy_params && signature.documentation.scipy_compatibility.is_none() {
1642 messages.push(ValidationMessage {
1643 severity: ValidationSeverity::Info,
1644 message: "Consider documenting SciPy compatibility status".to_string(),
1645 location: Some(format!(
1646 "{}::{}",
1647 signature.module_path, signature.function_name
1648 )),
1649 rule_id: "scipy_compatibility".to_string(),
1650 });
1651 suggested_fixes.push("Add SciPy compatibility note in documentation".to_string());
1652 }
1653
1654 ValidationResult {
1655 passed: true, messages,
1657 suggested_fixes,
1658 related_rules: vec!["documentation_completeness".to_string()],
1659 }
1660 }
1661
1662 fn validate_performance(&self, signature: &APISignature) -> ValidationResult {
1664 let mut messages = Vec::new();
1665 let mut suggested_fixes = Vec::new();
1666
1667 if signature.performance.time_complexity.is_none() {
1668 messages.push(ValidationMessage {
1669 severity: ValidationSeverity::Info,
1670 message: "Consider documenting time complexity".to_string(),
1671 location: Some(format!(
1672 "{}::{}",
1673 signature.module_path, signature.function_name
1674 )),
1675 rule_id: "performance_characteristics".to_string(),
1676 });
1677 suggested_fixes.push("Add time complexity documentation".to_string());
1678 }
1679
1680 ValidationResult {
1681 passed: true, messages,
1683 suggested_fixes,
1684 related_rules: vec![],
1685 }
1686 }
1687}
1688
1689impl ValidationReport {
1690 pub fn new(_functionname: String) -> Self {
1692 Self {
1693 function_name: _functionname,
1694 results: HashMap::new(),
1695 overall_status: ValidationStatus::Passed,
1696 summary: ValidationSummary {
1697 total_rules: 0,
1698 passed: 0,
1699 warnings: 0,
1700 errors: 0,
1701 critical: 0,
1702 },
1703 }
1704 }
1705
1706 pub fn add_result(&mut self, ruleid: String, result: ValidationResult) {
1708 self.summary.total_rules += 1;
1709
1710 if result.passed {
1711 self.summary.passed += 1;
1712 } else {
1713 let max_severity = result
1714 .messages
1715 .iter()
1716 .map(|m| m.severity)
1717 .max()
1718 .unwrap_or(ValidationSeverity::Info);
1719
1720 match max_severity {
1721 ValidationSeverity::Info => {}
1722 ValidationSeverity::Warning => {
1723 self.summary.warnings += 1;
1724 if matches!(self.overall_status, ValidationStatus::Passed) {
1725 self.overall_status = ValidationStatus::PassedWithWarnings;
1726 }
1727 }
1728 ValidationSeverity::Error => {
1729 self.summary.errors += 1;
1730 if !matches!(self.overall_status, ValidationStatus::Critical) {
1731 self.overall_status = ValidationStatus::Failed;
1732 }
1733 }
1734 ValidationSeverity::Critical => {
1735 self.summary.critical += 1;
1736 self.overall_status = ValidationStatus::Critical;
1737 }
1738 }
1739 }
1740
1741 self.results.insert(ruleid, result);
1742 }
1743
1744 pub fn generate_report(&self) -> String {
1746 let mut report = String::new();
1747 report.push_str(&format!(
1748 "API Validation Report for {}\n",
1749 self.function_name
1750 ));
1751 report.push_str(&format!("Status: {:?}\n", self.overall_status));
1752 report.push_str(&format!(
1753 "Summary: {} passed, {} warnings, {} errors, {} critical\n\n",
1754 self.summary.passed, self.summary.warnings, self.summary.errors, self.summary.critical
1755 ));
1756
1757 for (rule_id, result) in &self.results {
1758 if !result.passed {
1759 report.push_str(&format!("Rule: {}\n", rule_id));
1760 for message in &result.messages {
1761 report.push_str(&format!(" {:?}: {}\n", message.severity, message.message));
1762 }
1763 if !result.suggested_fixes.is_empty() {
1764 report.push_str(" Suggestions:\n");
1765 for fix in &result.suggested_fixes {
1766 report.push_str(&format!(" - {}\n", fix));
1767 }
1768 }
1769 report.push('\n');
1770 }
1771 }
1772
1773 report
1774 }
1775}
1776
1777impl Default for APIValidationFramework {
1778 fn default() -> Self {
1779 Self::new()
1780 }
1781}
1782
1783fn rank_array<F, D>(
1790 data: &scirs2_core::ndarray::ArrayBase<D, scirs2_core::ndarray::Ix1>,
1791) -> StatsResult<Array1<F>>
1792where
1793 F: Float + NumCast,
1794 D: scirs2_core::ndarray::Data<Elem = F>,
1795{
1796 let n = data.len();
1797 if n == 0 {
1798 return Ok(Array1::zeros(0));
1799 }
1800
1801 let mut indexed: Vec<(F, usize)> = data
1803 .iter()
1804 .copied()
1805 .enumerate()
1806 .map(|(i, v)| (v, i))
1807 .collect();
1808 indexed.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
1809
1810 let mut ranks = vec![F::zero(); n];
1811 let mut i = 0;
1812 while i < n {
1813 let current_val = indexed[i].0;
1814 let mut j = i;
1815 while j + 1 < n && indexed[j + 1].0 == current_val {
1817 j += 1;
1818 }
1819 let avg_rank = F::from((i + j) as f64 / 2.0 + 1.0).ok_or_else(|| {
1821 StatsError::ComputationError("rank_array: numeric cast failed".to_string())
1822 })?;
1823 for item in indexed.iter().take(j + 1).skip(i) {
1824 ranks[item.1] = avg_rank;
1825 }
1826 i = j + 1;
1827 }
1828
1829 Ok(Array1::from(ranks))
1830}
1831
1832#[cfg(test)]
1835mod partial_corr_tests {
1836 use super::*;
1837 use scirs2_core::ndarray::{array, Array2};
1838
1839 #[test]
1840 fn test_partial_pearson_via_builder_with_controls() {
1841 let z = array![1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
1847 let x = array![1.3f64, 2.1, 3.4, 3.8, 5.2, 5.7, 7.3, 7.9, 9.4, 9.9]; let y = array![1.1f64, 2.4, 2.8, 4.2, 4.9, 6.1, 6.7, 8.3, 8.8, 10.1]; let controls = Array2::from_shape_vec((10, 1), z.to_vec()).expect("shape is valid");
1851
1852 let result = CorrelationBuilder::<f64>::new()
1853 .method(CorrelationMethod::PartialPearson)
1854 .with_controls(controls)
1855 .compute(x.view(), y.view())
1856 .expect("partial pearson should succeed");
1857
1858 let r = result.value.correlation;
1860 assert!(r.is_finite(), "partial Pearson must be finite, got {}", r);
1861 assert!(
1862 r >= -1.0 - 1e-10 && r <= 1.0 + 1e-10,
1863 "partial Pearson must be in [-1,1], got {}",
1864 r
1865 );
1866 assert!(
1868 result.warnings.is_empty(),
1869 "unexpected warnings: {:?}",
1870 result.warnings
1871 );
1872 }
1873
1874 #[test]
1875 fn test_partial_pearson_without_controls_falls_back() {
1876 let x = array![1.0f64, 2.0, 3.0, 4.0, 5.0];
1877 let y = array![2.0f64, 4.0, 6.0, 8.0, 10.0];
1878
1879 let result = CorrelationBuilder::<f64>::new()
1880 .method(CorrelationMethod::PartialPearson)
1881 .compute(x.view(), y.view())
1883 .expect("should succeed with fallback");
1884
1885 assert!(
1887 (result.value.correlation - 1.0_f64).abs() < 1e-10,
1888 "fallback Pearson correlation should be 1.0, got {}",
1889 result.value.correlation
1890 );
1891 assert!(
1892 !result.warnings.is_empty(),
1893 "should emit a warning about missing controls"
1894 );
1895 }
1896
1897 #[test]
1898 fn test_partial_spearman_via_builder_with_controls() {
1899 let z = array![1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
1906 let x = array![5.0f64, 1.0, 8.0, 3.0, 7.0, 2.0, 9.0, 4.0, 6.0, 10.0];
1908 let y = array![3.0f64, 7.0, 1.0, 9.0, 2.0, 8.0, 4.0, 6.0, 10.0, 5.0];
1910 let controls = Array2::from_shape_vec((10, 1), z.to_vec()).expect("shape is valid");
1911
1912 let result = CorrelationBuilder::<f64>::new()
1913 .method(CorrelationMethod::PartialSpearman)
1914 .with_controls(controls)
1915 .compute(x.view(), y.view())
1916 .expect("partial spearman should succeed");
1917
1918 let r = result.value.correlation;
1920 assert!(
1921 r.is_finite(),
1922 "partial Spearman should produce a finite value, got {}",
1923 r
1924 );
1925 assert!(
1926 r >= -1.0 - 1e-10 && r <= 1.0 + 1e-10,
1927 "partial Spearman must be in [-1,1], got {}",
1928 r
1929 );
1930 assert!(
1932 result.warnings.is_empty(),
1933 "unexpected warnings: {:?}",
1934 result.warnings
1935 );
1936 }
1937
1938 #[test]
1939 fn test_partial_spearman_without_controls_falls_back() {
1940 let x = array![1.0f64, 2.0, 3.0, 4.0, 5.0];
1941 let y = array![1.0f64, 4.0, 9.0, 16.0, 25.0]; let result = CorrelationBuilder::<f64>::new()
1944 .method(CorrelationMethod::PartialSpearman)
1945 .compute(x.view(), y.view())
1947 .expect("should succeed with fallback");
1948
1949 assert!(
1950 (result.value.correlation - 1.0_f64).abs() < 1e-10,
1951 "fallback Spearman should be 1.0 for monotone data, got {}",
1952 result.value.correlation
1953 );
1954 assert!(
1955 !result.warnings.is_empty(),
1956 "should emit a warning about missing controls"
1957 );
1958 }
1959
1960 #[test]
1961 fn test_rank_array_basic() {
1962 let data = array![3.0f64, 1.0, 4.0, 1.0, 5.0]; let ranks = rank_array(&data.view()).expect("should succeed");
1964 let expected = [3.0f64, 1.5, 4.0, 1.5, 5.0];
1966 for (i, (&r, &e)) in ranks.iter().zip(expected.iter()).enumerate() {
1967 assert!((r - e).abs() < 1e-10, "rank[{}] = {} expected {}", i, r, e);
1968 }
1969 }
1970}