use anyhow::{anyhow, Result};
use scirs2_core::metrics::MetricsRegistry;
use scirs2_core::ndarray_ext::Array1;
use scirs2_core::profiling::Profiler;
use scirs2_core::random::{rngs::StdRng, Random};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
use tracing::{debug, info};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum ConsensusParameter {
ElectionTimeout,
HeartbeatInterval,
LogBatchSize,
SnapshotThreshold,
MaxInflightAppends,
}
impl ConsensusParameter {
pub fn range(&self) -> (f64, f64) {
match self {
ConsensusParameter::ElectionTimeout => (150.0, 300.0),
ConsensusParameter::HeartbeatInterval => (50.0, 150.0),
ConsensusParameter::LogBatchSize => (10.0, 1000.0),
ConsensusParameter::SnapshotThreshold => (1000.0, 10000.0),
ConsensusParameter::MaxInflightAppends => (1.0, 100.0),
}
}
pub fn discretize(&self, value: f64, num_bins: usize) -> usize {
let (min_val, max_val) = self.range();
let bin_size = (max_val - min_val) / num_bins as f64;
let normalized = (value - min_val).max(0.0).min(max_val - min_val);
(normalized / bin_size).floor().min((num_bins - 1) as f64) as usize
}
pub fn from_action(&self, action_idx: usize, num_bins: usize) -> f64 {
let (min_val, max_val) = self.range();
let bin_size = (max_val - min_val) / num_bins as f64;
min_val + (action_idx as f64 + 0.5) * bin_size
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ClusterState {
pub node_count: usize,
pub avg_throughput: f64,
pub avg_latency: f64,
pub election_frequency: f64,
pub replication_lag: f64,
pub cpu_utilization: f64,
pub network_congestion: f64,
}
impl ClusterState {
pub fn to_features(&self) -> Array1<f64> {
Array1::from_vec(vec![
self.node_count as f64 / 100.0, self.avg_throughput / 10000.0, self.avg_latency / 1000.0, self.election_frequency / 10.0, self.replication_lag / 1000.0, self.cpu_utilization, self.network_congestion, ])
}
pub fn discretize(&self, num_bins_per_feature: usize) -> usize {
let features = self.to_features();
let mut state_idx = 0;
let mut multiplier = 1;
for &feature in features.iter() {
let bin = (feature * num_bins_per_feature as f64)
.floor()
.min((num_bins_per_feature - 1) as f64) as usize;
state_idx += bin * multiplier;
multiplier *= num_bins_per_feature;
}
state_idx
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PerformanceReward {
pub throughput_score: f64,
pub latency_score: f64,
pub consistency_score: f64,
pub efficiency_score: f64,
pub total_reward: f64,
}
impl PerformanceReward {
pub fn from_state(state: &ClusterState, weights: &RewardWeights) -> Self {
let latency_score = weights.latency_weight * (1.0 - (state.avg_latency / 1000.0).min(1.0));
let throughput_score =
weights.throughput_weight * (state.avg_throughput / 10000.0).min(1.0);
let consistency_score =
weights.consistency_weight * (1.0 - (state.election_frequency / 10.0).min(1.0));
let efficiency_score = weights.efficiency_weight
* (1.0 - state.cpu_utilization)
* (state.avg_throughput / 10000.0).min(1.0);
let total_reward = throughput_score + latency_score + consistency_score + efficiency_score;
Self {
throughput_score,
latency_score,
consistency_score,
efficiency_score,
total_reward,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RewardWeights {
pub throughput_weight: f64,
pub latency_weight: f64,
pub consistency_weight: f64,
pub efficiency_weight: f64,
}
impl Default for RewardWeights {
fn default() -> Self {
Self {
throughput_weight: 0.3,
latency_weight: 0.3,
consistency_weight: 0.25,
efficiency_weight: 0.15,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum RLAlgorithm {
QLearning,
SARSA,
PolicyGradient,
ActorCritic,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RLOptimizerConfig {
pub algorithm: RLAlgorithm,
pub learning_rate: f64,
pub discount_factor: f64,
pub epsilon: f64,
pub epsilon_decay: f64,
pub min_epsilon: f64,
pub num_state_bins: usize,
pub num_action_bins: usize,
pub reward_weights: RewardWeights,
pub use_experience_replay: bool,
pub replay_buffer_size: usize,
pub replay_batch_size: usize,
}
impl Default for RLOptimizerConfig {
fn default() -> Self {
Self {
algorithm: RLAlgorithm::QLearning,
learning_rate: 0.1,
discount_factor: 0.95,
epsilon: 0.3,
epsilon_decay: 0.995,
min_epsilon: 0.01,
num_state_bins: 5,
num_action_bins: 10,
reward_weights: RewardWeights::default(),
use_experience_replay: true,
replay_buffer_size: 10000,
replay_batch_size: 32,
}
}
}
#[derive(Debug, Clone)]
struct Experience {
state: ClusterState,
action: HashMap<ConsensusParameter, usize>,
reward: f64,
next_state: ClusterState,
#[allow(dead_code)] done: bool,
}
type ActionRepr = Vec<(ConsensusParameter, usize)>;
fn action_to_repr(action: &HashMap<ConsensusParameter, usize>) -> ActionRepr {
let mut vec: Vec<_> = action.iter().map(|(&k, &v)| (k, v)).collect();
vec.sort_by_key(|&(param, _)| param as u8);
vec
}
pub struct RLConsensusOptimizer {
config: RLOptimizerConfig,
q_table: Arc<RwLock<HashMap<(usize, ActionRepr), f64>>>,
replay_buffer: Arc<RwLock<Vec<Experience>>>,
epsilon: Arc<RwLock<f64>>,
rng: Arc<RwLock<Random<StdRng>>>,
#[allow(dead_code)] metrics: Arc<MetricsRegistry>,
profiler: Arc<Profiler>,
episode_count: Arc<RwLock<usize>>,
total_steps: Arc<RwLock<usize>>,
best_reward: Arc<RwLock<f64>>,
}
impl RLConsensusOptimizer {
pub fn new(config: RLOptimizerConfig) -> Result<Self> {
let metrics = Arc::new(MetricsRegistry::new());
let profiler = Arc::new(Profiler::new());
let epsilon = config.epsilon;
Ok(Self {
config,
q_table: Arc::new(RwLock::new(HashMap::new())),
replay_buffer: Arc::new(RwLock::new(Vec::new())),
epsilon: Arc::new(RwLock::new(epsilon)),
rng: Arc::new(RwLock::new(Random::seed(42))),
metrics,
profiler,
episode_count: Arc::new(RwLock::new(0)),
total_steps: Arc::new(RwLock::new(0)),
best_reward: Arc::new(RwLock::new(f64::NEG_INFINITY)),
})
}
pub async fn select_action(
&self,
state: &ClusterState,
) -> Result<HashMap<ConsensusParameter, usize>> {
let epsilon = *self.epsilon.read().await;
let mut rng = self.rng.write().await;
if rng.gen_range(0.0..1.0) < epsilon {
self.random_action(&mut rng).await
} else {
self.greedy_action(state).await
}
}
async fn random_action(
&self,
rng: &mut Random<StdRng>,
) -> Result<HashMap<ConsensusParameter, usize>> {
let mut action = HashMap::new();
for param in &[
ConsensusParameter::ElectionTimeout,
ConsensusParameter::HeartbeatInterval,
ConsensusParameter::LogBatchSize,
ConsensusParameter::SnapshotThreshold,
ConsensusParameter::MaxInflightAppends,
] {
let action_idx = (rng.gen_range(0.0..1.0) * self.config.num_action_bins as f64)
.floor()
.min((self.config.num_action_bins - 1) as f64)
as usize;
action.insert(*param, action_idx);
}
Ok(action)
}
async fn greedy_action(
&self,
state: &ClusterState,
) -> Result<HashMap<ConsensusParameter, usize>> {
let state_idx = state.discretize(self.config.num_state_bins);
let q_table = self.q_table.read().await;
let mut best_action = None;
let mut best_q_value = f64::NEG_INFINITY;
let mut rng = self.rng.write().await;
for _ in 0..100 {
let action = self.random_action(&mut rng).await?;
let action_repr = action_to_repr(&action);
let q_value = *q_table.get(&(state_idx, action_repr)).unwrap_or(&0.0);
if q_value > best_q_value {
best_q_value = q_value;
best_action = Some(action);
}
}
best_action.ok_or_else(|| anyhow!("No action found"))
}
pub async fn update_q_value(
&self,
state: &ClusterState,
action: &HashMap<ConsensusParameter, usize>,
reward: f64,
next_state: &ClusterState,
) -> Result<()> {
let state_idx = state.discretize(self.config.num_state_bins);
let next_state_idx = next_state.discretize(self.config.num_state_bins);
let mut q_table = self.q_table.write().await;
let action_repr = action_to_repr(action);
let current_q = *q_table
.get(&(state_idx, action_repr.clone()))
.unwrap_or(&0.0);
let max_next_q = self.max_q_value(&q_table, next_state_idx).await;
let td_target = reward + self.config.discount_factor * max_next_q;
let td_error = td_target - current_q;
let new_q = current_q + self.config.learning_rate * td_error;
q_table.insert((state_idx, action_repr), new_q);
debug!(
"Q-learning update: Q({}, {:?}) = {:.4} (was {:.4})",
state_idx, action, new_q, current_q
);
Ok(())
}
async fn max_q_value(
&self,
q_table: &HashMap<(usize, ActionRepr), f64>,
state_idx: usize,
) -> f64 {
q_table
.iter()
.filter(|((s, _), _)| *s == state_idx)
.map(|(_, &q)| q)
.fold(0.0, f64::max)
}
pub async fn store_experience(
&self,
state: ClusterState,
action: HashMap<ConsensusParameter, usize>,
reward: f64,
next_state: ClusterState,
done: bool,
) -> Result<()> {
if !self.config.use_experience_replay {
return Ok(());
}
let mut buffer = self.replay_buffer.write().await;
let experience = Experience {
state,
action,
reward,
next_state,
done,
};
buffer.push(experience);
if buffer.len() > self.config.replay_buffer_size {
buffer.remove(0);
}
Ok(())
}
pub async fn replay_train(&self) -> Result<()> {
if !self.config.use_experience_replay {
return Ok(());
}
let buffer = self.replay_buffer.read().await;
if buffer.len() < self.config.replay_batch_size {
return Ok(()); }
let mut rng = self.rng.write().await;
let mut batch_indices = Vec::new();
for _ in 0..self.config.replay_batch_size {
let idx = (rng.gen_range(0.0..1.0) * buffer.len() as f64).floor() as usize;
batch_indices.push(idx);
}
drop(rng);
for idx in batch_indices {
let exp = &buffer[idx];
self.update_q_value(&exp.state, &exp.action, exp.reward, &exp.next_state)
.await?;
}
Ok(())
}
pub async fn decay_epsilon(&self) -> Result<()> {
let mut epsilon = self.epsilon.write().await;
*epsilon = (*epsilon * self.config.epsilon_decay).max(self.config.min_epsilon);
debug!("Epsilon decayed to {:.4}", *epsilon);
Ok(())
}
pub async fn complete_episode(&self, total_reward: f64) -> Result<()> {
let mut episode_count = self.episode_count.write().await;
*episode_count += 1;
let mut best_reward = self.best_reward.write().await;
if total_reward > *best_reward {
*best_reward = total_reward;
info!(
"New best reward: {:.4} (episode {})",
total_reward, *episode_count
);
}
self.decay_epsilon().await?;
self.replay_train().await?;
info!(
"Episode {} completed: reward={:.4}, best={:.4}",
*episode_count, total_reward, *best_reward
);
Ok(())
}
pub async fn get_optimized_parameters(
&self,
state: &ClusterState,
) -> Result<HashMap<ConsensusParameter, f64>> {
let action = self.greedy_action(state).await?;
let mut parameters = HashMap::new();
for (param, action_idx) in action {
let value = param.from_action(action_idx, self.config.num_action_bins);
parameters.insert(param, value);
}
Ok(parameters)
}
pub fn get_performance_report(&self) -> String {
self.profiler.get_report()
}
pub async fn get_statistics(&self) -> RLStatistics {
RLStatistics {
episode_count: *self.episode_count.read().await,
total_steps: *self.total_steps.read().await,
best_reward: *self.best_reward.read().await,
current_epsilon: *self.epsilon.read().await,
q_table_size: self.q_table.read().await.len(),
replay_buffer_size: self.replay_buffer.read().await.len(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RLStatistics {
pub episode_count: usize,
pub total_steps: usize,
pub best_reward: f64,
pub current_epsilon: f64,
pub q_table_size: usize,
pub replay_buffer_size: usize,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_consensus_parameter_range() {
assert_eq!(ConsensusParameter::ElectionTimeout.range(), (150.0, 300.0));
assert_eq!(ConsensusParameter::HeartbeatInterval.range(), (50.0, 150.0));
}
#[test]
fn test_consensus_parameter_discretize() {
let param = ConsensusParameter::ElectionTimeout;
assert_eq!(param.discretize(150.0, 10), 0);
assert_eq!(param.discretize(225.0, 10), 5);
assert_eq!(param.discretize(300.0, 10), 9);
}
#[test]
fn test_consensus_parameter_from_action() {
let param = ConsensusParameter::ElectionTimeout;
let value = param.from_action(5, 10);
assert!((value - 232.5).abs() < 1.0);
}
#[test]
fn test_cluster_state_to_features() {
let state = ClusterState {
node_count: 10,
avg_throughput: 5000.0,
avg_latency: 50.0,
election_frequency: 2.0,
replication_lag: 100.0,
cpu_utilization: 0.5,
network_congestion: 0.3,
};
let features = state.to_features();
assert_eq!(features.len(), 7);
assert!((features[0] - 0.1).abs() < 1e-6); assert!((features[1] - 0.5).abs() < 1e-6); }
#[test]
fn test_cluster_state_discretize() {
let state = ClusterState {
node_count: 10,
avg_throughput: 5000.0,
avg_latency: 50.0,
election_frequency: 2.0,
replication_lag: 100.0,
cpu_utilization: 0.5,
network_congestion: 0.3,
};
let state_idx = state.discretize(5);
assert!(state_idx < usize::MAX);
}
#[test]
fn test_performance_reward_calculation() {
let state = ClusterState {
node_count: 10,
avg_throughput: 8000.0,
avg_latency: 20.0,
election_frequency: 1.0,
replication_lag: 50.0,
cpu_utilization: 0.4,
network_congestion: 0.2,
};
let weights = RewardWeights::default();
let reward = PerformanceReward::from_state(&state, &weights);
assert!(reward.total_reward > 0.0);
assert!(reward.throughput_score > 0.0);
assert!(reward.latency_score > 0.0);
assert!(reward.consistency_score > 0.0);
assert!(reward.efficiency_score > 0.0);
}
#[tokio::test]
async fn test_rl_optimizer_creation() {
let config = RLOptimizerConfig::default();
let optimizer = RLConsensusOptimizer::new(config);
assert!(optimizer.is_ok());
}
#[tokio::test]
async fn test_rl_optimizer_random_action() {
let config = RLOptimizerConfig::default();
let optimizer = RLConsensusOptimizer::new(config).unwrap();
let mut rng = optimizer.rng.write().await;
let action = optimizer.random_action(&mut rng).await.unwrap();
assert_eq!(action.len(), 5); for param in &[
ConsensusParameter::ElectionTimeout,
ConsensusParameter::HeartbeatInterval,
ConsensusParameter::LogBatchSize,
ConsensusParameter::SnapshotThreshold,
ConsensusParameter::MaxInflightAppends,
] {
assert!(action.contains_key(param));
}
}
#[tokio::test]
async fn test_rl_optimizer_q_learning_update() {
let config = RLOptimizerConfig::default();
let optimizer = RLConsensusOptimizer::new(config).unwrap();
let state = ClusterState {
node_count: 10,
avg_throughput: 5000.0,
avg_latency: 50.0,
election_frequency: 2.0,
replication_lag: 100.0,
cpu_utilization: 0.5,
network_congestion: 0.3,
};
let mut action = HashMap::new();
action.insert(ConsensusParameter::ElectionTimeout, 5);
action.insert(ConsensusParameter::HeartbeatInterval, 3);
action.insert(ConsensusParameter::LogBatchSize, 7);
action.insert(ConsensusParameter::SnapshotThreshold, 4);
action.insert(ConsensusParameter::MaxInflightAppends, 6);
let next_state = state.clone();
let result = optimizer
.update_q_value(&state, &action, 1.0, &next_state)
.await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_rl_optimizer_experience_replay() {
let mut config = RLOptimizerConfig::default();
config.use_experience_replay = true;
config.replay_buffer_size = 100;
let optimizer = RLConsensusOptimizer::new(config).unwrap();
let state = ClusterState {
node_count: 10,
avg_throughput: 5000.0,
avg_latency: 50.0,
election_frequency: 2.0,
replication_lag: 100.0,
cpu_utilization: 0.5,
network_congestion: 0.3,
};
let mut action = HashMap::new();
action.insert(ConsensusParameter::ElectionTimeout, 5);
optimizer
.store_experience(state.clone(), action, 1.0, state.clone(), false)
.await
.unwrap();
let buffer_size = optimizer.replay_buffer.read().await.len();
assert_eq!(buffer_size, 1);
}
#[tokio::test]
async fn test_rl_optimizer_epsilon_decay() {
let config = RLOptimizerConfig::default();
let optimizer = RLConsensusOptimizer::new(config).unwrap();
let initial_epsilon = *optimizer.epsilon.read().await;
optimizer.decay_epsilon().await.unwrap();
let decayed_epsilon = *optimizer.epsilon.read().await;
assert!(decayed_epsilon < initial_epsilon);
assert!(decayed_epsilon >= optimizer.config.min_epsilon);
}
#[tokio::test]
async fn test_rl_optimizer_get_statistics() {
let config = RLOptimizerConfig::default();
let optimizer = RLConsensusOptimizer::new(config).unwrap();
let stats = optimizer.get_statistics().await;
assert_eq!(stats.episode_count, 0);
assert_eq!(stats.total_steps, 0);
assert_eq!(stats.q_table_size, 0);
assert_eq!(stats.replay_buffer_size, 0);
}
}