1use crate::error::{MLError, Result};
2use crate::qnn::QuantumNeuralNetwork;
3use quantrs2_circuit::prelude::Circuit;
4use scirs2_core::ndarray::{Array1, Array2};
5use scirs2_core::random::prelude::*;
6use std::collections::HashMap;
7
8pub trait Environment {
10 fn state(&self) -> Array1<f64>;
12
13 fn num_actions(&self) -> usize;
15
16 fn step(&mut self, action: usize) -> Result<(Array1<f64>, f64, bool)>;
18
19 fn reset(&mut self) -> Array1<f64>;
21}
22
23pub trait QuantumAgent {
25 fn get_action(&self, state: &Array1<f64>) -> Result<usize>;
27
28 fn update(
30 &mut self,
31 state: &Array1<f64>,
32 action: usize,
33 reward: f64,
34 next_state: &Array1<f64>,
35 done: bool,
36 ) -> Result<()>;
37
38 fn train(&mut self, env: &mut dyn Environment, episodes: usize) -> Result<f64>;
40
41 fn evaluate(&self, env: &mut dyn Environment, episodes: usize) -> Result<f64>;
43}
44
45#[derive(Debug, Clone, Copy)]
47pub enum ReinforcementLearningType {
48 QLearning,
50
51 SARSA,
53
54 DQN,
56
57 PolicyGradient,
59
60 QAOA,
62}
63
64#[derive(Debug, Clone)]
66pub struct ReinforcementLearning {
67 rl_type: ReinforcementLearningType,
69
70 qnn: QuantumNeuralNetwork,
72
73 learning_rate: f64,
75
76 discount_factor: f64,
78
79 exploration_rate: f64,
81
82 state_dim: usize,
84
85 action_dim: usize,
87}
88
89impl ReinforcementLearning {
90 pub fn new() -> Self {
92 let layers = vec![
96 crate::qnn::QNNLayerType::EncodingLayer { num_features: 4 },
97 crate::qnn::QNNLayerType::VariationalLayer { num_params: 16 },
98 crate::qnn::QNNLayerType::EntanglementLayer {
99 connectivity: "full".to_string(),
100 },
101 crate::qnn::QNNLayerType::VariationalLayer { num_params: 16 },
102 crate::qnn::QNNLayerType::MeasurementLayer {
103 measurement_basis: "computational".to_string(),
104 },
105 ];
106
107 let qnn = QuantumNeuralNetwork::new(
108 layers, 8, 4, 2, )
112 .unwrap();
113
114 ReinforcementLearning {
115 rl_type: ReinforcementLearningType::QLearning,
116 qnn,
117 learning_rate: 0.01,
118 discount_factor: 0.95,
119 exploration_rate: 0.1,
120 state_dim: 4,
121 action_dim: 2,
122 }
123 }
124
125 pub fn with_algorithm(mut self, rl_type: ReinforcementLearningType) -> Self {
127 self.rl_type = rl_type;
128 self
129 }
130
131 pub fn with_state_dimension(mut self, state_dim: usize) -> Self {
133 self.state_dim = state_dim;
134 self
135 }
136
137 pub fn with_action_dimension(mut self, action_dim: usize) -> Self {
139 self.action_dim = action_dim;
140 self
141 }
142
143 pub fn with_learning_rate(mut self, learning_rate: f64) -> Self {
145 self.learning_rate = learning_rate;
146 self
147 }
148
149 pub fn with_discount_factor(mut self, discount_factor: f64) -> Self {
151 self.discount_factor = discount_factor;
152 self
153 }
154
155 pub fn with_exploration_rate(mut self, exploration_rate: f64) -> Self {
157 self.exploration_rate = exploration_rate;
158 self
159 }
160
161 fn encode_state(&self, state: &Array1<f64>) -> Result<Circuit<8>> {
163 let mut circuit = Circuit::<8>::new();
167
168 for i in 0..state.len().min(8) {
169 circuit.ry(i, state[i] * std::f64::consts::PI)?;
170 }
171
172 Ok(circuit)
173 }
174
175 fn get_q_values(&self, state: &Array1<f64>) -> Result<Array1<f64>> {
177 let mut q_values = Array1::zeros(self.action_dim);
181
182 for i in 0..self.action_dim {
183 q_values[i] = 0.5 + 0.5 * thread_rng().gen::<f64>();
184 }
185
186 Ok(q_values)
187 }
188}
189
190impl QuantumAgent for ReinforcementLearning {
191 fn get_action(&self, state: &Array1<f64>) -> Result<usize> {
192 if thread_rng().gen::<f64>() < self.exploration_rate {
194 Ok(fastrand::usize(0..self.action_dim))
196 } else {
197 let q_values = self.get_q_values(state)?;
199 let mut best_action = 0;
200 let mut best_value = q_values[0];
201
202 for i in 1..self.action_dim {
203 if q_values[i] > best_value {
204 best_value = q_values[i];
205 best_action = i;
206 }
207 }
208
209 Ok(best_action)
210 }
211 }
212
213 fn update(
214 &mut self,
215 _state: &Array1<f64>,
216 _action: usize,
217 _reward: f64,
218 _next_state: &Array1<f64>,
219 _done: bool,
220 ) -> Result<()> {
221 Ok(())
225 }
226
227 fn train(&mut self, env: &mut dyn Environment, episodes: usize) -> Result<f64> {
228 let mut total_reward = 0.0;
229
230 for _ in 0..episodes {
231 let mut state = env.reset();
232 let mut episode_reward = 0.0;
233 let mut done = false;
234
235 while !done {
236 let action = self.get_action(&state)?;
237 let (next_state, reward, is_done) = env.step(action)?;
238
239 self.update(&state, action, reward, &next_state, is_done)?;
240
241 state = next_state;
242 episode_reward += reward;
243 done = is_done;
244 }
245
246 total_reward += episode_reward;
247 }
248
249 Ok(total_reward / episodes as f64)
250 }
251
252 fn evaluate(&self, env: &mut dyn Environment, episodes: usize) -> Result<f64> {
253 let mut total_reward = 0.0;
254
255 for _ in 0..episodes {
256 let mut state = env.reset();
257 let mut episode_reward = 0.0;
258 let mut done = false;
259
260 while !done {
261 let action = self.get_action(&state)?;
262 let (next_state, reward, is_done) = env.step(action)?;
263
264 state = next_state;
265 episode_reward += reward;
266 done = is_done;
267 }
268
269 total_reward += episode_reward;
270 }
271
272 Ok(total_reward / episodes as f64)
273 }
274}
275
276pub struct GridWorldEnvironment {
278 width: usize,
280
281 height: usize,
283
284 position: (usize, usize),
286
287 goal: (usize, usize),
289
290 obstacles: Vec<(usize, usize)>,
292}
293
294impl GridWorldEnvironment {
295 pub fn new(width: usize, height: usize) -> Self {
297 GridWorldEnvironment {
298 width,
299 height,
300 position: (0, 0),
301 goal: (width - 1, height - 1),
302 obstacles: Vec::new(),
303 }
304 }
305
306 pub fn with_goal(mut self, x: usize, y: usize) -> Self {
308 self.goal = (x.min(self.width - 1), y.min(self.height - 1));
309 self
310 }
311
312 pub fn with_obstacles(mut self, obstacles: Vec<(usize, usize)>) -> Self {
314 self.obstacles = obstacles;
315 self
316 }
317
318 pub fn is_obstacle(&self, x: usize, y: usize) -> bool {
320 self.obstacles.contains(&(x, y))
321 }
322
323 pub fn is_goal(&self, x: usize, y: usize) -> bool {
325 (x, y) == self.goal
326 }
327}
328
329impl Environment for GridWorldEnvironment {
330 fn state(&self) -> Array1<f64> {
331 let mut state = Array1::zeros(4);
332
333 state[0] = self.position.0 as f64 / self.width as f64;
335 state[1] = self.position.1 as f64 / self.height as f64;
336
337 state[2] = self.goal.0 as f64 / self.width as f64;
339 state[3] = self.goal.1 as f64 / self.height as f64;
340
341 state
342 }
343
344 fn num_actions(&self) -> usize {
345 4 }
347
348 fn step(&mut self, action: usize) -> Result<(Array1<f64>, f64, bool)> {
349 let (x, y) = self.position;
351 let (new_x, new_y) = match action {
352 0 => (x, y.saturating_sub(1)), 1 => (x + 1, y), 2 => (x, y + 1), 3 => (x.saturating_sub(1), y), _ => {
357 return Err(MLError::InvalidParameter(format!(
358 "Invalid action: {}",
359 action
360 )))
361 }
362 };
363
364 let new_x = new_x.min(self.width - 1);
366 let new_y = new_y.min(self.height - 1);
367
368 if self.obstacles.contains(&(new_x, new_y)) {
370 let reward = -1.0;
372 let done = false;
373 return Ok((self.state(), reward, done));
374 }
375
376 self.position = (new_x, new_y);
378
379 let reward = if (new_x, new_y) == self.goal {
381 10.0 } else {
383 -0.1 };
385
386 let done = (new_x, new_y) == self.goal;
388
389 Ok((self.state(), reward, done))
390 }
391
392 fn reset(&mut self) -> Array1<f64> {
393 self.position = (0, 0);
394 self.state()
395 }
396}