1use crate::{
8 error::QuantRS2Result, gate::multi::*, gate::single::*, gate::GateOp, qubit::QubitId,
9 variational::VariationalOptimizer,
10};
11use scirs2_core::ndarray::{Array1, Array2};
12use scirs2_core::random::{rngs::StdRng, Rng, SeedableRng};
13use serde::{Deserialize, Serialize};
14use std::collections::VecDeque;
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct QuantumRLConfig {
19 pub state_qubits: usize,
21 pub action_qubits: usize,
23 pub value_qubits: usize,
25 pub learning_rate: f64,
27 pub discount_factor: f64,
29 pub exploration_rate: f64,
31 pub exploration_decay: f64,
33 pub min_exploration_rate: f64,
35 pub replay_buffer_size: usize,
37 pub batch_size: usize,
39 pub circuit_depth: usize,
41 pub use_quantum_advantage: bool,
43 pub random_seed: Option<u64>,
45}
46
47impl Default for QuantumRLConfig {
48 fn default() -> Self {
49 Self {
50 state_qubits: 4,
51 action_qubits: 2,
52 value_qubits: 3,
53 learning_rate: 0.01,
54 discount_factor: 0.99,
55 exploration_rate: 1.0,
56 exploration_decay: 0.995,
57 min_exploration_rate: 0.01,
58 replay_buffer_size: 10000,
59 batch_size: 32,
60 circuit_depth: 6,
61 use_quantum_advantage: true,
62 random_seed: None,
63 }
64 }
65}
66
67#[derive(Debug, Clone)]
69pub struct Experience {
70 pub state: Array1<f64>,
72 pub action: usize,
74 pub reward: f64,
76 pub next_state: Array1<f64>,
78 pub done: bool,
80}
81
82pub struct ReplayBuffer {
84 buffer: VecDeque<Experience>,
86 max_size: usize,
88 rng: StdRng,
90}
91
92impl ReplayBuffer {
93 pub fn new(max_size: usize, seed: Option<u64>) -> Self {
95 let rng = match seed {
96 Some(s) => StdRng::seed_from_u64(s),
97 None => StdRng::from_seed([0; 32]), };
99
100 Self {
101 buffer: VecDeque::with_capacity(max_size),
102 max_size,
103 rng,
104 }
105 }
106
107 pub fn add(&mut self, experience: Experience) {
109 if self.buffer.len() >= self.max_size {
110 self.buffer.pop_front();
111 }
112 self.buffer.push_back(experience);
113 }
114
115 pub fn sample(&mut self, batch_size: usize) -> Vec<Experience> {
117 let mut samples = Vec::new();
118 let buffer_size = self.buffer.len();
119
120 if buffer_size < batch_size {
121 return self.buffer.iter().cloned().collect();
122 }
123
124 for _ in 0..batch_size {
125 let idx = self.rng.random_range(0..buffer_size);
126 samples.push(self.buffer[idx].clone());
127 }
128
129 samples
130 }
131
132 pub fn size(&self) -> usize {
134 self.buffer.len()
135 }
136
137 pub fn can_sample(&self, batch_size: usize) -> bool {
139 self.buffer.len() >= batch_size
140 }
141}
142
143pub struct QuantumDQN {
145 config: QuantumRLConfig,
147 q_network: QuantumValueNetwork,
149 target_q_network: QuantumValueNetwork,
151 policy_network: QuantumPolicyNetwork,
153 replay_buffer: ReplayBuffer,
155 training_steps: usize,
157 episodes: usize,
159 current_exploration_rate: f64,
161 rng: StdRng,
163}
164
165pub struct QuantumValueNetwork {
167 circuit: QuantumValueCircuit,
169 parameters: Array1<f64>,
171 optimizer: VariationalOptimizer,
173}
174
175pub struct QuantumPolicyNetwork {
177 circuit: QuantumPolicyCircuit,
179 parameters: Array1<f64>,
181 optimizer: VariationalOptimizer,
183}
184
185#[derive(Debug, Clone)]
187pub struct QuantumValueCircuit {
188 state_qubits: usize,
190 value_qubits: usize,
192 depth: usize,
194 total_qubits: usize,
196}
197
198#[derive(Debug, Clone)]
200pub struct QuantumPolicyCircuit {
201 state_qubits: usize,
203 action_qubits: usize,
205 depth: usize,
207 total_qubits: usize,
209}
210
211impl QuantumDQN {
212 pub fn new(config: QuantumRLConfig) -> QuantRS2Result<Self> {
214 let rng = match config.random_seed {
215 Some(seed) => StdRng::seed_from_u64(seed),
216 None => StdRng::from_seed([0; 32]), };
218
219 let q_network = QuantumValueNetwork::new(&config)?;
221 let mut target_q_network = QuantumValueNetwork::new(&config)?;
222
223 target_q_network
225 .parameters
226 .clone_from(&q_network.parameters);
227
228 let policy_network = QuantumPolicyNetwork::new(&config)?;
230
231 let replay_buffer = ReplayBuffer::new(config.replay_buffer_size, config.random_seed);
233
234 Ok(Self {
235 config: config.clone(),
236 q_network,
237 target_q_network,
238 policy_network,
239 replay_buffer,
240 training_steps: 0,
241 episodes: 0,
242 current_exploration_rate: config.exploration_rate,
243 rng,
244 })
245 }
246
247 pub fn select_action(&mut self, state: &Array1<f64>) -> QuantRS2Result<usize> {
249 if self.rng.random::<f64>() < self.current_exploration_rate {
251 let num_actions = 1 << self.config.action_qubits;
253 Ok(self.rng.random_range(0..num_actions))
254 } else {
255 self.policy_network.get_best_action(state)
257 }
258 }
259
260 pub fn store_experience(&mut self, experience: Experience) {
262 self.replay_buffer.add(experience);
263 }
264
265 pub fn train(&mut self) -> QuantRS2Result<TrainingMetrics> {
267 if !self.replay_buffer.can_sample(self.config.batch_size) {
268 return Ok(TrainingMetrics::default());
269 }
270
271 let experiences = self.replay_buffer.sample(self.config.batch_size);
273
274 let (states, actions, rewards, next_states, dones) =
276 self.prepare_training_data(&experiences);
277
278 let target_q_values = self.compute_target_q_values(&next_states, &rewards, &dones)?;
280
281 let q_loss = self.train_q_network(&states, &actions, &target_q_values)?;
283
284 let policy_loss = self.train_policy_network(&states)?;
286
287 if self.training_steps % 100 == 0 {
289 self.update_target_network();
290 }
291
292 self.update_exploration_rate();
294
295 self.training_steps += 1;
296
297 Ok(TrainingMetrics {
298 q_loss,
299 policy_loss,
300 exploration_rate: self.current_exploration_rate,
301 training_steps: self.training_steps,
302 })
303 }
304
305 fn update_target_network(&mut self) {
307 self.target_q_network.parameters = self.q_network.parameters.clone();
308 }
309
310 fn update_exploration_rate(&mut self) {
312 self.current_exploration_rate = (self.current_exploration_rate
313 * self.config.exploration_decay)
314 .max(self.config.min_exploration_rate);
315 }
316
317 fn prepare_training_data(
319 &self,
320 experiences: &[Experience],
321 ) -> (
322 Array2<f64>,
323 Array1<usize>,
324 Array1<f64>,
325 Array2<f64>,
326 Array1<bool>,
327 ) {
328 let batch_size = experiences.len();
329 let state_dim = experiences[0].state.len();
330
331 let mut states = Array2::zeros((batch_size, state_dim));
332 let mut actions = Array1::zeros(batch_size);
333 let mut rewards = Array1::zeros(batch_size);
334 let mut next_states = Array2::zeros((batch_size, state_dim));
335 let mut dones = Array1::from_elem(batch_size, false);
336
337 for (i, exp) in experiences.iter().enumerate() {
338 states.row_mut(i).assign(&exp.state);
339 actions[i] = exp.action;
340 rewards[i] = exp.reward;
341 next_states.row_mut(i).assign(&exp.next_state);
342 dones[i] = exp.done;
343 }
344
345 (states, actions, rewards, next_states, dones)
346 }
347
348 fn compute_target_q_values(
350 &self,
351 next_states: &Array2<f64>,
352 rewards: &Array1<f64>,
353 dones: &Array1<bool>,
354 ) -> QuantRS2Result<Array1<f64>> {
355 let batch_size = next_states.nrows();
356 let mut target_q_values = Array1::zeros(batch_size);
357
358 for i in 0..batch_size {
359 if dones[i] {
360 target_q_values[i] = rewards[i];
361 } else {
362 let next_state = next_states.row(i).to_owned();
363 let max_next_q = self.target_q_network.get_max_q_value(&next_state)?;
364 target_q_values[i] = self.config.discount_factor.mul_add(max_next_q, rewards[i]);
365 }
366 }
367
368 Ok(target_q_values)
369 }
370
371 fn train_q_network(
373 &mut self,
374 states: &Array2<f64>,
375 actions: &Array1<usize>,
376 target_q_values: &Array1<f64>,
377 ) -> QuantRS2Result<f64> {
378 let batch_size = states.nrows();
379 let mut total_loss = 0.0;
380
381 for i in 0..batch_size {
382 let state = states.row(i).to_owned();
383 let action = actions[i];
384 let target = target_q_values[i];
385
386 let current_q = self.q_network.get_q_value(&state, action)?;
388
389 let loss = (current_q - target).powi(2);
391 total_loss += loss;
392
393 let gradients = self.q_network.compute_gradients(&state, action, target)?;
395 self.q_network
396 .update_parameters(&gradients, self.config.learning_rate)?;
397 }
398
399 Ok(total_loss / batch_size as f64)
400 }
401
402 fn train_policy_network(&mut self, states: &Array2<f64>) -> QuantRS2Result<f64> {
404 let batch_size = states.nrows();
405 let mut total_loss = 0.0;
406
407 for i in 0..batch_size {
408 let state = states.row(i).to_owned();
409
410 let policy_loss = self
412 .policy_network
413 .compute_policy_loss(&state, &self.q_network)?;
414 total_loss += policy_loss;
415
416 let gradients = self
418 .policy_network
419 .compute_policy_gradients(&state, &self.q_network)?;
420 self.policy_network
421 .update_parameters(&gradients, self.config.learning_rate)?;
422 }
423
424 Ok(total_loss / batch_size as f64)
425 }
426
427 pub const fn end_episode(&mut self, _total_reward: f64) {
429 self.episodes += 1;
430 }
431
432 pub fn get_statistics(&self) -> QLearningStats {
434 QLearningStats {
435 episodes: self.episodes,
436 training_steps: self.training_steps,
437 exploration_rate: self.current_exploration_rate,
438 replay_buffer_size: self.replay_buffer.size(),
439 }
440 }
441}
442
443impl QuantumValueNetwork {
444 fn new(config: &QuantumRLConfig) -> QuantRS2Result<Self> {
446 let circuit = QuantumValueCircuit::new(
447 config.state_qubits,
448 config.value_qubits,
449 config.circuit_depth,
450 )?;
451
452 let num_parameters = circuit.get_parameter_count();
453 let mut parameters = Array1::zeros(num_parameters);
454
455 let mut rng = match config.random_seed {
457 Some(seed) => StdRng::seed_from_u64(seed),
458 None => StdRng::from_seed([0; 32]),
459 };
460
461 for param in &mut parameters {
462 *param = rng.random_range(-std::f64::consts::PI..std::f64::consts::PI);
463 }
464
465 let optimizer = VariationalOptimizer::new(0.01, 0.9);
466
467 Ok(Self {
468 circuit,
469 parameters,
470 optimizer,
471 })
472 }
473
474 fn get_q_value(&self, state: &Array1<f64>, action: usize) -> QuantRS2Result<f64> {
476 self.circuit
477 .evaluate_q_value(state, action, &self.parameters)
478 }
479
480 fn get_max_q_value(&self, state: &Array1<f64>) -> QuantRS2Result<f64> {
482 let num_actions = 1 << self.circuit.get_action_qubits();
483 let mut max_q = f64::NEG_INFINITY;
484
485 for action in 0..num_actions {
486 let q_value = self.get_q_value(state, action)?;
487 max_q = max_q.max(q_value);
488 }
489
490 Ok(max_q)
491 }
492
493 fn compute_gradients(
495 &self,
496 state: &Array1<f64>,
497 action: usize,
498 target: f64,
499 ) -> QuantRS2Result<Array1<f64>> {
500 self.circuit
501 .compute_parameter_gradients(state, action, target, &self.parameters)
502 }
503
504 fn update_parameters(
506 &mut self,
507 gradients: &Array1<f64>,
508 learning_rate: f64,
509 ) -> QuantRS2Result<()> {
510 for (param, &grad) in self.parameters.iter_mut().zip(gradients.iter()) {
511 *param -= learning_rate * grad;
512 }
513 Ok(())
514 }
515}
516
517impl QuantumPolicyNetwork {
518 fn new(config: &QuantumRLConfig) -> QuantRS2Result<Self> {
520 let circuit = QuantumPolicyCircuit::new(
521 config.state_qubits,
522 config.action_qubits,
523 config.circuit_depth,
524 )?;
525
526 let num_parameters = circuit.get_parameter_count();
527 let mut parameters = Array1::zeros(num_parameters);
528
529 let mut rng = match config.random_seed {
531 Some(seed) => StdRng::seed_from_u64(seed),
532 None => StdRng::from_seed([0; 32]),
533 };
534
535 for param in &mut parameters {
536 *param = rng.random_range(-std::f64::consts::PI..std::f64::consts::PI);
537 }
538
539 let optimizer = VariationalOptimizer::new(0.01, 0.9);
540
541 Ok(Self {
542 circuit,
543 parameters,
544 optimizer,
545 })
546 }
547
548 fn get_best_action(&self, state: &Array1<f64>) -> QuantRS2Result<usize> {
550 self.circuit.get_best_action(state, &self.parameters)
551 }
552
553 fn compute_policy_loss(
555 &self,
556 state: &Array1<f64>,
557 q_network: &QuantumValueNetwork,
558 ) -> QuantRS2Result<f64> {
559 let action_probs = self
561 .circuit
562 .get_action_probabilities(state, &self.parameters)?;
563 let num_actions = action_probs.len();
564
565 let mut expected_q = 0.0;
566 for action in 0..num_actions {
567 let q_value = q_network.get_q_value(state, action)?;
568 expected_q += action_probs[action] * q_value;
569 }
570
571 Ok(-expected_q)
573 }
574
575 fn compute_policy_gradients(
577 &self,
578 state: &Array1<f64>,
579 q_network: &QuantumValueNetwork,
580 ) -> QuantRS2Result<Array1<f64>> {
581 self.circuit
582 .compute_policy_gradients(state, q_network, &self.parameters)
583 }
584
585 fn update_parameters(
587 &mut self,
588 gradients: &Array1<f64>,
589 learning_rate: f64,
590 ) -> QuantRS2Result<()> {
591 for (param, &grad) in self.parameters.iter_mut().zip(gradients.iter()) {
592 *param -= learning_rate * grad;
593 }
594 Ok(())
595 }
596}
597
598impl QuantumValueCircuit {
599 const fn new(state_qubits: usize, value_qubits: usize, depth: usize) -> QuantRS2Result<Self> {
601 let total_qubits = state_qubits + value_qubits;
602
603 Ok(Self {
604 state_qubits,
605 value_qubits,
606 depth,
607 total_qubits,
608 })
609 }
610
611 const fn get_parameter_count(&self) -> usize {
613 let rotations_per_layer = self.get_total_qubits() * 3;
615 let entangling_per_layer = self.get_total_qubits(); self.depth * (rotations_per_layer + entangling_per_layer)
617 }
618
619 const fn get_total_qubits(&self) -> usize {
621 self.state_qubits + self.value_qubits
622 }
623
624 const fn get_action_qubits(&self) -> usize {
626 2 }
630
631 fn evaluate_q_value(
633 &self,
634 state: &Array1<f64>,
635 action: usize,
636 parameters: &Array1<f64>,
637 ) -> QuantRS2Result<f64> {
638 let mut gates = Vec::new();
640
641 for i in 0..self.state_qubits {
643 let state_value = if i < state.len() { state[i] } else { 0.0 };
644 gates.push(Box::new(RotationY {
645 target: QubitId(i as u32),
646 theta: state_value * std::f64::consts::PI,
647 }) as Box<dyn GateOp>);
648 }
649
650 for i in 0..2 {
652 if (action >> i) & 1 == 1 {
654 gates.push(Box::new(PauliX {
655 target: QubitId((self.state_qubits + i) as u32),
656 }) as Box<dyn GateOp>);
657 }
658 }
659
660 let mut param_idx = 0;
662 for _layer in 0..self.depth {
663 for qubit in 0..self.get_total_qubits() {
665 if param_idx + 2 < parameters.len() {
666 gates.push(Box::new(RotationX {
667 target: QubitId(qubit as u32),
668 theta: parameters[param_idx],
669 }) as Box<dyn GateOp>);
670 param_idx += 1;
671
672 gates.push(Box::new(RotationY {
673 target: QubitId(qubit as u32),
674 theta: parameters[param_idx],
675 }) as Box<dyn GateOp>);
676 param_idx += 1;
677
678 gates.push(Box::new(RotationZ {
679 target: QubitId(qubit as u32),
680 theta: parameters[param_idx],
681 }) as Box<dyn GateOp>);
682 param_idx += 1;
683 }
684 }
685
686 for qubit in 0..self.get_total_qubits() - 1 {
688 if param_idx < parameters.len() {
689 gates.push(Box::new(CRZ {
690 control: QubitId(qubit as u32),
691 target: QubitId((qubit + 1) as u32),
692 theta: parameters[param_idx],
693 }) as Box<dyn GateOp>);
694 param_idx += 1;
695 }
696 }
697 }
698
699 let q_value = self.simulate_circuit_expectation(&gates)?;
702
703 Ok(q_value)
704 }
705
706 fn simulate_circuit_expectation(&self, gates: &[Box<dyn GateOp>]) -> QuantRS2Result<f64> {
708 let mut hash_value = 0u64;
710
711 for gate in gates {
712 if let Ok(matrix) = gate.matrix() {
714 for complex in &matrix {
715 hash_value = hash_value.wrapping_add((complex.re * 1000.0) as u64);
716 hash_value = hash_value.wrapping_add((complex.im * 1000.0) as u64);
717 }
718 }
719 }
720
721 let expectation = (hash_value % 2000) as f64 / 1000.0 - 1.0;
723 Ok(expectation)
724 }
725
726 fn compute_parameter_gradients(
728 &self,
729 state: &Array1<f64>,
730 action: usize,
731 target: f64,
732 parameters: &Array1<f64>,
733 ) -> QuantRS2Result<Array1<f64>> {
734 let mut gradients = Array1::zeros(parameters.len());
735 let shift = std::f64::consts::PI / 2.0;
736
737 for i in 0..parameters.len() {
738 let mut params_plus = parameters.clone();
740 params_plus[i] += shift;
741 let q_plus = self.evaluate_q_value(state, action, ¶ms_plus)?;
742
743 let mut params_minus = parameters.clone();
745 params_minus[i] -= shift;
746 let q_minus = self.evaluate_q_value(state, action, ¶ms_minus)?;
747
748 let current_q = self.evaluate_q_value(state, action, parameters)?;
750 let loss_gradient = 2.0 * (current_q - target); gradients[i] = loss_gradient * (q_plus - q_minus) / 2.0;
753 }
754
755 Ok(gradients)
756 }
757}
758
759impl QuantumPolicyCircuit {
760 const fn new(state_qubits: usize, action_qubits: usize, depth: usize) -> QuantRS2Result<Self> {
762 let total_qubits = state_qubits + action_qubits;
763
764 Ok(Self {
765 state_qubits,
766 action_qubits,
767 depth,
768 total_qubits,
769 })
770 }
771
772 const fn get_parameter_count(&self) -> usize {
774 let total_qubits = self.state_qubits + self.action_qubits;
775 let rotations_per_layer = total_qubits * 3;
776 let entangling_per_layer = total_qubits;
777 self.depth * (rotations_per_layer + entangling_per_layer)
778 }
779
780 fn get_best_action(
782 &self,
783 state: &Array1<f64>,
784 parameters: &Array1<f64>,
785 ) -> QuantRS2Result<usize> {
786 let action_probs = self.get_action_probabilities(state, parameters)?;
787
788 let mut best_action = 0;
790 let mut best_prob = action_probs[0];
791
792 for (action, &prob) in action_probs.iter().enumerate() {
793 if prob > best_prob {
794 best_prob = prob;
795 best_action = action;
796 }
797 }
798
799 Ok(best_action)
800 }
801
802 fn get_action_probabilities(
804 &self,
805 state: &Array1<f64>,
806 parameters: &Array1<f64>,
807 ) -> QuantRS2Result<Vec<f64>> {
808 let num_actions = 1 << self.action_qubits;
809 let mut probabilities = vec![0.0; num_actions];
810
811 let base_prob = 1.0 / num_actions as f64;
813
814 for action in 0..num_actions {
815 let state_hash = state.iter().sum::<f64>();
817 let param_hash = parameters.iter().take(10).sum::<f64>();
818 let variation = 0.1 * ((state_hash + param_hash + action as f64).sin());
819
820 probabilities[action] = base_prob + variation;
821 }
822
823 let sum: f64 = probabilities.iter().sum();
825 for prob in &mut probabilities {
826 *prob /= sum;
827 }
828
829 Ok(probabilities)
830 }
831
832 fn compute_policy_gradients(
834 &self,
835 state: &Array1<f64>,
836 q_network: &QuantumValueNetwork,
837 parameters: &Array1<f64>,
838 ) -> QuantRS2Result<Array1<f64>> {
839 let mut gradients = Array1::zeros(parameters.len());
840 let shift = std::f64::consts::PI / 2.0;
841
842 for i in 0..parameters.len() {
843 let mut params_plus = parameters.clone();
845 params_plus[i] += shift;
846 let loss_plus = self.compute_policy_loss_with_params(state, q_network, ¶ms_plus)?;
847
848 let mut params_minus = parameters.clone();
850 params_minus[i] -= shift;
851 let loss_minus =
852 self.compute_policy_loss_with_params(state, q_network, ¶ms_minus)?;
853
854 gradients[i] = (loss_plus - loss_minus) / 2.0;
856 }
857
858 Ok(gradients)
859 }
860
861 fn compute_policy_loss_with_params(
863 &self,
864 state: &Array1<f64>,
865 q_network: &QuantumValueNetwork,
866 parameters: &Array1<f64>,
867 ) -> QuantRS2Result<f64> {
868 let action_probs = self.get_action_probabilities(state, parameters)?;
869 let num_actions = action_probs.len();
870
871 let mut expected_q = 0.0;
872 for action in 0..num_actions {
873 let q_value = q_network.get_q_value(state, action)?;
874 expected_q += action_probs[action] * q_value;
875 }
876
877 Ok(-expected_q) }
879}
880
881#[derive(Debug, Clone, Default)]
883pub struct TrainingMetrics {
884 pub q_loss: f64,
886 pub policy_loss: f64,
888 pub exploration_rate: f64,
890 pub training_steps: usize,
892}
893
894#[derive(Debug, Clone)]
896pub struct QLearningStats {
897 pub episodes: usize,
899 pub training_steps: usize,
901 pub exploration_rate: f64,
903 pub replay_buffer_size: usize,
905}
906
907pub struct QuantumActorCritic {
909 config: QuantumRLConfig,
911 actor: QuantumPolicyNetwork,
913 critic: QuantumValueNetwork,
915 metrics: TrainingMetrics,
917}
918
919impl QuantumActorCritic {
920 pub fn new(config: QuantumRLConfig) -> QuantRS2Result<Self> {
922 let actor = QuantumPolicyNetwork::new(&config)?;
923 let critic = QuantumValueNetwork::new(&config)?;
924
925 Ok(Self {
926 config,
927 actor,
928 critic,
929 metrics: TrainingMetrics::default(),
930 })
931 }
932
933 pub fn update(
935 &mut self,
936 state: &Array1<f64>,
937 _action: usize,
938 reward: f64,
939 next_state: &Array1<f64>,
940 done: bool,
941 ) -> QuantRS2Result<()> {
942 let current_value = self.critic.get_q_value(state, 0)?; let next_value = if done {
945 0.0
946 } else {
947 self.critic.get_max_q_value(next_state)?
948 };
949
950 let target_value = self.config.discount_factor.mul_add(next_value, reward);
951 let td_error = target_value - current_value;
952
953 let critic_gradients = self.critic.compute_gradients(state, 0, target_value)?;
955 self.critic
956 .update_parameters(&critic_gradients, self.config.learning_rate)?;
957
958 let actor_gradients = self.actor.compute_policy_gradients(state, &self.critic)?;
960 let scaled_gradients = actor_gradients * td_error; self.actor
962 .update_parameters(&scaled_gradients, self.config.learning_rate)?;
963
964 self.metrics.q_loss = td_error.abs();
966 self.metrics.policy_loss = -td_error; Ok(())
969 }
970
971 pub fn select_action(&self, state: &Array1<f64>) -> QuantRS2Result<usize> {
973 self.actor.get_best_action(state)
974 }
975
976 pub const fn get_metrics(&self) -> &TrainingMetrics {
978 &self.metrics
979 }
980}
981
982#[cfg(test)]
983mod tests {
984 use super::*;
985
986 #[test]
987 fn test_quantum_dqn_creation() {
988 let config = QuantumRLConfig::default();
989 let agent = QuantumDQN::new(config).expect("Failed to create QuantumDQN agent");
990
991 let stats = agent.get_statistics();
992 assert_eq!(stats.episodes, 0);
993 assert_eq!(stats.training_steps, 0);
994 }
995
996 #[test]
997 fn test_replay_buffer() {
998 let mut buffer = ReplayBuffer::new(10, Some(42));
999
1000 let experience = Experience {
1001 state: Array1::from_vec(vec![1.0, 0.0, -1.0]),
1002 action: 1,
1003 reward: 1.0,
1004 next_state: Array1::from_vec(vec![0.0, 1.0, 0.0]),
1005 done: false,
1006 };
1007
1008 buffer.add(experience);
1009 assert_eq!(buffer.size(), 1);
1010
1011 let samples = buffer.sample(1);
1012 assert_eq!(samples.len(), 1);
1013 }
1014
1015 #[test]
1016 fn test_quantum_value_circuit() {
1017 let circuit =
1018 QuantumValueCircuit::new(3, 2, 4).expect("Failed to create QuantumValueCircuit");
1019 let param_count = circuit.get_parameter_count();
1020 assert!(param_count > 0);
1021
1022 let state = Array1::from_vec(vec![0.5, -0.5, 0.0]);
1023 let parameters = Array1::zeros(param_count);
1024
1025 let q_value = circuit
1026 .evaluate_q_value(&state, 1, ¶meters)
1027 .expect("Failed to evaluate Q-value");
1028 assert!(q_value.is_finite());
1029 }
1030
1031 #[test]
1032 fn test_quantum_actor_critic() {
1033 let config = QuantumRLConfig::default();
1034 let mut agent =
1035 QuantumActorCritic::new(config).expect("Failed to create QuantumActorCritic agent");
1036
1037 let state = Array1::from_vec(vec![0.5, -0.5]);
1038 let next_state = Array1::from_vec(vec![0.0, 1.0]);
1039
1040 let action = agent
1041 .select_action(&state)
1042 .expect("Failed to select action");
1043 assert!(action < 4); agent
1046 .update(&state, action, 1.0, &next_state, false)
1047 .expect("Failed to update agent");
1048
1049 let metrics = agent.get_metrics();
1050 assert!(metrics.q_loss >= 0.0);
1051 }
1052
1053 #[test]
1054 fn test_quantum_rl_config_default() {
1055 let config = QuantumRLConfig::default();
1056 assert_eq!(config.state_qubits, 4);
1057 assert_eq!(config.action_qubits, 2);
1058 assert!(config.learning_rate > 0.0);
1059 assert!(config.discount_factor < 1.0);
1060 }
1061}