Skip to main content

quantrs2_core/
variational_optimization.rs

1//! Enhanced variational parameter optimization using SciRS2
2//!
3//! This module provides advanced optimization techniques for variational quantum algorithms
4//! leveraging SciRS2's optimization capabilities including:
5//! - Gradient-based methods (BFGS, L-BFGS, Conjugate Gradient)
6//! - Gradient-free methods (Nelder-Mead, Powell, COBYLA)
7//! - Stochastic optimization (SPSA, Adam, RMSprop)
8//! - Natural gradient descent for quantum circuits
9
10use crate::{
11    error::{QuantRS2Error, QuantRS2Result},
12    variational::VariationalCircuit,
13};
14use scirs2_core::ndarray::{Array1, Array2};
15// use scirs2_core::parallel_ops::*;
16use crate::optimization_stubs::{minimize, Method, OptimizeResult, Options};
17use crate::parallel_ops_stubs::*;
18// use scirs2_core::optimization::{minimize, Method, OptimizeResult, Options};
19use rustc_hash::FxHashMap;
20use std::sync::{Arc, Mutex};
21
22// Import SciRS2 optimization
23// extern crate scirs2_optimize;
24// use scirs2_optimize::unconstrained::{minimize, Method, Options};
25
26// Import SciRS2 linear algebra for natural gradient (SciRS2 POLICY)
27// extern crate scirs2_linalg; // available via Cargo.toml workspace dependency
28
29/// Advanced optimizer for variational quantum circuits
30pub struct VariationalQuantumOptimizer {
31    /// Optimization method
32    method: OptimizationMethod,
33    /// Configuration
34    config: OptimizationConfig,
35    /// History of optimization
36    history: OptimizationHistory,
37    /// Fisher information matrix cache
38    fisher_cache: Option<FisherCache>,
39}
40
41/// Optimization methods available
42#[derive(Debug, Clone)]
43pub enum OptimizationMethod {
44    /// Standard gradient descent
45    GradientDescent { learning_rate: f64 },
46    /// Momentum-based gradient descent
47    Momentum { learning_rate: f64, momentum: f64 },
48    /// Adam optimizer
49    Adam {
50        learning_rate: f64,
51        beta1: f64,
52        beta2: f64,
53        epsilon: f64,
54    },
55    /// RMSprop optimizer
56    RMSprop {
57        learning_rate: f64,
58        decay_rate: f64,
59        epsilon: f64,
60    },
61    /// Natural gradient descent
62    NaturalGradient {
63        learning_rate: f64,
64        regularization: f64,
65    },
66    /// SciRS2 BFGS method
67    BFGS,
68    /// SciRS2 L-BFGS method
69    LBFGS { memory_size: usize },
70    /// SciRS2 Conjugate Gradient
71    ConjugateGradient,
72    /// SciRS2 Nelder-Mead simplex
73    NelderMead,
74    /// SciRS2 Powell's method
75    Powell,
76    /// Simultaneous Perturbation Stochastic Approximation
77    SPSA {
78        a: f64,
79        c: f64,
80        alpha: f64,
81        gamma: f64,
82    },
83    /// Quantum Natural SPSA
84    QNSPSA {
85        learning_rate: f64,
86        regularization: f64,
87        spsa_epsilon: f64,
88    },
89}
90
91/// Configuration for optimization
92#[derive(Clone)]
93pub struct OptimizationConfig {
94    /// Maximum iterations
95    pub max_iterations: usize,
96    /// Function tolerance
97    pub f_tol: f64,
98    /// Gradient tolerance
99    pub g_tol: f64,
100    /// Parameter tolerance
101    pub x_tol: f64,
102    /// Enable parallel gradient computation
103    pub parallel_gradients: bool,
104    /// Batch size for stochastic methods
105    pub batch_size: Option<usize>,
106    /// Random seed
107    pub seed: Option<u64>,
108    /// Callback function after each iteration
109    pub callback: Option<Arc<dyn Fn(&[f64], f64) + Send + Sync>>,
110    /// Early stopping patience
111    pub patience: Option<usize>,
112    /// Gradient clipping value
113    pub grad_clip: Option<f64>,
114}
115
116impl Default for OptimizationConfig {
117    fn default() -> Self {
118        Self {
119            max_iterations: 100,
120            f_tol: 1e-8,
121            g_tol: 1e-8,
122            x_tol: 1e-8,
123            parallel_gradients: true,
124            batch_size: None,
125            seed: None,
126            callback: None,
127            patience: None,
128            grad_clip: None,
129        }
130    }
131}
132
133/// Optimization history tracking
134#[derive(Debug, Clone)]
135pub struct OptimizationHistory {
136    /// Parameter values at each iteration
137    pub parameters: Vec<Vec<f64>>,
138    /// Loss values
139    pub loss_values: Vec<f64>,
140    /// Gradient norms
141    pub gradient_norms: Vec<f64>,
142    /// Iteration times (ms)
143    pub iteration_times: Vec<f64>,
144    /// Total iterations
145    pub total_iterations: usize,
146    /// Converged flag
147    pub converged: bool,
148}
149
150impl OptimizationHistory {
151    const fn new() -> Self {
152        Self {
153            parameters: Vec::new(),
154            loss_values: Vec::new(),
155            gradient_norms: Vec::new(),
156            iteration_times: Vec::new(),
157            total_iterations: 0,
158            converged: false,
159        }
160    }
161}
162
163/// Fisher information matrix cache
164struct FisherCache {
165    /// Cached Fisher matrix
166    matrix: Arc<Mutex<Option<Array2<f64>>>>,
167    /// Parameters for cached matrix
168    params: Arc<Mutex<Option<Vec<f64>>>>,
169    /// Cache validity threshold
170    threshold: f64,
171}
172
173/// Optimizer state for stateful methods
174struct OptimizerState {
175    /// Momentum vectors
176    momentum: FxHashMap<String, f64>,
177    /// Adam first moment
178    adam_m: FxHashMap<String, f64>,
179    /// Adam second moment
180    adam_v: FxHashMap<String, f64>,
181    /// RMSprop moving average
182    rms_avg: FxHashMap<String, f64>,
183    /// Iteration counter
184    iteration: usize,
185}
186
187impl VariationalQuantumOptimizer {
188    /// Create a new optimizer
189    pub fn new(method: OptimizationMethod, config: OptimizationConfig) -> Self {
190        let fisher_cache = match &method {
191            OptimizationMethod::NaturalGradient { .. } | OptimizationMethod::QNSPSA { .. } => {
192                Some(FisherCache {
193                    matrix: Arc::new(Mutex::new(None)),
194                    params: Arc::new(Mutex::new(None)),
195                    threshold: 1e-3,
196                })
197            }
198            _ => None,
199        };
200
201        Self {
202            method,
203            config,
204            history: OptimizationHistory::new(),
205            fisher_cache,
206        }
207    }
208
209    /// Optimize a variational circuit
210    pub fn optimize(
211        &mut self,
212        circuit: &mut VariationalCircuit,
213        cost_fn: impl Fn(&VariationalCircuit) -> QuantRS2Result<f64> + Send + Sync + 'static,
214    ) -> QuantRS2Result<OptimizationResult> {
215        let cost_fn = Arc::new(cost_fn);
216
217        match &self.method {
218            OptimizationMethod::BFGS
219            | OptimizationMethod::LBFGS { .. }
220            | OptimizationMethod::ConjugateGradient
221            | OptimizationMethod::NelderMead
222            | OptimizationMethod::Powell => self.optimize_with_scirs2(circuit, cost_fn),
223            _ => self.optimize_custom(circuit, cost_fn),
224        }
225    }
226
227    /// Optimize using SciRS2 methods
228    fn optimize_with_scirs2(
229        &mut self,
230        circuit: &mut VariationalCircuit,
231        cost_fn: Arc<dyn Fn(&VariationalCircuit) -> QuantRS2Result<f64> + Send + Sync>,
232    ) -> QuantRS2Result<OptimizationResult> {
233        let param_names = circuit.parameter_names();
234        let initial_params: Vec<f64> = param_names
235            .iter()
236            .map(|name| circuit.get_parameters().get(name).copied().unwrap_or(0.0))
237            .collect();
238
239        let circuit_clone = Arc::new(Mutex::new(circuit.clone()));
240        let param_names_clone = param_names.clone();
241
242        // Create objective function for SciRS2
243        let objective = move |params: &scirs2_core::ndarray::ArrayView1<f64>| -> f64 {
244            let params_slice = match params.as_slice() {
245                Some(slice) => slice,
246                None => return f64::INFINITY, // Non-contiguous array - return infinity to skip
247            };
248            let mut param_map = FxHashMap::default();
249            for (name, &value) in param_names_clone.iter().zip(params_slice) {
250                param_map.insert(name.clone(), value);
251            }
252
253            let mut circuit = circuit_clone.lock().unwrap_or_else(|e| e.into_inner());
254            if circuit.set_parameters(&param_map).is_err() {
255                return f64::INFINITY;
256            }
257
258            cost_fn(&*circuit).unwrap_or(f64::INFINITY)
259        };
260
261        // Set up SciRS2 method
262        let method = match &self.method {
263            OptimizationMethod::BFGS | OptimizationMethod::ConjugateGradient => Method::BFGS, // CG uses BFGS fallback
264            OptimizationMethod::LBFGS { memory_size: _ } => Method::LBFGS,
265            OptimizationMethod::NelderMead => Method::NelderMead,
266            OptimizationMethod::Powell => Method::Powell,
267            _ => unreachable!(),
268        };
269
270        // Configure options
271        let options = Options {
272            max_iter: self.config.max_iterations,
273            ftol: self.config.f_tol,
274            gtol: self.config.g_tol,
275            xtol: self.config.x_tol,
276            ..Default::default()
277        };
278
279        // Run optimization
280        let start_time = std::time::Instant::now();
281        let initial_array = scirs2_core::ndarray::Array1::from_vec(initial_params);
282        let result = minimize(objective, &initial_array, method, Some(options))
283            .map_err(|e| QuantRS2Error::InvalidInput(format!("Optimization failed: {e:?}")))?;
284
285        // Update circuit with optimal parameters
286        let mut final_params = FxHashMap::default();
287        let result_slice = result.x.as_slice().ok_or_else(|| {
288            QuantRS2Error::RuntimeError("Optimization result array is non-contiguous".to_string())
289        })?;
290        for (name, &value) in param_names.iter().zip(result_slice) {
291            final_params.insert(name.clone(), value);
292        }
293        circuit.set_parameters(&final_params)?;
294
295        // Update history
296        self.history.parameters.push(result.x.to_vec());
297        self.history.loss_values.push(result.fun);
298        self.history.total_iterations = result.iterations;
299        self.history.converged = result.success;
300
301        Ok(OptimizationResult {
302            optimal_parameters: final_params,
303            final_loss: result.fun,
304            iterations: result.iterations,
305            converged: result.success,
306            optimization_time: start_time.elapsed().as_secs_f64(),
307            history: self.history.clone(),
308        })
309    }
310
311    /// Optimize using custom methods
312    fn optimize_custom(
313        &mut self,
314        circuit: &mut VariationalCircuit,
315        cost_fn: Arc<dyn Fn(&VariationalCircuit) -> QuantRS2Result<f64> + Send + Sync>,
316    ) -> QuantRS2Result<OptimizationResult> {
317        let mut state = OptimizerState {
318            momentum: FxHashMap::default(),
319            adam_m: FxHashMap::default(),
320            adam_v: FxHashMap::default(),
321            rms_avg: FxHashMap::default(),
322            iteration: 0,
323        };
324
325        let param_names = circuit.parameter_names();
326        let start_time = std::time::Instant::now();
327        let mut best_loss = f64::INFINITY;
328        let mut patience_counter = 0;
329
330        for iter in 0..self.config.max_iterations {
331            let iter_start = std::time::Instant::now();
332
333            // Compute loss
334            let loss = cost_fn(circuit)?;
335
336            // Check for improvement
337            if loss < best_loss - self.config.f_tol {
338                best_loss = loss;
339                patience_counter = 0;
340            } else if let Some(patience) = self.config.patience {
341                patience_counter += 1;
342                if patience_counter >= patience {
343                    self.history.converged = true;
344                    break;
345                }
346            }
347
348            // Compute gradients
349            let gradients = self.compute_gradients(circuit, &cost_fn)?;
350
351            // Clip gradients if requested
352            let gradients = if let Some(max_norm) = self.config.grad_clip {
353                self.clip_gradients(gradients, max_norm)
354            } else {
355                gradients
356            };
357
358            // Update parameters based on method
359            self.update_parameters(circuit, &gradients, &mut state)?;
360
361            // Update history
362            let current_params: Vec<f64> = param_names
363                .iter()
364                .map(|name| circuit.get_parameters().get(name).copied().unwrap_or(0.0))
365                .collect();
366
367            let grad_norm = gradients.values().map(|g| g * g).sum::<f64>().sqrt();
368
369            self.history.parameters.push(current_params);
370            self.history.loss_values.push(loss);
371            self.history.gradient_norms.push(grad_norm);
372            self.history
373                .iteration_times
374                .push(iter_start.elapsed().as_secs_f64() * 1000.0);
375            self.history.total_iterations = iter + 1;
376
377            // Callback
378            if let Some(callback) = &self.config.callback {
379                let params: Vec<f64> = param_names
380                    .iter()
381                    .map(|name| circuit.get_parameters().get(name).copied().unwrap_or(0.0))
382                    .collect();
383                callback(&params, loss);
384            }
385
386            // Check convergence
387            if grad_norm < self.config.g_tol {
388                self.history.converged = true;
389                break;
390            }
391
392            state.iteration += 1;
393        }
394
395        let final_params = circuit.get_parameters();
396        let final_loss = cost_fn(circuit)?;
397
398        Ok(OptimizationResult {
399            optimal_parameters: final_params,
400            final_loss,
401            iterations: self.history.total_iterations,
402            converged: self.history.converged,
403            optimization_time: start_time.elapsed().as_secs_f64(),
404            history: self.history.clone(),
405        })
406    }
407
408    /// Compute gradients for all parameters
409    fn compute_gradients(
410        &self,
411        circuit: &VariationalCircuit,
412        cost_fn: &Arc<dyn Fn(&VariationalCircuit) -> QuantRS2Result<f64> + Send + Sync>,
413    ) -> QuantRS2Result<FxHashMap<String, f64>> {
414        let param_names = circuit.parameter_names();
415
416        if self.config.parallel_gradients {
417            // Parallel gradient computation
418            let gradients: Vec<(String, f64)> = param_names
419                .par_iter()
420                .map(|param_name| {
421                    let grad = self
422                        .compute_single_gradient(circuit, param_name, cost_fn)
423                        .unwrap_or(0.0);
424                    (param_name.clone(), grad)
425                })
426                .collect();
427
428            Ok(gradients.into_iter().collect())
429        } else {
430            // Sequential gradient computation
431            let mut gradients = FxHashMap::default();
432            for param_name in &param_names {
433                let grad = self.compute_single_gradient(circuit, param_name, cost_fn)?;
434                gradients.insert(param_name.clone(), grad);
435            }
436            Ok(gradients)
437        }
438    }
439
440    /// Compute gradient for a single parameter
441    fn compute_single_gradient(
442        &self,
443        circuit: &VariationalCircuit,
444        param_name: &str,
445        cost_fn: &Arc<dyn Fn(&VariationalCircuit) -> QuantRS2Result<f64> + Send + Sync>,
446    ) -> QuantRS2Result<f64> {
447        match &self.method {
448            OptimizationMethod::SPSA { c, .. } => {
449                // SPSA gradient approximation
450                self.spsa_gradient(circuit, param_name, cost_fn, *c)
451            }
452            _ => {
453                // Parameter shift rule
454                self.parameter_shift_gradient(circuit, param_name, cost_fn)
455            }
456        }
457    }
458
459    /// Parameter shift rule gradient
460    fn parameter_shift_gradient(
461        &self,
462        circuit: &VariationalCircuit,
463        param_name: &str,
464        cost_fn: &Arc<dyn Fn(&VariationalCircuit) -> QuantRS2Result<f64> + Send + Sync>,
465    ) -> QuantRS2Result<f64> {
466        let current_params = circuit.get_parameters();
467        let current_value = *current_params.get(param_name).ok_or_else(|| {
468            QuantRS2Error::InvalidInput(format!("Parameter {param_name} not found"))
469        })?;
470
471        // Shift parameter by +π/2
472        let mut circuit_plus = circuit.clone();
473        let mut params_plus = current_params.clone();
474        params_plus.insert(
475            param_name.to_string(),
476            current_value + std::f64::consts::PI / 2.0,
477        );
478        circuit_plus.set_parameters(&params_plus)?;
479        let loss_plus = cost_fn(&circuit_plus)?;
480
481        // Shift parameter by -π/2
482        let mut circuit_minus = circuit.clone();
483        let mut params_minus = current_params;
484        params_minus.insert(
485            param_name.to_string(),
486            current_value - std::f64::consts::PI / 2.0,
487        );
488        circuit_minus.set_parameters(&params_minus)?;
489        let loss_minus = cost_fn(&circuit_minus)?;
490
491        Ok((loss_plus - loss_minus) / 2.0)
492    }
493
494    /// SPSA gradient approximation
495    fn spsa_gradient(
496        &self,
497        circuit: &VariationalCircuit,
498        param_name: &str,
499        cost_fn: &Arc<dyn Fn(&VariationalCircuit) -> QuantRS2Result<f64> + Send + Sync>,
500        epsilon: f64,
501    ) -> QuantRS2Result<f64> {
502        use scirs2_core::random::prelude::*;
503
504        let mut rng = if let Some(seed) = self.config.seed {
505            StdRng::seed_from_u64(seed)
506        } else {
507            StdRng::from_seed(thread_rng().random())
508        };
509
510        let current_params = circuit.get_parameters();
511        let perturbation = if rng.random::<bool>() {
512            epsilon
513        } else {
514            -epsilon
515        };
516
517        // Positive perturbation
518        let mut circuit_plus = circuit.clone();
519        let mut params_plus = current_params.clone();
520        for (name, value) in &mut params_plus {
521            if name == param_name {
522                *value += perturbation;
523            }
524        }
525        circuit_plus.set_parameters(&params_plus)?;
526        let loss_plus = cost_fn(&circuit_plus)?;
527
528        // Negative perturbation
529        let mut circuit_minus = circuit.clone();
530        let mut params_minus = current_params;
531        for (name, value) in &mut params_minus {
532            if name == param_name {
533                *value -= perturbation;
534            }
535        }
536        circuit_minus.set_parameters(&params_minus)?;
537        let loss_minus = cost_fn(&circuit_minus)?;
538
539        Ok((loss_plus - loss_minus) / (2.0 * perturbation))
540    }
541
542    /// Clip gradients by norm
543    fn clip_gradients(
544        &self,
545        mut gradients: FxHashMap<String, f64>,
546        max_norm: f64,
547    ) -> FxHashMap<String, f64> {
548        let norm = gradients.values().map(|g| g * g).sum::<f64>().sqrt();
549
550        if norm > max_norm {
551            let scale = max_norm / norm;
552            for grad in gradients.values_mut() {
553                *grad *= scale;
554            }
555        }
556
557        gradients
558    }
559
560    /// Update parameters based on optimization method
561    fn update_parameters(
562        &self,
563        circuit: &mut VariationalCircuit,
564        gradients: &FxHashMap<String, f64>,
565        state: &mut OptimizerState,
566    ) -> QuantRS2Result<()> {
567        let mut new_params = circuit.get_parameters();
568
569        match &self.method {
570            OptimizationMethod::GradientDescent { learning_rate } => {
571                // Simple gradient descent
572                for (param_name, &grad) in gradients {
573                    if let Some(value) = new_params.get_mut(param_name) {
574                        *value -= learning_rate * grad;
575                    }
576                }
577            }
578            OptimizationMethod::Momentum {
579                learning_rate,
580                momentum,
581            } => {
582                // Momentum-based gradient descent
583                for (param_name, &grad) in gradients {
584                    let velocity = state.momentum.entry(param_name.clone()).or_insert(0.0);
585                    *velocity = momentum * *velocity - learning_rate * grad;
586
587                    if let Some(value) = new_params.get_mut(param_name) {
588                        *value += *velocity;
589                    }
590                }
591            }
592            OptimizationMethod::Adam {
593                learning_rate,
594                beta1,
595                beta2,
596                epsilon,
597            } => {
598                // Adam optimizer
599                let t = state.iteration as f64 + 1.0;
600                let lr_t = learning_rate * (1.0 - beta2.powf(t)).sqrt() / (1.0 - beta1.powf(t));
601
602                for (param_name, &grad) in gradients {
603                    let m = state.adam_m.entry(param_name.clone()).or_insert(0.0);
604                    let v = state.adam_v.entry(param_name.clone()).or_insert(0.0);
605
606                    *m = (1.0 - beta1).mul_add(grad, beta1 * *m);
607                    *v = ((1.0 - beta2) * grad).mul_add(grad, beta2 * *v);
608
609                    if let Some(value) = new_params.get_mut(param_name) {
610                        *value -= lr_t * *m / (v.sqrt() + epsilon);
611                    }
612                }
613            }
614            OptimizationMethod::RMSprop {
615                learning_rate,
616                decay_rate,
617                epsilon,
618            } => {
619                // RMSprop optimizer
620                for (param_name, &grad) in gradients {
621                    let avg = state.rms_avg.entry(param_name.clone()).or_insert(0.0);
622                    *avg = ((1.0 - decay_rate) * grad).mul_add(grad, decay_rate * *avg);
623
624                    if let Some(value) = new_params.get_mut(param_name) {
625                        *value -= learning_rate * grad / (avg.sqrt() + epsilon);
626                    }
627                }
628            }
629            OptimizationMethod::NaturalGradient {
630                learning_rate,
631                regularization,
632            } => {
633                // Natural gradient descent
634                let fisher_inv =
635                    self.compute_fisher_inverse(circuit, gradients, *regularization)?;
636                let natural_grad = self.apply_fisher_inverse(&fisher_inv, gradients);
637
638                for (param_name, &nat_grad) in &natural_grad {
639                    if let Some(value) = new_params.get_mut(param_name) {
640                        *value -= learning_rate * nat_grad;
641                    }
642                }
643            }
644            OptimizationMethod::SPSA { a, alpha, .. } => {
645                // SPSA parameter update
646                let ak = a / (state.iteration as f64 + 1.0).powf(*alpha);
647
648                for (param_name, &grad) in gradients {
649                    if let Some(value) = new_params.get_mut(param_name) {
650                        *value -= ak * grad;
651                    }
652                }
653            }
654            OptimizationMethod::QNSPSA {
655                learning_rate,
656                regularization,
657                ..
658            } => {
659                // Quantum Natural SPSA
660                let fisher_inv =
661                    self.compute_fisher_inverse(circuit, gradients, *regularization)?;
662                let natural_grad = self.apply_fisher_inverse(&fisher_inv, gradients);
663
664                for (param_name, &nat_grad) in &natural_grad {
665                    if let Some(value) = new_params.get_mut(param_name) {
666                        *value -= learning_rate * nat_grad;
667                    }
668                }
669            }
670            _ => {
671                // Should not reach here for SciRS2 methods
672                return Err(QuantRS2Error::InvalidInput(
673                    "Invalid optimization method".to_string(),
674                ));
675            }
676        }
677
678        circuit.set_parameters(&new_params)
679    }
680
681    /// Compute Fisher information matrix inverse
682    fn compute_fisher_inverse(
683        &self,
684        circuit: &VariationalCircuit,
685        gradients: &FxHashMap<String, f64>,
686        regularization: f64,
687    ) -> QuantRS2Result<Array2<f64>> {
688        let param_names: Vec<_> = gradients.keys().cloned().collect();
689        let n_params = param_names.len();
690
691        // Check cache
692        if let Some(cache) = &self.fisher_cache {
693            let cached_matrix_opt = cache
694                .matrix
695                .lock()
696                .map_err(|e| QuantRS2Error::RuntimeError(format!("Lock poisoned: {e}")))?;
697            if let Some(cached_matrix) = cached_matrix_opt.as_ref() {
698                let cached_params_opt = cache
699                    .params
700                    .lock()
701                    .map_err(|e| QuantRS2Error::RuntimeError(format!("Lock poisoned: {e}")))?;
702                if let Some(cached_params) = cached_params_opt.as_ref() {
703                    let current_params: Vec<f64> = param_names
704                        .iter()
705                        .map(|name| circuit.get_parameters().get(name).copied().unwrap_or(0.0))
706                        .collect();
707
708                    let diff_norm: f64 = current_params
709                        .iter()
710                        .zip(cached_params.iter())
711                        .map(|(a, b)| (a - b).powi(2))
712                        .sum::<f64>()
713                        .sqrt();
714
715                    if diff_norm < cache.threshold {
716                        return Ok(cached_matrix.clone());
717                    }
718                }
719            }
720        }
721
722        // Compute Fisher information matrix
723        let mut fisher = Array2::zeros((n_params, n_params));
724
725        // Simplified Fisher matrix computation
726        // In practice, this would involve quantum state overlaps
727        for i in 0..n_params {
728            for j in i..n_params {
729                // Approximation: use gradient outer product
730                let value = gradients[&param_names[i]] * gradients[&param_names[j]];
731                fisher[[i, j]] = value;
732                fisher[[j, i]] = value;
733            }
734        }
735
736        // Add regularization
737        for i in 0..n_params {
738            fisher[[i, i]] += regularization;
739        }
740
741        // Compute inverse using scirs2_linalg::inv (SciRS2 POLICY compliant).
742        // Falls back to Gauss-Jordan elimination when scirs2_linalg is unavailable.
743        let fisher_inv = Self::invert_fisher_matrix(&fisher)?;
744
745        // Update cache
746        if let Some(cache) = &self.fisher_cache {
747            let current_params: Vec<f64> = param_names
748                .iter()
749                .map(|name| circuit.get_parameters().get(name).copied().unwrap_or(0.0))
750                .collect();
751
752            *cache
753                .matrix
754                .lock()
755                .map_err(|e| QuantRS2Error::RuntimeError(format!("Lock poisoned: {e}")))? =
756                Some(fisher_inv.clone());
757            *cache
758                .params
759                .lock()
760                .map_err(|e| QuantRS2Error::RuntimeError(format!("Lock poisoned: {e}")))? =
761                Some(current_params);
762        }
763
764        Ok(fisher_inv)
765    }
766
767    /// Apply Fisher information matrix inverse to gradients
768    fn apply_fisher_inverse(
769        &self,
770        fisher_inv: &Array2<f64>,
771        gradients: &FxHashMap<String, f64>,
772    ) -> FxHashMap<String, f64> {
773        let param_names: Vec<_> = gradients.keys().cloned().collect();
774        let grad_vec: Vec<f64> = param_names.iter().map(|name| gradients[name]).collect();
775
776        let grad_array = Array1::from_vec(grad_vec);
777        let natural_grad = fisher_inv.dot(&grad_array);
778
779        let mut result = FxHashMap::default();
780        for (i, name) in param_names.iter().enumerate() {
781            result.insert(name.clone(), natural_grad[i]);
782        }
783
784        result
785    }
786
787    /// Invert the Fisher information matrix using scirs2_linalg.
788    ///
789    /// Uses `scirs2_linalg::inv` for all sizes.  For the degenerate n==1 case it
790    /// also handles the trivial scalar inversion directly to avoid going through
791    /// the full LU path.  Returns `QuantRS2Error::InvalidInput` if the matrix is
792    /// singular (within a tolerance of 1e-14).
793    fn invert_fisher_matrix(fisher: &Array2<f64>) -> QuantRS2Result<Array2<f64>> {
794        let n = fisher.nrows();
795        if n == 0 {
796            return Ok(Array2::zeros((0, 0)));
797        }
798        if n == 1 {
799            let val = fisher[[0, 0]];
800            if val.abs() < 1e-14 {
801                return Err(QuantRS2Error::InvalidInput(
802                    "Fisher matrix is singular (1×1 zero)".to_string(),
803                ));
804            }
805            let mut inv_mat = Array2::zeros((1, 1));
806            inv_mat[[0, 0]] = 1.0 / val;
807            return Ok(inv_mat);
808        }
809
810        // Use scirs2_linalg::inv for general n×n inversion (SciRS2 POLICY)
811        scirs2_linalg::inv(&fisher.view(), None).map_err(|e| {
812            QuantRS2Error::InvalidInput(format!("Fisher matrix inversion failed: {e:?}"))
813        })
814    }
815}
816
817/// Optimization result
818#[derive(Debug, Clone)]
819pub struct OptimizationResult {
820    /// Optimal parameters
821    pub optimal_parameters: FxHashMap<String, f64>,
822    /// Final loss value
823    pub final_loss: f64,
824    /// Number of iterations
825    pub iterations: usize,
826    /// Whether optimization converged
827    pub converged: bool,
828    /// Total optimization time (seconds)
829    pub optimization_time: f64,
830    /// Full optimization history
831    pub history: OptimizationHistory,
832}
833
834/// Create optimized VQE optimizer
835pub fn create_vqe_optimizer() -> VariationalQuantumOptimizer {
836    let config = OptimizationConfig {
837        max_iterations: 200,
838        f_tol: 1e-10,
839        g_tol: 1e-10,
840        parallel_gradients: true,
841        grad_clip: Some(1.0),
842        ..Default::default()
843    };
844
845    VariationalQuantumOptimizer::new(OptimizationMethod::LBFGS { memory_size: 10 }, config)
846}
847
848/// Create optimized QAOA optimizer
849pub fn create_qaoa_optimizer() -> VariationalQuantumOptimizer {
850    let config = OptimizationConfig {
851        max_iterations: 100,
852        parallel_gradients: true,
853        ..Default::default()
854    };
855
856    VariationalQuantumOptimizer::new(OptimizationMethod::BFGS, config)
857}
858
859/// Create natural gradient optimizer
860pub fn create_natural_gradient_optimizer(learning_rate: f64) -> VariationalQuantumOptimizer {
861    let config = OptimizationConfig {
862        max_iterations: 100,
863        parallel_gradients: true,
864        ..Default::default()
865    };
866
867    VariationalQuantumOptimizer::new(
868        OptimizationMethod::NaturalGradient {
869            learning_rate,
870            regularization: 1e-4,
871        },
872        config,
873    )
874}
875
876/// Create SPSA optimizer for noisy quantum devices
877pub fn create_spsa_optimizer() -> VariationalQuantumOptimizer {
878    let config = OptimizationConfig {
879        max_iterations: 500,
880        seed: Some(42),
881        ..Default::default()
882    };
883
884    VariationalQuantumOptimizer::new(
885        OptimizationMethod::SPSA {
886            a: 0.1,
887            c: 0.1,
888            alpha: 0.602,
889            gamma: 0.101,
890        },
891        config,
892    )
893}
894
895/// Constrained optimization for variational circuits
896pub struct ConstrainedVariationalOptimizer {
897    /// Base optimizer
898    base_optimizer: VariationalQuantumOptimizer,
899    /// Constraints
900    constraints: Vec<Constraint>,
901}
902
903/// Constraint for optimization
904#[derive(Clone)]
905pub struct Constraint {
906    /// Constraint function
907    pub function: Arc<dyn Fn(&FxHashMap<String, f64>) -> f64 + Send + Sync>,
908    /// Constraint type
909    pub constraint_type: ConstraintType,
910    /// Constraint value
911    pub value: f64,
912}
913
914/// Constraint type
915#[derive(Debug, Clone, Copy)]
916pub enum ConstraintType {
917    /// Equality constraint
918    Eq,
919    /// Inequality constraint
920    Ineq,
921}
922
923impl ConstrainedVariationalOptimizer {
924    /// Create a new constrained optimizer
925    pub const fn new(base_optimizer: VariationalQuantumOptimizer) -> Self {
926        Self {
927            base_optimizer,
928            constraints: Vec::new(),
929        }
930    }
931
932    /// Add an equality constraint
933    pub fn add_equality_constraint(
934        &mut self,
935        constraint_fn: impl Fn(&FxHashMap<String, f64>) -> f64 + Send + Sync + 'static,
936        value: f64,
937    ) {
938        self.constraints.push(Constraint {
939            function: Arc::new(constraint_fn),
940            constraint_type: ConstraintType::Eq,
941            value,
942        });
943    }
944
945    /// Add an inequality constraint
946    pub fn add_inequality_constraint(
947        &mut self,
948        constraint_fn: impl Fn(&FxHashMap<String, f64>) -> f64 + Send + Sync + 'static,
949        value: f64,
950    ) {
951        self.constraints.push(Constraint {
952            function: Arc::new(constraint_fn),
953            constraint_type: ConstraintType::Ineq,
954            value,
955        });
956    }
957
958    /// Optimize with constraints
959    pub fn optimize(
960        &mut self,
961        circuit: &mut VariationalCircuit,
962        cost_fn: impl Fn(&VariationalCircuit) -> QuantRS2Result<f64> + Send + Sync + 'static,
963    ) -> QuantRS2Result<OptimizationResult> {
964        if self.constraints.is_empty() {
965            return self.base_optimizer.optimize(circuit, cost_fn);
966        }
967
968        // For constrained optimization, use penalty method
969        let cost_fn = Arc::new(cost_fn);
970        let constraints = self.constraints.clone();
971        let penalty_weight = 1000.0;
972
973        let penalized_cost = move |circuit: &VariationalCircuit| -> QuantRS2Result<f64> {
974            let base_cost = cost_fn(circuit)?;
975            let params = circuit.get_parameters();
976
977            let mut penalty = 0.0;
978            for constraint in &constraints {
979                let constraint_value = (constraint.function)(&params);
980                match constraint.constraint_type {
981                    ConstraintType::Eq => {
982                        penalty += penalty_weight * (constraint_value - constraint.value).powi(2);
983                    }
984                    ConstraintType::Ineq => {
985                        if constraint_value > constraint.value {
986                            penalty +=
987                                penalty_weight * (constraint_value - constraint.value).powi(2);
988                        }
989                    }
990                }
991            }
992
993            Ok(base_cost + penalty)
994        };
995
996        self.base_optimizer.optimize(circuit, penalized_cost)
997    }
998}
999
1000/// Hyperparameter optimization for variational circuits
1001pub struct HyperparameterOptimizer {
1002    /// Search space for hyperparameters
1003    search_space: FxHashMap<String, (f64, f64)>,
1004    /// Number of trials
1005    n_trials: usize,
1006    /// Optimization method for inner loop
1007    inner_method: OptimizationMethod,
1008}
1009
1010impl HyperparameterOptimizer {
1011    /// Create a new hyperparameter optimizer
1012    pub fn new(n_trials: usize) -> Self {
1013        Self {
1014            search_space: FxHashMap::default(),
1015            n_trials,
1016            inner_method: OptimizationMethod::BFGS,
1017        }
1018    }
1019
1020    /// Add a hyperparameter to search
1021    pub fn add_hyperparameter(&mut self, name: String, min_value: f64, max_value: f64) {
1022        self.search_space.insert(name, (min_value, max_value));
1023    }
1024
1025    /// Optimize hyperparameters
1026    pub fn optimize(
1027        &self,
1028        circuit_builder: impl Fn(&FxHashMap<String, f64>) -> VariationalCircuit + Send + Sync,
1029        cost_fn: impl Fn(&VariationalCircuit) -> QuantRS2Result<f64> + Send + Sync + Clone + 'static,
1030    ) -> QuantRS2Result<HyperparameterResult> {
1031        use scirs2_core::random::prelude::*;
1032
1033        let mut rng = StdRng::from_seed(thread_rng().random());
1034        let mut best_hyperparams = FxHashMap::default();
1035        let mut best_loss = f64::INFINITY;
1036        let mut all_trials = Vec::new();
1037
1038        for _trial in 0..self.n_trials {
1039            // Sample hyperparameters
1040            let mut hyperparams = FxHashMap::default();
1041            for (name, &(min_val, max_val)) in &self.search_space {
1042                let value = rng.random_range(min_val..max_val);
1043                hyperparams.insert(name.clone(), value);
1044            }
1045
1046            // Build circuit with hyperparameters
1047            let mut circuit = circuit_builder(&hyperparams);
1048
1049            // Optimize circuit
1050            let config = OptimizationConfig {
1051                max_iterations: 50,
1052                ..Default::default()
1053            };
1054
1055            let mut optimizer = VariationalQuantumOptimizer::new(self.inner_method.clone(), config);
1056
1057            let result = optimizer.optimize(&mut circuit, cost_fn.clone())?;
1058
1059            all_trials.push(HyperparameterTrial {
1060                hyperparameters: hyperparams.clone(),
1061                final_loss: result.final_loss,
1062                optimal_parameters: result.optimal_parameters,
1063            });
1064
1065            if result.final_loss < best_loss {
1066                best_loss = result.final_loss;
1067                best_hyperparams = hyperparams;
1068            }
1069        }
1070
1071        Ok(HyperparameterResult {
1072            best_hyperparameters: best_hyperparams,
1073            best_loss,
1074            all_trials,
1075        })
1076    }
1077}
1078
1079/// Hyperparameter optimization result
1080#[derive(Debug, Clone)]
1081pub struct HyperparameterResult {
1082    /// Best hyperparameters found
1083    pub best_hyperparameters: FxHashMap<String, f64>,
1084    /// Best loss achieved
1085    pub best_loss: f64,
1086    /// All trials
1087    pub all_trials: Vec<HyperparameterTrial>,
1088}
1089
1090/// Single hyperparameter trial
1091#[derive(Debug, Clone)]
1092pub struct HyperparameterTrial {
1093    /// Hyperparameters used
1094    pub hyperparameters: FxHashMap<String, f64>,
1095    /// Final loss achieved
1096    pub final_loss: f64,
1097    /// Optimal variational parameters
1098    pub optimal_parameters: FxHashMap<String, f64>,
1099}
1100
1101// Clone implementation for VariationalCircuit
1102impl Clone for VariationalCircuit {
1103    fn clone(&self) -> Self {
1104        Self {
1105            gates: self.gates.clone(),
1106            param_map: self.param_map.clone(),
1107            num_qubits: self.num_qubits,
1108        }
1109    }
1110}
1111
1112#[cfg(test)]
1113mod tests {
1114    use super::*;
1115    use crate::qubit::QubitId;
1116    use crate::variational::VariationalGate;
1117
1118    #[test]
1119    fn test_gradient_descent_optimizer() {
1120        let mut circuit = VariationalCircuit::new(1);
1121        circuit.add_gate(VariationalGate::rx(QubitId(0), "theta".to_string(), 0.0));
1122
1123        let config = OptimizationConfig {
1124            max_iterations: 10,
1125            ..Default::default()
1126        };
1127
1128        let mut optimizer = VariationalQuantumOptimizer::new(
1129            OptimizationMethod::GradientDescent { learning_rate: 0.1 },
1130            config,
1131        );
1132
1133        // Simple cost function
1134        let cost_fn = |circuit: &VariationalCircuit| -> QuantRS2Result<f64> {
1135            let theta = circuit
1136                .get_parameters()
1137                .get("theta")
1138                .copied()
1139                .unwrap_or(0.0);
1140            Ok((theta - 1.0).powi(2))
1141        };
1142
1143        let result = optimizer
1144            .optimize(&mut circuit, cost_fn)
1145            .expect("Optimization should succeed");
1146
1147        assert!(result.converged || result.iterations == 10);
1148        assert!((result.optimal_parameters["theta"] - 1.0).abs() < 0.1);
1149    }
1150
1151    #[test]
1152    fn test_adam_optimizer() {
1153        let mut circuit = VariationalCircuit::new(2);
1154        circuit.add_gate(VariationalGate::ry(QubitId(0), "alpha".to_string(), 0.5));
1155        circuit.add_gate(VariationalGate::rz(QubitId(1), "beta".to_string(), 0.5));
1156
1157        let config = OptimizationConfig {
1158            max_iterations: 100,
1159            f_tol: 1e-6,
1160            g_tol: 1e-6,
1161            ..Default::default()
1162        };
1163
1164        let mut optimizer = VariationalQuantumOptimizer::new(
1165            OptimizationMethod::Adam {
1166                learning_rate: 0.1,
1167                beta1: 0.9,
1168                beta2: 0.999,
1169                epsilon: 1e-8,
1170            },
1171            config,
1172        );
1173
1174        // Cost function with multiple parameters
1175        let cost_fn = |circuit: &VariationalCircuit| -> QuantRS2Result<f64> {
1176            let params = circuit.get_parameters();
1177            let alpha = params.get("alpha").copied().unwrap_or(0.0);
1178            let beta = params.get("beta").copied().unwrap_or(0.0);
1179            Ok(alpha.powi(2) + beta.powi(2))
1180        };
1181
1182        let result = optimizer
1183            .optimize(&mut circuit, cost_fn)
1184            .expect("Optimization should succeed");
1185
1186        assert!(result.optimal_parameters["alpha"].abs() < 0.1);
1187        assert!(result.optimal_parameters["beta"].abs() < 0.1);
1188    }
1189
1190    #[test]
1191    fn test_constrained_optimization() {
1192        let mut circuit = VariationalCircuit::new(1);
1193        circuit.add_gate(VariationalGate::rx(QubitId(0), "x".to_string(), 2.0));
1194
1195        let base_optimizer =
1196            VariationalQuantumOptimizer::new(OptimizationMethod::BFGS, Default::default());
1197
1198        let mut constrained_opt = ConstrainedVariationalOptimizer::new(base_optimizer);
1199
1200        // Add constraint: x >= 1.0
1201        constrained_opt
1202            .add_inequality_constraint(|params| 1.0 - params.get("x").copied().unwrap_or(0.0), 0.0);
1203
1204        // Minimize x^2
1205        let cost_fn = |circuit: &VariationalCircuit| -> QuantRS2Result<f64> {
1206            let x = circuit.get_parameters().get("x").copied().unwrap_or(0.0);
1207            Ok(x.powi(2))
1208        };
1209
1210        let result = constrained_opt
1211            .optimize(&mut circuit, cost_fn)
1212            .expect("Constrained optimization should succeed");
1213
1214        let optimized_x = result.optimal_parameters["x"];
1215        assert!(optimized_x >= 1.0 - 1e-6);
1216        assert!(optimized_x <= 2.0 + 1e-6);
1217    }
1218}