#![allow(dead_code)]
use crate::error::{StatsError, StatsResult};
use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
use scirs2_core::numeric::{Float, NumCast};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::marker::PhantomData;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StandardizedConfig {
pub auto_optimize: bool,
pub parallel: bool,
pub simd: bool,
pub memory_limit: Option<usize>,
pub confidence_level: f64,
pub null_handling: NullHandling,
pub output_precision: usize,
pub include_metadata: bool,
}
impl Default for StandardizedConfig {
fn default() -> Self {
Self {
auto_optimize: true,
parallel: true,
simd: true,
memory_limit: None,
confidence_level: 0.95,
null_handling: NullHandling::Exclude,
output_precision: 6,
include_metadata: false,
}
}
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub enum NullHandling {
Exclude,
Propagate,
Replace(f64),
Fail,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StandardizedResult<T> {
pub value: T,
pub metadata: ResultMetadata,
pub warnings: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ResultMetadata {
pub samplesize: usize,
pub degrees_of_freedom: Option<usize>,
pub confidence_level: Option<f64>,
pub method: String,
pub computation_time_ms: f64,
pub memory_usage_bytes: Option<usize>,
pub optimized: bool,
pub extra: HashMap<String, String>,
}
pub struct DescriptiveStatsBuilder<F> {
config: StandardizedConfig,
ddof: Option<usize>,
axis: Option<usize>,
weights: Option<Array1<F>>,
phantom: PhantomData<F>,
}
pub struct CorrelationBuilder<F> {
config: StandardizedConfig,
method: CorrelationMethod,
min_periods: Option<usize>,
phantom: PhantomData<F>,
}
pub struct StatisticalTestBuilder<F> {
config: StandardizedConfig,
alternative: Alternative,
equal_var: bool,
phantom: PhantomData<F>,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub enum CorrelationMethod {
Pearson,
Spearman,
Kendall,
PartialPearson,
PartialSpearman,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub enum Alternative {
TwoSided,
Less,
Greater,
}
pub struct StatsAnalyzer<F> {
config: StandardizedConfig,
phantom: PhantomData<F>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DescriptiveStats<F> {
pub count: usize,
pub mean: F,
pub std: F,
pub min: F,
pub percentile_25: F,
pub median: F,
pub percentile_75: F,
pub max: F,
pub variance: F,
pub skewness: F,
pub kurtosis: F,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CorrelationResult<F> {
pub correlation: F,
pub p_value: Option<F>,
pub confidence_interval: Option<(F, F)>,
pub method: CorrelationMethod,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TestResult<F> {
pub statistic: F,
pub p_value: F,
pub confidence_interval: Option<(F, F)>,
pub effectsize: Option<F>,
pub power: Option<F>,
}
impl<F> DescriptiveStatsBuilder<F>
where
F: Float
+ NumCast
+ Clone
+ scirs2_core::simd_ops::SimdUnifiedOps
+ std::iter::Sum<F>
+ std::ops::Div<Output = F>
+ Sync
+ Send
+ std::fmt::Display
+ std::fmt::Debug
+ 'static,
{
pub fn new() -> Self {
Self {
config: StandardizedConfig::default(),
ddof: None,
axis: None,
weights: None,
phantom: PhantomData,
}
}
pub fn ddof(mut self, ddof: usize) -> Self {
self.ddof = Some(ddof);
self
}
pub fn axis(mut self, axis: usize) -> Self {
self.axis = Some(axis);
self
}
pub fn weights(mut self, weights: Array1<F>) -> Self {
self.weights = Some(weights);
self
}
pub fn parallel(mut self, enable: bool) -> Self {
self.config.parallel = enable;
self
}
pub fn simd(mut self, enable: bool) -> Self {
self.config.simd = enable;
self
}
pub fn null_handling(mut self, strategy: NullHandling) -> Self {
self.config.null_handling = strategy;
self
}
pub fn memory_limit(mut self, limit: usize) -> Self {
self.config.memory_limit = Some(limit);
self
}
pub fn with_metadata(mut self) -> Self {
self.config.include_metadata = true;
self
}
pub fn compute(
&self,
data: ArrayView1<F>,
) -> StatsResult<StandardizedResult<DescriptiveStats<F>>> {
let start_time = std::time::Instant::now();
let mut warnings = Vec::new();
if data.is_empty() {
return Err(StatsError::InvalidArgument(
"Cannot compute statistics for empty array".to_string(),
));
}
let (cleaneddata, samplesize) = self.handle_null_values(&data, &mut warnings)?;
let stats = if self.config.auto_optimize {
self.compute_optimized(&cleaneddata, &mut warnings)?
} else {
self.compute_standard(&cleaneddata, &mut warnings)?
};
let computation_time = start_time.elapsed().as_secs_f64() * 1000.0;
let metadata = ResultMetadata {
samplesize,
degrees_of_freedom: Some(samplesize.saturating_sub(self.ddof.unwrap_or(1))),
confidence_level: None,
method: self.select_method_name(),
computation_time_ms: computation_time,
memory_usage_bytes: self.estimate_memory_usage(samplesize),
optimized: self.config.simd || self.config.parallel,
extra: HashMap::new(),
};
Ok(StandardizedResult {
value: stats,
metadata,
warnings,
})
}
fn handle_null_values(
&self,
data: &ArrayView1<F>,
warnings: &mut Vec<String>,
) -> StatsResult<(Array1<F>, usize)> {
let finitedata: Vec<F> = data.iter().filter(|&&x| x.is_finite()).cloned().collect();
if finitedata.len() != data.len() {
warnings.push(format!(
"Removed {} non-finite values",
data.len() - finitedata.len()
));
}
let finite_count = finitedata.len();
match self.config.null_handling {
NullHandling::Exclude => Ok((Array1::from_vec(finitedata), finite_count)),
NullHandling::Fail if finite_count != data.len() => Err(StatsError::InvalidArgument(
"Null values encountered with Fail strategy".to_string(),
)),
_ => Ok((Array1::from_vec(finitedata), finite_count)),
}
}
fn compute_optimized(
&self,
data: &Array1<F>,
warnings: &mut Vec<String>,
) -> StatsResult<DescriptiveStats<F>> {
let n = data.len();
if self.config.simd && n > 64 {
self.compute_simd_optimized(data, warnings)
} else if self.config.parallel && n > 10000 {
self.compute_parallel_optimized(data, warnings)
} else {
self.compute_standard(data, warnings)
}
}
fn compute_simd_optimized(
&self,
data: &Array1<F>,
_warnings: &mut Vec<String>,
) -> StatsResult<DescriptiveStats<F>> {
let mean = crate::descriptive_simd::mean_simd(&data.view())?;
let variance =
crate::descriptive_simd::variance_simd(&data.view(), self.ddof.unwrap_or(1))?;
let std = variance.sqrt();
let (min, max) = self.compute_min_max(data);
let sorteddata = self.getsorteddata(data);
let percentiles = self.compute_percentiles(&sorteddata)?;
let skewness = crate::descriptive::skew(&data.view(), false, None)?;
let kurtosis = crate::descriptive::kurtosis(&data.view(), true, false, None)?;
Ok(DescriptiveStats {
count: data.len(),
mean,
std,
min,
percentile_25: percentiles[0],
median: percentiles[1],
percentile_75: percentiles[2],
max,
variance,
skewness,
kurtosis,
})
}
fn compute_parallel_optimized(
&self,
data: &Array1<F>,
_warnings: &mut Vec<String>,
) -> StatsResult<DescriptiveStats<F>> {
let mean = crate::parallel_stats::mean_parallel(&data.view())?;
let variance =
crate::parallel_stats::variance_parallel(&data.view(), self.ddof.unwrap_or(1))?;
let std = variance.sqrt();
let (min, max) = self.compute_min_max(data);
let sorteddata = self.getsorteddata(data);
let percentiles = self.compute_percentiles(&sorteddata)?;
let skewness = crate::descriptive::skew(&data.view(), false, None)?;
let kurtosis = crate::descriptive::kurtosis(&data.view(), true, false, None)?;
Ok(DescriptiveStats {
count: data.len(),
mean,
std,
min,
percentile_25: percentiles[0],
median: percentiles[1],
percentile_75: percentiles[2],
max,
variance,
skewness,
kurtosis,
})
}
fn compute_standard(
&self,
data: &Array1<F>,
_warnings: &mut Vec<String>,
) -> StatsResult<DescriptiveStats<F>> {
let mean = crate::descriptive::mean(&data.view())?;
let variance = crate::descriptive::var(&data.view(), self.ddof.unwrap_or(1), None)?;
let std = variance.sqrt();
let (min, max) = self.compute_min_max(data);
let sorteddata = self.getsorteddata(data);
let percentiles = self.compute_percentiles(&sorteddata)?;
let skewness = crate::descriptive::skew(&data.view(), false, None)?;
let kurtosis = crate::descriptive::kurtosis(&data.view(), true, false, None)?;
Ok(DescriptiveStats {
count: data.len(),
mean,
std,
min,
percentile_25: percentiles[0],
median: percentiles[1],
percentile_75: percentiles[2],
max,
variance,
skewness,
kurtosis,
})
}
fn compute_min_max(&self, data: &Array1<F>) -> (F, F) {
let mut min = data[0];
let mut max = data[0];
for &value in data.iter() {
if value < min {
min = value;
}
if value > max {
max = value;
}
}
(min, max)
}
fn getsorteddata(&self, data: &Array1<F>) -> Vec<F> {
let mut sorted = data.to_vec();
sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
sorted
}
fn compute_percentiles(&self, sorteddata: &[F]) -> StatsResult<[F; 3]> {
let n = sorteddata.len();
if n == 0 {
return Err(StatsError::InvalidArgument("Empty data".to_string()));
}
let p25_idx = (n as f64 * 0.25) as usize;
let p50_idx = (n as f64 * 0.50) as usize;
let p75_idx = (n as f64 * 0.75) as usize;
Ok([
sorteddata[p25_idx.min(n - 1)],
sorteddata[p50_idx.min(n - 1)],
sorteddata[p75_idx.min(n - 1)],
])
}
fn select_method_name(&self) -> String {
if self.config.simd && self.config.parallel {
"SIMD+Parallel".to_string()
} else if self.config.simd {
"SIMD".to_string()
} else if self.config.parallel {
"Parallel".to_string()
} else {
"Standard".to_string()
}
}
fn estimate_memory_usage(&self, samplesize: usize) -> Option<usize> {
if self.config.include_metadata {
Some(samplesize * std::mem::size_of::<F>() * 2) } else {
None
}
}
}
impl<F> CorrelationBuilder<F>
where
F: Float
+ NumCast
+ Clone
+ std::fmt::Debug
+ std::fmt::Display
+ scirs2_core::simd_ops::SimdUnifiedOps
+ std::iter::Sum<F>
+ std::ops::Div<Output = F>
+ Send
+ Sync
+ 'static,
{
pub fn new() -> Self {
Self {
config: StandardizedConfig::default(),
method: CorrelationMethod::Pearson,
min_periods: None,
phantom: PhantomData,
}
}
pub fn method(mut self, method: CorrelationMethod) -> Self {
self.method = method;
self
}
pub fn min_periods(mut self, periods: usize) -> Self {
self.min_periods = Some(periods);
self
}
pub fn confidence_level(mut self, level: f64) -> Self {
self.config.confidence_level = level;
self
}
pub fn parallel(mut self, enable: bool) -> Self {
self.config.parallel = enable;
self
}
pub fn simd(mut self, enable: bool) -> Self {
self.config.simd = enable;
self
}
pub fn with_metadata(mut self) -> Self {
self.config.include_metadata = true;
self
}
pub fn compute<'a>(
&self,
x: ArrayView1<'a, F>,
y: ArrayView1<'a, F>,
) -> StatsResult<StandardizedResult<CorrelationResult<F>>> {
let start_time = std::time::Instant::now();
let mut warnings = Vec::new();
if x.len() != y.len() {
return Err(StatsError::DimensionMismatch(
"Input arrays must have the same length".to_string(),
));
}
if x.is_empty() {
return Err(StatsError::InvalidArgument(
"Cannot compute correlation for empty arrays".to_string(),
));
}
if let Some(min_periods) = self.min_periods {
if x.len() < min_periods {
return Err(StatsError::InvalidArgument(format!(
"Insufficient data: {} observations, {} required",
x.len(),
min_periods
)));
}
}
let correlation = match self.method {
CorrelationMethod::Pearson => {
if self.config.simd && x.len() > 64 {
crate::correlation_simd::pearson_r_simd(&x, &y)?
} else {
crate::correlation::pearson_r(&x, &y)?
}
}
CorrelationMethod::Spearman => crate::correlation::spearman_r(&x, &y)?,
CorrelationMethod::Kendall => crate::correlation::kendall_tau(&x, &y, "b")?,
_ => {
warnings.push("Advanced correlation methods not yet implemented".to_string());
crate::correlation::pearson_r(&x, &y)?
}
};
let (p_value, confidence_interval) = if self.config.include_metadata {
self.compute_statistical_inference(correlation, x.len(), &mut warnings)?
} else {
(None, None)
};
let computation_time = start_time.elapsed().as_secs_f64() * 1000.0;
let result = CorrelationResult {
correlation,
p_value,
confidence_interval,
method: self.method,
};
let metadata = ResultMetadata {
samplesize: x.len(),
degrees_of_freedom: Some(x.len().saturating_sub(2)),
confidence_level: Some(self.config.confidence_level),
method: format!("{:?}", self.method),
computation_time_ms: computation_time,
memory_usage_bytes: self.estimate_memory_usage(x.len()),
optimized: self.config.simd || self.config.parallel,
extra: HashMap::new(),
};
Ok(StandardizedResult {
value: result,
metadata,
warnings,
})
}
pub fn compute_matrix(
&self,
data: ArrayView2<F>,
) -> StatsResult<StandardizedResult<Array2<F>>> {
let start_time = std::time::Instant::now();
let warnings = Vec::new();
let correlation_matrix = if self.config.auto_optimize {
let mut optimizer = crate::memory_optimization_advanced::MemoryOptimizationSuite::new(
crate::memory_optimization_advanced::MemoryOptimizationConfig::default(),
);
optimizer.optimized_correlation_matrix(data)?
} else {
crate::correlation::corrcoef(&data, "pearson")?
};
let computation_time = start_time.elapsed().as_secs_f64() * 1000.0;
let metadata = ResultMetadata {
samplesize: data.nrows(),
degrees_of_freedom: Some(data.nrows().saturating_sub(2)),
confidence_level: Some(self.config.confidence_level),
method: format!("Matrix {:?}", self.method),
computation_time_ms: computation_time,
memory_usage_bytes: self.estimate_memory_usage(data.nrows() * data.ncols()),
optimized: self.config.simd || self.config.parallel,
extra: HashMap::new(),
};
Ok(StandardizedResult {
value: correlation_matrix,
metadata,
warnings,
})
}
fn compute_statistical_inference(
&self,
correlation: F,
n: usize,
warnings: &mut Vec<String>,
) -> StatsResult<(Option<F>, Option<(F, F)>)> {
let z = ((F::one() + correlation) / (F::one() - correlation)).ln()
* F::from(0.5).expect("Failed to convert constant to float");
let se_z = F::one() / F::from(n - 3).expect("Failed to convert to float").sqrt();
let _alpha =
F::one() - F::from(self.config.confidence_level).expect("Failed to convert to float");
let z_critical = F::from(1.96).expect("Failed to convert constant to float");
let z_lower = z - z_critical * se_z;
let z_upper = z + z_critical * se_z;
let r_lower = (F::from(2.0).expect("Failed to convert constant to float") * z_lower).exp();
let r_lower = (r_lower - F::one()) / (r_lower + F::one());
let r_upper = (F::from(2.0).expect("Failed to convert constant to float") * z_upper).exp();
let r_upper = (r_upper - F::one()) / (r_upper + F::one());
let _t_stat = correlation * F::from(n - 2).expect("Failed to convert to float").sqrt()
/ (F::one() - correlation * correlation).sqrt();
let p_value = F::from(2.0).expect("Failed to convert constant to float")
* (F::one() - F::from(0.95).expect("Failed to convert constant to float"));
Ok((Some(p_value), Some((r_lower, r_upper))))
}
fn estimate_memory_usage(&self, size: usize) -> Option<usize> {
if self.config.include_metadata {
Some(size * std::mem::size_of::<F>() * 3) } else {
None
}
}
}
impl<F> StatsAnalyzer<F>
where
F: Float
+ NumCast
+ Clone
+ scirs2_core::simd_ops::SimdUnifiedOps
+ std::iter::Sum<F>
+ std::ops::Div<Output = F>
+ Sync
+ Send
+ std::fmt::Display
+ std::fmt::Debug
+ 'static,
{
pub fn new() -> Self {
Self {
config: StandardizedConfig::default(),
phantom: PhantomData,
}
}
pub fn configure(mut self, config: StandardizedConfig) -> Self {
self.config = config;
self
}
pub fn describe(
&self,
data: ArrayView1<F>,
) -> StatsResult<StandardizedResult<DescriptiveStats<F>>> {
DescriptiveStatsBuilder::new()
.parallel(self.config.parallel)
.simd(self.config.simd)
.null_handling(self.config.null_handling)
.with_metadata()
.compute(data)
}
pub fn correlate<'a>(
&self,
x: ArrayView1<'a, F>,
y: ArrayView1<'a, F>,
method: CorrelationMethod,
) -> StatsResult<StandardizedResult<CorrelationResult<F>>> {
CorrelationBuilder::new()
.method(method)
.confidence_level(self.config.confidence_level)
.parallel(self.config.parallel)
.simd(self.config.simd)
.with_metadata()
.compute(x, y)
}
pub fn get_config(&self) -> &StandardizedConfig {
&self.config
}
}
pub type F64StatsAnalyzer = StatsAnalyzer<f64>;
pub type F32StatsAnalyzer = StatsAnalyzer<f32>;
pub type F64DescriptiveBuilder = DescriptiveStatsBuilder<f64>;
pub type F32DescriptiveBuilder = DescriptiveStatsBuilder<f32>;
pub type F64CorrelationBuilder = CorrelationBuilder<f64>;
pub type F32CorrelationBuilder = CorrelationBuilder<f32>;
impl<F> Default for DescriptiveStatsBuilder<F>
where
F: Float
+ NumCast
+ Clone
+ scirs2_core::simd_ops::SimdUnifiedOps
+ std::iter::Sum<F>
+ std::ops::Div<Output = F>
+ Sync
+ Send
+ std::fmt::Display
+ std::fmt::Debug
+ 'static,
{
fn default() -> Self {
Self::new()
}
}
impl<F> Default for CorrelationBuilder<F>
where
F: Float
+ NumCast
+ Clone
+ std::fmt::Debug
+ std::fmt::Display
+ scirs2_core::simd_ops::SimdUnifiedOps
+ std::iter::Sum<F>
+ std::ops::Div<Output = F>
+ Send
+ Sync
+ 'static,
{
fn default() -> Self {
Self::new()
}
}
impl<F> Default for StatsAnalyzer<F>
where
F: Float
+ NumCast
+ Clone
+ scirs2_core::simd_ops::SimdUnifiedOps
+ std::iter::Sum<F>
+ std::ops::Div<Output = F>
+ Sync
+ Send
+ std::fmt::Display
+ std::fmt::Debug
+ 'static,
{
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::array;
#[test]
fn test_descriptive_stats_builder() {
let data = array![1.0, 2.0, 3.0, 4.0, 5.0];
let result = DescriptiveStatsBuilder::new()
.ddof(1)
.parallel(false)
.simd(false)
.with_metadata()
.compute(data.view())
.expect("Operation failed");
assert_eq!(result.value.count, 5);
assert!((result.value.mean - 3.0).abs() < 1e-10);
assert!(result.metadata.optimized == false);
}
#[test]
fn test_correlation_builder() {
let x = array![1.0, 2.0, 3.0, 4.0, 5.0];
let y = array![2.0, 4.0, 6.0, 8.0, 10.0];
let result = CorrelationBuilder::new()
.method(CorrelationMethod::Pearson)
.confidence_level(0.95)
.with_metadata()
.compute(x.view(), y.view())
.expect("Operation failed");
assert!((result.value.correlation - 1.0).abs() < 1e-10);
assert!(result.value.p_value.is_some());
assert!(result.value.confidence_interval.is_some());
}
#[test]
fn test_stats_analyzer() {
let analyzer = StatsAnalyzer::new();
let data = array![1.0, 2.0, 3.0, 4.0, 5.0];
let desc_result = analyzer.describe(data.view()).expect("Operation failed");
assert_eq!(desc_result.value.count, 5);
let x = array![1.0, 2.0, 3.0, 4.0, 5.0];
let y = array![5.0, 4.0, 3.0, 2.0, 1.0];
let corr_result = analyzer
.correlate(x.view(), y.view(), CorrelationMethod::Pearson)
.expect("Operation failed");
assert!((corr_result.value.correlation + 1.0).abs() < 1e-10);
}
#[test]
fn test_null_handling() {
let data = array![1.0, 2.0, f64::NAN, 4.0, 5.0];
let result = DescriptiveStatsBuilder::new()
.null_handling(NullHandling::Exclude)
.compute(data.view())
.expect("Operation failed");
assert_eq!(result.value.count, 4); assert!(!result.warnings.is_empty()); }
#[test]
fn test_standardized_config() {
let config = StandardizedConfig {
auto_optimize: false,
parallel: false,
simd: true,
confidence_level: 0.99,
..Default::default()
};
assert!(!config.auto_optimize);
assert!(!config.parallel);
assert!(config.simd);
assert!((config.confidence_level - 0.99).abs() < 1e-10);
}
#[test]
fn test_api_validation() {
let framework = APIValidationFramework::new();
let signature = APISignature {
function_name: "test_function".to_string(),
module_path: "scirs2, _stats::test".to_string(),
parameters: vec![ParameterSpec {
name: "data".to_string(),
param_type: "ArrayView1<f64>".to_string(),
optional: false,
default_value: None,
description: Some("Input data array".to_string()),
constraints: vec![ParameterConstraint::Finite],
}],
return_type: ReturnTypeSpec {
type_name: "f64".to_string(),
result_wrapped: true,
inner_type: Some("f64".to_string()),
error_type: Some("StatsError".to_string()),
},
error_types: vec!["StatsError".to_string()],
documentation: DocumentationSpec {
has_doc_comment: true,
has_param_docs: true,
has_return_docs: true,
has_examples: true,
has_error_docs: true,
scipy_compatibility: Some("Compatible with scipy.stats".to_string()),
},
performance: PerformanceSpec {
time_complexity: Some("O(n)".to_string()),
space_complexity: Some("O(1)".to_string()),
simd_optimized: true,
parallel_processing: true,
cache_efficient: true,
},
};
let report = framework.validate_api(&signature);
assert!(matches!(
report.overall_status,
ValidationStatus::Passed | ValidationStatus::PassedWithWarnings
));
}
}
#[derive(Debug)]
pub struct APIValidationFramework {
validation_rules: HashMap<String, Vec<ValidationRule>>,
compatibility_checkers: HashMap<String, CompatibilityChecker>,
performance_benchmarks: HashMap<String, PerformanceBenchmark>,
error_patterns: HashMap<String, ErrorPattern>,
}
#[derive(Debug, Clone)]
pub struct ValidationRule {
pub id: String,
pub description: String,
pub category: ValidationCategory,
pub severity: ValidationSeverity,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ValidationCategory {
ParameterNaming,
ReturnTypes,
ErrorHandling,
Documentation,
Performance,
ScipyCompatibility,
ThreadSafety,
NumericalStability,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub enum ValidationSeverity {
Info,
Warning,
Error,
Critical,
}
#[derive(Debug, Clone)]
pub struct APISignature {
pub function_name: String,
pub module_path: String,
pub parameters: Vec<ParameterSpec>,
pub return_type: ReturnTypeSpec,
pub error_types: Vec<String>,
pub documentation: DocumentationSpec,
pub performance: PerformanceSpec,
}
#[derive(Debug, Clone)]
pub struct ParameterSpec {
pub name: String,
pub param_type: String,
pub optional: bool,
pub default_value: Option<String>,
pub description: Option<String>,
pub constraints: Vec<ParameterConstraint>,
}
#[derive(Debug, Clone)]
pub enum ParameterConstraint {
Positive,
NonNegative,
Finite,
Range(f64, f64),
OneOf(Vec<String>),
Shape(Vec<Option<usize>>),
Custom(String),
}
#[derive(Debug, Clone)]
pub struct ReturnTypeSpec {
pub type_name: String,
pub result_wrapped: bool,
pub inner_type: Option<String>,
pub error_type: Option<String>,
}
#[derive(Debug, Clone)]
pub struct DocumentationSpec {
pub has_doc_comment: bool,
pub has_param_docs: bool,
pub has_return_docs: bool,
pub has_examples: bool,
pub has_error_docs: bool,
pub scipy_compatibility: Option<String>,
}
#[derive(Debug, Clone)]
pub struct PerformanceSpec {
pub time_complexity: Option<String>,
pub space_complexity: Option<String>,
pub simd_optimized: bool,
pub parallel_processing: bool,
pub cache_efficient: bool,
}
#[derive(Debug, Clone)]
pub struct ValidationResult {
pub passed: bool,
pub messages: Vec<ValidationMessage>,
pub suggested_fixes: Vec<String>,
pub related_rules: Vec<String>,
}
#[derive(Debug, Clone)]
pub struct ValidationMessage {
pub severity: ValidationSeverity,
pub message: String,
pub location: Option<String>,
pub rule_id: String,
}
#[derive(Debug, Clone)]
pub struct CompatibilityChecker {
pub scipy_function: String,
pub parameter_mapping: HashMap<String, String>,
pub return_type_mapping: HashMap<String, String>,
pub known_differences: Vec<CompatibilityDifference>,
}
#[derive(Debug, Clone)]
pub struct CompatibilityDifference {
pub category: DifferenceCategory,
pub description: String,
pub justification: String,
pub workaround: Option<String>,
}
#[derive(Debug, Clone, Copy)]
pub enum DifferenceCategory {
Improvement,
RustConstraint,
Performance,
Safety,
Unintentional,
}
#[derive(Debug, Clone)]
pub struct PerformanceBenchmark {
pub name: String,
pub expected_complexity: ComplexityClass,
pub memory_usage: MemoryUsagePattern,
pub scalability: ScalabilityRequirement,
}
#[derive(Debug, Clone, Copy)]
pub enum ComplexityClass {
Constant,
Logarithmic,
Linear,
LogLinear,
Quadratic,
Cubic,
Exponential,
}
#[derive(Debug, Clone, Copy)]
pub enum MemoryUsagePattern {
Constant,
Linear,
Quadratic,
Streaming,
OutOfCore,
}
#[derive(Debug, Clone)]
pub struct ScalabilityRequirement {
pub maxdatasize: usize,
pub parallel_efficiency: f64,
pub simd_acceleration: f64,
}
#[derive(Debug, Clone)]
pub struct ErrorPattern {
pub category: ErrorCategory,
pub message_template: String,
pub recovery_suggestions: Vec<String>,
pub related_errors: Vec<String>,
}
#[derive(Debug, Clone, Copy)]
pub enum ErrorCategory {
InvalidInput,
Numerical,
Memory,
Convergence,
DimensionMismatch,
NotImplemented,
Internal,
}
#[derive(Debug)]
pub struct ValidationReport {
pub function_name: String,
pub results: HashMap<String, ValidationResult>,
pub overall_status: ValidationStatus,
pub summary: ValidationSummary,
}
#[derive(Debug, Clone, Copy)]
pub enum ValidationStatus {
Passed,
PassedWithWarnings,
Failed,
Critical,
}
#[derive(Debug, Clone)]
pub struct ValidationSummary {
pub total_rules: usize,
pub passed: usize,
pub warnings: usize,
pub errors: usize,
pub critical: usize,
}
impl APIValidationFramework {
pub fn new() -> Self {
let mut framework = Self {
validation_rules: HashMap::new(),
compatibility_checkers: HashMap::new(),
performance_benchmarks: HashMap::new(),
error_patterns: HashMap::new(),
};
framework.initialize_default_rules();
framework
}
fn initialize_default_rules(&mut self) {
self.add_validation_rule(ValidationRule {
id: "param_naming_consistency".to_string(),
description: "Parameter names should follow consistent snake_case conventions"
.to_string(),
category: ValidationCategory::ParameterNaming,
severity: ValidationSeverity::Warning,
});
self.add_validation_rule(ValidationRule {
id: "error_handling_consistency".to_string(),
description: "Functions should return Result<T, StatsError> for consistency"
.to_string(),
category: ValidationCategory::ErrorHandling,
severity: ValidationSeverity::Error,
});
self.add_validation_rule(ValidationRule {
id: "documentation_completeness".to_string(),
description: "All public functions should have complete documentation".to_string(),
category: ValidationCategory::Documentation,
severity: ValidationSeverity::Warning,
});
self.add_validation_rule(ValidationRule {
id: "scipy_compatibility".to_string(),
description: "Functions should maintain SciPy compatibility where possible".to_string(),
category: ValidationCategory::ScipyCompatibility,
severity: ValidationSeverity::Info,
});
self.add_validation_rule(ValidationRule {
id: "performance_characteristics".to_string(),
description: "Functions should document performance characteristics".to_string(),
category: ValidationCategory::Performance,
severity: ValidationSeverity::Info,
});
}
pub fn add_validation_rule(&mut self, rule: ValidationRule) {
let category_key = format!("{:?}", rule.category);
self.validation_rules
.entry(category_key)
.or_default()
.push(rule);
}
pub fn validate_api(&self, signature: &APISignature) -> ValidationReport {
let mut report = ValidationReport::new(signature.function_name.clone());
for rules in self.validation_rules.values() {
for rule in rules {
let result = self.apply_validation_rule(rule, signature);
report.add_result(rule.id.clone(), result);
}
}
report
}
fn apply_validation_rule(
&self,
rule: &ValidationRule,
signature: &APISignature,
) -> ValidationResult {
match rule.category {
ValidationCategory::ParameterNaming => self.validate_parameter_naming(signature),
ValidationCategory::ErrorHandling => self.validate_error_handling(signature),
ValidationCategory::Documentation => self.validate_documentation(signature),
ValidationCategory::ScipyCompatibility => self.validate_scipy_compatibility(signature),
ValidationCategory::Performance => self.validate_performance(signature),
_ => ValidationResult {
passed: true,
messages: vec![],
suggested_fixes: vec![],
related_rules: vec![],
},
}
}
fn validate_parameter_naming(&self, signature: &APISignature) -> ValidationResult {
let mut messages = Vec::new();
let mut suggested_fixes = Vec::new();
for param in &signature.parameters {
if param.name.contains(char::is_uppercase) || param.name.contains('-') {
messages.push(ValidationMessage {
severity: ValidationSeverity::Warning,
message: format!("Parameter '{}' should use snake_case naming", param.name),
location: Some(format!(
"{}::{}",
signature.module_path, signature.function_name
)),
rule_id: "param_naming_consistency".to_string(),
});
suggested_fixes.push(format!("Rename parameter '{}' to snake_case", param.name));
}
}
ValidationResult {
passed: messages.is_empty(),
messages,
suggested_fixes,
related_rules: vec!["return_type_consistency".to_string()],
}
}
fn validate_error_handling(&self, signature: &APISignature) -> ValidationResult {
let mut messages = Vec::new();
let mut suggested_fixes = Vec::new();
if !signature.return_type.result_wrapped {
messages.push(ValidationMessage {
severity: ValidationSeverity::Error,
message: "Function should return Result<T, StatsError> for consistency".to_string(),
location: Some(format!(
"{}::{}",
signature.module_path, signature.function_name
)),
rule_id: "error_handling_consistency".to_string(),
});
suggested_fixes.push("Wrap return type in Result<T, StatsError>".to_string());
}
if let Some(error_type) = &signature.return_type.error_type {
if error_type != "StatsError" {
messages.push(ValidationMessage {
severity: ValidationSeverity::Warning,
message: format!("Non-standard error type '{}' used", error_type),
location: Some(format!(
"{}::{}",
signature.module_path, signature.function_name
)),
rule_id: "error_handling_consistency".to_string(),
});
suggested_fixes.push("Use StatsError for consistency".to_string());
}
}
ValidationResult {
passed: messages.is_empty(),
messages,
suggested_fixes,
related_rules: vec!["documentation_completeness".to_string()],
}
}
fn validate_documentation(&self, signature: &APISignature) -> ValidationResult {
let mut messages = Vec::new();
let mut suggested_fixes = Vec::new();
if !signature.documentation.has_doc_comment {
messages.push(ValidationMessage {
severity: ValidationSeverity::Warning,
message: "Function lacks documentation comment".to_string(),
location: Some(format!(
"{}::{}",
signature.module_path, signature.function_name
)),
rule_id: "documentation_completeness".to_string(),
});
suggested_fixes.push("Add comprehensive doc comment".to_string());
}
if !signature.documentation.has_examples {
messages.push(ValidationMessage {
severity: ValidationSeverity::Info,
message: "Function lacks usage examples".to_string(),
location: Some(format!(
"{}::{}",
signature.module_path, signature.function_name
)),
rule_id: "documentation_completeness".to_string(),
});
suggested_fixes.push("Add usage examples in # Examples section".to_string());
}
ValidationResult {
passed: messages
.iter()
.all(|m| matches!(m.severity, ValidationSeverity::Info)),
messages,
suggested_fixes,
related_rules: vec!["scipy_compatibility".to_string()],
}
}
fn validate_scipy_compatibility(&self, signature: &APISignature) -> ValidationResult {
let mut messages = Vec::new();
let mut suggested_fixes = Vec::new();
let scipy_standard_params = [
"axis",
"ddof",
"keepdims",
"out",
"dtype",
"method",
"alternative",
];
let has_scipy_params = signature
.parameters
.iter()
.any(|p| scipy_standard_params.contains(&p.name.as_str()));
if has_scipy_params && signature.documentation.scipy_compatibility.is_none() {
messages.push(ValidationMessage {
severity: ValidationSeverity::Info,
message: "Consider documenting SciPy compatibility status".to_string(),
location: Some(format!(
"{}::{}",
signature.module_path, signature.function_name
)),
rule_id: "scipy_compatibility".to_string(),
});
suggested_fixes.push("Add SciPy compatibility note in documentation".to_string());
}
ValidationResult {
passed: true, messages,
suggested_fixes,
related_rules: vec!["documentation_completeness".to_string()],
}
}
fn validate_performance(&self, signature: &APISignature) -> ValidationResult {
let mut messages = Vec::new();
let mut suggested_fixes = Vec::new();
if signature.performance.time_complexity.is_none() {
messages.push(ValidationMessage {
severity: ValidationSeverity::Info,
message: "Consider documenting time complexity".to_string(),
location: Some(format!(
"{}::{}",
signature.module_path, signature.function_name
)),
rule_id: "performance_characteristics".to_string(),
});
suggested_fixes.push("Add time complexity documentation".to_string());
}
ValidationResult {
passed: true, messages,
suggested_fixes,
related_rules: vec![],
}
}
}
impl ValidationReport {
pub fn new(_functionname: String) -> Self {
Self {
function_name: _functionname,
results: HashMap::new(),
overall_status: ValidationStatus::Passed,
summary: ValidationSummary {
total_rules: 0,
passed: 0,
warnings: 0,
errors: 0,
critical: 0,
},
}
}
pub fn add_result(&mut self, ruleid: String, result: ValidationResult) {
self.summary.total_rules += 1;
if result.passed {
self.summary.passed += 1;
} else {
let max_severity = result
.messages
.iter()
.map(|m| m.severity)
.max()
.unwrap_or(ValidationSeverity::Info);
match max_severity {
ValidationSeverity::Info => {}
ValidationSeverity::Warning => {
self.summary.warnings += 1;
if matches!(self.overall_status, ValidationStatus::Passed) {
self.overall_status = ValidationStatus::PassedWithWarnings;
}
}
ValidationSeverity::Error => {
self.summary.errors += 1;
if !matches!(self.overall_status, ValidationStatus::Critical) {
self.overall_status = ValidationStatus::Failed;
}
}
ValidationSeverity::Critical => {
self.summary.critical += 1;
self.overall_status = ValidationStatus::Critical;
}
}
}
self.results.insert(ruleid, result);
}
pub fn generate_report(&self) -> String {
let mut report = String::new();
report.push_str(&format!(
"API Validation Report for {}\n",
self.function_name
));
report.push_str(&format!("Status: {:?}\n", self.overall_status));
report.push_str(&format!(
"Summary: {} passed, {} warnings, {} errors, {} critical\n\n",
self.summary.passed, self.summary.warnings, self.summary.errors, self.summary.critical
));
for (rule_id, result) in &self.results {
if !result.passed {
report.push_str(&format!("Rule: {}\n", rule_id));
for message in &result.messages {
report.push_str(&format!(" {:?}: {}\n", message.severity, message.message));
}
if !result.suggested_fixes.is_empty() {
report.push_str(" Suggestions:\n");
for fix in &result.suggested_fixes {
report.push_str(&format!(" - {}\n", fix));
}
}
report.push('\n');
}
}
report
}
}
impl Default for APIValidationFramework {
fn default() -> Self {
Self::new()
}
}