1use crate::error::{QuantRS2Error, QuantRS2Result};
9use crate::gate::GateOp;
10use crate::qubit::QubitId;
11use scirs2_core::ndarray::{Array1, Array2};
12use scirs2_core::random::{thread_rng, Rng};
13use std::collections::HashMap;
14use std::sync::{Arc, RwLock};
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
18pub enum OptimizationAction {
19 MergeSingleQubitGates { gate_index: usize },
21 CancelInversePairs { gate_index: usize },
23 CommuteGates {
25 gate1_index: usize,
26 gate2_index: usize,
27 },
28 ReplaceSequence {
30 start_index: usize,
31 end_index: usize,
32 },
33 OptimizeTwoQubitGate { gate_index: usize },
35 NoOp,
37}
38
39#[derive(Debug, Clone)]
41pub struct CircuitState {
42 pub depth: usize,
44 pub gate_count: usize,
46 pub two_qubit_count: usize,
48 pub fidelity: f64,
50 pub qubit_count: usize,
52 pub connectivity_density: f64,
54 pub entanglement_measure: f64,
56}
57
58impl CircuitState {
59 pub fn to_features(&self) -> Vec<f64> {
61 vec![
62 self.depth as f64 / 100.0, self.gate_count as f64 / 1000.0,
64 self.two_qubit_count as f64 / 500.0,
65 self.fidelity,
66 self.qubit_count as f64 / 50.0,
67 self.connectivity_density,
68 self.entanglement_measure,
69 ]
70 }
71
72 pub fn from_circuit(gates: &[Box<dyn GateOp>], num_qubits: usize) -> Self {
74 let mut depth_map: HashMap<QubitId, usize> = HashMap::new();
75 let mut two_qubit_count = 0;
76 let mut connectivity_edges = 0;
77
78 for gate in gates {
79 let qubits = gate.qubits();
80
81 if qubits.len() == 2 {
82 two_qubit_count += 1;
83 connectivity_edges += 1;
84 }
85
86 let max_depth = qubits
88 .iter()
89 .map(|q| *depth_map.get(q).unwrap_or(&0))
90 .max()
91 .unwrap_or(0);
92
93 for qubit in qubits {
94 depth_map.insert(qubit, max_depth + 1);
95 }
96 }
97
98 let depth = depth_map.values().max().copied().unwrap_or(0);
99 let gate_count = gates.len();
100
101 let fidelity = 0.9999_f64.powi(gate_count as i32 - two_qubit_count as i32)
103 * 0.99_f64.powi(two_qubit_count as i32);
104
105 let max_edges = num_qubits * (num_qubits - 1) / 2;
107 let connectivity_density = if max_edges > 0 {
108 connectivity_edges as f64 / max_edges as f64
109 } else {
110 0.0
111 };
112
113 let entanglement_measure = (two_qubit_count as f64 / num_qubits as f64).min(1.0);
115
116 Self {
117 depth,
118 gate_count,
119 two_qubit_count,
120 fidelity,
121 qubit_count: num_qubits,
122 connectivity_density,
123 entanglement_measure,
124 }
125 }
126}
127
128pub struct QLearningOptimizer {
130 q_table: Arc<RwLock<HashMap<(Vec<u8>, OptimizationAction), f64>>>,
132 learning_rate: f64,
134 discount_factor: f64,
136 epsilon: f64,
138 epsilon_decay: f64,
140 min_epsilon: f64,
142 episodes: Arc<RwLock<usize>>,
144 performance_history: Arc<RwLock<Vec<OptimizationEpisode>>>,
146}
147
148#[derive(Debug, Clone)]
150pub struct OptimizationEpisode {
151 pub initial_depth: usize,
152 pub final_depth: usize,
153 pub initial_gate_count: usize,
154 pub final_gate_count: usize,
155 pub reward: f64,
156 pub steps_taken: usize,
157}
158
159impl QLearningOptimizer {
160 pub fn new(learning_rate: f64, discount_factor: f64, initial_epsilon: f64) -> Self {
167 Self {
168 q_table: Arc::new(RwLock::new(HashMap::new())),
169 learning_rate,
170 discount_factor,
171 epsilon: initial_epsilon,
172 epsilon_decay: 0.995,
173 min_epsilon: 0.01,
174 episodes: Arc::new(RwLock::new(0)),
175 performance_history: Arc::new(RwLock::new(Vec::new())),
176 }
177 }
178
179 pub fn choose_action(
185 &self,
186 state: &CircuitState,
187 available_actions: &[OptimizationAction],
188 ) -> OptimizationAction {
189 if available_actions.is_empty() {
190 return OptimizationAction::NoOp;
191 }
192
193 let mut rng = thread_rng();
194
195 if rng.gen::<f64>() < self.epsilon {
197 available_actions[rng.gen_range(0..available_actions.len())]
199 } else {
200 self.get_best_action(state, available_actions)
202 }
203 }
204
205 fn get_best_action(
207 &self,
208 state: &CircuitState,
209 available_actions: &[OptimizationAction],
210 ) -> OptimizationAction {
211 let state_key = self.state_to_key(state);
212 let q_table = self.q_table.read().unwrap();
213
214 let mut best_action = available_actions[0];
215 let mut best_q_value = f64::NEG_INFINITY;
216
217 for &action in available_actions {
218 let q_value = *q_table.get(&(state_key.clone(), action)).unwrap_or(&0.0);
219 if q_value > best_q_value {
220 best_q_value = q_value;
221 best_action = action;
222 }
223 }
224
225 best_action
226 }
227
228 pub fn update_q_value(
237 &mut self,
238 state: &CircuitState,
239 action: OptimizationAction,
240 reward: f64,
241 next_state: &CircuitState,
242 next_actions: &[OptimizationAction],
243 ) {
244 let state_key = self.state_to_key(state);
245 let next_state_key = self.state_to_key(next_state);
246
247 let q_table = self.q_table.read().unwrap();
249 let max_next_q = if !next_actions.is_empty() {
250 next_actions
251 .iter()
252 .map(|&a| *q_table.get(&(next_state_key.clone(), a)).unwrap_or(&0.0))
253 .fold(f64::NEG_INFINITY, f64::max)
254 } else {
255 0.0
256 };
257 drop(q_table);
258
259 let mut q_table = self.q_table.write().unwrap();
262 let current_q = *q_table.get(&(state_key.clone(), action)).unwrap_or(&0.0);
263 let new_q = current_q
264 + self.learning_rate * (reward + self.discount_factor * max_next_q - current_q);
265 q_table.insert((state_key, action), new_q);
266 }
267
268 pub fn calculate_reward(&self, old_state: &CircuitState, new_state: &CircuitState) -> f64 {
275 let mut reward = 0.0;
276
277 let depth_improvement = old_state.depth as f64 - new_state.depth as f64;
279 reward += depth_improvement * 2.0;
280
281 let gate_improvement = old_state.gate_count as f64 - new_state.gate_count as f64;
283 reward += gate_improvement * 1.0;
284
285 let two_qubit_improvement =
287 old_state.two_qubit_count as f64 - new_state.two_qubit_count as f64;
288 reward += two_qubit_improvement * 3.0;
289
290 let fidelity_change = new_state.fidelity - old_state.fidelity;
292 reward += fidelity_change * 100.0; if reward == 0.0 {
296 reward = -0.1;
297 }
298
299 reward
300 }
301
302 pub fn finish_episode(&mut self, episode: OptimizationEpisode) {
304 self.epsilon = (self.epsilon * self.epsilon_decay).max(self.min_epsilon);
306
307 {
309 let mut episodes = self.episodes.write().unwrap();
310 *episodes += 1;
311
312 let mut history = self.performance_history.write().unwrap();
313 history.push(episode);
314
315 if history.len() > 1000 {
317 let len = history.len();
318 history.drain(0..len - 1000);
319 }
320 }
321 }
322
323 pub fn get_statistics(&self) -> OptimizationStatistics {
325 let history = self.performance_history.read().unwrap();
326
327 if history.is_empty() {
328 return OptimizationStatistics {
329 total_episodes: 0,
330 average_depth_improvement: 0.0,
331 average_gate_reduction: 0.0,
332 average_reward: 0.0,
333 current_epsilon: self.epsilon,
334 q_table_size: self.q_table.read().unwrap().len(),
335 };
336 }
337
338 let total_episodes = history.len();
339 let avg_depth_improvement: f64 = history
340 .iter()
341 .map(|e| (e.initial_depth - e.final_depth) as f64)
342 .sum::<f64>()
343 / total_episodes as f64;
344
345 let avg_gate_reduction: f64 = history
346 .iter()
347 .map(|e| (e.initial_gate_count - e.final_gate_count) as f64)
348 .sum::<f64>()
349 / total_episodes as f64;
350
351 let avg_reward: f64 = history.iter().map(|e| e.reward).sum::<f64>() / total_episodes as f64;
352
353 OptimizationStatistics {
354 total_episodes,
355 average_depth_improvement: avg_depth_improvement,
356 average_gate_reduction: avg_gate_reduction,
357 average_reward: avg_reward,
358 current_epsilon: self.epsilon,
359 q_table_size: self.q_table.read().unwrap().len(),
360 }
361 }
362
363 fn state_to_key(&self, state: &CircuitState) -> Vec<u8> {
365 let features = state.to_features();
367 features
368 .iter()
369 .map(|&f| ((f * 10.0).round() as i32).clamp(0, 255) as u8)
370 .collect()
371 }
372
373 pub fn save_q_table(&self, path: &str) -> QuantRS2Result<()> {
375 Ok(())
378 }
379
380 pub fn load_q_table(&mut self, path: &str) -> QuantRS2Result<()> {
382 Ok(())
385 }
386}
387
388#[derive(Debug, Clone)]
390pub struct OptimizationStatistics {
391 pub total_episodes: usize,
392 pub average_depth_improvement: f64,
393 pub average_gate_reduction: f64,
394 pub average_reward: f64,
395 pub current_epsilon: f64,
396 pub q_table_size: usize,
397}
398
399impl Default for QLearningOptimizer {
400 fn default() -> Self {
401 Self::new(0.1, 0.95, 0.3)
402 }
403}
404
405#[cfg(test)]
406mod tests {
407 use super::*;
408
409 #[test]
410 fn test_circuit_state_creation() {
411 let state = CircuitState {
412 depth: 10,
413 gate_count: 50,
414 two_qubit_count: 15,
415 fidelity: 0.95,
416 qubit_count: 5,
417 connectivity_density: 0.6,
418 entanglement_measure: 0.8,
419 };
420
421 let features = state.to_features();
422 assert_eq!(features.len(), 7);
423 assert!(features.iter().all(|&f| f >= 0.0 && f <= 1.1)); }
425
426 #[test]
427 fn test_q_learning_optimizer_creation() {
428 let optimizer = QLearningOptimizer::new(0.1, 0.95, 0.3);
429 assert_eq!(optimizer.learning_rate, 0.1);
430 assert_eq!(optimizer.discount_factor, 0.95);
431 assert_eq!(optimizer.epsilon, 0.3);
432 }
433
434 #[test]
435 fn test_action_selection() {
436 let optimizer = QLearningOptimizer::new(0.1, 0.95, 0.0); let state = CircuitState {
439 depth: 10,
440 gate_count: 50,
441 two_qubit_count: 15,
442 fidelity: 0.95,
443 qubit_count: 5,
444 connectivity_density: 0.6,
445 entanglement_measure: 0.8,
446 };
447
448 let actions = vec![
449 OptimizationAction::MergeSingleQubitGates { gate_index: 0 },
450 OptimizationAction::CancelInversePairs { gate_index: 1 },
451 ];
452
453 let action = optimizer.choose_action(&state, &actions);
454 assert!(actions.contains(&action));
455 }
456
457 #[test]
458 fn test_reward_calculation() {
459 let optimizer = QLearningOptimizer::new(0.1, 0.95, 0.3);
460
461 let old_state = CircuitState {
462 depth: 10,
463 gate_count: 50,
464 two_qubit_count: 15,
465 fidelity: 0.95,
466 qubit_count: 5,
467 connectivity_density: 0.6,
468 entanglement_measure: 0.8,
469 };
470
471 let new_state = CircuitState {
472 depth: 8,
473 gate_count: 45,
474 two_qubit_count: 12,
475 fidelity: 0.96,
476 qubit_count: 5,
477 connectivity_density: 0.6,
478 entanglement_measure: 0.8,
479 };
480
481 let reward = optimizer.calculate_reward(&old_state, &new_state);
482 assert!(reward > 0.0); }
484
485 #[test]
486 fn test_q_value_update() {
487 let mut optimizer = QLearningOptimizer::new(0.1, 0.95, 0.3);
488
489 let state = CircuitState {
490 depth: 10,
491 gate_count: 50,
492 two_qubit_count: 15,
493 fidelity: 0.95,
494 qubit_count: 5,
495 connectivity_density: 0.6,
496 entanglement_measure: 0.8,
497 };
498
499 let action = OptimizationAction::MergeSingleQubitGates { gate_index: 0 };
500
501 let next_state = CircuitState {
502 depth: 9,
503 gate_count: 48,
504 two_qubit_count: 15,
505 fidelity: 0.95,
506 qubit_count: 5,
507 connectivity_density: 0.6,
508 entanglement_measure: 0.8,
509 };
510
511 optimizer.update_q_value(&state, action, 5.0, &next_state, &[]);
512
513 let q_table = optimizer.q_table.read().unwrap();
515 assert!(!q_table.is_empty());
516 }
517
518 #[test]
519 fn test_epsilon_decay() {
520 let mut optimizer = QLearningOptimizer::new(0.1, 0.95, 0.5);
521 let initial_epsilon = optimizer.epsilon;
522
523 let episode = OptimizationEpisode {
524 initial_depth: 10,
525 final_depth: 8,
526 initial_gate_count: 50,
527 final_gate_count: 45,
528 reward: 10.0,
529 steps_taken: 5,
530 };
531
532 optimizer.finish_episode(episode);
533
534 assert!(optimizer.epsilon < initial_epsilon);
535 assert!(optimizer.epsilon >= optimizer.min_epsilon);
536 }
537
538 #[test]
539 fn test_statistics() {
540 let mut optimizer = QLearningOptimizer::new(0.1, 0.95, 0.3);
541
542 let episode1 = OptimizationEpisode {
543 initial_depth: 10,
544 final_depth: 8,
545 initial_gate_count: 50,
546 final_gate_count: 45,
547 reward: 10.0,
548 steps_taken: 5,
549 };
550
551 let episode2 = OptimizationEpisode {
552 initial_depth: 12,
553 final_depth: 9,
554 initial_gate_count: 60,
555 final_gate_count: 52,
556 reward: 15.0,
557 steps_taken: 7,
558 };
559
560 optimizer.finish_episode(episode1);
561 optimizer.finish_episode(episode2);
562
563 let stats = optimizer.get_statistics();
564 assert_eq!(stats.total_episodes, 2);
565 assert!(stats.average_depth_improvement > 0.0);
566 assert!(stats.average_gate_reduction > 0.0);
567 }
568}