quantrs2_ml/
continuous_rl.rs

1//! Quantum Reinforcement Learning with Continuous Actions
2//!
3//! This module extends quantum reinforcement learning to support continuous action spaces,
4//! implementing algorithms like DDPG, TD3, and SAC adapted for quantum circuits.
5
6use crate::autodiff::optimizers::Optimizer;
7use crate::error::{MLError, Result};
8use crate::optimization::OptimizationMethod;
9use crate::qnn::{QNNLayerType, QuantumNeuralNetwork};
10use ndarray::{Array1, Array2, ArrayView1};
11use quantrs2_circuit::builder::{Circuit, Simulator};
12use quantrs2_core::gate::{
13    single::{RotationX, RotationY, RotationZ},
14    GateOp,
15};
16use quantrs2_sim::statevector::StateVectorSimulator;
17use std::collections::{HashMap, VecDeque};
18use std::f64::consts::PI;
19
20/// Continuous action environment trait
21pub trait ContinuousEnvironment {
22    /// Gets the current state
23    fn state(&self) -> Array1<f64>;
24
25    /// Gets the action space bounds (min, max) for each dimension
26    fn action_bounds(&self) -> Vec<(f64, f64)>;
27
28    /// Takes a continuous action and returns reward and next state
29    fn step(&mut self, action: Array1<f64>) -> Result<(Array1<f64>, f64, bool)>;
30
31    /// Resets the environment
32    fn reset(&mut self) -> Array1<f64>;
33
34    /// Get state dimension
35    fn state_dim(&self) -> usize;
36
37    /// Get action dimension
38    fn action_dim(&self) -> usize;
39}
40
41/// Experience replay buffer for continuous RL
42#[derive(Debug, Clone)]
43pub struct ReplayBuffer {
44    /// Maximum buffer size
45    capacity: usize,
46
47    /// Buffer storage
48    buffer: VecDeque<Experience>,
49}
50
51/// Single experience tuple
52#[derive(Debug, Clone)]
53pub struct Experience {
54    pub state: Array1<f64>,
55    pub action: Array1<f64>,
56    pub reward: f64,
57    pub next_state: Array1<f64>,
58    pub done: bool,
59}
60
61impl ReplayBuffer {
62    /// Create new replay buffer
63    pub fn new(capacity: usize) -> Self {
64        Self {
65            capacity,
66            buffer: VecDeque::with_capacity(capacity),
67        }
68    }
69
70    /// Add experience to buffer
71    pub fn push(&mut self, exp: Experience) {
72        if self.buffer.len() >= self.capacity {
73            self.buffer.pop_front();
74        }
75        self.buffer.push_back(exp);
76    }
77
78    /// Sample batch from buffer
79    pub fn sample(&self, batch_size: usize) -> Result<Vec<Experience>> {
80        if self.buffer.len() < batch_size {
81            return Err(MLError::ModelCreationError(
82                "Not enough experiences in buffer".to_string(),
83            ));
84        }
85
86        let mut batch = Vec::new();
87        let mut rng = fastrand::Rng::new();
88
89        for _ in 0..batch_size {
90            let idx = rng.usize(0..self.buffer.len());
91            batch.push(self.buffer[idx].clone());
92        }
93
94        Ok(batch)
95    }
96
97    /// Get buffer size
98    pub fn len(&self) -> usize {
99        self.buffer.len()
100    }
101}
102
103/// Quantum actor network for continuous actions
104pub struct QuantumActor {
105    /// Quantum neural network
106    qnn: QuantumNeuralNetwork,
107
108    /// Action bounds
109    action_bounds: Vec<(f64, f64)>,
110
111    /// State dimension
112    state_dim: usize,
113
114    /// Action dimension
115    action_dim: usize,
116}
117
118impl QuantumActor {
119    /// Create new quantum actor
120    pub fn new(
121        state_dim: usize,
122        action_dim: usize,
123        action_bounds: Vec<(f64, f64)>,
124        num_qubits: usize,
125    ) -> Result<Self> {
126        let layers = vec![
127            QNNLayerType::EncodingLayer {
128                num_features: state_dim,
129            },
130            QNNLayerType::VariationalLayer {
131                num_params: num_qubits * 3,
132            },
133            QNNLayerType::EntanglementLayer {
134                connectivity: "circular".to_string(),
135            },
136            QNNLayerType::VariationalLayer {
137                num_params: num_qubits * 3,
138            },
139            QNNLayerType::MeasurementLayer {
140                measurement_basis: "Pauli-Z".to_string(),
141            },
142        ];
143
144        let qnn = QuantumNeuralNetwork::new(layers, num_qubits, state_dim, action_dim)?;
145
146        Ok(Self {
147            qnn,
148            action_bounds,
149            state_dim,
150            action_dim,
151        })
152    }
153
154    /// Get action from state
155    pub fn get_action(&self, state: &Array1<f64>, add_noise: bool) -> Result<Array1<f64>> {
156        // Placeholder - would use quantum circuit to generate actions
157        let raw_actions = self.extract_continuous_actions_placeholder()?;
158
159        // Apply bounds and noise
160        let mut actions = Array1::zeros(self.action_dim);
161        for i in 0..self.action_dim {
162            let (min_val, max_val) = self.action_bounds[i];
163
164            // Map quantum output to action range
165            actions[i] = min_val + (max_val - min_val) * (raw_actions[i] + 1.0) / 2.0;
166
167            // Add exploration noise if requested
168            if add_noise {
169                let noise = 0.1 * (max_val - min_val) * (2.0 * rand::random::<f64>() - 1.0);
170                actions[i] = (actions[i] + noise).clamp(min_val, max_val);
171            }
172        }
173
174        Ok(actions)
175    }
176
177    /// Extract continuous actions from quantum state (placeholder)
178    fn extract_continuous_actions_placeholder(&self) -> Result<Array1<f64>> {
179        // Placeholder - would measure expectation values
180        let mut actions = Array1::zeros(self.action_dim);
181
182        for i in 0..self.action_dim {
183            // Simulate measurement of Pauli-Z on different qubits
184            actions[i] = 2.0 * rand::random::<f64>() - 1.0; // [-1, 1]
185        }
186
187        Ok(actions)
188    }
189}
190
191/// Quantum critic network for value estimation
192pub struct QuantumCritic {
193    /// Quantum neural network
194    qnn: QuantumNeuralNetwork,
195
196    /// Input dimension (state + action)
197    input_dim: usize,
198}
199
200impl QuantumCritic {
201    /// Create new quantum critic
202    pub fn new(state_dim: usize, action_dim: usize, num_qubits: usize) -> Result<Self> {
203        let input_dim = state_dim + action_dim;
204
205        let layers = vec![
206            QNNLayerType::EncodingLayer {
207                num_features: input_dim,
208            },
209            QNNLayerType::VariationalLayer {
210                num_params: num_qubits * 3,
211            },
212            QNNLayerType::EntanglementLayer {
213                connectivity: "full".to_string(),
214            },
215            QNNLayerType::VariationalLayer {
216                num_params: num_qubits * 3,
217            },
218            QNNLayerType::MeasurementLayer {
219                measurement_basis: "computational".to_string(),
220            },
221        ];
222
223        let qnn = QuantumNeuralNetwork::new(
224            layers, num_qubits, input_dim, 1, // Q-value output
225        )?;
226
227        Ok(Self { qnn, input_dim })
228    }
229
230    /// Estimate Q-value for state-action pair
231    pub fn get_q_value(&self, state: &Array1<f64>, action: &Array1<f64>) -> Result<f64> {
232        // Concatenate state and action
233        let mut input = Array1::zeros(self.input_dim);
234        for i in 0..state.len() {
235            input[i] = state[i];
236        }
237        for i in 0..action.len() {
238            input[state.len() + i] = action[i];
239        }
240
241        // Placeholder - would use quantum circuit to estimate Q-value
242        Ok(0.5 + 0.5 * (2.0 * rand::random::<f64>() - 1.0))
243    }
244}
245
246/// Quantum Deep Deterministic Policy Gradient (QDDPG)
247pub struct QuantumDDPG {
248    /// Actor network
249    actor: QuantumActor,
250
251    /// Critic network
252    critic: QuantumCritic,
253
254    /// Target actor network
255    target_actor: QuantumActor,
256
257    /// Target critic network
258    target_critic: QuantumCritic,
259
260    /// Replay buffer
261    replay_buffer: ReplayBuffer,
262
263    /// Discount factor
264    gamma: f64,
265
266    /// Soft update coefficient
267    tau: f64,
268
269    /// Batch size
270    batch_size: usize,
271}
272
273impl QuantumDDPG {
274    /// Create new QDDPG agent
275    pub fn new(
276        state_dim: usize,
277        action_dim: usize,
278        action_bounds: Vec<(f64, f64)>,
279        num_qubits: usize,
280        buffer_capacity: usize,
281    ) -> Result<Self> {
282        let actor = QuantumActor::new(state_dim, action_dim, action_bounds.clone(), num_qubits)?;
283        let critic = QuantumCritic::new(state_dim, action_dim, num_qubits)?;
284
285        // Clone for target networks
286        let target_actor = QuantumActor::new(state_dim, action_dim, action_bounds, num_qubits)?;
287        let target_critic = QuantumCritic::new(state_dim, action_dim, num_qubits)?;
288
289        Ok(Self {
290            actor,
291            critic,
292            target_actor,
293            target_critic,
294            replay_buffer: ReplayBuffer::new(buffer_capacity),
295            gamma: 0.99,
296            tau: 0.005,
297            batch_size: 64,
298        })
299    }
300
301    /// Get action for state
302    pub fn get_action(&self, state: &Array1<f64>, training: bool) -> Result<Array1<f64>> {
303        self.actor.get_action(state, training)
304    }
305
306    /// Store experience in replay buffer
307    pub fn store_experience(&mut self, exp: Experience) {
308        self.replay_buffer.push(exp);
309    }
310
311    /// Update networks
312    pub fn update(
313        &mut self,
314        actor_optimizer: &mut dyn Optimizer,
315        critic_optimizer: &mut dyn Optimizer,
316    ) -> Result<()> {
317        if self.replay_buffer.len() < self.batch_size {
318            return Ok(());
319        }
320
321        // Sample batch
322        let batch = self.replay_buffer.sample(self.batch_size)?;
323
324        // Update critic
325        self.update_critic(&batch, critic_optimizer)?;
326
327        // Update actor
328        self.update_actor(&batch, actor_optimizer)?;
329
330        // Soft update target networks
331        self.soft_update()?;
332
333        Ok(())
334    }
335
336    /// Update critic network
337    fn update_critic(&mut self, batch: &[Experience], optimizer: &mut dyn Optimizer) -> Result<()> {
338        // Compute target Q-values
339        let mut target_q_values = Vec::new();
340
341        for exp in batch {
342            let target_action = self.target_actor.get_action(&exp.next_state, false)?;
343            let target_q = self
344                .target_critic
345                .get_q_value(&exp.next_state, &target_action)?;
346            let y = exp.reward + if exp.done { 0.0 } else { self.gamma * target_q };
347            target_q_values.push(y);
348        }
349
350        // Placeholder - would compute loss and update parameters
351
352        Ok(())
353    }
354
355    /// Update actor network
356    fn update_actor(&mut self, batch: &[Experience], optimizer: &mut dyn Optimizer) -> Result<()> {
357        // Compute policy gradient
358        let mut policy_loss = 0.0;
359
360        for exp in batch {
361            let action = self.actor.get_action(&exp.state, false)?;
362            let q_value = self.critic.get_q_value(&exp.state, &action)?;
363            policy_loss -= q_value; // Maximize Q-value
364        }
365
366        policy_loss /= batch.len() as f64;
367
368        // Placeholder - would compute gradients and update
369
370        Ok(())
371    }
372
373    /// Soft update target networks
374    fn soft_update(&mut self) -> Result<()> {
375        // Update target actor parameters
376        for i in 0..self.actor.qnn.parameters.len() {
377            self.target_actor.qnn.parameters[i] = self.tau * self.actor.qnn.parameters[i]
378                + (1.0 - self.tau) * self.target_actor.qnn.parameters[i];
379        }
380
381        // Update target critic parameters
382        for i in 0..self.critic.qnn.parameters.len() {
383            self.target_critic.qnn.parameters[i] = self.tau * self.critic.qnn.parameters[i]
384                + (1.0 - self.tau) * self.target_critic.qnn.parameters[i];
385        }
386
387        Ok(())
388    }
389
390    /// Train on environment
391    pub fn train(
392        &mut self,
393        env: &mut dyn ContinuousEnvironment,
394        episodes: usize,
395        actor_optimizer: &mut dyn Optimizer,
396        critic_optimizer: &mut dyn Optimizer,
397    ) -> Result<Vec<f64>> {
398        let mut episode_rewards = Vec::new();
399
400        for episode in 0..episodes {
401            let mut state = env.reset();
402            let mut episode_reward = 0.0;
403            let mut done = false;
404
405            while !done {
406                // Get action
407                let action = self.get_action(&state, true)?;
408
409                // Step environment
410                let (next_state, reward, is_done) = env.step(action.clone())?;
411
412                // Store experience
413                self.store_experience(Experience {
414                    state: state.clone(),
415                    action,
416                    reward,
417                    next_state: next_state.clone(),
418                    done: is_done,
419                });
420
421                // Update networks
422                self.update(actor_optimizer, critic_optimizer)?;
423
424                state = next_state;
425                episode_reward += reward;
426                done = is_done;
427            }
428
429            episode_rewards.push(episode_reward);
430
431            if episode % 10 == 0 {
432                println!("Episode {}: Reward = {:.2}", episode, episode_reward);
433            }
434        }
435
436        Ok(episode_rewards)
437    }
438}
439
440/// Quantum Soft Actor-Critic (QSAC)
441pub struct QuantumSAC {
442    /// Actor network
443    actor: QuantumActor,
444
445    /// Two Q-networks for stability
446    q1: QuantumCritic,
447    q2: QuantumCritic,
448
449    /// Target Q-networks
450    target_q1: QuantumCritic,
451    target_q2: QuantumCritic,
452
453    /// Temperature parameter for entropy
454    alpha: f64,
455
456    /// Replay buffer
457    replay_buffer: ReplayBuffer,
458
459    /// Hyperparameters
460    gamma: f64,
461    tau: f64,
462    batch_size: usize,
463}
464
465impl QuantumSAC {
466    /// Create new QSAC agent
467    pub fn new(
468        state_dim: usize,
469        action_dim: usize,
470        action_bounds: Vec<(f64, f64)>,
471        num_qubits: usize,
472        buffer_capacity: usize,
473    ) -> Result<Self> {
474        let actor = QuantumActor::new(state_dim, action_dim, action_bounds, num_qubits)?;
475
476        let q1 = QuantumCritic::new(state_dim, action_dim, num_qubits)?;
477        let q2 = QuantumCritic::new(state_dim, action_dim, num_qubits)?;
478
479        let target_q1 = QuantumCritic::new(state_dim, action_dim, num_qubits)?;
480        let target_q2 = QuantumCritic::new(state_dim, action_dim, num_qubits)?;
481
482        Ok(Self {
483            actor,
484            q1,
485            q2,
486            target_q1,
487            target_q2,
488            alpha: 0.2,
489            replay_buffer: ReplayBuffer::new(buffer_capacity),
490            gamma: 0.99,
491            tau: 0.005,
492            batch_size: 64,
493        })
494    }
495
496    /// Get action with entropy regularization
497    pub fn get_action(&self, state: &Array1<f64>, training: bool) -> Result<Array1<f64>> {
498        // SAC uses stochastic policy even during evaluation
499        self.actor.get_action(state, true)
500    }
501
502    /// Compute log probability of action (for entropy)
503    fn log_prob(&self, state: &Array1<f64>, action: &Array1<f64>) -> Result<f64> {
504        // Placeholder - would compute actual log probability
505        Ok(-0.5 * action.mapv(|a| a * a).sum())
506    }
507}
508
509/// Pendulum environment for continuous control
510pub struct PendulumEnvironment {
511    /// Angle (radians)
512    theta: f64,
513
514    /// Angular velocity
515    theta_dot: f64,
516
517    /// Time step
518    dt: f64,
519
520    /// Maximum steps per episode
521    max_steps: usize,
522
523    /// Current step
524    current_step: usize,
525}
526
527impl PendulumEnvironment {
528    /// Create new pendulum environment
529    pub fn new() -> Self {
530        Self {
531            theta: 0.0,
532            theta_dot: 0.0,
533            dt: 0.05,
534            max_steps: 200,
535            current_step: 0,
536        }
537    }
538}
539
540impl ContinuousEnvironment for PendulumEnvironment {
541    fn state(&self) -> Array1<f64> {
542        Array1::from_vec(vec![self.theta.cos(), self.theta.sin(), self.theta_dot])
543    }
544
545    fn action_bounds(&self) -> Vec<(f64, f64)> {
546        vec![(-2.0, 2.0)] // Torque bounds
547    }
548
549    fn step(&mut self, action: Array1<f64>) -> Result<(Array1<f64>, f64, bool)> {
550        let torque = action[0].clamp(-2.0, 2.0);
551
552        // Physics simulation
553        let g = 10.0;
554        let m = 1.0;
555        let l = 1.0;
556
557        // Update dynamics
558        let theta_acc = -3.0 * g / (2.0 * l) * self.theta.sin() + 3.0 * torque / (m * l * l);
559        self.theta_dot += theta_acc * self.dt;
560        self.theta_dot = self.theta_dot.clamp(-8.0, 8.0);
561        self.theta += self.theta_dot * self.dt;
562
563        // Normalize angle to [-pi, pi]
564        self.theta = ((self.theta + PI) % (2.0 * PI)) - PI;
565
566        // Compute reward (penalize angle and velocity)
567        let reward = -(self.theta.powi(2) + 0.1 * self.theta_dot.powi(2) + 0.001 * torque.powi(2));
568
569        self.current_step += 1;
570        let done = self.current_step >= self.max_steps;
571
572        Ok((self.state(), reward, done))
573    }
574
575    fn reset(&mut self) -> Array1<f64> {
576        self.theta = PI * (2.0 * rand::random::<f64>() - 1.0);
577        self.theta_dot = 2.0 * rand::random::<f64>() - 1.0;
578        self.current_step = 0;
579        self.state()
580    }
581
582    fn state_dim(&self) -> usize {
583        3
584    }
585
586    fn action_dim(&self) -> usize {
587        1
588    }
589}
590
591#[cfg(test)]
592mod tests {
593    use super::*;
594    use crate::autodiff::optimizers::Adam;
595
596    #[test]
597    fn test_replay_buffer() {
598        let mut buffer = ReplayBuffer::new(100);
599
600        for i in 0..150 {
601            let exp = Experience {
602                state: Array1::zeros(4),
603                action: Array1::zeros(2),
604                reward: i as f64,
605                next_state: Array1::zeros(4),
606                done: false,
607            };
608            buffer.push(exp);
609        }
610
611        assert_eq!(buffer.len(), 100);
612
613        let batch = buffer.sample(10).unwrap();
614        assert_eq!(batch.len(), 10);
615    }
616
617    #[test]
618    fn test_pendulum_environment() {
619        let mut env = PendulumEnvironment::new();
620        let state = env.reset();
621        assert_eq!(state.len(), 3);
622
623        let action = Array1::from_vec(vec![1.0]);
624        let (next_state, reward, done) = env.step(action).unwrap();
625
626        assert_eq!(next_state.len(), 3);
627        assert!(reward <= 0.0); // Reward should be negative
628        assert!(!done); // Not done after one step
629    }
630
631    #[test]
632    fn test_quantum_actor() {
633        let actor = QuantumActor::new(
634            3, // state_dim
635            1, // action_dim
636            vec![(-2.0, 2.0)],
637            4, // num_qubits
638        )
639        .unwrap();
640
641        let state = Array1::from_vec(vec![1.0, 0.0, 0.5]);
642        let action = actor.get_action(&state, false).unwrap();
643
644        assert_eq!(action.len(), 1);
645        assert!(action[0] >= -2.0 && action[0] <= 2.0);
646    }
647
648    #[test]
649    fn test_quantum_critic() {
650        let critic = QuantumCritic::new(3, 1, 4).unwrap();
651
652        let state = Array1::from_vec(vec![1.0, 0.0, 0.5]);
653        let action = Array1::from_vec(vec![1.5]);
654
655        let q_value = critic.get_q_value(&state, &action).unwrap();
656        assert!(q_value.is_finite());
657    }
658}