1use crate::error::{MLError, Result};
2use crate::qnn::QuantumNeuralNetwork;
3use ndarray::{Array1, Array2};
4use quantrs2_circuit::prelude::Circuit;
5use std::collections::HashMap;
6
7pub trait Environment {
9 fn state(&self) -> Array1<f64>;
11
12 fn num_actions(&self) -> usize;
14
15 fn step(&mut self, action: usize) -> Result<(Array1<f64>, f64, bool)>;
17
18 fn reset(&mut self) -> Array1<f64>;
20}
21
22pub trait QuantumAgent {
24 fn get_action(&self, state: &Array1<f64>) -> Result<usize>;
26
27 fn update(
29 &mut self,
30 state: &Array1<f64>,
31 action: usize,
32 reward: f64,
33 next_state: &Array1<f64>,
34 done: bool,
35 ) -> Result<()>;
36
37 fn train(&mut self, env: &mut dyn Environment, episodes: usize) -> Result<f64>;
39
40 fn evaluate(&self, env: &mut dyn Environment, episodes: usize) -> Result<f64>;
42}
43
44#[derive(Debug, Clone, Copy)]
46pub enum ReinforcementLearningType {
47 QLearning,
49
50 SARSA,
52
53 DQN,
55
56 PolicyGradient,
58
59 QAOA,
61}
62
63#[derive(Debug, Clone)]
65pub struct ReinforcementLearning {
66 rl_type: ReinforcementLearningType,
68
69 qnn: QuantumNeuralNetwork,
71
72 learning_rate: f64,
74
75 discount_factor: f64,
77
78 exploration_rate: f64,
80
81 state_dim: usize,
83
84 action_dim: usize,
86}
87
88impl ReinforcementLearning {
89 pub fn new() -> Self {
91 let layers = vec![
95 crate::qnn::QNNLayerType::EncodingLayer { num_features: 4 },
96 crate::qnn::QNNLayerType::VariationalLayer { num_params: 16 },
97 crate::qnn::QNNLayerType::EntanglementLayer {
98 connectivity: "full".to_string(),
99 },
100 crate::qnn::QNNLayerType::VariationalLayer { num_params: 16 },
101 crate::qnn::QNNLayerType::MeasurementLayer {
102 measurement_basis: "computational".to_string(),
103 },
104 ];
105
106 let qnn = QuantumNeuralNetwork::new(
107 layers, 8, 4, 2, )
111 .unwrap();
112
113 ReinforcementLearning {
114 rl_type: ReinforcementLearningType::QLearning,
115 qnn,
116 learning_rate: 0.01,
117 discount_factor: 0.95,
118 exploration_rate: 0.1,
119 state_dim: 4,
120 action_dim: 2,
121 }
122 }
123
124 pub fn with_algorithm(mut self, rl_type: ReinforcementLearningType) -> Self {
126 self.rl_type = rl_type;
127 self
128 }
129
130 pub fn with_state_dimension(mut self, state_dim: usize) -> Self {
132 self.state_dim = state_dim;
133 self
134 }
135
136 pub fn with_action_dimension(mut self, action_dim: usize) -> Self {
138 self.action_dim = action_dim;
139 self
140 }
141
142 pub fn with_learning_rate(mut self, learning_rate: f64) -> Self {
144 self.learning_rate = learning_rate;
145 self
146 }
147
148 pub fn with_discount_factor(mut self, discount_factor: f64) -> Self {
150 self.discount_factor = discount_factor;
151 self
152 }
153
154 pub fn with_exploration_rate(mut self, exploration_rate: f64) -> Self {
156 self.exploration_rate = exploration_rate;
157 self
158 }
159
160 fn encode_state(&self, state: &Array1<f64>) -> Result<Circuit<8>> {
162 let mut circuit = Circuit::<8>::new();
166
167 for i in 0..state.len().min(8) {
168 circuit.ry(i, state[i] * std::f64::consts::PI)?;
169 }
170
171 Ok(circuit)
172 }
173
174 fn get_q_values(&self, state: &Array1<f64>) -> Result<Array1<f64>> {
176 let mut q_values = Array1::zeros(self.action_dim);
180
181 for i in 0..self.action_dim {
182 q_values[i] = 0.5 + 0.5 * rand::random::<f64>();
183 }
184
185 Ok(q_values)
186 }
187}
188
189impl QuantumAgent for ReinforcementLearning {
190 fn get_action(&self, state: &Array1<f64>) -> Result<usize> {
191 if rand::random::<f64>() < self.exploration_rate {
193 Ok(fastrand::usize(0..self.action_dim))
195 } else {
196 let q_values = self.get_q_values(state)?;
198 let mut best_action = 0;
199 let mut best_value = q_values[0];
200
201 for i in 1..self.action_dim {
202 if q_values[i] > best_value {
203 best_value = q_values[i];
204 best_action = i;
205 }
206 }
207
208 Ok(best_action)
209 }
210 }
211
212 fn update(
213 &mut self,
214 _state: &Array1<f64>,
215 _action: usize,
216 _reward: f64,
217 _next_state: &Array1<f64>,
218 _done: bool,
219 ) -> Result<()> {
220 Ok(())
224 }
225
226 fn train(&mut self, env: &mut dyn Environment, episodes: usize) -> Result<f64> {
227 let mut total_reward = 0.0;
228
229 for _ in 0..episodes {
230 let mut state = env.reset();
231 let mut episode_reward = 0.0;
232 let mut done = false;
233
234 while !done {
235 let action = self.get_action(&state)?;
236 let (next_state, reward, is_done) = env.step(action)?;
237
238 self.update(&state, action, reward, &next_state, is_done)?;
239
240 state = next_state;
241 episode_reward += reward;
242 done = is_done;
243 }
244
245 total_reward += episode_reward;
246 }
247
248 Ok(total_reward / episodes as f64)
249 }
250
251 fn evaluate(&self, env: &mut dyn Environment, episodes: usize) -> Result<f64> {
252 let mut total_reward = 0.0;
253
254 for _ in 0..episodes {
255 let mut state = env.reset();
256 let mut episode_reward = 0.0;
257 let mut done = false;
258
259 while !done {
260 let action = self.get_action(&state)?;
261 let (next_state, reward, is_done) = env.step(action)?;
262
263 state = next_state;
264 episode_reward += reward;
265 done = is_done;
266 }
267
268 total_reward += episode_reward;
269 }
270
271 Ok(total_reward / episodes as f64)
272 }
273}
274
275pub struct GridWorldEnvironment {
277 width: usize,
279
280 height: usize,
282
283 position: (usize, usize),
285
286 goal: (usize, usize),
288
289 obstacles: Vec<(usize, usize)>,
291}
292
293impl GridWorldEnvironment {
294 pub fn new(width: usize, height: usize) -> Self {
296 GridWorldEnvironment {
297 width,
298 height,
299 position: (0, 0),
300 goal: (width - 1, height - 1),
301 obstacles: Vec::new(),
302 }
303 }
304
305 pub fn with_goal(mut self, x: usize, y: usize) -> Self {
307 self.goal = (x.min(self.width - 1), y.min(self.height - 1));
308 self
309 }
310
311 pub fn with_obstacles(mut self, obstacles: Vec<(usize, usize)>) -> Self {
313 self.obstacles = obstacles;
314 self
315 }
316
317 pub fn is_obstacle(&self, x: usize, y: usize) -> bool {
319 self.obstacles.contains(&(x, y))
320 }
321
322 pub fn is_goal(&self, x: usize, y: usize) -> bool {
324 (x, y) == self.goal
325 }
326}
327
328impl Environment for GridWorldEnvironment {
329 fn state(&self) -> Array1<f64> {
330 let mut state = Array1::zeros(4);
331
332 state[0] = self.position.0 as f64 / self.width as f64;
334 state[1] = self.position.1 as f64 / self.height as f64;
335
336 state[2] = self.goal.0 as f64 / self.width as f64;
338 state[3] = self.goal.1 as f64 / self.height as f64;
339
340 state
341 }
342
343 fn num_actions(&self) -> usize {
344 4 }
346
347 fn step(&mut self, action: usize) -> Result<(Array1<f64>, f64, bool)> {
348 let (x, y) = self.position;
350 let (new_x, new_y) = match action {
351 0 => (x, y.saturating_sub(1)), 1 => (x + 1, y), 2 => (x, y + 1), 3 => (x.saturating_sub(1), y), _ => {
356 return Err(MLError::InvalidParameter(format!(
357 "Invalid action: {}",
358 action
359 )))
360 }
361 };
362
363 let new_x = new_x.min(self.width - 1);
365 let new_y = new_y.min(self.height - 1);
366
367 if self.obstacles.contains(&(new_x, new_y)) {
369 let reward = -1.0;
371 let done = false;
372 return Ok((self.state(), reward, done));
373 }
374
375 self.position = (new_x, new_y);
377
378 let reward = if (new_x, new_y) == self.goal {
380 10.0 } else {
382 -0.1 };
384
385 let done = (new_x, new_y) == self.goal;
387
388 Ok((self.state(), reward, done))
389 }
390
391 fn reset(&mut self) -> Array1<f64> {
392 self.position = (0, 0);
393 self.state()
394 }
395}