use crate::error::StatsResult;
use scirs2_core::ndarray::{Array1, ArrayView1};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug)]
pub struct NumericalStabilityAnalyzer {
config: StabilityConfig,
analysis_results: HashMap<String, StabilityAnalysisResult>,
}
#[derive(Debug, Clone)]
pub struct StabilityConfig {
pub zero_tolerance: f64,
pub precision_tolerance: f64,
pub max_condition_number: f64,
pub perturbation_tests: usize,
pub perturbation_magnitude: f64,
pub test_extreme_values: bool,
pub test_singular_cases: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StabilityAnalysisResult {
pub function_name: String,
pub stability_grade: StabilityGrade,
pub condition_analysis: ConditionNumberAnalysis,
pub error_propagation: ErrorPropagationAnalysis,
pub edge_case_robustness: EdgeCaseRobustness,
pub precision_analysis: PrecisionAnalysis,
pub recommendations: Vec<StabilityRecommendation>,
pub stability_score: f64,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
pub enum StabilityGrade {
Excellent,
Good,
Acceptable,
Poor,
Unstable,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConditionNumberAnalysis {
pub condition_number: f64,
pub conditioning_class: ConditioningClass,
pub accuracy_loss_digits: f64,
pub input_sensitivity: f64,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
pub enum ConditioningClass {
WellConditioned,
ModeratelyConditioned,
PoorlyConditioned,
NearlySingular,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ErrorPropagationAnalysis {
pub forward_error_bound: f64,
pub backward_error_bound: f64,
pub error_amplification: f64,
pub rounding_error_stability: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EdgeCaseRobustness {
pub handles_infinity: bool,
pub handles_nan: bool,
pub handles_zero: bool,
pub handles_large_values: bool,
pub handles_small_values: bool,
pub edge_case_success_rate: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PrecisionAnalysis {
pub precision_loss_bits: f64,
pub relative_precision: f64,
pub cancellation_errors: Vec<CancellationError>,
pub overflow_underflow_risk: OverflowRisk,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CancellationError {
pub location: String,
pub precision_loss: f64,
pub mitigation: String,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
pub enum OverflowRisk {
None,
Low,
Moderate,
High,
Certain,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StabilityRecommendation {
pub recommendation_type: RecommendationType,
pub description: String,
pub suggestion: String,
pub priority: RecommendationPriority,
pub expected_improvement: f64,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
pub enum RecommendationType {
Algorithm,
Numerical,
InputValidation,
Precision,
ErrorHandling,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord)]
pub enum RecommendationPriority {
Critical,
High,
Medium,
Low,
}
impl Default for StabilityConfig {
fn default() -> Self {
Self {
zero_tolerance: 1e-15,
precision_tolerance: 1e-12,
max_condition_number: 1e12,
perturbation_tests: 100,
perturbation_magnitude: 1e-10,
test_extreme_values: true,
test_singular_cases: true,
}
}
}
impl NumericalStabilityAnalyzer {
pub fn new(config: StabilityConfig) -> Self {
Self {
config,
analysis_results: HashMap::new(),
}
}
pub fn default() -> Self {
Self::new(StabilityConfig::default())
}
pub fn analyze_function<F>(
&mut self,
function_name: &str,
function: F,
testdata: &ArrayView1<f64>,
) -> StatsResult<StabilityAnalysisResult>
where
F: Fn(&ArrayView1<f64>) -> StatsResult<f64>,
{
let condition_analysis = self.analyze_condition_number(testdata)?;
let error_propagation = self.analyze_error_propagation(&function, testdata)?;
let edge_case_robustness = self.test_edge_case_robustness(&function)?;
let precision_analysis = self.analyze_precision_loss(&function, testdata)?;
let recommendations = self.generate_recommendations(
&condition_analysis,
&error_propagation,
&edge_case_robustness,
&precision_analysis,
);
let stability_score = self.calculate_stability_score(
&condition_analysis,
&error_propagation,
&edge_case_robustness,
&precision_analysis,
);
let stability_grade = self.grade_stability(stability_score);
let result = StabilityAnalysisResult {
function_name: function_name.to_string(),
stability_grade,
condition_analysis,
error_propagation,
edge_case_robustness,
precision_analysis,
recommendations,
stability_score,
};
self.analysis_results
.insert(function_name.to_string(), result.clone());
Ok(result)
}
fn analyze_condition_number(
&self,
data: &ArrayView1<f64>,
) -> StatsResult<ConditionNumberAnalysis> {
let data_range = data.iter().fold(0.0f64, |acc, &x| acc.max(x.abs()));
let data_min = data.iter().fold(f64::INFINITY, |acc, &x| acc.min(x.abs()));
let condition_number = if data_min > self.config.zero_tolerance {
data_range / data_min
} else {
f64::INFINITY
};
let conditioning_class = if condition_number < 1e12 {
ConditioningClass::WellConditioned
} else if condition_number < 1e14 {
ConditioningClass::ModeratelyConditioned
} else if condition_number < 1e16 {
ConditioningClass::PoorlyConditioned
} else {
ConditioningClass::NearlySingular
};
let accuracy_loss_digits = condition_number.log10().max(0.0);
let input_sensitivity = condition_number / 1e16;
Ok(ConditionNumberAnalysis {
condition_number,
conditioning_class,
accuracy_loss_digits,
input_sensitivity,
})
}
fn analyze_error_propagation<F>(
&self,
function: &F,
data: &ArrayView1<f64>,
) -> StatsResult<ErrorPropagationAnalysis>
where
F: Fn(&ArrayView1<f64>) -> StatsResult<f64>,
{
let reference_result = function(data)?;
let mut forward_errors = Vec::new();
let mut backward_errors = Vec::new();
for i in 0..self.config.perturbation_tests.min(data.len()) {
let mut perturbeddata = data.to_owned();
let perturbation = self.config.perturbation_magnitude * perturbeddata[i].abs().max(1.0);
perturbeddata[i] += perturbation;
if let Ok(perturbed_result) = function(&perturbeddata.view()) {
let forward_error = (perturbed_result - reference_result).abs();
let backward_error = perturbation.abs();
forward_errors.push(forward_error);
backward_errors.push(backward_error);
}
}
let forward_error_bound = forward_errors.iter().fold(0.0f64, |acc, &x| acc.max(x));
let backward_error_bound = backward_errors.iter().fold(0.0f64, |acc, &x| acc.max(x));
let error_amplification = if backward_error_bound > 0.0 {
forward_error_bound / backward_error_bound
} else {
1.0
};
let rounding_error_stability = 1.0 / (1.0 + error_amplification);
Ok(ErrorPropagationAnalysis {
forward_error_bound,
backward_error_bound,
error_amplification,
rounding_error_stability,
})
}
fn test_edge_case_robustness<F>(&self, function: &F) -> StatsResult<EdgeCaseRobustness>
where
F: Fn(&ArrayView1<f64>) -> StatsResult<f64>,
{
let mut tests_passed = 0;
let mut total_tests = 0;
let mut handles_infinity = false;
let mut handles_nan = false;
let mut handles_zero = false;
let mut handles_large_values = false;
let mut handles_small_values = false;
if self.config.test_extreme_values {
total_tests += 1;
let infdata = Array1::from_vec(vec![f64::INFINITY, 1.0, 2.0]);
if let Ok(result) = function(&infdata.view()) {
if result.is_finite() || result.is_infinite() {
handles_infinity = true;
tests_passed += 1;
}
}
total_tests += 1;
let nandata = Array1::from_vec(vec![f64::NAN, 1.0, 2.0]);
if let Ok(result) = function(&nandata.view()) {
if result.is_nan() || result.is_finite() {
handles_nan = true;
tests_passed += 1;
}
}
total_tests += 1;
let zerodata = Array1::from_vec(vec![0.0, 0.0, 0.0]);
if function(&zerodata.view()).is_ok() {
handles_zero = true;
tests_passed += 1;
}
total_tests += 1;
let largedata = Array1::from_vec(vec![1e100, 1e200, 1e300]);
if function(&largedata.view()).is_ok() {
handles_large_values = true;
tests_passed += 1;
}
total_tests += 1;
let smalldata = Array1::from_vec(vec![1e-100, 1e-200, 1e-300]);
if function(&smalldata.view()).is_ok() {
handles_small_values = true;
tests_passed += 1;
}
}
let edge_case_success_rate = if total_tests > 0 {
tests_passed as f64 / total_tests as f64
} else {
1.0
};
Ok(EdgeCaseRobustness {
handles_infinity,
handles_nan,
handles_zero,
handles_large_values,
handles_small_values,
edge_case_success_rate,
})
}
fn analyze_precision_loss<F>(
&self,
function: &F,
data: &ArrayView1<f64>,
) -> StatsResult<PrecisionAnalysis>
where
F: Fn(&ArrayView1<f64>) -> StatsResult<f64>,
{
let result = function(data)?;
let precision_loss_bits = if result.abs() < self.config.precision_tolerance {
16.0 } else if result.abs() < 1e-10 {
8.0 } else {
2.0 };
let relative_precision = 1.0 - (precision_loss_bits / 64.0);
let mut cancellation_errors = Vec::new();
if data.iter().any(|&x| x.abs() < self.config.zero_tolerance) {
cancellation_errors.push(CancellationError {
location: "inputdata".to_string(),
precision_loss: precision_loss_bits,
mitigation: "Use higher precision arithmetic or alternative algorithm".to_string(),
});
}
let max_val = data.iter().fold(0.0f64, |acc, &x| acc.max(x.abs()));
let overflow_underflow_risk = if max_val > 1e100 {
OverflowRisk::High
} else if max_val > 1e50 {
OverflowRisk::Moderate
} else if max_val < 1e-100 {
OverflowRisk::Moderate
} else {
OverflowRisk::Low
};
Ok(PrecisionAnalysis {
precision_loss_bits,
relative_precision,
cancellation_errors,
overflow_underflow_risk,
})
}
fn generate_recommendations(
&self,
condition_analysis: &ConditionNumberAnalysis,
error_propagation: &ErrorPropagationAnalysis,
edge_case_robustness: &EdgeCaseRobustness,
precision_analysis: &PrecisionAnalysis,
) -> Vec<StabilityRecommendation> {
let mut recommendations = Vec::new();
if matches!(
condition_analysis.conditioning_class,
ConditioningClass::PoorlyConditioned | ConditioningClass::NearlySingular
) {
recommendations.push(StabilityRecommendation {
recommendation_type: RecommendationType::Algorithm,
description: "Poor conditioning detected".to_string(),
suggestion: "Consider using regularization or alternative algorithms for ill-conditioned problems".to_string(),
priority: RecommendationPriority::High,
expected_improvement: 30.0,
});
}
if error_propagation.error_amplification > 100.0 {
recommendations.push(StabilityRecommendation {
recommendation_type: RecommendationType::Numerical,
description: "High error amplification detected".to_string(),
suggestion: "Implement error _analysis and use more stable numerical methods"
.to_string(),
priority: RecommendationPriority::High,
expected_improvement: 25.0,
});
}
if edge_case_robustness.edge_case_success_rate < 0.8 {
recommendations.push(StabilityRecommendation {
recommendation_type: RecommendationType::InputValidation,
description: "Poor edge case handling".to_string(),
suggestion:
"Improve input validation and add special case handling for extreme values"
.to_string(),
priority: RecommendationPriority::Medium,
expected_improvement: 20.0,
});
}
if precision_analysis.precision_loss_bits > 10.0 {
recommendations.push(StabilityRecommendation {
recommendation_type: RecommendationType::Precision,
description: "Significant precision loss detected".to_string(),
suggestion:
"Consider using higher precision arithmetic or numerically stable algorithms"
.to_string(),
priority: RecommendationPriority::High,
expected_improvement: 35.0,
});
}
recommendations
}
fn calculate_stability_score(
&self,
condition_analysis: &ConditionNumberAnalysis,
error_propagation: &ErrorPropagationAnalysis,
edge_case_robustness: &EdgeCaseRobustness,
precision_analysis: &PrecisionAnalysis,
) -> f64 {
let mut score = 100.0;
score -= match condition_analysis.conditioning_class {
ConditioningClass::WellConditioned => 0.0,
ConditioningClass::ModeratelyConditioned => 10.0,
ConditioningClass::PoorlyConditioned => 25.0,
ConditioningClass::NearlySingular => 40.0,
};
score -= (error_propagation.error_amplification.log10() * 5.0).min(30.0);
score -= (1.0 - edge_case_robustness.edge_case_success_rate) * 20.0;
score -= (precision_analysis.precision_loss_bits / 64.0) * 30.0;
score.max(0.0)
}
fn grade_stability(&self, score: f64) -> StabilityGrade {
if score >= 90.0 {
StabilityGrade::Excellent
} else if score >= 75.0 {
StabilityGrade::Good
} else if score >= 60.0 {
StabilityGrade::Acceptable
} else if score >= 40.0 {
StabilityGrade::Poor
} else {
StabilityGrade::Unstable
}
}
pub fn generate_stability_report(&self) -> StabilityReport {
let results: Vec<_> = self.analysis_results.values().cloned().collect();
let total_functions = results.len();
let excellent_count = results
.iter()
.filter(|r| r.stability_grade == StabilityGrade::Excellent)
.count();
let good_count = results
.iter()
.filter(|r| r.stability_grade == StabilityGrade::Good)
.count();
let acceptable_count = results
.iter()
.filter(|r| r.stability_grade == StabilityGrade::Acceptable)
.count();
let poor_count = results
.iter()
.filter(|r| r.stability_grade == StabilityGrade::Poor)
.count();
let unstable_count = results
.iter()
.filter(|r| r.stability_grade == StabilityGrade::Unstable)
.count();
let average_score = if total_functions > 0 {
results.iter().map(|r| r.stability_score).sum::<f64>() / total_functions as f64
} else {
0.0
};
StabilityReport {
total_functions,
excellent_count,
good_count,
acceptable_count,
poor_count,
unstable_count,
average_score,
function_results: results,
generated_at: chrono::Utc::now(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StabilityReport {
pub total_functions: usize,
pub excellent_count: usize,
pub good_count: usize,
pub acceptable_count: usize,
pub poor_count: usize,
pub unstable_count: usize,
pub average_score: f64,
pub function_results: Vec<StabilityAnalysisResult>,
pub generated_at: chrono::DateTime<chrono::Utc>,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::descriptive::mean;
#[test]
fn test_stability_analyzer_creation() {
let analyzer = NumericalStabilityAnalyzer::default();
assert_eq!(analyzer.config.zero_tolerance, 1e-15);
assert_eq!(analyzer.config.precision_tolerance, 1e-12);
}
#[test]
fn test_condition_number_analysis() {
let analyzer = NumericalStabilityAnalyzer::default();
let data = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
let result = analyzer
.analyze_condition_number(&data.view())
.expect("Operation failed");
assert_eq!(
result.conditioning_class,
ConditioningClass::WellConditioned
);
assert!(result.condition_number > 0.0);
}
#[test]
fn test_stability_grading() {
let analyzer = NumericalStabilityAnalyzer::default();
assert_eq!(analyzer.grade_stability(95.0), StabilityGrade::Excellent);
assert_eq!(analyzer.grade_stability(80.0), StabilityGrade::Good);
assert_eq!(analyzer.grade_stability(65.0), StabilityGrade::Acceptable);
assert_eq!(analyzer.grade_stability(45.0), StabilityGrade::Poor);
assert_eq!(analyzer.grade_stability(20.0), StabilityGrade::Unstable);
}
#[test]
fn test_mean_stability_analysis() {
let mut analyzer = NumericalStabilityAnalyzer::default();
let data = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
let result = analyzer
.analyze_function("mean", |x| mean(x), &data.view())
.expect("Operation failed");
assert_eq!(result.function_name, "mean");
assert!(matches!(
result.stability_grade,
StabilityGrade::Excellent | StabilityGrade::Good
));
assert!(result.stability_score > 50.0);
}
}