1use crate::{
8 error::QuantRS2Result, gate::multi::*, gate::single::*, gate::GateOp, qubit::QubitId,
9 variational::VariationalOptimizer,
10};
11use ndarray::{Array1, Array2};
12use rand::{rngs::StdRng, Rng, SeedableRng};
13use serde::{Deserialize, Serialize};
14use std::collections::VecDeque;
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct QuantumRLConfig {
19 pub state_qubits: usize,
21 pub action_qubits: usize,
23 pub value_qubits: usize,
25 pub learning_rate: f64,
27 pub discount_factor: f64,
29 pub exploration_rate: f64,
31 pub exploration_decay: f64,
33 pub min_exploration_rate: f64,
35 pub replay_buffer_size: usize,
37 pub batch_size: usize,
39 pub circuit_depth: usize,
41 pub use_quantum_advantage: bool,
43 pub random_seed: Option<u64>,
45}
46
47impl Default for QuantumRLConfig {
48 fn default() -> Self {
49 Self {
50 state_qubits: 4,
51 action_qubits: 2,
52 value_qubits: 3,
53 learning_rate: 0.01,
54 discount_factor: 0.99,
55 exploration_rate: 1.0,
56 exploration_decay: 0.995,
57 min_exploration_rate: 0.01,
58 replay_buffer_size: 10000,
59 batch_size: 32,
60 circuit_depth: 6,
61 use_quantum_advantage: true,
62 random_seed: None,
63 }
64 }
65}
66
67#[derive(Debug, Clone)]
69pub struct Experience {
70 pub state: Array1<f64>,
72 pub action: usize,
74 pub reward: f64,
76 pub next_state: Array1<f64>,
78 pub done: bool,
80}
81
82pub struct ReplayBuffer {
84 buffer: VecDeque<Experience>,
86 max_size: usize,
88 rng: StdRng,
90}
91
92impl ReplayBuffer {
93 pub fn new(max_size: usize, seed: Option<u64>) -> Self {
95 let rng = match seed {
96 Some(s) => StdRng::seed_from_u64(s),
97 None => StdRng::from_seed([0; 32]), };
99
100 Self {
101 buffer: VecDeque::with_capacity(max_size),
102 max_size,
103 rng,
104 }
105 }
106
107 pub fn add(&mut self, experience: Experience) {
109 if self.buffer.len() >= self.max_size {
110 self.buffer.pop_front();
111 }
112 self.buffer.push_back(experience);
113 }
114
115 pub fn sample(&mut self, batch_size: usize) -> Vec<Experience> {
117 let mut samples = Vec::new();
118 let buffer_size = self.buffer.len();
119
120 if buffer_size < batch_size {
121 return self.buffer.iter().cloned().collect();
122 }
123
124 for _ in 0..batch_size {
125 let idx = self.rng.random_range(0..buffer_size);
126 samples.push(self.buffer[idx].clone());
127 }
128
129 samples
130 }
131
132 pub fn size(&self) -> usize {
134 self.buffer.len()
135 }
136
137 pub fn can_sample(&self, batch_size: usize) -> bool {
139 self.buffer.len() >= batch_size
140 }
141}
142
143pub struct QuantumDQN {
145 config: QuantumRLConfig,
147 q_network: QuantumValueNetwork,
149 target_q_network: QuantumValueNetwork,
151 policy_network: QuantumPolicyNetwork,
153 replay_buffer: ReplayBuffer,
155 training_steps: usize,
157 episodes: usize,
159 current_exploration_rate: f64,
161 rng: StdRng,
163}
164
165pub struct QuantumValueNetwork {
167 circuit: QuantumValueCircuit,
169 parameters: Array1<f64>,
171 optimizer: VariationalOptimizer,
173}
174
175pub struct QuantumPolicyNetwork {
177 circuit: QuantumPolicyCircuit,
179 parameters: Array1<f64>,
181 optimizer: VariationalOptimizer,
183}
184
185#[derive(Debug, Clone)]
187pub struct QuantumValueCircuit {
188 state_qubits: usize,
190 value_qubits: usize,
192 depth: usize,
194 total_qubits: usize,
196}
197
198#[derive(Debug, Clone)]
200pub struct QuantumPolicyCircuit {
201 state_qubits: usize,
203 action_qubits: usize,
205 depth: usize,
207 total_qubits: usize,
209}
210
211impl QuantumDQN {
212 pub fn new(config: QuantumRLConfig) -> QuantRS2Result<Self> {
214 let rng = match config.random_seed {
215 Some(seed) => StdRng::seed_from_u64(seed),
216 None => StdRng::from_seed([0; 32]), };
218
219 let q_network = QuantumValueNetwork::new(&config)?;
221 let mut target_q_network = QuantumValueNetwork::new(&config)?;
222
223 target_q_network.parameters = q_network.parameters.clone();
225
226 let policy_network = QuantumPolicyNetwork::new(&config)?;
228
229 let replay_buffer = ReplayBuffer::new(config.replay_buffer_size, config.random_seed);
231
232 Ok(Self {
233 config: config.clone(),
234 q_network,
235 target_q_network,
236 policy_network,
237 replay_buffer,
238 training_steps: 0,
239 episodes: 0,
240 current_exploration_rate: config.exploration_rate,
241 rng,
242 })
243 }
244
245 pub fn select_action(&mut self, state: &Array1<f64>) -> QuantRS2Result<usize> {
247 if self.rng.random::<f64>() < self.current_exploration_rate {
249 let num_actions = 1 << self.config.action_qubits;
251 Ok(self.rng.random_range(0..num_actions))
252 } else {
253 self.policy_network.get_best_action(state)
255 }
256 }
257
258 pub fn store_experience(&mut self, experience: Experience) {
260 self.replay_buffer.add(experience);
261 }
262
263 pub fn train(&mut self) -> QuantRS2Result<TrainingMetrics> {
265 if !self.replay_buffer.can_sample(self.config.batch_size) {
266 return Ok(TrainingMetrics::default());
267 }
268
269 let experiences = self.replay_buffer.sample(self.config.batch_size);
271
272 let (states, actions, rewards, next_states, dones) =
274 self.prepare_training_data(&experiences);
275
276 let target_q_values = self.compute_target_q_values(&next_states, &rewards, &dones)?;
278
279 let q_loss = self.train_q_network(&states, &actions, &target_q_values)?;
281
282 let policy_loss = self.train_policy_network(&states)?;
284
285 if self.training_steps % 100 == 0 {
287 self.update_target_network();
288 }
289
290 self.update_exploration_rate();
292
293 self.training_steps += 1;
294
295 Ok(TrainingMetrics {
296 q_loss,
297 policy_loss,
298 exploration_rate: self.current_exploration_rate,
299 training_steps: self.training_steps,
300 })
301 }
302
303 fn update_target_network(&mut self) {
305 self.target_q_network.parameters = self.q_network.parameters.clone();
306 }
307
308 fn update_exploration_rate(&mut self) {
310 self.current_exploration_rate = (self.current_exploration_rate
311 * self.config.exploration_decay)
312 .max(self.config.min_exploration_rate);
313 }
314
315 fn prepare_training_data(
317 &self,
318 experiences: &[Experience],
319 ) -> (
320 Array2<f64>,
321 Array1<usize>,
322 Array1<f64>,
323 Array2<f64>,
324 Array1<bool>,
325 ) {
326 let batch_size = experiences.len();
327 let state_dim = experiences[0].state.len();
328
329 let mut states = Array2::zeros((batch_size, state_dim));
330 let mut actions = Array1::zeros(batch_size);
331 let mut rewards = Array1::zeros(batch_size);
332 let mut next_states = Array2::zeros((batch_size, state_dim));
333 let mut dones = Array1::from_elem(batch_size, false);
334
335 for (i, exp) in experiences.iter().enumerate() {
336 states.row_mut(i).assign(&exp.state);
337 actions[i] = exp.action;
338 rewards[i] = exp.reward;
339 next_states.row_mut(i).assign(&exp.next_state);
340 dones[i] = exp.done;
341 }
342
343 (states, actions, rewards, next_states, dones)
344 }
345
346 fn compute_target_q_values(
348 &self,
349 next_states: &Array2<f64>,
350 rewards: &Array1<f64>,
351 dones: &Array1<bool>,
352 ) -> QuantRS2Result<Array1<f64>> {
353 let batch_size = next_states.nrows();
354 let mut target_q_values = Array1::zeros(batch_size);
355
356 for i in 0..batch_size {
357 if dones[i] {
358 target_q_values[i] = rewards[i];
359 } else {
360 let next_state = next_states.row(i).to_owned();
361 let max_next_q = self.target_q_network.get_max_q_value(&next_state)?;
362 target_q_values[i] = rewards[i] + self.config.discount_factor * max_next_q;
363 }
364 }
365
366 Ok(target_q_values)
367 }
368
369 fn train_q_network(
371 &mut self,
372 states: &Array2<f64>,
373 actions: &Array1<usize>,
374 target_q_values: &Array1<f64>,
375 ) -> QuantRS2Result<f64> {
376 let batch_size = states.nrows();
377 let mut total_loss = 0.0;
378
379 for i in 0..batch_size {
380 let state = states.row(i).to_owned();
381 let action = actions[i];
382 let target = target_q_values[i];
383
384 let current_q = self.q_network.get_q_value(&state, action)?;
386
387 let loss = (current_q - target).powi(2);
389 total_loss += loss;
390
391 let gradients = self.q_network.compute_gradients(&state, action, target)?;
393 self.q_network
394 .update_parameters(&gradients, self.config.learning_rate)?;
395 }
396
397 Ok(total_loss / batch_size as f64)
398 }
399
400 fn train_policy_network(&mut self, states: &Array2<f64>) -> QuantRS2Result<f64> {
402 let batch_size = states.nrows();
403 let mut total_loss = 0.0;
404
405 for i in 0..batch_size {
406 let state = states.row(i).to_owned();
407
408 let policy_loss = self
410 .policy_network
411 .compute_policy_loss(&state, &self.q_network)?;
412 total_loss += policy_loss;
413
414 let gradients = self
416 .policy_network
417 .compute_policy_gradients(&state, &self.q_network)?;
418 self.policy_network
419 .update_parameters(&gradients, self.config.learning_rate)?;
420 }
421
422 Ok(total_loss / batch_size as f64)
423 }
424
425 pub fn end_episode(&mut self, _total_reward: f64) {
427 self.episodes += 1;
428 }
429
430 pub fn get_statistics(&self) -> QLearningStats {
432 QLearningStats {
433 episodes: self.episodes,
434 training_steps: self.training_steps,
435 exploration_rate: self.current_exploration_rate,
436 replay_buffer_size: self.replay_buffer.size(),
437 }
438 }
439}
440
441impl QuantumValueNetwork {
442 fn new(config: &QuantumRLConfig) -> QuantRS2Result<Self> {
444 let circuit = QuantumValueCircuit::new(
445 config.state_qubits,
446 config.value_qubits,
447 config.circuit_depth,
448 )?;
449
450 let num_parameters = circuit.get_parameter_count();
451 let mut parameters = Array1::zeros(num_parameters);
452
453 let mut rng = match config.random_seed {
455 Some(seed) => StdRng::seed_from_u64(seed),
456 None => StdRng::from_seed([0; 32]),
457 };
458
459 for param in parameters.iter_mut() {
460 *param = rng.random_range(-std::f64::consts::PI..std::f64::consts::PI);
461 }
462
463 let optimizer = VariationalOptimizer::new(0.01, 0.9);
464
465 Ok(Self {
466 circuit,
467 parameters,
468 optimizer,
469 })
470 }
471
472 fn get_q_value(&self, state: &Array1<f64>, action: usize) -> QuantRS2Result<f64> {
474 self.circuit
475 .evaluate_q_value(state, action, &self.parameters)
476 }
477
478 fn get_max_q_value(&self, state: &Array1<f64>) -> QuantRS2Result<f64> {
480 let num_actions = 1 << self.circuit.get_action_qubits();
481 let mut max_q = f64::NEG_INFINITY;
482
483 for action in 0..num_actions {
484 let q_value = self.get_q_value(state, action)?;
485 max_q = max_q.max(q_value);
486 }
487
488 Ok(max_q)
489 }
490
491 fn compute_gradients(
493 &self,
494 state: &Array1<f64>,
495 action: usize,
496 target: f64,
497 ) -> QuantRS2Result<Array1<f64>> {
498 self.circuit
499 .compute_parameter_gradients(state, action, target, &self.parameters)
500 }
501
502 fn update_parameters(
504 &mut self,
505 gradients: &Array1<f64>,
506 learning_rate: f64,
507 ) -> QuantRS2Result<()> {
508 for (param, &grad) in self.parameters.iter_mut().zip(gradients.iter()) {
509 *param -= learning_rate * grad;
510 }
511 Ok(())
512 }
513}
514
515impl QuantumPolicyNetwork {
516 fn new(config: &QuantumRLConfig) -> QuantRS2Result<Self> {
518 let circuit = QuantumPolicyCircuit::new(
519 config.state_qubits,
520 config.action_qubits,
521 config.circuit_depth,
522 )?;
523
524 let num_parameters = circuit.get_parameter_count();
525 let mut parameters = Array1::zeros(num_parameters);
526
527 let mut rng = match config.random_seed {
529 Some(seed) => StdRng::seed_from_u64(seed),
530 None => StdRng::from_seed([0; 32]),
531 };
532
533 for param in parameters.iter_mut() {
534 *param = rng.random_range(-std::f64::consts::PI..std::f64::consts::PI);
535 }
536
537 let optimizer = VariationalOptimizer::new(0.01, 0.9);
538
539 Ok(Self {
540 circuit,
541 parameters,
542 optimizer,
543 })
544 }
545
546 fn get_best_action(&self, state: &Array1<f64>) -> QuantRS2Result<usize> {
548 self.circuit.get_best_action(state, &self.parameters)
549 }
550
551 fn compute_policy_loss(
553 &self,
554 state: &Array1<f64>,
555 q_network: &QuantumValueNetwork,
556 ) -> QuantRS2Result<f64> {
557 let action_probs = self
559 .circuit
560 .get_action_probabilities(state, &self.parameters)?;
561 let num_actions = action_probs.len();
562
563 let mut expected_q = 0.0;
564 for action in 0..num_actions {
565 let q_value = q_network.get_q_value(state, action)?;
566 expected_q += action_probs[action] * q_value;
567 }
568
569 Ok(-expected_q)
571 }
572
573 fn compute_policy_gradients(
575 &self,
576 state: &Array1<f64>,
577 q_network: &QuantumValueNetwork,
578 ) -> QuantRS2Result<Array1<f64>> {
579 self.circuit
580 .compute_policy_gradients(state, q_network, &self.parameters)
581 }
582
583 fn update_parameters(
585 &mut self,
586 gradients: &Array1<f64>,
587 learning_rate: f64,
588 ) -> QuantRS2Result<()> {
589 for (param, &grad) in self.parameters.iter_mut().zip(gradients.iter()) {
590 *param -= learning_rate * grad;
591 }
592 Ok(())
593 }
594}
595
596impl QuantumValueCircuit {
597 fn new(state_qubits: usize, value_qubits: usize, depth: usize) -> QuantRS2Result<Self> {
599 let total_qubits = state_qubits + value_qubits;
600
601 Ok(Self {
602 state_qubits,
603 value_qubits,
604 depth,
605 total_qubits,
606 })
607 }
608
609 fn get_parameter_count(&self) -> usize {
611 let rotations_per_layer = self.get_total_qubits() * 3;
613 let entangling_per_layer = self.get_total_qubits(); self.depth * (rotations_per_layer + entangling_per_layer)
615 }
616
617 fn get_total_qubits(&self) -> usize {
619 self.state_qubits + self.value_qubits
620 }
621
622 fn get_action_qubits(&self) -> usize {
624 2 }
628
629 fn evaluate_q_value(
631 &self,
632 state: &Array1<f64>,
633 action: usize,
634 parameters: &Array1<f64>,
635 ) -> QuantRS2Result<f64> {
636 let mut gates = Vec::new();
638
639 for i in 0..self.state_qubits {
641 let state_value = if i < state.len() { state[i] } else { 0.0 };
642 gates.push(Box::new(RotationY {
643 target: QubitId(i as u32),
644 theta: state_value * std::f64::consts::PI,
645 }) as Box<dyn GateOp>);
646 }
647
648 for i in 0..2 {
650 if (action >> i) & 1 == 1 {
652 gates.push(Box::new(PauliX {
653 target: QubitId((self.state_qubits + i) as u32),
654 }) as Box<dyn GateOp>);
655 }
656 }
657
658 let mut param_idx = 0;
660 for _layer in 0..self.depth {
661 for qubit in 0..self.get_total_qubits() {
663 if param_idx + 2 < parameters.len() {
664 gates.push(Box::new(RotationX {
665 target: QubitId(qubit as u32),
666 theta: parameters[param_idx],
667 }) as Box<dyn GateOp>);
668 param_idx += 1;
669
670 gates.push(Box::new(RotationY {
671 target: QubitId(qubit as u32),
672 theta: parameters[param_idx],
673 }) as Box<dyn GateOp>);
674 param_idx += 1;
675
676 gates.push(Box::new(RotationZ {
677 target: QubitId(qubit as u32),
678 theta: parameters[param_idx],
679 }) as Box<dyn GateOp>);
680 param_idx += 1;
681 }
682 }
683
684 for qubit in 0..self.get_total_qubits() - 1 {
686 if param_idx < parameters.len() {
687 gates.push(Box::new(CRZ {
688 control: QubitId(qubit as u32),
689 target: QubitId((qubit + 1) as u32),
690 theta: parameters[param_idx],
691 }) as Box<dyn GateOp>);
692 param_idx += 1;
693 }
694 }
695 }
696
697 let q_value = self.simulate_circuit_expectation(&gates)?;
700
701 Ok(q_value)
702 }
703
704 fn simulate_circuit_expectation(&self, gates: &[Box<dyn GateOp>]) -> QuantRS2Result<f64> {
706 let mut hash_value = 0u64;
708
709 for gate in gates {
710 if let Ok(matrix) = gate.matrix() {
712 for complex in &matrix {
713 hash_value = hash_value.wrapping_add((complex.re * 1000.0) as u64);
714 hash_value = hash_value.wrapping_add((complex.im * 1000.0) as u64);
715 }
716 }
717 }
718
719 let expectation = (hash_value % 2000) as f64 / 1000.0 - 1.0;
721 Ok(expectation)
722 }
723
724 fn compute_parameter_gradients(
726 &self,
727 state: &Array1<f64>,
728 action: usize,
729 target: f64,
730 parameters: &Array1<f64>,
731 ) -> QuantRS2Result<Array1<f64>> {
732 let mut gradients = Array1::zeros(parameters.len());
733 let shift = std::f64::consts::PI / 2.0;
734
735 for i in 0..parameters.len() {
736 let mut params_plus = parameters.clone();
738 params_plus[i] += shift;
739 let q_plus = self.evaluate_q_value(state, action, ¶ms_plus)?;
740
741 let mut params_minus = parameters.clone();
743 params_minus[i] -= shift;
744 let q_minus = self.evaluate_q_value(state, action, ¶ms_minus)?;
745
746 let current_q = self.evaluate_q_value(state, action, parameters)?;
748 let loss_gradient = 2.0 * (current_q - target); gradients[i] = loss_gradient * (q_plus - q_minus) / 2.0;
751 }
752
753 Ok(gradients)
754 }
755}
756
757impl QuantumPolicyCircuit {
758 fn new(state_qubits: usize, action_qubits: usize, depth: usize) -> QuantRS2Result<Self> {
760 let total_qubits = state_qubits + action_qubits;
761
762 Ok(Self {
763 state_qubits,
764 action_qubits,
765 depth,
766 total_qubits,
767 })
768 }
769
770 fn get_parameter_count(&self) -> usize {
772 let total_qubits = self.state_qubits + self.action_qubits;
773 let rotations_per_layer = total_qubits * 3;
774 let entangling_per_layer = total_qubits;
775 self.depth * (rotations_per_layer + entangling_per_layer)
776 }
777
778 fn get_best_action(
780 &self,
781 state: &Array1<f64>,
782 parameters: &Array1<f64>,
783 ) -> QuantRS2Result<usize> {
784 let action_probs = self.get_action_probabilities(state, parameters)?;
785
786 let mut best_action = 0;
788 let mut best_prob = action_probs[0];
789
790 for (action, &prob) in action_probs.iter().enumerate() {
791 if prob > best_prob {
792 best_prob = prob;
793 best_action = action;
794 }
795 }
796
797 Ok(best_action)
798 }
799
800 fn get_action_probabilities(
802 &self,
803 state: &Array1<f64>,
804 parameters: &Array1<f64>,
805 ) -> QuantRS2Result<Vec<f64>> {
806 let num_actions = 1 << self.action_qubits;
807 let mut probabilities = vec![0.0; num_actions];
808
809 let base_prob = 1.0 / num_actions as f64;
811
812 for action in 0..num_actions {
813 let state_hash = state.iter().sum::<f64>();
815 let param_hash = parameters.iter().take(10).sum::<f64>();
816 let variation = 0.1 * ((state_hash + param_hash + action as f64).sin());
817
818 probabilities[action] = base_prob + variation;
819 }
820
821 let sum: f64 = probabilities.iter().sum();
823 for prob in &mut probabilities {
824 *prob /= sum;
825 }
826
827 Ok(probabilities)
828 }
829
830 fn compute_policy_gradients(
832 &self,
833 state: &Array1<f64>,
834 q_network: &QuantumValueNetwork,
835 parameters: &Array1<f64>,
836 ) -> QuantRS2Result<Array1<f64>> {
837 let mut gradients = Array1::zeros(parameters.len());
838 let shift = std::f64::consts::PI / 2.0;
839
840 for i in 0..parameters.len() {
841 let mut params_plus = parameters.clone();
843 params_plus[i] += shift;
844 let loss_plus = self.compute_policy_loss_with_params(state, q_network, ¶ms_plus)?;
845
846 let mut params_minus = parameters.clone();
848 params_minus[i] -= shift;
849 let loss_minus =
850 self.compute_policy_loss_with_params(state, q_network, ¶ms_minus)?;
851
852 gradients[i] = (loss_plus - loss_minus) / 2.0;
854 }
855
856 Ok(gradients)
857 }
858
859 fn compute_policy_loss_with_params(
861 &self,
862 state: &Array1<f64>,
863 q_network: &QuantumValueNetwork,
864 parameters: &Array1<f64>,
865 ) -> QuantRS2Result<f64> {
866 let action_probs = self.get_action_probabilities(state, parameters)?;
867 let num_actions = action_probs.len();
868
869 let mut expected_q = 0.0;
870 for action in 0..num_actions {
871 let q_value = q_network.get_q_value(state, action)?;
872 expected_q += action_probs[action] * q_value;
873 }
874
875 Ok(-expected_q) }
877}
878
879#[derive(Debug, Clone, Default)]
881pub struct TrainingMetrics {
882 pub q_loss: f64,
884 pub policy_loss: f64,
886 pub exploration_rate: f64,
888 pub training_steps: usize,
890}
891
892#[derive(Debug, Clone)]
894pub struct QLearningStats {
895 pub episodes: usize,
897 pub training_steps: usize,
899 pub exploration_rate: f64,
901 pub replay_buffer_size: usize,
903}
904
905pub struct QuantumActorCritic {
907 config: QuantumRLConfig,
909 actor: QuantumPolicyNetwork,
911 critic: QuantumValueNetwork,
913 metrics: TrainingMetrics,
915}
916
917impl QuantumActorCritic {
918 pub fn new(config: QuantumRLConfig) -> QuantRS2Result<Self> {
920 let actor = QuantumPolicyNetwork::new(&config)?;
921 let critic = QuantumValueNetwork::new(&config)?;
922
923 Ok(Self {
924 config,
925 actor,
926 critic,
927 metrics: TrainingMetrics::default(),
928 })
929 }
930
931 pub fn update(
933 &mut self,
934 state: &Array1<f64>,
935 _action: usize,
936 reward: f64,
937 next_state: &Array1<f64>,
938 done: bool,
939 ) -> QuantRS2Result<()> {
940 let current_value = self.critic.get_q_value(state, 0)?; let next_value = if done {
943 0.0
944 } else {
945 self.critic.get_max_q_value(next_state)?
946 };
947
948 let target_value = reward + self.config.discount_factor * next_value;
949 let td_error = target_value - current_value;
950
951 let critic_gradients = self.critic.compute_gradients(state, 0, target_value)?;
953 self.critic
954 .update_parameters(&critic_gradients, self.config.learning_rate)?;
955
956 let actor_gradients = self.actor.compute_policy_gradients(state, &self.critic)?;
958 let scaled_gradients = actor_gradients * td_error; self.actor
960 .update_parameters(&scaled_gradients, self.config.learning_rate)?;
961
962 self.metrics.q_loss = td_error.abs();
964 self.metrics.policy_loss = -td_error; Ok(())
967 }
968
969 pub fn select_action(&self, state: &Array1<f64>) -> QuantRS2Result<usize> {
971 self.actor.get_best_action(state)
972 }
973
974 pub fn get_metrics(&self) -> &TrainingMetrics {
976 &self.metrics
977 }
978}
979
980#[cfg(test)]
981mod tests {
982 use super::*;
983
984 #[test]
985 fn test_quantum_dqn_creation() {
986 let config = QuantumRLConfig::default();
987 let agent = QuantumDQN::new(config).unwrap();
988
989 let stats = agent.get_statistics();
990 assert_eq!(stats.episodes, 0);
991 assert_eq!(stats.training_steps, 0);
992 }
993
994 #[test]
995 fn test_replay_buffer() {
996 let mut buffer = ReplayBuffer::new(10, Some(42));
997
998 let experience = Experience {
999 state: Array1::from_vec(vec![1.0, 0.0, -1.0]),
1000 action: 1,
1001 reward: 1.0,
1002 next_state: Array1::from_vec(vec![0.0, 1.0, 0.0]),
1003 done: false,
1004 };
1005
1006 buffer.add(experience);
1007 assert_eq!(buffer.size(), 1);
1008
1009 let samples = buffer.sample(1);
1010 assert_eq!(samples.len(), 1);
1011 }
1012
1013 #[test]
1014 fn test_quantum_value_circuit() {
1015 let circuit = QuantumValueCircuit::new(3, 2, 4).unwrap();
1016 let param_count = circuit.get_parameter_count();
1017 assert!(param_count > 0);
1018
1019 let state = Array1::from_vec(vec![0.5, -0.5, 0.0]);
1020 let parameters = Array1::zeros(param_count);
1021
1022 let q_value = circuit.evaluate_q_value(&state, 1, ¶meters).unwrap();
1023 assert!(q_value.is_finite());
1024 }
1025
1026 #[test]
1027 fn test_quantum_actor_critic() {
1028 let config = QuantumRLConfig::default();
1029 let mut agent = QuantumActorCritic::new(config).unwrap();
1030
1031 let state = Array1::from_vec(vec![0.5, -0.5]);
1032 let next_state = Array1::from_vec(vec![0.0, 1.0]);
1033
1034 let action = agent.select_action(&state).unwrap();
1035 assert!(action < 4); agent
1038 .update(&state, action, 1.0, &next_state, false)
1039 .unwrap();
1040
1041 let metrics = agent.get_metrics();
1042 assert!(metrics.q_loss >= 0.0);
1043 }
1044
1045 #[test]
1046 fn test_quantum_rl_config_default() {
1047 let config = QuantumRLConfig::default();
1048 assert_eq!(config.state_qubits, 4);
1049 assert_eq!(config.action_qubits, 2);
1050 assert!(config.learning_rate > 0.0);
1051 assert!(config.discount_factor < 1.0);
1052 }
1053}