Skip to main content

oxirs_stream/
reinforcement_learning.rs

1//! # Reinforcement Learning for Stream Processing Optimization
2//!
3//! This module provides reinforcement learning agents that can automatically optimize
4//! stream processing parameters in real-time, learning from system feedback to improve
5//! performance metrics like throughput, latency, and resource utilization.
6//!
7//! ## Features
8//! - Q-Learning and Deep Q-Networks (DQN) for discrete action spaces
9//! - Policy gradient methods (REINFORCE, Actor-Critic) for continuous actions
10//! - Multi-armed bandit algorithms for hyperparameter tuning
11//! - Experience replay for stable learning
12//! - Adaptive exploration strategies (ε-greedy, UCB, Thompson sampling)
13//! - Reward shaping for complex optimization objectives
14//!
15//! ## Example Usage
16//! ```rust,ignore
17//! use oxirs_stream::reinforcement_learning::{RLAgent, RLConfig, RLAlgorithm};
18//!
19//! let config = RLConfig {
20//!     algorithm: RLAlgorithm::DQN,
21//!     learning_rate: 0.001,
22//!     discount_factor: 0.99,
23//!     ..Default::default()
24//! };
25//!
26//! let mut agent = RLAgent::new(config)?;
27//! let action = agent.select_action(&state).await?;
28//! let reward = execute_action(action);
29//! agent.learn(&state, action, reward, &next_state).await?;
30//! ```
31
32use anyhow::{anyhow, Result};
33use scirs2_core::ndarray_ext::{Array1, Array2};
34use scirs2_core::random::{Random, Rng};
35use serde::{Deserialize, Serialize};
36use std::collections::{HashMap, VecDeque};
37use std::sync::Arc;
38use tokio::sync::{Mutex, RwLock};
39use tracing::{debug, info};
40
41/// Reinforcement learning algorithm types
42#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
43pub enum RLAlgorithm {
44    /// Q-Learning (discrete actions)
45    QLearning,
46    /// Deep Q-Network
47    DQN,
48    /// SARSA (on-policy TD)
49    SARSA,
50    /// Actor-Critic
51    ActorCritic,
52    /// REINFORCE (policy gradient)
53    REINFORCE,
54    /// Proximal Policy Optimization
55    PPO,
56    /// Multi-Armed Bandit (UCB)
57    UCB,
58    /// Thompson Sampling
59    ThompsonSampling,
60    /// Epsilon-Greedy Bandit
61    EpsilonGreedy,
62}
63
64/// State representation for stream processing
65#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct State {
67    /// Current throughput (events/second)
68    pub throughput: f64,
69    /// Current latency (milliseconds)
70    pub latency_ms: f64,
71    /// CPU utilization (0-1)
72    pub cpu_utilization: f64,
73    /// Memory utilization (0-1)
74    pub memory_utilization: f64,
75    /// Queue depth
76    pub queue_depth: usize,
77    /// Error rate (0-1)
78    pub error_rate: f64,
79    /// Additional features
80    pub features: Vec<f64>,
81}
82
83impl State {
84    /// Convert state to feature vector
85    pub fn to_vector(&self) -> Vec<f64> {
86        let mut vec = vec![
87            self.throughput,
88            self.latency_ms,
89            self.cpu_utilization,
90            self.memory_utilization,
91            self.queue_depth as f64,
92            self.error_rate,
93        ];
94        vec.extend(&self.features);
95        vec
96    }
97
98    /// Get state dimension
99    pub fn dimension(&self) -> usize {
100        6 + self.features.len()
101    }
102}
103
104/// Action representation (can be discrete or continuous)
105#[derive(Debug, Clone, Serialize, Deserialize)]
106pub enum Action {
107    /// Discrete action (index)
108    Discrete(usize),
109    /// Continuous action (vector)
110    Continuous(Vec<f64>),
111}
112
113impl Action {
114    /// Get action as index (for discrete actions)
115    pub fn as_index(&self) -> Option<usize> {
116        match self {
117            Action::Discrete(idx) => Some(*idx),
118            _ => None,
119        }
120    }
121
122    /// Get action as vector (for continuous actions)
123    pub fn as_vector(&self) -> Option<&[f64]> {
124        match self {
125            Action::Continuous(vec) => Some(vec),
126            _ => None,
127        }
128    }
129}
130
131/// Experience tuple for replay buffer
132#[derive(Debug, Clone)]
133pub struct Experience {
134    /// Current state
135    pub state: State,
136    /// Action taken
137    pub action: Action,
138    /// Reward received
139    pub reward: f64,
140    /// Next state
141    pub next_state: State,
142    /// Whether episode terminated
143    pub done: bool,
144}
145
146/// RL agent configuration
147#[derive(Debug, Clone, Serialize, Deserialize)]
148pub struct RLConfig {
149    /// RL algorithm to use
150    pub algorithm: RLAlgorithm,
151    /// Learning rate
152    pub learning_rate: f64,
153    /// Discount factor (gamma)
154    pub discount_factor: f64,
155    /// Exploration rate (epsilon) for ε-greedy
156    pub epsilon: f64,
157    /// Epsilon decay rate
158    pub epsilon_decay: f64,
159    /// Minimum epsilon
160    pub epsilon_min: f64,
161    /// Experience replay buffer size
162    pub replay_buffer_size: usize,
163    /// Batch size for learning
164    pub batch_size: usize,
165    /// Target network update frequency (for DQN)
166    pub target_update_freq: usize,
167    /// Number of discrete actions (if applicable)
168    pub n_actions: usize,
169    /// Number of hidden units in neural network
170    pub hidden_units: Vec<usize>,
171    /// Enable prioritized experience replay
172    pub prioritized_replay: bool,
173    /// UCB exploration constant
174    pub ucb_c: f64,
175}
176
177impl Default for RLConfig {
178    fn default() -> Self {
179        Self {
180            algorithm: RLAlgorithm::DQN,
181            learning_rate: 0.001,
182            discount_factor: 0.99,
183            epsilon: 1.0,
184            epsilon_decay: 0.995,
185            epsilon_min: 0.01,
186            replay_buffer_size: 10000,
187            batch_size: 32,
188            target_update_freq: 100,
189            n_actions: 10,
190            hidden_units: vec![64, 64],
191            prioritized_replay: false,
192            ucb_c: 2.0,
193        }
194    }
195}
196
197/// Q-table for tabular RL
198type QTable = HashMap<String, Vec<f64>>;
199
200/// Neural network weights (simplified)
201#[derive(Debug, Clone)]
202pub struct NeuralNetwork {
203    /// Layer weights
204    pub weights: Vec<Array2<f64>>,
205    /// Layer biases
206    pub biases: Vec<Array1<f64>>,
207}
208
209impl NeuralNetwork {
210    /// Create a new neural network
211    pub fn new(
212        input_dim: usize,
213        hidden_dims: &[usize],
214        output_dim: usize,
215        rng: &mut Random,
216    ) -> Self {
217        let mut weights = Vec::new();
218        let mut biases = Vec::new();
219
220        let mut dims = vec![input_dim];
221        dims.extend(hidden_dims);
222        dims.push(output_dim);
223
224        for i in 0..dims.len() - 1 {
225            let w = Self::init_weights(dims[i], dims[i + 1], rng);
226            let b = Array1::zeros(dims[i + 1]);
227            weights.push(w);
228            biases.push(b);
229        }
230
231        Self { weights, biases }
232    }
233
234    /// Initialize weights with Xavier/Glorot initialization
235    fn init_weights(input_dim: usize, output_dim: usize, rng: &mut Random) -> Array2<f64> {
236        let scale = (2.0 / (input_dim + output_dim) as f64).sqrt();
237        let values: Vec<f64> = (0..input_dim * output_dim)
238            .map(|_| (rng.random::<f64>() * 2.0 - 1.0) * scale)
239            .collect();
240        Array2::from_shape_vec((input_dim, output_dim), values)
241            .expect("shape and vector length match")
242    }
243
244    /// Forward pass
245    pub fn forward(&self, input: &Array1<f64>) -> Array1<f64> {
246        let mut activation = input.clone();
247
248        for (w, b) in self.weights.iter().zip(&self.biases) {
249            // Linear transform
250            activation = activation.dot(w) + b;
251
252            // ReLU activation (except last layer)
253            if w != self
254                .weights
255                .last()
256                .expect("collection validated to be non-empty")
257            {
258                activation.mapv_inplace(|x| x.max(0.0));
259            }
260        }
261
262        activation
263    }
264
265    /// Update weights (simplified gradient descent)
266    pub fn update(&mut self, gradient_scale: f64, learning_rate: f64) {
267        for w in &mut self.weights {
268            w.mapv_inplace(|x| x - learning_rate * gradient_scale);
269        }
270    }
271}
272
273/// RL Agent statistics
274#[derive(Debug, Clone, Serialize, Deserialize)]
275pub struct RLStats {
276    /// Total steps taken
277    pub total_steps: u64,
278    /// Total episodes completed
279    pub total_episodes: u64,
280    /// Average reward per episode
281    pub avg_reward_per_episode: f64,
282    /// Current epsilon (exploration rate)
283    pub current_epsilon: f64,
284    /// Total reward accumulated
285    pub total_reward: f64,
286    /// Average Q-value
287    pub avg_q_value: f64,
288    /// Loss (for neural network methods)
289    pub avg_loss: f64,
290}
291
292impl Default for RLStats {
293    fn default() -> Self {
294        Self {
295            total_steps: 0,
296            total_episodes: 0,
297            avg_reward_per_episode: 0.0,
298            current_epsilon: 1.0,
299            total_reward: 0.0,
300            avg_q_value: 0.0,
301            avg_loss: 0.0,
302        }
303    }
304}
305
306/// Main RL Agent for stream optimization
307pub struct RLAgent {
308    config: RLConfig,
309    /// Q-table for tabular methods
310    q_table: Arc<RwLock<QTable>>,
311    /// Q-network for DQN
312    q_network: Arc<RwLock<Option<NeuralNetwork>>>,
313    /// Target network for DQN
314    target_network: Arc<RwLock<Option<NeuralNetwork>>>,
315    /// Experience replay buffer
316    replay_buffer: Arc<RwLock<VecDeque<Experience>>>,
317    /// Action counts (for bandits and UCB)
318    action_counts: Arc<RwLock<Vec<u64>>>,
319    /// Action rewards (for bandits)
320    action_rewards: Arc<RwLock<Vec<f64>>>,
321    /// Statistics
322    stats: Arc<RwLock<RLStats>>,
323    /// Random number generator
324    #[allow(clippy::arc_with_non_send_sync)]
325    rng: Arc<Mutex<Random>>,
326    /// Current episode reward
327    episode_reward: Arc<RwLock<f64>>,
328    /// Update counter
329    update_counter: Arc<RwLock<usize>>,
330}
331
332impl RLAgent {
333    /// Create a new RL agent
334    #[allow(clippy::arc_with_non_send_sync)]
335    pub fn new(config: RLConfig) -> Result<Self> {
336        let action_counts = vec![0u64; config.n_actions];
337        let action_rewards = vec![0.0; config.n_actions];
338        let buffer_size = config.replay_buffer_size;
339        let epsilon = config.epsilon;
340
341        Ok(Self {
342            config,
343            q_table: Arc::new(RwLock::new(HashMap::new())),
344            q_network: Arc::new(RwLock::new(None)),
345            target_network: Arc::new(RwLock::new(None)),
346            replay_buffer: Arc::new(RwLock::new(VecDeque::with_capacity(buffer_size))),
347            action_counts: Arc::new(RwLock::new(action_counts)),
348            action_rewards: Arc::new(RwLock::new(action_rewards)),
349            stats: Arc::new(RwLock::new(RLStats {
350                current_epsilon: epsilon,
351                ..Default::default()
352            })),
353            rng: Arc::new(Mutex::new(Random::default())),
354            episode_reward: Arc::new(RwLock::new(0.0)),
355            update_counter: Arc::new(RwLock::new(0)),
356        })
357    }
358
359    /// Initialize neural networks (for DQN/Actor-Critic)
360    pub async fn initialize_networks(&mut self, state_dim: usize) -> Result<()> {
361        if matches!(
362            self.config.algorithm,
363            RLAlgorithm::DQN | RLAlgorithm::ActorCritic | RLAlgorithm::PPO
364        ) {
365            let mut rng = self.rng.lock().await;
366
367            let q_net = NeuralNetwork::new(
368                state_dim,
369                &self.config.hidden_units,
370                self.config.n_actions,
371                &mut rng,
372            );
373
374            let target_net = NeuralNetwork::new(
375                state_dim,
376                &self.config.hidden_units,
377                self.config.n_actions,
378                &mut rng,
379            );
380
381            *self.q_network.write().await = Some(q_net);
382            *self.target_network.write().await = Some(target_net);
383
384            info!(
385                "Initialized neural networks with state_dim={}, n_actions={}",
386                state_dim, self.config.n_actions
387            );
388        }
389
390        Ok(())
391    }
392
393    /// Select an action given current state
394    pub async fn select_action(&self, state: &State) -> Result<Action> {
395        match self.config.algorithm {
396            RLAlgorithm::QLearning | RLAlgorithm::SARSA => {
397                self.select_action_q_learning(state).await
398            }
399            RLAlgorithm::DQN => self.select_action_dqn(state).await,
400            RLAlgorithm::UCB => self.select_action_ucb().await,
401            RLAlgorithm::ThompsonSampling => self.select_action_thompson().await,
402            RLAlgorithm::EpsilonGreedy => self.select_action_epsilon_greedy().await,
403            _ => {
404                // Default to ε-greedy
405                self.select_action_epsilon_greedy().await
406            }
407        }
408    }
409
410    /// ε-greedy action selection for Q-learning
411    async fn select_action_q_learning(&self, state: &State) -> Result<Action> {
412        let stats = self.stats.read().await;
413        let epsilon = stats.current_epsilon;
414        drop(stats);
415
416        let mut rng = self.rng.lock().await;
417
418        if rng.random::<f64>() < epsilon {
419            // Explore: random action
420            let action_idx = rng.random_range(0..self.config.n_actions);
421            Ok(Action::Discrete(action_idx))
422        } else {
423            // Exploit: best action from Q-table
424            let state_key = self.state_to_key(state);
425            let q_table = self.q_table.read().await;
426
427            let q_values = q_table
428                .get(&state_key)
429                .cloned()
430                .unwrap_or_else(|| vec![0.0; self.config.n_actions]);
431
432            let best_action = q_values
433                .iter()
434                .enumerate()
435                .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
436                .map(|(idx, _)| idx)
437                .unwrap_or(0);
438
439            Ok(Action::Discrete(best_action))
440        }
441    }
442
443    /// Action selection for DQN
444    async fn select_action_dqn(&self, state: &State) -> Result<Action> {
445        let stats = self.stats.read().await;
446        let epsilon = stats.current_epsilon;
447        drop(stats);
448
449        let mut rng = self.rng.lock().await;
450
451        if rng.random::<f64>() < epsilon {
452            let action_idx = rng.random_range(0..self.config.n_actions);
453            Ok(Action::Discrete(action_idx))
454        } else {
455            drop(rng);
456
457            let q_network = self.q_network.read().await;
458            if let Some(ref network) = *q_network {
459                let state_vec = Array1::from_vec(state.to_vector());
460                let q_values = network.forward(&state_vec);
461
462                let best_action = q_values
463                    .iter()
464                    .enumerate()
465                    .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
466                    .map(|(idx, _)| idx)
467                    .unwrap_or(0);
468
469                Ok(Action::Discrete(best_action))
470            } else {
471                Err(anyhow!("Q-network not initialized"))
472            }
473        }
474    }
475
476    /// UCB action selection
477    async fn select_action_ucb(&self) -> Result<Action> {
478        let action_counts = self.action_counts.read().await;
479        let action_rewards = self.action_rewards.read().await;
480        let stats = self.stats.read().await;
481        let total_steps = stats.total_steps;
482
483        let mut ucb_values = Vec::with_capacity(self.config.n_actions);
484
485        for i in 0..self.config.n_actions {
486            let count = action_counts[i];
487            let avg_reward = if count > 0 {
488                action_rewards[i] / count as f64
489            } else {
490                f64::INFINITY // Prioritize unexplored actions
491            };
492
493            let exploration_bonus = if count > 0 {
494                self.config.ucb_c * ((total_steps as f64).ln() / count as f64).sqrt()
495            } else {
496                f64::INFINITY
497            };
498
499            ucb_values.push(avg_reward + exploration_bonus);
500        }
501
502        let best_action = ucb_values
503            .iter()
504            .enumerate()
505            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
506            .map(|(idx, _)| idx)
507            .unwrap_or(0);
508
509        Ok(Action::Discrete(best_action))
510    }
511
512    /// Thompson sampling action selection
513    async fn select_action_thompson(&self) -> Result<Action> {
514        let action_counts = self.action_counts.read().await;
515        let action_rewards = self.action_rewards.read().await;
516        let mut rng = self.rng.lock().await;
517
518        let mut sampled_values = Vec::with_capacity(self.config.n_actions);
519
520        for i in 0..self.config.n_actions {
521            let count = action_counts[i];
522            let sum_reward = action_rewards[i];
523
524            // Beta distribution sampling (simplified)
525            let alpha = sum_reward + 1.0;
526            let beta = (count as f64 - sum_reward).max(0.0) + 1.0;
527
528            // Simplified beta sampling
529            let sample = rng.random::<f64>().powf(1.0 / alpha)
530                * (1.0 - rng.random::<f64>()).powf(1.0 / beta);
531            sampled_values.push(sample);
532        }
533
534        let best_action = sampled_values
535            .iter()
536            .enumerate()
537            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
538            .map(|(idx, _)| idx)
539            .unwrap_or(0);
540
541        Ok(Action::Discrete(best_action))
542    }
543
544    /// Epsilon-greedy bandit action selection
545    async fn select_action_epsilon_greedy(&self) -> Result<Action> {
546        let stats = self.stats.read().await;
547        let epsilon = stats.current_epsilon;
548        drop(stats);
549
550        let mut rng = self.rng.lock().await;
551
552        if rng.random::<f64>() < epsilon {
553            let action_idx = rng.random_range(0..self.config.n_actions);
554            Ok(Action::Discrete(action_idx))
555        } else {
556            drop(rng);
557
558            let action_counts = self.action_counts.read().await;
559            let action_rewards = self.action_rewards.read().await;
560
561            let best_action = (0..self.config.n_actions)
562                .max_by(|&a, &b| {
563                    let avg_a = if action_counts[a] > 0 {
564                        action_rewards[a] / action_counts[a] as f64
565                    } else {
566                        0.0
567                    };
568                    let avg_b = if action_counts[b] > 0 {
569                        action_rewards[b] / action_counts[b] as f64
570                    } else {
571                        0.0
572                    };
573                    avg_a
574                        .partial_cmp(&avg_b)
575                        .unwrap_or(std::cmp::Ordering::Equal)
576                })
577                .unwrap_or(0);
578
579            Ok(Action::Discrete(best_action))
580        }
581    }
582
583    /// Learn from experience (update model)
584    pub async fn learn(
585        &mut self,
586        state: &State,
587        action: Action,
588        reward: f64,
589        next_state: &State,
590    ) -> Result<()> {
591        // Add to replay buffer
592        let experience = Experience {
593            state: state.clone(),
594            action: action.clone(),
595            reward,
596            next_state: next_state.clone(),
597            done: false,
598        };
599
600        let mut replay_buffer = self.replay_buffer.write().await;
601        replay_buffer.push_back(experience);
602
603        if replay_buffer.len() > self.config.replay_buffer_size {
604            replay_buffer.pop_front();
605        }
606        drop(replay_buffer);
607
608        // Update statistics
609        *self.episode_reward.write().await += reward;
610        let mut stats = self.stats.write().await;
611        stats.total_steps += 1;
612        stats.total_reward += reward;
613
614        // Update action counts for bandits
615        if let Action::Discrete(idx) = action {
616            let mut counts = self.action_counts.write().await;
617            let mut rewards = self.action_rewards.write().await;
618            counts[idx] += 1;
619            rewards[idx] += reward;
620        }
621
622        // Perform learning update
623        match self.config.algorithm {
624            RLAlgorithm::QLearning | RLAlgorithm::SARSA => {
625                drop(stats);
626                self.update_q_learning(state, &action, reward, next_state)
627                    .await?;
628            }
629            RLAlgorithm::DQN => {
630                drop(stats);
631                self.update_dqn().await?;
632            }
633            _ => {
634                // Bandits don't need explicit update beyond counting
635            }
636        }
637
638        // Decay epsilon
639        let mut stats = self.stats.write().await;
640        stats.current_epsilon =
641            (stats.current_epsilon * self.config.epsilon_decay).max(self.config.epsilon_min);
642
643        Ok(())
644    }
645
646    /// Q-learning update
647    async fn update_q_learning(
648        &self,
649        state: &State,
650        action: &Action,
651        reward: f64,
652        next_state: &State,
653    ) -> Result<()> {
654        if let Action::Discrete(action_idx) = action {
655            let state_key = self.state_to_key(state);
656            let next_state_key = self.state_to_key(next_state);
657
658            let mut q_table = self.q_table.write().await;
659
660            // Get max next Q value first
661            let max_next_q = {
662                let next_q_values = q_table
663                    .entry(next_state_key)
664                    .or_insert_with(|| vec![0.0; self.config.n_actions]);
665                next_q_values
666                    .iter()
667                    .copied()
668                    .fold(f64::NEG_INFINITY, f64::max)
669            };
670
671            // Now update current Q value
672            let q_values = q_table
673                .entry(state_key.clone())
674                .or_insert_with(|| vec![0.0; self.config.n_actions]);
675
676            // Q-learning update: Q(s,a) <- Q(s,a) + α[r + γ*max(Q(s',a')) - Q(s,a)]
677            let current_q = q_values[*action_idx];
678            let td_target = reward + self.config.discount_factor * max_next_q;
679            let td_error = td_target - current_q;
680
681            q_values[*action_idx] += self.config.learning_rate * td_error;
682
683            debug!(
684                "Q-learning update: state={}, action={}, Q={:.4}",
685                state_key, action_idx, q_values[*action_idx]
686            );
687        }
688
689        Ok(())
690    }
691
692    /// DQN update
693    async fn update_dqn(&self) -> Result<()> {
694        let replay_buffer = self.replay_buffer.read().await;
695
696        if replay_buffer.len() < self.config.batch_size {
697            return Ok(()); // Not enough samples
698        }
699
700        // Sample random batch
701        let batch_indices: Vec<usize> = {
702            let mut rng = self.rng.lock().await;
703            (0..self.config.batch_size)
704                .map(|_| rng.random_range(0..replay_buffer.len()))
705                .collect()
706        };
707
708        // Clone the batch experiences before dropping the lock
709        let batch: Vec<Experience> = batch_indices
710            .iter()
711            .map(|&i| replay_buffer[i].clone())
712            .collect();
713        drop(replay_buffer);
714
715        // Compute TD errors and update network
716        let mut total_loss = 0.0;
717
718        let q_network = self.q_network.read().await;
719        let target_network = self.target_network.read().await;
720
721        if let (Some(ref q_net), Some(ref target_net)) = (&*q_network, &*target_network) {
722            for exp in &batch {
723                let state_vec = Array1::from_vec(exp.state.to_vector());
724                let next_state_vec = Array1::from_vec(exp.next_state.to_vector());
725
726                let q_values = q_net.forward(&state_vec);
727                let next_q_values = target_net.forward(&next_state_vec);
728
729                let max_next_q = next_q_values
730                    .iter()
731                    .copied()
732                    .fold(f64::NEG_INFINITY, f64::max);
733
734                if let Action::Discrete(action_idx) = exp.action {
735                    let td_target = exp.reward + self.config.discount_factor * max_next_q;
736                    let td_error = td_target - q_values[action_idx];
737                    total_loss += td_error * td_error;
738                }
739            }
740        }
741        drop(q_network);
742        drop(target_network);
743
744        // Update network (simplified)
745        let mut q_network = self.q_network.write().await;
746        if let Some(ref mut network) = *q_network {
747            let gradient_scale = total_loss / self.config.batch_size as f64;
748            network.update(gradient_scale, self.config.learning_rate);
749        }
750        drop(q_network);
751
752        // Update target network periodically
753        let mut counter = self.update_counter.write().await;
754        *counter += 1;
755
756        if *counter % self.config.target_update_freq == 0 {
757            let q_net = self.q_network.read().await;
758            if let Some(ref network) = *q_net {
759                *self.target_network.write().await = Some(network.clone());
760                debug!("Updated target network at step {}", *counter);
761            }
762        }
763
764        // Update stats
765        let mut stats = self.stats.write().await;
766        stats.avg_loss = (stats.avg_loss * (stats.total_steps - 1) as f64 + total_loss)
767            / stats.total_steps as f64;
768
769        Ok(())
770    }
771
772    /// End episode and record statistics
773    pub async fn end_episode(&mut self) -> Result<()> {
774        let episode_reward = *self.episode_reward.read().await;
775        *self.episode_reward.write().await = 0.0;
776
777        let mut stats = self.stats.write().await;
778        stats.total_episodes += 1;
779        stats.avg_reward_per_episode =
780            (stats.avg_reward_per_episode * (stats.total_episodes - 1) as f64 + episode_reward)
781                / stats.total_episodes as f64;
782
783        info!(
784            "Episode {} complete: reward={:.2}, avg_reward={:.2}",
785            stats.total_episodes, episode_reward, stats.avg_reward_per_episode
786        );
787
788        Ok(())
789    }
790
791    /// Convert state to string key for Q-table
792    fn state_to_key(&self, state: &State) -> String {
793        // Discretize continuous state for tabular methods
794        format!(
795            "{:.0}_{:.0}_{:.2}_{:.2}_{}_{ :.2}",
796            (state.throughput / 1000.0).round(),
797            (state.latency_ms / 10.0).round(),
798            (state.cpu_utilization * 10.0).round() / 10.0,
799            (state.memory_utilization * 10.0).round() / 10.0,
800            state.queue_depth / 100,
801            (state.error_rate * 100.0).round() / 100.0,
802        )
803    }
804
805    /// Get RL statistics
806    pub async fn get_stats(&self) -> RLStats {
807        self.stats.read().await.clone()
808    }
809
810    /// Get current epsilon
811    pub async fn get_epsilon(&self) -> f64 {
812        self.stats.read().await.current_epsilon
813    }
814
815    /// Set epsilon (for exploration control)
816    pub async fn set_epsilon(&mut self, epsilon: f64) {
817        self.stats.write().await.current_epsilon = epsilon.clamp(0.0, 1.0);
818    }
819
820    /// Export policy for deployment
821    pub async fn export_policy(&self) -> Result<String> {
822        let policy = match self.config.algorithm {
823            RLAlgorithm::QLearning | RLAlgorithm::SARSA => {
824                let q_table = self.q_table.read().await;
825                serde_json::json!({
826                    "algorithm": "Q-Learning",
827                    "q_table": q_table.iter().take(10).collect::<HashMap<_, _>>(), // Sample
828                })
829            }
830            _ => {
831                let stats = self.get_stats().await;
832                serde_json::json!({
833                    "algorithm": format!("{:?}", self.config.algorithm),
834                    "stats": stats,
835                })
836            }
837        };
838
839        Ok(serde_json::to_string_pretty(&policy)?)
840    }
841}
842
843#[cfg(test)]
844mod tests {
845    use super::*;
846
847    fn create_test_state() -> State {
848        State {
849            throughput: 10000.0,
850            latency_ms: 5.0,
851            cpu_utilization: 0.5,
852            memory_utilization: 0.6,
853            queue_depth: 100,
854            error_rate: 0.01,
855            features: vec![],
856        }
857    }
858
859    #[tokio::test]
860    async fn test_rl_agent_creation() {
861        let config = RLConfig::default();
862        let agent = RLAgent::new(config);
863        assert!(agent.is_ok());
864    }
865
866    #[tokio::test]
867    async fn test_q_learning_action_selection() {
868        let config = RLConfig {
869            algorithm: RLAlgorithm::QLearning,
870            n_actions: 5,
871            ..Default::default()
872        };
873
874        let agent = RLAgent::new(config).unwrap();
875        let state = create_test_state();
876
877        let action = agent.select_action(&state).await;
878        assert!(action.is_ok());
879
880        if let Action::Discrete(idx) = action.unwrap() {
881            assert!(idx < 5);
882        }
883    }
884
885    #[tokio::test]
886    async fn test_dqn_initialization() {
887        let config = RLConfig {
888            algorithm: RLAlgorithm::DQN,
889            n_actions: 10,
890            hidden_units: vec![32, 32],
891            ..Default::default()
892        };
893
894        let mut agent = RLAgent::new(config).unwrap();
895        let state = create_test_state();
896
897        agent.initialize_networks(state.dimension()).await.unwrap();
898
899        let action = agent.select_action(&state).await;
900        assert!(action.is_ok());
901    }
902
903    #[tokio::test]
904    async fn test_ucb_action_selection() {
905        let config = RLConfig {
906            algorithm: RLAlgorithm::UCB,
907            n_actions: 5,
908            ..Default::default()
909        };
910
911        let agent = RLAgent::new(config).unwrap();
912        let action = agent.select_action_ucb().await;
913        assert!(action.is_ok());
914    }
915
916    #[tokio::test]
917    async fn test_learning_update() {
918        let config = RLConfig {
919            algorithm: RLAlgorithm::QLearning,
920            n_actions: 3,
921            ..Default::default()
922        };
923
924        let mut agent = RLAgent::new(config).unwrap();
925        let state = create_test_state();
926        let action = Action::Discrete(1);
927        let reward = 1.0;
928        let next_state = create_test_state();
929
930        let result = agent.learn(&state, action, reward, &next_state).await;
931        assert!(result.is_ok());
932
933        let stats = agent.get_stats().await;
934        assert_eq!(stats.total_steps, 1);
935        assert_eq!(stats.total_reward, 1.0);
936    }
937
938    #[tokio::test]
939    async fn test_epsilon_decay() {
940        let config = RLConfig {
941            epsilon: 1.0,
942            epsilon_decay: 0.9,
943            epsilon_min: 0.1,
944            ..Default::default()
945        };
946
947        let mut agent = RLAgent::new(config).unwrap();
948        let initial_epsilon = agent.get_epsilon().await;
949
950        let state = create_test_state();
951        for _ in 0..10 {
952            agent
953                .learn(&state, Action::Discrete(0), 0.0, &state)
954                .await
955                .unwrap();
956        }
957
958        let final_epsilon = agent.get_epsilon().await;
959        assert!(final_epsilon < initial_epsilon);
960        assert!(final_epsilon >= 0.1);
961    }
962
963    #[tokio::test]
964    async fn test_episode_management() {
965        let config = RLConfig::default();
966        let mut agent = RLAgent::new(config).unwrap();
967
968        let state = create_test_state();
969        agent
970            .learn(&state, Action::Discrete(0), 1.0, &state)
971            .await
972            .unwrap();
973        agent
974            .learn(&state, Action::Discrete(1), 2.0, &state)
975            .await
976            .unwrap();
977
978        agent.end_episode().await.unwrap();
979
980        let stats = agent.get_stats().await;
981        assert_eq!(stats.total_episodes, 1);
982        assert!(stats.avg_reward_per_episode > 0.0);
983    }
984
985    #[tokio::test]
986    async fn test_replay_buffer() {
987        let config = RLConfig {
988            replay_buffer_size: 5,
989            ..Default::default()
990        };
991
992        let mut agent = RLAgent::new(config).unwrap();
993        let state = create_test_state();
994
995        for i in 0..10 {
996            agent
997                .learn(&state, Action::Discrete(0), i as f64, &state)
998                .await
999                .unwrap();
1000        }
1001
1002        let buffer = agent.replay_buffer.read().await;
1003        assert_eq!(buffer.len(), 5); // Should not exceed buffer size
1004    }
1005
1006    #[tokio::test]
1007    async fn test_export_policy() {
1008        let config = RLConfig {
1009            algorithm: RLAlgorithm::QLearning,
1010            ..Default::default()
1011        };
1012
1013        let mut agent = RLAgent::new(config).unwrap();
1014        let state = create_test_state();
1015
1016        agent
1017            .learn(&state, Action::Discrete(0), 1.0, &state)
1018            .await
1019            .unwrap();
1020
1021        let export = agent.export_policy().await;
1022        assert!(export.is_ok());
1023        assert!(export.unwrap().contains("algorithm"));
1024    }
1025
1026    #[tokio::test]
1027    async fn test_thompson_sampling() {
1028        let config = RLConfig {
1029            algorithm: RLAlgorithm::ThompsonSampling,
1030            n_actions: 5,
1031            ..Default::default()
1032        };
1033
1034        let agent = RLAgent::new(config).unwrap();
1035        let action = agent.select_action_thompson().await;
1036        assert!(action.is_ok());
1037    }
1038
1039    #[tokio::test]
1040    async fn test_multiple_episodes() {
1041        let config = RLConfig {
1042            algorithm: RLAlgorithm::QLearning,
1043            n_actions: 3,
1044            ..Default::default()
1045        };
1046
1047        let mut agent = RLAgent::new(config).unwrap();
1048        let state = create_test_state();
1049
1050        for episode in 0..5 {
1051            for _ in 0..10 {
1052                let action = agent.select_action(&state).await.unwrap();
1053                let reward = if episode % 2 == 0 { 1.0 } else { -1.0 };
1054                agent.learn(&state, action, reward, &state).await.unwrap();
1055            }
1056            agent.end_episode().await.unwrap();
1057        }
1058
1059        let stats = agent.get_stats().await;
1060        assert_eq!(stats.total_episodes, 5);
1061        assert_eq!(stats.total_steps, 50);
1062    }
1063}