#![allow(dead_code)]
use crate::error::StatsResult;
use crate::tests::ttest::Alternative;
use scirs2_core::ndarray::{ArrayBase, Data, Ix1};
use scirs2_core::numeric::Float;
#[derive(Debug, Clone, Copy)]
pub struct CorrelationResult<F> {
pub coefficient: F,
pub p_value: Option<F>,
}
impl<F: Float + std::fmt::Display> CorrelationResult<F> {
pub fn new(coefficient: F) -> Self {
Self {
coefficient,
p_value: None,
}
}
pub fn with_p_value(_coefficient: F, pvalue: F) -> Self {
Self {
coefficient: _coefficient,
p_value: Some(pvalue),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CorrelationMethod {
Pearson,
Spearman,
KendallTau,
}
impl CorrelationMethod {
pub fn from_str(s: &str) -> StatsResult<Self> {
match s.to_lowercase().as_str() {
"pearson" => Ok(CorrelationMethod::Pearson),
"spearman" => Ok(CorrelationMethod::Spearman),
"kendall" | "kendall_tau" | "kendalltau" => Ok(CorrelationMethod::KendallTau),
_ => Err(crate::error::StatsError::InvalidArgument(format!(
"Invalid correlation method: '{}'",
s
))),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum OptimizationHint {
Auto,
Scalar,
Simd,
Parallel,
}
#[derive(Debug, Clone)]
pub struct StatsConfig {
pub optimization: OptimizationHint,
pub compute_p_value: bool,
pub alternative: Alternative,
}
impl Default for StatsConfig {
fn default() -> Self {
Self {
optimization: OptimizationHint::Auto,
compute_p_value: false,
alternative: Alternative::TwoSided,
}
}
}
impl StatsConfig {
pub fn with_p_value(mut self) -> Self {
self.compute_p_value = true;
self
}
pub fn with_alternative(mut self, alternative: Alternative) -> Self {
self.alternative = alternative;
self
}
pub fn with_optimization(mut self, optimization: OptimizationHint) -> Self {
self.optimization = optimization;
self
}
}
pub trait CorrelationExt<F, D>
where
F: Float + std::fmt::Display + std::iter::Sum + Send + Sync,
D: Data<Elem = F>,
{
fn correlation(
&self,
other: &ArrayBase<D, Ix1>,
method: CorrelationMethod,
config: Option<StatsConfig>,
) -> StatsResult<CorrelationResult<F>>;
fn pearson(&self, other: &ArrayBase<D, Ix1>) -> StatsResult<F> {
self.correlation(other, CorrelationMethod::Pearson, None)
.map(|r| r.coefficient)
}
fn spearman(&self, other: &ArrayBase<D, Ix1>) -> StatsResult<F> {
self.correlation(other, CorrelationMethod::Spearman, None)
.map(|r| r.coefficient)
}
fn kendall(&self, other: &ArrayBase<D, Ix1>) -> StatsResult<F> {
self.correlation(other, CorrelationMethod::KendallTau, None)
.map(|r| r.coefficient)
}
}
pub struct StatsBuilder<F> {
data: Option<Vec<F>>,
config: StatsConfig,
}
impl<F: Float + std::fmt::Display + std::iter::Sum + Send + Sync> Default for StatsBuilder<F> {
fn default() -> Self {
Self::new()
}
}
impl<F: Float + std::fmt::Display + std::iter::Sum + Send + Sync> StatsBuilder<F> {
pub fn new() -> Self {
Self {
data: None,
config: StatsConfig::default(),
}
}
pub fn data(mut self, data: Vec<F>) -> StatsResult<Self> {
if data.is_empty() {
return Err(crate::error::StatsError::invalid_argument(
"Data cannot be empty",
));
}
self.data = Some(data);
Ok(self)
}
pub fn data_unchecked(mut self, data: Vec<F>) -> Self {
self.data = Some(data);
self
}
pub fn with_p_value(mut self) -> Self {
self.config.compute_p_value = true;
self
}
pub fn alternative(mut self, alt: Alternative) -> Self {
self.config.alternative = alt;
self
}
pub fn optimization(mut self, opt: OptimizationHint) -> Self {
self.config.optimization = opt;
self
}
pub fn validate(&self) -> StatsResult<()> {
if self.data.is_none() {
return Err(crate::error::StatsError::invalid_argument(
"No data provided to builder",
));
}
if let Some(ref data) = self.data {
if data.is_empty() {
return Err(crate::error::StatsError::invalid_argument(
"Data cannot be empty",
));
}
}
Ok(())
}
pub fn getdata(&self) -> Option<&Vec<F>> {
self.data.as_ref()
}
pub fn get_config(&self) -> &StatsConfig {
&self.config
}
}
#[derive(Debug, Clone)]
pub struct TestResult<F> {
pub statistic: F,
pub p_value: F,
pub df: Option<F>,
pub effectsize: Option<F>,
pub confidence_interval: Option<(F, F)>,
}
impl<F: Float + std::fmt::Display> TestResult<F> {
pub fn new(_statistic: F, pvalue: F) -> Self {
Self {
statistic: _statistic,
p_value: pvalue,
df: None,
effectsize: None,
confidence_interval: None,
}
}
pub fn with_df(mut self, df: F) -> Self {
self.df = Some(df);
self
}
pub fn with_effectsize(mut self, effectsize: F) -> Self {
self.effectsize = Some(effectsize);
self
}
pub fn with_confidence_interval(mut self, lower: F, upper: F) -> Self {
self.confidence_interval = Some((lower, upper));
self
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_correlation_method_from_str() {
assert_eq!(
CorrelationMethod::from_str("pearson").expect("Operation failed"),
CorrelationMethod::Pearson
);
assert_eq!(
CorrelationMethod::from_str("spearman").expect("Operation failed"),
CorrelationMethod::Spearman
);
assert_eq!(
CorrelationMethod::from_str("kendall").expect("Operation failed"),
CorrelationMethod::KendallTau
);
assert!(CorrelationMethod::from_str("invalid").is_err());
}
#[test]
fn test_stats_config_builder() {
let config = StatsConfig::default()
.with_p_value()
.with_alternative(Alternative::Greater);
assert!(config.compute_p_value);
assert_eq!(config.alternative, Alternative::Greater);
assert_eq!(config.optimization, OptimizationHint::Auto);
}
}