oxirs_stream/
reinforcement_learning.rs

1//! # Reinforcement Learning for Stream Processing Optimization
2//!
3//! This module provides reinforcement learning agents that can automatically optimize
4//! stream processing parameters in real-time, learning from system feedback to improve
5//! performance metrics like throughput, latency, and resource utilization.
6//!
7//! ## Features
8//! - Q-Learning and Deep Q-Networks (DQN) for discrete action spaces
9//! - Policy gradient methods (REINFORCE, Actor-Critic) for continuous actions
10//! - Multi-armed bandit algorithms for hyperparameter tuning
11//! - Experience replay for stable learning
12//! - Adaptive exploration strategies (ε-greedy, UCB, Thompson sampling)
13//! - Reward shaping for complex optimization objectives
14//!
15//! ## Example Usage
16//! ```rust,ignore
17//! use oxirs_stream::reinforcement_learning::{RLAgent, RLConfig, RLAlgorithm};
18//!
19//! let config = RLConfig {
20//!     algorithm: RLAlgorithm::DQN,
21//!     learning_rate: 0.001,
22//!     discount_factor: 0.99,
23//!     ..Default::default()
24//! };
25//!
26//! let mut agent = RLAgent::new(config)?;
27//! let action = agent.select_action(&state).await?;
28//! let reward = execute_action(action);
29//! agent.learn(&state, action, reward, &next_state).await?;
30//! ```
31
32use anyhow::{anyhow, Result};
33use scirs2_core::ndarray_ext::{Array1, Array2};
34use scirs2_core::random::{Random, Rng};
35use serde::{Deserialize, Serialize};
36use std::collections::{HashMap, VecDeque};
37use std::sync::Arc;
38use tokio::sync::{Mutex, RwLock};
39use tracing::{debug, info};
40
41/// Reinforcement learning algorithm types
42#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
43pub enum RLAlgorithm {
44    /// Q-Learning (discrete actions)
45    QLearning,
46    /// Deep Q-Network
47    DQN,
48    /// SARSA (on-policy TD)
49    SARSA,
50    /// Actor-Critic
51    ActorCritic,
52    /// REINFORCE (policy gradient)
53    REINFORCE,
54    /// Proximal Policy Optimization
55    PPO,
56    /// Multi-Armed Bandit (UCB)
57    UCB,
58    /// Thompson Sampling
59    ThompsonSampling,
60    /// Epsilon-Greedy Bandit
61    EpsilonGreedy,
62}
63
64/// State representation for stream processing
65#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct State {
67    /// Current throughput (events/second)
68    pub throughput: f64,
69    /// Current latency (milliseconds)
70    pub latency_ms: f64,
71    /// CPU utilization (0-1)
72    pub cpu_utilization: f64,
73    /// Memory utilization (0-1)
74    pub memory_utilization: f64,
75    /// Queue depth
76    pub queue_depth: usize,
77    /// Error rate (0-1)
78    pub error_rate: f64,
79    /// Additional features
80    pub features: Vec<f64>,
81}
82
83impl State {
84    /// Convert state to feature vector
85    pub fn to_vector(&self) -> Vec<f64> {
86        let mut vec = vec![
87            self.throughput,
88            self.latency_ms,
89            self.cpu_utilization,
90            self.memory_utilization,
91            self.queue_depth as f64,
92            self.error_rate,
93        ];
94        vec.extend(&self.features);
95        vec
96    }
97
98    /// Get state dimension
99    pub fn dimension(&self) -> usize {
100        6 + self.features.len()
101    }
102}
103
104/// Action representation (can be discrete or continuous)
105#[derive(Debug, Clone, Serialize, Deserialize)]
106pub enum Action {
107    /// Discrete action (index)
108    Discrete(usize),
109    /// Continuous action (vector)
110    Continuous(Vec<f64>),
111}
112
113impl Action {
114    /// Get action as index (for discrete actions)
115    pub fn as_index(&self) -> Option<usize> {
116        match self {
117            Action::Discrete(idx) => Some(*idx),
118            _ => None,
119        }
120    }
121
122    /// Get action as vector (for continuous actions)
123    pub fn as_vector(&self) -> Option<&[f64]> {
124        match self {
125            Action::Continuous(vec) => Some(vec),
126            _ => None,
127        }
128    }
129}
130
131/// Experience tuple for replay buffer
132#[derive(Debug, Clone)]
133pub struct Experience {
134    /// Current state
135    pub state: State,
136    /// Action taken
137    pub action: Action,
138    /// Reward received
139    pub reward: f64,
140    /// Next state
141    pub next_state: State,
142    /// Whether episode terminated
143    pub done: bool,
144}
145
146/// RL agent configuration
147#[derive(Debug, Clone, Serialize, Deserialize)]
148pub struct RLConfig {
149    /// RL algorithm to use
150    pub algorithm: RLAlgorithm,
151    /// Learning rate
152    pub learning_rate: f64,
153    /// Discount factor (gamma)
154    pub discount_factor: f64,
155    /// Exploration rate (epsilon) for ε-greedy
156    pub epsilon: f64,
157    /// Epsilon decay rate
158    pub epsilon_decay: f64,
159    /// Minimum epsilon
160    pub epsilon_min: f64,
161    /// Experience replay buffer size
162    pub replay_buffer_size: usize,
163    /// Batch size for learning
164    pub batch_size: usize,
165    /// Target network update frequency (for DQN)
166    pub target_update_freq: usize,
167    /// Number of discrete actions (if applicable)
168    pub n_actions: usize,
169    /// Number of hidden units in neural network
170    pub hidden_units: Vec<usize>,
171    /// Enable prioritized experience replay
172    pub prioritized_replay: bool,
173    /// UCB exploration constant
174    pub ucb_c: f64,
175}
176
177impl Default for RLConfig {
178    fn default() -> Self {
179        Self {
180            algorithm: RLAlgorithm::DQN,
181            learning_rate: 0.001,
182            discount_factor: 0.99,
183            epsilon: 1.0,
184            epsilon_decay: 0.995,
185            epsilon_min: 0.01,
186            replay_buffer_size: 10000,
187            batch_size: 32,
188            target_update_freq: 100,
189            n_actions: 10,
190            hidden_units: vec![64, 64],
191            prioritized_replay: false,
192            ucb_c: 2.0,
193        }
194    }
195}
196
197/// Q-table for tabular RL
198type QTable = HashMap<String, Vec<f64>>;
199
200/// Neural network weights (simplified)
201#[derive(Debug, Clone)]
202pub struct NeuralNetwork {
203    /// Layer weights
204    pub weights: Vec<Array2<f64>>,
205    /// Layer biases
206    pub biases: Vec<Array1<f64>>,
207}
208
209impl NeuralNetwork {
210    /// Create a new neural network
211    pub fn new(
212        input_dim: usize,
213        hidden_dims: &[usize],
214        output_dim: usize,
215        rng: &mut Random,
216    ) -> Self {
217        let mut weights = Vec::new();
218        let mut biases = Vec::new();
219
220        let mut dims = vec![input_dim];
221        dims.extend(hidden_dims);
222        dims.push(output_dim);
223
224        for i in 0..dims.len() - 1 {
225            let w = Self::init_weights(dims[i], dims[i + 1], rng);
226            let b = Array1::zeros(dims[i + 1]);
227            weights.push(w);
228            biases.push(b);
229        }
230
231        Self { weights, biases }
232    }
233
234    /// Initialize weights with Xavier/Glorot initialization
235    fn init_weights(input_dim: usize, output_dim: usize, rng: &mut Random) -> Array2<f64> {
236        let scale = (2.0 / (input_dim + output_dim) as f64).sqrt();
237        let values: Vec<f64> = (0..input_dim * output_dim)
238            .map(|_| (rng.random::<f64>() * 2.0 - 1.0) * scale)
239            .collect();
240        Array2::from_shape_vec((input_dim, output_dim), values).unwrap()
241    }
242
243    /// Forward pass
244    pub fn forward(&self, input: &Array1<f64>) -> Array1<f64> {
245        let mut activation = input.clone();
246
247        for (w, b) in self.weights.iter().zip(&self.biases) {
248            // Linear transform
249            activation = activation.dot(w) + b;
250
251            // ReLU activation (except last layer)
252            if w != self.weights.last().unwrap() {
253                activation.mapv_inplace(|x| x.max(0.0));
254            }
255        }
256
257        activation
258    }
259
260    /// Update weights (simplified gradient descent)
261    pub fn update(&mut self, gradient_scale: f64, learning_rate: f64) {
262        for w in &mut self.weights {
263            w.mapv_inplace(|x| x - learning_rate * gradient_scale);
264        }
265    }
266}
267
268/// RL Agent statistics
269#[derive(Debug, Clone, Serialize, Deserialize)]
270pub struct RLStats {
271    /// Total steps taken
272    pub total_steps: u64,
273    /// Total episodes completed
274    pub total_episodes: u64,
275    /// Average reward per episode
276    pub avg_reward_per_episode: f64,
277    /// Current epsilon (exploration rate)
278    pub current_epsilon: f64,
279    /// Total reward accumulated
280    pub total_reward: f64,
281    /// Average Q-value
282    pub avg_q_value: f64,
283    /// Loss (for neural network methods)
284    pub avg_loss: f64,
285}
286
287impl Default for RLStats {
288    fn default() -> Self {
289        Self {
290            total_steps: 0,
291            total_episodes: 0,
292            avg_reward_per_episode: 0.0,
293            current_epsilon: 1.0,
294            total_reward: 0.0,
295            avg_q_value: 0.0,
296            avg_loss: 0.0,
297        }
298    }
299}
300
301/// Main RL Agent for stream optimization
302pub struct RLAgent {
303    config: RLConfig,
304    /// Q-table for tabular methods
305    q_table: Arc<RwLock<QTable>>,
306    /// Q-network for DQN
307    q_network: Arc<RwLock<Option<NeuralNetwork>>>,
308    /// Target network for DQN
309    target_network: Arc<RwLock<Option<NeuralNetwork>>>,
310    /// Experience replay buffer
311    replay_buffer: Arc<RwLock<VecDeque<Experience>>>,
312    /// Action counts (for bandits and UCB)
313    action_counts: Arc<RwLock<Vec<u64>>>,
314    /// Action rewards (for bandits)
315    action_rewards: Arc<RwLock<Vec<f64>>>,
316    /// Statistics
317    stats: Arc<RwLock<RLStats>>,
318    /// Random number generator
319    #[allow(clippy::arc_with_non_send_sync)]
320    rng: Arc<Mutex<Random>>,
321    /// Current episode reward
322    episode_reward: Arc<RwLock<f64>>,
323    /// Update counter
324    update_counter: Arc<RwLock<usize>>,
325}
326
327impl RLAgent {
328    /// Create a new RL agent
329    #[allow(clippy::arc_with_non_send_sync)]
330    pub fn new(config: RLConfig) -> Result<Self> {
331        let action_counts = vec![0u64; config.n_actions];
332        let action_rewards = vec![0.0; config.n_actions];
333        let buffer_size = config.replay_buffer_size;
334        let epsilon = config.epsilon;
335
336        Ok(Self {
337            config,
338            q_table: Arc::new(RwLock::new(HashMap::new())),
339            q_network: Arc::new(RwLock::new(None)),
340            target_network: Arc::new(RwLock::new(None)),
341            replay_buffer: Arc::new(RwLock::new(VecDeque::with_capacity(buffer_size))),
342            action_counts: Arc::new(RwLock::new(action_counts)),
343            action_rewards: Arc::new(RwLock::new(action_rewards)),
344            stats: Arc::new(RwLock::new(RLStats {
345                current_epsilon: epsilon,
346                ..Default::default()
347            })),
348            rng: Arc::new(Mutex::new(Random::default())),
349            episode_reward: Arc::new(RwLock::new(0.0)),
350            update_counter: Arc::new(RwLock::new(0)),
351        })
352    }
353
354    /// Initialize neural networks (for DQN/Actor-Critic)
355    pub async fn initialize_networks(&mut self, state_dim: usize) -> Result<()> {
356        if matches!(
357            self.config.algorithm,
358            RLAlgorithm::DQN | RLAlgorithm::ActorCritic | RLAlgorithm::PPO
359        ) {
360            let mut rng = self.rng.lock().await;
361
362            let q_net = NeuralNetwork::new(
363                state_dim,
364                &self.config.hidden_units,
365                self.config.n_actions,
366                &mut rng,
367            );
368
369            let target_net = NeuralNetwork::new(
370                state_dim,
371                &self.config.hidden_units,
372                self.config.n_actions,
373                &mut rng,
374            );
375
376            *self.q_network.write().await = Some(q_net);
377            *self.target_network.write().await = Some(target_net);
378
379            info!(
380                "Initialized neural networks with state_dim={}, n_actions={}",
381                state_dim, self.config.n_actions
382            );
383        }
384
385        Ok(())
386    }
387
388    /// Select an action given current state
389    pub async fn select_action(&self, state: &State) -> Result<Action> {
390        match self.config.algorithm {
391            RLAlgorithm::QLearning | RLAlgorithm::SARSA => {
392                self.select_action_q_learning(state).await
393            }
394            RLAlgorithm::DQN => self.select_action_dqn(state).await,
395            RLAlgorithm::UCB => self.select_action_ucb().await,
396            RLAlgorithm::ThompsonSampling => self.select_action_thompson().await,
397            RLAlgorithm::EpsilonGreedy => self.select_action_epsilon_greedy().await,
398            _ => {
399                // Default to ε-greedy
400                self.select_action_epsilon_greedy().await
401            }
402        }
403    }
404
405    /// ε-greedy action selection for Q-learning
406    async fn select_action_q_learning(&self, state: &State) -> Result<Action> {
407        let stats = self.stats.read().await;
408        let epsilon = stats.current_epsilon;
409        drop(stats);
410
411        let mut rng = self.rng.lock().await;
412
413        if rng.random::<f64>() < epsilon {
414            // Explore: random action
415            let action_idx = rng.random_range(0..self.config.n_actions);
416            Ok(Action::Discrete(action_idx))
417        } else {
418            // Exploit: best action from Q-table
419            let state_key = self.state_to_key(state);
420            let q_table = self.q_table.read().await;
421
422            let q_values = q_table
423                .get(&state_key)
424                .cloned()
425                .unwrap_or_else(|| vec![0.0; self.config.n_actions]);
426
427            let best_action = q_values
428                .iter()
429                .enumerate()
430                .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
431                .map(|(idx, _)| idx)
432                .unwrap_or(0);
433
434            Ok(Action::Discrete(best_action))
435        }
436    }
437
438    /// Action selection for DQN
439    async fn select_action_dqn(&self, state: &State) -> Result<Action> {
440        let stats = self.stats.read().await;
441        let epsilon = stats.current_epsilon;
442        drop(stats);
443
444        let mut rng = self.rng.lock().await;
445
446        if rng.random::<f64>() < epsilon {
447            let action_idx = rng.random_range(0..self.config.n_actions);
448            Ok(Action::Discrete(action_idx))
449        } else {
450            drop(rng);
451
452            let q_network = self.q_network.read().await;
453            if let Some(ref network) = *q_network {
454                let state_vec = Array1::from_vec(state.to_vector());
455                let q_values = network.forward(&state_vec);
456
457                let best_action = q_values
458                    .iter()
459                    .enumerate()
460                    .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
461                    .map(|(idx, _)| idx)
462                    .unwrap_or(0);
463
464                Ok(Action::Discrete(best_action))
465            } else {
466                Err(anyhow!("Q-network not initialized"))
467            }
468        }
469    }
470
471    /// UCB action selection
472    async fn select_action_ucb(&self) -> Result<Action> {
473        let action_counts = self.action_counts.read().await;
474        let action_rewards = self.action_rewards.read().await;
475        let stats = self.stats.read().await;
476        let total_steps = stats.total_steps;
477
478        let mut ucb_values = Vec::with_capacity(self.config.n_actions);
479
480        for i in 0..self.config.n_actions {
481            let count = action_counts[i];
482            let avg_reward = if count > 0 {
483                action_rewards[i] / count as f64
484            } else {
485                f64::INFINITY // Prioritize unexplored actions
486            };
487
488            let exploration_bonus = if count > 0 {
489                self.config.ucb_c * ((total_steps as f64).ln() / count as f64).sqrt()
490            } else {
491                f64::INFINITY
492            };
493
494            ucb_values.push(avg_reward + exploration_bonus);
495        }
496
497        let best_action = ucb_values
498            .iter()
499            .enumerate()
500            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
501            .map(|(idx, _)| idx)
502            .unwrap_or(0);
503
504        Ok(Action::Discrete(best_action))
505    }
506
507    /// Thompson sampling action selection
508    async fn select_action_thompson(&self) -> Result<Action> {
509        let action_counts = self.action_counts.read().await;
510        let action_rewards = self.action_rewards.read().await;
511        let mut rng = self.rng.lock().await;
512
513        let mut sampled_values = Vec::with_capacity(self.config.n_actions);
514
515        for i in 0..self.config.n_actions {
516            let count = action_counts[i];
517            let sum_reward = action_rewards[i];
518
519            // Beta distribution sampling (simplified)
520            let alpha = sum_reward + 1.0;
521            let beta = (count as f64 - sum_reward).max(0.0) + 1.0;
522
523            // Simplified beta sampling
524            let sample = rng.random::<f64>().powf(1.0 / alpha)
525                * (1.0 - rng.random::<f64>()).powf(1.0 / beta);
526            sampled_values.push(sample);
527        }
528
529        let best_action = sampled_values
530            .iter()
531            .enumerate()
532            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
533            .map(|(idx, _)| idx)
534            .unwrap_or(0);
535
536        Ok(Action::Discrete(best_action))
537    }
538
539    /// Epsilon-greedy bandit action selection
540    async fn select_action_epsilon_greedy(&self) -> Result<Action> {
541        let stats = self.stats.read().await;
542        let epsilon = stats.current_epsilon;
543        drop(stats);
544
545        let mut rng = self.rng.lock().await;
546
547        if rng.random::<f64>() < epsilon {
548            let action_idx = rng.random_range(0..self.config.n_actions);
549            Ok(Action::Discrete(action_idx))
550        } else {
551            drop(rng);
552
553            let action_counts = self.action_counts.read().await;
554            let action_rewards = self.action_rewards.read().await;
555
556            let best_action = (0..self.config.n_actions)
557                .max_by(|&a, &b| {
558                    let avg_a = if action_counts[a] > 0 {
559                        action_rewards[a] / action_counts[a] as f64
560                    } else {
561                        0.0
562                    };
563                    let avg_b = if action_counts[b] > 0 {
564                        action_rewards[b] / action_counts[b] as f64
565                    } else {
566                        0.0
567                    };
568                    avg_a
569                        .partial_cmp(&avg_b)
570                        .unwrap_or(std::cmp::Ordering::Equal)
571                })
572                .unwrap_or(0);
573
574            Ok(Action::Discrete(best_action))
575        }
576    }
577
578    /// Learn from experience (update model)
579    pub async fn learn(
580        &mut self,
581        state: &State,
582        action: Action,
583        reward: f64,
584        next_state: &State,
585    ) -> Result<()> {
586        // Add to replay buffer
587        let experience = Experience {
588            state: state.clone(),
589            action: action.clone(),
590            reward,
591            next_state: next_state.clone(),
592            done: false,
593        };
594
595        let mut replay_buffer = self.replay_buffer.write().await;
596        replay_buffer.push_back(experience);
597
598        if replay_buffer.len() > self.config.replay_buffer_size {
599            replay_buffer.pop_front();
600        }
601        drop(replay_buffer);
602
603        // Update statistics
604        *self.episode_reward.write().await += reward;
605        let mut stats = self.stats.write().await;
606        stats.total_steps += 1;
607        stats.total_reward += reward;
608
609        // Update action counts for bandits
610        if let Action::Discrete(idx) = action {
611            let mut counts = self.action_counts.write().await;
612            let mut rewards = self.action_rewards.write().await;
613            counts[idx] += 1;
614            rewards[idx] += reward;
615        }
616
617        // Perform learning update
618        match self.config.algorithm {
619            RLAlgorithm::QLearning | RLAlgorithm::SARSA => {
620                drop(stats);
621                self.update_q_learning(state, &action, reward, next_state)
622                    .await?;
623            }
624            RLAlgorithm::DQN => {
625                drop(stats);
626                self.update_dqn().await?;
627            }
628            _ => {
629                // Bandits don't need explicit update beyond counting
630            }
631        }
632
633        // Decay epsilon
634        let mut stats = self.stats.write().await;
635        stats.current_epsilon =
636            (stats.current_epsilon * self.config.epsilon_decay).max(self.config.epsilon_min);
637
638        Ok(())
639    }
640
641    /// Q-learning update
642    async fn update_q_learning(
643        &self,
644        state: &State,
645        action: &Action,
646        reward: f64,
647        next_state: &State,
648    ) -> Result<()> {
649        if let Action::Discrete(action_idx) = action {
650            let state_key = self.state_to_key(state);
651            let next_state_key = self.state_to_key(next_state);
652
653            let mut q_table = self.q_table.write().await;
654
655            // Get max next Q value first
656            let max_next_q = {
657                let next_q_values = q_table
658                    .entry(next_state_key)
659                    .or_insert_with(|| vec![0.0; self.config.n_actions]);
660                next_q_values
661                    .iter()
662                    .copied()
663                    .fold(f64::NEG_INFINITY, f64::max)
664            };
665
666            // Now update current Q value
667            let q_values = q_table
668                .entry(state_key.clone())
669                .or_insert_with(|| vec![0.0; self.config.n_actions]);
670
671            // Q-learning update: Q(s,a) <- Q(s,a) + α[r + γ*max(Q(s',a')) - Q(s,a)]
672            let current_q = q_values[*action_idx];
673            let td_target = reward + self.config.discount_factor * max_next_q;
674            let td_error = td_target - current_q;
675
676            q_values[*action_idx] += self.config.learning_rate * td_error;
677
678            debug!(
679                "Q-learning update: state={}, action={}, Q={:.4}",
680                state_key, action_idx, q_values[*action_idx]
681            );
682        }
683
684        Ok(())
685    }
686
687    /// DQN update
688    async fn update_dqn(&self) -> Result<()> {
689        let replay_buffer = self.replay_buffer.read().await;
690
691        if replay_buffer.len() < self.config.batch_size {
692            return Ok(()); // Not enough samples
693        }
694
695        // Sample random batch
696        let batch_indices: Vec<usize> = {
697            let mut rng = self.rng.lock().await;
698            (0..self.config.batch_size)
699                .map(|_| rng.random_range(0..replay_buffer.len()))
700                .collect()
701        };
702
703        // Clone the batch experiences before dropping the lock
704        let batch: Vec<Experience> = batch_indices
705            .iter()
706            .map(|&i| replay_buffer[i].clone())
707            .collect();
708        drop(replay_buffer);
709
710        // Compute TD errors and update network
711        let mut total_loss = 0.0;
712
713        let q_network = self.q_network.read().await;
714        let target_network = self.target_network.read().await;
715
716        if let (Some(ref q_net), Some(ref target_net)) = (&*q_network, &*target_network) {
717            for exp in &batch {
718                let state_vec = Array1::from_vec(exp.state.to_vector());
719                let next_state_vec = Array1::from_vec(exp.next_state.to_vector());
720
721                let q_values = q_net.forward(&state_vec);
722                let next_q_values = target_net.forward(&next_state_vec);
723
724                let max_next_q = next_q_values
725                    .iter()
726                    .copied()
727                    .fold(f64::NEG_INFINITY, f64::max);
728
729                if let Action::Discrete(action_idx) = exp.action {
730                    let td_target = exp.reward + self.config.discount_factor * max_next_q;
731                    let td_error = td_target - q_values[action_idx];
732                    total_loss += td_error * td_error;
733                }
734            }
735        }
736        drop(q_network);
737        drop(target_network);
738
739        // Update network (simplified)
740        let mut q_network = self.q_network.write().await;
741        if let Some(ref mut network) = *q_network {
742            let gradient_scale = total_loss / self.config.batch_size as f64;
743            network.update(gradient_scale, self.config.learning_rate);
744        }
745        drop(q_network);
746
747        // Update target network periodically
748        let mut counter = self.update_counter.write().await;
749        *counter += 1;
750
751        if *counter % self.config.target_update_freq == 0 {
752            let q_net = self.q_network.read().await;
753            if let Some(ref network) = *q_net {
754                *self.target_network.write().await = Some(network.clone());
755                debug!("Updated target network at step {}", *counter);
756            }
757        }
758
759        // Update stats
760        let mut stats = self.stats.write().await;
761        stats.avg_loss = (stats.avg_loss * (stats.total_steps - 1) as f64 + total_loss)
762            / stats.total_steps as f64;
763
764        Ok(())
765    }
766
767    /// End episode and record statistics
768    pub async fn end_episode(&mut self) -> Result<()> {
769        let episode_reward = *self.episode_reward.read().await;
770        *self.episode_reward.write().await = 0.0;
771
772        let mut stats = self.stats.write().await;
773        stats.total_episodes += 1;
774        stats.avg_reward_per_episode =
775            (stats.avg_reward_per_episode * (stats.total_episodes - 1) as f64 + episode_reward)
776                / stats.total_episodes as f64;
777
778        info!(
779            "Episode {} complete: reward={:.2}, avg_reward={:.2}",
780            stats.total_episodes, episode_reward, stats.avg_reward_per_episode
781        );
782
783        Ok(())
784    }
785
786    /// Convert state to string key for Q-table
787    fn state_to_key(&self, state: &State) -> String {
788        // Discretize continuous state for tabular methods
789        format!(
790            "{:.0}_{:.0}_{:.2}_{:.2}_{}_{ :.2}",
791            (state.throughput / 1000.0).round(),
792            (state.latency_ms / 10.0).round(),
793            (state.cpu_utilization * 10.0).round() / 10.0,
794            (state.memory_utilization * 10.0).round() / 10.0,
795            state.queue_depth / 100,
796            (state.error_rate * 100.0).round() / 100.0,
797        )
798    }
799
800    /// Get RL statistics
801    pub async fn get_stats(&self) -> RLStats {
802        self.stats.read().await.clone()
803    }
804
805    /// Get current epsilon
806    pub async fn get_epsilon(&self) -> f64 {
807        self.stats.read().await.current_epsilon
808    }
809
810    /// Set epsilon (for exploration control)
811    pub async fn set_epsilon(&mut self, epsilon: f64) {
812        self.stats.write().await.current_epsilon = epsilon.clamp(0.0, 1.0);
813    }
814
815    /// Export policy for deployment
816    pub async fn export_policy(&self) -> Result<String> {
817        let policy = match self.config.algorithm {
818            RLAlgorithm::QLearning | RLAlgorithm::SARSA => {
819                let q_table = self.q_table.read().await;
820                serde_json::json!({
821                    "algorithm": "Q-Learning",
822                    "q_table": q_table.iter().take(10).collect::<HashMap<_, _>>(), // Sample
823                })
824            }
825            _ => {
826                let stats = self.get_stats().await;
827                serde_json::json!({
828                    "algorithm": format!("{:?}", self.config.algorithm),
829                    "stats": stats,
830                })
831            }
832        };
833
834        Ok(serde_json::to_string_pretty(&policy)?)
835    }
836}
837
838#[cfg(test)]
839mod tests {
840    use super::*;
841
842    fn create_test_state() -> State {
843        State {
844            throughput: 10000.0,
845            latency_ms: 5.0,
846            cpu_utilization: 0.5,
847            memory_utilization: 0.6,
848            queue_depth: 100,
849            error_rate: 0.01,
850            features: vec![],
851        }
852    }
853
854    #[tokio::test]
855    async fn test_rl_agent_creation() {
856        let config = RLConfig::default();
857        let agent = RLAgent::new(config);
858        assert!(agent.is_ok());
859    }
860
861    #[tokio::test]
862    async fn test_q_learning_action_selection() {
863        let config = RLConfig {
864            algorithm: RLAlgorithm::QLearning,
865            n_actions: 5,
866            ..Default::default()
867        };
868
869        let agent = RLAgent::new(config).unwrap();
870        let state = create_test_state();
871
872        let action = agent.select_action(&state).await;
873        assert!(action.is_ok());
874
875        if let Action::Discrete(idx) = action.unwrap() {
876            assert!(idx < 5);
877        }
878    }
879
880    #[tokio::test]
881    async fn test_dqn_initialization() {
882        let config = RLConfig {
883            algorithm: RLAlgorithm::DQN,
884            n_actions: 10,
885            hidden_units: vec![32, 32],
886            ..Default::default()
887        };
888
889        let mut agent = RLAgent::new(config).unwrap();
890        let state = create_test_state();
891
892        agent.initialize_networks(state.dimension()).await.unwrap();
893
894        let action = agent.select_action(&state).await;
895        assert!(action.is_ok());
896    }
897
898    #[tokio::test]
899    async fn test_ucb_action_selection() {
900        let config = RLConfig {
901            algorithm: RLAlgorithm::UCB,
902            n_actions: 5,
903            ..Default::default()
904        };
905
906        let agent = RLAgent::new(config).unwrap();
907        let action = agent.select_action_ucb().await;
908        assert!(action.is_ok());
909    }
910
911    #[tokio::test]
912    async fn test_learning_update() {
913        let config = RLConfig {
914            algorithm: RLAlgorithm::QLearning,
915            n_actions: 3,
916            ..Default::default()
917        };
918
919        let mut agent = RLAgent::new(config).unwrap();
920        let state = create_test_state();
921        let action = Action::Discrete(1);
922        let reward = 1.0;
923        let next_state = create_test_state();
924
925        let result = agent.learn(&state, action, reward, &next_state).await;
926        assert!(result.is_ok());
927
928        let stats = agent.get_stats().await;
929        assert_eq!(stats.total_steps, 1);
930        assert_eq!(stats.total_reward, 1.0);
931    }
932
933    #[tokio::test]
934    async fn test_epsilon_decay() {
935        let config = RLConfig {
936            epsilon: 1.0,
937            epsilon_decay: 0.9,
938            epsilon_min: 0.1,
939            ..Default::default()
940        };
941
942        let mut agent = RLAgent::new(config).unwrap();
943        let initial_epsilon = agent.get_epsilon().await;
944
945        let state = create_test_state();
946        for _ in 0..10 {
947            agent
948                .learn(&state, Action::Discrete(0), 0.0, &state)
949                .await
950                .unwrap();
951        }
952
953        let final_epsilon = agent.get_epsilon().await;
954        assert!(final_epsilon < initial_epsilon);
955        assert!(final_epsilon >= 0.1);
956    }
957
958    #[tokio::test]
959    async fn test_episode_management() {
960        let config = RLConfig::default();
961        let mut agent = RLAgent::new(config).unwrap();
962
963        let state = create_test_state();
964        agent
965            .learn(&state, Action::Discrete(0), 1.0, &state)
966            .await
967            .unwrap();
968        agent
969            .learn(&state, Action::Discrete(1), 2.0, &state)
970            .await
971            .unwrap();
972
973        agent.end_episode().await.unwrap();
974
975        let stats = agent.get_stats().await;
976        assert_eq!(stats.total_episodes, 1);
977        assert!(stats.avg_reward_per_episode > 0.0);
978    }
979
980    #[tokio::test]
981    async fn test_replay_buffer() {
982        let config = RLConfig {
983            replay_buffer_size: 5,
984            ..Default::default()
985        };
986
987        let mut agent = RLAgent::new(config).unwrap();
988        let state = create_test_state();
989
990        for i in 0..10 {
991            agent
992                .learn(&state, Action::Discrete(0), i as f64, &state)
993                .await
994                .unwrap();
995        }
996
997        let buffer = agent.replay_buffer.read().await;
998        assert_eq!(buffer.len(), 5); // Should not exceed buffer size
999    }
1000
1001    #[tokio::test]
1002    async fn test_export_policy() {
1003        let config = RLConfig {
1004            algorithm: RLAlgorithm::QLearning,
1005            ..Default::default()
1006        };
1007
1008        let mut agent = RLAgent::new(config).unwrap();
1009        let state = create_test_state();
1010
1011        agent
1012            .learn(&state, Action::Discrete(0), 1.0, &state)
1013            .await
1014            .unwrap();
1015
1016        let export = agent.export_policy().await;
1017        assert!(export.is_ok());
1018        assert!(export.unwrap().contains("algorithm"));
1019    }
1020
1021    #[tokio::test]
1022    async fn test_thompson_sampling() {
1023        let config = RLConfig {
1024            algorithm: RLAlgorithm::ThompsonSampling,
1025            n_actions: 5,
1026            ..Default::default()
1027        };
1028
1029        let agent = RLAgent::new(config).unwrap();
1030        let action = agent.select_action_thompson().await;
1031        assert!(action.is_ok());
1032    }
1033
1034    #[tokio::test]
1035    async fn test_multiple_episodes() {
1036        let config = RLConfig {
1037            algorithm: RLAlgorithm::QLearning,
1038            n_actions: 3,
1039            ..Default::default()
1040        };
1041
1042        let mut agent = RLAgent::new(config).unwrap();
1043        let state = create_test_state();
1044
1045        for episode in 0..5 {
1046            for _ in 0..10 {
1047                let action = agent.select_action(&state).await.unwrap();
1048                let reward = if episode % 2 == 0 { 1.0 } else { -1.0 };
1049                agent.learn(&state, action, reward, &state).await.unwrap();
1050            }
1051            agent.end_episode().await.unwrap();
1052        }
1053
1054        let stats = agent.get_stats().await;
1055        assert_eq!(stats.total_episodes, 5);
1056        assert_eq!(stats.total_steps, 50);
1057    }
1058}