1use crate::error::{MLError, Result};
2use crate::qnn::QuantumNeuralNetwork;
3use quantrs2_circuit::prelude::Circuit;
4use scirs2_core::ndarray::{Array1, Array2};
5use scirs2_core::random::prelude::*;
6use std::collections::HashMap;
7
8pub trait Environment {
10 fn state(&self) -> Array1<f64>;
12
13 fn num_actions(&self) -> usize;
15
16 fn step(&mut self, action: usize) -> Result<(Array1<f64>, f64, bool)>;
18
19 fn reset(&mut self) -> Array1<f64>;
21}
22
23pub trait QuantumAgent {
25 fn get_action(&self, state: &Array1<f64>) -> Result<usize>;
27
28 fn update(
30 &mut self,
31 state: &Array1<f64>,
32 action: usize,
33 reward: f64,
34 next_state: &Array1<f64>,
35 done: bool,
36 ) -> Result<()>;
37
38 fn train(&mut self, env: &mut dyn Environment, episodes: usize) -> Result<f64>;
40
41 fn evaluate(&self, env: &mut dyn Environment, episodes: usize) -> Result<f64>;
43}
44
45#[derive(Debug, Clone, Copy)]
47pub enum ReinforcementLearningType {
48 QLearning,
50
51 SARSA,
53
54 DQN,
56
57 PolicyGradient,
59
60 QAOA,
62}
63
64#[derive(Debug, Clone)]
66pub struct ReinforcementLearning {
67 rl_type: ReinforcementLearningType,
69
70 qnn: QuantumNeuralNetwork,
72
73 learning_rate: f64,
75
76 discount_factor: f64,
78
79 exploration_rate: f64,
81
82 state_dim: usize,
84
85 action_dim: usize,
87}
88
89impl ReinforcementLearning {
90 pub fn new() -> Result<Self> {
95 let layers = vec![
99 crate::qnn::QNNLayerType::EncodingLayer { num_features: 4 },
100 crate::qnn::QNNLayerType::VariationalLayer { num_params: 16 },
101 crate::qnn::QNNLayerType::EntanglementLayer {
102 connectivity: "full".to_string(),
103 },
104 crate::qnn::QNNLayerType::VariationalLayer { num_params: 16 },
105 crate::qnn::QNNLayerType::MeasurementLayer {
106 measurement_basis: "computational".to_string(),
107 },
108 ];
109
110 let qnn = QuantumNeuralNetwork::new(
111 layers, 8, 4, 2, )?;
115
116 Ok(ReinforcementLearning {
117 rl_type: ReinforcementLearningType::QLearning,
118 qnn,
119 learning_rate: 0.01,
120 discount_factor: 0.95,
121 exploration_rate: 0.1,
122 state_dim: 4,
123 action_dim: 2,
124 })
125 }
126
127 pub fn with_algorithm(mut self, rl_type: ReinforcementLearningType) -> Self {
129 self.rl_type = rl_type;
130 self
131 }
132
133 pub fn with_state_dimension(mut self, state_dim: usize) -> Self {
135 self.state_dim = state_dim;
136 self
137 }
138
139 pub fn with_action_dimension(mut self, action_dim: usize) -> Self {
141 self.action_dim = action_dim;
142 self
143 }
144
145 pub fn with_learning_rate(mut self, learning_rate: f64) -> Self {
147 self.learning_rate = learning_rate;
148 self
149 }
150
151 pub fn with_discount_factor(mut self, discount_factor: f64) -> Self {
153 self.discount_factor = discount_factor;
154 self
155 }
156
157 pub fn with_exploration_rate(mut self, exploration_rate: f64) -> Self {
159 self.exploration_rate = exploration_rate;
160 self
161 }
162
163 fn encode_state(&self, state: &Array1<f64>) -> Result<Circuit<8>> {
165 let mut circuit = Circuit::<8>::new();
169
170 for i in 0..state.len().min(8) {
171 circuit.ry(i, state[i] * std::f64::consts::PI)?;
172 }
173
174 Ok(circuit)
175 }
176
177 fn get_q_values(&self, state: &Array1<f64>) -> Result<Array1<f64>> {
179 let mut q_values = Array1::zeros(self.action_dim);
183
184 for i in 0..self.action_dim {
185 q_values[i] = 0.5 + 0.5 * thread_rng().gen::<f64>();
186 }
187
188 Ok(q_values)
189 }
190}
191
192impl QuantumAgent for ReinforcementLearning {
193 fn get_action(&self, state: &Array1<f64>) -> Result<usize> {
194 if thread_rng().gen::<f64>() < self.exploration_rate {
196 Ok(fastrand::usize(0..self.action_dim))
198 } else {
199 let q_values = self.get_q_values(state)?;
201 let mut best_action = 0;
202 let mut best_value = q_values[0];
203
204 for i in 1..self.action_dim {
205 if q_values[i] > best_value {
206 best_value = q_values[i];
207 best_action = i;
208 }
209 }
210
211 Ok(best_action)
212 }
213 }
214
215 fn update(
216 &mut self,
217 _state: &Array1<f64>,
218 _action: usize,
219 _reward: f64,
220 _next_state: &Array1<f64>,
221 _done: bool,
222 ) -> Result<()> {
223 Ok(())
227 }
228
229 fn train(&mut self, env: &mut dyn Environment, episodes: usize) -> Result<f64> {
230 let mut total_reward = 0.0;
231
232 for _ in 0..episodes {
233 let mut state = env.reset();
234 let mut episode_reward = 0.0;
235 let mut done = false;
236
237 while !done {
238 let action = self.get_action(&state)?;
239 let (next_state, reward, is_done) = env.step(action)?;
240
241 self.update(&state, action, reward, &next_state, is_done)?;
242
243 state = next_state;
244 episode_reward += reward;
245 done = is_done;
246 }
247
248 total_reward += episode_reward;
249 }
250
251 Ok(total_reward / episodes as f64)
252 }
253
254 fn evaluate(&self, env: &mut dyn Environment, episodes: usize) -> Result<f64> {
255 let mut total_reward = 0.0;
256
257 for _ in 0..episodes {
258 let mut state = env.reset();
259 let mut episode_reward = 0.0;
260 let mut done = false;
261
262 while !done {
263 let action = self.get_action(&state)?;
264 let (next_state, reward, is_done) = env.step(action)?;
265
266 state = next_state;
267 episode_reward += reward;
268 done = is_done;
269 }
270
271 total_reward += episode_reward;
272 }
273
274 Ok(total_reward / episodes as f64)
275 }
276}
277
278pub struct GridWorldEnvironment {
280 width: usize,
282
283 height: usize,
285
286 position: (usize, usize),
288
289 goal: (usize, usize),
291
292 obstacles: Vec<(usize, usize)>,
294}
295
296impl GridWorldEnvironment {
297 pub fn new(width: usize, height: usize) -> Self {
299 GridWorldEnvironment {
300 width,
301 height,
302 position: (0, 0),
303 goal: (width - 1, height - 1),
304 obstacles: Vec::new(),
305 }
306 }
307
308 pub fn with_goal(mut self, x: usize, y: usize) -> Self {
310 self.goal = (x.min(self.width - 1), y.min(self.height - 1));
311 self
312 }
313
314 pub fn with_obstacles(mut self, obstacles: Vec<(usize, usize)>) -> Self {
316 self.obstacles = obstacles;
317 self
318 }
319
320 pub fn is_obstacle(&self, x: usize, y: usize) -> bool {
322 self.obstacles.contains(&(x, y))
323 }
324
325 pub fn is_goal(&self, x: usize, y: usize) -> bool {
327 (x, y) == self.goal
328 }
329}
330
331impl Environment for GridWorldEnvironment {
332 fn state(&self) -> Array1<f64> {
333 let mut state = Array1::zeros(4);
334
335 state[0] = self.position.0 as f64 / self.width as f64;
337 state[1] = self.position.1 as f64 / self.height as f64;
338
339 state[2] = self.goal.0 as f64 / self.width as f64;
341 state[3] = self.goal.1 as f64 / self.height as f64;
342
343 state
344 }
345
346 fn num_actions(&self) -> usize {
347 4 }
349
350 fn step(&mut self, action: usize) -> Result<(Array1<f64>, f64, bool)> {
351 let (x, y) = self.position;
353 let (new_x, new_y) = match action {
354 0 => (x, y.saturating_sub(1)), 1 => (x + 1, y), 2 => (x, y + 1), 3 => (x.saturating_sub(1), y), _ => {
359 return Err(MLError::InvalidParameter(format!(
360 "Invalid action: {}",
361 action
362 )))
363 }
364 };
365
366 let new_x = new_x.min(self.width - 1);
368 let new_y = new_y.min(self.height - 1);
369
370 if self.obstacles.contains(&(new_x, new_y)) {
372 let reward = -1.0;
374 let done = false;
375 return Ok((self.state(), reward, done));
376 }
377
378 self.position = (new_x, new_y);
380
381 let reward = if (new_x, new_y) == self.goal {
383 10.0 } else {
385 -0.1 };
387
388 let done = (new_x, new_y) == self.goal;
390
391 Ok((self.state(), reward, done))
392 }
393
394 fn reset(&mut self) -> Array1<f64> {
395 self.position = (0, 0);
396 self.state()
397 }
398}