quantrs2_core/qml/
reinforcement_learning.rs

1//! Quantum Reinforcement Learning Algorithms
2//!
3//! This module implements quantum reinforcement learning algorithms that leverage
4//! quantum advantage for policy optimization, value function approximation, and
5//! exploration strategies in reinforcement learning tasks.
6
7use crate::{
8    error::QuantRS2Result, gate::multi::*, gate::single::*, gate::GateOp, qubit::QubitId,
9    variational::VariationalOptimizer,
10};
11use scirs2_core::ndarray::{Array1, Array2};
12use scirs2_core::random::{rngs::StdRng, Rng, SeedableRng};
13use serde::{Deserialize, Serialize};
14use std::collections::VecDeque;
15
16/// Configuration for quantum reinforcement learning
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct QuantumRLConfig {
19    /// Number of qubits for state representation
20    pub state_qubits: usize,
21    /// Number of qubits for action representation
22    pub action_qubits: usize,
23    /// Number of qubits for value function
24    pub value_qubits: usize,
25    /// Learning rate for policy optimization
26    pub learning_rate: f64,
27    /// Discount factor (gamma)
28    pub discount_factor: f64,
29    /// Exploration rate (epsilon for epsilon-greedy)
30    pub exploration_rate: f64,
31    /// Exploration decay rate
32    pub exploration_decay: f64,
33    /// Minimum exploration rate
34    pub min_exploration_rate: f64,
35    /// Replay buffer size
36    pub replay_buffer_size: usize,
37    /// Batch size for training
38    pub batch_size: usize,
39    /// Number of circuit layers
40    pub circuit_depth: usize,
41    /// Whether to use quantum advantage techniques
42    pub use_quantum_advantage: bool,
43    /// Random seed for reproducibility
44    pub random_seed: Option<u64>,
45}
46
47impl Default for QuantumRLConfig {
48    fn default() -> Self {
49        Self {
50            state_qubits: 4,
51            action_qubits: 2,
52            value_qubits: 3,
53            learning_rate: 0.01,
54            discount_factor: 0.99,
55            exploration_rate: 1.0,
56            exploration_decay: 0.995,
57            min_exploration_rate: 0.01,
58            replay_buffer_size: 10000,
59            batch_size: 32,
60            circuit_depth: 6,
61            use_quantum_advantage: true,
62            random_seed: None,
63        }
64    }
65}
66
67/// Experience tuple for replay buffer
68#[derive(Debug, Clone)]
69pub struct Experience {
70    /// Current state
71    pub state: Array1<f64>,
72    /// Action taken
73    pub action: usize,
74    /// Reward received
75    pub reward: f64,
76    /// Next state
77    pub next_state: Array1<f64>,
78    /// Whether episode ended
79    pub done: bool,
80}
81
82/// Replay buffer for experience storage
83pub struct ReplayBuffer {
84    /// Storage for experiences
85    buffer: VecDeque<Experience>,
86    /// Maximum buffer size
87    max_size: usize,
88    /// Random number generator
89    rng: StdRng,
90}
91
92impl ReplayBuffer {
93    /// Create a new replay buffer
94    pub fn new(max_size: usize, seed: Option<u64>) -> Self {
95        let rng = match seed {
96            Some(s) => StdRng::seed_from_u64(s),
97            None => StdRng::from_seed([0; 32]), // Use fixed seed for StdRng
98        };
99
100        Self {
101            buffer: VecDeque::with_capacity(max_size),
102            max_size,
103            rng,
104        }
105    }
106
107    /// Add experience to buffer
108    pub fn add(&mut self, experience: Experience) {
109        if self.buffer.len() >= self.max_size {
110            self.buffer.pop_front();
111        }
112        self.buffer.push_back(experience);
113    }
114
115    /// Sample a batch of experiences
116    pub fn sample(&mut self, batch_size: usize) -> Vec<Experience> {
117        let mut samples = Vec::new();
118        let buffer_size = self.buffer.len();
119
120        if buffer_size < batch_size {
121            return self.buffer.iter().cloned().collect();
122        }
123
124        for _ in 0..batch_size {
125            let idx = self.rng.random_range(0..buffer_size);
126            samples.push(self.buffer[idx].clone());
127        }
128
129        samples
130    }
131
132    /// Get current buffer size
133    pub fn size(&self) -> usize {
134        self.buffer.len()
135    }
136
137    /// Check if buffer has enough samples
138    pub fn can_sample(&self, batch_size: usize) -> bool {
139        self.buffer.len() >= batch_size
140    }
141}
142
143/// Quantum Deep Q-Network (QDQN) agent
144pub struct QuantumDQN {
145    /// Configuration
146    config: QuantumRLConfig,
147    /// Q-network (quantum circuit for value function)
148    q_network: QuantumValueNetwork,
149    /// Target Q-network for stable training
150    target_q_network: QuantumValueNetwork,
151    /// Policy network for action selection
152    policy_network: QuantumPolicyNetwork,
153    /// Replay buffer
154    replay_buffer: ReplayBuffer,
155    /// Training step counter
156    training_steps: usize,
157    /// Episode counter
158    episodes: usize,
159    /// Current exploration rate
160    current_exploration_rate: f64,
161    /// Random number generator
162    rng: StdRng,
163}
164
165/// Quantum value network for Q-function approximation
166pub struct QuantumValueNetwork {
167    /// Quantum circuit for value estimation
168    circuit: QuantumValueCircuit,
169    /// Variational parameters
170    parameters: Array1<f64>,
171    /// Optimizer for parameter updates
172    optimizer: VariationalOptimizer,
173}
174
175/// Quantum policy network for action selection
176pub struct QuantumPolicyNetwork {
177    /// Quantum circuit for policy
178    circuit: QuantumPolicyCircuit,
179    /// Variational parameters
180    parameters: Array1<f64>,
181    /// Optimizer for parameter updates
182    optimizer: VariationalOptimizer,
183}
184
185/// Quantum circuit for value function approximation
186#[derive(Debug, Clone)]
187pub struct QuantumValueCircuit {
188    /// Number of state qubits
189    state_qubits: usize,
190    /// Number of value qubits
191    value_qubits: usize,
192    /// Circuit depth
193    depth: usize,
194    /// Total number of qubits
195    total_qubits: usize,
196}
197
198/// Quantum circuit for policy network
199#[derive(Debug, Clone)]
200pub struct QuantumPolicyCircuit {
201    /// Number of state qubits
202    state_qubits: usize,
203    /// Number of action qubits
204    action_qubits: usize,
205    /// Circuit depth
206    depth: usize,
207    /// Total number of qubits
208    total_qubits: usize,
209}
210
211impl QuantumDQN {
212    /// Create a new Quantum DQN agent
213    pub fn new(config: QuantumRLConfig) -> QuantRS2Result<Self> {
214        let rng = match config.random_seed {
215            Some(seed) => StdRng::seed_from_u64(seed),
216            None => StdRng::from_seed([0; 32]), // Use fixed seed for StdRng
217        };
218
219        // Create Q-network
220        let q_network = QuantumValueNetwork::new(&config)?;
221        let mut target_q_network = QuantumValueNetwork::new(&config)?;
222
223        // Initialize target network with same parameters
224        target_q_network
225            .parameters
226            .clone_from(&q_network.parameters);
227
228        // Create policy network
229        let policy_network = QuantumPolicyNetwork::new(&config)?;
230
231        // Create replay buffer
232        let replay_buffer = ReplayBuffer::new(config.replay_buffer_size, config.random_seed);
233
234        Ok(Self {
235            config: config.clone(),
236            q_network,
237            target_q_network,
238            policy_network,
239            replay_buffer,
240            training_steps: 0,
241            episodes: 0,
242            current_exploration_rate: config.exploration_rate,
243            rng,
244        })
245    }
246
247    /// Select action using epsilon-greedy policy with quantum enhancement
248    pub fn select_action(&mut self, state: &Array1<f64>) -> QuantRS2Result<usize> {
249        // Epsilon-greedy exploration
250        if self.rng.random::<f64>() < self.current_exploration_rate {
251            // Random action
252            let num_actions = 1 << self.config.action_qubits;
253            Ok(self.rng.random_range(0..num_actions))
254        } else {
255            // Greedy action using quantum policy network
256            self.policy_network.get_best_action(state)
257        }
258    }
259
260    /// Store experience in replay buffer
261    pub fn store_experience(&mut self, experience: Experience) {
262        self.replay_buffer.add(experience);
263    }
264
265    /// Train the agent using quantum advantage techniques
266    pub fn train(&mut self) -> QuantRS2Result<TrainingMetrics> {
267        if !self.replay_buffer.can_sample(self.config.batch_size) {
268            return Ok(TrainingMetrics::default());
269        }
270
271        // Sample batch from replay buffer
272        let experiences = self.replay_buffer.sample(self.config.batch_size);
273
274        // Prepare training data
275        let (states, actions, rewards, next_states, dones) =
276            self.prepare_training_data(&experiences);
277
278        // Compute target Q-values using quantum advantage
279        let target_q_values = self.compute_target_q_values(&next_states, &rewards, &dones)?;
280
281        // Train Q-network
282        let q_loss = self.train_q_network(&states, &actions, &target_q_values)?;
283
284        // Train policy network
285        let policy_loss = self.train_policy_network(&states)?;
286
287        // Update target network periodically
288        if self.training_steps % 100 == 0 {
289            self.update_target_network();
290        }
291
292        // Update exploration rate
293        self.update_exploration_rate();
294
295        self.training_steps += 1;
296
297        Ok(TrainingMetrics {
298            q_loss,
299            policy_loss,
300            exploration_rate: self.current_exploration_rate,
301            training_steps: self.training_steps,
302        })
303    }
304
305    /// Update target network parameters
306    fn update_target_network(&mut self) {
307        self.target_q_network.parameters = self.q_network.parameters.clone();
308    }
309
310    /// Update exploration rate with decay
311    fn update_exploration_rate(&mut self) {
312        self.current_exploration_rate = (self.current_exploration_rate
313            * self.config.exploration_decay)
314            .max(self.config.min_exploration_rate);
315    }
316
317    /// Prepare training data from experiences
318    fn prepare_training_data(
319        &self,
320        experiences: &[Experience],
321    ) -> (
322        Array2<f64>,
323        Array1<usize>,
324        Array1<f64>,
325        Array2<f64>,
326        Array1<bool>,
327    ) {
328        let batch_size = experiences.len();
329        let state_dim = experiences[0].state.len();
330
331        let mut states = Array2::zeros((batch_size, state_dim));
332        let mut actions = Array1::zeros(batch_size);
333        let mut rewards = Array1::zeros(batch_size);
334        let mut next_states = Array2::zeros((batch_size, state_dim));
335        let mut dones = Array1::from_elem(batch_size, false);
336
337        for (i, exp) in experiences.iter().enumerate() {
338            states.row_mut(i).assign(&exp.state);
339            actions[i] = exp.action;
340            rewards[i] = exp.reward;
341            next_states.row_mut(i).assign(&exp.next_state);
342            dones[i] = exp.done;
343        }
344
345        (states, actions, rewards, next_states, dones)
346    }
347
348    /// Compute target Q-values using quantum advantage
349    fn compute_target_q_values(
350        &self,
351        next_states: &Array2<f64>,
352        rewards: &Array1<f64>,
353        dones: &Array1<bool>,
354    ) -> QuantRS2Result<Array1<f64>> {
355        let batch_size = next_states.nrows();
356        let mut target_q_values = Array1::zeros(batch_size);
357
358        for i in 0..batch_size {
359            if dones[i] {
360                target_q_values[i] = rewards[i];
361            } else {
362                let next_state = next_states.row(i).to_owned();
363                let max_next_q = self.target_q_network.get_max_q_value(&next_state)?;
364                target_q_values[i] = self.config.discount_factor.mul_add(max_next_q, rewards[i]);
365            }
366        }
367
368        Ok(target_q_values)
369    }
370
371    /// Train Q-network using quantum gradients
372    fn train_q_network(
373        &mut self,
374        states: &Array2<f64>,
375        actions: &Array1<usize>,
376        target_q_values: &Array1<f64>,
377    ) -> QuantRS2Result<f64> {
378        let batch_size = states.nrows();
379        let mut total_loss = 0.0;
380
381        for i in 0..batch_size {
382            let state = states.row(i).to_owned();
383            let action = actions[i];
384            let target = target_q_values[i];
385
386            // Compute current Q-value
387            let current_q = self.q_network.get_q_value(&state, action)?;
388
389            // Compute loss (squared error)
390            let loss = (current_q - target).powi(2);
391            total_loss += loss;
392
393            // Compute quantum gradients and update parameters
394            let gradients = self.q_network.compute_gradients(&state, action, target)?;
395            self.q_network
396                .update_parameters(&gradients, self.config.learning_rate)?;
397        }
398
399        Ok(total_loss / batch_size as f64)
400    }
401
402    /// Train policy network using quantum policy gradients
403    fn train_policy_network(&mut self, states: &Array2<f64>) -> QuantRS2Result<f64> {
404        let batch_size = states.nrows();
405        let mut total_loss = 0.0;
406
407        for i in 0..batch_size {
408            let state = states.row(i).to_owned();
409
410            // Compute policy loss using quantum advantage
411            let policy_loss = self
412                .policy_network
413                .compute_policy_loss(&state, &self.q_network)?;
414            total_loss += policy_loss;
415
416            // Update policy parameters
417            let gradients = self
418                .policy_network
419                .compute_policy_gradients(&state, &self.q_network)?;
420            self.policy_network
421                .update_parameters(&gradients, self.config.learning_rate)?;
422        }
423
424        Ok(total_loss / batch_size as f64)
425    }
426
427    /// End episode and update statistics
428    pub const fn end_episode(&mut self, _total_reward: f64) {
429        self.episodes += 1;
430    }
431
432    /// Get training statistics
433    pub fn get_statistics(&self) -> QLearningStats {
434        QLearningStats {
435            episodes: self.episodes,
436            training_steps: self.training_steps,
437            exploration_rate: self.current_exploration_rate,
438            replay_buffer_size: self.replay_buffer.size(),
439        }
440    }
441}
442
443impl QuantumValueNetwork {
444    /// Create a new quantum value network
445    fn new(config: &QuantumRLConfig) -> QuantRS2Result<Self> {
446        let circuit = QuantumValueCircuit::new(
447            config.state_qubits,
448            config.value_qubits,
449            config.circuit_depth,
450        )?;
451
452        let num_parameters = circuit.get_parameter_count();
453        let mut parameters = Array1::zeros(num_parameters);
454
455        // Initialize parameters randomly
456        let mut rng = match config.random_seed {
457            Some(seed) => StdRng::seed_from_u64(seed),
458            None => StdRng::from_seed([0; 32]),
459        };
460
461        for param in &mut parameters {
462            *param = rng.random_range(-std::f64::consts::PI..std::f64::consts::PI);
463        }
464
465        let optimizer = VariationalOptimizer::new(0.01, 0.9);
466
467        Ok(Self {
468            circuit,
469            parameters,
470            optimizer,
471        })
472    }
473
474    /// Get Q-value for a specific state-action pair
475    fn get_q_value(&self, state: &Array1<f64>, action: usize) -> QuantRS2Result<f64> {
476        self.circuit
477            .evaluate_q_value(state, action, &self.parameters)
478    }
479
480    /// Get maximum Q-value for a state over all actions
481    fn get_max_q_value(&self, state: &Array1<f64>) -> QuantRS2Result<f64> {
482        let num_actions = 1 << self.circuit.get_action_qubits();
483        let mut max_q = f64::NEG_INFINITY;
484
485        for action in 0..num_actions {
486            let q_value = self.get_q_value(state, action)?;
487            max_q = max_q.max(q_value);
488        }
489
490        Ok(max_q)
491    }
492
493    /// Compute gradients using quantum parameter-shift rule
494    fn compute_gradients(
495        &self,
496        state: &Array1<f64>,
497        action: usize,
498        target: f64,
499    ) -> QuantRS2Result<Array1<f64>> {
500        self.circuit
501            .compute_parameter_gradients(state, action, target, &self.parameters)
502    }
503
504    /// Update parameters using gradients
505    fn update_parameters(
506        &mut self,
507        gradients: &Array1<f64>,
508        learning_rate: f64,
509    ) -> QuantRS2Result<()> {
510        for (param, &grad) in self.parameters.iter_mut().zip(gradients.iter()) {
511            *param -= learning_rate * grad;
512        }
513        Ok(())
514    }
515}
516
517impl QuantumPolicyNetwork {
518    /// Create a new quantum policy network
519    fn new(config: &QuantumRLConfig) -> QuantRS2Result<Self> {
520        let circuit = QuantumPolicyCircuit::new(
521            config.state_qubits,
522            config.action_qubits,
523            config.circuit_depth,
524        )?;
525
526        let num_parameters = circuit.get_parameter_count();
527        let mut parameters = Array1::zeros(num_parameters);
528
529        // Initialize parameters randomly
530        let mut rng = match config.random_seed {
531            Some(seed) => StdRng::seed_from_u64(seed),
532            None => StdRng::from_seed([0; 32]),
533        };
534
535        for param in &mut parameters {
536            *param = rng.random_range(-std::f64::consts::PI..std::f64::consts::PI);
537        }
538
539        let optimizer = VariationalOptimizer::new(0.01, 0.9);
540
541        Ok(Self {
542            circuit,
543            parameters,
544            optimizer,
545        })
546    }
547
548    /// Get best action for a state
549    fn get_best_action(&self, state: &Array1<f64>) -> QuantRS2Result<usize> {
550        self.circuit.get_best_action(state, &self.parameters)
551    }
552
553    /// Compute policy loss
554    fn compute_policy_loss(
555        &self,
556        state: &Array1<f64>,
557        q_network: &QuantumValueNetwork,
558    ) -> QuantRS2Result<f64> {
559        // Use expected Q-value as policy loss (to maximize)
560        let action_probs = self
561            .circuit
562            .get_action_probabilities(state, &self.parameters)?;
563        let num_actions = action_probs.len();
564
565        let mut expected_q = 0.0;
566        for action in 0..num_actions {
567            let q_value = q_network.get_q_value(state, action)?;
568            expected_q += action_probs[action] * q_value;
569        }
570
571        // Negative because we want to maximize (minimize negative)
572        Ok(-expected_q)
573    }
574
575    /// Compute policy gradients
576    fn compute_policy_gradients(
577        &self,
578        state: &Array1<f64>,
579        q_network: &QuantumValueNetwork,
580    ) -> QuantRS2Result<Array1<f64>> {
581        self.circuit
582            .compute_policy_gradients(state, q_network, &self.parameters)
583    }
584
585    /// Update parameters
586    fn update_parameters(
587        &mut self,
588        gradients: &Array1<f64>,
589        learning_rate: f64,
590    ) -> QuantRS2Result<()> {
591        for (param, &grad) in self.parameters.iter_mut().zip(gradients.iter()) {
592            *param -= learning_rate * grad;
593        }
594        Ok(())
595    }
596}
597
598impl QuantumValueCircuit {
599    /// Create a new quantum value circuit
600    const fn new(state_qubits: usize, value_qubits: usize, depth: usize) -> QuantRS2Result<Self> {
601        let total_qubits = state_qubits + value_qubits;
602
603        Ok(Self {
604            state_qubits,
605            value_qubits,
606            depth,
607            total_qubits,
608        })
609    }
610
611    /// Get number of parameters in the circuit
612    const fn get_parameter_count(&self) -> usize {
613        // Each layer has rotation gates on each qubit (3 parameters each) plus entangling gates
614        let rotations_per_layer = self.get_total_qubits() * 3;
615        let entangling_per_layer = self.get_total_qubits(); // Simplified estimate
616        self.depth * (rotations_per_layer + entangling_per_layer)
617    }
618
619    /// Get total number of qubits
620    const fn get_total_qubits(&self) -> usize {
621        self.state_qubits + self.value_qubits
622    }
623
624    /// Get number of action qubits (for external interface)
625    const fn get_action_qubits(&self) -> usize {
626        // This is a bit of a hack - in a real implementation,
627        // the value circuit wouldn't directly know about actions
628        2 // Default action qubits
629    }
630
631    /// Evaluate Q-value using quantum circuit
632    fn evaluate_q_value(
633        &self,
634        state: &Array1<f64>,
635        action: usize,
636        parameters: &Array1<f64>,
637    ) -> QuantRS2Result<f64> {
638        // Encode state into quantum circuit
639        let mut gates = Vec::new();
640
641        // State encoding
642        for i in 0..self.state_qubits {
643            let state_value = if i < state.len() { state[i] } else { 0.0 };
644            gates.push(Box::new(RotationY {
645                target: QubitId(i as u32),
646                theta: state_value * std::f64::consts::PI,
647            }) as Box<dyn GateOp>);
648        }
649
650        // Action encoding (simple binary encoding)
651        for i in 0..2 {
652            // Assuming 2 action qubits
653            if (action >> i) & 1 == 1 {
654                gates.push(Box::new(PauliX {
655                    target: QubitId((self.state_qubits + i) as u32),
656                }) as Box<dyn GateOp>);
657            }
658        }
659
660        // Variational circuit layers
661        let mut param_idx = 0;
662        for _layer in 0..self.depth {
663            // Rotation layer
664            for qubit in 0..self.get_total_qubits() {
665                if param_idx + 2 < parameters.len() {
666                    gates.push(Box::new(RotationX {
667                        target: QubitId(qubit as u32),
668                        theta: parameters[param_idx],
669                    }) as Box<dyn GateOp>);
670                    param_idx += 1;
671
672                    gates.push(Box::new(RotationY {
673                        target: QubitId(qubit as u32),
674                        theta: parameters[param_idx],
675                    }) as Box<dyn GateOp>);
676                    param_idx += 1;
677
678                    gates.push(Box::new(RotationZ {
679                        target: QubitId(qubit as u32),
680                        theta: parameters[param_idx],
681                    }) as Box<dyn GateOp>);
682                    param_idx += 1;
683                }
684            }
685
686            // Entangling layer
687            for qubit in 0..self.get_total_qubits() - 1 {
688                if param_idx < parameters.len() {
689                    gates.push(Box::new(CRZ {
690                        control: QubitId(qubit as u32),
691                        target: QubitId((qubit + 1) as u32),
692                        theta: parameters[param_idx],
693                    }) as Box<dyn GateOp>);
694                    param_idx += 1;
695                }
696            }
697        }
698
699        // Simplified evaluation: return a mock Q-value
700        // In a real implementation, this would involve quantum simulation
701        let q_value = self.simulate_circuit_expectation(&gates)?;
702
703        Ok(q_value)
704    }
705
706    /// Simulate circuit and return expectation value
707    fn simulate_circuit_expectation(&self, gates: &[Box<dyn GateOp>]) -> QuantRS2Result<f64> {
708        // Simplified simulation: compute a hash-based mock expectation
709        let mut hash_value = 0u64;
710
711        for gate in gates {
712            // Simple hash of gate parameters
713            if let Ok(matrix) = gate.matrix() {
714                for complex in &matrix {
715                    hash_value = hash_value.wrapping_add((complex.re * 1000.0) as u64);
716                    hash_value = hash_value.wrapping_add((complex.im * 1000.0) as u64);
717                }
718            }
719        }
720
721        // Convert to expectation value in [-1, 1]
722        let expectation = (hash_value % 2000) as f64 / 1000.0 - 1.0;
723        Ok(expectation)
724    }
725
726    /// Compute parameter gradients using parameter-shift rule
727    fn compute_parameter_gradients(
728        &self,
729        state: &Array1<f64>,
730        action: usize,
731        target: f64,
732        parameters: &Array1<f64>,
733    ) -> QuantRS2Result<Array1<f64>> {
734        let mut gradients = Array1::zeros(parameters.len());
735        let shift = std::f64::consts::PI / 2.0;
736
737        for i in 0..parameters.len() {
738            // Forward shift
739            let mut params_plus = parameters.clone();
740            params_plus[i] += shift;
741            let q_plus = self.evaluate_q_value(state, action, &params_plus)?;
742
743            // Backward shift
744            let mut params_minus = parameters.clone();
745            params_minus[i] -= shift;
746            let q_minus = self.evaluate_q_value(state, action, &params_minus)?;
747
748            // Parameter-shift rule gradient
749            let current_q = self.evaluate_q_value(state, action, parameters)?;
750            let loss_gradient = 2.0 * (current_q - target); // d/dθ (q - target)²
751
752            gradients[i] = loss_gradient * (q_plus - q_minus) / 2.0;
753        }
754
755        Ok(gradients)
756    }
757}
758
759impl QuantumPolicyCircuit {
760    /// Create a new quantum policy circuit
761    const fn new(state_qubits: usize, action_qubits: usize, depth: usize) -> QuantRS2Result<Self> {
762        let total_qubits = state_qubits + action_qubits;
763
764        Ok(Self {
765            state_qubits,
766            action_qubits,
767            depth,
768            total_qubits,
769        })
770    }
771
772    /// Get number of parameters
773    const fn get_parameter_count(&self) -> usize {
774        let total_qubits = self.state_qubits + self.action_qubits;
775        let rotations_per_layer = total_qubits * 3;
776        let entangling_per_layer = total_qubits;
777        self.depth * (rotations_per_layer + entangling_per_layer)
778    }
779
780    /// Get best action for a state
781    fn get_best_action(
782        &self,
783        state: &Array1<f64>,
784        parameters: &Array1<f64>,
785    ) -> QuantRS2Result<usize> {
786        let action_probs = self.get_action_probabilities(state, parameters)?;
787
788        // Find action with highest probability
789        let mut best_action = 0;
790        let mut best_prob = action_probs[0];
791
792        for (action, &prob) in action_probs.iter().enumerate() {
793            if prob > best_prob {
794                best_prob = prob;
795                best_action = action;
796            }
797        }
798
799        Ok(best_action)
800    }
801
802    /// Get action probabilities
803    fn get_action_probabilities(
804        &self,
805        state: &Array1<f64>,
806        parameters: &Array1<f64>,
807    ) -> QuantRS2Result<Vec<f64>> {
808        let num_actions = 1 << self.action_qubits;
809        let mut probabilities = vec![0.0; num_actions];
810
811        // Simplified: uniform distribution with slight variations based on state and parameters
812        let base_prob = 1.0 / num_actions as f64;
813
814        for action in 0..num_actions {
815            // Add state and parameter-dependent variation
816            let state_hash = state.iter().sum::<f64>();
817            let param_hash = parameters.iter().take(10).sum::<f64>();
818            let variation = 0.1 * ((state_hash + param_hash + action as f64).sin());
819
820            probabilities[action] = base_prob + variation;
821        }
822
823        // Normalize probabilities
824        let sum: f64 = probabilities.iter().sum();
825        for prob in &mut probabilities {
826            *prob /= sum;
827        }
828
829        Ok(probabilities)
830    }
831
832    /// Compute policy gradients
833    fn compute_policy_gradients(
834        &self,
835        state: &Array1<f64>,
836        q_network: &QuantumValueNetwork,
837        parameters: &Array1<f64>,
838    ) -> QuantRS2Result<Array1<f64>> {
839        let mut gradients = Array1::zeros(parameters.len());
840        let shift = std::f64::consts::PI / 2.0;
841
842        for i in 0..parameters.len() {
843            // Forward shift
844            let mut params_plus = parameters.clone();
845            params_plus[i] += shift;
846            let loss_plus = self.compute_policy_loss_with_params(state, q_network, &params_plus)?;
847
848            // Backward shift
849            let mut params_minus = parameters.clone();
850            params_minus[i] -= shift;
851            let loss_minus =
852                self.compute_policy_loss_with_params(state, q_network, &params_minus)?;
853
854            // Parameter-shift rule
855            gradients[i] = (loss_plus - loss_minus) / 2.0;
856        }
857
858        Ok(gradients)
859    }
860
861    /// Compute policy loss with specific parameters
862    fn compute_policy_loss_with_params(
863        &self,
864        state: &Array1<f64>,
865        q_network: &QuantumValueNetwork,
866        parameters: &Array1<f64>,
867    ) -> QuantRS2Result<f64> {
868        let action_probs = self.get_action_probabilities(state, parameters)?;
869        let num_actions = action_probs.len();
870
871        let mut expected_q = 0.0;
872        for action in 0..num_actions {
873            let q_value = q_network.get_q_value(state, action)?;
874            expected_q += action_probs[action] * q_value;
875        }
876
877        Ok(-expected_q) // Negative to maximize
878    }
879}
880
881/// Training metrics for quantum RL
882#[derive(Debug, Clone, Default)]
883pub struct TrainingMetrics {
884    /// Q-network loss
885    pub q_loss: f64,
886    /// Policy network loss
887    pub policy_loss: f64,
888    /// Current exploration rate
889    pub exploration_rate: f64,
890    /// Number of training steps
891    pub training_steps: usize,
892}
893
894/// Q-learning statistics
895#[derive(Debug, Clone)]
896pub struct QLearningStats {
897    /// Number of episodes completed
898    pub episodes: usize,
899    /// Number of training steps
900    pub training_steps: usize,
901    /// Current exploration rate
902    pub exploration_rate: f64,
903    /// Current replay buffer size
904    pub replay_buffer_size: usize,
905}
906
907/// Quantum Actor-Critic agent
908pub struct QuantumActorCritic {
909    /// Configuration
910    config: QuantumRLConfig,
911    /// Actor network (policy)
912    actor: QuantumPolicyNetwork,
913    /// Critic network (value function)
914    critic: QuantumValueNetwork,
915    /// Training metrics
916    metrics: TrainingMetrics,
917}
918
919impl QuantumActorCritic {
920    /// Create a new Quantum Actor-Critic agent
921    pub fn new(config: QuantumRLConfig) -> QuantRS2Result<Self> {
922        let actor = QuantumPolicyNetwork::new(&config)?;
923        let critic = QuantumValueNetwork::new(&config)?;
924
925        Ok(Self {
926            config,
927            actor,
928            critic,
929            metrics: TrainingMetrics::default(),
930        })
931    }
932
933    /// Update networks using actor-critic algorithm
934    pub fn update(
935        &mut self,
936        state: &Array1<f64>,
937        _action: usize,
938        reward: f64,
939        next_state: &Array1<f64>,
940        done: bool,
941    ) -> QuantRS2Result<()> {
942        // Compute TD error
943        let current_value = self.critic.get_q_value(state, 0)?; // Use first action for state value
944        let next_value = if done {
945            0.0
946        } else {
947            self.critic.get_max_q_value(next_state)?
948        };
949
950        let target_value = self.config.discount_factor.mul_add(next_value, reward);
951        let td_error = target_value - current_value;
952
953        // Update critic
954        let critic_gradients = self.critic.compute_gradients(state, 0, target_value)?;
955        self.critic
956            .update_parameters(&critic_gradients, self.config.learning_rate)?;
957
958        // Update actor using policy gradient scaled by TD error
959        let actor_gradients = self.actor.compute_policy_gradients(state, &self.critic)?;
960        let scaled_gradients = actor_gradients * td_error; // Scale by advantage
961        self.actor
962            .update_parameters(&scaled_gradients, self.config.learning_rate)?;
963
964        // Update metrics
965        self.metrics.q_loss = td_error.abs();
966        self.metrics.policy_loss = -td_error; // Negative because we want to maximize
967
968        Ok(())
969    }
970
971    /// Select action using current policy
972    pub fn select_action(&self, state: &Array1<f64>) -> QuantRS2Result<usize> {
973        self.actor.get_best_action(state)
974    }
975
976    /// Get training metrics
977    pub const fn get_metrics(&self) -> &TrainingMetrics {
978        &self.metrics
979    }
980}
981
982#[cfg(test)]
983mod tests {
984    use super::*;
985
986    #[test]
987    fn test_quantum_dqn_creation() {
988        let config = QuantumRLConfig::default();
989        let agent = QuantumDQN::new(config).expect("Failed to create QuantumDQN agent");
990
991        let stats = agent.get_statistics();
992        assert_eq!(stats.episodes, 0);
993        assert_eq!(stats.training_steps, 0);
994    }
995
996    #[test]
997    fn test_replay_buffer() {
998        let mut buffer = ReplayBuffer::new(10, Some(42));
999
1000        let experience = Experience {
1001            state: Array1::from_vec(vec![1.0, 0.0, -1.0]),
1002            action: 1,
1003            reward: 1.0,
1004            next_state: Array1::from_vec(vec![0.0, 1.0, 0.0]),
1005            done: false,
1006        };
1007
1008        buffer.add(experience);
1009        assert_eq!(buffer.size(), 1);
1010
1011        let samples = buffer.sample(1);
1012        assert_eq!(samples.len(), 1);
1013    }
1014
1015    #[test]
1016    fn test_quantum_value_circuit() {
1017        let circuit =
1018            QuantumValueCircuit::new(3, 2, 4).expect("Failed to create QuantumValueCircuit");
1019        let param_count = circuit.get_parameter_count();
1020        assert!(param_count > 0);
1021
1022        let state = Array1::from_vec(vec![0.5, -0.5, 0.0]);
1023        let parameters = Array1::zeros(param_count);
1024
1025        let q_value = circuit
1026            .evaluate_q_value(&state, 1, &parameters)
1027            .expect("Failed to evaluate Q-value");
1028        assert!(q_value.is_finite());
1029    }
1030
1031    #[test]
1032    fn test_quantum_actor_critic() {
1033        let config = QuantumRLConfig::default();
1034        let mut agent =
1035            QuantumActorCritic::new(config).expect("Failed to create QuantumActorCritic agent");
1036
1037        let state = Array1::from_vec(vec![0.5, -0.5]);
1038        let next_state = Array1::from_vec(vec![0.0, 1.0]);
1039
1040        let action = agent
1041            .select_action(&state)
1042            .expect("Failed to select action");
1043        assert!(action < 4); // 2^2 actions for 2 action qubits
1044
1045        agent
1046            .update(&state, action, 1.0, &next_state, false)
1047            .expect("Failed to update agent");
1048
1049        let metrics = agent.get_metrics();
1050        assert!(metrics.q_loss >= 0.0);
1051    }
1052
1053    #[test]
1054    fn test_quantum_rl_config_default() {
1055        let config = QuantumRLConfig::default();
1056        assert_eq!(config.state_qubits, 4);
1057        assert_eq!(config.action_qubits, 2);
1058        assert!(config.learning_rate > 0.0);
1059        assert!(config.discount_factor < 1.0);
1060    }
1061}