1use anyhow::{anyhow, Result};
33use scirs2_core::ndarray_ext::{Array1, Array2};
34use scirs2_core::random::{Random, Rng};
35use serde::{Deserialize, Serialize};
36use std::collections::{HashMap, VecDeque};
37use std::sync::Arc;
38use tokio::sync::{Mutex, RwLock};
39use tracing::{debug, info};
40
41#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
43pub enum RLAlgorithm {
44 QLearning,
46 DQN,
48 SARSA,
50 ActorCritic,
52 REINFORCE,
54 PPO,
56 UCB,
58 ThompsonSampling,
60 EpsilonGreedy,
62}
63
64#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct State {
67 pub throughput: f64,
69 pub latency_ms: f64,
71 pub cpu_utilization: f64,
73 pub memory_utilization: f64,
75 pub queue_depth: usize,
77 pub error_rate: f64,
79 pub features: Vec<f64>,
81}
82
83impl State {
84 pub fn to_vector(&self) -> Vec<f64> {
86 let mut vec = vec![
87 self.throughput,
88 self.latency_ms,
89 self.cpu_utilization,
90 self.memory_utilization,
91 self.queue_depth as f64,
92 self.error_rate,
93 ];
94 vec.extend(&self.features);
95 vec
96 }
97
98 pub fn dimension(&self) -> usize {
100 6 + self.features.len()
101 }
102}
103
104#[derive(Debug, Clone, Serialize, Deserialize)]
106pub enum Action {
107 Discrete(usize),
109 Continuous(Vec<f64>),
111}
112
113impl Action {
114 pub fn as_index(&self) -> Option<usize> {
116 match self {
117 Action::Discrete(idx) => Some(*idx),
118 _ => None,
119 }
120 }
121
122 pub fn as_vector(&self) -> Option<&[f64]> {
124 match self {
125 Action::Continuous(vec) => Some(vec),
126 _ => None,
127 }
128 }
129}
130
131#[derive(Debug, Clone)]
133pub struct Experience {
134 pub state: State,
136 pub action: Action,
138 pub reward: f64,
140 pub next_state: State,
142 pub done: bool,
144}
145
146#[derive(Debug, Clone, Serialize, Deserialize)]
148pub struct RLConfig {
149 pub algorithm: RLAlgorithm,
151 pub learning_rate: f64,
153 pub discount_factor: f64,
155 pub epsilon: f64,
157 pub epsilon_decay: f64,
159 pub epsilon_min: f64,
161 pub replay_buffer_size: usize,
163 pub batch_size: usize,
165 pub target_update_freq: usize,
167 pub n_actions: usize,
169 pub hidden_units: Vec<usize>,
171 pub prioritized_replay: bool,
173 pub ucb_c: f64,
175}
176
177impl Default for RLConfig {
178 fn default() -> Self {
179 Self {
180 algorithm: RLAlgorithm::DQN,
181 learning_rate: 0.001,
182 discount_factor: 0.99,
183 epsilon: 1.0,
184 epsilon_decay: 0.995,
185 epsilon_min: 0.01,
186 replay_buffer_size: 10000,
187 batch_size: 32,
188 target_update_freq: 100,
189 n_actions: 10,
190 hidden_units: vec![64, 64],
191 prioritized_replay: false,
192 ucb_c: 2.0,
193 }
194 }
195}
196
197type QTable = HashMap<String, Vec<f64>>;
199
200#[derive(Debug, Clone)]
202pub struct NeuralNetwork {
203 pub weights: Vec<Array2<f64>>,
205 pub biases: Vec<Array1<f64>>,
207}
208
209impl NeuralNetwork {
210 pub fn new(
212 input_dim: usize,
213 hidden_dims: &[usize],
214 output_dim: usize,
215 rng: &mut Random,
216 ) -> Self {
217 let mut weights = Vec::new();
218 let mut biases = Vec::new();
219
220 let mut dims = vec![input_dim];
221 dims.extend(hidden_dims);
222 dims.push(output_dim);
223
224 for i in 0..dims.len() - 1 {
225 let w = Self::init_weights(dims[i], dims[i + 1], rng);
226 let b = Array1::zeros(dims[i + 1]);
227 weights.push(w);
228 biases.push(b);
229 }
230
231 Self { weights, biases }
232 }
233
234 fn init_weights(input_dim: usize, output_dim: usize, rng: &mut Random) -> Array2<f64> {
236 let scale = (2.0 / (input_dim + output_dim) as f64).sqrt();
237 let values: Vec<f64> = (0..input_dim * output_dim)
238 .map(|_| (rng.random::<f64>() * 2.0 - 1.0) * scale)
239 .collect();
240 Array2::from_shape_vec((input_dim, output_dim), values).unwrap()
241 }
242
243 pub fn forward(&self, input: &Array1<f64>) -> Array1<f64> {
245 let mut activation = input.clone();
246
247 for (w, b) in self.weights.iter().zip(&self.biases) {
248 activation = activation.dot(w) + b;
250
251 if w != self.weights.last().unwrap() {
253 activation.mapv_inplace(|x| x.max(0.0));
254 }
255 }
256
257 activation
258 }
259
260 pub fn update(&mut self, gradient_scale: f64, learning_rate: f64) {
262 for w in &mut self.weights {
263 w.mapv_inplace(|x| x - learning_rate * gradient_scale);
264 }
265 }
266}
267
268#[derive(Debug, Clone, Serialize, Deserialize)]
270pub struct RLStats {
271 pub total_steps: u64,
273 pub total_episodes: u64,
275 pub avg_reward_per_episode: f64,
277 pub current_epsilon: f64,
279 pub total_reward: f64,
281 pub avg_q_value: f64,
283 pub avg_loss: f64,
285}
286
287impl Default for RLStats {
288 fn default() -> Self {
289 Self {
290 total_steps: 0,
291 total_episodes: 0,
292 avg_reward_per_episode: 0.0,
293 current_epsilon: 1.0,
294 total_reward: 0.0,
295 avg_q_value: 0.0,
296 avg_loss: 0.0,
297 }
298 }
299}
300
301pub struct RLAgent {
303 config: RLConfig,
304 q_table: Arc<RwLock<QTable>>,
306 q_network: Arc<RwLock<Option<NeuralNetwork>>>,
308 target_network: Arc<RwLock<Option<NeuralNetwork>>>,
310 replay_buffer: Arc<RwLock<VecDeque<Experience>>>,
312 action_counts: Arc<RwLock<Vec<u64>>>,
314 action_rewards: Arc<RwLock<Vec<f64>>>,
316 stats: Arc<RwLock<RLStats>>,
318 #[allow(clippy::arc_with_non_send_sync)]
320 rng: Arc<Mutex<Random>>,
321 episode_reward: Arc<RwLock<f64>>,
323 update_counter: Arc<RwLock<usize>>,
325}
326
327impl RLAgent {
328 #[allow(clippy::arc_with_non_send_sync)]
330 pub fn new(config: RLConfig) -> Result<Self> {
331 let action_counts = vec![0u64; config.n_actions];
332 let action_rewards = vec![0.0; config.n_actions];
333 let buffer_size = config.replay_buffer_size;
334 let epsilon = config.epsilon;
335
336 Ok(Self {
337 config,
338 q_table: Arc::new(RwLock::new(HashMap::new())),
339 q_network: Arc::new(RwLock::new(None)),
340 target_network: Arc::new(RwLock::new(None)),
341 replay_buffer: Arc::new(RwLock::new(VecDeque::with_capacity(buffer_size))),
342 action_counts: Arc::new(RwLock::new(action_counts)),
343 action_rewards: Arc::new(RwLock::new(action_rewards)),
344 stats: Arc::new(RwLock::new(RLStats {
345 current_epsilon: epsilon,
346 ..Default::default()
347 })),
348 rng: Arc::new(Mutex::new(Random::default())),
349 episode_reward: Arc::new(RwLock::new(0.0)),
350 update_counter: Arc::new(RwLock::new(0)),
351 })
352 }
353
354 pub async fn initialize_networks(&mut self, state_dim: usize) -> Result<()> {
356 if matches!(
357 self.config.algorithm,
358 RLAlgorithm::DQN | RLAlgorithm::ActorCritic | RLAlgorithm::PPO
359 ) {
360 let mut rng = self.rng.lock().await;
361
362 let q_net = NeuralNetwork::new(
363 state_dim,
364 &self.config.hidden_units,
365 self.config.n_actions,
366 &mut rng,
367 );
368
369 let target_net = NeuralNetwork::new(
370 state_dim,
371 &self.config.hidden_units,
372 self.config.n_actions,
373 &mut rng,
374 );
375
376 *self.q_network.write().await = Some(q_net);
377 *self.target_network.write().await = Some(target_net);
378
379 info!(
380 "Initialized neural networks with state_dim={}, n_actions={}",
381 state_dim, self.config.n_actions
382 );
383 }
384
385 Ok(())
386 }
387
388 pub async fn select_action(&self, state: &State) -> Result<Action> {
390 match self.config.algorithm {
391 RLAlgorithm::QLearning | RLAlgorithm::SARSA => {
392 self.select_action_q_learning(state).await
393 }
394 RLAlgorithm::DQN => self.select_action_dqn(state).await,
395 RLAlgorithm::UCB => self.select_action_ucb().await,
396 RLAlgorithm::ThompsonSampling => self.select_action_thompson().await,
397 RLAlgorithm::EpsilonGreedy => self.select_action_epsilon_greedy().await,
398 _ => {
399 self.select_action_epsilon_greedy().await
401 }
402 }
403 }
404
405 async fn select_action_q_learning(&self, state: &State) -> Result<Action> {
407 let stats = self.stats.read().await;
408 let epsilon = stats.current_epsilon;
409 drop(stats);
410
411 let mut rng = self.rng.lock().await;
412
413 if rng.random::<f64>() < epsilon {
414 let action_idx = rng.random_range(0..self.config.n_actions);
416 Ok(Action::Discrete(action_idx))
417 } else {
418 let state_key = self.state_to_key(state);
420 let q_table = self.q_table.read().await;
421
422 let q_values = q_table
423 .get(&state_key)
424 .cloned()
425 .unwrap_or_else(|| vec![0.0; self.config.n_actions]);
426
427 let best_action = q_values
428 .iter()
429 .enumerate()
430 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
431 .map(|(idx, _)| idx)
432 .unwrap_or(0);
433
434 Ok(Action::Discrete(best_action))
435 }
436 }
437
438 async fn select_action_dqn(&self, state: &State) -> Result<Action> {
440 let stats = self.stats.read().await;
441 let epsilon = stats.current_epsilon;
442 drop(stats);
443
444 let mut rng = self.rng.lock().await;
445
446 if rng.random::<f64>() < epsilon {
447 let action_idx = rng.random_range(0..self.config.n_actions);
448 Ok(Action::Discrete(action_idx))
449 } else {
450 drop(rng);
451
452 let q_network = self.q_network.read().await;
453 if let Some(ref network) = *q_network {
454 let state_vec = Array1::from_vec(state.to_vector());
455 let q_values = network.forward(&state_vec);
456
457 let best_action = q_values
458 .iter()
459 .enumerate()
460 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
461 .map(|(idx, _)| idx)
462 .unwrap_or(0);
463
464 Ok(Action::Discrete(best_action))
465 } else {
466 Err(anyhow!("Q-network not initialized"))
467 }
468 }
469 }
470
471 async fn select_action_ucb(&self) -> Result<Action> {
473 let action_counts = self.action_counts.read().await;
474 let action_rewards = self.action_rewards.read().await;
475 let stats = self.stats.read().await;
476 let total_steps = stats.total_steps;
477
478 let mut ucb_values = Vec::with_capacity(self.config.n_actions);
479
480 for i in 0..self.config.n_actions {
481 let count = action_counts[i];
482 let avg_reward = if count > 0 {
483 action_rewards[i] / count as f64
484 } else {
485 f64::INFINITY };
487
488 let exploration_bonus = if count > 0 {
489 self.config.ucb_c * ((total_steps as f64).ln() / count as f64).sqrt()
490 } else {
491 f64::INFINITY
492 };
493
494 ucb_values.push(avg_reward + exploration_bonus);
495 }
496
497 let best_action = ucb_values
498 .iter()
499 .enumerate()
500 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
501 .map(|(idx, _)| idx)
502 .unwrap_or(0);
503
504 Ok(Action::Discrete(best_action))
505 }
506
507 async fn select_action_thompson(&self) -> Result<Action> {
509 let action_counts = self.action_counts.read().await;
510 let action_rewards = self.action_rewards.read().await;
511 let mut rng = self.rng.lock().await;
512
513 let mut sampled_values = Vec::with_capacity(self.config.n_actions);
514
515 for i in 0..self.config.n_actions {
516 let count = action_counts[i];
517 let sum_reward = action_rewards[i];
518
519 let alpha = sum_reward + 1.0;
521 let beta = (count as f64 - sum_reward).max(0.0) + 1.0;
522
523 let sample = rng.random::<f64>().powf(1.0 / alpha)
525 * (1.0 - rng.random::<f64>()).powf(1.0 / beta);
526 sampled_values.push(sample);
527 }
528
529 let best_action = sampled_values
530 .iter()
531 .enumerate()
532 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
533 .map(|(idx, _)| idx)
534 .unwrap_or(0);
535
536 Ok(Action::Discrete(best_action))
537 }
538
539 async fn select_action_epsilon_greedy(&self) -> Result<Action> {
541 let stats = self.stats.read().await;
542 let epsilon = stats.current_epsilon;
543 drop(stats);
544
545 let mut rng = self.rng.lock().await;
546
547 if rng.random::<f64>() < epsilon {
548 let action_idx = rng.random_range(0..self.config.n_actions);
549 Ok(Action::Discrete(action_idx))
550 } else {
551 drop(rng);
552
553 let action_counts = self.action_counts.read().await;
554 let action_rewards = self.action_rewards.read().await;
555
556 let best_action = (0..self.config.n_actions)
557 .max_by(|&a, &b| {
558 let avg_a = if action_counts[a] > 0 {
559 action_rewards[a] / action_counts[a] as f64
560 } else {
561 0.0
562 };
563 let avg_b = if action_counts[b] > 0 {
564 action_rewards[b] / action_counts[b] as f64
565 } else {
566 0.0
567 };
568 avg_a
569 .partial_cmp(&avg_b)
570 .unwrap_or(std::cmp::Ordering::Equal)
571 })
572 .unwrap_or(0);
573
574 Ok(Action::Discrete(best_action))
575 }
576 }
577
578 pub async fn learn(
580 &mut self,
581 state: &State,
582 action: Action,
583 reward: f64,
584 next_state: &State,
585 ) -> Result<()> {
586 let experience = Experience {
588 state: state.clone(),
589 action: action.clone(),
590 reward,
591 next_state: next_state.clone(),
592 done: false,
593 };
594
595 let mut replay_buffer = self.replay_buffer.write().await;
596 replay_buffer.push_back(experience);
597
598 if replay_buffer.len() > self.config.replay_buffer_size {
599 replay_buffer.pop_front();
600 }
601 drop(replay_buffer);
602
603 *self.episode_reward.write().await += reward;
605 let mut stats = self.stats.write().await;
606 stats.total_steps += 1;
607 stats.total_reward += reward;
608
609 if let Action::Discrete(idx) = action {
611 let mut counts = self.action_counts.write().await;
612 let mut rewards = self.action_rewards.write().await;
613 counts[idx] += 1;
614 rewards[idx] += reward;
615 }
616
617 match self.config.algorithm {
619 RLAlgorithm::QLearning | RLAlgorithm::SARSA => {
620 drop(stats);
621 self.update_q_learning(state, &action, reward, next_state)
622 .await?;
623 }
624 RLAlgorithm::DQN => {
625 drop(stats);
626 self.update_dqn().await?;
627 }
628 _ => {
629 }
631 }
632
633 let mut stats = self.stats.write().await;
635 stats.current_epsilon =
636 (stats.current_epsilon * self.config.epsilon_decay).max(self.config.epsilon_min);
637
638 Ok(())
639 }
640
641 async fn update_q_learning(
643 &self,
644 state: &State,
645 action: &Action,
646 reward: f64,
647 next_state: &State,
648 ) -> Result<()> {
649 if let Action::Discrete(action_idx) = action {
650 let state_key = self.state_to_key(state);
651 let next_state_key = self.state_to_key(next_state);
652
653 let mut q_table = self.q_table.write().await;
654
655 let max_next_q = {
657 let next_q_values = q_table
658 .entry(next_state_key)
659 .or_insert_with(|| vec![0.0; self.config.n_actions]);
660 next_q_values
661 .iter()
662 .copied()
663 .fold(f64::NEG_INFINITY, f64::max)
664 };
665
666 let q_values = q_table
668 .entry(state_key.clone())
669 .or_insert_with(|| vec![0.0; self.config.n_actions]);
670
671 let current_q = q_values[*action_idx];
673 let td_target = reward + self.config.discount_factor * max_next_q;
674 let td_error = td_target - current_q;
675
676 q_values[*action_idx] += self.config.learning_rate * td_error;
677
678 debug!(
679 "Q-learning update: state={}, action={}, Q={:.4}",
680 state_key, action_idx, q_values[*action_idx]
681 );
682 }
683
684 Ok(())
685 }
686
687 async fn update_dqn(&self) -> Result<()> {
689 let replay_buffer = self.replay_buffer.read().await;
690
691 if replay_buffer.len() < self.config.batch_size {
692 return Ok(()); }
694
695 let batch_indices: Vec<usize> = {
697 let mut rng = self.rng.lock().await;
698 (0..self.config.batch_size)
699 .map(|_| rng.random_range(0..replay_buffer.len()))
700 .collect()
701 };
702
703 let batch: Vec<Experience> = batch_indices
705 .iter()
706 .map(|&i| replay_buffer[i].clone())
707 .collect();
708 drop(replay_buffer);
709
710 let mut total_loss = 0.0;
712
713 let q_network = self.q_network.read().await;
714 let target_network = self.target_network.read().await;
715
716 if let (Some(ref q_net), Some(ref target_net)) = (&*q_network, &*target_network) {
717 for exp in &batch {
718 let state_vec = Array1::from_vec(exp.state.to_vector());
719 let next_state_vec = Array1::from_vec(exp.next_state.to_vector());
720
721 let q_values = q_net.forward(&state_vec);
722 let next_q_values = target_net.forward(&next_state_vec);
723
724 let max_next_q = next_q_values
725 .iter()
726 .copied()
727 .fold(f64::NEG_INFINITY, f64::max);
728
729 if let Action::Discrete(action_idx) = exp.action {
730 let td_target = exp.reward + self.config.discount_factor * max_next_q;
731 let td_error = td_target - q_values[action_idx];
732 total_loss += td_error * td_error;
733 }
734 }
735 }
736 drop(q_network);
737 drop(target_network);
738
739 let mut q_network = self.q_network.write().await;
741 if let Some(ref mut network) = *q_network {
742 let gradient_scale = total_loss / self.config.batch_size as f64;
743 network.update(gradient_scale, self.config.learning_rate);
744 }
745 drop(q_network);
746
747 let mut counter = self.update_counter.write().await;
749 *counter += 1;
750
751 if *counter % self.config.target_update_freq == 0 {
752 let q_net = self.q_network.read().await;
753 if let Some(ref network) = *q_net {
754 *self.target_network.write().await = Some(network.clone());
755 debug!("Updated target network at step {}", *counter);
756 }
757 }
758
759 let mut stats = self.stats.write().await;
761 stats.avg_loss = (stats.avg_loss * (stats.total_steps - 1) as f64 + total_loss)
762 / stats.total_steps as f64;
763
764 Ok(())
765 }
766
767 pub async fn end_episode(&mut self) -> Result<()> {
769 let episode_reward = *self.episode_reward.read().await;
770 *self.episode_reward.write().await = 0.0;
771
772 let mut stats = self.stats.write().await;
773 stats.total_episodes += 1;
774 stats.avg_reward_per_episode =
775 (stats.avg_reward_per_episode * (stats.total_episodes - 1) as f64 + episode_reward)
776 / stats.total_episodes as f64;
777
778 info!(
779 "Episode {} complete: reward={:.2}, avg_reward={:.2}",
780 stats.total_episodes, episode_reward, stats.avg_reward_per_episode
781 );
782
783 Ok(())
784 }
785
786 fn state_to_key(&self, state: &State) -> String {
788 format!(
790 "{:.0}_{:.0}_{:.2}_{:.2}_{}_{ :.2}",
791 (state.throughput / 1000.0).round(),
792 (state.latency_ms / 10.0).round(),
793 (state.cpu_utilization * 10.0).round() / 10.0,
794 (state.memory_utilization * 10.0).round() / 10.0,
795 state.queue_depth / 100,
796 (state.error_rate * 100.0).round() / 100.0,
797 )
798 }
799
800 pub async fn get_stats(&self) -> RLStats {
802 self.stats.read().await.clone()
803 }
804
805 pub async fn get_epsilon(&self) -> f64 {
807 self.stats.read().await.current_epsilon
808 }
809
810 pub async fn set_epsilon(&mut self, epsilon: f64) {
812 self.stats.write().await.current_epsilon = epsilon.clamp(0.0, 1.0);
813 }
814
815 pub async fn export_policy(&self) -> Result<String> {
817 let policy = match self.config.algorithm {
818 RLAlgorithm::QLearning | RLAlgorithm::SARSA => {
819 let q_table = self.q_table.read().await;
820 serde_json::json!({
821 "algorithm": "Q-Learning",
822 "q_table": q_table.iter().take(10).collect::<HashMap<_, _>>(), })
824 }
825 _ => {
826 let stats = self.get_stats().await;
827 serde_json::json!({
828 "algorithm": format!("{:?}", self.config.algorithm),
829 "stats": stats,
830 })
831 }
832 };
833
834 Ok(serde_json::to_string_pretty(&policy)?)
835 }
836}
837
838#[cfg(test)]
839mod tests {
840 use super::*;
841
842 fn create_test_state() -> State {
843 State {
844 throughput: 10000.0,
845 latency_ms: 5.0,
846 cpu_utilization: 0.5,
847 memory_utilization: 0.6,
848 queue_depth: 100,
849 error_rate: 0.01,
850 features: vec![],
851 }
852 }
853
854 #[tokio::test]
855 async fn test_rl_agent_creation() {
856 let config = RLConfig::default();
857 let agent = RLAgent::new(config);
858 assert!(agent.is_ok());
859 }
860
861 #[tokio::test]
862 async fn test_q_learning_action_selection() {
863 let config = RLConfig {
864 algorithm: RLAlgorithm::QLearning,
865 n_actions: 5,
866 ..Default::default()
867 };
868
869 let agent = RLAgent::new(config).unwrap();
870 let state = create_test_state();
871
872 let action = agent.select_action(&state).await;
873 assert!(action.is_ok());
874
875 if let Action::Discrete(idx) = action.unwrap() {
876 assert!(idx < 5);
877 }
878 }
879
880 #[tokio::test]
881 async fn test_dqn_initialization() {
882 let config = RLConfig {
883 algorithm: RLAlgorithm::DQN,
884 n_actions: 10,
885 hidden_units: vec![32, 32],
886 ..Default::default()
887 };
888
889 let mut agent = RLAgent::new(config).unwrap();
890 let state = create_test_state();
891
892 agent.initialize_networks(state.dimension()).await.unwrap();
893
894 let action = agent.select_action(&state).await;
895 assert!(action.is_ok());
896 }
897
898 #[tokio::test]
899 async fn test_ucb_action_selection() {
900 let config = RLConfig {
901 algorithm: RLAlgorithm::UCB,
902 n_actions: 5,
903 ..Default::default()
904 };
905
906 let agent = RLAgent::new(config).unwrap();
907 let action = agent.select_action_ucb().await;
908 assert!(action.is_ok());
909 }
910
911 #[tokio::test]
912 async fn test_learning_update() {
913 let config = RLConfig {
914 algorithm: RLAlgorithm::QLearning,
915 n_actions: 3,
916 ..Default::default()
917 };
918
919 let mut agent = RLAgent::new(config).unwrap();
920 let state = create_test_state();
921 let action = Action::Discrete(1);
922 let reward = 1.0;
923 let next_state = create_test_state();
924
925 let result = agent.learn(&state, action, reward, &next_state).await;
926 assert!(result.is_ok());
927
928 let stats = agent.get_stats().await;
929 assert_eq!(stats.total_steps, 1);
930 assert_eq!(stats.total_reward, 1.0);
931 }
932
933 #[tokio::test]
934 async fn test_epsilon_decay() {
935 let config = RLConfig {
936 epsilon: 1.0,
937 epsilon_decay: 0.9,
938 epsilon_min: 0.1,
939 ..Default::default()
940 };
941
942 let mut agent = RLAgent::new(config).unwrap();
943 let initial_epsilon = agent.get_epsilon().await;
944
945 let state = create_test_state();
946 for _ in 0..10 {
947 agent
948 .learn(&state, Action::Discrete(0), 0.0, &state)
949 .await
950 .unwrap();
951 }
952
953 let final_epsilon = agent.get_epsilon().await;
954 assert!(final_epsilon < initial_epsilon);
955 assert!(final_epsilon >= 0.1);
956 }
957
958 #[tokio::test]
959 async fn test_episode_management() {
960 let config = RLConfig::default();
961 let mut agent = RLAgent::new(config).unwrap();
962
963 let state = create_test_state();
964 agent
965 .learn(&state, Action::Discrete(0), 1.0, &state)
966 .await
967 .unwrap();
968 agent
969 .learn(&state, Action::Discrete(1), 2.0, &state)
970 .await
971 .unwrap();
972
973 agent.end_episode().await.unwrap();
974
975 let stats = agent.get_stats().await;
976 assert_eq!(stats.total_episodes, 1);
977 assert!(stats.avg_reward_per_episode > 0.0);
978 }
979
980 #[tokio::test]
981 async fn test_replay_buffer() {
982 let config = RLConfig {
983 replay_buffer_size: 5,
984 ..Default::default()
985 };
986
987 let mut agent = RLAgent::new(config).unwrap();
988 let state = create_test_state();
989
990 for i in 0..10 {
991 agent
992 .learn(&state, Action::Discrete(0), i as f64, &state)
993 .await
994 .unwrap();
995 }
996
997 let buffer = agent.replay_buffer.read().await;
998 assert_eq!(buffer.len(), 5); }
1000
1001 #[tokio::test]
1002 async fn test_export_policy() {
1003 let config = RLConfig {
1004 algorithm: RLAlgorithm::QLearning,
1005 ..Default::default()
1006 };
1007
1008 let mut agent = RLAgent::new(config).unwrap();
1009 let state = create_test_state();
1010
1011 agent
1012 .learn(&state, Action::Discrete(0), 1.0, &state)
1013 .await
1014 .unwrap();
1015
1016 let export = agent.export_policy().await;
1017 assert!(export.is_ok());
1018 assert!(export.unwrap().contains("algorithm"));
1019 }
1020
1021 #[tokio::test]
1022 async fn test_thompson_sampling() {
1023 let config = RLConfig {
1024 algorithm: RLAlgorithm::ThompsonSampling,
1025 n_actions: 5,
1026 ..Default::default()
1027 };
1028
1029 let agent = RLAgent::new(config).unwrap();
1030 let action = agent.select_action_thompson().await;
1031 assert!(action.is_ok());
1032 }
1033
1034 #[tokio::test]
1035 async fn test_multiple_episodes() {
1036 let config = RLConfig {
1037 algorithm: RLAlgorithm::QLearning,
1038 n_actions: 3,
1039 ..Default::default()
1040 };
1041
1042 let mut agent = RLAgent::new(config).unwrap();
1043 let state = create_test_state();
1044
1045 for episode in 0..5 {
1046 for _ in 0..10 {
1047 let action = agent.select_action(&state).await.unwrap();
1048 let reward = if episode % 2 == 0 { 1.0 } else { -1.0 };
1049 agent.learn(&state, action, reward, &state).await.unwrap();
1050 }
1051 agent.end_episode().await.unwrap();
1052 }
1053
1054 let stats = agent.get_stats().await;
1055 assert_eq!(stats.total_episodes, 5);
1056 assert_eq!(stats.total_steps, 50);
1057 }
1058}