1use anyhow::{anyhow, Result};
33use scirs2_core::ndarray_ext::{Array1, Array2};
34use scirs2_core::random::{Random, Rng};
35use serde::{Deserialize, Serialize};
36use std::collections::{HashMap, VecDeque};
37use std::sync::Arc;
38use tokio::sync::{Mutex, RwLock};
39use tracing::{debug, info};
40
41#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
43pub enum RLAlgorithm {
44 QLearning,
46 DQN,
48 SARSA,
50 ActorCritic,
52 REINFORCE,
54 PPO,
56 UCB,
58 ThompsonSampling,
60 EpsilonGreedy,
62}
63
64#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct State {
67 pub throughput: f64,
69 pub latency_ms: f64,
71 pub cpu_utilization: f64,
73 pub memory_utilization: f64,
75 pub queue_depth: usize,
77 pub error_rate: f64,
79 pub features: Vec<f64>,
81}
82
83impl State {
84 pub fn to_vector(&self) -> Vec<f64> {
86 let mut vec = vec![
87 self.throughput,
88 self.latency_ms,
89 self.cpu_utilization,
90 self.memory_utilization,
91 self.queue_depth as f64,
92 self.error_rate,
93 ];
94 vec.extend(&self.features);
95 vec
96 }
97
98 pub fn dimension(&self) -> usize {
100 6 + self.features.len()
101 }
102}
103
104#[derive(Debug, Clone, Serialize, Deserialize)]
106pub enum Action {
107 Discrete(usize),
109 Continuous(Vec<f64>),
111}
112
113impl Action {
114 pub fn as_index(&self) -> Option<usize> {
116 match self {
117 Action::Discrete(idx) => Some(*idx),
118 _ => None,
119 }
120 }
121
122 pub fn as_vector(&self) -> Option<&[f64]> {
124 match self {
125 Action::Continuous(vec) => Some(vec),
126 _ => None,
127 }
128 }
129}
130
131#[derive(Debug, Clone)]
133pub struct Experience {
134 pub state: State,
136 pub action: Action,
138 pub reward: f64,
140 pub next_state: State,
142 pub done: bool,
144}
145
146#[derive(Debug, Clone, Serialize, Deserialize)]
148pub struct RLConfig {
149 pub algorithm: RLAlgorithm,
151 pub learning_rate: f64,
153 pub discount_factor: f64,
155 pub epsilon: f64,
157 pub epsilon_decay: f64,
159 pub epsilon_min: f64,
161 pub replay_buffer_size: usize,
163 pub batch_size: usize,
165 pub target_update_freq: usize,
167 pub n_actions: usize,
169 pub hidden_units: Vec<usize>,
171 pub prioritized_replay: bool,
173 pub ucb_c: f64,
175}
176
177impl Default for RLConfig {
178 fn default() -> Self {
179 Self {
180 algorithm: RLAlgorithm::DQN,
181 learning_rate: 0.001,
182 discount_factor: 0.99,
183 epsilon: 1.0,
184 epsilon_decay: 0.995,
185 epsilon_min: 0.01,
186 replay_buffer_size: 10000,
187 batch_size: 32,
188 target_update_freq: 100,
189 n_actions: 10,
190 hidden_units: vec![64, 64],
191 prioritized_replay: false,
192 ucb_c: 2.0,
193 }
194 }
195}
196
197type QTable = HashMap<String, Vec<f64>>;
199
200#[derive(Debug, Clone)]
202pub struct NeuralNetwork {
203 pub weights: Vec<Array2<f64>>,
205 pub biases: Vec<Array1<f64>>,
207}
208
209impl NeuralNetwork {
210 pub fn new(
212 input_dim: usize,
213 hidden_dims: &[usize],
214 output_dim: usize,
215 rng: &mut Random,
216 ) -> Self {
217 let mut weights = Vec::new();
218 let mut biases = Vec::new();
219
220 let mut dims = vec![input_dim];
221 dims.extend(hidden_dims);
222 dims.push(output_dim);
223
224 for i in 0..dims.len() - 1 {
225 let w = Self::init_weights(dims[i], dims[i + 1], rng);
226 let b = Array1::zeros(dims[i + 1]);
227 weights.push(w);
228 biases.push(b);
229 }
230
231 Self { weights, biases }
232 }
233
234 fn init_weights(input_dim: usize, output_dim: usize, rng: &mut Random) -> Array2<f64> {
236 let scale = (2.0 / (input_dim + output_dim) as f64).sqrt();
237 let values: Vec<f64> = (0..input_dim * output_dim)
238 .map(|_| (rng.random::<f64>() * 2.0 - 1.0) * scale)
239 .collect();
240 Array2::from_shape_vec((input_dim, output_dim), values)
241 .expect("shape and vector length match")
242 }
243
244 pub fn forward(&self, input: &Array1<f64>) -> Array1<f64> {
246 let mut activation = input.clone();
247
248 for (w, b) in self.weights.iter().zip(&self.biases) {
249 activation = activation.dot(w) + b;
251
252 if w != self
254 .weights
255 .last()
256 .expect("collection validated to be non-empty")
257 {
258 activation.mapv_inplace(|x| x.max(0.0));
259 }
260 }
261
262 activation
263 }
264
265 pub fn update(&mut self, gradient_scale: f64, learning_rate: f64) {
267 for w in &mut self.weights {
268 w.mapv_inplace(|x| x - learning_rate * gradient_scale);
269 }
270 }
271}
272
273#[derive(Debug, Clone, Serialize, Deserialize)]
275pub struct RLStats {
276 pub total_steps: u64,
278 pub total_episodes: u64,
280 pub avg_reward_per_episode: f64,
282 pub current_epsilon: f64,
284 pub total_reward: f64,
286 pub avg_q_value: f64,
288 pub avg_loss: f64,
290}
291
292impl Default for RLStats {
293 fn default() -> Self {
294 Self {
295 total_steps: 0,
296 total_episodes: 0,
297 avg_reward_per_episode: 0.0,
298 current_epsilon: 1.0,
299 total_reward: 0.0,
300 avg_q_value: 0.0,
301 avg_loss: 0.0,
302 }
303 }
304}
305
306pub struct RLAgent {
308 config: RLConfig,
309 q_table: Arc<RwLock<QTable>>,
311 q_network: Arc<RwLock<Option<NeuralNetwork>>>,
313 target_network: Arc<RwLock<Option<NeuralNetwork>>>,
315 replay_buffer: Arc<RwLock<VecDeque<Experience>>>,
317 action_counts: Arc<RwLock<Vec<u64>>>,
319 action_rewards: Arc<RwLock<Vec<f64>>>,
321 stats: Arc<RwLock<RLStats>>,
323 #[allow(clippy::arc_with_non_send_sync)]
325 rng: Arc<Mutex<Random>>,
326 episode_reward: Arc<RwLock<f64>>,
328 update_counter: Arc<RwLock<usize>>,
330}
331
332impl RLAgent {
333 #[allow(clippy::arc_with_non_send_sync)]
335 pub fn new(config: RLConfig) -> Result<Self> {
336 let action_counts = vec![0u64; config.n_actions];
337 let action_rewards = vec![0.0; config.n_actions];
338 let buffer_size = config.replay_buffer_size;
339 let epsilon = config.epsilon;
340
341 Ok(Self {
342 config,
343 q_table: Arc::new(RwLock::new(HashMap::new())),
344 q_network: Arc::new(RwLock::new(None)),
345 target_network: Arc::new(RwLock::new(None)),
346 replay_buffer: Arc::new(RwLock::new(VecDeque::with_capacity(buffer_size))),
347 action_counts: Arc::new(RwLock::new(action_counts)),
348 action_rewards: Arc::new(RwLock::new(action_rewards)),
349 stats: Arc::new(RwLock::new(RLStats {
350 current_epsilon: epsilon,
351 ..Default::default()
352 })),
353 rng: Arc::new(Mutex::new(Random::default())),
354 episode_reward: Arc::new(RwLock::new(0.0)),
355 update_counter: Arc::new(RwLock::new(0)),
356 })
357 }
358
359 pub async fn initialize_networks(&mut self, state_dim: usize) -> Result<()> {
361 if matches!(
362 self.config.algorithm,
363 RLAlgorithm::DQN | RLAlgorithm::ActorCritic | RLAlgorithm::PPO
364 ) {
365 let mut rng = self.rng.lock().await;
366
367 let q_net = NeuralNetwork::new(
368 state_dim,
369 &self.config.hidden_units,
370 self.config.n_actions,
371 &mut rng,
372 );
373
374 let target_net = NeuralNetwork::new(
375 state_dim,
376 &self.config.hidden_units,
377 self.config.n_actions,
378 &mut rng,
379 );
380
381 *self.q_network.write().await = Some(q_net);
382 *self.target_network.write().await = Some(target_net);
383
384 info!(
385 "Initialized neural networks with state_dim={}, n_actions={}",
386 state_dim, self.config.n_actions
387 );
388 }
389
390 Ok(())
391 }
392
393 pub async fn select_action(&self, state: &State) -> Result<Action> {
395 match self.config.algorithm {
396 RLAlgorithm::QLearning | RLAlgorithm::SARSA => {
397 self.select_action_q_learning(state).await
398 }
399 RLAlgorithm::DQN => self.select_action_dqn(state).await,
400 RLAlgorithm::UCB => self.select_action_ucb().await,
401 RLAlgorithm::ThompsonSampling => self.select_action_thompson().await,
402 RLAlgorithm::EpsilonGreedy => self.select_action_epsilon_greedy().await,
403 _ => {
404 self.select_action_epsilon_greedy().await
406 }
407 }
408 }
409
410 async fn select_action_q_learning(&self, state: &State) -> Result<Action> {
412 let stats = self.stats.read().await;
413 let epsilon = stats.current_epsilon;
414 drop(stats);
415
416 let mut rng = self.rng.lock().await;
417
418 if rng.random::<f64>() < epsilon {
419 let action_idx = rng.random_range(0..self.config.n_actions);
421 Ok(Action::Discrete(action_idx))
422 } else {
423 let state_key = self.state_to_key(state);
425 let q_table = self.q_table.read().await;
426
427 let q_values = q_table
428 .get(&state_key)
429 .cloned()
430 .unwrap_or_else(|| vec![0.0; self.config.n_actions]);
431
432 let best_action = q_values
433 .iter()
434 .enumerate()
435 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
436 .map(|(idx, _)| idx)
437 .unwrap_or(0);
438
439 Ok(Action::Discrete(best_action))
440 }
441 }
442
443 async fn select_action_dqn(&self, state: &State) -> Result<Action> {
445 let stats = self.stats.read().await;
446 let epsilon = stats.current_epsilon;
447 drop(stats);
448
449 let mut rng = self.rng.lock().await;
450
451 if rng.random::<f64>() < epsilon {
452 let action_idx = rng.random_range(0..self.config.n_actions);
453 Ok(Action::Discrete(action_idx))
454 } else {
455 drop(rng);
456
457 let q_network = self.q_network.read().await;
458 if let Some(ref network) = *q_network {
459 let state_vec = Array1::from_vec(state.to_vector());
460 let q_values = network.forward(&state_vec);
461
462 let best_action = q_values
463 .iter()
464 .enumerate()
465 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
466 .map(|(idx, _)| idx)
467 .unwrap_or(0);
468
469 Ok(Action::Discrete(best_action))
470 } else {
471 Err(anyhow!("Q-network not initialized"))
472 }
473 }
474 }
475
476 async fn select_action_ucb(&self) -> Result<Action> {
478 let action_counts = self.action_counts.read().await;
479 let action_rewards = self.action_rewards.read().await;
480 let stats = self.stats.read().await;
481 let total_steps = stats.total_steps;
482
483 let mut ucb_values = Vec::with_capacity(self.config.n_actions);
484
485 for i in 0..self.config.n_actions {
486 let count = action_counts[i];
487 let avg_reward = if count > 0 {
488 action_rewards[i] / count as f64
489 } else {
490 f64::INFINITY };
492
493 let exploration_bonus = if count > 0 {
494 self.config.ucb_c * ((total_steps as f64).ln() / count as f64).sqrt()
495 } else {
496 f64::INFINITY
497 };
498
499 ucb_values.push(avg_reward + exploration_bonus);
500 }
501
502 let best_action = ucb_values
503 .iter()
504 .enumerate()
505 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
506 .map(|(idx, _)| idx)
507 .unwrap_or(0);
508
509 Ok(Action::Discrete(best_action))
510 }
511
512 async fn select_action_thompson(&self) -> Result<Action> {
514 let action_counts = self.action_counts.read().await;
515 let action_rewards = self.action_rewards.read().await;
516 let mut rng = self.rng.lock().await;
517
518 let mut sampled_values = Vec::with_capacity(self.config.n_actions);
519
520 for i in 0..self.config.n_actions {
521 let count = action_counts[i];
522 let sum_reward = action_rewards[i];
523
524 let alpha = sum_reward + 1.0;
526 let beta = (count as f64 - sum_reward).max(0.0) + 1.0;
527
528 let sample = rng.random::<f64>().powf(1.0 / alpha)
530 * (1.0 - rng.random::<f64>()).powf(1.0 / beta);
531 sampled_values.push(sample);
532 }
533
534 let best_action = sampled_values
535 .iter()
536 .enumerate()
537 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
538 .map(|(idx, _)| idx)
539 .unwrap_or(0);
540
541 Ok(Action::Discrete(best_action))
542 }
543
544 async fn select_action_epsilon_greedy(&self) -> Result<Action> {
546 let stats = self.stats.read().await;
547 let epsilon = stats.current_epsilon;
548 drop(stats);
549
550 let mut rng = self.rng.lock().await;
551
552 if rng.random::<f64>() < epsilon {
553 let action_idx = rng.random_range(0..self.config.n_actions);
554 Ok(Action::Discrete(action_idx))
555 } else {
556 drop(rng);
557
558 let action_counts = self.action_counts.read().await;
559 let action_rewards = self.action_rewards.read().await;
560
561 let best_action = (0..self.config.n_actions)
562 .max_by(|&a, &b| {
563 let avg_a = if action_counts[a] > 0 {
564 action_rewards[a] / action_counts[a] as f64
565 } else {
566 0.0
567 };
568 let avg_b = if action_counts[b] > 0 {
569 action_rewards[b] / action_counts[b] as f64
570 } else {
571 0.0
572 };
573 avg_a
574 .partial_cmp(&avg_b)
575 .unwrap_or(std::cmp::Ordering::Equal)
576 })
577 .unwrap_or(0);
578
579 Ok(Action::Discrete(best_action))
580 }
581 }
582
583 pub async fn learn(
585 &mut self,
586 state: &State,
587 action: Action,
588 reward: f64,
589 next_state: &State,
590 ) -> Result<()> {
591 let experience = Experience {
593 state: state.clone(),
594 action: action.clone(),
595 reward,
596 next_state: next_state.clone(),
597 done: false,
598 };
599
600 let mut replay_buffer = self.replay_buffer.write().await;
601 replay_buffer.push_back(experience);
602
603 if replay_buffer.len() > self.config.replay_buffer_size {
604 replay_buffer.pop_front();
605 }
606 drop(replay_buffer);
607
608 *self.episode_reward.write().await += reward;
610 let mut stats = self.stats.write().await;
611 stats.total_steps += 1;
612 stats.total_reward += reward;
613
614 if let Action::Discrete(idx) = action {
616 let mut counts = self.action_counts.write().await;
617 let mut rewards = self.action_rewards.write().await;
618 counts[idx] += 1;
619 rewards[idx] += reward;
620 }
621
622 match self.config.algorithm {
624 RLAlgorithm::QLearning | RLAlgorithm::SARSA => {
625 drop(stats);
626 self.update_q_learning(state, &action, reward, next_state)
627 .await?;
628 }
629 RLAlgorithm::DQN => {
630 drop(stats);
631 self.update_dqn().await?;
632 }
633 _ => {
634 }
636 }
637
638 let mut stats = self.stats.write().await;
640 stats.current_epsilon =
641 (stats.current_epsilon * self.config.epsilon_decay).max(self.config.epsilon_min);
642
643 Ok(())
644 }
645
646 async fn update_q_learning(
648 &self,
649 state: &State,
650 action: &Action,
651 reward: f64,
652 next_state: &State,
653 ) -> Result<()> {
654 if let Action::Discrete(action_idx) = action {
655 let state_key = self.state_to_key(state);
656 let next_state_key = self.state_to_key(next_state);
657
658 let mut q_table = self.q_table.write().await;
659
660 let max_next_q = {
662 let next_q_values = q_table
663 .entry(next_state_key)
664 .or_insert_with(|| vec![0.0; self.config.n_actions]);
665 next_q_values
666 .iter()
667 .copied()
668 .fold(f64::NEG_INFINITY, f64::max)
669 };
670
671 let q_values = q_table
673 .entry(state_key.clone())
674 .or_insert_with(|| vec![0.0; self.config.n_actions]);
675
676 let current_q = q_values[*action_idx];
678 let td_target = reward + self.config.discount_factor * max_next_q;
679 let td_error = td_target - current_q;
680
681 q_values[*action_idx] += self.config.learning_rate * td_error;
682
683 debug!(
684 "Q-learning update: state={}, action={}, Q={:.4}",
685 state_key, action_idx, q_values[*action_idx]
686 );
687 }
688
689 Ok(())
690 }
691
692 async fn update_dqn(&self) -> Result<()> {
694 let replay_buffer = self.replay_buffer.read().await;
695
696 if replay_buffer.len() < self.config.batch_size {
697 return Ok(()); }
699
700 let batch_indices: Vec<usize> = {
702 let mut rng = self.rng.lock().await;
703 (0..self.config.batch_size)
704 .map(|_| rng.random_range(0..replay_buffer.len()))
705 .collect()
706 };
707
708 let batch: Vec<Experience> = batch_indices
710 .iter()
711 .map(|&i| replay_buffer[i].clone())
712 .collect();
713 drop(replay_buffer);
714
715 let mut total_loss = 0.0;
717
718 let q_network = self.q_network.read().await;
719 let target_network = self.target_network.read().await;
720
721 if let (Some(ref q_net), Some(ref target_net)) = (&*q_network, &*target_network) {
722 for exp in &batch {
723 let state_vec = Array1::from_vec(exp.state.to_vector());
724 let next_state_vec = Array1::from_vec(exp.next_state.to_vector());
725
726 let q_values = q_net.forward(&state_vec);
727 let next_q_values = target_net.forward(&next_state_vec);
728
729 let max_next_q = next_q_values
730 .iter()
731 .copied()
732 .fold(f64::NEG_INFINITY, f64::max);
733
734 if let Action::Discrete(action_idx) = exp.action {
735 let td_target = exp.reward + self.config.discount_factor * max_next_q;
736 let td_error = td_target - q_values[action_idx];
737 total_loss += td_error * td_error;
738 }
739 }
740 }
741 drop(q_network);
742 drop(target_network);
743
744 let mut q_network = self.q_network.write().await;
746 if let Some(ref mut network) = *q_network {
747 let gradient_scale = total_loss / self.config.batch_size as f64;
748 network.update(gradient_scale, self.config.learning_rate);
749 }
750 drop(q_network);
751
752 let mut counter = self.update_counter.write().await;
754 *counter += 1;
755
756 if *counter % self.config.target_update_freq == 0 {
757 let q_net = self.q_network.read().await;
758 if let Some(ref network) = *q_net {
759 *self.target_network.write().await = Some(network.clone());
760 debug!("Updated target network at step {}", *counter);
761 }
762 }
763
764 let mut stats = self.stats.write().await;
766 stats.avg_loss = (stats.avg_loss * (stats.total_steps - 1) as f64 + total_loss)
767 / stats.total_steps as f64;
768
769 Ok(())
770 }
771
772 pub async fn end_episode(&mut self) -> Result<()> {
774 let episode_reward = *self.episode_reward.read().await;
775 *self.episode_reward.write().await = 0.0;
776
777 let mut stats = self.stats.write().await;
778 stats.total_episodes += 1;
779 stats.avg_reward_per_episode =
780 (stats.avg_reward_per_episode * (stats.total_episodes - 1) as f64 + episode_reward)
781 / stats.total_episodes as f64;
782
783 info!(
784 "Episode {} complete: reward={:.2}, avg_reward={:.2}",
785 stats.total_episodes, episode_reward, stats.avg_reward_per_episode
786 );
787
788 Ok(())
789 }
790
791 fn state_to_key(&self, state: &State) -> String {
793 format!(
795 "{:.0}_{:.0}_{:.2}_{:.2}_{}_{ :.2}",
796 (state.throughput / 1000.0).round(),
797 (state.latency_ms / 10.0).round(),
798 (state.cpu_utilization * 10.0).round() / 10.0,
799 (state.memory_utilization * 10.0).round() / 10.0,
800 state.queue_depth / 100,
801 (state.error_rate * 100.0).round() / 100.0,
802 )
803 }
804
805 pub async fn get_stats(&self) -> RLStats {
807 self.stats.read().await.clone()
808 }
809
810 pub async fn get_epsilon(&self) -> f64 {
812 self.stats.read().await.current_epsilon
813 }
814
815 pub async fn set_epsilon(&mut self, epsilon: f64) {
817 self.stats.write().await.current_epsilon = epsilon.clamp(0.0, 1.0);
818 }
819
820 pub async fn export_policy(&self) -> Result<String> {
822 let policy = match self.config.algorithm {
823 RLAlgorithm::QLearning | RLAlgorithm::SARSA => {
824 let q_table = self.q_table.read().await;
825 serde_json::json!({
826 "algorithm": "Q-Learning",
827 "q_table": q_table.iter().take(10).collect::<HashMap<_, _>>(), })
829 }
830 _ => {
831 let stats = self.get_stats().await;
832 serde_json::json!({
833 "algorithm": format!("{:?}", self.config.algorithm),
834 "stats": stats,
835 })
836 }
837 };
838
839 Ok(serde_json::to_string_pretty(&policy)?)
840 }
841}
842
843#[cfg(test)]
844mod tests {
845 use super::*;
846
847 fn create_test_state() -> State {
848 State {
849 throughput: 10000.0,
850 latency_ms: 5.0,
851 cpu_utilization: 0.5,
852 memory_utilization: 0.6,
853 queue_depth: 100,
854 error_rate: 0.01,
855 features: vec![],
856 }
857 }
858
859 #[tokio::test]
860 async fn test_rl_agent_creation() {
861 let config = RLConfig::default();
862 let agent = RLAgent::new(config);
863 assert!(agent.is_ok());
864 }
865
866 #[tokio::test]
867 async fn test_q_learning_action_selection() {
868 let config = RLConfig {
869 algorithm: RLAlgorithm::QLearning,
870 n_actions: 5,
871 ..Default::default()
872 };
873
874 let agent = RLAgent::new(config).unwrap();
875 let state = create_test_state();
876
877 let action = agent.select_action(&state).await;
878 assert!(action.is_ok());
879
880 if let Action::Discrete(idx) = action.unwrap() {
881 assert!(idx < 5);
882 }
883 }
884
885 #[tokio::test]
886 async fn test_dqn_initialization() {
887 let config = RLConfig {
888 algorithm: RLAlgorithm::DQN,
889 n_actions: 10,
890 hidden_units: vec![32, 32],
891 ..Default::default()
892 };
893
894 let mut agent = RLAgent::new(config).unwrap();
895 let state = create_test_state();
896
897 agent.initialize_networks(state.dimension()).await.unwrap();
898
899 let action = agent.select_action(&state).await;
900 assert!(action.is_ok());
901 }
902
903 #[tokio::test]
904 async fn test_ucb_action_selection() {
905 let config = RLConfig {
906 algorithm: RLAlgorithm::UCB,
907 n_actions: 5,
908 ..Default::default()
909 };
910
911 let agent = RLAgent::new(config).unwrap();
912 let action = agent.select_action_ucb().await;
913 assert!(action.is_ok());
914 }
915
916 #[tokio::test]
917 async fn test_learning_update() {
918 let config = RLConfig {
919 algorithm: RLAlgorithm::QLearning,
920 n_actions: 3,
921 ..Default::default()
922 };
923
924 let mut agent = RLAgent::new(config).unwrap();
925 let state = create_test_state();
926 let action = Action::Discrete(1);
927 let reward = 1.0;
928 let next_state = create_test_state();
929
930 let result = agent.learn(&state, action, reward, &next_state).await;
931 assert!(result.is_ok());
932
933 let stats = agent.get_stats().await;
934 assert_eq!(stats.total_steps, 1);
935 assert_eq!(stats.total_reward, 1.0);
936 }
937
938 #[tokio::test]
939 async fn test_epsilon_decay() {
940 let config = RLConfig {
941 epsilon: 1.0,
942 epsilon_decay: 0.9,
943 epsilon_min: 0.1,
944 ..Default::default()
945 };
946
947 let mut agent = RLAgent::new(config).unwrap();
948 let initial_epsilon = agent.get_epsilon().await;
949
950 let state = create_test_state();
951 for _ in 0..10 {
952 agent
953 .learn(&state, Action::Discrete(0), 0.0, &state)
954 .await
955 .unwrap();
956 }
957
958 let final_epsilon = agent.get_epsilon().await;
959 assert!(final_epsilon < initial_epsilon);
960 assert!(final_epsilon >= 0.1);
961 }
962
963 #[tokio::test]
964 async fn test_episode_management() {
965 let config = RLConfig::default();
966 let mut agent = RLAgent::new(config).unwrap();
967
968 let state = create_test_state();
969 agent
970 .learn(&state, Action::Discrete(0), 1.0, &state)
971 .await
972 .unwrap();
973 agent
974 .learn(&state, Action::Discrete(1), 2.0, &state)
975 .await
976 .unwrap();
977
978 agent.end_episode().await.unwrap();
979
980 let stats = agent.get_stats().await;
981 assert_eq!(stats.total_episodes, 1);
982 assert!(stats.avg_reward_per_episode > 0.0);
983 }
984
985 #[tokio::test]
986 async fn test_replay_buffer() {
987 let config = RLConfig {
988 replay_buffer_size: 5,
989 ..Default::default()
990 };
991
992 let mut agent = RLAgent::new(config).unwrap();
993 let state = create_test_state();
994
995 for i in 0..10 {
996 agent
997 .learn(&state, Action::Discrete(0), i as f64, &state)
998 .await
999 .unwrap();
1000 }
1001
1002 let buffer = agent.replay_buffer.read().await;
1003 assert_eq!(buffer.len(), 5); }
1005
1006 #[tokio::test]
1007 async fn test_export_policy() {
1008 let config = RLConfig {
1009 algorithm: RLAlgorithm::QLearning,
1010 ..Default::default()
1011 };
1012
1013 let mut agent = RLAgent::new(config).unwrap();
1014 let state = create_test_state();
1015
1016 agent
1017 .learn(&state, Action::Discrete(0), 1.0, &state)
1018 .await
1019 .unwrap();
1020
1021 let export = agent.export_policy().await;
1022 assert!(export.is_ok());
1023 assert!(export.unwrap().contains("algorithm"));
1024 }
1025
1026 #[tokio::test]
1027 async fn test_thompson_sampling() {
1028 let config = RLConfig {
1029 algorithm: RLAlgorithm::ThompsonSampling,
1030 n_actions: 5,
1031 ..Default::default()
1032 };
1033
1034 let agent = RLAgent::new(config).unwrap();
1035 let action = agent.select_action_thompson().await;
1036 assert!(action.is_ok());
1037 }
1038
1039 #[tokio::test]
1040 async fn test_multiple_episodes() {
1041 let config = RLConfig {
1042 algorithm: RLAlgorithm::QLearning,
1043 n_actions: 3,
1044 ..Default::default()
1045 };
1046
1047 let mut agent = RLAgent::new(config).unwrap();
1048 let state = create_test_state();
1049
1050 for episode in 0..5 {
1051 for _ in 0..10 {
1052 let action = agent.select_action(&state).await.unwrap();
1053 let reward = if episode % 2 == 0 { 1.0 } else { -1.0 };
1054 agent.learn(&state, action, reward, &state).await.unwrap();
1055 }
1056 agent.end_episode().await.unwrap();
1057 }
1058
1059 let stats = agent.get_stats().await;
1060 assert_eq!(stats.total_episodes, 5);
1061 assert_eq!(stats.total_steps, 50);
1062 }
1063}