use super::types::*;
use crate::error::InterpolateResult;
use scirs2_core::numeric::Float;
use std::collections::{HashMap, VecDeque};
use std::fmt::Debug;
use std::time::Instant;
#[derive(Debug)]
pub struct AccuracyOptimizationEngine<F: Float + Debug> {
strategy: AccuracyOptimizationStrategy,
targets: AccuracyTargets<F>,
error_predictor: ErrorPredictionModel<F>,
optimization_history: VecDeque<AccuracyOptimizationResult>,
}
#[derive(Debug, Clone)]
pub enum AccuracyOptimizationStrategy {
MaximizeAccuracy,
BalancedAccuracy,
MinimumAccuracy,
Adaptive,
Custom {
accuracy_weight: f64,
performance_weight: f64,
},
}
#[derive(Debug, Clone)]
pub struct AccuracyTargets<F: Float> {
pub target_absolute_error: Option<F>,
pub target_relative_error: Option<F>,
pub max_acceptable_error: F,
pub confidence_level: F,
}
#[derive(Debug)]
pub struct ErrorPredictionModel<F: Float> {
prediction_params: HashMap<String, F>,
error_history: VecDeque<ErrorRecord<F>>,
model_accuracy: F,
}
#[derive(Debug, Clone)]
pub struct ErrorRecord<F: Float> {
pub predicted_error: F,
pub actual_error: F,
pub data_characteristics: String,
pub method: InterpolationMethodType,
pub timestamp: Instant,
}
#[derive(Debug, Clone)]
pub struct AccuracyOptimizationResult {
pub method: InterpolationMethodType,
pub adjusted_parameters: HashMap<String, f64>,
pub accuracy_improvement: f64,
pub performance_impact: f64,
pub success: bool,
pub timestamp: Instant,
}
impl<F: Float + Debug + std::ops::AddAssign> AccuracyOptimizationEngine<F> {
pub fn new() -> InterpolateResult<Self> {
Ok(Self {
strategy: AccuracyOptimizationStrategy::BalancedAccuracy,
targets: AccuracyTargets::default(),
error_predictor: ErrorPredictionModel::new()?,
optimization_history: VecDeque::new(),
})
}
pub fn set_strategy(&mut self, strategy: AccuracyOptimizationStrategy) {
self.strategy = strategy;
}
pub fn set_targets(&mut self, targets: AccuracyTargets<F>) {
self.targets = targets;
}
pub fn optimize_accuracy(
&mut self,
method: InterpolationMethodType,
data_profile: &DataProfile<F>,
current_parameters: &HashMap<String, f64>,
) -> InterpolateResult<AccuracyOptimizationResult> {
let start_time = Instant::now();
let predicted_accuracy = self.predict_accuracy(method, data_profile, current_parameters)?;
if self.meets_accuracy_targets(&predicted_accuracy)? {
return Ok(AccuracyOptimizationResult {
method,
adjusted_parameters: current_parameters.clone(),
accuracy_improvement: 0.0,
performance_impact: 0.0,
success: true,
timestamp: start_time,
});
}
let optimized_params = match &self.strategy {
AccuracyOptimizationStrategy::MaximizeAccuracy => {
self.maximize_accuracy_optimization(method, data_profile, current_parameters)?
}
AccuracyOptimizationStrategy::BalancedAccuracy => {
self.balanced_optimization(method, data_profile, current_parameters)?
}
AccuracyOptimizationStrategy::MinimumAccuracy => {
self.minimum_accuracy_optimization(method, data_profile, current_parameters)?
}
AccuracyOptimizationStrategy::Adaptive => {
self.adaptive_optimization(method, data_profile, current_parameters)?
}
AccuracyOptimizationStrategy::Custom {
accuracy_weight,
performance_weight,
} => self.custom_weighted_optimization(
method,
data_profile,
current_parameters,
*accuracy_weight,
*performance_weight,
)?,
};
let optimized_accuracy = self.predict_accuracy(method, data_profile, &optimized_params)?;
let accuracy_improvement = optimized_accuracy
.predicted_accuracy
.to_f64()
.unwrap_or(0.0)
- predicted_accuracy
.predicted_accuracy
.to_f64()
.unwrap_or(0.0);
let performance_impact =
self.estimate_performance_impact(&optimized_params, current_parameters);
let result = AccuracyOptimizationResult {
method,
adjusted_parameters: optimized_params,
accuracy_improvement,
performance_impact,
success: accuracy_improvement > 0.0,
timestamp: start_time,
};
self.optimization_history.push_back(result.clone());
if self.optimization_history.len() > 100 {
self.optimization_history.pop_front();
}
Ok(result)
}
pub fn predict_accuracy(
&self,
method: InterpolationMethodType,
data_profile: &DataProfile<F>,
parameters: &HashMap<String, f64>,
) -> InterpolateResult<AccuracyPrediction<F>> {
self.error_predictor
.predict_accuracy(method, data_profile, parameters)
}
pub fn update_error_model(
&mut self,
method: InterpolationMethodType,
data_profile: &DataProfile<F>,
predicted_error: F,
actual_error: F,
) -> InterpolateResult<()> {
let error_record = ErrorRecord {
predicted_error,
actual_error,
data_characteristics: format!(
"size:{},dim:{}",
data_profile.size, data_profile.dimensionality
),
method,
timestamp: Instant::now(),
};
self.error_predictor.add_error_record(error_record)?;
self.error_predictor.update_model()?;
Ok(())
}
pub fn get_optimization_history(&self) -> &VecDeque<AccuracyOptimizationResult> {
&self.optimization_history
}
pub fn get_targets(&self) -> &AccuracyTargets<F> {
&self.targets
}
fn meets_accuracy_targets(
&self,
prediction: &AccuracyPrediction<F>,
) -> InterpolateResult<bool> {
let predicted_error = prediction.predicted_accuracy;
if predicted_error > self.targets.max_acceptable_error {
return Ok(false);
}
if let Some(target_abs) = self.targets.target_absolute_error {
if predicted_error > target_abs {
return Ok(false);
}
}
if let Some(target_rel) = self.targets.target_relative_error {
if predicted_error > target_rel {
return Ok(false);
}
}
Ok(true)
}
fn maximize_accuracy_optimization(
&self,
method: InterpolationMethodType,
data_profile: &DataProfile<F>,
current_parameters: &HashMap<String, f64>,
) -> InterpolateResult<HashMap<String, f64>> {
let mut optimized = current_parameters.clone();
match method {
InterpolationMethodType::CubicSpline => {
if let Some(smoothing) = optimized.get_mut("smoothing") {
*smoothing *= 0.1;
}
}
InterpolationMethodType::BSpline => {
if let Some(degree) = optimized.get_mut("degree") {
*degree = (*degree + 1.0).min(5.0);
}
}
InterpolationMethodType::RadialBasisFunction => {
if let Some(shape) = optimized.get_mut("shape_parameter") {
*shape = self.optimize_rbf_shape_parameter(data_profile);
}
}
_ => {
optimized.insert("tolerance".to_string(), 1e-12);
optimized.insert("max_iterations".to_string(), 1000.0);
}
}
Ok(optimized)
}
fn balanced_optimization(
&self,
method: InterpolationMethodType,
data_profile: &DataProfile<F>,
current_parameters: &HashMap<String, f64>,
) -> InterpolateResult<HashMap<String, f64>> {
let mut optimized = current_parameters.clone();
let noise_level = data_profile.noise_level.to_f64().unwrap_or(0.1);
let smoothness = data_profile.smoothness.to_f64().unwrap_or(0.5);
match method {
InterpolationMethodType::CubicSpline => {
let smoothing_factor = if noise_level > 0.1 {
noise_level * 0.5
} else {
0.01
};
optimized.insert("smoothing".to_string(), smoothing_factor);
}
InterpolationMethodType::BSpline => {
let degree = if smoothness > 0.8 { 3.0 } else { 2.0 };
optimized.insert("degree".to_string(), degree);
}
_ => {
optimized.insert("tolerance".to_string(), 1e-8);
optimized.insert("max_iterations".to_string(), 100.0);
}
}
Ok(optimized)
}
fn minimum_accuracy_optimization(
&self,
_method: InterpolationMethodType,
_data_profile: &DataProfile<F>,
current_parameters: &HashMap<String, f64>,
) -> InterpolateResult<HashMap<String, f64>> {
let mut optimized = current_parameters.clone();
optimized.insert("tolerance".to_string(), 1e-4);
optimized.insert("max_iterations".to_string(), 50.0);
Ok(optimized)
}
fn adaptive_optimization(
&self,
method: InterpolationMethodType,
data_profile: &DataProfile<F>,
current_parameters: &HashMap<String, f64>,
) -> InterpolateResult<HashMap<String, f64>> {
let noise_level = data_profile.noise_level.to_f64().unwrap_or(0.1);
let data_size = data_profile.size;
if noise_level > 0.2 {
self.balanced_optimization(method, data_profile, current_parameters)
} else if data_size > 10000 {
self.minimum_accuracy_optimization(method, data_profile, current_parameters)
} else {
self.maximize_accuracy_optimization(method, data_profile, current_parameters)
}
}
fn custom_weighted_optimization(
&self,
method: InterpolationMethodType,
data_profile: &DataProfile<F>,
current_parameters: &HashMap<String, f64>,
accuracy_weight: f64,
performance_weight: f64,
) -> InterpolateResult<HashMap<String, f64>> {
let accuracy_params =
self.maximize_accuracy_optimization(method, data_profile, current_parameters)?;
let performance_params =
self.minimum_accuracy_optimization(method, data_profile, current_parameters)?;
let mut optimized = HashMap::new();
for (key, &acc_val) in &accuracy_params {
let perf_val = performance_params.get(key).copied().unwrap_or(acc_val);
let weighted_val = accuracy_weight * acc_val + performance_weight * perf_val;
optimized.insert(key.clone(), weighted_val);
}
Ok(optimized)
}
fn optimize_rbf_shape_parameter(&self, data_profile: &DataProfile<F>) -> f64 {
let typical_distance = (data_profile.value_range.1 - data_profile.value_range.0)
.to_f64()
.unwrap_or(1.0)
/ (data_profile.size as f64).sqrt();
1.0 / typical_distance
}
fn estimate_performance_impact(
&self,
optimized_params: &HashMap<String, f64>,
current_params: &HashMap<String, f64>,
) -> f64 {
let mut impact = 0.0;
if let (Some(&opt_tol), Some(&cur_tol)) = (
optimized_params.get("tolerance"),
current_params.get("tolerance"),
) {
if opt_tol < cur_tol {
impact += (cur_tol / opt_tol).log10() * 0.1; }
}
if let (Some(&opt_iter), Some(&cur_iter)) = (
optimized_params.get("max_iterations"),
current_params.get("max_iterations"),
) {
impact += (opt_iter / cur_iter - 1.0) * 0.05; }
if let (Some(&opt_deg), Some(&cur_deg)) =
(optimized_params.get("degree"), current_params.get("degree"))
{
impact += (opt_deg - cur_deg) * 0.15; }
impact.max(-0.5).min(2.0) }
}
impl<F: Float> Default for AccuracyTargets<F> {
fn default() -> Self {
Self {
target_absolute_error: None,
target_relative_error: None,
max_acceptable_error: F::from(1e-6).expect("Failed to convert constant to float"),
confidence_level: F::from(0.95).expect("Failed to convert constant to float"),
}
}
}
impl<F: Float + std::ops::AddAssign> ErrorPredictionModel<F> {
pub fn new() -> InterpolateResult<Self> {
Ok(Self {
prediction_params: HashMap::new(),
error_history: VecDeque::new(),
model_accuracy: F::from(0.8).expect("Failed to convert constant to float"),
})
}
pub fn predict_accuracy(
&self,
method: InterpolationMethodType,
data_profile: &DataProfile<F>,
_parameters: &HashMap<String, f64>,
) -> InterpolateResult<AccuracyPrediction<F>> {
let base_accuracy = self.get_base_accuracy(method);
let noise_penalty = data_profile.noise_level.to_f64().unwrap_or(0.1) * 0.5;
let size_bonus = if data_profile.size > 1000 { 0.05 } else { 0.0 };
let predicted_error = F::from(1.0 - base_accuracy + noise_penalty - size_bonus)
.expect("Failed to convert to float");
let confidence = self.model_accuracy;
Ok(AccuracyPrediction {
predicted_accuracy: predicted_error
.max(F::from(1e-12).expect("Failed to convert constant to float")),
confidence_interval: (
predicted_error * F::from(0.8).expect("Failed to convert constant to float"),
predicted_error * F::from(1.2).expect("Failed to convert constant to float"),
),
prediction_confidence: confidence,
accuracy_factors: vec![
AccuracyFactor {
name: "Method capability".to_string(),
impact: F::from(base_accuracy - 0.5).expect("Failed to convert to float"),
confidence: F::from(0.9).expect("Failed to convert constant to float"),
mitigations: vec!["Consider higher-order methods".to_string()],
},
AccuracyFactor {
name: "Data noise level".to_string(),
impact: F::from(-noise_penalty).expect("Failed to convert to float"),
confidence: F::from(0.8).expect("Failed to convert constant to float"),
mitigations: vec![
"Apply data smoothing".to_string(),
"Use robust methods".to_string(),
],
},
],
})
}
pub fn add_error_record(&mut self, record: ErrorRecord<F>) -> InterpolateResult<()> {
self.error_history.push_back(record);
if self.error_history.len() > 1000 {
self.error_history.pop_front();
}
Ok(())
}
pub fn update_model(&mut self) -> InterpolateResult<()> {
if self.error_history.len() < 10 {
return Ok(()); }
let recent_records: Vec<_> = self.error_history.iter().rev().take(50).collect();
let mut total_error = F::zero();
let mut count = 0;
for record in recent_records {
let relative_error =
(record.predicted_error - record.actual_error).abs() / record.actual_error;
total_error += relative_error;
count += 1;
}
if count > 0 {
let avg_relative_error =
total_error / F::from(count).expect("Failed to convert to float");
self.model_accuracy = (F::one() - avg_relative_error)
.max(F::from(0.1).expect("Failed to convert constant to float"));
}
Ok(())
}
fn get_base_accuracy(&self, method: InterpolationMethodType) -> f64 {
match method {
InterpolationMethodType::Linear => 0.7,
InterpolationMethodType::CubicSpline => 0.9,
InterpolationMethodType::BSpline => 0.92,
InterpolationMethodType::RadialBasisFunction => 0.95,
InterpolationMethodType::Kriging => 0.98,
InterpolationMethodType::Polynomial => 0.85,
InterpolationMethodType::PchipInterpolation => 0.88,
InterpolationMethodType::AkimaSpline => 0.87,
InterpolationMethodType::ThinPlateSpline => 0.93,
InterpolationMethodType::NaturalNeighbor => 0.86,
InterpolationMethodType::ShepardsMethod => 0.75,
InterpolationMethodType::QuantumInspired => 0.99,
}
}
pub fn get_model_accuracy(&self) -> F {
self.model_accuracy
}
pub fn get_prediction_statistics(&self) -> HashMap<String, f64> {
let mut stats = HashMap::new();
if !self.error_history.is_empty() {
let mut total_abs_error = F::zero();
let mut total_rel_error = F::zero();
let count = self.error_history.len();
for record in &self.error_history {
let abs_error = (record.predicted_error - record.actual_error).abs();
let rel_error = abs_error / record.actual_error;
total_abs_error += abs_error;
total_rel_error += rel_error;
}
stats.insert(
"mean_absolute_error".to_string(),
(total_abs_error / F::from(count).expect("Failed to convert to float"))
.to_f64()
.unwrap_or(0.0),
);
stats.insert(
"mean_relative_error".to_string(),
(total_rel_error / F::from(count).expect("Failed to convert to float"))
.to_f64()
.unwrap_or(0.0),
);
stats.insert(
"model_accuracy".to_string(),
self.model_accuracy.to_f64().unwrap_or(0.0),
);
stats.insert("sample_count".to_string(), count as f64);
}
stats
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_accuracy_targets_default() {
let targets: AccuracyTargets<f64> = AccuracyTargets::default();
assert_eq!(targets.max_acceptable_error, 1e-6);
assert_eq!(targets.confidence_level, 0.95);
assert!(targets.target_absolute_error.is_none());
}
#[test]
fn test_error_prediction_model_creation() {
let model: ErrorPredictionModel<f64> =
ErrorPredictionModel::new().expect("Operation failed");
assert_eq!(model.model_accuracy, 0.8);
assert!(model.error_history.is_empty());
}
#[test]
fn test_accuracy_optimization_engine_creation() {
let engine: AccuracyOptimizationEngine<f64> =
AccuracyOptimizationEngine::new().expect("Operation failed");
assert!(matches!(
engine.strategy,
AccuracyOptimizationStrategy::BalancedAccuracy
));
assert!(engine.optimization_history.is_empty());
}
}