use scirs2_stats::distributions::StudentT;
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, VecDeque};
use std::sync::Arc;
use std::time::{Duration, Instant, SystemTime};
use tokio::sync::RwLock;
use tracing::info;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum AnomalyDetectionMethod {
ZScore,
IQR,
MAD,
ModifiedZScore,
ExponentialSmoothing,
Ensemble,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
pub enum AnomalySeverity {
Low,
Medium,
High,
Critical,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Anomaly {
pub metric: String,
pub node_id: u64,
pub value: f64,
pub expected: f64,
pub deviation_score: f64,
pub method: AnomalyDetectionMethod,
pub severity: AnomalySeverity,
pub timestamp: SystemTime,
pub confidence: f64,
pub suggested_action: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FailurePrediction {
pub node_id: u64,
pub probability: f64,
pub time_to_failure: Option<Duration>,
pub risk_factors: Vec<RiskFactor>,
pub confidence: f64,
pub recommendations: Vec<String>,
pub timestamp: SystemTime,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RiskFactor {
pub name: String,
pub score: f64,
pub trend: TrendDirection,
pub description: String,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum TrendDirection {
Increasing,
Decreasing,
Stable,
Volatile,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LoadPrediction {
pub predicted_load: f64,
pub lower_bound: f64,
pub upper_bound: f64,
pub horizon: Duration,
pub confidence_level: f64,
pub seasonality: Option<SeasonalityInfo>,
pub trend: f64,
pub timestamp: SystemTime,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SeasonalityInfo {
pub period_seconds: u64,
pub amplitude: f64,
pub phase: f64,
pub strength: f64,
}
#[derive(Debug, Clone)]
struct TimeSeriesPoint {
value: f64,
timestamp: Instant,
}
pub struct MLClusterOptimizer {
node_id: u64,
metrics_history: Arc<RwLock<HashMap<String, VecDeque<TimeSeriesPoint>>>>,
anomaly_config: AnomalyDetectionConfig,
failure_config: FailurePredictionConfig,
load_config: LoadPredictionConfig,
anomalies: Arc<RwLock<Vec<Anomaly>>>,
failure_predictions: Arc<RwLock<HashMap<u64, FailurePrediction>>>,
training_state: Arc<RwLock<TrainingState>>,
enabled: Arc<RwLock<bool>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AnomalyDetectionConfig {
pub zscore_threshold: f64,
pub iqr_multiplier: f64,
pub mad_threshold: f64,
pub min_samples: usize,
pub max_history: usize,
pub methods: Vec<AnomalyDetectionMethod>,
pub ensemble_threshold: f64,
}
impl Default for AnomalyDetectionConfig {
fn default() -> Self {
Self {
zscore_threshold: 3.0,
iqr_multiplier: 1.5,
mad_threshold: 3.5,
min_samples: 30,
max_history: 1000,
methods: vec![
AnomalyDetectionMethod::ZScore,
AnomalyDetectionMethod::IQR,
AnomalyDetectionMethod::MAD,
],
ensemble_threshold: 0.5,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FailurePredictionConfig {
pub warning_threshold: f64,
pub critical_threshold: f64,
pub lookback_window: Duration,
pub prediction_horizon: Duration,
pub min_samples: usize,
pub risk_weights: HashMap<String, f64>,
}
impl Default for FailurePredictionConfig {
fn default() -> Self {
let mut risk_weights = HashMap::new();
risk_weights.insert("latency_trend".to_string(), 0.3);
risk_weights.insert("error_rate".to_string(), 0.25);
risk_weights.insert("memory_pressure".to_string(), 0.2);
risk_weights.insert("cpu_saturation".to_string(), 0.15);
risk_weights.insert("network_issues".to_string(), 0.1);
Self {
warning_threshold: 0.3,
critical_threshold: 0.7,
lookback_window: Duration::from_secs(3600), prediction_horizon: Duration::from_secs(1800), min_samples: 60,
risk_weights,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LoadPredictionConfig {
pub alpha: f64,
pub beta: f64,
pub gamma: f64,
pub seasonal_period: usize,
pub confidence_level: f64,
pub min_samples: usize,
}
impl Default for LoadPredictionConfig {
fn default() -> Self {
Self {
alpha: 0.3,
beta: 0.1,
gamma: 0.1,
seasonal_period: 24, confidence_level: 0.95,
min_samples: 48,
}
}
}
#[derive(Debug, Clone, Default)]
struct TrainingState {
last_training: Option<Instant>,
samples_processed: u64,
#[allow(dead_code)]
accuracy_metrics: HashMap<String, f64>,
is_trained: bool,
}
impl MLClusterOptimizer {
pub fn new(node_id: u64) -> Self {
Self {
node_id,
metrics_history: Arc::new(RwLock::new(HashMap::new())),
anomaly_config: AnomalyDetectionConfig::default(),
failure_config: FailurePredictionConfig::default(),
load_config: LoadPredictionConfig::default(),
anomalies: Arc::new(RwLock::new(Vec::new())),
failure_predictions: Arc::new(RwLock::new(HashMap::new())),
training_state: Arc::new(RwLock::new(TrainingState::default())),
enabled: Arc::new(RwLock::new(true)),
}
}
pub fn with_config(
node_id: u64,
anomaly_config: AnomalyDetectionConfig,
failure_config: FailurePredictionConfig,
load_config: LoadPredictionConfig,
) -> Self {
Self {
node_id,
metrics_history: Arc::new(RwLock::new(HashMap::new())),
anomaly_config,
failure_config,
load_config,
anomalies: Arc::new(RwLock::new(Vec::new())),
failure_predictions: Arc::new(RwLock::new(HashMap::new())),
training_state: Arc::new(RwLock::new(TrainingState::default())),
enabled: Arc::new(RwLock::new(true)),
}
}
pub async fn enable(&self) {
let mut enabled = self.enabled.write().await;
*enabled = true;
info!("ML cluster optimizer enabled for node {}", self.node_id);
}
pub async fn disable(&self) {
let mut enabled = self.enabled.write().await;
*enabled = false;
info!("ML cluster optimizer disabled for node {}", self.node_id);
}
pub async fn is_enabled(&self) -> bool {
*self.enabled.read().await
}
pub async fn record_metric(&self, metric: &str, value: f64) {
if !*self.enabled.read().await {
return;
}
let mut history = self.metrics_history.write().await;
let series = history
.entry(metric.to_string())
.or_insert_with(VecDeque::new);
series.push_back(TimeSeriesPoint {
value,
timestamp: Instant::now(),
});
while series.len() > self.anomaly_config.max_history {
series.pop_front();
}
let mut state = self.training_state.write().await;
state.samples_processed += 1;
}
pub async fn detect_anomalies(&self) -> Vec<Anomaly> {
if !*self.enabled.read().await {
return Vec::new();
}
let history = self.metrics_history.read().await;
let mut detected = Vec::new();
for (metric, series) in history.iter() {
if series.len() < self.anomaly_config.min_samples {
continue;
}
let values: Vec<f64> = series.iter().map(|p| p.value).collect();
let latest = *values.last().expect("collection validated to be non-empty");
let mut votes = 0;
let mut total_methods = 0;
let mut best_score = 0.0;
let mut best_method = AnomalyDetectionMethod::ZScore;
for method in &self.anomaly_config.methods {
total_methods += 1;
let (is_anomaly, score) = match method {
AnomalyDetectionMethod::ZScore => self.detect_zscore_anomaly(&values, latest),
AnomalyDetectionMethod::IQR => self.detect_iqr_anomaly(&values, latest),
AnomalyDetectionMethod::MAD => self.detect_mad_anomaly(&values, latest),
AnomalyDetectionMethod::ModifiedZScore => {
self.detect_modified_zscore_anomaly(&values, latest)
}
AnomalyDetectionMethod::ExponentialSmoothing => {
self.detect_exponential_smoothing_anomaly(&values, latest)
}
AnomalyDetectionMethod::Ensemble => {
continue;
}
};
if is_anomaly {
votes += 1;
if score > best_score {
best_score = score;
best_method = *method;
}
}
}
let vote_ratio = votes as f64 / total_methods as f64;
if vote_ratio >= self.anomaly_config.ensemble_threshold {
let mean = self.calculate_mean(&values);
let severity = self.classify_severity(best_score);
let confidence = vote_ratio;
detected.push(Anomaly {
metric: metric.clone(),
node_id: self.node_id,
value: latest,
expected: mean,
deviation_score: best_score,
method: best_method,
severity,
timestamp: SystemTime::now(),
confidence,
suggested_action: self.suggest_action(&severity, metric),
});
}
}
if !detected.is_empty() {
let mut anomalies = self.anomalies.write().await;
anomalies.extend(detected.clone());
let len = anomalies.len();
if len > 1000 {
anomalies.drain(0..len - 1000);
}
}
detected
}
fn detect_zscore_anomaly(&self, values: &[f64], latest: f64) -> (bool, f64) {
let mean = self.calculate_mean(values);
let std_dev = self.calculate_std_dev(values, mean);
if std_dev == 0.0 {
return (false, 0.0);
}
let zscore = (latest - mean).abs() / std_dev;
(zscore > self.anomaly_config.zscore_threshold, zscore)
}
fn detect_iqr_anomaly(&self, values: &[f64], latest: f64) -> (bool, f64) {
let mut sorted = values.to_vec();
sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let q1 = self.percentile(&sorted, 25.0);
let q3 = self.percentile(&sorted, 75.0);
let iqr = q3 - q1;
if iqr == 0.0 {
return (false, 0.0);
}
let lower = q1 - self.anomaly_config.iqr_multiplier * iqr;
let upper = q3 + self.anomaly_config.iqr_multiplier * iqr;
let is_anomaly = latest < lower || latest > upper;
let score = if is_anomaly {
if latest < lower {
(lower - latest) / iqr
} else {
(latest - upper) / iqr
}
} else {
0.0
};
(is_anomaly, score)
}
fn detect_mad_anomaly(&self, values: &[f64], latest: f64) -> (bool, f64) {
let median = self.calculate_median(values);
let deviations: Vec<f64> = values.iter().map(|&v| (v - median).abs()).collect();
let mad = self.calculate_median(&deviations);
if mad == 0.0 {
return (false, 0.0);
}
let score = 0.6745 * (latest - median).abs() / mad;
(score > self.anomaly_config.mad_threshold, score)
}
fn detect_modified_zscore_anomaly(&self, values: &[f64], latest: f64) -> (bool, f64) {
let median = self.calculate_median(values);
let deviations: Vec<f64> = values.iter().map(|&v| (v - median).abs()).collect();
let mad = self.calculate_median(&deviations);
if mad == 0.0 {
return (false, 0.0);
}
let modified_zscore = 0.6745 * (latest - median) / mad;
let score = modified_zscore.abs();
(score > self.anomaly_config.zscore_threshold, score)
}
fn detect_exponential_smoothing_anomaly(&self, values: &[f64], latest: f64) -> (bool, f64) {
if values.len() < 2 {
return (false, 0.0);
}
let alpha = self.load_config.alpha;
let mut smoothed = values[0];
let mut errors = Vec::new();
for &value in values.iter().skip(1) {
let error = (value - smoothed).abs();
errors.push(error);
smoothed = alpha * value + (1.0 - alpha) * smoothed;
}
let prediction_error = (latest - smoothed).abs();
let mean_error = self.calculate_mean(&errors);
let std_error = self.calculate_std_dev(&errors, mean_error);
if std_error == 0.0 {
return (false, 0.0);
}
let score = (prediction_error - mean_error) / std_error;
(score > self.anomaly_config.zscore_threshold, score.abs())
}
pub async fn predict_failure(&self, node_id: u64) -> Option<FailurePrediction> {
if !*self.enabled.read().await {
return None;
}
let history = self.metrics_history.read().await;
let mut risk_factors = Vec::new();
let mut total_risk = 0.0;
let mut total_weight = 0.0;
for (factor_name, weight) in &self.failure_config.risk_weights {
let metric_name = self.factor_to_metric(factor_name);
if let Some(series) = history.get(&metric_name) {
if series.len() >= self.failure_config.min_samples {
let values: Vec<f64> = series.iter().map(|p| p.value).collect();
let (score, trend) = self.analyze_risk_factor(&values);
let description = self.describe_risk_factor(factor_name, score, &trend);
risk_factors.push(RiskFactor {
name: factor_name.clone(),
score,
trend,
description,
});
total_risk += score * weight;
total_weight += weight;
}
}
}
if total_weight == 0.0 {
return None;
}
let probability = total_risk / total_weight;
let confidence = (total_weight / self.failure_config.risk_weights.len() as f64).min(1.0);
let time_to_failure = if probability > self.failure_config.warning_threshold {
Some(self.estimate_time_to_failure(probability))
} else {
None
};
let recommendations = self.generate_failure_recommendations(probability, &risk_factors);
let prediction = FailurePrediction {
node_id,
probability,
time_to_failure,
risk_factors,
confidence,
recommendations,
timestamp: SystemTime::now(),
};
let mut predictions = self.failure_predictions.write().await;
predictions.insert(node_id, prediction.clone());
Some(prediction)
}
pub async fn predict_load(&self, metric: &str, horizon: Duration) -> Option<LoadPrediction> {
if !*self.enabled.read().await {
return None;
}
let history = self.metrics_history.read().await;
let series = history.get(metric)?;
if series.len() < self.load_config.min_samples {
return None;
}
let values: Vec<f64> = series.iter().map(|p| p.value).collect();
let (level, trend, _seasonal) = self.holt_winters_decompose(&values);
let avg_interval = if series.len() > 1 {
let first = series
.front()
.expect("series should not be empty when len > 1")
.timestamp;
let last = series
.back()
.expect("series should not be empty when len > 1")
.timestamp;
last.duration_since(first).as_secs_f64() / (series.len() - 1) as f64
} else {
1.0
};
let steps_ahead = (horizon.as_secs_f64() / avg_interval).ceil() as usize;
let predicted = level + trend * steps_ahead as f64;
let std_dev = self.calculate_std_dev(&values, self.calculate_mean(&values));
let t_value = self.get_t_value(values.len() - 1, self.load_config.confidence_level);
let margin = t_value * std_dev * (1.0 + 1.0 / values.len() as f64).sqrt();
let seasonality = self.detect_seasonality(&values);
Some(LoadPrediction {
predicted_load: predicted,
lower_bound: predicted - margin,
upper_bound: predicted + margin,
horizon,
confidence_level: self.load_config.confidence_level,
seasonality,
trend,
timestamp: SystemTime::now(),
})
}
fn holt_winters_decompose(&self, values: &[f64]) -> (f64, f64, Vec<f64>) {
let alpha = self.load_config.alpha;
let beta = self.load_config.beta;
let mut level = values[0];
let mut trend = if values.len() > 1 {
values[1] - values[0]
} else {
0.0
};
let seasonal = vec![0.0; self.load_config.seasonal_period];
for &value in values.iter().skip(1) {
let prev_level = level;
level = alpha * value + (1.0 - alpha) * (prev_level + trend);
trend = beta * (level - prev_level) + (1.0 - beta) * trend;
}
(level, trend, seasonal)
}
fn detect_seasonality(&self, values: &[f64]) -> Option<SeasonalityInfo> {
if values.len() < self.load_config.seasonal_period * 2 {
return None;
}
let mean = self.calculate_mean(values);
let variance = self.calculate_variance(values, mean);
if variance == 0.0 {
return None;
}
let period = self.load_config.seasonal_period;
let mut max_autocorr = 0.0;
let mut best_lag = 0;
for lag in 1..=period * 2 {
if lag >= values.len() {
break;
}
let mut sum = 0.0;
for i in lag..values.len() {
sum += (values[i] - mean) * (values[i - lag] - mean);
}
let autocorr = sum / ((values.len() - lag) as f64 * variance);
if autocorr > max_autocorr {
max_autocorr = autocorr;
best_lag = lag;
}
}
if max_autocorr > 0.3 {
Some(SeasonalityInfo {
period_seconds: best_lag as u64 * 60, amplitude: variance.sqrt() * max_autocorr,
phase: 0.0,
strength: max_autocorr,
})
} else {
None
}
}
pub async fn get_recent_anomalies(&self, limit: usize) -> Vec<Anomaly> {
let anomalies = self.anomalies.read().await;
anomalies.iter().rev().take(limit).cloned().collect()
}
pub async fn get_failure_predictions(&self) -> HashMap<u64, FailurePrediction> {
self.failure_predictions.read().await.clone()
}
pub async fn reset(&self) {
let mut history = self.metrics_history.write().await;
history.clear();
let mut anomalies = self.anomalies.write().await;
anomalies.clear();
let mut predictions = self.failure_predictions.write().await;
predictions.clear();
let mut state = self.training_state.write().await;
*state = TrainingState::default();
info!("Reset ML optimizer for node {}", self.node_id);
}
pub async fn get_statistics(&self) -> MLOptimizerStatistics {
let history = self.metrics_history.read().await;
let anomalies = self.anomalies.read().await;
let predictions = self.failure_predictions.read().await;
let state = self.training_state.read().await;
MLOptimizerStatistics {
node_id: self.node_id,
metrics_tracked: history.len(),
total_samples: state.samples_processed,
anomalies_detected: anomalies.len(),
active_predictions: predictions.len(),
is_trained: state.is_trained,
last_training: state.last_training.map(|t| t.elapsed()),
}
}
fn calculate_mean(&self, values: &[f64]) -> f64 {
if values.is_empty() {
return 0.0;
}
values.iter().sum::<f64>() / values.len() as f64
}
fn calculate_std_dev(&self, values: &[f64], mean: f64) -> f64 {
if values.len() < 2 {
return 0.0;
}
let variance =
values.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / (values.len() - 1) as f64;
variance.sqrt()
}
fn calculate_variance(&self, values: &[f64], mean: f64) -> f64 {
if values.len() < 2 {
return 0.0;
}
values.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / (values.len() - 1) as f64
}
fn calculate_median(&self, values: &[f64]) -> f64 {
if values.is_empty() {
return 0.0;
}
let mut sorted = values.to_vec();
sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let mid = sorted.len() / 2;
if sorted.len() % 2 == 0 {
(sorted[mid - 1] + sorted[mid]) / 2.0
} else {
sorted[mid]
}
}
fn percentile(&self, sorted_values: &[f64], p: f64) -> f64 {
if sorted_values.is_empty() {
return 0.0;
}
let idx = ((p / 100.0) * (sorted_values.len() - 1) as f64).round() as usize;
sorted_values[idx.min(sorted_values.len() - 1)]
}
fn classify_severity(&self, score: f64) -> AnomalySeverity {
if score > 5.0 {
AnomalySeverity::Critical
} else if score > 4.0 {
AnomalySeverity::High
} else if score > 3.0 {
AnomalySeverity::Medium
} else {
AnomalySeverity::Low
}
}
fn suggest_action(&self, severity: &AnomalySeverity, metric: &str) -> String {
match severity {
AnomalySeverity::Critical => {
format!("CRITICAL: Immediate investigation required for {}. Consider automatic failover.", metric)
}
AnomalySeverity::High => {
format!(
"HIGH: Investigate {} anomaly urgently. Consider scaling or load balancing.",
metric
)
}
AnomalySeverity::Medium => {
format!(
"MEDIUM: Monitor {} closely. May require attention soon.",
metric
)
}
AnomalySeverity::Low => {
format!("LOW: Minor deviation in {}. Continue monitoring.", metric)
}
}
}
fn factor_to_metric(&self, factor: &str) -> String {
match factor {
"latency_trend" => "latency_ms".to_string(),
"error_rate" => "error_rate".to_string(),
"memory_pressure" => "memory_usage_bytes".to_string(),
"cpu_saturation" => "cpu_usage_percent".to_string(),
"network_issues" => "network_errors".to_string(),
_ => factor.to_string(),
}
}
fn analyze_risk_factor(&self, values: &[f64]) -> (f64, TrendDirection) {
let mean = self.calculate_mean(values);
let std_dev = self.calculate_std_dev(values, mean);
let latest = *values.last().expect("collection validated to be non-empty");
let trend = self.calculate_trend(values);
let trend_dir = if trend.abs() < std_dev * 0.1 {
TrendDirection::Stable
} else if trend > 0.0 {
TrendDirection::Increasing
} else {
TrendDirection::Decreasing
};
let cv = if mean != 0.0 { std_dev / mean } else { 0.0 };
let trend_dir = if cv > 0.5 {
TrendDirection::Volatile
} else {
trend_dir
};
let deviation = if std_dev != 0.0 {
((latest - mean) / std_dev).abs()
} else {
0.0
};
let trend_factor = match trend_dir {
TrendDirection::Increasing => 1.5,
TrendDirection::Volatile => 1.3,
TrendDirection::Stable => 1.0,
TrendDirection::Decreasing => 0.8,
};
let score = (deviation * trend_factor / 5.0).min(1.0);
(score, trend_dir)
}
fn calculate_trend(&self, values: &[f64]) -> f64 {
if values.len() < 2 {
return 0.0;
}
let n = values.len() as f64;
let mut sum_x = 0.0;
let mut sum_y = 0.0;
let mut sum_xy = 0.0;
let mut sum_x2 = 0.0;
for (i, &y) in values.iter().enumerate() {
let x = i as f64;
sum_x += x;
sum_y += y;
sum_xy += x * y;
sum_x2 += x * x;
}
let denominator = n * sum_x2 - sum_x * sum_x;
if denominator == 0.0 {
0.0
} else {
(n * sum_xy - sum_x * sum_y) / denominator
}
}
fn describe_risk_factor(&self, factor: &str, score: f64, trend: &TrendDirection) -> String {
let severity = if score > 0.7 {
"critical"
} else if score > 0.5 {
"elevated"
} else if score > 0.3 {
"moderate"
} else {
"normal"
};
let trend_desc = match trend {
TrendDirection::Increasing => "and increasing",
TrendDirection::Decreasing => "and decreasing",
TrendDirection::Stable => "and stable",
TrendDirection::Volatile => "with high volatility",
};
format!("{} is {} {}", factor, severity, trend_desc)
}
fn estimate_time_to_failure(&self, probability: f64) -> Duration {
let base_hours = 24.0 * (1.0 - probability);
Duration::from_secs((base_hours * 3600.0) as u64)
}
fn generate_failure_recommendations(
&self,
probability: f64,
risk_factors: &[RiskFactor],
) -> Vec<String> {
let mut recommendations = Vec::new();
if probability > self.failure_config.critical_threshold {
recommendations
.push("CRITICAL: Consider immediate failover to backup node".to_string());
recommendations.push("Prepare disaster recovery procedures".to_string());
} else if probability > self.failure_config.warning_threshold {
recommendations.push("Schedule maintenance window soon".to_string());
recommendations.push("Increase monitoring frequency".to_string());
}
for factor in risk_factors {
if factor.score > 0.5 {
match factor.name.as_str() {
"latency_trend" => {
recommendations
.push("Consider adding more nodes to distribute load".to_string());
}
"error_rate" => {
recommendations
.push("Review error logs and fix underlying issues".to_string());
}
"memory_pressure" => {
recommendations
.push("Consider increasing memory or optimizing queries".to_string());
}
"cpu_saturation" => {
recommendations.push(
"Consider scaling horizontally or optimizing workload".to_string(),
);
}
"network_issues" => {
recommendations
.push("Check network connectivity and bandwidth".to_string());
}
_ => {}
}
}
}
recommendations
}
fn get_t_value(&self, df: usize, confidence: f64) -> f64 {
if df == 0 {
return 1.96; }
let df = df as f64;
let alpha = (1.0 - confidence) / 2.0;
if let Ok(dist) = StudentT::new(0.0, 1.0, df) {
let target = 1.0 - alpha;
let mut low = 0.0;
let mut high = 10.0;
for _ in 0..50 {
let mid = (low + high) / 2.0;
let cdf = dist.cdf(mid);
if cdf < target {
low = mid;
} else {
high = mid;
}
}
(low + high) / 2.0
} else {
match confidence {
c if c >= 0.99 => 2.576 + 2.0 / df,
c if c >= 0.95 => 1.96 + 1.0 / df,
c if c >= 0.90 => 1.645 + 0.5 / df,
_ => 1.96,
}
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MLOptimizerStatistics {
pub node_id: u64,
pub metrics_tracked: usize,
pub total_samples: u64,
pub anomalies_detected: usize,
pub active_predictions: usize,
pub is_trained: bool,
pub last_training: Option<Duration>,
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_ml_optimizer_creation() {
let optimizer = MLClusterOptimizer::new(1);
assert!(optimizer.is_enabled().await);
}
#[tokio::test]
async fn test_record_metrics() {
let optimizer = MLClusterOptimizer::new(1);
for i in 0..100 {
optimizer.record_metric("test_metric", i as f64).await;
}
let stats = optimizer.get_statistics().await;
assert_eq!(stats.total_samples, 100);
assert_eq!(stats.metrics_tracked, 1);
}
#[tokio::test]
async fn test_anomaly_detection_zscore() {
let optimizer = MLClusterOptimizer::new(1);
for i in 0..50 {
optimizer
.record_metric("latency", 100.0 + (i as f64 % 10.0))
.await;
}
optimizer.record_metric("latency", 500.0).await;
let anomalies = optimizer.detect_anomalies().await;
assert!(!anomalies.is_empty());
assert!(anomalies[0].deviation_score > 3.0);
}
#[tokio::test]
async fn test_anomaly_severity_classification() {
let optimizer = MLClusterOptimizer::new(1);
assert_eq!(optimizer.classify_severity(3.5), AnomalySeverity::Medium);
assert_eq!(optimizer.classify_severity(4.5), AnomalySeverity::High);
assert_eq!(optimizer.classify_severity(5.5), AnomalySeverity::Critical);
assert_eq!(optimizer.classify_severity(2.5), AnomalySeverity::Low);
}
#[tokio::test]
async fn test_failure_prediction() {
let optimizer = MLClusterOptimizer::new(1);
for i in 0..100 {
optimizer
.record_metric("latency_ms", 100.0 + i as f64 * 2.0)
.await;
optimizer
.record_metric("error_rate", 0.01 + i as f64 * 0.001)
.await;
optimizer
.record_metric("memory_usage_bytes", 1000000.0 + i as f64 * 10000.0)
.await;
}
let prediction = optimizer.predict_failure(1).await;
assert!(prediction.is_some());
let prediction = prediction.unwrap();
assert!(prediction.probability >= 0.0 && prediction.probability <= 1.0);
assert!(!prediction.risk_factors.is_empty());
}
#[tokio::test]
async fn test_load_prediction() {
let optimizer = MLClusterOptimizer::new(1);
for i in 0..100 {
let load = 50.0 + i as f64 * 0.5 + (i as f64 * 0.1).sin() * 5.0;
optimizer.record_metric("cpu_load", load).await;
}
let prediction = optimizer
.predict_load("cpu_load", Duration::from_secs(3600))
.await;
assert!(prediction.is_some());
let prediction = prediction.unwrap();
assert!(prediction.predicted_load > 0.0);
assert!(prediction.upper_bound >= prediction.predicted_load);
assert!(prediction.lower_bound <= prediction.predicted_load);
}
#[tokio::test]
async fn test_trend_calculation() {
let optimizer = MLClusterOptimizer::new(1);
let values: Vec<f64> = (0..50).map(|i| i as f64 * 2.0).collect();
let trend = optimizer.calculate_trend(&values);
assert!(trend > 1.9 && trend < 2.1);
let values: Vec<f64> = (0..50).map(|i| 100.0 - i as f64 * 2.0).collect();
let trend = optimizer.calculate_trend(&values);
assert!(trend < -1.9 && trend > -2.1);
}
#[tokio::test]
async fn test_ensemble_detection() {
let config = AnomalyDetectionConfig {
ensemble_threshold: 0.6,
..Default::default()
};
let optimizer = MLClusterOptimizer::with_config(
1,
config,
FailurePredictionConfig::default(),
LoadPredictionConfig::default(),
);
for i in 0..50 {
optimizer
.record_metric("metric", 100.0 + (i % 5) as f64)
.await;
}
optimizer.record_metric("metric", 1000.0).await;
let anomalies = optimizer.detect_anomalies().await;
assert!(!anomalies.is_empty());
}
#[tokio::test]
async fn test_enable_disable() {
let optimizer = MLClusterOptimizer::new(1);
assert!(optimizer.is_enabled().await);
optimizer.disable().await;
assert!(!optimizer.is_enabled().await);
optimizer.record_metric("test", 100.0).await;
let stats = optimizer.get_statistics().await;
assert_eq!(stats.total_samples, 0);
optimizer.enable().await;
assert!(optimizer.is_enabled().await);
}
#[tokio::test]
async fn test_reset() {
let optimizer = MLClusterOptimizer::new(1);
for i in 0..50 {
optimizer.record_metric("test", i as f64).await;
}
let stats = optimizer.get_statistics().await;
assert_eq!(stats.total_samples, 50);
optimizer.reset().await;
let stats = optimizer.get_statistics().await;
assert_eq!(stats.total_samples, 0);
assert_eq!(stats.metrics_tracked, 0);
}
#[tokio::test]
async fn test_get_recent_anomalies() {
let optimizer = MLClusterOptimizer::new(1);
for _ in 0..50 {
optimizer.record_metric("test", 100.0).await;
}
optimizer.record_metric("test", 1000.0).await;
optimizer.detect_anomalies().await;
let recent = optimizer.get_recent_anomalies(10).await;
let _detected = !recent.is_empty();
}
#[tokio::test]
async fn test_median_calculation() {
let optimizer = MLClusterOptimizer::new(1);
let values = vec![1.0, 3.0, 5.0, 7.0, 9.0];
assert_eq!(optimizer.calculate_median(&values), 5.0);
let values = vec![1.0, 3.0, 5.0, 7.0];
assert_eq!(optimizer.calculate_median(&values), 4.0);
}
#[tokio::test]
async fn test_mad_anomaly_detection() {
let optimizer = MLClusterOptimizer::new(1);
let values: Vec<f64> = (0..50).map(|_| 100.0).collect();
let (is_anomaly, _score) = optimizer.detect_mad_anomaly(&values, 100.0);
assert!(!is_anomaly);
let values: Vec<f64> = (0..50).map(|i| 100.0 + (i % 5) as f64).collect();
let (is_anomaly, score) = optimizer.detect_mad_anomaly(&values, 200.0);
assert!(is_anomaly);
assert!(score > 0.0);
}
#[tokio::test]
async fn test_risk_factor_analysis() {
let optimizer = MLClusterOptimizer::new(1);
let values: Vec<f64> = (0..50).map(|i| 100.0 + (i as f64).powf(2.0)).collect();
let (score, trend) = optimizer.analyze_risk_factor(&values);
assert!(score > 0.0);
assert!(matches!(
trend,
TrendDirection::Increasing | TrendDirection::Volatile
));
}
#[tokio::test]
async fn test_seasonality_detection() {
let optimizer = MLClusterOptimizer::new(1);
let values: Vec<f64> = (0..100)
.map(|i| 100.0 + 20.0 * (i as f64 * std::f64::consts::PI / 12.0).sin())
.collect();
let seasonality = optimizer.detect_seasonality(&values);
let _ = seasonality;
}
#[tokio::test]
async fn test_holt_winters_decomposition() {
let optimizer = MLClusterOptimizer::new(1);
let values: Vec<f64> = (0..50).map(|i| 100.0 + i as f64).collect();
let (level, trend, _seasonal) = optimizer.holt_winters_decompose(&values);
assert!(level > 100.0);
assert!(trend > 0.0);
}
#[tokio::test]
async fn test_failure_recommendations() {
let optimizer = MLClusterOptimizer::new(1);
let risk_factors = vec![RiskFactor {
name: "latency_trend".to_string(),
score: 0.8,
trend: TrendDirection::Increasing,
description: "High latency".to_string(),
}];
let recommendations = optimizer.generate_failure_recommendations(0.8, &risk_factors);
assert!(!recommendations.is_empty());
}
#[tokio::test]
async fn test_t_value_calculation() {
let optimizer = MLClusterOptimizer::new(1);
let t_95 = optimizer.get_t_value(30, 0.95);
assert!(t_95 > 1.5 && t_95 < 2.5);
let t_99 = optimizer.get_t_value(30, 0.99);
assert!(t_99 > t_95);
}
#[tokio::test]
async fn test_exponential_smoothing_anomaly() {
let optimizer = MLClusterOptimizer::new(1);
let values: Vec<f64> = (0..50).map(|i| 100.0 + (i % 5) as f64).collect();
let (is_anomaly, score) = optimizer.detect_exponential_smoothing_anomaly(&values, 200.0);
assert!(is_anomaly);
assert!(score > 0.0);
}
#[tokio::test]
async fn test_iqr_anomaly_detection() {
let optimizer = MLClusterOptimizer::new(1);
let values: Vec<f64> = (0..50).map(|i| 100.0 + (i % 10) as f64).collect();
let (is_anomaly, _) = optimizer.detect_iqr_anomaly(&values, 105.0);
assert!(!is_anomaly);
let (is_anomaly, score) = optimizer.detect_iqr_anomaly(&values, 200.0);
assert!(is_anomaly);
assert!(score > 0.0);
}
#[tokio::test]
async fn test_statistics() {
let optimizer = MLClusterOptimizer::new(42);
let stats = optimizer.get_statistics().await;
assert_eq!(stats.node_id, 42);
assert_eq!(stats.total_samples, 0);
assert!(!stats.is_trained);
}
}