quantrs2_core/
variational_optimization.rs

1//! Enhanced variational parameter optimization using SciRS2
2//!
3//! This module provides advanced optimization techniques for variational quantum algorithms
4//! leveraging SciRS2's optimization capabilities including:
5//! - Gradient-based methods (BFGS, L-BFGS, Conjugate Gradient)
6//! - Gradient-free methods (Nelder-Mead, Powell, COBYLA)
7//! - Stochastic optimization (SPSA, Adam, RMSprop)
8//! - Natural gradient descent for quantum circuits
9
10use crate::{
11    error::{QuantRS2Error, QuantRS2Result},
12    variational::VariationalCircuit,
13};
14use ndarray::{Array1, Array2};
15use rayon::prelude::*;
16use rustc_hash::FxHashMap;
17use std::sync::{Arc, Mutex};
18
19// Import SciRS2 optimization
20extern crate scirs2_optimize;
21use scirs2_optimize::unconstrained::{minimize, Method, Options};
22
23// Import SciRS2 linear algebra for natural gradient
24extern crate scirs2_linalg;
25
26/// Advanced optimizer for variational quantum circuits
27pub struct VariationalQuantumOptimizer {
28    /// Optimization method
29    method: OptimizationMethod,
30    /// Configuration
31    config: OptimizationConfig,
32    /// History of optimization
33    history: OptimizationHistory,
34    /// Fisher information matrix cache
35    fisher_cache: Option<FisherCache>,
36}
37
38/// Optimization methods available
39#[derive(Debug, Clone)]
40pub enum OptimizationMethod {
41    /// Standard gradient descent
42    GradientDescent { learning_rate: f64 },
43    /// Momentum-based gradient descent
44    Momentum { learning_rate: f64, momentum: f64 },
45    /// Adam optimizer
46    Adam {
47        learning_rate: f64,
48        beta1: f64,
49        beta2: f64,
50        epsilon: f64,
51    },
52    /// RMSprop optimizer
53    RMSprop {
54        learning_rate: f64,
55        decay_rate: f64,
56        epsilon: f64,
57    },
58    /// Natural gradient descent
59    NaturalGradient {
60        learning_rate: f64,
61        regularization: f64,
62    },
63    /// SciRS2 BFGS method
64    BFGS,
65    /// SciRS2 L-BFGS method
66    LBFGS { memory_size: usize },
67    /// SciRS2 Conjugate Gradient
68    ConjugateGradient,
69    /// SciRS2 Nelder-Mead simplex
70    NelderMead,
71    /// SciRS2 Powell's method
72    Powell,
73    /// Simultaneous Perturbation Stochastic Approximation
74    SPSA {
75        a: f64,
76        c: f64,
77        alpha: f64,
78        gamma: f64,
79    },
80    /// Quantum Natural SPSA
81    QNSPSA {
82        learning_rate: f64,
83        regularization: f64,
84        spsa_epsilon: f64,
85    },
86}
87
88/// Configuration for optimization
89#[derive(Clone)]
90pub struct OptimizationConfig {
91    /// Maximum iterations
92    pub max_iterations: usize,
93    /// Function tolerance
94    pub f_tol: f64,
95    /// Gradient tolerance
96    pub g_tol: f64,
97    /// Parameter tolerance
98    pub x_tol: f64,
99    /// Enable parallel gradient computation
100    pub parallel_gradients: bool,
101    /// Batch size for stochastic methods
102    pub batch_size: Option<usize>,
103    /// Random seed
104    pub seed: Option<u64>,
105    /// Callback function after each iteration
106    pub callback: Option<Arc<dyn Fn(&[f64], f64) + Send + Sync>>,
107    /// Early stopping patience
108    pub patience: Option<usize>,
109    /// Gradient clipping value
110    pub grad_clip: Option<f64>,
111}
112
113impl Default for OptimizationConfig {
114    fn default() -> Self {
115        Self {
116            max_iterations: 100,
117            f_tol: 1e-8,
118            g_tol: 1e-8,
119            x_tol: 1e-8,
120            parallel_gradients: true,
121            batch_size: None,
122            seed: None,
123            callback: None,
124            patience: None,
125            grad_clip: None,
126        }
127    }
128}
129
130/// Optimization history tracking
131#[derive(Debug, Clone)]
132pub struct OptimizationHistory {
133    /// Parameter values at each iteration
134    pub parameters: Vec<Vec<f64>>,
135    /// Loss values
136    pub loss_values: Vec<f64>,
137    /// Gradient norms
138    pub gradient_norms: Vec<f64>,
139    /// Iteration times (ms)
140    pub iteration_times: Vec<f64>,
141    /// Total iterations
142    pub total_iterations: usize,
143    /// Converged flag
144    pub converged: bool,
145}
146
147impl OptimizationHistory {
148    fn new() -> Self {
149        Self {
150            parameters: Vec::new(),
151            loss_values: Vec::new(),
152            gradient_norms: Vec::new(),
153            iteration_times: Vec::new(),
154            total_iterations: 0,
155            converged: false,
156        }
157    }
158}
159
160/// Fisher information matrix cache
161struct FisherCache {
162    /// Cached Fisher matrix
163    matrix: Arc<Mutex<Option<Array2<f64>>>>,
164    /// Parameters for cached matrix
165    params: Arc<Mutex<Option<Vec<f64>>>>,
166    /// Cache validity threshold
167    threshold: f64,
168}
169
170/// Optimizer state for stateful methods
171struct OptimizerState {
172    /// Momentum vectors
173    momentum: FxHashMap<String, f64>,
174    /// Adam first moment
175    adam_m: FxHashMap<String, f64>,
176    /// Adam second moment
177    adam_v: FxHashMap<String, f64>,
178    /// RMSprop moving average
179    rms_avg: FxHashMap<String, f64>,
180    /// Iteration counter
181    iteration: usize,
182}
183
184impl VariationalQuantumOptimizer {
185    /// Create a new optimizer
186    pub fn new(method: OptimizationMethod, config: OptimizationConfig) -> Self {
187        let fisher_cache = match &method {
188            OptimizationMethod::NaturalGradient { .. } | OptimizationMethod::QNSPSA { .. } => {
189                Some(FisherCache {
190                    matrix: Arc::new(Mutex::new(None)),
191                    params: Arc::new(Mutex::new(None)),
192                    threshold: 1e-3,
193                })
194            }
195            _ => None,
196        };
197
198        Self {
199            method,
200            config,
201            history: OptimizationHistory::new(),
202            fisher_cache,
203        }
204    }
205
206    /// Optimize a variational circuit
207    pub fn optimize(
208        &mut self,
209        circuit: &mut VariationalCircuit,
210        cost_fn: impl Fn(&VariationalCircuit) -> QuantRS2Result<f64> + Send + Sync + 'static,
211    ) -> QuantRS2Result<OptimizationResult> {
212        let cost_fn = Arc::new(cost_fn);
213
214        match &self.method {
215            OptimizationMethod::BFGS
216            | OptimizationMethod::LBFGS { .. }
217            | OptimizationMethod::ConjugateGradient
218            | OptimizationMethod::NelderMead
219            | OptimizationMethod::Powell => self.optimize_with_scirs2(circuit, cost_fn),
220            _ => self.optimize_custom(circuit, cost_fn),
221        }
222    }
223
224    /// Optimize using SciRS2 methods
225    fn optimize_with_scirs2(
226        &mut self,
227        circuit: &mut VariationalCircuit,
228        cost_fn: Arc<dyn Fn(&VariationalCircuit) -> QuantRS2Result<f64> + Send + Sync>,
229    ) -> QuantRS2Result<OptimizationResult> {
230        let param_names = circuit.parameter_names();
231        let initial_params: Vec<f64> = param_names
232            .iter()
233            .map(|name| circuit.get_parameters().get(name).copied().unwrap_or(0.0))
234            .collect();
235
236        let circuit_clone = Arc::new(Mutex::new(circuit.clone()));
237        let param_names_clone = param_names.clone();
238
239        // Create objective function for SciRS2
240        let objective = move |params: &ndarray::ArrayView1<f64>| -> f64 {
241            let params_slice = params.as_slice().unwrap();
242            let mut param_map = FxHashMap::default();
243            for (name, &value) in param_names_clone.iter().zip(params_slice) {
244                param_map.insert(name.clone(), value);
245            }
246
247            let mut circuit = circuit_clone.lock().unwrap();
248            if circuit.set_parameters(&param_map).is_err() {
249                return f64::INFINITY;
250            }
251
252            match cost_fn(&*circuit) {
253                Ok(loss) => loss,
254                Err(_) => f64::INFINITY,
255            }
256        };
257
258        // Set up SciRS2 method
259        let method = match &self.method {
260            OptimizationMethod::BFGS => Method::BFGS,
261            OptimizationMethod::LBFGS { memory_size: _ } => Method::LBFGS,
262            OptimizationMethod::ConjugateGradient => Method::BFGS, // Use BFGS as fallback
263            OptimizationMethod::NelderMead => Method::NelderMead,
264            OptimizationMethod::Powell => Method::Powell,
265            _ => unreachable!(),
266        };
267
268        // Configure options
269        let options = Options {
270            max_iter: self.config.max_iterations,
271            ftol: self.config.f_tol,
272            gtol: self.config.g_tol,
273            xtol: self.config.x_tol,
274            ..Default::default()
275        };
276
277        // Run optimization
278        let start_time = std::time::Instant::now();
279        let result = minimize(objective, &initial_params, method, Some(options))
280            .map_err(|e| QuantRS2Error::InvalidInput(format!("Optimization failed: {:?}", e)))?;
281
282        // Update circuit with optimal parameters
283        let mut final_params = FxHashMap::default();
284        for (name, &value) in param_names.iter().zip(result.x.as_slice().unwrap()) {
285            final_params.insert(name.clone(), value);
286        }
287        circuit.set_parameters(&final_params)?;
288
289        // Update history
290        self.history.parameters.push(result.x.to_vec());
291        self.history.loss_values.push(result.fun);
292        self.history.total_iterations = result.iterations;
293        self.history.converged = result.success;
294
295        Ok(OptimizationResult {
296            optimal_parameters: final_params,
297            final_loss: result.fun,
298            iterations: result.iterations,
299            converged: result.success,
300            optimization_time: start_time.elapsed().as_secs_f64(),
301            history: self.history.clone(),
302        })
303    }
304
305    /// Optimize using custom methods
306    fn optimize_custom(
307        &mut self,
308        circuit: &mut VariationalCircuit,
309        cost_fn: Arc<dyn Fn(&VariationalCircuit) -> QuantRS2Result<f64> + Send + Sync>,
310    ) -> QuantRS2Result<OptimizationResult> {
311        let mut state = OptimizerState {
312            momentum: FxHashMap::default(),
313            adam_m: FxHashMap::default(),
314            adam_v: FxHashMap::default(),
315            rms_avg: FxHashMap::default(),
316            iteration: 0,
317        };
318
319        let param_names = circuit.parameter_names();
320        let start_time = std::time::Instant::now();
321        let mut best_loss = f64::INFINITY;
322        let mut patience_counter = 0;
323
324        for iter in 0..self.config.max_iterations {
325            let iter_start = std::time::Instant::now();
326
327            // Compute loss
328            let loss = cost_fn(circuit)?;
329
330            // Check for improvement
331            if loss < best_loss - self.config.f_tol {
332                best_loss = loss;
333                patience_counter = 0;
334            } else if let Some(patience) = self.config.patience {
335                patience_counter += 1;
336                if patience_counter >= patience {
337                    self.history.converged = true;
338                    break;
339                }
340            }
341
342            // Compute gradients
343            let gradients = self.compute_gradients(circuit, &cost_fn)?;
344
345            // Clip gradients if requested
346            let gradients = if let Some(max_norm) = self.config.grad_clip {
347                self.clip_gradients(gradients, max_norm)
348            } else {
349                gradients
350            };
351
352            // Update parameters based on method
353            self.update_parameters(circuit, &gradients, &mut state)?;
354
355            // Update history
356            let current_params: Vec<f64> = param_names
357                .iter()
358                .map(|name| circuit.get_parameters().get(name).copied().unwrap_or(0.0))
359                .collect();
360
361            let grad_norm = gradients.values().map(|g| g * g).sum::<f64>().sqrt();
362
363            self.history.parameters.push(current_params);
364            self.history.loss_values.push(loss);
365            self.history.gradient_norms.push(grad_norm);
366            self.history
367                .iteration_times
368                .push(iter_start.elapsed().as_secs_f64() * 1000.0);
369            self.history.total_iterations = iter + 1;
370
371            // Callback
372            if let Some(callback) = &self.config.callback {
373                let params: Vec<f64> = param_names
374                    .iter()
375                    .map(|name| circuit.get_parameters().get(name).copied().unwrap_or(0.0))
376                    .collect();
377                callback(&params, loss);
378            }
379
380            // Check convergence
381            if grad_norm < self.config.g_tol {
382                self.history.converged = true;
383                break;
384            }
385
386            state.iteration += 1;
387        }
388
389        let final_params = circuit.get_parameters();
390        let final_loss = cost_fn(circuit)?;
391
392        Ok(OptimizationResult {
393            optimal_parameters: final_params,
394            final_loss,
395            iterations: self.history.total_iterations,
396            converged: self.history.converged,
397            optimization_time: start_time.elapsed().as_secs_f64(),
398            history: self.history.clone(),
399        })
400    }
401
402    /// Compute gradients for all parameters
403    fn compute_gradients(
404        &self,
405        circuit: &VariationalCircuit,
406        cost_fn: &Arc<dyn Fn(&VariationalCircuit) -> QuantRS2Result<f64> + Send + Sync>,
407    ) -> QuantRS2Result<FxHashMap<String, f64>> {
408        let param_names = circuit.parameter_names();
409
410        if self.config.parallel_gradients {
411            // Parallel gradient computation
412            let gradients: Vec<(String, f64)> = param_names
413                .par_iter()
414                .map(|param_name| {
415                    let grad = self
416                        .compute_single_gradient(circuit, param_name, cost_fn)
417                        .unwrap_or(0.0);
418                    (param_name.clone(), grad)
419                })
420                .collect();
421
422            Ok(gradients.into_iter().collect())
423        } else {
424            // Sequential gradient computation
425            let mut gradients = FxHashMap::default();
426            for param_name in &param_names {
427                let grad = self.compute_single_gradient(circuit, param_name, cost_fn)?;
428                gradients.insert(param_name.clone(), grad);
429            }
430            Ok(gradients)
431        }
432    }
433
434    /// Compute gradient for a single parameter
435    fn compute_single_gradient(
436        &self,
437        circuit: &VariationalCircuit,
438        param_name: &str,
439        cost_fn: &Arc<dyn Fn(&VariationalCircuit) -> QuantRS2Result<f64> + Send + Sync>,
440    ) -> QuantRS2Result<f64> {
441        match &self.method {
442            OptimizationMethod::SPSA { c, .. } => {
443                // SPSA gradient approximation
444                self.spsa_gradient(circuit, param_name, cost_fn, *c)
445            }
446            _ => {
447                // Parameter shift rule
448                self.parameter_shift_gradient(circuit, param_name, cost_fn)
449            }
450        }
451    }
452
453    /// Parameter shift rule gradient
454    fn parameter_shift_gradient(
455        &self,
456        circuit: &VariationalCircuit,
457        param_name: &str,
458        cost_fn: &Arc<dyn Fn(&VariationalCircuit) -> QuantRS2Result<f64> + Send + Sync>,
459    ) -> QuantRS2Result<f64> {
460        let current_params = circuit.get_parameters();
461        let current_value = *current_params.get(param_name).ok_or_else(|| {
462            QuantRS2Error::InvalidInput(format!("Parameter {} not found", param_name))
463        })?;
464
465        // Shift parameter by +π/2
466        let mut circuit_plus = circuit.clone();
467        let mut params_plus = current_params.clone();
468        params_plus.insert(
469            param_name.to_string(),
470            current_value + std::f64::consts::PI / 2.0,
471        );
472        circuit_plus.set_parameters(&params_plus)?;
473        let loss_plus = cost_fn(&circuit_plus)?;
474
475        // Shift parameter by -π/2
476        let mut circuit_minus = circuit.clone();
477        let mut params_minus = current_params.clone();
478        params_minus.insert(
479            param_name.to_string(),
480            current_value - std::f64::consts::PI / 2.0,
481        );
482        circuit_minus.set_parameters(&params_minus)?;
483        let loss_minus = cost_fn(&circuit_minus)?;
484
485        Ok((loss_plus - loss_minus) / 2.0)
486    }
487
488    /// SPSA gradient approximation
489    fn spsa_gradient(
490        &self,
491        circuit: &VariationalCircuit,
492        param_name: &str,
493        cost_fn: &Arc<dyn Fn(&VariationalCircuit) -> QuantRS2Result<f64> + Send + Sync>,
494        epsilon: f64,
495    ) -> QuantRS2Result<f64> {
496        use rand::{rngs::StdRng, Rng, SeedableRng};
497
498        let mut rng = if let Some(seed) = self.config.seed {
499            StdRng::seed_from_u64(seed)
500        } else {
501            StdRng::from_seed(rand::rng().random())
502        };
503
504        let current_params = circuit.get_parameters();
505        let perturbation = if rng.random::<bool>() {
506            epsilon
507        } else {
508            -epsilon
509        };
510
511        // Positive perturbation
512        let mut circuit_plus = circuit.clone();
513        let mut params_plus = current_params.clone();
514        for (name, value) in params_plus.iter_mut() {
515            if name == param_name {
516                *value += perturbation;
517            }
518        }
519        circuit_plus.set_parameters(&params_plus)?;
520        let loss_plus = cost_fn(&circuit_plus)?;
521
522        // Negative perturbation
523        let mut circuit_minus = circuit.clone();
524        let mut params_minus = current_params.clone();
525        for (name, value) in params_minus.iter_mut() {
526            if name == param_name {
527                *value -= perturbation;
528            }
529        }
530        circuit_minus.set_parameters(&params_minus)?;
531        let loss_minus = cost_fn(&circuit_minus)?;
532
533        Ok((loss_plus - loss_minus) / (2.0 * perturbation))
534    }
535
536    /// Clip gradients by norm
537    fn clip_gradients(
538        &self,
539        mut gradients: FxHashMap<String, f64>,
540        max_norm: f64,
541    ) -> FxHashMap<String, f64> {
542        let norm = gradients.values().map(|g| g * g).sum::<f64>().sqrt();
543
544        if norm > max_norm {
545            let scale = max_norm / norm;
546            for grad in gradients.values_mut() {
547                *grad *= scale;
548            }
549        }
550
551        gradients
552    }
553
554    /// Update parameters based on optimization method
555    fn update_parameters(
556        &mut self,
557        circuit: &mut VariationalCircuit,
558        gradients: &FxHashMap<String, f64>,
559        state: &mut OptimizerState,
560    ) -> QuantRS2Result<()> {
561        let mut new_params = circuit.get_parameters();
562
563        match &self.method {
564            OptimizationMethod::GradientDescent { learning_rate } => {
565                // Simple gradient descent
566                for (param_name, &grad) in gradients {
567                    if let Some(value) = new_params.get_mut(param_name) {
568                        *value -= learning_rate * grad;
569                    }
570                }
571            }
572            OptimizationMethod::Momentum {
573                learning_rate,
574                momentum,
575            } => {
576                // Momentum-based gradient descent
577                for (param_name, &grad) in gradients {
578                    let velocity = state.momentum.entry(param_name.clone()).or_insert(0.0);
579                    *velocity = momentum * *velocity - learning_rate * grad;
580
581                    if let Some(value) = new_params.get_mut(param_name) {
582                        *value += *velocity;
583                    }
584                }
585            }
586            OptimizationMethod::Adam {
587                learning_rate,
588                beta1,
589                beta2,
590                epsilon,
591            } => {
592                // Adam optimizer
593                let t = state.iteration as f64 + 1.0;
594                let lr_t = learning_rate * (1.0 - beta2.powf(t)).sqrt() / (1.0 - beta1.powf(t));
595
596                for (param_name, &grad) in gradients {
597                    let m = state.adam_m.entry(param_name.clone()).or_insert(0.0);
598                    let v = state.adam_v.entry(param_name.clone()).or_insert(0.0);
599
600                    *m = beta1 * *m + (1.0 - beta1) * grad;
601                    *v = beta2 * *v + (1.0 - beta2) * grad * grad;
602
603                    if let Some(value) = new_params.get_mut(param_name) {
604                        *value -= lr_t * *m / (v.sqrt() + epsilon);
605                    }
606                }
607            }
608            OptimizationMethod::RMSprop {
609                learning_rate,
610                decay_rate,
611                epsilon,
612            } => {
613                // RMSprop optimizer
614                for (param_name, &grad) in gradients {
615                    let avg = state.rms_avg.entry(param_name.clone()).or_insert(0.0);
616                    *avg = decay_rate * *avg + (1.0 - decay_rate) * grad * grad;
617
618                    if let Some(value) = new_params.get_mut(param_name) {
619                        *value -= learning_rate * grad / (avg.sqrt() + epsilon);
620                    }
621                }
622            }
623            OptimizationMethod::NaturalGradient {
624                learning_rate,
625                regularization,
626            } => {
627                // Natural gradient descent
628                let fisher_inv =
629                    self.compute_fisher_inverse(circuit, gradients, *regularization)?;
630                let natural_grad = self.apply_fisher_inverse(&fisher_inv, gradients);
631
632                for (param_name, &nat_grad) in &natural_grad {
633                    if let Some(value) = new_params.get_mut(param_name) {
634                        *value -= learning_rate * nat_grad;
635                    }
636                }
637            }
638            OptimizationMethod::SPSA { a, alpha, .. } => {
639                // SPSA parameter update
640                let ak = a / (state.iteration as f64 + 1.0).powf(*alpha);
641
642                for (param_name, &grad) in gradients {
643                    if let Some(value) = new_params.get_mut(param_name) {
644                        *value -= ak * grad;
645                    }
646                }
647            }
648            OptimizationMethod::QNSPSA {
649                learning_rate,
650                regularization,
651                ..
652            } => {
653                // Quantum Natural SPSA
654                let fisher_inv =
655                    self.compute_fisher_inverse(circuit, gradients, *regularization)?;
656                let natural_grad = self.apply_fisher_inverse(&fisher_inv, gradients);
657
658                for (param_name, &nat_grad) in &natural_grad {
659                    if let Some(value) = new_params.get_mut(param_name) {
660                        *value -= learning_rate * nat_grad;
661                    }
662                }
663            }
664            _ => {
665                // Should not reach here for SciRS2 methods
666                return Err(QuantRS2Error::InvalidInput(
667                    "Invalid optimization method".to_string(),
668                ));
669            }
670        }
671
672        circuit.set_parameters(&new_params)
673    }
674
675    /// Compute Fisher information matrix inverse
676    fn compute_fisher_inverse(
677        &self,
678        circuit: &VariationalCircuit,
679        gradients: &FxHashMap<String, f64>,
680        regularization: f64,
681    ) -> QuantRS2Result<Array2<f64>> {
682        let param_names: Vec<_> = gradients.keys().cloned().collect();
683        let n_params = param_names.len();
684
685        // Check cache
686        if let Some(cache) = &self.fisher_cache {
687            if let Some(cached_matrix) = cache.matrix.lock().unwrap().as_ref() {
688                if let Some(cached_params) = cache.params.lock().unwrap().as_ref() {
689                    let current_params: Vec<f64> = param_names
690                        .iter()
691                        .map(|name| circuit.get_parameters().get(name).copied().unwrap_or(0.0))
692                        .collect();
693
694                    let diff_norm: f64 = current_params
695                        .iter()
696                        .zip(cached_params.iter())
697                        .map(|(a, b)| (a - b).powi(2))
698                        .sum::<f64>()
699                        .sqrt();
700
701                    if diff_norm < cache.threshold {
702                        return Ok(cached_matrix.clone());
703                    }
704                }
705            }
706        }
707
708        // Compute Fisher information matrix
709        let mut fisher = Array2::zeros((n_params, n_params));
710
711        // Simplified Fisher matrix computation
712        // In practice, this would involve quantum state overlaps
713        for i in 0..n_params {
714            for j in i..n_params {
715                // Approximation: use gradient outer product
716                let value = gradients[&param_names[i]] * gradients[&param_names[j]];
717                fisher[[i, j]] = value;
718                fisher[[j, i]] = value;
719            }
720        }
721
722        // Add regularization
723        for i in 0..n_params {
724            fisher[[i, i]] += regularization;
725        }
726
727        // Compute inverse using simple matrix inversion
728        // For now, use a simple inversion approach
729        // TODO: Use ndarray-linalg when trait import issues are resolved
730        let n = fisher.nrows();
731        let mut fisher_inv = Array2::eye(n);
732
733        // Simple inversion using Gaussian elimination (placeholder)
734        // In practice, should use proper numerical methods
735        if n == 1 {
736            fisher_inv[[0, 0]] = 1.0 / fisher[[0, 0]];
737        } else if n == 2 {
738            let det = fisher[[0, 0]] * fisher[[1, 1]] - fisher[[0, 1]] * fisher[[1, 0]];
739            if det.abs() < 1e-10 {
740                return Err(QuantRS2Error::InvalidInput(
741                    "Fisher matrix is singular".to_string(),
742                ));
743            }
744            fisher_inv[[0, 0]] = fisher[[1, 1]] / det;
745            fisher_inv[[0, 1]] = -fisher[[0, 1]] / det;
746            fisher_inv[[1, 0]] = -fisher[[1, 0]] / det;
747            fisher_inv[[1, 1]] = fisher[[0, 0]] / det;
748        } else {
749            // For larger matrices, return identity as placeholder
750            // TODO: Implement proper inversion
751        }
752
753        // Update cache
754        if let Some(cache) = &self.fisher_cache {
755            let current_params: Vec<f64> = param_names
756                .iter()
757                .map(|name| circuit.get_parameters().get(name).copied().unwrap_or(0.0))
758                .collect();
759
760            *cache.matrix.lock().unwrap() = Some(fisher_inv.clone());
761            *cache.params.lock().unwrap() = Some(current_params);
762        }
763
764        Ok(fisher_inv)
765    }
766
767    /// Apply Fisher information matrix inverse to gradients
768    fn apply_fisher_inverse(
769        &self,
770        fisher_inv: &Array2<f64>,
771        gradients: &FxHashMap<String, f64>,
772    ) -> FxHashMap<String, f64> {
773        let param_names: Vec<_> = gradients.keys().cloned().collect();
774        let grad_vec: Vec<f64> = param_names.iter().map(|name| gradients[name]).collect();
775
776        let grad_array = Array1::from_vec(grad_vec);
777        let natural_grad = fisher_inv.dot(&grad_array);
778
779        let mut result = FxHashMap::default();
780        for (i, name) in param_names.iter().enumerate() {
781            result.insert(name.clone(), natural_grad[i]);
782        }
783
784        result
785    }
786}
787
788/// Optimization result
789#[derive(Debug, Clone)]
790pub struct OptimizationResult {
791    /// Optimal parameters
792    pub optimal_parameters: FxHashMap<String, f64>,
793    /// Final loss value
794    pub final_loss: f64,
795    /// Number of iterations
796    pub iterations: usize,
797    /// Whether optimization converged
798    pub converged: bool,
799    /// Total optimization time (seconds)
800    pub optimization_time: f64,
801    /// Full optimization history
802    pub history: OptimizationHistory,
803}
804
805/// Create optimized VQE optimizer
806pub fn create_vqe_optimizer() -> VariationalQuantumOptimizer {
807    let config = OptimizationConfig {
808        max_iterations: 200,
809        f_tol: 1e-10,
810        g_tol: 1e-10,
811        parallel_gradients: true,
812        grad_clip: Some(1.0),
813        ..Default::default()
814    };
815
816    VariationalQuantumOptimizer::new(OptimizationMethod::LBFGS { memory_size: 10 }, config)
817}
818
819/// Create optimized QAOA optimizer
820pub fn create_qaoa_optimizer() -> VariationalQuantumOptimizer {
821    let config = OptimizationConfig {
822        max_iterations: 100,
823        parallel_gradients: true,
824        ..Default::default()
825    };
826
827    VariationalQuantumOptimizer::new(OptimizationMethod::BFGS, config)
828}
829
830/// Create natural gradient optimizer
831pub fn create_natural_gradient_optimizer(learning_rate: f64) -> VariationalQuantumOptimizer {
832    let config = OptimizationConfig {
833        max_iterations: 100,
834        parallel_gradients: true,
835        ..Default::default()
836    };
837
838    VariationalQuantumOptimizer::new(
839        OptimizationMethod::NaturalGradient {
840            learning_rate,
841            regularization: 1e-4,
842        },
843        config,
844    )
845}
846
847/// Create SPSA optimizer for noisy quantum devices
848pub fn create_spsa_optimizer() -> VariationalQuantumOptimizer {
849    let config = OptimizationConfig {
850        max_iterations: 500,
851        seed: Some(42),
852        ..Default::default()
853    };
854
855    VariationalQuantumOptimizer::new(
856        OptimizationMethod::SPSA {
857            a: 0.1,
858            c: 0.1,
859            alpha: 0.602,
860            gamma: 0.101,
861        },
862        config,
863    )
864}
865
866/// Constrained optimization for variational circuits
867pub struct ConstrainedVariationalOptimizer {
868    /// Base optimizer
869    base_optimizer: VariationalQuantumOptimizer,
870    /// Constraints
871    constraints: Vec<Constraint>,
872}
873
874/// Constraint for optimization
875#[derive(Clone)]
876pub struct Constraint {
877    /// Constraint function
878    pub function: Arc<dyn Fn(&FxHashMap<String, f64>) -> f64 + Send + Sync>,
879    /// Constraint type
880    pub constraint_type: ConstraintType,
881    /// Constraint value
882    pub value: f64,
883}
884
885/// Constraint type
886#[derive(Debug, Clone, Copy)]
887pub enum ConstraintType {
888    /// Equality constraint
889    Eq,
890    /// Inequality constraint
891    Ineq,
892}
893
894impl ConstrainedVariationalOptimizer {
895    /// Create a new constrained optimizer
896    pub fn new(base_optimizer: VariationalQuantumOptimizer) -> Self {
897        Self {
898            base_optimizer,
899            constraints: Vec::new(),
900        }
901    }
902
903    /// Add an equality constraint
904    pub fn add_equality_constraint(
905        &mut self,
906        constraint_fn: impl Fn(&FxHashMap<String, f64>) -> f64 + Send + Sync + 'static,
907        value: f64,
908    ) {
909        self.constraints.push(Constraint {
910            function: Arc::new(constraint_fn),
911            constraint_type: ConstraintType::Eq,
912            value,
913        });
914    }
915
916    /// Add an inequality constraint
917    pub fn add_inequality_constraint(
918        &mut self,
919        constraint_fn: impl Fn(&FxHashMap<String, f64>) -> f64 + Send + Sync + 'static,
920        value: f64,
921    ) {
922        self.constraints.push(Constraint {
923            function: Arc::new(constraint_fn),
924            constraint_type: ConstraintType::Ineq,
925            value,
926        });
927    }
928
929    /// Optimize with constraints
930    pub fn optimize(
931        &mut self,
932        circuit: &mut VariationalCircuit,
933        cost_fn: impl Fn(&VariationalCircuit) -> QuantRS2Result<f64> + Send + Sync + 'static,
934    ) -> QuantRS2Result<OptimizationResult> {
935        if self.constraints.is_empty() {
936            return self.base_optimizer.optimize(circuit, cost_fn);
937        }
938
939        // For constrained optimization, use penalty method
940        let cost_fn = Arc::new(cost_fn);
941        let constraints = self.constraints.clone();
942        let penalty_weight = 1000.0;
943
944        let penalized_cost = move |circuit: &VariationalCircuit| -> QuantRS2Result<f64> {
945            let base_cost = cost_fn(circuit)?;
946            let params = circuit.get_parameters();
947
948            let mut penalty = 0.0;
949            for constraint in &constraints {
950                let constraint_value = (constraint.function)(&params);
951                match constraint.constraint_type {
952                    ConstraintType::Eq => {
953                        penalty += penalty_weight * (constraint_value - constraint.value).powi(2);
954                    }
955                    ConstraintType::Ineq => {
956                        if constraint_value > constraint.value {
957                            penalty +=
958                                penalty_weight * (constraint_value - constraint.value).powi(2);
959                        }
960                    }
961                }
962            }
963
964            Ok(base_cost + penalty)
965        };
966
967        self.base_optimizer.optimize(circuit, penalized_cost)
968    }
969}
970
971/// Hyperparameter optimization for variational circuits
972pub struct HyperparameterOptimizer {
973    /// Search space for hyperparameters
974    search_space: FxHashMap<String, (f64, f64)>,
975    /// Number of trials
976    n_trials: usize,
977    /// Optimization method for inner loop
978    inner_method: OptimizationMethod,
979}
980
981impl HyperparameterOptimizer {
982    /// Create a new hyperparameter optimizer
983    pub fn new(n_trials: usize) -> Self {
984        Self {
985            search_space: FxHashMap::default(),
986            n_trials,
987            inner_method: OptimizationMethod::BFGS,
988        }
989    }
990
991    /// Add a hyperparameter to search
992    pub fn add_hyperparameter(&mut self, name: String, min_value: f64, max_value: f64) {
993        self.search_space.insert(name, (min_value, max_value));
994    }
995
996    /// Optimize hyperparameters
997    pub fn optimize(
998        &self,
999        circuit_builder: impl Fn(&FxHashMap<String, f64>) -> VariationalCircuit + Send + Sync,
1000        cost_fn: impl Fn(&VariationalCircuit) -> QuantRS2Result<f64> + Send + Sync + Clone + 'static,
1001    ) -> QuantRS2Result<HyperparameterResult> {
1002        use rand::{rngs::StdRng, Rng, SeedableRng};
1003
1004        let mut rng = StdRng::from_seed(rand::rng().random());
1005        let mut best_hyperparams = FxHashMap::default();
1006        let mut best_loss = f64::INFINITY;
1007        let mut all_trials = Vec::new();
1008
1009        for _trial in 0..self.n_trials {
1010            // Sample hyperparameters
1011            let mut hyperparams = FxHashMap::default();
1012            for (name, &(min_val, max_val)) in &self.search_space {
1013                let value = rng.random_range(min_val..max_val);
1014                hyperparams.insert(name.clone(), value);
1015            }
1016
1017            // Build circuit with hyperparameters
1018            let mut circuit = circuit_builder(&hyperparams);
1019
1020            // Optimize circuit
1021            let config = OptimizationConfig {
1022                max_iterations: 50,
1023                ..Default::default()
1024            };
1025
1026            let mut optimizer = VariationalQuantumOptimizer::new(self.inner_method.clone(), config);
1027
1028            let result = optimizer.optimize(&mut circuit, cost_fn.clone())?;
1029
1030            all_trials.push(HyperparameterTrial {
1031                hyperparameters: hyperparams.clone(),
1032                final_loss: result.final_loss,
1033                optimal_parameters: result.optimal_parameters,
1034            });
1035
1036            if result.final_loss < best_loss {
1037                best_loss = result.final_loss;
1038                best_hyperparams = hyperparams;
1039            }
1040        }
1041
1042        Ok(HyperparameterResult {
1043            best_hyperparameters: best_hyperparams,
1044            best_loss,
1045            all_trials,
1046        })
1047    }
1048}
1049
1050/// Hyperparameter optimization result
1051#[derive(Debug, Clone)]
1052pub struct HyperparameterResult {
1053    /// Best hyperparameters found
1054    pub best_hyperparameters: FxHashMap<String, f64>,
1055    /// Best loss achieved
1056    pub best_loss: f64,
1057    /// All trials
1058    pub all_trials: Vec<HyperparameterTrial>,
1059}
1060
1061/// Single hyperparameter trial
1062#[derive(Debug, Clone)]
1063pub struct HyperparameterTrial {
1064    /// Hyperparameters used
1065    pub hyperparameters: FxHashMap<String, f64>,
1066    /// Final loss achieved
1067    pub final_loss: f64,
1068    /// Optimal variational parameters
1069    pub optimal_parameters: FxHashMap<String, f64>,
1070}
1071
1072// Clone implementation for VariationalCircuit
1073impl Clone for VariationalCircuit {
1074    fn clone(&self) -> Self {
1075        Self {
1076            gates: self.gates.clone(),
1077            param_map: self.param_map.clone(),
1078            num_qubits: self.num_qubits,
1079        }
1080    }
1081}
1082
1083#[cfg(test)]
1084mod tests {
1085    use super::*;
1086    use crate::qubit::QubitId;
1087    use crate::variational::VariationalGate;
1088
1089    #[test]
1090    fn test_gradient_descent_optimizer() {
1091        let mut circuit = VariationalCircuit::new(1);
1092        circuit.add_gate(VariationalGate::rx(QubitId(0), "theta".to_string(), 0.0));
1093
1094        let config = OptimizationConfig {
1095            max_iterations: 10,
1096            ..Default::default()
1097        };
1098
1099        let mut optimizer = VariationalQuantumOptimizer::new(
1100            OptimizationMethod::GradientDescent { learning_rate: 0.1 },
1101            config,
1102        );
1103
1104        // Simple cost function
1105        let cost_fn = |circuit: &VariationalCircuit| -> QuantRS2Result<f64> {
1106            let theta = circuit
1107                .get_parameters()
1108                .get("theta")
1109                .copied()
1110                .unwrap_or(0.0);
1111            Ok((theta - 1.0).powi(2))
1112        };
1113
1114        let result = optimizer.optimize(&mut circuit, cost_fn).unwrap();
1115
1116        assert!(result.converged || result.iterations == 10);
1117        assert!((result.optimal_parameters["theta"] - 1.0).abs() < 0.1);
1118    }
1119
1120    #[test]
1121    fn test_adam_optimizer() {
1122        let mut circuit = VariationalCircuit::new(2);
1123        circuit.add_gate(VariationalGate::ry(QubitId(0), "alpha".to_string(), 0.5));
1124        circuit.add_gate(VariationalGate::rz(QubitId(1), "beta".to_string(), 0.5));
1125
1126        let config = OptimizationConfig {
1127            max_iterations: 100,
1128            f_tol: 1e-6,
1129            g_tol: 1e-6,
1130            ..Default::default()
1131        };
1132
1133        let mut optimizer = VariationalQuantumOptimizer::new(
1134            OptimizationMethod::Adam {
1135                learning_rate: 0.1,
1136                beta1: 0.9,
1137                beta2: 0.999,
1138                epsilon: 1e-8,
1139            },
1140            config,
1141        );
1142
1143        // Cost function with multiple parameters
1144        let cost_fn = |circuit: &VariationalCircuit| -> QuantRS2Result<f64> {
1145            let params = circuit.get_parameters();
1146            let alpha = params.get("alpha").copied().unwrap_or(0.0);
1147            let beta = params.get("beta").copied().unwrap_or(0.0);
1148            Ok(alpha.powi(2) + beta.powi(2))
1149        };
1150
1151        let result = optimizer.optimize(&mut circuit, cost_fn).unwrap();
1152
1153        assert!(result.optimal_parameters["alpha"].abs() < 0.1);
1154        assert!(result.optimal_parameters["beta"].abs() < 0.1);
1155    }
1156
1157    #[test]
1158    fn test_constrained_optimization() {
1159        let mut circuit = VariationalCircuit::new(1);
1160        circuit.add_gate(VariationalGate::rx(QubitId(0), "x".to_string(), 2.0));
1161
1162        let base_optimizer =
1163            VariationalQuantumOptimizer::new(OptimizationMethod::BFGS, Default::default());
1164
1165        let mut constrained_opt = ConstrainedVariationalOptimizer::new(base_optimizer);
1166
1167        // Add constraint: x >= 1.0
1168        constrained_opt
1169            .add_inequality_constraint(|params| 1.0 - params.get("x").copied().unwrap_or(0.0), 0.0);
1170
1171        // Minimize x^2
1172        let cost_fn = |circuit: &VariationalCircuit| -> QuantRS2Result<f64> {
1173            let x = circuit.get_parameters().get("x").copied().unwrap_or(0.0);
1174            Ok(x.powi(2))
1175        };
1176
1177        let result = constrained_opt.optimize(&mut circuit, cost_fn).unwrap();
1178
1179        // Should converge to x ≈ 1.0
1180        assert!((result.optimal_parameters["x"] - 1.0).abs() < 0.1);
1181    }
1182}