1use crate::error::{MLError, Result};
8use crate::qnn::QuantumNeuralNetwork;
9use quantrs2_circuit::prelude::Circuit;
10use scirs2_core::ndarray::{Array1, Array2};
11use scirs2_core::random::prelude::*;
12use std::collections::HashMap;
13
14pub trait Environment {
16 fn state(&self) -> Array1<f64>;
18
19 fn num_actions(&self) -> usize;
21
22 fn step(&mut self, action: usize) -> Result<(Array1<f64>, f64, bool)>;
24
25 fn reset(&mut self) -> Array1<f64>;
27}
28
29pub trait QuantumAgent {
31 fn get_action(&self, state: &Array1<f64>) -> Result<usize>;
33
34 fn update(
36 &mut self,
37 state: &Array1<f64>,
38 action: usize,
39 reward: f64,
40 next_state: &Array1<f64>,
41 done: bool,
42 ) -> Result<()>;
43
44 fn train(&mut self, env: &mut dyn Environment, episodes: usize) -> Result<f64>;
46
47 fn evaluate(&self, env: &mut dyn Environment, episodes: usize) -> Result<f64>;
49}
50
51#[derive(Debug, Clone, Copy)]
53pub enum ReinforcementLearningType {
54 QLearning,
56
57 SARSA,
59
60 DQN,
62
63 PolicyGradient,
65
66 QAOA,
68}
69
70#[derive(Debug, Clone)]
72pub struct ReinforcementLearning {
73 rl_type: ReinforcementLearningType,
75
76 qnn: QuantumNeuralNetwork,
78
79 learning_rate: f64,
81
82 discount_factor: f64,
84
85 exploration_rate: f64,
87
88 state_dim: usize,
90
91 action_dim: usize,
93}
94
95impl ReinforcementLearning {
96 pub fn new() -> Result<Self> {
101 let layers = vec![
105 crate::qnn::QNNLayerType::EncodingLayer { num_features: 4 },
106 crate::qnn::QNNLayerType::VariationalLayer { num_params: 16 },
107 crate::qnn::QNNLayerType::EntanglementLayer {
108 connectivity: "full".to_string(),
109 },
110 crate::qnn::QNNLayerType::VariationalLayer { num_params: 16 },
111 crate::qnn::QNNLayerType::MeasurementLayer {
112 measurement_basis: "computational".to_string(),
113 },
114 ];
115
116 let qnn = QuantumNeuralNetwork::new(
117 layers, 8, 4, 2, )?;
121
122 Ok(ReinforcementLearning {
123 rl_type: ReinforcementLearningType::QLearning,
124 qnn,
125 learning_rate: 0.01,
126 discount_factor: 0.95,
127 exploration_rate: 0.1,
128 state_dim: 4,
129 action_dim: 2,
130 })
131 }
132
133 pub fn with_algorithm(mut self, rl_type: ReinforcementLearningType) -> Self {
135 self.rl_type = rl_type;
136 self
137 }
138
139 pub fn with_state_dimension(mut self, state_dim: usize) -> Self {
141 self.state_dim = state_dim;
142 self
143 }
144
145 pub fn with_action_dimension(mut self, action_dim: usize) -> Self {
147 self.action_dim = action_dim;
148 self
149 }
150
151 pub fn with_learning_rate(mut self, learning_rate: f64) -> Self {
153 self.learning_rate = learning_rate;
154 self
155 }
156
157 pub fn with_discount_factor(mut self, discount_factor: f64) -> Self {
159 self.discount_factor = discount_factor;
160 self
161 }
162
163 pub fn with_exploration_rate(mut self, exploration_rate: f64) -> Self {
165 self.exploration_rate = exploration_rate;
166 self
167 }
168
169 fn encode_state(&self, state: &Array1<f64>) -> Result<Circuit<8>> {
171 let mut circuit = Circuit::<8>::new();
175
176 for i in 0..state.len().min(8) {
177 circuit.ry(i, state[i] * std::f64::consts::PI)?;
178 }
179
180 Ok(circuit)
181 }
182
183 fn get_q_values(&self, state: &Array1<f64>) -> Result<Array1<f64>> {
185 let mut q_values = Array1::zeros(self.action_dim);
189
190 for i in 0..self.action_dim {
191 q_values[i] = 0.5 + 0.5 * thread_rng().random::<f64>();
192 }
193
194 Ok(q_values)
195 }
196}
197
198impl QuantumAgent for ReinforcementLearning {
199 fn get_action(&self, state: &Array1<f64>) -> Result<usize> {
200 if thread_rng().random::<f64>() < self.exploration_rate {
202 Ok(fastrand::usize(0..self.action_dim))
204 } else {
205 let q_values = self.get_q_values(state)?;
207 let mut best_action = 0;
208 let mut best_value = q_values[0];
209
210 for i in 1..self.action_dim {
211 if q_values[i] > best_value {
212 best_value = q_values[i];
213 best_action = i;
214 }
215 }
216
217 Ok(best_action)
218 }
219 }
220
221 fn update(
222 &mut self,
223 _state: &Array1<f64>,
224 _action: usize,
225 _reward: f64,
226 _next_state: &Array1<f64>,
227 _done: bool,
228 ) -> Result<()> {
229 Ok(())
233 }
234
235 fn train(&mut self, env: &mut dyn Environment, episodes: usize) -> Result<f64> {
236 let mut total_reward = 0.0;
237
238 for _ in 0..episodes {
239 let mut state = env.reset();
240 let mut episode_reward = 0.0;
241 let mut done = false;
242
243 while !done {
244 let action = self.get_action(&state)?;
245 let (next_state, reward, is_done) = env.step(action)?;
246
247 self.update(&state, action, reward, &next_state, is_done)?;
248
249 state = next_state;
250 episode_reward += reward;
251 done = is_done;
252 }
253
254 total_reward += episode_reward;
255 }
256
257 Ok(total_reward / episodes as f64)
258 }
259
260 fn evaluate(&self, env: &mut dyn Environment, episodes: usize) -> Result<f64> {
261 let mut total_reward = 0.0;
262
263 for _ in 0..episodes {
264 let mut state = env.reset();
265 let mut episode_reward = 0.0;
266 let mut done = false;
267
268 while !done {
269 let action = self.get_action(&state)?;
270 let (next_state, reward, is_done) = env.step(action)?;
271
272 state = next_state;
273 episode_reward += reward;
274 done = is_done;
275 }
276
277 total_reward += episode_reward;
278 }
279
280 Ok(total_reward / episodes as f64)
281 }
282}
283
284pub struct GridWorldEnvironment {
286 width: usize,
288
289 height: usize,
291
292 position: (usize, usize),
294
295 goal: (usize, usize),
297
298 obstacles: Vec<(usize, usize)>,
300}
301
302impl GridWorldEnvironment {
303 pub fn new(width: usize, height: usize) -> Self {
305 GridWorldEnvironment {
306 width,
307 height,
308 position: (0, 0),
309 goal: (width - 1, height - 1),
310 obstacles: Vec::new(),
311 }
312 }
313
314 pub fn with_goal(mut self, x: usize, y: usize) -> Self {
316 self.goal = (x.min(self.width - 1), y.min(self.height - 1));
317 self
318 }
319
320 pub fn with_obstacles(mut self, obstacles: Vec<(usize, usize)>) -> Self {
322 self.obstacles = obstacles;
323 self
324 }
325
326 pub fn is_obstacle(&self, x: usize, y: usize) -> bool {
328 self.obstacles.contains(&(x, y))
329 }
330
331 pub fn is_goal(&self, x: usize, y: usize) -> bool {
333 (x, y) == self.goal
334 }
335}
336
337impl Environment for GridWorldEnvironment {
338 fn state(&self) -> Array1<f64> {
339 let mut state = Array1::zeros(4);
340
341 state[0] = self.position.0 as f64 / self.width as f64;
343 state[1] = self.position.1 as f64 / self.height as f64;
344
345 state[2] = self.goal.0 as f64 / self.width as f64;
347 state[3] = self.goal.1 as f64 / self.height as f64;
348
349 state
350 }
351
352 fn num_actions(&self) -> usize {
353 4 }
355
356 fn step(&mut self, action: usize) -> Result<(Array1<f64>, f64, bool)> {
357 let (x, y) = self.position;
359 let (new_x, new_y) = match action {
360 0 => (x, y.saturating_sub(1)), 1 => (x + 1, y), 2 => (x, y + 1), 3 => (x.saturating_sub(1), y), _ => {
365 return Err(MLError::InvalidParameter(format!(
366 "Invalid action: {}",
367 action
368 )))
369 }
370 };
371
372 let new_x = new_x.min(self.width - 1);
374 let new_y = new_y.min(self.height - 1);
375
376 if self.obstacles.contains(&(new_x, new_y)) {
378 let reward = -1.0;
380 let done = false;
381 return Ok((self.state(), reward, done));
382 }
383
384 self.position = (new_x, new_y);
386
387 let reward = if (new_x, new_y) == self.goal {
389 10.0 } else {
391 -0.1 };
393
394 let done = (new_x, new_y) == self.goal;
396
397 Ok((self.state(), reward, done))
398 }
399
400 fn reset(&mut self) -> Array1<f64> {
401 self.position = (0, 0);
402 self.state()
403 }
404}