use super::neural_network::NeuralNetwork;
use super::pattern_memory::OptimizationStrategy;
use crate::error::SparseResult;
use scirs2_core::random::{Rng, RngExt};
use std::collections::VecDeque;
#[derive(Debug, Clone, Copy)]
pub enum RLAlgorithm {
DQN,
PolicyGradient,
ActorCritic,
PPO,
SAC,
}
#[derive(Debug)]
#[allow(dead_code)]
pub(crate) struct RLAgent {
pub q_network: NeuralNetwork,
pub target_network: Option<NeuralNetwork>,
pub policy_network: Option<NeuralNetwork>,
pub value_network: Option<NeuralNetwork>,
pub algorithm: RLAlgorithm,
pub epsilon: f64,
pub learningrate: f64,
}
#[derive(Debug, Clone)]
#[allow(dead_code)]
pub(crate) struct Experience {
pub state: Vec<f64>,
pub action: OptimizationStrategy,
pub reward: f64,
pub next_state: Vec<f64>,
pub done: bool,
pub timestamp: u64,
}
#[derive(Debug)]
pub(crate) struct ExperienceBuffer {
pub buffer: VecDeque<Experience>,
pub capacity: usize,
pub priority_weights: Vec<f64>,
}
#[derive(Debug, Clone)]
pub struct PerformanceMetrics {
#[allow(dead_code)]
pub executiontime: f64,
#[allow(dead_code)]
pub cache_efficiency: f64,
#[allow(dead_code)]
pub simd_utilization: f64,
#[allow(dead_code)]
pub parallel_efficiency: f64,
#[allow(dead_code)]
pub memory_bandwidth: f64,
pub strategy_used: OptimizationStrategy,
}
impl RLAgent {
pub fn new(
state_size: usize,
action_size: usize,
algorithm: RLAlgorithm,
learning_rate: f64,
epsilon: f64,
) -> Self {
let q_network = NeuralNetwork::new(state_size, 3, 64, action_size, 4);
let target_network = match algorithm {
RLAlgorithm::DQN => Some(q_network.clone()),
_ => None,
};
let (policy_network, value_network) = match algorithm {
RLAlgorithm::ActorCritic | RLAlgorithm::PPO | RLAlgorithm::SAC => {
let policy = NeuralNetwork::new(state_size, 2, 32, action_size, 4);
let value = NeuralNetwork::new(state_size, 2, 32, 1, 4);
(Some(policy), Some(value))
}
_ => (None, None),
};
Self {
q_network,
target_network,
policy_network,
value_network,
algorithm,
epsilon,
learningrate: learning_rate,
}
}
pub fn select_action(&self, state: &[f64]) -> OptimizationStrategy {
let mut rng = scirs2_core::random::thread_rng();
if matches!(self.algorithm, RLAlgorithm::DQN) && rng.random::<f64>() < self.epsilon {
self.random_action()
} else {
self.greedy_action(state)
}
}
fn greedy_action(&self, state: &[f64]) -> OptimizationStrategy {
match self.algorithm {
RLAlgorithm::DQN => {
let q_values = self.q_network.forward(state);
let best_action_idx = q_values
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).expect("Operation failed"))
.map(|(idx, _)| idx)
.unwrap_or(0);
self.idx_to_strategy(best_action_idx)
}
RLAlgorithm::PolicyGradient
| RLAlgorithm::ActorCritic
| RLAlgorithm::PPO
| RLAlgorithm::SAC => {
if let Some(ref policy_network) = self.policy_network {
let action_probs = policy_network.forward(state);
let best_action_idx = action_probs
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).expect("Operation failed"))
.map(|(idx, _)| idx)
.unwrap_or(0);
self.idx_to_strategy(best_action_idx)
} else {
self.random_action()
}
}
}
}
fn random_action(&self) -> OptimizationStrategy {
let mut rng = scirs2_core::random::thread_rng();
let strategies = [
OptimizationStrategy::RowWiseCache,
OptimizationStrategy::ColumnWiseLocality,
OptimizationStrategy::BlockStructured,
OptimizationStrategy::DiagonalOptimized,
OptimizationStrategy::Hierarchical,
OptimizationStrategy::StreamingCompute,
OptimizationStrategy::SIMDVectorized,
OptimizationStrategy::ParallelWorkStealing,
OptimizationStrategy::AdaptiveHybrid,
];
strategies[rng.random_range(0..strategies.len())]
}
fn idx_to_strategy(&self, idx: usize) -> OptimizationStrategy {
match idx % 9 {
0 => OptimizationStrategy::RowWiseCache,
1 => OptimizationStrategy::ColumnWiseLocality,
2 => OptimizationStrategy::BlockStructured,
3 => OptimizationStrategy::DiagonalOptimized,
4 => OptimizationStrategy::Hierarchical,
5 => OptimizationStrategy::StreamingCompute,
6 => OptimizationStrategy::SIMDVectorized,
7 => OptimizationStrategy::ParallelWorkStealing,
_ => OptimizationStrategy::AdaptiveHybrid,
}
}
fn strategy_to_idx(&self, strategy: OptimizationStrategy) -> usize {
Self::strategy_to_idx_static(strategy)
}
fn strategy_to_idx_static(strategy: OptimizationStrategy) -> usize {
match strategy {
OptimizationStrategy::RowWiseCache => 0,
OptimizationStrategy::ColumnWiseLocality => 1,
OptimizationStrategy::BlockStructured => 2,
OptimizationStrategy::DiagonalOptimized => 3,
OptimizationStrategy::Hierarchical => 4,
OptimizationStrategy::StreamingCompute => 5,
OptimizationStrategy::SIMDVectorized => 6,
OptimizationStrategy::ParallelWorkStealing => 7,
OptimizationStrategy::AdaptiveHybrid => 8,
}
}
pub fn train(&mut self, experiences: &[Experience]) -> SparseResult<()> {
if experiences.is_empty() {
return Ok(());
}
match self.algorithm {
RLAlgorithm::DQN => self.train_dqn(experiences),
RLAlgorithm::PolicyGradient => self.train_policy_gradient(experiences),
RLAlgorithm::ActorCritic => self.train_actor_critic(experiences),
RLAlgorithm::PPO => self.train_ppo(experiences),
RLAlgorithm::SAC => self.train_sac(experiences),
}
}
fn train_dqn(&mut self, experiences: &[Experience]) -> SparseResult<()> {
for experience in experiences {
let current_q_values = self.q_network.forward(&experience.state);
let action_idx = self.strategy_to_idx(experience.action);
let target = if experience.done {
experience.reward
} else if let Some(ref target_network) = self.target_network {
let next_q_values = target_network.forward(&experience.next_state);
let max_next_q = next_q_values
.iter()
.fold(f64::NEG_INFINITY, |a, &b| a.max(b));
experience.reward + 0.99 * max_next_q } else {
experience.reward
};
let mut target_q_values = current_q_values;
if action_idx < target_q_values.len() {
target_q_values[action_idx] = target;
}
let (_, cache) = self.q_network.forward_with_cache(&experience.state);
let gradients =
self.q_network
.compute_gradients(&experience.state, &target_q_values, &cache);
self.q_network.update_weights(&gradients, self.learningrate);
}
Ok(())
}
fn train_policy_gradient(&mut self, experiences: &[Experience]) -> SparseResult<()> {
let learning_rate = self.learningrate;
if let Some(ref mut policy_network) = self.policy_network {
for experience in experiences {
let action_probs = policy_network.forward(&experience.state);
let action_idx = Self::strategy_to_idx_static(experience.action);
let mut target_probs = action_probs;
if action_idx < target_probs.len() {
target_probs[action_idx] += learning_rate * experience.reward;
}
let (_, cache) = policy_network.forward_with_cache(&experience.state);
let gradients =
policy_network.compute_gradients(&experience.state, &target_probs, &cache);
policy_network.update_weights(&gradients, learning_rate);
}
}
Ok(())
}
fn train_actor_critic(&mut self, experiences: &[Experience]) -> SparseResult<()> {
let learning_rate = self.learningrate;
for experience in experiences {
if let Some(ref mut value_network) = self.value_network {
let current_value = value_network.forward(&experience.state)[0];
let target_value = if experience.done {
experience.reward
} else {
let next_value = value_network.forward(&experience.next_state)[0];
experience.reward + 0.99 * next_value
};
let (_, cache) = value_network.forward_with_cache(&experience.state);
let gradients =
value_network.compute_gradients(&experience.state, &[target_value], &cache);
value_network.update_weights(&gradients, learning_rate);
if let Some(ref mut policy_network) = self.policy_network {
let advantage = target_value - current_value;
let action_probs = policy_network.forward(&experience.state);
let action_idx = Self::strategy_to_idx_static(experience.action);
let mut target_probs = action_probs;
if action_idx < target_probs.len() {
target_probs[action_idx] += learning_rate * advantage;
}
let (_, cache) = policy_network.forward_with_cache(&experience.state);
let gradients =
policy_network.compute_gradients(&experience.state, &target_probs, &cache);
policy_network.update_weights(&gradients, learning_rate);
}
}
}
Ok(())
}
fn train_ppo(&mut self, experiences: &[Experience]) -> SparseResult<()> {
self.train_actor_critic(experiences) }
fn train_sac(&mut self, experiences: &[Experience]) -> SparseResult<()> {
self.train_actor_critic(experiences) }
pub fn update_target_network(&mut self) {
if let Some(ref mut target_network) = self.target_network {
let params = self.q_network.get_parameters();
target_network.set_parameters(¶ms);
}
}
pub fn decay_epsilon(&mut self, decay_rate: f64) {
self.epsilon *= decay_rate;
self.epsilon = self.epsilon.max(0.01); }
pub fn estimate_value(&self, state: &[f64]) -> f64 {
match self.algorithm {
RLAlgorithm::DQN => {
let q_values = self.q_network.forward(state);
q_values.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b))
}
_ => {
if let Some(ref value_network) = self.value_network {
value_network.forward(state)[0]
} else {
0.0
}
}
}
}
}
impl ExperienceBuffer {
pub fn new(capacity: usize) -> Self {
Self {
buffer: VecDeque::new(),
capacity,
priority_weights: Vec::new(),
}
}
pub fn add(&mut self, experience: Experience) {
if self.buffer.len() >= self.capacity {
self.buffer.pop_front();
if !self.priority_weights.is_empty() {
self.priority_weights.remove(0);
}
}
self.buffer.push_back(experience);
self.priority_weights.push(1.0); }
pub fn sample(&self, batch_size: usize) -> Vec<Experience> {
let mut rng = scirs2_core::random::thread_rng();
let mut batch = Vec::new();
for _ in 0..batch_size.min(self.buffer.len()) {
let idx = rng.random_range(0..self.buffer.len());
if let Some(experience) = self.buffer.get(idx) {
batch.push(experience.clone());
}
}
batch
}
pub fn sample_prioritized(&self, batch_size: usize) -> Vec<Experience> {
if self.priority_weights.is_empty() {
return self.sample(batch_size);
}
let mut rng = scirs2_core::random::thread_rng();
let mut batch = Vec::new();
let total_weight: f64 = self.priority_weights.iter().sum();
for _ in 0..batch_size.min(self.buffer.len()) {
let mut weight_sum = 0.0;
let target = rng.random::<f64>() * total_weight;
for (idx, &weight) in self.priority_weights.iter().enumerate() {
weight_sum += weight;
if weight_sum >= target {
if let Some(experience) = self.buffer.get(idx) {
batch.push(experience.clone());
break;
}
}
}
}
batch
}
pub fn update_priority(&mut self, idx: usize, priority: f64) {
if idx < self.priority_weights.len() {
self.priority_weights[idx] = priority.max(0.01); }
}
pub fn len(&self) -> usize {
self.buffer.len()
}
pub fn is_empty(&self) -> bool {
self.buffer.is_empty()
}
pub fn clear(&mut self) {
self.buffer.clear();
self.priority_weights.clear();
}
}
impl PerformanceMetrics {
pub fn new(
execution_time: f64,
cache_efficiency: f64,
simd_utilization: f64,
parallel_efficiency: f64,
memory_bandwidth: f64,
strategy_used: OptimizationStrategy,
) -> Self {
Self {
executiontime: execution_time,
cache_efficiency,
simd_utilization,
parallel_efficiency,
memory_bandwidth,
strategy_used,
}
}
pub fn compute_reward(&self, baseline_time: f64) -> f64 {
let time_improvement = (baseline_time - self.executiontime) / baseline_time;
let efficiency_score =
(self.cache_efficiency + self.simd_utilization + self.parallel_efficiency) / 3.0;
time_improvement * 10.0 + efficiency_score * 5.0
}
pub fn performance_score(&self) -> f64 {
let time_score = 1.0 / (1.0 + self.executiontime); let efficiency_score = (self.cache_efficiency
+ self.simd_utilization
+ self.parallel_efficiency
+ self.memory_bandwidth)
/ 4.0;
(time_score + efficiency_score) / 2.0
}
}