1use crate::{
11 error::{QuantRS2Error, QuantRS2Result},
12 gate::GateOp,
13 qubit::QubitId,
14 register::Register,
15 variational::{DiffMode, VariationalCircuit, VariationalGate},
16};
17use ndarray::{Array1, Array2};
18use num_complex::Complex64;
19use rayon::prelude::*;
20use rustc_hash::FxHashMap;
21use std::sync::{Arc, Mutex};
22
23extern crate scirs2_optimize;
25use scirs2_optimize::unconstrained::{minimize, Method, OptimizeResult, Options};
26
27extern crate scirs2_linalg;
29
30pub struct VariationalQuantumOptimizer {
32 method: OptimizationMethod,
34 config: OptimizationConfig,
36 history: OptimizationHistory,
38 fisher_cache: Option<FisherCache>,
40}
41
42#[derive(Debug, Clone)]
44pub enum OptimizationMethod {
45 GradientDescent { learning_rate: f64 },
47 Momentum { learning_rate: f64, momentum: f64 },
49 Adam {
51 learning_rate: f64,
52 beta1: f64,
53 beta2: f64,
54 epsilon: f64,
55 },
56 RMSprop {
58 learning_rate: f64,
59 decay_rate: f64,
60 epsilon: f64,
61 },
62 NaturalGradient {
64 learning_rate: f64,
65 regularization: f64,
66 },
67 BFGS,
69 LBFGS { memory_size: usize },
71 ConjugateGradient,
73 NelderMead,
75 Powell,
77 SPSA {
79 a: f64,
80 c: f64,
81 alpha: f64,
82 gamma: f64,
83 },
84 QNSPSA {
86 learning_rate: f64,
87 regularization: f64,
88 spsa_epsilon: f64,
89 },
90}
91
92#[derive(Clone)]
94pub struct OptimizationConfig {
95 pub max_iterations: usize,
97 pub f_tol: f64,
99 pub g_tol: f64,
101 pub x_tol: f64,
103 pub parallel_gradients: bool,
105 pub batch_size: Option<usize>,
107 pub seed: Option<u64>,
109 pub callback: Option<Arc<dyn Fn(&[f64], f64) + Send + Sync>>,
111 pub patience: Option<usize>,
113 pub grad_clip: Option<f64>,
115}
116
117impl Default for OptimizationConfig {
118 fn default() -> Self {
119 Self {
120 max_iterations: 100,
121 f_tol: 1e-8,
122 g_tol: 1e-8,
123 x_tol: 1e-8,
124 parallel_gradients: true,
125 batch_size: None,
126 seed: None,
127 callback: None,
128 patience: None,
129 grad_clip: None,
130 }
131 }
132}
133
134#[derive(Debug, Clone)]
136pub struct OptimizationHistory {
137 pub parameters: Vec<Vec<f64>>,
139 pub loss_values: Vec<f64>,
141 pub gradient_norms: Vec<f64>,
143 pub iteration_times: Vec<f64>,
145 pub total_iterations: usize,
147 pub converged: bool,
149}
150
151impl OptimizationHistory {
152 fn new() -> Self {
153 Self {
154 parameters: Vec::new(),
155 loss_values: Vec::new(),
156 gradient_norms: Vec::new(),
157 iteration_times: Vec::new(),
158 total_iterations: 0,
159 converged: false,
160 }
161 }
162}
163
164struct FisherCache {
166 matrix: Arc<Mutex<Option<Array2<f64>>>>,
168 params: Arc<Mutex<Option<Vec<f64>>>>,
170 threshold: f64,
172}
173
174struct OptimizerState {
176 momentum: FxHashMap<String, f64>,
178 adam_m: FxHashMap<String, f64>,
180 adam_v: FxHashMap<String, f64>,
182 rms_avg: FxHashMap<String, f64>,
184 iteration: usize,
186}
187
188impl VariationalQuantumOptimizer {
189 pub fn new(method: OptimizationMethod, config: OptimizationConfig) -> Self {
191 let fisher_cache = match &method {
192 OptimizationMethod::NaturalGradient { .. } | OptimizationMethod::QNSPSA { .. } => {
193 Some(FisherCache {
194 matrix: Arc::new(Mutex::new(None)),
195 params: Arc::new(Mutex::new(None)),
196 threshold: 1e-3,
197 })
198 }
199 _ => None,
200 };
201
202 Self {
203 method,
204 config,
205 history: OptimizationHistory::new(),
206 fisher_cache,
207 }
208 }
209
210 pub fn optimize(
212 &mut self,
213 circuit: &mut VariationalCircuit,
214 cost_fn: impl Fn(&VariationalCircuit) -> QuantRS2Result<f64> + Send + Sync + 'static,
215 ) -> QuantRS2Result<OptimizationResult> {
216 let cost_fn = Arc::new(cost_fn);
217
218 match &self.method {
219 OptimizationMethod::BFGS
220 | OptimizationMethod::LBFGS { .. }
221 | OptimizationMethod::ConjugateGradient
222 | OptimizationMethod::NelderMead
223 | OptimizationMethod::Powell => self.optimize_with_scirs2(circuit, cost_fn),
224 _ => self.optimize_custom(circuit, cost_fn),
225 }
226 }
227
228 fn optimize_with_scirs2(
230 &mut self,
231 circuit: &mut VariationalCircuit,
232 cost_fn: Arc<dyn Fn(&VariationalCircuit) -> QuantRS2Result<f64> + Send + Sync>,
233 ) -> QuantRS2Result<OptimizationResult> {
234 let param_names = circuit.parameter_names();
235 let initial_params: Vec<f64> = param_names
236 .iter()
237 .map(|name| circuit.get_parameters().get(name).copied().unwrap_or(0.0))
238 .collect();
239
240 let circuit_clone = Arc::new(Mutex::new(circuit.clone()));
241 let param_names_clone = param_names.clone();
242
243 let objective = move |params: &ndarray::ArrayView1<f64>| -> f64 {
245 let params_slice = params.as_slice().unwrap();
246 let mut param_map = FxHashMap::default();
247 for (name, &value) in param_names_clone.iter().zip(params_slice) {
248 param_map.insert(name.clone(), value);
249 }
250
251 let mut circuit = circuit_clone.lock().unwrap();
252 if circuit.set_parameters(¶m_map).is_err() {
253 return f64::INFINITY;
254 }
255
256 match cost_fn(&*circuit) {
257 Ok(loss) => loss,
258 Err(_) => f64::INFINITY,
259 }
260 };
261
262 let method = match &self.method {
264 OptimizationMethod::BFGS => Method::BFGS,
265 OptimizationMethod::LBFGS { memory_size: _ } => Method::LBFGS,
266 OptimizationMethod::ConjugateGradient => Method::BFGS, OptimizationMethod::NelderMead => Method::NelderMead,
268 OptimizationMethod::Powell => Method::Powell,
269 _ => unreachable!(),
270 };
271
272 let options = Options {
274 max_iter: self.config.max_iterations,
275 ftol: self.config.f_tol,
276 gtol: self.config.g_tol,
277 xtol: self.config.x_tol,
278 ..Default::default()
279 };
280
281 let start_time = std::time::Instant::now();
283 let result = minimize(objective, &initial_params, method, Some(options))
284 .map_err(|e| QuantRS2Error::InvalidInput(format!("Optimization failed: {:?}", e)))?;
285
286 let mut final_params = FxHashMap::default();
288 for (name, &value) in param_names.iter().zip(result.x.as_slice().unwrap()) {
289 final_params.insert(name.clone(), value);
290 }
291 circuit.set_parameters(&final_params)?;
292
293 self.history.parameters.push(result.x.to_vec());
295 self.history.loss_values.push(result.fun);
296 self.history.total_iterations = result.iterations;
297 self.history.converged = result.success;
298
299 Ok(OptimizationResult {
300 optimal_parameters: final_params,
301 final_loss: result.fun,
302 iterations: result.iterations,
303 converged: result.success,
304 optimization_time: start_time.elapsed().as_secs_f64(),
305 history: self.history.clone(),
306 })
307 }
308
309 fn optimize_custom(
311 &mut self,
312 circuit: &mut VariationalCircuit,
313 cost_fn: Arc<dyn Fn(&VariationalCircuit) -> QuantRS2Result<f64> + Send + Sync>,
314 ) -> QuantRS2Result<OptimizationResult> {
315 let mut state = OptimizerState {
316 momentum: FxHashMap::default(),
317 adam_m: FxHashMap::default(),
318 adam_v: FxHashMap::default(),
319 rms_avg: FxHashMap::default(),
320 iteration: 0,
321 };
322
323 let param_names = circuit.parameter_names();
324 let start_time = std::time::Instant::now();
325 let mut best_loss = f64::INFINITY;
326 let mut patience_counter = 0;
327
328 for iter in 0..self.config.max_iterations {
329 let iter_start = std::time::Instant::now();
330
331 let loss = cost_fn(circuit)?;
333
334 if loss < best_loss - self.config.f_tol {
336 best_loss = loss;
337 patience_counter = 0;
338 } else if let Some(patience) = self.config.patience {
339 patience_counter += 1;
340 if patience_counter >= patience {
341 self.history.converged = true;
342 break;
343 }
344 }
345
346 let gradients = self.compute_gradients(circuit, &cost_fn)?;
348
349 let gradients = if let Some(max_norm) = self.config.grad_clip {
351 self.clip_gradients(gradients, max_norm)
352 } else {
353 gradients
354 };
355
356 self.update_parameters(circuit, &gradients, &mut state)?;
358
359 let current_params: Vec<f64> = param_names
361 .iter()
362 .map(|name| circuit.get_parameters().get(name).copied().unwrap_or(0.0))
363 .collect();
364
365 let grad_norm = gradients.values().map(|g| g * g).sum::<f64>().sqrt();
366
367 self.history.parameters.push(current_params);
368 self.history.loss_values.push(loss);
369 self.history.gradient_norms.push(grad_norm);
370 self.history
371 .iteration_times
372 .push(iter_start.elapsed().as_secs_f64() * 1000.0);
373 self.history.total_iterations = iter + 1;
374
375 if let Some(callback) = &self.config.callback {
377 let params: Vec<f64> = param_names
378 .iter()
379 .map(|name| circuit.get_parameters().get(name).copied().unwrap_or(0.0))
380 .collect();
381 callback(¶ms, loss);
382 }
383
384 if grad_norm < self.config.g_tol {
386 self.history.converged = true;
387 break;
388 }
389
390 state.iteration += 1;
391 }
392
393 let final_params = circuit.get_parameters();
394 let final_loss = cost_fn(circuit)?;
395
396 Ok(OptimizationResult {
397 optimal_parameters: final_params,
398 final_loss,
399 iterations: self.history.total_iterations,
400 converged: self.history.converged,
401 optimization_time: start_time.elapsed().as_secs_f64(),
402 history: self.history.clone(),
403 })
404 }
405
406 fn compute_gradients(
408 &self,
409 circuit: &VariationalCircuit,
410 cost_fn: &Arc<dyn Fn(&VariationalCircuit) -> QuantRS2Result<f64> + Send + Sync>,
411 ) -> QuantRS2Result<FxHashMap<String, f64>> {
412 let param_names = circuit.parameter_names();
413
414 if self.config.parallel_gradients {
415 let gradients: Vec<(String, f64)> = param_names
417 .par_iter()
418 .map(|param_name| {
419 let grad = self
420 .compute_single_gradient(circuit, param_name, cost_fn)
421 .unwrap_or(0.0);
422 (param_name.clone(), grad)
423 })
424 .collect();
425
426 Ok(gradients.into_iter().collect())
427 } else {
428 let mut gradients = FxHashMap::default();
430 for param_name in ¶m_names {
431 let grad = self.compute_single_gradient(circuit, param_name, cost_fn)?;
432 gradients.insert(param_name.clone(), grad);
433 }
434 Ok(gradients)
435 }
436 }
437
438 fn compute_single_gradient(
440 &self,
441 circuit: &VariationalCircuit,
442 param_name: &str,
443 cost_fn: &Arc<dyn Fn(&VariationalCircuit) -> QuantRS2Result<f64> + Send + Sync>,
444 ) -> QuantRS2Result<f64> {
445 match &self.method {
446 OptimizationMethod::SPSA { c, .. } => {
447 self.spsa_gradient(circuit, param_name, cost_fn, *c)
449 }
450 _ => {
451 self.parameter_shift_gradient(circuit, param_name, cost_fn)
453 }
454 }
455 }
456
457 fn parameter_shift_gradient(
459 &self,
460 circuit: &VariationalCircuit,
461 param_name: &str,
462 cost_fn: &Arc<dyn Fn(&VariationalCircuit) -> QuantRS2Result<f64> + Send + Sync>,
463 ) -> QuantRS2Result<f64> {
464 let current_params = circuit.get_parameters();
465 let current_value = *current_params.get(param_name).ok_or_else(|| {
466 QuantRS2Error::InvalidInput(format!("Parameter {} not found", param_name))
467 })?;
468
469 let mut circuit_plus = circuit.clone();
471 let mut params_plus = current_params.clone();
472 params_plus.insert(
473 param_name.to_string(),
474 current_value + std::f64::consts::PI / 2.0,
475 );
476 circuit_plus.set_parameters(¶ms_plus)?;
477 let loss_plus = cost_fn(&circuit_plus)?;
478
479 let mut circuit_minus = circuit.clone();
481 let mut params_minus = current_params.clone();
482 params_minus.insert(
483 param_name.to_string(),
484 current_value - std::f64::consts::PI / 2.0,
485 );
486 circuit_minus.set_parameters(¶ms_minus)?;
487 let loss_minus = cost_fn(&circuit_minus)?;
488
489 Ok((loss_plus - loss_minus) / 2.0)
490 }
491
492 fn spsa_gradient(
494 &self,
495 circuit: &VariationalCircuit,
496 param_name: &str,
497 cost_fn: &Arc<dyn Fn(&VariationalCircuit) -> QuantRS2Result<f64> + Send + Sync>,
498 epsilon: f64,
499 ) -> QuantRS2Result<f64> {
500 use rand::{rngs::StdRng, Rng, SeedableRng};
501
502 let mut rng = if let Some(seed) = self.config.seed {
503 StdRng::seed_from_u64(seed)
504 } else {
505 StdRng::from_seed(rand::thread_rng().gen())
506 };
507
508 let current_params = circuit.get_parameters();
509 let perturbation = if rng.gen::<bool>() { epsilon } else { -epsilon };
510
511 let mut circuit_plus = circuit.clone();
513 let mut params_plus = current_params.clone();
514 for (name, value) in params_plus.iter_mut() {
515 if name == param_name {
516 *value += perturbation;
517 }
518 }
519 circuit_plus.set_parameters(¶ms_plus)?;
520 let loss_plus = cost_fn(&circuit_plus)?;
521
522 let mut circuit_minus = circuit.clone();
524 let mut params_minus = current_params.clone();
525 for (name, value) in params_minus.iter_mut() {
526 if name == param_name {
527 *value -= perturbation;
528 }
529 }
530 circuit_minus.set_parameters(¶ms_minus)?;
531 let loss_minus = cost_fn(&circuit_minus)?;
532
533 Ok((loss_plus - loss_minus) / (2.0 * perturbation))
534 }
535
536 fn clip_gradients(
538 &self,
539 mut gradients: FxHashMap<String, f64>,
540 max_norm: f64,
541 ) -> FxHashMap<String, f64> {
542 let norm = gradients.values().map(|g| g * g).sum::<f64>().sqrt();
543
544 if norm > max_norm {
545 let scale = max_norm / norm;
546 for grad in gradients.values_mut() {
547 *grad *= scale;
548 }
549 }
550
551 gradients
552 }
553
554 fn update_parameters(
556 &mut self,
557 circuit: &mut VariationalCircuit,
558 gradients: &FxHashMap<String, f64>,
559 state: &mut OptimizerState,
560 ) -> QuantRS2Result<()> {
561 let mut new_params = circuit.get_parameters();
562
563 match &self.method {
564 OptimizationMethod::GradientDescent { learning_rate } => {
565 for (param_name, &grad) in gradients {
567 if let Some(value) = new_params.get_mut(param_name) {
568 *value -= learning_rate * grad;
569 }
570 }
571 }
572 OptimizationMethod::Momentum {
573 learning_rate,
574 momentum,
575 } => {
576 for (param_name, &grad) in gradients {
578 let velocity = state.momentum.entry(param_name.clone()).or_insert(0.0);
579 *velocity = momentum * *velocity - learning_rate * grad;
580
581 if let Some(value) = new_params.get_mut(param_name) {
582 *value += *velocity;
583 }
584 }
585 }
586 OptimizationMethod::Adam {
587 learning_rate,
588 beta1,
589 beta2,
590 epsilon,
591 } => {
592 let t = state.iteration as f64 + 1.0;
594 let lr_t = learning_rate * (1.0 - beta2.powf(t)).sqrt() / (1.0 - beta1.powf(t));
595
596 for (param_name, &grad) in gradients {
597 let m = state.adam_m.entry(param_name.clone()).or_insert(0.0);
598 let v = state.adam_v.entry(param_name.clone()).or_insert(0.0);
599
600 *m = beta1 * *m + (1.0 - beta1) * grad;
601 *v = beta2 * *v + (1.0 - beta2) * grad * grad;
602
603 if let Some(value) = new_params.get_mut(param_name) {
604 *value -= lr_t * *m / (v.sqrt() + epsilon);
605 }
606 }
607 }
608 OptimizationMethod::RMSprop {
609 learning_rate,
610 decay_rate,
611 epsilon,
612 } => {
613 for (param_name, &grad) in gradients {
615 let avg = state.rms_avg.entry(param_name.clone()).or_insert(0.0);
616 *avg = decay_rate * *avg + (1.0 - decay_rate) * grad * grad;
617
618 if let Some(value) = new_params.get_mut(param_name) {
619 *value -= learning_rate * grad / (avg.sqrt() + epsilon);
620 }
621 }
622 }
623 OptimizationMethod::NaturalGradient {
624 learning_rate,
625 regularization,
626 } => {
627 let fisher_inv =
629 self.compute_fisher_inverse(circuit, gradients, *regularization)?;
630 let natural_grad = self.apply_fisher_inverse(&fisher_inv, gradients);
631
632 for (param_name, &nat_grad) in &natural_grad {
633 if let Some(value) = new_params.get_mut(param_name) {
634 *value -= learning_rate * nat_grad;
635 }
636 }
637 }
638 OptimizationMethod::SPSA {
639 a, alpha, gamma, ..
640 } => {
641 let ak = a / (state.iteration as f64 + 1.0).powf(*alpha);
643
644 for (param_name, &grad) in gradients {
645 if let Some(value) = new_params.get_mut(param_name) {
646 *value -= ak * grad;
647 }
648 }
649 }
650 OptimizationMethod::QNSPSA {
651 learning_rate,
652 regularization,
653 ..
654 } => {
655 let fisher_inv =
657 self.compute_fisher_inverse(circuit, gradients, *regularization)?;
658 let natural_grad = self.apply_fisher_inverse(&fisher_inv, gradients);
659
660 for (param_name, &nat_grad) in &natural_grad {
661 if let Some(value) = new_params.get_mut(param_name) {
662 *value -= learning_rate * nat_grad;
663 }
664 }
665 }
666 _ => {
667 return Err(QuantRS2Error::InvalidInput(
669 "Invalid optimization method".to_string(),
670 ));
671 }
672 }
673
674 circuit.set_parameters(&new_params)
675 }
676
677 fn compute_fisher_inverse(
679 &self,
680 circuit: &VariationalCircuit,
681 gradients: &FxHashMap<String, f64>,
682 regularization: f64,
683 ) -> QuantRS2Result<Array2<f64>> {
684 let param_names: Vec<_> = gradients.keys().cloned().collect();
685 let n_params = param_names.len();
686
687 if let Some(cache) = &self.fisher_cache {
689 if let Some(cached_matrix) = cache.matrix.lock().unwrap().as_ref() {
690 if let Some(cached_params) = cache.params.lock().unwrap().as_ref() {
691 let current_params: Vec<f64> = param_names
692 .iter()
693 .map(|name| circuit.get_parameters().get(name).copied().unwrap_or(0.0))
694 .collect();
695
696 let diff_norm: f64 = current_params
697 .iter()
698 .zip(cached_params.iter())
699 .map(|(a, b)| (a - b).powi(2))
700 .sum::<f64>()
701 .sqrt();
702
703 if diff_norm < cache.threshold {
704 return Ok(cached_matrix.clone());
705 }
706 }
707 }
708 }
709
710 let mut fisher = Array2::zeros((n_params, n_params));
712
713 for i in 0..n_params {
716 for j in i..n_params {
717 let value = gradients[¶m_names[i]] * gradients[¶m_names[j]];
719 fisher[[i, j]] = value;
720 fisher[[j, i]] = value;
721 }
722 }
723
724 for i in 0..n_params {
726 fisher[[i, i]] += regularization;
727 }
728
729 let n = fisher.nrows();
733 let mut fisher_inv = Array2::eye(n);
734
735 if n == 1 {
738 fisher_inv[[0, 0]] = 1.0 / fisher[[0, 0]];
739 } else if n == 2 {
740 let det = fisher[[0, 0]] * fisher[[1, 1]] - fisher[[0, 1]] * fisher[[1, 0]];
741 if det.abs() < 1e-10 {
742 return Err(QuantRS2Error::InvalidInput(
743 "Fisher matrix is singular".to_string(),
744 ));
745 }
746 fisher_inv[[0, 0]] = fisher[[1, 1]] / det;
747 fisher_inv[[0, 1]] = -fisher[[0, 1]] / det;
748 fisher_inv[[1, 0]] = -fisher[[1, 0]] / det;
749 fisher_inv[[1, 1]] = fisher[[0, 0]] / det;
750 } else {
751 }
754
755 if let Some(cache) = &self.fisher_cache {
757 let current_params: Vec<f64> = param_names
758 .iter()
759 .map(|name| circuit.get_parameters().get(name).copied().unwrap_or(0.0))
760 .collect();
761
762 *cache.matrix.lock().unwrap() = Some(fisher_inv.clone());
763 *cache.params.lock().unwrap() = Some(current_params);
764 }
765
766 Ok(fisher_inv)
767 }
768
769 fn apply_fisher_inverse(
771 &self,
772 fisher_inv: &Array2<f64>,
773 gradients: &FxHashMap<String, f64>,
774 ) -> FxHashMap<String, f64> {
775 let param_names: Vec<_> = gradients.keys().cloned().collect();
776 let grad_vec: Vec<f64> = param_names.iter().map(|name| gradients[name]).collect();
777
778 let grad_array = Array1::from_vec(grad_vec);
779 let natural_grad = fisher_inv.dot(&grad_array);
780
781 let mut result = FxHashMap::default();
782 for (i, name) in param_names.iter().enumerate() {
783 result.insert(name.clone(), natural_grad[i]);
784 }
785
786 result
787 }
788}
789
790#[derive(Debug, Clone)]
792pub struct OptimizationResult {
793 pub optimal_parameters: FxHashMap<String, f64>,
795 pub final_loss: f64,
797 pub iterations: usize,
799 pub converged: bool,
801 pub optimization_time: f64,
803 pub history: OptimizationHistory,
805}
806
807pub fn create_vqe_optimizer() -> VariationalQuantumOptimizer {
809 let config = OptimizationConfig {
810 max_iterations: 200,
811 f_tol: 1e-10,
812 g_tol: 1e-10,
813 parallel_gradients: true,
814 grad_clip: Some(1.0),
815 ..Default::default()
816 };
817
818 VariationalQuantumOptimizer::new(OptimizationMethod::LBFGS { memory_size: 10 }, config)
819}
820
821pub fn create_qaoa_optimizer() -> VariationalQuantumOptimizer {
823 let config = OptimizationConfig {
824 max_iterations: 100,
825 parallel_gradients: true,
826 ..Default::default()
827 };
828
829 VariationalQuantumOptimizer::new(OptimizationMethod::BFGS, config)
830}
831
832pub fn create_natural_gradient_optimizer(learning_rate: f64) -> VariationalQuantumOptimizer {
834 let config = OptimizationConfig {
835 max_iterations: 100,
836 parallel_gradients: true,
837 ..Default::default()
838 };
839
840 VariationalQuantumOptimizer::new(
841 OptimizationMethod::NaturalGradient {
842 learning_rate,
843 regularization: 1e-4,
844 },
845 config,
846 )
847}
848
849pub fn create_spsa_optimizer() -> VariationalQuantumOptimizer {
851 let config = OptimizationConfig {
852 max_iterations: 500,
853 seed: Some(42),
854 ..Default::default()
855 };
856
857 VariationalQuantumOptimizer::new(
858 OptimizationMethod::SPSA {
859 a: 0.1,
860 c: 0.1,
861 alpha: 0.602,
862 gamma: 0.101,
863 },
864 config,
865 )
866}
867
868pub struct ConstrainedVariationalOptimizer {
870 base_optimizer: VariationalQuantumOptimizer,
872 constraints: Vec<Constraint>,
874}
875
876#[derive(Clone)]
878pub struct Constraint {
879 pub function: Arc<dyn Fn(&FxHashMap<String, f64>) -> f64 + Send + Sync>,
881 pub constraint_type: ConstraintType,
883 pub value: f64,
885}
886
887#[derive(Debug, Clone, Copy)]
889pub enum ConstraintType {
890 Eq,
892 Ineq,
894}
895
896impl ConstrainedVariationalOptimizer {
897 pub fn new(base_optimizer: VariationalQuantumOptimizer) -> Self {
899 Self {
900 base_optimizer,
901 constraints: Vec::new(),
902 }
903 }
904
905 pub fn add_equality_constraint(
907 &mut self,
908 constraint_fn: impl Fn(&FxHashMap<String, f64>) -> f64 + Send + Sync + 'static,
909 value: f64,
910 ) {
911 self.constraints.push(Constraint {
912 function: Arc::new(constraint_fn),
913 constraint_type: ConstraintType::Eq,
914 value,
915 });
916 }
917
918 pub fn add_inequality_constraint(
920 &mut self,
921 constraint_fn: impl Fn(&FxHashMap<String, f64>) -> f64 + Send + Sync + 'static,
922 value: f64,
923 ) {
924 self.constraints.push(Constraint {
925 function: Arc::new(constraint_fn),
926 constraint_type: ConstraintType::Ineq,
927 value,
928 });
929 }
930
931 pub fn optimize(
933 &mut self,
934 circuit: &mut VariationalCircuit,
935 cost_fn: impl Fn(&VariationalCircuit) -> QuantRS2Result<f64> + Send + Sync + 'static,
936 ) -> QuantRS2Result<OptimizationResult> {
937 if self.constraints.is_empty() {
938 return self.base_optimizer.optimize(circuit, cost_fn);
939 }
940
941 let cost_fn = Arc::new(cost_fn);
943 let constraints = self.constraints.clone();
944 let penalty_weight = 1000.0;
945
946 let penalized_cost = move |circuit: &VariationalCircuit| -> QuantRS2Result<f64> {
947 let base_cost = cost_fn(circuit)?;
948 let params = circuit.get_parameters();
949
950 let mut penalty = 0.0;
951 for constraint in &constraints {
952 let constraint_value = (constraint.function)(¶ms);
953 match constraint.constraint_type {
954 ConstraintType::Eq => {
955 penalty += penalty_weight * (constraint_value - constraint.value).powi(2);
956 }
957 ConstraintType::Ineq => {
958 if constraint_value > constraint.value {
959 penalty +=
960 penalty_weight * (constraint_value - constraint.value).powi(2);
961 }
962 }
963 }
964 }
965
966 Ok(base_cost + penalty)
967 };
968
969 self.base_optimizer.optimize(circuit, penalized_cost)
970 }
971}
972
973pub struct HyperparameterOptimizer {
975 search_space: FxHashMap<String, (f64, f64)>,
977 n_trials: usize,
979 inner_method: OptimizationMethod,
981}
982
983impl HyperparameterOptimizer {
984 pub fn new(n_trials: usize) -> Self {
986 Self {
987 search_space: FxHashMap::default(),
988 n_trials,
989 inner_method: OptimizationMethod::BFGS,
990 }
991 }
992
993 pub fn add_hyperparameter(&mut self, name: String, min_value: f64, max_value: f64) {
995 self.search_space.insert(name, (min_value, max_value));
996 }
997
998 pub fn optimize(
1000 &self,
1001 circuit_builder: impl Fn(&FxHashMap<String, f64>) -> VariationalCircuit + Send + Sync,
1002 cost_fn: impl Fn(&VariationalCircuit) -> QuantRS2Result<f64> + Send + Sync + Clone + 'static,
1003 ) -> QuantRS2Result<HyperparameterResult> {
1004 use rand::{rngs::StdRng, Rng, SeedableRng};
1005
1006 let mut rng = StdRng::from_seed(rand::thread_rng().gen());
1007 let mut best_hyperparams = FxHashMap::default();
1008 let mut best_loss = f64::INFINITY;
1009 let mut all_trials = Vec::new();
1010
1011 for trial in 0..self.n_trials {
1012 let mut hyperparams = FxHashMap::default();
1014 for (name, &(min_val, max_val)) in &self.search_space {
1015 let value = rng.gen_range(min_val..max_val);
1016 hyperparams.insert(name.clone(), value);
1017 }
1018
1019 let mut circuit = circuit_builder(&hyperparams);
1021
1022 let config = OptimizationConfig {
1024 max_iterations: 50,
1025 ..Default::default()
1026 };
1027
1028 let mut optimizer = VariationalQuantumOptimizer::new(self.inner_method.clone(), config);
1029
1030 let result = optimizer.optimize(&mut circuit, cost_fn.clone())?;
1031
1032 all_trials.push(HyperparameterTrial {
1033 hyperparameters: hyperparams.clone(),
1034 final_loss: result.final_loss,
1035 optimal_parameters: result.optimal_parameters,
1036 });
1037
1038 if result.final_loss < best_loss {
1039 best_loss = result.final_loss;
1040 best_hyperparams = hyperparams;
1041 }
1042 }
1043
1044 Ok(HyperparameterResult {
1045 best_hyperparameters: best_hyperparams,
1046 best_loss,
1047 all_trials,
1048 })
1049 }
1050}
1051
1052#[derive(Debug, Clone)]
1054pub struct HyperparameterResult {
1055 pub best_hyperparameters: FxHashMap<String, f64>,
1057 pub best_loss: f64,
1059 pub all_trials: Vec<HyperparameterTrial>,
1061}
1062
1063#[derive(Debug, Clone)]
1065pub struct HyperparameterTrial {
1066 pub hyperparameters: FxHashMap<String, f64>,
1068 pub final_loss: f64,
1070 pub optimal_parameters: FxHashMap<String, f64>,
1072}
1073
1074impl Clone for VariationalCircuit {
1076 fn clone(&self) -> Self {
1077 Self {
1078 gates: self.gates.clone(),
1079 param_map: self.param_map.clone(),
1080 num_qubits: self.num_qubits,
1081 }
1082 }
1083}
1084
1085#[cfg(test)]
1086mod tests {
1087 use super::*;
1088 use crate::qubit::QubitId;
1089 use crate::variational::{DiffMode, VariationalGate};
1090
1091 #[test]
1092 fn test_gradient_descent_optimizer() {
1093 let mut circuit = VariationalCircuit::new(1);
1094 circuit.add_gate(VariationalGate::rx(QubitId(0), "theta".to_string(), 0.0));
1095
1096 let config = OptimizationConfig {
1097 max_iterations: 10,
1098 ..Default::default()
1099 };
1100
1101 let mut optimizer = VariationalQuantumOptimizer::new(
1102 OptimizationMethod::GradientDescent { learning_rate: 0.1 },
1103 config,
1104 );
1105
1106 let cost_fn = |circuit: &VariationalCircuit| -> QuantRS2Result<f64> {
1108 let theta = circuit
1109 .get_parameters()
1110 .get("theta")
1111 .copied()
1112 .unwrap_or(0.0);
1113 Ok((theta - 1.0).powi(2))
1114 };
1115
1116 let result = optimizer.optimize(&mut circuit, cost_fn).unwrap();
1117
1118 assert!(result.converged || result.iterations == 10);
1119 assert!((result.optimal_parameters["theta"] - 1.0).abs() < 0.1);
1120 }
1121
1122 #[test]
1123 fn test_adam_optimizer() {
1124 let mut circuit = VariationalCircuit::new(2);
1125 circuit.add_gate(VariationalGate::ry(QubitId(0), "alpha".to_string(), 0.5));
1126 circuit.add_gate(VariationalGate::rz(QubitId(1), "beta".to_string(), 0.5));
1127
1128 let config = OptimizationConfig {
1129 max_iterations: 100,
1130 f_tol: 1e-6,
1131 g_tol: 1e-6,
1132 ..Default::default()
1133 };
1134
1135 let mut optimizer = VariationalQuantumOptimizer::new(
1136 OptimizationMethod::Adam {
1137 learning_rate: 0.1,
1138 beta1: 0.9,
1139 beta2: 0.999,
1140 epsilon: 1e-8,
1141 },
1142 config,
1143 );
1144
1145 let cost_fn = |circuit: &VariationalCircuit| -> QuantRS2Result<f64> {
1147 let params = circuit.get_parameters();
1148 let alpha = params.get("alpha").copied().unwrap_or(0.0);
1149 let beta = params.get("beta").copied().unwrap_or(0.0);
1150 Ok(alpha.powi(2) + beta.powi(2))
1151 };
1152
1153 let result = optimizer.optimize(&mut circuit, cost_fn).unwrap();
1154
1155 assert!(result.optimal_parameters["alpha"].abs() < 0.1);
1156 assert!(result.optimal_parameters["beta"].abs() < 0.1);
1157 }
1158
1159 #[test]
1160 fn test_constrained_optimization() {
1161 let mut circuit = VariationalCircuit::new(1);
1162 circuit.add_gate(VariationalGate::rx(QubitId(0), "x".to_string(), 2.0));
1163
1164 let base_optimizer =
1165 VariationalQuantumOptimizer::new(OptimizationMethod::BFGS, Default::default());
1166
1167 let mut constrained_opt = ConstrainedVariationalOptimizer::new(base_optimizer);
1168
1169 constrained_opt
1171 .add_inequality_constraint(|params| 1.0 - params.get("x").copied().unwrap_or(0.0), 0.0);
1172
1173 let cost_fn = |circuit: &VariationalCircuit| -> QuantRS2Result<f64> {
1175 let x = circuit.get_parameters().get("x").copied().unwrap_or(0.0);
1176 Ok(x.powi(2))
1177 };
1178
1179 let result = constrained_opt.optimize(&mut circuit, cost_fn).unwrap();
1180
1181 assert!((result.optimal_parameters["x"] - 1.0).abs() < 0.1);
1183 }
1184}