quantrs2_ml/
continuous_rl.rs

1//! Quantum Reinforcement Learning with Continuous Actions
2//!
3//! This module extends quantum reinforcement learning to support continuous action spaces,
4//! implementing algorithms like DDPG, TD3, and SAC adapted for quantum circuits.
5
6use crate::autodiff::optimizers::Optimizer;
7use crate::error::{MLError, Result};
8use crate::optimization::OptimizationMethod;
9use crate::qnn::{QNNLayerType, QuantumNeuralNetwork};
10use quantrs2_circuit::builder::{Circuit, Simulator};
11use quantrs2_core::gate::{
12    single::{RotationX, RotationY, RotationZ},
13    GateOp,
14};
15use quantrs2_sim::statevector::StateVectorSimulator;
16use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
17use scirs2_core::random::prelude::*;
18use std::collections::{HashMap, VecDeque};
19use std::f64::consts::PI;
20
21/// Continuous action environment trait
22pub trait ContinuousEnvironment {
23    /// Gets the current state
24    fn state(&self) -> Array1<f64>;
25
26    /// Gets the action space bounds (min, max) for each dimension
27    fn action_bounds(&self) -> Vec<(f64, f64)>;
28
29    /// Takes a continuous action and returns reward and next state
30    fn step(&mut self, action: Array1<f64>) -> Result<(Array1<f64>, f64, bool)>;
31
32    /// Resets the environment
33    fn reset(&mut self) -> Array1<f64>;
34
35    /// Get state dimension
36    fn state_dim(&self) -> usize;
37
38    /// Get action dimension
39    fn action_dim(&self) -> usize;
40}
41
42/// Experience replay buffer for continuous RL
43#[derive(Debug, Clone)]
44pub struct ReplayBuffer {
45    /// Maximum buffer size
46    capacity: usize,
47
48    /// Buffer storage
49    buffer: VecDeque<Experience>,
50}
51
52/// Single experience tuple
53#[derive(Debug, Clone)]
54pub struct Experience {
55    pub state: Array1<f64>,
56    pub action: Array1<f64>,
57    pub reward: f64,
58    pub next_state: Array1<f64>,
59    pub done: bool,
60}
61
62impl ReplayBuffer {
63    /// Create new replay buffer
64    pub fn new(capacity: usize) -> Self {
65        Self {
66            capacity,
67            buffer: VecDeque::with_capacity(capacity),
68        }
69    }
70
71    /// Add experience to buffer
72    pub fn push(&mut self, exp: Experience) {
73        if self.buffer.len() >= self.capacity {
74            self.buffer.pop_front();
75        }
76        self.buffer.push_back(exp);
77    }
78
79    /// Sample batch from buffer
80    pub fn sample(&self, batch_size: usize) -> Result<Vec<Experience>> {
81        if self.buffer.len() < batch_size {
82            return Err(MLError::ModelCreationError(
83                "Not enough experiences in buffer".to_string(),
84            ));
85        }
86
87        let mut batch = Vec::new();
88        let mut rng = thread_rng();
89
90        for _ in 0..batch_size {
91            let idx = rng.gen_range(0..self.buffer.len());
92            batch.push(self.buffer[idx].clone());
93        }
94
95        Ok(batch)
96    }
97
98    /// Get buffer size
99    pub fn len(&self) -> usize {
100        self.buffer.len()
101    }
102}
103
104/// Quantum actor network for continuous actions
105pub struct QuantumActor {
106    /// Quantum neural network
107    qnn: QuantumNeuralNetwork,
108
109    /// Action bounds
110    action_bounds: Vec<(f64, f64)>,
111
112    /// State dimension
113    state_dim: usize,
114
115    /// Action dimension
116    action_dim: usize,
117}
118
119impl QuantumActor {
120    /// Create new quantum actor
121    pub fn new(
122        state_dim: usize,
123        action_dim: usize,
124        action_bounds: Vec<(f64, f64)>,
125        num_qubits: usize,
126    ) -> Result<Self> {
127        let layers = vec![
128            QNNLayerType::EncodingLayer {
129                num_features: state_dim,
130            },
131            QNNLayerType::VariationalLayer {
132                num_params: num_qubits * 3,
133            },
134            QNNLayerType::EntanglementLayer {
135                connectivity: "circular".to_string(),
136            },
137            QNNLayerType::VariationalLayer {
138                num_params: num_qubits * 3,
139            },
140            QNNLayerType::MeasurementLayer {
141                measurement_basis: "Pauli-Z".to_string(),
142            },
143        ];
144
145        let qnn = QuantumNeuralNetwork::new(layers, num_qubits, state_dim, action_dim)?;
146
147        Ok(Self {
148            qnn,
149            action_bounds,
150            state_dim,
151            action_dim,
152        })
153    }
154
155    /// Get action from state
156    pub fn get_action(&self, state: &Array1<f64>, add_noise: bool) -> Result<Array1<f64>> {
157        // Placeholder - would use quantum circuit to generate actions
158        let raw_actions = self.extract_continuous_actions_placeholder()?;
159
160        // Apply bounds and noise
161        let mut actions = Array1::zeros(self.action_dim);
162        for i in 0..self.action_dim {
163            let (min_val, max_val) = self.action_bounds[i];
164
165            // Map quantum output to action range
166            actions[i] = min_val + (max_val - min_val) * (raw_actions[i] + 1.0) / 2.0;
167
168            // Add exploration noise if requested
169            if add_noise {
170                let noise = 0.1 * (max_val - min_val) * (2.0 * thread_rng().gen::<f64>() - 1.0);
171                actions[i] = (actions[i] + noise).clamp(min_val, max_val);
172            }
173        }
174
175        Ok(actions)
176    }
177
178    /// Extract continuous actions from quantum state (placeholder)
179    fn extract_continuous_actions_placeholder(&self) -> Result<Array1<f64>> {
180        // Placeholder - would measure expectation values
181        let mut actions = Array1::zeros(self.action_dim);
182
183        for i in 0..self.action_dim {
184            // Simulate measurement of Pauli-Z on different qubits
185            actions[i] = 2.0 * thread_rng().gen::<f64>() - 1.0; // [-1, 1]
186        }
187
188        Ok(actions)
189    }
190}
191
192/// Quantum critic network for value estimation
193pub struct QuantumCritic {
194    /// Quantum neural network
195    qnn: QuantumNeuralNetwork,
196
197    /// Input dimension (state + action)
198    input_dim: usize,
199}
200
201impl QuantumCritic {
202    /// Create new quantum critic
203    pub fn new(state_dim: usize, action_dim: usize, num_qubits: usize) -> Result<Self> {
204        let input_dim = state_dim + action_dim;
205
206        let layers = vec![
207            QNNLayerType::EncodingLayer {
208                num_features: input_dim,
209            },
210            QNNLayerType::VariationalLayer {
211                num_params: num_qubits * 3,
212            },
213            QNNLayerType::EntanglementLayer {
214                connectivity: "full".to_string(),
215            },
216            QNNLayerType::VariationalLayer {
217                num_params: num_qubits * 3,
218            },
219            QNNLayerType::MeasurementLayer {
220                measurement_basis: "computational".to_string(),
221            },
222        ];
223
224        let qnn = QuantumNeuralNetwork::new(
225            layers, num_qubits, input_dim, 1, // Q-value output
226        )?;
227
228        Ok(Self { qnn, input_dim })
229    }
230
231    /// Estimate Q-value for state-action pair
232    pub fn get_q_value(&self, state: &Array1<f64>, action: &Array1<f64>) -> Result<f64> {
233        // Concatenate state and action
234        let mut input = Array1::zeros(self.input_dim);
235        for i in 0..state.len() {
236            input[i] = state[i];
237        }
238        for i in 0..action.len() {
239            input[state.len() + i] = action[i];
240        }
241
242        // Placeholder - would use quantum circuit to estimate Q-value
243        Ok(0.5 + 0.5 * (2.0 * thread_rng().gen::<f64>() - 1.0))
244    }
245}
246
247/// Quantum Deep Deterministic Policy Gradient (QDDPG)
248pub struct QuantumDDPG {
249    /// Actor network
250    actor: QuantumActor,
251
252    /// Critic network
253    critic: QuantumCritic,
254
255    /// Target actor network
256    target_actor: QuantumActor,
257
258    /// Target critic network
259    target_critic: QuantumCritic,
260
261    /// Replay buffer
262    replay_buffer: ReplayBuffer,
263
264    /// Discount factor
265    gamma: f64,
266
267    /// Soft update coefficient
268    tau: f64,
269
270    /// Batch size
271    batch_size: usize,
272}
273
274impl QuantumDDPG {
275    /// Create new QDDPG agent
276    pub fn new(
277        state_dim: usize,
278        action_dim: usize,
279        action_bounds: Vec<(f64, f64)>,
280        num_qubits: usize,
281        buffer_capacity: usize,
282    ) -> Result<Self> {
283        let actor = QuantumActor::new(state_dim, action_dim, action_bounds.clone(), num_qubits)?;
284        let critic = QuantumCritic::new(state_dim, action_dim, num_qubits)?;
285
286        // Clone for target networks
287        let target_actor = QuantumActor::new(state_dim, action_dim, action_bounds, num_qubits)?;
288        let target_critic = QuantumCritic::new(state_dim, action_dim, num_qubits)?;
289
290        Ok(Self {
291            actor,
292            critic,
293            target_actor,
294            target_critic,
295            replay_buffer: ReplayBuffer::new(buffer_capacity),
296            gamma: 0.99,
297            tau: 0.005,
298            batch_size: 64,
299        })
300    }
301
302    /// Get action for state
303    pub fn get_action(&self, state: &Array1<f64>, training: bool) -> Result<Array1<f64>> {
304        self.actor.get_action(state, training)
305    }
306
307    /// Store experience in replay buffer
308    pub fn store_experience(&mut self, exp: Experience) {
309        self.replay_buffer.push(exp);
310    }
311
312    /// Update networks
313    pub fn update(
314        &mut self,
315        actor_optimizer: &mut dyn Optimizer,
316        critic_optimizer: &mut dyn Optimizer,
317    ) -> Result<()> {
318        if self.replay_buffer.len() < self.batch_size {
319            return Ok(());
320        }
321
322        // Sample batch
323        let batch = self.replay_buffer.sample(self.batch_size)?;
324
325        // Update critic
326        self.update_critic(&batch, critic_optimizer)?;
327
328        // Update actor
329        self.update_actor(&batch, actor_optimizer)?;
330
331        // Soft update target networks
332        self.soft_update()?;
333
334        Ok(())
335    }
336
337    /// Update critic network
338    fn update_critic(&mut self, batch: &[Experience], optimizer: &mut dyn Optimizer) -> Result<()> {
339        // Compute target Q-values
340        let mut target_q_values = Vec::new();
341
342        for exp in batch {
343            let target_action = self.target_actor.get_action(&exp.next_state, false)?;
344            let target_q = self
345                .target_critic
346                .get_q_value(&exp.next_state, &target_action)?;
347            let y = exp.reward + if exp.done { 0.0 } else { self.gamma * target_q };
348            target_q_values.push(y);
349        }
350
351        // Placeholder - would compute loss and update parameters
352
353        Ok(())
354    }
355
356    /// Update actor network
357    fn update_actor(&mut self, batch: &[Experience], optimizer: &mut dyn Optimizer) -> Result<()> {
358        // Compute policy gradient
359        let mut policy_loss = 0.0;
360
361        for exp in batch {
362            let action = self.actor.get_action(&exp.state, false)?;
363            let q_value = self.critic.get_q_value(&exp.state, &action)?;
364            policy_loss -= q_value; // Maximize Q-value
365        }
366
367        policy_loss /= batch.len() as f64;
368
369        // Placeholder - would compute gradients and update
370
371        Ok(())
372    }
373
374    /// Soft update target networks
375    fn soft_update(&mut self) -> Result<()> {
376        // Update target actor parameters
377        for i in 0..self.actor.qnn.parameters.len() {
378            self.target_actor.qnn.parameters[i] = self.tau * self.actor.qnn.parameters[i]
379                + (1.0 - self.tau) * self.target_actor.qnn.parameters[i];
380        }
381
382        // Update target critic parameters
383        for i in 0..self.critic.qnn.parameters.len() {
384            self.target_critic.qnn.parameters[i] = self.tau * self.critic.qnn.parameters[i]
385                + (1.0 - self.tau) * self.target_critic.qnn.parameters[i];
386        }
387
388        Ok(())
389    }
390
391    /// Train on environment
392    pub fn train(
393        &mut self,
394        env: &mut dyn ContinuousEnvironment,
395        episodes: usize,
396        actor_optimizer: &mut dyn Optimizer,
397        critic_optimizer: &mut dyn Optimizer,
398    ) -> Result<Vec<f64>> {
399        let mut episode_rewards = Vec::new();
400
401        for episode in 0..episodes {
402            let mut state = env.reset();
403            let mut episode_reward = 0.0;
404            let mut done = false;
405
406            while !done {
407                // Get action
408                let action = self.get_action(&state, true)?;
409
410                // Step environment
411                let (next_state, reward, is_done) = env.step(action.clone())?;
412
413                // Store experience
414                self.store_experience(Experience {
415                    state: state.clone(),
416                    action,
417                    reward,
418                    next_state: next_state.clone(),
419                    done: is_done,
420                });
421
422                // Update networks
423                self.update(actor_optimizer, critic_optimizer)?;
424
425                state = next_state;
426                episode_reward += reward;
427                done = is_done;
428            }
429
430            episode_rewards.push(episode_reward);
431
432            if episode % 10 == 0 {
433                println!("Episode {}: Reward = {:.2}", episode, episode_reward);
434            }
435        }
436
437        Ok(episode_rewards)
438    }
439}
440
441/// Quantum Soft Actor-Critic (QSAC)
442pub struct QuantumSAC {
443    /// Actor network
444    actor: QuantumActor,
445
446    /// Two Q-networks for stability
447    q1: QuantumCritic,
448    q2: QuantumCritic,
449
450    /// Target Q-networks
451    target_q1: QuantumCritic,
452    target_q2: QuantumCritic,
453
454    /// Temperature parameter for entropy
455    alpha: f64,
456
457    /// Replay buffer
458    replay_buffer: ReplayBuffer,
459
460    /// Hyperparameters
461    gamma: f64,
462    tau: f64,
463    batch_size: usize,
464}
465
466impl QuantumSAC {
467    /// Create new QSAC agent
468    pub fn new(
469        state_dim: usize,
470        action_dim: usize,
471        action_bounds: Vec<(f64, f64)>,
472        num_qubits: usize,
473        buffer_capacity: usize,
474    ) -> Result<Self> {
475        let actor = QuantumActor::new(state_dim, action_dim, action_bounds, num_qubits)?;
476
477        let q1 = QuantumCritic::new(state_dim, action_dim, num_qubits)?;
478        let q2 = QuantumCritic::new(state_dim, action_dim, num_qubits)?;
479
480        let target_q1 = QuantumCritic::new(state_dim, action_dim, num_qubits)?;
481        let target_q2 = QuantumCritic::new(state_dim, action_dim, num_qubits)?;
482
483        Ok(Self {
484            actor,
485            q1,
486            q2,
487            target_q1,
488            target_q2,
489            alpha: 0.2,
490            replay_buffer: ReplayBuffer::new(buffer_capacity),
491            gamma: 0.99,
492            tau: 0.005,
493            batch_size: 64,
494        })
495    }
496
497    /// Get action with entropy regularization
498    pub fn get_action(&self, state: &Array1<f64>, training: bool) -> Result<Array1<f64>> {
499        // SAC uses stochastic policy even during evaluation
500        self.actor.get_action(state, true)
501    }
502
503    /// Compute log probability of action (for entropy)
504    fn log_prob(&self, state: &Array1<f64>, action: &Array1<f64>) -> Result<f64> {
505        // Placeholder - would compute actual log probability
506        Ok(-0.5 * action.mapv(|a| a * a).sum())
507    }
508}
509
510/// Pendulum environment for continuous control
511pub struct PendulumEnvironment {
512    /// Angle (radians)
513    theta: f64,
514
515    /// Angular velocity
516    theta_dot: f64,
517
518    /// Time step
519    dt: f64,
520
521    /// Maximum steps per episode
522    max_steps: usize,
523
524    /// Current step
525    current_step: usize,
526}
527
528impl PendulumEnvironment {
529    /// Create new pendulum environment
530    pub fn new() -> Self {
531        Self {
532            theta: 0.0,
533            theta_dot: 0.0,
534            dt: 0.05,
535            max_steps: 200,
536            current_step: 0,
537        }
538    }
539}
540
541impl ContinuousEnvironment for PendulumEnvironment {
542    fn state(&self) -> Array1<f64> {
543        Array1::from_vec(vec![self.theta.cos(), self.theta.sin(), self.theta_dot])
544    }
545
546    fn action_bounds(&self) -> Vec<(f64, f64)> {
547        vec![(-2.0, 2.0)] // Torque bounds
548    }
549
550    fn step(&mut self, action: Array1<f64>) -> Result<(Array1<f64>, f64, bool)> {
551        let torque = action[0].clamp(-2.0, 2.0);
552
553        // Physics simulation
554        let g = 10.0;
555        let m = 1.0;
556        let l = 1.0;
557
558        // Update dynamics
559        let theta_acc = -3.0 * g / (2.0 * l) * self.theta.sin() + 3.0 * torque / (m * l * l);
560        self.theta_dot += theta_acc * self.dt;
561        self.theta_dot = self.theta_dot.clamp(-8.0, 8.0);
562        self.theta += self.theta_dot * self.dt;
563
564        // Normalize angle to [-pi, pi]
565        self.theta = ((self.theta + PI) % (2.0 * PI)) - PI;
566
567        // Compute reward (penalize angle and velocity)
568        let reward = -(self.theta.powi(2) + 0.1 * self.theta_dot.powi(2) + 0.001 * torque.powi(2));
569
570        self.current_step += 1;
571        let done = self.current_step >= self.max_steps;
572
573        Ok((self.state(), reward, done))
574    }
575
576    fn reset(&mut self) -> Array1<f64> {
577        self.theta = PI * (2.0 * thread_rng().gen::<f64>() - 1.0);
578        self.theta_dot = 2.0 * thread_rng().gen::<f64>() - 1.0;
579        self.current_step = 0;
580        self.state()
581    }
582
583    fn state_dim(&self) -> usize {
584        3
585    }
586
587    fn action_dim(&self) -> usize {
588        1
589    }
590}
591
592#[cfg(test)]
593mod tests {
594    use super::*;
595    use crate::autodiff::optimizers::Adam;
596
597    #[test]
598    fn test_replay_buffer() {
599        let mut buffer = ReplayBuffer::new(100);
600
601        for i in 0..150 {
602            let exp = Experience {
603                state: Array1::zeros(4),
604                action: Array1::zeros(2),
605                reward: i as f64,
606                next_state: Array1::zeros(4),
607                done: false,
608            };
609            buffer.push(exp);
610        }
611
612        assert_eq!(buffer.len(), 100);
613
614        let batch = buffer.sample(10).unwrap();
615        assert_eq!(batch.len(), 10);
616    }
617
618    #[test]
619    fn test_pendulum_environment() {
620        let mut env = PendulumEnvironment::new();
621        let state = env.reset();
622        assert_eq!(state.len(), 3);
623
624        let action = Array1::from_vec(vec![1.0]);
625        let (next_state, reward, done) = env.step(action).unwrap();
626
627        assert_eq!(next_state.len(), 3);
628        assert!(reward <= 0.0); // Reward should be negative
629        assert!(!done); // Not done after one step
630    }
631
632    #[test]
633    fn test_quantum_actor() {
634        let actor = QuantumActor::new(
635            3, // state_dim
636            1, // action_dim
637            vec![(-2.0, 2.0)],
638            4, // num_qubits
639        )
640        .unwrap();
641
642        let state = Array1::from_vec(vec![1.0, 0.0, 0.5]);
643        let action = actor.get_action(&state, false).unwrap();
644
645        assert_eq!(action.len(), 1);
646        assert!(action[0] >= -2.0 && action[0] <= 2.0);
647    }
648
649    #[test]
650    fn test_quantum_critic() {
651        let critic = QuantumCritic::new(3, 1, 4).unwrap();
652
653        let state = Array1::from_vec(vec![1.0, 0.0, 0.5]);
654        let action = Array1::from_vec(vec![1.5]);
655
656        let q_value = critic.get_q_value(&state, &action).unwrap();
657        assert!(q_value.is_finite());
658    }
659}