quantrs2_core/qml/
reinforcement_learning.rs

1//! Quantum Reinforcement Learning Algorithms
2//!
3//! This module implements quantum reinforcement learning algorithms that leverage
4//! quantum advantage for policy optimization, value function approximation, and
5//! exploration strategies in reinforcement learning tasks.
6
7use crate::{
8    error::QuantRS2Result, gate::multi::*, gate::single::*, gate::GateOp, qubit::QubitId,
9    variational::VariationalOptimizer,
10};
11use ndarray::{Array1, Array2};
12use rand::{rngs::StdRng, Rng, SeedableRng};
13use serde::{Deserialize, Serialize};
14use std::collections::VecDeque;
15
16/// Configuration for quantum reinforcement learning
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct QuantumRLConfig {
19    /// Number of qubits for state representation
20    pub state_qubits: usize,
21    /// Number of qubits for action representation
22    pub action_qubits: usize,
23    /// Number of qubits for value function
24    pub value_qubits: usize,
25    /// Learning rate for policy optimization
26    pub learning_rate: f64,
27    /// Discount factor (gamma)
28    pub discount_factor: f64,
29    /// Exploration rate (epsilon for epsilon-greedy)
30    pub exploration_rate: f64,
31    /// Exploration decay rate
32    pub exploration_decay: f64,
33    /// Minimum exploration rate
34    pub min_exploration_rate: f64,
35    /// Replay buffer size
36    pub replay_buffer_size: usize,
37    /// Batch size for training
38    pub batch_size: usize,
39    /// Number of circuit layers
40    pub circuit_depth: usize,
41    /// Whether to use quantum advantage techniques
42    pub use_quantum_advantage: bool,
43    /// Random seed for reproducibility
44    pub random_seed: Option<u64>,
45}
46
47impl Default for QuantumRLConfig {
48    fn default() -> Self {
49        Self {
50            state_qubits: 4,
51            action_qubits: 2,
52            value_qubits: 3,
53            learning_rate: 0.01,
54            discount_factor: 0.99,
55            exploration_rate: 1.0,
56            exploration_decay: 0.995,
57            min_exploration_rate: 0.01,
58            replay_buffer_size: 10000,
59            batch_size: 32,
60            circuit_depth: 6,
61            use_quantum_advantage: true,
62            random_seed: None,
63        }
64    }
65}
66
67/// Experience tuple for replay buffer
68#[derive(Debug, Clone)]
69pub struct Experience {
70    /// Current state
71    pub state: Array1<f64>,
72    /// Action taken
73    pub action: usize,
74    /// Reward received
75    pub reward: f64,
76    /// Next state
77    pub next_state: Array1<f64>,
78    /// Whether episode ended
79    pub done: bool,
80}
81
82/// Replay buffer for experience storage
83pub struct ReplayBuffer {
84    /// Storage for experiences
85    buffer: VecDeque<Experience>,
86    /// Maximum buffer size
87    max_size: usize,
88    /// Random number generator
89    rng: StdRng,
90}
91
92impl ReplayBuffer {
93    /// Create a new replay buffer
94    pub fn new(max_size: usize, seed: Option<u64>) -> Self {
95        let rng = match seed {
96            Some(s) => StdRng::seed_from_u64(s),
97            None => StdRng::from_seed([0; 32]), // Use fixed seed for StdRng
98        };
99
100        Self {
101            buffer: VecDeque::with_capacity(max_size),
102            max_size,
103            rng,
104        }
105    }
106
107    /// Add experience to buffer
108    pub fn add(&mut self, experience: Experience) {
109        if self.buffer.len() >= self.max_size {
110            self.buffer.pop_front();
111        }
112        self.buffer.push_back(experience);
113    }
114
115    /// Sample a batch of experiences
116    pub fn sample(&mut self, batch_size: usize) -> Vec<Experience> {
117        let mut samples = Vec::new();
118        let buffer_size = self.buffer.len();
119
120        if buffer_size < batch_size {
121            return self.buffer.iter().cloned().collect();
122        }
123
124        for _ in 0..batch_size {
125            let idx = self.rng.random_range(0..buffer_size);
126            samples.push(self.buffer[idx].clone());
127        }
128
129        samples
130    }
131
132    /// Get current buffer size
133    pub fn size(&self) -> usize {
134        self.buffer.len()
135    }
136
137    /// Check if buffer has enough samples
138    pub fn can_sample(&self, batch_size: usize) -> bool {
139        self.buffer.len() >= batch_size
140    }
141}
142
143/// Quantum Deep Q-Network (QDQN) agent
144pub struct QuantumDQN {
145    /// Configuration
146    config: QuantumRLConfig,
147    /// Q-network (quantum circuit for value function)
148    q_network: QuantumValueNetwork,
149    /// Target Q-network for stable training
150    target_q_network: QuantumValueNetwork,
151    /// Policy network for action selection
152    policy_network: QuantumPolicyNetwork,
153    /// Replay buffer
154    replay_buffer: ReplayBuffer,
155    /// Training step counter
156    training_steps: usize,
157    /// Episode counter
158    episodes: usize,
159    /// Current exploration rate
160    current_exploration_rate: f64,
161    /// Random number generator
162    rng: StdRng,
163}
164
165/// Quantum value network for Q-function approximation
166pub struct QuantumValueNetwork {
167    /// Quantum circuit for value estimation
168    circuit: QuantumValueCircuit,
169    /// Variational parameters
170    parameters: Array1<f64>,
171    /// Optimizer for parameter updates
172    optimizer: VariationalOptimizer,
173}
174
175/// Quantum policy network for action selection
176pub struct QuantumPolicyNetwork {
177    /// Quantum circuit for policy
178    circuit: QuantumPolicyCircuit,
179    /// Variational parameters
180    parameters: Array1<f64>,
181    /// Optimizer for parameter updates
182    optimizer: VariationalOptimizer,
183}
184
185/// Quantum circuit for value function approximation
186#[derive(Debug, Clone)]
187pub struct QuantumValueCircuit {
188    /// Number of state qubits
189    state_qubits: usize,
190    /// Number of value qubits
191    value_qubits: usize,
192    /// Circuit depth
193    depth: usize,
194    /// Total number of qubits
195    total_qubits: usize,
196}
197
198/// Quantum circuit for policy network
199#[derive(Debug, Clone)]
200pub struct QuantumPolicyCircuit {
201    /// Number of state qubits
202    state_qubits: usize,
203    /// Number of action qubits
204    action_qubits: usize,
205    /// Circuit depth
206    depth: usize,
207    /// Total number of qubits
208    total_qubits: usize,
209}
210
211impl QuantumDQN {
212    /// Create a new Quantum DQN agent
213    pub fn new(config: QuantumRLConfig) -> QuantRS2Result<Self> {
214        let rng = match config.random_seed {
215            Some(seed) => StdRng::seed_from_u64(seed),
216            None => StdRng::from_seed([0; 32]), // Use fixed seed for StdRng
217        };
218
219        // Create Q-network
220        let q_network = QuantumValueNetwork::new(&config)?;
221        let mut target_q_network = QuantumValueNetwork::new(&config)?;
222
223        // Initialize target network with same parameters
224        target_q_network.parameters = q_network.parameters.clone();
225
226        // Create policy network
227        let policy_network = QuantumPolicyNetwork::new(&config)?;
228
229        // Create replay buffer
230        let replay_buffer = ReplayBuffer::new(config.replay_buffer_size, config.random_seed);
231
232        Ok(Self {
233            config: config.clone(),
234            q_network,
235            target_q_network,
236            policy_network,
237            replay_buffer,
238            training_steps: 0,
239            episodes: 0,
240            current_exploration_rate: config.exploration_rate,
241            rng,
242        })
243    }
244
245    /// Select action using epsilon-greedy policy with quantum enhancement
246    pub fn select_action(&mut self, state: &Array1<f64>) -> QuantRS2Result<usize> {
247        // Epsilon-greedy exploration
248        if self.rng.random::<f64>() < self.current_exploration_rate {
249            // Random action
250            let num_actions = 1 << self.config.action_qubits;
251            Ok(self.rng.random_range(0..num_actions))
252        } else {
253            // Greedy action using quantum policy network
254            self.policy_network.get_best_action(state)
255        }
256    }
257
258    /// Store experience in replay buffer
259    pub fn store_experience(&mut self, experience: Experience) {
260        self.replay_buffer.add(experience);
261    }
262
263    /// Train the agent using quantum advantage techniques
264    pub fn train(&mut self) -> QuantRS2Result<TrainingMetrics> {
265        if !self.replay_buffer.can_sample(self.config.batch_size) {
266            return Ok(TrainingMetrics::default());
267        }
268
269        // Sample batch from replay buffer
270        let experiences = self.replay_buffer.sample(self.config.batch_size);
271
272        // Prepare training data
273        let (states, actions, rewards, next_states, dones) =
274            self.prepare_training_data(&experiences);
275
276        // Compute target Q-values using quantum advantage
277        let target_q_values = self.compute_target_q_values(&next_states, &rewards, &dones)?;
278
279        // Train Q-network
280        let q_loss = self.train_q_network(&states, &actions, &target_q_values)?;
281
282        // Train policy network
283        let policy_loss = self.train_policy_network(&states)?;
284
285        // Update target network periodically
286        if self.training_steps % 100 == 0 {
287            self.update_target_network();
288        }
289
290        // Update exploration rate
291        self.update_exploration_rate();
292
293        self.training_steps += 1;
294
295        Ok(TrainingMetrics {
296            q_loss,
297            policy_loss,
298            exploration_rate: self.current_exploration_rate,
299            training_steps: self.training_steps,
300        })
301    }
302
303    /// Update target network parameters
304    fn update_target_network(&mut self) {
305        self.target_q_network.parameters = self.q_network.parameters.clone();
306    }
307
308    /// Update exploration rate with decay
309    fn update_exploration_rate(&mut self) {
310        self.current_exploration_rate = (self.current_exploration_rate
311            * self.config.exploration_decay)
312            .max(self.config.min_exploration_rate);
313    }
314
315    /// Prepare training data from experiences
316    fn prepare_training_data(
317        &self,
318        experiences: &[Experience],
319    ) -> (
320        Array2<f64>,
321        Array1<usize>,
322        Array1<f64>,
323        Array2<f64>,
324        Array1<bool>,
325    ) {
326        let batch_size = experiences.len();
327        let state_dim = experiences[0].state.len();
328
329        let mut states = Array2::zeros((batch_size, state_dim));
330        let mut actions = Array1::zeros(batch_size);
331        let mut rewards = Array1::zeros(batch_size);
332        let mut next_states = Array2::zeros((batch_size, state_dim));
333        let mut dones = Array1::from_elem(batch_size, false);
334
335        for (i, exp) in experiences.iter().enumerate() {
336            states.row_mut(i).assign(&exp.state);
337            actions[i] = exp.action;
338            rewards[i] = exp.reward;
339            next_states.row_mut(i).assign(&exp.next_state);
340            dones[i] = exp.done;
341        }
342
343        (states, actions, rewards, next_states, dones)
344    }
345
346    /// Compute target Q-values using quantum advantage
347    fn compute_target_q_values(
348        &self,
349        next_states: &Array2<f64>,
350        rewards: &Array1<f64>,
351        dones: &Array1<bool>,
352    ) -> QuantRS2Result<Array1<f64>> {
353        let batch_size = next_states.nrows();
354        let mut target_q_values = Array1::zeros(batch_size);
355
356        for i in 0..batch_size {
357            if dones[i] {
358                target_q_values[i] = rewards[i];
359            } else {
360                let next_state = next_states.row(i).to_owned();
361                let max_next_q = self.target_q_network.get_max_q_value(&next_state)?;
362                target_q_values[i] = rewards[i] + self.config.discount_factor * max_next_q;
363            }
364        }
365
366        Ok(target_q_values)
367    }
368
369    /// Train Q-network using quantum gradients
370    fn train_q_network(
371        &mut self,
372        states: &Array2<f64>,
373        actions: &Array1<usize>,
374        target_q_values: &Array1<f64>,
375    ) -> QuantRS2Result<f64> {
376        let batch_size = states.nrows();
377        let mut total_loss = 0.0;
378
379        for i in 0..batch_size {
380            let state = states.row(i).to_owned();
381            let action = actions[i];
382            let target = target_q_values[i];
383
384            // Compute current Q-value
385            let current_q = self.q_network.get_q_value(&state, action)?;
386
387            // Compute loss (squared error)
388            let loss = (current_q - target).powi(2);
389            total_loss += loss;
390
391            // Compute quantum gradients and update parameters
392            let gradients = self.q_network.compute_gradients(&state, action, target)?;
393            self.q_network
394                .update_parameters(&gradients, self.config.learning_rate)?;
395        }
396
397        Ok(total_loss / batch_size as f64)
398    }
399
400    /// Train policy network using quantum policy gradients
401    fn train_policy_network(&mut self, states: &Array2<f64>) -> QuantRS2Result<f64> {
402        let batch_size = states.nrows();
403        let mut total_loss = 0.0;
404
405        for i in 0..batch_size {
406            let state = states.row(i).to_owned();
407
408            // Compute policy loss using quantum advantage
409            let policy_loss = self
410                .policy_network
411                .compute_policy_loss(&state, &self.q_network)?;
412            total_loss += policy_loss;
413
414            // Update policy parameters
415            let gradients = self
416                .policy_network
417                .compute_policy_gradients(&state, &self.q_network)?;
418            self.policy_network
419                .update_parameters(&gradients, self.config.learning_rate)?;
420        }
421
422        Ok(total_loss / batch_size as f64)
423    }
424
425    /// End episode and update statistics
426    pub fn end_episode(&mut self, _total_reward: f64) {
427        self.episodes += 1;
428    }
429
430    /// Get training statistics
431    pub fn get_statistics(&self) -> QLearningStats {
432        QLearningStats {
433            episodes: self.episodes,
434            training_steps: self.training_steps,
435            exploration_rate: self.current_exploration_rate,
436            replay_buffer_size: self.replay_buffer.size(),
437        }
438    }
439}
440
441impl QuantumValueNetwork {
442    /// Create a new quantum value network
443    fn new(config: &QuantumRLConfig) -> QuantRS2Result<Self> {
444        let circuit = QuantumValueCircuit::new(
445            config.state_qubits,
446            config.value_qubits,
447            config.circuit_depth,
448        )?;
449
450        let num_parameters = circuit.get_parameter_count();
451        let mut parameters = Array1::zeros(num_parameters);
452
453        // Initialize parameters randomly
454        let mut rng = match config.random_seed {
455            Some(seed) => StdRng::seed_from_u64(seed),
456            None => StdRng::from_seed([0; 32]),
457        };
458
459        for param in parameters.iter_mut() {
460            *param = rng.random_range(-std::f64::consts::PI..std::f64::consts::PI);
461        }
462
463        let optimizer = VariationalOptimizer::new(0.01, 0.9);
464
465        Ok(Self {
466            circuit,
467            parameters,
468            optimizer,
469        })
470    }
471
472    /// Get Q-value for a specific state-action pair
473    fn get_q_value(&self, state: &Array1<f64>, action: usize) -> QuantRS2Result<f64> {
474        self.circuit
475            .evaluate_q_value(state, action, &self.parameters)
476    }
477
478    /// Get maximum Q-value for a state over all actions
479    fn get_max_q_value(&self, state: &Array1<f64>) -> QuantRS2Result<f64> {
480        let num_actions = 1 << self.circuit.get_action_qubits();
481        let mut max_q = f64::NEG_INFINITY;
482
483        for action in 0..num_actions {
484            let q_value = self.get_q_value(state, action)?;
485            max_q = max_q.max(q_value);
486        }
487
488        Ok(max_q)
489    }
490
491    /// Compute gradients using quantum parameter-shift rule
492    fn compute_gradients(
493        &self,
494        state: &Array1<f64>,
495        action: usize,
496        target: f64,
497    ) -> QuantRS2Result<Array1<f64>> {
498        self.circuit
499            .compute_parameter_gradients(state, action, target, &self.parameters)
500    }
501
502    /// Update parameters using gradients
503    fn update_parameters(
504        &mut self,
505        gradients: &Array1<f64>,
506        learning_rate: f64,
507    ) -> QuantRS2Result<()> {
508        for (param, &grad) in self.parameters.iter_mut().zip(gradients.iter()) {
509            *param -= learning_rate * grad;
510        }
511        Ok(())
512    }
513}
514
515impl QuantumPolicyNetwork {
516    /// Create a new quantum policy network
517    fn new(config: &QuantumRLConfig) -> QuantRS2Result<Self> {
518        let circuit = QuantumPolicyCircuit::new(
519            config.state_qubits,
520            config.action_qubits,
521            config.circuit_depth,
522        )?;
523
524        let num_parameters = circuit.get_parameter_count();
525        let mut parameters = Array1::zeros(num_parameters);
526
527        // Initialize parameters randomly
528        let mut rng = match config.random_seed {
529            Some(seed) => StdRng::seed_from_u64(seed),
530            None => StdRng::from_seed([0; 32]),
531        };
532
533        for param in parameters.iter_mut() {
534            *param = rng.random_range(-std::f64::consts::PI..std::f64::consts::PI);
535        }
536
537        let optimizer = VariationalOptimizer::new(0.01, 0.9);
538
539        Ok(Self {
540            circuit,
541            parameters,
542            optimizer,
543        })
544    }
545
546    /// Get best action for a state
547    fn get_best_action(&self, state: &Array1<f64>) -> QuantRS2Result<usize> {
548        self.circuit.get_best_action(state, &self.parameters)
549    }
550
551    /// Compute policy loss
552    fn compute_policy_loss(
553        &self,
554        state: &Array1<f64>,
555        q_network: &QuantumValueNetwork,
556    ) -> QuantRS2Result<f64> {
557        // Use expected Q-value as policy loss (to maximize)
558        let action_probs = self
559            .circuit
560            .get_action_probabilities(state, &self.parameters)?;
561        let num_actions = action_probs.len();
562
563        let mut expected_q = 0.0;
564        for action in 0..num_actions {
565            let q_value = q_network.get_q_value(state, action)?;
566            expected_q += action_probs[action] * q_value;
567        }
568
569        // Negative because we want to maximize (minimize negative)
570        Ok(-expected_q)
571    }
572
573    /// Compute policy gradients
574    fn compute_policy_gradients(
575        &self,
576        state: &Array1<f64>,
577        q_network: &QuantumValueNetwork,
578    ) -> QuantRS2Result<Array1<f64>> {
579        self.circuit
580            .compute_policy_gradients(state, q_network, &self.parameters)
581    }
582
583    /// Update parameters
584    fn update_parameters(
585        &mut self,
586        gradients: &Array1<f64>,
587        learning_rate: f64,
588    ) -> QuantRS2Result<()> {
589        for (param, &grad) in self.parameters.iter_mut().zip(gradients.iter()) {
590            *param -= learning_rate * grad;
591        }
592        Ok(())
593    }
594}
595
596impl QuantumValueCircuit {
597    /// Create a new quantum value circuit
598    fn new(state_qubits: usize, value_qubits: usize, depth: usize) -> QuantRS2Result<Self> {
599        let total_qubits = state_qubits + value_qubits;
600
601        Ok(Self {
602            state_qubits,
603            value_qubits,
604            depth,
605            total_qubits,
606        })
607    }
608
609    /// Get number of parameters in the circuit
610    fn get_parameter_count(&self) -> usize {
611        // Each layer has rotation gates on each qubit (3 parameters each) plus entangling gates
612        let rotations_per_layer = self.get_total_qubits() * 3;
613        let entangling_per_layer = self.get_total_qubits(); // Simplified estimate
614        self.depth * (rotations_per_layer + entangling_per_layer)
615    }
616
617    /// Get total number of qubits
618    fn get_total_qubits(&self) -> usize {
619        self.state_qubits + self.value_qubits
620    }
621
622    /// Get number of action qubits (for external interface)
623    fn get_action_qubits(&self) -> usize {
624        // This is a bit of a hack - in a real implementation,
625        // the value circuit wouldn't directly know about actions
626        2 // Default action qubits
627    }
628
629    /// Evaluate Q-value using quantum circuit
630    fn evaluate_q_value(
631        &self,
632        state: &Array1<f64>,
633        action: usize,
634        parameters: &Array1<f64>,
635    ) -> QuantRS2Result<f64> {
636        // Encode state into quantum circuit
637        let mut gates = Vec::new();
638
639        // State encoding
640        for i in 0..self.state_qubits {
641            let state_value = if i < state.len() { state[i] } else { 0.0 };
642            gates.push(Box::new(RotationY {
643                target: QubitId(i as u32),
644                theta: state_value * std::f64::consts::PI,
645            }) as Box<dyn GateOp>);
646        }
647
648        // Action encoding (simple binary encoding)
649        for i in 0..2 {
650            // Assuming 2 action qubits
651            if (action >> i) & 1 == 1 {
652                gates.push(Box::new(PauliX {
653                    target: QubitId((self.state_qubits + i) as u32),
654                }) as Box<dyn GateOp>);
655            }
656        }
657
658        // Variational circuit layers
659        let mut param_idx = 0;
660        for _layer in 0..self.depth {
661            // Rotation layer
662            for qubit in 0..self.get_total_qubits() {
663                if param_idx + 2 < parameters.len() {
664                    gates.push(Box::new(RotationX {
665                        target: QubitId(qubit as u32),
666                        theta: parameters[param_idx],
667                    }) as Box<dyn GateOp>);
668                    param_idx += 1;
669
670                    gates.push(Box::new(RotationY {
671                        target: QubitId(qubit as u32),
672                        theta: parameters[param_idx],
673                    }) as Box<dyn GateOp>);
674                    param_idx += 1;
675
676                    gates.push(Box::new(RotationZ {
677                        target: QubitId(qubit as u32),
678                        theta: parameters[param_idx],
679                    }) as Box<dyn GateOp>);
680                    param_idx += 1;
681                }
682            }
683
684            // Entangling layer
685            for qubit in 0..self.get_total_qubits() - 1 {
686                if param_idx < parameters.len() {
687                    gates.push(Box::new(CRZ {
688                        control: QubitId(qubit as u32),
689                        target: QubitId((qubit + 1) as u32),
690                        theta: parameters[param_idx],
691                    }) as Box<dyn GateOp>);
692                    param_idx += 1;
693                }
694            }
695        }
696
697        // Simplified evaluation: return a mock Q-value
698        // In a real implementation, this would involve quantum simulation
699        let q_value = self.simulate_circuit_expectation(&gates)?;
700
701        Ok(q_value)
702    }
703
704    /// Simulate circuit and return expectation value
705    fn simulate_circuit_expectation(&self, gates: &[Box<dyn GateOp>]) -> QuantRS2Result<f64> {
706        // Simplified simulation: compute a hash-based mock expectation
707        let mut hash_value = 0u64;
708
709        for gate in gates {
710            // Simple hash of gate parameters
711            if let Ok(matrix) = gate.matrix() {
712                for complex in &matrix {
713                    hash_value = hash_value.wrapping_add((complex.re * 1000.0) as u64);
714                    hash_value = hash_value.wrapping_add((complex.im * 1000.0) as u64);
715                }
716            }
717        }
718
719        // Convert to expectation value in [-1, 1]
720        let expectation = (hash_value % 2000) as f64 / 1000.0 - 1.0;
721        Ok(expectation)
722    }
723
724    /// Compute parameter gradients using parameter-shift rule
725    fn compute_parameter_gradients(
726        &self,
727        state: &Array1<f64>,
728        action: usize,
729        target: f64,
730        parameters: &Array1<f64>,
731    ) -> QuantRS2Result<Array1<f64>> {
732        let mut gradients = Array1::zeros(parameters.len());
733        let shift = std::f64::consts::PI / 2.0;
734
735        for i in 0..parameters.len() {
736            // Forward shift
737            let mut params_plus = parameters.clone();
738            params_plus[i] += shift;
739            let q_plus = self.evaluate_q_value(state, action, &params_plus)?;
740
741            // Backward shift
742            let mut params_minus = parameters.clone();
743            params_minus[i] -= shift;
744            let q_minus = self.evaluate_q_value(state, action, &params_minus)?;
745
746            // Parameter-shift rule gradient
747            let current_q = self.evaluate_q_value(state, action, parameters)?;
748            let loss_gradient = 2.0 * (current_q - target); // d/dθ (q - target)²
749
750            gradients[i] = loss_gradient * (q_plus - q_minus) / 2.0;
751        }
752
753        Ok(gradients)
754    }
755}
756
757impl QuantumPolicyCircuit {
758    /// Create a new quantum policy circuit
759    fn new(state_qubits: usize, action_qubits: usize, depth: usize) -> QuantRS2Result<Self> {
760        let total_qubits = state_qubits + action_qubits;
761
762        Ok(Self {
763            state_qubits,
764            action_qubits,
765            depth,
766            total_qubits,
767        })
768    }
769
770    /// Get number of parameters
771    fn get_parameter_count(&self) -> usize {
772        let total_qubits = self.state_qubits + self.action_qubits;
773        let rotations_per_layer = total_qubits * 3;
774        let entangling_per_layer = total_qubits;
775        self.depth * (rotations_per_layer + entangling_per_layer)
776    }
777
778    /// Get best action for a state
779    fn get_best_action(
780        &self,
781        state: &Array1<f64>,
782        parameters: &Array1<f64>,
783    ) -> QuantRS2Result<usize> {
784        let action_probs = self.get_action_probabilities(state, parameters)?;
785
786        // Find action with highest probability
787        let mut best_action = 0;
788        let mut best_prob = action_probs[0];
789
790        for (action, &prob) in action_probs.iter().enumerate() {
791            if prob > best_prob {
792                best_prob = prob;
793                best_action = action;
794            }
795        }
796
797        Ok(best_action)
798    }
799
800    /// Get action probabilities
801    fn get_action_probabilities(
802        &self,
803        state: &Array1<f64>,
804        parameters: &Array1<f64>,
805    ) -> QuantRS2Result<Vec<f64>> {
806        let num_actions = 1 << self.action_qubits;
807        let mut probabilities = vec![0.0; num_actions];
808
809        // Simplified: uniform distribution with slight variations based on state and parameters
810        let base_prob = 1.0 / num_actions as f64;
811
812        for action in 0..num_actions {
813            // Add state and parameter-dependent variation
814            let state_hash = state.iter().sum::<f64>();
815            let param_hash = parameters.iter().take(10).sum::<f64>();
816            let variation = 0.1 * ((state_hash + param_hash + action as f64).sin());
817
818            probabilities[action] = base_prob + variation;
819        }
820
821        // Normalize probabilities
822        let sum: f64 = probabilities.iter().sum();
823        for prob in &mut probabilities {
824            *prob /= sum;
825        }
826
827        Ok(probabilities)
828    }
829
830    /// Compute policy gradients
831    fn compute_policy_gradients(
832        &self,
833        state: &Array1<f64>,
834        q_network: &QuantumValueNetwork,
835        parameters: &Array1<f64>,
836    ) -> QuantRS2Result<Array1<f64>> {
837        let mut gradients = Array1::zeros(parameters.len());
838        let shift = std::f64::consts::PI / 2.0;
839
840        for i in 0..parameters.len() {
841            // Forward shift
842            let mut params_plus = parameters.clone();
843            params_plus[i] += shift;
844            let loss_plus = self.compute_policy_loss_with_params(state, q_network, &params_plus)?;
845
846            // Backward shift
847            let mut params_minus = parameters.clone();
848            params_minus[i] -= shift;
849            let loss_minus =
850                self.compute_policy_loss_with_params(state, q_network, &params_minus)?;
851
852            // Parameter-shift rule
853            gradients[i] = (loss_plus - loss_minus) / 2.0;
854        }
855
856        Ok(gradients)
857    }
858
859    /// Compute policy loss with specific parameters
860    fn compute_policy_loss_with_params(
861        &self,
862        state: &Array1<f64>,
863        q_network: &QuantumValueNetwork,
864        parameters: &Array1<f64>,
865    ) -> QuantRS2Result<f64> {
866        let action_probs = self.get_action_probabilities(state, parameters)?;
867        let num_actions = action_probs.len();
868
869        let mut expected_q = 0.0;
870        for action in 0..num_actions {
871            let q_value = q_network.get_q_value(state, action)?;
872            expected_q += action_probs[action] * q_value;
873        }
874
875        Ok(-expected_q) // Negative to maximize
876    }
877}
878
879/// Training metrics for quantum RL
880#[derive(Debug, Clone, Default)]
881pub struct TrainingMetrics {
882    /// Q-network loss
883    pub q_loss: f64,
884    /// Policy network loss
885    pub policy_loss: f64,
886    /// Current exploration rate
887    pub exploration_rate: f64,
888    /// Number of training steps
889    pub training_steps: usize,
890}
891
892/// Q-learning statistics
893#[derive(Debug, Clone)]
894pub struct QLearningStats {
895    /// Number of episodes completed
896    pub episodes: usize,
897    /// Number of training steps
898    pub training_steps: usize,
899    /// Current exploration rate
900    pub exploration_rate: f64,
901    /// Current replay buffer size
902    pub replay_buffer_size: usize,
903}
904
905/// Quantum Actor-Critic agent
906pub struct QuantumActorCritic {
907    /// Configuration
908    config: QuantumRLConfig,
909    /// Actor network (policy)
910    actor: QuantumPolicyNetwork,
911    /// Critic network (value function)
912    critic: QuantumValueNetwork,
913    /// Training metrics
914    metrics: TrainingMetrics,
915}
916
917impl QuantumActorCritic {
918    /// Create a new Quantum Actor-Critic agent
919    pub fn new(config: QuantumRLConfig) -> QuantRS2Result<Self> {
920        let actor = QuantumPolicyNetwork::new(&config)?;
921        let critic = QuantumValueNetwork::new(&config)?;
922
923        Ok(Self {
924            config,
925            actor,
926            critic,
927            metrics: TrainingMetrics::default(),
928        })
929    }
930
931    /// Update networks using actor-critic algorithm
932    pub fn update(
933        &mut self,
934        state: &Array1<f64>,
935        _action: usize,
936        reward: f64,
937        next_state: &Array1<f64>,
938        done: bool,
939    ) -> QuantRS2Result<()> {
940        // Compute TD error
941        let current_value = self.critic.get_q_value(state, 0)?; // Use first action for state value
942        let next_value = if done {
943            0.0
944        } else {
945            self.critic.get_max_q_value(next_state)?
946        };
947
948        let target_value = reward + self.config.discount_factor * next_value;
949        let td_error = target_value - current_value;
950
951        // Update critic
952        let critic_gradients = self.critic.compute_gradients(state, 0, target_value)?;
953        self.critic
954            .update_parameters(&critic_gradients, self.config.learning_rate)?;
955
956        // Update actor using policy gradient scaled by TD error
957        let actor_gradients = self.actor.compute_policy_gradients(state, &self.critic)?;
958        let scaled_gradients = actor_gradients * td_error; // Scale by advantage
959        self.actor
960            .update_parameters(&scaled_gradients, self.config.learning_rate)?;
961
962        // Update metrics
963        self.metrics.q_loss = td_error.abs();
964        self.metrics.policy_loss = -td_error; // Negative because we want to maximize
965
966        Ok(())
967    }
968
969    /// Select action using current policy
970    pub fn select_action(&self, state: &Array1<f64>) -> QuantRS2Result<usize> {
971        self.actor.get_best_action(state)
972    }
973
974    /// Get training metrics
975    pub fn get_metrics(&self) -> &TrainingMetrics {
976        &self.metrics
977    }
978}
979
980#[cfg(test)]
981mod tests {
982    use super::*;
983
984    #[test]
985    fn test_quantum_dqn_creation() {
986        let config = QuantumRLConfig::default();
987        let agent = QuantumDQN::new(config).unwrap();
988
989        let stats = agent.get_statistics();
990        assert_eq!(stats.episodes, 0);
991        assert_eq!(stats.training_steps, 0);
992    }
993
994    #[test]
995    fn test_replay_buffer() {
996        let mut buffer = ReplayBuffer::new(10, Some(42));
997
998        let experience = Experience {
999            state: Array1::from_vec(vec![1.0, 0.0, -1.0]),
1000            action: 1,
1001            reward: 1.0,
1002            next_state: Array1::from_vec(vec![0.0, 1.0, 0.0]),
1003            done: false,
1004        };
1005
1006        buffer.add(experience);
1007        assert_eq!(buffer.size(), 1);
1008
1009        let samples = buffer.sample(1);
1010        assert_eq!(samples.len(), 1);
1011    }
1012
1013    #[test]
1014    fn test_quantum_value_circuit() {
1015        let circuit = QuantumValueCircuit::new(3, 2, 4).unwrap();
1016        let param_count = circuit.get_parameter_count();
1017        assert!(param_count > 0);
1018
1019        let state = Array1::from_vec(vec![0.5, -0.5, 0.0]);
1020        let parameters = Array1::zeros(param_count);
1021
1022        let q_value = circuit.evaluate_q_value(&state, 1, &parameters).unwrap();
1023        assert!(q_value.is_finite());
1024    }
1025
1026    #[test]
1027    fn test_quantum_actor_critic() {
1028        let config = QuantumRLConfig::default();
1029        let mut agent = QuantumActorCritic::new(config).unwrap();
1030
1031        let state = Array1::from_vec(vec![0.5, -0.5]);
1032        let next_state = Array1::from_vec(vec![0.0, 1.0]);
1033
1034        let action = agent.select_action(&state).unwrap();
1035        assert!(action < 4); // 2^2 actions for 2 action qubits
1036
1037        agent
1038            .update(&state, action, 1.0, &next_state, false)
1039            .unwrap();
1040
1041        let metrics = agent.get_metrics();
1042        assert!(metrics.q_loss >= 0.0);
1043    }
1044
1045    #[test]
1046    fn test_quantum_rl_config_default() {
1047        let config = QuantumRLConfig::default();
1048        assert_eq!(config.state_qubits, 4);
1049        assert_eq!(config.action_qubits, 2);
1050        assert!(config.learning_rate > 0.0);
1051        assert!(config.discount_factor < 1.0);
1052    }
1053}